├── .gitignore ├── HISTORY.rst ├── HISTORY_experiment.rst ├── HISTORY_oneshot.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── README_experiment.rst ├── README_ignite.rst ├── README_oneshot.rst ├── data_for_augmentation ├── 1shotRun10ClassIdxDict.pkl ├── 1shotRun10UsedIndices.pkl ├── 1shotRun1ClassIdxDict.pkl ├── 1shotRun1UsedIndices.pkl ├── 1shotRun2ClassIdxDict.pkl ├── 1shotRun2UsedIndices.pkl ├── 1shotRun3ClassIdxDict.pkl ├── 1shotRun3UsedIndices.pkl ├── 1shotRun4ClassIdxDict.pkl ├── 1shotRun4UsedIndices.pkl ├── 1shotRun5ClassIdxDict.pkl ├── 1shotRun5UsedIndices.pkl ├── 1shotRun6ClassIdxDict.pkl ├── 1shotRun6UsedIndices.pkl ├── 1shotRun7ClassIdxDict.pkl ├── 1shotRun7UsedIndices.pkl ├── 1shotRun8ClassIdxDict.pkl ├── 1shotRun8UsedIndices.pkl ├── 1shotRun9ClassIdxDict.pkl ├── 1shotRun9UsedIndices.pkl ├── 5shotRun10ClassIdxDict.pkl ├── 5shotRun10UsedIndices.pkl ├── 5shotRun1ClassIdxDict.pkl ├── 5shotRun1UsedIndices.pkl ├── 5shotRun2ClassIdxDict.pkl ├── 5shotRun2UsedIndices.pkl ├── 5shotRun3ClassIdxDict.pkl ├── 5shotRun3UsedIndices.pkl ├── 5shotRun4ClassIdxDict.pkl ├── 5shotRun4UsedIndices.pkl ├── 5shotRun5ClassIdxDict.pkl ├── 5shotRun5UsedIndices.pkl ├── 5shotRun6ClassIdxDict.pkl ├── 5shotRun6UsedIndices.pkl ├── 5shotRun7ClassIdxDict.pkl ├── 5shotRun7UsedIndices.pkl ├── 5shotRun8ClassIdxDict.pkl ├── 5shotRun8UsedIndices.pkl ├── 5shotRun9ClassIdxDict.pkl └── 5shotRun9UsedIndices.pkl ├── environment.yml ├── experiment ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── experiment.cpython-36.pyc │ ├── monitor.cpython-36.pyc │ ├── tensorboard_x.cpython-36.pyc │ ├── utils.cpython-36.pyc │ └── visdom.cpython-36.pyc ├── experiment.py ├── monitor.py ├── tensorboard_x.py ├── utils.py └── visdom.py ├── ignite ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── _six.cpython-36.pyc │ ├── _utils.cpython-36.pyc │ ├── exceptions.cpython-36.pyc │ └── utils.cpython-36.pyc ├── _six.py ├── _utils.py ├── contrib │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── engines │ │ ├── __init__.py │ │ └── tbptt.py │ ├── handlers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── mlflow_logger.cpython-36.pyc │ │ │ ├── param_scheduler.cpython-36.pyc │ │ │ ├── tensorboard2_logger.cpython-36.pyc │ │ │ ├── tqdm_logger.cpython-36.pyc │ │ │ └── visdom_logger.cpython-36.pyc │ │ ├── mlflow_logger.py │ │ ├── param_scheduler.py │ │ ├── tensorboard2_logger.py │ │ ├── tqdm_logger.py │ │ └── visdom_logger.py │ └── metrics │ │ ├── __init__.py │ │ ├── average_precision.py │ │ ├── regression │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── canberra_metric.py │ │ ├── fractional_absolute_error.py │ │ ├── fractional_bias.py │ │ ├── geometric_mean_absolute_error.py │ │ ├── manhattan_distance.py │ │ ├── maximum_absolute_error.py │ │ ├── mean_absolute_relative_error.py │ │ ├── mean_error.py │ │ ├── mean_normalized_bias.py │ │ └── wave_hedges_distance.py │ │ └── roc_auc.py ├── engine │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── engine.cpython-36.pyc │ └── engine.py ├── exceptions.py ├── handlers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── checkpoint.cpython-36.pyc │ │ ├── early_stopping.cpython-36.pyc │ │ ├── terminate_on_nan.cpython-36.pyc │ │ └── timing.cpython-36.pyc │ ├── checkpoint.py │ ├── early_stopping.py │ ├── terminate_on_nan.py │ └── timing.py ├── metrics │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── accuracy.cpython-36.pyc │ │ ├── binary_accuracy.cpython-36.pyc │ │ ├── categorical_accuracy.cpython-36.pyc │ │ ├── epoch_metric.cpython-36.pyc │ │ ├── loss.cpython-36.pyc │ │ ├── mean_absolute_error.cpython-36.pyc │ │ ├── mean_pairwise_distance.cpython-36.pyc │ │ ├── mean_squared_error.cpython-36.pyc │ │ ├── metric.cpython-36.pyc │ │ ├── metrics_lambda.cpython-36.pyc │ │ ├── precision.cpython-36.pyc │ │ ├── recall.cpython-36.pyc │ │ ├── root_mean_squared_error.cpython-36.pyc │ │ ├── running_average.cpython-36.pyc │ │ └── top_k_categorical_accuracy.cpython-36.pyc │ ├── accuracy.py │ ├── binary_accuracy.py │ ├── categorical_accuracy.py │ ├── epoch_metric.py │ ├── loss.py │ ├── mean_absolute_error.py │ ├── mean_pairwise_distance.py │ ├── mean_squared_error.py │ ├── metric.py │ ├── metrics_lambda.py │ ├── precision.py │ ├── recall.py │ ├── root_mean_squared_error.py │ ├── running_average.py │ └── top_k_categorical_accuracy.py └── utils.py ├── oneshot ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── coco.cpython-36.pyc │ └── utils.cpython-36.pyc ├── alfassy │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── datasets.cpython-36.pyc │ │ ├── img_to_vec.cpython-36.pyc │ │ ├── setops_funcs.cpython-36.pyc │ │ ├── testing_functions.cpython-36.pyc │ │ └── utils.cpython-36.pyc │ ├── datasets.py │ ├── img_to_vec.py │ ├── setops_funcs.py │ └── utils.py ├── celeba.py ├── cnnvisualizer │ ├── __init__.py │ ├── cnnvisualizer.py │ ├── tightcrop.py │ ├── wideresnet.py │ └── wideresnet_utils.py ├── coco.py ├── global_settings.py ├── ignite │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── engine │ │ └── __init__.py │ ├── handlers │ │ ├── __init__.py │ │ ├── find_learning_rate.py │ │ └── param_scheduler.py │ └── metrics │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── metrics.cpython-36.pyc │ │ └── metrics.py ├── mixup.py ├── oneshot.py ├── pytorch │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── datasets.cpython-36.pyc │ │ └── losses.cpython-36.pyc │ ├── datasets.py │ ├── losses.py │ └── monitor.py ├── setops_models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── ae_setops.cpython-36.pyc │ │ ├── discriminators.cpython-36.pyc │ │ ├── inception.cpython-36.pyc │ │ ├── res_setops.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── setops.cpython-36.pyc │ ├── ae_setops.py │ ├── discriminators.py │ ├── inception.py │ ├── res_setops.py │ ├── resnet.py │ ├── resnet_backup.py │ ├── setops.py │ └── vae_setops.py ├── stn.py ├── triplet │ ├── __init__.py │ ├── datasets.py │ ├── losses.py │ ├── trainer.py │ └── utils.py ├── utils.py └── wideresnet_places.py ├── requirements_dev.txt ├── scripts_coco ├── example_use.py ├── test_augmentation.py ├── test_precision.py ├── test_retrieval.py └── train_setops_stripped.py ├── setup.cfg ├── setup.py └── spec-file.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.1.0 (2018-08-13) 6 | ------------------ 7 | 8 | * First release on PyPI. 9 | -------------------------------------------------------------------------------- /HISTORY_experiment.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.2.0 (2019-01-10) 6 | ------------------ 7 | 8 | * First release. 9 | -------------------------------------------------------------------------------- /HISTORY_oneshot.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.1.0 (2018-08-13) 6 | ------------------ 7 | 8 | * First release on PyPI. 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 IBM Corp. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include HISTORY.rst 2 | include LICENSE 3 | include README.rst 4 | 5 | recursive-include tests * 6 | recursive-exclude * __pycache__ 7 | recursive-exclude * *.py[co] 8 | 9 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 10 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | try: 8 | from urllib import pathname2url 9 | except: 10 | from urllib.request import pathname2url 11 | 12 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 13 | endef 14 | export BROWSER_PYSCRIPT 15 | 16 | define PRINT_HELP_PYSCRIPT 17 | import re, sys 18 | 19 | for line in sys.stdin: 20 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 21 | if match: 22 | target, help = match.groups() 23 | print("%-20s %s" % (target, help)) 24 | endef 25 | export PRINT_HELP_PYSCRIPT 26 | 27 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 28 | 29 | help: 30 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 31 | 32 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 33 | 34 | clean-build: ## remove build artifacts 35 | rm -fr build/ 36 | rm -fr dist/ 37 | rm -fr .eggs/ 38 | find . -name '*.egg-info' -exec rm -fr {} + 39 | find . -name '*.egg' -exec rm -f {} + 40 | 41 | clean-pyc: ## remove Python file artifacts 42 | find . -name '*.pyc' -exec rm -f {} + 43 | find . -name '*.pyo' -exec rm -f {} + 44 | find . -name '*~' -exec rm -f {} + 45 | find . -name '__pycache__' -exec rm -fr {} + 46 | 47 | clean-test: ## remove test and coverage artifacts 48 | rm -fr .tox/ 49 | rm -f .coverage 50 | rm -fr htmlcov/ 51 | rm -fr .pytest_cache 52 | 53 | lint: ## check style with flake8 54 | flake8 oneshot tests 55 | 56 | test: ## run tests quickly with the default Python 57 | python setup.py test 58 | 59 | test-all: ## run tests on every Python version with tox 60 | tox 61 | 62 | coverage: ## check code coverage quickly with the default Python 63 | coverage run --source oneshot setup.py test 64 | coverage report -m 65 | coverage html 66 | $(BROWSER) htmlcov/index.html 67 | 68 | docs: ## generate Sphinx HTML documentation, including API docs 69 | rm -f docs/oneshot.rst 70 | rm -f docs/modules.rst 71 | sphinx-apidoc -o docs/ oneshot 72 | $(MAKE) -C docs clean 73 | $(MAKE) -C docs html 74 | $(BROWSER) docs/_build/html/index.html 75 | 76 | servedocs: docs ## compile the docs watching for changes 77 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 78 | 79 | release: dist ## package and upload a release 80 | twine upload dist/* 81 | 82 | dist: clean ## builds source and wheel package 83 | python setup.py sdist 84 | python setup.py bdist_wheel 85 | ls -l dist 86 | 87 | install: clean ## install the package to the active Python's site-packages 88 | python setup.py install 89 | -------------------------------------------------------------------------------- /README_experiment.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Experiment 3 | ========== 4 | 5 | 6 | .. image:: https://img.shields.io/pypi/v/experiment.svg 7 | :target: https://pypi.python.org/pypi/experiment 8 | 9 | .. image:: https://img.shields.io/travis/amitibo/experiment.svg 10 | :target: https://travis-ci.org/amitibo/experiment 11 | 12 | .. image:: https://readthedocs.org/projects/experiment/badge/?version=latest 13 | :target: https://experiment.readthedocs.io/en/latest/?badge=latest 14 | :alt: Documentation Status 15 | 16 | 17 | Framework for running experiments. 18 | 19 | The `experiment` package is meant for simplifying conducting experiments by hiding 20 | most of the "boring" boiler plate code, e.g. experiment configuration and logging. 21 | It is based on the Traitlets_ package. 22 | 23 | .. note:: 24 | The `experiment` package is still in beta state and the API might change. 25 | 26 | * Free software: MIT license 27 | 28 | .. * Documentation: https://pages.github.ibm.com/AMITAID/experiment/ 29 | 30 | 31 | TL;DR 32 | ----- 33 | 34 | Copy the following example to a python file ``hello_experiment.py``:: 35 | 36 | 37 | from experiment import Experiment 38 | import logging 39 | import time 40 | from traitlets import Int, Unicode 41 | 42 | 43 | class Main(Experiment): 44 | description = Unicode("My hellow world experiment.") 45 | epochs = Int(10, config=True, help="Number of epochs") 46 | 47 | def run(self): 48 | """Running the experiment""" 49 | 50 | logging.info("Starting experiment") 51 | 52 | loss = 100 53 | for i in range(self.epochs): 54 | logging.info("Running epoch [{}/[]]".format(i, self.epochs)) 55 | time.sleep(.5) 56 | 57 | logging.info("Experiment finished") 58 | 59 | 60 | if __name__ == "__main__": 61 | main = Main() 62 | main.initialize() 63 | main.start() 64 | 65 | Run the script from the command line like:: 66 | 67 | $ python hello_experiment.py --epochs 15 68 | 69 | The configuration, logs and results of the script will be stored in a unique folder under ``/tmp/results/...``. 70 | 71 | To check the script documentation, run the following from the command line:: 72 | 73 | $ python hello_experiment.py --help 74 | 75 | See the documentation for more advanced usage. 76 | 77 | Features 78 | -------- 79 | 80 | * Clean and versatile configuration system based on the Traitlets_ package. 81 | * Automatic logging setup. 82 | * Configuration and logging are automatically saved in a unique results folder. 83 | * Run parameters are stored in a configuraiton file to allow for replaying the same experiment. 84 | * Support for multiple logging frameworks: mlflow_, visdom_, tensorboard_ 85 | * Automatic monitoring of GPU usage. 86 | 87 | The examples_ folder contains multiple examples showcasing the package features. 88 | 89 | Credits 90 | ------- 91 | 92 | This package was created with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template. 93 | 94 | .. _Cookiecutter: https://github.com/audreyr/cookiecutter 95 | .. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage 96 | .. _Traitlets: https://traitlets.readthedocs.io/en/stable/index.html 97 | .. _mlflow: https://mlflow.org/ 98 | .. _visdom: https://github.com/facebookresearch/visdom 99 | .. _tensorboard: https://www.tensorflow.org/guide/summaries_and_tensorboard 100 | .. _examples: https://github.ibm.com/AMITAID/experiment/tree/master/examples 101 | -------------------------------------------------------------------------------- /README_ignite.rst: -------------------------------------------------------------------------------- 1 | Ignite 2 | ====== 3 | 4 | .. image:: https://travis-ci.org/pytorch/ignite.svg?branch=master 5 | :target: https://travis-ci.org/pytorch/ignite 6 | 7 | .. image:: https://codecov.io/gh/pytorch/ignite/branch/master/graph/badge.svg 8 | :target: https://codecov.io/gh/pytorch/ignite 9 | 10 | .. image:: https://pepy.tech/badge/pytorch-ignite 11 | :target: https://pepy.tech/project/pytorch-ignite 12 | 13 | .. image:: https://img.shields.io/badge/dynamic/json.svg?label=docs&url=https%3A%2F%2Fpypi.org%2Fpypi%2Fpytorch-ignite%2Fjson&query=%24.info.version&colorB=brightgreen&prefix=v 14 | :target: https://pytorch.org/ignite/index.html 15 | 16 | Ignite is a high-level library to help with training neural networks in PyTorch. 17 | 18 | - ignite helps you write compact but full-featured training loops in a few lines of code 19 | - you get a training loop with metrics, early-stopping, model checkpointing and other features without the boilerplate 20 | 21 | Below we show a side-by-side comparison of using pure pytorch and using ignite to create a training loop 22 | to train and validate your model with occasional checkpointing: 23 | 24 | .. image:: assets/ignite_vs_bare_pytorch.png 25 | :target: https://raw.githubusercontent.com/pytorch/ignite/master/assets/ignite_vs_bare_pytorch.png 26 | 27 | As you can see, the code is more concise and readable with ignite. Furthermore, adding additional metrics, or 28 | things like early stopping is a breeze in ignite, but can start to rapidly increase the complexity of 29 | your code when "rolling your own" training loop. 30 | 31 | 32 | Installation 33 | ============ 34 | 35 | From pip: 36 | 37 | .. code:: bash 38 | 39 | pip install pytorch-ignite 40 | 41 | 42 | From conda: 43 | 44 | .. code:: bash 45 | 46 | conda install ignite -c pytorch 47 | 48 | 49 | From source: 50 | 51 | .. code:: bash 52 | 53 | python setup.py install 54 | 55 | 56 | Why Ignite? 57 | =========== 58 | Ignite's high level of abstraction assumes less about the type of network (or networks) that you are training, and we require the user to define the closure to be run in the training and validation loop. This level of abstraction allows for a great deal more of flexibility, such as co-training multiple models (i.e. GANs) and computing/tracking multiple losses and metrics in your training loop. 59 | 60 | Ignite also allows for multiple handlers to be attached to events, and a finer granularity of events in the engine loop. 61 | 62 | 63 | Documentation 64 | ============= 65 | API documentation and an overview of the library can be found `here `_. 66 | 67 | 68 | Structure 69 | ========= 70 | - **ignite**: Core of the library, contains an engine for training and evaluating, all of the classic machine learning metrics and a variety of handlers to ease the pain of training and validation of neural networks! 71 | 72 | - **ignite.contrib**: The Contrib directory contains additional modules contributed by Ignite users. Modules vary from TBPTT engine, various optimisation parameter schedulers, logging handlers and a metrics module containing many regression metrics (`ignite.contrib.metrics.regression `_)! 73 | 74 | The code in **ignite.contrib** is not as fully maintained as the core part of the library. It may change or be removed at any time without notice. 75 | 76 | 77 | Examples 78 | ======== 79 | Please check out the `examples 80 | `_ to see how to use `ignite` to train various types of networks, as well as how to use `visdom `_ or `tensorboardX `_ for training visualizations. 81 | 82 | 83 | Contributing 84 | ============ 85 | We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us. 86 | 87 | Please see the `contribution guidelines `_ for more information. 88 | 89 | As always, PRs are welcome :) 90 | -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun10ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun10ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun10UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun10UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun1ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun1ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun1UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun1UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun2ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun2ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun2UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun2UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun3ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun3ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun3UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun3UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun4ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun4ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun4UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun4UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun5ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun5ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun5UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun5UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun6ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun6ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun6UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun6UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun7ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun7ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun7UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun7UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun8ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun8ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun8UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun8UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun9ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun9ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/1shotRun9UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/1shotRun9UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun10ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun10ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun10UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun10UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun1ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun1ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun1UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun1UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun2ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun2ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun2UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun2UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun3ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun3ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun3UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun3UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun4ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun4ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun4UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun4UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun5ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun5ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun5UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun5UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun6ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun6ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun6UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun6UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun7ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun7ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun7UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun7UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun8ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun8ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun8UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun8UsedIndices.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun9ClassIdxDict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun9ClassIdxDict.pkl -------------------------------------------------------------------------------- /data_for_augmentation/5shotRun9UsedIndices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/data_for_augmentation/5shotRun9UsedIndices.pkl -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for Experiment.""" 4 | 5 | __author__ = """Amit Aides""" 6 | __email__ = 'amiti.bo@gmail.com' 7 | __version__ = '0.2.0' 8 | 9 | 10 | from .experiment import Experiment 11 | from .experiment import MLflowExperiment 12 | from .experiment import TensorboardXExperiment 13 | from .experiment import VisdomExperiment 14 | -------------------------------------------------------------------------------- /experiment/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/experiment/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/experiment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/experiment/__pycache__/experiment.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/monitor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/experiment/__pycache__/monitor.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/tensorboard_x.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/experiment/__pycache__/tensorboard_x.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/experiment/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/visdom.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/experiment/__pycache__/visdom.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | from typing import Callable 4 | from typing import Tuple 5 | 6 | import py3nvml.py3nvml as nvml 7 | 8 | def try_get_info( 9 | f : Callable[[int], object], 10 | h : object, 11 | default : object='N/A') -> int: 12 | """Safely try to call pynvml api.""" 13 | 14 | try: 15 | v = f(h) 16 | except nvml.NVMLError_NotSupported: 17 | v = default 18 | 19 | return v 20 | 21 | 22 | def gpu_info(gpu_index : int) -> Tuple[str, int]: 23 | """Returns a description of a GPU 24 | 25 | Returns the description and memory size of GPU. 26 | """ 27 | 28 | nvml.nvmlInit() 29 | 30 | handle = nvml.nvmlDeviceGetHandleByIndex(gpu_index) 31 | gpu_desc = nvml.nvmlDeviceGetName(handle) 32 | 33 | # 34 | # Get memory info. 35 | # 36 | mem_info = try_get_info(nvml.nvmlDeviceGetMemoryInfo, handle) 37 | if mem_info != 'N/A': 38 | mem_total = mem_info.total >> 20 39 | else: 40 | mem_total = 0 41 | 42 | return gpu_desc, mem_total 43 | 44 | 45 | def query_gpu(index : int) -> Tuple[int, int, int]: 46 | 47 | h = nvml.nvmlDeviceGetHandleByIndex(index) 48 | 49 | # 50 | # Get memory info. 51 | # 52 | mem_info = try_get_info(nvml.nvmlDeviceGetMemoryInfo, h) 53 | if mem_info != 'N/A': 54 | mem_used = mem_info.used >> 20 55 | mem_total = mem_info.total >> 20 56 | else: 57 | mem_used = 0 58 | mem_total = 0 59 | 60 | # 61 | # Get utilization info 62 | # 63 | util = try_get_info(nvml.nvmlDeviceGetUtilizationRates, h) 64 | if util != 'N/A': 65 | gpu_util = util.gpu 66 | else: 67 | gpu_util = 0 68 | 69 | return mem_used, mem_total, gpu_util 70 | 71 | 72 | class GPUMonitor(threading.Thread): 73 | shutdown = False 74 | daemon = True 75 | 76 | def __init__( 77 | self, 78 | gpu_index : int, 79 | callback : Callable[[int, int, int, int], None], 80 | sampling_period : float=0.5): 81 | """Utility class that monitors a specific GPU. 82 | 83 | Args: 84 | gpu_index (int): Index of GPU to monitor. 85 | callback (callback); Callback to call with GPU info. 86 | sampling_period (float): The monitoring period. 87 | """ 88 | threading.Thread.__init__(self) 89 | self.gpu_index = gpu_index 90 | self.sampling_period = sampling_period 91 | self._monitor_callback = callback 92 | 93 | def run(self): 94 | 95 | # 96 | # Initialize nvml on the thread. 97 | # 98 | nvml.nvmlInit() 99 | 100 | t0 = time.time() 101 | while not self.shutdown: 102 | dt = int(time.time() - t0) 103 | 104 | mem_used, mem_total, gpu_util = query_gpu(self.gpu_index) 105 | self._monitor_callback(dt, mem_used, mem_total, gpu_util) 106 | 107 | time.sleep(self.sampling_period) 108 | 109 | def stop(self): 110 | self.shutdown = True 111 | -------------------------------------------------------------------------------- /experiment/tensorboard_x.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | import numpy as np 4 | import os 5 | import pprint 6 | import threading 7 | from typing import Any, Union, List, Tuple, Dict 8 | 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | class TensorBoardXLogHandler(logging.Handler): 13 | """Logging handler that logs to a TensorBoardX instance. 14 | 15 | Args: 16 | summary_writer (tensorboard.SummaryWriter): The summarywriter to log to. 17 | title (string): Title/tag to write to. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | summary_writer, # type: SummaryWriter 23 | title="Logging", # type: str 24 | *args, **kwds 25 | ): 26 | 27 | super(TensorBoardXLogHandler, self).__init__(*args, **kwds) 28 | 29 | self.summary_writer = summary_writer 30 | self.title = title 31 | self.global_step = 0 32 | self.accomulated_entries = "" 33 | 34 | def emit(self, record): 35 | log_entry = self.format(record) 36 | 37 | # 38 | # There seems to be a bug in writing new lines: 39 | # https://stackoverflow.com/questions/45016458/tensorflow-tf-summary-text-and-linebreaks 40 | # 41 | self.accomulated_entries += " \n" 42 | self.accomulated_entries += log_entry.replace("\n", " \n") 43 | 44 | self.summary_writer.add_text( 45 | tag=self.title, 46 | text_string=self.accomulated_entries, 47 | global_step=self.global_step 48 | ) 49 | 50 | 51 | def write_conf( 52 | summary_writer, # type: SummaryWriter 53 | args=None, text=None): 54 | """Write configuration to the Visdom env. 55 | 56 | Args: 57 | summary_writer (tensorboard.SummaryWriter): The summarywriter to log to. 58 | args (Namespace, optional): The argument namespace returned by argparse. 59 | text (string, optional): Configuration as text block. 60 | """ 61 | 62 | conf_text = "" 63 | if args: 64 | conf_text += pprint.pformat(args.__dict__, indent=4) 65 | if text: 66 | conf_text += text 67 | 68 | # 69 | # There seems to be a bug in writing new lines: 70 | # https://stackoverflow.com/questions/45016458/tensorflow-tf-summary-text-and-linebreaks 71 | # 72 | conf_text = conf_text.replace("\n", " \n") 73 | 74 | summary_writer.add_text( 75 | tag="Configuration", 76 | text_string=conf_text 77 | ) 78 | 79 | 80 | def monitor_gpu( 81 | summary_writer, # type: SummaryWriter 82 | gpu_index=None, # type: int 83 | xtick_size=100, # type: int 84 | ): # -> threading.Thread 85 | """Monitor the memory and utilization of a GPU. 86 | 87 | Args: 88 | env (str): The visdom environment to log to. 89 | gpu_index (int): The GPU to monitor. 90 | """ 91 | 92 | import CCC.monitor as mon 93 | 94 | if gpu_index is None: 95 | gpu_index = int(os.environ["CUDA_VISIBLE_DEVICES"]) 96 | 97 | desc, total = mon.gpu_info(gpu_index) 98 | title = desc.replace(' ', '_') 99 | 100 | def cb(dt, mem_used, mem_total, gpu_util): 101 | 102 | summary_writer.add_scalar( 103 | tag='/'.join([title, "mem"]), 104 | scalar_value=int(mem_used / total * 100), 105 | global_step=dt 106 | ) 107 | 108 | summary_writer.add_scalar( 109 | tag='/'.join([title, "util"]), 110 | scalar_value=gpu_util, 111 | global_step=dt 112 | ) 113 | 114 | sm = mon.GPUMonitor(gpu_index, cb) 115 | sm.start() 116 | 117 | return sm 118 | 119 | -------------------------------------------------------------------------------- /ignite/__init__.py: -------------------------------------------------------------------------------- 1 | import ignite.engine 2 | import ignite.handlers 3 | import ignite.metrics 4 | import ignite.exceptions 5 | import ignite.contrib 6 | import ignite.utils 7 | 8 | __version__ = '0.1.2' 9 | -------------------------------------------------------------------------------- /ignite/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/__pycache__/_six.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/__pycache__/_six.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/__pycache__/_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/__pycache__/_utils.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/__pycache__/exceptions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/__pycache__/exceptions.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/_six.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2010-2017 Benjamin Peterson 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | 22 | def with_metaclass(meta, *bases): 23 | """Create a base class with a metaclass.""" 24 | # This requires a bit of explanation: the basic idea is to make a dummy 25 | # metaclass for one level of class instantiation that replaces itself with 26 | # the actual metaclass. 27 | class metaclass(meta): 28 | 29 | def __new__(cls, name, this_bases, d): 30 | return meta(name, bases, d) 31 | return type.__new__(metaclass, 'temporary_class', (), {}) 32 | -------------------------------------------------------------------------------- /ignite/_utils.py: -------------------------------------------------------------------------------- 1 | 2 | # For compatibilty 3 | from ignite.utils import convert_tensor, apply_to_tensor, apply_to_type, to_onehot 4 | 5 | 6 | def _to_hours_mins_secs(time_taken): 7 | """Convert seconds to hours, mins, and seconds.""" 8 | mins, secs = divmod(time_taken, 60) 9 | hours, mins = divmod(mins, 60) 10 | return hours, mins, secs 11 | -------------------------------------------------------------------------------- /ignite/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/__init__.py -------------------------------------------------------------------------------- /ignite/contrib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/contrib/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.contrib.engines.tbptt import create_supervised_tbptt_trainer 2 | from ignite.contrib.engines.tbptt import Tbptt_Events 3 | -------------------------------------------------------------------------------- /ignite/contrib/engines/tbptt.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from enum import Enum 4 | 5 | import torch 6 | 7 | from ignite.utils import apply_to_tensor 8 | from ignite.engine import Engine, _prepare_batch 9 | 10 | 11 | class Tbptt_Events(Enum): 12 | """Aditional tbptt events. 13 | 14 | Additional events for truncated backpropagation throught time dedicated 15 | trainer. 16 | """ 17 | 18 | TIME_ITERATION_STARTED = "time_iteration_started" 19 | TIME_ITERATION_COMPLETED = "time_iteration_completed" 20 | 21 | 22 | def _detach_hidden(hidden): 23 | """Cut backpropagation graph. 24 | 25 | Auxillary function to cut the backpropagation graph by detaching the hidden 26 | vector. 27 | """ 28 | return apply_to_tensor(hidden, torch.Tensor.detach) 29 | 30 | 31 | def create_supervised_tbptt_trainer( 32 | model, 33 | optimizer, 34 | loss_fn, 35 | tbtt_step, 36 | dim=0, 37 | device=None, 38 | non_blocking=False, 39 | prepare_batch=_prepare_batch 40 | ): 41 | """Create a trainer for truncated backprop through time supervised models. 42 | 43 | Training recurrent model on long sequences is computationally intensive as 44 | it requires to process the whole sequence before getting a gradient. 45 | However, when the training loss is computed over many outputs 46 | (`X to many `_), 47 | there is an opportunity to compute a gradient over a subsequence. This is 48 | known as 49 | `truncated backpropagation through time `_. 51 | This supervised trainer apply gradient optimization step every `tbtt_step` 52 | time steps of the sequence, while backpropagating through the same 53 | `tbtt_step` time steps. 54 | 55 | Args: 56 | model (`torch.nn.Module`): the model to train. 57 | optimizer (`torch.optim.Optimizer`): the optimizer to use. 58 | loss_fn (torch.nn loss function): the loss function to use. 59 | tbtt_step (int): the length of time chunks (last one may be smaller). 60 | dim (int): axis representing the time dimension. 61 | device (str, optional): device type specification (default: None). 62 | Applies to both model and batches. 63 | non_blocking (bool, optional): if True and this copy is between CPU and GPU, 64 | the copy may occur asynchronously with respect to the host. For other cases, 65 | this argument has no effect. 66 | prepare_batch (callable, optional): function that receives `batch`, `device`, 67 | `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. 68 | 69 | Returns: 70 | Engine: a trainer engine with supervised update function. 71 | 72 | """ 73 | if device: 74 | model.to(device) 75 | 76 | def _update(engine, batch): 77 | loss_list = [] 78 | hidden = None 79 | 80 | x, y = batch 81 | for batch_t in zip(x.split(tbtt_step, dim=dim), y.split(tbtt_step, dim=dim)): 82 | x_t, y_t = prepare_batch(batch_t, device=device, non_blocking=non_blocking) 83 | # Fire event for start of iteration 84 | engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED) 85 | # Forward, backward and 86 | model.train() 87 | optimizer.zero_grad() 88 | if hidden is None: 89 | y_pred_t, hidden = model(x_t) 90 | else: 91 | hidden = _detach_hidden(hidden) 92 | y_pred_t, hidden = model(x_t, hidden) 93 | loss_t = loss_fn(y_pred_t, y_t) 94 | loss_t.backward() 95 | optimizer.step() 96 | 97 | # Setting state of engine for consistent behaviour 98 | engine.state.output = loss_t.item() 99 | loss_list.append(loss_t.item()) 100 | 101 | # Fire event for end of iteration 102 | engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED) 103 | 104 | # return average loss over the time splits 105 | return sum(loss_list) / len(loss_list) 106 | 107 | engine = Engine(_update) 108 | engine.register_events(*Tbptt_Events) 109 | return engine 110 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler, CosineAnnealingScheduler, \ 3 | ConcatScheduler, LRScheduler, create_lr_scheduler_with_warmup 4 | from ignite.contrib.handlers.param_scheduler import ParamScheduler, CyclicalScheduler, \ 5 | ReduceLROnPlateau 6 | 7 | from ignite.contrib.handlers.tensorboard2_logger import TensorboardLogger 8 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 9 | from ignite.contrib.handlers.visdom_logger import VisdomLogger 10 | from ignite.contrib.handlers.mlflow_logger import MlflowLogger 11 | -------------------------------------------------------------------------------- /ignite/contrib/handlers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/handlers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/contrib/handlers/__pycache__/mlflow_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/handlers/__pycache__/mlflow_logger.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/contrib/handlers/__pycache__/param_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/handlers/__pycache__/param_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/contrib/handlers/__pycache__/tensorboard2_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/handlers/__pycache__/tensorboard2_logger.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/contrib/handlers/__pycache__/tqdm_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/handlers/__pycache__/tqdm_logger.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/contrib/handlers/__pycache__/visdom_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/contrib/handlers/__pycache__/visdom_logger.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/contrib/handlers/mlflow_logger.py: -------------------------------------------------------------------------------- 1 | from ignite.engine import Engine 2 | from ignite.engine import Events 3 | from typing import Callable, List 4 | 5 | import mlflow 6 | 7 | 8 | class MlflowLogger: 9 | """Handler that logs metrics using the `mlflow tracking` system. 10 | 11 | Examples: 12 | 13 | Plotting of trainer loss. 14 | 15 | .. code-block:: python 16 | 17 | import mlflow 18 | 19 | mlflow.set_tracking_uri(server_url) 20 | experiment_id = mlflow.set_experiment(MLFLOW_EXPERIMENT) 21 | 22 | # 23 | # Run the training under mlflow 24 | # 25 | with mlflow.start_run(experiment_id=experiment_id): 26 | 27 | trainer = create_supervised_trainer(model, optimizer, loss) 28 | 29 | mlflow_plotter = MlflowLogger() 30 | 31 | mlflow_plotter.attach( 32 | engine=trainer, 33 | prefix="Train ", 34 | plot_event=Events.ITERATION_COMPLETED, 35 | output_transform=lambda x: {"loss": x} 36 | ) 37 | 38 | trainer.run(train_loader, max_epochs=epochs_num) 39 | 40 | """ 41 | 42 | def __init__(self): 43 | 44 | self.metrics_step = [] 45 | 46 | def _update( 47 | self, 48 | engine, # type: Engine 49 | attach_id, # type: int 50 | prefix, # type: str 51 | update_period, # type: int 52 | metric_names=None, # type: List 53 | output_transform=None, # type: Callable 54 | param_history=False # type: bool 55 | ): 56 | step = self.metrics_step[attach_id] 57 | self.metrics_step[attach_id] += 1 58 | if step % update_period != 0: 59 | return 60 | 61 | # 62 | # Get all the metrics 63 | # 64 | metrics = [] 65 | if metric_names is not None: 66 | if not all(metric in engine.state.metrics for metric in metric_names): 67 | raise KeyError("metrics not found in engine.state.metrics") 68 | 69 | metrics.extend([(name, engine.state.metrics[name]) for name in metric_names]) 70 | 71 | if output_transform is not None: 72 | output_dict = output_transform(engine.state.output) 73 | 74 | if not isinstance(output_dict, dict): 75 | output_dict = {"output": output_dict} 76 | 77 | metrics.extend([(name, value) for name, value in output_dict.items()]) 78 | 79 | if param_history: 80 | metrics.extend([(name, value[-1][0]) for name, value in engine.state.param_history.items()]) 81 | 82 | if not metrics: 83 | return 84 | 85 | for metric_name, new_value in metrics: 86 | mlflow.log_metric(prefix + metric_name, new_value) 87 | 88 | def attach( 89 | self, 90 | engine, # type: Engine 91 | prefix="", # type: str 92 | plot_event=Events.EPOCH_COMPLETED, # type: Events 93 | update_period=1, # type: int 94 | metric_names=None, # type: List 95 | output_transform=None, # type: Callable 96 | param_history=False, # type: bool 97 | ): 98 | """ 99 | Attaches the mlflow plotter to an engine object 100 | 101 | Args: 102 | engine (Engine): engine object 103 | prefix (str, optional): A prefix to add before the metric name. 104 | plot_event (str, optional): Name of event to handle. 105 | update_period (int, optional): Can be used to limit the number of plot updates. 106 | metric_names (list, optional): list of the metrics names to log. 107 | output_transform (Callable, optional): a function to select what you want to plot from the engine's 108 | output. This function may return either a dictionary with entries in the format of ``{name: value}``, 109 | or a single scalar, which will be displayed with the default name `output`. 110 | param_history (bool, optional): If true, will plot all the parameters logged in `param_history`. 111 | """ 112 | if metric_names is not None and not isinstance(metric_names, list): 113 | raise TypeError("metric_names should be a list, got {} instead".format(type(metric_names))) 114 | 115 | if output_transform is not None and not callable(output_transform): 116 | raise TypeError("output_transform should be a function, got {} instead" 117 | .format(type(output_transform))) 118 | 119 | assert plot_event in (Events.ITERATION_COMPLETED, Events.EPOCH_COMPLETED), \ 120 | "The plotting event should be either {} or {}".format(Events.ITERATION_COMPLETED, Events.EPOCH_COMPLETED) 121 | 122 | attach_id = len(self.metrics_step) 123 | self.metrics_step.append(0) 124 | 125 | engine.add_event_handler( 126 | plot_event, 127 | self._update, 128 | attach_id=attach_id, 129 | prefix=prefix, 130 | update_period=update_period, 131 | metric_names=metric_names, 132 | output_transform=output_transform, 133 | param_history=param_history 134 | ) 135 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.contrib.metrics.average_precision import AveragePrecision 2 | from ignite.contrib.metrics.roc_auc import ROC_AUC 3 | import ignite.contrib.metrics.regression 4 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/average_precision.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from ignite.metrics import EpochMetric 3 | 4 | 5 | def average_precision_compute_fn(y_preds, y_targets, activation=None): 6 | try: 7 | from sklearn.metrics import average_precision_score 8 | except ImportError: 9 | raise RuntimeError("This contrib module requires sklearn to be installed.") 10 | 11 | y_true = y_targets.numpy() 12 | if activation is not None: 13 | y_preds = activation(y_preds) 14 | y_pred = y_preds.numpy() 15 | return average_precision_score(y_true, y_pred) 16 | 17 | 18 | class AveragePrecision(EpochMetric): 19 | """Computes Average Precision accumulating predictions and the ground-truth during an epoch 20 | and applying `sklearn.metrics.average_precision_score `_ . 22 | 23 | Args: 24 | activation (callable, optional): optional function to apply on prediction tensors, 25 | e.g. `activation=torch.sigmoid` to transform logits. 26 | output_transform (callable, optional): a callable that is used to transform the 27 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 28 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 29 | you want to compute the metric with respect to one of the outputs. 30 | 31 | """ 32 | def __init__(self, activation=None, output_transform=lambda x: x): 33 | super(AveragePrecision, self).__init__(partial(average_precision_compute_fn, activation=activation), 34 | output_transform=output_transform) 35 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.contrib.metrics.regression.maximum_absolute_error import MaximumAbsoluteError 2 | from ignite.contrib.metrics.regression.fractional_bias import FractionalBias 3 | from ignite.contrib.metrics.regression.manhattan_distance import ManhattanDistance 4 | from ignite.contrib.metrics.regression.mean_error import MeanError 5 | from ignite.contrib.metrics.regression.mean_normalized_bias import MeanNormalizedBias 6 | from ignite.contrib.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError 7 | from ignite.contrib.metrics.regression.canberra_metric import CanberraMetric 8 | from ignite.contrib.metrics.regression.fractional_absolute_error import FractionalAbsoluteError 9 | from ignite.contrib.metrics.regression.wave_hedges_distance import WaveHedgesDistance 10 | from ignite.contrib.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError 11 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/_base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from ignite.metrics import Metric 3 | 4 | 5 | class _BaseRegression(Metric): 6 | # Base class for all regression metrics 7 | # `update` method check the shapes and call internal overloaded 8 | # method `_update` 9 | 10 | def update(self, output): 11 | y_pred, y = output 12 | if y_pred.shape != y.shape: 13 | raise ValueError("Input data shapes should be the same, but given {} and {}" 14 | .format(y_pred.shape, y.shape)) 15 | 16 | c1 = y_pred.ndimension() == 2 and y_pred.shape[1] == 1 17 | if not (y_pred.ndimension() == 1 or c1): 18 | raise ValueError("Input y_pred should have shape (N,) or (N, 1), but given {}".format(y_pred.shape)) 19 | 20 | c2 = y.ndimension() == 2 and y.shape[1] == 1 21 | if not (y.ndimension() == 1 or c2): 22 | raise ValueError("Input y should have shape (N,) or (N, 1), but given {}".format(y.shape)) 23 | 24 | if c1: 25 | y_pred = y_pred.squeeze(dim=-1) 26 | 27 | if c2: 28 | y = y.squeeze(dim=-1) 29 | 30 | self._update((y_pred, y)) 31 | 32 | @abstractmethod 33 | def _update(self, output): 34 | pass 35 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/canberra_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from ignite.contrib.metrics.regression._base import _BaseRegression 4 | 5 | 6 | class CanberraMetric(_BaseRegression): 7 | r""" 8 | Calculates the Canberra Metric. 9 | 10 | :math:`\text{CM} = \sum _j^n\frac{|A_j - P_j|}{A_j + P_j}` 11 | 12 | where, :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 13 | 14 | More details can be found in `Botchkarev 2018`__. 15 | 16 | - `update` must receive output of the form `(y_pred, y)`. 17 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 18 | 19 | __ https://arxiv.org/abs/1809.03006 20 | """ 21 | 22 | def reset(self): 23 | self._sum_of_errors = 0.0 24 | 25 | def _update(self, output): 26 | y_pred, y = output 27 | errors = torch.abs(y.view_as(y_pred) - y_pred) / (y_pred + y.view_as(y_pred)) 28 | self._sum_of_errors += torch.sum(errors).item() 29 | 30 | def compute(self): 31 | return self._sum_of_errors 32 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/fractional_absolute_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from ignite.exceptions import NotComputableError 4 | from ignite.contrib.metrics.regression._base import _BaseRegression 5 | 6 | 7 | class FractionalAbsoluteError(_BaseRegression): 8 | r""" 9 | Calculates the Fractional Absolute Error. 10 | 11 | :math:`\text{FAE} = \frac{1}{n}\sum _j^n\frac{2 * |A_j - P_j|}{|A_j| + |P_j|}` 12 | 13 | where, :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 14 | 15 | More details can be found in `Botchkarev 2018`__. 16 | 17 | - `update` must receive output of the form `(y_pred, y)`. 18 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 19 | 20 | __ https://arxiv.org/abs/1809.03006 21 | """ 22 | 23 | def reset(self): 24 | self._sum_of_errors = 0.0 25 | self._num_examples = 0 26 | 27 | def _update(self, output): 28 | y_pred, y = output 29 | errors = 2 * torch.abs(y.view_as(y_pred) - y_pred) / (torch.abs(y_pred) + torch.abs(y.view_as(y_pred))) 30 | self._sum_of_errors += torch.sum(errors).item() 31 | self._num_examples += y.shape[0] 32 | 33 | def compute(self): 34 | if self._num_examples == 0: 35 | raise NotComputableError('FractionalAbsoluteError must have at least ' 36 | 'one example before it can be computed.') 37 | return self._sum_of_errors / self._num_examples 38 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/fractional_bias.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.exceptions import NotComputableError 6 | from ignite.contrib.metrics.regression._base import _BaseRegression 7 | 8 | 9 | class FractionalBias(_BaseRegression): 10 | r""" 11 | Calculates the Fractional Bias: 12 | 13 | :math:`\text{FB} = \frac{1}{n}\sum_{j=1}^n\frac{2 * (A_j - P_j)}{A_j + P_j}`, 14 | 15 | where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 16 | 17 | More details can be found in `Botchkarev 2018`__. 18 | 19 | - `update` must receive output of the form `(y_pred, y)`. 20 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 21 | 22 | __ https://arxiv.org/abs/1809.03006 23 | 24 | """ 25 | def reset(self): 26 | self._sum_of_errors = 0.0 27 | self._num_examples = 0 28 | 29 | def _update(self, output): 30 | y_pred, y = output 31 | errors = 2 * (y.view_as(y_pred) - y_pred) / (y_pred + y.view_as(y_pred)) 32 | self._sum_of_errors += torch.sum(errors).item() 33 | self._num_examples += y.shape[0] 34 | 35 | def compute(self): 36 | if self._num_examples == 0: 37 | raise NotComputableError('FractionalBias must have at least one example before it can be computed.') 38 | return self._sum_of_errors / self._num_examples 39 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/geometric_mean_absolute_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from ignite.exceptions import NotComputableError 4 | from ignite.contrib.metrics.regression._base import _BaseRegression 5 | 6 | 7 | class GeometricMeanAbsoluteError(_BaseRegression): 8 | r""" 9 | Calculates the Geometric Mean Absolute Error. 10 | 11 | :math:`\text{GMAE} = exp(\frac{1}{n}\sum_{j=1}^n\ln(|A_j - P_j|)`) 12 | 13 | where, :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 14 | 15 | More details can be found in `Botchkarev 2018`__. 16 | 17 | - `update` must receive output of the form `(y_pred, y)`. 18 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 19 | 20 | __ https://arxiv.org/abs/1809.03006 21 | """ 22 | 23 | def reset(self): 24 | self._sum_of_errors = 0.0 25 | self._num_examples = 0 26 | 27 | def _update(self, output): 28 | y_pred, y = output 29 | errors = torch.log(torch.abs(y.view_as(y_pred) - y_pred)) 30 | self._sum_of_errors += torch.sum(errors) 31 | self._num_examples += y.shape[0] 32 | 33 | def compute(self): 34 | if self._num_examples == 0: 35 | raise NotComputableError('GeometricMeanAbsoluteError must have at ' 36 | 'least one example before it can be computed.') 37 | return torch.exp(self._sum_of_errors / self._num_examples).item() 38 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/manhattan_distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.contrib.metrics.regression._base import _BaseRegression 6 | 7 | 8 | class ManhattanDistance(_BaseRegression): 9 | r""" 10 | Calculates the Manhattan Distance: 11 | 12 | :math:`\text{MD} = \sum_{j=1}^n (A_j - P_j)`, 13 | 14 | where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 15 | 16 | More details can be found in `Botchkarev 2018`__. 17 | 18 | - `update` must receive output of the form `(y_pred, y)`. 19 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 20 | 21 | __ https://arxiv.org/abs/1809.03006 22 | 23 | """ 24 | def reset(self): 25 | self._sum_of_errors = 0.0 26 | 27 | def _update(self, output): 28 | y_pred, y = output 29 | errors = y.view_as(y_pred) - y_pred 30 | self._sum_of_errors += torch.sum(errors).item() 31 | 32 | def compute(self): 33 | return self._sum_of_errors 34 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/maximum_absolute_error.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ignite.exceptions import NotComputableError 4 | from ignite.contrib.metrics.regression._base import _BaseRegression 5 | 6 | 7 | class MaximumAbsoluteError(_BaseRegression): 8 | r""" 9 | Calculates the Maximum Absolute Error: 10 | 11 | :math:`\text{MaxAE} = \max_{j=1,n} \left( \lvert A_j-P_j \rvert \right)`, 12 | 13 | where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 14 | 15 | More details can be found in `Botchkarev 2018`__. 16 | 17 | - `update` must receive output of the form `(y_pred, y)`. 18 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 19 | 20 | __ https://arxiv.org/abs/1809.03006 21 | 22 | """ 23 | 24 | def reset(self): 25 | self._max_of_absolute_errors = -1 26 | 27 | def _update(self, output): 28 | y_pred, y = output 29 | mae = torch.abs(y_pred - y.view_as(y_pred)).max().item() 30 | if self._max_of_absolute_errors < mae: 31 | self._max_of_absolute_errors = mae 32 | 33 | def compute(self): 34 | if self._max_of_absolute_errors < 0: 35 | raise NotComputableError('MaximumAbsoluteError must have at least one example before it can be computed.') 36 | return self._max_of_absolute_errors 37 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/mean_absolute_relative_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.exceptions import NotComputableError 6 | from ignite.contrib.metrics.regression._base import _BaseRegression 7 | 8 | 9 | class MeanAbsoluteRelativeError(_BaseRegression): 10 | r""" 11 | Calculate Mean Absolute Relative Error: 12 | 13 | :math:`\text{MARE} = \frac{1}{n}\sum_{j=1}^n\frac{\left|A_j-P_j\right|}{\left|A_j\right|}`, 14 | 15 | where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 16 | 17 | More details can be found in the reference `Botchkarev 2018`__. 18 | 19 | - `update` must receive output of the form `(y_pred, y)`. 20 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 21 | 22 | __ https://arxiv.org/ftp/arxiv/papers/1809/1809.03006.pdf 23 | 24 | """ 25 | 26 | def reset(self): 27 | self._sum_of_absolute_relative_errors = 0.0 28 | self._num_samples = 0 29 | 30 | def _update(self, output): 31 | y_pred, y = output 32 | if (y == 0).any(): 33 | raise NotComputableError('The ground truth has 0.') 34 | absolute_error = torch.abs(y_pred - y.view_as(y_pred)) / torch.abs(y.view_as(y_pred)) 35 | self._sum_of_absolute_relative_errors += torch.sum(absolute_error).item() 36 | self._num_samples += y.size()[0] 37 | 38 | def compute(self): 39 | if self._num_samples == 0: 40 | raise NotComputableError('MeanAbsoluteRelativeError must have at least' 41 | 'one sample before it can be computed.') 42 | return self._sum_of_absolute_relative_errors / self._num_samples 43 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/mean_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.exceptions import NotComputableError 6 | from ignite.contrib.metrics.regression._base import _BaseRegression 7 | 8 | 9 | class MeanError(_BaseRegression): 10 | r""" 11 | Calculates the Mean Error: 12 | 13 | :math:`\text{ME} = \frac{1}{n}\sum_{j=1}^n (A_j - P_j)`, 14 | 15 | where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 16 | 17 | More details can be found in the reference `Botchkarev 2018`__. 18 | 19 | - `update` must receive output of the form `(y_pred, y)`. 20 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 21 | 22 | __ https://arxiv.org/abs/1809.03006 23 | 24 | """ 25 | def reset(self): 26 | self._sum_of_errors = 0.0 27 | self._num_examples = 0 28 | 29 | def _update(self, output): 30 | y_pred, y = output 31 | errors = (y.view_as(y_pred) - y_pred) 32 | self._sum_of_errors += torch.sum(errors).item() 33 | self._num_examples += y.shape[0] 34 | 35 | def compute(self): 36 | if self._num_examples == 0: 37 | raise NotComputableError('MeanError must have at least one example before it can be computed.') 38 | return self._sum_of_errors / self._num_examples 39 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/mean_normalized_bias.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.exceptions import NotComputableError 6 | from ignite.contrib.metrics.regression._base import _BaseRegression 7 | 8 | 9 | class MeanNormalizedBias(_BaseRegression): 10 | r""" 11 | Calculates the Mean Normalized Bias: 12 | 13 | :math:`\text{MNB} = \frac{1}{n}\sum_{j=1}^n\frac{A_j - P_j}{A_j}`, 14 | 15 | where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 16 | 17 | More details can be found in the reference `Botchkarev 2018`__. 18 | 19 | - `update` must receive output of the form `(y_pred, y)`. 20 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 21 | 22 | __ https://arxiv.org/abs/1809.03006 23 | 24 | """ 25 | def reset(self): 26 | self._sum_of_errors = 0.0 27 | self._num_examples = 0 28 | 29 | def _update(self, output): 30 | y_pred, y = output 31 | 32 | if (y == 0).any(): 33 | raise NotComputableError('The ground truth has 0.') 34 | 35 | errors = (y.view_as(y_pred) - y_pred) / y 36 | self._sum_of_errors += torch.sum(errors).item() 37 | self._num_examples += y.shape[0] 38 | 39 | def compute(self): 40 | if self._num_examples == 0: 41 | raise NotComputableError('MeanNormalizedBias must have at least one example before it can be computed.') 42 | return self._sum_of_errors / self._num_examples 43 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/regression/wave_hedges_distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from ignite.contrib.metrics.regression._base import _BaseRegression 4 | 5 | 6 | class WaveHedgesDistance(_BaseRegression): 7 | r""" 8 | Calculates the Wave Hedges Distance. 9 | 10 | :math:`\text{WHD} = \sum _j^n\frac{|A_j - P_j|}{max(A_j, P_j)}`, 11 | where, :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. 12 | 13 | More details can be found in `Botchkarev 2018`__. 14 | 15 | - `update` must receive output of the form `(y_pred, y)`. 16 | - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. 17 | 18 | __ https://arxiv.org/abs/1809.03006 19 | """ 20 | 21 | def reset(self): 22 | self._sum_of_errors = 0.0 23 | 24 | def _update(self, output): 25 | y_pred, y = output 26 | errors = torch.abs(y.view_as(y_pred) - y_pred) / torch.max(y_pred, y.view_as(y_pred)) 27 | self._sum_of_errors += torch.sum(errors).item() 28 | 29 | def compute(self): 30 | return self._sum_of_errors 31 | -------------------------------------------------------------------------------- /ignite/contrib/metrics/roc_auc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from ignite.metrics import EpochMetric 3 | 4 | 5 | def roc_auc_compute_fn(y_preds, y_targets, activation=None): 6 | try: 7 | from sklearn.metrics import roc_auc_score 8 | except ImportError: 9 | raise RuntimeError("This contrib module requires sklearn to be installed.") 10 | 11 | y_true = y_targets.numpy() 12 | if activation is not None: 13 | y_preds = activation(y_preds) 14 | y_pred = y_preds.numpy() 15 | return roc_auc_score(y_true, y_pred) 16 | 17 | 18 | class ROC_AUC(EpochMetric): 19 | """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) 20 | accumulating predictions and the ground-truth during an epoch and applying 21 | `sklearn.metrics.roc_auc_score `_ . 23 | 24 | Args: 25 | activation (callable, optional): optional function to apply on prediction tensors, 26 | e.g. `activation=torch.sigmoid` to transform logits. 27 | output_transform (callable, optional): a callable that is used to transform the 28 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 29 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 30 | you want to compute the metric with respect to one of the outputs. 31 | 32 | """ 33 | def __init__(self, activation=None, output_transform=lambda x: x): 34 | super(ROC_AUC, self).__init__(partial(roc_auc_compute_fn, activation=activation), 35 | output_transform=output_transform) 36 | -------------------------------------------------------------------------------- /ignite/engine/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ignite.engine.engine import Engine, State, Events 4 | from ignite.utils import convert_tensor 5 | 6 | 7 | def _prepare_batch(batch, device=None, non_blocking=False): 8 | """Prepare batch for training: pass to a device with options. 9 | 10 | """ 11 | x, y = batch 12 | return (convert_tensor(x, device=device, non_blocking=non_blocking), 13 | convert_tensor(y, device=device, non_blocking=non_blocking)) 14 | 15 | 16 | def create_supervised_trainer(model, optimizer, loss_fn, 17 | device=None, non_blocking=False, 18 | prepare_batch=_prepare_batch): 19 | """ 20 | Factory function for creating a trainer for supervised models. 21 | 22 | Args: 23 | model (`torch.nn.Module`): the model to train. 24 | optimizer (`torch.optim.Optimizer`): the optimizer to use. 25 | loss_fn (torch.nn loss function): the loss function to use. 26 | device (str, optional): device type specification (default: None). 27 | Applies to both model and batches. 28 | non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously 29 | with respect to the host. For other cases, this argument has no effect. 30 | prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs 31 | tuple of tensors `(batch_x, batch_y)`. 32 | 33 | Note: `engine.state.output` for this engine is the loss of the processed batch. 34 | 35 | Returns: 36 | Engine: a trainer engine with supervised update function. 37 | """ 38 | if device: 39 | model.to(device) 40 | 41 | def _update(engine, batch): 42 | model.train() 43 | optimizer.zero_grad() 44 | x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) 45 | y_pred = model(x) 46 | loss = loss_fn(y_pred, y) 47 | loss.backward() 48 | optimizer.step() 49 | return loss.item() 50 | 51 | return Engine(_update) 52 | 53 | 54 | def create_supervised_evaluator(model, metrics={}, 55 | device=None, non_blocking=False, 56 | prepare_batch=_prepare_batch): 57 | """ 58 | Factory function for creating an evaluator for supervised models. 59 | 60 | Args: 61 | model (`torch.nn.Module`): the model to train. 62 | metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics. 63 | device (str, optional): device type specification (default: None). 64 | Applies to both model and batches. 65 | non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously 66 | with respect to the host. For other cases, this argument has no effect. 67 | prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs 68 | tuple of tensors `(batch_x, batch_y)`. 69 | 70 | Note: `engine.state.output` for this engine is a tuple of `(batch_pred, batch_y)`. 71 | 72 | Returns: 73 | Engine: an evaluator engine with supervised inference function. 74 | """ 75 | if device: 76 | model.to(device) 77 | 78 | def _inference(engine, batch): 79 | model.eval() 80 | with torch.no_grad(): 81 | x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) 82 | y_pred = model(x) 83 | return y_pred, y 84 | 85 | engine = Engine(_inference) 86 | 87 | for name, metric in metrics.items(): 88 | metric.attach(engine, name) 89 | 90 | return engine 91 | -------------------------------------------------------------------------------- /ignite/engine/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/engine/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/engine/__pycache__/engine.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/engine/__pycache__/engine.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/exceptions.py: -------------------------------------------------------------------------------- 1 | class NotComputableError(RuntimeError): 2 | """ 3 | Exception class to raise if Metric cannot be computed. 4 | """ 5 | -------------------------------------------------------------------------------- /ignite/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.handlers.checkpoint import ModelCheckpoint 2 | from ignite.handlers.timing import Timer 3 | from ignite.handlers.early_stopping import EarlyStopping 4 | from ignite.handlers.terminate_on_nan import TerminateOnNan 5 | -------------------------------------------------------------------------------- /ignite/handlers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/handlers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/handlers/__pycache__/checkpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/handlers/__pycache__/checkpoint.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/handlers/__pycache__/early_stopping.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/handlers/__pycache__/early_stopping.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/handlers/__pycache__/terminate_on_nan.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/handlers/__pycache__/terminate_on_nan.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/handlers/__pycache__/timing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/handlers/__pycache__/timing.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/handlers/early_stopping.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ignite.engine import Engine 4 | 5 | 6 | class EarlyStopping(object): 7 | """EarlyStopping handler can be used to stop the training if no improvement after a given number of events. 8 | 9 | Args: 10 | patience (int): 11 | Number of events to wait if no improvement and then stop the training. 12 | score_function (callable): 13 | It should be a function taking a single argument, an :class:`~ignite.engine.Engine` object, 14 | and return a score `float`. An improvement is considered if the score is higher. 15 | trainer (Engine): 16 | trainer engine to stop the run if no improvement. 17 | 18 | Examples: 19 | 20 | .. code-block:: python 21 | 22 | from ignite.engine import Engine, Events 23 | from ignite.handlers import EarlyStopping 24 | 25 | def score_function(engine): 26 | val_loss = engine.state.metrics['nll'] 27 | return -val_loss 28 | 29 | handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) 30 | # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). 31 | evaluator.add_event_handler(Events.COMPLETED, handler) 32 | 33 | """ 34 | def __init__(self, patience, score_function, trainer): 35 | 36 | if not callable(score_function): 37 | raise TypeError("Argument score_function should be a function.") 38 | 39 | if patience < 1: 40 | raise ValueError("Argument patience should be positive integer.") 41 | 42 | if not isinstance(trainer, Engine): 43 | raise TypeError("Argument trainer should be an instance of Engine.") 44 | 45 | self.score_function = score_function 46 | self.patience = patience 47 | self.trainer = trainer 48 | self.counter = 0 49 | self.best_score = None 50 | self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) 51 | self._logger.addHandler(logging.NullHandler()) 52 | 53 | def __call__(self, engine): 54 | score = self.score_function(engine) 55 | 56 | if self.best_score is None: 57 | self.best_score = score 58 | elif score < self.best_score: 59 | self.counter += 1 60 | self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) 61 | if self.counter >= self.patience: 62 | self._logger.info("EarlyStopping: Stop training") 63 | self.trainer.terminate() 64 | else: 65 | self.best_score = score 66 | self.counter = 0 67 | -------------------------------------------------------------------------------- /ignite/handlers/terminate_on_nan.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numbers 3 | 4 | import torch 5 | 6 | from ignite.utils import apply_to_type 7 | 8 | 9 | class TerminateOnNan(object): 10 | """TerminateOnNan handler can be used to stop the training if the `process_function`'s output 11 | contains a NaN or infinite number or `torch.tensor`. 12 | The output can be of type: number, tensor or collection of them. The training is stopped if 13 | there is at least a single number/tensor have NaN or Infinite value. For example, if the output is 14 | `[1.23, torch.tensor(...), torch.tensor(float('nan'))]` the handler will stop the training. 15 | 16 | Args: 17 | output_transform (callable, optional): a callable that is used to transform the 18 | :class:`~ignite.engine.Engine`'s `process_function`'s output into a number or `torch.tensor` 19 | or collection of them. This can be useful if, for example, you have a multi-output model and 20 | you want to check one or multiple values of the output. 21 | 22 | 23 | Examples: 24 | 25 | .. code-block:: python 26 | 27 | trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) 28 | 29 | """ 30 | 31 | def __init__(self, output_transform=lambda x: x): 32 | self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) 33 | self._logger.addHandler(logging.StreamHandler()) 34 | self._output_transform = output_transform 35 | 36 | def __call__(self, engine): 37 | output = self._output_transform(engine.state.output) 38 | 39 | def raise_error(x): 40 | 41 | if isinstance(x, numbers.Number): 42 | x = torch.tensor(x) 43 | 44 | if isinstance(x, torch.Tensor) and not bool(torch.isfinite(x).all()): 45 | raise RuntimeError("Infinite or NaN tensor found.") 46 | 47 | try: 48 | apply_to_type(output, (numbers.Number, torch.Tensor), raise_error) 49 | except RuntimeError: 50 | self._logger.warning("{}: Output '{}' contains NaN or Inf. Stop training" 51 | .format(self.__class__.__name__, output)) 52 | engine.terminate() 53 | -------------------------------------------------------------------------------- /ignite/handlers/timing.py: -------------------------------------------------------------------------------- 1 | from ignite.engine import Events 2 | 3 | try: 4 | from time import perf_counter 5 | except ImportError: 6 | from time import time as perf_counter 7 | 8 | 9 | class Timer: 10 | """ Timer object can be used to measure (average) time between events. 11 | 12 | Args: 13 | average (bool, optional): if True, then when ``.value()`` method is called, the returned value 14 | will be equal to total time measured, divided by the value of internal counter. 15 | 16 | Attributes: 17 | total (float): total time elapsed when the Timer was running (in seconds). 18 | step_count (int): internal counter, usefull to measure average time, e.g. of processing a single batch. 19 | Incremented with the ``.step()`` method. 20 | running (bool): flag indicating if timer is measuring time. 21 | 22 | Notes: 23 | When using ``Timer(average=True)`` do not forget to call ``timer.step()`` everytime an event occurs. See 24 | the examples below. 25 | 26 | Examples: 27 | 28 | Measuring total time of the epoch: 29 | 30 | >>> from ignite.handlers import Timer 31 | >>> import time 32 | >>> work = lambda : time.sleep(0.1) 33 | >>> idle = lambda : time.sleep(0.1) 34 | >>> t = Timer(average=False) 35 | >>> for _ in range(10): 36 | ... work() 37 | ... idle() 38 | ... 39 | >>> t.value() 40 | 2.003073937026784 41 | 42 | Measuring average time of the epoch: 43 | 44 | >>> t = Timer(average=True) 45 | >>> for _ in range(10): 46 | ... work() 47 | ... idle() 48 | ... t.step() 49 | ... 50 | >>> t.value() 51 | 0.2003182829997968 52 | 53 | Measuring average time it takes to execute a single ``work()`` call: 54 | 55 | >>> t = Timer(average=True) 56 | >>> for _ in range(10): 57 | ... t.resume() 58 | ... work() 59 | ... t.pause() 60 | ... idle() 61 | ... t.step() 62 | ... 63 | >>> t.value() 64 | 0.10016545779653825 65 | 66 | Using the Timer to measure average time it takes to process a single batch of examples: 67 | 68 | >>> from ignite.engine import Engine, Events 69 | >>> from ignite.handlers import Timer 70 | >>> trainer = Engine(training_update_function) 71 | >>> timer = Timer(average=True) 72 | >>> timer.attach(trainer, 73 | ... start=Events.EPOCH_STARTED, 74 | ... resume=Events.ITERATION_STARTED, 75 | ... pause=Events.ITERATION_COMPLETED, 76 | ... step=Events.ITERATION_COMPLETED) 77 | """ 78 | 79 | def __init__(self, average=False): 80 | self._average = average 81 | self._t0 = perf_counter() 82 | 83 | self.total = 0. 84 | self.step_count = 0. 85 | self.running = True 86 | 87 | def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None): 88 | """ Register callbacks to control the timer. 89 | 90 | Args: 91 | engine (Engine): 92 | Engine that this timer will be attached to. 93 | start (Events): 94 | Event which should start (reset) the timer. 95 | pause (Events): 96 | Event which should pause the timer. 97 | resume (Events, optional): 98 | Event which should resume the timer. 99 | step (Events, optional): 100 | Event which should call the `step` method of the counter. 101 | 102 | Returns: 103 | self (Timer) 104 | 105 | """ 106 | 107 | engine.add_event_handler(start, self.reset) 108 | engine.add_event_handler(pause, self.pause) 109 | 110 | if resume is not None: 111 | engine.add_event_handler(resume, self.resume) 112 | 113 | if step is not None: 114 | engine.add_event_handler(step, self.step) 115 | 116 | return self 117 | 118 | def reset(self, *args): 119 | self.__init__(self._average) 120 | return self 121 | 122 | def pause(self, *args): 123 | if self.running: 124 | self.total += self._elapsed() 125 | self.running = False 126 | 127 | def resume(self, *args): 128 | if not self.running: 129 | self.running = True 130 | self._t0 = perf_counter() 131 | 132 | def value(self): 133 | total = self.total 134 | if self.running: 135 | total += self._elapsed() 136 | 137 | if self._average: 138 | denominator = max(self.step_count, 1.) 139 | else: 140 | denominator = 1. 141 | 142 | return total / denominator 143 | 144 | def step(self, *args): 145 | self.step_count += 1. 146 | 147 | def _elapsed(self): 148 | return perf_counter() - self._t0 149 | -------------------------------------------------------------------------------- /ignite/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.binary_accuracy import BinaryAccuracy 2 | from ignite.metrics.categorical_accuracy import CategoricalAccuracy 3 | from ignite.metrics.accuracy import Accuracy 4 | from ignite.metrics.loss import Loss 5 | from ignite.metrics.mean_absolute_error import MeanAbsoluteError 6 | from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance 7 | from ignite.metrics.mean_squared_error import MeanSquaredError 8 | from ignite.metrics.metric import Metric 9 | from ignite.metrics.epoch_metric import EpochMetric 10 | from ignite.metrics.precision import Precision 11 | from ignite.metrics.recall import Recall 12 | from ignite.metrics.root_mean_squared_error import RootMeanSquaredError 13 | from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy 14 | from ignite.metrics.running_average import RunningAverage 15 | from ignite.metrics.metrics_lambda import MetricsLambda 16 | -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/accuracy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/accuracy.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/binary_accuracy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/binary_accuracy.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/categorical_accuracy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/categorical_accuracy.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/epoch_metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/epoch_metric.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/mean_absolute_error.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/mean_absolute_error.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/mean_pairwise_distance.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/mean_pairwise_distance.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/mean_squared_error.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/mean_squared_error.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/metrics_lambda.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/metrics_lambda.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/precision.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/precision.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/recall.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/recall.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/root_mean_squared_error.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/root_mean_squared_error.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/running_average.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/running_average.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/__pycache__/top_k_categorical_accuracy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/ignite/metrics/__pycache__/top_k_categorical_accuracy.cpython-36.pyc -------------------------------------------------------------------------------- /ignite/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.metrics.metric import Metric 6 | from ignite.exceptions import NotComputableError 7 | 8 | 9 | class _BaseClassification(Metric): 10 | 11 | def __init__(self, output_transform=lambda x: x, is_multilabel=False): 12 | self._is_multilabel = is_multilabel 13 | self._type = None 14 | super(_BaseClassification, self).__init__(output_transform=output_transform) 15 | 16 | def _check_shape(self, output): 17 | y_pred, y = output 18 | 19 | if y.ndimension() > 1 and y.shape[1] == 1: 20 | # (N, 1, ...) -> (N, ...) 21 | y = y.squeeze(dim=1) 22 | 23 | if y_pred.ndimension() > 1 and y_pred.shape[1] == 1: 24 | # (N, 1, ...) -> (N, ...) 25 | y_pred = y_pred.squeeze(dim=1) 26 | 27 | if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()): 28 | raise ValueError("y must have shape of (batch_size, ...) and y_pred must have " 29 | "shape of (batch_size, num_categories, ...) or (batch_size, ...), " 30 | "but given {} vs {}.".format(y.shape, y_pred.shape)) 31 | 32 | y_shape = y.shape 33 | y_pred_shape = y_pred.shape 34 | 35 | if y.ndimension() + 1 == y_pred.ndimension(): 36 | y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:] 37 | 38 | if not (y_shape == y_pred_shape): 39 | raise ValueError("y and y_pred must have compatible shapes.") 40 | 41 | if self._is_multilabel and not (y.shape == y_pred.shape and y.ndimension() > 1 and y.shape[1] != 1): 42 | raise ValueError("y and y_pred must have same shape of (batch_size, num_categories, ...).") 43 | 44 | return y_pred, y 45 | 46 | def _check_type(self, output): 47 | y_pred, y = output 48 | 49 | if y.ndimension() + 1 == y_pred.ndimension(): 50 | update_type = "multiclass" 51 | elif y.ndimension() == y_pred.ndimension(): 52 | if not torch.equal(y, y ** 2): 53 | raise ValueError("For binary cases, y must be comprised of 0's and 1's.") 54 | 55 | if not torch.equal(y_pred, y_pred ** 2): 56 | raise ValueError("For binary cases, y_pred must be comprised of 0's and 1's.") 57 | 58 | if self._is_multilabel: 59 | update_type = "multilabel" 60 | else: 61 | update_type = "binary" 62 | else: 63 | raise RuntimeError("Invalid shapes of y (shape={}) and y_pred (shape={}), check documentation." 64 | " for expected shapes of y and y_pred.".format(y.shape, y_pred.shape)) 65 | if self._type is None: 66 | self._type = update_type 67 | else: 68 | if self._type != update_type: 69 | raise RuntimeError("Input data type has changed from {} to {}.".format(self._type, update_type)) 70 | 71 | 72 | class Accuracy(_BaseClassification): 73 | """ 74 | Calculates the accuracy for binary, multiclass and multilabel data. 75 | 76 | - `update` must receive output of the form `(y_pred, y)`. 77 | - `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...). 78 | - `y` must be in the following shape (batch_size, ...). 79 | - `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) for multilabel cases. 80 | 81 | In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of 82 | predictions can be done as below: 83 | 84 | .. code-block:: python 85 | 86 | def thresholded_output_transform(output): 87 | y_pred, y = output 88 | y_pred = torch.round(y_pred) 89 | return y_pred, y 90 | 91 | binary_accuracy = Accuracy(thresholded_output_transform) 92 | 93 | 94 | Args: 95 | output_transform (callable, optional): a callable that is used to transform the 96 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 97 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 98 | you want to compute the metric with respect to one of the outputs. 99 | is_multilabel (bool, optional): flag to use in multilabel case. By default, False. 100 | """ 101 | 102 | def __init__(self, output_transform=lambda x: x, is_multilabel=False): 103 | super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel) 104 | 105 | def reset(self): 106 | self._num_correct = 0 107 | self._num_examples = 0 108 | 109 | def update(self, output): 110 | 111 | y_pred, y = self._check_shape(output) 112 | self._check_type((y_pred, y)) 113 | 114 | if self._type == "binary": 115 | correct = torch.eq(y_pred.type(y.type()), y).view(-1) 116 | elif self._type == "multiclass": 117 | indices = torch.max(y_pred, dim=1)[1] 118 | correct = torch.eq(indices, y).view(-1) 119 | elif self._type == "multilabel": 120 | # if y, y_pred shape is (N, C, ...) -> (N x ..., C) 121 | num_classes = y_pred.size(1) 122 | last_dim = y_pred.ndimension() 123 | y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes) 124 | y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) 125 | correct = torch.all(y == y_pred.type_as(y), dim=-1) 126 | 127 | self._num_correct += torch.sum(correct).item() 128 | self._num_examples += correct.shape[0] 129 | 130 | def compute(self): 131 | if self._num_examples == 0: 132 | raise NotComputableError('Accuracy must have at least one example before it can be computed.') 133 | return self._num_correct / self._num_examples 134 | -------------------------------------------------------------------------------- /ignite/metrics/binary_accuracy.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import warnings 3 | 4 | from ignite.metrics.accuracy import Accuracy 5 | 6 | 7 | class BinaryAccuracy(Accuracy): 8 | """ 9 | Note: This metric is deprecated in favor of :class:`~ignite.metrics.Accuracy`. 10 | """ 11 | def __init__(self, *args, **kwargs): 12 | warnings.warn("The use of ignite.metrics.BinaryAccuracy is deprecated, it will be " 13 | "removed in 0.2.0. Please use ignite.metrics.Accuracy instead.", DeprecationWarning) 14 | super(Accuracy, self).__init__(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /ignite/metrics/categorical_accuracy.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import warnings 3 | 4 | from ignite.metrics.accuracy import Accuracy 5 | 6 | 7 | class CategoricalAccuracy(Accuracy): 8 | """ 9 | Note: This metric is deprecated in favor of :class:`~ignite.metrics.Accuracy`. 10 | """ 11 | def __init__(self, *args, **kwargs): 12 | warnings.warn("The use of ignite.metrics.CategoricalAccuracy is deprecated, it will be " 13 | "removed in 0.2.0. Please use ignite.metrics.Accuracy instead.", DeprecationWarning) 14 | super(Accuracy, self).__init__(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /ignite/metrics/epoch_metric.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | from ignite.metrics.metric import Metric 6 | 7 | 8 | class EpochMetric(Metric): 9 | """Class for metrics that should be computed on the entire output history of a model. 10 | Model's output and targets are restricted to be of shape `(batch_size, n_classes)`. Output 11 | datatype should be `float32`. Target datatype should be `long`. 12 | 13 | .. warning:: 14 | 15 | Current implementation stores all input data (output and target) in as tensors before computing a metric. 16 | This can potentially lead to a memory error if the input data is larger than available RAM. 17 | 18 | 19 | - `update` must receive output of the form `(y_pred, y)`. 20 | 21 | If target shape is `(batch_size, n_classes)` and `n_classes > 1` than it should be binary: e.g. `[[0, 1, 0, 1], ]`. 22 | 23 | Args: 24 | compute_fn (callable): a callable with the signature (`torch.tensor`, `torch.tensor`) takes as the input 25 | `predictions` and `targets` and returns a scalar. 26 | output_transform (callable, optional): a callable that is used to transform the 27 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 28 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 29 | you want to compute the metric with respect to one of the outputs. 30 | 31 | """ 32 | 33 | def __init__(self, compute_fn, output_transform=lambda x: x): 34 | 35 | if not callable(compute_fn): 36 | raise TypeError("Argument compute_fn should be callable.") 37 | 38 | super(EpochMetric, self).__init__(output_transform=output_transform) 39 | self.compute_fn = compute_fn 40 | 41 | def reset(self): 42 | self._predictions = torch.tensor([], dtype=torch.float32) 43 | self._targets = torch.tensor([], dtype=torch.long) 44 | 45 | def update(self, output): 46 | y_pred, y = output 47 | 48 | if y_pred.ndimension() not in (1, 2): 49 | raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).") 50 | 51 | if y.ndimension() not in (1, 2): 52 | raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).") 53 | 54 | if y.ndimension() == 2: 55 | if not torch.equal(y ** 2, y): 56 | raise ValueError("Targets should be binary (0 or 1).") 57 | 58 | if y_pred.ndimension() == 2 and y_pred.shape[1] == 1: 59 | y_pred = y_pred.squeeze(dim=-1) 60 | 61 | if y.ndimension() == 2 and y.shape[1] == 1: 62 | y = y.squeeze(dim=-1) 63 | 64 | y_pred = y_pred.type_as(self._predictions) 65 | y = y.type_as(self._targets) 66 | 67 | self._predictions = torch.cat([self._predictions, y_pred], dim=0) 68 | self._targets = torch.cat([self._targets, y], dim=0) 69 | 70 | # Check once the signature and execution of compute_fn 71 | if self._predictions.shape == y_pred.shape: 72 | try: 73 | self.compute_fn(self._predictions, self._targets) 74 | except Exception as e: 75 | warnings.warn("Probably, there can be a problem with `compute_fn`:\n {}.".format(e), 76 | RuntimeWarning) 77 | 78 | def compute(self): 79 | return self.compute_fn(self._predictions, self._targets) 80 | -------------------------------------------------------------------------------- /ignite/metrics/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from ignite.exceptions import NotComputableError 4 | from ignite.metrics.metric import Metric 5 | 6 | 7 | class Loss(Metric): 8 | """ 9 | Calculates the average loss according to the passed loss_fn. 10 | 11 | Args: 12 | loss_fn (callable): a callable taking a prediction tensor, a target 13 | tensor, optionally other arguments, and returns the average loss 14 | over all observations in the batch. 15 | output_transform (callable): a callable that is used to transform the 16 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 17 | form expected by the metric. 18 | This can be useful if, for example, you have a multi-output model and 19 | you want to compute the metric with respect to one of the outputs. 20 | The output is is expected to be a tuple (prediction, target) or 21 | (prediction, target, kwargs) where kwargs is a dictionary of extra 22 | keywords arguments. 23 | batch_size (callable): a callable taking a target tensor that returns the 24 | first dimension size (usually the batch size). 25 | 26 | """ 27 | 28 | def __init__(self, loss_fn, output_transform=lambda x: x, 29 | batch_size=lambda x: x.shape[0]): 30 | super(Loss, self).__init__(output_transform) 31 | self._loss_fn = loss_fn 32 | self._batch_size = batch_size 33 | 34 | def reset(self): 35 | self._sum = 0 36 | self._num_examples = 0 37 | 38 | def update(self, output): 39 | if len(output) == 2: 40 | y_pred, y = output 41 | kwargs = {} 42 | else: 43 | y_pred, y, kwargs = output 44 | average_loss = self._loss_fn(y_pred, y, **kwargs) 45 | 46 | if len(average_loss.shape) != 0: 47 | raise ValueError('loss_fn did not return the average loss.') 48 | 49 | N = self._batch_size(y) 50 | self._sum += average_loss.item() * N 51 | self._num_examples += N 52 | 53 | def compute(self): 54 | if self._num_examples == 0: 55 | raise NotComputableError( 56 | 'Loss must have at least one example before it can be computed.') 57 | return self._sum / self._num_examples 58 | -------------------------------------------------------------------------------- /ignite/metrics/mean_absolute_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.exceptions import NotComputableError 6 | from ignite.metrics.metric import Metric 7 | 8 | 9 | class MeanAbsoluteError(Metric): 10 | """ 11 | Calculates the mean absolute error. 12 | 13 | - `update` must receive output of the form `(y_pred, y)`. 14 | """ 15 | def reset(self): 16 | self._sum_of_absolute_errors = 0.0 17 | self._num_examples = 0 18 | 19 | def update(self, output): 20 | y_pred, y = output 21 | absolute_errors = torch.abs(y_pred - y.view_as(y_pred)) 22 | self._sum_of_absolute_errors += torch.sum(absolute_errors).item() 23 | self._num_examples += y.shape[0] 24 | 25 | def compute(self): 26 | if self._num_examples == 0: 27 | raise NotComputableError('MeanAbsoluteError must have at least one example before it can be computed.') 28 | return self._sum_of_absolute_errors / self._num_examples 29 | -------------------------------------------------------------------------------- /ignite/metrics/mean_pairwise_distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | from torch.nn.functional import pairwise_distance 5 | 6 | from ignite.exceptions import NotComputableError 7 | from ignite.metrics.metric import Metric 8 | 9 | 10 | class MeanPairwiseDistance(Metric): 11 | """ 12 | Calculates the mean pairwise distance. 13 | 14 | - `update` must receive output of the form `(y_pred, y)`. 15 | """ 16 | def __init__(self, p=2, eps=1e-6, output_transform=lambda x: x): 17 | super(MeanPairwiseDistance, self).__init__(output_transform) 18 | self._p = p 19 | self._eps = eps 20 | 21 | def reset(self): 22 | self._sum_of_distances = 0.0 23 | self._num_examples = 0 24 | 25 | def update(self, output): 26 | y_pred, y = output 27 | distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps) 28 | self._sum_of_distances += torch.sum(distances).item() 29 | self._num_examples += y.shape[0] 30 | 31 | def compute(self): 32 | if self._num_examples == 0: 33 | raise NotComputableError('MeanAbsoluteError must have at least one example before it can be computed.') 34 | return self._sum_of_distances / self._num_examples 35 | -------------------------------------------------------------------------------- /ignite/metrics/mean_squared_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.exceptions import NotComputableError 6 | from ignite.metrics.metric import Metric 7 | 8 | 9 | class MeanSquaredError(Metric): 10 | """ 11 | Calculates the mean squared error. 12 | 13 | - `update` must receive output of the form `(y_pred, y)`. 14 | """ 15 | def reset(self): 16 | self._sum_of_squared_errors = 0.0 17 | self._num_examples = 0 18 | 19 | def update(self, output): 20 | y_pred, y = output 21 | squared_errors = torch.pow(y_pred - y.view_as(y_pred), 2) 22 | self._sum_of_squared_errors += torch.sum(squared_errors).item() 23 | self._num_examples += y.shape[0] 24 | 25 | def compute(self): 26 | if self._num_examples == 0: 27 | raise NotComputableError('MeanSquaredError must have at least one example before it can be computed.') 28 | return self._sum_of_squared_errors / self._num_examples 29 | -------------------------------------------------------------------------------- /ignite/metrics/metric.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from ignite._six import with_metaclass 3 | from ignite.engine import Events 4 | import torch 5 | 6 | 7 | class Metric(with_metaclass(ABCMeta, object)): 8 | """ 9 | Base class for all Metrics. 10 | 11 | Args: 12 | output_transform (callable, optional): a callable that is used to transform the 13 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 14 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 15 | you want to compute the metric with respect to one of the outputs. 16 | 17 | """ 18 | 19 | def __init__(self, output_transform=lambda x: x): 20 | self._output_transform = output_transform 21 | self.reset() 22 | 23 | @abstractmethod 24 | def reset(self): 25 | """ 26 | Resets the metric to it's initial state. 27 | 28 | This is called at the start of each epoch. 29 | """ 30 | pass 31 | 32 | @abstractmethod 33 | def update(self, output): 34 | """ 35 | Updates the metric's state using the passed batch output. 36 | 37 | This is called once for each batch. 38 | 39 | Args: 40 | output: the is the output from the engine's process function. 41 | """ 42 | pass 43 | 44 | @abstractmethod 45 | def compute(self): 46 | """ 47 | Computes the metric based on it's accumulated state. 48 | 49 | This is called at the end of each epoch. 50 | 51 | Returns: 52 | Any: the actual quantity of interest. 53 | 54 | Raises: 55 | NotComputableError: raised when the metric cannot be computed. 56 | """ 57 | pass 58 | 59 | def started(self, engine): 60 | self.reset() 61 | 62 | @torch.no_grad() 63 | def iteration_completed(self, engine): 64 | output = self._output_transform(engine.state.output) 65 | self.update(output) 66 | 67 | def completed(self, engine, name): 68 | if '|$^' not in name: 69 | engine.state.metrics[name] = self.compute() 70 | 71 | def attach(self, engine, name): 72 | engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) 73 | if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): 74 | engine.add_event_handler(Events.EPOCH_STARTED, self.started) 75 | if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): 76 | engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) 77 | 78 | def __add__(self, other): 79 | from ignite.metrics import MetricsLambda 80 | return MetricsLambda(lambda x, y: x + y, self, other) 81 | 82 | def __radd__(self, other): 83 | from ignite.metrics import MetricsLambda 84 | return MetricsLambda(lambda x, y: x + y, other, self) 85 | 86 | def __sub__(self, other): 87 | from ignite.metrics import MetricsLambda 88 | return MetricsLambda(lambda x, y: x - y, self, other) 89 | 90 | def __rsub__(self, other): 91 | from ignite.metrics import MetricsLambda 92 | return MetricsLambda(lambda x, y: x - y, other, self) 93 | 94 | def __mul__(self, other): 95 | from ignite.metrics import MetricsLambda 96 | return MetricsLambda(lambda x, y: x * y, self, other) 97 | 98 | def __rmul__(self, other): 99 | from ignite.metrics import MetricsLambda 100 | return MetricsLambda(lambda x, y: x * y, other, self) 101 | 102 | def __pow__(self, other): 103 | from ignite.metrics import MetricsLambda 104 | return MetricsLambda(lambda x, y: x ** y, self, other) 105 | 106 | def __rpow__(self, other): 107 | from ignite.metrics import MetricsLambda 108 | return MetricsLambda(lambda x, y: x ** y, other, self) 109 | 110 | def __mod__(self, other): 111 | from ignite.metrics import MetricsLambda 112 | return MetricsLambda(lambda x, y: x % y, self, other) 113 | 114 | def __div__(self, other): 115 | from ignite.metrics import MetricsLambda 116 | return MetricsLambda(lambda x, y: x.__div__(y), self, other) 117 | 118 | def __rdiv__(self, other): 119 | from ignite.metrics import MetricsLambda 120 | return MetricsLambda(lambda x, y: x.__div__(y), other, self) 121 | 122 | def __truediv__(self, other): 123 | from ignite.metrics import MetricsLambda 124 | return MetricsLambda(lambda x, y: x.__truediv__(y), self, other) 125 | 126 | def __rtruediv__(self, other): 127 | from ignite.metrics import MetricsLambda 128 | return MetricsLambda(lambda x, y: x.__truediv__(y), other, self) 129 | 130 | def __floordiv__(self, other): 131 | from ignite.metrics import MetricsLambda 132 | return MetricsLambda(lambda x, y: x // y, self, other) 133 | -------------------------------------------------------------------------------- /ignite/metrics/metrics_lambda.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.metric import Metric 2 | from ignite.engine import Events 3 | 4 | 5 | class MetricsLambda(Metric): 6 | """ 7 | Apply a function to other metrics to obtain a new metric. 8 | The result of the new metric is defined to be the result 9 | of applying the function to the result of argument metrics. 10 | 11 | When update, this metric does not recursively update the metrics 12 | it depends on. When reset, all its dependency metrics would be 13 | resetted. When attach, all its dependencies would be automatically 14 | attached. 15 | 16 | Args: 17 | f (callable): the function that defines the computation 18 | args (sequence): Sequence of other metrics or something 19 | else that will be fed to ``f`` as arguments. 20 | 21 | Example: 22 | 23 | .. code-block:: python 24 | 25 | precision = Precision(average=False) 26 | recall = Recall(average=False) 27 | 28 | def Fbeta(r, p, beta): 29 | return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item() 30 | 31 | F1 = MetricsLambda(Fbeta, recall, precision, 1) 32 | F2 = MetricsLambda(Fbeta, recall, precision, 2) 33 | F3 = MetricsLambda(Fbeta, recall, precision, 3) 34 | F4 = MetricsLambda(Fbeta, recall, precision, 4) 35 | """ 36 | def __init__(self, f, *args): 37 | self.function = f 38 | self.args = args 39 | super(MetricsLambda, self).__init__() 40 | 41 | def reset(self): 42 | for i in self.args: 43 | if isinstance(i, Metric): 44 | i.reset() 45 | 46 | def update(self, output): 47 | # NB: this method does not recursively update dependency metrics, 48 | # which might cause duplicate update issue. To update this metric, 49 | # users should manually update its dependencies. 50 | pass 51 | 52 | def compute(self): 53 | materialized = [i.compute() if isinstance(i, Metric) else i for i in self.args] 54 | return self.function(*materialized) 55 | 56 | def attach(self, engine, name): 57 | # recursively attach all its dependencies 58 | for index, metric in enumerate(self.args): 59 | if isinstance(metric, Metric): 60 | metric.attach(engine, name + '|$^[{}]'.format(index)) 61 | super(MetricsLambda, self).attach(engine, name) 62 | -------------------------------------------------------------------------------- /ignite/metrics/precision.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.metrics.accuracy import _BaseClassification 6 | from ignite.exceptions import NotComputableError 7 | from ignite.utils import to_onehot 8 | 9 | 10 | class _BasePrecisionRecall(_BaseClassification): 11 | 12 | def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=False): 13 | self._average = average 14 | super(_BasePrecisionRecall, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel) 15 | self.eps = 1e-20 16 | 17 | def reset(self): 18 | self._true_positives = torch.DoubleTensor(0) if self._is_multilabel else 0 19 | self._positives = torch.DoubleTensor(0) if self._is_multilabel else 0 20 | 21 | def compute(self): 22 | if not isinstance(self._positives, torch.Tensor): 23 | raise NotComputableError("{} must have at least one example before" 24 | " it can be computed.".format(self.__class__.__name__)) 25 | 26 | result = self._true_positives / (self._positives + self.eps) 27 | 28 | if self._average: 29 | return result.mean().item() 30 | else: 31 | return result 32 | 33 | 34 | class Precision(_BasePrecisionRecall): 35 | """ 36 | Calculates precision for binary and multiclass data. 37 | 38 | - `update` must receive output of the form `(y_pred, y)`. 39 | - `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...). 40 | - `y` must be in the following shape (batch_size, ...). 41 | 42 | In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of 43 | predictions can be done as below: 44 | 45 | .. code-block:: python 46 | 47 | def thresholded_output_transform(output): 48 | y_pred, y = output 49 | y_pred = torch.round(y_pred) 50 | return y_pred, y 51 | 52 | binary_accuracy = Precision(output_transform=thresholded_output_transform) 53 | 54 | In multilabel cases, average parameter should be True. If the user is trying to metrics to calculate F1 for 55 | example, average paramter should be False. This can be done as shown below: 56 | 57 | .. warning:: 58 | 59 | If average is False, current implementation stores all input data (output and target) in as tensors before 60 | computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM. 61 | 62 | .. code-block:: python 63 | 64 | precision = Precision(average=False, is_multilabel=True) 65 | recall = Recall(average=False, is_multilabel=True) 66 | F1 = precision * recall * 2 / (precision + recall + 1e-20) 67 | F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1) 68 | 69 | Args: 70 | output_transform (callable, optional): a callable that is used to transform the 71 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 72 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 73 | you want to compute the metric with respect to one of the outputs. 74 | average (bool, optional): if True, precision is computed as the unweighted average (across all classes 75 | in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). 76 | is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average 77 | parameter should be True and the average is computed across samples, instead of classes. 78 | """ 79 | 80 | def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=False): 81 | super(Precision, self).__init__(output_transform=output_transform, 82 | average=average, is_multilabel=is_multilabel) 83 | 84 | def update(self, output): 85 | y_pred, y = self._check_shape(output) 86 | self._check_type((y_pred, y)) 87 | 88 | if self._type == "binary": 89 | y_pred = y_pred.view(-1) 90 | y = y.view(-1) 91 | elif self._type == "multiclass": 92 | num_classes = y_pred.size(1) 93 | y = to_onehot(y.view(-1), num_classes=num_classes) 94 | indices = torch.max(y_pred, dim=1)[1].view(-1) 95 | y_pred = to_onehot(indices, num_classes=num_classes) 96 | elif self._type == "multilabel": 97 | # if y, y_pred shape is (N, C, ...) -> (C, N x ...) 98 | num_classes = y_pred.size(1) 99 | y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) 100 | y = torch.transpose(y, 1, 0).reshape(num_classes, -1) 101 | 102 | y = y.type_as(y_pred) 103 | correct = y * y_pred 104 | all_positives = y_pred.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu 105 | 106 | if correct.sum() == 0: 107 | true_positives = torch.zeros_like(all_positives) 108 | else: 109 | true_positives = correct.sum(dim=0) 110 | # Convert from int cuda/cpu to double cpu 111 | # We need double precision for the division true_positives / all_positives 112 | true_positives = true_positives.type(torch.DoubleTensor) 113 | 114 | if self._type == "multilabel": 115 | self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) 116 | self._positives = torch.cat([self._positives, all_positives], dim=0) 117 | else: 118 | self._true_positives += true_positives 119 | self._positives += all_positives 120 | -------------------------------------------------------------------------------- /ignite/metrics/recall.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.metrics.precision import _BasePrecisionRecall 6 | from ignite.utils import to_onehot 7 | 8 | 9 | class Recall(_BasePrecisionRecall): 10 | """ 11 | Calculates recall for binary and multiclass data. 12 | 13 | - `update` must receive output of the form `(y_pred, y)`. 14 | - `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...). 15 | - `y` must be in the following shape (batch_size, ...). 16 | 17 | In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of 18 | predictions can be done as below: 19 | 20 | .. code-block:: python 21 | 22 | def thresholded_output_transform(output): 23 | y_pred, y = output 24 | y_pred = torch.round(y_pred) 25 | return y_pred, y 26 | 27 | binary_accuracy = Recall(output_transform=thresholded_output_transform) 28 | 29 | In multilabel cases, average parameter should be True. If the user is trying to metrics to calculate F1 for 30 | example, average paramter should be False. This can be done as shown below: 31 | 32 | .. warning:: 33 | 34 | If average is False, current implementation stores all input data (output and target) in as tensors before 35 | computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM. 36 | 37 | .. code-block:: python 38 | 39 | precision = Precision(average=False, is_multilabel=True) 40 | recall = Recall(average=False, is_multilabel=True) 41 | F1 = precision * recall * 2 / (precision + recall + 1e-20) 42 | F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1) 43 | 44 | Args: 45 | output_transform (callable, optional): a callable that is used to transform the 46 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 47 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 48 | you want to compute the metric with respect to one of the outputs. 49 | average (bool, optional): if True, precision is computed as the unweighted average (across all classes 50 | in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). 51 | is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average 52 | parameter should be True and the average is computed across samples, instead of classes. 53 | """ 54 | 55 | def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=False): 56 | super(Recall, self).__init__(output_transform=output_transform, 57 | average=average, is_multilabel=is_multilabel) 58 | 59 | def update(self, output): 60 | y_pred, y = self._check_shape(output) 61 | self._check_type((y_pred, y)) 62 | 63 | if self._type == "binary": 64 | y_pred = y_pred.view(-1) 65 | y = y.view(-1) 66 | elif self._type == "multiclass": 67 | num_classes = y_pred.size(1) 68 | y = to_onehot(y.view(-1), num_classes=num_classes) 69 | indices = torch.max(y_pred, dim=1)[1].view(-1) 70 | y_pred = to_onehot(indices, num_classes=num_classes) 71 | elif self._type == "multilabel": 72 | # if y, y_pred shape is (N, C, ...) -> (C, N x ...) 73 | num_classes = y_pred.size(1) 74 | y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) 75 | y = torch.transpose(y, 1, 0).reshape(num_classes, -1) 76 | 77 | y = y.type_as(y_pred) 78 | correct = y * y_pred 79 | actual_positives = y.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu 80 | 81 | if correct.sum() == 0: 82 | true_positives = torch.zeros_like(actual_positives) 83 | else: 84 | true_positives = correct.sum(dim=0) 85 | 86 | # Convert from int cuda/cpu to double cpu 87 | # We need double precision for the division true_positives / actual_positives 88 | true_positives = true_positives.type(torch.DoubleTensor) 89 | 90 | if self._type == "multilabel": 91 | self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) 92 | self._positives = torch.cat([self._positives, actual_positives], dim=0) 93 | else: 94 | self._true_positives += true_positives 95 | self._positives += actual_positives 96 | -------------------------------------------------------------------------------- /ignite/metrics/root_mean_squared_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | 4 | from ignite.metrics.mean_squared_error import MeanSquaredError 5 | 6 | 7 | class RootMeanSquaredError(MeanSquaredError): 8 | """ 9 | Calculates the root mean squared error. 10 | 11 | - `update` must receive output of the form (y_pred, y). 12 | """ 13 | def compute(self): 14 | mse = super(RootMeanSquaredError, self).compute() 15 | return math.sqrt(mse) 16 | -------------------------------------------------------------------------------- /ignite/metrics/running_average.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics import Metric 2 | from ignite.engine import Events 3 | 4 | 5 | class RunningAverage(Metric): 6 | """Compute running average of a metric or the output of process function. 7 | 8 | Args: 9 | src (Metric or None): input source: an instance of :class:`~ignite.metrics.Metric` or None. The latter 10 | corresponds to `engine.state.output` which holds the output of process function. 11 | alpha (float, optional): running average decay factor, default 0.98 12 | output_transform (callable, optional): a function to use to transform the output if `src` is None and 13 | corresponds the output of process function. Otherwise it should be None. 14 | 15 | Examples: 16 | 17 | .. code-block:: python 18 | 19 | alpha = 0.98 20 | acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) 21 | acc_metric.attach(trainer, 'running_avg_accuracy') 22 | 23 | avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) 24 | avg_output.attach(trainer, 'running_avg_loss') 25 | 26 | @trainer.on(Events.ITERATION_COMPLETED) 27 | def log_running_avg_metrics(engine): 28 | print("running avg accuracy:", engine.state.metrics['running_avg_accuracy']) 29 | print("running avg loss:", engine.state.metrics['running_avg_loss']) 30 | 31 | """ 32 | 33 | def __init__(self, src=None, alpha=0.98, output_transform=None): 34 | if not (isinstance(src, Metric) or src is None): 35 | raise TypeError("Argument src should be a Metric or None.") 36 | if not (0.0 < alpha <= 1.0): 37 | raise ValueError("Argument alpha should be a float between 0.0 and 1.0.") 38 | 39 | if isinstance(src, Metric): 40 | if output_transform is not None: 41 | raise ValueError("Argument output_transform should be None if src is a Metric.") 42 | self.src = src 43 | self._get_src_value = self._get_metric_value 44 | self.iteration_completed = self._metric_iteration_completed 45 | else: 46 | if output_transform is None: 47 | raise ValueError("Argument output_transform should not be None if src corresponds " 48 | "to the output of process function.") 49 | self._get_src_value = self._get_output_value 50 | self.update = self._output_update 51 | 52 | self.alpha = alpha 53 | super(RunningAverage, self).__init__(output_transform=output_transform) 54 | 55 | def reset(self): 56 | self._value = None 57 | 58 | def update(self, output): 59 | # Implement abstract method 60 | pass 61 | 62 | def compute(self): 63 | if self._value is None: 64 | self._value = self._get_src_value() 65 | else: 66 | self._value = self._value * self.alpha + (1.0 - self.alpha) * self._get_src_value() 67 | return self._value 68 | 69 | def attach(self, engine, name): 70 | # restart average every epoch 71 | engine.add_event_handler(Events.EPOCH_STARTED, self.started) 72 | # compute metric 73 | engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) 74 | # apply running average 75 | engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name) 76 | 77 | def _get_metric_value(self): 78 | return self.src.compute() 79 | 80 | def _get_output_value(self): 81 | return self.src 82 | 83 | def _metric_iteration_completed(self, engine): 84 | self.src.started(engine) 85 | self.src.iteration_completed(engine) 86 | 87 | def _output_update(self, output): 88 | self.src = output 89 | -------------------------------------------------------------------------------- /ignite/metrics/top_k_categorical_accuracy.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.metrics.metric import Metric 6 | from ignite.exceptions import NotComputableError 7 | 8 | 9 | class TopKCategoricalAccuracy(Metric): 10 | """ 11 | Calculates the top-k categorical accuracy. 12 | 13 | - `update` must receive output of the form `(y_pred, y)`. 14 | """ 15 | def __init__(self, k=5, output_transform=lambda x: x): 16 | super(TopKCategoricalAccuracy, self).__init__(output_transform) 17 | self._k = k 18 | 19 | def reset(self): 20 | self._num_correct = 0 21 | self._num_examples = 0 22 | 23 | def update(self, output): 24 | y_pred, y = output 25 | sorted_indices = torch.topk(y_pred, self._k, dim=1)[1] 26 | expanded_y = y.view(-1, 1).expand(-1, self._k) 27 | correct = torch.sum(torch.eq(sorted_indices, expanded_y), dim=1) 28 | self._num_correct += torch.sum(correct).item() 29 | self._num_examples += correct.shape[0] 30 | 31 | def compute(self): 32 | if self._num_examples == 0: 33 | raise NotComputableError("TopKCategoricalAccuracy must have at" 34 | "least one example before it can be computed.") 35 | return self._num_correct / self._num_examples 36 | -------------------------------------------------------------------------------- /ignite/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from torch._six import string_classes 5 | 6 | IS_PYTHON2 = sys.version_info[0] < 3 7 | 8 | if IS_PYTHON2: 9 | import collections 10 | else: 11 | import collections.abc as collections 12 | 13 | 14 | def convert_tensor(input_, device=None, non_blocking=False): 15 | """Move tensors to relevant device.""" 16 | def _func(tensor): 17 | return tensor.to(device=device, non_blocking=non_blocking) if device else tensor 18 | 19 | return apply_to_tensor(input_, _func) 20 | 21 | 22 | def apply_to_tensor(input_, func): 23 | """Apply a function on a tensor or mapping, or sequence of tensors. 24 | """ 25 | return apply_to_type(input_, torch.Tensor, func) 26 | 27 | 28 | def apply_to_type(input_, input_type, func): 29 | """Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`. 30 | """ 31 | if isinstance(input_, input_type): 32 | return func(input_) 33 | elif isinstance(input_, string_classes): 34 | return input_ 35 | elif isinstance(input_, collections.Mapping): 36 | return {k: apply_to_type(sample, input_type, func) for k, sample in input_.items()} 37 | elif isinstance(input_, collections.Sequence): 38 | return [apply_to_type(sample, input_type, func) for sample in input_] 39 | else: 40 | raise TypeError(("input must contain {}, dicts or lists; found {}" 41 | .format(input_type, type(input_)))) 42 | 43 | 44 | def to_onehot(indices, num_classes): 45 | """Convert a tensor of indices to a tensor of one-hot indicators.""" 46 | onehot = torch.zeros(indices.size(0), num_classes, device=indices.device) 47 | return onehot.scatter_(1, indices.unsqueeze(1), 1) 48 | -------------------------------------------------------------------------------- /oneshot/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for oneshot.""" 4 | 5 | __author__ = """Amit Aides""" 6 | __email__ = 'amitaid@il.ibm.com' 7 | __version__ = '0.1.0' 8 | -------------------------------------------------------------------------------- /oneshot/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/__pycache__/coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/__pycache__/coco.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/alfassy/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import CocoDataset 2 | from .datasets import CocoDatasetAugmentation 3 | from .datasets import CocoDatasetPairs 4 | from .datasets import CocoDatasetTriplets 5 | from .datasets import labels_list_to_1hot 6 | from .datasets import CocoDatasetPairsSub 7 | 8 | from .img_to_vec import Img2OurVec 9 | 10 | from .setops_funcs import set_intersection_operation 11 | from .setops_funcs import set_subtraction_operation 12 | from .setops_funcs import set_union_operation 13 | 14 | from .utils import IOU_fake_vectors_accuracy 15 | from .utils import IOU_real_vectors_accuracy 16 | from .utils import precision_recall_statistics 17 | from .utils import get_subtraction_exp 18 | from .utils import set_subtraction_operation 19 | from .utils import set_union_operation 20 | from .utils import set_intersection_operation 21 | from .utils import configure_logging 22 | from .utils import save_checkpoint 23 | from .utils import get_learning_rate 24 | 25 | -------------------------------------------------------------------------------- /oneshot/alfassy/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/alfassy/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/alfassy/__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/alfassy/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/alfassy/__pycache__/img_to_vec.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/alfassy/__pycache__/img_to_vec.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/alfassy/__pycache__/setops_funcs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/alfassy/__pycache__/setops_funcs.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/alfassy/__pycache__/testing_functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/alfassy/__pycache__/testing_functions.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/alfassy/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/alfassy/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/alfassy/img_to_vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torchvision.transforms as transforms 5 | import os 6 | import random 7 | import torch.nn.functional as F 8 | 9 | use_cuda = True if torch.cuda.is_available() else False 10 | random.seed(5) 11 | torch.manual_seed(5) 12 | if use_cuda: 13 | torch.cuda.manual_seed_all(5) 14 | 15 | 16 | class Img2OurVec(): 17 | #def __init__(self, model='inception_v3', layer='default', layer_output_size=2048): 18 | def __init__(self, model='inception', layer='default', layer_output_size=2048, data="top10", transform=None): 19 | """ Img2Vec 20 | :param model: String name of requested model 21 | :param layer: String or Int depending on model. See more docs: https://github.com/christiansafka/img2vec.git 22 | :param layer_output_size: Int depicting the output size of the requested layer 23 | """ 24 | cuda = True if torch.cuda.is_available() else False 25 | 26 | self.device = torch.device("cuda" if cuda else "cpu") 27 | self.layer_output_size = layer_output_size 28 | # self.model_path = '/dccstor/alfassy/saved_models/inception_traincocoInceptionT10Half2018.9.1.9:30epoch:71' 29 | # self.model_path = '/dccstor/alfassy/saved_models/inception_trainCocoIncHalf2018.10.3.13:39best' 30 | # self.model_path = '/dccstor/alfassy/saved_models/inception_trainCocoIncHalf2018.10.8.12:46best' 31 | self.model_path = '/dccstor/alfassy/saved_models/inception_trainCocoIncHalf642018.10.9.13:44epoch:30' 32 | self.model, self.extraction_layer = self._get_model_and_layer(model, layer, data) 33 | self.model = self.model.to(self.device) 34 | self.model.eval() 35 | #self.scaler = transforms.Resize(224, 224) 36 | #self.scaler = transforms.Scale((224, 224)) 37 | self.transform = transform 38 | self.model_name = model 39 | 40 | def get_vec(self, image, tensor=True): 41 | """ Get vector embedding from PIL image 42 | :param img: PIL Image 43 | :param tensor: If True, get_vec will return a FloatTensor instead of Numpy array 44 | :returns: Numpy ndarray 45 | """ 46 | 47 | if self.transform is not None: 48 | image = self.transform(image).unsqueeze(0).to(self.device) 49 | 50 | batch_size = image.shape[0] 51 | 52 | # print(image.shape) 53 | if self.model_name == "inception": 54 | my_embedding = torch.zeros(batch_size, self.layer_output_size, 8, 8).to(self.device) 55 | 56 | else: 57 | my_embedding = torch.zeros(batch_size, self.layer_output_size, 1, 1).to(self.device) 58 | 59 | def copy_data_resnet(m, i, o): 60 | my_embedding.copy_(o.data) 61 | 62 | def copy_data_inception(m, i, o): 63 | my_embedding.copy_(i.data) 64 | 65 | if self.model_name == "inception": 66 | h = self.extraction_layer.register_forward_hook(copy_data_resnet) 67 | else: 68 | h = self.extraction_layer.register_forward_hook(copy_data_resnet) 69 | h_x = self.model(image) 70 | h.remove() 71 | # print(my_embedding.shape) 72 | my_embedding = F.avg_pool2d(my_embedding, kernel_size=8) 73 | 74 | if tensor: 75 | return my_embedding 76 | else: 77 | return my_embedding.numpy()[0, :, 0, 0] 78 | 79 | def _get_model_and_layer(self, model_name, layer, data): 80 | """ Internal method for getting layer from model 81 | :param model_name: model name such as 'resnet-18' 82 | :param layer: layer as a string for resnet-18 or int for alexnet 83 | :returns: pytorch model, selected layer 84 | """ 85 | if data == "full": 86 | out_size = 200 87 | else: 88 | out_size = 80 89 | 90 | if model_name == 'inception': 91 | model = models.inception_v3(pretrained=True) 92 | num_ftrs = model.fc.in_features 93 | model.fc = nn.Linear(num_ftrs, out_size) 94 | num_ftrs = model.AuxLogits.fc.in_features 95 | model.AuxLogits.fc = nn.Linear(num_ftrs, out_size) 96 | elif model_name == 'resnet18': 97 | model = models.resnet18(pretrained=True) 98 | num_ftrs = model.fc.in_features 99 | model.fc = nn.Linear(num_ftrs, out_size) 100 | elif model_name == 'ourT10Class': 101 | # model = torch.load('/dccstor/alfassy/saved_models/trained_discriminatorfeatureClassifierTrain2018.8.22.12:54epoch:128') 102 | model = torch.load('/dccstor/alfassy/saved_models/inception_trainincT10Half2018.9.4.14:40epoch:26') 103 | else: 104 | raise KeyError('Model %s was not found' % model_name) 105 | 106 | model.eval() 107 | 108 | if use_cuda: 109 | model.cuda() 110 | 111 | if model_name == 'inception' or model_name == 'resnet18': 112 | # Load checkpoint. 113 | assert os.path.isfile(self.model_path), 'Error: no checkpoint found!' 114 | checkpoint = torch.load(self.model_path) 115 | best_acc = checkpoint['best_acc'] 116 | start_epoch = checkpoint['epoch'] 117 | model.load_state_dict(checkpoint['state_dict']) 118 | if model_name == 'inception': 119 | if layer == 'default': 120 | layer = model._modules.get('Mixed_7c') 121 | self.layer_output_size = 2048 122 | else: 123 | raise Exception('wrong layer name') 124 | return model, layer 125 | elif model_name == 'resnet18': 126 | if layer == 'default': 127 | layer = model._modules.get('avgpool') 128 | self.layer_output_size = 512 129 | else: 130 | raise Exception('wrong layer name') 131 | return model, layer 132 | elif model_name == 'ourT10Class': 133 | layer = model._modules.get('linear_block') 134 | self.layer_output_size = 2048 135 | return model, layer -------------------------------------------------------------------------------- /oneshot/alfassy/setops_funcs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_subtraction_operation(labels1, labels2): 6 | batch_size = labels1.shape[0] 7 | classesNum = labels1.shape[1] 8 | # print("labels1: ", labels1) 9 | # print("labels2: ", labels2) 10 | subLabels = [] 11 | for vecNum in range(batch_size): 12 | subLabelPerClass = [] 13 | for classNum in range(classesNum): 14 | if (labels1[vecNum][classNum] == 1) and (labels2[vecNum][classNum] == 0): 15 | subLabelPerClass += [1] 16 | else: 17 | subLabelPerClass += [0] 18 | subLabels += [subLabelPerClass] 19 | # print(subLabels) 20 | npSubLabels = np.asarray(subLabels) 21 | # print(npSubLabels) 22 | torSubLabels = torch.from_numpy(npSubLabels) 23 | # print(torSubLabels) 24 | return torSubLabels 25 | 26 | 27 | def set_union_operation(labels1, labels2): 28 | batch_size = labels1.shape[0] 29 | classesNum = labels1.shape[1] 30 | subLabels = [] 31 | for vecNum in range(batch_size): 32 | subLabelPerClass = [] 33 | for classNum in range(classesNum): 34 | if (labels1[vecNum][classNum] == 1) or (labels2[vecNum][classNum] == 1): 35 | subLabelPerClass += [1] 36 | else: 37 | subLabelPerClass += [0] 38 | subLabels += [subLabelPerClass] 39 | npSubLabels = np.asarray(subLabels) 40 | torSubLabels = torch.from_numpy(npSubLabels) 41 | return torSubLabels 42 | 43 | 44 | def set_intersection_operation(labels1, labels2): 45 | batch_size = labels1.shape[0] 46 | classesNum = labels1.shape[1] 47 | subLabels = [] 48 | for vecNum in range(batch_size): 49 | subLabelPerClass = [] 50 | for classNum in range(classesNum): 51 | if (labels1[vecNum][classNum] == 1) and (labels2[vecNum][classNum] == 1): 52 | subLabelPerClass += [1] 53 | else: 54 | subLabelPerClass += [0] 55 | subLabels += [subLabelPerClass] 56 | npSubLabels = np.asarray(subLabels) 57 | torSubLabels = torch.from_numpy(npSubLabels) 58 | return torSubLabels 59 | 60 | 61 | def set_subtraction_operation_one_sample(labels1, labels2): 62 | classesNum = labels1.shape[0] 63 | # print("labels1: ", labels1) 64 | # print("labels2: ", labels2) 65 | subLabelPerClass = [] 66 | for classNum in range(classesNum): 67 | if (labels1[classNum] == 1) and (labels2[classNum] == 0): 68 | subLabelPerClass += [1] 69 | else: 70 | subLabelPerClass += [0] 71 | # print(subLabels) 72 | npSubLabels = np.asarray(subLabelPerClass) 73 | # print(npSubLabels) 74 | # subLabelPerClass = torch.from_numpy(subLabelPerClass) 75 | # print(torSubLabels) 76 | return npSubLabels 77 | 78 | 79 | def set_union_operation_one_sample(labels1, labels2): 80 | classesNum = labels1.shape[0] 81 | subLabelPerClass = [] 82 | for classNum in range(classesNum): 83 | if (labels1[classNum] == 1) or (labels2[classNum] == 1): 84 | subLabelPerClass += [1] 85 | else: 86 | subLabelPerClass += [0] 87 | npSubLabels = np.asarray(subLabelPerClass) 88 | # torSubLabels = torch.from_numpy(npSubLabels) 89 | return npSubLabels 90 | 91 | 92 | def set_intersection_operation_one_sample(labels1, labels2): 93 | classesNum = labels1.shape[0] 94 | subLabelPerClass = [] 95 | for classNum in range(classesNum): 96 | if (labels1[classNum] == 1) and (labels2[classNum] == 1): 97 | subLabelPerClass += [1] 98 | else: 99 | subLabelPerClass += [0] 100 | npSubLabels = np.asarray(subLabelPerClass) 101 | # torSubLabels = torch.from_numpy(npSubLabels) 102 | return npSubLabels -------------------------------------------------------------------------------- /oneshot/cnnvisualizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/cnnvisualizer/__init__.py -------------------------------------------------------------------------------- /oneshot/cnnvisualizer/cnnvisualizer.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | 5 | IMG_EXTENSIONS = ['.png', '.jpg'] 6 | 7 | 8 | def default_inception_transform(img_size): 9 | tf = transforms.Compose([ 10 | transforms.Scale(img_size), 11 | transforms.CenterCrop(img_size), 12 | transforms.ToTensor(), 13 | LeNormalize(), 14 | ]) 15 | return tf 16 | 17 | 18 | 19 | class Dataset(data.Dataset): 20 | 21 | def __init__(self,imglist,transform=None): 22 | 23 | if len(imglist) == 0: 24 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 25 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 26 | 27 | self.imgs = imglist 28 | self.transform = transform 29 | 30 | def __getitem__(self, index): 31 | path = self.imgs[index] 32 | target = None 33 | img = Image.open(path).convert('RGB') 34 | if self.transform is not None: 35 | img = self.transform(img) 36 | return img, path 37 | 38 | def __len__(self): 39 | return len(self.imgs) 40 | 41 | -------------------------------------------------------------------------------- /oneshot/cnnvisualizer/wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from oneshot.cnnvisualizer import wideresnet_utils as utils 3 | 4 | 5 | def resnet(depth, width, num_classes): 6 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 7 | n = (depth - 4) // 6 8 | widths = [int(v * width) for v in (16, 32, 64)] 9 | 10 | def gen_block_params(ni, no): 11 | return { 12 | 'conv0': utils.conv_params(ni, no, 3), 13 | 'conv1': utils.conv_params(no, no, 3), 14 | 'bn0': utils.bnparams(ni), 15 | 'bn1': utils.bnparams(no), 16 | 'convdim': utils.conv_params(ni, no, 1) if ni != no else None, 17 | } 18 | 19 | def gen_group_params(ni, no, count): 20 | return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) 21 | for i in range(count)} 22 | 23 | flat_params = utils.cast(utils.flatten({ 24 | 'conv0': utils.conv_params(3, 16, 3), 25 | 'group0': gen_group_params(16, widths[0], n), 26 | 'group1': gen_group_params(widths[0], widths[1], n), 27 | 'group2': gen_group_params(widths[1], widths[2], n), 28 | 'bn': utils.bnparams(widths[2]), 29 | 'fc': utils.linear_params(widths[2], num_classes), 30 | })) 31 | 32 | utils.set_requires_grad_except_bn_(flat_params) 33 | 34 | def block(x, params, base, mode, stride): 35 | o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode), inplace=True) 36 | y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1) 37 | o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode), inplace=True) 38 | z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1) 39 | if base + '.convdim' in params: 40 | return z + F.conv2d(o1, params[base + '.convdim'], stride=stride) 41 | else: 42 | return z + x 43 | 44 | def group(o, params, base, mode, stride): 45 | for i in range(n): 46 | o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1) 47 | return o 48 | 49 | def f(input, params, mode): 50 | x = F.conv2d(input, params['conv0'], padding=1) 51 | g0 = group(x, params, 'group0', mode, 1) 52 | g1 = group(g0, params, 'group1', mode, 2) 53 | g2 = group(g1, params, 'group2', mode, 2) 54 | o = F.relu(utils.batch_norm(g2, params, 'bn', mode)) 55 | o = F.avg_pool2d(o, 8, 1, 0) 56 | o = o.view(o.size(0), -1) 57 | o = F.linear(o, params['fc.weight'], params['fc.bias']) 58 | return o 59 | 60 | return f, flat_params -------------------------------------------------------------------------------- /oneshot/cnnvisualizer/wideresnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.init import kaiming_normal_ 3 | import torch.nn.functional as F 4 | from torch.nn.parallel._functions import Broadcast 5 | from torch.nn.parallel import scatter, parallel_apply, gather 6 | from functools import partial 7 | from nested_dict import nested_dict 8 | 9 | 10 | def cast(params, dtype='float'): 11 | if isinstance(params, dict): 12 | return {k: cast(v, dtype) for k,v in params.items()} 13 | else: 14 | return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)() 15 | 16 | 17 | def conv_params(ni, no, k=1): 18 | return kaiming_normal_(torch.Tensor(no, ni, k, k)) 19 | 20 | 21 | def linear_params(ni, no): 22 | return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)} 23 | 24 | 25 | def bnparams(n): 26 | return {'weight': torch.rand(n), 27 | 'bias': torch.zeros(n), 28 | 'running_mean': torch.zeros(n), 29 | 'running_var': torch.ones(n)} 30 | 31 | 32 | def data_parallel(f, input, params, mode, device_ids, output_device=None): 33 | assert isinstance(device_ids, list) 34 | if output_device is None: 35 | output_device = device_ids[0] 36 | 37 | if len(device_ids) == 1: 38 | return f(input, params, mode) 39 | 40 | params_all = Broadcast.apply(device_ids, *params.values()) 41 | params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())} 42 | for j in range(len(device_ids))] 43 | 44 | replicas = [partial(f, params=p, mode=mode) 45 | for p in params_replicas] 46 | inputs = scatter([input], device_ids) 47 | outputs = parallel_apply(replicas, inputs) 48 | return gather(outputs, output_device) 49 | 50 | 51 | def flatten(params): 52 | return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None} 53 | 54 | 55 | def batch_norm(x, params, base, mode): 56 | return F.batch_norm(x, weight=params[base + '.weight'], 57 | bias=params[base + '.bias'], 58 | running_mean=params[base + '.running_mean'], 59 | running_var=params[base + '.running_var'], 60 | training=mode) 61 | 62 | 63 | def print_tensor_dict(params): 64 | kmax = max(len(key) for key in params.keys()) 65 | for i, (key, v) in enumerate(params.items()): 66 | print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad) 67 | 68 | 69 | def set_requires_grad_except_bn_(params): 70 | for k, v in params.items(): 71 | if not k.endswith('running_mean') and not k.endswith('running_var'): 72 | v.requires_grad = True -------------------------------------------------------------------------------- /oneshot/coco.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | 11 | from pycocotools.coco import COCO 12 | 13 | from PIL import Image 14 | 15 | 16 | COCO_LABELS_NUM = 80 17 | 18 | 19 | def get_coco_labels(coco_ann: COCO) -> Tuple[dict, dict]: 20 | """Get the mapping between category and label. 21 | 22 | The coco categories ids are numbers between 1-90. This function maps 23 | these to labels between 0-79. 24 | """ 25 | 26 | categories = coco_ann.loadCats(coco_ann.getCatIds()) 27 | categories.sort(key=lambda x: x['id']) 28 | 29 | classes = {} 30 | id2label_map = {} 31 | 32 | for cat in categories: 33 | id2label_map[cat['id']] = len(classes) 34 | classes[len(classes)] = cat['name'] 35 | 36 | return id2label_map, classes 37 | 38 | 39 | def load_paths_multilabels(annotation_file: Path, imgs_base: Path) -> Tuple[list, list, dict]: 40 | """Load the paths to coco images and the multi-labels. 41 | 42 | """ 43 | coco_ann = COCO(annotation_file=annotation_file) 44 | 45 | id2label_map, classes = get_coco_labels(coco_ann) 46 | 47 | imgs_paths, imgs_labels = [], [] 48 | for img_id, img_data in coco_ann.imgs.items(): 49 | img_anns_ids = coco_ann.getAnnIds(imgIds=[img_id]) 50 | img_anns = coco_ann.loadAnns(img_anns_ids) 51 | 52 | # 53 | # Filter annotations with empty bounding box. 54 | # 55 | img_labels = { 56 | id2label_map[ann["category_id"]] for ann in img_anns \ 57 | if ann["bbox"][2] > 0 and ann["bbox"][3] > 0 58 | } 59 | 60 | imgs_paths.append(str(imgs_base / img_data["file_name"])) 61 | imgs_labels.append(list(img_labels)) 62 | 63 | return imgs_paths, imgs_labels, classes 64 | 65 | 66 | class CocoMlDataset(Dataset): 67 | """Dataset class for the Multilabel COCO dataset 68 | """ 69 | 70 | def __init__(self, image_paths: list, labels: list, transform: transforms.Compose = None, 71 | categories_num: int = COCO_LABELS_NUM): 72 | 73 | self.image_paths = image_paths 74 | self.labels = labels 75 | self.categories_num = categories_num 76 | 77 | self.transform = transform 78 | 79 | def __getitem__(self, index: int) -> Tuple[Image.Image, np.array]: 80 | """ 81 | Args: 82 | index (int): Index 83 | Returns: 84 | tuple: (sample, target) where target is class_index of the target class. 85 | """ 86 | image_path = self.image_paths[index] 87 | target_inds = self.labels[index] 88 | target = np.zeros(self.categories_num, dtype=np.float32) 89 | target[target_inds] = 1 90 | 91 | sample = Image.open(image_path) 92 | if sample.mode != 'RGB': 93 | sample = sample.convert(mode='RGB') 94 | 95 | if self.transform is not None: 96 | sample = self.transform(sample) 97 | 98 | return sample, target 99 | 100 | def __len__(self) -> int: 101 | return len(self.image_paths) 102 | 103 | 104 | def copy_coco_data(force=False): 105 | """Copy the coco data to local machine for speed of reading. 106 | 107 | Args: 108 | force (bool, optional): Force copying coco data even if tmp folder exists. 109 | """ 110 | 111 | if not os.path.exists("/tmp/aa"): 112 | os.system("mkdir /tmp/aa") 113 | 114 | if not os.path.exists("/tmp/aa/coco") or force: 115 | logging.info("Copying data to tmp") 116 | os.system("mkdir /tmp/aa/coco") 117 | os.system("mkdir /tmp/aa/coco/images") 118 | 119 | # 120 | # Copy the train data. 121 | # 122 | os.system("curl -o /tmp/aa/coco/train2014.zip FILE:/dccstor/faceid/data/oneshot/coco/train2014.zip") 123 | os.system( 124 | """unzip -o /tmp/aa/coco/train2014.zip -d /tmp/aa/coco/images | awk 'BEGIN {ORS=" "} {if(NR%100==0)print "."}'""") 125 | os.remove("/tmp/aa/coco/train2014.zip") 126 | 127 | # 128 | # Copy the train data. 129 | # 130 | os.system("curl -o /tmp/aa/coco/val2014.zip FILE:/dccstor/faceid/data/oneshot/coco/val2014.zip") 131 | os.system( 132 | """unzip -o /tmp/aa/coco/val2014.zip -d /tmp/aa/coco/images | awk 'BEGIN {ORS=" "} {if(NR%100==0)print "."}'""") 133 | os.remove("/tmp/aa/coco/val2014.zip") 134 | 135 | # 136 | # Copy the annotation data. 137 | # 138 | os.system( 139 | "curl -o /tmp/aa/coco/annotations_trainval2014.zip FILE:/dccstor/faceid/data/oneshot/coco/annotations_trainval2014.zip") 140 | os.system( 141 | """unzip -o /tmp/aa/coco/annotations_trainval2014.zip -d /tmp/aa/coco | awk 'BEGIN {ORS=" "} {if(NR%100==0)print "."}'""") 142 | os.remove("/tmp/aa/coco/annotations_trainval2014.zip") 143 | 144 | 145 | -------------------------------------------------------------------------------- /oneshot/global_settings.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import os 4 | from pkg_resources import resource_filename 5 | import warnings 6 | 7 | 8 | RESOURCES_BASE = os.path.abspath(resource_filename(__name__, '../resources')) 9 | 10 | try: 11 | FACEID_BASE = os.environ['FACEID_BASE'] 12 | except Exception: 13 | FACEID_BASE = 'C:' 14 | warnings.warn('Failed to find find FACEID_BASE environment variable') 15 | 16 | RESULTS_HOME = os.path.join(FACEID_BASE, 'results') 17 | TENSORBOARD_HOME = os.path.join(FACEID_BASE, 'tensorboard') 18 | -------------------------------------------------------------------------------- /oneshot/ignite/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/ignite/__init__.py -------------------------------------------------------------------------------- /oneshot/ignite/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/ignite/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/ignite/engine/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ignite.engine.engine import Engine, State, Events 3 | from ignite._utils import convert_tensor 4 | 5 | 6 | def _prepare_batch(batch, device=None): 7 | xs, ys = batch 8 | return [convert_tensor(x, device=device) for x in xs],\ 9 | [convert_tensor(y, device=device) for y in ys] 10 | 11 | 12 | def create_supervised_trainer(model, optimizer, weights, loss_fns, device=None): 13 | """ 14 | Factory function for creating a trainer for supervised models 15 | 16 | Args: 17 | model (`torch.nn.Module`): the model to train 18 | optimizer (`torch.optim.Optimizer`): the optimizer to use 19 | loss_fn (torch.nn loss function): the loss function to use 20 | device (str, optional): device type specification (default: None). 21 | Applies to both model and batches. 22 | 23 | Returns: 24 | Engine: a trainer engine with supervised update function 25 | """ 26 | if device: 27 | model.to(device) 28 | 29 | def _update(engine, batch): 30 | model.train() 31 | optimizer.zero_grad() 32 | 33 | x, ys = _prepare_batch(batch, device=device) 34 | y_preds = model(xs) 35 | 36 | losses = [loss_fn(y_pred, y) for loss_fn, y_pred, y in zip(loss_fns, y_preds, ys)] 37 | 38 | loss = 0 39 | for w, l in zip(weights, losses): 40 | loss += w.value * l 41 | 42 | loss.backward() 43 | optimizer.step() 44 | 45 | return loss.item() 46 | 47 | return Engine(_update) 48 | 49 | 50 | def create_supervised_evaluator(model, metrics={}, device=None): 51 | """ 52 | Factory function for creating an evaluator for supervised models 53 | 54 | Args: 55 | model (`torch.nn.Module`): the model to train 56 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 57 | device (str, optional): device type specification (default: None). 58 | Applies to both model and batches. 59 | 60 | Returns: 61 | Engine: an evaluator engine with supervised inference function 62 | """ 63 | if device: 64 | model.to(device) 65 | 66 | def _inference(engine, batch): 67 | model.eval() 68 | with torch.no_grad(): 69 | x, y = _prepare_batch(batch, device=device) 70 | y_pred = model(x) 71 | return y_pred, y 72 | 73 | engine = Engine(_inference) 74 | 75 | for name, metric in metrics.items(): 76 | metric.attach(engine, name) 77 | 78 | return engine 79 | -------------------------------------------------------------------------------- /oneshot/ignite/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | from .param_scheduler import ManualParamScheduler 2 | from .find_learning_rate import LRFinder -------------------------------------------------------------------------------- /oneshot/ignite/handlers/find_learning_rate.py: -------------------------------------------------------------------------------- 1 | from ignite.engine import Events 2 | from ignite.metrics import Metric 3 | import math 4 | 5 | 6 | class LRFinder(Metric): 7 | """""" 8 | 9 | def __init__(self, optimizer, init_lr=1e-8, final_lr=10., beta=0.98, output_transform=lambda x: x): 10 | 11 | self.init_lr = init_lr 12 | self.final_lr = final_lr 13 | self.beta = beta 14 | 15 | self.optimizer = optimizer 16 | 17 | super(LRFinder, self).__init__(output_transform=output_transform) 18 | 19 | def attach(self, engine, name): 20 | """ Register callbacks to control the search for learning rate. 21 | 22 | Args: 23 | engine (ignite.engine.Engine): 24 | Engine that this handler will be attached to 25 | 26 | Returns: 27 | self (Timer) 28 | 29 | """ 30 | 31 | engine.add_event_handler(Events.EPOCH_STARTED, self.started) 32 | engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) 33 | engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name) 34 | 35 | self.engine = engine 36 | 37 | return self 38 | 39 | def _update_optimizer_lr(self): 40 | self.optimizer.param_groups[0]['lr'] = self.lr 41 | 42 | def reset(self): 43 | """ 44 | Resets the metric to to it's initial state. 45 | 46 | This is called at the start of each epoch. 47 | """ 48 | self.lr = self.init_lr 49 | self._update_optimizer_lr() 50 | 51 | self.avg_loss = 0 52 | 53 | def update(self, output): 54 | """ 55 | Updates the metric's state using the passed batch output. 56 | 57 | This is called once for each batch. 58 | 59 | Args: 60 | output: the is the output from the engine's process function 61 | """ 62 | self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * output 63 | 64 | self.lr *= self.mult 65 | self._update_optimizer_lr() 66 | 67 | def compute(self): 68 | """ 69 | Computes the metric based on it's accumulated state. 70 | 71 | This is called at the end of each epoch. 72 | 73 | Returns: 74 | Any: the actual quantity of interest 75 | 76 | Raises: 77 | NotComputableError: raised when the metric cannot be computed 78 | """ 79 | # 80 | # Compute the smoothed loss 81 | # 82 | smoothed_loss = self.avg_loss / (1 - self.beta ** (self.engine.state.iteration + 1)) 83 | 84 | # 85 | # Stop if the loss is exploding 86 | # 87 | if self.engine.state.iteration > 1 and smoothed_loss > 4 * self.best_loss: 88 | self.engine.terminate() 89 | 90 | # 91 | # Record the best loss 92 | # 93 | if self.engine.state.iteration == 1 or smoothed_loss < self.best_loss: 94 | self.best_loss = smoothed_loss 95 | 96 | # 97 | # Store the values 98 | # 99 | self.log_lr = math.log10(self.lr) 100 | 101 | return smoothed_loss 102 | 103 | def started(self, engine): 104 | 105 | self.mult = (self.final_lr / self.init_lr) ** (1 / len(engine.state.dataloader)) 106 | -------------------------------------------------------------------------------- /oneshot/ignite/handlers/param_scheduler.py: -------------------------------------------------------------------------------- 1 | from ignite.contrib.handlers.param_scheduler import ParamScheduler 2 | 3 | 4 | class ManualParamScheduler(ParamScheduler): 5 | """A class for updating an optimizer's parameter value manually. 6 | 7 | Args: 8 | optimizer (`torch.optim.Optimizer`): the optimizer to use 9 | param_name (str): name of optimizer's parameter to update 10 | param_callback (callable): A callback that should return the value. 11 | save_history (bool, optional): whether to log the parameter values 12 | (default=False) 13 | """ 14 | def __init__(self, optimizer, param_name, param_callback, save_history=False): 15 | super(ManualParamScheduler, self).__init__(optimizer, param_name, save_history=save_history) 16 | 17 | self.param_callback = param_callback 18 | 19 | def get_param(self): 20 | """Method to get current optimizer's parameter value 21 | """ 22 | return self.param_callback() 23 | 24 | 25 | -------------------------------------------------------------------------------- /oneshot/ignite/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import EWMeanSquaredError 2 | from .metrics import mAP 3 | from .metrics import MultiLabelSoftMarginAccuracy 4 | from .metrics import MultiLabelSoftMarginIOUaccuracy 5 | from .metrics import ReductionMetric -------------------------------------------------------------------------------- /oneshot/ignite/metrics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/ignite/metrics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/ignite/metrics/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/ignite/metrics/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ignite.engine import Engine 5 | 6 | 7 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 8 | '''Returns mixed inputs, pairs of targets, and lambda''' 9 | 10 | if alpha > 0: 11 | lam = np.random.beta(alpha, alpha) 12 | else: 13 | lam = 1 14 | 15 | batch_size = x.size()[0] 16 | 17 | if use_cuda: 18 | index = torch.randperm(batch_size).cuda() 19 | else: 20 | index = torch.randperm(batch_size) 21 | 22 | mixed_x = lam * x + (1 - lam) * x[index, :] 23 | y_a, y_b = y, y[index] 24 | 25 | return mixed_x, y_a, y_b, lam 26 | 27 | 28 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 29 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 30 | 31 | 32 | def create_mixup_trainer(model, optimizer, loss_fn, alpha=1.0, device=None): 33 | """ 34 | Factory function for creating a trainer for mixup augmented models 35 | 36 | Args: 37 | model (`torch.nn.Module`): the model to train 38 | optimizer (`torch.optim.Optimizer`): the optimizer to use 39 | loss_fn (torch.nn loss function): the loss function to use 40 | device (str, optional): device type specification (default: None). 41 | Applies to both model and batches. 42 | 43 | Returns: 44 | Engine: a trainer engine with supervised update function 45 | """ 46 | if device: 47 | model.to(device) 48 | 49 | def _update(engine, batch): 50 | 51 | from ignite.engine import _prepare_batch 52 | 53 | model.train() 54 | optimizer.zero_grad() 55 | 56 | inputs, targets = _prepare_batch(batch, device=device) 57 | 58 | inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, 59 | alpha, use_cuda=(device=="cuda")) 60 | outputs = model(inputs) 61 | 62 | loss = mixup_criterion(loss_fn, outputs, targets_a, targets_b, lam) 63 | 64 | loss.backward() 65 | optimizer.step() 66 | return loss.item() 67 | 68 | return Engine(_update) 69 | 70 | 71 | def create_mixup_trainer_x2(model, optimizer, loss_fn, alpha=1.0, device=None): 72 | """ 73 | Factory function for creating a trainer for mixup augmented models. 74 | It expects that the model outputs two outputs (like the Inception3 auxlogits). 75 | 76 | Args: 77 | model (`torch.nn.Module`): the model to train 78 | optimizer (`torch.optim.Optimizer`): the optimizer to use 79 | loss_fn (torch.nn loss function): the loss function to use 80 | device (str, optional): device type specification (default: None). 81 | Applies to both model and batches. 82 | 83 | Returns: 84 | Engine: a trainer engine with supervised update function 85 | """ 86 | if device: 87 | model.to(device) 88 | 89 | def _update(engine, batch): 90 | 91 | from ignite.engine import _prepare_batch 92 | 93 | model.train() 94 | optimizer.zero_grad() 95 | 96 | inputs, targets = _prepare_batch(batch, device=device) 97 | 98 | inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, 99 | alpha, use_cuda=(device=="cuda")) 100 | outputs1, outputs2 = model(inputs) 101 | 102 | loss1 = mixup_criterion(loss_fn, outputs1, targets_a, targets_b, lam) 103 | loss2 = mixup_criterion(loss_fn, outputs2, targets_a, targets_b, lam) 104 | 105 | loss = loss1 + loss2 106 | 107 | loss.backward() 108 | optimizer.step() 109 | return loss.item() 110 | 111 | return Engine(_update) 112 | 113 | 114 | -------------------------------------------------------------------------------- /oneshot/oneshot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Main module.""" 4 | -------------------------------------------------------------------------------- /oneshot/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import InverseDataset 2 | from .datasets import ZippedDataLoader 3 | from .losses import FocalLoss -------------------------------------------------------------------------------- /oneshot/pytorch/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/pytorch/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/pytorch/__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/pytorch/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/pytorch/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/pytorch/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/pytorch/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class ZippedDataLoader(DataLoader): 9 | """Wrapper class for zipping together several dataloaders. 10 | """ 11 | def __init__(self, *data_loaders): 12 | self.data_loaders = data_loaders 13 | 14 | def __setattr__(self, attr, val): 15 | if attr == "data_loaders": 16 | super(ZippedDataLoader, self).__setattr__(attr, val) 17 | else: 18 | for data_loader in self.data_loaders: 19 | data_loader.__setattr__(attr, val) 20 | 21 | def __iter__(self): 22 | return zip(*[d.__iter__() for d in self.data_loaders]) 23 | 24 | def __len__(self): 25 | return min(len(d) for d in self.data_loaders) 26 | 27 | 28 | class InverseDataset(Dataset): 29 | """Wrapper class for inverting a dataset 30 | """ 31 | 32 | def __init__(self, dataset : Dataset): 33 | 34 | self.dataset = dataset 35 | 36 | def __getitem__(self, index: int): 37 | """ 38 | Args: 39 | index (int): Index 40 | Returns: 41 | tuple: (sample, target) where target is class_index of the target class. 42 | """ 43 | 44 | output = self.dataset[len(self.dataset) - index - 1] 45 | 46 | return output 47 | 48 | def __len__(self) -> int: 49 | return len(self.dataset) 50 | 51 | 52 | class FocalLoss(nn.Module): 53 | """Implement Focal Loss. 54 | 55 | Args: 56 | gamma (int, optional): 57 | """ 58 | def __init__(self, gamma: int=2): 59 | super(FocalLoss, self).__init__() 60 | self.gamma = gamma 61 | 62 | def focal_loss(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 63 | """Focal loss 64 | 65 | Args: 66 | x: (tensor) sized [N, D]. 67 | y: (tensor) sized [N, D]. 68 | 69 | Return: 70 | (tensor) focal loss. 71 | """ 72 | 73 | p = torch.sigmoid(x) 74 | pt = p * y + (1 - p) * (1 - y) 75 | w = (1 - pt).pow(self.gamma) 76 | w = F.normalize(w, p=1, dim=1) 77 | w.requires_grad = False 78 | 79 | return F.binary_cross_entropy_with_logits(x, y, w, reduction='sum') 80 | 81 | def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 82 | """Compute focal loss between preds targets. 83 | 84 | Args: 85 | preds: (tensor) predicted labels, sized [batch_size, classes_num]. 86 | targets: (tensor) target labels, sized [batch_size, classes_num]. 87 | """ 88 | 89 | cls_loss = self.focal_loss(preds, targets) 90 | 91 | return cls_loss -------------------------------------------------------------------------------- /oneshot/pytorch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class ZippedDataLoader(DataLoader): 9 | """Wrapper class for zipping together several dataloaders. 10 | """ 11 | def __init__(self, *data_loaders): 12 | self.data_loaders = data_loaders 13 | 14 | def __setattr__(self, attr, val): 15 | if attr == "data_loaders": 16 | super(ZippedDataLoader, self).__setattr__(attr, val) 17 | else: 18 | for data_loader in self.data_loaders: 19 | data_loader.__setattr__(attr, val) 20 | 21 | def __iter__(self): 22 | return zip(*[d.__iter__() for d in self.data_loaders]) 23 | 24 | def __len__(self): 25 | return min(len(d) for d in self.data_loaders) 26 | 27 | 28 | class InverseDataset(Dataset): 29 | """Wrapper class for inverting a dataset 30 | """ 31 | 32 | def __init__(self, dataset : Dataset): 33 | 34 | self.dataset = dataset 35 | 36 | def __getitem__(self, index: int): 37 | """ 38 | Args: 39 | index (int): Index 40 | Returns: 41 | tuple: (sample, target) where target is class_index of the target class. 42 | """ 43 | 44 | output = self.dataset[len(self.dataset) - index - 1] 45 | 46 | return output 47 | 48 | def __len__(self) -> int: 49 | return len(self.dataset) 50 | 51 | 52 | class FocalLoss(nn.Module): 53 | """Implement Focal Loss. 54 | 55 | Args: 56 | gamma (int, optional): 57 | alpha (float, optional): 58 | mask (tensor, optional): Used to mask non used classes. 59 | """ 60 | def __init__(self, gamma: int=2, alpha: float=0.25, mask: torch.Tensor=None): 61 | super(FocalLoss, self).__init__() 62 | self.gamma = gamma 63 | self.alpha = alpha 64 | if mask is None: 65 | self.register_buffer("mask", torch.ones(1)) 66 | else: 67 | self.register_buffer("mask", mask) 68 | 69 | def focal_loss(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 70 | """Focal loss 71 | 72 | Args: 73 | x: (tensor) sized [N, D]. 74 | y: (tensor) sized [N, D]. 75 | 76 | Return: 77 | (tensor) focal loss. 78 | """ 79 | 80 | var_x = x.detach() 81 | p = torch.sigmoid(var_x) 82 | pt = p * y + (1 - p) * (1 - y) 83 | w = self.alpha * y + (1 - self.alpha) * (1 - y) 84 | w = w * (1 - pt).pow(self.gamma) 85 | w = F.normalize(w, p=1, dim=1) 86 | w = w * self.mask 87 | #w.requires_grad = False 88 | 89 | return F.binary_cross_entropy_with_logits(x, y, w, reduction='sum') 90 | 91 | def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 92 | """Compute focal loss between preds targets. 93 | 94 | Args: 95 | preds: (tensor) predicted labels, sized [batch_size, classes_num]. 96 | targets: (tensor) target labels, sized [batch_size, classes_num]. 97 | """ 98 | 99 | cls_loss = self.focal_loss(preds, targets) 100 | 101 | return cls_loss -------------------------------------------------------------------------------- /oneshot/setops_models/__init__.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for set-operations networks (LaSO). 2 | """ 3 | 4 | from .ae_setops import SetOpsAEModule 5 | 6 | from .discriminators import AmitDiscriminator 7 | from .discriminators import Discriminator1Layer 8 | from .discriminators import Discriminator2Layer 9 | 10 | from .inception import Inception3 11 | from .inception import Inception3Classifier 12 | from .inception import Inception3SpatialAdapter 13 | from .inception import Inception3SpatialAdapter_6e 14 | from .inception import inception3_ids 15 | from .inception import Inception3_6e 16 | from .inception import SpatialConvolution 17 | from .inception import SpatialConvolution_v1 18 | from .inception import SetopsSpatialAdapter 19 | from .inception import SetopsSpatialAdapter_v1 20 | from .inception import SetopsSpatialAdapter_6e 21 | 22 | from .res_setops import SetopResBasicBlock 23 | from .res_setops import SetopResBasicBlock_v1 24 | from .res_setops import SetopResBlock 25 | from .res_setops import SetopResBlock_v1 26 | from .res_setops import SetopResBlock_v2 27 | from .res_setops import SetOpsResModule 28 | 29 | from .resnet import ResNet 30 | from .resnet import ResNetClassifier 31 | from .resnet import resnet18 32 | from .resnet import resnet18_ids 33 | from .resnet import resnet18_ids_pre_v2 34 | from .resnet import resnet18_ids_pre_v3 35 | from .resnet import resnet34 36 | from .resnet import resnet34_ids 37 | from .resnet import resnet34_v2 38 | from .resnet import resnet34_ids_pre 39 | from .resnet import resnet34_ids_pre_v2 40 | from .resnet import resnet34_ids_pre_v3 41 | from .resnet import resnet50 42 | from .resnet import resnet50_ids_pre_v3 43 | from .resnet import resnet101 44 | from .resnet import resnet152 45 | 46 | from .setops import AttrsClassifier 47 | from .setops import AttrsClassifier_v2 48 | from .setops import CelebAAttrClassifier 49 | from .setops import IDsEmbedding 50 | from .setops import SetOpBlock 51 | from .setops import SetOpBlock_v2 52 | from .setops import SetOpBlock_v3 53 | from .setops import SetOpBlock_v4 54 | from .setops import SetOpBlock_v5 55 | from .setops import PaperGenerator 56 | from .setops import SetOpsModule 57 | from .setops import SetOpsModule_v2 58 | from .setops import SetOpsModule_v3 59 | from .setops import SetOpsModule_v4 60 | from .setops import SetOpsModule_v5 61 | from .setops import SetOpsModule_v6 62 | from .setops import SetOpsModulePaper 63 | from .setops import TopLayer -------------------------------------------------------------------------------- /oneshot/setops_models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/setops_models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/setops_models/__pycache__/ae_setops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/setops_models/__pycache__/ae_setops.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/setops_models/__pycache__/discriminators.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/setops_models/__pycache__/discriminators.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/setops_models/__pycache__/inception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/setops_models/__pycache__/inception.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/setops_models/__pycache__/res_setops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/setops_models/__pycache__/res_setops.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/setops_models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/setops_models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/setops_models/__pycache__/setops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/setops_models/__pycache__/setops.cpython-36.pyc -------------------------------------------------------------------------------- /oneshot/setops_models/ae_setops.py: -------------------------------------------------------------------------------- 1 | """Auto Encoder set operations. 2 | """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import sys 8 | 9 | 10 | class BasicLayer(nn.Module): 11 | """A basic linear++ layer.. 12 | 13 | Applies Linear+BN+leaky-relu on the input. 14 | """ 15 | def __init__(self, dim, **kwargs): 16 | super(BasicLayer, self).__init__() 17 | 18 | self.fc = nn.Linear(dim, dim) 19 | self.bn = nn.BatchNorm1d(dim) 20 | self.relu = nn.LeakyReLU(0.2, inplace=True) 21 | 22 | def forward(self, x): 23 | 24 | out = self.fc(x) 25 | out = self.bn(out) 26 | out = self.relu(out) 27 | 28 | return out 29 | 30 | 31 | class SetopEncoderDecoder(nn.Module): 32 | """Basic Set-Operation Encoder Decoder Module. 33 | 34 | Args: 35 | input_dim: 36 | layers_num: 37 | arithm_op: 38 | """ 39 | 40 | def __init__( 41 | self, 42 | input_dim: int, 43 | latent_dim: int, 44 | output_dim: int, 45 | layers_num: int, 46 | dropout_ratio: float, 47 | **kwargs): 48 | 49 | super(SetopEncoderDecoder, self).__init__() 50 | 51 | self.in_net = nn.Sequential( 52 | nn.Linear(input_dim, latent_dim), 53 | nn.BatchNorm1d(latent_dim), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | ) 56 | 57 | # 58 | # Build the network. 59 | # 60 | self.layers = [] 61 | for i in range(layers_num): 62 | layer_name = "ae_layer{}".format(i) 63 | setattr(self, layer_name, BasicLayer(latent_dim, **kwargs)) 64 | self.layers.append(layer_name) 65 | 66 | self.out_net = nn.Sequential( 67 | nn.Linear(latent_dim, output_dim), 68 | nn.Dropout(dropout_ratio), 69 | nn.ReLU(inplace=True), 70 | ) 71 | 72 | def forward(self, x: torch.Tensor) -> torch.Tensor: 73 | 74 | out = self.in_net(x) 75 | 76 | for layer_name in self.layers: 77 | layer = getattr(self, layer_name) 78 | out = layer(out) 79 | 80 | out = self.out_net(out) 81 | 82 | return out 83 | 84 | 85 | def subrelu(x, y): 86 | return F.relu(x-y) 87 | 88 | 89 | class SetOpsAEModule(nn.Module): 90 | def __init__( 91 | self, 92 | input_dim: int, 93 | latent_dim: int, 94 | encoder_dim: int, 95 | layers_num: int, 96 | encoder_cls_name: str="SetopEncoderDecoder", 97 | decoder_cls_name: str="SetopEncoderDecoder", 98 | dropout_ratio: float=0.5, 99 | **kwargs): 100 | 101 | super(SetOpsAEModule, self).__init__() 102 | 103 | encoder_cls = getattr(sys.modules[__name__], encoder_cls_name) 104 | decoder_cls = getattr(sys.modules[__name__], decoder_cls_name) 105 | 106 | self.encoder = encoder_cls( 107 | input_dim=input_dim, 108 | latent_dim=latent_dim, 109 | output_dim=encoder_dim, 110 | layers_num=layers_num, 111 | dropout_ratio=dropout_ratio, 112 | **kwargs 113 | ) 114 | self.decoder = decoder_cls( 115 | input_dim=encoder_dim, 116 | latent_dim=latent_dim, 117 | output_dim=input_dim, 118 | layers_num=layers_num, 119 | dropout_ratio=dropout_ratio, 120 | **kwargs 121 | ) 122 | 123 | self.subtract_op = subrelu 124 | self.intersect_op = torch.min 125 | self.union_op = torch.add 126 | 127 | def forward(self, a, b): 128 | 129 | a = self.encoder(a) 130 | b = self.encoder(b) 131 | 132 | a_S_b = self.subtract_op(a, b) 133 | b_S_a = self.subtract_op(b, a) 134 | 135 | a_S_b_b = self.subtract_op(a_S_b, b) 136 | b_S_a_a = self.subtract_op(b_S_a, a) 137 | 138 | a_I_b = self.intersect_op(a, b) 139 | b_I_a = self.intersect_op(b, a) 140 | 141 | a_S_b_I_a = self.subtract_op(a, b_I_a) 142 | b_S_a_I_b = self.subtract_op(b, a_I_b) 143 | a_S_a_I_b = self.subtract_op(a, a_I_b) 144 | b_S_b_I_a = self.subtract_op(b, b_I_a) 145 | 146 | a_I_b_b = self.intersect_op(a_I_b, b) 147 | b_I_a_a = self.intersect_op(b_I_a, a) 148 | 149 | a_U_b = self.union_op(a, b) 150 | b_U_a = self.union_op(b, a) 151 | 152 | a_U_b_b = self.union_op(a_U_b, b) 153 | b_U_a_a = self.union_op(b_U_a, a) 154 | 155 | out_a = self.union_op(a_S_b_I_a, a_I_b) 156 | out_b = self.union_op(b_S_a_I_b, b_I_a) 157 | 158 | outputs = [out_a, out_b, a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a, 159 | a_S_b_b, b_S_a_a, a_I_b_b, b_I_a_a, a_U_b_b, b_U_a_a, 160 | a_S_b_I_a, b_S_a_I_b, a_S_a_I_b, b_S_b_I_a, a, b] 161 | 162 | outputs = [self.decoder(o) for o in outputs] 163 | 164 | return outputs 165 | -------------------------------------------------------------------------------- /oneshot/setops_models/discriminators.py: -------------------------------------------------------------------------------- 1 | """Models of discriminators, i.e. classifiers of unseen classes. 2 | """ 3 | 4 | from torch import nn 5 | from oneshot.coco import COCO_LABELS_NUM 6 | 7 | 8 | class AmitDiscriminator(nn.Module): 9 | def __init__(self, input_dim, latent_dim, n_classes=COCO_LABELS_NUM, dropout_ratio=0.5, **kwargs): 10 | 11 | super(AmitDiscriminator, self).__init__() 12 | 13 | self.linear_block = nn.Sequential( 14 | nn.Linear(input_dim, latent_dim), 15 | nn.BatchNorm1d(latent_dim), 16 | nn.LeakyReLU(0.2, inplace=True), 17 | nn.Dropout(p=dropout_ratio), 18 | nn.Linear(latent_dim, latent_dim), 19 | nn.BatchNorm1d(latent_dim), 20 | nn.LeakyReLU(0.2, inplace=True), 21 | nn.Dropout(p=dropout_ratio), 22 | nn.Linear(latent_dim, latent_dim), 23 | nn.BatchNorm1d(latent_dim), 24 | nn.LeakyReLU(0.2, inplace=True) 25 | ) 26 | 27 | # 28 | # Output layers 29 | # 30 | self.aux_layer = nn.Sequential(nn.Linear(latent_dim, n_classes)) 31 | 32 | def forward(self, feature_vec): 33 | 34 | out = self.linear_block(feature_vec) 35 | label = self.aux_layer(out) 36 | return label 37 | 38 | 39 | class Discriminator1Layer(nn.Module): 40 | 41 | def __init__(self, input_dim, n_classes=COCO_LABELS_NUM, **kwargs): 42 | super(Discriminator1Layer, self).__init__() 43 | self.fc = nn.Linear(input_dim, n_classes) 44 | 45 | def forward(self, x): 46 | return self.fc(x) 47 | 48 | 49 | class Discriminator2Layer(nn.Module): 50 | 51 | def __init__(self, input_dim, latent_dim, n_classes=COCO_LABELS_NUM, dropout_ratio=0.5, **kwargs): 52 | 53 | super(Discriminator2Layer, self).__init__() 54 | 55 | self.linear_block = nn.Sequential( 56 | nn.Linear(input_dim, latent_dim), 57 | nn.BatchNorm1d(latent_dim), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | nn.Dropout(p=dropout_ratio), 60 | ) 61 | 62 | # 63 | # Output layers 64 | # 65 | self.aux_layer = nn.Sequential(nn.Linear(latent_dim, n_classes)) 66 | 67 | def forward(self, feature_vec): 68 | 69 | out = self.linear_block(feature_vec) 70 | label = self.aux_layer(out) 71 | 72 | return label 73 | -------------------------------------------------------------------------------- /oneshot/setops_models/vae_setops.py: -------------------------------------------------------------------------------- 1 | """Variational Auto Encoder set operations. 2 | 3 | Taken from: 4 | https://github.com/pytorch/examples/blob/master/vae/main.py 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import sys 11 | 12 | 13 | class VAE(nn.Module): 14 | def __init__(self, input_dim=2048, vae_dim=128): 15 | super(VAE, self).__init__() 16 | 17 | self.fc0 = nn.Linear(input_dim, 784) 18 | self.fc1 = nn.Linear(784, 400) 19 | self.fc21 = nn.Linear(400, vae_dim) 20 | self.fc22 = nn.Linear(400, vae_dim) 21 | self.fc3 = nn.Linear(vae_dim, 400) 22 | self.fc4 = nn.Linear(400, 784) 23 | self.fc5 = nn.Linear(784, input_dim) 24 | 25 | def encode(self, x): 26 | h = F.relu(self.fc0(x)) 27 | h = F.relu(self.fc1(h)) 28 | return self.fc21(h), self.fc22(h) 29 | 30 | def reparameterize(self, mu, logvar): 31 | 32 | std = torch.exp(0.5*logvar) 33 | eps = torch.randn_like(std) 34 | 35 | return eps.mul(std).add_(mu) 36 | 37 | def decode(self, z): 38 | h = F.relu(self.fc3(z)) 39 | h = F.relu(self.fc4(h)) 40 | return F.relu(self.fc5(h)) 41 | 42 | def forward(self, x): 43 | mu, logvar = self.encode(x.view(-1, 784)) 44 | z = self.reparameterize(mu, logvar) 45 | return self.decode(z), mu, logvar 46 | 47 | 48 | def loss_function(recon_loss, recon_x, x, mu, logvar): 49 | """Reconstruction + KL divergence losses summed over all elements and batch 50 | """ 51 | 52 | loss = recon_loss(recon_x, x) 53 | 54 | # see Appendix B from VAE paper: 55 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 56 | # https://arxiv.org/abs/1312.6114 57 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 58 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 59 | 60 | return loss + KLD 61 | 62 | 63 | def subrelu(x, y): 64 | return F.relu(x-y) 65 | 66 | 67 | class SetOpsVAEModule(nn.Module): 68 | def __init__( 69 | self, 70 | input_dim: int, 71 | vae_dim: int, 72 | vae_cls_name: str="VAE", 73 | **kwargs): 74 | 75 | super(SetOpsVAEModule, self).__init__() 76 | 77 | vae_cls = getattr(sys.modules[__name__], vae_cls_name) 78 | 79 | self.vae = vae_cls( 80 | input_dim=input_dim, 81 | vae_dim=vae_dim, 82 | **kwargs 83 | ) 84 | 85 | self.subtract_op = subrelu 86 | self.intersect_op = torch.min 87 | self.union_op = torch.add 88 | 89 | def forward(self, a, b): 90 | 91 | a, logvar_a = self.vae.encode(a) 92 | b, logvar_b = self.vae.encode(b) 93 | 94 | logvar = (logvar_a + logvar_b) / 2 95 | 96 | a_S_b = self.subtract_op(a, b) 97 | b_S_a = self.subtract_op(b, a) 98 | 99 | a_S_b_b = self.subtract_op(a_S_b, b) 100 | b_S_a_a = self.subtract_op(b_S_a, a) 101 | 102 | a_I_b = self.intersect_op(a, b) 103 | b_I_a = self.intersect_op(b, a) 104 | 105 | a_S_b_I_a = self.subtract_op(a, b_I_a) 106 | b_S_a_I_b = self.subtract_op(b, a_I_b) 107 | a_S_a_I_b = self.subtract_op(a, a_I_b) 108 | b_S_b_I_a = self.subtract_op(b, b_I_a) 109 | 110 | a_I_b_b = self.intersect_op(a_I_b, b) 111 | b_I_a_a = self.intersect_op(b_I_a, a) 112 | 113 | a_U_b = self.union_op(a, b) 114 | b_U_a = self.union_op(b, a) 115 | 116 | a_U_b_b = self.union_op(a_U_b, b) 117 | b_U_a_a = self.union_op(b_U_a, a) 118 | 119 | out_a = self.union_op(a_S_b_I_a, a_I_b) 120 | out_b = self.union_op(b_S_a_I_b, b_I_a) 121 | 122 | outputs = [out_a, out_b, a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a, 123 | a_S_b_b, b_S_a_a, a_I_b_b, b_I_a_a, a_U_b_b, b_U_a_a, 124 | a_S_b_I_a, b_S_a_I_b, a_S_a_I_b, b_S_b_I_a] 125 | 126 | outputs = [self.vae.decode(self.vae.reparameterize(o, logvar)) for o in outputs] 127 | 128 | return outputs, a, logvar_a, b, logvar_b 129 | -------------------------------------------------------------------------------- /oneshot/stn.py: -------------------------------------------------------------------------------- 1 | """Spatial Transformer Network. 2 | 3 | Example taken from: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html 4 | """ 5 | 6 | import logging 7 | import math 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from .wideresnet_places import BasicBlock 12 | #from .wideresnet_places import ResNet 13 | 14 | 15 | class STNResNet(nn.Module): 16 | 17 | def __init__(self, block, layers): 18 | self.inplanes = 64 19 | super(STNResNet, self).__init__() 20 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 21 | bias=False) 22 | self.bn1 = nn.BatchNorm2d(64) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | self.layer1 = self._make_layer(block, 64, layers[0]) 26 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 27 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 28 | self.layer4 = self._make_layer(block, 32, layers[3], stride=2) 29 | self.maxpool = nn.MaxPool2d(2, stride=2) 30 | 31 | for m in self.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 34 | m.weight.data.normal_(0, math.sqrt(2. / n)) 35 | elif isinstance(m, nn.BatchNorm2d): 36 | #m.weight.data.fill_(1) 37 | #m.bias.data.zero_() 38 | nn.init.constant_(m.weight, 1) 39 | nn.init.constant_(m.bias, 0) 40 | 41 | def _make_layer(self, block, planes, blocks, stride=1): 42 | downsample = None 43 | if stride != 1 or self.inplanes != planes * block.expansion: 44 | downsample = nn.Sequential( 45 | nn.Conv2d(self.inplanes, planes * block.expansion, 46 | kernel_size=1, stride=stride, bias=False), 47 | nn.BatchNorm2d(planes * block.expansion), 48 | ) 49 | 50 | layers = [] 51 | layers.append(block(self.inplanes, planes, stride, downsample)) 52 | self.inplanes = planes * block.expansion 53 | for i in range(1, blocks): 54 | layers.append(block(self.inplanes, planes)) 55 | 56 | return nn.Sequential(*layers) 57 | 58 | def forward(self, x): 59 | x = self.conv1(x) 60 | x = self.bn1(x) 61 | x = self.relu(x) 62 | #x = self.maxpool(x) 63 | 64 | x = self.layer1(x) 65 | x = self.layer2(x) 66 | x = self.layer3(x) 67 | x = self.layer4(x) 68 | 69 | x = self.maxpool(x) 70 | x = self.relu(x) 71 | 72 | return x 73 | 74 | 75 | class MultipleOptimizer(object): 76 | def __init__(self, *op): 77 | self.optimizers = op 78 | 79 | def zero_grad(self): 80 | for op in self.optimizers: 81 | op.zero_grad() 82 | 83 | def step(self): 84 | for op in self.optimizers: 85 | op.step() 86 | 87 | 88 | class STN(nn.Module): 89 | def __init__(self, reset_fc_loc=True): 90 | super(STN, self).__init__() 91 | 92 | # Spatial transformer localization-network 93 | self.localization = STNResNet( 94 | BasicBlock, 95 | layers=[1, 1, 1, 1], 96 | ) 97 | 98 | # Regressor for the 3 * 2 affine matrix 99 | self.fc_loc = nn.Sequential( 100 | nn.Linear(32*7*7, 64), 101 | nn.ReLU(True), 102 | nn.Linear(64, 3 * 2) 103 | ) 104 | 105 | if reset_fc_loc: 106 | logging.info("Resetting the fc_loc network") 107 | # Initialize the weights/bias with identity transformation 108 | self.fc_loc[2].weight.data.zero_() 109 | self.fc_loc[2].bias.data.copy_( 110 | torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float) 111 | ) 112 | 113 | def forward(self, x): 114 | # 115 | # Calculate the transform 116 | # 117 | xs = self.localization(x) 118 | xs = xs.view(-1, 32*7*7) 119 | 120 | theta = self.fc_loc(xs) 121 | theta = theta.view(-1, 2, 3) 122 | 123 | grid = F.affine_grid(theta, x.size()) 124 | 125 | # 126 | # transform the input 127 | # 128 | x = F.grid_sample(x, grid) 129 | 130 | return x 131 | 132 | -------------------------------------------------------------------------------- /oneshot/triplet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leokarlin/LaSO/eea81de6046cc0817fc81a1d1570d967c1bb0aef/oneshot/triplet/__init__.py -------------------------------------------------------------------------------- /oneshot/triplet/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ContrastiveLoss(nn.Module): 7 | """ 8 | Contrastive loss 9 | Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise 10 | """ 11 | 12 | def __init__(self, margin): 13 | super(ContrastiveLoss, self).__init__() 14 | self.margin = margin 15 | 16 | def forward(self, output1, output2, target, size_average=True): 17 | distances = (output2 - output1).pow(2).sum(1) # squared distances 18 | losses = 0.5 * (target.float() * distances + 19 | (1 + -1 * target).float() * F.relu(self.margin - distances.sqrt()).pow(2)) 20 | return losses.mean() if size_average else losses.sum() 21 | 22 | 23 | class TripletLoss(nn.Module): 24 | """ 25 | Triplet loss 26 | Takes embeddings of an anchor sample, a positive sample and a negative sample 27 | """ 28 | 29 | def __init__(self, margin): 30 | super(TripletLoss, self).__init__() 31 | self.margin = margin 32 | 33 | def forward(self, anchor, positive, negative, size_average=True): 34 | distance_positive = (anchor - positive).pow(2).sum(1) # .pow(.5) 35 | distance_negative = (anchor - negative).pow(2).sum(1) # .pow(.5) 36 | losses = F.relu(distance_positive - distance_negative + self.margin) 37 | return losses.mean() if size_average else losses.sum() 38 | 39 | 40 | class OnlineContrastiveLoss(nn.Module): 41 | """Online Contrastive loss 42 | 43 | Takes a batch of embeddings and corresponding labels. 44 | Pairs are generated using pair_selector object that take embeddings and 45 | targets and return indices of positive and negative pairs. 46 | """ 47 | 48 | def __init__(self, margin, pair_selector): 49 | super(OnlineContrastiveLoss, self).__init__() 50 | self.margin = margin 51 | self.pair_selector = pair_selector 52 | 53 | def forward(self, embeddings, target): 54 | positive_pairs, negative_pairs = self.pair_selector.get_pairs(embeddings, target) 55 | if embeddings.is_cuda: 56 | positive_pairs = positive_pairs.cuda() 57 | negative_pairs = negative_pairs.cuda() 58 | positive_loss = (embeddings[positive_pairs[:, 0]] - embeddings[positive_pairs[:, 1]]).pow(2).sum(1) 59 | negative_loss = F.relu( 60 | self.margin - (embeddings[negative_pairs[:, 0]] - embeddings[negative_pairs[:, 1]]).pow(2).sum( 61 | 1).sqrt()).pow(2) 62 | loss = torch.cat([positive_loss, negative_loss], dim=0) 63 | return loss.mean() 64 | 65 | 66 | class OnlineTripletLoss(nn.Module): 67 | """Online Triplets loss 68 | 69 | Takes a batch of embeddings and corresponding labels. 70 | Triplets are generated using triplet_selector object that take embeddings 71 | and targets and return indices of triplets. 72 | """ 73 | 74 | def __init__(self, margin, triplet_selector, as_metric=False): 75 | super(OnlineTripletLoss, self).__init__() 76 | self.margin = margin 77 | self.triplet_selector = triplet_selector 78 | self.as_metric = as_metric 79 | 80 | def forward(self, embeddings, target): 81 | 82 | triplets = self.triplet_selector.get_triplets(embeddings, target) 83 | 84 | if embeddings.is_cuda: 85 | triplets = triplets.cuda() 86 | 87 | ap_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 1]]).pow(2).sum(1) # .pow(.5) 88 | an_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 2]]).pow(2).sum(1) # .pow(.5) 89 | losses = F.relu(ap_distances - an_distances + self.margin) 90 | 91 | if self.as_metric: 92 | # 93 | # This returns the number of the hard negatives found in this batch. 94 | # This is a kind of a metric. Note, the hack to make it a zero dim 95 | # tensor. 96 | # 97 | return torch.Tensor((len(triplets),))[0] 98 | 99 | return losses.mean() -------------------------------------------------------------------------------- /oneshot/triplet/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics=[], 6 | start_epoch=0): 7 | """ 8 | Loaders, model, loss function and metrics should work together for a given task, 9 | i.e. The model should be able to process data output of loaders, 10 | loss function should process target output of loaders and outputs from the model 11 | Examples: Classification: batch loader, classification model, NLL loss, accuracy metric 12 | Siamese network: Siamese loader, siamese model, contrastive loss 13 | Online triplet learning: batch loader, embedding model, online triplet loss 14 | """ 15 | for epoch in range(0, start_epoch): 16 | scheduler.step() 17 | 18 | for epoch in range(start_epoch, n_epochs): 19 | scheduler.step() 20 | 21 | # Train stage 22 | train_loss, metrics = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics) 23 | 24 | message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss) 25 | for metric in metrics: 26 | message += '\t{}: {}'.format(metric.name(), metric.value()) 27 | 28 | val_loss, metrics = test_epoch(val_loader, model, loss_fn, cuda, metrics) 29 | val_loss /= len(val_loader) 30 | 31 | message += '\nEpoch: {}/{}. Validation set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, 32 | val_loss) 33 | for metric in metrics: 34 | message += '\t{}: {}'.format(metric.name(), metric.value()) 35 | 36 | print(message) 37 | 38 | 39 | def train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics): 40 | for metric in metrics: 41 | metric.reset() 42 | 43 | model.train() 44 | losses = [] 45 | total_loss = 0 46 | 47 | for batch_idx, (data, target) in enumerate(train_loader): 48 | target = target if len(target) > 0 else None 49 | if not type(data) in (tuple, list): 50 | data = (data,) 51 | if cuda: 52 | data = tuple(d.cuda() for d in data) 53 | if target is not None: 54 | target = target.cuda() 55 | 56 | 57 | optimizer.zero_grad() 58 | outputs = model(*data) 59 | 60 | if type(outputs) not in (tuple, list): 61 | outputs = (outputs,) 62 | 63 | loss_inputs = outputs 64 | if target is not None: 65 | target = (target,) 66 | loss_inputs += target 67 | 68 | loss_outputs = loss_fn(*loss_inputs) 69 | loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs 70 | losses.append(loss.item()) 71 | total_loss += loss.item() 72 | loss.backward() 73 | optimizer.step() 74 | 75 | for metric in metrics: 76 | metric(outputs, target, loss_outputs) 77 | 78 | if batch_idx % log_interval == 0: 79 | message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 80 | batch_idx * len(data[0]), len(train_loader.dataset), 81 | 100. * batch_idx / len(train_loader), np.mean(losses)) 82 | for metric in metrics: 83 | message += '\t{}: {}'.format(metric.name(), metric.value()) 84 | 85 | print(message) 86 | losses = [] 87 | 88 | total_loss /= (batch_idx + 1) 89 | return total_loss, metrics 90 | 91 | 92 | def test_epoch(val_loader, model, loss_fn, cuda, metrics): 93 | with torch.no_grad(): 94 | for metric in metrics: 95 | metric.reset() 96 | model.eval() 97 | val_loss = 0 98 | for batch_idx, (data, target) in enumerate(val_loader): 99 | target = target if len(target) > 0 else None 100 | if not type(data) in (tuple, list): 101 | data = (data,) 102 | if cuda: 103 | data = tuple(d.cuda() for d in data) 104 | if target is not None: 105 | target = target.cuda() 106 | 107 | outputs = model(*data) 108 | 109 | if type(outputs) not in (tuple, list): 110 | outputs = (outputs,) 111 | loss_inputs = outputs 112 | if target is not None: 113 | target = (target,) 114 | loss_inputs += target 115 | 116 | loss_outputs = loss_fn(*loss_inputs) 117 | loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs 118 | val_loss += loss.item() 119 | 120 | for metric in metrics: 121 | metric(outputs, target, loss_outputs) 122 | 123 | return val_loss, metrics -------------------------------------------------------------------------------- /oneshot/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | class conditional(object): 3 | """Wrap another context manager and enter it only if condition is true. 4 | """ 5 | 6 | def __init__(self, condition, contextmanager): 7 | self.condition = condition 8 | self.contextmanager = contextmanager 9 | 10 | def __enter__(self): 11 | if self.condition: 12 | return self.contextmanager.__enter__() 13 | 14 | def __exit__(self, *args): 15 | if self.condition: 16 | return self.contextmanager.__exit__(*args) 17 | 18 | 19 | def setupCUDAdevice(cuda_visible_device=None): 20 | """Setup `CUDA_VISIBLE_DEVICES` environment variable. 21 | 22 | The `CUDA_VISIBLE_DEVICES` environment variable is used for determining the 23 | GPU available to a process. It is automatically set by the `jbsub` command. 24 | In some situations a user would like to connect to a GPU node not through 25 | a job, e.g. using a remote debugger. In this cases the `CUDA_VISIBLE_DEVICES` 26 | will not be set, even if the user has secured a GPU through a separate 27 | interactive session. `setupCUDAdevice` can be used to query the user for the 28 | available device number. It does nothing when `CUDA_VISIBLE_DEVICES` is 29 | already set. 30 | 31 | Args: 32 | cuda_visible_device (int): Hard code a device number (will by pass 33 | user input). 34 | 35 | .. note:: 36 | Using hard coded `cuda_visible_device` should be used carefully. 37 | """ 38 | 39 | if "CUDA_VISIBLE_DEVICES" not in os.environ or \ 40 | os.environ["CUDA_VISIBLE_DEVICES"] == "": 41 | if cuda_visible_device is None: 42 | cuda_visible_device = input("CUDA_VISIBLE_DEVICES: ") 43 | 44 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_visible_device) 45 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pip==9.0.1 2 | bumpversion==0.5.3 3 | wheel==0.30.0 4 | watchdog==0.8.3 5 | flake8==3.5.0 6 | tox==2.9.1 7 | coverage==4.5.1 8 | Sphinx==1.7.1 9 | twine==1.10.0 10 | 11 | 12 | -------------------------------------------------------------------------------- /scripts_coco/example_use.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Example of how to use the LaSO model in your code. 3 | ''' 4 | 5 | # import main from the appropriate script 6 | from scripts_coco.train_setops_stripped import Main 7 | # from scripts_coco.test_retrieval import Main 8 | # from scripts_coco.test_precision import Main 9 | 10 | # define an instance of the main class 11 | main_instance = Main() 12 | # define model paramters 13 | main_instance.coco_path = '/dccstor/leonidka1/data/coco' 14 | main_instance.epochs = 50 15 | # initialize model 16 | main_instance.initialize() 17 | # run model - test/ train 18 | main_instance.start() 19 | 20 | # test output folder structure 21 | ''' 22 | (kef) [amitalfa@dccxc203 initial_layers]$ ll /dccstor/alfassy/results/test_retrieval/0218_8977e29/681989/190428_081951/ 23 | total 101137 24 | -rw-r--r-- 1 amitalfa users 113 Apr 28 08:19 cmdline.txt 25 | -rw-r--r-- 1 amitalfa users 4684 Apr 28 08:19 config.py 26 | -rw-r--r-- 1 amitalfa users 192518 Apr 28 08:19 git_diff.txt 27 | -rw-r--r-- 1 amitalfa users 15919239 Apr 28 10:21 results_a_I_b.pkl 28 | -rw-r--r-- 1 amitalfa users 16202039 Apr 28 10:59 results_a.pkl 29 | -rw-r--r-- 1 amitalfa users 7936839 Apr 28 09:01 results_a_S_b.pkl 30 | -rw-r--r-- 1 amitalfa users 11644839 Apr 28 09:41 results_a_U_b.pkl 31 | -rw-r--r-- 1 amitalfa users 15919239 Apr 28 10:41 results_b_I_a.pkl 32 | -rw-r--r-- 1 amitalfa users 16202039 Apr 28 11:18 results_b.pkl 33 | -rw-r--r-- 1 amitalfa users 7840439 Apr 28 09:21 results_b_S_a.pkl 34 | -rw-r--r-- 1 amitalfa users 11644839 Apr 28 10:01 results_b_U_a.pkl 35 | -rw-r--r-- 1 amitalfa users 2001 Apr 28 11:18 script_log 36 | ''' 37 | 38 | # test output script_log example 39 | ''' 40 | 2019-04-28 08:19:51,766 [MainThread ] [INFO ] Created results path: /dccstor/alfassy/results/test_retrieval/0218_8977e29/681989/190428_081951 41 | 2019-04-28 08:19:51,787 [MainThread ] [INFO ] Setup the models. 42 | 2019-04-28 08:19:51,787 [MainThread ] [INFO ] Inception3 model 43 | 2019-04-28 08:19:53,886 [MainThread ] [INFO ] Initialize inception model using Amit's networks. 44 | 2019-04-28 08:20:08,524 [MainThread ] [INFO ] Resuming the models. 45 | 2019-04-28 08:20:08,524 [MainThread ] [INFO ] using paper models 46 | 2019-04-28 08:20:12,248 [MainThread ] [INFO ] Setting up the datasets. 47 | 2019-04-28 08:20:12,315 [MainThread ] [INFO ] Copying data to tmp 48 | 2019-04-28 08:22:54,842 [MainThread ] [INFO ] Calculating indices. 49 | 2019-04-28 08:23:04,983 [MainThread ] [INFO ] Calculate the validation embeddings. 50 | 2019-04-28 08:29:16,558 [MainThread ] [INFO ] Calculate the embedding NN BallTree. 51 | 2019-04-28 08:29:23,840 [MainThread ] [INFO ] Calculate test set embedding. 52 | 2019-04-28 08:41:26,312 [MainThread ] [INFO ] Calculate scores. 53 | 2019-04-28 09:01:16,456 [MainThread ] [INFO ] Test a_S_b average recall (k=1, 3, 5): [0.15774941 0.30975018 0.39174374] 54 | 2019-04-28 09:21:36,746 [MainThread ] [INFO ] Test b_S_a average recall (k=1, 3, 5): [0.15835367 0.30721486 0.3891798 ] 55 | 2019-04-28 09:41:55,912 [MainThread ] [INFO ] Test a_U_b average recall (k=1, 3, 5): [0.62102854 0.72747891 0.76325716] 56 | 2019-04-28 10:01:54,786 [MainThread ] [INFO ] Test b_U_a average recall (k=1, 3, 5): [0.62122391 0.72644851 0.76149227] 57 | 2019-04-28 10:21:42,917 [MainThread ] [INFO ] Test a_I_b average recall (k=1, 3, 5): [0.67967224 0.78322097 0.81947849] 58 | 2019-04-28 10:41:30,038 [MainThread ] [INFO ] Test b_I_a average recall (k=1, 3, 5): [0.67914967 0.78431402 0.81867714] 59 | 2019-04-28 10:59:48,137 [MainThread ] [INFO ] Test a average recall (k=1, 3, 5): [0.61569753 0.72506626 0.76083887] 60 | 2019-04-28 11:18:05,401 [MainThread ] [INFO ] Test b average recall (k=1, 3, 5): [0.61773566 0.72654131 0.76237568] 61 | ''' 62 | 63 | # train output folder structure 64 | ''' 65 | (kef) [amitalfa@dccxc203 initial_layers]$ ll /dccstor/alfassy/results/train_setops_stripped/0218_8977e29/1650167/190507_101157/ 66 | total 1158881 67 | -rw-r--r-- 1 amitalfa users 427 May 7 10:11 cmdline.txt 68 | -rw-r--r-- 1 amitalfa users 6312 May 7 10:11 config.py 69 | -rw-r--r-- 1 amitalfa users 262038 May 7 10:11 git_diff.txt 70 | -rw------- 1 amitalfa users 94301854 May 7 15:20 networks_base_model_2.pth 71 | -rw------- 1 amitalfa users 94301854 May 7 15:20 networks_base_model_2_val_acc=0.529.pth 72 | -rw------- 1 amitalfa users 94301854 May 7 17:57 networks_base_model_3_val_acc=0.526.pth 73 | -rw------- 1 amitalfa users 94301854 May 7 20:33 networks_base_model_4.pth 74 | -rw------- 1 amitalfa users 656244 May 7 15:20 networks_classifier_2.pth 75 | -rw------- 1 amitalfa users 656244 May 7 15:20 networks_classifier_2_val_acc=0.529.pth 76 | -rw------- 1 amitalfa users 656244 May 7 17:57 networks_classifier_3_val_acc=0.526.pth 77 | -rw------- 1 amitalfa users 656244 May 7 20:33 networks_classifier_4.pth 78 | -rw------- 1 amitalfa users 201606120 May 7 15:20 networks_setops_model_2.pth 79 | -rw------- 1 amitalfa users 201606120 May 7 15:20 networks_setops_model_2_val_acc=0.529.pth 80 | -rw------- 1 amitalfa users 201606120 May 7 17:57 networks_setops_model_3_val_acc=0.526.pth 81 | -rw------- 1 amitalfa users 201606120 May 7 20:33 networks_setops_model_4.pth 82 | -rw-r--r-- 1 amitalfa users 5284 May 7 23:10 script_log 83 | ''' 84 | 85 | 86 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:oneshot/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | 20 | [aliases] 21 | # Define setup.py command aliases here 22 | 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | import os 6 | import io 7 | import re 8 | from setuptools import setup, find_packages 9 | 10 | with open('README_oneshot.rst') as readme_file: 11 | readme = readme_file.read() 12 | 13 | with open('HISTORY_oneshot.rst') as history_file: 14 | history = history_file.read() 15 | 16 | requirements = [ ] 17 | 18 | setup_requirements = [ ] 19 | 20 | test_requirements = [ ] 21 | 22 | setup( 23 | author="Amit Aides", 24 | author_email='amitaid@il.ibm.com', 25 | classifiers=[ 26 | 'Development Status :: 2 - Pre-Alpha', 27 | 'Intended Audience :: Developers', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Natural Language :: English', 30 | "Programming Language :: Python :: 2", 31 | 'Programming Language :: Python :: 2.7', 32 | 'Programming Language :: Python :: 3', 33 | 'Programming Language :: Python :: 3.4', 34 | 'Programming Language :: Python :: 3.5', 35 | 'Programming Language :: Python :: 3.6', 36 | 'Programming Language :: Python :: 3.7', 37 | ], 38 | description="Experiments in One-Shot deeplearning", 39 | install_requires=requirements, 40 | license="MIT license", 41 | long_description=readme + '\n\n' + history, 42 | include_package_data=True, 43 | keywords='oneshot', 44 | name='oneshot', 45 | packages=find_packages(include=['oneshot']), 46 | setup_requires=setup_requirements, 47 | test_suite='tests', 48 | tests_require=test_requirements, 49 | url='https://github.com/amitaid/oneshot', 50 | version='0.1.0', 51 | zip_safe=False, 52 | ) 53 | 54 | 55 | with open('README_experiment.rst') as readme_file: 56 | readme = readme_file.read() 57 | 58 | with open('HISTORY_experiment.rst') as history_file: 59 | history = history_file.read() 60 | 61 | requirements = ['py3nvml', 'traitlets'] 62 | 63 | setup_requirements = ['pytest-runner', ] 64 | 65 | test_requirements = ['pytest', ] 66 | 67 | setup( 68 | author="Amit Aides", 69 | author_email='amitaid@il.ibm.com', 70 | classifiers=[ 71 | 'Development Status :: 2 - Pre-Alpha', 72 | 'Intended Audience :: Developers', 73 | 'License :: OSI Approved :: MIT License', 74 | 'Natural Language :: English', 75 | 'Programming Language :: Python :: 3.5', 76 | 'Programming Language :: Python :: 3.6', 77 | 'Programming Language :: Python :: 3.7', 78 | ], 79 | description="Framework for running experiments.", 80 | install_requires=requirements, 81 | license="MIT license", 82 | long_description=readme + '\n\n' + history, 83 | include_package_data=True, 84 | keywords='experiment', 85 | name='experiment', 86 | packages=find_packages(include=['experiment']), 87 | setup_requires=setup_requirements, 88 | test_suite='tests', 89 | tests_require=test_requirements, 90 | url='https://github.ibm.com/AMITAID/experiment', 91 | version='0.2.0', 92 | zip_safe=False, 93 | ) 94 | 95 | def read(*names, **kwargs): 96 | with io.open(os.path.join(os.path.dirname(__file__), *names), 97 | encoding=kwargs.get("encoding", "utf8")) as fp: 98 | return fp.read() 99 | 100 | 101 | readme = read('README_ignite.rst') 102 | 103 | requirements = ['enum34;python_version<"3.4"', 'futures; python_version == "2.7"', 'torch'] 104 | 105 | setup( 106 | # Metadata 107 | name='ignite', 108 | version='0.1.2', 109 | author='PyTorch Core Team', 110 | author_email='soumith@pytorch.org', 111 | url='https://github.com/pytorch/ignite', 112 | description='A lightweight library to help with training neural networks in PyTorch.', 113 | long_description=readme, 114 | license='BSD', 115 | 116 | # Package info 117 | packages=find_packages(exclude=('tests', 'tests.*',)), 118 | 119 | zip_safe=True, 120 | install_requires=requirements, 121 | ) 122 | 123 | --------------------------------------------------------------------------------