├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── documentation.md │ ├── feature-request.md │ ├── questions-help-support.md │ └── user-feedback.md ├── PULL_REQUEST_TEMPLATE.md ├── failed_schedule_issue_template.md ├── pr-labeler-config.yml └── workflows │ ├── binaries-nightly-release.yml │ ├── build_docs.sh │ ├── code-style.yml │ ├── discord_issues.yml │ ├── discord_pull_requests.yaml │ ├── docker-build.yml │ ├── docs.yml │ ├── gpu-hvd-tests.yml │ ├── gpu-tests.yml │ ├── hvd-tests.yml │ ├── install_docs_deps.sh │ ├── mps-tests.yml │ ├── pytorch-version-tests.yml │ ├── stable-release-anaconda.yml │ ├── stable-release-pypi.yml │ ├── tpu-tests.yml │ ├── triage.yml │ └── unit-tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── ignite_vs_bare_pytorch.png ├── logo │ ├── ignite_logo.png │ ├── ignite_logo.svg │ ├── ignite_logo_dark.png │ ├── ignite_logo_dark.svg │ ├── ignite_logo_guidelines.md │ ├── ignite_logo_light.png │ ├── ignite_logo_light.svg │ ├── ignite_logo_mixed.png │ ├── ignite_logo_mixed.svg │ ├── ignite_logo_mixed_light.png │ ├── ignite_logo_mixed_light.svg │ ├── ignite_logomark.png │ ├── ignite_logomark.svg │ ├── ignite_logomark_dark.png │ ├── ignite_logomark_dark.svg │ ├── ignite_logomark_light.png │ └── ignite_logomark_light.svg └── tldr │ ├── pytorch-ignite-teaser.gif │ ├── teaser.ipynb │ └── teaser.py ├── codecov.yml ├── conda.recipe ├── build_and_upload.sh └── meta.yaml ├── docker ├── README.md ├── build.sh ├── docker.cfg ├── hvd │ ├── Dockerfile.hvd-apex │ ├── Dockerfile.hvd-apex-nlp │ ├── Dockerfile.hvd-apex-vision │ ├── Dockerfile.hvd-base │ ├── Dockerfile.hvd-nlp │ └── Dockerfile.hvd-vision ├── main │ ├── Dockerfile.apex │ ├── Dockerfile.apex-nlp │ ├── Dockerfile.apex-vision │ ├── Dockerfile.base │ ├── Dockerfile.nlp │ └── Dockerfile.vision ├── msdp │ ├── Dockerfile.msdp-apex │ ├── Dockerfile.msdp-apex-nlp │ └── Dockerfile.msdp-apex-vision ├── push_all.sh └── test_image.py ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── img │ │ ├── concepts │ │ └── timeline_and_events.png │ │ └── schedulers │ │ ├── concat_example.png │ │ ├── concat_linear_exp_step_lr.png │ │ ├── cosine_annealing_example.png │ │ ├── linear_cyclical_example.png │ │ ├── lr_scheduler.png │ │ ├── piecewise_linear.png │ │ └── reduce_lr_on_plateau_example.png │ ├── _templates │ ├── _static │ │ ├── css │ │ │ └── ignite_theme.css │ │ └── img │ │ │ ├── ignite_logo.svg │ │ │ ├── ignite_logomark.svg │ │ │ └── netlify-light.svg │ ├── autosummary │ │ ├── class.rst │ │ ├── function.rst │ │ └── module.rst │ ├── classwithcall.rst │ ├── fonts.html │ ├── footer.html │ ├── layout.html │ └── theme_variables.jinja │ ├── conf.py │ ├── contrib │ ├── engines.rst │ ├── handlers.rst │ └── metrics.rst │ ├── defaults.rst │ ├── distributed.rst │ ├── engine.rst │ ├── exceptions.rst │ ├── handlers.rst │ ├── index.rst │ ├── metrics.rst │ └── utils.rst ├── examples ├── cifar10 │ ├── .gitignore │ ├── README.md │ ├── main.py │ ├── requirements.txt │ └── utils.py ├── cifar100_amp_benchmark │ ├── benchmark_fp32.py │ ├── benchmark_nvidia_apex.py │ ├── benchmark_torch_cuda_amp.py │ └── utils.py ├── cifar10_qat │ ├── .gitignore │ ├── README.md │ ├── main.py │ ├── pact.py │ └── utils.py ├── fast_neural_style │ ├── README.md │ ├── handlers.py │ ├── images │ │ ├── content_images │ │ │ └── amber.jpg │ │ ├── output_images │ │ │ └── mosaic_amber.jpg │ │ └── style_images │ │ │ └── mosaic.jpg │ ├── neural_style.py │ ├── transformer_net.py │ ├── utils.py │ └── vgg.py ├── gan │ ├── README.md │ └── dcgan.py ├── mnist │ ├── README.md │ ├── assets │ │ ├── logs-det_run_1_2.png │ │ ├── logs-det_run_3_4.png │ │ ├── logs_run_1_2.png │ │ ├── logs_run_3_4.png │ │ ├── run_vs_resume_run_logs_1_2.png │ │ └── run_vs_resume_run_logs_3_4.png │ ├── mnist.py │ ├── mnist_save_resume_engine.py │ ├── mnist_with_clearml_logger.py │ ├── mnist_with_neptune_logger.py │ ├── mnist_with_tensorboard.py │ ├── mnist_with_tensorboard_logger.py │ ├── mnist_with_tensorboard_on_tpu.py │ ├── mnist_with_tqdm_logger.py │ ├── mnist_with_visdom.py │ ├── mnist_with_visdom_logger.py │ └── mnist_with_wandb_logger.py ├── notebooks │ ├── Cifar100_bench_amp.ipynb │ ├── Cifar10_Ax_hyperparam_tuning.ipynb │ ├── CycleGAN_with_nvidia_apex.ipynb │ ├── CycleGAN_with_torch_cuda_amp.ipynb │ ├── EfficientNet_Cifar100_finetuning.ipynb │ ├── FashionMNIST.ipynb │ ├── FastaiLRFinder_MNIST.ipynb │ ├── HandlersTimeProfiler_MNIST.ipynb │ ├── MNIST_on_TPU.ipynb │ ├── TextCNN.ipynb │ ├── VAE.ipynb │ └── assets │ │ ├── ax_hparams.png │ │ ├── efficientnets.png │ │ ├── fashion-mnist.png │ │ ├── fastresnet_v2.svg │ │ └── tb_cyclegan_images.png ├── references │ ├── README.md │ ├── classification │ │ └── imagenet │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── configs │ │ │ ├── baseline_resnet50.py │ │ │ └── check_baseline_resnet50.py │ │ │ ├── dataflow.py │ │ │ ├── main.py │ │ │ ├── requirements.txt │ │ │ ├── utils.py │ │ │ └── vis.py │ └── segmentation │ │ └── pascal_voc2012 │ │ ├── README.md │ │ ├── configs │ │ ├── baseline_dplv3_resnet101.py │ │ ├── baseline_dplv3_resnet101_sbd.py │ │ └── eval_baseline_dplv3_resnet101_sbd.py │ │ ├── dataflow.py │ │ ├── main.py │ │ ├── requirements.txt │ │ ├── utils.py │ │ └── vis.py ├── reinforcement_learning │ ├── README.md │ ├── actor_critic.py │ └── reinforce.py ├── siamese_network │ ├── README.md │ ├── requirements.txt │ └── siamese_network.py ├── super_resolution │ ├── README.md │ ├── images │ │ ├── bicubic_image_cifar.png │ │ ├── input_cifar.png │ │ └── out_cifar.png │ ├── main.py │ ├── model.py │ └── super_resolve.py └── transformers │ ├── README.md │ ├── dataset.py │ ├── main.py │ ├── model.py │ ├── requirements.txt │ └── utils.py ├── ignite ├── __init__.py ├── _utils.py ├── base │ ├── __init__.py │ └── mixins.py ├── contrib │ ├── __init__.py │ ├── engines │ │ ├── __init__.py │ │ ├── common.py │ │ └── tbptt.py │ ├── handlers │ │ ├── __init__.py │ │ ├── base_logger.py │ │ ├── clearml_logger.py │ │ ├── lr_finder.py │ │ ├── mlflow_logger.py │ │ ├── neptune_logger.py │ │ ├── param_scheduler.py │ │ ├── polyaxon_logger.py │ │ ├── tensorboard_logger.py │ │ ├── time_profilers.py │ │ ├── tqdm_logger.py │ │ ├── visdom_logger.py │ │ └── wandb_logger.py │ └── metrics │ │ ├── __init__.py │ │ ├── average_precision.py │ │ ├── cohen_kappa.py │ │ ├── gpu_info.py │ │ ├── precision_recall_curve.py │ │ ├── regression │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── canberra_metric.py │ │ ├── fractional_absolute_error.py │ │ ├── fractional_bias.py │ │ ├── geometric_mean_absolute_error.py │ │ ├── geometric_mean_relative_absolute_error.py │ │ ├── manhattan_distance.py │ │ ├── maximum_absolute_error.py │ │ ├── mean_absolute_relative_error.py │ │ ├── mean_error.py │ │ ├── mean_normalized_bias.py │ │ ├── median_absolute_error.py │ │ ├── median_absolute_percentage_error.py │ │ ├── median_relative_absolute_error.py │ │ ├── r2_score.py │ │ └── wave_hedges_distance.py │ │ └── roc_auc.py ├── distributed │ ├── __init__.py │ ├── auto.py │ ├── comp_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── horovod.py │ │ ├── native.py │ │ └── xla.py │ ├── launcher.py │ └── utils.py ├── engine │ ├── __init__.py │ ├── deterministic.py │ ├── engine.py │ ├── events.py │ └── utils.py ├── exceptions.py ├── handlers │ ├── __init__.py │ ├── base_logger.py │ ├── checkpoint.py │ ├── clearml_logger.py │ ├── early_stopping.py │ ├── ema_handler.py │ ├── fbresearch_logger.py │ ├── lr_finder.py │ ├── mlflow_logger.py │ ├── neptune_logger.py │ ├── param_scheduler.py │ ├── polyaxon_logger.py │ ├── state_param_scheduler.py │ ├── stores.py │ ├── tensorboard_logger.py │ ├── terminate_on_nan.py │ ├── time_limit.py │ ├── time_profilers.py │ ├── timing.py │ ├── tqdm_logger.py │ ├── utils.py │ ├── visdom_logger.py │ └── wandb_logger.py ├── metrics │ ├── __init__.py │ ├── accumulation.py │ ├── accuracy.py │ ├── average_precision.py │ ├── classification_report.py │ ├── clustering │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── calinski_harabasz_score.py │ │ ├── davies_bouldin_score.py │ │ └── silhouette_score.py │ ├── cohen_kappa.py │ ├── confusion_matrix.py │ ├── cosine_similarity.py │ ├── entropy.py │ ├── epoch_metric.py │ ├── fbeta.py │ ├── frequency.py │ ├── gan │ │ ├── __init__.py │ │ ├── fid.py │ │ ├── inception_score.py │ │ └── utils.py │ ├── gpu_info.py │ ├── hsic.py │ ├── js_divergence.py │ ├── kl_divergence.py │ ├── loss.py │ ├── maximum_mean_discrepancy.py │ ├── mean_absolute_error.py │ ├── mean_average_precision.py │ ├── mean_pairwise_distance.py │ ├── mean_squared_error.py │ ├── metric.py │ ├── metric_group.py │ ├── metrics_lambda.py │ ├── multilabel_confusion_matrix.py │ ├── mutual_information.py │ ├── nlp │ │ ├── __init__.py │ │ ├── bleu.py │ │ ├── rouge.py │ │ └── utils.py │ ├── precision.py │ ├── precision_recall_curve.py │ ├── psnr.py │ ├── recall.py │ ├── regression │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── canberra_metric.py │ │ ├── fractional_absolute_error.py │ │ ├── fractional_bias.py │ │ ├── geometric_mean_absolute_error.py │ │ ├── geometric_mean_relative_absolute_error.py │ │ ├── kendall_correlation.py │ │ ├── manhattan_distance.py │ │ ├── maximum_absolute_error.py │ │ ├── mean_absolute_relative_error.py │ │ ├── mean_error.py │ │ ├── mean_normalized_bias.py │ │ ├── median_absolute_error.py │ │ ├── median_absolute_percentage_error.py │ │ ├── median_relative_absolute_error.py │ │ ├── pearson_correlation.py │ │ ├── r2_score.py │ │ ├── spearman_correlation.py │ │ └── wave_hedges_distance.py │ ├── roc_auc.py │ ├── root_mean_squared_error.py │ ├── running_average.py │ ├── ssim.py │ ├── top_k_categorical_accuracy.py │ └── vision │ │ ├── __init__.py │ │ └── object_detection_average_precision_recall.py ├── py.typed └── utils.py ├── mypy.ini ├── pyproject.toml ├── requirements-dev.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── common_test_functionality.sh ├── ignite ├── __init__.py ├── base │ ├── __init__.py │ └── test_mixins.py ├── conftest.py ├── contrib │ ├── __init__.py │ ├── conftest.py │ ├── engines │ │ ├── __init__.py │ │ ├── test_common.py │ │ └── test_tbptt.py │ ├── handlers │ │ └── test_warnings_of_deprecation_of_handlers.py │ └── metrics │ │ └── test_warnings_of_deprecation_of_metrics.py ├── distributed │ ├── __init__.py │ ├── check_idist_parallel.py │ ├── comp_models │ │ ├── __init__.py │ │ ├── test_base.py │ │ ├── test_horovod.py │ │ ├── test_native.py │ │ └── test_xla.py │ ├── test_auto.py │ ├── test_launcher.py │ └── utils │ │ ├── __init__.py │ │ ├── test_horovod.py │ │ ├── test_native.py │ │ ├── test_serial.py │ │ └── test_xla.py ├── engine │ ├── __init__.py │ ├── test_create_supervised.py │ ├── test_custom_events.py │ ├── test_deterministic.py │ ├── test_engine.py │ ├── test_engine_state_dict.py │ └── test_event_handlers.py ├── handlers │ ├── __init__.py │ ├── conftest.py │ ├── test_base_logger.py │ ├── test_checkpoint.py │ ├── test_clearml_logger.py │ ├── test_early_stopping.py │ ├── test_ema_handler.py │ ├── test_fbresearch_logger.py │ ├── test_handlers.py │ ├── test_lr_finder.py │ ├── test_mlflow_logger.py │ ├── test_neptune_logger.py │ ├── test_param_scheduler.py │ ├── test_polyaxon_logger.py │ ├── test_state_param_scheduler.py │ ├── test_stores.py │ ├── test_tensorboard_logger.py │ ├── test_terminate_on_nan.py │ ├── test_time_limit.py │ ├── test_time_profilers.py │ ├── test_timing.py │ ├── test_tqdm_logger.py │ ├── test_visdom_logger.py │ └── test_wandb_logger.py ├── metrics │ ├── __init__.py │ ├── clustering │ │ ├── __init__.py │ │ ├── test_calinski_harabasz_score.py │ │ ├── test_davies_bouldin_score.py │ │ └── test_silhouette_score.py │ ├── conftest.py │ ├── gan │ │ ├── __init__.py │ │ ├── test_fid.py │ │ ├── test_inception_score.py │ │ └── test_utils.py │ ├── nlp │ │ ├── __init__.py │ │ ├── test_bleu.py │ │ ├── test_rouge.py │ │ └── test_utils.py │ ├── regression │ │ ├── __init__.py │ │ ├── test__base.py │ │ ├── test_canberra_metric.py │ │ ├── test_fractional_absolute_error.py │ │ ├── test_fractional_bias.py │ │ ├── test_geometric_mean_absolute_error.py │ │ ├── test_geometric_mean_relative_absolute_error.py │ │ ├── test_kendall_correlation.py │ │ ├── test_manhattan_distance.py │ │ ├── test_maximum_absolute_error.py │ │ ├── test_mean_absolute_relative_error.py │ │ ├── test_mean_error.py │ │ ├── test_mean_normalized_bias.py │ │ ├── test_median_absolute_error.py │ │ ├── test_median_absolute_percentage_error.py │ │ ├── test_median_relative_absolute_error.py │ │ ├── test_pearson_correlation.py │ │ ├── test_r2_score.py │ │ ├── test_spearman_correlation.py │ │ └── test_wave_hedges_distance.py │ ├── test_accumulation.py │ ├── test_accuracy.py │ ├── test_average_precision.py │ ├── test_classification_report.py │ ├── test_cohen_kappa.py │ ├── test_confusion_matrix.py │ ├── test_cosine_similarity.py │ ├── test_dill.py │ ├── test_entropy.py │ ├── test_epoch_metric.py │ ├── test_fbeta.py │ ├── test_frequency.py │ ├── test_gpu_info.py │ ├── test_hsic.py │ ├── test_js_divergence.py │ ├── test_kl_divergence.py │ ├── test_loss.py │ ├── test_maximum_mean_discrepancy.py │ ├── test_mean_absolute_error.py │ ├── test_mean_average_precision.py │ ├── test_mean_pairwise_distance.py │ ├── test_mean_squared_error.py │ ├── test_metric.py │ ├── test_metric_group.py │ ├── test_metrics_lambda.py │ ├── test_multilabel_confusion_matrix.py │ ├── test_mutual_information.py │ ├── test_precision.py │ ├── test_precision_recall_curve.py │ ├── test_psnr.py │ ├── test_recall.py │ ├── test_roc_auc.py │ ├── test_roc_curve.py │ ├── test_root_mean_squared_error.py │ ├── test_running_average.py │ ├── test_ssim.py │ ├── test_top_k_categorical_accuracy.py │ └── vision │ │ ├── __init__.py │ │ └── test_object_detection_map.py └── test_utils.py ├── run_code_style.bat ├── run_code_style.sh ├── run_cpu_tests.sh ├── run_gpu_tests.sh ├── run_multinode_tests_in_docker.sh └── run_tpu_tests.sh /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [vfdev-5] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: pytorch-ignite # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Submit a bug report to help us improve Ignite 4 | --- 5 | 6 | ## 🐛 Bug description 7 | 8 | 9 | 10 | 11 | 12 | 13 | ## Environment 14 | 15 | - PyTorch Version (e.g., 1.4): 16 | - Ignite Version (e.g., 0.3.0): 17 | - OS (e.g., Linux): 18 | - How you installed Ignite (`conda`, `pip`, source): 19 | - Python version: 20 | - Any other relevant information: 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Report an issue, comment or suggestion related to project's docs 4 | labels: "docs" 5 | --- 6 | 7 | ## 📚 Documentation 8 | 9 | 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature Request" 3 | about: Submit a proposal/request for a new Ingite feature 4 | --- 5 | 6 | ## 🚀 Feature 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓Questions/Help/Support" 3 | about: Do you have a question? 4 | labels: "question" 5 | --- 6 | 7 | ## ❓ Questions/Help/Support 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/user-feedback.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F44D User feedback" 3 | about: Say thanks or why you don't like 4 | title: "" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | > This is a place to leave any feedback on this package. 10 | > If you like the work, feel free to say thanks here 11 | > If you do not like something, please, share it with us and we can see how to improve 12 | > Thank you ! 13 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Fixes #{issue number} 2 | 3 | Description: 4 | 5 | Check list: 6 | 7 | - [ ] New tests are added (if a new feature is added) 8 | - [ ] New doc strings: description and/or example code are in RST format 9 | - [ ] Documentation is updated (if required) 10 | -------------------------------------------------------------------------------- /.github/failed_schedule_issue_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Scheduled workflow failed 3 | labels: 4 | - bug 5 | --- 6 | 7 | Oh no, something went wrong in the scheduled workflow **{{ env.GITHUB_WORKFLOW }} with commit {{ env.GITHUB_SHA }}**. 8 | Please look into it: 9 | 10 | {{ env.GITHUB_SERVER_URL }}/{{ env.GITHUB_REPOSITORY }}/actions/runs/{{ env.GITHUB_RUN_ID }} 11 | 12 | Feel free to close this if this was just a one-off error. 13 | -------------------------------------------------------------------------------- /.github/pr-labeler-config.yml: -------------------------------------------------------------------------------- 1 | # Add 'docker' to any changes within 'docker' folder or any subfolders 2 | docker: 3 | - changed-files: 4 | - any-glob-to-any-file: docker/** 5 | 6 | # Add 'docs' to any changes within 'docs' folder 7 | docs: 8 | - changed-files: 9 | - any-glob-to-any-file: docs/** 10 | 11 | # Add 'ci' to any changes in '.github' folder 12 | ci: 13 | - changed-files: 14 | - any-glob-to-any-file: .github/** 15 | 16 | # Add 'examples' to any changes within 'examples' folder 17 | examples: 18 | - changed-files: 19 | - any-glob-to-any-file: examples/** 20 | 21 | # Add 'base' to any changes within 'base' folder 22 | "module: base": 23 | - changed-files: 24 | - any-glob-to-any-file: ignite/base/**/* 25 | 26 | # Add 'contrib' to any changes within 'contrib' folder 27 | "module: contrib": 28 | - changed-files: 29 | - any-glob-to-any-file: ignite/contrib/**/* 30 | 31 | # Add 'distributed' to any changes within 'distributed' folder 32 | "module: distributed": 33 | - changed-files: 34 | - any-glob-to-any-file: ignite/distributed/**/* 35 | 36 | # Add 'engine' to any changes within 'engine' folder 37 | "module: engine": 38 | - changed-files: 39 | - any-glob-to-any-file: ignite/engine/**/* 40 | 41 | # Add 'handlers' to any changes within 'handlers' folder 42 | "module: handlers": 43 | - changed-files: 44 | - any-glob-to-any-file: ignite/handlers/**/* 45 | 46 | # Add 'metrics' to any changes within 'metrics' folder 47 | "module: metrics": 48 | - changed-files: 49 | - any-glob-to-any-file: ignite/metrics/**/* 50 | 51 | - 52 | # Add 'utils' to any changes within 'utils' module 53 | "module: utils": 54 | - changed-files: 55 | - any-glob-to-any-file: ignite/utils.py 56 | -------------------------------------------------------------------------------- /.github/workflows/binaries-nightly-release.yml: -------------------------------------------------------------------------------- 1 | name: Nightly Releases 2 | 3 | on: 4 | # https://docs.github.com/en/free-pro-team@latest/actions/reference/workflow-syntax-for-github-actions#onschedule 5 | schedule: 6 | # Run at 00:00 UTC Every Day 7 | - cron: "0 0 * * *" 8 | 9 | jobs: 10 | build-publish: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Setup Miniconda 16 | uses: conda-incubator/setup-miniconda@v2 17 | with: 18 | miniconda-version: "latest" 19 | python-version: "3.10" 20 | 21 | - name: Setup nightly version 22 | run: | 23 | sed -i "s/__version__ = \"\(.*\)\"/__version__ = \"\1.dev$(date -u +%Y%m%d)\"/g" ignite/__init__.py 24 | cat ignite/__init__.py 25 | 26 | - name: Install dependencies 27 | shell: bash -l {0} 28 | run: | 29 | conda install -y pytorch torchvision cpuonly -c pytorch-nightly 30 | pip install -r requirements-dev.txt 31 | 32 | - name: Build and Publish Conda binaries 33 | shell: bash -l {0} 34 | env: 35 | ANACONDA_TOKEN: ${{ secrets.ANACONDA_TOKEN }} 36 | UPLOAD_USER: "pytorch-nightly" 37 | run: | 38 | chmod +x ./conda.recipe/build_and_upload.sh 39 | ./conda.recipe/build_and_upload.sh 40 | 41 | - name: Build and Publish PyPI binaries 42 | shell: bash -l {0} 43 | run: | 44 | # workaround to fix https://github.com/pytorch/ignite/issues/2373 45 | pip uninstall -y twine pkginfo 46 | pip install --upgrade --no-cache-dir twine 'pkginfo>=1.8.2' 47 | python setup.py sdist bdist_wheel 48 | twine --version 49 | twine check dist/* 50 | TWINE_USERNAME="${{ secrets.PYPI_USER }}" TWINE_PASSWORD="${{ secrets.PYPI_TOKEN }}" twine upload --verbose dist/* 51 | 52 | - uses: JasonEtco/create-an-issue@v2 53 | name: Create issue if nightly releases failed 54 | if: failure() 55 | env: 56 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 57 | with: 58 | filename: .github/failed_schedule_issue_template.md 59 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.sh: -------------------------------------------------------------------------------- 1 | sphinx-versioning --use-master-conf --use-master-templates build --greatest-tag --whitelist-branches master docs/source docs/build/html 2 | -------------------------------------------------------------------------------- /.github/workflows/code-style.yml: -------------------------------------------------------------------------------- 1 | name: Format python code 2 | 3 | on: 4 | push: 5 | paths: 6 | - "**.py" 7 | - "setup.cfg" 8 | - "requirements-dev.txt" 9 | - "pyproject.toml" 10 | - "tests/run_code_style.sh" 11 | - ".github/workflows/code-style.yml" 12 | - "!assets/**" 13 | - "!docker/**" 14 | - "!docs/**" 15 | - "!conda.recipe" 16 | 17 | 18 | jobs: 19 | code-style: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - if: github.event_name == 'push' 23 | uses: actions/checkout@v4 24 | - uses: actions/setup-python@v4 25 | with: 26 | python-version: "3.11" 27 | - run: | 28 | bash ./tests/run_code_style.sh install 29 | bash ./tests/run_code_style.sh fmt 30 | 31 | - name: Commit and push changes 32 | uses: stefanzweifel/git-auto-commit-action@v4 33 | with: 34 | commit_message: "autopep8 fix" 35 | env: 36 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/discord_issues.yml: -------------------------------------------------------------------------------- 1 | name: Discuss "help-wanted" issue on Discord 2 | 3 | on: 4 | issues: 5 | types: 6 | - labeled 7 | workflow_dispatch: 8 | inputs: 9 | issue_number: 10 | description: 'Issue number' 11 | required: true 12 | 13 | permissions: 14 | issues: write 15 | 16 | jobs: 17 | discord: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: "Discuss on Discord-Issues" 21 | if: ${{ github.event.label.name == 'help wanted' }} 22 | uses: EndBug/discuss-on-discord@v1.1.0 23 | with: 24 | discord_bot_token: ${{ secrets.DISCORD_BOT_TOKEN }} 25 | destination: ${{ secrets.DISCORD_BOT_DESTINATION }} 26 | issue_number: ${{ github.event.inputs.issue_number || github.event.issue.number }} 27 | issue_comment: Hey 👋, I've just created a [thread]($THREAD_LINK$) for this issue on [PyTorch-Ignite Discord](https://pytorch-ignite.ai/chat) where you can quickly talk to the community on the topic. 28 | discord_message: New issue created in `${{ github.repository }}`: 29 | 30 | 31 | -------------------------------------------------------------------------------- /.github/workflows/discord_pull_requests.yaml: -------------------------------------------------------------------------------- 1 | name: Discuss "help-wanted" PR on Discord 2 | 3 | on: 4 | pull_request: 5 | types: 6 | - labeled 7 | workflow_dispatch: 8 | inputs: 9 | pull_request_number: 10 | description: 'Pull request number' 11 | required: true 12 | 13 | permissions: 14 | pull-requests: write 15 | 16 | jobs: 17 | discord: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: "Discuss on Discord-PR (Non-maintainer only)" 21 | if: ${{ github.event.label.name == 'help wanted' }} 22 | uses: EndBug/discuss-on-discord@v1.1.0 23 | with: 24 | discord_bot_token: ${{ secrets.DISCORD_BOT_TOKEN }} 25 | destination: ${{ secrets.DISCORD_BOT_DESTINATION }} 26 | issue_number: ${{ github.event.inputs.pull_request_number || github.event.pull_request.number }} 27 | issue_comment: Hey 👋, I've just created a [thread]($THREAD_LINK$) for this pull request on [PyTorch-Ignite Discord](https://pytorch-ignite.ai/chat) where you can quickly talk to the community on the topic. 28 | discord_message: New PR created in `${{ github.repository }}`: 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Build docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | paths-ignore: 9 | - "tests/**" 10 | - "docker/**" 11 | release: 12 | types: [published] 13 | workflow_dispatch: 14 | 15 | jobs: 16 | build-deploy: 17 | permissions: 18 | contents: write 19 | if: (github.ref == 'refs/heads/master' && github.event_name == 'push') || github.event_name == 'release' 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: actions/setup-python@v4 24 | with: 25 | python-version: "3.10" 26 | 27 | - run: sudo npm install katex -g 28 | - uses: actions/cache@v3 29 | with: 30 | path: ~/.cache/pip 31 | key: pip-${{ hashFiles('requirements-dev.txt') }}-${{ hashFiles('docs/requirements.txt') }} 32 | 33 | - name: Install docs deps 34 | run: bash .github/workflows/install_docs_deps.sh 35 | 36 | - name: Build docs 37 | run: bash .github/workflows/build_docs.sh 38 | 39 | - name: Deploy docs 40 | uses: peaceiris/actions-gh-pages@v3 41 | with: 42 | github_token: ${{ secrets.GITHUB_TOKEN }} 43 | publish_dir: docs/build/html 44 | publish_branch: gh-pages 45 | commit_message: Deploy pytorch/ignite docs 46 | force_orphan: true 47 | 48 | linkcheck: 49 | if: github.event_name == 'pull_request' || github.event_name == 'push' 50 | runs-on: ubuntu-latest 51 | timeout-minutes: 10 52 | steps: 53 | - uses: actions/checkout@v4 54 | - uses: actions/setup-python@v4 55 | with: 56 | python-version: "3.10" 57 | 58 | - uses: actions/cache@v3 59 | with: 60 | path: ~/.cache/pip 61 | key: pip-${{ hashFiles('requirements-dev.txt') }}-${{ hashFiles('docs/requirements.txt') }} 62 | 63 | - name: Install docs deps 64 | run: bash .github/workflows/install_docs_deps.sh 65 | 66 | - name: make linkcheck 67 | working-directory: ./docs/ 68 | run: make linkcheck 69 | 70 | doctest: 71 | if: github.event_name == 'pull_request' || github.event_name == 'push' 72 | runs-on: ubuntu-latest 73 | steps: 74 | - uses: actions/checkout@v4 75 | - uses: actions/setup-python@v4 76 | with: 77 | python-version: "3.10" 78 | 79 | - run: sudo npm install katex -g 80 | - uses: actions/cache@v3 81 | with: 82 | path: ~/.cache/pip 83 | key: pip-${{ hashFiles('requirements-dev.txt') }}-${{ hashFiles('docs/requirements.txt') }} 84 | 85 | - name: Install docs deps 86 | run: bash .github/workflows/install_docs_deps.sh 87 | 88 | - name: make doctest 89 | working-directory: ./docs/ 90 | run: | 91 | make html 92 | make doctest 93 | make coverage 94 | -------------------------------------------------------------------------------- /.github/workflows/install_docs_deps.sh: -------------------------------------------------------------------------------- 1 | # remove pkg-resources as it causes failure when installing https://github.com/pytorch-ignite/sphinxcontrib-versioning 2 | pip uninstall -y pkg-resources setuptools && pip install --upgrade setuptools pip wheel 3 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U 4 | pip install -r requirements-dev.txt 5 | pip install -r docs/requirements.txt 6 | pip install git+https://github.com/pytorch-ignite/sphinxcontrib-versioning.git 7 | -------------------------------------------------------------------------------- /.github/workflows/stable-release-anaconda.yml: -------------------------------------------------------------------------------- 1 | name: Anaconda Stable Releases 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | conda-build-publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - name: Setup Miniconda 14 | uses: conda-incubator/setup-miniconda@v2 15 | with: 16 | miniconda-version: "latest" 17 | python-version: "3.10" 18 | 19 | - name: Install dependencies 20 | shell: bash -l {0} 21 | run: | 22 | conda install -y pytorch torchvision cpuonly -c pytorch 23 | python setup.py install 24 | 25 | - name: Build and Publish Conda binaries 26 | shell: bash -l {0} 27 | env: 28 | ANACONDA_TOKEN: ${{ secrets.ANACONDA_TOKEN }} 29 | UPLOAD_USER: "pytorch" 30 | run: | 31 | chmod +x ./conda.recipe/build_and_upload.sh 32 | ./conda.recipe/build_and_upload.sh 33 | -------------------------------------------------------------------------------- /.github/workflows/stable-release-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Stable Releases 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build-publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - name: Setup Miniconda 14 | uses: conda-incubator/setup-miniconda@v2 15 | with: 16 | miniconda-version: "latest" 17 | python-version: "3.10" 18 | 19 | - name: Install dependencies 20 | shell: bash -l {0} 21 | run: | 22 | conda install -y pytorch torchvision cpuonly -c pytorch 23 | pip install -r requirements-dev.txt 24 | 25 | - name: Build and Publish PyPI binaries 26 | shell: bash -l {0} 27 | run: | 28 | # workaround to fix https://github.com/pytorch/ignite/issues/2373 29 | pip uninstall -y twine pkginfo 30 | pip install --upgrade --no-cache-dir twine 'pkginfo>=1.8.2' 31 | python setup.py sdist bdist_wheel 32 | twine --version 33 | twine check dist/* 34 | TWINE_USERNAME="${{ secrets.PYPI_USER }}" TWINE_PASSWORD="${{ secrets.PYPI_TOKEN }}" twine upload --verbose dist/* 35 | 36 | -------------------------------------------------------------------------------- /.github/workflows/triage.yml: -------------------------------------------------------------------------------- 1 | name: Triage 2 | 3 | on: 4 | pull_request_target: 5 | # types: [opened] 6 | # issues: 7 | # types: [opened] 8 | 9 | jobs: 10 | triage: 11 | permissions: 12 | contents: read 13 | pull-requests: write 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Pull Request Labeler 18 | uses: actions/labeler@v5 19 | with: 20 | configuration-path: .github/pr-labeler-config.yml 21 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 22 | 23 | # Turned off due to unexpected behavior on issue opening+labeling? https://github.com/pytorch/ignite/issues/1836 24 | # - name: Welcome 25 | # uses: actions/first-interaction@v1 26 | # with: 27 | # issue-message: "**Thank you for opening your First Issue!**\n\nWe appreciate a lot user's feedback on what we are doing!\n\nIf you'd like to contribute to the project, please check out our [Contributing Guide](https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md)." 28 | # pr-message: "**Thank you for opening your First Pull Request!**\nWe appreciate a lot community contributions as pull requests!\n\nIf you would like to get more details on the project development, please take a look at our [Contributing Guide](https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md)." 29 | # repo-token: ${{ secrets.GITHUB_TOKEN }} 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *.pyc 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # Distribution / packaging 8 | .Python 9 | build/ 10 | develop-eggs/ 11 | dist/ 12 | downloads/ 13 | eggs/ 14 | .eggs/ 15 | lib/ 16 | lib64/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | **/.DS_Store 26 | 27 | # Unit test / coverage reports 28 | htmlcov/ 29 | .tox/ 30 | .coverage 31 | .coverage.* 32 | .cache 33 | nosetests.xml 34 | coverage.xml 35 | *.cover 36 | .hypothesis/ 37 | .pytest_cache/ 38 | 39 | # Documentation 40 | /docs/src/ 41 | /docs/build/ 42 | /docs/source/generated/ 43 | 44 | # Virtualenv 45 | .venv/ 46 | .python-version 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: "^conda.recipe" 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v3.4.0 5 | hooks: 6 | - id: check-toml 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | 11 | - repo: https://github.com/pre-commit/mirrors-prettier 12 | rev: v2.2.1 13 | hooks: 14 | - id: prettier 15 | exclude_types: ["python", "jupyter", "shell", "gitignore"] 16 | 17 | - repo: https://github.com/omnilib/ufmt 18 | rev: v2.7.3 19 | hooks: 20 | - id: ufmt 21 | additional_dependencies: 22 | - black == 24.10.0 23 | - usort == 1.0.8.post1 24 | 25 | - repo: https://github.com/pycqa/flake8 26 | rev: 7.1.1 27 | hooks: 28 | - id: flake8 29 | args: ["--config", "setup.cfg"] 30 | -------------------------------------------------------------------------------- /CITATION: -------------------------------------------------------------------------------- 1 | @misc{pytorch-ignite, 2 | author = {V. Fomin and J. Anmol and S. Desroziers and J. Kriss and A. Tejani}, 3 | title = {High-level library to help with training neural networks in PyTorch}, 4 | year = {2020}, 5 | publisher = {GitHub}, 6 | journal = {GitHub repository}, 7 | howpublished = {\url{https://github.com/pytorch/ignite}}, 8 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018-present, PyTorch-Ignite team 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /assets/ignite_vs_bare_pytorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/ignite_vs_bare_pytorch.png -------------------------------------------------------------------------------- /assets/logo/ignite_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logo.png -------------------------------------------------------------------------------- /assets/logo/ignite_logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logo_dark.png -------------------------------------------------------------------------------- /assets/logo/ignite_logo_guidelines.md: -------------------------------------------------------------------------------- 1 | # PyTorch Ignite Logo Guidelines 2 | 3 | These guidelines are meant to help keep the PyTorch Ignite logo (as developed in [#1221](https://github.com/pytorch/ignite/issues/1221)) consistent and recognizable across all its uses. They also provide a common language for referring to the logos and their components. 4 | 5 | The primary logo is the combination of the logomark and wordmark next to each other. The logomark is the flame alone (no text) and the wordmark is only the text. It's preferable to use the primary logo whenever possible, and the logomark when a smaller version is needed. 6 | 7 | ## Color 8 | 9 | The full color options are a combonation of PyTorch's main orange (`#ee4c2c`) with yellow details (`#eaa700`). Light options are white (`#FFFFFF`) and dark options dark grey (`#2a2a2a`). The alternate "mixed" logo uses the full color logomark with a dark grey wordmark. 10 | 11 | Whenever possible, use the full color logos. One color logos (light or dark) are to be used when full color will not have enough contrast, usually when logos must be on colored backgrounds or are being reproduced somewhere that doesn't support color. 12 | 13 | Please note: The orange (`#ee4c2c`) and yellow (`#eaa700`) do not meet WCAG 2.1 color contrast recommendations for text or UI when used with white or other light colors. Make sure to use these colors primarily as decorative elements or with a dark color for text and/or UI. Accessibility should not be overlooked. 14 | 15 | ## Type 16 | 17 | The PyTorch Ignite wordmark is made from Oxygen (by Vernon Adams @vernnobile). 18 | 19 | ## Minimum Size 20 | 21 | For consistent legibility, please do not display the primary logo at less than 60px wide or the logomark at less than 15px wide. 22 | 23 | ## Logo Integrity 24 | 25 | A few other notes to keep in mind when using the logo: 26 | 27 | - Make sure to scale the logo proportionally. 28 | - Maintain a good amount of space around the logo. Don’t let it overlap with text, images, or other elements. 29 | - Do not try and recreate or modify the logo. For example, do not use the logomark and then try to write PyTorch Ignite in another font. 30 | -------------------------------------------------------------------------------- /assets/logo/ignite_logo_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logo_light.png -------------------------------------------------------------------------------- /assets/logo/ignite_logo_mixed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logo_mixed.png -------------------------------------------------------------------------------- /assets/logo/ignite_logo_mixed_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logo_mixed_light.png -------------------------------------------------------------------------------- /assets/logo/ignite_logomark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logomark.png -------------------------------------------------------------------------------- /assets/logo/ignite_logomark.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/logo/ignite_logomark_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logomark_dark.png -------------------------------------------------------------------------------- /assets/logo/ignite_logomark_dark.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/logo/ignite_logomark_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/logo/ignite_logomark_light.png -------------------------------------------------------------------------------- /assets/logo/ignite_logomark_light.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/tldr/pytorch-ignite-teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/assets/tldr/pytorch-ignite-teaser.gif -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | round: down 4 | range: "95...100" 5 | status: 6 | patch: 7 | default: 8 | target: 90 9 | project: 10 | default: 11 | threshold: 1% 12 | changes: false 13 | comment: false 14 | ignore: 15 | - "tests/" 16 | -------------------------------------------------------------------------------- /conda.recipe/build_and_upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Build and upload Conda binaries" 4 | 5 | # ANACONDA_TOKEN should be provided 6 | # How to generate ANACONDA_TOKEN: https://docs.anaconda.com/anaconda-cloud/user-guide/tasks/work-with-accounts#creating-access-tokens 7 | # https://conda.io/docs/user-guide/tasks/build-packages/install-conda-build.html 8 | 9 | if [ -z $ANACONDA_TOKEN ]; then 10 | echo "Can not find ANACONDA_TOKEN env variable" 11 | echo "Please, export ANACONDA_TOKEN= before calling this script" 12 | exit 1 13 | fi 14 | 15 | if [ -z $UPLOAD_USER ]; then 16 | echo "Can not find UPLOAD_USER env variable" 17 | echo "Please, export UPLOAD_USER= before calling this script" 18 | exit 1 19 | fi 20 | 21 | set -xeu 22 | 23 | conda install -y conda-build conda-verify anaconda-client conda-package-handling 24 | conda config --set anaconda_upload no 25 | 26 | conda build --no-test --output-folder conda_build conda.recipe -c pytorch --package-format 1 27 | cph transmute $(ls conda_build/*/*.tar.bz2) .conda 28 | 29 | # Upload to Anaconda 30 | conda config --set anaconda_upload yes 31 | ls conda_build/*/*.{conda,tar.bz2} | xargs -I {} anaconda -v -t $ANACONDA_TOKEN upload -u $UPLOAD_USER {} 32 | -------------------------------------------------------------------------------- /conda.recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data() %} 2 | 3 | package: 4 | name: ignite 5 | version: {{ data['version'] }} 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | number: 0 12 | noarch: python 13 | script: python setup.py install --single-version-externally-managed --record=record.txt 14 | 15 | # https://conda.io/docs/user-guide/tasks/build-packages/define-metadata.html#export-runtime-requirements 16 | requirements: 17 | build: 18 | - python>=3.9 19 | - setuptools 20 | - pytorch>=1.10 21 | 22 | run: 23 | - python>=3.9 24 | - pytorch>=1.10 25 | 26 | test: 27 | imports: 28 | - ignite 29 | - ignite.engine 30 | - ignite.handlers 31 | - ignite.metrics 32 | - ignite.contrib 33 | 34 | about: 35 | home: {{ data['url'] }} 36 | license: {{ data['license'] }} 37 | summary: {{ data['description'] }} 38 | -------------------------------------------------------------------------------- /docker/docker.cfg: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | build_docker_image_pytorch_version = 2.6.0-cuda12.4-cudnn9 3 | build_docker_image_hvd_version = v0.28.1 4 | build_docker_image_msdp_version = v0.14.0 5 | -------------------------------------------------------------------------------- /docker/hvd/Dockerfile.hvd-apex-nlp: -------------------------------------------------------------------------------- 1 | # Dockerfile.hvd-apex-nlp 2 | FROM pytorchignite/hvd-apex:latest 3 | 4 | # Ignite NLP dependencies 5 | RUN pip install --upgrade --no-cache-dir transformers \ 6 | spacy \ 7 | nltk 8 | -------------------------------------------------------------------------------- /docker/hvd/Dockerfile.hvd-apex-vision: -------------------------------------------------------------------------------- 1 | # Dockerfile.hvd-apex-vision 2 | FROM pytorchignite/hvd-apex:latest 3 | 4 | # Ignite vision dependencies 5 | RUN pip install --upgrade --no-cache-dir albumentations \ 6 | image-dataset-viz \ 7 | numpy \ 8 | opencv-python-headless \ 9 | py_config_runner \ 10 | clearml 11 | -------------------------------------------------------------------------------- /docker/hvd/Dockerfile.hvd-base: -------------------------------------------------------------------------------- 1 | # Multi-stage build 2 | # Dockerfile.hvd-base 3 | 4 | ARG PTH_VERSION 5 | 6 | FROM pytorch/pytorch:${PTH_VERSION}-devel as builder 7 | 8 | ARG HVD_VERSION 9 | 10 | # Build Horovod 11 | RUN apt-get update && apt-get install -y git && \ 12 | git clone --recursive --depth 1 --branch ${HVD_VERSION} https://github.com/horovod/horovod.git /horovod && \ 13 | conda install -y cmake nccl -c conda-forge && \ 14 | cd /horovod && \ 15 | # temporary -std=c++17 fix 16 | sed -i "s/CMAKE_CXX_STANDARD 14/CMAKE_CXX_STANDARD 17/g" CMakeLists.txt && \ 17 | sed -i "s/CMAKE_CXX_STANDARD 14/CMAKE_CXX_STANDARD 17/g" horovod/torch/CMakeLists.txt && \ 18 | HOROVOD_GPU_OPERATIONS=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITHOUT_MPI=1 HOROVOD_WITH_PYTORCH=1 pip wheel --no-cache-dir . && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | # Build runtime image 22 | FROM pytorch/pytorch:${PTH_VERSION}-runtime 23 | 24 | # Install tzdata / git 25 | RUN apt-get update && \ 26 | ln -fs /usr/share/zoneinfo/Europe/Paris /etc/localtime && \ 27 | apt-get -y install --no-install-recommends tzdata git && \ 28 | dpkg-reconfigure --frontend noninteractive tzdata && \ 29 | apt-get autoremove -y && \ 30 | apt-get clean -y && \ 31 | rm -rf /var/lib/apt/lists/* 32 | 33 | # Ignite main dependencies 34 | RUN pip install --upgrade --no-cache-dir pytorch-ignite \ 35 | tensorboard \ 36 | tqdm \ 37 | fire 38 | 39 | # Replace pillow with pillow-simd 40 | RUN apt-get update && apt-get -y install --no-install-recommends g++ && \ 41 | pip uninstall -y pillow && \ 42 | CC="cc -mavx2" pip install --upgrade --no-cache-dir --force-reinstall pillow-simd && \ 43 | apt-get remove -y g++ && \ 44 | apt-get autoremove -y && \ 45 | rm -rf /var/lib/apt/lists/* 46 | 47 | # Checkout Ignite examples only 48 | RUN mkdir -p pytorch-ignite-examples && \ 49 | cd pytorch-ignite-examples && \ 50 | git init && \ 51 | git config core.sparsecheckout true && \ 52 | echo examples >> .git/info/sparse-checkout && \ 53 | git remote add -f origin https://github.com/pytorch/ignite.git && \ 54 | git pull origin master && \ 55 | # rm very large .git folder 56 | rm -rf .git 57 | 58 | # Horovod 59 | RUN conda install -y nccl -c conda-forge 60 | 61 | ENV LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH 62 | 63 | COPY --from=builder /horovod/horovod-*.whl /horovod/ 64 | 65 | RUN cd /horovod && \ 66 | pip install --no-cache-dir horovod-*.whl && \ 67 | rm -fr /horovod -------------------------------------------------------------------------------- /docker/hvd/Dockerfile.hvd-nlp: -------------------------------------------------------------------------------- 1 | # Dockerfile.hvd-nlp 2 | FROM pytorchignite/hvd-base:latest 3 | 4 | # Ignite NLP dependencies 5 | RUN pip install --upgrade --no-cache-dir transformers \ 6 | spacy \ 7 | nltk 8 | -------------------------------------------------------------------------------- /docker/hvd/Dockerfile.hvd-vision: -------------------------------------------------------------------------------- 1 | # Dockerfile.hvd-vision 2 | FROM pytorchignite/hvd-base:latest 3 | 4 | # Ignite vision dependencies 5 | RUN pip install --upgrade --no-cache-dir albumentations \ 6 | image-dataset-viz \ 7 | numpy \ 8 | opencv-python-headless \ 9 | py_config_runner \ 10 | clearml 11 | -------------------------------------------------------------------------------- /docker/main/Dockerfile.apex: -------------------------------------------------------------------------------- 1 | # Multi-stage build 2 | # Dockerfile.apex 3 | 4 | ARG PTH_VERSION 5 | 6 | # 1/Building apex with pytorch:*-devel 7 | FROM pytorch/pytorch:${PTH_VERSION}-devel AS apex-builder 8 | 9 | ENV CUDA_HOME=/usr/local/cuda 10 | 11 | # Install git 12 | RUN apt-get update && apt-get install -y --no-install-recommends git && \ 13 | rm -rf /var/lib/apt/lists/* 14 | 15 | # Build apex 16 | RUN echo "Setup NVIDIA Apex" && \ 17 | tmp_apex_path="/tmp/apex" && \ 18 | rm -rf $tmp_apex_path && \ 19 | git clone https://github.com/NVIDIA/apex $tmp_apex_path && \ 20 | cd $tmp_apex_path && \ 21 | pip install packaging && \ 22 | pip wheel -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" . 23 | 24 | # 2/ Build the runtime image 25 | FROM pytorch/pytorch:${PTH_VERSION}-runtime 26 | 27 | COPY --from=apex-builder /tmp/apex/apex-*.whl /tmp/apex/ 28 | RUN pip install --no-cache-dir /tmp/apex/apex-*.whl && \ 29 | rm -fr /tmp/apex 30 | 31 | # Install tzdata / git 32 | RUN apt-get update && \ 33 | ln -fs /usr/share/zoneinfo/Europe/Paris /etc/localtime && \ 34 | apt-get -y install --no-install-recommends tzdata git && \ 35 | dpkg-reconfigure --frontend noninteractive tzdata && \ 36 | apt-get autoremove -y && \ 37 | apt-get clean -y && \ 38 | rm -rf /var/lib/apt/lists/* 39 | 40 | # Ignite main dependencies 41 | RUN pip install --upgrade --no-cache-dir pytorch-ignite \ 42 | tensorboard \ 43 | tqdm \ 44 | fire 45 | 46 | # replace pillow with pillow-simd 47 | RUN apt-get update && apt-get -y install --no-install-recommends g++ && \ 48 | pip uninstall -y pillow && \ 49 | CC="cc -mavx2" pip install --upgrade --no-cache-dir --force-reinstall pillow-simd && \ 50 | apt-get remove -y g++ && \ 51 | apt-get autoremove -y && \ 52 | rm -rf /var/lib/apt/lists/* 53 | 54 | # Checkout Ignite examples only 55 | RUN mkdir -p pytorch-ignite-examples && \ 56 | cd pytorch-ignite-examples && \ 57 | git init && \ 58 | git config core.sparsecheckout true && \ 59 | echo examples >> .git/info/sparse-checkout && \ 60 | git remote add -f origin https://github.com/pytorch/ignite.git && \ 61 | git pull origin master && \ 62 | # rm very large .git folder 63 | rm -rf .git -------------------------------------------------------------------------------- /docker/main/Dockerfile.apex-nlp: -------------------------------------------------------------------------------- 1 | # Dockerfile.apex-nlp 2 | FROM pytorchignite/apex:latest 3 | 4 | # Ignite NLP dependencies 5 | RUN pip install --upgrade --no-cache-dir transformers \ 6 | spacy \ 7 | nltk 8 | -------------------------------------------------------------------------------- /docker/main/Dockerfile.apex-vision: -------------------------------------------------------------------------------- 1 | # Dockerfile.apex-vision 2 | FROM pytorchignite/apex:latest 3 | 4 | # Ignite vision dependencies 5 | RUN pip install --upgrade --no-cache-dir albumentations \ 6 | image-dataset-viz \ 7 | numpy \ 8 | opencv-python-headless \ 9 | py_config_runner \ 10 | clearml 11 | -------------------------------------------------------------------------------- /docker/main/Dockerfile.base: -------------------------------------------------------------------------------- 1 | # Dockerfile.base 2 | ARG PTH_VERSION 3 | 4 | FROM pytorch/pytorch:${PTH_VERSION}-runtime 5 | 6 | # Install tzdata / git 7 | RUN apt-get update && \ 8 | ln -fs /usr/share/zoneinfo/Europe/Paris /etc/localtime && \ 9 | apt-get -y install --no-install-recommends tzdata git && \ 10 | dpkg-reconfigure --frontend noninteractive tzdata && \ 11 | apt-get autoremove -y && \ 12 | apt-get clean -y && \ 13 | rm -rf /var/lib/apt/lists/* 14 | 15 | # Ignite main dependencies 16 | RUN pip install --upgrade --no-cache-dir pytorch-ignite \ 17 | tensorboard \ 18 | tqdm \ 19 | fire 20 | 21 | # Replace pillow with pillow-simd 22 | RUN apt-get update && apt-get -y install --no-install-recommends g++ && \ 23 | pip uninstall -y pillow && \ 24 | CC="cc -mavx2" pip install --upgrade --no-cache-dir --force-reinstall pillow-simd && \ 25 | apt-get remove -y g++ && \ 26 | apt-get autoremove -y && \ 27 | rm -rf /var/lib/apt/lists/* 28 | 29 | # Checkout Ignite examples only 30 | RUN mkdir -p pytorch-ignite-examples && \ 31 | cd pytorch-ignite-examples && \ 32 | git init && \ 33 | git config core.sparsecheckout true && \ 34 | echo examples >> .git/info/sparse-checkout && \ 35 | git remote add -f origin https://github.com/pytorch/ignite.git && \ 36 | git pull origin master && \ 37 | # rm very large .git folder 38 | rm -rf .git -------------------------------------------------------------------------------- /docker/main/Dockerfile.nlp: -------------------------------------------------------------------------------- 1 | # Dockerfile.nlp 2 | FROM pytorchignite/base:latest 3 | 4 | # Ignite NLP dependencies 5 | RUN pip install --upgrade --no-cache-dir transformers \ 6 | spacy \ 7 | nltk -------------------------------------------------------------------------------- /docker/main/Dockerfile.vision: -------------------------------------------------------------------------------- 1 | # Dockerfile.vision 2 | FROM pytorchignite/base:latest 3 | 4 | # Ignite vision dependencies 5 | RUN pip install --upgrade --no-cache-dir albumentations \ 6 | image-dataset-viz \ 7 | numpy \ 8 | opencv-python-headless \ 9 | py_config_runner \ 10 | clearml 11 | -------------------------------------------------------------------------------- /docker/msdp/Dockerfile.msdp-apex-nlp: -------------------------------------------------------------------------------- 1 | # Dockerfile.msdp-apex-nlp 2 | FROM pytorchignite/msdp-apex:latest 3 | 4 | # Ignite NLP dependencies 5 | RUN pip install --upgrade --no-cache-dir transformers \ 6 | spacy \ 7 | nltk \ 8 | torchtext 9 | -------------------------------------------------------------------------------- /docker/msdp/Dockerfile.msdp-apex-vision: -------------------------------------------------------------------------------- 1 | # Dockerfile.msdp-apex-vision 2 | FROM pytorchignite/msdp-apex:latest 3 | 4 | # Ignite vision dependencies 5 | RUN pip install --upgrade --no-cache-dir albumentations \ 6 | image-dataset-viz \ 7 | numpy \ 8 | opencv-python-headless \ 9 | py_config_runner \ 10 | clearml 11 | -------------------------------------------------------------------------------- /docker/push_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Push all PyTorch-Ignite docker images" 4 | 5 | if [ -z $DOCKER_USER ]; then 6 | echo "Can not find DOCKER_USER env variable" 7 | echo "Please, export DOCKER_USER= before calling this script" 8 | exit 1 9 | fi 10 | 11 | if [ -z $DOCKER_TOKEN ]; then 12 | echo "Can not find DOCKER_TOKEN env variable" 13 | echo "Please, export DOCKER_TOKEN= before calling this script" 14 | exit 1 15 | fi 16 | 17 | if [ -z "$1" ]; then 18 | push_selected_image="all" 19 | else 20 | push_selected_image="$1" 21 | fi 22 | 23 | set -eu 24 | 25 | echo $DOCKER_TOKEN | docker login --username=$DOCKER_USER --password-stdin 26 | 27 | set -xeu 28 | 29 | if [ ${push_selected_image} == "all" ]; then 30 | image_name="base" 31 | image_tag=`docker run --rm -i pytorchignite/${image_name}:latest python -c "import torch; import ignite; print(torch.__version__.split('+')[0] + \"-\" + ignite.__version__, end=\"\")"` 32 | 33 | for image_name in "base" "vision" "nlp" "apex" "apex-vision" "apex-nlp" 34 | do 35 | 36 | docker push pytorchignite/${image_name}:latest 37 | docker push pytorchignite/${image_name}:${image_tag} 38 | 39 | done 40 | 41 | for image_name in "hvd-base" "hvd-vision" "hvd-nlp" "hvd-apex" "hvd-apex-vision" "hvd-apex-nlp" 42 | do 43 | 44 | docker push pytorchignite/${image_name}:latest 45 | docker push pytorchignite/${image_name}:${image_tag} 46 | 47 | done 48 | 49 | # DEPRECATED due to no activity 50 | # for image_name in "msdp-apex" "msdp-apex-vision" "msdp-apex-nlp" 51 | # do 52 | 53 | # docker push pytorchignite/${image_name}:latest 54 | # docker push pytorchignite/${image_name}:${image_tag} 55 | 56 | # done 57 | 58 | else 59 | 60 | image_name=${push_selected_image} 61 | image_tag=`docker run --rm -i pytorchignite/${image_name}:latest python -c "import torch; import ignite; print(torch.__version__.split('+')[0] + \"-\" + ignite.__version__, end=\"\")"` 62 | 63 | docker push pytorchignite/${image_name}:latest 64 | docker push pytorchignite/${image_name}:${image_tag} 65 | 66 | fi 67 | 68 | 69 | 70 | # If use locally, mind to clean dangling images 71 | # docker images | grep 'pytorchignite\|' | awk '{print $3}' | xargs docker rmi -f 72 | # or 73 | # docker image prune 74 | -------------------------------------------------------------------------------- /docker/test_image.py: -------------------------------------------------------------------------------- 1 | # 2 | # Tests : 3 | # For all images 4 | # can import torch and its version == required one 5 | # can import ignite and its version == required one 6 | # for all -vision images 7 | # can import opencv without driver issue 8 | # for all horovod images 9 | # can import horovod and its version == required one 10 | # 11 | import argparse 12 | import importlib 13 | import os 14 | 15 | 16 | def check_package(package_name, expected_version=None): 17 | mod = importlib.import_module(package_name) 18 | 19 | if expected_version is not None: 20 | assert hasattr(mod, "__version__"), f"Imported package {package_name} does not have __version__ attribute" 21 | version = mod.__version__ 22 | # Remove all +something from the version name: e.g torch 2.5.1+cu124 23 | if "+" in version: 24 | old_version = version 25 | version = version.split("+")[0] 26 | print(f"Transformed version: {old_version} -> {version}") 27 | assert ( 28 | version == expected_version 29 | ), f"Version mismatch for package {package_name}: got {version} but expected {expected_version}" 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser("Check docker image script") 34 | parser.add_argument("image", type=str, help="Docker image to check") 35 | args = parser.parse_args() 36 | 37 | docker_image_name = args.image 38 | name, version = docker_image_name.split(":") 39 | assert version != "latest", version 40 | torch_version, ignite_version = version.split("-") 41 | _, image_type = name.split("/") 42 | 43 | check_package("torch", expected_version=torch_version) 44 | check_package("ignite", expected_version=ignite_version) 45 | 46 | if "hvd" in image_type: 47 | assert "HVD_VERSION" in os.environ 48 | val = os.environ["HVD_VERSION"] 49 | hvd_version = val if val[0] != "v" else val[1:] 50 | check_package("horovod", expected_version=hvd_version) 51 | 52 | if "msdp" in image_type: 53 | assert "MSDP_VERSION" in os.environ 54 | val = os.environ["MSDP_VERSION"] 55 | hvd_version = val if val[0] != "v" else val[1:] 56 | check_package("deepspeed", expected_version=hvd_version) 57 | 58 | if "vision" in image_type: 59 | check_package("cv2") 60 | 61 | if "nlp" in image_type: 62 | check_package("transformers") 63 | 64 | if "apex" in image_type: 65 | check_package("apex") 66 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -j auto -WT --keep-going --color 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = ignite 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | docset: html 16 | doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url https://pytorch.org/ignite/ --force $(BUILDDIR)/html/ 17 | 18 | # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. 19 | cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png 20 | convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png 21 | 22 | rebuild: 23 | rm -rf source/generated && make clean && make html 24 | 25 | clean: 26 | @echo "Cleaning up..." 27 | python -c "import shutil; shutil.rmtree('$(BUILDDIR)', ignore_errors=True)" 28 | python -c "import shutil; shutil.rmtree('$(SOURCEDIR)/generated', ignore_errors=True)" 29 | python -c "import os; [os.remove(f) for f in os.listdir('.') if f.endswith('.pyc')]" 30 | python -c "import shutil; import os; [shutil.rmtree(f) for f in os.listdir('.') if f == '__pycache__' and os.path.isdir(f)]" 31 | 32 | .PHONY: help Makefile docset 33 | 34 | # Catch-all target: route all unknown targets to Sphinx using the new 35 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 36 | %: Makefile 37 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 38 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=ignite 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<6 2 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 3 | sphinxcontrib-katex 4 | sphinx-copybutton==0.4.0 5 | docutils<0.18 6 | sphinx_togglebutton 7 | sphinx_design 8 | -------------------------------------------------------------------------------- /docs/source/_static/img/concepts/timeline_and_events.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/concepts/timeline_and_events.png -------------------------------------------------------------------------------- /docs/source/_static/img/schedulers/concat_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/schedulers/concat_example.png -------------------------------------------------------------------------------- /docs/source/_static/img/schedulers/concat_linear_exp_step_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/schedulers/concat_linear_exp_step_lr.png -------------------------------------------------------------------------------- /docs/source/_static/img/schedulers/cosine_annealing_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/schedulers/cosine_annealing_example.png -------------------------------------------------------------------------------- /docs/source/_static/img/schedulers/linear_cyclical_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/schedulers/linear_cyclical_example.png -------------------------------------------------------------------------------- /docs/source/_static/img/schedulers/lr_scheduler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/schedulers/lr_scheduler.png -------------------------------------------------------------------------------- /docs/source/_static/img/schedulers/piecewise_linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/schedulers/piecewise_linear.png -------------------------------------------------------------------------------- /docs/source/_static/img/schedulers/reduce_lr_on_plateau_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/docs/source/_static/img/schedulers/reduce_lr_on_plateau_example.png -------------------------------------------------------------------------------- /docs/source/_templates/_static/img/ignite_logomark.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | 8 | {% block methods %} 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | :nosignatures: 14 | {% for item in methods %} 15 | {% if item != "__init__" and item not in inherited_members %} 16 | ~{{ name }}.{{ item }} 17 | {% endif %} 18 | {% endfor %} 19 | {% endif %} 20 | {% endblock %} 21 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/function.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autofunction:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline }} 2 | 3 | .. automodule:: {{ fullname }} 4 | :members: 5 | 6 | {% block attributes %} 7 | {% if attributes %} 8 | .. rubric:: {{ _('Module Attributes') }} 9 | 10 | .. autosummary:: 11 | :nosignatures: 12 | {% for item in attributes %} 13 | {{ item }} 14 | {%- endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block functions %} 19 | {% if functions %} 20 | .. rubric:: {{ _('Functions') }} 21 | 22 | .. autosummary:: 23 | :nosignatures: 24 | {% for item in functions %} 25 | {{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block classes %} 31 | {% if classes %} 32 | .. rubric:: {{ _('Classes') }} 33 | 34 | .. autosummary:: 35 | :nosignatures: 36 | {% for item in classes %} 37 | {{ item }} 38 | {%- endfor %} 39 | {% endif %} 40 | {% endblock %} 41 | 42 | {% block exceptions %} 43 | {% if exceptions %} 44 | .. rubric:: {{ _('Exceptions') }} 45 | 46 | .. autosummary:: 47 | :nosignatures: 48 | {% for item in exceptions %} 49 | {{ item }} 50 | {%- endfor %} 51 | {% endif %} 52 | {% endblock %} 53 | 54 | {% block modules %} 55 | {% if modules %} 56 | .. rubric:: Modules 57 | 58 | .. autosummary:: 59 | :toctree: 60 | :recursive: 61 | {% for item in modules %} 62 | {{ item }} 63 | {%- endfor %} 64 | {% endif %} 65 | {% endblock %} 66 | -------------------------------------------------------------------------------- /docs/source/_templates/classwithcall.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :special-members: __call__ 8 | 9 | {% block methods %} 10 | {% if methods %} 11 | .. rubric:: {{ _('Methods') }} 12 | 13 | .. autosummary:: 14 | :nosignatures: 15 | 16 | ~{{ name }}.__call__ 17 | {% for item in methods %} 18 | {% if item != "__init__" and item not in inherited_members %} 19 | ~{{ name }}.{{ item }} 20 | {% endif %} 21 | {% endfor %} 22 | {% endif %} 23 | {% endblock %} 24 | -------------------------------------------------------------------------------- /docs/source/_templates/fonts.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- set external_urls = { 2 | 'github': 'https://github.com/pytorch/ignite', 3 | 'github_issues': 'https://github.com/pytorch/ignite/issues', 4 | 'contributing': 'https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md', 5 | 'docs': 'https://pytorch.org/ignite/index.html', 6 | 7 | 'home': 'https://pytorch-ignite.ai', 8 | 9 | 'concepts': 'concepts/', 10 | 'quickstart': 'tutorials/beginner/01-getting-started/', 11 | 'examples': 'tutorials/beginner/01-getting-started/#complete-code', 12 | 'faq': 'how-to-guides/', 13 | 'about_us': 'about/community/' 14 | } 15 | -%} -------------------------------------------------------------------------------- /docs/source/contrib/engines.rst: -------------------------------------------------------------------------------- 1 | ignite.contrib.engines 2 | ====================== 3 | 4 | Contribution module of engines and helper tools: 5 | 6 | ignite.contrib.engines.tbptt 7 | 8 | .. currentmodule:: ignite.contrib.engines.tbptt 9 | 10 | .. autosummary:: 11 | :nosignatures: 12 | :autolist: 13 | 14 | ignite.contrib.engines.common 15 | 16 | .. currentmodule:: ignite.contrib.engines.common 17 | 18 | .. autosummary:: 19 | :nosignatures: 20 | :autolist: 21 | 22 | Truncated Backpropagation Through Time 23 | --------------------------------------- 24 | 25 | .. automodule:: ignite.contrib.engines.tbptt 26 | :members: 27 | 28 | Helper methods to setup trainer/evaluator 29 | ----------------------------------------- 30 | 31 | .. automodule:: ignite.contrib.engines.common 32 | :members: 33 | -------------------------------------------------------------------------------- /docs/source/contrib/handlers.rst: -------------------------------------------------------------------------------- 1 | ignite.contrib.handlers 2 | ======================= 3 | 4 | Contribution module of handlers 5 | 6 | 7 | Parameter scheduler [deprecated] 8 | -------------------------------- 9 | 10 | .. deprecated:: 0.4.4 11 | Use :class:`~ignite.handlers.param_scheduler.ParamScheduler` instead, will be removed in version 0.6.0. 12 | 13 | Was moved to :ref:`param-scheduler-label`. 14 | 15 | LR finder [deprecated] 16 | ---------------------- 17 | 18 | .. deprecated:: 0.4.4 19 | Use :class:`~ignite.handlers.lr_finder.FastaiLRFinder` instead, will be removed in version 0.6.0. 20 | 21 | Time profilers [deprecated] 22 | --------------------------- 23 | 24 | .. deprecated:: 0.4.6 25 | Use :class:`~ignite.handlers.time_profilers.BasicTimeProfiler` instead, will be removed in version 0.6.0. 26 | Use :class:`~ignite.handlers.time_profilers.HandlersTimeProfiler` instead, will be removed in version 0.6.0. 27 | 28 | Loggers [deprecated] 29 | -------------------- 30 | 31 | .. deprecated:: 0.5.0 32 | Loggers moved to :ref:`Loggers`. 33 | -------------------------------------------------------------------------------- /docs/source/contrib/metrics.rst: -------------------------------------------------------------------------------- 1 | ignite.contrib.metrics 2 | ======================= 3 | 4 | Contrib module metrics [deprecated] 5 | ----------------------------------- 6 | 7 | .. deprecated:: 0.5.0 8 | All metrics moved to :ref:`Complete list of metrics`. 9 | 10 | 11 | Regression metrics [deprecated] 12 | -------------------------------- 13 | 14 | .. deprecated:: 0.5.0 15 | All metrics moved to :ref:`Complete list of metrics`. 16 | -------------------------------------------------------------------------------- /docs/source/defaults.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. toggle:: 4 | 5 | .. testcode:: default, 1, 2, 3, 4, 5 6 | 7 | from collections import OrderedDict 8 | 9 | import torch 10 | from torch import nn, optim 11 | 12 | from ignite.engine import * 13 | from ignite.handlers import * 14 | from ignite.metrics import * 15 | from ignite.metrics.clustering import * 16 | from ignite.metrics.regression import * 17 | from ignite.utils import * 18 | 19 | # create default evaluator for doctests 20 | 21 | def eval_step(engine, batch): 22 | return batch 23 | 24 | default_evaluator = Engine(eval_step) 25 | 26 | # create default optimizer for doctests 27 | 28 | param_tensor = torch.zeros([1], requires_grad=True) 29 | default_optimizer = torch.optim.SGD([param_tensor], lr=0.1) 30 | 31 | # create default trainer for doctests 32 | # as handlers could be attached to the trainer, 33 | # each test must define his own trainer using `.. testsetup:` 34 | 35 | def get_default_trainer(): 36 | 37 | def train_step(engine, batch): 38 | return batch 39 | 40 | return Engine(train_step) 41 | 42 | # create default model for doctests 43 | 44 | default_model = nn.Sequential(OrderedDict([ 45 | ('base', nn.Linear(4, 2)), 46 | ('fc', nn.Linear(2, 1)) 47 | ])) 48 | 49 | manual_seed(666) 50 | -------------------------------------------------------------------------------- /docs/source/exceptions.rst: -------------------------------------------------------------------------------- 1 | ignite.exceptions 2 | ================= 3 | 4 | .. currentmodule:: ignite.exceptions 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | :autolist: 9 | 10 | .. autoclass:: NotComputableError 11 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Ignite Your Networks! 2 | ===================== 3 | 4 | PyTorch-Ignite is a high-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently. 5 | 6 | .. image:: https://raw.githubusercontent.com/pytorch/ignite/master/assets/logo/ignite_logo_mixed.png 7 | :target: https://pytorch-ignite.ai 8 | 9 | All our documentation moved to `pytorch-ignite.ai `_ 10 | 11 | 12 | .. automodule:: ignite 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: Package Reference 17 | 18 | engine 19 | handlers 20 | metrics 21 | distributed 22 | exceptions 23 | utils 24 | 25 | 26 | .. automodule:: ignite.contrib 27 | 28 | .. toctree:: 29 | :maxdepth: 2 30 | :caption: Contrib Package Reference 31 | 32 | contrib/engines 33 | contrib/metrics 34 | contrib/handlers 35 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | ignite.utils 2 | ============ 3 | 4 | Module with helper methods 5 | 6 | .. currentmodule:: ignite.utils 7 | 8 | .. autosummary:: 9 | :nosignatures: 10 | :autolist: 11 | 12 | .. automodule:: ignite.utils 13 | :members: 14 | -------------------------------------------------------------------------------- /examples/cifar10/.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | cifar10 3 | raw_pytorch 4 | -------------------------------------------------------------------------------- /examples/cifar10/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-ignite 2 | torchvision 3 | tqdm 4 | tensorboardX 5 | fire 6 | clearml 7 | -------------------------------------------------------------------------------- /examples/cifar10/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from torchvision import datasets, models 5 | from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomHorizontalFlip, ToTensor 6 | 7 | train_transform = Compose( 8 | [ 9 | Pad(4), 10 | RandomCrop(32, fill=128), 11 | RandomHorizontalFlip(), 12 | ToTensor(), 13 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 14 | ] 15 | ) 16 | 17 | test_transform = Compose([ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 18 | 19 | 20 | def get_train_test_datasets(path): 21 | path = Path(path) 22 | if not path.exists(): 23 | path.mkdir(parents=True) 24 | download = True 25 | else: 26 | download = True if len(os.listdir(path)) < 1 else False 27 | 28 | train_ds = datasets.CIFAR10(root=path, train=True, download=download, transform=train_transform) 29 | test_ds = datasets.CIFAR10(root=path, train=False, download=False, transform=test_transform) 30 | 31 | return train_ds, test_ds 32 | 33 | 34 | def get_model(name): 35 | if name in models.__dict__: 36 | fn = models.__dict__[name] 37 | else: 38 | raise RuntimeError(f"Unknown model name {name}") 39 | 40 | return fn(num_classes=10) 41 | -------------------------------------------------------------------------------- /examples/cifar100_amp_benchmark/benchmark_fp32.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import torch 3 | from torch.nn import CrossEntropyLoss 4 | from torch.optim import SGD 5 | from torchvision.models import wide_resnet50_2 6 | from utils import get_train_eval_loaders 7 | 8 | from ignite.engine import convert_tensor, create_supervised_evaluator, Engine, Events 9 | 10 | from ignite.handlers import ProgressBar, Timer 11 | from ignite.metrics import Accuracy, Loss 12 | 13 | 14 | def main(dataset_path, batch_size=256, max_epochs=10): 15 | assert torch.cuda.is_available() 16 | assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled." 17 | torch.backends.cudnn.benchmark = True 18 | 19 | device = "cuda" 20 | 21 | train_loader, test_loader, eval_train_loader = get_train_eval_loaders(dataset_path, batch_size=batch_size) 22 | 23 | model = wide_resnet50_2(num_classes=100).to(device) 24 | optimizer = SGD(model.parameters(), lr=0.01) 25 | criterion = CrossEntropyLoss().to(device) 26 | 27 | def train_step(engine, batch): 28 | x = convert_tensor(batch[0], device, non_blocking=True) 29 | y = convert_tensor(batch[1], device, non_blocking=True) 30 | 31 | optimizer.zero_grad() 32 | 33 | y_pred = model(x) 34 | loss = criterion(y_pred, y) 35 | loss.backward() 36 | 37 | optimizer.step() 38 | 39 | return loss.item() 40 | 41 | trainer = Engine(train_step) 42 | timer = Timer(average=True) 43 | timer.attach(trainer, step=Events.EPOCH_COMPLETED) 44 | ProgressBar(persist=True).attach(trainer, output_transform=lambda out: {"batch loss": out}) 45 | 46 | metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)} 47 | 48 | evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) 49 | 50 | def log_metrics(engine, title): 51 | for name in metrics: 52 | print(f"\t{title} {name}: {engine.state.metrics[name]:.2f}") 53 | 54 | @trainer.on(Events.COMPLETED) 55 | def run_validation(_): 56 | print(f"- Mean elapsed time for 1 epoch: {timer.value()}") 57 | print("- Metrics:") 58 | with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Train"): 59 | evaluator.run(eval_train_loader) 60 | 61 | with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Test"): 62 | evaluator.run(test_loader) 63 | 64 | trainer.run(train_loader, max_epochs=max_epochs) 65 | 66 | 67 | if __name__ == "__main__": 68 | fire.Fire(main) 69 | -------------------------------------------------------------------------------- /examples/cifar100_amp_benchmark/benchmark_nvidia_apex.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import torch 3 | from apex import amp 4 | from torch.nn import CrossEntropyLoss 5 | from torch.optim import SGD 6 | from torchvision.models import wide_resnet50_2 7 | from utils import get_train_eval_loaders 8 | 9 | from ignite.engine import convert_tensor, create_supervised_evaluator, Engine, Events 10 | 11 | from ignite.handlers import ProgressBar, Timer 12 | from ignite.metrics import Accuracy, Loss 13 | 14 | 15 | def main(dataset_path, batch_size=256, max_epochs=10, opt="O1"): 16 | assert torch.cuda.is_available() 17 | assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled." 18 | torch.backends.cudnn.benchmark = True 19 | 20 | device = "cuda" 21 | 22 | train_loader, test_loader, eval_train_loader = get_train_eval_loaders(dataset_path, batch_size=batch_size) 23 | 24 | model = wide_resnet50_2(num_classes=100).to(device) 25 | optimizer = SGD(model.parameters(), lr=0.01) 26 | criterion = CrossEntropyLoss().to(device) 27 | 28 | model, optimizer = amp.initialize(model, optimizer, opt_level=opt) 29 | 30 | def train_step(engine, batch): 31 | x = convert_tensor(batch[0], device, non_blocking=True) 32 | y = convert_tensor(batch[1], device, non_blocking=True) 33 | 34 | optimizer.zero_grad() 35 | 36 | y_pred = model(x) 37 | loss = criterion(y_pred, y) 38 | 39 | with amp.scale_loss(loss, optimizer) as scaled_loss: 40 | scaled_loss.backward() 41 | 42 | optimizer.step() 43 | 44 | return loss.item() 45 | 46 | trainer = Engine(train_step) 47 | timer = Timer(average=True) 48 | timer.attach(trainer, step=Events.EPOCH_COMPLETED) 49 | ProgressBar(persist=True).attach(trainer, output_transform=lambda out: {"batch loss": out}) 50 | 51 | metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)} 52 | 53 | evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) 54 | 55 | def log_metrics(engine, title): 56 | for name in metrics: 57 | print(f"\t{title} {name}: {engine.state.metrics[name]:.2f}") 58 | 59 | @trainer.on(Events.COMPLETED) 60 | def run_validation(_): 61 | print(f"- Mean elapsed time for 1 epoch: {timer.value()}") 62 | print("- Metrics:") 63 | with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Train"): 64 | evaluator.run(eval_train_loader) 65 | 66 | with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Test"): 67 | evaluator.run(test_loader) 68 | 69 | trainer.run(train_loader, max_epochs=max_epochs) 70 | 71 | 72 | if __name__ == "__main__": 73 | fire.Fire(main) 74 | -------------------------------------------------------------------------------- /examples/cifar100_amp_benchmark/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torch.utils.data import DataLoader, Subset 4 | from torchvision.datasets.cifar import CIFAR100 5 | from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomErasing, RandomHorizontalFlip, ToTensor 6 | 7 | 8 | def get_train_eval_loaders(path, batch_size=256): 9 | """Setup the dataflow: 10 | - load CIFAR100 train and test datasets 11 | - setup train/test image transforms 12 | - horizontally flipped randomly and augmented using cutout. 13 | - each mini-batch contained 256 examples 14 | - setup train/test data loaders 15 | 16 | Returns: 17 | train_loader, test_loader, eval_train_loader 18 | """ 19 | train_transform = Compose( 20 | [ 21 | Pad(4), 22 | RandomCrop(32), 23 | RandomHorizontalFlip(), 24 | ToTensor(), 25 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 26 | RandomErasing(), 27 | ] 28 | ) 29 | 30 | test_transform = Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 31 | 32 | train_dataset = CIFAR100(root=path, train=True, transform=train_transform, download=True) 33 | test_dataset = CIFAR100(root=path, train=False, transform=test_transform, download=False) 34 | 35 | train_eval_indices = [random.randint(0, len(train_dataset) - 1) for i in range(len(test_dataset))] 36 | train_eval_dataset = Subset(train_dataset, train_eval_indices) 37 | 38 | train_loader = DataLoader( 39 | train_dataset, batch_size=batch_size, num_workers=12, shuffle=True, drop_last=True, pin_memory=True 40 | ) 41 | 42 | test_loader = DataLoader( 43 | test_dataset, batch_size=batch_size, num_workers=12, shuffle=False, drop_last=False, pin_memory=True 44 | ) 45 | 46 | eval_train_loader = DataLoader( 47 | train_eval_dataset, batch_size=batch_size, num_workers=12, shuffle=False, drop_last=False, pin_memory=True 48 | ) 49 | 50 | return train_loader, test_loader, eval_train_loader 51 | -------------------------------------------------------------------------------- /examples/cifar10_qat/.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | cifar10 3 | raw_pytorch 4 | -------------------------------------------------------------------------------- /examples/cifar10_qat/pact.py: -------------------------------------------------------------------------------- 1 | # Implementation taken from https://discuss.pytorch.org/t/evaluator-returns-nan/107972/3 2 | # Ref: https://arxiv.org/abs/1805.06085 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class PACTClip(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, x, alpha): 11 | ctx.save_for_backward(x, alpha) 12 | return torch.clamp(x, 0, alpha.data) 13 | 14 | @staticmethod 15 | def backward(ctx, dy): 16 | x, alpha = ctx.saved_tensors 17 | 18 | dx = dy.clone() 19 | dx[x < 0] = 0 20 | dx[x > alpha] = 0 21 | 22 | dalpha = dy.clone() 23 | dalpha[x <= alpha] = 0 24 | 25 | return dx, torch.sum(dalpha) 26 | 27 | 28 | class PACTReLU(nn.Module): 29 | def __init__(self, alpha=6.0): 30 | super().__init__() 31 | self.alpha = nn.Parameter(torch.tensor(alpha)) 32 | 33 | def forward(self, x): 34 | return PACTClip.apply(x, self.alpha) 35 | -------------------------------------------------------------------------------- /examples/fast_neural_style/handlers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class Progbar(object): 5 | def __init__(self, loader, metrics): 6 | self.num_iterations = len(loader) 7 | self.output_stream = sys.stdout 8 | self.metrics = metrics 9 | self.alpha = 0.98 10 | 11 | def _calc_running_avg(self, engine): 12 | for k, v in engine.state.output.items(): 13 | old_v = self.metrics.get(k, v) 14 | new_v = self.alpha * old_v + (1 - self.alpha) * v 15 | self.metrics[k] = new_v 16 | 17 | def __call__(self, engine): 18 | self._calc_running_avg(engine) 19 | num_seen = engine.state.iteration - self.num_iterations * (engine.state.epoch - 1) 20 | 21 | percent_seen = 100 * float(num_seen) / self.num_iterations 22 | equal_to = int(percent_seen / 10) 23 | done = int(percent_seen) == 100 24 | 25 | bar = "[" + "=" * equal_to + ">" * (not done) + " " * (10 - equal_to) + "]" 26 | message = f"Epoch {engine.state.epoch} | {percent_seen:.2f}% | {bar}" 27 | for key, value in self.metrics.items(): 28 | message += f" | {key}: {value:.2e}" 29 | 30 | message += "\r" 31 | 32 | self.output_stream.write(message) 33 | self.output_stream.flush() 34 | 35 | if done: 36 | self.output_stream.write("\n") 37 | -------------------------------------------------------------------------------- /examples/fast_neural_style/images/content_images/amber.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/fast_neural_style/images/content_images/amber.jpg -------------------------------------------------------------------------------- /examples/fast_neural_style/images/output_images/mosaic_amber.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/fast_neural_style/images/output_images/mosaic_amber.jpg -------------------------------------------------------------------------------- /examples/fast_neural_style/images/style_images/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/fast_neural_style/images/style_images/mosaic.jpg -------------------------------------------------------------------------------- /examples/fast_neural_style/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | 4 | def load_image(filename, size=None, scale=None): 5 | img = Image.open(filename) 6 | if size is not None: 7 | img = img.resize((size, size), Image.LANCZOS) 8 | elif scale is not None: 9 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.LANCZOS) 10 | return img 11 | 12 | 13 | def save_image(filename, data): 14 | img = data.clone().clamp(0, 255).numpy() 15 | img = img.transpose(1, 2, 0).astype("uint8") 16 | img = Image.fromarray(img) 17 | img.save(filename) 18 | 19 | 20 | def gram_matrix(y): 21 | (b, ch, h, w) = y.size() 22 | features = y.view(b, ch, w * h) 23 | features_t = features.transpose(1, 2) 24 | gram = features.bmm(features_t) / (ch * h * w) 25 | return gram 26 | 27 | 28 | def normalize_batch(batch): 29 | # normalize using imagenet mean and std 30 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 31 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 32 | batch = batch.div_(255.0) 33 | return (batch - mean) / std 34 | -------------------------------------------------------------------------------- /examples/fast_neural_style/vgg.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torchvision import models 5 | from torchvision.models.vgg import VGG16_Weights 6 | 7 | 8 | class Vgg16(torch.nn.Module): 9 | def __init__(self, requires_grad=False): 10 | super(Vgg16, self).__init__() 11 | vgg_pretrained_features = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features 12 | self.slice1 = torch.nn.Sequential() 13 | self.slice2 = torch.nn.Sequential() 14 | self.slice3 = torch.nn.Sequential() 15 | self.slice4 = torch.nn.Sequential() 16 | for x in range(4): 17 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 18 | for x in range(4, 9): 19 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(9, 16): 21 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(16, 23): 23 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 24 | if not requires_grad: 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def forward(self, X): 29 | h = self.slice1(X) 30 | h_relu1_2 = h 31 | h = self.slice2(h) 32 | h_relu2_2 = h 33 | h = self.slice3(h) 34 | h_relu3_3 = h 35 | h = self.slice4(h) 36 | h_relu4_3 = h 37 | vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3"]) 38 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 39 | return out 40 | -------------------------------------------------------------------------------- /examples/gan/README.md: -------------------------------------------------------------------------------- 1 | # Deep Convolution Generative Adversarial Networks with Ignite 2 | 3 | ported from [pytorch-examples](https://github.com/pytorch/examples/tree/master/dcgan) 4 | 5 | Usage: 6 | 7 | For example, run on CIFAR10 dataset: 8 | 9 | ``` 10 | python dcgan.py --dataset cifar10 --dataroot /tmp/cifar10 --output-dir /tmp/outputs-dcgan 11 | ``` 12 | -------------------------------------------------------------------------------- /examples/mnist/assets/logs-det_run_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/mnist/assets/logs-det_run_1_2.png -------------------------------------------------------------------------------- /examples/mnist/assets/logs-det_run_3_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/mnist/assets/logs-det_run_3_4.png -------------------------------------------------------------------------------- /examples/mnist/assets/logs_run_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/mnist/assets/logs_run_1_2.png -------------------------------------------------------------------------------- /examples/mnist/assets/logs_run_3_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/mnist/assets/logs_run_3_4.png -------------------------------------------------------------------------------- /examples/mnist/assets/run_vs_resume_run_logs_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/mnist/assets/run_vs_resume_run_logs_1_2.png -------------------------------------------------------------------------------- /examples/mnist/assets/run_vs_resume_run_logs_3_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/mnist/assets/run_vs_resume_run_logs_3_4.png -------------------------------------------------------------------------------- /examples/notebooks/assets/ax_hparams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/notebooks/assets/ax_hparams.png -------------------------------------------------------------------------------- /examples/notebooks/assets/efficientnets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/notebooks/assets/efficientnets.png -------------------------------------------------------------------------------- /examples/notebooks/assets/fashion-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/notebooks/assets/fashion-mnist.png -------------------------------------------------------------------------------- /examples/notebooks/assets/tb_cyclegan_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/notebooks/assets/tb_cyclegan_images.png -------------------------------------------------------------------------------- /examples/references/README.md: -------------------------------------------------------------------------------- 1 | # Reproducible trainings with Ignite 2 | 3 | Inspired by [torchvision/references](https://github.com/pytorch/vision/tree/master/references), we provide several 4 | reproducible baselines for vision tasks: 5 | 6 | - [x] Classification 7 | 8 | - [x] [ImageNet](classification/imagenet) 9 | 10 | - [x] Segmentation 11 | - [x] [Pascal VOC2012](segmentation/pascal_voc2012) 12 | -------------------------------------------------------------------------------- /examples/references/classification/imagenet/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | input 3 | output 4 | experiments/plx/*.yml 5 | experiments/plx/*.yaml 6 | .polyaxon/ 7 | .polyaxonignore 8 | -------------------------------------------------------------------------------- /examples/references/classification/imagenet/README.md: -------------------------------------------------------------------------------- 1 | # Reproducible ImageNet training with Ignite 2 | 3 | In this example, we provide script and tools to perform reproducible experiments on training neural networks on ImageNet 4 | dataset. 5 | 6 | Features: 7 | 8 | - Distributed training with native automatic mixed precision 9 | - Experiments tracking with [ClearML](https://github.com/allegroai/clearml) 10 | 11 | | Model | Training Top-1 Accuracy | Training Top-5 Accuracy | Test Top-1 Accuracy | Test Top-5 Accuracy | 12 | | --------- | ----------------------- | ----------------------- | ------------------- | ------------------- | 13 | | ResNet-50 | 78% | 92% | 77% | 94% | 14 | 15 | Experiment | Model | Training Top-1 Accuracy | Training Top-5 Accuracy | Test Top-1 Accuracy | Test Top-5 Accuracy | ClearML Link 16 | ---|---|---|---|---|---|--- 17 | configs/???.py | 18 | 19 | ## Setup 20 | 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ### Docker 26 | 27 | For docker users, you can use the following images to run the example: 28 | ```bash 29 | docker pull pytorchignite/vision:latest 30 | ``` 31 | 32 | and install other requirements as suggested above 33 | 34 | ## Usage 35 | 36 | Please, export the `DATASET_PATH` environment variable for the ImageNet dataset. 37 | 38 | ```bash 39 | export DATASET_PATH=/path/to/imagenet 40 | # e.g. export DATASET_PATH=/data/ where "train", "val", "meta.bin" are located 41 | ``` 42 | 43 | ### Training 44 | 45 | #### Single GPU 46 | 47 | - Adjust batch size for your GPU type in the configuration file: `configs/baseline_resnet50.py` or `configs/baseline_resnet50.py` 48 | 49 | Run the following command: 50 | ```bash 51 | CUDA_VISIBLE_DEVICES=0 python -u main.py training configs/baseline_resnet50.py 52 | ``` 53 | #### Multiple GPUs 54 | 55 | - Adjust total batch size for your GPUs in the configuration file: `configs/baseline_resnet50.py` or `configs/baseline_resnet50.py` 56 | 57 | ```bash 58 | OMP_NUM_THREADS=1 torchrun --nproc_per_node=2 main.py training configs/baseline_resnet50.py 59 | ``` 60 | 61 | 62 | ## Acknowledgements 63 | 64 | Trainings were done using credits provided by [trainml.ai](trainml.ai) platform. 65 | -------------------------------------------------------------------------------- /examples/references/classification/imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations 2 | numpy 3 | opencv-python-headless 4 | fire 5 | pytorch-ignite 6 | tensorboard 7 | torch 8 | torchvision 9 | tqdm 10 | clearml 11 | image-dataset-viz 12 | py_config_runner>=0.2.0,<1.0.0 -------------------------------------------------------------------------------- /examples/references/classification/imagenet/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import ignite 4 | import ignite.distributed as idist 5 | from ignite.handlers import DiskSaver 6 | 7 | 8 | def initialize(config): 9 | device = idist.device() 10 | 11 | model = config.model.to(device) 12 | optimizer = config.optimizer 13 | 14 | # Adapt model to dist config 15 | model = idist.auto_model(model) 16 | optimizer = idist.auto_optim(optimizer) 17 | criterion = config.criterion.to(device) 18 | 19 | return model, optimizer, criterion 20 | 21 | 22 | def log_basic_info(logger, config): 23 | logger.info(f"- PyTorch version: {torch.__version__}") 24 | logger.info(f"- Ignite version: {ignite.__version__}") 25 | if torch.cuda.is_available(): 26 | # explicitly import cudnn as 27 | # torch.backends.cudnn can not be pickled with hvd spawning procs 28 | from torch.backends import cudnn 29 | 30 | logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}") 31 | logger.info(f"- CUDA version: {torch.version.cuda}") 32 | logger.info(f"- CUDNN version: {cudnn.version()}") 33 | 34 | logger.info("\n") 35 | logger.info("Configuration:") 36 | for key, value in config.items(): 37 | logger.info(f"\t{key}: {value}") 38 | logger.info("\n") 39 | 40 | if idist.get_world_size() > 1: 41 | logger.info("\nDistributed setting:") 42 | logger.info(f"\tbackend: {idist.backend()}") 43 | logger.info(f"\tworld size: {idist.get_world_size()}") 44 | logger.info("\n") 45 | 46 | 47 | def log_metrics(logger, epoch, elapsed, tag, metrics): 48 | metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()]) 49 | logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}") 50 | 51 | 52 | def get_save_handler(output_path, with_clearml): 53 | if with_clearml: 54 | from ignite.handlers.clearml_logger import ClearMLSaver 55 | 56 | return ClearMLSaver(dirname=output_path) 57 | 58 | return DiskSaver(output_path) 59 | -------------------------------------------------------------------------------- /examples/references/segmentation/pascal_voc2012/configs/eval_baseline_dplv3_resnet101_sbd.py: -------------------------------------------------------------------------------- 1 | # Basic training configuration 2 | import os 3 | 4 | import albumentations as A 5 | import cv2 6 | from albumentations.pytorch import ToTensorV2 as ToTensor 7 | from dataflow import get_inference_dataloader, ignore_mask_boundaries 8 | from torchvision.models.segmentation import deeplabv3_resnet101 9 | 10 | # ############################## 11 | # Global configs 12 | # ############################## 13 | 14 | seed = 21 15 | device = "cuda" 16 | debug = False 17 | # Use AMP with torch native 18 | with_amp = True 19 | 20 | 21 | num_classes = 21 22 | batch_size = 9 # total batch size 23 | num_workers = 8 # total num workers per node 24 | 25 | val_img_size = 513 26 | 27 | # ############################## 28 | # Setup Dataflow 29 | # ############################## 30 | 31 | assert "DATASET_PATH" in os.environ 32 | data_path = os.environ["DATASET_PATH"] 33 | 34 | assert "SBD_DATASET_PATH" in os.environ 35 | sbd_data_path = os.environ["SBD_DATASET_PATH"] 36 | 37 | mean = (0.485, 0.456, 0.406) 38 | std = (0.229, 0.224, 0.225) 39 | 40 | 41 | val_transforms = A.Compose( 42 | [ 43 | A.PadIfNeeded(val_img_size, val_img_size, border_mode=cv2.BORDER_CONSTANT), 44 | A.Normalize(mean=mean, std=std), 45 | ignore_mask_boundaries, 46 | ToTensor(), 47 | ] 48 | ) 49 | 50 | 51 | data_loader = get_inference_dataloader( 52 | root_path=data_path, 53 | mode="test", 54 | transforms=val_transforms, 55 | batch_size=batch_size, 56 | num_workers=num_workers, 57 | limit_num_samples=batch_size * 5 if debug else None, 58 | ) 59 | 60 | # ############################## 61 | # Setup model 62 | # ############################## 63 | 64 | num_classes = 21 65 | model = deeplabv3_resnet101(num_classes=num_classes) 66 | 67 | 68 | def model_output_transform(output): 69 | return output["out"] 70 | 71 | 72 | # baseline_dplv3_resnet101_sbd: best_model_78_val_miou_bg=0.6871.pt 73 | weights_path = "d8b4687d86cf445a944853fdd6a6b999" 74 | # or can specify a path 75 | # weights_path = "/path/to/best_model.pt" 76 | -------------------------------------------------------------------------------- /examples/references/segmentation/pascal_voc2012/requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations 2 | numpy 3 | opencv-python-headless 4 | fire 5 | pytorch-ignite 6 | tensorboard 7 | torch 8 | torchvision 9 | tqdm 10 | clearml 11 | image-dataset-viz 12 | py_config_runner>=0.2.0,<1.0.0 -------------------------------------------------------------------------------- /examples/references/segmentation/pascal_voc2012/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import ignite 4 | import ignite.distributed as idist 5 | from ignite.handlers import DiskSaver 6 | 7 | 8 | def initialize(config): 9 | device = idist.device() 10 | 11 | model = config.model.to(device) 12 | optimizer = config.optimizer 13 | 14 | # Adapt model to dist config 15 | model = idist.auto_model(model) 16 | optimizer = idist.auto_optim(optimizer) 17 | criterion = config.criterion.to(device) 18 | 19 | return model, optimizer, criterion 20 | 21 | 22 | def log_basic_info(logger, config): 23 | logger.info(f"- PyTorch version: {torch.__version__}") 24 | logger.info(f"- Ignite version: {ignite.__version__}") 25 | if torch.cuda.is_available(): 26 | # explicitly import cudnn as 27 | # torch.backends.cudnn can not be pickled with hvd spawning procs 28 | from torch.backends import cudnn 29 | 30 | logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}") 31 | logger.info(f"- CUDA version: {torch.version.cuda}") 32 | logger.info(f"- CUDNN version: {cudnn.version()}") 33 | 34 | logger.info("\n") 35 | logger.info("Configuration:") 36 | for key, value in config.items(): 37 | logger.info(f"\t{key}: {value}") 38 | logger.info("\n") 39 | 40 | if idist.get_world_size() > 1: 41 | logger.info("\nDistributed setting:") 42 | logger.info(f"\tbackend: {idist.backend()}") 43 | logger.info(f"\tworld size: {idist.get_world_size()}") 44 | logger.info("\n") 45 | 46 | 47 | def log_metrics(logger, epoch, elapsed, tag, metrics): 48 | metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()]) 49 | logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}") 50 | 51 | 52 | def get_save_handler(output_path, with_clearml): 53 | if with_clearml: 54 | from ignite.handlers.clearml_logger import ClearMLSaver 55 | 56 | return ClearMLSaver(dirname=output_path) 57 | 58 | return DiskSaver(output_path) 59 | -------------------------------------------------------------------------------- /examples/reinforcement_learning/README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement learning training examples with Ignite 2 | 3 | ported from [pytorch-examples](https://github.com/pytorch/examples/tree/master/reinforcement_learning) 4 | 5 | ```bash 6 | pip install gymnasium 7 | # For REINFORCE: 8 | python reinforce.py 9 | # For actor critic: 10 | python actor_critic.py 11 | ``` 12 | -------------------------------------------------------------------------------- /examples/siamese_network/README.md: -------------------------------------------------------------------------------- 1 | # Siamese Network example on MNIST dataset 2 | 3 | This example is ported over from [pytorch/examples/siamese_network](https://github.com/pytorch/examples/tree/main/siamese_network) 4 | 5 | Usage: 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | python siamese_network.py 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/siamese_network/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pytorch-ignite -------------------------------------------------------------------------------- /examples/super_resolution/README.md: -------------------------------------------------------------------------------- 1 | # Super-Resolution using an efficient sub-pixel convolutional neural network 2 | 3 | ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution) 4 | 5 | This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al. 2016](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution. 6 | 7 | ``` 8 | usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--crop_size CROPSIZE] [--batch_size BATCHSIZE] 9 | [--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR] 10 | [--cuda] [--threads THREADS] [--seed SEED] [--debug] 11 | 12 | PyTorch Super Res Example 13 | 14 | optional arguments: 15 | -h, --help show this help message and exit 16 | --upscale_factor super resolution upscale factor 17 | --crop_size cropped size of the images for training 18 | --batch_size training batch size 19 | --test_batch_size testing batch size 20 | --n_epochs number of epochs to train for 21 | --lr Learning Rate. Default=0.01 22 | --cuda use cuda 23 | --mps enable GPU on macOS 24 | --threads number of threads for data loader to use Default=4 25 | --seed random seed to use. Default=123 26 | --debug debug mode for testing 27 | ``` 28 | 29 | This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_.pth` 30 | 31 | ## Example Usage: 32 | 33 | ### Train 34 | 35 | `python main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001` 36 | 37 | ### Super-Resolve 38 | 39 | `python super_resolve.py --input_image .jpg --model model_epoch_500.pth --output_filename out.png` 40 | 41 | ### Super-resolve example on a Cifar-10 image 42 | 43 | #### Input Image 44 | ![Cifar input image](./images/input_cifar.png) 45 | 46 | #### Output Images 47 | | Output image from Model | Output from bicubic sampling | 48 | |-------------------------------|------------------------------------| 49 | | ![Cifar output image](./images/out_cifar.png) | ![Cifar output from bicubic sampling](./images/bicubic_image_cifar.png)| 50 | 51 | 52 | -------------------------------------------------------------------------------- /examples/super_resolution/images/bicubic_image_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/super_resolution/images/bicubic_image_cifar.png -------------------------------------------------------------------------------- /examples/super_resolution/images/input_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/super_resolution/images/input_cifar.png -------------------------------------------------------------------------------- /examples/super_resolution/images/out_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/examples/super_resolution/images/out_cifar.png -------------------------------------------------------------------------------- /examples/super_resolution/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | 4 | 5 | class Net(nn.Module): 6 | def __init__(self, upscale_factor): 7 | super(Net, self).__init__() 8 | 9 | self.relu = nn.ReLU() 10 | self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) 11 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 12 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 13 | self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1)) 14 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor) 15 | 16 | self._initialize_weights() 17 | 18 | def forward(self, x): 19 | x = self.relu(self.conv1(x)) 20 | x = self.relu(self.conv2(x)) 21 | x = self.relu(self.conv3(x)) 22 | x = self.pixel_shuffle(self.conv4(x)) 23 | return x 24 | 25 | def _initialize_weights(self): 26 | init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) 27 | init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) 28 | init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) 29 | init.orthogonal_(self.conv4.weight) 30 | -------------------------------------------------------------------------------- /examples/super_resolution/super_resolve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torchvision.transforms.functional import to_tensor 7 | 8 | # Training settings 9 | parser = argparse.ArgumentParser(description="PyTorch Super Res Example") 10 | parser.add_argument("--input_image", type=str, required=True, help="input image to use") 11 | parser.add_argument("--model", type=str, required=True, help="model file to use") 12 | parser.add_argument("--output_filename", type=str, help="where to save the output image") 13 | parser.add_argument("--cuda", action="store_true", help="use cuda") 14 | opt = parser.parse_args() 15 | 16 | print(opt) 17 | img = Image.open(opt.input_image).convert("YCbCr") 18 | y, cb, cr = img.split() 19 | 20 | model = torch.load(opt.model) 21 | input = to_tensor(y).view(1, -1, y.size[1], y.size[0]) 22 | 23 | if opt.cuda: 24 | model = model.cuda() 25 | input = input.cuda() 26 | 27 | model.eval() 28 | with torch.no_grad(): 29 | out = model(input) 30 | out = out.cpu() 31 | out_img_y = out[0].detach().numpy() 32 | out_img_y *= 255.0 33 | out_img_y = out_img_y.clip(0, 255) 34 | out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L") 35 | 36 | out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) 37 | out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) 38 | out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB") 39 | 40 | out_img.save(opt.output_filename) 41 | print("output image saved to ", opt.output_filename) 42 | -------------------------------------------------------------------------------- /examples/transformers/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TransformerDataset(torch.utils.data.Dataset): 5 | def __init__(self, texts, labels, tokenizer, max_length): 6 | self.texts = texts 7 | self.labels = labels 8 | self.tokenizer = tokenizer 9 | self.max_length = max_length 10 | 11 | def __getitem__(self, idx): 12 | text = str(self.texts[idx]) 13 | text = " ".join(text.split()) 14 | inputs = self.tokenizer( 15 | text, 16 | None, 17 | add_special_tokens=True, 18 | max_length=self.max_length, 19 | truncation=True, 20 | padding="max_length", 21 | return_tensors="pt", 22 | ) 23 | inputs = {k: v.type(torch.long).squeeze(0) for k, v in inputs.items()} 24 | 25 | labels_pt = torch.tensor(self.labels[idx], dtype=torch.float) 26 | return inputs, labels_pt 27 | 28 | def __len__(self): 29 | return len(self.labels) 30 | -------------------------------------------------------------------------------- /examples/transformers/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import AutoConfig, AutoModelForSequenceClassification 3 | 4 | 5 | class TransformerModel(nn.Module): 6 | def __init__(self, model_name, model_dir, dropout, n_fc, n_classes): 7 | super(TransformerModel, self).__init__() 8 | self.config = AutoConfig.from_pretrained( 9 | model_name, 10 | num_labels=n_classes, 11 | output_hidden_states=n_fc, 12 | classifier_dropout=dropout, 13 | output_attentions=True, 14 | ) 15 | self.transformer = AutoModelForSequenceClassification.from_pretrained( 16 | model_name, cache_dir=model_dir, config=self.config 17 | ) 18 | 19 | def forward(self, inputs): 20 | output = self.transformer(**inputs)["logits"] 21 | 22 | return output 23 | -------------------------------------------------------------------------------- /examples/transformers/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-ignite 2 | transformers 3 | datasets 4 | tqdm 5 | tensorboardX 6 | fire 7 | clearml -------------------------------------------------------------------------------- /examples/transformers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset import TransformerDataset 3 | from datasets import load_dataset 4 | from model import TransformerModel 5 | from transformers import AutoTokenizer 6 | 7 | from ignite.handlers import DiskSaver 8 | 9 | 10 | def get_tokenizer(tokenizer_name, tokenizer_dir): 11 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=tokenizer_dir, do_lower_case=True) 12 | return tokenizer 13 | 14 | 15 | def get_model(model_name, model_dir, drop_out, n_fc, num_classes): 16 | model = TransformerModel(model_name, model_dir, drop_out, n_fc, num_classes) 17 | return model 18 | 19 | 20 | def get_dataset(cache_dir, tokenizer_name, tokenizer_dir, max_length): 21 | train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"], cache_dir=cache_dir) 22 | tokenizer = get_tokenizer(tokenizer_name, tokenizer_dir) 23 | train_texts, train_labels = train_dataset["text"], train_dataset["label"] 24 | test_texts, test_labels = test_dataset["text"], test_dataset["label"] 25 | train_dataset = TransformerDataset(train_texts, train_labels, tokenizer, max_length) 26 | test_dataset = TransformerDataset(test_texts, test_labels, tokenizer, max_length) 27 | return train_dataset, test_dataset 28 | 29 | 30 | def thresholded_output_transform(output): 31 | y_pred, y = output 32 | return torch.round(torch.sigmoid(y_pred)), y 33 | 34 | 35 | def get_save_handler(config): 36 | if config["with_clearml"]: 37 | from ignite.handlers.clearml_logger import ClearMLSaver 38 | 39 | return ClearMLSaver(dirname=config["output_dir"]) 40 | 41 | return DiskSaver(config["output_dir"], require_empty=False) 42 | -------------------------------------------------------------------------------- /ignite/__init__.py: -------------------------------------------------------------------------------- 1 | import ignite.contrib 2 | import ignite.distributed 3 | import ignite.engine 4 | import ignite.exceptions 5 | import ignite.handlers 6 | import ignite.metrics 7 | import ignite.utils 8 | 9 | __version__ = "0.6.0" 10 | -------------------------------------------------------------------------------- /ignite/_utils.py: -------------------------------------------------------------------------------- 1 | # For compatibility 2 | from ignite.utils import apply_to_tensor, apply_to_type, convert_tensor, to_onehot 3 | 4 | __all__ = ["apply_to_tensor", "apply_to_type", "convert_tensor", "to_onehot"] 5 | -------------------------------------------------------------------------------- /ignite/base/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.base.mixins import Serializable 2 | -------------------------------------------------------------------------------- /ignite/base/mixins.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from collections.abc import Mapping 3 | from typing import Tuple 4 | 5 | 6 | class Serializable: 7 | _state_dict_all_req_keys: Tuple = () 8 | _state_dict_one_of_opt_keys: Tuple = () 9 | 10 | def state_dict(self) -> OrderedDict: 11 | raise NotImplementedError 12 | 13 | def load_state_dict(self, state_dict: Mapping) -> None: 14 | if not isinstance(state_dict, Mapping): 15 | raise TypeError(f"Argument state_dict should be a dictionary, but given {type(state_dict)}") 16 | 17 | for k in self._state_dict_all_req_keys: 18 | if k not in state_dict: 19 | raise ValueError( 20 | f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" 21 | ) 22 | opts = [k in state_dict for k in self._state_dict_one_of_opt_keys] 23 | if len(opts) > 0 and ((not any(opts)) or (all(opts))): 24 | raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys") 25 | -------------------------------------------------------------------------------- /ignite/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/ignite/contrib/__init__.py -------------------------------------------------------------------------------- /ignite/contrib/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.contrib.engines.tbptt import create_supervised_tbptt_trainer, Tbptt_Events 2 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.handlers import ( # ref # ref 2 | clearml_logger, 3 | EpochOutputStore, 4 | global_step_from_engine, 5 | mlflow_logger, 6 | neptune_logger, 7 | polyaxon_logger, 8 | tensorboard_logger, 9 | tqdm_logger, 10 | visdom_logger, 11 | wandb_logger, 12 | ) 13 | from ignite.handlers.clearml_logger import ClearMLLogger 14 | from ignite.handlers.lr_finder import FastaiLRFinder 15 | from ignite.handlers.mlflow_logger import MLflowLogger 16 | from ignite.handlers.neptune_logger import NeptuneLogger 17 | from ignite.handlers.param_scheduler import ( 18 | ConcatScheduler, 19 | CosineAnnealingScheduler, 20 | create_lr_scheduler_with_warmup, 21 | LinearCyclicalScheduler, 22 | LRScheduler, 23 | ParamGroupScheduler, 24 | PiecewiseLinear, 25 | ) 26 | from ignite.handlers.polyaxon_logger import PolyaxonLogger 27 | from ignite.handlers.tensorboard_logger import TensorboardLogger 28 | from ignite.handlers.time_profilers import BasicTimeProfiler, HandlersTimeProfiler 29 | from ignite.handlers.tqdm_logger import ProgressBar 30 | 31 | from ignite.handlers.visdom_logger import VisdomLogger 32 | from ignite.handlers.wandb_logger import WandBLogger 33 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/base_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.base_logger`` was moved to ``ignite.handlers.base_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.base_logger`` was moved to ``ignite.handlers.base_logger``. 4 | Please refer to :mod:`~ignite.handlers.base_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/base_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.base_logger import ( 17 | BaseHandler, 18 | BaseLogger, 19 | BaseOptimizerParamsHandler, 20 | BaseOutputHandler, 21 | BaseWeightsHandler, 22 | BaseWeightsScalarHandler, 23 | ) 24 | 25 | __all__ = [ 26 | "BaseHandler", 27 | "BaseWeightsHandler", 28 | "BaseOptimizerParamsHandler", 29 | "BaseOutputHandler", 30 | "BaseWeightsScalarHandler", 31 | "BaseLogger", 32 | ] 33 | BaseHandler = BaseHandler 34 | BaseWeightsHandler = BaseWeightsHandler 35 | BaseOptimizerParamsHandler = BaseOptimizerParamsHandler 36 | BaseOutputHandler = BaseOutputHandler 37 | BaseWeightsScalarHandler = BaseWeightsScalarHandler 38 | BaseLogger = BaseLogger 39 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/clearml_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.clearml_logger`` was moved to ``ignite.handlers.clearml_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.clearml_logger`` was moved to ``ignite.handlers.clearml_logger``. 4 | Please refer to :mod:`~ignite.handlers.clearml_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/clearml_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.clearml_logger import ( 17 | ClearMLLogger, 18 | ClearMLSaver, 19 | GradsHistHandler, 20 | GradsScalarHandler, 21 | OptimizerParamsHandler, 22 | OutputHandler, 23 | WeightsHistHandler, 24 | WeightsScalarHandler, 25 | ) 26 | from ignite.handlers.utils import global_step_from_engine # noqa 27 | 28 | __all__ = [ 29 | "ClearMLLogger", 30 | "ClearMLSaver", 31 | "OptimizerParamsHandler", 32 | "OutputHandler", 33 | "WeightsScalarHandler", 34 | "WeightsHistHandler", 35 | "GradsScalarHandler", 36 | "GradsHistHandler", 37 | ] 38 | ClearMLLogger = ClearMLLogger 39 | ClearMLSaver = ClearMLSaver 40 | OptimizerParamsHandler = OptimizerParamsHandler 41 | OutputHandler = OutputHandler 42 | WeightsScalarHandler = WeightsScalarHandler 43 | WeightsHistHandler = WeightsHistHandler 44 | GradsScalarHandler = GradsScalarHandler 45 | GradsHistHandler = GradsHistHandler 46 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/lr_finder.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.lr_finder`` was moved to ``ignite.handlers.lr_finder``. 2 | Note: 3 | ``ignite.contrib.handlers.lr_finder`` was moved to ``ignite.handlers.lr_finder``. 4 | Please refer to :mod:`~ignite.handlers.lr_finder`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/lr_finder.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.lr_finder import FastaiLRFinder 17 | 18 | __all__ = [ 19 | "FastaiLRFinder", 20 | ] 21 | 22 | FastaiLRFinder = FastaiLRFinder 23 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/mlflow_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.mlflow_logger`` was moved to ``ignite.handlers.mlflow_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.mlflow_logger`` was moved to ``ignite.handlers.mlflow_logger``. 4 | Please refer to :mod:`~ignite.handlers.mlflow_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/mlflow_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.mlflow_logger import MLflowLogger, OptimizerParamsHandler, OutputHandler 17 | from ignite.handlers.utils import global_step_from_engine # noqa 18 | 19 | __all__ = ["MLflowLogger", "OutputHandler", "OptimizerParamsHandler"] 20 | MLflowLogger = MLflowLogger 21 | OutputHandler = OutputHandler 22 | OptimizerParamsHandler = OptimizerParamsHandler 23 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/neptune_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.neptune_logger`` was moved to ``ignite.handlers.neptune_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.neptune_logger`` was moved to ``ignite.handlers.neptune_logger``. 4 | Please refer to :mod:`~ignite.handlers.neptune_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/neptune_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.neptune_logger import ( 17 | _INTEGRATION_VERSION_KEY, 18 | GradsScalarHandler, 19 | NeptuneLogger, 20 | NeptuneSaver, 21 | OptimizerParamsHandler, 22 | OutputHandler, 23 | WeightsScalarHandler, 24 | ) 25 | from ignite.handlers.utils import global_step_from_engine # noqa 26 | 27 | __all__ = [ 28 | "NeptuneLogger", 29 | "NeptuneSaver", 30 | "OptimizerParamsHandler", 31 | "OutputHandler", 32 | "WeightsScalarHandler", 33 | "GradsScalarHandler", 34 | ] 35 | NeptuneLogger = NeptuneLogger 36 | NeptuneSaver = NeptuneSaver 37 | OptimizerParamsHandler = OptimizerParamsHandler 38 | OutputHandler = OutputHandler 39 | WeightsScalarHandler = WeightsScalarHandler 40 | GradsScalarHandler = GradsScalarHandler 41 | _INTEGRATION_VERSION_KEY = _INTEGRATION_VERSION_KEY 42 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/param_scheduler.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.param_scheduler`` was moved to ``ignite.handlers.param_scheduler``. 2 | Note: 3 | ``ignite.contrib.handlers.param_scheduler`` was moved to ``ignite.handlers.param_scheduler``. 4 | Please refer to :mod:`~ignite.handlers.param_scheduler`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/param_scheduler.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.param_scheduler import ( 17 | ConcatScheduler, 18 | CosineAnnealingScheduler, 19 | create_lr_scheduler_with_warmup, 20 | CyclicalScheduler, 21 | LinearCyclicalScheduler, 22 | LRScheduler, 23 | ParamGroupScheduler, 24 | ParamScheduler, 25 | PiecewiseLinear, 26 | ) 27 | 28 | __all__ = [ 29 | "ConcatScheduler", 30 | "CosineAnnealingScheduler", 31 | "LinearCyclicalScheduler", 32 | "LRScheduler", 33 | "ParamGroupScheduler", 34 | "ParamScheduler", 35 | "PiecewiseLinear", 36 | "CyclicalScheduler", 37 | "create_lr_scheduler_with_warmup", 38 | ] 39 | 40 | ConcatScheduler = ConcatScheduler 41 | CosineAnnealingScheduler = CosineAnnealingScheduler 42 | LinearCyclicalScheduler = LinearCyclicalScheduler 43 | LRScheduler = LRScheduler 44 | ParamGroupScheduler = ParamGroupScheduler 45 | ParamScheduler = ParamScheduler 46 | PiecewiseLinear = PiecewiseLinear 47 | CyclicalScheduler = CyclicalScheduler 48 | create_lr_scheduler_with_warmup = create_lr_scheduler_with_warmup 49 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/polyaxon_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.polyaxon_logger`` was moved to ``ignite.handlers.polyaxon_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.polyaxon_logger`` was moved to ``ignite.handlers.polyaxon_logger``. 4 | Please refer to :mod:`~ignite.handlers.polyaxon_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/polyaxon_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.polyaxon_logger import OptimizerParamsHandler, OutputHandler, PolyaxonLogger 17 | from ignite.handlers.utils import global_step_from_engine # noqa 18 | 19 | __all__ = ["PolyaxonLogger", "OutputHandler", "OptimizerParamsHandler"] 20 | PolyaxonLogger = PolyaxonLogger 21 | OutputHandler = OutputHandler 22 | OptimizerParamsHandler = OptimizerParamsHandler 23 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.tensorboard_logger`` was moved to ``ignite.handlers.tensorboard_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.tensorboard_logger`` was moved to ``ignite.handlers.tensorboard_logger``. 4 | Please refer to :mod:`~ignite.handlers.tensorboard_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/tensorboard_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.tensorboard_logger import ( 17 | GradsHistHandler, 18 | GradsScalarHandler, 19 | OptimizerParamsHandler, 20 | OutputHandler, 21 | TensorboardLogger, 22 | WeightsHistHandler, 23 | WeightsScalarHandler, 24 | ) 25 | from ignite.handlers.utils import global_step_from_engine # noqa 26 | 27 | 28 | __all__ = [ 29 | "TensorboardLogger", 30 | "OptimizerParamsHandler", 31 | "OutputHandler", 32 | "WeightsScalarHandler", 33 | "WeightsHistHandler", 34 | "GradsScalarHandler", 35 | "GradsHistHandler", 36 | ] 37 | TensorboardLogger = TensorboardLogger 38 | OptimizerParamsHandler = OptimizerParamsHandler 39 | OutputHandler = OutputHandler 40 | WeightsScalarHandler = WeightsScalarHandler 41 | WeightsHistHandler = WeightsHistHandler 42 | GradsScalarHandler = GradsScalarHandler 43 | GradsHistHandler = GradsHistHandler 44 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/time_profilers.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.time_profilers.py`` was moved to ``ignite.handlers.time_profilers``. 2 | Note: 3 | ``ignite.contrib.handlers.time_profilers`` was moved to ``ignite.handlers.time_profilers``. 4 | Please refer to :mod:`~ignite.handlers.time_profilers`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/time_profilers.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.time_profilers import BasicTimeProfiler, HandlersTimeProfiler 17 | 18 | __all__ = [ 19 | "BasicTimeProfiler", 20 | "HandlersTimeProfiler", 21 | ] 22 | 23 | BasicTimeProfiler = BasicTimeProfiler 24 | HandlersTimeProfiler = HandlersTimeProfiler 25 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/tqdm_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.tqdm_logger`` was moved to ``ignite.handlers.tqdm_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.tqdm_logger`` was moved to ``ignite.handlers.tqdm_logger``. 4 | Please refer to :mod:`~ignite.handlers.tqdm_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/tqdm_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.tqdm_logger import ProgressBar 17 | 18 | __all__ = [ 19 | "ProgressBar", 20 | ] 21 | ProgressBar = ProgressBar 22 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/visdom_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.visdom_logger`` was moved to ``ignite.handlers.visdom_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.visdom_logger`` was moved to ``ignite.handlers.visdom_logger``. 4 | Please refer to :mod:`~ignite.handlers.visdom_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/visdom_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.utils import global_step_from_engine # noqa 17 | from ignite.handlers.visdom_logger import ( 18 | _DummyExecutor, 19 | GradsScalarHandler, 20 | OptimizerParamsHandler, 21 | OutputHandler, 22 | VisdomLogger, 23 | WeightsScalarHandler, 24 | ) 25 | 26 | __all__ = [ 27 | "VisdomLogger", 28 | "OptimizerParamsHandler", 29 | "OutputHandler", 30 | "WeightsScalarHandler", 31 | "GradsScalarHandler", 32 | ] 33 | VisdomLogger = VisdomLogger 34 | OptimizerParamsHandler = OptimizerParamsHandler 35 | OutputHandler = OutputHandler 36 | WeightsScalarHandler = WeightsScalarHandler 37 | GradsScalarHandler = GradsScalarHandler 38 | _DummyExecutor = _DummyExecutor 39 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/wandb_logger.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.handlers.wandb_logger`` was moved to ``ignite.handlers.wandb_logger``. 2 | Note: 3 | ``ignite.contrib.handlers.wandb_logger`` was moved to ``ignite.handlers.wandb_logger``. 4 | Please refer to :mod:`~ignite.handlers.wandb_logger`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/handlers/wandb_logger.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.handlers.utils import global_step_from_engine # noqa 17 | from ignite.handlers.wandb_logger import OptimizerParamsHandler, OutputHandler, WandBLogger 18 | 19 | __all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler"] 20 | WandBLogger = WandBLogger 21 | OutputHandler = OutputHandler 22 | OptimizerParamsHandler = OptimizerParamsHandler 23 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import ignite.metrics.regression 2 | from ignite.metrics import average_precision, cohen_kappa, gpu_info, precision_recall_curve, roc_auc 3 | from ignite.metrics.average_precision import AveragePrecision 4 | from ignite.metrics.cohen_kappa import CohenKappa 5 | from ignite.metrics.gpu_info import GpuInfo 6 | from ignite.metrics.precision_recall_curve import PrecisionRecallCurve 7 | from ignite.metrics.roc_auc import ROC_AUC, RocCurve 8 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/average_precision.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.average_precision`` was moved to ``ignite.metrics.average_precision``. 2 | Note: 3 | ``ignite.contrib.metrics.average_precision`` was moved to ``ignite.metrics.average_precision``. 4 | Please refer to :mod:`~ignite.metrics.average_precision`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to /ignite/metrics/average_precision.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.metrics.average_precision import AveragePrecision 17 | 18 | __all__ = [ 19 | "AveragePrecision", 20 | ] 21 | 22 | AveragePrecision = AveragePrecision 23 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/cohen_kappa.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.cohen_kappa`` was moved to ``ignite.metrics.cohen_kappa``. 2 | Note: 3 | ``ignite.contrib.metrics.cohen_kappa`` was moved to ``ignite.metrics.cohen_kappa``. 4 | Please refer to :mod:`~ignite.metrics.cohen_kappa`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/cohen_kappa.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.metrics.cohen_kappa import CohenKappa 17 | 18 | __all__ = [ 19 | "CohenKappa", 20 | ] 21 | 22 | CohenKappa = CohenKappa 23 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/gpu_info.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.gpu_info`` was moved to ``ignite.metrics.gpu_info``. 2 | Note: 3 | ``ignite.contrib.metrics.gpu_info`` was moved to ``ignite.metrics.gpu_info``. 4 | Please refer to :mod:`~ignite.metrics.gpu_info`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/gpu_info.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.metrics.gpu_info import GpuInfo 17 | 18 | __all__ = [ 19 | "GpuInfo", 20 | ] 21 | 22 | GpuInfo = GpuInfo 23 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/precision_recall_curve.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.precision_recall_curve`` was moved to ``ignite.metrics.precision_recall_curve``. 2 | Note: 3 | ``ignite.contrib.metrics.precision_recall_curve`` was moved to ``ignite.metrics.precision_recall_curve``. 4 | Please refer to :mod:`~ignite.metrics.precision_recall_curve`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/precision_recall_curve.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.metrics.precision_recall_curve import precision_recall_curve_compute_fn, PrecisionRecallCurve 17 | 18 | __all__ = [ 19 | "PrecisionRecallCurve", 20 | "precision_recall_curve_compute_fn", 21 | ] 22 | 23 | 24 | PrecisionRecallCurve = PrecisionRecallCurve 25 | precision_recall_curve_compute_fn = precision_recall_curve_compute_fn 26 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.regression import ( 2 | canberra_metric, 3 | fractional_absolute_error, 4 | fractional_bias, 5 | geometric_mean_absolute_error, 6 | geometric_mean_relative_absolute_error, 7 | manhattan_distance, 8 | maximum_absolute_error, 9 | mean_absolute_relative_error, 10 | mean_error, 11 | mean_normalized_bias, 12 | median_absolute_error, 13 | median_absolute_percentage_error, 14 | median_relative_absolute_error, 15 | r2_score, 16 | wave_hedges_distance, 17 | ) 18 | from ignite.metrics.regression.canberra_metric import CanberraMetric 19 | from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError 20 | from ignite.metrics.regression.fractional_bias import FractionalBias 21 | from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError 22 | from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError 23 | from ignite.metrics.regression.manhattan_distance import ManhattanDistance 24 | from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError 25 | from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError 26 | from ignite.metrics.regression.mean_error import MeanError 27 | from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias 28 | from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError 29 | from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError 30 | from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError 31 | from ignite.metrics.regression.r2_score import R2Score 32 | from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance 33 | 34 | 35 | __all__ = [ 36 | "CanberraMetric", 37 | "FractionalAbsoluteError", 38 | "FractionalBias", 39 | "GeometricMeanAbsoluteError", 40 | "GeometricMeanRelativeAbsoluteError", 41 | "ManhattanDistance", 42 | "MaximumAbsoluteError", 43 | "MeanAbsoluteRelativeError", 44 | "MeanError", 45 | "MeanNormalizedBias", 46 | "MedianAbsoluteError", 47 | "MedianAbsolutePercentageError", 48 | "MedianRelativeAbsoluteError", 49 | "R2Score", 50 | "WaveHedgesDistance", 51 | "canberra_metric", 52 | "fractional_absolute_error", 53 | "fractional_bias", 54 | "geometric_mean_absolute_error", 55 | "geometric_mean_relative_absolute_error", 56 | "manhattan_distance", 57 | "maximum_absolute_error", 58 | "mean_absolute_relative_error", 59 | "mean_error", 60 | "mean_normalized_bias", 61 | "median_absolute_error", 62 | "median_absolute_percentage_error", 63 | "median_relative_absolute_error", 64 | "r2_score", 65 | "wave_hedges_distance", 66 | ] 67 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/_base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | from ignite.metrics import Metric 7 | from ignite.metrics.metric import reinit__is_reduced 8 | 9 | 10 | def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]) -> None: 11 | y_pred, y = output 12 | c1 = y_pred.ndimension() == 2 and y_pred.shape[1] == 1 13 | if not (y_pred.ndimension() == 1 or c1): 14 | raise ValueError(f"Input y_pred should have shape (N,) or (N, 1), but given {y_pred.shape}") 15 | 16 | c2 = y.ndimension() == 2 and y.shape[1] == 1 17 | if not (y.ndimension() == 1 or c2): 18 | raise ValueError(f"Input y should have shape (N,) or (N, 1), but given {y.shape}") 19 | 20 | if y_pred.shape != y.shape: 21 | raise ValueError(f"Input data shapes should be the same, but given {y_pred.shape} and {y.shape}") 22 | 23 | 24 | def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]) -> None: 25 | y_pred, y = output 26 | if y_pred.dtype not in (torch.float16, torch.float32, torch.float64): 27 | raise TypeError(f"Input y_pred dtype should be float 16, 32 or 64, but given {y_pred.dtype}") 28 | 29 | if y.dtype not in (torch.float16, torch.float32, torch.float64): 30 | raise TypeError(f"Input y dtype should be float 16, 32 or 64, but given {y.dtype}") 31 | 32 | 33 | def _torch_median(output: torch.Tensor) -> float: 34 | output = output.view(-1) 35 | len_ = len(output) 36 | 37 | if len_ % 2 == 0: 38 | return float((torch.kthvalue(output, len_ // 2)[0] + torch.kthvalue(output, len_ // 2 + 1)[0]) / 2) 39 | else: 40 | return float(torch.kthvalue(output, len_ // 2 + 1)[0]) 41 | 42 | 43 | class _BaseRegression(Metric): 44 | # Base class for all regression metrics 45 | # `update` method check the shapes and call internal overloaded 46 | # method `_update`. 47 | 48 | @reinit__is_reduced 49 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: 50 | _check_output_shapes(output) 51 | _check_output_types(output) 52 | y_pred, y = output[0].detach(), output[1].detach() 53 | 54 | if y_pred.ndimension() == 2 and y_pred.shape[1] == 1: 55 | y_pred = y_pred.squeeze(dim=-1) 56 | 57 | if y.ndimension() == 2 and y.shape[1] == 1: 58 | y = y.squeeze(dim=-1) 59 | 60 | self._update((y_pred, y)) 61 | 62 | @abstractmethod 63 | def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: 64 | pass 65 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/canberra_metric.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.canberra_metric`` was moved to ``ignite.metrics.regression.canberra_metric``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.canberra_metric`` was moved to ``ignite.metrics.regression.canberra_metric``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.canberra_metric`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/canberra_metric.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.canberra_metric import CanberraMetric 18 | 19 | __all__ = ["CanberraMetric"] 20 | 21 | CanberraMetric = CanberraMetric 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/fractional_absolute_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.fractional_absolute_error`` was moved to ``ignite.metrics.regression.fractional_absolute_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.fractional_absolute_error`` was moved to ``ignite.metrics.regression.fractional_absolute_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.fractional_absolute_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/fractional_absolute_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError 18 | 19 | __all__ = ["FractionalAbsoluteError"] 20 | 21 | FractionalAbsoluteError = FractionalAbsoluteError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/fractional_bias.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.fractional_bias`` was moved to ``ignite.metrics.regression.fractional_bias``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.fractional_bias`` was moved to ``ignite.metrics.regression.fractional_bias``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.fractional_bias`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/fractional_bias.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.fractional_bias import FractionalBias 18 | 19 | __all__ = ["FractionalBias"] 20 | 21 | FractionalBias = FractionalBias 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/geometric_mean_absolute_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.geometric_mean_absolute_error`` was moved to ``ignite.metrics.regression.geometric_mean_absolute_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.geometric_mean_absolute_error`` was moved to ``ignite.metrics.regression.geometric_mean_absolute_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.geometric_mean_absolute_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_absolute_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError 18 | 19 | __all__ = ["GeometricMeanAbsoluteError"] 20 | 21 | GeometricMeanAbsoluteError = GeometricMeanAbsoluteError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.geometric_mean_relative_absolute_error`` was moved to ``ignite.metrics.regression.geometric_mean_relative_absolute_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.geometric_mean_relative_absolute_error`` was moved to ``ignite.metrics.regression.geometric_mean_relative_absolute_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.geometric_mean_relative_absolute_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_relative_absolute_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError 18 | 19 | __all__ = ["GeometricMeanRelativeAbsoluteError"] 20 | 21 | GeometricMeanRelativeAbsoluteError = GeometricMeanRelativeAbsoluteError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/manhattan_distance.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.manhattan_distance`` was moved to ``ignite.metrics.regression.manhattan_distance``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.manhattan_distance`` was moved to ``ignite.metrics.regression.manhattan_distance``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.manhattan_distance`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/manhattan_distance.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.manhattan_distance import ManhattanDistance 18 | 19 | __all__ = ["ManhattanDistance"] 20 | 21 | ManhattanDistance = ManhattanDistance 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/maximum_absolute_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.maximum_absolute_error`` was moved to ``ignite.metrics.regression.maximum_absolute_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.maximum_absolute_error`` was moved to ``ignite.metrics.regression.maximum_absolute_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.maximum_absolute_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/maximum_absolute_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError 18 | 19 | __all__ = ["MaximumAbsoluteError"] 20 | 21 | MaximumAbsoluteError = MaximumAbsoluteError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/mean_absolute_relative_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.mean_absolute_relative_error`` was moved to ``ignite.metrics.regression.mean_absolute_relative_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.mean_absolute_relative_error`` was moved to ``ignite.metrics.regression.mean_absolute_relative_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.mean_absolute_relative_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/mean_absolute_relative_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError 18 | 19 | __all__ = ["MeanAbsoluteRelativeError"] 20 | 21 | MeanAbsoluteRelativeError = MeanAbsoluteRelativeError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/mean_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.mean_error`` was moved to ``ignite.metrics.regression.mean_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.mean_error`` was moved to ``ignite.metrics.regression.mean_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.mean_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/mean_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.mean_error import MeanError 18 | 19 | __all__ = ["MeanError"] 20 | 21 | MeanError = MeanError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/mean_normalized_bias.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.mean_normalized_bias`` was moved to ``ignite.metrics.regression.mean_normalized_bias``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.mean_normalized_bias`` was moved to ``ignite.metrics.regression.mean_normalized_bias``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.mean_normalized_bias`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/mean_normalized_bias.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias 18 | 19 | __all__ = ["MeanNormalizedBias"] 20 | 21 | MeanNormalizedBias = MeanNormalizedBias 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/median_absolute_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.median_absolute_error`` was moved to ``ignite.metrics.regression.median_absolute_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.median_absolute_error`` was moved to ``ignite.metrics.regression.median_absolute_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.median_absolute_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/median_absolute_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError 18 | 19 | __all__ = ["MedianAbsoluteError"] 20 | 21 | MedianAbsoluteError = MedianAbsoluteError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/median_absolute_percentage_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.median_absolute_percentage_error`` was moved to ``ignite.metrics.regression.median_absolute_percentage_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.median_absolute_percentage_error`` was moved to ``ignite.metrics.regression.median_absolute_percentage_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.median_absolute_percentage_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/median_absolute_percentage_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError 18 | 19 | __all__ = ["MedianAbsolutePercentageError"] 20 | 21 | MedianAbsolutePercentageError = MedianAbsolutePercentageError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/median_relative_absolute_error.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.median_relative_absolute_error`` was moved to ``ignite.metrics.regression.median_relative_absolute_error``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.median_relative_absolute_error`` was moved to ``ignite.metrics.regression.median_relative_absolute_error``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.median_relative_absolute_error`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/median_relative_absolute_error.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError 18 | 19 | __all__ = ["MedianRelativeAbsoluteError"] 20 | 21 | MedianRelativeAbsoluteError = MedianRelativeAbsoluteError 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/r2_score.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.r2_score`` was moved to ``ignite.metrics.regression.r2_score``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.r2_score`` was moved to ``ignite.metrics.regression.r2_score``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.r2_score`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/r2_score.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.r2_score import R2Score 18 | 19 | __all__ = ["R2Score"] 20 | 21 | R2Score = R2Score 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/wave_hedges_distance.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.regression.wave_hedges_distance`` was moved to ``ignite.metrics.regression.wave_hedges_distance``. # noqa 2 | Note: 3 | ``ignite.contrib.metrics.regression.wave_hedges_distance`` was moved to ``ignite.metrics.regression.wave_hedges_distance``. # noqa 4 | Please refer to :mod:`~ignite.metrics.regression.wave_hedges_distance`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/regression/wave_hedges_distance.py" 12 | f" and will be removed in version {removed_in}" 13 | if removed_in 14 | else "" ".\n Please refer to the documentation for more details." 15 | ) 16 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 17 | from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance 18 | 19 | __all__ = ["WaveHedgesDistance"] 20 | 21 | WaveHedgesDistance = WaveHedgesDistance 22 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/roc_auc.py: -------------------------------------------------------------------------------- 1 | """ ``ignite.contrib.metrics.roc_auc`` was moved to ``ignite.metrics.roc_auc``. 2 | Note: 3 | ``ignite.contrib.metrics.roc_auc`` was moved to ``ignite.metrics.roc_auc``. 4 | Please refer to :mod:`~ignite.metrics.roc_auc`. 5 | """ 6 | 7 | import warnings 8 | 9 | removed_in = "0.6.0" 10 | deprecation_warning = ( 11 | f"{__file__} has been moved to ignite/metrics/roc_auc.py" 12 | + (f" and will be removed in version {removed_in}" if removed_in else "") 13 | + ".\n Please refer to the documentation for more details." 14 | ) 15 | warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) 16 | from ignite.metrics.roc_auc import ROC_AUC, RocCurve 17 | 18 | __all__ = ["RocCurve", "ROC_AUC"] 19 | 20 | RocCurve = RocCurve 21 | ROC_AUC = ROC_AUC 22 | -------------------------------------------------------------------------------- /ignite/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.distributed.auto import * 2 | from ignite.distributed.comp_models import native, xla 3 | from ignite.distributed.launcher import Parallel 4 | from ignite.distributed.utils import * 5 | -------------------------------------------------------------------------------- /ignite/distributed/comp_models/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Type, TYPE_CHECKING, Union 2 | 3 | from ignite.distributed.comp_models.base import _SerialModel 4 | from ignite.distributed.comp_models.horovod import has_hvd_support 5 | from ignite.distributed.comp_models.native import has_native_dist_support 6 | from ignite.distributed.comp_models.xla import has_xla_support 7 | 8 | if TYPE_CHECKING: 9 | from ignite.distributed.comp_models.horovod import _HorovodDistModel 10 | from ignite.distributed.comp_models.native import _NativeDistModel 11 | from ignite.distributed.comp_models.xla import _XlaDistModel 12 | 13 | 14 | def setup_available_computation_models() -> ( 15 | Tuple[Type[Union[_SerialModel, "_NativeDistModel", "_XlaDistModel", "_HorovodDistModel"]], ...] 16 | ): 17 | models: List[Type[Union[_SerialModel, "_NativeDistModel", "_XlaDistModel", "_HorovodDistModel"]]] = [ 18 | _SerialModel, 19 | ] 20 | if has_native_dist_support: 21 | from ignite.distributed.comp_models.native import _NativeDistModel 22 | 23 | models.append(_NativeDistModel) 24 | if has_xla_support: 25 | from ignite.distributed.comp_models.xla import _XlaDistModel 26 | 27 | models.append(_XlaDistModel) 28 | if has_hvd_support: 29 | from ignite.distributed.comp_models.horovod import _HorovodDistModel 30 | 31 | models.append(_HorovodDistModel) 32 | 33 | return tuple(models) 34 | 35 | 36 | registered_computation_models = setup_available_computation_models() 37 | -------------------------------------------------------------------------------- /ignite/engine/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Tuple, Union 3 | 4 | 5 | def _check_signature(fn: Callable, fn_description: str, *args: Any, **kwargs: Any) -> None: 6 | # if handler with filter, check the handler rather than the decorator 7 | if hasattr(fn, "_parent"): 8 | signature = inspect.signature(fn._parent()) 9 | else: 10 | signature = inspect.signature(fn) 11 | try: # try without engine 12 | signature.bind(*args, **kwargs) 13 | except TypeError as exc: 14 | fn_params = list(signature.parameters) 15 | exception_msg = str(exc) 16 | passed_params = list(args) + list(kwargs) 17 | raise ValueError( 18 | f"Error adding {fn} '{fn_description}': " 19 | f"takes parameters {fn_params} but will be called with {passed_params}" 20 | f"({exception_msg})." 21 | ) 22 | 23 | 24 | def _to_hours_mins_secs(time_taken: Union[float, int]) -> Tuple[int, int, float]: 25 | """Convert seconds to hours, mins, seconds and milliseconds.""" 26 | mins, secs = divmod(time_taken, 60) 27 | hours, mins = divmod(mins, 60) 28 | return round(hours), round(mins), secs 29 | -------------------------------------------------------------------------------- /ignite/exceptions.py: -------------------------------------------------------------------------------- 1 | __all__ = ["NotComputableError"] 2 | 3 | 4 | class NotComputableError(RuntimeError): 5 | """ 6 | Exception class to raise if Metric cannot be computed. 7 | """ 8 | -------------------------------------------------------------------------------- /ignite/handlers/stores.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional 2 | 3 | from ignite.engine import Engine, Events 4 | 5 | 6 | class EpochOutputStore: 7 | """EpochOutputStore handler to save output prediction and target history 8 | after every epoch, could be useful for e.g., visualization purposes. 9 | 10 | Note: 11 | This can potentially lead to a memory error if the output data is 12 | larger than available RAM. 13 | 14 | Args: 15 | output_transform: a callable that is used to 16 | transform the :class:`~ignite.engine.engine.Engine`'s 17 | ``process_function``'s output , e.g., lambda x: x[0] 18 | 19 | Attributes: 20 | data: a list of :class:`~ignite.engine.engine.Engine` outputs, 21 | optionally transformed by `output_transform`. 22 | 23 | Examples: 24 | .. code-block:: python 25 | 26 | eos = EpochOutputStore() 27 | trainer = create_supervised_trainer(model, optimizer, loss) 28 | train_evaluator = create_supervised_evaluator(model, metrics) 29 | eos.attach(train_evaluator, 'output') 30 | 31 | @trainer.on(Events.EPOCH_COMPLETED) 32 | def log_training_results(engine): 33 | train_evaluator.run(train_loader) 34 | output = train_evaluator.state.output 35 | # output = [(y_pred0, y0), (y_pred1, y1), ...] 36 | # do something with output, e.g., plotting 37 | 38 | .. versionadded:: 0.4.5 39 | .. versionchanged:: 0.4.5 40 | `attach` now accepts an optional argument `name` 41 | """ 42 | 43 | def __init__(self, output_transform: Callable = lambda x: x): 44 | self.data: List[Any] = [] 45 | self.output_transform = output_transform 46 | 47 | def reset(self) -> None: 48 | """Reset the attribute data to empty list.""" 49 | self.data = [] 50 | 51 | def update(self, engine: Engine) -> None: 52 | """Append the output of Engine to attribute data.""" 53 | output = self.output_transform(engine.state.output) 54 | self.data.append(output) 55 | 56 | def store(self, engine: Engine) -> None: 57 | """Store `self.data` on `engine.state.{self.name}`""" 58 | setattr(engine.state, self.name, self.data) 59 | 60 | def attach(self, engine: Engine, name: Optional[str] = None) -> None: 61 | """Attaching `reset` method at EPOCH_STARTED and 62 | `update` method at ITERATION_COMPLETED. 63 | 64 | If `name` is passed, will store `self.data` on `engine.state` 65 | under `name`. 66 | """ 67 | engine.add_event_handler(Events.EPOCH_STARTED, self.reset) 68 | engine.add_event_handler(Events.ITERATION_COMPLETED, self.update) 69 | if name: 70 | self.name = name 71 | engine.add_event_handler(Events.EPOCH_COMPLETED, self.store) 72 | -------------------------------------------------------------------------------- /ignite/handlers/terminate_on_nan.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numbers 3 | from typing import Callable, Union 4 | 5 | import torch 6 | 7 | from ignite.engine import Engine 8 | from ignite.utils import apply_to_type, setup_logger 9 | 10 | __all__ = ["TerminateOnNan"] 11 | 12 | 13 | class TerminateOnNan: 14 | """TerminateOnNan handler can be used to stop the training if the `process_function`'s output 15 | contains a NaN or infinite number or `torch.tensor`. 16 | The output can be of type: number, tensor or collection of them. The training is stopped if 17 | there is at least a single number/tensor have NaN or Infinite value. For example, if the output is 18 | `[1.23, torch.tensor(...), torch.tensor(float('nan'))]` the handler will stop the training. 19 | 20 | Args: 21 | output_transform: a callable that is used to transform the 22 | :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into a number or `torch.tensor` 23 | or collection of them. This can be useful if, for example, you have a multi-output model and 24 | you want to check one or multiple values of the output. 25 | 26 | 27 | Examples: 28 | .. code-block:: python 29 | 30 | trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) 31 | 32 | """ 33 | 34 | def __init__(self, output_transform: Callable = lambda x: x): 35 | self.logger = setup_logger(__name__ + "." + self.__class__.__name__) 36 | self.logger.addHandler(logging.StreamHandler()) 37 | self._output_transform = output_transform 38 | 39 | def __call__(self, engine: Engine) -> None: 40 | output = self._output_transform(engine.state.output) 41 | 42 | def raise_error(x: Union[float, torch.Tensor]) -> None: 43 | if isinstance(x, numbers.Number): 44 | x = torch.tensor(x) 45 | 46 | if isinstance(x, torch.Tensor) and not bool(torch.isfinite(x).all()): 47 | raise RuntimeError("Infinite or NaN tensor found.") 48 | 49 | try: 50 | apply_to_type(output, (numbers.Number, torch.Tensor), raise_error) 51 | except RuntimeError: 52 | self.logger.warning(f"{self.__class__.__name__}: Output '{output}' contains NaN or Inf. Stop training") 53 | engine.terminate() 54 | -------------------------------------------------------------------------------- /ignite/handlers/time_limit.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Optional 3 | 4 | from ignite.engine import Engine 5 | 6 | __all__ = ["TimeLimit"] 7 | 8 | from ignite.utils import setup_logger 9 | 10 | 11 | class TimeLimit: 12 | """TimeLimit handler can be used to control training time for computing environments where session time is limited. 13 | Timer starts when handler is created and not training started. 14 | This handler gracefully terminates the training if time passed in the training exceeds a limit. 15 | 16 | Args: 17 | limit_sec: Maximum time before training terminates (in seconds). Defaults to 28800. 18 | 19 | Examples: 20 | .. code-block:: python 21 | 22 | from ignite.engine import Events 23 | from ignite.handlers import TimeLimit 24 | 25 | handler = TimeLimit() # 8 hours of training 26 | trainer.add_event_handler(Events.ITERATION_COMPLETED, handler) 27 | 28 | .. versionadded:: 0.4.3 29 | """ 30 | 31 | def __init__(self, limit_sec: Optional[int] = 28800): 32 | if not isinstance(limit_sec, int): 33 | raise TypeError("Argument limit_sec should be an integer.") 34 | if limit_sec <= 0: 35 | raise ValueError("Argument limit_sec should be a positive integer.") 36 | 37 | self.limit_sec = limit_sec 38 | self.start_time = time.time() 39 | self.logger = setup_logger(__name__ + "." + self.__class__.__name__) 40 | 41 | def __call__(self, engine: Engine) -> None: 42 | elapsed_time = time.time() - self.start_time 43 | if elapsed_time > self.limit_sec: 44 | self.logger.info("Reached the time limit: {} sec. Stop training".format(self.limit_sec)) 45 | engine.terminate() 46 | -------------------------------------------------------------------------------- /ignite/handlers/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ignite.engine import Engine 4 | from ignite.engine.events import Events 5 | 6 | 7 | def global_step_from_engine(engine: Engine, custom_event_name: Optional[Events] = None) -> Callable: 8 | """Helper method to setup `global_step_transform` function using another engine. 9 | This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator. 10 | 11 | Args: 12 | engine: engine which state is used to provide the global step 13 | custom_event_name: registered event name. Optional argument, event name to use. 14 | 15 | Returns: 16 | global step based on provided engine 17 | """ 18 | 19 | def wrapper(_: Any, event_name: Events) -> int: 20 | if custom_event_name is not None: 21 | event_name = custom_event_name 22 | return engine.state.get_event_attrib_value(event_name) 23 | 24 | return wrapper 25 | -------------------------------------------------------------------------------- /ignite/metrics/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore 2 | from ignite.metrics.clustering.davies_bouldin_score import DaviesBouldinScore 3 | from ignite.metrics.clustering.silhouette_score import SilhouetteScore 4 | -------------------------------------------------------------------------------- /ignite/metrics/clustering/_base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from torch import Tensor 4 | 5 | from ignite.exceptions import NotComputableError 6 | from ignite.metrics.epoch_metric import EpochMetric 7 | 8 | 9 | class _ClusteringMetricBase(EpochMetric): 10 | required_output_keys = ("features", "labels") 11 | 12 | def _check_shape(self, output: Tuple[Tensor, Tensor]) -> None: 13 | features, labels = output 14 | if features.ndimension() != 2: 15 | raise ValueError("Features should be of shape (batch_size, n_targets).") 16 | 17 | if labels.ndimension() != 1: 18 | raise ValueError("Labels should be of shape (batch_size, ).") 19 | 20 | def _check_type(self, output: Tuple[Tensor, Tensor]) -> None: 21 | features, labels = output 22 | if len(self._predictions) < 1: 23 | return 24 | dtype_preds = self._predictions[-1].dtype 25 | if dtype_preds != features.dtype: 26 | raise ValueError( 27 | f"Incoherent types between input features and stored features: {dtype_preds} vs {features.dtype}" 28 | ) 29 | 30 | dtype_targets = self._targets[-1].dtype 31 | if dtype_targets != labels.dtype: 32 | raise ValueError( 33 | f"Incoherent types between input labels and stored labels: {dtype_targets} vs {labels.dtype}" 34 | ) 35 | 36 | def compute(self) -> float: 37 | if len(self._predictions) < 1 or len(self._targets) < 1: 38 | raise NotComputableError( 39 | f"{self.__class__.__name__} must have at least one example before it can be computed." 40 | ) 41 | 42 | return super().compute() 43 | -------------------------------------------------------------------------------- /ignite/metrics/gan/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.gan.fid import FID 2 | from ignite.metrics.gan.inception_score import InceptionScore 3 | 4 | __all__ = [ 5 | "InceptionScore", 6 | "FID", 7 | ] 8 | -------------------------------------------------------------------------------- /ignite/metrics/metric_group.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Sequence, Tuple 2 | 3 | import torch 4 | 5 | from ignite.metrics import Metric 6 | 7 | 8 | class MetricGroup(Metric): 9 | """ 10 | A class for grouping metrics so that user could manage them easier. 11 | 12 | Args: 13 | metrics: a dictionary of names to metric instances. 14 | output_transform: a callable that is used to transform the 15 | :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the 16 | form expected by the metric. `output_transform` of each metric in the group is also 17 | called upon its update. 18 | skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be 19 | true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as 20 | ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for 21 | ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle 22 | this. 23 | 24 | Examples: 25 | We construct a group of metrics, attach them to the engine at once and retrieve their result. 26 | 27 | .. code-block:: python 28 | 29 | import torch 30 | 31 | metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) 32 | metric_group.attach(default_evaluator, "eval_metrics") 33 | y_true = torch.tensor([1, 0, 1, 1, 0, 1]) 34 | y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) 35 | state = default_evaluator.run([[y_pred, y_true]]) 36 | 37 | # Metrics individually available in `state.metrics` 38 | state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] 39 | 40 | # And also altogether 41 | state.metrics["eval_metrics"] 42 | 43 | .. versionchanged:: 0.5.2 44 | ``skip_unrolling`` argument is added. 45 | """ 46 | 47 | _state_dict_all_req_keys: Tuple[str, ...] = ("metrics",) 48 | 49 | def __init__( 50 | self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False 51 | ): 52 | self.metrics = metrics 53 | super(MetricGroup, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling) 54 | 55 | def reset(self) -> None: 56 | for m in self.metrics.values(): 57 | m.reset() 58 | 59 | def update(self, output: Sequence[torch.Tensor]) -> None: 60 | for m in self.metrics.values(): 61 | m.update(m._output_transform(output)) 62 | 63 | def compute(self) -> Dict[str, Any]: 64 | return {k: m.compute() for k, m in self.metrics.items()} 65 | -------------------------------------------------------------------------------- /ignite/metrics/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.nlp.bleu import Bleu 2 | from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN 3 | 4 | __all__ = [ 5 | "Bleu", 6 | "Rouge", 7 | "RougeN", 8 | "RougeL", 9 | ] 10 | -------------------------------------------------------------------------------- /ignite/metrics/nlp/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import Any, Sequence, Tuple 3 | 4 | __all__ = ["ngrams", "lcs", "modified_precision"] 5 | 6 | 7 | def ngrams(sequence: Sequence[Any], n: int) -> Counter: 8 | """ 9 | Generate the ngrams from a sequence of items 10 | 11 | Args: 12 | sequence: sequence of items 13 | n: n-gram order 14 | 15 | Returns: 16 | A counter of ngram objects 17 | 18 | .. versionadded:: 0.4.5 19 | """ 20 | return Counter([tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)]) 21 | 22 | 23 | def lcs(seq_a: Sequence[Any], seq_b: Sequence[Any]) -> int: 24 | """ 25 | Compute the length of the longest common subsequence in two sequence of items 26 | https://en.wikipedia.org/wiki/Longest_common_subsequence_problem 27 | 28 | Args: 29 | seq_a: first sequence of items 30 | seq_b: second sequence of items 31 | 32 | Returns: 33 | The length of the longest common subsequence 34 | 35 | .. versionadded:: 0.4.5 36 | """ 37 | m = len(seq_a) 38 | n = len(seq_b) 39 | 40 | dp = [[0] * (n + 1) for _ in range(m + 1)] 41 | 42 | for i in range(m + 1): 43 | for j in range(n + 1): 44 | if i == 0 or j == 0: 45 | dp[i][j] = 0 46 | elif seq_a[i - 1] == seq_b[j - 1]: 47 | dp[i][j] = dp[i - 1][j - 1] + 1 48 | else: 49 | dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) 50 | 51 | return dp[m][n] 52 | 53 | 54 | def modified_precision(references: Sequence[Sequence[Any]], candidate: Any, n: int) -> Tuple[int, int]: 55 | """ 56 | Compute the modified precision 57 | 58 | .. math:: 59 | p_{n} = \frac{m_{n}}{l_{n}} 60 | 61 | where m_{n} is the number of matched n-grams between translation T and its reference R, and l_{n} is the 62 | total number of n-grams in the translation T. 63 | 64 | More details can be found in `Papineni et al. 2002`__. 65 | 66 | __ https://www.aclweb.org/anthology/P02-1040.pdf 67 | 68 | Args: 69 | references: list of references R 70 | candidate: translation T 71 | n: n-gram order 72 | 73 | Returns: 74 | The length of the longest common subsequence 75 | 76 | .. versionadded:: 0.4.5 77 | """ 78 | # ngrams of the candidate 79 | counts = ngrams(candidate, n) 80 | 81 | # union of ngrams of references 82 | max_counts: Counter = Counter() 83 | for reference in references: 84 | max_counts |= ngrams(reference, n) 85 | 86 | # clipped count of the candidate and references 87 | clipped_counts = counts & max_counts 88 | 89 | return sum(clipped_counts.values()), sum(counts.values()) 90 | -------------------------------------------------------------------------------- /ignite/metrics/regression/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.regression.canberra_metric import CanberraMetric 2 | from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError 3 | from ignite.metrics.regression.fractional_bias import FractionalBias 4 | from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError 5 | from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError 6 | from ignite.metrics.regression.kendall_correlation import KendallRankCorrelation 7 | from ignite.metrics.regression.manhattan_distance import ManhattanDistance 8 | from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError 9 | from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError 10 | from ignite.metrics.regression.mean_error import MeanError 11 | from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias 12 | from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError 13 | from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError 14 | from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError 15 | from ignite.metrics.regression.pearson_correlation import PearsonCorrelation 16 | from ignite.metrics.regression.r2_score import R2Score 17 | from ignite.metrics.regression.spearman_correlation import SpearmanRankCorrelation 18 | from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance 19 | -------------------------------------------------------------------------------- /ignite/metrics/regression/_base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | from ignite.metrics.metric import Metric, reinit__is_reduced 7 | 8 | 9 | def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]) -> None: 10 | y_pred, y = output 11 | c1 = y_pred.ndimension() == 2 and y_pred.shape[1] == 1 12 | if not (y_pred.ndimension() == 1 or c1): 13 | raise ValueError(f"Input y_pred should have shape (N,) or (N, 1), but given {y_pred.shape}") 14 | 15 | c2 = y.ndimension() == 2 and y.shape[1] == 1 16 | if not (y.ndimension() == 1 or c2): 17 | raise ValueError(f"Input y should have shape (N,) or (N, 1), but given {y.shape}") 18 | 19 | if y_pred.shape != y.shape: 20 | raise ValueError(f"Input data shapes should be the same, but given {y_pred.shape} and {y.shape}") 21 | 22 | 23 | def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]) -> None: 24 | y_pred, y = output 25 | if y_pred.dtype not in (torch.float16, torch.float32, torch.float64): 26 | raise TypeError(f"Input y_pred dtype should be float 16, 32 or 64, but given {y_pred.dtype}") 27 | 28 | if y.dtype not in (torch.float16, torch.float32, torch.float64): 29 | raise TypeError(f"Input y dtype should be float 16, 32 or 64, but given {y.dtype}") 30 | 31 | 32 | def _torch_median(output: torch.Tensor) -> float: 33 | # torch.kthvalue used later is not supported on MPS 34 | if output.device.type == "mps": 35 | output = output.cpu() 36 | 37 | output = output.view(-1) 38 | len_ = len(output) 39 | 40 | if len_ % 2 == 0: 41 | return float((torch.kthvalue(output, len_ // 2)[0] + torch.kthvalue(output, len_ // 2 + 1)[0]) / 2) 42 | else: 43 | return float(torch.kthvalue(output, len_ // 2 + 1)[0]) 44 | 45 | 46 | class _BaseRegression(Metric): 47 | # Base class for all regression metrics 48 | # `update` method check the shapes and call internal overloaded 49 | # method `_update`. 50 | 51 | @reinit__is_reduced 52 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: 53 | _check_output_shapes(output) 54 | _check_output_types(output) 55 | y_pred, y = output[0].detach(), output[1].detach() 56 | 57 | if y_pred.ndimension() == 2 and y_pred.shape[1] == 1: 58 | y_pred = y_pred.squeeze(dim=-1) 59 | 60 | if y.ndimension() == 2 and y.shape[1] == 1: 61 | y = y.squeeze(dim=-1) 62 | 63 | self._update((y_pred, y)) 64 | 65 | @abstractmethod 66 | def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: 67 | pass 68 | -------------------------------------------------------------------------------- /ignite/metrics/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionAvgPrecisionRecall 2 | 3 | __all__ = ["ObjectDetectionAvgPrecisionRecall"] 4 | -------------------------------------------------------------------------------- /ignite/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/ignite/py.typed -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | files = ignite 3 | pretty = True 4 | show_error_codes = True 5 | 6 | check_untyped_defs = True 7 | ; a lot of work needed to fix issues 8 | disallow_any_generics = False 9 | disallow_incomplete_defs = True 10 | disallow_subclassing_any = True 11 | ; due to missing types in pytorch set to False 12 | disallow_untyped_calls = False 13 | disallow_untyped_decorators = True 14 | disallow_untyped_defs = True 15 | no_implicit_optional = True 16 | ; would need a more precise import of pytorch classes and methods, which is not possible, therefore set to False 17 | no_implicit_reexport = False 18 | strict_equality = True 19 | warn_redundant_casts = True 20 | ; due to missing types in multiple libs set to False 21 | warn_return_any = False 22 | ; results in too many false positives, therefore set to False 23 | warn_unreachable = False 24 | warn_unused_configs = True 25 | warn_unused_ignores = True 26 | 27 | [mypy-apex.*] 28 | ignore_missing_imports = True 29 | 30 | [mypy-clearml.*] 31 | ignore_missing_imports = True 32 | 33 | [mypy-horovod.*] 34 | ignore_missing_imports = True 35 | 36 | [mypy-matplotlib.*] 37 | ignore_missing_imports = True 38 | 39 | [mypy-mlflow.*] 40 | ignore_missing_imports = True 41 | 42 | [mypy-neptune.*] 43 | ignore_missing_imports = True 44 | 45 | [mypy-numpy.*] 46 | ignore_missing_imports = True 47 | 48 | [mypy-pandas.*] 49 | ignore_missing_imports = True 50 | 51 | [mypy-sklearn.*] 52 | ignore_missing_imports = True 53 | 54 | [mypy-polyaxon.*] 55 | ignore_missing_imports = True 56 | 57 | [mypy-polyaxon_client.*] 58 | ignore_missing_imports = True 59 | 60 | [mypy-pynvml.*] 61 | ignore_missing_imports = True 62 | 63 | [mypy-tensorboardX.*] 64 | ignore_missing_imports = True 65 | 66 | [mypy-torch_xla.*] 67 | ignore_missing_imports = True 68 | 69 | [mypy-trains.*] 70 | ignore_missing_imports = True 71 | 72 | [mypy-tqdm.*] 73 | ignore_missing_imports = True 74 | 75 | [mypy-scipy.*] 76 | ignore_missing_imports = True 77 | 78 | [mypy-torchvision.*] 79 | ignore_missing_imports = True 80 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py39', 'py311'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | 7 | ( 8 | /( 9 | \.eggs # exclude a few common directories in the 10 | | \.git # root of the project 11 | | \.hg 12 | | \.mypy_cache 13 | | \.tox 14 | | \.venv 15 | | _build 16 | | buck-out 17 | | build 18 | | dist 19 | | assets 20 | )/ 21 | | foo.py # also separately exclude a file named foo.py in 22 | # the root of the project 23 | ) 24 | ''' 25 | 26 | [tool.usort.known] 27 | first_party = [ 28 | "ignite", 29 | ] 30 | third_party = [ 31 | "clearml", 32 | "dill", 33 | "matplotlib", 34 | "numpy", 35 | "pkg_resources", 36 | "pytest", 37 | "requests", 38 | "setuptools", 39 | "skimage", 40 | "sklearn", 41 | "torch", 42 | "torchvision", 43 | ] 44 | 45 | [tool.ufmt] 46 | excludes = [ 47 | "assets/", 48 | ] 49 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Tests 2 | numpy 3 | pytest 4 | pytest-cov 5 | pytest-xdist 6 | pytest-timeout 7 | dill 8 | filelock 9 | setuptools 10 | # Test contrib dependencies 11 | scipy 12 | pytorch_fid 13 | tqdm 14 | scikit-learn 15 | matplotlib 16 | tensorboardX 17 | visdom 18 | polyaxon 19 | wandb 20 | mlflow 21 | neptune-client>=0.16.17 22 | tensorboard 23 | torchvision 24 | pynvml<12 # pynvml module was removed in 12.X, is not developed or maintained. We should replace pynvml with something else. 25 | clearml 26 | scikit-image 27 | py-rouge 28 | pycocotools 29 | # temporary fix for python=3.12 and v3.8.1 30 | # nltk 31 | git+https://github.com/nltk/nltk@aba99c8 32 | # Examples dependencies 33 | pandas 34 | gymnasium 35 | # temporary fix: E AttributeError: module 'mpmath' has no attribute 'rational' 36 | mpmath<1.4 37 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE 3 | 4 | [pycodestyle] 5 | exclude = .eggs,*.egg,build,docs/*,.git,versioneer.py,*/conf.py 6 | ignore = E402, E721 7 | max_line_length = 120 8 | 9 | [isort] 10 | known_third_party=clearml,dill,matplotlib,numpy,pkg_resources,pytest,requests,setuptools,skimage,sklearn,torch,torchvision 11 | multi_line_output=3 12 | include_trailing_comma=True 13 | force_grid_wrap=0 14 | use_parentheses=True 15 | line_length=120 16 | skip_glob=docs/** 17 | filter_files=True 18 | profile=black 19 | 20 | [flake8] 21 | max-line-length = 120 22 | ignore = E722,E203,E231,F841,W503,F403,E402 23 | per-file-ignores = __init__.py: F401 24 | 25 | [tool:pytest] 26 | markers = 27 | distributed: mark a test with distributed option 28 | multinode_distributed: mark a test with multi-node distributed option 29 | tpu: mark a test as requiring XLA 30 | addopts = 31 | --color=yes 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | from setuptools import find_packages, setup 6 | 7 | 8 | def read(*names, **kwargs): 9 | with io.open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp: 10 | return fp.read() 11 | 12 | 13 | def find_version(*file_paths): 14 | version_file = read(*file_paths) 15 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 16 | if version_match: 17 | return version_match.group(1) 18 | raise RuntimeError("Unable to find version string.") 19 | 20 | 21 | readme = read("README.md").replace( 22 | 'src="assets/', 'src="https://raw.githubusercontent.com/pytorch/ignite/master/assets/' 23 | ) 24 | 25 | VERSION = find_version("ignite", "__init__.py") 26 | 27 | requirements = ["torch>=1.3,<3", "packaging"] 28 | 29 | setup( 30 | # Metadata 31 | name="pytorch-ignite", 32 | version=VERSION, 33 | author="PyTorch-Ignite Team", 34 | author_email="contact@pytorch-ignite.ai", 35 | url="https://github.com/pytorch/ignite", 36 | description="A lightweight library to help with training neural networks in PyTorch.", 37 | long_description_content_type="text/markdown", 38 | long_description=readme, 39 | license="BSD", 40 | # Package info 41 | packages=find_packages(exclude=("tests", "tests.*")), 42 | package_data={"ignite": ["py.typed"]}, 43 | zip_safe=False, 44 | install_requires=requirements, 45 | ) 46 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Needed to collect coverage data 2 | -------------------------------------------------------------------------------- /tests/ignite/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ignite.distributed.comp_models.base import _torch_version_gt_112 4 | 5 | 6 | def is_mps_available_and_functional(): 7 | if not _torch_version_gt_112: 8 | return False 9 | 10 | if not torch.backends.mps.is_available(): 11 | return False 12 | try: 13 | # Try to allocate a small tensor on the MPS device 14 | torch.tensor([1.0], device="mps") 15 | return True 16 | except RuntimeError: 17 | return False 18 | -------------------------------------------------------------------------------- /tests/ignite/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/base/__init__.py -------------------------------------------------------------------------------- /tests/ignite/base/test_mixins.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ignite.base import Serializable 4 | 5 | 6 | def test_state_dict(): 7 | s = Serializable() 8 | with pytest.raises(NotImplementedError): 9 | s.state_dict() 10 | 11 | 12 | def test_load_state_dict(): 13 | s = Serializable() 14 | s.load_state_dict({}) 15 | -------------------------------------------------------------------------------- /tests/ignite/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | -------------------------------------------------------------------------------- /tests/ignite/contrib/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture() 5 | def visdom_offline_logfile(dirname): 6 | log_file = dirname / "logs.visdom" 7 | yield log_file 8 | -------------------------------------------------------------------------------- /tests/ignite/contrib/engines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/contrib/engines/__init__.py -------------------------------------------------------------------------------- /tests/ignite/contrib/metrics/test_warnings_of_deprecation_of_metrics.py: -------------------------------------------------------------------------------- 1 | from importlib import __import__ 2 | 3 | import pytest 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "log_module,fromlist", 8 | [ 9 | ("average_precision", ["AveragePrecision"]), 10 | ("cohen_kappa", ["CohenKappa"]), 11 | ("gpu_info", ["GpuInfo"]), 12 | ("precision_recall_curve", ["PrecisionRecallCurve"]), 13 | ("roc_auc", ["ROC_AUC", "RocCurve"]), 14 | ("regression.canberra_metric", ["CanberraMetric"]), 15 | ("regression.fractional_absolute_error", ["FractionalAbsoluteError"]), 16 | ("regression.fractional_bias", ["FractionalBias"]), 17 | ("regression.geometric_mean_absolute_error", ["GeometricMeanAbsoluteError"]), 18 | ("regression.geometric_mean_relative_absolute_error", ["GeometricMeanRelativeAbsoluteError"]), 19 | ("regression.manhattan_distance", ["ManhattanDistance"]), 20 | ("regression.maximum_absolute_error", ["MaximumAbsoluteError"]), 21 | ("regression.mean_absolute_relative_error", ["MeanAbsoluteRelativeError"]), 22 | ("regression.mean_error", ["MeanError"]), 23 | ("regression.mean_normalized_bias", ["MeanNormalizedBias"]), 24 | ("regression.median_absolute_error", ["MedianAbsoluteError"]), 25 | ("regression.median_absolute_percentage_error", ["MedianAbsolutePercentageError"]), 26 | ("regression.median_relative_absolute_error", ["MedianRelativeAbsoluteError"]), 27 | ("regression.r2_score", ["R2Score"]), 28 | ("regression.wave_hedges_distance", ["WaveHedgesDistance"]), 29 | ], 30 | ) 31 | def test_imports(log_module, fromlist): 32 | with pytest.warns(DeprecationWarning, match="will be removed in version 0.6.0"): 33 | imported = __import__(f"ignite.contrib.metrics.{log_module}", globals(), locals(), fromlist) 34 | for attr in fromlist: 35 | getattr(imported, attr) 36 | -------------------------------------------------------------------------------- /tests/ignite/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/distributed/__init__.py -------------------------------------------------------------------------------- /tests/ignite/distributed/comp_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/distributed/comp_models/__init__.py -------------------------------------------------------------------------------- /tests/ignite/engine/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | try: 4 | from torch.utils.data import IterableDataset 5 | except ImportError: 6 | 7 | class IterableDataset: 8 | pass 9 | 10 | 11 | class BatchChecker: 12 | def __init__(self, data, init_counter=0): 13 | self.counter = init_counter 14 | self.data = data 15 | self.true_batch = None 16 | 17 | def check(self, batch): 18 | self.true_batch = self.data[self.counter % len(self.data)] 19 | self.counter += 1 20 | res = self.true_batch == batch 21 | return res.all() if not isinstance(res, bool) else res 22 | 23 | 24 | class IterationCounter: 25 | def __init__(self, start_value=1): 26 | self.current_iteration_count = start_value 27 | 28 | def __call__(self, engine): 29 | assert engine.state.iteration == self.current_iteration_count 30 | self.current_iteration_count += 1 31 | 32 | 33 | class EpochCounter: 34 | def __init__(self, start_value=1): 35 | self.current_epoch_count = start_value 36 | 37 | def __call__(self, engine): 38 | assert engine.state.epoch == self.current_epoch_count 39 | self.current_epoch_count += 1 40 | 41 | 42 | def setup_sampler(sampler_type, num_iters, batch_size): 43 | if sampler_type is None: 44 | return None, batch_size 45 | 46 | if sampler_type == "weighted": 47 | from torch.utils.data.sampler import WeightedRandomSampler 48 | 49 | w = torch.ones(num_iters * batch_size, dtype=torch.float) 50 | for i in range(num_iters): 51 | w[batch_size * i : batch_size * (i + 1)] += i * 1.0 52 | return WeightedRandomSampler(w, num_samples=num_iters * batch_size, replacement=True), batch_size 53 | 54 | if sampler_type == "distributed": 55 | import torch.distributed as dist 56 | from torch.utils.data.distributed import DistributedSampler 57 | 58 | num_replicas = 1 59 | rank = 0 60 | if dist.is_available() and dist.is_initialized(): 61 | num_replicas = dist.get_world_size() 62 | rank = dist.get_rank() 63 | 64 | dataset = torch.zeros(num_iters * batch_size) 65 | return DistributedSampler(dataset, num_replicas=num_replicas, rank=rank), batch_size // num_replicas 66 | 67 | 68 | class MyIterableDataset(IterableDataset): 69 | def __init__(self, start, end): 70 | super(MyIterableDataset).__init__() 71 | assert end > start, "this example code only works with end >= start" 72 | self.start = start 73 | self.end = end 74 | 75 | def __iter__(self): 76 | return iter(range(self.start, self.end)) 77 | 78 | 79 | def get_iterable_dataset(*args, **kwargs): 80 | return MyIterableDataset(*args, **kwargs) 81 | -------------------------------------------------------------------------------- /tests/ignite/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # Needed to collect coverage data 2 | class MockFP16DeepSpeedZeroOptimizer: 3 | def __init__(self, optimizer): 4 | self.optimizer = optimizer 5 | 6 | def step(self, closure=None): 7 | self.optimizer.step() 8 | 9 | def _get_param_groups(self): 10 | return self.optimizer.param_groups 11 | 12 | def _set_param_groups(self, value): 13 | self.optimizer.param_groups = value 14 | 15 | param_groups = property(_get_param_groups, _set_param_groups) 16 | -------------------------------------------------------------------------------- /tests/ignite/handlers/test_handlers.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | from ignite.engine import Engine, Events 4 | from ignite.handlers import global_step_from_engine 5 | 6 | 7 | def test_global_step_from_engine(): 8 | iteration = 12 9 | epoch = 23 10 | 11 | trainer = Engine(lambda e, b: None) 12 | trainer.state.iteration = iteration 13 | trainer.state.epoch = epoch 14 | 15 | gst = global_step_from_engine(trainer) 16 | assert gst(MagicMock(), Events.EPOCH_COMPLETED) == epoch 17 | 18 | gst = global_step_from_engine(trainer, custom_event_name=Events.ITERATION_COMPLETED) 19 | assert gst(MagicMock(), Events.EPOCH_COMPLETED) == iteration 20 | -------------------------------------------------------------------------------- /tests/ignite/handlers/test_stores.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ignite.engine.engine import Engine, Events 4 | from ignite.handlers import EpochOutputStore 5 | 6 | 7 | @pytest.fixture 8 | def dummy_evaluator(): 9 | def dummy_process_function(engine, batch): 10 | return 1, 0 11 | 12 | dummy_evaluator = Engine(dummy_process_function) 13 | 14 | return dummy_evaluator 15 | 16 | 17 | @pytest.fixture 18 | def eos(): 19 | return EpochOutputStore() 20 | 21 | 22 | def test_no_transform(dummy_evaluator, eos): 23 | eos.attach(dummy_evaluator) 24 | dummy_evaluator.run(range(1)) 25 | assert eos.data == [(1, 0)] 26 | 27 | 28 | def test_transform(dummy_evaluator): 29 | eos = EpochOutputStore(output_transform=lambda x: x[0]) 30 | eos.attach(dummy_evaluator) 31 | 32 | dummy_evaluator.run(range(1)) 33 | assert eos.data == [1] 34 | 35 | 36 | def test_reset(dummy_evaluator, eos): 37 | eos.attach(dummy_evaluator) 38 | dummy_evaluator.run(range(2)) 39 | eos.reset() 40 | assert eos.data == [] 41 | 42 | 43 | def test_update_one_iteration(dummy_evaluator, eos): 44 | eos.attach(dummy_evaluator) 45 | dummy_evaluator.run(range(1)) 46 | assert len(eos.data) == 1 47 | 48 | 49 | def test_update_five_iterations(dummy_evaluator, eos): 50 | eos.attach(dummy_evaluator) 51 | dummy_evaluator.run(range(5)) 52 | assert len(eos.data) == 5 53 | 54 | 55 | def test_attatch(dummy_evaluator, eos): 56 | eos.attach(dummy_evaluator) 57 | assert dummy_evaluator.has_event_handler(eos.reset, Events.EPOCH_STARTED) 58 | assert dummy_evaluator.has_event_handler(eos.update, Events.ITERATION_COMPLETED) 59 | 60 | 61 | def test_store_data(dummy_evaluator, eos): 62 | eos.attach(dummy_evaluator, name="eval_data") 63 | dummy_evaluator.run(range(1)) 64 | assert dummy_evaluator.state.eval_data == eos.data 65 | -------------------------------------------------------------------------------- /tests/ignite/handlers/test_time_limit.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | 5 | from ignite.engine import Engine, Events 6 | from ignite.handlers import TimeLimit 7 | 8 | 9 | def test_arg_validation(): 10 | with pytest.raises(ValueError, match=r"Argument limit_sec should be a positive integer."): 11 | TimeLimit(limit_sec=-5) 12 | 13 | with pytest.raises(TypeError, match=r"Argument limit_sec should be an integer."): 14 | TimeLimit(limit_sec="abc") 15 | 16 | 17 | def _train_func(engine, batch): 18 | time.sleep(1) 19 | 20 | 21 | @pytest.mark.parametrize("n_iters, limit", [(20, 10), (5, 10)]) 22 | def test_terminate_on_time_limit(n_iters, limit): 23 | started = time.time() 24 | trainer = Engine(_train_func) 25 | 26 | @trainer.on(Events.TERMINATE) 27 | def _(): 28 | trainer.state.is_terminated = True 29 | 30 | trainer.add_event_handler(Events.ITERATION_COMPLETED, TimeLimit(limit)) 31 | trainer.state.is_terminated = False 32 | 33 | trainer.run(range(n_iters)) 34 | elapsed = round(time.time() - started) 35 | assert elapsed <= limit + 1 36 | assert trainer.state.is_terminated == (n_iters > limit) 37 | -------------------------------------------------------------------------------- /tests/ignite/handlers/test_timing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import pytest 5 | 6 | from ignite.engine import Engine, Events 7 | from ignite.handlers import Timer 8 | 9 | if sys.platform.startswith("darwin"): 10 | pytest.skip("Skip if on MacOS", allow_module_level=True) 11 | 12 | 13 | def test_timer(): 14 | sleep_t = 0.2 15 | n_iter = 3 16 | 17 | def _train_func(engine, batch): 18 | time.sleep(sleep_t) 19 | 20 | def _test_func(engine, batch): 21 | time.sleep(sleep_t) 22 | 23 | trainer = Engine(_train_func) 24 | tester = Engine(_test_func) 25 | 26 | t_total = Timer() 27 | t_batch = Timer(average=True) 28 | t_train = Timer() 29 | 30 | t_total.attach(trainer) 31 | t_batch.attach( 32 | trainer, pause=Events.ITERATION_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED 33 | ) 34 | t_train.attach(trainer, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED) 35 | 36 | @trainer.on(Events.EPOCH_COMPLETED) 37 | def run_validation(trainer): 38 | tester.run(range(n_iter)) 39 | 40 | # Run "training" 41 | trainer.run(range(n_iter)) 42 | 43 | assert pytest.approx(t_total.value(), abs=1e-1) == 2 * n_iter * sleep_t 44 | assert pytest.approx(t_batch.value(), abs=1e-1) == sleep_t 45 | assert pytest.approx(t_train.value(), abs=1e-1) == n_iter * sleep_t 46 | 47 | t_total.reset() 48 | assert pytest.approx(t_total.value(), abs=1e-1) == 0.0 49 | -------------------------------------------------------------------------------- /tests/ignite/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Needed to collect coverage data 2 | -------------------------------------------------------------------------------- /tests/ignite/metrics/clustering/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/metrics/clustering/__init__.py -------------------------------------------------------------------------------- /tests/ignite/metrics/gan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/metrics/gan/__init__.py -------------------------------------------------------------------------------- /tests/ignite/metrics/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["CorpusForTest"] 2 | 3 | 4 | class CorpusForTest: 5 | def __init__(self, lower_split=False): 6 | def preproc(text): 7 | if lower_split: 8 | return text.lower().split() 9 | else: 10 | return text 11 | 12 | # BLEU Paper examples 13 | self.cand_1 = preproc("the the the the the the the") 14 | self.ref_1a = preproc("The cat is on the mat") 15 | self.ref_1b = preproc("There is a cat on the mat") 16 | 17 | self.cand_2a = preproc( 18 | "It is a guide to action which ensures that the military always obeys the commands of the party" 19 | ) 20 | self.cand_2b = preproc("It is to insure the troops forever hearing the activity guidebook that " "party direct") 21 | self.ref_2a = preproc( 22 | "It is a guide to action that ensures that the military will forever heed " "Party commands" 23 | ) 24 | self.ref_2b = preproc( 25 | "It is the guiding principle which guarantees the military forces always being under the command of " 26 | "the Party" 27 | ) 28 | self.ref_2c = preproc("It is the practical guide for the army always to heed the directions of the party") 29 | 30 | self.cand_3 = preproc("of the") 31 | 32 | self.references_1 = [self.ref_1a, self.ref_1b] 33 | self.references_2 = [self.ref_2a, self.ref_2b, self.ref_2c] 34 | 35 | self.sample_1 = ([self.cand_1], [self.references_1]) 36 | self.sample_2 = ([self.cand_3], [self.references_2]) 37 | self.sample_3 = ([self.cand_2a], [self.references_2]) 38 | self.sample_4 = ([self.cand_2b], [self.references_2]) 39 | self.sample_5 = ([self.cand_2a, self.cand_2b], [self.references_2, self.references_2]) 40 | 41 | self.references_3 = [self.ref_2a, self.ref_2b] 42 | self.references_4 = [self.ref_2b, self.ref_2c] 43 | self.references_5 = [self.ref_2a, self.ref_2c] 44 | 45 | self.chunks = [ 46 | ([self.cand_1], [self.references_1]), 47 | ([self.cand_2a], [self.references_2]), 48 | ([self.cand_2b], [self.references_2]), 49 | ([self.cand_1], [[self.ref_1a]]), 50 | ([self.cand_2a], [self.references_3]), 51 | ([self.cand_2b], [self.references_3]), 52 | ([self.cand_1], [[self.ref_1b]]), 53 | ([self.cand_2a], [self.references_4]), 54 | ([self.cand_2b], [self.references_4]), 55 | ([self.cand_1], [self.references_1]), 56 | ([self.cand_2a], [self.references_5]), 57 | ([self.cand_2b], [self.references_5]), 58 | ([self.cand_1], [[self.ref_1a]]), 59 | ([self.cand_2a], [[self.ref_2a]]), 60 | ([self.cand_2b], [[self.ref_2c]]), 61 | ] 62 | -------------------------------------------------------------------------------- /tests/ignite/metrics/nlp/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ignite.metrics.nlp.utils import lcs, modified_precision, ngrams 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "sequence, n, expected_keys, expected_values", 8 | [ 9 | ([], 1, [], []), 10 | ([0, 1, 2], 1, [(0,), (1,), (2,)], [1, 1, 1]), 11 | ([0, 1, 2], 2, [(0, 1), (1, 2)], [1, 1]), 12 | ([0, 1, 2], 3, [(0, 1, 2)], [1]), 13 | ([0, 0, 0], 1, [(0,)], [3]), 14 | ([0, 0, 0], 2, [(0, 0)], [2]), 15 | ("abcde", 4, [("a", "b", "c", "d"), ("b", "c", "d", "e")], [1, 1]), 16 | ], 17 | ) 18 | def test_ngrams(sequence, n, expected_keys, expected_values): 19 | ngrams_counter = ngrams(sequence=sequence, n=n) 20 | assert list(ngrams_counter.values()) == expected_values 21 | assert list(ngrams_counter.keys()) == expected_keys 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "seq_a, seq_b, expected", 26 | [([], [], 0), ([0, 1, 2], [0, 1, 2], 3), ([0, 1, 2], [0, 3, 2], 2), ("academy", "abracadabra", 4)], 27 | ) 28 | def test_lcs(seq_a, seq_b, expected): 29 | assert lcs(seq_a, seq_b) == expected 30 | 31 | 32 | def test_modified_precision_empty(): 33 | for k in range(1, 5): 34 | n, d = modified_precision([[]], [], k) 35 | assert n == 0 and d == 0 36 | n, d = modified_precision([[]], [0], k) 37 | assert n == 0 and d == (k == 1) 38 | n, d = modified_precision([[0]], [], k) 39 | assert n == 0 and d == 0 40 | n, d = modified_precision([[]], list(range(k)), k) 41 | assert n == 0 and d == 1 42 | n, d = modified_precision([list(range(k))], [], k) 43 | assert n == 0 and d == 0 44 | 45 | 46 | @pytest.mark.parametrize( 47 | "references, candidate, expected", 48 | [ 49 | ([[0, 0, 0], [1, 2]], [1, 2, 3, 4], ((2, 4), (1, 3), (0, 2))), 50 | ([[0, 1, 2], [0, 0, 3]], [0, 0, 0, 1, 2], ((4, 5), (3, 4), (1, 3))), 51 | ([[0, 1, 2], [3, 0, 3]], [3, 0, 0, 1, 2], ((4, 5), (3, 4), (1, 3))), 52 | ], 53 | ) 54 | def test_modified_precision(references, candidate, expected): 55 | for n, (e_n, e_d) in enumerate(expected, start=1): 56 | n, d = modified_precision(references, candidate, n) 57 | assert n == e_n and d == e_d 58 | -------------------------------------------------------------------------------- /tests/ignite/metrics/regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/metrics/regression/__init__.py -------------------------------------------------------------------------------- /tests/ignite/metrics/regression/test__base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | 5 | import pytest 6 | import torch 7 | 8 | import ignite.distributed as idist 9 | 10 | from ignite.metrics.regression._base import _BaseRegression, _torch_median 11 | 12 | 13 | def test_base_regression_shapes(): 14 | class L1(_BaseRegression): 15 | def reset(self): 16 | self._sum_of_errors = 0.0 17 | 18 | def _update(self, output): 19 | y_pred, y = output 20 | errors = torch.abs(y.view_as(y_pred) - y_pred) 21 | self._sum_of_errors += torch.sum(errors).item() 22 | 23 | def compute(self): 24 | return self._sum_of_errors 25 | 26 | m = L1() 27 | 28 | with pytest.raises(ValueError, match=r"Input y_pred should have shape \(N,\) or \(N, 1\)"): 29 | y = torch.rand([1, 1, 1]) 30 | m.update((y, y)) 31 | 32 | with pytest.raises(ValueError, match=r"Input y should have shape \(N,\) or \(N, 1\)"): 33 | y = torch.rand([1, 1, 1]) 34 | m.update((torch.rand(1, 1), y)) 35 | 36 | with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): 37 | m.update((torch.rand(2), torch.rand(2, 1))) 38 | 39 | with pytest.raises(TypeError, match=r"Input y_pred dtype should be float"): 40 | y = torch.tensor([1, 1]) 41 | m.update((y, y)) 42 | 43 | with pytest.raises(TypeError, match=r"Input y dtype should be float"): 44 | y = torch.tensor([1, 1]) 45 | m.update((y.float(), y)) 46 | 47 | 48 | @pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)]) 49 | def test_torch_median_numpy(size, device: Optional[str] = None): 50 | data = torch.rand(size).to(device) 51 | assert _torch_median(data) == np.median(data.cpu().numpy()) 52 | 53 | 54 | @pytest.mark.tpu 55 | @pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)]) 56 | @pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") 57 | def test_on_even_size_xla(size): 58 | device = "xla" 59 | test_torch_median_numpy(size, device=device) 60 | 61 | 62 | @pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)]) 63 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") 64 | def test_on_even_size_gpu(size): 65 | test_torch_median_numpy(size, device="cuda") 66 | 67 | 68 | @pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)]) 69 | def test_create_even_size_cpu(size): 70 | test_torch_median_numpy(size, device="cpu") 71 | -------------------------------------------------------------------------------- /tests/ignite/metrics/test_dill.py: -------------------------------------------------------------------------------- 1 | import dill 2 | 3 | from ignite.metrics import Metric 4 | 5 | 6 | class Accumulation(Metric): 7 | def __init__(self): 8 | self.value = 0 9 | super(Accumulation, self).__init__() 10 | 11 | def reset(self): 12 | self.value = 0 13 | 14 | def compute(self): 15 | return self.value 16 | 17 | def update(self, output): 18 | self.value += output 19 | 20 | 21 | def test_metric(): 22 | def _test(m, values, e): 23 | for v in values: 24 | m.update(v) 25 | assert m.compute() == e 26 | 27 | metric = Accumulation() 28 | 29 | m1 = dill.loads(dill.dumps(metric)) 30 | 31 | values = list(range(10)) 32 | expected = sum(values) 33 | 34 | _test(m1, values, expected) 35 | 36 | metric.update(5) 37 | 38 | m2 = dill.loads(dill.dumps(metric)) 39 | 40 | _test(m2, values, expected + 5) 41 | -------------------------------------------------------------------------------- /tests/ignite/metrics/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ignite/adbb260bfb1756c106fed2e793e1b645918ece27/tests/ignite/metrics/vision/__init__.py -------------------------------------------------------------------------------- /tests/run_code_style.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | if "%1" == "lint" goto lint 4 | if "%1" == "fmt" goto fmt 5 | if "%1" == "mypy" goto mypy 6 | if "%1" == "install" goto install 7 | goto end 8 | 9 | :lint 10 | flake8 ignite tests examples --config setup.cfg 11 | ufmt diff . 12 | goto end 13 | 14 | :fmt 15 | ufmt format . 16 | goto end 17 | 18 | :mypy 19 | mypy --config-file mypy.ini 20 | goto end 21 | 22 | :install 23 | pip install --upgrade flake8 "black==24.10.0" "usort==1.0.8.post1" "ufmt==2.7.3" "mypy" 24 | goto end 25 | 26 | :end 27 | popd 28 | -------------------------------------------------------------------------------- /tests/run_code_style.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xeu 4 | 5 | if [ $1 = "lint" ]; then 6 | flake8 ignite tests examples --config setup.cfg 7 | ufmt diff . 8 | elif [ $1 = "fmt" ]; then 9 | ufmt format . 10 | elif [ $1 = "mypy" ]; then 11 | mypy --config-file mypy.ini 12 | elif [ $1 = "install" ]; then 13 | pip install --upgrade flake8 "black==24.10.0" "usort==1.0.8.post1" "ufmt==2.7.3" "mypy" 14 | fi 15 | -------------------------------------------------------------------------------- /tests/run_cpu_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source "$(dirname "$0")/common_test_functionality.sh" 3 | set -xeu 4 | 5 | skip_distrib_tests=${SKIP_DISTRIB_TESTS:-0} 6 | use_last_failed=${USE_LAST_FAILED:-0} 7 | match_tests_expression=${1:-""} 8 | 9 | use_xdist=${USE_XDIST:-1} 10 | core_args="-vvv tests/ignite" 11 | if [ "${use_xdist}" -eq "1" ]; then 12 | core_args="${core_args} --tx 4*popen//python=python" 13 | fi 14 | 15 | CUDA_VISIBLE_DEVICES="" run_tests \ 16 | --core_args "${core_args}" \ 17 | --cache_dir ".cpu-not-distrib" \ 18 | --skip_distrib_tests "${skip_distrib_tests}" \ 19 | --use_coverage 1 \ 20 | --match_tests_expression "${match_tests_expression}" \ 21 | --use_last_failed ${use_last_failed} 22 | 23 | # https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02 24 | if [ "${skip_distrib_tests}" -eq "1" ] || [ "${use_xdist}" -eq "0" ]; then 25 | exit 0 26 | fi 27 | 28 | # Run 2 processes with --dist=each 29 | CUDA_VISIBLE_DEVICES="" run_tests \ 30 | --core_args "-m distributed -vvv tests/ignite" \ 31 | --world_size 4 \ 32 | --cache_dir ".cpu-distrib" \ 33 | --skip_distrib_tests 0 \ 34 | --use_coverage 1 \ 35 | --match_tests_expression "${match_tests_expression}" \ 36 | --use_last_failed ${use_last_failed} 37 | -------------------------------------------------------------------------------- /tests/run_gpu_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source "$(dirname "$0")/common_test_functionality.sh" 3 | set -xeu 4 | 5 | # https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02 6 | skip_distrib_tests=${SKIP_DISTRIB_TESTS:-0} 7 | use_last_failed=${USE_LAST_FAILED:-0} 8 | ngpus=${1:-1} 9 | 10 | match_tests_expression=${2:-""} 11 | if [ -z "$match_tests_expression" ]; then 12 | cuda_pattern="cuda or nccl or gloo" 13 | else 14 | cuda_pattern="(cuda or nccl or gloo) and $match_tests_expression" 15 | fi 16 | 17 | run_tests \ 18 | --core_args "-vvv tests/ignite -m 'not distributed'" \ 19 | --cache_dir ".gpu-cuda" \ 20 | --skip_distrib_tests "${skip_distrib_tests}" \ 21 | --use_coverage 1 \ 22 | --match_tests_expression "${cuda_pattern}" \ 23 | --use_last_failed ${use_last_failed} 24 | 25 | if [ "${skip_distrib_tests}" -eq "1" ]; then 26 | exit 0 27 | fi 28 | 29 | run_tests \ 30 | --core_args "-vvv -m distributed tests/ignite" \ 31 | --cache_dir ".gpu-distrib" \ 32 | --skip_distrib_tests 0 \ 33 | --use_coverage 1 \ 34 | --match_tests_expression "${match_tests_expression}" \ 35 | --use_last_failed ${use_last_failed} 36 | 37 | 38 | if [ ${ngpus} -gt 1 ]; then 39 | run_tests \ 40 | --core_args "-vvv -m distributed tests/ignite" \ 41 | --world_size "${ngpus}" \ 42 | --cache_dir ".gpu-distrib-multi" \ 43 | --skip_distrib_tests 0 \ 44 | --use_coverage 1 \ 45 | --match_tests_expression "${match_tests_expression}" \ 46 | --use_last_failed ${use_last_failed} 47 | fi 48 | -------------------------------------------------------------------------------- /tests/run_multinode_tests_in_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Tests configuration: 4 | if [[ -z "$1" || "$1" -lt 2 ]]; then 5 | echo "nnodes setting default to 2" 6 | export nnodes=2 7 | else 8 | export nnodes=$1 9 | fi 10 | 11 | if [[ -z "$2" || "$2" -lt 1 ]]; then 12 | echo "nproc_per_node setting default to 4" 13 | export nproc_per_node=4 14 | else 15 | export nproc_per_node=$2 16 | fi 17 | 18 | if [ -z "$3" ]; then 19 | echo "gpu setting default to 0 ( False )" 20 | export gpu=0 21 | else 22 | export gpu=$3 23 | fi 24 | 25 | # Start script from ignite root folder 26 | if [ ! -d tests ]; then 27 | echo "Ignite tests folder is not found. Please run script from ignite's root folder" 28 | exit 1 29 | fi 30 | 31 | docker_image="pytorchignite/tests:latest" 32 | 33 | docker build -t $docker_image -<