├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── img │ ├── icifar.png │ ├── icifar_expected_stdout.png │ ├── pull.png │ ├── robot.png │ ├── sidetuning_lifelong.png │ └── taskonomy.png └── pytorch │ ├── distillation │ └── fcn4-cifar.pth │ └── resnet44-cifar.pth ├── configs ├── core.py ├── doom.py ├── gibson.py ├── habitat.py ├── habitat_eval.py ├── icifar_cfg.py ├── imitation_learning.py ├── rl.py ├── rl_extra.py ├── seq_taskonomy_cfg.py ├── seq_taskonomy_cfg_extra.py ├── shared.py ├── vision_lifelong.py └── vision_transfer.py ├── evkit ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── env │ ├── __init__.py │ ├── base_embodied_env.py │ ├── distributed_factory.py │ ├── envs.py │ ├── habitat │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── default.py │ │ ├── habitatenv.py │ │ ├── habitatexpenv.py │ │ ├── habitatnavenv.py │ │ ├── utils.py │ │ └── wrapperenv.py │ ├── util │ │ ├── __init__.py │ │ ├── make_env.py │ │ ├── occupancy_map.py │ │ ├── tile_images.py │ │ └── vec_env │ │ │ ├── __init__.py │ │ │ ├── dummy_vec_env.py │ │ │ └── subproc_vec_embodied_env.py │ └── wrappers │ │ ├── __init__.py │ │ ├── preprocessingwrapper.py │ │ ├── sensorenv.py │ │ ├── skip_wrapper.py │ │ ├── tensorboardmonitor.py │ │ └── visdommonitor.py ├── models │ ├── .gitkeep │ ├── __init__.py │ ├── actor_critic_module.py │ ├── actor_critic_module_curiosity.py │ ├── alexnet.py │ ├── architectures.py │ ├── chain.py │ ├── expert.py │ ├── forward_inverse.py │ ├── shortest_path_follower.py │ ├── sidetuning.py │ ├── sparsely_gated_mixture_of_experts.py │ ├── srl_architectures.py │ ├── taskonomy_network.py │ ├── triangle.py │ └── unet.py ├── preprocess │ ├── __init__.py │ ├── baseline_transforms.py │ ├── filters.py │ ├── gaussian.py │ ├── transform_factory.py │ └── transforms.py ├── rl │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── algo │ │ ├── __init__.py │ │ ├── a2c_acktr.py │ │ ├── deepq.py │ │ ├── kfac.py │ │ ├── ppo.py │ │ ├── ppo_curiosity.py │ │ └── ppo_replay.py │ ├── distributions.py │ ├── main.py │ ├── model.py │ ├── policy.py │ ├── preprocessing.py │ ├── requirements.txt │ ├── storage │ │ ├── __init__.py │ │ ├── memory.py │ │ ├── rollout.py │ │ ├── rollout_curiosity.py │ │ ├── segment_tree.py │ │ └── stackedobservation.py │ ├── utils.py │ └── visualize.py ├── saving │ ├── __init__.py │ ├── checkpoints.py │ ├── monitor.py │ ├── naming.py │ ├── observers.py │ └── video.py ├── sensors │ ├── __init__.py │ └── sensorpack.py └── utils │ ├── __init__.py │ ├── logging.py │ ├── losses.py │ ├── misc.py │ ├── parallel.py │ ├── profiler.py │ ├── radam.py │ ├── random.py │ └── viz │ ├── __init__.py │ ├── core.py │ └── rl.py ├── feature_selector ├── __init__.py └── models │ ├── __init__.py │ ├── student_models.py │ └── vision_transfer_architectures.py ├── requirements.txt ├── scripts ├── __init__.py ├── calculate_blind_transfer.py ├── demo_icifar.py ├── demo_taskonomy.py ├── prep │ ├── collect_expert_trajs.py │ ├── copy.sh │ ├── count_num_frames.py │ ├── csv_read.py │ ├── distill.py │ ├── download.sh │ ├── get_reprs.py │ ├── make_fewshot_datasets.py │ ├── make_fewshot_squad.py │ ├── make_masks.py │ ├── move_models.py │ ├── rm_empty_exp.py │ ├── run_distillation.sh │ ├── shrink_images.py │ ├── shrink_rgb.sh │ ├── store_weights.py │ ├── subsample_expert_trajs.py │ ├── subsample_squad.py │ └── subset_tars.py ├── run_hps.py ├── run_il_exps.sh ├── run_lifelong_cifar.sh ├── run_lifelong_taskonomy.sh ├── run_nlp.sh ├── run_rl_eval.sh ├── run_rl_exps.sh ├── run_vision_transfer.sh ├── train_lifelong.py ├── train_rl.py └── train_transfer.py ├── tlkit ├── __init__.py ├── data │ ├── __init__.py │ ├── datasets │ │ ├── expert_dataset.py │ │ ├── fashion_mnist_dataset.py │ │ ├── icifar_dataset.py │ │ ├── imagenet_dataset.py │ │ └── taskonomy_dataset.py │ ├── img_transforms.py │ ├── links │ │ ├── all_links_taskonomydata.txt │ │ └── rgb_links_taskonomydata.txt │ ├── sequential_tasks_dataloaders.py │ ├── splits.py │ ├── splits_taskonomy │ │ ├── splits.txt │ │ ├── train_val_test_debug.csv │ │ ├── train_val_test_debug2.csv │ │ ├── train_val_test_few100.csv │ │ ├── train_val_test_few1000.csv │ │ ├── train_val_test_few5.csv │ │ ├── train_val_test_few500.csv │ │ ├── train_val_test_full.csv │ │ ├── train_val_test_fullplus.csv │ │ ├── train_val_test_medium.csv │ │ ├── train_val_test_supersmall.csv │ │ └── train_val_test_tiny.csv │ └── synset.py ├── logging_helpers.py ├── models │ ├── __init__.py │ ├── basic_models.py │ ├── ewc.py │ ├── feedback.py │ ├── fusion.py │ ├── lifelong_framework.py │ ├── merge_operators.py │ ├── model_utils.py │ ├── models_additional.py │ ├── resnet_cifar.py │ ├── sidetune_architecture.py │ ├── student_models.py │ ├── superposition.py │ └── vision_transfer_architectures.py └── utils.py └── tnt ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ ├── css │ │ └── pytorch_theme.css │ └── img │ │ ├── dynamic_graph.gif │ │ ├── pytorch-logo-dark.svg │ │ └── pytorch-logo-flame.svg ├── _templates │ └── layout.html ├── conf.py ├── index.rst ├── make.bat ├── requirements.txt └── source │ ├── modules.rst │ ├── torchnet.dataset.rst │ ├── torchnet.engine.rst │ ├── torchnet.logger.rst │ ├── torchnet.meter.rst │ ├── torchnet.rst │ └── torchnet.utils.rst ├── example ├── README.md ├── mnist.py ├── mnist_with_meterlogger.py └── mnist_with_visdom.py ├── requirements.txt ├── setup.py ├── test ├── run_test.sh ├── test_datasets.py ├── test_meters.py └── test_transforms.py ├── torchnet ├── __init__.py ├── dataset │ ├── __init__.py │ ├── batchdataset.py │ ├── concatdataset.py │ ├── dataset.py │ ├── listdataset.py │ ├── resampledataset.py │ ├── shuffledataset.py │ ├── splitdataset.py │ ├── tensordataset.py │ └── transformdataset.py ├── engine │ ├── __init__.py │ └── engine.py ├── logger │ ├── __init__.py │ ├── filelogger.py │ ├── logger.py │ ├── meterlogger.py │ ├── tensorboardmeterlogger.py │ ├── visdomlogger.py │ └── visdommeterlogger.py ├── meter │ ├── __init__.py │ ├── apmeter.py │ ├── aucmeter.py │ ├── averagevaluemeter.py │ ├── classerrormeter.py │ ├── confusionmeter.py │ ├── mapmeter.py │ ├── medianimagemeter.py │ ├── meter.py │ ├── movingaveragevaluemeter.py │ ├── msemeter.py │ ├── multivaluesummarymeter.py │ ├── singletonmeter.py │ ├── timemeter.py │ └── valuesummarymeter.py ├── transform.py └── utils │ ├── __init__.py │ ├── multitaskdataloader.py │ └── table.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb* 3 | __pycache__ 4 | *.pyc 5 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # This Docker was originally set up for Habitat Challenge 2 | 3 | #FROM nvidia/cudagl:9.0-base-ubuntu16.04 4 | FROM fairembodied/habitat-challenge:latest 5 | 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | cuda-samples-$CUDA_PKG_VERSION && \ 8 | rm -rf /var/lib/apt/lists/* 9 | WORKDIR /usr/local/cuda/samples/5_Simulations/nbody 10 | RUN make 11 | 12 | #CMD ./nbody 13 | 14 | RUN apt-get update && apt-get install -y curl && apt-get install -y apt-utils && apt-get install -y ffmpeg 15 | RUN conda update -y conda 16 | 17 | RUN . activate habitat 18 | ENV PATH /opt/conda/envs/habitat/bin:$PATH 19 | 20 | 21 | ############################### 22 | # set up habitat 23 | ############################### 24 | WORKDIR /root/side-tuning 25 | RUN git clone https://github.com/facebookresearch/habitat-sim.git 26 | RUN git clone https://github.com/facebookresearch/habitat-api.git 27 | 28 | WORKDIR /root/side-tuning/habitat-sim 29 | RUN conda install -y cmake 30 | RUN pip install numpy 31 | RUN python setup.py install --headless 32 | 33 | WORKDIR /root/side-tuning/habitat-api 34 | RUN git checkout 05dbf7220e8386eb2337502c4d4851fc8dce30cd 35 | RUN pip install --upgrade -e . 36 | ADD habitat_data /root/side-tuning/habitat-api/data 37 | RUN rm -r /root/side-tuning/habitat-api/configs 38 | ADD habitat_configs /root/side-tuning/habitat-api/configs 39 | RUN rm -r baselines 40 | 41 | 42 | ############################### 43 | # set up side-tuning 44 | ############################### 45 | ADD requirements.txt /root/side-tuning/requirements.txt 46 | ADD __init__.py /root/side-tuning/__init__.py 47 | ADD assets /root/side-tuning/assets 48 | ADD configs /root/side-tuning/configs 49 | ADD evkit /root/side-tuning/evkit 50 | ADD feature_selector /root/side-tuning/feature_selector 51 | ADD scripts /root/side-tuning/scripts 52 | ADD tlkit /root/side-tuning/tlkit 53 | WORKDIR /root/side-tuning 54 | RUN pip install -r requirements.txt 55 | RUN ln -s habitat-api/data . 56 | 57 | 58 | ############################### 59 | # set up baselines 60 | ############################### 61 | WORKDIR /root 62 | RUN apt-get update && apt-get install -y cmake libopenmpi-dev python3-dev zlib1g-dev 63 | RUN git clone https://github.com/openai/baselines.git; cd baselines; pip install -e . 64 | 65 | 66 | ###################################### 67 | # install tnt 68 | ###################################### 69 | ADD tnt /root/side-tuning/tnt 70 | WORKDIR /root/side-tuning/tnt 71 | RUN pip install -e . 72 | 73 | 74 | ###################################### 75 | # and... we are ready! 76 | ###################################### 77 | WORKDIR /root/side-tuning 78 | RUN pip install gym==0.10.9 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Jeffrey O. Zhang, Alexander Sax, Amir Zamir, Leonidas Guibas, Jitendra Malik. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/__init__.py -------------------------------------------------------------------------------- /assets/img/icifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/img/icifar.png -------------------------------------------------------------------------------- /assets/img/icifar_expected_stdout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/img/icifar_expected_stdout.png -------------------------------------------------------------------------------- /assets/img/pull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/img/pull.png -------------------------------------------------------------------------------- /assets/img/robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/img/robot.png -------------------------------------------------------------------------------- /assets/img/sidetuning_lifelong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/img/sidetuning_lifelong.png -------------------------------------------------------------------------------- /assets/img/taskonomy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/img/taskonomy.png -------------------------------------------------------------------------------- /assets/pytorch/distillation/fcn4-cifar.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/pytorch/distillation/fcn4-cifar.pth -------------------------------------------------------------------------------- /assets/pytorch/resnet44-cifar.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/assets/pytorch/resnet44-cifar.pth -------------------------------------------------------------------------------- /configs/shared.py: -------------------------------------------------------------------------------- 1 | @ex.named_config 2 | def radam(): 3 | cfg = {} 4 | cfg['learner'] = { 5 | 'optimizer_class': 'RAdam' 6 | } 7 | 8 | @ex.named_config 9 | def reckless(): 10 | cfg = {} 11 | cfg['training'] = { 12 | 'resume_training': False, 13 | } 14 | cfg['saving'] = { 15 | 'obliterate_logs': True, 16 | } 17 | -------------------------------------------------------------------------------- /evkit/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # PyInstaller 26 | # Usually these files are written by a python script from a template 27 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 28 | *.manifest 29 | *.spec 30 | 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | 35 | # Unit test / coverage reports 36 | htmlcov/ 37 | .tox/ 38 | .coverage 39 | .coverage.* 40 | .cache 41 | nosetests.xml 42 | coverage.xml 43 | *,cover 44 | 45 | # Translations 46 | *.mo 47 | *.pot 48 | 49 | # Django stuff: 50 | *.log 51 | 52 | # Sphinx documentation 53 | docs/_build/ 54 | 55 | # PyBuilder 56 | target/ 57 | 58 | # DotEnv configuration 59 | .env 60 | 61 | # Database 62 | *.db 63 | *.rdb 64 | 65 | # Pycharm 66 | .idea 67 | 68 | # VS Code 69 | .vscode/ 70 | 71 | # Spyder 72 | .spyproject/ 73 | 74 | # Jupyter NB Checkpoints 75 | .ipynb_checkpoints/ 76 | 77 | # exclude data from source control by default 78 | /data/ 79 | 80 | # ViZDoom files 81 | _vizdoom* 82 | *.lmp 83 | 84 | # Mac OS-specific storage files 85 | .DS_Store 86 | .nfs* 87 | -------------------------------------------------------------------------------- /evkit/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Alexander Sax, Bradley Emi, Jeffrey Zhang, Amir R. Zamir, Silvio Savarese, Leonidas Guibas, Jitendra Malik. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /evkit/README.md: -------------------------------------------------------------------------------- 1 | # Embodied Vision Toolkit 2 | 3 | This folder implements the research platform and code for the papers: 4 | 5 | **Mid-Level Visual Representations Improve Generalization and Sample Efficiency for Learning Visuomotor Policies**. _Arxiv preprint 2018_. Alexander Sax, Bradley Emi, Amir R. Zamir, Silvio Savarese, Leonidas Guibas, Jitendra Malik. 6 | 7 | **Mid-Level Visual Representations Improve Generalization and Sample Efficiency for Learning Habitat Challenge Policies**. _Arxiv preprint 2019_. Alexander Sax*, Jeffrey Zhang*, Bradley Emi, Amir R. Zamir, Silvio Savarese, Leonidas Guibas, Jitendra Malik. 8 | 9 | More information, as well as [online demos](http://perceptual.actor/policy_explorer/), [pretrained models](https://github.com/alexsax/midlevel-reps/tree/master#using-mid-level-perception-in-your-code-), and the [full paper](http://perceptual.actor/#paper) is available on the [website](http://perceptual.actor) or in the [main repository](https://github.com/alexsax/midlevel-reps/tree/master). 10 | 11 | If you find this code useful, please cite: 12 | 13 | ``` 14 | @inproceedings{midLevelReps2018, 15 |  title={Mid-Level Visual Representations Improve Generalization and Sample Efficiency for Learning Visuomotor Policies.}, 16 |  author={Alexander Sax and Bradley Emi and Amir R. Zamir and Leonidas J. Guibas and Silvio Savarese and Jitendra Malik}, 17 |  year={2018}, 18 | } 19 | ``` 20 | 21 | Or if you use the Habitat experiments, then please cite: 22 | 23 | **Mid-Level Visual Representations Improve Generalization and Sample Efficiency for Learning Habitat Challenge Policies**. _Arxiv preprint 2019_. Alexander Sax*, Jeffrey Zhang*, Bradley Emi, Amir R. Zamir, Silvio Savarese, Leonidas Guibas, Jitendra Malik. 24 | 25 | -------------------------------------------------------------------------------- /evkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/__init__.py -------------------------------------------------------------------------------- /evkit/env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | from gym.error import Error as GymError 3 | from .distributed_factory import DistributedEnv 4 | from .envs import EnvFactory 5 | 6 | ''' 7 | ViZDoom Environments 8 | ''' 9 | try: 10 | 11 | register( 12 | id='VizdoomRoom-v0', 13 | entry_point='evkit.env.vizdoom:VizdoomPointGoalEnv' 14 | ) 15 | 16 | register( 17 | id='VizdoomBasic-v0', 18 | entry_point='evkit.env.vizdoom:VizdoomBasic' 19 | ) 20 | 21 | register( 22 | id='VizdoomCorridor-v0', 23 | entry_point='evkit.env.vizdoom:VizdoomCorridor' 24 | ) 25 | 26 | register( 27 | id='VizdoomDefendCenter-v0', 28 | entry_point='evkit.env.vizdoom:VizdoomDefendCenter' 29 | ) 30 | 31 | register( 32 | id='VizdoomDefendLine-v0', 33 | entry_point='evkit.env.vizdoom:VizdoomDefendLine' 34 | ) 35 | 36 | register( 37 | id='VizdoomHealthGathering-v0', 38 | entry_point='evkit.env.vizdoom:VizdoomHealthGathering' 39 | ) 40 | 41 | register( 42 | id='VizdoomMyWayHome-v0', 43 | entry_point='evkit.env.vizdoom:VizdoomMyWayHome' 44 | ) 45 | 46 | register( 47 | id='VizdoomPredictPosition-v0', 48 | entry_point='evkit.env.vizdoom:VizdoomPredictPosition' 49 | ) 50 | 51 | register( 52 | id='VizdoomTakeCover-v0', 53 | entry_point='evkit.env.vizdoom:VizdoomTakeCover' 54 | ) 55 | 56 | register( 57 | id='VizdoomDeathmatch-v0', 58 | entry_point='evkit.env.vizdoom:VizdoomDeathmatch' 59 | ) 60 | 61 | register( 62 | id='VizdoomHealthGatheringSupreme-v0', 63 | entry_point='evkit.env.vizdoom:VizdoomHealthGatheringSupreme' 64 | ) 65 | except GymError: 66 | pass 67 | -------------------------------------------------------------------------------- /evkit/env/base_embodied_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | class BaseEmbodiedEnv(gym.Env): 4 | ''' Abstract class for all embodied environments. ''' 5 | 6 | is_embodied = True 7 | -------------------------------------------------------------------------------- /evkit/env/distributed_factory.py: -------------------------------------------------------------------------------- 1 | from evkit.env.util.vec_env.subproc_vec_embodied_env import SubprocVecEmbodiedEnv 2 | 3 | from evkit.utils.misc import Bunch 4 | 5 | _distribution_schemes = Bunch( 6 | {'vectorize': 'VECTORIZE', 7 | 'independent': 'INDEPENDENT'}) 8 | 9 | class DistributedEnv(object): 10 | 11 | distribution_schemes = _distribution_schemes 12 | 13 | @classmethod 14 | def new(cls, envs, gae_gamma=None, distribution_method=_distribution_schemes): 15 | if distribution_method == cls.distribution_schemes.vectorize: 16 | return cls.vectorized(envs, gae_gamma) 17 | elif distribution_method == cls.distribution_schemes.independent: 18 | return cls.independent(envs, gae_gamma) 19 | else: 20 | raise NotImplementedError 21 | 22 | def vectorized(envs, gae_gamma=None): 23 | ''' Vectorizes an interable of environments 24 | Params: 25 | envs: an iterable of environments 26 | gae_gamma: if not none and there observation space is one-dimensional, then apply the gamma parameter from GAE 27 | ''' 28 | envs = SubprocVecEmbodiedEnv(envs) 29 | 30 | # if len(envs) > 1: 31 | # envs = SubprocVecEmbodiedEnv(envs) 32 | # else: 33 | # envs = DummyVecEnv(envs) # TODO: Update this to work with sensordict 34 | 35 | if gae_gamma is not None: 36 | if hasattr(envs.observation_space, "spaces") \ 37 | and len(envs.observation_space.spaces) == 1 \ 38 | and len(list(envs.observation_space.spaces.values())[0].shape) == 1: 39 | envs = VecNormalize(envs, gamma=gae_gamma) 40 | elif not hasattr(envs.observation_space, "spaces") and len(envs.observation_space.shape) == 1: 41 | envs = VecNormalize(envs, gamma=gae_gamma) 42 | 43 | return envs 44 | 45 | def independent(envs, gae_gamma=None): 46 | if gae_gamma is not None: 47 | raise NotImplementedError('gae_gamma not supported for "independent" distributed environments') 48 | 49 | envs = [e() for e in envs] 50 | return envs -------------------------------------------------------------------------------- /evkit/env/habitat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/env/habitat/__init__.py -------------------------------------------------------------------------------- /evkit/env/habitat/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/env/habitat/config/__init__.py -------------------------------------------------------------------------------- /evkit/env/habitat/config/default.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | from typing import Optional 10 | from habitat import get_config 11 | from habitat.config import Config as CN 12 | 13 | DEFAULT_CONFIG_DIR = "configs/" 14 | 15 | # ----------------------------------------------------------------------------- 16 | # Config definition 17 | # ----------------------------------------------------------------------------- 18 | _C = CN() 19 | _C.SEED = 100 20 | # ----------------------------------------------------------------------------- 21 | # BASELINE 22 | # ----------------------------------------------------------------------------- 23 | _C.BASELINE = CN() 24 | # ----------------------------------------------------------------------------- 25 | # REINFORCEMENT LEARNING (RL) 26 | # ----------------------------------------------------------------------------- 27 | _C.BASELINE.RL = CN() 28 | _C.BASELINE.RL.SUCCESS_REWARD = 10.0 29 | _C.BASELINE.RL.SLACK_REWARD = -0.01 30 | # ----------------------------------------------------------------------------- 31 | # ORBSLAM2 BASELINE 32 | # ----------------------------------------------------------------------------- 33 | _C.BASELINE.ORBSLAM2 = CN() 34 | _C.BASELINE.ORBSLAM2.SLAM_VOCAB_PATH = "baselines/slambased/data/ORBvoc.txt" 35 | _C.BASELINE.ORBSLAM2.SLAM_SETTINGS_PATH = "baselines/slambased/data/mp3d3_small1k.yaml" 36 | _C.BASELINE.ORBSLAM2.MAP_CELL_SIZE = 0.1 37 | _C.BASELINE.ORBSLAM2.MAP_SIZE = 40 38 | _C.BASELINE.ORBSLAM2.CAMERA_HEIGHT = get_config().SIMULATOR.DEPTH_SENSOR.POSITION[1] 39 | _C.BASELINE.ORBSLAM2.BETA = 100 40 | _C.BASELINE.ORBSLAM2.H_OBSTACLE_MIN = 0.3 * _C.BASELINE.ORBSLAM2.CAMERA_HEIGHT 41 | _C.BASELINE.ORBSLAM2.H_OBSTACLE_MAX = 1.0 * _C.BASELINE.ORBSLAM2.CAMERA_HEIGHT 42 | _C.BASELINE.ORBSLAM2.D_OBSTACLE_MIN = 0.1 43 | _C.BASELINE.ORBSLAM2.D_OBSTACLE_MAX = 4.0 44 | _C.BASELINE.ORBSLAM2.PREPROCESS_MAP = True 45 | _C.BASELINE.ORBSLAM2.MIN_PTS_IN_OBSTACLE = get_config().SIMULATOR.DEPTH_SENSOR.WIDTH/2.0 46 | _C.BASELINE.ORBSLAM2.ANGLE_TH = float(np.deg2rad(15)) 47 | _C.BASELINE.ORBSLAM2.DIST_REACHED_TH = 0.15 48 | _C.BASELINE.ORBSLAM2.NEXT_WAYPOINT_TH = 0.5 49 | _C.BASELINE.ORBSLAM2.NUM_ACTIONS = 3 50 | _C.BASELINE.ORBSLAM2.DIST_TO_STOP = 0.05 51 | _C.BASELINE.ORBSLAM2.PLANNER_MAX_STEPS = 500 52 | _C.BASELINE.ORBSLAM2.DEPTH_DENORM = get_config().SIMULATOR.DEPTH_SENSOR.MAX_DEPTH 53 | 54 | 55 | def cfg( 56 | config_file: Optional[str] = None, config_dir: str = DEFAULT_CONFIG_DIR 57 | ) -> CN: 58 | config = _C.clone() 59 | if config_file: 60 | config.merge_from_file(os.path.join(config_dir, config_file)) 61 | return config 62 | -------------------------------------------------------------------------------- /evkit/env/habitat/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import random 5 | from habitat.sims.habitat_simulator import SimulatorActions 6 | from habitat.utils.visualizations import maps 7 | 8 | try: 9 | from habitat.sims.habitat_simulator import SIM_NAME_TO_ACTION # backwards support 10 | except: 11 | pass 12 | 13 | 14 | # TODO these are action values. Make sure to add the word "action" into the name 15 | FORWARD_VALUE = SimulatorActions.FORWARD.value 16 | FORWARD_VALUE = FORWARD_VALUE if isinstance(FORWARD_VALUE, int) else SIM_NAME_TO_ACTION[FORWARD_VALUE] 17 | 18 | STOP_VALUE = SimulatorActions.STOP.value 19 | STOP_VALUE = STOP_VALUE if isinstance(STOP_VALUE, int) else SIM_NAME_TO_ACTION[STOP_VALUE] 20 | 21 | LEFT_VALUE = SimulatorActions.LEFT.value 22 | LEFT_VALUE = LEFT_VALUE if isinstance(LEFT_VALUE, int) else SIM_NAME_TO_ACTION[LEFT_VALUE] 23 | 24 | RIGHT_VALUE = SimulatorActions.RIGHT.value 25 | RIGHT_VALUE = RIGHT_VALUE if isinstance(RIGHT_VALUE, int) else SIM_NAME_TO_ACTION[RIGHT_VALUE] 26 | 27 | 28 | TAKEOVER1 = [LEFT_VALUE] * 4 + [FORWARD_VALUE] * 4 29 | TAKEOVER2 = [RIGHT_VALUE] * 4 + [FORWARD_VALUE] * 4 30 | TAKEOVER3 = [LEFT_VALUE] * 6 + [FORWARD_VALUE] * 2 31 | TAKEOVER4 = [RIGHT_VALUE] * 6 + [FORWARD_VALUE] * 2 32 | # TAKEOVER5 = [LEFT_VALUE] * 8 # rotation only seems not to break out of bad behavior 33 | # TAKEOVER6 = [RIGHT_VALUE] * 8 34 | TAKEOVER_ACTION_SEQUENCES = [TAKEOVER1, TAKEOVER2, TAKEOVER3, TAKEOVER4] 35 | TAKEOVER_ACTION_SEQUENCES = [torch.Tensor(t).long() for t in TAKEOVER_ACTION_SEQUENCES] 36 | 37 | DEFAULT_TAKEOVER_ACTIONS = torch.Tensor([LEFT_VALUE, LEFT_VALUE, LEFT_VALUE, LEFT_VALUE, FORWARD_VALUE, FORWARD_VALUE]).long() 38 | NON_STOP_VALUES = torch.Tensor([FORWARD_VALUE, LEFT_VALUE, RIGHT_VALUE]).long() 39 | 40 | 41 | flatten = lambda l: [item for sublist in l for item in sublist] 42 | 43 | def chunks(l, n): 44 | """Yield successive n-sized chunks from l.""" 45 | for i in range(0, len(l), n): 46 | yield l[i:i + n] 47 | 48 | def shuffle_episodes(env, swap_every_k=10): 49 | episodes = env.episodes 50 | # buildings_for_epidodes = [e.scene_id for e in episodes] 51 | episodes = env.episodes = random.sample([c for c in chunks(episodes, swap_every_k)], len(episodes) // swap_every_k) 52 | env.episodes = flatten(episodes) 53 | return env.episodes 54 | 55 | 56 | def draw_top_down_map(info, heading, output_size): 57 | if info is None: 58 | return 59 | top_down_map = maps.colorize_topdown_map(info["top_down_map"]["map"]) 60 | original_map_size = top_down_map.shape[:2] 61 | map_scale = np.array( 62 | (1, original_map_size[1] * 1.0 / original_map_size[0]) 63 | ) 64 | new_map_size = np.round(output_size * map_scale).astype(np.int32) 65 | # OpenCV expects w, h but map size is in h, w 66 | top_down_map = cv2.resize(top_down_map, (new_map_size[1], new_map_size[0])) 67 | 68 | map_agent_pos = info["top_down_map"]["agent_map_coord"] 69 | map_agent_pos = np.round( 70 | map_agent_pos * new_map_size / original_map_size 71 | ).astype(np.int32) 72 | top_down_map = maps.draw_agent( 73 | top_down_map, 74 | map_agent_pos, 75 | heading - np.pi / 2, 76 | agent_radius_px=top_down_map.shape[0] / 40, 77 | ) 78 | return top_down_map 79 | 80 | def gray_to_rgb(img_arr): 81 | # Input: (H,W,1) or (H,W) or (H,W,3) 82 | # Output: (H,W,3) 83 | if len(img_arr.shape) == 3 and img_arr.shape[2] == 3: # (H,W,3) 84 | return img_arr 85 | if len(img_arr.shape) == 3: 86 | img_arr = img_arr.squeeze(2) # (H,W,1) 87 | return np.dstack((img_arr, img_arr, img_arr)) -------------------------------------------------------------------------------- /evkit/env/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/env/util/__init__.py -------------------------------------------------------------------------------- /evkit/env/util/make_env.py: -------------------------------------------------------------------------------- 1 | ''' A version of the Gym registry that allows more flexible use of kwargs''' 2 | 3 | from gym import error, logger 4 | from gym.envs.registration as registration #import register, EnvSpec, spec 5 | 6 | 7 | 8 | def make(self, id, kwargs): 9 | logger.info('Making new env: %s', id) 10 | spec = registration.spec(id) 11 | env = spec.make() 12 | # We used to have people override _reset/_step rather than 13 | # reset/step. Set _gym_disable_underscore_compat = True on 14 | # your environment if you use these methods and don't want 15 | # compatibility code to be invoked. 16 | if hasattr(env, "_reset") and hasattr(env, "_step") and not getattr(env, "_gym_disable_underscore_compat", False): 17 | patch_deprecated_methods(env) 18 | if (env.spec.timestep_limit is not None) and not spec.tags.get('vnc'): 19 | from gym.wrappers.time_limit import TimeLimit 20 | env = TimeLimit(env, 21 | max_episode_steps=env.spec.max_episode_steps, 22 | max_episode_seconds=env.spec.max_episode_seconds) 23 | return env -------------------------------------------------------------------------------- /evkit/env/util/tile_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def tile_images(img_nhwc): 4 | """ 5 | Tile N images into one big PxQ image 6 | (P,Q) are chosen to be as close as possible, and if N 7 | is square, then P=Q. 8 | 9 | input: img_nhwc, list or array of images, ndim=4 once turned into array 10 | n = batch index, h = height, w = width, c = channel 11 | returns: 12 | bigim_HWc, ndarray with ndim=3 13 | """ 14 | img_nhwc = np.asarray(img_nhwc) 15 | N, h, w, c = img_nhwc.shape 16 | H = int(np.ceil(np.sqrt(N))) 17 | W = int(np.ceil(float(N)/H)) 18 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 19 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 20 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 21 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 22 | return img_Hh_Ww_c 23 | 24 | -------------------------------------------------------------------------------- /evkit/env/util/vec_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/env/util/vec_env/__init__.py -------------------------------------------------------------------------------- /evkit/env/util/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import spaces 3 | from collections import OrderedDict 4 | from baselines.common.vec_env import VecEnv 5 | 6 | class DummyVecEnv(VecEnv): 7 | def __init__(self, env_fns): 8 | self.envs = [fn() for fn in env_fns] 9 | env = self.envs[0] 10 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 11 | shapes, dtypes = {}, {} 12 | self.keys = [] 13 | obs_space = env.observation_space 14 | 15 | if isinstance(obs_space, spaces.Dict): 16 | assert isinstance(obs_space.spaces, OrderedDict) 17 | subspaces = obs_space.spaces 18 | else: 19 | subspaces = {None: obs_space} 20 | 21 | for key, box in subspaces.items(): 22 | shapes[key] = box.shape 23 | dtypes[key] = box.dtype 24 | self.keys.append(key) 25 | 26 | self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } 27 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) 28 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 29 | self.buf_infos = [{} for _ in range(self.num_envs)] 30 | self.actions = None 31 | 32 | def step_async(self, actions): 33 | self.actions = actions 34 | 35 | def step_wait(self): 36 | for e in range(self.num_envs): 37 | obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e]) 38 | if self.buf_dones[e]: 39 | obs = self.envs[e].reset() 40 | self._save_obs(e, obs) 41 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), 42 | self.buf_infos.copy()) 43 | 44 | def reset(self): 45 | for e in range(self.num_envs): 46 | obs = self.envs[e].reset() 47 | self._save_obs(e, obs) 48 | return self._obs_from_buf() 49 | 50 | def close(self): 51 | return 52 | 53 | def render(self, mode='human'): 54 | return [e.render(mode=mode) for e in self.envs] 55 | 56 | def _save_obs(self, e, obs): 57 | for k in self.keys: 58 | if k is None: 59 | self.buf_obs[k][e] = obs 60 | else: 61 | self.buf_obs[k][e] = obs[k] 62 | 63 | def _obs_from_buf(self): 64 | if self.keys==[None]: 65 | return self.buf_obs[None] 66 | else: 67 | return self.buf_obs 68 | -------------------------------------------------------------------------------- /evkit/env/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .visdommonitor import VisdomMonitor 2 | from .preprocessingwrapper import ProcessObservationWrapper 3 | from .skip_wrapper import SkipWrapper 4 | from .sensorenv import SensorEnvWrapper 5 | 6 | -------------------------------------------------------------------------------- /evkit/env/wrappers/preprocessingwrapper.py: -------------------------------------------------------------------------------- 1 | from gym.spaces.box import Box 2 | import gym 3 | import torch 4 | 5 | class ProcessObservationWrapper(gym.ObservationWrapper): 6 | ''' Wraps an environment so that instead of 7 | obs = env.step(), 8 | obs = transform(env.step()) 9 | 10 | Args: 11 | transform: a function that transforms obs 12 | obs_shape: the final obs_shape is needed to set the observation space of the env 13 | ''' 14 | def __init__(self, env, transform, obs_space): 15 | super().__init__(env) 16 | self.observation_space = obs_space 17 | self.transform = transform 18 | 19 | def observation(self, observation): 20 | return self.transform(observation) 21 | -------------------------------------------------------------------------------- /evkit/env/wrappers/sensorenv.py: -------------------------------------------------------------------------------- 1 | from gym import spaces 2 | import gym 3 | import torch 4 | from evkit.sensors import SensorDict 5 | 6 | class SensorEnvWrapper(gym.ObservationWrapper): 7 | ''' Wraps a typical gym environment so to work with our package 8 | obs = env.step(), 9 | obs = {sensor_name: env.step()} 10 | 11 | Parameters: 12 | name: what to name the sensor 13 | ''' 14 | def __init__(self, env, name='obs'): 15 | super().__init__(env) 16 | self.name = name 17 | self.observation_space = spaces.Dict({self.name: self.observation_space}) 18 | 19 | def observation(self, observation): 20 | return SensorDict({self.name: observation}) 21 | -------------------------------------------------------------------------------- /evkit/env/wrappers/skip_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | __all__ = ['SkipWrapper'] 4 | 5 | def SkipWrapper(repeat_count): 6 | class SkipWrapper(gym.Wrapper): 7 | """ 8 | Generic common frame skipping wrapper 9 | Will perform action for `x` additional steps 10 | """ 11 | def __init__(self, env): 12 | super(SkipWrapper, self).__init__(env) 13 | self.repeat_count = repeat_count 14 | self.stepcount = 0 15 | 16 | def step(self, action): 17 | done = False 18 | total_reward = 0 19 | current_step = 0 20 | while current_step < (self.repeat_count + 1) and not done: 21 | self.stepcount += 1 22 | if (current_step < self.repeat_count): 23 | _, reward, done, info = self.env.step_physics(action) 24 | else: 25 | self.obs, reward, done, info = self.env.step(action) 26 | total_reward += reward 27 | current_step += 1 28 | if 'skip.stepcount' in info: 29 | raise gym.error.Error('Key "skip.stepcount" already in info. Make sure you are not stacking ' \ 30 | 'the SkipWrapper wrappers.') 31 | info['skip.stepcount'] = self.stepcount 32 | info['skip.repeat_count'] = self.repeat_count 33 | return self.obs, total_reward, done, info 34 | 35 | def reset(self): 36 | self.stepcount = 0 37 | self.obs = self.env.reset() 38 | return self.obs 39 | 40 | return SkipWrapper -------------------------------------------------------------------------------- /evkit/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/models/.gitkeep -------------------------------------------------------------------------------- /evkit/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class SingleSensorModule(nn.Module): 4 | def __init__(self, module, sensor_name): 5 | super().__init__() 6 | self.module = module 7 | self.sensor_name = sensor_name 8 | 9 | def __call__(self, obs): 10 | # return {self.sensor_name: self.module(obs[self.sensor_name])} 11 | return self.module(obs[self.sensor_name]) 12 | -------------------------------------------------------------------------------- /evkit/models/actor_critic_module_curiosity.py: -------------------------------------------------------------------------------- 1 | 2 | from .actor_critic_module import NaivelyRecurrentACModule 3 | 4 | 5 | class ForwardInverseACModule(NaivelyRecurrentACModule): 6 | ''' 7 | This Module adds a forward-inverse model on top of the perception unit. 8 | ''' 9 | def __init__(self, perception_unit, forward_model, inverse_model, use_recurrency=False, internal_state_size=512): 10 | super().__init__(perception_unit, use_recurrency, internal_state_size) 11 | 12 | self.forward_model = forward_model 13 | self.inverse_model = inverse_model -------------------------------------------------------------------------------- /evkit/models/forward_inverse.py: -------------------------------------------------------------------------------- 1 | from gym import spaces 2 | import multiprocessing.dummy as mp 3 | import multiprocessing 4 | import numpy as np 5 | import os 6 | import torch 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import Parameter, ModuleList 10 | import torch.nn.functional as F 11 | 12 | from evkit.rl.utils import init, init_normc_ 13 | from evkit.utils.misc import is_cuda 14 | from evkit.preprocess import transforms 15 | 16 | import pickle as pkl 17 | 18 | init_ = lambda m: init(m, 19 | nn.init.orthogonal_, 20 | lambda x: nn.init.constant_(x, 0), 21 | nn.init.calculate_gain('relu')) 22 | 23 | ################################ 24 | # Inverse Models 25 | # Predict s_{t+1} | s_t, a_t 26 | ################################ 27 | class ForwardModel(nn.Module): 28 | 29 | def __init__(self, state_shape, action_shape, hidden_size): 30 | super().__init__() 31 | self.fc1 = init_(nn.Linear(state_shape + action_shape[1], hidden_size)) 32 | self.fc2 = init_(nn.Linear(hidden_size, state_shape)) 33 | 34 | def forward(self, state, action): 35 | x = torch.cat([state, action], 1) 36 | x = F.relu(self.fc1(x)) 37 | x = self.fc2(x) 38 | return x 39 | 40 | ################################ 41 | # Inverse Models 42 | # Predict a_t | s_t, s_{t+1} 43 | ################################ 44 | class InverseModel(nn.Module): 45 | 46 | def __init__(self, input_size, hidden_size, output_size): 47 | super().__init__() 48 | self.fc1 = init_(nn.Linear(input_size * 2, hidden_size)) 49 | # Note to stoip gradient 50 | self.fc2 = init_(nn.Linear(hidden_size, output_size)) 51 | 52 | def forward(self, phi_t, phi_t_plus_1): 53 | x = torch.cat([phi_t, phi_t_plus_1], 1) 54 | x = F.relu(self.fc1(x)) 55 | logits = self.fc2(x) 56 | return logits 57 | # ainvprobs = nn.softmax(logits, dim=-1) -------------------------------------------------------------------------------- /evkit/models/triangle.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | import torch 4 | from torchsummary import summary 5 | 6 | from teas.models.unet import UNet, UNetHeteroscedasticFull, UNetHeteroscedasticIndep, UNetHeteroscedasticPooled 7 | 8 | 9 | class TriangleModel(nn.Module): 10 | 11 | def __init__(self, network_constructors, n_channels_lists, universal_kwargses=[{}]): 12 | super().__init__() 13 | self.chains = nn.ModuleList() 14 | for network_constructor, n_channels_list, universal_kwargs in zip(network_constructors, n_channels_lists, 15 | universal_kwargses): 16 | print(network_constructor) 17 | chain = network_constructor(n_channels_list=n_channels_list, 18 | universal_kwargs=universal_kwargs) 19 | # chain.append(net) 20 | self.chains.append(chain) 21 | 22 | def initialize_from_checkpoints(self, checkpoint_paths, logger=None): 23 | for i, (chain, ckpt_fpath) in enumerate(zip(self.chains, checkpoint_paths)): 24 | if logger is not None: 25 | logger.info(f"Loading step {i} from {ckpt_fpath}") 26 | checkpoint = torch.load(ckpt_fpath) 27 | sd = {k.replace("module.", ""): v for k, v in checkpoint['state_dict'].items()} 28 | chain.load_state_dict(sd) 29 | # initialize_from_checkpoints(ckpt_fpaths, logger) 30 | return self 31 | 32 | def forward(self, x): 33 | chain_outputs = [] 34 | for chain in self.chains: 35 | outputs = chain(x) 36 | # chain_x = x 37 | # for net in chain: 38 | # chain_x = net(chain_x) 39 | # outputs.append(chain_x) 40 | # if isinstance(chain_x, tuple): 41 | # chain_x = torch.cat(chain_x, dim=1) 42 | chain_outputs.append(outputs) 43 | return chain_outputs 44 | -------------------------------------------------------------------------------- /evkit/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform_factory import TransformFactory 2 | 3 | -------------------------------------------------------------------------------- /evkit/preprocess/baseline_transforms.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | import skimage 4 | import torchvision as vision 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import multiprocessing.dummy as mp 9 | import multiprocessing 10 | from gym import spaces 11 | 12 | from evkit.sensors import SensorPack 13 | 14 | 15 | def blind(output_size, dtype=np.float32): 16 | ''' rescale_centercrop_resize 17 | 18 | Args: 19 | output_size: A tuple CxWxH 20 | dtype: of the output (must be np, not torch) 21 | 22 | Returns: 23 | a function which returns takes 'env' and returns transform, output_size, dtype 24 | ''' 25 | def _thunk(obs_space): 26 | pipeline = lambda x: torch.zeros(output_size) 27 | return pipeline, spaces.Box(-1, 1, output_size, dtype) 28 | return _thunk 29 | 30 | 31 | 32 | def pixels_as_state(output_size, dtype=np.float32): 33 | ''' rescale_centercrop_resize 34 | 35 | Args: 36 | output_size: A tuple CxWxH 37 | dtype: of the output (must be np, not torch) 38 | 39 | Returns: 40 | a function which returns takes 'env' and returns transform, output_size, dtype 41 | ''' 42 | def _thunk(obs_space): 43 | obs_shape = obs_space.shape 44 | obs_min_wh = min(obs_shape[:2]) 45 | output_wh = output_size[-2:] # The out 46 | processed_env_shape = output_size 47 | 48 | base_pipeline = vision.transforms.Compose([ 49 | vision.transforms.ToPILImage(), 50 | vision.transforms.CenterCrop([obs_min_wh, obs_min_wh]), 51 | vision.transforms.Resize(output_wh)]) 52 | 53 | grayscale_pipeline = vision.transforms.Compose([ 54 | vision.transforms.Grayscale(), 55 | vision.transforms.ToTensor(), 56 | RESCALE_0_1_NEG1_POS1, 57 | ]) 58 | 59 | rgb_pipeline = vision.transforms.Compose([ 60 | vision.transforms.ToTensor(), 61 | RESCALE_0_1_NEG1_POS1, 62 | ]) 63 | 64 | def pipeline(x): 65 | base = base_pipeline(x) 66 | rgb = rgb_pipeline(base) 67 | gray = grayscale_pipeline(base) 68 | 69 | n_rgb = output_size[0] // 3 70 | n_gray = output_size[0] % 3 71 | return torch.cat([rgb] * n_rgb + [gray] * n_gray) 72 | return pipeline, spaces.Box(-1, 1, output_size, dtype) 73 | return _thunk 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /evkit/preprocess/filters.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class GaussianSmoothing(nn.Module): 9 | """ 10 | Apply gaussian smoothing on a 11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 12 | in the input using a depthwise convolution. 13 | Arguments: 14 | channels (int, sequence): Number of channels of the input tensors. Output will 15 | have this number of channels as well. 16 | kernel_size (int, sequence): Size of the gaussian kernel. 17 | sigma (float, sequence): Standard deviation of the gaussian kernel. 18 | dim (int, optional): The number of dimensions of the data. 19 | Default value is 2 (spatial). 20 | """ 21 | 22 | def __init__(self, channels, kernel_size, sigma, dim=2): 23 | super(GaussianSmoothing, self).__init__() 24 | if isinstance(kernel_size, numbers.Number): 25 | kernel_size = [kernel_size] * dim 26 | if isinstance(sigma, numbers.Number): 27 | sigma = [sigma] * dim 28 | 29 | self.kernel_size = kernel_size 30 | # The gaussian kernel is the product of the 31 | # gaussian function of each dimension. 32 | kernel = 1 33 | meshgrids = torch.meshgrid( 34 | [ 35 | torch.arange(size, dtype=torch.float32) 36 | for size in kernel_size 37 | ] 38 | ) 39 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 40 | mean = (size - 1) / 2 41 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 42 | torch.exp(-((mgrid - mean) / std) ** 2 / 2) 43 | 44 | # Make sure sum of values in gaussian kernel equals 1. 45 | kernel = kernel / torch.sum(kernel) 46 | 47 | # Reshape to depthwise convolutional weight 48 | kernel = kernel.view(1, 1, *kernel.size()) 49 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 50 | 51 | self.register_buffer('weight', kernel) 52 | self.groups = channels 53 | 54 | if dim == 1: 55 | self.conv = F.conv1d 56 | elif dim == 2: 57 | self.conv = F.conv2d 58 | elif dim == 3: 59 | self.conv = F.conv3d 60 | else: 61 | raise RuntimeError( 62 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 63 | ) 64 | 65 | def forward(self, input): 66 | """ 67 | Apply gaussian filter to input. 68 | Arguments: 69 | input (torch.Tensor): Input to apply gaussian filter on. 70 | Returns: 71 | filtered (torch.Tensor): Filtered output. 72 | """ 73 | input_was_3 = (len(input) == 3) 74 | if input_was_3: 75 | input = input.unsqueeze(0) 76 | input = F.pad(input, [self.kernel_size[0] // 2] * 4, mode='reflect') 77 | res = self.conv(input, weight=self.weight, groups=self.groups) 78 | return res.squeeze(0) if input_was_3 else res 79 | -------------------------------------------------------------------------------- /evkit/preprocess/gaussian.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | class GaussianSmoothing(nn.Module): 8 | """ 9 | Apply gaussian smoothing on a 10 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 11 | in the input using a depthwise convolution. 12 | Arguments: 13 | channels (int, sequence): Number of channels of the input tensors. Output will 14 | have this number of channels as well. 15 | kernel_size (int, sequence): Size of the gaussian kernel. 16 | sigma (float, sequence): Standard deviation of the gaussian kernel. 17 | dim (int, optional): The number of dimensions of the data. 18 | Default value is 2 (spatial). 19 | """ 20 | def __init__(self, channels, kernel_size, sigma, dim=2): 21 | super(GaussianSmoothing, self).__init__() 22 | if isinstance(kernel_size, numbers.Number): 23 | kernel_size = [kernel_size] * dim 24 | self.kernel_size = kernel_size[0] 25 | self.dim = dim 26 | 27 | if isinstance(sigma, numbers.Number): 28 | sigma = [sigma] * dim 29 | 30 | # The gaussian kernel is the product of the 31 | # gaussian function of each dimension. 32 | kernel = 1 33 | meshgrids = torch.meshgrid( 34 | [ 35 | torch.arange(size, dtype=torch.float32) 36 | for size in kernel_size 37 | ] 38 | ) 39 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 40 | mean = (size - 1) / 2 41 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 42 | torch.exp(-((mgrid - mean) / std) ** 2 / 2) 43 | 44 | # Make sure sum of values in gaussian kernel equals 1. 45 | kernel = kernel / torch.sum(kernel) 46 | 47 | # Reshape to depthwise convolutional weight 48 | kernel = kernel.view(1, 1, *kernel.size()) 49 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 50 | 51 | self.register_buffer('weight', kernel) 52 | self.groups = channels 53 | 54 | if dim == 1: 55 | self.conv = F.conv1d 56 | elif dim == 2: 57 | self.conv = F.conv2d 58 | elif dim == 3: 59 | self.conv = F.conv3d 60 | else: 61 | raise RuntimeError( 62 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 63 | ) 64 | 65 | def forward(self, input): 66 | """ 67 | Apply gaussian filter to input. 68 | Arguments: 69 | input (torch.Tensor): Input to apply gaussian filter on. 70 | Returns: 71 | filtered (torch.Tensor): Filtered output. 72 | """ 73 | input = F.pad(input, [self.kernel_size//2] * 2 * self.dim, mode='reflect') 74 | return self.conv(input, weight=self.weight, groups=self.groups) -------------------------------------------------------------------------------- /evkit/rl/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ilya Kostrikov, Bradley Emi, Alexander Sax, Jeffrey Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /evkit/rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/rl/__init__.py -------------------------------------------------------------------------------- /evkit/rl/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .a2c_acktr import A2C_ACKTR 2 | from .ppo import PPO 3 | from .ppo_replay import PPOReplay 4 | from .deepq import QLearner 5 | from .ppo_curiosity import PPOCuriosity, PPOReplayCuriosity 6 | 7 | -------------------------------------------------------------------------------- /evkit/rl/algo/a2c_acktr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | from .kfac import KFACOptimizer 6 | 7 | 8 | class A2C_ACKTR(object): 9 | def __init__(self, 10 | actor_critic, 11 | value_loss_coef, 12 | entropy_coef, 13 | lr=None, 14 | eps=None, 15 | alpha=None, 16 | max_grad_norm=None, 17 | acktr=False): 18 | 19 | self.actor_critic = actor_critic 20 | self.acktr = acktr 21 | 22 | self.value_loss_coef = value_loss_coef 23 | self.entropy_coef = entropy_coef 24 | 25 | self.max_grad_norm = max_grad_norm 26 | 27 | if acktr: 28 | self.optimizer = KFACOptimizer(actor_critic) 29 | else: 30 | self.optimizer = optim.RMSprop( 31 | actor_critic.parameters(), lr, eps=eps, alpha=alpha) 32 | 33 | def update(self, rollouts): 34 | obs_shape = rollouts.observations.size()[2:] 35 | action_shape = rollouts.actions.size()[-1] 36 | num_steps, num_processes, _ = rollouts.rewards.size() 37 | 38 | values, action_log_probs, dist_entropy, states = self.actor_critic.evaluate_actions( 39 | rollouts.observations[:-1].view(-1, *obs_shape), 40 | rollouts.states[0].view(-1, self.actor_critic.state_size), 41 | rollouts.masks[:-1].view(-1, 1), 42 | rollouts.actions.view(-1, action_shape)) 43 | 44 | values = values.view(num_steps, num_processes, 1) 45 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 46 | 47 | advantages = rollouts.returns[:-1] - values 48 | value_loss = advantages.pow(2).mean() 49 | 50 | action_loss = -(advantages.detach() * action_log_probs).mean() 51 | 52 | if self.acktr and self.optimizer.steps % self.optimizer.Ts == 0: 53 | # Sampled fisher, see Martens 2014 54 | self.actor_critic.zero_grad() 55 | pg_fisher_loss = -action_log_probs.mean() 56 | 57 | value_noise = torch.randn(values.size()) 58 | if values.is_cuda: 59 | value_noise = value_noise.cuda() 60 | 61 | sample_values = values + value_noise 62 | vf_fisher_loss = -(values - sample_values.detach()).pow(2).mean() 63 | 64 | fisher_loss = pg_fisher_loss + vf_fisher_loss 65 | self.optimizer.acc_stats = True 66 | fisher_loss.backward(retain_graph=True) 67 | self.optimizer.acc_stats = False 68 | 69 | self.optimizer.zero_grad() 70 | (value_loss * self.value_loss_coef + action_loss - 71 | dist_entropy * self.entropy_coef).backward() 72 | 73 | if self.acktr == False: 74 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 75 | self.max_grad_norm) 76 | 77 | self.optimizer.step() 78 | 79 | return value_loss.item(), action_loss.item(), dist_entropy.item() 80 | -------------------------------------------------------------------------------- /evkit/rl/algo/deepq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import random 5 | import torch 6 | import time 7 | import numpy as np 8 | 9 | class QLearner(nn.Module): 10 | def __init__(self, actor_network, target_network, 11 | action_dim, batch_size, lr, eps, gamma, 12 | copy_frequency, 13 | start_schedule, schedule_timesteps, 14 | initial_p, final_p): 15 | super(QLearner, self).__init__() 16 | self.actor_network = actor_network 17 | self.target_network = target_network 18 | self.learning_schedule = LearningSchedule(start_schedule, schedule_timesteps, initial_p, final_p) 19 | self.beta_schedule = LearningSchedule(start_schedule, schedule_timesteps, 0.4, 1.0) 20 | self.action_dim = action_dim 21 | self.copy_frequency = copy_frequency 22 | self.batch_size = batch_size 23 | self.gamma = gamma 24 | 25 | self.optimizer = optim.Adam(actor_network.parameters(), 26 | lr=lr, 27 | eps=eps) 28 | 29 | self.step = 0 30 | 31 | def cuda(self): 32 | self.actor_network = self.actor_network.cuda() 33 | self.target_network = self.target_network.cuda() 34 | 35 | def act(self, observation, greedy=False): 36 | self.step += 1 37 | if self.step % self.copy_frequency == 1: 38 | self.target_network.load_state_dict(self.actor_network.state_dict()) 39 | if random.random() > self.learning_schedule.value(self.step) or greedy: 40 | with torch.no_grad(): 41 | return self.actor_network(observation).max(1)[1].view(1,1) 42 | else: 43 | return torch.tensor([[random.randrange(self.action_dim)]]) 44 | 45 | 46 | def update(self, rollouts): 47 | loss_epoch = 0 48 | observations, actions, rewards, masks, next_observations, weights, indices = rollouts.sample(self.batch_size, 49 | beta=self.beta_schedule.value(self.step)) 50 | next_state_values = self.target_network(next_observations).detach().max(1)[0].unsqueeze(1) 51 | 52 | state_action_values = self.actor_network(observations).gather(1, actions) 53 | targets = rewards + self.gamma * masks * next_state_values 54 | if rollouts.use_priority: 55 | with torch.no_grad(): 56 | td_errors = torch.abs(targets - state_action_values).detach() + 1e-6 57 | 58 | rollouts.update_priorities(indices, td_errors) 59 | loss = torch.sum(weights * (targets - state_action_values) ** 2) 60 | self.optimizer.zero_grad() 61 | loss.backward() 62 | self.optimizer.step() 63 | loss_epoch += loss.item() 64 | return loss_epoch 65 | 66 | def get_epsilon(self): 67 | return self.learning_schedule.value(self.step) 68 | 69 | 70 | class LearningSchedule(object): 71 | def __init__(self, start_schedule, schedule_timesteps, initial_p=1.0, final_p=0.05): 72 | self.initial_p = initial_p 73 | self.final_p = final_p 74 | self.schedule_timesteps = schedule_timesteps 75 | self.start_schedule = start_schedule 76 | 77 | def value(self, t): 78 | fraction = min(max(0.0, float(t - self.start_schedule)) / self.schedule_timesteps, 1.0) 79 | return self.initial_p + fraction * (self.final_p - self.initial_p) -------------------------------------------------------------------------------- /evkit/rl/algo/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import os 6 | 7 | class PPO(object): 8 | def __init__(self, 9 | actor_critic, 10 | clip_param, 11 | ppo_epoch, 12 | num_mini_batch, 13 | value_loss_coef, 14 | entropy_coef, 15 | lr=None, 16 | eps=None, 17 | max_grad_norm=None, 18 | amsgrad=True, 19 | weight_decay=0.0): 20 | 21 | self.actor_critic = actor_critic 22 | 23 | self.clip_param = clip_param 24 | self.ppo_epoch = ppo_epoch 25 | self.num_mini_batch = num_mini_batch 26 | 27 | self.value_loss_coef = value_loss_coef 28 | self.entropy_coef = entropy_coef 29 | 30 | self.max_grad_norm = max_grad_norm 31 | 32 | self.optimizer = optim.Adam(actor_critic.parameters(), 33 | lr=lr, 34 | eps=eps, 35 | weight_decay=weight_decay, 36 | amsgrad=amsgrad) 37 | self.last_grad_norm = None 38 | 39 | def update(self, rollouts): 40 | advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] 41 | advantages = (advantages - advantages.mean()) / ( 42 | advantages.std() + 1e-5) 43 | 44 | 45 | value_loss_epoch = 0 46 | action_loss_epoch = 0 47 | dist_entropy_epoch = 0 48 | max_importance_weight_epoch = 0 49 | 50 | for e in range(self.ppo_epoch): 51 | if hasattr(self.actor_critic.base, 'gru'): 52 | data_generator = rollouts.recurrent_generator( 53 | advantages, self.num_mini_batch) 54 | else: 55 | data_generator = rollouts.feed_forward_generator( 56 | advantages, self.num_mini_batch) 57 | 58 | for sample in data_generator: 59 | observations_batch, states_batch, actions_batch, \ 60 | return_batch, masks_batch, old_action_log_probs_batch, \ 61 | adv_targ = sample 62 | 63 | # Reshape to do in a single forward pass for all steps 64 | values, action_log_probs, dist_entropy, states = self.actor_critic.evaluate_actions( 65 | observations_batch, states_batch, 66 | masks_batch, actions_batch) 67 | 68 | ratio = torch.exp(action_log_probs - old_action_log_probs_batch) 69 | surr1 = ratio * adv_targ 70 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 71 | 1.0 + self.clip_param) * adv_targ 72 | action_loss = -torch.min(surr1, surr2).mean() 73 | value_loss = F.mse_loss(values, return_batch) 74 | self.optimizer.zero_grad() 75 | (value_loss * self.value_loss_coef + 76 | action_loss - 77 | dist_entropy * self.entropy_coef).backward() 78 | self.last_grad_norm = nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 79 | self.max_grad_norm) 80 | self.optimizer.step() 81 | 82 | value_loss_epoch += value_loss.item() 83 | action_loss_epoch += action_loss.item() 84 | dist_entropy_epoch += dist_entropy.item() 85 | max_importance_weight_epoch = max(torch.max(ratio).item(), max_importance_weight_epoch) 86 | 87 | num_updates = self.ppo_epoch * self.num_mini_batch 88 | value_loss_epoch /= num_updates 89 | action_loss_epoch /= num_updates 90 | dist_entropy_epoch /= num_updates 91 | 92 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, max_importance_weight_epoch, {} 93 | -------------------------------------------------------------------------------- /evkit/rl/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .utils import init, init_normc_, AddBias 7 | 8 | """ 9 | Modify standard PyTorch distributions so they are compatible with this code. 10 | """ 11 | 12 | FixedCategorical = torch.distributions.Categorical 13 | 14 | old_sample = FixedCategorical.sample 15 | FixedCategorical.sample = lambda self: old_sample(self).unsqueeze(-1) 16 | 17 | log_prob_cat = FixedCategorical.log_prob 18 | FixedCategorical.log_probs = lambda self, actions: log_prob_cat(self, actions.squeeze(-1)).unsqueeze(-1) 19 | 20 | FixedCategorical.mode = lambda self: self.probs.argmax(dim=1, keepdim=True) 21 | 22 | FixedNormal = torch.distributions.Normal 23 | log_prob_normal = FixedNormal.log_prob 24 | FixedNormal.log_probs = lambda self, actions: log_prob_normal(self, actions).sum(-1, keepdim=True) 25 | 26 | entropy = FixedNormal.entropy 27 | FixedNormal.entropy = lambda self: entropy(self).sum(-1) 28 | 29 | FixedNormal.mode = lambda self: self.mean 30 | 31 | 32 | class Categorical(nn.Module): 33 | def __init__(self, num_inputs, num_outputs): 34 | super(Categorical, self).__init__() 35 | self.num_outputs = num_outputs 36 | 37 | 38 | init_ = lambda m: init(m, 39 | nn.init.orthogonal_, 40 | lambda x: nn.init.constant_(x, 0), 41 | gain=0.01) 42 | 43 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 44 | 45 | def forward(self, x): 46 | x = self.linear(x) 47 | return FixedCategorical(logits=x) 48 | 49 | 50 | class DiagGaussian(nn.Module): 51 | def __init__(self, num_inputs, num_outputs): 52 | super(DiagGaussian, self).__init__() 53 | self.num_outputs = num_outputs 54 | 55 | init_ = lambda m: init(m, 56 | init_normc_, 57 | lambda x: nn.init.constant_(x, 0)) 58 | 59 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 60 | self.logstd = AddBias(torch.zeros(num_outputs)) 61 | 62 | def forward(self, x): 63 | action_mean = self.fc_mean(x) 64 | 65 | # An ugly hack for my KFAC implementation. 66 | zeros = torch.zeros(action_mean.size()) 67 | if x.is_cuda: 68 | zeros = zeros.cuda() 69 | 70 | action_logstd = self.logstd(zeros) 71 | return FixedNormal(action_mean, action_logstd.exp()) -------------------------------------------------------------------------------- /evkit/rl/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | class PreprocessingTranforms(object): 4 | def __init__(self, input_dims): 5 | pass 6 | 7 | def forward(self, batch): 8 | pass 9 | 10 | -------------------------------------------------------------------------------- /evkit/rl/requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | matplotlib 3 | pybullet 4 | -------------------------------------------------------------------------------- /evkit/rl/storage/__init__.py: -------------------------------------------------------------------------------- 1 | from .stackedobservation import StackedSensorDictStorage, StackedTensorStorage 2 | from .rollout import RolloutSensorDictStorage, RolloutTensorStorage, RolloutSensorDictReplayBuffer, RolloutSensorDictDQNReplayBuffer 3 | from .memory import ReplayMemory 4 | 5 | # Curiosity 6 | from .rollout_curiosity import RolloutSensorDictCuriosityReplayBuffer 7 | 8 | -------------------------------------------------------------------------------- /evkit/rl/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Necessary for my KFAC implementation. 6 | class AddBias(nn.Module): 7 | def __init__(self, bias): 8 | super(AddBias, self).__init__() 9 | self._bias = nn.Parameter(bias.unsqueeze(1)) 10 | 11 | def forward(self, x): 12 | if x.dim() == 2: 13 | bias = self._bias.t().view(1, -1) 14 | else: 15 | bias = self._bias.t().view(1, -1, 1, 1) 16 | 17 | return x + bias 18 | 19 | 20 | def init(module, weight_init, bias_init, gain=1): 21 | weight_init(module.weight.data, gain=gain) 22 | bias_init(module.bias.data) 23 | return module 24 | 25 | 26 | # https://github.com/openai/baselines/blob/master/baselines/common/tf_util.py#L87 27 | def init_normc_(weight, gain=1): 28 | weight.normal_(0, 1) 29 | weight *= gain / torch.sqrt(weight.pow(2).sum(1, keepdim=True)) 30 | -------------------------------------------------------------------------------- /evkit/saving/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/saving/__init__.py -------------------------------------------------------------------------------- /evkit/saving/checkpoints.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import shutil 5 | import subprocess 6 | import torch 7 | import dill as pickle 8 | import warnings 9 | 10 | def load_experiment_configs(log_dir, uuid=None): 11 | ''' 12 | Loads all experiments in a given directory 13 | Optionally, may be restricted to those with a given uuid 14 | ''' 15 | dirs = [f for f in os.listdir(log_dir) if os.path.isdir(os.path.join(log_dir, f))] 16 | results = [] 17 | for d in dirs: 18 | cfg_path = os.path.join(log_dir, d, 'config.json') 19 | if not os.path.exists(cfg_path): 20 | continue 21 | with open(os.path.join(log_dir, d, 'config.json'), 'r') as f: 22 | results.append(json.load(f)) 23 | if uuid is not None and results[-1]['uuid'] != uuid: 24 | results.pop() 25 | return results 26 | 27 | def load_experiment_config_paths(log_dir, uuid=None): 28 | dirs = [f for f in os.listdir(log_dir) if os.path.isdir(os.path.join(log_dir, f))] 29 | results = [] 30 | for d in dirs: 31 | cfg_path = os.path.join(log_dir, d, 'config.json') 32 | if not os.path.exists(cfg_path): 33 | continue 34 | with open(cfg_path, 'r') as f: 35 | cfg = json.load(f) 36 | results.append(cfg_path) 37 | if uuid is not None and cfg['uuid'] != uuid: 38 | results.pop() 39 | return results 40 | 41 | 42 | 43 | 44 | 45 | def checkpoint_name(checkpoint_dir, epoch='latest'): 46 | return os.path.join(checkpoint_dir, 'ckpt-{}.dat'.format(epoch)) 47 | 48 | def last_archived_run(base_dir, uuid): 49 | ''' Returns the name of the last archived run. Of the form: 50 | 'UUID_run_K' 51 | ''' 52 | archive_dir = os.path.join(base_dir, 'archive') 53 | existing_runs = glob.glob(os.path.join(archive_dir, uuid + "_run_*")) 54 | print(os.path.join(archive_dir, uuid + "_run_*")) 55 | if len(existing_runs) == 0: 56 | return None 57 | run_numbers = [int(run.split("_")[-1]) for run in existing_runs] 58 | current_run_number = max(run_numbers) if len(existing_runs) > 0 else 0 59 | current_run_archive_dir = os.path.join(archive_dir, "{}_run_{}".format(uuid, current_run_number)) 60 | return current_run_archive_dir 61 | 62 | def archive_current_run(base_dir, uuid): 63 | ''' Archives the current run. That is, it moves everything 64 | base_dir/*uuid* -> base_dir/archive/uuid_run_K/ 65 | where K is determined automatically. 66 | ''' 67 | matching_files = glob.glob(os.path.join(base_dir, "*" + uuid + "*")) 68 | if len(matching_files) == 0: 69 | return 70 | 71 | archive_dir = os.path.join(base_dir, 'archive') 72 | os.makedirs(archive_dir, exist_ok=True) 73 | existing_runs = glob.glob(os.path.join(archive_dir, uuid + "_run_*")) 74 | run_numbers = [int(run.split("_")[-1]) for run in existing_runs] 75 | current_run_number = max(run_numbers) + 1 if len(existing_runs) > 0 else 0 76 | current_run_archive_dir = os.path.join(archive_dir, "{}_run_{}".format(uuid, current_run_number)) 77 | os.makedirs(current_run_archive_dir) 78 | for f in matching_files: 79 | shutil.move(f, current_run_archive_dir) 80 | return 81 | 82 | 83 | def save_checkpoint(obj, directory, step_num, use_thread=False): 84 | if use_thread: 85 | warnings.warn('use_threads set to True, but done synchronously still') 86 | os.makedirs(directory, exist_ok=True) 87 | torch.save(obj, checkpoint_name(directory), pickle_module=pickle) 88 | torch.save(obj, checkpoint_name(directory, step_num), pickle_module=pickle) # using `cp` leads to OSError for RL 89 | -------------------------------------------------------------------------------- /evkit/saving/monitor.py: -------------------------------------------------------------------------------- 1 | from gym.wrappers import Monitor 2 | import gym 3 | 4 | 5 | class VisdomMonitor(Monitor): 6 | 7 | def __init__(self, env, directory, 8 | video_callable=None, force=False, resume=False, 9 | write_upon_reset=False, uid=None, mode=None, 10 | server="localhost", env='main', port=8097): 11 | super(VisdomMonitor, self).__init__(env, directory, 12 | video_callable=video_callable, force=force, 13 | resume=resume, write_upon_reset=write_upon_reset, 14 | uid=uid, mode=mode) 15 | 16 | 17 | 18 | def _close_video_recorder(self): 19 | video_recorder -------------------------------------------------------------------------------- /evkit/saving/naming.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def checkpoint_name(checkpoint_dir, epoch='latest'): 4 | return os.path.join(checkpoint_dir, 'ckpt-{}.dat'.format(epoch)) 5 | -------------------------------------------------------------------------------- /evkit/saving/observers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from sacred.observers import FileStorageObserver 4 | 5 | class FileStorageObserverWithExUuid(FileStorageObserver): 6 | ''' Wraps the FileStorageObserver so that we can pass in the Id. 7 | This allows us to save experiments into subdirectories with 8 | meaningful names. The standard FileStorageObserver jsut increments 9 | a counter.''' 10 | 11 | UNUSED_VALUE = -1 12 | 13 | def started_event(self, ex_info, command, host_info, start_time, config, 14 | meta_info, _id): 15 | _id = config['uuid'] + "_metadata" 16 | super().started_event(ex_info, command, host_info, start_time, config, 17 | meta_info, _id=_id) 18 | 19 | def queued_event(self, ex_info, command, host_info, queue_time, config, 20 | meta_info, _id): 21 | assert 'uuid' in config, "The config must contain a key 'uuid'" 22 | _id = config['uuid'] + "_metadata" 23 | super().queued_event(ex_info, command, host_info, queue_time, config, 24 | meta_info, _id=_id) -------------------------------------------------------------------------------- /evkit/saving/video.py: -------------------------------------------------------------------------------- 1 | import skvideo.io 2 | 3 | 4 | class VideoLogger(object): 5 | ''' Logs a video to a file, frame-by-frame 6 | 7 | All frames must be the same height. 8 | 9 | Example: 10 | >>> logger = VideoLogger("output.mp4") 11 | >>> for i in range(30): 12 | >>> logger.log(color_transitions_(i, n_frames, width, height) ) 13 | >>> del logger #or, just let the logger go out of scope 14 | ''' 15 | 16 | def __init__(self, save_path, fps=30): 17 | fps = str(fps) 18 | self.writer = skvideo.io.FFmpegWriter(save_path, 19 | inputdict={'-r': fps}, 20 | outputdict={ 21 | '-vcodec': 'libx264', 22 | '-r': fps, 23 | }) 24 | self.f_open = False 25 | 26 | def log(self, frame): 27 | ''' Adds a frame to the file 28 | Parameters: 29 | frame: A WxHxC numpy array (uint8). All frames must be the same height 30 | ''' 31 | self.writer.writeFrame(frame) 32 | 33 | def close(self): 34 | try: 35 | self.writer.close() 36 | except AttributeError: 37 | pass 38 | 39 | def __del__(self): 40 | self.close() 41 | 42 | 43 | # def color_video_logger(logger_path, fps=30): 44 | # n_frames = 900 45 | # width, height = 640, 480 46 | # rate = '30' 47 | # writer = skvideo.io.FFmpegWriter("writer_test.mp4", inputdict={ 48 | # '-r': rate, 49 | # }, 50 | # outputdict={ 51 | # '-vcodec': 'libx264', 52 | # '-r': rate, 53 | # }) 54 | # for i in range(n_frames): 55 | # writer.writeFrame((fade_to_white(i, n_frames, width, height) * 255).astype(np.uint8)) 56 | # writer.close() 57 | 58 | 59 | 60 | ###################################### 61 | # TESTING 62 | ###################################### 63 | import numpy as np 64 | def color_transitions_(i, k, width, height): 65 | x = np.linspace(0, 1.0, width) 66 | y = np.linspace(0, 1.0, height) 67 | bg = np.array(np.meshgrid(x, y)) 68 | bg = (1.0 - (i / k)) * bg + (i / k) * (1 - bg) 69 | r = np.ones_like(bg[0][np.newaxis, ...]) * i / k 70 | return np.uint8(np.rollaxis(np.concatenate([bg, r], axis=0), 0, 3) * 255) 71 | -------------------------------------------------------------------------------- /evkit/sensors/__init__.py: -------------------------------------------------------------------------------- 1 | from .sensorpack import SensorPack as SensorDict 2 | from .sensorpack import SensorPack 3 | -------------------------------------------------------------------------------- /evkit/sensors/sensorpack.py: -------------------------------------------------------------------------------- 1 | class SensorPack(dict): 2 | ''' Fun fact, you can slice using np.s_. E.g. 3 | sensors.at(np.s_[:2]) 4 | ''' 5 | 6 | def at(self, val): 7 | return SensorPack({k: v[val] for k, v in self.items()}) 8 | 9 | def apply(self, lambda_fn): 10 | return SensorPack({k: lambda_fn(k, v) for k, v in self.items()}) 11 | 12 | def size(self, idx, key=None): 13 | assert idx == 0, 'can only get batch size for SensorPack' 14 | if key is None: 15 | key = list(self.keys())[0] 16 | return self[key].size(idx) 17 | -------------------------------------------------------------------------------- /evkit/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/utils/__init__.py -------------------------------------------------------------------------------- /evkit/utils/misc.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import pprint 4 | import string 5 | from evkit.preprocess.transforms import rescale_centercrop_resize, rescale, grayscale_rescale, cross_modal_transform, \ 6 | identity_transform, rescale_centercrop_resize_collated, map_pool_collated, map_pool, taskonomy_features_transform, \ 7 | image_to_input_collated, taskonomy_multi_features_transform 8 | from evkit.models.alexnet import alexnet_transform, alexnet_features_transform 9 | from evkit.preprocess.baseline_transforms import blind, pixels_as_state 10 | from evkit.models.srl_architectures import srl_features_transform 11 | import warnings 12 | remove_whitespace = str.maketrans('', '', string.whitespace) 13 | 14 | 15 | def cfg_to_md(cfg, uuid): 16 | ''' Because tensorboard uses markdown''' 17 | return uuid + "\n\n " + pprint.pformat((cfg)).replace("\n", " \n").replace("\n \'", "\n \'") + "" 18 | 19 | def count_trainable_parameters(model): 20 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 21 | 22 | def count_total_parameters(model): 23 | return sum(p.numel() for p in model.parameters()) 24 | 25 | def is_interactive(): 26 | try: 27 | ip = get_ipython() 28 | return ip.has_trait('kernel') 29 | except: 30 | return False 31 | 32 | def is_cuda(model): 33 | return next(model.parameters()).is_cuda 34 | 35 | 36 | class Bunch(object): 37 | def __init__(self, adict): 38 | self.__dict__.update(adict) 39 | self._keys, self._vals = zip(*adict.items()) 40 | self._keys, self._vals = list(self._keys), list(self._vals) 41 | 42 | def keys(self): 43 | return self._keys 44 | 45 | def vals(self): 46 | return self._vals 47 | 48 | 49 | def compute_weight_norm(parameters): 50 | ''' no grads! ''' 51 | total = 0.0 52 | count = 0 53 | for p in parameters: 54 | total += torch.sum(p.data**2) 55 | # total += p.numel() 56 | count += p.numel() 57 | return (total / count) 58 | 59 | def get_number(name): 60 | """ 61 | use regex to get the first integer in the name 62 | if none exists, return -1 63 | """ 64 | try: 65 | num = int(re.findall("[0-9]+", name)[0]) 66 | except: 67 | num = -1 68 | return num 69 | 70 | def append_dict(d, u, stop_recurse_keys=[]): 71 | for k, v in u.items(): 72 | if isinstance(v, collections.Mapping) and k not in stop_recurse_keys: 73 | d[k] = append_dict(d.get(k, {}), v, stop_recurse_keys=stop_recurse_keys) 74 | else: 75 | if k not in d: 76 | d[k] = [] 77 | d[k].append(v) 78 | return d 79 | 80 | def update_dict_deepcopy(d, u): # we need a deep dictionary update 81 | for k, v in u.items(): 82 | if isinstance(v, collections.Mapping): 83 | d[k] = update_dict_deepcopy(d.get(k, {}), v) 84 | else: 85 | d[k] = v 86 | return d 87 | 88 | 89 | def eval_dict_values(d): 90 | for k in d.keys(): 91 | if isinstance(d[k], collections.Mapping): 92 | d[k] = eval_dict_values(d[k]) 93 | elif isinstance(d[k], str): 94 | d[k] = eval(d[k].replace("---", "'")) 95 | return d 96 | 97 | 98 | def search_and_replace_dict(model_kwargs, task_initial): 99 | for k, v in model_kwargs.items(): 100 | if isinstance(v, collections.Mapping): 101 | search_and_replace_dict(v, task_initial) 102 | else: 103 | if isinstance(v, str) and 'encoder' in v and task_initial not in v: 104 | new_pth = v.replace('curvature', task_initial) # TODO make this the string between / and encoder 105 | warnings.warn(f'BE CAREFUL - CHANGING ENCODER PATH: {v} is being replaced for {new_pth}') 106 | model_kwargs[k] = new_pth 107 | return 108 | -------------------------------------------------------------------------------- /evkit/utils/parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | class _CustomDataParallel(nn.Module): 6 | def __init__(self, model, device_ids): 7 | super(_CustomDataParallel, self).__init__() 8 | self.model = nn.DataParallel(model, device_ids=device_ids) 9 | self.model.to(device) 10 | num_devices = torch.cuda.device_count() if device_ids is None else len(device_ids) 11 | print(f"{type(model)} using {num_devices} GPUs!") 12 | 13 | def forward(self, *input, **kwargs): 14 | return self.model(*input, **kwargs) 15 | 16 | def __getattr__(self, name): 17 | try: 18 | return super().__getattr__(name) 19 | except AttributeError: 20 | return getattr(self.model.module, name) -------------------------------------------------------------------------------- /evkit/utils/profiler.py: -------------------------------------------------------------------------------- 1 | ''' A simple profiler for logging ''' 2 | import logging 3 | import time 4 | 5 | class Profiler(object): 6 | def __init__(self, name, logger=None, level=logging.INFO): 7 | self.name = name 8 | self.logger = logger 9 | self.level = level 10 | 11 | def step( self, name ): 12 | """ Returns the duration and stepname since last step/start """ 13 | duration = self.summarize_step( start=self.step_start, step_name=name, level=self.level ) 14 | now = time.time() 15 | self.step_start = now 16 | return duration 17 | 18 | def __enter__( self ): 19 | self.start = time.time() 20 | self.step_start = time.time() 21 | return self 22 | 23 | def __exit__( self, exception_type, exception_value, traceback ): 24 | self.summarize_step( self.start, step_name="complete" ) 25 | 26 | def summarize_step( self, start, step_name="", level=None ): 27 | duration = time.time() - start 28 | step_semicolon = ':' if step_name else "" 29 | if self.logger: 30 | level = level or self.level 31 | self.logger.log( self.level, "{name}{step}: {secs} seconds".format( name=self.name, step=step_semicolon + step_name, secs=duration) ) 32 | else: 33 | print("{name}{step}: {secs} seconds".format( name=self.name, step=step_semicolon + step_name, secs=duration)) 34 | return duration 35 | -------------------------------------------------------------------------------- /evkit/utils/random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | def set_seed(seed): 6 | torch.manual_seed(seed) 7 | torch.cuda.manual_seed(seed) 8 | torch.cuda.manual_seed_all(seed) 9 | np.random.seed(seed) 10 | random.seed(seed) -------------------------------------------------------------------------------- /evkit/utils/viz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/evkit/utils/viz/__init__.py -------------------------------------------------------------------------------- /feature_selector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/feature_selector/__init__.py -------------------------------------------------------------------------------- /feature_selector/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/feature_selector/models/__init__.py -------------------------------------------------------------------------------- /feature_selector/models/student_models.py: -------------------------------------------------------------------------------- 1 | from tlkit.models.student_models import FCN5 2 | FCN5Skip = FCN5 -------------------------------------------------------------------------------- /feature_selector/models/vision_transfer_architectures.py: -------------------------------------------------------------------------------- 1 | # This is here for loading old models 2 | from tlkit.models.student_models import FCN4Reshaped 3 | FCN5SkipCifar = FCN4Reshaped -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.10.9 2 | numba 3 | scikit-image 4 | numpy 5 | torch==1.0.1.post2 6 | torchvision 7 | torchsummary 8 | visdom 9 | ipdb 10 | sacred==0.7.4 11 | matplotlib 12 | tensorflow==1.5.0 # not sure if really needed 13 | psutil 14 | sklearn 15 | tensorboard==1.12 16 | tensorboardX 17 | watchdog 18 | pytest==3.4.2 19 | jupyterlab 20 | ipykernel 21 | gitpython 22 | moviepy 23 | GPUtil 24 | scikit-fmm 25 | dill 26 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/calculate_blind_transfer.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | import time 6 | from tqdm import tqdm 7 | import warnings 8 | 9 | from tlkit.data.splits import taskonomy_no_midlevel as split_taskonomy_no_midlevel 10 | from tlkit.utils import np_to_pil, pil_to_np 11 | from tnt.torchnet.meter.valuesummarymeter import ValueSummaryMeter 12 | from tnt.torchnet.meter.medianimagemeter import MedianImageMeter 13 | 14 | BASE_DIR = '/mnt/data' 15 | task = 'normal' # TODO only works for pix tasks, impl for logits (e.g. class_object) 16 | #building_dir = os.path.join(BASE_DIR, task, 'allensville') 17 | 18 | # Get image paths for all images for all splits 19 | split_to_images = {} 20 | for split, split_data in split_taskonomy_no_midlevel.items(): 21 | data_paths = [os.path.join(BASE_DIR, task, building) for building in split_data['train']] 22 | img_paths = [] 23 | for data_dir in data_paths: 24 | if os.path.exists(data_dir): 25 | img_paths.extend([os.path.join(data_dir, fn) for fn in os.listdir(data_dir)]) 26 | else: 27 | warnings.warn(f'{data_dir} is missing.') 28 | split_to_images[split] = img_paths 29 | print(split, len(split_data['train']), len(img_paths)) 30 | 31 | 32 | chunk = lambda l, n: [l[i:i+n] for i in range(0, len(l), max(1,n))] 33 | def compute_optimal_imgs(img_paths, use_pool=False): 34 | median_time, mean_time, pil_time = 0, 0, 0 35 | img_paths = [path for path in img_paths if '.png' in path] 36 | mean_meter = ValueSummaryMeter() 37 | median_meter = MedianImageMeter(bit_depth=8, im_shape=(256, 256, 3), device='cuda') 38 | p = Pool(6) 39 | for img_paths_chunk in tqdm(chunk(img_paths, 64)): 40 | t0 = time.time() 41 | if use_pool: 42 | imgs = p.map(Image.open, img_paths_chunk) 43 | else: 44 | imgs = [Image.open(img_path) for img_path in img_paths_chunk] 45 | t1 = time.time() 46 | for img in imgs: 47 | median_meter.add(pil_to_np(img)) # keep at uint8 - median wants discrete numbers 48 | t2 = time.time() 49 | for img in imgs: 50 | mean_meter.add(pil_to_np(img).astype(np.float32)) # convert to float - mean requires compute 51 | img.close() 52 | t3 = time.time() 53 | median_time += t2 - t1 54 | mean_time += t3 - t2 55 | pil_time += t1 - t0 56 | p.close() 57 | print('median', median_time, 'mean', mean_time, 'pil', pil_time) 58 | return np_to_pil(mean_meter.value()[0]), np_to_pil(median_meter.value()) 59 | 60 | for split, img_paths in split_to_images.items(): 61 | print(f'starting {split}') 62 | mean, median = compute_optimal_imgs(img_paths, use_pool=True) 63 | mean.save(os.path.join(BASE_DIR, task, f'mean_{split}.png')) 64 | median.save(os.path.join(BASE_DIR, task, f'median_{split}.png')) 65 | -------------------------------------------------------------------------------- /scripts/prep/copy.sh: -------------------------------------------------------------------------------- 1 | watch -n 60 'rsync --recursive /mnt/logdir2/$1 /mnt/logdir/ ' &>/dev/null & 2 | -------------------------------------------------------------------------------- /scripts/prep/count_num_frames.py: -------------------------------------------------------------------------------- 1 | # count_num_frames in a expert trajectory directory (needs to be /train) 2 | import os 3 | import shutil 4 | import sys 5 | from tqdm import tqdm 6 | import re 7 | 8 | try: 9 | buildings_dir = sys.argv[1] 10 | except: 11 | buildings_dir = '/mnt/data/expert_trajs/debug/train' 12 | 13 | total_frames = 0 14 | n_episodes = 0 15 | for building in tqdm(os.listdir(buildings_dir)): 16 | episodes_dir = os.path.join(buildings_dir, building) 17 | for episode in sorted(os.listdir(episodes_dir)): 18 | episode_pth = os.path.join(episodes_dir, episode) 19 | last = sorted(os.listdir(episode_pth))[-1] 20 | num_frames = int(re.findall("[0-9]+", last)[0]) + 1 21 | total_frames += num_frames 22 | n_episodes += 1 23 | print(f'In {buildings_dir}, we have a total of {total_frames} frames in {n_episodes} episodes, averaging {total_frames//n_episodes} frames per episode') -------------------------------------------------------------------------------- /scripts/prep/csv_read.py: -------------------------------------------------------------------------------- 1 | # Reads the document that maps building to train/val/test split 2 | # Used in dataloader 3 | 4 | import csv 5 | mid_level_train_buildings = ['beechwood'] 6 | mid_level_test_buildings = ['aloha', 'ancor', 'corder', 'duarte', 'eagan', 'globe', 'hanson', 'hatfield', 'kemblesville', 'martinville', 'sweatman', 'vails', 'wiconisco'] 7 | 8 | forbidden_buildings = ['mosquito', 'castroville', 'goodyear'] 9 | 10 | with open('train_val_test_fullplus.csv') as csvfile: 11 | readCSV = csv.reader(csvfile, delimiter=',') 12 | 13 | train_list = [] 14 | val_list = [] 15 | test_list = [] 16 | 17 | for row in readCSV: 18 | name, is_train, is_val, is_test = row 19 | 20 | if is_train == '1': 21 | train_list.append(name) 22 | if is_val == '1': 23 | val_list.append(name) 24 | if is_test == '1': 25 | test_list.append(name) 26 | 27 | print(train_list) 28 | print(val_list) 29 | print(test_list) 30 | 31 | print(len(train_list)) 32 | print(len(val_list)) 33 | print(len(test_list)) 34 | 35 | 36 | l = [b for b in mid_level_test_buildings if b not in train_list] 37 | print(l) 38 | 39 | train_list = [b for b in train_list if b not in mid_level_test_buildings] 40 | print(len(train_list)) 41 | print(len(mid_level_test_buildings)) 42 | 43 | -------------------------------------------------------------------------------- /scripts/prep/download.sh: -------------------------------------------------------------------------------- 1 | # wget -i rgb_links_taskonomydata.txt 2 | wget -i http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/beechwood_rgb.tar 3 | -------------------------------------------------------------------------------- /scripts/prep/make_fewshot_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | ns = [5, 10, 500, 1000] 5 | classes = ['rgb', 'normal', 'curvature_encoding', 'mask_valid'] 6 | BASEDIR = '/mnt/hdd1/taskonomy/small/' 7 | DRY = False 8 | n = 5 9 | get_point_no = lambda x: re.search(r'\d+', x).group() 10 | get_images_from_dir = lambda wd: sorted([x for x in os.listdir(wd) if '.png' in x or '.npy in x']) 11 | 12 | for n in ns: 13 | for cl in classes: 14 | print(f'---copying {n} from {cl} ---') 15 | old_dir = os.path.join(BASEDIR, cl, 'collierville') 16 | new_dir = os.path.join(BASEDIR, cl, f'collierville{n}') 17 | mkdir_cmd = f'mkdir -p {new_dir}' 18 | if DRY: 19 | print(mkdir_cmd) 20 | else: 21 | subprocess.Popen(mkdir_cmd, shell=True) 22 | images = get_images_from_dir(old_dir) 23 | seen_point_nos = set() 24 | for img in images: 25 | if len(seen_point_nos) > n: 26 | break 27 | point_no = get_point_no(img) 28 | if point_no not in seen_point_nos: 29 | seen_point_nos.add(point_no) 30 | src = os.path.join(old_dir, img) 31 | target = os.path.join(new_dir, img) 32 | cp_cmd = f'cp {src} {new_dir}' 33 | if DRY: 34 | print(cp_cmd) 35 | else: 36 | subprocess.Popen(cp_cmd, shell=True) 37 | 38 | 39 | -------------------------------------------------------------------------------- /scripts/prep/make_fewshot_squad.py: -------------------------------------------------------------------------------- 1 | import sys 2 | N = eval(sys.argv[1]) if len(sys.argv) >= 2 else 10 3 | N = int(N) 4 | 5 | import json 6 | data = json.load(open('/mnt/data/squad2/train-v2.0.json', 'r')) 7 | data['data'] = data['data'][:1] 8 | data['data'][-1]['paragraphs'] = data['data'][-1]['paragraphs'][:N] 9 | num_examples = sum([len(x['qas']) for x in data['data'][-1]['paragraphs']]) 10 | 11 | data = json.dump(data, open(f'/mnt/data/squad2/trainfew{num_examples}-v2.0.json', 'w')) 12 | print(f'Created dataset with {num_examples} question answer pairs') 13 | -------------------------------------------------------------------------------- /scripts/prep/move_models.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | task_mapping = { 4 | 'autoencoder': 'autoencoding', 5 | 'colorization': 'colorization', 6 | 'curvature': 'curvature', 7 | 'denoise': 'denoising', 8 | 'edge2d':'edge_texture', 9 | 'edge3d': 'edge_occlusion', 10 | 'ego_motion': 'egomotion', 11 | 'fix_pose': 'fixated_pose', 12 | 'jigsaw': 'jigsaw', 13 | 'keypoint2d': 'keypoints2d', 14 | 'keypoint3d': 'keypoints3d', 15 | 'non_fixated_pose': 'nonfixated_pose', 16 | 'point_match': 'point_matching', 17 | 'reshade': 'reshading', 18 | 'rgb2depth': 'depth_zbuffer', 19 | 'rgb2mist': 'depth_euclidean', 20 | 'rgb2sfnorm': 'normal', 21 | 'room_layout': 'room_layout', 22 | 'segment25d': 'segment_unsup25d', 23 | 'segment2d': 'segment_unsup2d', 24 | 'segmentsemantic': 'segment_semantic', 25 | 'class_1000': 'class_object', 26 | 'class_places': 'class_scene', 27 | 'inpainting_whole': 'inpainting', 28 | 'vanishing_point': 'vanishing_point' 29 | } 30 | 31 | 32 | if __name__ == '__main__': 33 | for old, new in task_mapping.items(): 34 | subprocess.call(f'mv taskonomy_data/{old}_encoder.dat taskonomy_data/{new}_encoder.dat', shell=True) 35 | subprocess.call(f'mv taskonomy_data/{old}_decoder.dat taskonomy_data/{new}_decoder.dat', shell=True) 36 | -------------------------------------------------------------------------------- /scripts/prep/rm_empty_exp.py: -------------------------------------------------------------------------------- 1 | # used to remove extra distillation experiments 2 | import os 3 | import subprocess 4 | import sys 5 | 6 | exp_dir = sys.argv[1] 7 | exps = os.listdir(exp_dir) 8 | exp_paths = [os.path.join(exp_dir, exp) for exp in exps] 9 | 10 | num_empty_exp = 0 11 | for exp_path in exp_paths: 12 | ckpt_path = os.path.join(exp_path, 'checkpoints') 13 | if not os.path.exists(ckpt_path): # no checkpoint folder 14 | print('no ckpt dir') 15 | print(os.listdir(exp_path)) 16 | subprocess.call("rm -rf {}".format(exp_path), shell=True) 17 | num_empty_exp += 1 18 | continue 19 | 20 | ckpts = [f for f in os.listdir(ckpt_path) if 'ckpt' in f] 21 | if len(ckpts) < 4: # small checkpoint folder 22 | subprocess.call("rm -rf {}".format(exp_path), shell=True) 23 | num_empty_exp += 1 24 | else: # big checkpoint folder 25 | print('real exp') 26 | print(exp_path) 27 | print(ckpts) 28 | 29 | print(f'killed {num_empty_exp} experiments') 30 | -------------------------------------------------------------------------------- /scripts/prep/run_distillation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ./scripts/run_distillation.sh zero tiny curvature fcn8 3 | 4 | run() { 5 | ALGO=$1 6 | SIZE=$2 7 | TASK=$3 8 | ARCH=$4 9 | echo "Distillation experiment:" $ALGO $SIZE $TASK $ARCH 10 | 11 | if [ "$TASK" == "curvature" ]; then 12 | EXTRA="loss_perceptual_l2" 13 | elif [ "$TASK" == "denoising" ]; then 14 | EXTRA="loss_perceptual" 15 | elif [ "$TASK" == "class_object" ]; then 16 | EXTRA="loss_perceptual_cross_entropy" 17 | else 18 | echo "BAD TASK" 19 | exit 20 | fi 21 | 22 | CMD="python -m tlkit.transfer \ 23 | /mnt/logdir/distillation/${ALGO}/${TASK}_${SIZE}_${ARCH} train with \ 24 | taskonomy_hp model_${ARCH} scheduler_reduce_on_plateau ${EXTRA} \ 25 | uuid=distil \ 26 | cfg.training.data_dir=/mnt/data \ 27 | cfg.training.split_to_use=splits.taskonomy_no_midlevel\[\'${SIZE}\'\] \ 28 | cfg.training.sources=\[\'rgb\'\] \ 29 | cfg.training.targets=\[\'${TASK}_encoding\'\] \ 30 | cfg.training.loss_kwargs.decoder_path='/mnt/models/${TASK}_decoder.dat' \ 31 | cfg.training.annotator_weights_path='/mnt/models/${TASK}_encoder.dat' \ 32 | cfg.training.seed=42 \ 33 | cfg.training.num_epochs=10 \ 34 | cfg.training.loss_kwargs.bake_decodings=False \ 35 | cfg.training.suppress_target_and_use_annotator=True \ 36 | cfg.training.resume_from_checkpoint_path=/mnt/logdir/distillation/${ALGO}/${TASK}_${SIZE}/checkpoints/ckpt-latest.dat \ 37 | cfg.training.resume_training=True \ 38 | cfg.training.algo=${ALGO}" 39 | 40 | echo $CMD 41 | bash -c "$CMD" 42 | } 43 | export -f run 44 | 45 | ALGOS='student zero' 46 | SIZES='debug tiny small medium large full fullplus' 47 | TASKS='curvature denoising class_object' 48 | ARCHS='fcn5_skip fcn8' 49 | 50 | run ${1} ${2} ${3} ${4} 51 | -------------------------------------------------------------------------------- /scripts/prep/shrink_rgb.sh: -------------------------------------------------------------------------------- 1 | 2 | model=$1 3 | task=rgb 4 | mkdir /mnt/hdd1/taskonomy/small/$task/$model 5 | mkdir /tmp/${task}/$model 6 | tar -xf /mnt/hdd2/taskonomy/${task}/${model}_${task}.tar -C /tmp/${task}/$model/ 7 | cd /root/feature_selector 8 | python -m feature_selector.shrink_images with data_dir=/tmp/ folders_to_convert=\[\"$model\"\] save_dir=/mnt/hdd1/taskonomy/small 9 | rm -rf /tmp/${task}/$model/${task} 10 | rmdir /tmp/${task}/$model 11 | -------------------------------------------------------------------------------- /scripts/prep/store_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | exp_name = sys.argv[1] 5 | in_file = f'/mnt/logdir/{exp_name}/checkpoints/ckpt-latest.dat' 6 | out_file = f'/mnt/logdir/{exp_name}/checkpoints/weights_and_more-latest.dat' 7 | 8 | ckpt = torch.load(in_file) 9 | agent = ckpt['agent'] 10 | epoch = ckpt['epoch'] 11 | new_ckpt = {'optimizer': agent.optimizer, 'state_dict': agent.actor_critic.state_dict(), 'epoch': epoch } 12 | 13 | torch.save(new_ckpt, out_file) 14 | 15 | 16 | ckpt = torch.load(out_file) 17 | agent.actor_critic.load_state_dict(ckpt['state_dict']) 18 | agent.optimizer = ckpt['optimizer'] 19 | 20 | print('done') -------------------------------------------------------------------------------- /scripts/prep/subsample_expert_trajs.py: -------------------------------------------------------------------------------- 1 | # sample from DATA_SOURCE every SUBSAMPLE_RATE episode and store under /mnt/data/expert_trajs/DATA_SIZE 2 | # take all val episodes 3 | 4 | import os 5 | import shutil 6 | import sys 7 | from tqdm import tqdm 8 | 9 | DATA_SIZE = sys.argv[1] if len(sys.argv) >= 4 else 'small' 10 | SUBSAMPLE_RATE = eval(sys.argv[2]) if len(sys.argv) >= 4 else 100 11 | DATA_SOURCE = sys.argv[3] if len(sys.argv) >= 4 else 'large' 12 | 13 | BASE_DIR = '/mnt/data/expert_trajs' 14 | SOURCE_DIR = os.path.join(BASE_DIR, DATA_SOURCE) 15 | TARGET_DIR = os.path.join(BASE_DIR, DATA_SIZE) 16 | buildings_dir = os.path.join(SOURCE_DIR, 'train') 17 | 18 | os.makedirs(os.path.join(BASE_DIR, DATA_SIZE, 'train'), exist_ok=True) 19 | 20 | # handle train split 21 | counter = 0 22 | for building in tqdm(os.listdir(buildings_dir)): 23 | episodes_dir = os.path.join(buildings_dir, building) 24 | for episode in sorted(os.listdir(episodes_dir)): 25 | episode_pth = os.path.join(episodes_dir, episode) 26 | counter += 1 27 | if counter % SUBSAMPLE_RATE == 0: 28 | copy_loc = episode_pth.replace(DATA_SOURCE, DATA_SIZE) 29 | shutil.copytree(episode_pth, copy_loc) 30 | 31 | print(f'Copied over {counter//SUBSAMPLE_RATE} training episodes from {DATA_SOURCE} to {DATA_SIZE}') 32 | 33 | # handle val/test splits 34 | os.symlink(os.path.join(SOURCE_DIR, 'val'), os.path.join(TARGET_DIR, 'val')) 35 | os.symlink(os.path.join(SOURCE_DIR, 'test'), os.path.join(TARGET_DIR, 'test')) 36 | print(f'Transfered from {counter//SUBSAMPLE_RATE} training episodes from {DATA_SOURCE} to {DATA_SIZE}') 37 | -------------------------------------------------------------------------------- /scripts/prep/subsample_squad.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | subsample_rate = eval(sys.argv[1]) if len(sys.argv) > 1 else 2 5 | data_path = "/mnt/data/squad2/train-v2.0.json" 6 | 7 | data = json.load(open(data_path, 'r')) 8 | data['data'] = data['data'][::subsample_rate] 9 | with open(data_path.replace('train', f'train{subsample_rate}'), 'w') as f: 10 | json.dump(data, f) 11 | -------------------------------------------------------------------------------- /scripts/prep/subset_tars.py: -------------------------------------------------------------------------------- 1 | # Puts only models (buildings) from the specified split and task into a new directory 2 | 3 | import csv 4 | import os 5 | import subprocess 6 | import sys 7 | 8 | tar_dir = '/mnt/barn/data/taskonomy_small' 9 | 10 | task = '' # empty string does all 11 | SPLIT = 'tiny' 12 | new_dir = f'/mnt/jeff_data/data/{SPLIT}' 13 | dry_run = False 14 | split_csv = f'/root/perception_module/tlkit/data/splits_taskonomy/train_val_test_{SPLIT}.csv' 15 | 16 | with open(split_csv) as csvfile: 17 | readCSV = csv.reader(csvfile, delimiter=',') 18 | models = [row[0] for row in readCSV][1:] 19 | 20 | for model in models: 21 | source = f'{tar_dir}/{model}_{task}*' 22 | os.makedirs(new_dir, exist_ok=True) 23 | cmd = f'rsync -chavP --ignore-existing {source} {new_dir}' 24 | if dry_run: 25 | print(cmd) 26 | else: 27 | subprocess.call(cmd, shell=True) 28 | -------------------------------------------------------------------------------- /scripts/run_hps.py: -------------------------------------------------------------------------------- 1 | import sys 2 | LOG_DIR = sys.argv[1] # must be run before import transfer because transfer will pop the argv 3 | from scripts.train_transfer import ex 4 | from tlkit.utils import flatten 5 | import numpy as np 6 | import os 7 | import subprocess 8 | from evkit.saving.observers import FileStorageObserverWithExUuid 9 | 10 | @ex.command 11 | def run_hps(cfg, uuid): 12 | print(cfg) 13 | 14 | # Get argv 15 | argv_plus_hps = sys.argv 16 | script_name = argv_plus_hps[0] 17 | script_name = script_name.replace('.py','').replace('/','.') 18 | script_name = script_name[1:] if script_name.startswith('.') else script_name 19 | 20 | # Sample and load HPS into argv 21 | for hp, hp_range in flatten(cfg['hps_kwargs']['hp']).items(): 22 | hp_val = np.power(10, np.random.uniform(*hp_range)) 23 | argv_plus_hps.append(f'cfg.{hp}={hp_val}') 24 | 25 | # Update argv script name and uuid 26 | argv_plus_hps = [a.replace('run_hps', cfg['hps_kwargs']['script_name']) for a in argv_plus_hps] 27 | argv_plus_hps.append(f'uuid={uuid}_hps_run') 28 | 29 | # Run real experiment 30 | print(f'python -m {script_name} {LOG_DIR} {" ".join(argv_plus_hps[1:])}') 31 | ex.run_commandline(argv=argv_plus_hps) 32 | 33 | 34 | @ex.named_config 35 | def cfg_hps(): 36 | uuid='hps' 37 | cfg = {} 38 | cfg['hps_kwargs'] = { 39 | 'hp': { 40 | # pass in hp like you would for regular run. but instead of a number, pass in a log exp range 41 | # (if not log range, we will need to update, maybe with explicit dictionaries) 42 | 'learner': { 43 | 'lr': (-5, -3), 44 | 'optimizer_kwargs' : { 45 | 'weight_decay': (-6,-4) 46 | }, 47 | }, 48 | }, 49 | 'script_name': 'train', 50 | 'add_time_to_logdir': True, # TODO Should I edit logdir for uniqueness or do it manually? time, make it a parameter to do it or not 51 | } 52 | 53 | if __name__ == '__main__': 54 | assert LOG_DIR, 'log dir cannot be empty' 55 | os.makedirs(LOG_DIR, exist_ok=True) 56 | subprocess.call("rm -rf {}/*".format(LOG_DIR), shell=True) 57 | ex.observers.append(FileStorageObserverWithExUuid.create(LOG_DIR)) 58 | ex.run_commandline() 59 | else: 60 | print(__name__) 61 | 62 | 63 | -------------------------------------------------------------------------------- /scripts/run_lifelong_cifar.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ./scripts/run_lifelong_cifar.sh sidetune model_boosted_cifar 3 | 4 | MODELS='sidetune_reverse finetune independent features' 5 | EXTRAS='ewc bsp_norecurse_cifar init_xavier init_lowenergy_cifar' 6 | 7 | MODEL=$1 8 | EXTRA=${2:-} 9 | HP=${3:-} 10 | 11 | if [ "$MODEL" == "all" ]; then 12 | for MODEL in $MODELS; do 13 | export MODEL 14 | python -m feature_selector.lifelong \ 15 | /mnt/logdir/lifelong/icifar/${MODEL} \ 16 | train with cifar_hp icifar_data \ 17 | model_lifelong_${MODEL}_cifar 18 | done 19 | 20 | HP='4' 21 | for EXTRA in $EXTRAS; do 22 | export EXTRA 23 | python -m feature_selector.lifelong \ 24 | /mnt/logdir/lifelong/icifar/finetune_${EXTRA}${HP} \ 25 | train with cifar_hp icifar_data \ 26 | model_lifelong_finetune_cifar ${EXTRA}\ 27 | cfg.training.regularizer_kwargs.coef=1e${HP} 28 | done 29 | else 30 | if [ "$HP" == "" ]; then 31 | HP_TXT="" 32 | else 33 | HP_TXT="cfg.training.regularizer_kwargs.coef=1e${HP}" 34 | fi 35 | CMD="python -m scripts.train_lifelong \ 36 | /mnt/logdir/lifelong/icifar/${MODEL}_${EXTRA}${HP} \ 37 | train with cifar_hp icifar_data \ 38 | model_lifelong_${MODEL}_cifar ${EXTRA}\ 39 | ${HP_TXT}" 40 | 41 | echo $CMD 42 | bash -c "$CMD" 43 | fi 44 | 45 | -------------------------------------------------------------------------------- /scripts/run_lifelong_taskonomy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # MODEL='finetune' && SIZE='std' && SPLIT='12' && EXTRA='bsp' && HP='' 3 | # ./scripts/run_lifelong_taskonomy.sh finetune std ewc 2 4 | 5 | run() { 6 | MODEL=$1 7 | SIZE=$2 8 | SPLIT=$3 9 | EXTRA=$4 10 | HP=$5 11 | if [ "$MODEL" == "independent" ]; then 12 | BATCHSIZE=16 13 | else 14 | BATCHSIZE=32 15 | fi 16 | if [ "$HP" == "" ]; then 17 | HPTXT="" 18 | else 19 | HPTXT="cfg.training.regularizer_kwargs.coef=1e${HP}" 20 | fi 21 | CMD="python -m scripts.train_lifelong \ 22 | /mnt/logdir/lifelong/taskonomy/arxiv/${SPLIT}_${MODEL}_${SIZE}_${EXTRA}${HP} train with \ 23 | taskonomy_base_data taskonomy_${SPLIT}_data ${EXTRA} \ 24 | model_lifelong_${MODEL}_${SIZE}_taskonomy ${EXTRA} \ 25 | model_learned_decoder \ 26 | ${HPTXT} cfg.training.dataloader_fn_kwargs.batch_size=${BATCHSIZE}" 27 | echo $CMD 28 | bash -c "$CMD" 29 | } 30 | export -f run 31 | 32 | #SIZES="std resnet50 fcn5s" 33 | #SPLITS="3 12 shuffle12" 34 | #EXTRAS="bsp ewc init_xavier init_lowenergy debug pnn_v4 " 35 | MODELS="finetune sidetune independent" 36 | 37 | if [ "$1" == "all" ]; then 38 | for MODEL in $MODELS; do 39 | export MODEL 40 | bash -c "run ${MODEL} std 12" 41 | done 42 | 43 | bash -c "run finetune std 12 ewc 2" 44 | bash -c "run finetune std 12 bsp" 45 | else 46 | run ${1} ${2} ${3} ${4} ${5} 47 | fi 48 | 49 | #OLD COMMAND 50 | #MODEL='finetune' && SIZE='std' && SPLIT='big' && EXTRA='bsp' && HP='' && \ 51 | #python -m tlkit.lifelong \ 52 | # /mnt/logdir/lifelong/taskonomy/${SPLIT}_${MODEL}_${SIZE}_${EXTRA}${HP} train with \ 53 | # taskonomy_${SPLIT}_data ${EXTRA} \ 54 | # model_lifelong_${MODEL}_${SIZE}_taskonomy \ 55 | # model_learned_decoder \ 56 | # cfg.training.regularizer_kwargs.coef=1e${HP} 57 | -------------------------------------------------------------------------------- /scripts/run_nlp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ./scripts/run_nlp.sh train features 3 | 4 | METHODS='features sidetune finetune scratch' 5 | 6 | PHASE=$1 7 | METHOD=$2 8 | SIZE=${3:-1} 9 | 10 | if [ "$SIZE" == "1" ]; then 11 | SIZE="" 12 | NUM_EPOCHS=2 13 | CACHE_DIR="/mnt/data/squad2" 14 | elif [ "$SIZE" == "few125" ]; then 15 | NUM_EPOCHS=100 16 | CACHE_DIR="/mnt/data/squad2${SIZE}" 17 | else 18 | NUM_EPOCHS=5 19 | CACHE_DIR="/mnt/data/squad2_${SIZE}" 20 | fi 21 | 22 | if [ "$PHASE" == "train" ]; then 23 | CMD="python -m torch.distributed.launch --master_port=6011 --nproc_per_node=3 ./nlp/run_squad.py \ 24 | --model_type bert \ 25 | --model_name_or_path bert-large-uncased-whole-word-masking \ 26 | --do_train \ 27 | --do_eval \ 28 | --do_lower_case \ 29 | --train_file /mnt/data/squad2_${SIZE}/train${SIZE}-v2.0.json \ 30 | --predict_file /mnt/data/squad2_${SIZE}/dev-v2.0.json \ 31 | --version_2_with_negative \ 32 | --learning_rate 3e-5 \ 33 | --num_train_epochs ${NUM_EPOCHS} \ 34 | --max_seq_length 384 \ 35 | --doc_stride 128 \ 36 | --per_gpu_eval_batch_size=1 \ 37 | --per_gpu_train_batch_size=1 \ 38 | --save_steps 10000 \ 39 | --cache_dir ${CACHE_DIR} \ 40 | --output_dir /mnt/models/wwm_uncased_${METHOD}_squad${SIZE}/ \ 41 | --${METHOD}" 42 | elif [ "$PHASE" == "eval" ] || [ "$PHASE" == "test" ]; then 43 | CMD="python -m nlp.run_squad \ 44 | --model_type bert \ 45 | --model_name_or_path /mnt/models/wwm_uncased_${METHOD}_squad${SIZE}/ \ 46 | --config_name /mnt/models/wwm_uncased_${METHOD}_squad${SIZE}/config.json \ 47 | --do_eval \ 48 | --do_lower_case \ 49 | --train_file /mnt/data/squad2_${SIZE}/train${SIZE}-v2.0.json \ 50 | --predict_file /mnt/data/squad2_${SIZE}/dev-v2.0.json \ 51 | --version_2_with_negative \ 52 | --learning_rate 3e-5 \ 53 | --num_train_epochs 2 \ 54 | --max_seq_length 384 \ 55 | --doc_stride 128 \ 56 | --output_dir /mnt/models/wwm_uncased_${METHOD}_squad${SIZE}/ \ 57 | --per_gpu_eval_batch_size=2 \ 58 | --cache_dir ${CACHE_DIR} \ 59 | --${METHOD}" 60 | else 61 | echo BAD PHASE 62 | exit 63 | fi 64 | 65 | echo $CMD 66 | bash -c "$CMD" 67 | -------------------------------------------------------------------------------- /scripts/run_rl_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ./scripts/run_rl_eval.sh sidetune curvature 3 | # ./scripts/run_rl_eval.sh sidetune curvature 3700 4 | 5 | MODELS='sidetune finetune feat scratch' 6 | 7 | run() { 8 | MODEL=$1 9 | TASK=$2 10 | CKPTNUM=${3:-None} 11 | if [ "$MODEL" == "feat" ]; then 12 | MODELTXT="taskonomy_features \ 13 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.encoder_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 14 | cfg.env.transform_fn_post_aggregation_kwargs.names_to_transforms.taskonomy='taskonomy_features_transform(---/mnt/models/${TASK}_encoder_student.dat---,model=---FCN5Skip---)'" 15 | elif [ "$MODEL" == "sidetune" ]; then 16 | MODELTXT="taskonomy_features sidetune \ 17 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.encoder_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 18 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.sidetuner_network_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 19 | cfg.env.transform_fn_post_aggregation_kwargs.names_to_transforms.taskonomy='taskonomy_features_transform(---/mnt/models/${TASK}_encoder_student.dat---,model=---FCN5Skip---)'" 20 | elif [ "$MODEL" == "finetune" ]; then 21 | MODELTXT="finetune rlgsn_encoder_learned \ 22 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.encoder_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 23 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.sidetuner_network_weights_path=/mnt/models/${TASK}_encoder_student.dat" 24 | elif [ "$MODEL" == "scratch" ]; then 25 | MODELTXT="finetune rlgsn_encoder_learned" 26 | TASK='' 27 | else 28 | echo 'BAD MODEL' 29 | exit 30 | fi 31 | 32 | CMD="python -m scripts.train_rl /mnt/logdir/rl/planning/iclr_eval/${TASK}_${MODEL}_2fcn5s \ 33 | run_training with cfg_habitat planning cfg_test \ 34 | uuid=iclr_${TASK}_${MODEL}_eval \ 35 | rlgsn_encoder_fcn5s rlgsn_side_fcn5s ${MODELTXT} \ 36 | cfg.env.env_specific_kwargs.gpu_devices=\[1\] cfg.training.gpu_devices=\[0\] \ 37 | cfg.saving.checkpoint=/mnt/logdir/rl/planning/iclr/${TASK}_${MODEL}_2fcn5s cfg.training.resumable=True \ 38 | cfg.saving.checkpoint_num=${CKPTNUM}" 39 | # cfg.saving.checkpoint_configs=False 40 | echo $CMD 41 | bash -c "$CMD" 42 | } 43 | export -f run 44 | 45 | run ${1} ${2} ${3} -------------------------------------------------------------------------------- /scripts/run_rl_exps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ./scripts/run_rl_exps.sh small_settings5 sidetune curvature 3 | 4 | MODELS='sidetune finetune feat scratch' 5 | SETTING='small_settings5 corl_settings' 6 | #MODELSIZES='std 2fcn5s' 7 | #N_GPUS=4 8 | #EXTRAS='radam dreg' 9 | 10 | run() { 11 | SETTING=$1 12 | MODEL=$2 13 | TASK=$3 14 | if [ "$SETTING" != "small_settings5" ]; then 15 | echo 'BAD SETTING' 16 | exit 17 | fi 18 | if [ "$MODEL" == "feat" ]; then 19 | MODELTXT="taskonomy_features \ 20 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.base_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 21 | cfg.env.transform_fn_post_aggregation_kwargs.names_to_transforms.taskonomy='taskonomy_features_transform(---/mnt/models/${TASK}_encoder_student.dat---,model=---FCN5---)'" 22 | elif [ "$MODEL" == "sidetune" ]; then 23 | MODELTXT="taskonomy_features sidetune \ 24 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.base_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 25 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.side_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 26 | cfg.env.transform_fn_post_aggregation_kwargs.names_to_transforms.taskonomy='taskonomy_features_transform(---/mnt/models/${TASK}_encoder_student.dat---,model=---FCN5---)'" 27 | elif [ "$MODEL" == "finetune" ]; then 28 | MODELTXT="finetune rlgsn_base_learned \ 29 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.base_weights_path=/mnt/models/${TASK}_encoder_student.dat \ 30 | cfg.learner.perception_network_kwargs.extra_kwargs.sidetune_kwargs.side_weights_path=/mnt/models/${TASK}_encoder_student.dat" 31 | elif [ "$MODEL" == "scratch" ]; then 32 | MODELTXT="finetune rlgsn_base_learned" 33 | TASK='' 34 | else 35 | echo 'BAD MODEL' 36 | exit 37 | fi 38 | 39 | CMD="python -m scripts.train_rl /mnt/logdir/rl/planning/iclr/${TASK}_${MODEL}_2fcn5s \ 40 | run_training with cfg_habitat planning ${SETTING} \ 41 | uuid=iclr_${TASK}_${MODEL} \ 42 | ${MODELTXT} rlgsn_base_fcn5s rlgsn_side_fcn5s \ 43 | cfg.env.env_specific_kwargs.gpu_devices=\[1\] cfg.training.gpu_devices=\[0\]" 44 | echo $CMD 45 | bash -c "$CMD" 46 | } 47 | export -f run 48 | 49 | run ${1} ${2} ${3} -------------------------------------------------------------------------------- /scripts/run_vision_transfer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ./scripts/run_lifelong_taskonomy.sh curvature normal few100 sidetune 3 | 4 | # This script only supports the following settings (others can be set up): 5 | SOURCE_TASK='normal curvature class_object' 6 | TARGET_TASK='normal curvature class_object' 7 | DATA_SIZE='few100 fullplus' 8 | MODEL='sidetune' 9 | 10 | run() { 11 | SOURCE=$1 12 | TARGET=$2 13 | DATA_SIZE=$3 14 | MODEL=$4 15 | 16 | if [ "$TARGET" == "normal" ]; then 17 | TARGET_TXT="cfg.training.loss_fn='weighted_l1_loss'" 18 | elif [ "$TARGET" == "curvature" ]; then 19 | TARGET_TXT="cfg.training.loss_fn='weighted_l2_loss'" 20 | elif [ "$TARGET" == "class_object" ]; then 21 | TARGET_TXT="cfg.training.loss_fn='softmax_cross_entropy'" 22 | else 23 | echo 'Not set up for current target task' 24 | exit 25 | fi 26 | 27 | CMD=" 28 | python -m scripts.train_transfer \ 29 | /mnt/logdir/vision_transfers/arxiv_code/${SOURCE}_to_${TARGET}_${DATA_SIZE}_${MODEL} \ 30 | train with \ 31 | gsn_transfer_residual_prenorm \ 32 | taskonomy_hp \ 33 | cfg.training.data_dir=/mnt/data \ 34 | cfg.learner.max_grad_norm=1 \ 35 | model_sidetune_encoding gsn_side_resnet50 \ 36 | cfg.learner.model_kwargs.base_weights_path='/mnt/models/${SOURCE}_encoder.dat' \ 37 | cfg.learner.model_kwargs.side_weights_path='/mnt/models/${SOURCE}_encoder.dat' \ 38 | cfg.learner.model_kwargs.use_baked_encoding=False \ 39 | cfg.training.sources=\[\'rgb\'\] \ 40 | cfg.training.targets=\[\'${TARGET}\'\] \ 41 | ${TARGET_TXT} \ 42 | data_size_${DATA_SIZE} 43 | " 44 | echo $CMD 45 | bash -c "$CMD" 46 | } 47 | export -f run 48 | 49 | 50 | run ${1} ${2} ${3} ${4} 51 | 52 | -------------------------------------------------------------------------------- /tlkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/tlkit/__init__.py -------------------------------------------------------------------------------- /tlkit/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/tlkit/data/__init__.py -------------------------------------------------------------------------------- /tlkit/data/datasets/fashion_mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch.utils.data import DataLoader 3 | 4 | def get_dataloaders(data_path, 5 | inputs_and_outputs, 6 | batch_size=64, 7 | batch_size_val=4, 8 | transform=None, 9 | num_workers=0, 10 | load_to_mem=False, 11 | pin_memory=False): 12 | 13 | dataloaders = {} 14 | dataset = torchvision.datasets.FashionMNIST(root, train=True, transform=transform, target_transform=None, download=True) 15 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) 16 | dataloaders['train'] = dataloader 17 | 18 | dataset = torchvision.datasets.FashionMNIST(root, train=False, transform=transform, target_transform=None, download=False) 19 | dataloader = DataLoader(dataset, batch_size=batch_size_val, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) 20 | dataloaders['val'] = dataloader 21 | 22 | dataset = torchvision.datasets.FashionMNIST(root, train=False, transform=transform, target_transform=None, download=False) 23 | dataloader = DataLoader(dataset, batch_size=batch_size_val, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) 24 | dataloaders['test'] = dataloader 25 | return dataloaders -------------------------------------------------------------------------------- /tlkit/data/splits.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | forbidden_buildings = ['mosquito', 'tansboro'] 5 | # forbidden_buildings = ['mosquito', 'tansboro', 'tomkins', 'darnestown', 'brinnon'] 6 | # We do not have the rgb data for tomkins, darnestown, brinnon 7 | 8 | SPLIT_TO_NUM_IMAGES = { 9 | 'few100': 100, 10 | 'debug2': 100, 11 | 'debug': 2863, 12 | 'supersmall': 14575, 13 | 'tiny': 262745, 14 | 'fullplus': 3349691, 15 | } 16 | 17 | def get_splits(split_path): 18 | with open(split_path) as csvfile: 19 | readCSV = csv.reader(csvfile, delimiter=',') 20 | 21 | train_list = [] 22 | val_list = [] 23 | test_list = [] 24 | 25 | for row in readCSV: 26 | name, is_train, is_val, is_test = row 27 | if name in forbidden_buildings: 28 | continue 29 | if is_train == '1': 30 | train_list.append(name) 31 | if is_val == '1': 32 | val_list.append(name) 33 | if is_test == '1': 34 | test_list.append(name) 35 | return { 36 | 'train': sorted(train_list), 37 | 'val': sorted(val_list), 38 | 'test': sorted(test_list) 39 | } 40 | 41 | 42 | subsets = ['debug', 'tiny', 'medium', 'full', 'fullplus', 'supersmall', 'few5', 'few100', 'few500', 'few1000', 'debug2'] 43 | split_files = {s: os.path.join(os.path.dirname(__file__), 44 | 'splits_taskonomy', 45 | 'train_val_test_{}.csv'.format(s.lower())) 46 | for s in subsets} 47 | 48 | taskonomy = {s: get_splits(split_files[s]) for s in subsets} 49 | 50 | midlevel = { 51 | 'train': ['beechwood'], 52 | 'test': ['aloha', 'ancor', 'corder', 'duarte', 'eagan', 'globe', 'hanson', 'hatfield', 'kemblesville', 'martinville', 'sweatman', 'vails', 'wiconisco'] 53 | } 54 | taskonomy_no_midlevel = {subset: {split: sorted(buildings) for split, buildings in taskonomy[subset].items()} 55 | for subset in taskonomy.keys()} 56 | for subset, splits in taskonomy_no_midlevel.items(): 57 | taskonomy_no_midlevel[subset]['train'] = [b for b in splits['train'] if b not in midlevel['test']] 58 | -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_debug.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test allensville,1,1,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_debug2.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test collierville100,1,1,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_few100.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test 2 | collierville100,1,0,0 3 | ihlen,0,1,0 4 | mcdade,0,1,0 5 | muleshoe,0,1,0 6 | noxapater,0,1,0 7 | uvalda,0,1,0 8 | allensville,0,0,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_few1000.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test 2 | collierville1000,1,0,0 3 | ihlen,0,1,0 4 | mcdade,0,1,0 5 | muleshoe,0,1,0 6 | noxapater,0,1,0 7 | uvalda,0,1,0 8 | allensville,0,0,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_few5.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test 2 | collierville5,1,0,0 3 | ihlen,0,1,0 4 | mcdade,0,1,0 5 | muleshoe,0,1,0 6 | noxapater,0,1,0 7 | uvalda,0,1,0 8 | allensville,0,0,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_few500.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test 2 | collierville500,1,0,0 3 | ihlen,0,1,0 4 | mcdade,0,1,0 5 | muleshoe,0,1,0 6 | noxapater,0,1,0 7 | uvalda,0,1,0 8 | allensville,0,0,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_medium.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test hanson,1,0,0 merom,1,0,0 goodfield,1,0,0 eagan,1,0,0 adairsville,1,0,0 castor,1,0,0 klickitat,1,0,0 cottonport,1,0,0 tyler,1,0,0 sugarville,1,0,0 martinville,1,0,0 chilhowie,1,0,0 silas,1,0,0 lynchburg,1,0,0 tokeland,1,0,0 onaga,1,0,0 frankfort,1,0,0 goodyear,1,0,0 albertville,1,0,0 andover,1,0,0 airport,1,0,0 rogue,1,0,0 ancor,1,0,0 leonardo,1,0,0 maida,1,0,0 marstons,1,0,0 athens,1,0,0 newfields,1,0,0 broseley,1,0,0 irvine,1,0,0 pinesdale,1,0,0 tilghmanton,1,0,0 goodwine,1,0,0 hildebran,1,0,0 winooski,1,0,0 lakeville,1,0,0 cosmos,1,0,0 goffs,1,0,0 sunshine,1,0,0 globe,1,0,0 benevolence,1,0,0 emmaus,1,0,0 pomaria,1,0,0 neibert,1,0,0 parole,1,0,0 tolstoy,1,0,0 shelbyville,1,0,0 potterville,1,0,0 rosser,1,0,0 allensville,1,0,0 springerville,1,0,0 nuevo,1,0,0 stilwell,1,0,0 browntown,1,0,0 readsboro,1,0,0 shelbiana,1,0,0 wainscott,1,0,0 arkansaw,1,0,0 bonnie,1,0,0 beechwood,1,0,0 hominy,1,0,0 churchton,1,0,0 coffeen,1,0,0 willow,1,0,0 timberon,1,0,0 bohemia,1,0,0 micanopy,1,0,0 hillsdale,1,0,0 wilseyville,1,0,0 kemblesville,1,0,0 thrall,1,0,0 bonesteel,1,0,0 annona,1,0,0 stockman,1,0,0 soldier,1,0,0 neshkoro,1,0,0 newcomb,1,0,0 byers,1,0,0 oyens,1,0,0 victorville,1,0,0 pamelia,1,0,0 marland,1,0,0 hiteman,1,0,0 sussex,1,0,0 bautista,1,0,0 highspire,1,0,0 woodbine,1,0,0 sweatman,1,0,0 clairton,1,0,0 touhy,1,0,0 lindenwood,1,0,0 anaheim,1,0,0 duarte,1,0,0 musicks,1,0,0 forkland,1,0,0 mifflinburg,1,0,0 hainesburg,1,0,0 maugansville,1,0,0 ranchester,1,0,0 hortense,0,1,0 southfield,0,1,0 wiconisco,0,1,0 gravelly,0,1,0 hordville,0,1,0 corozal,0,1,0 swormville,0,1,0 collierville,0,1,0 pearce,0,1,0 pablo,0,1,0 pittsburg,0,1,0 markleeville,0,1,0 sands,0,1,0 kobuk,0,1,0 westfield,0,1,0 wyldwood,0,1,0 swisshome,0,1,0 scioto,0,1,0 waipahu,0,1,0 darden,0,1,0 brinnon,0,0,1 ihlen,0,0,1 darrtown,0,0,1 cousins,0,0,1 muleshoe,0,0,1 uvalda,0,0,1 donaldson,0,0,1 poipu,0,0,1 rockport,0,0,1 cauthron,0,0,1 german,0,0,1 edson,0,0,1 wando,0,0,1 noxapater,0,0,1 mcdade,0,0,1 helton,0,0,1 natural,0,0,1 cochranton,0,0,1 losantville,0,0,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_supersmall.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test 2 | collierville,1,0,0 3 | corozal,1,0,0 4 | ihlen,0,1,0 5 | mcdade,0,1,0 6 | muleshoe,0,1,0 7 | noxapater,0,1,0 8 | uvalda,0,1,0 9 | allensville,0,0,1 -------------------------------------------------------------------------------- /tlkit/data/splits_taskonomy/train_val_test_tiny.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test 2 | hanson,1,0,0 3 | merom,1,0,0 4 | klickitat,1,0,0 5 | onaga,1,0,0 6 | leonardo,1,0,0 7 | marstons,1,0,0 8 | newfields,1,0,0 9 | pinesdale,1,0,0 10 | lakeville,1,0,0 11 | cosmos,1,0,0 12 | benevolence,1,0,0 13 | pomaria,1,0,0 14 | tolstoy,1,0,0 15 | shelbyville,1,0,0 16 | allensville,1,0,0 17 | wainscott,1,0,0 18 | beechwood,1,0,0 19 | coffeen,1,0,0 20 | stockman,1,0,0 21 | hiteman,1,0,0 22 | woodbine,1,0,0 23 | lindenwood,1,0,0 24 | forkland,1,0,0 25 | mifflinburg,1,0,0 26 | ranchester,1,0,0 27 | wiconisco,0,1,0 28 | corozal,0,1,0 29 | collierville,0,1,0 30 | markleeville,0,1,0 31 | darden,0,1,0 32 | ihlen,0,0,1 33 | muleshoe,0,0,1 34 | uvalda,0,0,1 35 | noxapater,0,0,1 36 | mcdade,0,0,1 -------------------------------------------------------------------------------- /tlkit/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/tlkit/models/__init__.py -------------------------------------------------------------------------------- /tlkit/models/basic_models.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # Because we cannot pickle lambda functions 7 | class IdentityFn(nn.Module): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__() 10 | 11 | def forward(self, x, **kwargs): 12 | return x 13 | 14 | def requires_grad_(self, *args, **kwargs): 15 | pass 16 | 17 | def identity_fn(x): 18 | return x 19 | 20 | class ZeroFn(nn.Module): 21 | def forward(self, *args, **kwargs): 22 | return 0.0 23 | 24 | def requires_grad_(self, *args, **kwargs): 25 | pass 26 | 27 | def zero_fn(x): 28 | return 0.0 29 | 30 | class ScaleLayer(nn.Module): 31 | def __init__(self, init_value=1e-3): 32 | super().__init__() 33 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 34 | 35 | def forward(self, input): 36 | return input * self.scale 37 | 38 | class LambdaLayer(nn.Module): 39 | def __init__(self, lambd): 40 | super(LambdaLayer, self).__init__() 41 | self.lambd = lambd 42 | 43 | def forward(self, x): 44 | return self.lambd(x) 45 | 46 | class ResidualLayer(nn.Module): 47 | def __init__(self, net: nn.Module): 48 | super().__init__() 49 | self.net = net 50 | 51 | def forward(self, x): 52 | return x + self.net(x) 53 | 54 | class EvalOnlyModel(nn.Module): 55 | def __init__(self, eval_only=None, train=False, **kwargs): 56 | super().__init__() 57 | if eval_only is None: 58 | warnings.warn(f'Model eval_only flag is not set for {type(self)}. Defaulting to True') 59 | eval_only = True 60 | 61 | if train: 62 | warnings.warn('Model train flag is deprecated') 63 | 64 | self.eval_only = eval_only 65 | 66 | 67 | def forward(self, x, cache={}, time_idx:int=-1): 68 | pass 69 | 70 | def train(self, train): 71 | if self.eval_only: 72 | super().train(False) 73 | for p in self.parameters(): # This must be done after parameters are initialized 74 | p.requires_grad = False 75 | 76 | if train and self.eval_only: 77 | warnings.warn("Ignoring 'train()' in TaskonomyEncoder since 'eval_only' was set during initialization.", RuntimeWarning) 78 | else: 79 | return super().train(train) -------------------------------------------------------------------------------- /tlkit/models/feedback.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .student_models import FCN5 4 | from .model_utils import _make_layer, upsampler 5 | 6 | class FCN5MidFeedback(FCN5): 7 | def __init__(self, kernel_size=3, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | if kernel_size == 3: 10 | net_kwargs = { 'kernel_size': 3, 'stride': 1, 'padding': 1 } 11 | elif kernel_size == 1: 12 | net_kwargs = { 'kernel_size': 1, 'stride': 1, 'padding': 0 } 13 | else: 14 | assert False, f'kernel size not recognized ({kernel_size})' 15 | self.fb_conv1 = _make_layer(8, 64, **net_kwargs) 16 | self.fb_conv2 = _make_layer(64, 8, **net_kwargs) 17 | self.feedback_net = nn.Sequential(self.fb_conv1, self.fb_conv2) 18 | 19 | def forward(self, x, task_idx:int=-1, cache={}): 20 | # Prepare feedback input 21 | last_repr = cache['last_repr'] 22 | last_repr_tweeked = self.feedback_net(last_repr) 23 | last_repr_tweeked = upsampler(last_repr_tweeked) 24 | last_repr_tweeked = last_repr_tweeked.repeat(1,256//8,1,1) 25 | 26 | # Run forward 27 | x = self.conv1(x) 28 | x = self.conv2(x) 29 | x = x + last_repr_tweeked 30 | x2 = x 31 | x = self.conv3(x) 32 | x = self.conv4(x) 33 | x = self.conv5(x) 34 | x = x + self.skip(x2) 35 | 36 | if self.normalize_outputs: 37 | x = self.groupnorm(x) 38 | return last_repr + x 39 | 40 | class FCN5LateFeedback(FCN5): 41 | # Late Feedback because the cache information is not incorporated in the base FCN5Skip 42 | # Instead, it is used later to augment the output 43 | # This does not have feedback. Output is not being used in input 44 | def __init__(self, kernel_size=3, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | if kernel_size == 3: 47 | net_kwargs = { 'kernel_size': 3, 'stride': 1, 'padding': 1 } 48 | elif kernel_size == 1: 49 | net_kwargs = { 'kernel_size': 1, 'stride': 1, 'padding': 0 } 50 | else: 51 | assert False, f'kernel size not recognized ({kernel_size})' 52 | self.fb_conv1 = _make_layer(8, 64, **net_kwargs) 53 | self.fb_conv2 = _make_layer(64, 8, **net_kwargs) 54 | self.feedback_net = nn.Sequential(self.fb_conv1, self.fb_conv2) 55 | 56 | def forward(self, x, task_idx:int=-1, cache={}): 57 | last_repr = cache['last_repr'] 58 | 59 | ret_input_only = super().forward(x, task_idx) 60 | ret_output_only = self.feedback_net(last_repr) 61 | return last_repr + ret_input_only + ret_output_only 62 | -------------------------------------------------------------------------------- /tlkit/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.modules.upsampling import Upsample 3 | 4 | from tlkit.utils import load_state_dict_from_path 5 | from .superposition import HashConv2d, ProjectedConv2d 6 | from .basic_models import zero_fn, ScaleLayer 7 | 8 | upsampler = Upsample(scale_factor=2, mode='nearest') 9 | 10 | def load_submodule(model_class, model_weights_path, model_kwargs, backup_fn=zero_fn): 11 | # If there is a model, use it! If there is initialization, use it! If neither, use backup_fn 12 | if model_class is not None: 13 | model = model_class(**model_kwargs) 14 | if model_weights_path is not None: 15 | model, _ = load_state_dict_from_path(model, model_weights_path) 16 | else: 17 | model = backup_fn 18 | assert model_weights_path is None, 'cannot have weights without model' 19 | return model 20 | 21 | def _make_layer(in_channels, out_channels, num_groups=2, kernel_size=3, stride=1, padding=0, dilation=1, normalize=True, 22 | bsp=False, period=None, debug=False, projected=False, scaling=False, postlinear=False, linear=False): 23 | assert not (bsp and projected), 'cannot do bsp and projectedconv' 24 | if linear: 25 | conv = nn.Linear(in_channels, out_channels, bias=False) 26 | elif bsp: 27 | assert dilation == 1, 'Dilation is not implemented for binary superposition' 28 | assert period is not None, 'Need to specify period' 29 | conv = HashConv2d(in_channels, out_channels, kernel_size=kernel_size, period=period, stride=stride, padding=padding, bias=False, debug=debug) 30 | elif projected: 31 | assert dilation == 1, 'Dilation is not implemented for projected conv' 32 | conv = ProjectedConv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 33 | else: 34 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, dilation=dilation) 35 | gn = nn.GroupNorm(num_groups, out_channels) 36 | relu = nn.ReLU() 37 | 38 | layers = [conv, relu] 39 | if normalize: 40 | layers = [conv, gn, relu] 41 | if scaling: 42 | layers = [ScaleLayer(.9)] + layers 43 | if postlinear: 44 | if linear: 45 | layers = layers + [nn.Linear(in_channels, out_channels)] 46 | else: 47 | layers = layers + [nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)] 48 | return nn.Sequential(*layers) 49 | 50 | -------------------------------------------------------------------------------- /tlkit/models/vision_transfer_architectures.py: -------------------------------------------------------------------------------- 1 | from tlkit.models.student_models import FCN4Reshaped 2 | FCN5SkipCifar = FCN4Reshaped -------------------------------------------------------------------------------- /tnt/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | build/ 3 | dist/ 4 | docs/_build/ 5 | torchnet.egg-info/ 6 | */**/__pycache__ 7 | */**/*.pyc 8 | *.ipynb_checkpoints 9 | -------------------------------------------------------------------------------- /tnt/.travis.yml: -------------------------------------------------------------------------------- 1 | # https://travis-ci.org/pytorch/pytorch 2 | # https://github.com/uber/pyro/blob/dev/.travis.yml 3 | 4 | language: python 5 | env: 6 | global: 7 | - PYTHONPATH=$PWD:$PYTHONPATH 8 | install: 9 | - pip install -U pip 10 | - pip install --progress-bar off -e .[test] 11 | - pip freeze 12 | jobs: 13 | fast_finish: true 14 | include: 15 | - stage: lint 16 | python: 2.7 17 | before_install: pip install -U pip; pip install --progress-bar off flake8 nbstripout 18 | nbformat torch 19 | install: 20 | script: flake8 21 | - stage: unit test 22 | python: 2.7 23 | script: "./test/run_test.sh" 24 | - python: 3.5 25 | script: "./test/run_test.sh" 26 | - stage: deploy 27 | script: echo "Deploying to PyPI..." 28 | deploy: 29 | provider: pypi 30 | user: pytorch 31 | password: 32 | secure: C8v2M7QaeN4xsM05N9MonP+/twtFsxTd0QWsLGSFrm59jT/qur63sCxP7IAmpKYm72zd7F8sLiadyBHsRAgzHwz2zVqGSjMNon5+44aEOOZV7SVqqXudWU5Pr74Wrn/ZQ2ezMf99Tg2pcTgTpmOAtEyc+hOz91IT857tzMR6jy13jYVQN1cVGwtcOrxAGngUqlaNegp5s2Ja9+XH9dyzwpDkgTrllg6r7mCiC7Xy4hKViTmrA0RMD13X/5UFq7t+181RosZDbxjtv2elrpTeWt0CCSel+B9DQQZOQeY5XM+GevcoYwM96IxLPt8aoAFGgR3JlYeiy5NJjR+xXbsjPBwdnkcQTyMsmGYb4vrTqWfsijYc033auPSxvPqpnh4ql3wqCciiQ79Rfhxc8q/AZMayoPdxma2JhPEScfsx0AEMxwWOnLlm6NiHyPEkuZQTU7YqlBSuk9sIAiQDoPI/GISGNAZbfjHWb9DBa/8AYPNOfnE7vBAhgeZnFpenedfUD6mIKrY7tA5QIl9Pnm0Lj5iL0739yKXGHaNZbEFOLe5XIcrvA8ueR35kCgr9zdy+hLvpkc+U4cvp9ELqDDhWV64tBiy6VOP4IuQqfZ/VyXBa7SrM2ITPD6b1/sRPp9c/fUvzZ9BQZRkYW+1y0KOyDICgPf0kfnoy+UnAdwsSZPU= 33 | skip_cleanup: true 34 | on: 35 | tags: true 36 | branch: master 37 | -------------------------------------------------------------------------------- /tnt/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017- Sergey Zagoruyko, 4 | Copyright (c) 2017- Sasank Chilamkurthy, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /tnt/README.md: -------------------------------------------------------------------------------- 1 | TNT 2 | ========== 3 | 4 | **TNT** is a library providing powerful dataloading, logging and visualization utilities for Python. 5 | It is closely integrated with [PyTorch](http://pytorch.org) and is designed to enable rapid iteration with any model 6 | or training regimen. 7 | 8 | ![travis](https://travis-ci.org/pytorch/tnt.svg?branch=master) 9 | [![Documentation Status](https://readthedocs.org/projects/tnt/badge/?version=latest)](http://tnt.readthedocs.io/en/latest/?badge=latest) 10 | 11 | - [About](#about) 12 | - [Installation](#installation) 13 | - [Documentation](http://tnt.readthedocs.io) 14 | - [Getting Started](#getting-started) 15 | 16 | 17 | ## Installation 18 | 19 | TNT can be installed with pip. To do so, run: 20 | 21 | ```buildoutcfg 22 | pip install torchnet 23 | ``` 24 | 25 | If you run into issues, make sure that Pytorch is installed first. 26 | 27 | You can also install the latest verstion from master. Just run: 28 | 29 | ```buildoutcfg 30 | pip install git+https://github.com/pytorch/tnt.git@master 31 | ``` 32 | 33 | To update to the latest version from master: 34 | 35 | ```buildoutcfg 36 | pip install --upgrade git+https://github.com/pytorch/tnt.git@master 37 | ``` 38 | 39 | ## About 40 | TNT (imported as _torchnet_) is a framework for PyTorch which provides a set of abstractions for PyTorch 41 | aiming at encouraging code re-use as well as encouraging modular programming. It provides powerful dataloading, logging, 42 | and visualization utilities. 43 | 44 | The project was inspired by [TorchNet](https://github.com/torchnet/torchnet), and legend says that it stood for “TorchNetTwo”. 45 | Since the deprecation of TorchNet TNT has developed on its own. 46 | 47 | For example, TNT provides simple methods to record model preformance in the `torchnet.meter` module and to log them to Visdom 48 | (or in the future, TensorboardX) with the `torchnet.logging` module. 49 | 50 | In the future, TNT will also provide strong support for multi-task learning and transfer learning applications. It 51 | currently supports joint training data loading through torchnet.utils.MultiTaskDataLoader. 52 | 53 | Most of the modules support NumPy arrays as well as PyTorch tensors on input, and so could potentially be used with 54 | other frameworks. 55 | 56 | 57 | ## Getting Started 58 | See some of the examples in https://github.com/pytorch/examples. We would like to include some walkthroughs in the 59 | [docs](https://tnt.readthedocs.io) (contributions welcome!). 60 | 61 | 62 | ## [LEGACY] Differences with lua version 63 | 64 | What's been ported so far: 65 | 66 | * Datasets: 67 | * BatchDataset 68 | * ListDataset 69 | * ResampleDataset 70 | * ShuffleDataset 71 | * TensorDataset [new] 72 | * TransformDataset 73 | * Meters: 74 | * APMeter 75 | * mAPMeter 76 | * AverageValueMeter 77 | * AUCMeter 78 | * ClassErrorMeter 79 | * ConfusionMeter 80 | * MovingAverageValueMeter 81 | * MSEMeter 82 | * TimeMeter 83 | * Engines: 84 | * Engine 85 | * Logger 86 | * Logger 87 | * VisdomLogger 88 | * MeterLogger [new, easy to plot multi-meter via Visdom] 89 | 90 | Any dataset can now be plugged into `torch.utils.DataLoader`, or called 91 | `.parallel(num_workers=8)` to utilize multiprocessing. 92 | -------------------------------------------------------------------------------- /tnt/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = TNT 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /tnt/docs/_static/css/pytorch_theme.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 3 | } 4 | 5 | /* Default header fonts are ugly */ 6 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 7 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 8 | } 9 | 10 | /* Use white for docs background */ 11 | .wy-side-nav-search { 12 | background-color: #fff; 13 | } 14 | 15 | .wy-nav-content-wrap, .wy-menu li.current > a { 16 | background-color: #fff; 17 | } 18 | 19 | @media screen and (min-width: 1400px) { 20 | .wy-nav-content-wrap { 21 | background-color: rgba(0, 0, 0, 0.0470588); 22 | } 23 | 24 | .wy-nav-content { 25 | background-color: #fff; 26 | } 27 | } 28 | 29 | /* Fixes for mobile */ 30 | .wy-nav-top { 31 | background-color: #fff; 32 | background-image: url('../img/pytorch-logo-dark.svg'); 33 | background-repeat: no-repeat; 34 | background-position: center; 35 | padding: 0; 36 | margin: 0.4045em 0.809em; 37 | color: #333; 38 | } 39 | 40 | .wy-nav-top > a { 41 | display: none; 42 | } 43 | 44 | @media screen and (max-width: 768px) { 45 | .wy-side-nav-search>a img.logo { 46 | height: 60px; 47 | } 48 | } 49 | 50 | /* This is needed to ensure that logo above search scales properly */ 51 | .wy-side-nav-search a { 52 | display: block; 53 | } 54 | 55 | /* This ensures that multiple constructors will remain in separate lines. */ 56 | .rst-content dl:not(.docutils) dt { 57 | display: table; 58 | } 59 | 60 | /* Use our red for literals (it's very similar to the original color) */ 61 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { 62 | color: #F05732; 63 | } 64 | 65 | .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, 66 | .rst-content code.xref, a .rst-content tt, a .rst-content code { 67 | color: #404040; 68 | } 69 | 70 | /* Change link colors (except for the menu) */ 71 | 72 | a { 73 | color: #F05732; 74 | 75 | } 76 | 77 | a:hover { 78 | color: #F05732; 79 | } 80 | 81 | 82 | a:visited { 83 | color: #D44D2C; 84 | } 85 | 86 | .wy-side-nav-search a { 87 | color: #F05732; 88 | font-size: 150% 89 | } 90 | 91 | .wy-side-nav-search a:hover { 92 | color: rgb(240, 112, 81); 93 | } 94 | 95 | .wy-menu a { 96 | color: #b3b3b3; 97 | } 98 | 99 | .wy-menu a:hover { 100 | color: #b3b3b3; 101 | } 102 | 103 | /* Default footer text is quite big */ 104 | footer { 105 | font-size: 80%; 106 | } 107 | 108 | footer .rst-footer-buttons { 109 | font-size: 125%; /* revert footer settings - 1/80% = 125% */ 110 | } 111 | 112 | footer p { 113 | font-size: 100%; 114 | } 115 | 116 | /* For hidden headers that appear in TOC tree */ 117 | /* see http://stackoverflow.com/a/32363545/3343043 */ 118 | .rst-content .hidden-section { 119 | display: none; 120 | } 121 | 122 | nav .hidden-section { 123 | display: inherit; 124 | } 125 | 126 | .wy-side-nav-search>div.version { 127 | color: #000; 128 | } 129 | -------------------------------------------------------------------------------- /tnt/docs/_static/img/dynamic_graph.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/side-tuning/dea345691fb7ee0230150fe56ddd644efdffa6ac/tnt/docs/_static/img/dynamic_graph.gif -------------------------------------------------------------------------------- /tnt/docs/_static/img/pytorch-logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 10 | 13 | 14 | 16 | 17 | 18 | 20 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /tnt/docs/_static/img/pytorch-logo-flame.svg: -------------------------------------------------------------------------------- 1 | 2 | image/svg+xml -------------------------------------------------------------------------------- /tnt/docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block menu %} 4 |
5 | 9 |
10 | {{ super() }} 11 | {% endblock %} 12 | 13 | {% block footer %} 14 | {{ super() }} 15 | 25 | {% endblock %} 26 | -------------------------------------------------------------------------------- /tnt/docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TNT documentation master file, created by 2 | sphinx-quickstart on Tue May 1 11:04:29 2018. 3 | 4 | 5 | TNT Documentation 6 | ================================= 7 | 8 | TNT is a library providing powerful dataloading, logging and visualization utlities for Python. 9 | It is closely intergrated with `PyTorch `_ and is designed to enable rapid iteration with any model or training regimen. 10 | 11 | 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :caption: Notes 16 | 17 | Examples 18 | 19 | 20 | .. toctree:: 21 | :maxdepth: 1 22 | :caption: Package Reference 23 | 24 | torchnet.dataset 25 | torchnet.engine 26 | torchnet.logger 27 | torchnet.meter 28 | torchnet.utils 29 | 30 | 31 | TNT was inspired by TorchNet, and legend says that it stood for "TorchNetTwo". Since then, TNT has developed 32 | on its own. 33 | 34 | TNT provides simple methods to record model preformance in the `torchnet.meter `_ module 35 | and to log them to Visdom (or in the future, TensorboardX) with the `torchnet.logging `_ 36 | module. 37 | 38 | In the future, TNT will also provide strong support for multi-task learning and transfer learning applications. It 39 | currently supports joint training data loading through 40 | `torchnet.utils.MultiTaskDataLoader `_. 41 | 42 | -------------------------------------------------------------------------------- /tnt/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=TNT 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /tnt/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | -------------------------------------------------------------------------------- /tnt/docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | torchnet 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | torchnet.utils 8 | -------------------------------------------------------------------------------- /tnt/docs/source/torchnet.dataset.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchnet.dataset 5 | ======================== 6 | 7 | .. automodule:: torchnet.dataset 8 | .. currentmodule:: torchnet.dataset 9 | 10 | Provides a :class:`Dataset` interface, similar to vanilla PyTorch. 11 | 12 | .. autoclass:: torchnet.dataset.dataset.Dataset 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | 17 | 18 | :hidden:`BatchDataset` 19 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 | .. autoclass:: BatchDataset 21 | :members: 22 | :show-inheritance: 23 | 24 | 25 | :hidden:`ConcatDataset` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | .. autoclass:: ConcatDataset 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | 33 | :hidden:`ListDataset` 34 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 35 | .. autoclass:: ListDataset 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | 40 | :hidden:`ResampleDataset` 41 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 42 | .. autoclass:: ResampleDataset 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | :hidden:`ShuffleDataset` 48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 49 | .. autoclass:: ShuffleDataset 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | :hidden:`SplitDataset` 55 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 56 | .. autoclass:: SplitDataset 57 | :members: 58 | :undoc-members: 59 | :show-inheritance: 60 | 61 | :hidden:`TensorDataset` 62 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 63 | .. autoclass:: TensorDataset 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | :hidden:`TransformDataset` 69 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 70 | .. autoclass:: TransformDataset 71 | :members: 72 | :undoc-members: 73 | :show-inheritance: 74 | 75 | -------------------------------------------------------------------------------- /tnt/docs/source/torchnet.engine.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchnet.engine 5 | ====================== 6 | 7 | .. automodule:: torchnet.engine 8 | .. currentmodule:: torchnet.engine 9 | 10 | Engines are a utility to wrap a training loop. They provide several hooks which 11 | allow users to define their own fucntions to run at specified points during the 12 | train/val loop. 13 | 14 | Some people like engines, others do not. TNT is build modularly, so you can use 15 | the other modules with/without using an engine. 16 | 17 | torchnet.engine.Engine 18 | ----------------------------- 19 | 20 | .. autoclass:: Engine 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | -------------------------------------------------------------------------------- /tnt/docs/source/torchnet.logger.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchnet.logger 5 | ======================== 6 | 7 | .. automodule:: torchnet.logger 8 | .. currentmodule:: torchnet.logger 9 | 10 | Loggers provide a way to monitor your models. For example, the :class:`MeterLogger` class 11 | provides easy meter visualizetion with `Visdom `_ , 12 | as well as the ability to print and save meters with the :class:`ResultsWriter` class. 13 | 14 | For visualization libraries, the current loggers support ``Visdom``, although ``TensorboardX`` 15 | would also be simple to implement. 16 | 17 | 18 | MeterLogger 19 | ~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: MeterLogger 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | VisdomLogger 27 | ~~~~~~~~~~~~~~~ 28 | .. automodule:: torchnet.logger.visdomlogger 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | -------------------------------------------------------------------------------- /tnt/docs/source/torchnet.meter.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchnet.meter 5 | ====================== 6 | 7 | .. automodule:: torchnet.meter 8 | .. currentmodule:: torchnet.meter 9 | 10 | Meters provide a way to keep track of important statistics in an online manner. 11 | TNT also provides convenient ways to visualize and manage meters via the :class:`torchnet.logger.MeterLogger` class. 12 | 13 | .. autoclass:: torchnet.meter.meter.Meter 14 | :members: 15 | 16 | Classification Meters 17 | ------------------------------ 18 | 19 | 20 | :hidden:`APMeter` 21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | .. autoclass:: APMeter 23 | :members: 24 | 25 | :hidden:`mAPMeter` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | .. autoclass:: mAPMeter 28 | :members: 29 | 30 | :hidden:`ClassErrorMeter` 31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 32 | .. autoclass:: ClassErrorMeter 33 | :members: 34 | 35 | :hidden:`ConfusionMeter` 36 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 37 | .. autoclass:: ConfusionMeter 38 | :members: 39 | 40 | 41 | Regression/Loss Meters 42 | ------------------------------ 43 | 44 | :hidden:`AverageValueMeter` 45 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 46 | .. autoclass:: AverageValueMeter 47 | :members: 48 | 49 | :hidden:`AUCMeter` 50 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 51 | .. autoclass:: AUCMeter 52 | :members: 53 | 54 | :hidden:`MovingAverageValueMeter` 55 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 56 | .. autoclass:: MovingAverageValueMeter 57 | :members: 58 | 59 | :hidden:`MSEMeter` 60 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 61 | .. autoclass:: MSEMeter 62 | :members: 63 | 64 | 65 | 66 | 67 | 68 | Miscellaneous Meters 69 | ------------------------------ 70 | 71 | :hidden:`TimeMeter` 72 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 73 | .. autoclass:: TimeMeter 74 | :members: 75 | -------------------------------------------------------------------------------- /tnt/docs/source/torchnet.rst: -------------------------------------------------------------------------------- 1 | torchnet package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | torchnet.dataset 10 | torchnet.engine 11 | torchnet.logger 12 | torchnet.meter 13 | torchnet.utils 14 | 15 | Submodules 16 | ---------- 17 | 18 | torchnet.transform module 19 | ------------------------- 20 | 21 | .. automodule:: torchnet.transform 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | 27 | Module contents 28 | --------------- 29 | 30 | .. automodule:: torchnet 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | -------------------------------------------------------------------------------- /tnt/docs/source/torchnet.utils.rst: -------------------------------------------------------------------------------- 1 | torchnet.utils 2 | ====================== 3 | 4 | 5 | MultiTaskDataLoader 6 | ----------------------------------------- 7 | 8 | .. autoclass:: torchnet.utils.MultiTaskDataLoader 9 | :members: 10 | :exclude-members: zip_batches 11 | :undoc-members: 12 | :show-inheritance: 13 | 14 | ResultsWriter 15 | ----------------------------------- 16 | 17 | .. autoclass:: torchnet.utils.ResultsWriter 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | Table module 23 | --------------------------- 24 | 25 | .. automodule:: torchnet.utils.table 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | -------------------------------------------------------------------------------- /tnt/example/README.md: -------------------------------------------------------------------------------- 1 | # mnist_with_meterlogger 2 | 3 | 4 | 5 | ## Start Visdom on a server 6 | 7 | ```bash 8 | python -m visdom.server 9 | # python -m visdom.server -port 9999 # to specify port to ex, 9999 10 | ``` 11 | 12 | 13 | ## Run Example 14 | 15 | ```bash 16 | python mnist_with_meterlogger.py 17 | # CUDA_VISIBLE_DEVICES=1 python mnist_with_meterlogger.py # to specify GPU id to ex. 1 18 | ``` 19 | 20 | ## Multi-meter 21 | 22 | Easy to plot multi-meter with just one-line code: 23 | 24 | ### Plotting Accuracy, mAP 25 | 26 | ```python 27 | mlog.updateMeter(output, target, meters={'accuracy', 'map'}) 28 | ``` 29 | 30 | ### Plotting Loss Curve 31 | 32 | ```python 33 | # NLL Loss 34 | nll_loss = F.nll_loss(output, target) 35 | mlog.updateLoss(nll_loss, meter='nll_loss') 36 | 37 | # Cross Entropy Loss 38 | ce_loss = F.cross_entropy(output, target) 39 | mlog.updateLoss(ce_loss, meter='ce_loss') 40 | ``` 41 | 42 | ## Remote Plotting 43 | 44 | ```python 45 | mlog = MeterLogger(server="Server's IP", nclass=10, title="mnist") 46 | ``` 47 | 48 | 49 | ## Figure 50 | 51 | ![visdom.png](meterlogger.png) 52 | -------------------------------------------------------------------------------- /tnt/example/mnist.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import torch.optim 4 | import torchnet as tnt 5 | from torchvision.datasets.mnist import MNIST 6 | from torchnet.engine import Engine 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | from torch.nn.init import kaiming_normal 10 | 11 | 12 | def get_iterator(mode): 13 | ds = MNIST(root='./', download=True, train=mode) 14 | data = getattr(ds, 'train_data' if mode else 'test_data') 15 | labels = getattr(ds, 'train_labels' if mode else 'test_labels') 16 | tds = tnt.dataset.TensorDataset([data, labels]) 17 | return tds.parallel(batch_size=128, num_workers=4, shuffle=mode) 18 | 19 | 20 | def conv_init(ni, no, k): 21 | return kaiming_normal(torch.Tensor(no, ni, k, k)) 22 | 23 | 24 | def linear_init(ni, no): 25 | return kaiming_normal(torch.Tensor(no, ni)) 26 | 27 | 28 | def f(params, inputs, mode): 29 | o = inputs.view(inputs.size(0), 1, 28, 28) 30 | o = F.conv2d(o, params['conv0.weight'], params['conv0.bias'], stride=2) 31 | o = F.relu(o) 32 | o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'], stride=2) 33 | o = F.relu(o) 34 | o = o.view(o.size(0), -1) 35 | o = F.linear(o, params['linear2.weight'], params['linear2.bias']) 36 | o = F.relu(o) 37 | o = F.linear(o, params['linear3.weight'], params['linear3.bias']) 38 | return o 39 | 40 | 41 | def main(): 42 | params = { 43 | 'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50), 44 | 'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50), 45 | 'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512), 46 | 'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10), 47 | } 48 | params = {k: Variable(v, requires_grad=True) for k, v in params.items()} 49 | 50 | optimizer = torch.optim.SGD( 51 | params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005) 52 | 53 | engine = Engine() 54 | meter_loss = tnt.meter.AverageValueMeter() 55 | classerr = tnt.meter.ClassErrorMeter(accuracy=True) 56 | 57 | def h(sample): 58 | inputs = Variable(sample[0].float() / 255.0) 59 | targets = Variable(torch.LongTensor(sample[1])) 60 | o = f(params, inputs, sample[2]) 61 | return F.cross_entropy(o, targets), o 62 | 63 | def reset_meters(): 64 | classerr.reset() 65 | meter_loss.reset() 66 | 67 | def on_sample(state): 68 | state['sample'].append(state['train']) 69 | 70 | def on_forward(state): 71 | classerr.add(state['output'].data, 72 | torch.LongTensor(state['sample'][1])) 73 | meter_loss.add(state['loss'].data[0]) 74 | 75 | def on_start_epoch(state): 76 | reset_meters() 77 | state['iterator'] = tqdm(state['iterator']) 78 | 79 | def on_end_epoch(state): 80 | print('Training loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0])) 81 | # do validation at the end of each epoch 82 | reset_meters() 83 | engine.test(h, get_iterator(False)) 84 | print('Testing loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0])) 85 | 86 | engine.hooks['on_sample'] = on_sample 87 | engine.hooks['on_forward'] = on_forward 88 | engine.hooks['on_start_epoch'] = on_start_epoch 89 | engine.hooks['on_end_epoch'] = on_end_epoch 90 | engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /tnt/example/mnist_with_meterlogger.py: -------------------------------------------------------------------------------- 1 | """ Run MNIST example and log to visdom 2 | Notes: 3 | - Visdom must be installed (pip works) 4 | - the Visdom server must be running at start! 5 | 6 | Example: 7 | $ python -m visdom.server -port 8097 & 8 | $ python mnist_with_visdom.py 9 | """ 10 | from tqdm import tqdm 11 | import torch 12 | import torch.optim 13 | import torchnet as tnt 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | from torch.nn.init import kaiming_normal 17 | from torchnet.engine import Engine 18 | from torchnet.logger import MeterLogger 19 | from torchvision.datasets.mnist import MNIST 20 | 21 | 22 | def get_iterator(mode): 23 | ds = MNIST(root='./', download=True, train=mode) 24 | data = getattr(ds, 'train_data' if mode else 'test_data') 25 | labels = getattr(ds, 'train_labels' if mode else 'test_labels') 26 | tds = tnt.dataset.TensorDataset([data, labels]) 27 | return tds.parallel(batch_size=128, num_workers=4, shuffle=mode) 28 | 29 | 30 | def conv_init(ni, no, k): 31 | return kaiming_normal(torch.Tensor(no, ni, k, k)) 32 | 33 | 34 | def linear_init(ni, no): 35 | return kaiming_normal(torch.Tensor(no, ni)) 36 | 37 | 38 | def f(params, inputs, mode): 39 | o = inputs.view(inputs.size(0), 1, 28, 28) 40 | o = F.conv2d(o, params['conv0.weight'], params['conv0.bias'], stride=2) 41 | o = F.relu(o) 42 | o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'], stride=2) 43 | o = F.relu(o) 44 | o = o.view(o.size(0), -1) 45 | o = F.linear(o, params['linear2.weight'], params['linear2.bias']) 46 | o = F.relu(o) 47 | o = F.linear(o, params['linear3.weight'], params['linear3.bias']) 48 | return o 49 | 50 | 51 | def main(): 52 | params = { 53 | 'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50), 54 | 'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50), 55 | 'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512), 56 | 'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10), 57 | } 58 | params = {k: Variable(v, requires_grad=True) for k, v in params.items()} 59 | 60 | optimizer = torch.optim.SGD( 61 | params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005) 62 | 63 | engine = Engine() 64 | 65 | mlog = MeterLogger(server='10.10.30.91', port=9917, nclass=10, title="mnist_meterlogger") 66 | 67 | def h(sample): 68 | inputs = Variable(sample[0].float() / 255.0) 69 | targets = Variable(torch.LongTensor(sample[1])) 70 | o = f(params, inputs, sample[2]) 71 | return F.cross_entropy(o, targets), o 72 | 73 | def on_sample(state): 74 | state['sample'].append(state['train']) 75 | 76 | def on_forward(state): 77 | loss = state['loss'] 78 | output = state['output'] 79 | target = state['sample'][1] 80 | # online ploter 81 | mlog.update_loss(loss, meter='loss') 82 | mlog.update_meter(output, target, meters={'accuracy', 'map', 'confusion'}) 83 | 84 | def on_start_epoch(state): 85 | mlog.timer.reset() 86 | state['iterator'] = tqdm(state['iterator']) 87 | 88 | def on_end_epoch(state): 89 | mlog.print_meter(mode="Train", iepoch=state['epoch']) 90 | mlog.reset_meter(mode="Train", iepoch=state['epoch']) 91 | 92 | # do validation at the end of each epoch 93 | engine.test(h, get_iterator(False)) 94 | mlog.print_meter(mode="Test", iepoch=state['epoch']) 95 | mlog.reset_meter(mode="Test", iepoch=state['epoch']) 96 | 97 | engine.hooks['on_sample'] = on_sample 98 | engine.hooks['on_forward'] = on_forward 99 | engine.hooks['on_start_epoch'] = on_start_epoch 100 | engine.hooks['on_end_epoch'] = on_end_epoch 101 | engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tnt/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | visdom 3 | -------------------------------------------------------------------------------- /tnt/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import shutil 4 | import sys 5 | from setuptools import setup, find_packages 6 | 7 | VERSION = '0.0.4' 8 | 9 | long_description = "Simple tools for logging and visualizing, loading and training" 10 | 11 | setup_info = dict( 12 | # Metadata 13 | name='torchnet', 14 | version=VERSION, 15 | author='PyTorch', 16 | author_email='sergey.zagoruyko@enpc.fr', 17 | url='https://github.com/pytorch/tnt/', 18 | description='an abstraction to train neural networks', 19 | long_description=long_description, 20 | license='BSD', 21 | 22 | # Package info 23 | packages=find_packages(exclude=('test', 'docs')), 24 | 25 | zip_safe=True, 26 | 27 | install_requires=[ 28 | 'torch', 29 | 'six', 30 | 'visdom' 31 | ] 32 | ) 33 | 34 | setup(**setup_info) 35 | -------------------------------------------------------------------------------- /tnt/test/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | PYCMD=${PYCMD:="python"} 5 | if [ "$1" == "coverage" ]; 6 | then 7 | coverage erase 8 | PYCMD="coverage run --parallel-mode --source torch " 9 | echo "coverage flag found. Setting python command to: \"$PYCMD\"" 10 | fi 11 | 12 | pushd "$(dirname "$0")" 13 | 14 | $PYCMD test_datasets.py 15 | $PYCMD test_meters.py 16 | $PYCMD test_transforms.py 17 | -------------------------------------------------------------------------------- /tnt/test/test_transforms.py: -------------------------------------------------------------------------------- 1 | import torchnet.transform as transform 2 | import unittest 3 | import torch 4 | 5 | 6 | class TestTransforms(unittest.TestCase): 7 | def testCompose(self): 8 | self.assertEqual(transform.compose([lambda x: x + 1, lambda x: x + 2, lambda x: x / 2])(1), 2) 9 | 10 | def testTableMergeKeys(self): 11 | x = { 12 | 'sample1': {'input': 1, 'target': "a"}, 13 | 'sample2': {'input': 2, 'target': "b", 'flag': "hard"} 14 | } 15 | 16 | y = transform.tablemergekeys()(x) 17 | 18 | self.assertEqual(y['input'], {'sample1': 1, 'sample2': 2}) 19 | self.assertEqual(y['target'], {'sample1': "a", 'sample2': "b"}) 20 | self.assertEqual(y['flag'], {'sample2': "hard"}) 21 | 22 | def testTableApply(self): 23 | x = {1: 1, 2: 2} 24 | y = transform.tableapply(lambda x: x + 1)(x) 25 | self.assertEqual(y, {1: 2, 2: 3}) 26 | 27 | def testMakeBatch(self): 28 | x = [ 29 | {'input': torch.randn(4), 'target': "a"}, 30 | {'input': torch.randn(4), 'target': "b"}, 31 | ] 32 | y = transform.makebatch()(x) 33 | self.assertEqual(y['input'].size(), torch.Size([2, 4])) 34 | self.assertEqual(y['target'], ["a", "b"]) 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /tnt/torchnet/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dataset, meter, engine, transform, logger 2 | __all__ = ['dataset', 'meter', 'engine', 'transform', 'logger'] 3 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .batchdataset import BatchDataset 2 | from .listdataset import ListDataset 3 | from .tensordataset import TensorDataset 4 | from .transformdataset import TransformDataset 5 | from .resampledataset import ResampleDataset 6 | from .shuffledataset import ShuffleDataset 7 | from .concatdataset import ConcatDataset 8 | from .splitdataset import SplitDataset 9 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/concatdataset.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | import numpy as np 3 | 4 | 5 | class ConcatDataset(Dataset): 6 | """ 7 | Dataset to concatenate multiple datasets. 8 | 9 | Purpose: useful to assemble different existing datasets, possibly 10 | large-scale datasets as the concatenation operation is done in an 11 | on-the-fly manner. 12 | 13 | Args: 14 | datasets (iterable): List of datasets to be concatenated 15 | """ 16 | 17 | def __init__(self, datasets): 18 | super(ConcatDataset, self).__init__() 19 | 20 | self.datasets = list(datasets) 21 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 22 | self.cum_sizes = np.cumsum([len(x) for x in self.datasets]) 23 | 24 | def __len__(self): 25 | return self.cum_sizes[-1] 26 | 27 | def __getitem__(self, idx): 28 | super(ConcatDataset, self).__getitem__(idx) 29 | dataset_index = self.cum_sizes.searchsorted(idx, 'right') 30 | 31 | if dataset_index == 0: 32 | dataset_idx = idx 33 | else: 34 | dataset_idx = idx - self.cum_sizes[dataset_index - 1] 35 | 36 | return self.datasets[dataset_index][dataset_idx] 37 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import tnt.torchnet 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class Dataset(object): 6 | def __init__(self): 7 | pass 8 | 9 | def __len__(self): 10 | pass 11 | 12 | def __getitem__(self, idx): 13 | if idx >= len(self): 14 | raise IndexError("CustomRange index out of range") 15 | pass 16 | 17 | def batch(self, *args, **kwargs): 18 | return torchnet.dataset.BatchDataset(self, *args, **kwargs) 19 | 20 | def transform(self, *args, **kwargs): 21 | return torchnet.dataset.TransformDataset(self, *args, **kwargs) 22 | 23 | def shuffle(self, *args, **kwargs): 24 | return torchnet.dataset.ShuffleDataset(self, *args, **kwargs) 25 | 26 | def parallel(self, *args, **kwargs): 27 | return DataLoader(self, *args, **kwargs) 28 | 29 | def split(self, *args, **kwargs): 30 | return torchnet.dataset.SplitDataset(self, *args, **kwargs) 31 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/listdataset.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | 3 | 4 | class ListDataset(Dataset): 5 | """ 6 | Dataset which loads data from a list using given function. 7 | 8 | Considering a `elem_list` (can be an iterable or a `string` ) i-th sample 9 | of a dataset will be returned by `load(elem_list[i])`, where `load()` 10 | is a function provided by the user. 11 | 12 | If `path` is provided, `elem_list` is assumed to be a list of strings, and 13 | each element `elem_list[i]` will prefixed by `path/` when fed to `load()`. 14 | 15 | Purpose: many low or medium-scale datasets can be seen as a list of files 16 | (for example representing input samples). For this list of file, a target 17 | can be often inferred in a simple manner. 18 | 19 | Args: 20 | elem_list (iterable/str): List of arguments which will be passed to 21 | `load` function. It can also be a path to file with each line 22 | containing the arguments to `load` 23 | load (function, optional): Function which loads the data. 24 | i-th sample is returned by `load(elem_list[i])`. By default `load` 25 | is identity i.e, `lambda x: x` 26 | path (str, optional): Defaults to None. If a string is provided, 27 | `elem_list` is assumed to be a list of strings, and each element 28 | `elem_list[i]` will prefixed by this string when fed to `load()`. 29 | 30 | """ 31 | 32 | def __init__(self, elem_list, load=lambda x: x, path=None): 33 | super(ListDataset, self).__init__() 34 | 35 | if isinstance(elem_list, str): 36 | with open(elem_list) as f: 37 | self.list = [line.replace('\n', '') for line in f] 38 | else: 39 | # just assume iterable 40 | self.list = elem_list 41 | 42 | self.path = path 43 | self.load = load 44 | 45 | def __len__(self): 46 | return len(self.list) 47 | 48 | def __getitem__(self, idx): 49 | super(ListDataset, self).__getitem__(idx) 50 | 51 | if self.path is not None: 52 | return self.load("%s/%s" % (self.path, self.list[idx])) 53 | else: 54 | return self.load(self.list[idx]) 55 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/resampledataset.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | 3 | 4 | class ResampleDataset(Dataset): 5 | """ 6 | Dataset which resamples a given dataset. 7 | 8 | Given a `dataset`, creates a new dataset which will (re-)sample from this 9 | underlying dataset using the provided `sampler(dataset, idx)` function. 10 | 11 | If `size` is provided, then the newly created dataset will have the 12 | specified `size`, which might be different than the underlying dataset 13 | size. If `size` is not provided, then the new dataset will have the same 14 | size as the underlying one. 15 | 16 | Purpose: shuffling data, re-weighting samples, getting a subset of the 17 | data. Note that an important sub-class `ShuffleDataset` is provided for 18 | convenience. 19 | 20 | Args: 21 | dataset (Dataset): Dataset to be resampled. 22 | sampler (function, optional): Function used for sampling. `idx`th 23 | sample is returned by `dataset[sampler(dataset, idx)]`. By default 24 | `sampler(dataset, idx)` is the identity, simply returning `idx`. 25 | `sampler(dataset, idx)` must return an index in the range 26 | acceptable for the underlying `dataset`. 27 | size (int, optional): Desired size of the dataset after resampling. By 28 | default, the new dataset will have the same size as the underlying 29 | one. 30 | 31 | """ 32 | 33 | def __init__(self, dataset, sampler=lambda ds, idx: idx, size=None): 34 | super(ResampleDataset, self).__init__() 35 | self.dataset = dataset 36 | self.sampler = sampler 37 | self.size = size 38 | 39 | def __len__(self): 40 | return (self.size and self.size > 0) and self.size or len(self.dataset) 41 | 42 | def __getitem__(self, idx): 43 | super(ResampleDataset, self).__getitem__(idx) 44 | idx = self.sampler(self.dataset, idx) 45 | 46 | if idx < 0 or idx >= len(self.dataset): 47 | raise IndexError('out of range') 48 | 49 | return self.dataset[idx] 50 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/shuffledataset.py: -------------------------------------------------------------------------------- 1 | from .resampledataset import ResampleDataset 2 | import torch 3 | 4 | 5 | class ShuffleDataset(ResampleDataset): 6 | """ 7 | Dataset which shuffles a given dataset. 8 | 9 | `ShuffleDataset` is a sub-class of `ResampleDataset` provided for 10 | convenience. It samples uniformly from the given `dataset` with, or without 11 | `replacement`. The chosen partition can be redrawn by calling `resample()` 12 | 13 | If `replacement` is `true`, then the specified `size` may be larger than 14 | the underlying `dataset`. 15 | If `size` is not provided, then the new dataset size will be equal to the 16 | underlying `dataset` size. 17 | 18 | Purpose: the easiest way to shuffle a dataset! 19 | 20 | Args: 21 | dataset (Dataset): Dataset to be shuffled. 22 | size (int, optional): Desired size of the shuffled dataset. If 23 | `replacement` is `true`, then can be larger than the `len(dataset)`. 24 | By default, the new dataset will have the same size as `dataset`. 25 | replacement (bool, optional): True if uniform sampling is to be done 26 | with replacement. False otherwise. Defaults to false. 27 | 28 | Raises: 29 | ValueError: If `size` is larger than the size of the underlying dataset 30 | and `replacement` is False. 31 | """ 32 | 33 | def __init__(self, dataset, size=None, replacement=False): 34 | if size and not replacement and size > len(dataset): 35 | raise ValueError('size cannot be larger than underlying dataset \ 36 | size when sampling without replacement') 37 | 38 | super(ShuffleDataset, self).__init__(dataset, 39 | lambda dataset, idx: self.perm[idx], 40 | size) 41 | self.replacement = replacement 42 | self.resample() 43 | 44 | def resample(self, seed=None): 45 | """Resample the dataset. 46 | 47 | Args: 48 | seed (int, optional): Seed for resampling. By default no seed is 49 | used. 50 | """ 51 | if seed is not None: 52 | gen = torch.manual_seed(seed) 53 | else: 54 | gen = torch.default_generator 55 | 56 | if self.replacement: 57 | self.perm = torch.LongTensor(len(self)).random_( 58 | len(self.dataset), generator=gen) 59 | else: 60 | self.perm = torch.randperm( 61 | len(self.dataset), generator=gen).narrow(0, 0, len(self)) 62 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/splitdataset.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | import numpy as np 3 | 4 | 5 | class SplitDataset(Dataset): 6 | """ 7 | Dataset to partition a given dataset. 8 | 9 | Partition a given `dataset`, according to the specified `partitions`. Use 10 | the method `select()` to select the current partition in use. 11 | 12 | The `partitions` is a dictionary where a key is a user-chosen string 13 | naming the partition, and value is a number representing the weight (as a 14 | number between 0 and 1) or the size (in number of samples) of the 15 | corresponding partition. 16 | 17 | Partioning is achieved linearly (no shuffling). See `ShuffleDataset` if you 18 | want to shuffle the dataset before partitioning. 19 | 20 | Args: 21 | dataset (Dataset): Dataset to be split. 22 | partitions (dict): Dictionary where key is a user-chosen string 23 | naming the partition, and value is a number representing the weight 24 | (as a number between 0 and 1) or the size (in number of samples) 25 | of the corresponding partition. 26 | initial_partition (str, optional): Initial parition to be selected. 27 | 28 | """ 29 | 30 | def __init__(self, dataset, partitions, initial_partition=None): 31 | super(SplitDataset, self).__init__() 32 | 33 | self.dataset = dataset 34 | self.partitions = partitions 35 | 36 | # A few assertions 37 | assert isinstance(partitions, dict), 'partitions must be a dict' 38 | assert len(partitions) >= 2, \ 39 | 'SplitDataset should have at least two partitions' 40 | assert min(partitions.values()) >= 0, \ 41 | 'partition sizes cannot be negative' 42 | assert max(partitions.values()) > 0, 'all partitions cannot be empty' 43 | 44 | self.partition_names = sorted(list(self.partitions.keys())) 45 | self.partition_index = {partition: i for i, partition in 46 | enumerate(self.partition_names)} 47 | 48 | self.partition_sizes = [self.partitions[parition] for parition in 49 | self.partition_names] 50 | # if partition sizes are fractions, convert to sizes: 51 | if sum(self.partition_sizes) <= 1: 52 | self.partition_sizes = [round(x * len(dataset)) for x in 53 | self.partition_sizes] 54 | else: 55 | for x in self.partition_sizes: 56 | assert x == int(x), ('partition sizes should be integer' 57 | ' numbers, or sum up to <= 1 ') 58 | 59 | self.partition_cum_sizes = np.cumsum(self.partition_sizes) 60 | 61 | if initial_partition is not None: 62 | self.select(initial_partition) 63 | 64 | def select(self, partition): 65 | """ 66 | Select the parition. 67 | 68 | Args: 69 | partition (str): Partition to be selected. 70 | """ 71 | self.current_partition_idx = self.partition_index[partition] 72 | 73 | def __len__(self): 74 | try: 75 | return self.partition_sizes[self.current_partition_idx] 76 | except AttributeError: 77 | raise ValueError("Select a partition before accessing data.") 78 | 79 | def __getitem__(self, idx): 80 | super(SplitDataset, self).__getitem__(idx) 81 | try: 82 | if self.current_partition_idx == 0: 83 | return self.dataset[idx] 84 | else: 85 | offset = self.partition_cum_sizes[self.current_partition_idx - 1] 86 | return self.dataset[int(offset) + idx] 87 | except AttributeError: 88 | raise ValueError("Select a partition before accessing data.") 89 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/tensordataset.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class TensorDataset(Dataset): 7 | """ 8 | Dataset from a tensor or array or list or dict. 9 | 10 | `TensorDataset` provides a way to create a dataset out of the data that is 11 | already loaded into memory. It accepts data in the following forms: 12 | 13 | tensor or numpy array 14 | `idx`th sample is `data[idx]` 15 | 16 | dict of tensors or numpy arrays 17 | `idx`th sample is `{k: v[idx] for k, v in data.items()}` 18 | 19 | list of tensors or numpy arrays 20 | `idx`th sample is `[v[idx] for v in data]` 21 | 22 | Purpose: Easy way to create a dataset out of standard data structures. 23 | 24 | Args: 25 | data (dict/list/tensor/ndarray): Data for the dataset. 26 | """ 27 | 28 | def __init__(self, data): 29 | super(TensorDataset, self).__init__() 30 | 31 | if isinstance(data, dict): 32 | assert len(data) > 0, "Should have at least one element" 33 | # check that all fields have the same size 34 | n_elem = len(list(data.values())[0]) 35 | for v in data.values(): 36 | assert len(v) == n_elem, "All values must have the same size" 37 | elif isinstance(data, list): 38 | assert len(data) > 0, "Should have at least one element" 39 | n_elem = len(data[0]) 40 | for v in data: 41 | assert len(v) == n_elem, "All elements must have the same size" 42 | 43 | self.data = data 44 | 45 | def __len__(self): 46 | if isinstance(self.data, dict): 47 | return len(list(self.data.values())[0]) 48 | elif isinstance(self.data, list): 49 | return len(self.data[0]) 50 | elif torch.is_tensor(self.data) or isinstance(self.data, np.ndarray): 51 | return len(self.data) 52 | 53 | def __getitem__(self, idx): 54 | super(TensorDataset, self).__getitem__(idx) 55 | if isinstance(self.data, dict): 56 | return {k: v[idx] for k, v in self.data.items()} 57 | elif isinstance(self.data, list): 58 | return [v[idx] for v in self.data] 59 | elif torch.is_tensor(self.data) or isinstance(self.data, np.ndarray): 60 | return self.data[idx] 61 | -------------------------------------------------------------------------------- /tnt/torchnet/dataset/transformdataset.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | 3 | 4 | class TransformDataset(Dataset): 5 | """ 6 | Dataset which transforms a given dataset with a given function. 7 | 8 | Given a function `transform`, and a `dataset`, `TransformDataset` applies 9 | the function in an on-the-fly manner when querying a sample with 10 | `__getitem__(idx)` and therefore returning `transform[dataset[idx]]`. 11 | 12 | `transform` can also be a dict with functions as values. In this case, it 13 | is assumed that `dataset[idx]` is a dict which has all the keys in 14 | `transform`. Then, `transform[key]` is applied to dataset[idx][key] for 15 | each key in `transform` 16 | 17 | The size of the new dataset is equal to the size of the underlying 18 | `dataset`. 19 | 20 | Purpose: when performing pre-processing operations, it is convenient to be 21 | able to perform on-the-fly transformations to a dataset. 22 | 23 | Args: 24 | dataset (Dataset): Dataset which has to be transformed. 25 | transforms (function/dict): Function or dict with function as values. 26 | These functions will be applied to data. 27 | """ 28 | 29 | def __init__(self, dataset, transforms): 30 | super(TransformDataset, self).__init__() 31 | 32 | assert isinstance(transforms, dict) or callable(transforms), \ 33 | 'expected a dict of transforms or a function' 34 | if isinstance(transforms, dict): 35 | for k, v in transforms.items(): 36 | assert callable(v), str(k) + ' is not a function' 37 | 38 | self.dataset = dataset 39 | self.transforms = transforms 40 | 41 | def __len__(self): 42 | return len(self.dataset) 43 | 44 | def __getitem__(self, idx): 45 | super(TransformDataset, self).__getitem__(idx) 46 | z = self.dataset[idx] 47 | 48 | if isinstance(self.transforms, dict): 49 | for k, transform in self.transforms.items(): 50 | z[k] = transform(z[k]) 51 | else: 52 | z = self.transforms(z) 53 | 54 | return z 55 | -------------------------------------------------------------------------------- /tnt/torchnet/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import Engine 2 | -------------------------------------------------------------------------------- /tnt/torchnet/engine/engine.py: -------------------------------------------------------------------------------- 1 | class Engine(object): 2 | def __init__(self): 3 | self.hooks = {} 4 | 5 | def hook(self, name, state): 6 | r"""Registers a backward hook. 7 | 8 | The hook will be called every time a gradient with respect to the 9 | Tensor is computed. The hook should have the following signature:: 10 | 11 | hook (grad) -> Tensor or None 12 | 13 | The hook should not modify its argument, but it can optionally return 14 | a new gradient which will be used in place of :attr:`grad`. 15 | This function returns a handle with a method ``handle.remove()`` 16 | that removes the hook from the module. 17 | 18 | Example: 19 | >>> v = torch.tensor([0., 0., 0.], requires_grad=True) 20 | >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient 21 | >>> v.backward(torch.tensor([1., 2., 3.])) 22 | >>> v.grad 23 | 2 24 | 4 25 | 6 26 | [torch.FloatTensor of size (3,)] 27 | >>> h.remove() # removes the hook 28 | 29 | """ 30 | if name in self.hooks: 31 | self.hooks[name](state) 32 | 33 | def train(self, network, iterator, maxepoch, optimizer): 34 | state = { 35 | 'network': network, 36 | 'iterator': iterator, 37 | 'maxepoch': maxepoch, 38 | 'optimizer': optimizer, 39 | 'epoch': 0, 40 | 't': 0, 41 | 'train': True, 42 | } 43 | 44 | self.hook('on_start', state) 45 | while state['epoch'] < state['maxepoch']: 46 | self.hook('on_start_epoch', state) 47 | for sample in state['iterator']: 48 | state['sample'] = sample 49 | self.hook('on_sample', state) 50 | 51 | def closure(): 52 | loss, output = state['network'](state['sample']) 53 | state['output'] = output 54 | state['loss'] = loss 55 | loss.backward() 56 | self.hook('on_forward', state) 57 | # to free memory in save_for_backward 58 | state['output'] = None 59 | state['loss'] = None 60 | return loss 61 | 62 | state['optimizer'].zero_grad() 63 | state['optimizer'].step(closure) 64 | self.hook('on_update', state) 65 | state['t'] += 1 66 | state['epoch'] += 1 67 | self.hook('on_end_epoch', state) 68 | self.hook('on_end', state) 69 | return state 70 | 71 | def test(self, network, iterator): 72 | state = { 73 | 'network': network, 74 | 'iterator': iterator, 75 | 't': 0, 76 | 'train': False, 77 | } 78 | 79 | self.hook('on_start', state) 80 | for sample in state['iterator']: 81 | state['sample'] = sample 82 | self.hook('on_sample', state) 83 | 84 | def closure(): 85 | loss, output = state['network'](state['sample']) 86 | state['output'] = output 87 | state['loss'] = loss 88 | self.hook('on_forward', state) 89 | # to free memory in save_for_backward 90 | state['output'] = None 91 | state['loss'] = None 92 | 93 | closure() 94 | state['t'] += 1 95 | self.hook('on_end', state) 96 | return state 97 | -------------------------------------------------------------------------------- /tnt/torchnet/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .meterlogger import MeterLogger 2 | from .filelogger import FileLogger 3 | from .tensorboardmeterlogger import TensorboardMeterLogger 4 | from .visdomlogger import VisdomLogger, VisdomPlotLogger, VisdomSaver, VisdomTextLogger 5 | from .visdommeterlogger import VisdomMeterLogger 6 | 7 | -------------------------------------------------------------------------------- /tnt/torchnet/logger/filelogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | class FileLogger(object): 6 | '''Logs results to a file. 7 | 8 | The FileLogger provides a convenient interface for periodically writing 9 | results to a file. It is designed to capture all information for a given 10 | experiment, which may have a sequence of distinct tasks. Therefore, it writes 11 | results in the format:: 12 | 13 | { 14 | 'tasks': [...] 15 | 'results': [...] 16 | } 17 | 18 | The FileLogger class chooses to use a top-level list instead of a dictionary 19 | to preserve temporal order of tasks (by default). 20 | 21 | Args: 22 | filepath (str): Path to write results to 23 | overwrite (bool): whether to clobber a file if it exists 24 | 25 | Example: 26 | >>> result_writer = ResultWriter(path) 27 | >>> for task in ['CIFAR-10', 'SVHN']: 28 | >>> train_results = train_model() 29 | >>> test_results = test_model() 30 | >>> result_writer.log(task, {'Train': train_results, 'Test': test_results}) 31 | 32 | ''' 33 | 34 | def __init__(self, filepath, overwrite=False): 35 | if not overwrite: 36 | assert not os.path.exists(filepath), 'Cannot write results to "{}". Already exists!'.format(filepath) 37 | with open(filepath, 'wb') as f: 38 | pickle.dump({ 39 | 'tasks': [], 40 | 'results': [] 41 | }, f) 42 | 43 | self.filepath = filepath 44 | self.tasks = set() 45 | 46 | def _add_task(self, task_name): 47 | assert task_name not in self.tasks, "Task already added! Use a different name." 48 | self.tasks.add(task_name) 49 | 50 | def log(self, task_name, result): 51 | ''' Update the results file with new information. 52 | 53 | Args: 54 | task_name (str): Name of the currently running task. A previously unseen 55 | ``task_name`` will create a new entry in both :attr:`tasks` 56 | and :attr:`results`. 57 | result: This will be appended to the list in :attr:`results` which 58 | corresponds to the ``task_name`` in ``task_name``:attr:`tasks`. 59 | 60 | ''' 61 | with open(self.filepath, 'rb') as f: 62 | existing_results = pickle.load(f) 63 | if task_name not in self.tasks: 64 | self._add_task(task_name) 65 | existing_results['tasks'].append(task_name) 66 | existing_results['results'].append([]) 67 | task_name_idx = existing_results['tasks'].index(task_name) 68 | results = existing_results['results'][task_name_idx] 69 | results.append(result) 70 | with open(self.filepath, 'wb') as f: 71 | pickle.dump(existing_results, f) 72 | -------------------------------------------------------------------------------- /tnt/torchnet/logger/logger.py: -------------------------------------------------------------------------------- 1 | """ Logging values to various sinks """ 2 | 3 | 4 | class Logger(object): 5 | _fields = None 6 | 7 | @property 8 | def fields(self): 9 | assert self._fields is not None, "self.fields is not set!" 10 | return self._fields 11 | 12 | @fields.setter 13 | def fields(self, value): 14 | self._fields 15 | 16 | def __init__(self, fields=None): 17 | """ Automatically logs the variables in 'fields' """ 18 | self.fields = fields 19 | 20 | def log(self, *args, **kwargs): 21 | pass 22 | 23 | def log_state(self, state_dict): 24 | pass 25 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/__init__.py: -------------------------------------------------------------------------------- 1 | from .averagevaluemeter import AverageValueMeter 2 | from .valuesummarymeter import ValueSummaryMeter 3 | from .multivaluesummarymeter import MultiValueSummaryMeter 4 | from .classerrormeter import ClassErrorMeter 5 | from .confusionmeter import ConfusionMeter 6 | from .timemeter import TimeMeter 7 | from .msemeter import MSEMeter 8 | from .movingaveragevaluemeter import MovingAverageValueMeter 9 | from .aucmeter import AUCMeter 10 | from .apmeter import APMeter 11 | from .mapmeter import mAPMeter 12 | from .singletonmeter import SingletonMeter 13 | 14 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/aucmeter.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from . import meter 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class AUCMeter(meter.Meter): 8 | """ 9 | The AUCMeter measures the area under the receiver-operating characteristic 10 | (ROC) curve for binary classification problems. The area under the curve (AUC) 11 | can be interpreted as the probability that, given a randomly selected positive 12 | example and a randomly selected negative example, the positive example is 13 | assigned a higher score by the classification model than the negative example. 14 | 15 | The AUCMeter is designed to operate on one-dimensional Tensors `output` 16 | and `target`, where (1) the `output` contains model output scores that ought to 17 | be higher when the model is more convinced that the example should be positively 18 | labeled, and smaller when the model believes the example should be negatively 19 | labeled (for instance, the output of a signoid function); and (2) the `target` 20 | contains only values 0 (for negative examples) and 1 (for positive examples). 21 | """ 22 | 23 | def __init__(self): 24 | super(AUCMeter, self).__init__() 25 | self.reset() 26 | 27 | def reset(self): 28 | self.scores = torch.DoubleTensor(torch.DoubleStorage()).numpy() 29 | self.targets = torch.LongTensor(torch.LongStorage()).numpy() 30 | 31 | def add(self, output, target): 32 | if torch.is_tensor(output): 33 | output = output.cpu().squeeze().numpy() 34 | if torch.is_tensor(target): 35 | target = target.cpu().squeeze().numpy() 36 | elif isinstance(target, numbers.Number): 37 | target = np.asarray([target]) 38 | assert np.ndim(output) == 1, \ 39 | 'wrong output size (1D expected)' 40 | assert np.ndim(target) == 1, \ 41 | 'wrong target size (1D expected)' 42 | assert output.shape[0] == target.shape[0], \ 43 | 'number of outputs and targets does not match' 44 | assert np.all(np.add(np.equal(target, 1), np.equal(target, 0))), \ 45 | 'targets should be binary (0, 1)' 46 | 47 | self.scores = np.append(self.scores, output) 48 | self.targets = np.append(self.targets, target) 49 | 50 | def value(self): 51 | # case when number of elements added are 0 52 | if self.scores.shape[0] == 0: 53 | return 0.5 54 | 55 | # sorting the arrays 56 | scores, sortind = torch.sort(torch.from_numpy( 57 | self.scores), dim=0, descending=True) 58 | scores = scores.numpy() 59 | sortind = sortind.numpy() 60 | 61 | # creating the roc curve 62 | tpr = np.zeros(shape=(scores.size + 1), dtype=np.float64) 63 | fpr = np.zeros(shape=(scores.size + 1), dtype=np.float64) 64 | 65 | for i in range(1, scores.size + 1): 66 | if self.targets[sortind[i - 1]] == 1: 67 | tpr[i] = tpr[i - 1] + 1 68 | fpr[i] = fpr[i - 1] 69 | else: 70 | tpr[i] = tpr[i - 1] 71 | fpr[i] = fpr[i - 1] + 1 72 | 73 | tpr /= (self.targets.sum() * 1.0) 74 | fpr /= ((self.targets - 1.0).sum() * -1.0) 75 | 76 | # calculating area under curve using trapezoidal rule 77 | n = tpr.shape[0] 78 | h = fpr[1:n] - fpr[0:n - 1] 79 | sum_h = np.zeros(fpr.shape) 80 | sum_h[0:n - 1] = h 81 | sum_h[1:n] += h 82 | area = (sum_h * tpr).sum() / 2.0 83 | 84 | return (area, tpr, fpr) 85 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/averagevaluemeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | from .valuesummarymeter import ValueSummaryMeter 4 | import numpy as np 5 | import warnings 6 | 7 | class AverageValueMeter(ValueSummaryMeter): 8 | def __init__(self): 9 | warnings.warn('AverageValueMeter is deprecated in favor of ValueSummaryMeter and will be removed in a future version', FutureWarning) 10 | super(AverageValueMeter, self).__init__() 11 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/classerrormeter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import numbers 4 | from . import meter 5 | 6 | 7 | class ClassErrorMeter(meter.Meter): 8 | def __init__(self, topk=[1], accuracy=False): 9 | super(ClassErrorMeter, self).__init__() 10 | self.topk = np.sort(topk) 11 | self.accuracy = accuracy 12 | self.reset() 13 | 14 | def reset(self): 15 | self.sum = {v: 0 for v in self.topk} 16 | self.n = 0 17 | 18 | def add(self, output, target): 19 | if torch.is_tensor(output): 20 | output = output.cpu().squeeze().numpy() 21 | if torch.is_tensor(target): 22 | target = np.atleast_1d(target.cpu().squeeze().numpy()) 23 | elif isinstance(target, numbers.Number): 24 | target = np.asarray([target]) 25 | if np.ndim(output) == 1: 26 | output = output[np.newaxis] 27 | else: 28 | assert np.ndim(output) == 2, \ 29 | 'wrong output size (1D or 2D expected)' 30 | assert np.ndim(target) == 1, \ 31 | 'target and output do not match' 32 | assert target.shape[0] == output.shape[0], \ 33 | 'target and output do not match' 34 | topk = self.topk 35 | maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64 36 | no = output.shape[0] 37 | 38 | pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy() 39 | correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1) 40 | 41 | for k in topk: 42 | self.sum[k] += no - correct[:, 0:k].sum() 43 | self.n += no 44 | 45 | def value(self, k=-1): 46 | if self.n == 0: 47 | return [-1 for _ in self.topk] 48 | if k != -1: 49 | assert k in self.sum.keys(), \ 50 | 'invalid k (this k was not provided at construction time)' 51 | if self.accuracy: 52 | return (1. - float(self.sum[k]) / self.n) * 100.0 53 | else: 54 | return float(self.sum[k]) / self.n * 100.0 55 | else: 56 | return [self.value(k_) for k_ in self.topk] 57 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/confusionmeter.py: -------------------------------------------------------------------------------- 1 | from . import meter 2 | import numpy as np 3 | 4 | 5 | class ConfusionMeter(meter.Meter): 6 | """Maintains a confusion matrix for a given calssification problem. 7 | 8 | The ConfusionMeter constructs a confusion matrix for a multi-class 9 | classification problems. It does not support multi-label, multi-class problems: 10 | for such problems, please use MultiLabelConfusionMeter. 11 | 12 | Args: 13 | k (int): number of classes in the classification problem 14 | normalized (boolean): Determines whether or not the confusion matrix 15 | is normalized or not 16 | 17 | """ 18 | 19 | def __init__(self, k, normalized=False): 20 | super(ConfusionMeter, self).__init__() 21 | self.conf = np.ndarray((k, k), dtype=np.int32) 22 | self.normalized = normalized 23 | self.k = k 24 | self.reset() 25 | 26 | def reset(self): 27 | self.conf.fill(0) 28 | 29 | def add(self, predicted, target): 30 | """Computes the confusion matrix of K x K size where K is no of classes 31 | 32 | Args: 33 | predicted (tensor): Can be an N x K tensor of predicted scores obtained from 34 | the model for N examples and K classes or an N-tensor of 35 | integer values between 0 and K-1. 36 | target (tensor): Can be a N-tensor of integer values assumed to be integer 37 | values between 0 and K-1 or N x K tensor, where targets are 38 | assumed to be provided as one-hot vectors 39 | 40 | """ 41 | predicted = predicted.cpu().numpy() 42 | target = target.cpu().numpy() 43 | 44 | assert predicted.shape[0] == target.shape[0], \ 45 | 'number of targets and predicted outputs do not match' 46 | 47 | if np.ndim(predicted) != 1: 48 | assert predicted.shape[1] == self.k, \ 49 | 'number of predictions does not match size of confusion matrix' 50 | predicted = np.argmax(predicted, 1) 51 | else: 52 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 53 | 'predicted values are not between 1 and k' 54 | 55 | onehot_target = np.ndim(target) != 1 56 | if onehot_target: 57 | assert target.shape[1] == self.k, \ 58 | 'Onehot target does not match size of confusion matrix' 59 | assert (target >= 0).all() and (target <= 1).all(), \ 60 | 'in one-hot encoding, target values should be 0 or 1' 61 | assert (target.sum(1) == 1).all(), \ 62 | 'multi-label setting is not supported' 63 | target = np.argmax(target, 1) 64 | else: 65 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 66 | 'predicted values are not between 0 and k-1' 67 | 68 | # hack for bincounting 2 arrays together 69 | x = predicted + self.k * target 70 | bincount_2d = np.bincount(x.astype(np.int32), 71 | minlength=self.k ** 2) 72 | assert bincount_2d.size == self.k ** 2 73 | conf = bincount_2d.reshape((self.k, self.k)) 74 | 75 | self.conf += conf 76 | 77 | def value(self): 78 | """ 79 | Returns: 80 | Confustion matrix of K rows and K columns, where rows corresponds 81 | to ground-truth targets and columns corresponds to predicted 82 | targets. 83 | """ 84 | if self.normalized: 85 | conf = self.conf.astype(np.float32) 86 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 87 | else: 88 | return self.conf 89 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/mapmeter.py: -------------------------------------------------------------------------------- 1 | from . import meter, APMeter 2 | 3 | 4 | class mAPMeter(meter.Meter): 5 | """ 6 | The mAPMeter measures the mean average precision over all classes. 7 | 8 | The mAPMeter is designed to operate on `NxK` Tensors `output` and 9 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` 10 | contains model output scores for `N` examples and `K` classes that ought to 11 | be higher when the model is more convinced that the example should be 12 | positively labeled, and smaller when the model believes the example should 13 | be negatively labeled (for instance, the output of a sigmoid function); (2) 14 | the `target` contains only values 0 (for negative examples) and 1 15 | (for positive examples); and (3) the `weight` ( > 0) represents weight for 16 | each sample. 17 | """ 18 | 19 | def __init__(self): 20 | super(mAPMeter, self).__init__() 21 | self.apmeter = APMeter() 22 | 23 | def reset(self): 24 | self.apmeter.reset() 25 | 26 | def add(self, output, target, weight=None): 27 | self.apmeter.add(output, target, weight) 28 | 29 | def value(self): 30 | return self.apmeter.value().mean() 31 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/medianimagemeter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class MedianImageMeter(object): 5 | def __init__(self, bit_depth, im_shape, device='cpu'): 6 | self.bit_depth = bit_depth 7 | self.im_shape = list(im_shape) 8 | self.device = device 9 | if bit_depth == 8: 10 | self.dtype = np.uint8 11 | elif bit_depth == 16: 12 | self.dtype = np.uint16 13 | else: 14 | raise NotImplementedError( 15 | "MedianMeter cannot find the median of non 8/16 bit-depth images.") 16 | self.reset() 17 | 18 | def reset(self): 19 | self.freqs = self.make_freqs_array() 20 | 21 | def add(self, val, mask=1): 22 | self.val = torch.LongTensor( val.astype(np.int64).flatten()[np.newaxis,:] ).to(self.device) 23 | 24 | if type(mask) == int: 25 | mask = torch.IntTensor(self.val.size()).fill_(mask).to(self.device) 26 | else: 27 | mask = torch.IntTensor(mask.astype(np.int32).flatten()[np.newaxis,:]).to(self.device) 28 | 29 | self.freqs.scatter_add_(0, self.val, mask) 30 | self.saved_val = val 31 | 32 | def value(self): 33 | self._avg = np.cumsum( 34 | self.freqs.cpu().numpy(), 35 | axis=0) 36 | self._avg = np.apply_along_axis( 37 | lambda a: a.searchsorted(a[-1] / 2.), 38 | axis=0, 39 | arr=self._avg)\ 40 | .reshape(tuple([-1] + self.im_shape)) 41 | return np.squeeze(self._avg, 0) 42 | 43 | def make_freqs_array(self): 44 | # freqs has shape N_categories x W x H x N_channels 45 | shape = tuple([2**self.bit_depth] + self.im_shape) 46 | freqs = torch.IntTensor(shape[0], int(np.prod(shape[1:]))).zero_() 47 | return freqs.to(self.device) 48 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/meter.py: -------------------------------------------------------------------------------- 1 | 2 | class Meter(object): 3 | '''Meters provide a way to keep track of important statistics in an online manner. 4 | 5 | This class is abstract, but provides a standard interface for all meters to follow. 6 | 7 | ''' 8 | 9 | def reset(self): 10 | '''Resets the meter to default settings.''' 11 | pass 12 | 13 | def add(self, value): 14 | '''Log a new value to the meter 15 | 16 | Args: 17 | value: Next restult to include. 18 | 19 | ''' 20 | pass 21 | 22 | def value(self): 23 | '''Get the value of the meter in the current state.''' 24 | pass 25 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/movingaveragevaluemeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | import torch 4 | 5 | 6 | class MovingAverageValueMeter(meter.Meter): 7 | def __init__(self, windowsize): 8 | super(MovingAverageValueMeter, self).__init__() 9 | self.windowsize = windowsize 10 | self.valuequeue = torch.Tensor(windowsize) 11 | self.reset() 12 | 13 | def reset(self): 14 | self.sum = 0.0 15 | self.n = 0 16 | self.var = 0.0 17 | self.valuequeue.fill_(0) 18 | 19 | def add(self, value): 20 | queueid = (self.n % self.windowsize) 21 | oldvalue = self.valuequeue[queueid] 22 | self.sum += value - oldvalue 23 | self.var += value * value - oldvalue * oldvalue 24 | self.valuequeue[queueid] = value 25 | self.n += 1 26 | 27 | def value(self): 28 | n = min(self.n, self.windowsize) 29 | mean = self.sum / max(1, n) 30 | std = math.sqrt(max((self.var - n * mean * mean) / max(1, n - 1), 0)) 31 | return mean, std 32 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/msemeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | import torch 4 | 5 | 6 | class MSEMeter(meter.Meter): 7 | def __init__(self, root=False): 8 | super(MSEMeter, self).__init__() 9 | self.reset() 10 | self.root = root 11 | 12 | def reset(self): 13 | self.n = 0 14 | self.sesum = 0.0 15 | 16 | def add(self, output, target): 17 | if not torch.is_tensor(output) and not torch.is_tensor(target): 18 | output = torch.from_numpy(output) 19 | target = torch.from_numpy(target) 20 | self.n += output.numel() 21 | self.sesum += torch.sum((output - target) ** 2) 22 | 23 | def value(self): 24 | mse = self.sesum / max(1, self.n) 25 | return math.sqrt(mse) if self.root else mse 26 | 27 | def __str__(self): 28 | res = "RMSE" if self.root else "MSE" 29 | res += " %.3f\t" 30 | tval = [self.value()] 31 | return res % tuple(tval) 32 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/multivaluesummarymeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter, ValueSummaryMeter 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class MultiValueSummaryMeter(ValueSummaryMeter): 8 | def __init__(self, keys): 9 | ''' 10 | Args: 11 | keys: An iterable of keys 12 | ''' 13 | super(MultiValueSummaryMeter, self).__init__() 14 | self.keys = list(keys) 15 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/singletonmeter.py: -------------------------------------------------------------------------------- 1 | from . import meter 2 | 3 | 4 | class SingletonMeter(meter.Meter): 5 | '''Stores exactly one value which can be regurgitated''' 6 | 7 | def __init__(self, maxlen=1): 8 | super(SingletonMeter, self).__init__() 9 | self.__val = None 10 | 11 | def reset(self): 12 | '''Resets the meter to default settings.''' 13 | old_val = self.__val 14 | self.__val = None 15 | return old_val 16 | 17 | def add(self, value): 18 | '''Log a new value to the meter 19 | 20 | Args: 21 | value: Next restult to include. 22 | ''' 23 | self.__val = value 24 | 25 | def value(self): 26 | '''Get the value of the meter in the current state.''' 27 | return self.__val 28 | 29 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/timemeter.py: -------------------------------------------------------------------------------- 1 | import time 2 | from . import meter 3 | 4 | 5 | class TimeMeter(meter.Meter): 6 | """ 7 | 8 | #### tnt.TimeMeter(@ARGP) 9 | @ARGT 10 | 11 | The `tnt.TimeMeter` is designed to measure the time between events and can be 12 | used to measure, for instance, the average processing time per batch of data. 13 | It is different from most other meters in terms of the methods it provides: 14 | 15 | The `tnt.TimeMeter` provides the following methods: 16 | 17 | * `reset()` resets the timer, setting the timer and unit counter to zero. 18 | * `value()` returns the time passed since the last `reset()`; divided by the counter value when `unit=true`. 19 | """ 20 | 21 | def __init__(self, unit): 22 | super(TimeMeter, self).__init__() 23 | self.unit = unit 24 | self.reset() 25 | 26 | def reset(self): 27 | self.n = 0 28 | self.time = time.time() 29 | 30 | def value(self): 31 | return time.time() - self.time 32 | -------------------------------------------------------------------------------- /tnt/torchnet/meter/valuesummarymeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class ValueSummaryMeter(meter.Meter): 8 | def __init__(self): 9 | super(ValueSummaryMeter, self).__init__() 10 | self.reset() 11 | self.val = 0 12 | 13 | def add(self, value, n=1): 14 | self.val = value 15 | self.sum += value 16 | self.var += value * value 17 | self.n += n 18 | 19 | if self.n == 0: 20 | self.mean, self.std = np.nan, np.nan 21 | elif self.n == 1: 22 | self.mean = self.sum + 0.0 # This is to force a copy in torch/numpy 23 | self.min = self.mean + 0.0 24 | self.max = self.mean + 0.0 25 | self.std = np.inf 26 | self.mean_old = self.mean 27 | self.m_s = 0.0 28 | else: 29 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 30 | self.m_s += (value - self.mean_old) * (value - self.mean) 31 | self.mean_old = self.mean 32 | self.std = np.sqrt(self.m_s / (self.n - 1.0)) 33 | self.min = np.minimum(self.min, value) 34 | self.max = np.maximum(self.max, value) 35 | 36 | def value(self): 37 | return self.mean, self.std 38 | 39 | def reset(self): 40 | self.n = 0 41 | self.sum = 0.0 42 | self.var = 0.0 43 | self.val = 0.0 44 | self.mean = np.nan 45 | self.mean_old = 0.0 46 | self.m_s = 0.0 47 | self.std = np.nan 48 | self.min = np.nan 49 | self.max = np.nan 50 | 51 | def __str__(self): 52 | old_po = np.get_printoptions() 53 | np.set_printoptions(precision=3) 54 | res = "mean(std) {} ({}) \tmin/max {}/{}\t".format( 55 | *[np.array(v) for v in [self.mean, self.std, self.min, self.max]]) 56 | np.set_printoptions(**old_po) 57 | return res 58 | -------------------------------------------------------------------------------- /tnt/torchnet/transform.py: -------------------------------------------------------------------------------- 1 | from six import iteritems 2 | from .utils.table import canmergetensor as canmerge 3 | from .utils.table import mergetensor as mergetensor 4 | 5 | 6 | def compose(transforms): 7 | assert isinstance(transforms, list) 8 | for tr in transforms: 9 | assert callable(tr), 'list of functions expected' 10 | 11 | def composition(z): 12 | for tr in transforms: 13 | z = tr(z) 14 | return z 15 | return composition 16 | 17 | 18 | def tablemergekeys(): 19 | def mergekeys(tbl): 20 | mergetbl = {} 21 | if isinstance(tbl, dict): 22 | for idx, elem in tbl.items(): 23 | for key, value in elem.items(): 24 | if key not in mergetbl: 25 | mergetbl[key] = {} 26 | mergetbl[key][idx] = value 27 | elif isinstance(tbl, list): 28 | for elem in tbl: 29 | for key, value in elem.items(): 30 | if key not in mergetbl: 31 | mergetbl[key] = [] 32 | mergetbl[key].append(value) 33 | return mergetbl 34 | return mergekeys 35 | 36 | 37 | def tableapply(f): 38 | return lambda d: dict(map(lambda kv: (kv[0], f(kv[1])), iteritems(d))) 39 | 40 | 41 | def makebatch(merge=None): 42 | if merge: 43 | makebatch = compose([tablemergekeys(), merge]) 44 | else: 45 | makebatch = compose([ 46 | tablemergekeys(), 47 | tableapply(lambda field: mergetensor(field) 48 | if canmerge(field) else field) 49 | ]) 50 | 51 | return lambda samples: makebatch(samples) 52 | -------------------------------------------------------------------------------- /tnt/torchnet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .multitaskdataloader import MultiTaskDataLoader 2 | -------------------------------------------------------------------------------- /tnt/torchnet/utils/multitaskdataloader.py: -------------------------------------------------------------------------------- 1 | from itertools import islice, chain, repeat 2 | import torch.utils.data 3 | 4 | 5 | class MultiTaskDataLoader(object): 6 | '''Loads batches simultaneously from multiple datasets. 7 | 8 | The MultiTaskDataLoader is designed to make multi-task learning simpler. It is 9 | ideal for jointly training a model for multiple tasks or multiple datasets. 10 | MultiTaskDataLoader is initialzes with an iterable of :class:`Dataset` objects, 11 | and provides an iterator which will return one batch that contains an equal number 12 | of samples from each of the :class:`Dataset` s. 13 | 14 | Specifically, it returns batches of ``[(B_0, 0), (B_1, 1), ..., (B_k, k)]`` 15 | from datasets ``(D_0, ..., D_k)``, where each `B_i` has :attr:`batch_size` samples 16 | 17 | 18 | Args: 19 | datasets: A list of :class:`Dataset` objects to serve batches from 20 | batch_size: Each batch from each :class:`Dataset` will have this many samples 21 | use_all (bool): If True, then the iterator will return batches until all 22 | datasets are exhausted. If False, then iteration stops as soon as one dataset 23 | runs out 24 | loading_kwargs: These are passed to the children dataloaders 25 | 26 | 27 | Example: 28 | >>> train_loader = MultiTaskDataLoader([dataset1, dataset2], batch_size=3) 29 | >>> for ((datas1, labels1), task1), (datas2, labels2), task2) in train_loader: 30 | >>> print(task1, task2) 31 | 0 1 32 | 0 1 33 | ... 34 | 0 1 35 | 36 | ''' 37 | 38 | def __init__(self, datasets, batch_size=1, use_all=False, **loading_kwargs): 39 | self.loaders = [] 40 | self.batch_size = batch_size 41 | self.use_all = use_all 42 | self.loading_kwargs = loading_kwargs 43 | for dataset in datasets: 44 | loader = torch.utils.data.DataLoader( 45 | dataset, 46 | batch_size=self.batch_size, 47 | **self.loading_kwargs) 48 | self.loaders.append(loader) 49 | self.min_loader_size = min([len(l) for l in self.loaders]) 50 | self.current_loader = 0 51 | 52 | def __iter__(self): 53 | '''Returns an iterator that simultaneously returns batches from each dataset. 54 | Specifically, it returns batches of 55 | [(B_0, 0), (B_1, 1), ..., (B_k, k)] 56 | from datasets 57 | (D_0, ..., D_k), 58 | 59 | ''' 60 | return zip_batches(*[zip(iter(l), repeat(loader_num)) for loader_num, l in enumerate(self.loaders)], 61 | use_all=self.use_all) 62 | 63 | def __len__(self): 64 | if self.use_all: 65 | return max([len(l) for loader in self.loaders]) 66 | else: 67 | return self.min_loader_size 68 | 69 | 70 | def zip_batches(*iterables, **kwargs): 71 | use_all = kwargs.pop('use_all', False) 72 | if use_all: 73 | try: 74 | from itertools import izip_longest as zip_longest 75 | except ImportError: 76 | from itertools import zip_longest 77 | return zip_longest(fillvalue=None, *iterables) 78 | else: 79 | return zip(*iterables) 80 | -------------------------------------------------------------------------------- /tnt/torchnet/utils/table.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def canmergetensor(tbl): 5 | if not isinstance(tbl, list): 6 | return False 7 | 8 | if torch.is_tensor(tbl[0]): 9 | sz = tbl[0].numel() 10 | for v in tbl: 11 | if v.numel() != sz: 12 | return False 13 | return True 14 | return False 15 | 16 | 17 | def mergetensor(tbl): 18 | sz = [len(tbl)] + list(tbl[0].size()) 19 | res = tbl[0].new(torch.Size(sz)) 20 | for i, v in enumerate(tbl): 21 | res[i].copy_(v) 22 | return res 23 | -------------------------------------------------------------------------------- /tnt/tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E305,E402,E721,E741,F401,F403,F405,F821,F841,F999 4 | exclude = build 5 | --------------------------------------------------------------------------------