├── .github ├── CODEOWNERS ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ ├── feature_request.md │ └── question.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── stale.yml └── workflows │ ├── codestyle.yml │ ├── deploy_publish.yml │ ├── deploy_push.yml │ ├── dl_cpu.yml │ ├── dl_cpu_minimal.yml │ ├── greetings.yml │ └── integrations.yml ├── .gitignore ├── .mergify.yml ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CITATION ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.md ├── Makefile ├── README.md ├── bin ├── scripts │ ├── catalyst-parallel-run │ ├── download-gdrive │ ├── extract-archive │ └── install-apex └── workflows │ ├── check_config_api.sh │ ├── check_docs.sh │ ├── check_pipelines.sh │ ├── check_projector.sh │ └── check_settings.sh ├── catalyst ├── __init__.py ├── __version__.py ├── callbacks │ ├── __init__.py │ ├── backward.py │ ├── batch_overfit.py │ ├── batch_transform.py │ ├── checkpoint.py │ ├── control_flow.py │ ├── criterion.py │ ├── metric.py │ ├── metric_aggregation.py │ ├── metrics │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── auc.py │ │ ├── classification.py │ │ ├── cmc_score.py │ │ ├── confusion_matrix.py │ │ ├── functional_metric.py │ │ ├── r2_squared.py │ │ ├── recsys.py │ │ ├── scikit_learn.py │ │ └── segmentation.py │ ├── misc.py │ ├── mixup.py │ ├── onnx.py │ ├── optimizer.py │ ├── optuna.py │ ├── periodic_loader.py │ ├── profiler.py │ ├── pruning.py │ ├── quantization.py │ ├── scheduler.py │ ├── sklearn_model.py │ ├── soft_update.py │ └── tracing.py ├── contrib │ ├── README.md │ ├── __init__.py │ ├── __main__.py │ ├── data │ │ ├── __init__.py │ │ ├── _misc.py │ │ ├── collate_fn.py │ │ ├── dataset.py │ │ ├── dataset_cv.py │ │ ├── dataset_ml.py │ │ ├── reader.py │ │ ├── reader_cv.py │ │ ├── sampler.py │ │ ├── sampler_inbatch.py │ │ └── transforms.py │ ├── datasets │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cifar.py │ │ ├── imagecar.py │ │ ├── imagenette.py │ │ ├── imagewang.py │ │ ├── imagewoof.py │ │ ├── market1501.py │ │ ├── misc.py │ │ ├── misc_cv.py │ │ ├── mnist.py │ │ └── movielens.py │ ├── layers │ │ ├── __init__.py │ │ ├── amsoftmax.py │ │ ├── arcface.py │ │ ├── arcmargin.py │ │ ├── common.py │ │ ├── cosface.py │ │ ├── curricularface.py │ │ ├── factorized.py │ │ ├── lama.py │ │ ├── pooling.py │ │ ├── rms_norm.py │ │ ├── se.py │ │ └── softmax.py │ ├── losses │ │ ├── __init__.py │ │ ├── ce.py │ │ ├── circle.py │ │ ├── contrastive.py │ │ ├── dice.py │ │ ├── focal.py │ │ ├── functional.py │ │ ├── gan.py │ │ ├── iou.py │ │ ├── lovasz.py │ │ ├── margin.py │ │ ├── ntxent.py │ │ ├── recsys.py │ │ ├── regression.py │ │ ├── smoothing_dice.py │ │ ├── supervised_contrastive.py │ │ ├── trevsky.py │ │ ├── triplet.py │ │ └── wing.py │ ├── models │ │ ├── __init__.py │ │ ├── mnist.py │ │ └── resnet_encoder.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── adamp.py │ │ ├── lamb.py │ │ ├── lookahead.py │ │ ├── qhadamw.py │ │ ├── radam.py │ │ ├── ralamb.py │ │ └── sgdp.py │ ├── schedulers │ │ ├── __init__.py │ │ ├── base.py │ │ └── onecycle.py │ ├── scripts │ │ ├── __init__.py │ │ ├── collect_env.py │ │ ├── project_embeddings.py │ │ ├── run.py │ │ └── tune.py │ └── utils │ │ ├── __init__.py │ │ ├── compression.py │ │ ├── image.py │ │ ├── numpy.py │ │ ├── report.py │ │ ├── serialization.py │ │ ├── swa.py │ │ ├── thresholds.py │ │ ├── torch.py │ │ └── visualization.py ├── core │ ├── __init__.py │ ├── callback.py │ ├── engine.py │ ├── logger.py │ ├── misc.py │ └── runner.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── loader.py │ └── sampler.py ├── dl │ └── __init__.py ├── engines │ ├── __init__.py │ └── torch.py ├── extras │ ├── __init__.py │ ├── forward_wrapper.py │ ├── frozen_class.py │ ├── metric_handler.py │ └── time_manager.py ├── loggers │ ├── __init__.py │ ├── comet.py │ ├── console.py │ ├── csv.py │ ├── mlflow.py │ ├── neptune.py │ ├── tensorboard.py │ └── wandb.py ├── metrics │ ├── __init__.py │ ├── _accumulative.py │ ├── _accuracy.py │ ├── _additive.py │ ├── _auc.py │ ├── _classification.py │ ├── _cmc_score.py │ ├── _confusion_matrix.py │ ├── _functional_metric.py │ ├── _hitrate.py │ ├── _map.py │ ├── _metric.py │ ├── _mrr.py │ ├── _ndcg.py │ ├── _r2_squared.py │ ├── _segmentation.py │ ├── _topk_metric.py │ └── functional │ │ ├── __init__.py │ │ ├── _accuracy.py │ │ ├── _auc.py │ │ ├── _average_precision.py │ │ ├── _classification.py │ │ ├── _cmc_score.py │ │ ├── _f1_score.py │ │ ├── _focal.py │ │ ├── _hitrate.py │ │ ├── _misc.py │ │ ├── _mrr.py │ │ ├── _ndcg.py │ │ ├── _precision.py │ │ ├── _r2_squared.py │ │ ├── _recall.py │ │ └── _segmentation.py ├── registry.py ├── runners │ ├── __init__.py │ ├── runner.py │ └── supervised.py ├── settings.py ├── typing.py └── utils │ ├── __init__.py │ ├── config.py │ ├── distributed.py │ ├── misc.py │ ├── onnx.py │ ├── pruning.py │ ├── quantization.py │ ├── torch.py │ └── tracing.py ├── docker ├── Dockerfile ├── Dockerfile-dev ├── README.md └── collect_dependencies_hash.py ├── docs ├── Makefile ├── README.md ├── api │ ├── callbacks.rst │ ├── contrib.rst │ ├── core.rst │ ├── data.rst │ ├── engines.rst │ ├── loggers.rst │ ├── metrics.rst │ ├── runners.rst │ └── utils.rst ├── conf.py ├── core │ ├── callback.rst │ ├── engine.rst │ ├── logger.rst │ ├── metric.rst │ └── runner.rst ├── create_api.sh ├── faq │ ├── architecture.rst │ ├── checkpointing.rst │ ├── dataflow.rst │ ├── ddp.rst │ ├── debugging.rst │ ├── dp.rst │ ├── early_stopping.rst │ ├── engines.rst │ ├── inference.rst │ ├── intro.rst │ ├── logging.rst │ ├── mixed_precision.rst │ ├── multi_components.rst │ ├── multi_keys.rst │ ├── optuna.rst │ └── settings.rst ├── getting_started │ └── quickstart.rst ├── index.rst ├── key.enc ├── requirements.txt └── tutorials │ └── ddp.rst ├── examples ├── README.md ├── catalyst_rl │ ├── README.md │ ├── buffer.py │ ├── db.py │ ├── ddpg.py │ ├── dqn.py │ ├── misc.py │ └── sampler.py ├── detection │ ├── README.md │ ├── __init__.py │ ├── callbacks.py │ ├── configs │ │ ├── centernet-config.yaml │ │ ├── ssd-config.yaml │ │ └── yolo-x-config.yaml │ ├── criterion.py │ ├── custom_runner.py │ ├── dataset.py │ ├── models │ │ ├── __init__.py │ │ ├── centernet.py │ │ ├── ssd.py │ │ └── yolo_x.py │ ├── requirements.txt │ ├── to_coco.py │ └── utils.py ├── engines │ ├── README.md │ ├── src.py │ ├── train_albert.py │ └── train_resnet.py ├── notebooks │ ├── DALI.ipynb │ ├── XLA.ipynb │ ├── XLA_ddp.ipynb │ ├── colab_ci_cd.ipynb │ ├── customization_tutorial.ipynb │ ├── customizing_what_happens_in_train.ipynb │ ├── neptune.ipynb │ └── reinforcement_learning.ipynb ├── recsys │ ├── README.md │ ├── macridvae.py │ ├── multidae.py │ └── multivae.py ├── reinforcement_learning │ ├── README.md │ ├── ddpg.py │ ├── dqn.py │ └── reinforce.py └── self_supervised │ ├── .dockerignore │ ├── Dockerfile │ ├── README.md │ ├── barlow_twins.py │ ├── byol.py │ ├── check.sh │ ├── run.sh │ ├── simCLR.py │ ├── src │ ├── __init__.py │ ├── common.py │ ├── datasets.py │ └── runner.py │ └── supervised_contrastive.py ├── pyproject.toml ├── requirements ├── requirements-comet.txt ├── requirements-cv.txt ├── requirements-deepspeed.txt ├── requirements-dev.txt ├── requirements-ml.txt ├── requirements-mlflow.txt ├── requirements-neptune.txt ├── requirements-onnx-gpu.txt ├── requirements-onnx.txt ├── requirements-optuna.txt ├── requirements-profiler.txt ├── requirements-wandb.txt ├── requirements-xla.txt └── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── benchmarks ├── __init__.py └── test_benchmark.py ├── catalyst ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── test_batch_overfit.py │ ├── test_batch_transform.py │ ├── test_checkpoint.py │ ├── test_control_flow.py │ ├── test_metric_aggregation.py │ ├── test_misc.py │ ├── test_mixup.py │ ├── test_optimizer.py │ ├── test_optuna.py │ ├── test_periodic_loader.py │ ├── test_profiler.py │ ├── test_pruning.py │ ├── test_quantization.py │ ├── test_sklearn_model.py │ ├── test_soft_update.py │ ├── test_tracing.py │ └── test_wrapper.py ├── data │ ├── __init__.py │ ├── test_data.py │ ├── test_loader.py │ └── test_sampler.py ├── dl │ └── __init__.py ├── metrics │ ├── __init__.py │ ├── functional │ │ ├── __init__.py │ │ ├── test_accuracy.py │ │ ├── test_auc.py │ │ ├── test_average_precision.py │ │ ├── test_classification.py │ │ ├── test_cmc_metric.py │ │ ├── test_dice.py │ │ ├── test_fbeta_precision_recall.py │ │ ├── test_hitrate.py │ │ ├── test_iou.py │ │ ├── test_misc.py │ │ ├── test_mrr.py │ │ ├── test_ndcg.py │ │ ├── test_r2_squared.py │ │ └── test_region_statictics.py │ ├── test_accuracy.py │ ├── test_additive.py │ ├── test_classification.py │ ├── test_cmc_score.py │ ├── test_functional_metric.py │ ├── test_r2squared.py │ └── test_segmentation.py ├── runners │ ├── __init__.py │ ├── test_profiler.py │ ├── test_reid.py │ ├── test_runner.py │ ├── test_sklearn_classifier.py │ └── test_sklearn_classifier_mnist.py └── utils │ ├── __init__.py │ ├── test_config.py │ ├── test_hash.py │ ├── test_onnx.py │ └── test_quantization.py ├── contrib ├── __init__.py ├── data │ ├── __init__.py │ ├── test_dataset.py │ ├── test_nifti.py │ └── test_sampler_inbatch.py ├── datasets │ ├── __init__.py │ ├── test_movielens.py │ └── test_movielens_20m.py ├── nn │ ├── __init__.py │ ├── test_criterion.py │ ├── test_modules.py │ ├── test_onecyle.py │ └── test_optimizer.py ├── scripts │ ├── __init__.py │ ├── test_main.py │ ├── test_tag2label.py │ ├── test_tune.py │ └── test_tune.yml └── utils │ ├── __init__.py │ ├── test_image.py │ ├── test_swa.py │ └── test_torch.py ├── misc.py └── pipelines ├── __init__.py ├── configs ├── engine_core.yml ├── engine_cpu.yml ├── engine_ddp.yml ├── engine_ddp_amp.yml ├── engine_ddp_xla.yml ├── engine_dp.yml ├── engine_dp_amp.yml ├── engine_gpu.yml ├── engine_gpu_amp.yml ├── test_classification.yml ├── test_gan.yml ├── test_metric_learning.yml ├── test_mnist.yml ├── test_mnist_custom.yml ├── test_mnist_multicriterion.yml ├── test_mnist_multimodel.yml ├── test_mnist_multioptimizer.yml ├── test_multihead_classification.yml ├── test_multilabel_classification.yml ├── test_recsys.yml ├── test_regression.yml └── test_vae.yml ├── test_classification.py ├── test_contrastive.py ├── test_distillation.py ├── test_gan.py ├── test_metric_learning.py ├── test_mnist.py ├── test_mnist_custom.py ├── test_mnist_multicriterion.py ├── test_mnist_multimodel.py ├── test_mnist_multioptimizer.py ├── test_multihead_classification.py ├── test_multilabel_classification.py ├── test_optuna.py ├── test_recsys.py ├── test_regression.py └── test_vae.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Lines starting with '#' are comments. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # These owners will be the default owners for everything in the repo. 5 | * @Scitator @bagxi @Ditwoo 6 | 7 | # Order is important. The last matching pattern has the most precedence. 8 | # So if a pull request only touches javascript files, only these owners 9 | # will be requested to review. 10 | 11 | 12 | # You can also use email addresses if you prefer. 13 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: catalyst_team 5 | open_collective: catalyst 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and documentation fixes 3 | about: Make awesome documentation and tutorials 4 | title: '' 5 | labels: '' 6 | assignees: Scitator, bagxi, Ditwoo 7 | 8 | --- 9 | 10 | ## 📚 Documentation 11 | 12 | For typos and documentation fixes, please go ahead and: 13 | 14 | 1. Create an issue. 15 | 2. Fix the typo. 16 | 3. Submit a PR. 17 | 18 | Thank you! 19 | 20 | 21 | ### Checklist 22 | - [ ] proposal / bug description 23 | - [ ] motivation / expected 24 | 25 | 26 | ### FAQ 27 | Please review the FAQ before submitting an issue: 28 | - [ ] I have read the [documentation and FAQ](https://catalyst-team.github.io/catalyst/) 29 | - [ ] I have reviewed the [minimal examples section](https://github.com/catalyst-team/catalyst#minimal-examples) 30 | - [ ] I have checked the [changelog](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md) for main framework updates 31 | - [ ] I have read the [contribution guide](https://github.com/catalyst-team/catalyst/blob/master/CONTRIBUTING.md) 32 | - [ ] I have joined [Catalyst slack (#__questions channel)](https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw) for issue discussion 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: Scitator, bagxi, Ditwoo 7 | 8 | --- 9 | 10 | ## 🚀 Feature Request 11 | 12 | 13 | 14 | ### Motivation 15 | 16 | 17 | 18 | ### Proposal 19 | 23 | 24 | 25 | ### Alternatives 26 | 27 | 28 | 29 | ### Additional context 30 | 31 | 32 | 33 | ### Checklist 34 | - [ ] feature proposal description 35 | - [ ] motivation 36 | - [ ] extra proposal context / proposal alternatives review 37 | 38 | 39 | ### FAQ 40 | Please review the FAQ before submitting an issue: 41 | - [ ] I have read the [documentation and FAQ](https://catalyst-team.github.io/catalyst/) 42 | - [ ] I have reviewed the [minimal examples section](https://github.com/catalyst-team/catalyst#minimal-examples) 43 | - [ ] I have checked the [changelog](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md) for main framework updates 44 | - [ ] I have read the [contribution guide](https://github.com/catalyst-team/catalyst/blob/master/CONTRIBUTING.md) 45 | - [ ] I have joined [Catalyst slack (#__questions channel)](https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw) for issue discussion 46 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: How to question 3 | about: Asking how-to questions 4 | title: '' 5 | labels: help wanted, question 6 | assignees: Scitator, bagxi, Ditwoo 7 | 8 | --- 9 | 10 | ## ❓ Questions and Help 11 | 12 | ### Before asking: 13 | 1. search the issues. 14 | 2. search the docs. 15 | 16 | 17 | 18 | 19 | #### What is your question? 20 | 21 | 22 | #### Code 23 | 24 | 25 | 26 | #### What have you tried? 27 | 28 | 29 | #### What's your environment? 30 | Please copy and paste the output from our environment collection script 31 | ```bash 32 | catalyst-contrib --collect-env 33 | # or manually 34 | wget https://raw.githubusercontent.com/catalyst-team/catalyst/master/catalyst/contrib/scripts/collect_env.py 35 | python collect_env.py 36 | ``` 37 | (or fill out the checklist below manually). 38 | 39 | ```bash 40 | # example checklist, fill with your info 41 | Catalyst version: 20.04.2 42 | PyTorch version: 1.4.0 43 | TensorFlow version: N/A 44 | TensorBoard version: 2.2.1 45 | OS: Mac OSX 10.15.4 46 | Python version: 3.7 47 | CUDA runtime version: No CUDA 48 | Nvidia driver version: No CUDA 49 | cuDNN version: No CUDA 50 | ``` 51 | 52 | 53 | ### Additional context 54 | 55 | 56 | 57 | ### FAQ 58 | Please review the FAQ before submitting an issue: 59 | - [ ] I have read the [documentation and FAQ](https://catalyst-team.github.io/catalyst/) 60 | - [ ] I have reviewed the [minimal examples section](https://github.com/catalyst-team/catalyst#minimal-examples) 61 | - [ ] I have checked the [changelog](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md) for main framework updates 62 | - [ ] I have read the [contribution guide](https://github.com/catalyst-team/catalyst/blob/master/CONTRIBUTING.md) 63 | - [ ] I have joined [Catalyst slack (#__questions channel)](https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw) for issue discussion 64 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: docker 4 | directory: "/docker" 5 | schedule: 6 | interval: daily 7 | open-pull-requests-limit: 10 8 | labels: 9 | - dependencies 10 | ignore: 11 | - dependency-name: nvidia/cuda 12 | versions: 13 | - ">= 0" 14 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/stale 2 | # Number of days of inactivity before an issue becomes stale 3 | daysUntilStale: 60 4 | # Number of days of inactivity before a stale issue is closed 5 | daysUntilClose: 14 6 | # Issues with these labels will never be considered stale 7 | exemptLabels: 8 | - pinned 9 | - security 10 | # Label to use when marking an issue as stale 11 | staleLabel: wontfix 12 | # Comment to post when marking an issue as stale. Set to `false` to disable 13 | markComment: > 14 | This issue has been automatically marked as stale because it has not had 15 | recent activity. It will be closed if no further activity occurs. Thank you 16 | for your contributions. 17 | # Comment to post when closing a stale issue. Set to `false` to disable 18 | closeComment: false -------------------------------------------------------------------------------- /.github/workflows/greetings.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/actions/first-interaction 2 | name: Greetings 3 | on: [issues] # pull_request 4 | 5 | jobs: 6 | greeting: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/first-interaction@v1 10 | with: 11 | repo-token: ${{ secrets.GITHUB_TOKEN }} 12 | issue-message: 'Hi! Thank you for your contribution! Please re-check all issue template checklists - unfilled issues would be closed automatically. And do not forget to [join our slack](https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw) for collaboration.' 13 | pr-message: 'Hey thanks for the pull request! Please re-check all PR template checklists - unfilled PRs would be closed automatically. And do not forget to [join our slack](https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw) for collaboration.' 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *pytest_cache* 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | builds/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | 106 | 107 | .DS_Store 108 | .idea 109 | .code 110 | 111 | *.bak 112 | *.csv 113 | *.tsv 114 | *.ipynb 115 | *.pt 116 | *.pth 117 | 118 | tmp/ 119 | logs/ 120 | # Examples - mock data 121 | !examples/distilbert_text_classification/input/*.csv 122 | !tests/_tests_nlp_classification/input/*.csv 123 | tests/logs/ 124 | notebooks/ 125 | 126 | _nogit* 127 | 128 | ### VisualStudioCode ### 129 | .vscode/* 130 | .vscode/settings.json 131 | !.vscode/tasks.json 132 | !.vscode/launch.json 133 | !.vscode/extensions.json 134 | 135 | ### VisualStudioCode Patch ### 136 | # Ignore all local history of files 137 | .history 138 | 139 | # End of https://www.gitignore.io/api/visualstudiocode 140 | 141 | presets/ 142 | _todo/ 143 | tests/MNIST/ 144 | codestyle.txt -------------------------------------------------------------------------------- /.mergify.yml: -------------------------------------------------------------------------------- 1 | pull_request_rules: 2 | # removes reviews done by collaborators when the pull request is updated 3 | - name: remove outdated reviews 4 | conditions: 5 | - base=master 6 | actions: 7 | dismiss_reviews: 8 | 9 | # automatic merge for master when required CI passes 10 | - name: automatic merge for master when CI passes and 2 review 11 | conditions: 12 | - "#approved-reviews-by>=2" 13 | - label!=WIP 14 | actions: 15 | merge: 16 | method: squash 17 | # - name: automatic merge for master when CI passes and 2 review 18 | # conditions: 19 | # - "#approved-reviews-by>=3" 20 | # - approved-reviews-by=github-actions 21 | # - label!=WIP 22 | # actions: 23 | # merge: 24 | # method: squash 25 | 26 | # deletes the head branch of the pull request, that is the branch which hosts the commits 27 | - name: delete head branch after merge 28 | conditions: 29 | - merged 30 | actions: 31 | delete_head_branch: {} 32 | 33 | # ask author of PR to resolve conflict 34 | - name: ask to resolve conflict 35 | conditions: 36 | - conflict 37 | actions: 38 | comment: 39 | message: "This pull request is now in conflicts. @{{ author }}, could you fix it? 🙏" 40 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/catalyst-team/codestyle 3 | rev: 'v21.09.2' 4 | hooks: 5 | - id: catalyst-make-codestyle 6 | args: [--line-length=89] 7 | - repo: https://github.com/catalyst-team/codestyle 8 | rev: 'v21.09.2' 9 | hooks: 10 | - id: catalyst-check-codestyle 11 | args: [--line-length=89] 12 | exclude: __init__.py 13 | -------------------------------------------------------------------------------- /CITATION: -------------------------------------------------------------------------------- 1 | @misc{catalyst, 2 | author = {Kolesnikov, Sergey}, 3 | title = {Catalyst - Accelerated deep learning R&D}, 4 | year = {2018}, 5 | publisher = {GitHub}, 6 | journal = {GitHub repository}, 7 | howpublished = {\url{https://github.com/catalyst-team/catalyst}}, 8 | } -------------------------------------------------------------------------------- /MANIFEST.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | [![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst) 4 | 5 | ### Manifest | Value | Mission 6 |
7 | 8 | ----- 9 | 10 | Writing Open Source Ecosystem for DL/RL research is challenging.
11 | Let's explain why we are doing this. 12 | 13 | ### Why? 14 | For real breakthroughs in DL/RL area, we need a Foundation.
15 | For DL/RL Foundation, we need an Open Source R&D Ecosystem.
16 | The Catalyst was designed by researchers for researchers, 17 | - to get the foundation for DL & RL research, which would be well tested and verified 18 | - to generalize and develop best practices that work in all areas of DL 19 | - to think less about the tech-side and focus on research hypothesis and their value 20 | - to accelerate your research 21 | 22 | ### What? 23 | We are creating a research Ecosystem, which 24 | - combines research best practices and helps knowledge sharing 25 | - connects DL/RL with Software Engineer for maximum performance 26 | - allows quickly and efficiently validate your hypotheses 27 | 28 | ### How? 29 | - open – it's entirely Open Source Ecosystem 30 | - equivalently – everyone can contribute, propose new ideas, have their own options 31 | - expertise – we are gathering top knowledge into one place 32 | - high-performance – we are developing the ecosystem with SE best practices 33 | 34 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: check-docs docker docker-dev install-from-source clean 2 | 3 | PYTHON ?= python 4 | TAG=$(shell ${PYTHON} docker/collect_dependencies_hash.py) 5 | 6 | check-docs: 7 | bash ./bin/workflows/check_docs.sh 8 | 9 | docker: 10 | echo building $${REPO_NAME:-catalyst}:$${TAG:-latest} ... 11 | docker build \ 12 | -t $${REPO_NAME:-catalyst}:$${TAG:-latest} . \ 13 | -f ./docker/Dockerfile --no-cache \ 14 | --build-arg PYTORCH_TAG=$$PYTORCH_TAG \ 15 | --build-arg CATALYST_DEV=$$CATALYST_DEV \ 16 | --build-arg CATALYST_CV=$$CATALYST_CV \ 17 | --build-arg CATALYST_ML=$$CATALYST_ML \ 18 | --build-arg CATALYST_OPTUNA=$$CATALYST_OPTUNA \ 19 | --build-arg CATALYST_ONNX=$$CATALYST_ONNX \ 20 | --build-arg CATALYST_ONNX_GPU=$$CATALYST_ONNX_GPU 21 | 22 | docker-dev: 23 | echo building $${REPO_NAME:-catalyst-dev}:$${TAG:-latest} ... 24 | docker build \ 25 | -t $${REPO_NAME:-catalyst-dev}:$${TAG:-latest} . \ 26 | -f ./docker/Dockerfile --no-cache \ 27 | --build-arg PYTORCH_TAG=$$PYTORCH_TAG \ 28 | --build-arg CATALYST_DEV=$$CATALYST_DEV \ 29 | --build-arg CATALYST_CV=$$CATALYST_CV \ 30 | --build-arg CATALYST_ML=$$CATALYST_ML \ 31 | --build-arg CATALYST_OPTUNA=$$CATALYST_OPTUNA \ 32 | --build-arg CATALYST_ONNX=$$CATALYST_ONNX \ 33 | --build-arg CATALYST_ONNX_GPU=$$CATALYST_ONNX_GPU 34 | 35 | install-from-source: 36 | pip uninstall catalyst -y && pip install -e ./ 37 | 38 | clean: 39 | rm -rf build/ 40 | docker rmi -f catalyst:latest 41 | docker rmi -f catalyst-dev:latest 42 | 43 | check: 44 | catalyst-make-codestyle -l 89 45 | catalyst-check-codestyle -l 89 46 | 47 | -------------------------------------------------------------------------------- /bin/scripts/download-gdrive: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # usage: 4 | # download-gdrive {FILE_ID} {FILENAME} 5 | 6 | if [[ "$(uname)" == "Darwin" ]]; then 7 | CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | gsed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p') 8 | elif [[ "$(expr substr $(uname -s) 1 5)" == "Linux" ]]; then 9 | CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p') 10 | fi 11 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1" -O $2 12 | rm -rf /tmp/cookies.txt 13 | -------------------------------------------------------------------------------- /bin/scripts/extract-archive: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [[ -z "$1" ]]; then 4 | # display usage if no parameters given 5 | echo "Usage: extract-archive ." 6 | echo " extract-archive [path/file_name_2.ext] [path/file_name_3.ext]" 7 | else 8 | for n in "$@" 9 | do 10 | if [ -f "$n" ] ; then 11 | case "${n%,}" in 12 | *.tar.bz2|*.tar.gz|*.tar.xz|*.tbz2|*.tgz|*.txz|*.tar) 13 | tar xvf "$n" ;; 14 | *.lzma) unlzma ./"$n" ;; 15 | *.bz2) bunzip2 ./"$n" ;; 16 | *.rar) unrar x -ad ./"$n" ;; 17 | *.gz) gunzip ./"$n" ;; 18 | *.zip) unzip ./"$n" ;; 19 | *.z) uncompress ./"$n" ;; 20 | *.7z|*.arj|*.cab|*.chm|*.deb|*.dmg|*.iso|*.lzh|*.msi|*.rpm|*.udf|*.wim|*.xar) 21 | 7z x ./"$n" ;; 22 | *.xz) unxz ./"$n" ;; 23 | *.exe) cabextract ./"$n" ;; 24 | *.cpio) cpio -id < ./"$n" ;; 25 | *) 26 | echo "extract-archive: '$n' - unknown archive method" 27 | return 1 28 | ;; 29 | esac 30 | else 31 | echo "'$n' - file does not exist" 32 | return 1 33 | fi 34 | done 35 | fi 36 | -------------------------------------------------------------------------------- /bin/scripts/install-apex: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Cause the script to exit if a single command fails 4 | set -eo pipefail -v 5 | 6 | 7 | ################################### APEX #################################### 8 | pip install -v --no-cache-dir \ 9 | --global-option="--cpp_ext" --global-option="--cuda_ext" \ 10 | git+https://github.com/NVIDIA/apex 11 | -------------------------------------------------------------------------------- /bin/workflows/check_config_api.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Cause the script to exit if a single command fails 4 | set -eo pipefail -v 5 | 6 | DEVICE=${DEVICE:="cpu"} 7 | 8 | pip install -e . --quiet --no-deps --upgrade-strategy only-if-needed 9 | 10 | ################################ pipeline 00 ################################ 11 | # checking `catalyst-run` console script entry point 12 | 13 | PYTHONPATH="${PYTHONPATH}:." catalyst-run \ 14 | -C "tests/pipelines/configs/test_mnist.yml" \ 15 | "tests/pipelines/configs/engine_${DEVICE}.yml" 16 | 17 | rm -rf tests/logs 18 | 19 | ################################ pipeline 01 ################################ 20 | # checking `catalyst-tune` console script entry point 21 | 22 | pip install -r requirements/requirements-optuna.txt --quiet \ 23 | --find-links https://download.pytorch.org/whl/cpu/torch_stable.html \ 24 | --upgrade-strategy only-if-needed 25 | 26 | PYTHONPATH="${PYTHONPATH}:." catalyst-tune \ 27 | -C "tests/contrib/scripts/test_tune.yml" \ 28 | "tests/pipelines/configs/engine_${DEVICE}.yml" \ 29 | --n-trials 2 30 | 31 | rm -rf tests/logs 32 | -------------------------------------------------------------------------------- /bin/workflows/check_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir -p builds/ 4 | 5 | sphinx-build -b html ./docs/ builds/ -W --keep-going 6 | CODE=$? 7 | 8 | if [[ -z $REMOVE_BUILDS || $REMOVE_BUILDS -eq 1 ]]; then 9 | rm -rf builds/ 10 | fi 11 | 12 | echo "#### CODE: ${CODE} ####" 13 | exit ${CODE} 14 | -------------------------------------------------------------------------------- /bin/workflows/check_pipelines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Cause the script to exit if a single command fails 4 | set -eo pipefail -v 5 | 6 | CPU_REQUIRED=1 MASTER_ADDR="127.0.0.1" MASTER_PORT=2112 CATALYST_COMPUTE_PER_CLASS_METRICS="1" OMP_NUM_THREADS="1" MKL_NUM_THREADS="1" pytest ./tests/pipelines 7 | 8 | GPU_REQUIRED=1 MASTER_ADDR="127.0.0.1" MASTER_PORT=2112 CATALYST_COMPUTE_PER_CLASS_METRICS="1" OMP_NUM_THREADS="1" MKL_NUM_THREADS="1" pytest ./tests/pipelines 9 | GPU_AMP_REQUIRED=1 MASTER_ADDR="127.0.0.1" MASTER_PORT=2112 CATALYST_COMPUTE_PER_CLASS_METRICS="1" OMP_NUM_THREADS="1" MKL_NUM_THREADS="1" pytest ./tests/pipelines 10 | 11 | DP_REQUIRED=1 MASTER_ADDR="127.0.0.1" MASTER_PORT=2112 CATALYST_COMPUTE_PER_CLASS_METRICS="1" OMP_NUM_THREADS="1" MKL_NUM_THREADS="1" pytest ./tests/pipelines 12 | DP_AMP_REQUIRED=1 MASTER_ADDR="127.0.0.1" MASTER_PORT=2112 CATALYST_COMPUTE_PER_CLASS_METRICS="1" OMP_NUM_THREADS="1" MKL_NUM_THREADS="1" pytest ./tests/pipelines 13 | 14 | DDP_REQUIRED=1 MASTER_ADDR="127.0.0.1" MASTER_PORT=2112 CATALYST_COMPUTE_PER_CLASS_METRICS="1" OMP_NUM_THREADS="1" MKL_NUM_THREADS="1" pytest ./tests/pipelines 15 | DDP_AMP_REQUIRED=1 MASTER_ADDR="127.0.0.1" MASTER_PORT=2112 CATALYST_COMPUTE_PER_CLASS_METRICS="1" OMP_NUM_THREADS="1" MKL_NUM_THREADS="1" pytest ./tests/pipelines 16 | -------------------------------------------------------------------------------- /bin/workflows/check_projector.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Cause the script to exit if a single command fails 4 | set -eo pipefail -v 5 | 6 | pip install -r requirements/requirements.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade-strategy only-if-needed 7 | 8 | ################################ pipeline 01 ################################ 9 | TEST_DIR='./logs/tb_check' 10 | 11 | python -c """ 12 | import os 13 | import csv 14 | import cv2 15 | import numpy as np 16 | 17 | num_images = 100 18 | dir = '$TEST_DIR' 19 | images_dir = os.path.join(dir, 'images') 20 | 21 | os.makedirs(images_dir) 22 | 23 | # images 24 | images = [] 25 | meta_info = [] 26 | for i in range(num_images): 27 | random_img = np.random.randn(64, 64, 3) 28 | img_path = os.path.join(images_dir, 'image_{}.jpeg'.format(i)) 29 | cv2.imwrite(img_path, random_img) 30 | images.append(img_path) 31 | meta_info.append(str(i)) 32 | 33 | # csv file 34 | with open(os.path.join(dir, 'images.csv'), 'w', newline='') as csv_file: 35 | fields = ['images', 'meta'] 36 | writer = csv.DictWriter(csv_file, fieldnames=fields) 37 | writer.writeheader() 38 | for img, meta in zip(images, meta_info): 39 | writer.writerow({'images': img, 'meta': meta}) 40 | 41 | # embedding 42 | vectors = np.random.randn(len(images), 128) 43 | np.save(os.path.join(dir, 'embeddings.npy'), vectors) 44 | """ 45 | 46 | PYTHONPATH=.:${PYTHONPATH} \ 47 | python catalyst/contrib/scripts/project_embeddings.py \ 48 | --in-npy="${TEST_DIR}/embeddings.npy" \ 49 | --in-csv="${TEST_DIR}/images.csv" \ 50 | --out-dir=${TEST_DIR} \ 51 | --img-size=64 \ 52 | --img-col='images' \ 53 | --meta-cols='meta' \ 54 | --img-rootpath='.' 55 | 56 | rm -rf ${TEST_DIR} 57 | -------------------------------------------------------------------------------- /catalyst/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore", message="numpy.dtype size changed", append=True) 5 | warnings.filterwarnings("ignore", module="tqdm", append=True) 6 | warnings.filterwarnings("once", append=True) 7 | warnings.filterwarnings( 8 | "ignore", message="This overload of add_ is deprecated", append=True 9 | ) 10 | 11 | from catalyst.__version__ import __version__ 12 | from catalyst.settings import SETTINGS 13 | -------------------------------------------------------------------------------- /catalyst/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "22.04" 2 | -------------------------------------------------------------------------------- /catalyst/callbacks/backward.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, TYPE_CHECKING, Union 2 | from functools import partial 3 | 4 | from catalyst.core.callback import IBackwardCallback 5 | from catalyst.registry import REGISTRY 6 | 7 | if TYPE_CHECKING: 8 | from catalyst.core.runner import IRunner 9 | 10 | 11 | class BackwardCallback(IBackwardCallback): 12 | """Optimizer callback, abstraction over backward step. 13 | 14 | Args: 15 | metric_key: a key to get loss from ``runner.batch_metrics`` 16 | grad_clip_fn: callable gradient cliping function or it's name 17 | grad_clip_params: key-value parameters for grad_clip_fn 18 | log_gradient: boolean flag to log gradient norm to ``runner.batch_metrics`` 19 | 20 | .. note:: 21 | Please follow the `minimal examples`_ sections for more use cases. 22 | 23 | .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples # noqa: E501, W505 24 | """ 25 | 26 | def __init__( 27 | self, 28 | metric_key: str, 29 | grad_clip_fn: Union[str, Callable] = None, 30 | grad_clip_params: Dict = None, 31 | log_gradient: bool = False, 32 | ): 33 | """Init.""" 34 | super().__init__() 35 | self.metric_key = metric_key 36 | 37 | if isinstance(grad_clip_fn, str): 38 | self.grad_clip_fn = REGISTRY.get(grad_clip_fn) 39 | else: 40 | self.grad_clip_fn = grad_clip_fn 41 | if grad_clip_params is not None: 42 | self.grad_clip_fn = partial(self.grad_clip_fn, **grad_clip_params) 43 | 44 | self._prefix_gradient = f"gradient/{metric_key}" 45 | self._log_gradient = log_gradient 46 | 47 | def on_batch_end(self, runner: "IRunner"): 48 | """Event handler.""" 49 | if runner.is_train_loader: 50 | loss = runner.batch_metrics[self.metric_key] 51 | runner.engine.backward(loss) 52 | 53 | if self.grad_clip_fn is not None: 54 | runner.engine.unscale_gradients() 55 | norm = self.grad_clip_fn(self.model.parameters()) 56 | if self._log_gradient: 57 | runner.batch_metrics[f"{self._prefix_gradient}/norm"] = norm 58 | 59 | 60 | __all__ = ["BackwardCallback"] 61 | -------------------------------------------------------------------------------- /catalyst/callbacks/criterion.py: -------------------------------------------------------------------------------- 1 | from catalyst.callbacks.metrics.functional_metric import FunctionalMetricCallback 2 | from catalyst.core.callback import ICriterionCallback 3 | from catalyst.core.runner import IRunner 4 | from catalyst.utils.misc import get_attr 5 | 6 | 7 | class CriterionCallback(FunctionalMetricCallback, ICriterionCallback): 8 | """Criterion callback, abstraction over criterion step. 9 | 10 | Args: 11 | input_key: input key to use for metric calculation, specifies our `y_pred` 12 | target_key: output key to use for metric calculation, specifies our `y_true` 13 | metric_key: key to store computed metric in ``runner.batch_metrics`` dictionary 14 | criterion_key: A key to take a criterion in case 15 | there are several of them, and they are in a dictionary format. 16 | 17 | .. note:: 18 | Please follow the `minimal examples`_ sections for more use cases. 19 | 20 | .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples # noqa: E501, W505 21 | """ 22 | 23 | def __init__( 24 | self, 25 | input_key: str, 26 | target_key: str, 27 | metric_key: str, 28 | criterion_key: str = None, 29 | prefix: str = None, 30 | suffix: str = None, 31 | ): 32 | """Init.""" 33 | super().__init__( 34 | input_key=input_key, 35 | target_key=target_key, 36 | metric_fn=self._metric_fn, 37 | metric_key=metric_key, 38 | compute_on_call=True, 39 | log_on_batch=True, 40 | prefix=prefix, 41 | suffix=suffix, 42 | ) 43 | self.criterion_key = criterion_key 44 | self.criterion = None 45 | 46 | def _metric_fn(self, *args, **kwargs): 47 | return self.criterion(*args, **kwargs) 48 | 49 | def on_experiment_start(self, runner: "IRunner"): 50 | """Event handler.""" 51 | self.criterion = get_attr(runner, key="criterion", inner_key=self.criterion_key) 52 | assert self.criterion is not None 53 | 54 | 55 | __all__ = ["CriterionCallback"] 56 | -------------------------------------------------------------------------------- /catalyst/callbacks/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.settings import SETTINGS 4 | 5 | from catalyst.callbacks.metrics.accuracy import ( 6 | AccuracyCallback, 7 | MultilabelAccuracyCallback, 8 | ) 9 | from catalyst.callbacks.metrics.auc import AUCCallback 10 | 11 | from catalyst.callbacks.metrics.classification import ( 12 | PrecisionRecallF1SupportCallback, 13 | MultilabelPrecisionRecallF1SupportCallback, 14 | ) 15 | 16 | from catalyst.callbacks.metrics.cmc_score import CMCScoreCallback, ReidCMCScoreCallback 17 | 18 | if SETTINGS.ml_required: 19 | from catalyst.callbacks.metrics.confusion_matrix import ConfusionMatrixCallback 20 | 21 | from catalyst.callbacks.metrics.functional_metric import FunctionalMetricCallback 22 | 23 | from catalyst.callbacks.metrics.r2_squared import R2SquaredCallback 24 | 25 | from catalyst.callbacks.metrics.recsys import ( 26 | HitrateCallback, 27 | MAPCallback, 28 | MRRCallback, 29 | NDCGCallback, 30 | ) 31 | 32 | if SETTINGS.ml_required: 33 | from catalyst.callbacks.metrics.scikit_learn import ( 34 | SklearnBatchCallback, 35 | SklearnLoaderCallback, 36 | ) 37 | 38 | from catalyst.callbacks.metrics.segmentation import ( 39 | DiceCallback, 40 | IOUCallback, 41 | TrevskyCallback, 42 | ) 43 | -------------------------------------------------------------------------------- /catalyst/callbacks/metrics/functional_metric.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Iterable, Union 2 | 3 | from catalyst.callbacks.metric import FunctionalBatchMetricCallback 4 | from catalyst.metrics._functional_metric import FunctionalBatchMetric 5 | 6 | 7 | class FunctionalMetricCallback(FunctionalBatchMetricCallback): 8 | """ 9 | 10 | Args: 11 | input_key: input key to use for metric calculation, specifies our `y_pred` 12 | target_key: output key to use for metric calculation, specifies our `y_true` 13 | metric_fn: metric function, that get outputs, targets 14 | and return score as torch.Tensor 15 | metric_key: key to store computed metric in ``runner.batch_metrics`` dictionary 16 | compute_on_call: Computes and returns metric value during metric call. 17 | Used for per-batch logging. default: True 18 | log_on_batch: boolean flag to log computed metrics every batch 19 | prefix: metric prefix 20 | suffix: metric suffix 21 | """ 22 | 23 | def __init__( 24 | self, 25 | input_key: Union[str, Iterable[str], Dict[str, str]], 26 | target_key: Union[str, Iterable[str], Dict[str, str]], 27 | metric_fn: Callable, 28 | metric_key: str, 29 | compute_on_call: bool = True, 30 | log_on_batch: bool = True, 31 | prefix: str = None, 32 | suffix: str = None, 33 | ): 34 | """Init.""" 35 | super().__init__( 36 | metric=FunctionalBatchMetric( 37 | metric_fn=metric_fn, 38 | metric_key=metric_key, 39 | compute_on_call=compute_on_call, 40 | prefix=prefix, 41 | suffix=suffix, 42 | ), 43 | input_key=input_key, 44 | target_key=target_key, 45 | log_on_batch=log_on_batch, 46 | ) 47 | 48 | 49 | __all__ = ["FunctionalMetricCallback"] 50 | -------------------------------------------------------------------------------- /catalyst/contrib/README.md: -------------------------------------------------------------------------------- 1 | # Catalyst.contrib 2 | 3 | Catalyst contrib modules are supported in the code-as-a-documentation format. If you are interested in the details - please, follow the code of the implementation. If you are interested in contributing to the library - feel free to open a pull request. 4 | -------------------------------------------------------------------------------- /catalyst/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # As far as contrib is **not** very stable - do not import it all :) 4 | 5 | # from catalyst.contrib.data import * 6 | # from catalyst.contrib.datasets import * 7 | # from catalyst.contrib.layers import * 8 | # from catalyst.contrib.losses import * 9 | # from catalyst.contrib.models import * 10 | # from catalyst.contrib.optimizers import * 11 | # from catalyst.contrib.schedulers import * 12 | -------------------------------------------------------------------------------- /catalyst/contrib/data/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.settings import SETTINGS 4 | 5 | from catalyst.contrib.data.collate_fn import FilteringCollateFn 6 | 7 | from catalyst.contrib.data.dataset import ( 8 | ListDataset, 9 | MergeDataset, 10 | NumpyDataset, 11 | PathsDataset, 12 | ) 13 | from catalyst.contrib.data.dataset_ml import ( 14 | MetricLearningTrainDataset, 15 | QueryGalleryDataset, 16 | ) 17 | 18 | from catalyst.contrib.data.reader import ( 19 | IReader, 20 | ScalarReader, 21 | LambdaReader, 22 | ReaderCompose, 23 | ) 24 | 25 | from catalyst.contrib.data.sampler_inbatch import ( 26 | IInbatchTripletSampler, 27 | InBatchTripletsSampler, 28 | AllTripletsSampler, 29 | HardTripletsSampler, 30 | HardClusterSampler, 31 | ) 32 | from catalyst.contrib.data.sampler import BalanceBatchSampler, DynamicBalanceClassSampler 33 | 34 | from catalyst.contrib.data.transforms import ( 35 | image_to_tensor, 36 | normalize_image, 37 | Compose, 38 | ImageToTensor, 39 | NormalizeImage, 40 | ) 41 | 42 | if SETTINGS.cv_required: 43 | from catalyst.contrib.data.dataset_cv import ImageFolderDataset 44 | from catalyst.contrib.data.reader_cv import ImageReader, MaskReader 45 | 46 | 47 | # if SETTINGS.nifti_required: 48 | # from catalyst.contrib.data.reader_nifti import NiftiReader 49 | -------------------------------------------------------------------------------- /catalyst/contrib/data/_misc.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, List, Union 2 | 3 | import numpy as np 4 | 5 | from torch import int as tint, long, short, Tensor 6 | 7 | 8 | def find_value_ids(it: Iterable[Any], value: Any) -> List[int]: 9 | """ 10 | Args: 11 | it: list of any 12 | value: query element 13 | 14 | Returns: 15 | indices of the all elements equal x0 16 | """ 17 | if isinstance(it, np.ndarray): 18 | inds = list(np.where(it == value)[0]) 19 | else: # could be very slow 20 | inds = [i for i, el in enumerate(it) if el == value] 21 | return inds 22 | 23 | 24 | def convert_labels2list(labels: Union[Tensor, List[int]]) -> List[int]: 25 | """ 26 | This function allows to work with 2 types of indexing: 27 | using a integer tensor and a list of indices. 28 | 29 | Args: 30 | labels: labels of batch samples 31 | 32 | Returns: 33 | labels of batch samples in the aligned format 34 | 35 | Raises: 36 | TypeError: if type of input labels is not tensor and list 37 | """ 38 | if isinstance(labels, Tensor): 39 | labels = labels.squeeze() 40 | assert (len(labels.shape) == 1) and ( 41 | labels.dtype in [short, tint, long] 42 | ), "Labels cannot be interpreted as indices." 43 | labels_list = labels.tolist() 44 | elif isinstance(labels, list): 45 | labels_list = labels.copy() 46 | else: 47 | raise TypeError(f"Unexpected type of labels: {type(labels)}).") 48 | 49 | return labels_list 50 | 51 | 52 | __all__ = ["find_value_ids", "convert_labels2list"] 53 | -------------------------------------------------------------------------------- /catalyst/contrib/data/collate_fn.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | 3 | from torch.utils.data.dataloader import default_collate 4 | 5 | 6 | class FilteringCollateFn: 7 | """ 8 | Callable object doing job of ``collate_fn`` like ``default_collate``, 9 | but does not cast batch items with specified key to `torch.Tensor`. 10 | 11 | Only adds them to list. 12 | Supports only key-value format batches 13 | """ 14 | 15 | def __init__(self, *keys): 16 | """ 17 | Args: 18 | keys: Keys for values that will not be converted to tensor and stacked 19 | """ 20 | self.keys = keys 21 | 22 | def __call__(self, batch): 23 | """ 24 | Args: 25 | batch: current batch 26 | 27 | Returns: 28 | batch values filtered by `keys` 29 | """ 30 | if isinstance(batch[0], collections.abc.Mapping): 31 | result = {} 32 | for key in batch[0]: 33 | items = [d[key] for d in batch] 34 | if key not in self.keys: 35 | items = default_collate(items) 36 | result[key] = items 37 | return result 38 | else: 39 | return default_collate(batch) 40 | 41 | 42 | __all__ = ["FilteringCollateFn"] 43 | -------------------------------------------------------------------------------- /catalyst/contrib/data/dataset_ml.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MetricLearningTrainDataset(Dataset, ABC): 9 | """ 10 | Base class for datasets adapted for 11 | metric learning train stage. 12 | """ 13 | 14 | @abstractmethod 15 | def get_labels(self) -> List[int]: 16 | """ 17 | Dataset for metric learning must provide 18 | label of each sample for forming positive 19 | and negative pairs during the training 20 | based on these labels. 21 | 22 | Raises: 23 | NotImplementedError: You should implement it # noqa: DAR402 24 | """ 25 | raise NotImplementedError() 26 | 27 | 28 | class QueryGalleryDataset(Dataset, ABC): 29 | """ 30 | QueryGallleryDataset for CMCScoreCallback 31 | """ 32 | 33 | @abstractmethod 34 | def __getitem__(self, item) -> Dict[str, torch.Tensor]: 35 | """ 36 | Dataset for query/gallery split should 37 | return dict with `feature`, `targets` and 38 | `is_query` key. Value by key `is_query` should 39 | be boolean and indicate whether current object 40 | is in query or in gallery. 41 | 42 | Args: 43 | item: Item 44 | 45 | Raises: 46 | NotImplementedError: You should implement it # noqa: DAR402 47 | """ 48 | raise NotImplementedError() 49 | 50 | @property 51 | @abstractmethod 52 | def query_size(self) -> int: 53 | """ 54 | Query/Gallery dataset should have property 55 | query size. 56 | 57 | Returns: 58 | query size # noqa: DAR202 59 | 60 | Raises: 61 | NotImplementedError: You should implement it # noqa: DAR402 62 | """ 63 | raise NotImplementedError() 64 | 65 | @property 66 | @abstractmethod 67 | def gallery_size(self) -> int: 68 | """ 69 | Query/Gallery dataset should have property 70 | gallery size. 71 | 72 | Returns: 73 | gallery size # noqa: DAR202 74 | 75 | Raises: 76 | NotImplementedError: You should implement it # noqa: DAR402 77 | """ 78 | raise NotImplementedError() 79 | 80 | 81 | __all__ = ["MetricLearningTrainDataset", "QueryGalleryDataset"] 82 | -------------------------------------------------------------------------------- /catalyst/contrib/datasets/README.md: -------------------------------------------------------------------------------- 1 | Some poarts of this subpackage was borrowed from [torchvision](https://github.com/pytorch/vision) and [fastai/imagenette](https://github.com/fastai/imagenette). 2 | -------------------------------------------------------------------------------- /catalyst/contrib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.settings import SETTINGS 4 | 5 | from catalyst.contrib.datasets.cifar import CIFAR10, CIFAR100 6 | 7 | if SETTINGS.cv_required: 8 | from catalyst.contrib.datasets.imagecar import CarvanaOneCarDataset 9 | 10 | from catalyst.contrib.datasets.imagenette import ( 11 | Imagenette, 12 | Imagenette160, 13 | Imagenette320, 14 | ) 15 | from catalyst.contrib.datasets.imagewang import ( 16 | Imagewang, 17 | Imagewang160, 18 | Imagewang320, 19 | ) 20 | from catalyst.contrib.datasets.imagewoof import ( 21 | Imagewoof, 22 | Imagewoof160, 23 | Imagewoof320, 24 | ) 25 | 26 | from catalyst.contrib.datasets.market1501 import ( 27 | Market1501MLDataset, 28 | Market1501QGDataset, 29 | ) 30 | 31 | from catalyst.contrib.datasets.mnist import ( 32 | MnistMLDataset, 33 | MnistQGDataset, 34 | MNIST, 35 | PartialMNIST, 36 | ) 37 | 38 | if SETTINGS.ml_required: 39 | from catalyst.contrib.datasets.movielens import MovieLens 40 | 41 | if SETTINGS.is_torch_1_7_0: 42 | from catalyst.contrib.datasets.movielens import MovieLens20M 43 | -------------------------------------------------------------------------------- /catalyst/contrib/datasets/imagenette.py: -------------------------------------------------------------------------------- 1 | from catalyst.contrib.datasets.misc_cv import ImageClassificationDataset 2 | 3 | 4 | class Imagenette(ImageClassificationDataset): 5 | """ 6 | `Imagenette `_ Dataset. 7 | 8 | .. note:: 9 | catalyst[cv] required for this dataset. 10 | """ 11 | 12 | name = "imagenette2" 13 | resources = [ 14 | ( 15 | "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", 16 | "fe2fc210e6bb7c5664d602c3cd71e612", 17 | ) 18 | ] 19 | 20 | 21 | class Imagenette160(ImageClassificationDataset): 22 | """ 23 | `Imagenette `_ Dataset 24 | with images resized so that the shortest size is 160 px. 25 | 26 | .. note:: 27 | catalyst[cv] required for this dataset. 28 | """ 29 | 30 | name = "imagenette2-160" 31 | resources = [ 32 | ( 33 | "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", 34 | "e793b78cc4c9e9a4ccc0c1155377a412", 35 | ) 36 | ] 37 | 38 | 39 | class Imagenette320(ImageClassificationDataset): 40 | """ 41 | `Imagenette `_ Dataset 42 | with images resized so that the shortest size is 320 px. 43 | 44 | .. note:: 45 | catalyst[cv] required for this dataset. 46 | """ 47 | 48 | name = "imagenette2-320" 49 | resources = [ 50 | ( 51 | "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", 52 | "3df6f0d01a2c9592104656642f5e78a3", 53 | ) 54 | ] 55 | 56 | 57 | __all__ = ["Imagenette", "Imagenette160", "Imagenette320"] 58 | -------------------------------------------------------------------------------- /catalyst/contrib/datasets/imagewang.py: -------------------------------------------------------------------------------- 1 | from catalyst.contrib.datasets.misc_cv import ImageClassificationDataset 2 | 3 | 4 | class Imagewang(ImageClassificationDataset): 5 | """ 6 | `Imagewang `_ Dataset. 7 | 8 | .. note:: 9 | catalyst[cv] required for this dataset. 10 | """ 11 | 12 | name = "imagewang" 13 | resources = [ 14 | ( 15 | "https://s3.amazonaws.com/fast-ai-imageclas/imagewang.tgz", 16 | "46f9749616a29837e7cd67b103396f6e", 17 | ) 18 | ] 19 | 20 | 21 | class Imagewang160(ImageClassificationDataset): 22 | """ 23 | `Imagewang `_ Dataset 24 | with images resized so that the shortest size is 160 px. 25 | 26 | .. note:: 27 | catalyst[cv] required for this dataset. 28 | """ 29 | 30 | name = "imagewang-160" 31 | resources = [ 32 | ( 33 | "https://s3.amazonaws.com/fast-ai-imageclas/imagewang-160.tgz", 34 | "1dc388d37d1dc52836c06749e14e37bc", 35 | ) 36 | ] 37 | 38 | 39 | class Imagewang320(ImageClassificationDataset): 40 | """ 41 | `Imagewang `_ Dataset 42 | with images resized so that the shortest size is 320 px. 43 | 44 | .. note:: 45 | catalyst[cv] required for this dataset. 46 | """ 47 | 48 | name = "imagewang-320" 49 | resources = [ 50 | ( 51 | "https://s3.amazonaws.com/fast-ai-imageclas/imagewang-320.tgz", 52 | "ff01d7c126230afce776bdf72bda87e6", 53 | ) 54 | ] 55 | 56 | 57 | __all__ = ["Imagewang", "Imagewang160", "Imagewang320"] 58 | -------------------------------------------------------------------------------- /catalyst/contrib/datasets/imagewoof.py: -------------------------------------------------------------------------------- 1 | from catalyst.contrib.datasets.misc_cv import ImageClassificationDataset 2 | 3 | 4 | class Imagewoof(ImageClassificationDataset): 5 | """ 6 | `Imagewoof `_ Dataset. 7 | 8 | .. note:: 9 | catalyst[cv] required for this dataset. 10 | """ 11 | 12 | name = "imagewoof2" 13 | resources = [ 14 | ( 15 | "https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2.tgz", 16 | "9aafe18bcdb1632c4249a76c458465ba", 17 | ) 18 | ] 19 | 20 | 21 | class Imagewoof160(ImageClassificationDataset): 22 | """ 23 | `Imagewoof `_ Dataset 24 | with images resized so that the shortest size is 160 px. 25 | 26 | .. note:: 27 | catalyst[cv] required for this dataset. 28 | """ 29 | 30 | name = "imagewoof2-160" 31 | resources = [ 32 | ( 33 | "https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2-160.tgz", 34 | "3d200a7be99704a0d7509be2a9fbfe15", 35 | ) 36 | ] 37 | 38 | 39 | class Imagewoof320(ImageClassificationDataset): 40 | """ 41 | `Imagewoof `_ Dataset 42 | with images resized so that the shortest size is 320 px. 43 | 44 | .. note:: 45 | catalyst[cv] required for this dataset. 46 | """ 47 | 48 | name = "imagewoof2-320" 49 | resources = [ 50 | ( 51 | "https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2-320.tgz", 52 | "0f46d997ec2264e97609196c95897a44", 53 | ) 54 | ] 55 | 56 | 57 | __all__ = ["Imagewoof", "Imagewoof160", "Imagewoof320"] 58 | -------------------------------------------------------------------------------- /catalyst/contrib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from torch.nn.modules import * 3 | 4 | from catalyst.contrib.layers.amsoftmax import AMSoftmax 5 | from catalyst.contrib.layers.arcface import ArcFace, SubCenterArcFace 6 | from catalyst.contrib.layers.arcmargin import ArcMarginProduct 7 | from catalyst.contrib.layers.common import ( 8 | Flatten, 9 | GaussianNoise, 10 | Lambda, 11 | Normalize, 12 | ResidualBlock, 13 | ) 14 | from catalyst.contrib.layers.cosface import CosFace, AdaCos 15 | from catalyst.contrib.layers.curricularface import CurricularFace 16 | from catalyst.contrib.layers.factorized import FactorizedLinear 17 | from catalyst.contrib.layers.lama import ( 18 | LamaPooling, 19 | TemporalLastPooling, 20 | TemporalAvgPooling, 21 | TemporalMaxPooling, 22 | TemporalDropLastWrapper, 23 | TemporalAttentionPooling, 24 | TemporalConcatPooling, 25 | ) 26 | from catalyst.contrib.layers.pooling import ( 27 | GlobalAttnPool2d, 28 | GlobalAvgAttnPool2d, 29 | GlobalAvgPool2d, 30 | GlobalConcatAttnPool2d, 31 | GlobalConcatPool2d, 32 | GlobalMaxAttnPool2d, 33 | GlobalMaxPool2d, 34 | GeM2d, 35 | ) 36 | from catalyst.contrib.layers.rms_norm import RMSNorm 37 | from catalyst.contrib.layers.se import ( 38 | sSE, 39 | scSE, 40 | cSE, 41 | ) 42 | from catalyst.contrib.layers.softmax import SoftMax 43 | -------------------------------------------------------------------------------- /catalyst/contrib/layers/arcmargin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ArcMarginProduct(nn.Module): 7 | """Implementation of Arc Margin Product. 8 | 9 | Args: 10 | in_features: size of each input sample. 11 | out_features: size of each output sample. 12 | 13 | Shape: 14 | - Input: :math:`(batch, H_{in})` where 15 | :math:`H_{in} = in\_features`. 16 | - Output: :math:`(batch, H_{out})` where 17 | :math:`H_{out} = out\_features`. 18 | 19 | Example: 20 | >>> layer = ArcMarginProduct(5, 10) 21 | >>> loss_fn = nn.CrosEntropyLoss() 22 | >>> embedding = torch.randn(3, 5, requires_grad=True) 23 | >>> target = torch.empty(3, dtype=torch.long).random_(10) 24 | >>> output = layer(embedding) 25 | >>> loss = loss_fn(output, target) 26 | >>> self.engine.backward(loss) 27 | 28 | """ 29 | 30 | def __init__(self, in_features: int, out_features: int): # noqa: D107 31 | super(ArcMarginProduct, self).__init__() 32 | self.in_features = in_features 33 | self.out_features = out_features 34 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 35 | nn.init.xavier_uniform_(self.weight) 36 | 37 | def __repr__(self) -> str: 38 | """Object representation.""" 39 | rep = ( 40 | "ArcMarginProduct(" 41 | f"in_features={self.in_features}," 42 | f"out_features={self.out_features}" 43 | ")" 44 | ) 45 | return rep 46 | 47 | def forward(self, input: torch.Tensor) -> torch.Tensor: 48 | """ 49 | Args: 50 | input: input features, 51 | expected shapes ``BxF`` where ``B`` 52 | is batch dimension and ``F`` is an 53 | input feature dimension. 54 | 55 | Returns: 56 | tensor (logits) with shapes ``BxC`` 57 | where ``C`` is a number of classes 58 | (out_features). 59 | """ 60 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 61 | return cosine 62 | 63 | 64 | __all__ = ["ArcMarginProduct"] 65 | -------------------------------------------------------------------------------- /catalyst/contrib/layers/factorized.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class FactorizedLinear(nn.Module): 8 | """Factorized wrapper for ``nn.Linear`` 9 | 10 | Args: 11 | nn_linear: torch ``nn.Linear`` module 12 | dim_ratio: dimension ration to use after weights SVD 13 | """ 14 | 15 | def __init__(self, nn_linear: nn.Linear, dim_ratio: Union[int, float] = 1.0): 16 | super().__init__() 17 | self.bias = nn.parameter.Parameter(nn_linear.bias.data, requires_grad=True) 18 | u, vh = self._spectral_init(nn_linear.weight.data, dim_ratio=dim_ratio) 19 | # print(f"Doing SVD of tensor {or_linear.weight.shape}, 20 | # U: {u.shape}, Vh: {vh.shape}") 21 | self.u = nn.parameter.Parameter(u, requires_grad=True) 22 | self.vh = nn.parameter.Parameter(vh, requires_grad=True) 23 | self.dim_ratio = dim_ratio 24 | self.in_features = u.size(0) 25 | self.out_features = vh.size(1) 26 | 27 | @staticmethod 28 | def _spectral_init(m, dim_ratio: Union[int, float] = 1): 29 | u, s, vh = torch.linalg.svd(m, full_matrices=False) 30 | u = u @ torch.diag(torch.sqrt(s)) 31 | vh = torch.diag(torch.sqrt(s)) @ vh 32 | if dim_ratio < 1: 33 | dims = int(u.size(1) * dim_ratio) 34 | u = u[:, :dims] 35 | vh = vh[:dims, :] 36 | # s_share = s[:dims].sum() / s.sum() * 100 37 | # print(f"SVD eigenvalue share {s_share:.2f}%") 38 | return u, vh 39 | 40 | def extra_repr(self) -> str: 41 | """Extra representation log.""" 42 | return ( 43 | f"in_features={self.in_features}, " 44 | f"out_features={self.out_features}, " 45 | f"bias=True, dim_ratio={self.dim_ratio}" 46 | ) 47 | 48 | def forward(self, x: torch.Tensor): 49 | """Forward call.""" 50 | return x @ (self.u @ self.vh).transpose(0, 1) + self.bias 51 | 52 | 53 | __all__ = ["FactorizedLinear"] 54 | -------------------------------------------------------------------------------- /catalyst/contrib/layers/rms_norm.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # @TODO: code formatting issue for 20.07 release 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class RMSNorm(nn.Module): 8 | """An implementation of RMS Normalization. 9 | 10 | @TODO: Docs (link to paper). Contribution is welcome. 11 | """ 12 | 13 | def __init__(self, dimension: int, epsilon: float = 1e-8, is_bias: bool = False): 14 | """ 15 | Args: 16 | dimension: the dimension of the layer output to normalize 17 | epsilon: an epsilon to prevent dividing by zero 18 | in case the layer has zero variance. (default = 1e-8) 19 | is_bias: a boolean value whether to include bias term 20 | while normalization 21 | """ 22 | super().__init__() 23 | self.dimension = dimension 24 | self.epsilon = epsilon 25 | self.is_bias = is_bias 26 | self.scale = nn.Parameter(torch.ones(self.dimension)) 27 | if self.is_bias: 28 | self.bias = nn.Parameter(torch.zeros(self.dimension)) 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | """@TODO: Docs. Contribution is welcome.""" 32 | x_std = torch.sqrt(torch.mean(x ** 2, -1, keepdim=True)) 33 | x_norm = x / (x_std + self.epsilon) 34 | if self.is_bias: 35 | return self.scale * x_norm + self.bias 36 | return self.scale * x_norm 37 | 38 | 39 | __all__ = ["RMSNorm"] 40 | -------------------------------------------------------------------------------- /catalyst/contrib/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | import torch 4 | from torch.nn.modules.loss import * 5 | 6 | from catalyst.contrib.losses.ce import ( 7 | MaskCrossEntropyLoss, 8 | NaiveCrossEntropyLoss, 9 | SymmetricCrossEntropyLoss, 10 | ) 11 | from catalyst.contrib.losses.circle import CircleLoss 12 | from catalyst.contrib.losses.contrastive import ( 13 | BarlowTwinsLoss, 14 | ContrastiveDistanceLoss, 15 | ContrastiveEmbeddingLoss, 16 | ContrastivePairwiseEmbeddingLoss, 17 | ) 18 | from catalyst.contrib.losses.dice import DiceLoss 19 | from catalyst.contrib.losses.focal import FocalLossBinary, FocalLossMultiClass 20 | from catalyst.contrib.losses.gan import GradientPenaltyLoss, MeanOutputLoss 21 | 22 | from catalyst.contrib.losses.iou import IoULoss 23 | from catalyst.contrib.losses.lovasz import ( 24 | LovaszLossBinary, 25 | LovaszLossMultiClass, 26 | LovaszLossMultiLabel, 27 | ) 28 | from catalyst.contrib.losses.margin import MarginLoss 29 | from catalyst.contrib.losses.ntxent import NTXentLoss 30 | from catalyst.contrib.losses.recsys import ( 31 | AdaptiveHingeLoss, 32 | BPRLoss, 33 | HingeLoss, 34 | LogisticLoss, 35 | RocStarLoss, 36 | WARPLoss, 37 | ) 38 | from catalyst.contrib.losses.regression import ( 39 | HuberLossV0, 40 | CategoricalRegressionLoss, 41 | QuantileRegressionLoss, 42 | RSquareLoss, 43 | ) 44 | from catalyst.contrib.losses.supervised_contrastive import SupervisedContrastiveLoss 45 | from catalyst.contrib.losses.smoothing_dice import SmoothingDiceLoss 46 | from catalyst.contrib.losses.trevsky import FocalTrevskyLoss, TrevskyLoss 47 | from catalyst.contrib.losses.triplet import ( 48 | TripletLoss, 49 | TripletLossV2, 50 | TripletMarginLossWithSampler, 51 | TripletPairwiseEmbeddingLoss, 52 | ) 53 | from catalyst.contrib.losses.wing import WingLoss 54 | -------------------------------------------------------------------------------- /catalyst/contrib/losses/dice.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from catalyst.metrics.functional import dice 8 | 9 | 10 | class DiceLoss(nn.Module): 11 | """The Dice loss. 12 | DiceLoss = 1 - dice score 13 | dice score = 2 * intersection / (intersection + union)) = \ 14 | = 2 * tp / (2 * tp + fp + fn) 15 | """ 16 | 17 | def __init__( 18 | self, 19 | class_dim: int = 1, 20 | mode: str = "macro", 21 | weights: List[float] = None, 22 | eps: float = 1e-7, 23 | ): 24 | """ 25 | Args: 26 | class_dim: indicates class dimention (K) for 27 | ``outputs`` and ``targets`` tensors (default = 1) 28 | mode: class summation strategy. Must be one of ['micro', 'macro', 29 | 'weighted']. If mode='micro', classes are ignored, and metric 30 | are calculated generally. If mode='macro', metric are 31 | calculated per-class and than are averaged over all classes. 32 | If mode='weighted', metric are calculated per-class and than 33 | summed over all classes with weights. 34 | weights: class weights(for mode="weighted") 35 | eps: epsilon to avoid zero division 36 | """ 37 | super().__init__() 38 | assert mode in ["micro", "macro", "weighted"] 39 | self.loss_fn = partial( 40 | dice, 41 | eps=eps, 42 | class_dim=class_dim, 43 | threshold=None, 44 | mode=mode, 45 | weights=weights, 46 | ) 47 | 48 | def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 49 | """Calculates loss between ``logits`` and ``target`` tensors.""" 50 | dice_score = self.loss_fn(outputs, targets) 51 | return 1 - dice_score 52 | 53 | 54 | __all__ = ["DiceLoss"] 55 | -------------------------------------------------------------------------------- /catalyst/contrib/losses/gan.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class MeanOutputLoss(nn.Module): 7 | """ 8 | Criterion to compute simple mean of the output, completely ignoring target 9 | (maybe useful e.g. for WGAN real/fake validity averaging. 10 | """ 11 | 12 | def forward(self, output, target): 13 | """Compute criterion. 14 | @TODO: Docs (add typing). Contribution is welcome. 15 | """ 16 | return output.mean() 17 | 18 | 19 | class GradientPenaltyLoss(nn.Module): 20 | """Criterion to compute gradient penalty. 21 | 22 | WARN: SHOULD NOT BE RUN WITH CriterionCallback, 23 | use special GradientPenaltyCallback instead 24 | """ 25 | 26 | def forward(self, fake_data, real_data, critic, critic_condition_args): 27 | """Compute gradient penalty. 28 | @TODO: Docs. Contribution is welcome. 29 | """ 30 | device = real_data.device 31 | # Random weight term for interpolation between real and fake samples 32 | alpha = torch.rand((real_data.size(0), 1, 1, 1), device=device) 33 | # Get random interpolation between real and fake samples 34 | interpolates = (alpha * real_data + ((1 - alpha) * fake_data)).detach() 35 | interpolates.requires_grad_(True) 36 | with torch.set_grad_enabled(True): # to compute in validation mode 37 | d_interpolates = critic(interpolates, *critic_condition_args) 38 | 39 | fake = torch.ones((real_data.size(0), 1), device=device, requires_grad=False) 40 | # Get gradient w.r.t. interpolates 41 | gradients = torch.autograd.grad( 42 | outputs=d_interpolates, 43 | inputs=interpolates, 44 | grad_outputs=fake, 45 | create_graph=True, 46 | retain_graph=True, 47 | only_inputs=True, 48 | )[0] 49 | gradients = gradients.view(gradients.size(0), -1) 50 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 51 | return gradient_penalty 52 | 53 | 54 | __all__ = ["MeanOutputLoss", "GradientPenaltyLoss"] 55 | -------------------------------------------------------------------------------- /catalyst/contrib/losses/iou.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from catalyst.metrics.functional import iou 8 | 9 | 10 | class IoULoss(nn.Module): 11 | """The intersection over union (Jaccard) loss. 12 | IOULoss = 1 - iou score 13 | iou score = intersection / union = tp / (tp + fp + fn) 14 | """ 15 | 16 | def __init__( 17 | self, 18 | class_dim: int = 1, 19 | mode: str = "macro", 20 | weights: List[float] = None, 21 | eps: float = 1e-7, 22 | ): 23 | """ 24 | Args: 25 | class_dim: indicates class dimention (K) for 26 | ``outputs`` and ``targets`` tensors (default = 1) 27 | mode: class summation strategy. Must be one of ['micro', 'macro', 28 | 'weighted']. If mode='micro', classes are ignored, and metric 29 | are calculated generally. If mode='macro', metric are 30 | calculated per-class and than are averaged over all classes. 31 | If mode='weighted', metric are calculated per-class and than 32 | summed over all classes with weights. 33 | weights: class weights(for mode="weighted") 34 | eps: epsilon to avoid zero division 35 | """ 36 | super().__init__() 37 | assert mode in ["micro", "macro", "weighted"] 38 | self.loss_fn = partial( 39 | iou, eps=eps, class_dim=class_dim, threshold=None, mode=mode, weights=weights 40 | ) 41 | 42 | def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 43 | """Calculates loss between ``logits`` and ``target`` tensors.""" 44 | iou_score = self.loss_fn(outputs, targets) 45 | return 1 - iou_score 46 | 47 | 48 | __all__ = ["IoULoss"] 49 | -------------------------------------------------------------------------------- /catalyst/contrib/losses/margin.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from catalyst.contrib.losses.functional import margin_loss 7 | 8 | 9 | class MarginLoss(nn.Module): 10 | """Margin loss criterion""" 11 | 12 | def __init__( 13 | self, 14 | alpha: float = 0.2, 15 | beta: float = 1.0, 16 | skip_labels: Union[int, List[int]] = -1, 17 | ): 18 | """ 19 | Margin loss constructor. 20 | 21 | Args: 22 | alpha: alpha 23 | beta: beta 24 | skip_labels (int or List[int]): labels to skip 25 | """ 26 | super().__init__() 27 | self.alpha = alpha 28 | self.beta = beta 29 | self.skip_labels = skip_labels 30 | 31 | def forward(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Forward method for the margin loss. 34 | 35 | Args: 36 | embeddings: tensor with embeddings 37 | targets: tensor with target labels 38 | 39 | Returns: 40 | computed loss 41 | """ 42 | return margin_loss( 43 | embeddings, 44 | targets, 45 | alpha=self.alpha, 46 | beta=self.beta, 47 | skip_labels=self.skip_labels, 48 | ) 49 | 50 | 51 | __all__ = ["MarginLoss"] 52 | -------------------------------------------------------------------------------- /catalyst/contrib/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.settings import SETTINGS 4 | 5 | from catalyst.contrib.models.mnist import MnistBatchNormNet, MnistSimpleNet 6 | 7 | if SETTINGS.cv_required: 8 | from catalyst.contrib.models.resnet_encoder import ResnetEncoder 9 | -------------------------------------------------------------------------------- /catalyst/contrib/models/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from catalyst.contrib.layers import Flatten, Normalize 5 | 6 | 7 | class MnistSimpleNet(nn.Module): 8 | """Simple MNIST convolutional network for test purposes.""" 9 | 10 | def __init__(self, out_features: int, normalize: bool = True): 11 | """ 12 | Args: 13 | out_features: size of the output tensor 14 | normalize: boolean flag to add normalize layer 15 | """ 16 | super().__init__() 17 | layers = [ 18 | nn.Conv2d(1, 32, 3, 1), 19 | nn.ReLU(), 20 | nn.Conv2d(32, 64, 3, 1), 21 | nn.ReLU(), 22 | nn.MaxPool2d(2), 23 | Flatten(), 24 | nn.Linear(9216, 128), 25 | nn.ReLU(), 26 | nn.Linear(128, out_features), 27 | ] 28 | if normalize: 29 | layers.append(Normalize()) 30 | self._net = nn.Sequential(*layers) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | """ 34 | Args: 35 | x: input 1d image tensor with the size of [28 x 28] 36 | 37 | Returns: 38 | extracted features 39 | """ 40 | return self._net(x) 41 | 42 | 43 | class MnistBatchNormNet(nn.Module): 44 | """Simple MNIST convolutional network with batch norm layers for test purposes.""" 45 | 46 | def __init__(self, out_features: int): 47 | """ 48 | Args: 49 | out_features: size of the output tensor 50 | """ 51 | super().__init__() 52 | layers = [ 53 | nn.Conv2d(1, 32, 3, 1), 54 | nn.LeakyReLU(), 55 | nn.BatchNorm2d(32), 56 | nn.Conv2d(32, 64, 3, 1), 57 | nn.LeakyReLU(), 58 | nn.MaxPool2d(2), 59 | Flatten(), 60 | nn.BatchNorm1d(9216), 61 | nn.Linear(9216, 128), 62 | nn.LeakyReLU(), 63 | nn.Linear(128, out_features), 64 | nn.BatchNorm1d(out_features), 65 | ] 66 | 67 | self._net = nn.Sequential(*layers) 68 | 69 | def forward(self, x: torch.Tensor) -> torch.Tensor: 70 | """ 71 | Args: 72 | x: input 1d image tensor with the size of [28 x 28] 73 | 74 | Returns: 75 | extracted features 76 | """ 77 | return self._net(x) 78 | 79 | 80 | __all__ = ["MnistSimpleNet", "MnistBatchNormNet"] 81 | -------------------------------------------------------------------------------- /catalyst/contrib/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from torch.optim import * 3 | 4 | from catalyst.contrib.optimizers.adamp import AdamP 5 | from catalyst.contrib.optimizers.lamb import Lamb 6 | from catalyst.contrib.optimizers.lookahead import Lookahead 7 | from catalyst.contrib.optimizers.qhadamw import QHAdamW 8 | from catalyst.contrib.optimizers.radam import RAdam 9 | from catalyst.contrib.optimizers.ralamb import Ralamb 10 | from catalyst.contrib.optimizers.sgdp import SGDP 11 | -------------------------------------------------------------------------------- /catalyst/contrib/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from torch.optim.lr_scheduler import * 3 | 4 | from catalyst.contrib.schedulers.base import BaseScheduler, BatchScheduler 5 | from catalyst.contrib.schedulers.onecycle import OneCycleLRWithWarmup 6 | -------------------------------------------------------------------------------- /catalyst/contrib/schedulers/base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from abc import ABC, abstractmethod 3 | 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | from catalyst.utils.torch import set_optimizer_momentum 7 | 8 | 9 | class BaseScheduler(_LRScheduler, ABC): 10 | """Base class for all schedulers with momentum update.""" 11 | 12 | @abstractmethod 13 | def get_momentum(self) -> List[float]: 14 | """Function that returns the new momentum for optimizer.""" 15 | pass 16 | 17 | def step(self, epoch: Optional[int] = None) -> None: 18 | """Make one scheduler step. 19 | 20 | Args: 21 | epoch (int, optional): current epoch num 22 | """ 23 | super().step(epoch) 24 | momentums = self.get_momentum() 25 | for i, momentum in enumerate(momentums): 26 | set_optimizer_momentum(self.optimizer, momentum, index=i) 27 | 28 | 29 | class BatchScheduler(BaseScheduler, ABC): 30 | """@TODO: Docs. Contribution is welcome.""" 31 | 32 | 33 | __all__ = ["BaseScheduler", "BatchScheduler"] 34 | -------------------------------------------------------------------------------- /catalyst/contrib/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/catalyst/contrib/scripts/__init__.py -------------------------------------------------------------------------------- /catalyst/contrib/scripts/tune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Config API and Optuna integration for AutoML hyperparameters tuning.""" 3 | from typing import Iterable 4 | import argparse 5 | import copy 6 | 7 | import optuna 8 | 9 | from catalyst import utils 10 | from catalyst.contrib.scripts import run 11 | from catalyst.registry import REGISTRY 12 | from hydra_slayer import functional as F 13 | 14 | 15 | def parse_args(args: Iterable = None, namespace: argparse.Namespace = None): 16 | """Parses the command line arguments and returns arguments and config.""" 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--n-trials", type=int, default=None) 19 | parser.add_argument("--timeout", type=int, default=None) 20 | parser.add_argument("--n-jobs", type=int, default=1) 21 | utils.boolean_flag(parser, "gc-after-trial", default=False) 22 | utils.boolean_flag(parser, "show-progress-bar", default=False) 23 | 24 | args, unknown_args = parser.parse_known_args(args=args, namespace=namespace) 25 | return vars(args), unknown_args 26 | 27 | 28 | def main(): 29 | """Runs the ``catalyst-tune`` script.""" 30 | kwargs_run, unknown_args = run.parse_args() 31 | kwargs_tune, _ = parse_args(args=unknown_args) 32 | 33 | config_full = run.process_configs(**kwargs_run) 34 | config = copy.copy(config_full) 35 | config_study = config.pop("study") 36 | 37 | # optuna objective 38 | def objective(trial: optuna.trial): 39 | # workaround for `REGISTRY.get_from_params` - redefine `trial` var 40 | experiment_params, _ = F._recursive_get_from_params( 41 | factory_key=REGISTRY.name_key, 42 | get_factory_func=REGISTRY.get, 43 | params=config, 44 | shared_params={}, 45 | var_key=REGISTRY.var_key, 46 | attrs_delimiter=REGISTRY.attrs_delimiter, 47 | vars_dict={**REGISTRY._vars_dict, "trial": trial}, 48 | ) 49 | runner = experiment_params["runner"] 50 | runner._trial = trial 51 | 52 | run.run_from_params(experiment_params) 53 | score = runner.epoch_metrics[runner._valid_loader][runner._valid_metric] 54 | 55 | return score 56 | 57 | study = REGISTRY.get_from_params(**config_study) 58 | study.optimize(objective, **kwargs_tune) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /catalyst/contrib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.settings import SETTINGS 4 | 5 | from catalyst.contrib.utils.compression import ( 6 | pack, 7 | pack_if_needed, 8 | unpack, 9 | unpack_if_needed, 10 | ) 11 | 12 | if SETTINGS.cv_required: 13 | from catalyst.contrib.utils.image import ( 14 | has_image_extension, 15 | imread, 16 | imwrite, 17 | imsave, 18 | mimread, 19 | ) 20 | 21 | from catalyst.contrib.utils.numpy import get_one_hot 22 | 23 | if SETTINGS.ml_required: 24 | from catalyst.contrib.utils.report import get_classification_report 25 | 26 | from catalyst.contrib.utils.serialization import deserialize, serialize 27 | 28 | from catalyst.contrib.utils.swa import average_weights, get_averaged_weights_by_path_mask 29 | 30 | from catalyst.contrib.utils.torch import ( 31 | get_optimal_inner_init, 32 | outer_init, 33 | get_network_output, 34 | trim_tensors, 35 | ) 36 | 37 | if SETTINGS.ml_required: 38 | from catalyst.contrib.utils.thresholds import ( 39 | get_baseline_thresholds, 40 | get_binary_threshold, 41 | get_multiclass_thresholds, 42 | get_multilabel_thresholds, 43 | get_binary_threshold_cv, 44 | get_multilabel_thresholds_cv, 45 | get_thresholds_greedy, 46 | get_multilabel_thresholds_greedy, 47 | get_multiclass_thresholds_greedy, 48 | get_best_multilabel_thresholds, 49 | get_best_multiclass_thresholds, 50 | ) 51 | 52 | if SETTINGS.ml_required: 53 | from catalyst.contrib.utils.visualization import ( 54 | plot_confusion_matrix, 55 | render_figure_to_array, 56 | ) 57 | -------------------------------------------------------------------------------- /catalyst/contrib/utils/compression.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # @TODO: code formatting issue for 20.07 release 3 | import base64 4 | import logging 5 | 6 | import numpy as np 7 | from six import string_types 8 | 9 | from catalyst.contrib.utils.serialization import deserialize, serialize 10 | from catalyst.settings import SETTINGS 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | if SETTINGS.use_lz4: 15 | try: 16 | import lz4.frame 17 | except ImportError as ex: 18 | LOGGER.warning("lz4 not available, to install lz4, run `pip install lz4`.") 19 | raise ex 20 | 21 | 22 | def is_compressed(data): 23 | """@TODO: Docs. Contribution is welcome.""" 24 | return isinstance(data, (bytes, string_types)) 25 | 26 | 27 | def compress(data): 28 | """@TODO: Docs. Contribution is welcome.""" 29 | if SETTINGS.use_lz4: 30 | data = serialize(data) 31 | data = lz4.frame.compress(data) 32 | data = base64.b64encode(data).decode("ascii") 33 | return data 34 | 35 | 36 | def compress_if_needed(data): 37 | """@TODO: Docs. Contribution is welcome.""" 38 | if isinstance(data, np.ndarray): 39 | data = compress(data) 40 | return data 41 | 42 | 43 | def decompress(data): 44 | """@TODO: Docs. Contribution is welcome.""" 45 | if SETTINGS.use_lz4: 46 | data = base64.b64decode(data) 47 | data = lz4.frame.decompress(data) 48 | data = deserialize(data) 49 | return data 50 | 51 | 52 | def decompress_if_needed(data): 53 | """@TODO: Docs. Contribution is welcome.""" 54 | if is_compressed(data): 55 | data = decompress(data) 56 | return data 57 | 58 | 59 | if SETTINGS.use_lz4: 60 | pack = compress 61 | pack_if_needed = compress_if_needed 62 | unpack = decompress 63 | unpack_if_needed = decompress_if_needed 64 | else: 65 | pack = serialize 66 | pack_if_needed = serialize 67 | unpack = deserialize 68 | unpack_if_needed = deserialize 69 | 70 | __all__ = ["pack", "pack_if_needed", "unpack", "unpack_if_needed"] 71 | -------------------------------------------------------------------------------- /catalyst/contrib/utils/numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_one_hot(label: int, num_classes: int, smoothing: float = None) -> np.ndarray: 5 | """ 6 | Applies OneHot vectorization to a giving scalar, optional with 7 | label smoothing as described in `Bag of Tricks for Image Classification 8 | with Convolutional Neural Networks`_. 9 | 10 | Args: 11 | label: scalar value to be vectorized 12 | num_classes: total number of classes 13 | smoothing (float, optional): if specified applies label smoothing 14 | from ``Bag of Tricks for Image Classification 15 | with Convolutional Neural Networks`` paper 16 | 17 | Returns: 18 | np.ndarray: a one-hot vector with shape ``(num_classes)`` 19 | 20 | .. _Bag of Tricks for Image Classification with 21 | Convolutional Neural Networks: https://arxiv.org/abs/1812.01187 22 | """ 23 | assert ( 24 | num_classes is not None and num_classes > 0 25 | ), f"Expect num_classes to be > 0, got {num_classes}" 26 | 27 | assert ( 28 | label is not None and 0 <= label < num_classes 29 | ), f"Expect label to be in [0; {num_classes}), got {label}" 30 | 31 | if smoothing is not None: 32 | assert ( 33 | 0.0 < smoothing < 1.0 34 | ), f"If smoothing is specified it must be in (0; 1), got {smoothing}" 35 | 36 | smoothed = smoothing / float(num_classes - 1) 37 | result = np.full((num_classes,), smoothed, dtype=np.float32) 38 | result[label] = 1.0 - smoothing 39 | 40 | return result 41 | 42 | result = np.zeros(num_classes, dtype=np.float32) 43 | result[label] = 1.0 44 | 45 | return result 46 | 47 | 48 | __all__ = ["get_one_hot"] 49 | -------------------------------------------------------------------------------- /catalyst/contrib/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | from catalyst.settings import SETTINGS 5 | 6 | LOGGER = logging.getLogger(__name__) 7 | 8 | if SETTINGS.use_pyarrow: 9 | try: 10 | import pyarrow 11 | except ImportError as ex: 12 | LOGGER.warning( 13 | "pyarrow not available, switching to pickle. " 14 | "To install pyarrow, run `pip install pyarrow`." 15 | ) 16 | raise ex 17 | 18 | 19 | def pyarrow_serialize(data): 20 | """Serialize the data into bytes using pyarrow. 21 | 22 | Args: 23 | data: a value 24 | 25 | Returns: 26 | Returns a bytes object serialized with pyarrow data. 27 | """ 28 | return pyarrow.serialize(data).to_buffer().to_pybytes() 29 | 30 | 31 | def pyarrow_deserialize(bytes): 32 | """Deserialize bytes into an object using pyarrow. 33 | 34 | Args: 35 | bytes: a bytes object containing serialized with pyarrow data. 36 | 37 | Returns: 38 | Returns a value deserialized from the bytes-like object. 39 | """ 40 | return pyarrow.deserialize(bytes) 41 | 42 | 43 | def pickle_serialize(data): 44 | """Serialize the data into bytes using pickle. 45 | 46 | Args: 47 | data: a value 48 | 49 | Returns: 50 | Returns a bytes object serialized with pickle data. 51 | """ 52 | return pickle.dumps(data) 53 | 54 | 55 | def pickle_deserialize(bytes): 56 | """Deserialize bytes into an object using pickle. 57 | 58 | Args: 59 | bytes: a bytes object containing serialized with pickle data. 60 | 61 | Returns: 62 | Returns a value deserialized from the bytes-like object. 63 | """ 64 | return pickle.loads(bytes) 65 | 66 | 67 | if SETTINGS.use_pyarrow: 68 | serialize = pyarrow_serialize 69 | deserialize = pyarrow_deserialize 70 | else: 71 | serialize = pickle_serialize 72 | deserialize = pickle_deserialize 73 | 74 | __all__ = ["serialize", "deserialize"] 75 | -------------------------------------------------------------------------------- /catalyst/contrib/utils/swa.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from collections import OrderedDict 3 | import glob 4 | import os 5 | from pathlib import Path 6 | 7 | import torch 8 | 9 | 10 | def _load_weights(path: str) -> dict: 11 | """ 12 | Load weights of a model. 13 | 14 | Args: 15 | path: Path to model weights 16 | 17 | Returns: 18 | Weights 19 | """ 20 | weights = torch.load(path, map_location=lambda storage, loc: storage) 21 | if "model_state_dict" in weights: 22 | weights = weights["model_state_dict"] 23 | return weights 24 | 25 | 26 | def average_weights(state_dicts: List[dict]) -> OrderedDict: 27 | """ 28 | Averaging of input weights. 29 | 30 | Args: 31 | state_dicts: Weights to average 32 | 33 | Raises: 34 | KeyError: If states do not match 35 | 36 | Returns: 37 | Averaged weights 38 | """ 39 | # source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b 40 | params_keys = None 41 | for i, state_dict in enumerate(state_dicts): 42 | model_params_keys = list(state_dict.keys()) 43 | if params_keys is None: 44 | params_keys = model_params_keys 45 | elif params_keys != model_params_keys: 46 | raise KeyError( 47 | "For checkpoint {}, expected list of params: {}, " 48 | "but found: {}".format(i, params_keys, model_params_keys) 49 | ) 50 | 51 | average_dict = OrderedDict() 52 | for k in state_dicts[0].keys(): 53 | average_dict[k] = torch.div( 54 | sum(state_dict[k] for state_dict in state_dicts), len(state_dicts) 55 | ) 56 | 57 | return average_dict 58 | 59 | 60 | def get_averaged_weights_by_path_mask( 61 | path_mask: str, logdir: Union[str, Path] = None 62 | ) -> OrderedDict: 63 | """ 64 | Averaging of input weights and saving them. 65 | 66 | Args: 67 | path_mask: globe-like pattern for models to average 68 | logdir: Path to logs directory 69 | 70 | Returns: 71 | Averaged weights 72 | """ 73 | if logdir is None: 74 | models_pathes = glob.glob(path_mask) 75 | else: 76 | models_pathes = glob.glob(os.path.join(logdir, "checkpoints", path_mask)) 77 | 78 | all_weights = [_load_weights(path) for path in models_pathes] 79 | averaged_dict = average_weights(all_weights) 80 | 81 | return averaged_dict 82 | 83 | 84 | __all__ = ["average_weights", "get_averaged_weights_by_path_mask"] 85 | -------------------------------------------------------------------------------- /catalyst/core/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # import order: 3 | # engine 4 | # runner 5 | # callback 6 | # logger 7 | 8 | from catalyst.core.engine import Engine 9 | from catalyst.core.runner import IRunner, IRunnerError 10 | from catalyst.core.callback import ( 11 | ICallback, 12 | Callback, 13 | CallbackOrder, 14 | ICriterionCallback, 15 | IBackwardCallback, 16 | IOptimizerCallback, 17 | ISchedulerCallback, 18 | CallbackWrapper, 19 | ) 20 | from catalyst.core.logger import ILogger 21 | -------------------------------------------------------------------------------- /catalyst/data/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from catalyst.data.dataset import DatasetFromSampler, SelfSupervisedDatasetWrapper 3 | from catalyst.data.loader import ( 4 | ILoaderWrapper, 5 | BatchLimitLoaderWrapper, 6 | BatchPrefetchLoaderWrapper, 7 | ) 8 | from catalyst.data.sampler import ( 9 | BalanceClassSampler, 10 | BatchBalanceClassSampler, 11 | DistributedSamplerWrapper, 12 | DynamicBalanceClassSampler, 13 | MiniEpochSampler, 14 | ) 15 | 16 | # from catalyst.contrib.data import * 17 | -------------------------------------------------------------------------------- /catalyst/dl/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.core import * 4 | from catalyst.engines import * 5 | from catalyst.loggers import * 6 | from catalyst.runners import * 7 | from catalyst.callbacks import * 8 | -------------------------------------------------------------------------------- /catalyst/engines/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.core.engine import Engine 4 | 5 | from catalyst.engines.torch import ( 6 | CPUEngine, 7 | GPUEngine, 8 | Engine, 9 | DataParallelEngine, 10 | DistributedDataParallelEngine, 11 | ) 12 | 13 | __all__ = [ 14 | "Engine", 15 | "CPUEngine", 16 | "GPUEngine", 17 | "DataParallelEngine", 18 | "DistributedDataParallelEngine", 19 | ] 20 | 21 | from catalyst.settings import SETTINGS 22 | 23 | if SETTINGS.xla_required: 24 | from catalyst.engines.torch import DistributedXLAEngine 25 | 26 | __all__ += ["DistributedXLAEngine"] 27 | -------------------------------------------------------------------------------- /catalyst/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from catalyst.extras.frozen_class import FrozenClass 3 | from catalyst.extras.forward_wrapper import ModelForwardWrapper 4 | from catalyst.extras.metric_handler import MetricHandler 5 | from catalyst.extras.time_manager import TimeManager 6 | -------------------------------------------------------------------------------- /catalyst/extras/forward_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ModelForwardWrapper(nn.Module): 5 | """Model that calls specified method instead of forward. 6 | 7 | Args: 8 | model: @TODO: docs 9 | method_name: @TODO: docs 10 | 11 | (Workaround, single method tracing is not supported) 12 | """ 13 | 14 | def __init__(self, model, method_name): 15 | """Init""" 16 | super().__init__() 17 | self.model = model 18 | self.method_name = method_name 19 | 20 | def forward(self, *args, **kwargs): 21 | """Forward pass. 22 | 23 | Args: 24 | *args: some args 25 | **kwargs: some kwargs 26 | 27 | Returns: 28 | output: specified method output 29 | """ 30 | return getattr(self.model, self.method_name)(*args, **kwargs) 31 | 32 | 33 | __all__ = ["ModelForwardWrapper"] 34 | -------------------------------------------------------------------------------- /catalyst/extras/frozen_class.py: -------------------------------------------------------------------------------- 1 | class FrozenClass: 2 | """Class which prohibit ``__setattr__`` on existing attributes. 3 | 4 | Examples: 5 | >>> class IRunner(FrozenClass): 6 | """ 7 | 8 | __is_frozen = False 9 | 10 | def __setattr__(self, key, value): 11 | """@TODO: Docs. Contribution is welcome.""" 12 | if self.__is_frozen and not hasattr(self, key): 13 | raise TypeError("%r is a frozen class for key %s" % (self, key)) 14 | object.__setattr__(self, key, value) 15 | 16 | def _freeze(self): 17 | self.__is_frozen = True 18 | 19 | def _unfreeze(self): 20 | self.__is_frozen = False 21 | 22 | 23 | __all__ = ["FrozenClass"] 24 | -------------------------------------------------------------------------------- /catalyst/extras/metric_handler.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from functools import partial 3 | 4 | 5 | def _is_better_min(score, best, min_delta): 6 | return score <= (best - min_delta) 7 | 8 | 9 | def _is_better_max(score, best, min_delta): 10 | return score >= (best + min_delta) 11 | 12 | 13 | class MetricHandler: 14 | """Docs. 15 | 16 | Args: 17 | minimize: Docs 18 | min_delta: Docs 19 | """ 20 | 21 | def __init__(self, minimize: bool = True, min_delta: float = 1e-6): 22 | """Init.""" 23 | self.minimize = minimize 24 | self.min_delta = min_delta 25 | if self.minimize: 26 | self.is_better = partial(_is_better_min, min_delta=min_delta) 27 | else: 28 | self.is_better = partial(_is_better_max, min_delta=min_delta) 29 | 30 | def __call__(self, score, best_score): 31 | """Docs.""" 32 | return self.is_better(score, best_score) 33 | 34 | 35 | __all__ = ["MetricHandler"] 36 | -------------------------------------------------------------------------------- /catalyst/extras/time_manager.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | 4 | class TimeManager(object): 5 | """Docs.""" 6 | 7 | def __init__(self): 8 | """Initialization""" 9 | self._starts = {} 10 | self.elapsed = {} 11 | 12 | def start(self, name: str) -> None: 13 | """Starts timer ``name``. 14 | 15 | Args: 16 | name: name of a timer 17 | """ 18 | self._starts[name] = time() 19 | 20 | def stop(self, name: str) -> None: 21 | """Stops timer ``name``. 22 | 23 | Args: 24 | name: name of a timer 25 | """ 26 | assert name in self._starts, f"Timer '{name}' wasn't started" 27 | 28 | self.elapsed[name] = time() - self._starts[name] 29 | del self._starts[name] 30 | 31 | def reset(self) -> None: 32 | """Reset all previous timers.""" 33 | self.elapsed = {} 34 | self._starts = {} 35 | 36 | 37 | __all__ = ["TimeManager"] 38 | -------------------------------------------------------------------------------- /catalyst/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.loggers.console import ConsoleLogger 4 | from catalyst.loggers.csv import CSVLogger 5 | from catalyst.loggers.tensorboard import TensorboardLogger 6 | from catalyst.settings import SETTINGS 7 | 8 | if SETTINGS.mlflow_required: 9 | from catalyst.loggers.mlflow import MLflowLogger 10 | 11 | if SETTINGS.wandb_required: 12 | from catalyst.loggers.wandb import WandbLogger 13 | 14 | if SETTINGS.neptune_required: 15 | from catalyst.loggers.neptune import NeptuneLogger 16 | 17 | if SETTINGS.comet_required: 18 | from catalyst.loggers.comet import CometLogger 19 | 20 | __all__ = ["ConsoleLogger", "CSVLogger", "TensorboardLogger"] 21 | 22 | if SETTINGS.mlflow_required: 23 | __all__ += ["MLflowLogger"] 24 | 25 | if SETTINGS.wandb_required: 26 | __all__ += ["WandbLogger"] 27 | 28 | if SETTINGS.neptune_required: 29 | __all__ += ["NeptuneLogger"] 30 | 31 | if SETTINGS.comet_required: 32 | __all__ += ["CometLogger"] 33 | -------------------------------------------------------------------------------- /catalyst/loggers/console.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, TYPE_CHECKING 2 | 3 | from catalyst.core.logger import ILogger 4 | 5 | if TYPE_CHECKING: 6 | from catalyst.core.runner import IRunner 7 | 8 | 9 | def _format_metrics(dct: Dict): 10 | return " | ".join([f"{k}: {float(dct[k])}" for k in sorted(dct.keys())]) 11 | 12 | 13 | class ConsoleLogger(ILogger): 14 | """Console logger for parameters and metrics. 15 | Output the metric into the console during experiment. 16 | 17 | Args: 18 | log_hparams: boolean flag to print all hparams to the console (default: False) 19 | 20 | .. note:: 21 | This logger is used by default by all Runners. 22 | """ 23 | 24 | def __init__(self, log_hparams: bool = False): 25 | super().__init__(log_batch_metrics=False, log_epoch_metrics=True) 26 | self._log_hparams = log_hparams 27 | 28 | def log_hparams(self, hparams: Dict, runner: "IRunner" = None) -> None: 29 | """Logs hyperparameters to the console.""" 30 | if self._log_hparams: 31 | print(f"Hparams: {hparams}") 32 | 33 | def log_metrics( 34 | self, 35 | metrics: Dict[str, float], 36 | scope: str, 37 | runner: "IRunner", 38 | ) -> None: 39 | """Logs loader and epoch metrics to stdout.""" 40 | if scope == "loader": 41 | prefix = f"{runner.loader_key} ({runner.epoch_step}/{runner.num_epochs}) " 42 | msg = prefix + _format_metrics(metrics) 43 | print(msg) 44 | elif scope == "epoch": 45 | # @TODO: remove trick to save pure epoch-based metrics, like lr/momentum 46 | prefix = f"* Epoch ({runner.epoch_step}/{runner.num_epochs}) " 47 | msg = prefix + _format_metrics(metrics["_epoch_"]) 48 | print(msg) 49 | 50 | 51 | __all__ = ["ConsoleLogger"] 52 | -------------------------------------------------------------------------------- /catalyst/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # import order: 3 | # functional 4 | # core metrics 5 | # metrics 6 | 7 | from catalyst.metrics.functional import * 8 | 9 | from catalyst.metrics._metric import ( 10 | ICallbackBatchMetric, 11 | ICallbackLoaderMetric, 12 | IMetric, 13 | ) 14 | from catalyst.metrics._accumulative import AccumulativeMetric 15 | from catalyst.metrics._additive import AdditiveMetric, AdditiveValueMetric 16 | from catalyst.metrics._confusion_matrix import ConfusionMatrixMetric 17 | from catalyst.metrics._functional_metric import ( 18 | FunctionalBatchMetric, 19 | FunctionalLoaderMetric, 20 | ) 21 | from catalyst.metrics._topk_metric import TopKMetric 22 | 23 | from catalyst.metrics._accuracy import AccuracyMetric, MultilabelAccuracyMetric 24 | from catalyst.metrics._auc import AUCMetric 25 | from catalyst.metrics._classification import ( 26 | BinaryPrecisionRecallF1Metric, 27 | MulticlassPrecisionRecallF1SupportMetric, 28 | MultilabelPrecisionRecallF1SupportMetric, 29 | ) 30 | from catalyst.metrics._cmc_score import CMCMetric, ReidCMCMetric 31 | from catalyst.metrics._hitrate import HitrateMetric 32 | from catalyst.metrics._map import MAPMetric 33 | from catalyst.metrics._mrr import MRRMetric 34 | from catalyst.metrics._ndcg import NDCGMetric 35 | from catalyst.metrics._r2_squared import R2Squared 36 | from catalyst.metrics._segmentation import ( 37 | RegionBasedMetric, 38 | IOUMetric, 39 | DiceMetric, 40 | TrevskyMetric, 41 | ) 42 | -------------------------------------------------------------------------------- /catalyst/metrics/_r2_squared.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from catalyst.metrics._metric import ICallbackLoaderMetric 6 | 7 | 8 | class R2Squared(ICallbackLoaderMetric): 9 | """This metric accumulates r2 score along loader 10 | 11 | Args: 12 | compute_on_call: if True, allows compute metric's value on call 13 | prefix: metric prefix 14 | suffix: metric suffix 15 | """ 16 | 17 | def __init__( 18 | self, 19 | compute_on_call: bool = True, 20 | prefix: Optional[str] = None, 21 | suffix: Optional[str] = None, 22 | ) -> None: 23 | """Init R2Squared""" 24 | super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix) 25 | self.metric_name = f"{self.prefix}r2squared{self.suffix}" 26 | self.num_examples = 0 27 | self.delta_sum = 0 28 | self.y_sum = 0 29 | self.y_sq_sum = 0 30 | 31 | def reset(self, num_batches: int, num_samples: int) -> None: 32 | """ 33 | Reset metrics fields 34 | """ 35 | self.num_examples = 0 36 | self.delta_sum = 0 37 | self.y_sum = 0 38 | self.y_sq_sum = 0 39 | 40 | def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None: 41 | """ 42 | Update accumulated data with new batch 43 | """ 44 | self.num_examples += len(y_true) 45 | self.delta_sum += torch.sum(torch.pow(y_pred - y_true, 2)) 46 | self.y_sum += torch.sum(y_true) 47 | self.y_sq_sum += torch.sum(torch.pow(y_true, 2)) 48 | 49 | def compute(self) -> torch.Tensor: 50 | """ 51 | Return accumulated metric 52 | """ 53 | return 1 - self.delta_sum / ( 54 | self.y_sq_sum - (self.y_sum ** 2) / self.num_examples 55 | ) 56 | 57 | def compute_key_value(self) -> torch.Tensor: 58 | """ 59 | Return key-value 60 | """ 61 | r2squared = self.compute() 62 | output = {self.metric_name: r2squared} 63 | return output 64 | 65 | 66 | __all__ = ["R2Squared"] 67 | -------------------------------------------------------------------------------- /catalyst/metrics/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from catalyst.metrics.functional._misc import ( 3 | check_consistent_length, 4 | process_multilabel_components, 5 | process_recsys_components, 6 | get_binary_statistics, 7 | get_multiclass_statistics, 8 | get_multilabel_statistics, 9 | get_default_topk, 10 | ) 11 | 12 | from catalyst.metrics.functional._accuracy import accuracy, multilabel_accuracy 13 | from catalyst.metrics.functional._auc import auc, binary_auc 14 | from catalyst.metrics.functional._average_precision import ( 15 | average_precision, 16 | mean_average_precision, 17 | binary_average_precision, 18 | ) 19 | from catalyst.metrics.functional._classification import precision_recall_fbeta_support 20 | from catalyst.metrics.functional._cmc_score import ( 21 | cmc_score, 22 | cmc_score_count, 23 | masked_cmc_score, 24 | ) 25 | from catalyst.metrics.functional._f1_score import f1_score, fbeta_score 26 | from catalyst.metrics.functional._focal import ( 27 | sigmoid_focal_loss, 28 | reduced_focal_loss, 29 | ) 30 | from catalyst.metrics.functional._hitrate import hitrate 31 | from catalyst.metrics.functional._mrr import reciprocal_rank, mrr 32 | from catalyst.metrics.functional._ndcg import dcg, ndcg 33 | from catalyst.metrics.functional._precision import precision 34 | from catalyst.metrics.functional._r2_squared import r2_squared 35 | from catalyst.metrics.functional._recall import recall 36 | from catalyst.metrics.functional._segmentation import ( 37 | iou, 38 | dice, 39 | trevsky, 40 | get_segmentation_statistics, 41 | ) 42 | -------------------------------------------------------------------------------- /catalyst/metrics/functional/_precision.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | 5 | from catalyst.metrics.functional._classification import precision_recall_fbeta_support 6 | 7 | 8 | def precision( 9 | outputs: torch.Tensor, 10 | targets: torch.Tensor, 11 | argmax_dim: int = -1, 12 | eps: float = 1e-7, 13 | num_classes: Optional[int] = None, 14 | ) -> Union[float, torch.Tensor]: 15 | """ 16 | Multiclass precision score. 17 | 18 | Args: 19 | outputs: estimated targets as predicted by a model 20 | with shape [bs; ..., (num_classes or 1)] 21 | targets: ground truth (correct) target values 22 | with shape [bs; ..., 1] 23 | argmax_dim: int, that specifies dimension for argmax transformation 24 | in case of scores/probabilities in ``outputs`` 25 | eps: float. Epsilon to avoid zero division. 26 | num_classes: int, that specifies number of classes if it known 27 | 28 | Returns: 29 | Tensor: precision for every class 30 | 31 | Examples: 32 | 33 | .. code-block:: python 34 | 35 | import torch 36 | from catalyst import metrics 37 | metrics.precision( 38 | outputs=torch.tensor([ 39 | [1, 0, 0], 40 | [0, 1, 0], 41 | [0, 0, 1], 42 | ]), 43 | targets=torch.tensor([0, 1, 2]), 44 | ) 45 | # tensor([1., 1., 1.]) 46 | 47 | 48 | .. code-block:: python 49 | 50 | import torch 51 | from catalyst import metrics 52 | metrics.precision( 53 | outputs=torch.tensor([[0, 0, 1, 1, 0, 1, 0, 1]]), 54 | targets=torch.tensor([[0, 1, 0, 1, 0, 0, 1, 1]]), 55 | ) 56 | # tensor([0.5000, 0.5000] 57 | """ 58 | precision_score, _, _, _, = precision_recall_fbeta_support( 59 | outputs=outputs, 60 | targets=targets, 61 | argmax_dim=argmax_dim, 62 | eps=eps, 63 | num_classes=num_classes, 64 | ) 65 | return precision_score 66 | 67 | 68 | __all__ = ["precision"] 69 | -------------------------------------------------------------------------------- /catalyst/metrics/functional/_r2_squared.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import torch 4 | 5 | 6 | def r2_squared(outputs: torch.Tensor, targets: torch.Tensor) -> Sequence[torch.Tensor]: 7 | """ 8 | Computes regression r2 squared. 9 | 10 | Args: 11 | outputs: model outputs 12 | with shape [bs; 1] 13 | targets: ground truth 14 | with shape [bs; 1] 15 | 16 | Returns: 17 | float of computed r2 squared 18 | 19 | Examples: 20 | 21 | .. code-block:: python 22 | 23 | import torch 24 | from catalyst import metrics 25 | metrics.r2_squared( 26 | outputs=torch.tensor([0, 1, 2]), 27 | targets=torch.tensor([0, 1, 2]), 28 | ) 29 | # tensor([1.]) 30 | 31 | 32 | .. code-block:: python 33 | 34 | import torch 35 | from catalyst import metrics 36 | metrics.r2_squared( 37 | outputs=torch.tensor([2.5, 0.0, 2, 8]), 38 | targets=torch.tensor([3, -0.5, 2, 7]), 39 | ) 40 | # tensor([0.9486]) 41 | """ 42 | total_sum_of_squares = torch.sum( 43 | torch.pow(targets.float() - torch.mean(targets.float()), 2) 44 | ).view(-1) 45 | residual_sum_of_squares = torch.sum( 46 | torch.pow(targets.float() - outputs.float(), 2) 47 | ).view(-1) 48 | output = 1 - residual_sum_of_squares / total_sum_of_squares 49 | return output 50 | 51 | 52 | __all__ = ["r2_squared"] 53 | -------------------------------------------------------------------------------- /catalyst/metrics/functional/_recall.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | 5 | from catalyst.metrics.functional._classification import precision_recall_fbeta_support 6 | 7 | 8 | def recall( 9 | outputs: torch.Tensor, 10 | targets: torch.Tensor, 11 | argmax_dim: int = -1, 12 | eps: float = 1e-7, 13 | num_classes: Optional[int] = None, 14 | ) -> Union[float, torch.Tensor]: 15 | """ 16 | Multiclass recall score. 17 | 18 | Args: 19 | outputs: estimated targets as predicted by a model 20 | with shape [bs; ..., (num_classes or 1)] 21 | targets: ground truth (correct) target values 22 | with shape [bs; ..., 1] 23 | argmax_dim: int, that specifies dimension for argmax transformation 24 | in case of scores/probabilities in ``outputs`` 25 | eps: float. Epsilon to avoid zero division. 26 | num_classes: int, that specifies number of classes if it known 27 | 28 | Returns: 29 | Tensor: recall for every class 30 | 31 | Examples: 32 | 33 | .. code-block:: python 34 | 35 | import torch 36 | from catalyst import metrics 37 | metrics.recall( 38 | outputs=torch.tensor([ 39 | [1, 0, 0], 40 | [0, 1, 0], 41 | [0, 0, 1], 42 | ]), 43 | targets=torch.tensor([0, 1, 2]), 44 | ) 45 | # tensor([1., 1., 1.]) 46 | 47 | 48 | .. code-block:: python 49 | 50 | import torch 51 | from catalyst import metrics 52 | metrics.recall( 53 | outputs=torch.tensor([[0, 0, 1, 1, 0, 1, 0, 1]]), 54 | targets=torch.tensor([[0, 1, 0, 1, 0, 0, 1, 1]]), 55 | ) 56 | # tensor([0.5000, 0.5000] 57 | """ 58 | _, recall_score, _, _ = precision_recall_fbeta_support( 59 | outputs=outputs, 60 | targets=targets, 61 | argmax_dim=argmax_dim, 62 | eps=eps, 63 | num_classes=num_classes, 64 | ) 65 | 66 | return recall_score 67 | 68 | 69 | __all__ = ["recall"] 70 | -------------------------------------------------------------------------------- /catalyst/registry.py: -------------------------------------------------------------------------------- 1 | import hydra_slayer 2 | 3 | REGISTRY = hydra_slayer.Registry() 4 | # Registry = REGISTRY.add 5 | 6 | 7 | __all__ = ["REGISTRY"] 8 | -------------------------------------------------------------------------------- /catalyst/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from catalyst.runners.runner import Runner 4 | from catalyst.runners.supervised import ISupervisedRunner, SupervisedRunner 5 | 6 | 7 | __all__ = [ 8 | "Runner", 9 | "ISupervisedRunner", 10 | "SupervisedRunner", 11 | ] 12 | -------------------------------------------------------------------------------- /catalyst/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | All Catalyst custom types are defined in this module. 3 | """ 4 | from typing import Dict, Union 5 | from numbers import Number 6 | from pathlib import Path 7 | 8 | import torch 9 | from torch import nn, optim 10 | from torch.optim import lr_scheduler 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch.utils import data 13 | from torch.utils.data import sampler 14 | 15 | Directory = Path 16 | File = Path 17 | 18 | TorchModel = nn.Module 19 | TorchCriterion = nn.Module 20 | TorchOptimizer = optim.Optimizer 21 | TorchScheduler = (lr_scheduler._LRScheduler, ReduceLROnPlateau) 22 | TorchDataset = data.Dataset 23 | TorchSampler = sampler.Sampler 24 | 25 | RunnerDevice = Union[str, torch.device] 26 | RunnerModel = Union[TorchModel, Dict[str, TorchModel]] 27 | RunnerCriterion = Union[TorchCriterion, Dict[str, TorchCriterion]] 28 | RunnerOptimizer = Union[TorchOptimizer, Dict[str, TorchOptimizer]] 29 | _torch_scheduler = Union[lr_scheduler._LRScheduler, ReduceLROnPlateau] 30 | RunnerScheduler = Union[_torch_scheduler, Dict[str, _torch_scheduler]] 31 | 32 | __all__ = [ 33 | "Number", 34 | "Directory", 35 | "File", 36 | "TorchModel", 37 | "TorchCriterion", 38 | "TorchOptimizer", 39 | "TorchScheduler", 40 | "TorchDataset", 41 | "TorchSampler", 42 | "RunnerDevice", 43 | "RunnerModel", 44 | "RunnerCriterion", 45 | "RunnerOptimizer", 46 | "RunnerScheduler", 47 | ] 48 | -------------------------------------------------------------------------------- /catalyst/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from catalyst.settings import SETTINGS 3 | 4 | if SETTINGS.yaml_required: 5 | from catalyst.utils.config import save_config, load_config 6 | from catalyst.utils.distributed import ( 7 | get_backend, 8 | get_world_size, 9 | get_rank, 10 | get_nn_from_ddp_module, 11 | sum_reduce, 12 | mean_reduce, 13 | all_gather, 14 | ddp_reduce, 15 | ) 16 | 17 | from catalyst.utils.misc import ( 18 | get_utcnow_time, 19 | maybe_recursive_call, 20 | get_attr, 21 | set_global_seed, 22 | boolean_flag, 23 | merge_dicts, 24 | flatten_dict, 25 | get_hash, 26 | get_short_hash, 27 | make_tuple, 28 | pairwise, 29 | get_by_keys, 30 | ) 31 | 32 | from catalyst.utils.onnx import onnx_export 33 | 34 | if SETTINGS.onnx_required: 35 | from catalyst.utils.onnx import quantize_onnx_model 36 | 37 | if SETTINGS.pruning_required: 38 | from catalyst.utils.pruning import prune_model, remove_reparametrization 39 | 40 | if SETTINGS.quantization_required: 41 | from catalyst.utils.quantization import quantize_model 42 | 43 | from catalyst.utils.torch import ( 44 | get_optimizer_momentum, 45 | get_optimizer_momentum_list, 46 | set_optimizer_momentum, 47 | get_device, 48 | get_available_gpus, 49 | get_available_engine, 50 | any2device, 51 | prepare_cudnn, 52 | get_requires_grad, 53 | set_requires_grad, 54 | pack_checkpoint, 55 | unpack_checkpoint, 56 | save_checkpoint, 57 | load_checkpoint, 58 | soft_update, 59 | mixup_batch, 60 | ) 61 | 62 | from catalyst.utils.tracing import trace_model 63 | 64 | 65 | # from catalyst.contrib.utils import * 66 | -------------------------------------------------------------------------------- /catalyst/utils/quantization.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | 3 | import torch 4 | from torch import quantization 5 | 6 | from catalyst.typing import TorchModel 7 | from catalyst.utils.distributed import get_nn_from_ddp_module 8 | 9 | 10 | def quantize_model( 11 | model: TorchModel, 12 | qconfig_spec: Dict = None, 13 | dtype: Union[str, Optional[torch.dtype]] = "qint8", 14 | ) -> TorchModel: 15 | """Function to quantize model weights. 16 | 17 | Args: 18 | model: model to be quantized 19 | qconfig_spec (Dict, optional): quantization config in PyTorch format. 20 | Defaults to None. 21 | dtype: Type of weights after quantization. 22 | Defaults to "qint8". 23 | 24 | Returns: 25 | Model: quantized model 26 | """ 27 | nn_model = get_nn_from_ddp_module(model) 28 | if isinstance(dtype, str): 29 | type_mapping = {"qint8": torch.qint8, "quint8": torch.quint8} 30 | try: 31 | quantized_model = quantization.quantize_dynamic( 32 | nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype] 33 | ) 34 | except RuntimeError: 35 | torch.backends.quantized.engine = "qnnpack" 36 | quantized_model = quantization.quantize_dynamic( 37 | nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype] 38 | ) 39 | 40 | return quantized_model 41 | 42 | 43 | __all__ = ["quantize_model"] 44 | -------------------------------------------------------------------------------- /catalyst/utils/tracing.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | from torch import jit 5 | 6 | from catalyst.extras.forward_wrapper import ModelForwardWrapper 7 | from catalyst.typing import TorchModel 8 | from catalyst.utils.distributed import get_nn_from_ddp_module 9 | 10 | 11 | def trace_model( 12 | model: TorchModel, 13 | batch: Union[Tuple[torch.Tensor], torch.Tensor], 14 | method_name: str = "forward", 15 | ) -> jit.ScriptModule: 16 | """Traces model using runner and batch. 17 | 18 | Args: 19 | model: Model to trace 20 | batch: Batch to trace the model 21 | method_name: Model's method name that will be 22 | used as entrypoint during tracing 23 | 24 | Example: 25 | .. code-block:: python 26 | 27 | import torch 28 | 29 | from catalyst.utils import trace_model 30 | 31 | class LinModel(torch.nn.Module): 32 | def __init__(self): 33 | super().__init__() 34 | self.lin1 = torch.nn.Linear(10, 10) 35 | self.lin2 = torch.nn.Linear(2, 10) 36 | 37 | def forward(self, inp_1, inp_2): 38 | return self.lin1(inp_1), self.lin2(inp_2) 39 | 40 | def first_only(self, inp_1): 41 | return self.lin1(inp_1) 42 | 43 | lin_model = LinModel() 44 | traced_model = trace_model( 45 | lin_model, batch=torch.randn(1, 10), method_name="first_only" 46 | ) 47 | 48 | Returns: 49 | jit.ScriptModule: Traced model 50 | """ 51 | nn_model = get_nn_from_ddp_module(model) 52 | wrapped_model = ModelForwardWrapper(model=nn_model, method_name=method_name) 53 | traced = jit.trace(wrapped_model, example_inputs=batch) 54 | return traced 55 | 56 | 57 | __all__ = ["trace_model"] 58 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ## Catalyst Docker 2 | [![Docker Pulls](https://img.shields.io/docker/pulls/catalystteam/catalyst)](https://hub.docker.com/r/catalystteam/catalyst/tags) 3 | 4 | - `catalystteam/catalyst:{CATALYST_VERSION}` – simple image with Catalyst 5 | - `catalystteam/catalyst:{CATALYST_VERSION}-fp16` – Catalyst with FP16 6 | - `catalystteam/catalyst:{CATALYST_VERSION}-dev` – Catalyst for development with all the requirements 7 | - `catalystteam/catalyst:{CATALYST_VERSION}-dev-fp16` – Catalyst for development with FP16 8 | 9 | ### Base version 10 | Base docker has Catalyst and all needed requirements. 11 | ```bash 12 | make docker 13 | ``` 14 | 15 | With FP16 16 | ```bash 17 | make docker-fp16 18 | ``` 19 | 20 | ### Developer version 21 | 22 | The developer version contains [packages](/requirements/requirements-dev.txt) for building docs, for checking the code style. 23 | And does not contain Catalyst itself. 24 | ```bash 25 | make docker-dev 26 | ``` 27 | 28 | With FP16 29 | ```bash 30 | make docker-dev-fp16 31 | ``` 32 | 33 | ## How to use 34 | 35 | ```bash 36 | export GPUS=... 37 | export LOGDIR=... 38 | docker run -it --rm --runtime=nvidia \ 39 | -v $(pwd):/workspace -v $LOGDIR:/logdir/ \ 40 | -e "CUDA_VISIBLE_DEVICES=${GPUS}" \ 41 | catalyst-base catalyst-dl run \ 42 | --config=./configs/train.yml --logdir=/logdir 43 | ``` 44 | -------------------------------------------------------------------------------- /docker/collect_dependencies_hash.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | sha = hashlib.sha1( 5 | "".join( 6 | [ 7 | open(x).read() if y else "" 8 | for x, y in [ 9 | ("docker/Dockerfile", True), 10 | ("requirements/requirements.txt", True), 11 | ( 12 | "requirements/requirements-dev.txt", 13 | os.environ.get("CATALYST_DEV", "0") == "1", 14 | ), 15 | ( 16 | "requirements/requirements-cv.txt", 17 | os.environ.get("CATALYST_CV", "0") == "1", 18 | ), 19 | ( 20 | "requirements/requirements-ml.txt", 21 | os.environ.get("CATALYST_ML", "0") == "1", 22 | ), 23 | ( 24 | "requirements/requirements-optuna.txt", 25 | os.environ.get("CATALYST_OPTUNA", "0") == "1", 26 | ), 27 | ( 28 | "requirements/requirements-onnx.txt", 29 | os.environ.get("CATALYST_ONNX", "0") == "1", 30 | ), 31 | ( 32 | "requirements/requirements-onnx-gpu.txt", 33 | os.environ.get("CATALYST_ONNX_GPU", "0") == "1", 34 | ), 35 | ] 36 | ] 37 | ).encode() 38 | ) 39 | print(sha.hexdigest()) 40 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | Docs are available [here](https://catalyst-team.github.io/catalyst/index.html) -------------------------------------------------------------------------------- /docs/api/core.rst: -------------------------------------------------------------------------------- 1 | Core 2 | ================================================ 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | .. contents:: 8 | :local: 9 | 10 | 11 | .. automodule:: catalyst.core 12 | :members: 13 | :show-inheritance: 14 | 15 | 16 | Runner 17 | ---------------------- 18 | 19 | IRunner 20 | ~~~~~~~~~~~~~~~~~~~~~~ 21 | .. autoclass:: catalyst.core.runner.IRunner 22 | :members: seed, hparams, num_epochs, get_engine, get_loggers, get_loaders, get_model, get_criterion, get_optimizer, get_scheduler, get_callbacks, log_metrics, log_image, log_hparams, handle_batch, run 23 | :exclude-members: __init__, on_experiment_start, on_stage_start, on_epoch_start, on_loader_start, on_batch_start, on_batch_end, on_loader_end, on_epoch_end, on_stage_end, on_experiment_end 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | IRunnerError 28 | ~~~~~~~~~~~~~~~~~~~~~~ 29 | .. autoclass:: catalyst.core.runner.IRunnerError 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | Engine 35 | ---------------------- 36 | .. autoclass:: catalyst.core.engine.Engine 37 | :members: 38 | :exclude-members: __init__ 39 | :undoc-members: 40 | :show-inheritance: 41 | 42 | Callback 43 | ---------------------- 44 | 45 | ICallback 46 | ~~~~~~~~~~~~~~~~~~~~~~ 47 | .. autoclass:: catalyst.core.callback.ICallback 48 | :members: 49 | :exclude-members: __init__ 50 | :undoc-members: 51 | :show-inheritance: 52 | 53 | CallbackOrder 54 | ~~~~~~~~~~~~~~~~~~~~~~ 55 | .. autoclass:: catalyst.core.callback.CallbackOrder 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | Callback 61 | ~~~~~~~~~~~~~~~~~~~~~~ 62 | .. autoclass:: catalyst.core.callback.Callback 63 | :members: 64 | :exclude-members: __init__ 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | CallbackWrapper 69 | ~~~~~~~~~~~~~~~~~~~~~~ 70 | .. autoclass:: catalyst.core.callback.CallbackWrapper 71 | :members: 72 | :exclude-members: __init__, on_experiment_start, on_stage_start, on_epoch_start, on_loader_start, on_batch_start, on_batch_end, on_loader_end, on_epoch_end, on_stage_end, on_experiment_end 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | ILogger 77 | ---------------------- 78 | .. autoclass:: catalyst.core.logger.ILogger 79 | :members: 80 | :undoc-members: 81 | :show-inheritance: 82 | -------------------------------------------------------------------------------- /docs/api/data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | ================================================ 3 | 4 | Data subpackage has data preprocessers and dataloader abstractions. 5 | 6 | .. toctree:: 7 | :titlesonly: 8 | 9 | .. contents:: 10 | :local: 11 | 12 | Dataset 13 | -------------------------------------- 14 | 15 | DatasetFromSampler 16 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 | .. autoclass:: catalyst.data.dataset.DatasetFromSampler 18 | :exclude-members: __init__ 19 | :show-inheritance: 20 | 21 | SelfSupervisedDatasetWrapper 22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 23 | .. autoclass:: catalyst.data.dataset.SelfSupervisedDatasetWrapper 24 | :exclude-members: __init__ 25 | :show-inheritance: 26 | 27 | Loader 28 | -------------------------------------- 29 | 30 | BatchLimitLoaderWrapper 31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 32 | .. autoclass:: catalyst.data.loader.BatchLimitLoaderWrapper 33 | :exclude-members: __init__ 34 | :special-members: 35 | 36 | BatchPrefetchLoaderWrapper 37 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 38 | .. autoclass:: catalyst.data.loader.BatchPrefetchLoaderWrapper 39 | :exclude-members: __init__ 40 | :special-members: 41 | 42 | 43 | Samplers 44 | -------------------------------------- 45 | 46 | BalanceClassSampler 47 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 48 | .. autoclass:: catalyst.data.sampler.BalanceClassSampler 49 | :exclude-members: __init__ 50 | :special-members: 51 | 52 | BatchBalanceClassSampler 53 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 54 | .. autoclass:: catalyst.data.sampler.BatchBalanceClassSampler 55 | :exclude-members: __init__ 56 | :special-members: 57 | 58 | DistributedSamplerWrapper 59 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 60 | .. autoclass:: catalyst.data.sampler.DistributedSamplerWrapper 61 | :exclude-members: __init__ 62 | :special-members: 63 | 64 | DynamicBalanceClassSampler 65 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 66 | .. autoclass:: catalyst.data.sampler.DynamicBalanceClassSampler 67 | :exclude-members: __init__ 68 | :special-members: 69 | 70 | MiniEpochSampler 71 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 72 | .. autoclass:: catalyst.data.sampler.MiniEpochSampler 73 | :exclude-members: __init__ 74 | :special-members: 75 | -------------------------------------------------------------------------------- /docs/api/engines.rst: -------------------------------------------------------------------------------- 1 | Engines 2 | ================================================ 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | .. contents:: 8 | :local: 9 | 10 | You could check engines overview under `examples/engines`_ section. 11 | 12 | .. _`examples/engines`: https://github.com/catalyst-team/catalyst/tree/master/examples/engines 13 | 14 | 15 | CPUEngine 16 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 | .. autoclass:: catalyst.engines.torch.CPUEngine 18 | :exclude-members: __init__ 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | GPUEngine 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | .. autoclass:: catalyst.engines.torch.GPUEngine 25 | :exclude-members: __init__ 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | DataParallelEngine 30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 31 | .. autoclass:: catalyst.engines.torch.DataParallelEngine 32 | :exclude-members: __init__ 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | DistributedDataParallelEngine 37 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 38 | .. autoclass:: catalyst.engines.torch.DistributedDataParallelEngine 39 | :exclude-members: __init__ 40 | :undoc-members: 41 | :show-inheritance: 42 | 43 | DistributedXLAEngine 44 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 45 | .. autoclass:: catalyst.engines.torch.DistributedXLAEngine 46 | :exclude-members: __init__ 47 | :undoc-members: 48 | :show-inheritance: 49 | -------------------------------------------------------------------------------- /docs/api/loggers.rst: -------------------------------------------------------------------------------- 1 | Loggers 2 | ================================================ 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | .. contents:: 8 | :local: 9 | 10 | 11 | CometLogger 12 | -------------------------------- 13 | .. autoclass:: catalyst.loggers.comet.CometLogger 14 | :exclude-members: __init__ 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | ConsoleLogger 19 | -------------------------------- 20 | .. autoclass:: catalyst.loggers.console.ConsoleLogger 21 | :exclude-members: __init__ 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | CSVLogger 26 | -------------------------------- 27 | .. autoclass:: catalyst.loggers.csv.CSVLogger 28 | :exclude-members: __init__ 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | MLflowLogger 33 | -------------------------------- 34 | .. autoclass:: catalyst.loggers.mlflow.MLflowLogger 35 | :exclude-members: __init__ 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | NeptuneLogger 40 | -------------------------------- 41 | .. autoclass:: catalyst.loggers.neptune.NeptuneLogger 42 | :exclude-members: __init__ 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | TensorboardLogger 47 | -------------------------------- 48 | .. autoclass:: catalyst.loggers.tensorboard.TensorboardLogger 49 | :exclude-members: __init__ 50 | :undoc-members: 51 | :show-inheritance: 52 | 53 | WandbLogger 54 | -------------------------------- 55 | .. autoclass:: catalyst.loggers.wandb.WandbLogger 56 | :exclude-members: __init__ 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | -------------------------------------------------------------------------------- /docs/api/runners.rst: -------------------------------------------------------------------------------- 1 | Runners 2 | ================================================ 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | .. contents:: 8 | :local: 9 | 10 | 11 | ISupervisedRunner 12 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 13 | .. autoclass:: catalyst.runners.supervised.ISupervisedRunner 14 | :members: 15 | :exclude-members: __init__, on_experiment_start, on_stage_start, on_epoch_start, on_loader_start, on_batch_start, on_batch_end, on_loader_end, on_epoch_end, on_stage_end, on_experiment_end 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | 20 | Runner 21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | .. autoclass:: catalyst.runners.runner.Runner 23 | :members: 24 | :exclude-members: __init__, on_experiment_start, on_stage_start, on_epoch_start, on_loader_start, on_batch_start, on_batch_end, on_loader_end, on_epoch_end, on_stage_end, on_experiment_end 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | 29 | SupervisedRunner 30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 31 | .. autoclass:: catalyst.runners.supervised.SupervisedRunner 32 | :members: 33 | :exclude-members: __init__, on_experiment_start, on_stage_start, on_epoch_start, on_loader_start, on_batch_start, on_batch_end, on_loader_end, on_epoch_end, on_stage_end, on_experiment_end 34 | :undoc-members: 35 | :show-inheritance: 36 | 37 | -------------------------------------------------------------------------------- /docs/api/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ================================================ 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | .. contents:: 8 | :local: 9 | 10 | 11 | Config 12 | ~~~~~~~~~~~~~~~~~~~~~~ 13 | .. automodule:: catalyst.utils.config 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | Distributed 19 | ~~~~~~~~~~~~~~~~~~~~~~ 20 | .. automodule:: catalyst.utils.distributed 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | Misc 26 | ~~~~~~~~~~~~~~~~~~~~~~ 27 | .. automodule:: catalyst.utils.misc 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | Onnx 33 | ~~~~~~~~~~~~~~~~~~~~~~ 34 | .. automodule:: catalyst.utils.onnx 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Pruning 40 | ~~~~~~~~~~~~~~~~~~~~~~ 41 | .. automodule:: catalyst.utils.pruning 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | Quantization 47 | ~~~~~~~~~~~~~~~~~~~~~~ 48 | .. automodule:: catalyst.utils.quantization 49 | :members: 50 | :undoc-members: 51 | :show-inheritance: 52 | 53 | Torch 54 | ~~~~~~~~~~~~~~~~~~~~~~ 55 | .. automodule:: catalyst.utils.torch 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | Tracing 61 | ~~~~~~~~~~~~~~~~~~~~~~ 62 | .. automodule:: catalyst.utils.tracing 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | 68 | Image (contrib) 69 | ~~~~~~~~~~~~~~~~~~~~~~ 70 | .. automodule:: catalyst.contrib.utils.image 71 | :members: 72 | :undoc-members: 73 | :show-inheritance: 74 | 75 | Report (contrib) 76 | ~~~~~~~~~~~~~~~~~~~~~~ 77 | .. automodule:: catalyst.contrib.utils.report 78 | :members: 79 | :undoc-members: 80 | :show-inheritance: 81 | 82 | Thresholds (contrib) 83 | ~~~~~~~~~~~~~~~~~~~~~~ 84 | .. automodule:: catalyst.contrib.utils.thresholds 85 | :members: 86 | :undoc-members: 87 | :show-inheritance: 88 | 89 | Visualization (contrib) 90 | ~~~~~~~~~~~~~~~~~~~~~~~ 91 | .. automodule:: catalyst.contrib.utils.visualization 92 | :members: 93 | :undoc-members: 94 | :show-inheritance: 95 | -------------------------------------------------------------------------------- /docs/core/callback.rst: -------------------------------------------------------------------------------- 1 | Callback 2 | ============================================================================== 3 | 4 | The ``Callback`` is an abstraction that helps you to customize the logic during your run. 5 | Once again, you could do anything natively with PyTorch and Catalyst as a for-loop wrapper. 6 | However, thanks to the callbacks, it's much easier to reuse typical deep learning extensions 7 | like metrics or augmentation tricks. 8 | For example, it's much more convenient to define the required metrics: 9 | 10 | - `ML - multiclass classification`_. 11 | - `ML – RecSys`_. 12 | 13 | The Callback API is very straightforward and repeats main for-loops in our train-loop abstraction: 14 | 15 | .. image:: https://raw.githubusercontent.com/Scitator/catalyst22-post-pics/main/callback.png 16 | :alt: Runner 17 | 18 | You could find a large variety of supported callbacks under the `Callback API section`_. 19 | 20 | 21 | If you haven't found the answer for your question, feel free to `join our slack`_ for the discussion. 22 | 23 | .. _`ML - multiclass classification`: https://github.com/catalyst-team/catalyst#minimal-examples 24 | .. _`ML – RecSys`: https://github.com/catalyst-team/catalyst#minimal-examples 25 | .. _`Callback API section`: https://catalyst-team.github.io/catalyst/api/callbacks.html 26 | .. _`join our slack`: https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw 27 | -------------------------------------------------------------------------------- /docs/core/engine.rst: -------------------------------------------------------------------------------- 1 | Engine 2 | ============================================================================== 3 | 4 | ``Engine`` is the main force of the ``Runner``. 5 | It defines the logic of hardware communication and different deep learning techniques usage 6 | like distributed or half-precision training: 7 | 8 | .. image:: https://raw.githubusercontent.com/Scitator/catalyst22-post-pics/main/engine.png 9 | :alt: Engine 10 | 11 | 12 | Thanks to the ``Engine`` design, 13 | it’s straightforward to adapt your pipeline for different hardware accelerators. 14 | For example, you could easily support: 15 | 16 | - `PyTorch distributed setup`_. 17 | - `AMP distributed setup`_. 18 | - `DeepSpeed distributed setup`_. 19 | - `TPU distributed setup`_. 20 | 21 | If you haven't found the answer for your question, feel free to `join our slack`_ for the discussion. 22 | 23 | .. _`PyTorch distributed setup`: https://catalyst-team.github.io/catalyst/api/engines.html 24 | .. _`AMP distributed setup`: https://catalyst-team.github.io/catalyst/api/engines.html 25 | .. _`DeepSpeed distributed setup`: https://catalyst-team.github.io/catalyst/api/engines.html 26 | .. _`TPU distributed setup`: https://catalyst-team.github.io/catalyst/api/engines.html 27 | .. _`join our slack`: https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw 28 | -------------------------------------------------------------------------------- /docs/core/logger.rst: -------------------------------------------------------------------------------- 1 | Logger 2 | ============================================================================== 3 | 4 | Speaking about the logging, Catalyst united the monitoring system API support into one abstraction: 5 | 6 | .. image:: https://raw.githubusercontent.com/Scitator/catalyst22-post-pics/main/logger.png 7 | :alt: Runner 8 | 9 | 10 | With such a simple API, 11 | we already provide integrations for `Tensorboard`_, `MLFlow`_, Comet, Neptune and Wandb monitoring systems. 12 | All currently supported loggers you could find under the `Logger API section`_. 13 | 14 | 15 | If you haven't found the answer for your question, feel free to `join our slack`_ for the discussion. 16 | 17 | .. _`Tensorboard`: https://catalyst-team.github.io/catalyst/api/loggers.html#tensorboardlogger 18 | .. _`MLFlow`: https://catalyst-team.github.io/catalyst/api/loggers.html#mlflowlogger 19 | .. _`Logger API section`: https://catalyst-team.github.io/catalyst/api/loggers.html 20 | .. _`join our slack`: https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw 21 | -------------------------------------------------------------------------------- /docs/core/metric.rst: -------------------------------------------------------------------------------- 1 | Metric 2 | ============================================================================== 3 | 4 | Catalyst API also involves ``Metric`` abstraction 5 | for convenient metric computation during an experiment run. Its API is quite simple: 6 | 7 | .. image:: https://raw.githubusercontent.com/Scitator/catalyst22-post-pics/main/metric.png 8 | :alt: Runner 9 | 10 | 11 | You could find all the supported metrics under the `Metric API section`_. 12 | 13 | Catalyst Metric API has a default `update` and `compute` methods 14 | to support per-batch statistic accumulation and final computation during training. 15 | All metrics also support `update` and `compute` key-value extensions 16 | for convenient usage during the run — it gives you the flexibility 17 | to store any number of metrics or aggregations you want 18 | with a simple communication protocol to use for their logging. 19 | 20 | 21 | If you haven't found the answer for your question, feel free to `join our slack`_ for the discussion. 22 | 23 | 24 | .. _`Metric API section`: https://catalyst-team.github.io/catalyst/api/metrics.html#runner-metrics 25 | .. _`join our slack`: https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw 26 | -------------------------------------------------------------------------------- /docs/core/runner.rst: -------------------------------------------------------------------------------- 1 | Runner 2 | ============================================================================== 3 | 4 | ``Runner`` is an abstraction that takes all the logic of your deep learning experiment: 5 | the data you are using, the model you are training, 6 | the batch handling logic, and everything about the selected metrics and monitoring systems: 7 | 8 | .. image:: https://raw.githubusercontent.com/Scitator/catalyst22-post-pics/main/runner.png 9 | :alt: Runner 10 | 11 | 12 | The ``Runner`` has the most crucial role 13 | in connecting all other abstractions and defining the whole experiment logic into one place. 14 | Most importantly, it does not force you to use Catalyst-only primitives. 15 | It gives you a flexible way to determine 16 | the level of high-level API you want to get from the framework. 17 | 18 | For example, you could: 19 | 20 | - Define everything in a Catalyst-way with Runner and Callbacks: `ML — multiclass classification`_ example. 21 | - Write forward-backward on your own, using Catalyst as a for-loop wrapper: `custom batch-metrics logging`_ example. 22 | - Mix these approaches: `CV — MNIST GAN`_, `CV — MNIST VAE`_ examples. 23 | 24 | Finally, the ``Runner`` architecture does not depend on PyTorch in any case, providing directions for adoption for Tensorflow2 or JAX support.  25 | If you are interested in this, please check out the `Animus`_ project. 26 | 27 | Supported Runners are listed under the `Runner API section`_. 28 | 29 | 30 | If you haven't found the answer for your question, feel free to `join our slack`_ for the discussion. 31 | 32 | .. _`ML — multiclass classification`: https://github.com/catalyst-team/catalyst#minimal-examples 33 | .. _`custom batch-metrics logging`: https://github.com/catalyst-team/catalyst#minimal-examples 34 | .. _`CV — MNIST GAN`: https://github.com/catalyst-team/catalyst#minimal-examples 35 | .. _`CV — MNIST VAE`: https://github.com/catalyst-team/catalyst#minimal-examples 36 | .. _`Runner API section`: https://catalyst-team.github.io/catalyst/api/runners.html 37 | .. _`Animus`: https://github.com/Scitator/animus 38 | .. _`join our slack`: https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw 39 | -------------------------------------------------------------------------------- /docs/create_api.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | sphinx-apidoc -o ./_modules ../ -M -f 3 | 4 | -------------------------------------------------------------------------------- /docs/faq/engines.rst: -------------------------------------------------------------------------------- 1 | Engines 2 | ============================================================================== 3 | 4 | Catalyst has engines - it's an abstraction over different hardware like CPU, 5 | GPUs or TPU cores. Engine manages all hardware-dependent operations 6 | like initializations, loading checkpoints, saving experiment components in DDP setup, 7 | casting values in AMP or loss scaling, etc. 8 | 9 | Based on device availability there are 3 groups: 10 | 11 | - ``CPUEngine``, ``GPUEngine`` - run experiments using one device like CPU or GPU. 12 | 13 | 14 | - ``DataParallelEngine`` - 15 | run experiments using multiple GPUs within one process 16 | 17 | 18 | - ``DistributedDataParallelEngine``, ``DistributedXLAEngine`` - 19 | run experiments using multiple GPUs/TPUs within several processes 20 | 21 | -------------------------------------------------------------------------------- /docs/faq/intro.rst: -------------------------------------------------------------------------------- 1 | How to make X? 2 | ============================================================================== 3 | 4 | If you haven't found the answer for your question, feel free to `join our slack`_ for the discussion. 5 | 6 | .. _`join our slack`: https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw -------------------------------------------------------------------------------- /docs/faq/settings.rst: -------------------------------------------------------------------------------- 1 | Settings 2 | ============================================================================== 3 | 4 | Catalyst extensions 5 | ---------------------------------------------------- 6 | 7 | Catalyst has a wide variety of framework extensions availabe if you need them. 8 | The base Catalyst package made as light as possible to extend only the PyTorch. 9 | Nevertheless, there are much more availabe: 10 | 11 | .. code-block:: bash 12 | 13 | pip install catalyst[comet] # + comet_ml 14 | pip install catalyst[cv] # + imageio, opencv, scikit-image, torchvision, Pillow 15 | pip install catalyst[deepspeed] # + deepspeed 16 | pip install catalyst[dev] # used for Catalyst development and documentaiton rendering 17 | pip install catalyst[ml] # + scipy, matplotlib, pandas, scikit-learn 18 | pip install catalyst[mlflow] # + mlflow 19 | pip install catalyst[neptune] # + neptune-client 20 | pip install catalyst[onnx-gpu] # + onnx, onnxruntime-gpu 21 | pip install catalyst[onnx] # + onnx, onnxruntime 22 | pip install catalyst[optuna] # + optuna 23 | pip install catalyst[profiler] # + profiler 24 | pip install catalyst[wandb] # + wandb 25 | pip install catalyst[all] # + catalyst[cv], catalyst[ml], catalyst[optuna] 26 | 27 | 28 | As far as you can see, Catalyst has a lot of extensions, and not all of them are strictly required for everyday PyTorch use, so they are extras. 29 | Please see the documentaiton for notes about extra requirements. 30 | 31 | 32 | Settigns 33 | ---------------------------------------------------- 34 | 35 | To make your workflow reproducible, you could create a ``.catalyst`` file under your project or root directory so that Catalyst could understand all needed requirements during framework initialization. 36 | For example: 37 | 38 | 39 | .. code-block:: bash 40 | 41 | [catalyst] 42 | cv_required = false 43 | mlflow_required = false 44 | ml_required = true 45 | neptune_required = false 46 | optuna_required = false 47 | 48 | With such a configuration file, Catalyst will raise you an error if there are now required `catalyst[ml]` dependencies were found. 49 | 50 | 51 | If you haven't found the answer for your question, feel free to `join our slack`_ for the discussion. 52 | 53 | .. _`join our slack`: https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw 54 | -------------------------------------------------------------------------------- /docs/key.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/docs/key.enc -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | catalyst-sphinx-theme==1.1.2 3 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Catalyst examples 2 | 3 | ### Notebooks 4 | 5 | - Introduction tutorial "[Customizing what happens in `train`](./notebooks/customizing_what_happens_in_train.ipynb)" [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/customizing_what_happens_in_train.ipynb) 6 | - Demo with [customization examples](./notebooks/customization_tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/customization_tutorial.ipynb) 7 | - [Reinforcement Learning examples with Catalyst](./notebooks/reinforcement_learning.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/reinforcement_learning.ipynb) 8 | 9 | Please follow the subfolders for more examples. 10 | -------------------------------------------------------------------------------- /examples/catalyst_rl/README.md: -------------------------------------------------------------------------------- 1 | # Catalyst.RL 2.0 concept [experimental API] 2 | 3 | ## Requirements 4 | ```bash 5 | sudo apt install redis-server 6 | pip install catalyst==22.02 gym==0.18.0 redis==3.5.3 7 | ``` 8 | 9 | ## Run 10 | ```bash 11 | # Redis 12 | redis-server --port 12000 13 | 14 | # DQN 15 | python dqn.py 16 | 17 | # DDPG 18 | python ddpg.py 19 | ``` 20 | -------------------------------------------------------------------------------- /examples/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore", category=RuntimeWarning) 4 | 5 | 6 | from catalyst.registry import Registry # noqa: E402 7 | 8 | from .callbacks import DetectionMeanAveragePrecision # noqa: E402 9 | from .criterion import CenterNetCriterion, SSDCriterion # noqa: E402 10 | from .custom_runner import ( # noqa: E402 11 | CenterNetDetectionRunner, 12 | SSDDetectionRunner, 13 | YOLOXDetectionRunner, 14 | ) 15 | from .dataset import CenterNetDataset, SSDDataset, YOLOXDataset # noqa: E402 16 | from .models import ( # noqa: E402 17 | CenterNet, 18 | SingleShotDetector, 19 | yolo_x_tiny, 20 | yolo_x_small, 21 | yolo_x_medium, 22 | yolo_x_large, 23 | yolo_x_big, 24 | ) 25 | 26 | # runers 27 | Registry(SSDDetectionRunner) 28 | Registry(CenterNetDetectionRunner) 29 | Registry(YOLOXDetectionRunner) 30 | 31 | # models 32 | Registry(SingleShotDetector) 33 | Registry(CenterNet) 34 | Registry(yolo_x_tiny) 35 | Registry(yolo_x_small) 36 | Registry(yolo_x_medium) 37 | Registry(yolo_x_large) 38 | Registry(yolo_x_big) 39 | 40 | # criterions 41 | Registry(SSDCriterion) 42 | Registry(CenterNetCriterion) 43 | 44 | # callbacks 45 | Registry(DetectionMeanAveragePrecision) 46 | 47 | # datasets 48 | Registry(SSDDataset) 49 | Registry(CenterNetDataset) 50 | Registry(YOLOXDataset) 51 | -------------------------------------------------------------------------------- /examples/detection/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .centernet import CenterNet 4 | from .ssd import SingleShotDetector 5 | from .yolo_x import ( 6 | yolo_x_tiny, 7 | yolo_x_small, 8 | yolo_x_medium, 9 | yolo_x_large, 10 | yolo_x_big, 11 | ) 12 | -------------------------------------------------------------------------------- /examples/detection/requirements.txt: -------------------------------------------------------------------------------- 1 | xmltodict 2 | mean-average-precision 3 | catalyst==21.12 -------------------------------------------------------------------------------- /examples/recsys/README.md: -------------------------------------------------------------------------------- 1 | # RecSys Learning Examples 2 | 3 | 4 | ## Requirements 5 | 6 | To run examples you need `catalyst[ml]>=21.11` 7 | ```bash 8 | pip install catalyst[ml] 9 | ``` 10 | 11 | 12 | ## Datasets 13 | 14 | ### MovieLens 100K Dataset 15 | 16 | [Click here for more information](https://files.grouplens.org/datasets/movielens/ml-100k-README.txt) 17 | 18 | To make an implicit feedback dataset from an explicit one, all `ratings > 0` are replaced with `1`. 19 | Train part is 80% of users, valid - 20%. 20 | 21 | ``collate_fn`` functions split each user interactions tensor on two tensors. ``collate_fn_train``: `inputs` and `targets` 22 | tensors contain all 100% interactions. ``collate_fn_valid``: `inputs` contain 80% interactions per each user, `targets` - all 100%. 23 | 24 | 25 | ## Models 26 | 27 | ### MultiVAE 28 | 29 | Implementation based on the article [Variational Autoencoders for Collaborative Filtering](https://arxiv.org/pdf/1802.05814.pdf). 30 | Model train on `MovieLens 100K Dataset`. 31 | 32 | Run the example: 33 | ```bash 34 | python multivae.py 35 | ``` 36 | 37 | ### MultiDAE 38 | 39 | Implementation based on the article [Variational Autoencoders for Collaborative Filtering](https://arxiv.org/pdf/1802.05814.pdf). 40 | Model train on `MovieLens 100K Dataset`. 41 | 42 | Run the example: 43 | ```bash 44 | python multidae.py 45 | ``` 46 | 47 | ### MacridVAE 48 | 49 | Implementation based on the article [Learning Disentangled Representations for Recommendation](https://arxiv.org/pdf/1910.14238.pdf). 50 | Model train on `MovieLens 100K Dataset`. 51 | 52 | Run the example: 53 | ```bash 54 | python macridvae.py 55 | ``` -------------------------------------------------------------------------------- /examples/reinforcement_learning/README.md: -------------------------------------------------------------------------------- 1 | # Requirements 2 | ```bash 3 | pip install catalyst==22.02 gym==0.18.0 4 | ``` 5 | 6 | # Run 7 | ```bash 8 | # DQN 9 | python dqn.py 10 | 11 | # DDPG 12 | python ddpg.py 13 | 14 | # REINFORCE 15 | python reinforce.py 16 | ``` -------------------------------------------------------------------------------- /examples/self_supervised/.dockerignore: -------------------------------------------------------------------------------- 1 | # Exclude everything: 2 | * 3 | # Include useful: 4 | !barlow_twins.py 5 | !byol.py 6 | !common.py 7 | !datasets.py 8 | !simCLR.py 9 | !supervised_contrastive.py -------------------------------------------------------------------------------- /examples/self_supervised/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | RUN pip install catalyst[cv]==22.02 4 | RUN pip install catalyst[ml]==22.02 5 | 6 | COPY . . -------------------------------------------------------------------------------- /examples/self_supervised/check.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # pip install catalyst[cv]==21.11 4 | # pip install catalyst[ml]==21.11 5 | 6 | export NUM_EPOCHS=20 7 | export BATCH_SIZE=256 8 | export LEARNING_RATE=0.001 9 | 10 | for DATASET in "MNIST"; do 11 | for METHOD in "barlow_twins" "byol" "simCLR" "supervised_contrastive"; do 12 | python $METHOD.py \ 13 | --dataset $DATASET \ 14 | --logdir="./logs/$DATASET/$METHOD" \ 15 | --batch-size=$BATCH_SIZE \ 16 | --epochs=$NUM_EPOCHS \ 17 | --learning-rate=$LEARNING_RATE \ 18 | --verbose 19 | done 20 | done 21 | -------------------------------------------------------------------------------- /examples/self_supervised/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # pip install catalyst[cv]==22.02 4 | # pip install catalyst[ml]==22.02 5 | 6 | export NUM_EPOCHS=20 7 | export BATCH_SIZE=256 8 | export LEARNING_RATE=0.001 9 | 10 | for DATASET in "CIFAR-10" "CIFAR-100" "STL10"; do 11 | for METHOD in "barlow_twins" "byol" "simCLR" "supervised_contrastive"; do 12 | python $METHOD.py \ 13 | --dataset $DATASET \ 14 | --logdir="./logs/$DATASET/$METHOD" \ 15 | --batch-size=$BATCH_SIZE \ 16 | --epochs=$NUM_EPOCHS \ 17 | --learning-rate=$LEARNING_RATE \ 18 | --verbose 19 | done 20 | done 21 | -------------------------------------------------------------------------------- /examples/self_supervised/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/examples/self_supervised/src/__init__.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.nitpick] 2 | style = "https://raw.githubusercontent.com/catalyst-team/codestyle/v21.09.2/styles/nitpick-style-catalyst.toml" 3 | 4 | [tool.black] 5 | line-length = 89 6 | -------------------------------------------------------------------------------- /requirements/requirements-comet.txt: -------------------------------------------------------------------------------- 1 | comet_ml -------------------------------------------------------------------------------- /requirements/requirements-cv.txt: -------------------------------------------------------------------------------- 1 | imageio>=2.5.0 2 | opencv-python-headless>=4.2.0.32 3 | scikit-image<0.19.0>=0.16.1 4 | torchvision>=0.5.0 5 | Pillow>=6.1 # torchvision fix (https://github.com/python-pillow/Pillow/issues/4130) 6 | requests 7 | -------------------------------------------------------------------------------- /requirements/requirements-deepspeed.txt: -------------------------------------------------------------------------------- 1 | deepspeed>=0.4.0 -------------------------------------------------------------------------------- /requirements/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | sphinx==2.2.1 3 | Jinja2<=3.0.3 4 | docutils==0.17.1 5 | # sphinx==4.2.0 6 | # git+https://github.com/bitprophet/releases/#egg=releases 7 | # git+https://github.com/readthedocs/sphinx_rtd_theme 8 | mock==3.0.5 9 | catalyst-codestyle==21.09.2 10 | black==21.8b0 11 | click<=8.0.4 12 | catalyst-sphinx-theme==1.2.0 13 | tomlkit==0.7.2 14 | pre-commit==2.13.0 15 | path 16 | -------------------------------------------------------------------------------- /requirements/requirements-ml.txt: -------------------------------------------------------------------------------- 1 | scipy>=1.4.1 2 | matplotlib>=3.1.0 3 | pandas>=1.0.0 4 | scikit-learn>=1.0 5 | -------------------------------------------------------------------------------- /requirements/requirements-mlflow.txt: -------------------------------------------------------------------------------- 1 | mlflow -------------------------------------------------------------------------------- /requirements/requirements-neptune.txt: -------------------------------------------------------------------------------- 1 | neptune-client>=0.9.8 -------------------------------------------------------------------------------- /requirements/requirements-onnx-gpu.txt: -------------------------------------------------------------------------------- 1 | onnx 2 | onnxruntime-gpu 3 | -------------------------------------------------------------------------------- /requirements/requirements-onnx.txt: -------------------------------------------------------------------------------- 1 | onnx 2 | onnxruntime 3 | -------------------------------------------------------------------------------- /requirements/requirements-optuna.txt: -------------------------------------------------------------------------------- 1 | optuna>=2.0.0 -------------------------------------------------------------------------------- /requirements/requirements-profiler.txt: -------------------------------------------------------------------------------- 1 | torch_tb_profiler -------------------------------------------------------------------------------- /requirements/requirements-wandb.txt: -------------------------------------------------------------------------------- 1 | wandb -------------------------------------------------------------------------------- /requirements/requirements-xla.txt: -------------------------------------------------------------------------------- 1 | https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | # framework must-have 2 | numpy>=1.18 3 | torch>=1.4.0 4 | 5 | # hardware backend 6 | accelerate>=0.5.1 7 | 8 | # registry 9 | hydra-slayer>=0.4.0 10 | 11 | # progress bar 12 | tqdm>=4.33.0 13 | 14 | # for future development: 15 | # tensorboardX provides tensorboard support for any framework 16 | # it's much easier to maintain one general tb provider 17 | # rather than a bunch of framework-adapted alternatives 18 | # and it's better than `tensorboard` ;) 19 | tensorboardX>=2.1.0 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = D100,D104,D107,D200,D204,D205,D212,D214,D301,D400,D401,D402,D412,D413,D415,DAR003,DAR103,DAR203,RST201,RST203,RST210,RST213,RST301,RST304,E203,E731,N812,P101,W503,W504,W605 3 | max-line-length = 89 4 | max-doc-length = 89 5 | inline-quotes = double 6 | multiline-quotes = double 7 | docstring-quotes = double 8 | convention = google 9 | docstring_style = google 10 | strictness = short 11 | 12 | [darglint] 13 | ignore_regex=^_(.*) 14 | 15 | [isort] 16 | combine_as_imports = true 17 | order_by_type = false 18 | force_grid_wrap = 0 19 | force_sort_within_sections = true 20 | line_length = 89 21 | lines_between_types = 0 22 | multi_line_output = 3 23 | no_lines_before = STDLIB,LOCALFOLDER 24 | reverse_relative = true 25 | default_section = THIRDPARTY 26 | known_first_party = catalyst,hydra_slayer 27 | known_src = src 28 | skip_glob = **/__init__.py 29 | force_to_top = typing 30 | include_trailing_comma = true 31 | use_parentheses = true 32 | 33 | # catalyst imports order: 34 | # - typing 35 | # - core python libs 36 | # - python libs (known_third_party) 37 | # - dl libs (known_dl) 38 | # - catalyst imports 39 | sections = FUTURE,STDLIB,THIRDPARTY,DL,FIRSTPARTY,SRC,LOCALFOLDER 40 | known_dl = accelerate,tensorboardX,torch -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # Needed to collect coverage data 3 | import os 4 | import path 5 | 6 | DATA_ROOT = path.Path(os.path.abspath(os.path.dirname(__file__))).abspath() 7 | MOVIELENS_ROOT = ( 8 | path.Path(os.path.abspath(os.path.dirname(__file__))) 9 | .joinpath("./movielens") 10 | .abspath() 11 | ) 12 | MOVIELENS20M_ROOT = ( 13 | path.Path(os.path.abspath(os.path.dirname(__file__))) 14 | .joinpath("./movielens20m") 15 | .abspath() 16 | ) 17 | 18 | 19 | IS_CPU_REQUIRED = os.environ.get("CPU_REQUIRED", "0") == "1" 20 | IS_GPU_REQUIRED = os.environ.get("GPU_REQUIRED", "0") == "1" 21 | IS_GPU_AMP_REQUIRED = os.environ.get("GPU_AMP_REQUIRED", "0") == "1" 22 | IS_DP_REQUIRED = os.environ.get("DP_REQUIRED", "0") == "1" 23 | IS_DP_AMP_REQUIRED = os.environ.get("DP_AMP_REQUIRED", "0") == "1" 24 | IS_DDP_REQUIRED = os.environ.get("DDP_REQUIRED", "0") == "1" 25 | IS_DDP_AMP_REQUIRED = os.environ.get("DDP_AMP_REQUIRED", "0") == "1" 26 | IS_CONFIGS_REQUIRED = os.environ.get("CONFIGS_REQUIRED", "0") == "1" 27 | -------------------------------------------------------------------------------- /tests/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | # Needed to collect coverage data 2 | -------------------------------------------------------------------------------- /tests/catalyst/__init__.py: -------------------------------------------------------------------------------- 1 | # Needed to collect coverage data 2 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/catalyst/callbacks/__init__.py -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_batch_overfit.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import torch 3 | from torch.utils.data import DataLoader, TensorDataset 4 | 5 | from catalyst import dl, utils 6 | 7 | 8 | class BatchOverfitCallbackCheck(dl.Callback): 9 | def __init__(self): 10 | super().__init__(order=dl.CallbackOrder.external) 11 | 12 | def on_loader_start(self, runner): 13 | # 320 samples with 32 batch size 14 | # -> 1 batch size = 32 15 | # -> 0.1 portion = 32 16 | assert len(runner.loaders[runner.loader_key]) == 32 17 | 18 | 19 | def _prepare_experiment(): 20 | # data 21 | utils.set_global_seed(42) 22 | num_samples, num_features = int(32e1), int(1e1) 23 | X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) 24 | dataset = TensorDataset(X, y) 25 | loader = DataLoader(dataset, batch_size=32, num_workers=0) 26 | loaders = {"train": loader, "valid": loader} 27 | 28 | # model, criterion, optimizer, scheduler 29 | model = torch.nn.Linear(num_features, 1) 30 | criterion = torch.nn.MSELoss() 31 | optimizer = torch.optim.Adam(model.parameters()) 32 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) 33 | return loaders, model, criterion, optimizer, scheduler 34 | 35 | 36 | def test_batch_overfit(): 37 | loaders, model, criterion, optimizer, scheduler = _prepare_experiment() 38 | runner = dl.SupervisedRunner() 39 | runner.train( 40 | model=model, 41 | criterion=criterion, 42 | optimizer=optimizer, 43 | scheduler=scheduler, 44 | loaders=loaders, 45 | logdir="./logs/batch_overfit", 46 | num_epochs=1, 47 | verbose=False, 48 | callbacks=[dl.BatchOverfitCallback(train=1, valid=0.1)], 49 | ) 50 | assert runner.epoch_metrics["train"]["loss"] < 1.4 51 | assert runner.epoch_metrics["valid"]["loss"] < 1.3 52 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_misc.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from unittest.mock import MagicMock, PropertyMock 3 | 4 | from catalyst.callbacks import EarlyStoppingCallback 5 | 6 | 7 | def test_patience1(): 8 | """Tests EarlyStoppingCallback.""" 9 | early_stop = EarlyStoppingCallback( 10 | patience=1, loader_key="valid", metric_key="loss", minimize=True 11 | ) 12 | runner = MagicMock() 13 | type(runner).epoch_metrics = PropertyMock(return_value={"valid": {"loss": 0.001}}) 14 | stop_mock = PropertyMock(return_value=False) 15 | type(runner).need_early_stop = stop_mock 16 | 17 | early_stop.on_epoch_end(runner) 18 | 19 | assert stop_mock.mock_calls == [] 20 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_mixup.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, TensorDataset 6 | 7 | from catalyst import dl, utils 8 | from catalyst.callbacks import MixupCallback 9 | 10 | 11 | class DymmyRunner(dl.Runner): 12 | def handle_batch(self, batch: Tuple[torch.Tensor]): 13 | self.batch = {"image": batch[0], "clf_targets_one_hot": batch[1]} 14 | 15 | 16 | def test_mixup_1(): 17 | utils.set_global_seed(42) 18 | num_samples, num_features, num_classes = int(1e4), int(1e1), 4 19 | X = torch.rand(num_samples, num_features) 20 | y = (torch.rand(num_samples) * num_classes).to(torch.int64) 21 | y = torch.nn.functional.one_hot(y, num_classes).double() 22 | dataset = TensorDataset(X, y) 23 | loader = DataLoader(dataset, batch_size=32, num_workers=1) 24 | loaders = {"train": loader, "valid": loader} 25 | runner = DymmyRunner() 26 | callback = MixupCallback(keys=["image", "clf_targets_one_hot"]) 27 | for loader_name in ["train", "valid"]: 28 | for batch in loaders[loader_name]: 29 | runner.handle_batch(batch) 30 | callback.on_batch_start(runner) 31 | assert runner.batch["clf_targets_one_hot"].max(1)[0].mean() < 1 32 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_optimizer.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import torch 3 | import torch.nn as nn 4 | 5 | from catalyst.callbacks import backward 6 | from catalyst.callbacks.backward import BackwardCallback 7 | from catalyst.callbacks.optimizer import OptimizerCallback 8 | from catalyst.engines.torch import CPUEngine 9 | 10 | 11 | class DummyRunner: 12 | def __init__( 13 | self, 14 | loss_value: torch.tensor, 15 | model: nn.Module, 16 | optimizer: torch.optim.Optimizer, 17 | ): 18 | self.batch_metrics = {"loss": loss_value} 19 | self.is_train_loader = True 20 | self.model = model 21 | self.optimizer = optimizer 22 | # self.device = torch.device("cpu") 23 | self.engine = CPUEngine() 24 | 25 | 26 | def test_zero_grad(): 27 | model = nn.Linear(10, 2) 28 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 29 | criterion = nn.BCEWithLogitsLoss() 30 | 31 | batch_size = 3 32 | inp = torch.randn(batch_size, 10) 33 | target = torch.FloatTensor(batch_size, 2).uniform_() 34 | 35 | backward_callback = BackwardCallback(metric_key="loss") 36 | optimizer_callback = OptimizerCallback(metric_key="loss") 37 | 38 | loss1 = criterion(model(inp), target) 39 | loss1_value = loss1.detach().item() 40 | 41 | runner = DummyRunner(loss1, model, optimizer) 42 | 43 | for clb in [backward_callback, optimizer_callback]: 44 | clb.on_experiment_start(runner) 45 | clb.on_epoch_start(runner) 46 | clb.on_batch_end(runner) 47 | 48 | loss2 = criterion(model(inp), target) 49 | loss2_value = loss2.detach().item() 50 | 51 | runner.batch_metrics = {"loss": loss2} 52 | for clb in [backward_callback, optimizer_callback]: 53 | clb.on_epoch_start(runner) 54 | clb.on_batch_end(runner) 55 | 56 | assert loss1_value > loss2_value 57 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_optuna.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | import pytest 4 | 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | 9 | from catalyst import dl 10 | from catalyst.contrib.datasets import MNIST 11 | from catalyst.settings import SETTINGS 12 | from tests import DATA_ROOT 13 | 14 | if SETTINGS.optuna_required: 15 | import optuna 16 | 17 | 18 | @pytest.mark.skipif(not (SETTINGS.optuna_required), reason="No optuna required") 19 | def test_optuna(): 20 | trainset = MNIST(DATA_ROOT, train=False) 21 | testset = MNIST(DATA_ROOT, train=False) 22 | loaders = { 23 | "train": DataLoader(trainset, batch_size=32), 24 | "valid": DataLoader(testset, batch_size=64), 25 | } 26 | model = nn.Sequential( 27 | nn.Flatten(), nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) 28 | ) 29 | 30 | def objective(trial): 31 | lr = trial.suggest_loguniform("lr", 1e-3, 1e-1) 32 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 33 | criterion = nn.CrossEntropyLoss() 34 | runner = dl.SupervisedRunner() 35 | runner.train( 36 | model=model, 37 | loaders=loaders, 38 | criterion=criterion, 39 | optimizer=optimizer, 40 | callbacks={ 41 | "optuna": dl.OptunaPruningCallback( 42 | loader_key="valid", metric_key="loss", minimize=True, trial=trial 43 | ), 44 | "accuracy": dl.AccuracyCallback( 45 | num_classes=10, input_key="logits", target_key="targets" 46 | ), 47 | }, 48 | num_epochs=2, 49 | valid_metric="accuracy01", 50 | minimize_valid_metric=False, 51 | ) 52 | return trial.best_score 53 | 54 | study = optuna.create_study( 55 | direction="maximize", 56 | pruner=optuna.pruners.MedianPruner( 57 | n_startup_trials=1, n_warmup_steps=0, interval_steps=1 58 | ), 59 | ) 60 | study.optimize(objective, n_trials=2, timeout=300) 61 | assert True 62 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_quantization.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # import os 3 | 4 | # import torch 5 | # from torch import nn 6 | # from torch.utils.data import DataLoader 7 | 8 | # from catalyst import dl 9 | # from catalyst.contrib.datasets import MNIST 10 | # from catalyst.contrib.layers import Flatten 11 | 12 | 13 | # def test_pruning_callback() -> None: 14 | # """Quantize model""" 15 | # loaders = { 16 | # "train": DataLoader(MNIST(os.getcwd(), train=False), batch_size=32,), 17 | # "valid": DataLoader(MNIST(os.getcwd(), train=False), batch_size=32,), 18 | # } 19 | # model = nn.Sequential(Flatten(), nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 10)) 20 | # criterion = nn.CrossEntropyLoss() 21 | # optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) 22 | # runner = dl.SupervisedRunner() 23 | # runner.train( 24 | # model=model, 25 | # callbacks=[dl.QuantizationCallback(logdir="./logs")], 26 | # loaders=loaders, 27 | # criterion=criterion, 28 | # optimizer=optimizer, 29 | # num_epochs=1, 30 | # logdir="./logs", 31 | # check=True, 32 | # ) 33 | # assert os.path.isfile("./logs/quantized.pth") 34 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_sklearn_model.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | import pytest 4 | 5 | from catalyst.settings import SETTINGS 6 | 7 | if SETTINGS.ml_required: 8 | from catalyst.callbacks import SklearnModelCallback 9 | 10 | 11 | @pytest.mark.skipif(not (SETTINGS.ml_required), reason="catalyst[ml] required") 12 | def test_init_from_str(): 13 | 14 | pathes = [ 15 | "sklearn.ensemble.RandomForestClassifier", 16 | "sklearn.linear_model.LogisticRegression", 17 | "sklearn.cluster.KMeans", 18 | ] 19 | 20 | for fn in pathes: 21 | SklearnModelCallback( 22 | feature_key="feature_key", 23 | target_key="target_key", 24 | train_loader="train", 25 | valid_loaders="valid_loader", 26 | model_fn=fn, 27 | ) 28 | 29 | pathes_with_transform = ["sklearn.cluster.KMeans", "sklearn.decomposition.PCA"] 30 | 31 | for fn in pathes_with_transform: 32 | SklearnModelCallback( 33 | feature_key="feature_key", 34 | target_key="target_key", 35 | train_loader="train", 36 | valid_loaders="valid_loader", 37 | model_fn=fn, 38 | predict_method="transform", 39 | ) 40 | -------------------------------------------------------------------------------- /tests/catalyst/callbacks/test_soft_update.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from catalyst import dl 8 | 9 | 10 | class DummyRunner(dl.Runner): 11 | def handle_batch(self, batch: Tuple[torch.Tensor]): 12 | pass 13 | 14 | 15 | def set_requires_grad(model, val): 16 | for p in model.parameters(): 17 | p.requires_grad = val 18 | 19 | 20 | def test_soft_update(): 21 | 22 | model = nn.ModuleDict( 23 | { 24 | "target": nn.Linear(10, 10, bias=False), 25 | "source": nn.Linear(10, 10, bias=False), 26 | } 27 | ) 28 | set_requires_grad(model, False) 29 | model["target"].weight.data.fill_(0) 30 | 31 | runner = DummyRunner(model=model) 32 | runner.is_train_loader = True 33 | 34 | soft_update = dl.SoftUpdateCallaback( 35 | target_model="target", 36 | source_model="source", 37 | tau=0.1, 38 | scope="on_batch_end", 39 | ) 40 | soft_update.on_batch_end(runner) 41 | 42 | checks = ( 43 | ( 44 | (0.1 * runner.model["source"].weight.data) 45 | == runner.model["target"].weight.data 46 | ) 47 | .flatten() 48 | .tolist() 49 | ) 50 | 51 | assert all(checks) 52 | 53 | 54 | def test_soft_update_not_work(): 55 | 56 | model = nn.ModuleDict( 57 | { 58 | "target": nn.Linear(10, 10, bias=False), 59 | "source": nn.Linear(10, 10, bias=False), 60 | } 61 | ) 62 | set_requires_grad(model, False) 63 | model["target"].weight.data.fill_(0) 64 | 65 | runner = DummyRunner(model=model) 66 | runner.is_train_loader = True 67 | 68 | soft_update = dl.SoftUpdateCallaback( 69 | target_model="target", 70 | source_model="source", 71 | tau=0.1, 72 | scope="on_batch_start", 73 | ) 74 | soft_update.on_batch_end(runner) 75 | 76 | checks = (runner.model["target"].weight.data == 0).flatten().tolist() 77 | 78 | assert all(checks) 79 | -------------------------------------------------------------------------------- /tests/catalyst/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/catalyst/data/__init__.py -------------------------------------------------------------------------------- /tests/catalyst/data/test_loader.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import torch 3 | from torch.utils.data import DataLoader, TensorDataset 4 | 5 | from catalyst.data.loader import BatchLimitLoaderWrapper 6 | 7 | 8 | def test_batch_limit1() -> None: 9 | for shuffle in (False, True): 10 | num_samples, num_features = int(1e2), int(1e1) 11 | X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) 12 | dataset = TensorDataset(X, y) 13 | loader = DataLoader(dataset, batch_size=4, num_workers=1, shuffle=shuffle) 14 | loader = BatchLimitLoaderWrapper(loader, num_batches=1) 15 | 16 | batch1 = next(iter(loader))[0] 17 | batch2 = next(iter(loader))[0] 18 | batch3 = next(iter(loader))[0] 19 | assert all(torch.isclose(x, y).all() for x, y in zip(batch1, batch2)) 20 | assert all(torch.isclose(x, y).all() for x, y in zip(batch2, batch3)) 21 | 22 | 23 | def test_batch_limit2() -> None: 24 | for shuffle in (False, True): 25 | num_samples, num_features = int(1e2), int(1e1) 26 | X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) 27 | dataset = TensorDataset(X, y) 28 | loader = DataLoader(dataset, batch_size=4, num_workers=1, shuffle=shuffle) 29 | loader = BatchLimitLoaderWrapper(loader, num_batches=2) 30 | 31 | batch1 = next(iter(loader))[0] 32 | batch2 = next(iter(loader))[0] 33 | batch3 = next(iter(loader))[0] 34 | batch4 = next(iter(loader))[0] 35 | assert all(torch.isclose(x, y).all() for x, y in zip(batch1, batch3)) 36 | assert all(torch.isclose(x, y).all() for x, y in zip(batch2, batch4)) 37 | -------------------------------------------------------------------------------- /tests/catalyst/dl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/catalyst/dl/__init__.py -------------------------------------------------------------------------------- /tests/catalyst/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/catalyst/metrics/__init__.py -------------------------------------------------------------------------------- /tests/catalyst/metrics/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/catalyst/metrics/functional/__init__.py -------------------------------------------------------------------------------- /tests/catalyst/metrics/functional/test_auc.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import math 3 | 4 | import torch 5 | 6 | from catalyst.metrics.functional._auc import auc 7 | 8 | 9 | def test_auc(): 10 | """ 11 | Tests for catalyst.metrics.auc metric. 12 | """ 13 | test_size = 1000 14 | scores = torch.cat((torch.rand(test_size), torch.rand(test_size))) 15 | targets = torch.cat((torch.zeros(test_size), torch.ones(test_size))) 16 | 17 | val = auc(scores, targets) 18 | assert math.fabs(val - 0.5) < 0.1, "AUC test1 failed" 19 | 20 | scores = torch.cat( 21 | ( 22 | torch.Tensor(test_size).fill_(0), 23 | torch.Tensor(test_size).fill_(0.1), 24 | torch.Tensor(test_size).fill_(0.2), 25 | torch.Tensor(test_size).fill_(0.3), 26 | torch.Tensor(test_size).fill_(0.4), 27 | torch.ones(test_size), 28 | ) 29 | ) 30 | targets = torch.cat( 31 | ( 32 | torch.zeros(test_size), 33 | torch.zeros(test_size), 34 | torch.zeros(test_size), 35 | torch.zeros(test_size), 36 | torch.zeros(test_size), 37 | torch.ones(test_size), 38 | ) 39 | ) 40 | 41 | val = auc(scores, targets) 42 | assert math.fabs(val - 1.0) < 0.0001, "AUC test2 failed" 43 | -------------------------------------------------------------------------------- /tests/catalyst/metrics/functional/test_dice.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import torch 3 | 4 | from catalyst.metrics.functional import dice 5 | 6 | 7 | def test_dice(): 8 | """ 9 | Tests for catalyst.metrics.dice metric. 10 | """ 11 | size = 4 12 | half_size = size // 2 13 | shape = (1, 1, size, size) 14 | 15 | # check 0: one empty 16 | empty = torch.zeros(shape) 17 | full = torch.ones(shape) 18 | assert dice(empty, full, class_dim=1, mode="per-class").item() == 0 19 | 20 | # check 0: no overlap 21 | left = torch.ones(shape) 22 | left[:, :, :, half_size:] = 0 23 | right = torch.ones(shape) 24 | right[:, :, :, :half_size] = 0 25 | assert dice(left, right, class_dim=1, mode="per-class").item() == 0 26 | 27 | # check 1: both empty, both full, complete overlap 28 | assert dice(empty, empty, class_dim=1, mode="per-class").item() == 1 29 | assert dice(full, full, class_dim=1, mode="per-class").item() == 1 30 | assert dice(left, left, class_dim=1, mode="per-class").item() == 1 31 | 32 | # check 0.5: half overlap 33 | top_left = torch.zeros(shape) 34 | top_left[:, :, :half_size, :half_size] = 1 35 | assert torch.isclose( 36 | dice(top_left, left, class_dim=1, mode="per-class"), torch.Tensor([[0.6666666]]) 37 | ) 38 | assert torch.isclose( 39 | dice(top_left, left, class_dim=1, mode="micro"), torch.Tensor([[0.6666666]]) 40 | ) 41 | assert torch.isclose( 42 | dice(top_left, left, class_dim=1, mode="macro"), torch.Tensor([[0.6666666]]) 43 | ) 44 | 45 | # check multiclass: 0, 0, 1, 1, 1, 0.66667 46 | a = torch.cat([empty, left, empty, full, left, top_left], dim=1) 47 | b = torch.cat([full, right, empty, full, left, left], dim=1) 48 | ans = torch.Tensor([0, 0, 1, 1, 1, 0.6666666]) 49 | ans_micro = torch.tensor(0.6087) 50 | assert torch.allclose(dice(a, b, class_dim=1, mode="per-class"), ans) 51 | assert torch.allclose(dice(a, b, class_dim=1, mode="micro"), ans_micro) 52 | 53 | aaa = torch.cat([a, a, a], dim=0) 54 | bbb = torch.cat([b, b, b], dim=0) 55 | assert torch.allclose(dice(aaa, bbb, class_dim=1, mode="per-class"), ans) 56 | assert torch.allclose(dice(aaa, bbb, class_dim=1, mode="micro"), ans_micro) 57 | -------------------------------------------------------------------------------- /tests/catalyst/metrics/functional/test_fbeta_precision_recall.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import pytest # noqa: F401 3 | 4 | import torch 5 | 6 | from catalyst.metrics.functional import ( 7 | f1_score, 8 | fbeta_score, 9 | precision, 10 | precision_recall_fbeta_support, 11 | recall, 12 | ) 13 | 14 | 15 | def test_precision_recall_f_binary_single_class() -> None: 16 | """Metrics test""" 17 | # Test precision, recall and F-scores behave with a single positive 18 | assert 1.0 == precision([1, 1], [1, 1])[1] 19 | assert 1.0 == recall([1, 1], [1, 1])[1] 20 | assert 1.0 == f1_score([1, 1], [1, 1])[1] 21 | assert 1.0 == fbeta_score([1, 1], [1, 1], 0)[1] 22 | # test with several classes 23 | assert 3.0 == f1_score([0, 1, 2], [0, 1, 2]).sum().item() 24 | assert 3.0 == precision([0, 1, 2], [0, 1, 2]).sum().item() 25 | assert 3.0 == recall([0, 1, 2], [0, 1, 2]).sum().item() 26 | 27 | 28 | @pytest.mark.parametrize( 29 | [ 30 | "outputs", 31 | "targets", 32 | "precision_true", 33 | "recall_true", 34 | "fbeta_true", 35 | "support_true", 36 | ], 37 | [ 38 | pytest.param( 39 | torch.tensor([[0, 0, 1, 1, 0, 1, 0, 1]]), 40 | torch.tensor([[0, 1, 0, 1, 0, 0, 1, 1]]), 41 | 0.5, 42 | 0.5, 43 | 0.5, 44 | 4, 45 | ), 46 | ], 47 | ) 48 | def test_precision_recall_fbeta_support_binary( 49 | outputs, targets, precision_true, recall_true, fbeta_true, support_true 50 | ) -> None: 51 | """ 52 | Test for precision_recall_fbeta_support. 53 | 54 | Args: 55 | outputs: test arg 56 | targets: test arg 57 | precision_true: test arg 58 | recall_true: test arg 59 | fbeta_true: test arg 60 | support_true: test arg 61 | """ 62 | ( 63 | precision_score, 64 | recall_score, 65 | fbeta_score_value, 66 | support, 67 | ) = precision_recall_fbeta_support(outputs=outputs, targets=targets) 68 | 69 | assert torch.isclose(precision_score[1], torch.tensor(precision_true)) 70 | assert torch.isclose(recall_score[1], torch.tensor(recall_true)) 71 | assert torch.isclose(fbeta_score_value[1], torch.tensor(fbeta_true)) 72 | assert support[1] == support_true 73 | -------------------------------------------------------------------------------- /tests/catalyst/metrics/functional/test_hitrate.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from catalyst.metrics.functional._hitrate import hitrate 7 | 8 | 9 | def test_hitrate(): 10 | """ 11 | Tests for catalyst.metrics.hitrate metric. 12 | """ 13 | y_pred = [0.5, 0.2, 0.1] 14 | y_true = [1.0, 0.0, 1.0] 15 | k = [1, 2, 3] 16 | 17 | hitrate_at1, hitrate_at2, hitrate_at3 = hitrate( 18 | torch.Tensor([y_pred]), torch.Tensor([y_true]), k 19 | ) 20 | assert hitrate_at1 == 0.5 21 | assert hitrate_at2 == 0.5 22 | assert hitrate_at3 == 1.0 23 | 24 | # check 1 simple case 25 | y_pred = [0.5, 0.2] 26 | y_true = [0.0, 0.0] 27 | k = [2] 28 | 29 | hitrate_at2 = hitrate(torch.Tensor([y_pred]), torch.Tensor([y_true]), k)[0] 30 | assert hitrate_at2 == 0.0 31 | 32 | # check batch case 33 | y_pred1 = [4.0, 2.0, 3.0, 1.0] 34 | y_pred2 = [1.0, 2.0, 3.0, 4.0] 35 | y_true1 = [0, 0, 1.0, 1.0] 36 | y_true2 = [0, 0, 0.0, 0.0] 37 | k = [1, 2, 3, 4] 38 | 39 | y_pred_torch = torch.Tensor([y_pred1, y_pred2]) 40 | y_true_torch = torch.Tensor([y_true1, y_true2]) 41 | 42 | hitrate_at1, hitrate_at2, hitrate_at3, hitrate_at4 = hitrate( 43 | y_pred_torch, y_true_torch, k 44 | ) 45 | 46 | assert hitrate_at1 == 0.0 47 | assert hitrate_at2 == 0.25 48 | assert hitrate_at3 == 0.25 49 | assert hitrate_at4 == 0.5 50 | -------------------------------------------------------------------------------- /tests/catalyst/metrics/functional/test_iou.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import torch 3 | 4 | from catalyst.metrics.functional import iou 5 | 6 | 7 | def test_iou(): 8 | """ 9 | Tests for catalyst.metrics.iou metric. 10 | """ 11 | size = 4 12 | half_size = size // 2 13 | shape = (1, 1, size, size) 14 | 15 | # check 0: one empty 16 | empty = torch.zeros(shape) 17 | full = torch.ones(shape) 18 | assert iou(empty, full, class_dim=1, mode="per-class").item() == 0 19 | 20 | # check 0: no overlap 21 | left = torch.ones(shape) 22 | left[:, :, :, half_size:] = 0 23 | right = torch.ones(shape) 24 | right[:, :, :, :half_size] = 0 25 | assert iou(left, right, class_dim=1, mode="per-class").item() == 0 26 | 27 | # check 1: both empty, both full, complete overlap 28 | assert iou(empty, empty, class_dim=1, mode="per-class").item() == 1 29 | assert iou(full, full, class_dim=1, mode="per-class").item() == 1 30 | assert iou(left, left, class_dim=1, mode="per-class").item() == 1 31 | 32 | # check 0.5: half overlap 33 | top_left = torch.zeros(shape) 34 | top_left[:, :, :half_size, :half_size] = 1 35 | assert torch.isclose( 36 | iou(top_left, left, class_dim=1, mode="per-class"), torch.Tensor([[0.5]]) 37 | ) 38 | assert torch.isclose( 39 | iou(top_left, left, class_dim=1, mode="micro"), torch.Tensor([[0.5]]) 40 | ) 41 | assert torch.isclose( 42 | iou(top_left, left, class_dim=1, mode="macro"), torch.Tensor([[0.5]]) 43 | ) 44 | 45 | # check multiclass: 0, 0, 1, 1, 1, 0.5 46 | a = torch.cat([empty, left, empty, full, left, top_left], dim=1) 47 | b = torch.cat([full, right, empty, full, left, left], dim=1) 48 | ans = torch.Tensor([0, 0, 1, 1, 1, 0.5]) 49 | ans_micro = torch.tensor(0.4375) 50 | assert torch.allclose(iou(a, b, class_dim=1, mode="per-class"), ans) 51 | assert torch.allclose(iou(a, b, class_dim=1, mode="micro"), ans_micro) 52 | 53 | aaa = torch.cat([a, a, a], dim=0) 54 | bbb = torch.cat([b, b, b], dim=0) 55 | assert torch.allclose(iou(aaa, bbb, class_dim=1, mode="per-class"), ans) 56 | assert torch.allclose(iou(aaa, bbb, class_dim=1, mode="micro"), ans_micro) 57 | -------------------------------------------------------------------------------- /tests/catalyst/metrics/functional/test_r2_squared.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from catalyst.metrics.functional._r2_squared import r2_squared 7 | 8 | 9 | def test_r2_squared(): 10 | """ 11 | Tests for catalyst.metrics.r2_squared metric. 12 | """ 13 | y_true = torch.tensor([3, -0.5, 2, 7]) 14 | y_pred = torch.tensor([2.5, 0.0, 2, 8]) 15 | val = r2_squared(y_pred, y_true) 16 | assert torch.isclose(val, torch.Tensor([0.9486])) 17 | -------------------------------------------------------------------------------- /tests/catalyst/metrics/test_r2squared.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Dict, Iterable, Union 3 | 4 | import pytest 5 | 6 | import torch 7 | 8 | from catalyst.metrics._r2_squared import R2Squared 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "outputs,targets,true_values", 13 | ( 14 | ( 15 | torch.Tensor([2.5, 0.0, 2, 8]), 16 | torch.Tensor([3, -0.5, 2, 7]), 17 | {"r2squared": torch.Tensor([0.9486])}, 18 | ), 19 | ), 20 | ) 21 | def test_r2_squared( 22 | outputs: torch.Tensor, 23 | targets: torch.Tensor, 24 | true_values: Dict[str, torch.Tensor], 25 | ) -> None: 26 | """ 27 | Test r2 squared metric 28 | 29 | Args: 30 | outputs: tensor of outputs 31 | targets: tensor of targets 32 | true_values: true metric values 33 | """ 34 | metric = R2Squared() 35 | metric.update(y_pred=outputs, y_true=targets) 36 | metrics = metric.compute_key_value() 37 | for key in true_values.keys(): 38 | assert torch.isclose(true_values[key], metrics[key]) 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "outputs_list,targets_list,true_values", 43 | ( 44 | ( 45 | ( 46 | torch.Tensor([2.5, 0.0, 2, 8]), 47 | torch.Tensor([2.5, 0.0, 2, 8]), 48 | torch.Tensor([2.5, 0.0, 2, 8]), 49 | torch.Tensor([2.5, 0.0, 2, 8]), 50 | ), 51 | ( 52 | torch.Tensor([3, -0.5, 2, 7]), 53 | torch.Tensor([3, -0.5, 2, 7]), 54 | torch.Tensor([3, -0.5, 2, 7]), 55 | torch.Tensor([3, -0.5, 2, 7]), 56 | ), 57 | {"r2squared": torch.Tensor([0.9486])}, 58 | ), 59 | ), 60 | ) 61 | def test_r2_squared_update( 62 | outputs_list: Iterable[torch.Tensor], 63 | targets_list: Iterable[torch.Tensor], 64 | true_values: Dict[str, torch.Tensor], 65 | ): 66 | """ 67 | Test r2 squared metric computation 68 | 69 | Args: 70 | outputs_list: list of outputs 71 | targets_list: list of targets 72 | true_values: true metric values 73 | """ 74 | metric = R2Squared() 75 | for outputs, targets in zip(outputs_list, targets_list): 76 | metric.update(y_pred=outputs, y_true=targets) 77 | metrics = metric.compute_key_value() 78 | for key in true_values.keys(): 79 | assert torch.isclose(true_values[key], metrics[key]) 80 | -------------------------------------------------------------------------------- /tests/catalyst/runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/catalyst/runners/__init__.py -------------------------------------------------------------------------------- /tests/catalyst/runners/test_profiler.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | import os 4 | 5 | from pytest import mark 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | 11 | from catalyst import dl 12 | from catalyst.contrib.datasets import MNIST 13 | from tests import DATA_ROOT 14 | 15 | 16 | def _is_profile_available(): 17 | try: 18 | from torch import profiler # noqa: F401 19 | 20 | return True 21 | except ImportError: 22 | return False 23 | 24 | 25 | def train_experiment(): 26 | loaders = { 27 | "train": DataLoader( 28 | MNIST(DATA_ROOT, train=False), 29 | batch_size=32, 30 | ), 31 | "valid": DataLoader( 32 | MNIST(DATA_ROOT, train=False), 33 | batch_size=32, 34 | ), 35 | } 36 | model = nn.Sequential( 37 | nn.Flatten(), nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 10) 38 | ) 39 | criterion = nn.CrossEntropyLoss() 40 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) 41 | runner = dl.SupervisedRunner() 42 | runner.train( 43 | model=model, 44 | loaders=loaders, 45 | criterion=criterion, 46 | optimizer=optimizer, 47 | num_epochs=5, 48 | logdir="./logs", 49 | profile=True, 50 | ) 51 | 52 | 53 | @mark.skipif(not _is_profile_available(), reason="Torch profiler is not available") 54 | def test_profiler(): 55 | train_experiment() 56 | assert ( 57 | os.path.isdir("./logs/tb_profile") 58 | and not len(os.listdir("./logs/tb_profile")) == 0 59 | ) 60 | -------------------------------------------------------------------------------- /tests/catalyst/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/catalyst/utils/__init__.py -------------------------------------------------------------------------------- /tests/catalyst/utils/test_config.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import io 3 | import json 4 | 5 | import numpy as np 6 | 7 | # from catalyst.dl.scripts._misc import parse_config_args 8 | from catalyst.utils.config import _load_ordered_yaml 9 | 10 | # def test_parse_config_args(): 11 | # configuration = { 12 | # "stages": {"one": "uno", "two": "dos", "three": "tres"}, 13 | # "key": {"value": "key2"}, 14 | # } 15 | 16 | # parser = argparse.ArgumentParser() 17 | # parser.add_argument("--command") 18 | 19 | # args, uargs = parser.parse_known_args( 20 | # [ 21 | # "--command", 22 | # "run", 23 | # "--path=test.yml:str", 24 | # "--stages/zero=cero:str", 25 | # "-C=like:str", 26 | # ] 27 | # ) 28 | 29 | # configuration, args = parse_config_args( 30 | # config=configuration, args=args, unknown_args=uargs 31 | # ) 32 | 33 | # assert args.command == "run" 34 | # assert args.path == "test.yml" 35 | # assert configuration.get("stages") is not None 36 | # assert "zero" in configuration["stages"] 37 | # assert configuration["stages"]["zero"] == "cero" 38 | # assert configuration.get("args") is not None 39 | # assert configuration["args"]["path"] == "test.yml" 40 | # assert configuration["args"]["C"] == "like" 41 | # assert configuration["args"]["command"] == "run" 42 | 43 | # for key, value in args._get_kwargs(): 44 | # v = configuration["args"].get(key) 45 | # assert v is not None 46 | # assert v == value 47 | 48 | 49 | def test_parse_numbers(): 50 | configuration = { 51 | "a": 1, 52 | "b": 20, 53 | "c": 303e5, 54 | "d": -4, 55 | "e": -50, 56 | "f": -666e7, 57 | "g": 0.35, 58 | "h": 7.35e-5, 59 | "k": 8e-10, 60 | } 61 | 62 | buffer = io.StringIO() 63 | json.dump(configuration, buffer) 64 | buffer.seek(0) 65 | yaml_config = _load_ordered_yaml(buffer) 66 | 67 | for key, item in configuration.items(): 68 | assert np.isclose(yaml_config[key], item) 69 | -------------------------------------------------------------------------------- /tests/catalyst/utils/test_hash.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from catalyst.utils.misc import get_hash 3 | 4 | 5 | def test_hash(): 6 | """Docs? Contribution is welcome.""" 7 | a = get_hash({"a": "foo"}) 8 | print(a) 9 | -------------------------------------------------------------------------------- /tests/catalyst/utils/test_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from catalyst.utils.onnx import onnx_export 7 | 8 | 9 | def test_api(): 10 | """Tests if API is working. Minimal example.""" 11 | model = nn.Sequential(nn.Linear(768, 128), nn.ReLU(), nn.Linear(128, 10)) 12 | onnx_export(model, batch=torch.randn((1, 768)), file="model.onnx") 13 | assert os.path.isfile("model.onnx") 14 | os.remove("model.onnx") 15 | -------------------------------------------------------------------------------- /tests/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/contrib/__init__.py -------------------------------------------------------------------------------- /tests/contrib/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/contrib/data/__init__.py -------------------------------------------------------------------------------- /tests/contrib/data/test_dataset.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Union 3 | from pathlib import Path 4 | 5 | from catalyst.contrib.data.dataset import PathsDataset 6 | 7 | 8 | def test_PathsDataset() -> None: 9 | def get_target(path: Union[str, Path]) -> int: 10 | result = str(path).split(".")[0].split("_")[1] 11 | result = int(result) 12 | 13 | return result 14 | 15 | def identity(x): 16 | return x 17 | 18 | filenames = ["path1_1.jpg", "path2_1.jpg", "path_0.jpg"] 19 | targets = [1, 1, 0] 20 | 21 | dataset = PathsDataset(filenames, open_fn=identity, label_fn=get_target) 22 | 23 | result = True 24 | for data, target in zip(dataset.data, targets): 25 | result &= data["targets"] == target 26 | assert result 27 | -------------------------------------------------------------------------------- /tests/contrib/data/test_nifti.py: -------------------------------------------------------------------------------- 1 | # import os 2 | 3 | # import pytest 4 | 5 | # from catalyst.contrib.data.reader_nifti import NiftiReader 6 | # from catalyst.settings import SETTINGS 7 | 8 | # if SETTINGS.nifti_required: 9 | # from nibabel.testing import data_path 10 | 11 | # example_filename = os.path.join(data_path, "example4d.nii.gz") 12 | 13 | 14 | # @pytest.mark.skipif(not SETTINGS.nifti_required, reason="Niftly is not available") 15 | # @pytest.mark.parametrize( 16 | # "input_key, output_key", [("images", None), ("images", "outputs")] 17 | # ) 18 | # def test_nifti_reader(input_key, output_key): 19 | # """Minimal test for getting images from a nifti file""" 20 | # test_annotations_dict = {input_key: example_filename} 21 | # if output_key: 22 | # nifti_reader = NiftiReader(input_key, output_key) 23 | # else: 24 | # nifti_reader = NiftiReader(input_key) 25 | # output = nifti_reader(test_annotations_dict) 26 | # assert isinstance(output, dict) 27 | # if output_key is None: 28 | # output_key = input_key 29 | # assert output[output_key].shape == (128, 96, 24, 2) 30 | -------------------------------------------------------------------------------- /tests/contrib/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/contrib/datasets/__init__.py -------------------------------------------------------------------------------- /tests/contrib/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/contrib/nn/__init__.py -------------------------------------------------------------------------------- /tests/contrib/nn/test_onecyle.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import torch 3 | from torch.utils.data import DataLoader, TensorDataset 4 | 5 | from catalyst.contrib.schedulers.onecycle import OneCycleLRWithWarmup 6 | from catalyst.dl import Callback, CallbackOrder, SupervisedRunner 7 | 8 | 9 | class LRCheckerCallback(Callback): 10 | def __init__(self, init_lr_value: float, final_lr_value: float): 11 | super().__init__(CallbackOrder.Internal) 12 | self.init_lr = init_lr_value 13 | self.final_lr = final_lr_value 14 | 15 | # Check initial LR 16 | def on_batch_start(self, runner): 17 | step = getattr(runner, "batch_step") 18 | if step == 1: 19 | assert self.init_lr == runner.scheduler.get_lr()[0] 20 | 21 | # Check final LR 22 | def on_experiment_end(self, runner): 23 | assert self.final_lr == runner.scheduler.get_lr()[0] 24 | 25 | 26 | def test_onecyle(): 27 | # experiment_setup 28 | logdir = "./logs/core_runner" 29 | 30 | # data 31 | num_samples, num_features = int(1e4), int(1e1) 32 | X = torch.rand(num_samples, num_features) 33 | y = torch.randint(0, 5, size=[num_samples]) 34 | dataset = TensorDataset(X, y) 35 | loader = DataLoader(dataset, batch_size=32, num_workers=1) 36 | loaders = { 37 | "train": loader, 38 | "valid": loader, 39 | } 40 | 41 | # number of steps, epochs, LR range, initial LR and warmup_fraction 42 | num_steps = 6 43 | epochs = 8 44 | min_lr = 1e-4 45 | max_lr = 2e-3 46 | init_lr = 1e-3 47 | warmup_fraction = 0.5 48 | 49 | # model, criterion, optimizer, scheduler 50 | model = torch.nn.Linear(num_features, 5) 51 | criterion = torch.nn.CrossEntropyLoss() 52 | optimizer = torch.optim.Adam(model.parameters()) 53 | scheduler = OneCycleLRWithWarmup( 54 | optimizer, 55 | num_steps=num_steps, 56 | lr_range=(max_lr, min_lr), 57 | init_lr=init_lr, 58 | warmup_fraction=warmup_fraction, 59 | ) 60 | 61 | runner = SupervisedRunner() 62 | 63 | callbacks = [LRCheckerCallback(init_lr, min_lr)] 64 | 65 | runner.train( 66 | model=model, 67 | criterion=criterion, 68 | optimizer=optimizer, 69 | scheduler=scheduler, 70 | loaders=loaders, 71 | logdir=logdir, 72 | num_epochs=epochs, 73 | verbose=False, 74 | callbacks=callbacks, 75 | ) 76 | -------------------------------------------------------------------------------- /tests/contrib/nn/test_optimizer.py: -------------------------------------------------------------------------------- 1 | from torch import nn, optim 2 | 3 | from catalyst.contrib import optimizers as module 4 | from catalyst.contrib.optimizers import Lookahead 5 | 6 | 7 | def test_optimizer_init(): 8 | """@TODO: Docs. Contribution is welcome.""" 9 | model = nn.Linear(10, 10) 10 | for name, module_class in module.__dict__.items(): 11 | if isinstance(module_class, type): 12 | if name in ["Optimizer", "SparseAdam"]: 13 | # @TODO: add test for SparseAdam 14 | continue 15 | elif module_class == Lookahead: 16 | instance = optim.SGD(model.parameters(), lr=1e-3) 17 | instance = module_class(instance) 18 | else: 19 | instance = module_class(model.parameters(), lr=1e-3) 20 | assert instance is not None 21 | -------------------------------------------------------------------------------- /tests/contrib/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/contrib/scripts/__init__.py -------------------------------------------------------------------------------- /tests/contrib/scripts/test_main.py: -------------------------------------------------------------------------------- 1 | # import pytest 2 | 3 | # from catalyst.dl import __main__ as main 4 | 5 | 6 | # def test_arg_parser_run(): 7 | # """Docs? Contribution is welcome.""" 8 | # parser = main.build_parser() 9 | 10 | # args, uargs = parser.parse_known_args(["run", "--config", "test.yml", "--unknown"]) 11 | 12 | # assert args.command == "run" 13 | # assert args.configs == ["test.yml"] 14 | # assert "--unknown" in uargs 15 | 16 | 17 | # def test_run_multiple_configs(): 18 | # """Docs? Contribution is welcome.""" 19 | # parser = main.build_parser() 20 | 21 | # args, uargs = parser.parse_known_args(["run", "--config", "test.yml", "test1.yml"]) 22 | 23 | # assert args.configs == ["test.yml", "test1.yml"] 24 | 25 | 26 | # def test_arg_parser_fail_on_none(): 27 | # """Docs? Contribution is welcome.""" 28 | # parser = main.build_parser() 29 | 30 | # with pytest.raises(SystemExit): 31 | # # Raises SystemExit when args are not ok 32 | # parser.parse_known_args(["--config", "test.yml", "--unknown"]) 33 | -------------------------------------------------------------------------------- /tests/contrib/scripts/test_tag2label.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # from catalyst.contrib.scripts.tag2label import _prepare_df_from_dirs 3 | # 4 | # 5 | # def _setup_dataset_fs(tmp_path): 6 | # def create_children(parent, children): 7 | # for child, sub_children in children.items(): 8 | # sub_parent = parent / child 9 | # sub_parent.mkdir() 10 | # if type(sub_children) == list: 11 | # [ 12 | # (sub_parent / sub_child).touch() 13 | # for sub_child in sub_children 14 | # ] 15 | # else: 16 | # create_children(sub_parent, sub_children) 17 | # 18 | # fs_structure = { 19 | # "datasets": { 20 | # "root1": {"act1": ["a.txt", "b.txt"], "act2": ["c.txt"]}, 21 | # "root2": {"act1": ["d.txt"], "act2": ["e.txt", "f.txt"]}, 22 | # } 23 | # } 24 | # 25 | # create_children(tmp_path, fs_structure) 26 | # 27 | # 28 | # def test_prepare_df_from_dirs_one(tmp_path): 29 | # def check_filepath(filepath): 30 | # return filepath.startswith("act1") or filepath.startswith("act2") 31 | # 32 | # _setup_dataset_fs(tmp_path) 33 | # root_path = tmp_path / "datasets/root1" 34 | # df = _prepare_df_from_dirs(str(root_path), "label") 35 | # 36 | # assert df.shape[0] == 3 37 | # assert df.filepath.apply(check_filepath).sum().all() 38 | # assert df.label.isin(["act1", "act2"]).all() 39 | # 40 | # 41 | # def test_prepare_df_from_dirs_multi(tmp_path): 42 | # def check_filepath(filepath): 43 | # return ( 44 | # filepath.startswith("root1/act1") 45 | # or filepath.startswith("root1/act2") 46 | # or filepath.startswith("root2/act1") 47 | # or filepath.startswith("root2/act2") 48 | # ) 49 | # 50 | # _setup_dataset_fs(tmp_path) 51 | # ds_path = tmp_path / "datasets" 52 | # root_paths = ",".join([str(ds_path / "root1"), str(ds_path / "root2")]) 53 | # df = _prepare_df_from_dirs(root_paths, "label") 54 | # 55 | # assert df.shape[0] == 6 56 | # assert df.filepath.apply(check_filepath).sum().all() 57 | # assert df.label.isin(["act1", "act2"]).all() 58 | -------------------------------------------------------------------------------- /tests/contrib/scripts/test_tune.yml: -------------------------------------------------------------------------------- 1 | study: 2 | _target_: optuna.create_study 3 | storage: null 4 | sampler: null 5 | pruner: null 6 | study_name: null 7 | direction: minimize 8 | load_if_exists: false 9 | _mode_: call 10 | 11 | runner: 12 | _target_: catalyst.runners.SupervisedRunner 13 | model: 14 | _var_: model 15 | _target_: tests.contrib.scripts.test_tune.CustomModule 16 | in_features: 784 # 28 * 28 17 | num_hidden: 18 | _var_: trial.suggest_int 19 | name: num_hidden 20 | low: 32 21 | high: 128 22 | out_features: 10 23 | input_key: features 24 | output_key: logits 25 | target_key: targets 26 | loss_key: loss 27 | 28 | run: 29 | - _call_: train 30 | 31 | criterion: 32 | _target_: torch.nn.CrossEntropyLoss 33 | 34 | optimizer: 35 | _target_: torch.optim.Adam 36 | params: 37 | _var_: model.parameters 38 | lr: 0.02 39 | 40 | loaders: 41 | train: 42 | _target_: torch.utils.data.DataLoader 43 | dataset: 44 | _target_: catalyst.contrib.datasets.MNIST 45 | root: tests 46 | train: y 47 | batch_size: 32 48 | valid: 49 | _target_: torch.utils.data.DataLoader 50 | dataset: 51 | _target_: catalyst.contrib.datasets.MNIST 52 | root: tests 53 | train: n 54 | batch_size: 32 55 | 56 | callbacks: 57 | - _target_: catalyst.callbacks.AccuracyCallback 58 | input_key: logits 59 | target_key: targets 60 | topk: [1,3,5] 61 | - _target_: catalyst.callbacks.PrecisionRecallF1SupportCallback 62 | input_key: logits 63 | target_key: targets 64 | num_classes: 10 65 | 66 | num_epochs: 1 67 | logdir: tests/logs # TODO: use `tempfile.TemporaryDirectory` 68 | valid_loader: valid 69 | valid_metric: loss 70 | minimize_valid_metric: y 71 | verbose: n 72 | load_best_on_end: y 73 | timeit: n 74 | check: n 75 | overfit: n 76 | -------------------------------------------------------------------------------- /tests/contrib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalyst-team/catalyst/e99f90655d0efcf22559a46e928f0f98c9807ebf/tests/contrib/utils/__init__.py -------------------------------------------------------------------------------- /tests/contrib/utils/test_image.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import pytest 3 | 4 | from catalyst.settings import SETTINGS 5 | 6 | if SETTINGS.cv_required: 7 | from catalyst.contrib.utils.image import imread 8 | 9 | 10 | @pytest.mark.skipif(not (SETTINGS.cv_required), reason="No catalyst[cv] required") 11 | def test_imread(): 12 | """Tests ``imread`` functionality.""" 13 | jpg_rgb_uri = ( 14 | "https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master" 15 | "/test_images/catalyst_icon.jpg" 16 | ) 17 | jpg_grs_uri = ( 18 | "https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master" 19 | "/test_images/catalyst_icon_grayscale.jpg" 20 | ) 21 | png_rgb_uri = ( 22 | "https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master" 23 | "/test_images/catalyst_icon.png" 24 | ) 25 | png_grs_uri = ( 26 | "https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master" 27 | "/test_images/catalyst_icon_grayscale.png" 28 | ) 29 | 30 | for uri in [jpg_rgb_uri, jpg_grs_uri, png_rgb_uri, png_grs_uri]: 31 | img = imread(uri) 32 | assert img.shape == (400, 400, 3) 33 | img = imread(uri, grayscale=True) 34 | assert img.shape == (400, 400, 1) 35 | 36 | 37 | # def test_tensor_to_ndimage(): 38 | # """Tests ``tensor_to_ndimage`` functionality.""" 39 | # orig_images = np.random.randint(0, 255, (2, 20, 10, 3), np.uint8) 40 | # 41 | # torch_images = torch.stack( 42 | # [normalize(to_tensor(im), _IMAGENET_MEAN, _IMAGENET_STD) for im in orig_images], dim=0, 43 | # ) 44 | # 45 | # byte_images = tensor_to_ndimage(torch_images, dtype=np.uint8) 46 | # float_images = tensor_to_ndimage(torch_images, dtype=np.float32) 47 | # 48 | # assert np.allclose(byte_images, orig_images) 49 | # assert np.allclose(float_images, orig_images / 255, atol=1e-3, rtol=1e-3) 50 | # 51 | # assert np.allclose( 52 | # tensor_to_ndimage(torch_images[0]), orig_images[0] / 255, atol=1e-3, rtol=1e-3, 53 | # ) 54 | -------------------------------------------------------------------------------- /tests/contrib/utils/test_swa.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os 3 | from pathlib import Path 4 | import shutil 5 | import unittest 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from catalyst.contrib.utils.swa import get_averaged_weights_by_path_mask 11 | 12 | 13 | class Net(nn.Module): 14 | """Dummy network class.""" 15 | 16 | def __init__(self, init_weight=4): 17 | """Initialization of network and filling it with given numbers.""" 18 | super(Net, self).__init__() 19 | self.fc = nn.Linear(2, 1) 20 | self.fc.weight.data.fill_(init_weight) 21 | self.fc.bias.data.fill_(init_weight) 22 | 23 | 24 | class TestSwa(unittest.TestCase): 25 | """Test SWA class.""" 26 | 27 | def setUp(self): 28 | """Test set up.""" 29 | net1 = Net(init_weight=2.0) 30 | net2 = Net(init_weight=5.0) 31 | os.mkdir("./checkpoints") 32 | torch.save(net1.state_dict(), "./checkpoints/net1.pth") 33 | torch.save(net2.state_dict(), "./checkpoints/net2.pth") 34 | 35 | def tearDown(self): 36 | """Test tear down.""" 37 | shutil.rmtree("./checkpoints") 38 | 39 | def test_averaging(self): 40 | """Test SWA method.""" 41 | weights = get_averaged_weights_by_path_mask(logdir=Path("./"), path_mask="net*") 42 | torch.save(weights, str("./checkpoints/swa_weights.pth")) 43 | model = Net() 44 | model.load_state_dict( 45 | torch.load( 46 | "./checkpoints/swa_weights.pth", 47 | map_location=lambda storage, loc: storage, 48 | ) 49 | ) 50 | self.assertEqual(float(model.fc.weight.data[0][0]), 3.5) 51 | self.assertEqual(float(model.fc.weight.data[0][1]), 3.5) 52 | self.assertEqual(float(model.fc.bias.data[0]), 3.5) 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /tests/contrib/utils/test_torch.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Dict 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from catalyst.contrib.utils.torch import get_network_output 8 | 9 | 10 | def test_network_output(): 11 | """Test for ``catalyst.utils.torch.get_network_output``.""" 12 | # case #1 - test net with one input variable 13 | net = nn.Identity() 14 | assert get_network_output(net, (1, 20)).shape == (1, 1, 20) 15 | 16 | net = nn.Linear(20, 10) 17 | assert get_network_output(net, (1, 20)).shape == (1, 1, 10) 18 | 19 | # case #2 - test net with several input variables 20 | class Net(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.net1 = nn.Linear(20, 10) 24 | self.net2 = nn.Linear(10, 5) 25 | 26 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 27 | z = torch.cat((self.net1(x), self.net2(y)), dim=-1) 28 | return z 29 | 30 | net = Net() 31 | assert get_network_output(net, (20,), (10,)).shape == (1, 15) 32 | 33 | # case #3 - test net with key-value input 34 | class Net(nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.net = nn.Linear(20, 10) 38 | 39 | def forward(self, *, x: torch.Tensor = None) -> torch.Tensor: 40 | y = self.net(x) 41 | return y 42 | 43 | net = Net() 44 | input_shapes = {"x": (20,)} 45 | assert get_network_output(net, **input_shapes).shape == (1, 10) 46 | 47 | # case #4 - test net with dict of variables input 48 | class Net(nn.Module): 49 | def __init__(self): 50 | super().__init__() 51 | self.net = nn.Linear(20, 10) 52 | 53 | def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: 54 | y = self.net(x["x"]) 55 | return y 56 | 57 | net = Net() 58 | input_shapes = {"x": (20,)} 59 | assert get_network_output(net, input_shapes).shape == (1, 10) 60 | -------------------------------------------------------------------------------- /tests/misc.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import subprocess 3 | 4 | import torch 5 | from torch.utils import data 6 | 7 | 8 | class TensorDataset(data.TensorDataset): 9 | """Dataset wrapping tensors. 10 | 11 | Each sample will be retrieved by indexing tensors along the first dimension. 12 | 13 | Args: 14 | *args: tensors that have the same size of the first dimension. 15 | """ 16 | 17 | def __init__(self, *args: torch.Tensor, **kwargs) -> None: 18 | super().__init__(*args) 19 | 20 | 21 | def run_experiment_from_configs( 22 | config_dir: Path, main_config: str, *auxiliary_configs: str 23 | ) -> None: 24 | """Runs experiment from config (for Config API tests).""" 25 | main_config = str(config_dir / main_config) 26 | auxiliary_configs = " ".join(str(config_dir / c) for c in auxiliary_configs) 27 | 28 | script = Path("catalyst", "contrib", "scripts", "run.py") 29 | cmd = f"python {script} -C {main_config} {auxiliary_configs}" 30 | subprocess.run(cmd.split(), check=True) 31 | -------------------------------------------------------------------------------- /tests/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Needed to collect coverage data 2 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_core.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.core.engine.Engine 4 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_cpu.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.CPUEngine 4 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_ddp.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.DistributedDataParallelEngine 4 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_ddp_amp.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.DistributedDataParallelEngine 4 | fp16: y 5 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_ddp_xla.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.DistributedXLAEngine 4 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_dp.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.DataParallelEngine 4 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_dp_amp.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.DataParallelEngine 4 | fp16: y 5 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_gpu.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.GPUEngine 4 | -------------------------------------------------------------------------------- /tests/pipelines/configs/engine_gpu_amp.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | engine: 3 | _target_: catalyst.engines.GPUEngine 4 | fp16: y 5 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_classification.yml: -------------------------------------------------------------------------------- 1 | shared: 2 | num_samples: &num_samples 10000 3 | num_features: &num_features 10 4 | num_classes: &num_classes 4 5 | loader: &loader 6 | _var_: loader 7 | _target_: torch.utils.data.DataLoader 8 | dataset: 9 | _target_: tests.misc.TensorDataset 10 | args: 11 | - _target_: torch.load 12 | f: tests/X.pt 13 | _mode_: call 14 | - _target_: torch.load 15 | f: tests/y.pt 16 | _mode_: call 17 | batch_size: 32 18 | num_workers: 1 19 | 20 | runner: 21 | _target_: catalyst.runners.SupervisedRunner 22 | model: 23 | _var_: model 24 | _target_: torch.nn.Linear 25 | in_features: *num_features 26 | out_features: *num_classes 27 | input_key: features 28 | output_key: &output_key logits 29 | target_key: &target_key targets 30 | loss_key: &loss_key loss 31 | 32 | run: 33 | - _call_: train 34 | 35 | criterion: 36 | _target_: torch.nn.CrossEntropyLoss 37 | 38 | optimizer: 39 | _var_: optimizer 40 | _target_: torch.optim.Adam 41 | params: 42 | _var_: model.parameters 43 | 44 | scheduler: 45 | _target_: torch.optim.lr_scheduler.MultiStepLR 46 | optimizer: 47 | _var_: optimizer 48 | milestones: [2] 49 | 50 | loaders: 51 | train: 52 | _var_: loader 53 | valid: 54 | _var_: loader 55 | 56 | callbacks: 57 | - _target_: catalyst.callbacks.AccuracyCallback 58 | input_key: *output_key 59 | target_key: *target_key 60 | num_classes: *num_classes 61 | - _target_: catalyst.callbacks.PrecisionRecallF1SupportCallback 62 | input_key: *output_key 63 | target_key: *target_key 64 | num_classes: *num_classes 65 | 66 | num_epochs: 1 67 | logdir: tests/logs # TODO: use `tempfile.TemporaryDirectory` 68 | valid_loader: valid 69 | valid_metric: accuracy03 70 | minimize_valid_metric: n 71 | verbose: n 72 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_mnist.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | _target_: catalyst.runners.SupervisedRunner 3 | model: 4 | _var_: model 5 | _target_: torch.nn.Sequential 6 | args: 7 | - _target_: torch.nn.Flatten 8 | - _target_: torch.nn.Linear 9 | in_features: 784 # 28 * 28 10 | out_features: 10 11 | input_key: features 12 | output_key: logits 13 | target_key: targets 14 | loss_key: loss 15 | 16 | run: 17 | - _call_: train 18 | 19 | criterion: 20 | _target_: torch.nn.CrossEntropyLoss 21 | 22 | optimizer: 23 | _target_: torch.optim.Adam 24 | params: 25 | _var_: model.parameters 26 | lr: 0.02 27 | 28 | loaders: 29 | train: 30 | _target_: torch.utils.data.DataLoader 31 | dataset: 32 | _target_: catalyst.contrib.datasets.MNIST 33 | root: tests 34 | train: y 35 | batch_size: 32 36 | valid: 37 | _target_: torch.utils.data.DataLoader 38 | dataset: 39 | _target_: catalyst.contrib.datasets.MNIST 40 | root: tests 41 | train: n 42 | batch_size: 32 43 | 44 | callbacks: 45 | - _target_: catalyst.callbacks.AccuracyCallback 46 | input_key: logits 47 | target_key: targets 48 | topk: [1,3,5] 49 | - _target_: catalyst.callbacks.PrecisionRecallF1SupportCallback 50 | input_key: logits 51 | target_key: targets 52 | num_classes: 10 53 | 54 | num_epochs: 1 55 | logdir: tests/logs # TODO: use `tempfile.TemporaryDirectory` 56 | valid_loader: valid 57 | valid_metric: loss 58 | minimize_valid_metric: y 59 | verbose: n 60 | load_best_on_end: y 61 | timeit: n 62 | check: n 63 | overfit: n 64 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_mnist_custom.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | _target_: tests.pipelines.test_mnist_custom.CustomRunner 3 | model: 4 | _var_: model 5 | _target_: torch.nn.Sequential 6 | args: 7 | - _target_: torch.nn.Flatten 8 | - _target_: torch.nn.Linear 9 | in_features: 784 # 28 * 28 10 | out_features: 10 11 | 12 | run: 13 | - _call_: train 14 | 15 | optimizer: 16 | _var_: optimizer 17 | _target_: torch.optim.Adam 18 | params: 19 | _var_: model.parameters 20 | lr: 0.02 21 | 22 | loaders: 23 | train: 24 | _target_: torch.utils.data.DataLoader 25 | dataset: 26 | _target_: catalyst.contrib.datasets.MNIST 27 | root: tests 28 | train: y 29 | batch_size: 32 30 | valid: 31 | _target_: torch.utils.data.DataLoader 32 | dataset: 33 | _target_: catalyst.contrib.datasets.MNIST 34 | root: tests 35 | train: n 36 | batch_size: 32 37 | 38 | logdir: tests/logs # TODO: use `tempfile.TemporaryDirectory` 39 | num_epochs: 1 40 | verbose: n 41 | valid_loader: valid 42 | valid_metric: loss 43 | minimize_valid_metric: y 44 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_mnist_multicriterion.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | _target_: tests.pipelines.test_mnist_multicriterion.CustomRunner 3 | model: 4 | _var_: model 5 | _target_: torch.nn.Sequential 6 | args: 7 | - _target_: torch.nn.Flatten 8 | - _target_: torch.nn.Linear 9 | in_features: 784 # 28 * 28 10 | out_features: 10 11 | 12 | run: 13 | - _call_: train 14 | 15 | criterion: 16 | multiclass: 17 | _target_: torch.nn.CrossEntropyLoss 18 | multilabel: 19 | _target_: torch.nn.BCEWithLogitsLoss 20 | 21 | optimizer: 22 | _var_: optimizer 23 | _target_: torch.optim.Adam 24 | params: 25 | _var_: model.parameters 26 | lr: 0.02 27 | 28 | loaders: 29 | train: 30 | _target_: torch.utils.data.DataLoader 31 | dataset: 32 | _target_: catalyst.contrib.datasets.MNIST 33 | root: tests 34 | train: y 35 | batch_size: 32 36 | valid: 37 | _target_: torch.utils.data.DataLoader 38 | dataset: 39 | _target_: catalyst.contrib.datasets.MNIST 40 | root: tests 41 | train: n 42 | batch_size: 32 43 | 44 | logdir: tests/logs # TODO: use `tempfile.TemporaryDirectory` 45 | num_epochs: 1 46 | verbose: n 47 | valid_loader: valid 48 | valid_metric: loss 49 | minimize_valid_metric: y 50 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_mnist_multimodel.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | _target_: tests.pipelines.test_mnist_multimodel.CustomRunner 3 | model: 4 | encoder: 5 | _var_: model_encoder 6 | _target_: torch.nn.Sequential 7 | args: 8 | - _target_: torch.nn.Flatten 9 | - _target_: torch.nn.Linear 10 | in_features: 784 # 28 * 28 11 | out_features: 128 12 | head: 13 | _var_: model_head 14 | _target_: torch.nn.Linear 15 | in_features: 128 16 | out_features: 10 17 | 18 | run: 19 | - _call_: train 20 | 21 | criterion: 22 | _target_: torch.nn.CrossEntropyLoss 23 | 24 | optimizer: 25 | _var_: optimizer 26 | _target_: torch.optim.Adam 27 | params: 28 | - params: 29 | _var_: model_encoder.parameters 30 | - params: 31 | _var_: model_head.parameters 32 | lr: 0.02 33 | 34 | loaders: 35 | train: 36 | _target_: torch.utils.data.DataLoader 37 | dataset: 38 | _target_: catalyst.contrib.datasets.MNIST 39 | root: tests 40 | train: y 41 | batch_size: 32 42 | valid: 43 | _target_: torch.utils.data.DataLoader 44 | dataset: 45 | _target_: catalyst.contrib.datasets.MNIST 46 | root: tests 47 | train: n 48 | batch_size: 32 49 | 50 | logdir: tests/logs # TODO: use `tempfile.TemporaryDirectory` 51 | num_epochs: 1 52 | verbose: n 53 | valid_loader: valid 54 | valid_metric: loss 55 | minimize_valid_metric: y 56 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_mnist_multioptimizer.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | _target_: tests.pipelines.test_mnist_multioptimizer.CustomRunner 3 | model: 4 | encoder: 5 | _var_: model_encoder 6 | _target_: torch.nn.Sequential 7 | args: 8 | - _target_: torch.nn.Flatten 9 | - _target_: torch.nn.Linear 10 | in_features: 784 # 28 * 28 11 | out_features: 128 12 | head: 13 | _var_: model_head 14 | _target_: torch.nn.Linear 15 | in_features: 128 16 | out_features: 10 17 | 18 | run: 19 | - _call_: train 20 | 21 | criterion: 22 | _target_: torch.nn.CrossEntropyLoss 23 | 24 | optimizer: 25 | encoder: 26 | _target_: torch.optim.Adam 27 | params: 28 | _var_: model_encoder.parameters 29 | lr: 0.02 30 | head: 31 | _target_: torch.optim.Adam 32 | params: 33 | _var_: model_head.parameters 34 | lr: 0.001 35 | 36 | loaders: 37 | train: 38 | _target_: torch.utils.data.DataLoader 39 | dataset: 40 | _target_: catalyst.contrib.datasets.MNIST 41 | root: tests 42 | train: y 43 | batch_size: 32 44 | valid: 45 | _target_: torch.utils.data.DataLoader 46 | dataset: 47 | _target_: catalyst.contrib.datasets.MNIST 48 | root: tests 49 | train: n 50 | batch_size: 32 51 | 52 | logdir: tests/logs # TODO: use `tempfile.TemporaryDirectory` 53 | num_epochs: 1 54 | verbose: n 55 | valid_loader: valid 56 | valid_metric: loss 57 | minimize_valid_metric: y 58 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_multilabel_classification.yml: -------------------------------------------------------------------------------- 1 | shared: 2 | num_samples: &num_samples 10000 3 | num_features: &num_features 10 4 | num_classes: &num_classes 4 5 | loader: &loader 6 | _var_: loader 7 | _target_: torch.utils.data.DataLoader 8 | dataset: 9 | _target_: tests.misc.TensorDataset 10 | args: 11 | - _target_: torch.load 12 | f: tests/X.pt 13 | _mode_: call 14 | - _target_: torch.load 15 | f: tests/y.pt 16 | _mode_: call 17 | batch_size: 32 18 | num_workers: 1 19 | 20 | runner: 21 | _target_: catalyst.runners.SupervisedRunner 22 | model: 23 | _var_: model 24 | _target_: torch.nn.Linear 25 | in_features: *num_features 26 | out_features: *num_classes 27 | 28 | run: 29 | - _call_: train 30 | 31 | criterion: 32 | _target_: torch.nn.BCEWithLogitsLoss 33 | 34 | optimizer: 35 | _var_: optimizer 36 | _target_: torch.optim.Adam 37 | params: 38 | _var_: model.parameters 39 | 40 | scheduler: 41 | _target_: torch.optim.lr_scheduler.MultiStepLR 42 | optimizer: 43 | _var_: optimizer 44 | milestones: [2] 45 | 46 | loaders: 47 | train: 48 | _var_: loader 49 | valid: 50 | _var_: loader 51 | 52 | callbacks: 53 | - _target_: catalyst.callbacks.BatchTransformCallback 54 | transform: torch.sigmoid 55 | scope: on_batch_end 56 | input_key: logits 57 | output_key: scores 58 | - _target_: catalyst.callbacks.MultilabelAccuracyCallback 59 | input_key: scores 60 | target_key: targets 61 | threshold: 0.5 62 | - _target_: catalyst.callbacks.MultilabelPrecisionRecallF1SupportCallback 63 | input_key: scores 64 | target_key: targets 65 | num_classes: *num_classes 66 | 67 | logdir: tests/logs 68 | num_epochs: 1 69 | valid_loader: valid 70 | valid_metric: accuracy 71 | minimize_valid_metric: n 72 | verbose: n 73 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_regression.yml: -------------------------------------------------------------------------------- 1 | shared: 2 | num_samples: &num_samples 10000 3 | num_features: &num_features 10 4 | loader: &loader 5 | _var_: loader 6 | _target_: torch.utils.data.DataLoader 7 | dataset: 8 | _target_: tests.misc.TensorDataset 9 | args: 10 | - _target_: torch.load 11 | f: tests/X.pt 12 | _mode_: call 13 | - _target_: torch.load 14 | f: tests/y.pt 15 | _mode_: call 16 | batch_size: 32 17 | num_workers: 1 18 | 19 | runner: 20 | _target_: catalyst.runners.SupervisedRunner 21 | model: 22 | _var_: model 23 | _target_: torch.nn.Linear 24 | in_features: *num_features 25 | out_features: 1 26 | 27 | run: 28 | - _call_: train 29 | 30 | criterion: 31 | _target_: torch.nn.MSELoss 32 | 33 | optimizer: 34 | _var_: optimizer 35 | _target_: torch.optim.Adam 36 | params: 37 | _var_: model.parameters 38 | 39 | scheduler: 40 | _target_: torch.optim.lr_scheduler.MultiStepLR 41 | optimizer: 42 | _var_: optimizer 43 | milestones: [3,6] 44 | 45 | loaders: 46 | train: 47 | _var_: loader 48 | valid: 49 | _var_: loader 50 | 51 | callbacks: 52 | - _target_: catalyst.callbacks.R2SquaredCallback 53 | input_key: logits 54 | target_key: targets 55 | 56 | logdir: tests/logs 57 | valid_loader: valid 58 | valid_metric: loss 59 | minimize_valid_metric: y 60 | num_epochs: 1 61 | verbose: n 62 | -------------------------------------------------------------------------------- /tests/pipelines/configs/test_vae.yml: -------------------------------------------------------------------------------- 1 | runner: 2 | _target_: tests.pipelines.test_vae.CustomRunner 3 | hid_features: 64 4 | logdir: tests/logs 5 | 6 | run: 7 | - _call_: run 8 | --------------------------------------------------------------------------------