├── keras_project_template ├── src │ ├── configs │ │ ├── __init__.py │ │ └── simple_CNN.py │ ├── __init__.py │ ├── scripts │ │ └── train_simple_CNN.py │ ├── models.py │ ├── data.py │ └── training_loop.py ├── etc │ └── README.md ├── requirements.txt ├── notebooks │ └── README.md ├── env.sh ├── .gitignore └── README.md ├── pytorch_project_template ├── results │ └── example_run │ │ ├── FINISHED │ │ ├── init_weights.pt │ │ ├── model_best_val.pt │ │ ├── model_last_epoch.pt │ │ ├── events.out.tfevents.1559763675.JASTRS01-MBP15T.local │ │ ├── history.csv │ │ ├── meta.json │ │ ├── cnn.gin │ │ ├── config.txt │ │ └── train.py ├── src │ ├── models │ │ ├── __init__.py │ │ ├── lenet.py │ │ └── simple_cnn.py │ ├── callbacks │ │ └── __init__.py │ ├── data │ │ ├── streams.py │ │ ├── __init__.py │ │ └── datasets.py │ ├── __init__.py │ └── utils.py ├── configs │ ├── cnn.gin │ └── cnn_full.gin ├── experiments │ └── tune_lr │ │ ├── template_config.gin │ │ ├── README.md │ │ └── main.py ├── e.sh ├── bin │ └── train.py ├── e.yml └── README.md ├── pytorch_lightning_project_template ├── src │ ├── modules │ │ ├── __init__.py │ │ └── supervised_training.py │ ├── models │ │ ├── __init__.py │ │ ├── simple_cnn.py │ │ └── lenet.py │ ├── data │ │ ├── __init__.py │ │ ├── utils.py │ │ └── datasets.py │ ├── callbacks │ │ ├── __init__.py │ │ └── base.py │ ├── __init__.py │ ├── training_loop.py │ └── utils.py ├── bin │ ├── utils │ │ ├── sync.sh │ │ ├── exclude.rsync │ │ ├── watch_changes.sh │ │ ├── slurm_template.sh │ │ ├── update_plots.py │ │ ├── run_on_a_gpu.py │ │ ├── run_on_free_gpus.py │ │ └── run_slurm.py │ ├── evaluate_supervised.py │ └── train_supervised.py ├── experiments │ └── tune_lr │ │ ├── large │ │ ├── configs │ │ │ ├── 0.gin │ │ │ ├── 1.gin │ │ │ ├── 2.gin │ │ │ ├── 3.gin │ │ │ ├── 4.gin │ │ │ └── 5.gin │ │ ├── batch.sh │ │ └── run.sh │ │ ├── template_config.gin │ │ ├── README.md │ │ └── main.py ├── configs │ ├── cnn.gin │ └── cnn_full.gin ├── e.sh ├── e.yml └── README.md ├── .gitignore ├── tf2_project_template ├── bin │ ├── utils │ │ ├── sync.sh │ │ ├── exclude.rsync │ │ ├── watch_changes.sh │ │ ├── slurm_template.sh │ │ ├── update_plots.py │ │ ├── run_on_a_gpu.py │ │ ├── run_on_free_gpus.py │ │ └── run_slurm.py │ ├── evaluate.py │ └── train.py ├── src │ ├── models │ │ ├── __init__.py │ │ └── simple_cnn.py │ ├── data │ │ ├── __init__.py │ │ └── streams.py │ ├── callbacks │ │ └── __init__.py │ ├── __init__.py │ ├── plotting.py │ └── utils.py ├── experiments │ ├── README.md │ └── tune_lr │ │ ├── large │ │ ├── batch.sh │ │ └── run.sh │ │ ├── template_config.gin │ │ ├── README.md │ │ └── main.py ├── configs │ ├── scnn.gin │ └── scnn_neptune.gin ├── env.sh ├── tf2_project_template.yml └── README.md ├── LICENSE └── README.md /keras_project_template/src/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/FINISHED: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_project_template/etc/README.md: -------------------------------------------------------------------------------- 1 | Any relevant files to the project. -------------------------------------------------------------------------------- /keras_project_template/requirements.txt: -------------------------------------------------------------------------------- 1 | theano >= 0.9 2 | keras >= 2.0.0 -------------------------------------------------------------------------------- /keras_project_template/notebooks/README.md: -------------------------------------------------------------------------------- 1 | Any relevant notebooks following convention `number_name.ipynb`, for instance `1.0_exploratory_analysis.ipynb`. -------------------------------------------------------------------------------- /keras_project_template/src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relevant constants 3 | """ 4 | 5 | import os 6 | 7 | DATA_DIR = os.environ.get("DATA_DIR", "data") 8 | -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/init_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/toolkit/HEAD/pytorch_project_template/results/example_run/init_weights.pt -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/model_best_val.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/toolkit/HEAD/pytorch_project_template/results/example_run/model_best_val.pt -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/model_last_epoch.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/toolkit/HEAD/pytorch_project_template/results/example_run/model_last_epoch.pt -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | *.pyc 3 | 4 | # ipynb 5 | notebooks/.ipynb_checkpoints 6 | 7 | # pycharm 8 | .idea 9 | 10 | # mandala 11 | .backends 12 | .graphs 13 | .mandala_cache 14 | *.pkl 15 | -------------------------------------------------------------------------------- /pytorch_project_template/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Models used in the project 4 | """ 5 | 6 | from .lenet import LeNet 7 | from .simple_cnn import SimpleCNN 8 | 9 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Models used in the project 4 | """ 5 | 6 | from .lenet import LeNet 7 | from .simple_cnn import SimpleCNN 8 | 9 | -------------------------------------------------------------------------------- /keras_project_template/env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$HOME/example_project 4 | export DATA_DIR=$SCRATCH/example_project/data 5 | export RESULTS_DIR=$SCRATCH/example_project/results -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/events.out.tfevents.1559763675.JASTRS01-MBP15T.local: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/toolkit/HEAD/pytorch_project_template/results/example_run/events.out.tfevents.1559763675.JASTRS01-MBP15T.local -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/sync.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ "$1" = "uj" ]; then 4 | WHERE=jastrzebski@access.capdnet.ii.uj.edu.pl:/home/jastrzebski/$PNAME 5 | rsync -vrpa * --exclude-from=bin/utils/exclude.rsync $WHERE 6 | fi -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/sync.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ "$1" = "uj" ]; then 4 | WHERE=jastrzebski@access.capdnet.ii.uj.edu.pl:/home/jastrzebski/$PNAME 5 | rsync -vrpa * --exclude-from=bin/utils/exclude.rsync $WHERE 6 | fi -------------------------------------------------------------------------------- /tf2_project_template/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Models used in the project 4 | """ 5 | 6 | import logging 7 | logger = logging.getLogger() 8 | 9 | from .simple_cnn import SimpleCNN 10 | 11 | custom_tf_layers = {} 12 | -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/history.csv: -------------------------------------------------------------------------------- 1 | epoch,loss,time,acc,val_loss,val_acc 2 | 0,5.935613933563232,1.7076696040000003,9.100000061035157,5.924050872802734,10.4 3 | 1,5.903020370483398,1.9569860650000006,8.400000061035156,5.882277899169922,9.14 4 | -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "cmd": "python bin/train.py results/example_run configs/cnn.gin -b=training_loop.n_epochs=2", 3 | "save_path": "results/example_run", 4 | "most_recent_train_start_date": "2019_06_05", 5 | "execution_time": 3.706902027130127 6 | } -------------------------------------------------------------------------------- /tf2_project_template/experiments/README.md: -------------------------------------------------------------------------------- 1 | # Experiments 2 | 3 | This folder contains scripts to reproduce experiments in Section 5 of the paper. 4 | 5 | ## Flat trajectories generalize well 6 | 7 | To run the experiment generate jobs using `prepare.sh`, execute them, and finally plot results using `plot.sh`. -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/exclude.rsync: -------------------------------------------------------------------------------- 1 | subprojects* 2 | *svg 3 | *png 4 | *h5 5 | old/* 6 | notebooks/* 7 | *txt 8 | *png 9 | *pdf 10 | *zip 11 | *csv 12 | results/* 13 | __MACOSX/* 14 | ./data/* 15 | *.pkl 16 | reports/* 17 | *.pyc 18 | *.ipynb_checkpoints 19 | exclude.rsync~ 20 | *.so 21 | *.info 22 | *.log 23 | paper/* 24 | .git/* -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/exclude.rsync: -------------------------------------------------------------------------------- 1 | subprojects* 2 | *svg 3 | *png 4 | *h5 5 | old/* 6 | notebooks/* 7 | *txt 8 | *png 9 | *pdf 10 | *zip 11 | *csv 12 | results/* 13 | __MACOSX/* 14 | ./data/* 15 | *.pkl 16 | reports/* 17 | *.pyc 18 | *.ipynb_checkpoints 19 | exclude.rsync~ 20 | *.so 21 | *.info 22 | *.log 23 | paper/* 24 | .git/* -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/watch_changes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CONFIGURATION 4 | MACHINE=uj 5 | 6 | watchman watch-del-all 7 | pkill watchman 8 | watchman-make -p 'experiments/*json' 'experiments/*.py' 'experiments/**/*.py' 'experiments/**/*.sh' 'experiments/**/*.gin' 'src/**/*.py' 'bin/**/*.py' 'bin/*.py' 'src/*.py' 'configs/*gin' --run "bash `pwd`/bin/utils/sync.sh $MACHINE" & 9 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/watch_changes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CONFIGURATION 4 | MACHINE=uj 5 | 6 | watchman watch-del-all 7 | pkill watchman 8 | watchman-make -p 'experiments/*json' 'experiments/*.py' 'experiments/**/*.py' 'experiments/**/*.sh' 'experiments/**/*.gin' 'src/**/*.py' 'bin/**/*.py' 'bin/*.py' 'src/*.py' 'configs/*gin' --run "bash `pwd`/bin/utils/sync.sh $MACHINE" & 9 | -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/slurm_template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --output={save_path}/out 3 | #SBATCH --error={save_path}/err 4 | #SBATCH --gres=gpu 5 | #SBATCH -J {batch_name} 6 | #SBATCH --time='8:00:00' 7 | #SBATCH -p'gpu4_medium,gpu4_long,gpu4_short,gpu8_short,gpu8_medium,gpu8_long' 8 | #SBATCH --mem=10000 9 | cd /gpfs/home/jastrs01/cooperative_optimization 10 | source /gpfs/home/jastrs01/cooperative_optimization/e_bigpurple.sh 11 | export CUDA_VISIBLE_DEVICES=0 12 | {job} -------------------------------------------------------------------------------- /tf2_project_template/experiments/tune_lr/large/batch.sh: -------------------------------------------------------------------------------- 1 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/tf2_project_template/results/tune_lr/large/0 experiments/tune_lr/large/configs/0.gin 2 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/tf2_project_template/results/tune_lr/large/1 experiments/tune_lr/large/configs/1.gin 3 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/tf2_project_template/results/tune_lr/large/2 experiments/tune_lr/large/configs/2.gin 4 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/slurm_template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --output={save_path}/out 3 | #SBATCH --error={save_path}/err 4 | #SBATCH --gres=gpu 5 | #SBATCH -J {batch_name} 6 | #SBATCH --time='8:00:00' 7 | #SBATCH -p'gpu4_medium,gpu4_long,gpu4_short,gpu8_short,gpu8_medium,gpu8_long' 8 | #SBATCH --mem=10000 9 | cd /gpfs/home/jastrs01/cooperative_optimization 10 | source /gpfs/home/jastrs01/cooperative_optimization/e_bigpurple.sh 11 | export CUDA_VISIBLE_DEVICES=0 12 | {job} -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/configs/0.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule'] 9 | LRSchedule.schedule=[[2, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=0.001 11 | 12 | # Training loop 13 | train.n_epochs=1 14 | 15 | # Dataset 16 | get_dataset.dataset='cifar' 17 | get_dataset.seed=777 18 | cifar.variant='10' 19 | cifar.use_valid=True -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/configs/1.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule'] 9 | LRSchedule.schedule=[[2, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=0.01 11 | 12 | # Training loop 13 | train.n_epochs=1 14 | 15 | # Dataset 16 | get_dataset.dataset='cifar' 17 | get_dataset.seed=777 18 | cifar.variant='10' 19 | cifar.use_valid=True -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/configs/2.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule'] 9 | LRSchedule.schedule=[[2, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=0.1 11 | 12 | # Training loop 13 | train.n_epochs=1 14 | 15 | # Dataset 16 | get_dataset.dataset='cifar' 17 | get_dataset.seed=777 18 | cifar.variant='10' 19 | cifar.use_valid=True -------------------------------------------------------------------------------- /pytorch_lightning_project_template/configs/cnn.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule'] 9 | LRSchedule.schedule=[[2, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=0.01 11 | 12 | # Training loop 13 | training_loop.n_epochs=5 14 | training_loop.limit_train_batches=10 15 | 16 | # Dataset 17 | get_dataset.dataset='cifar' 18 | get_dataset.seed=777 19 | cifar.variant='10' 20 | cifar.use_valid=True -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Datasets available for use in the project. 4 | """ 5 | import logging 6 | import gin 7 | 8 | from .datasets import cifar, stl10 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | @gin.configurable 13 | def get_dataset(dataset, seed, **kwargs): 14 | bundle = globals()[dataset](seed=seed, **kwargs) 15 | logger.info("Loaded dataset of name {} with x_train.shape={}".format(dataset, bundle[-1]['input_shape'])) 16 | return bundle 17 | -------------------------------------------------------------------------------- /tf2_project_template/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Simple data getters. Each returns iterator for train and dataset for test/valid. 4 | """ 5 | import logging 6 | from .datasets import cifar 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | def get_dataset(dataset, seed, **kwargs): 11 | bundle = globals()[dataset](seed=seed, **kwargs) 12 | logger.info("Loaded dataset of name {} with x_train.shape={} and num_classes={}".format(dataset, bundle[-1]['input_shape'], bundle[-1]['num_classes'])) 13 | return bundle 14 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/configs/3.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[1.0,2],[0.01,10]] 10 | LRSchedule.base_lr=16 11 | 12 | # Training loop 13 | training_loop.n_epochs=2 14 | 15 | # Dataset 16 | get_dataset.dataset='cifar' 17 | get_dataset.n_examples=1000 18 | get_dataset.data_seed=777 19 | cifar.which=10 20 | cifar.preprocessing='center' -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/configs/4.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[1.0,2],[0.01,10]] 10 | LRSchedule.base_lr=32 11 | 12 | # Training loop 13 | training_loop.n_epochs=2 14 | 15 | # Dataset 16 | get_dataset.dataset='cifar' 17 | get_dataset.n_examples=1000 18 | get_dataset.data_seed=777 19 | cifar.which=10 20 | cifar.preprocessing='center' -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/configs/5.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[1.0,2],[0.01,10]] 10 | LRSchedule.base_lr=64 11 | 12 | # Training loop 13 | training_loop.n_epochs=2 14 | 15 | # Dataset 16 | get_dataset.dataset='cifar' 17 | get_dataset.n_examples=1000 18 | get_dataset.data_seed=777 19 | cifar.which=10 20 | cifar.preprocessing='center' -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/template_config.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule'] 9 | LRSchedule.schedule=[[2, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=$learning_rate$ 11 | 12 | # Training loop 13 | training_loop.n_epochs=1 14 | training_loop.limit_train_batches=50 15 | 16 | # Dataset 17 | get_dataset.dataset='cifar' 18 | get_dataset.seed=777 19 | cifar.variant='10' 20 | cifar.use_valid=True -------------------------------------------------------------------------------- /pytorch_project_template/configs/cnn.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[5, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=0.01 11 | 12 | # Training loop 13 | training_loop.n_epochs=2 14 | training_loop.reload=False 15 | 16 | # Dataset 17 | get_dataset.dataset='cifar' 18 | get_dataset.n_examples=1000 19 | get_dataset.data_seed=777 20 | cifar.which=10 21 | cifar.preprocessing='center' -------------------------------------------------------------------------------- /pytorch_project_template/configs/cnn_full.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=128 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[2, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=0.1 11 | 12 | # Training loop 13 | training_loop.n_epochs=10 14 | training_loop.reload=False 15 | 16 | # Dataset 17 | get_dataset.dataset='cifar' 18 | get_dataset.n_examples=-1 19 | get_dataset.data_seed=777 20 | cifar.which=10 21 | cifar.preprocessing='center' -------------------------------------------------------------------------------- /pytorch_project_template/experiments/tune_lr/template_config.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[1.0,2],[0.01,10]] 10 | LRSchedule.base_lr=$learning_rate$ 11 | 12 | # Training loop 13 | training_loop.n_epochs=2 14 | 15 | # Dataset 16 | get_dataset.dataset='cifar' 17 | get_dataset.n_examples=1000 18 | get_dataset.data_seed=777 19 | cifar.which=10 20 | cifar.preprocessing='center' -------------------------------------------------------------------------------- /pytorch_lightning_project_template/configs/cnn_full.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=128 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[2, 1.0],[10, 0.1]] 10 | LRSchedule.base_lr=0.1 11 | 12 | # Training loop 13 | training_loop.n_epochs=5 14 | training_loop.resume=False 15 | # training_loop.use_cpu=True 16 | 17 | # Dataset 18 | get_dataset.dataset='cifar' 19 | get_dataset.seed=777 20 | cifar.variant='10' 21 | cifar.use_valid = True -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/cnn.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.callbacks=['lr_schedule', 'meta_saver'] 9 | LRSchedule.schedule=[[1.0,2],[0.01,10]] 10 | LRSchedule.base_lr=0.01 11 | 12 | # Training loop 13 | training_loop.n_epochs=2 14 | training_loop.reload=False 15 | 16 | # Dataset 17 | get_dataset.dataset='cifar' 18 | get_dataset.n_examples=1000 19 | get_dataset.data_seed=777 20 | cifar.which=10 21 | cifar.preprocessing='center' -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/batch.sh: -------------------------------------------------------------------------------- 1 | python3 bin/train_supervised.py /Users/kudkudak/Dropbox/Projekty/toolkit/pytorch_lightning_project_template/results/tune_lr/large/0 experiments/tune_lr/large/configs/0.gin 2 | python3 bin/train_supervised.py /Users/kudkudak/Dropbox/Projekty/toolkit/pytorch_lightning_project_template/results/tune_lr/large/1 experiments/tune_lr/large/configs/1.gin 3 | python3 bin/train_supervised.py /Users/kudkudak/Dropbox/Projekty/toolkit/pytorch_lightning_project_template/results/tune_lr/large/2 experiments/tune_lr/large/configs/2.gin 4 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Callbacks available in the project 4 | """ 5 | import logging 6 | from src.callbacks.base import LRSchedule 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | # Add your callbacks here 11 | _ALIASES = { 12 | "lr_schedule": LRSchedule 13 | } 14 | 15 | def get_callback(clb_name, verbose=1, **kwargs): 16 | if clb_name in _ALIASES: 17 | return _ALIASES[clb_name](**kwargs) 18 | else: 19 | if verbose: 20 | logger.warning("Couldn't find {} callback. Skipping.".format(clb_name)) 21 | return None 22 | -------------------------------------------------------------------------------- /tf2_project_template/experiments/tune_lr/template_config.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.n_filters=30 4 | SimpleCNN.n_dense=10 5 | 6 | # Train configuration 7 | train.batch_size=128 8 | train.steps_per_epoch=1 9 | train.callbacks=['lr_schedule', 'meta_saver'] 10 | LRSchedule.schedule=[[1.0,2],[0.01,10]] 11 | LRSchedule.base_lr=$learning_rate$ 12 | 13 | # Training loop 14 | training_loop.n_epochs=2 15 | training_loop.evaluation_freq=1 16 | training_loop.save_freq=1 17 | training_loop.reload=False 18 | 19 | # Dataset 20 | train.datasets=['cifar'] 21 | cifar.stream_seed=1 22 | cifar.n_examples=-1 23 | cifar.use_valid=False 24 | train.data_seed=777 25 | -------------------------------------------------------------------------------- /keras_project_template/src/configs/simple_CNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Configs used in the project 4 | """ 5 | 6 | from src.vegab import ConfigRegistry 7 | 8 | simple_CNN_configs = ConfigRegistry() 9 | 10 | simple_CNN_configs.set_root_config({ 11 | "n_layers": 1, 12 | "batch_size": 128, 13 | "augmented": True, 14 | "n_epochs": 2, 15 | "lr_schedule": [[10, 0.1], [20, 0.01]], 16 | "dim_dense": 100, 17 | "n_filters": 100 18 | }) 19 | 20 | c = simple_CNN_configs['root'] 21 | c['dataset'] = 'cifar10' 22 | simple_CNN_configs['cifar10'] = c 23 | 24 | 25 | c = simple_CNN_configs['root'] 26 | c['dataset'] = 'cifar100' 27 | simple_CNN_configs['cifar100'] = c -------------------------------------------------------------------------------- /pytorch_project_template/src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Callback module (inspired by Keras). 4 | """ 5 | from .callbacks import * 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | _ALIASES = { 10 | "lr_schedule": LRSchedule, 11 | "meta_saver": MetaSaver 12 | } 13 | 14 | def get_callback(clb_name, verbose=1, **kwargs): 15 | if clb_name in callbacks.__dict__: 16 | return callbacks.__dict__[clb_name](**kwargs) 17 | elif clb_name in _ALIASES: 18 | return _ALIASES[clb_name](**kwargs) 19 | else: 20 | if verbose: 21 | logger.warning("Couldn't find {} callback. Skipping.".format(clb_name)) 22 | return None 23 | -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/config.txt: -------------------------------------------------------------------------------- 1 | {('', '__main__.train'): {'lr': 0.1, 'batch_size': 128, 'callbacks': ['lr_schedule', 'meta_saver'], 'model': 'SimpleCNN'}, ('', 'src.data.get_dataset'): {'dataset': 'cifar', 'n_examples': 1000, 'data_seed': 777}, ('', 'src.data.datasets.cifar'): {'which': 10, 'preprocessing': 'center', 'use_valid': True}, ('', 'src.models.simple_cnn.SimpleCNN'): {'n_filters': 30, 'n_dense': 10}, ('', 'src.callbacks.callbacks.LRSchedule'): {'schedule': [[1.0, 2], [0.01, 10]], 'base_lr': 0.01}, ('', 'src.callbacks.callbacks.MetaSaver'): {}, ('', 'src.training_loop.training_loop'): {'checkpoint_monitor': 'val_acc', 'reload': False, 'n_epochs': 2, 'save_freq': 1, 'save_history_every_k_examples': -1}} -------------------------------------------------------------------------------- /tf2_project_template/src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Callback module (inspired by Keras). 4 | """ 5 | 6 | from .base import * 7 | 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | _ALIASES = { 13 | # Base 14 | "batch_lr_schedule": BatchLRSchedule, 15 | "neptune_monitor": NeptuneMonitor, 16 | "lr_schedule": LRSchedule, 17 | "meta_saver": MetaSaver, 18 | "save_weights": SaveWeights, 19 | "weight_norm": WeightNorm, 20 | } 21 | 22 | def get_callback(clb_name, verbose=1, **kwargs): 23 | if clb_name in _ALIASES: 24 | return _ALIASES[clb_name](**kwargs) 25 | else: 26 | if verbose: 27 | logger.warning("Couldn't find {} callback. Skipping.".format(clb_name)) 28 | return None 29 | -------------------------------------------------------------------------------- /pytorch_project_template/e.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export PNAME="pytorch_project_template" 3 | export ROOT="$( cd "$(dirname "$0")" ; pwd -P )" 4 | echo "Welcome to $PNAME rooted at $ROOT" 5 | echo "-" 6 | 7 | # Activates conda environment 8 | source activate ${PNAME} 9 | 10 | # Configures paths. Adapt to your needs! 11 | export PYTHONPATH=$ROOT:$PYTHONPATH 12 | export DATA_DIR=$ROOT/data 13 | export RESULTS_DIR=$ROOT/results 14 | 15 | # Optional: enables plotting in iterm 16 | # export MPLBACKEND="module://itermplot" 17 | 18 | # Switches off importing out of environment packages 19 | export PYTHONNOUSERSITE=1 20 | 21 | if [ ! -d "${DATA_DIR}" ]; then 22 | echo "Creating ${DATA_DIR}" 23 | mkdir -p ${DATA_DIR} 24 | fi 25 | 26 | if [ ! -d "${RESULTS_DIR}" ]; then 27 | echo "Creating ${RESULTS_DIR}" 28 | mkdir -p ${RESULTS_DIR} 29 | fi 30 | -------------------------------------------------------------------------------- /tf2_project_template/configs/scnn.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.bn=False # Stateless 4 | SimpleCNN.kernel_size=5 5 | SimpleCNN.n_filters=32 6 | 7 | # Train configuration 8 | train.batch_size=128 9 | train.callbacks=['lr_schedule'] 10 | train.momentum=0.0 11 | train.wd=0.0 12 | train.seed=777 13 | train.steps_per_epoch=1 # For speed purposes. Change to -1. 14 | train.data_seed=777 15 | LRSchedule.base_lr=0.03 16 | LRSchedule.schedule=[[150, 1.0],[225, 0.1],[30000, 0.01]] # A pretty standard schedule, for some reason 17 | 18 | # Training loop 19 | training_loop.n_epochs=2 20 | training_loop.evaluation_freq=1 21 | training_loop.save_freq=1 22 | training_loop.reload=False 23 | 24 | # Dataset 25 | train.datasets=['cifar'] 26 | cifar.stream_seed=1 27 | cifar.n_examples=-1 28 | cifar.one_hot=True 29 | cifar.use_valid=False 30 | train.data_seed=777 31 | -------------------------------------------------------------------------------- /tf2_project_template/configs/scnn_neptune.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | train.model='SimpleCNN' 3 | SimpleCNN.bn=False # Stateless 4 | SimpleCNN.kernel_size=5 5 | SimpleCNN.n_filters=32 6 | 7 | # Train configuration 8 | train.batch_size=128 9 | train.callbacks=['lr_schedule', 'neptune_monitor'] 10 | train.momentum=0.0 11 | train.wd=0.0 12 | train.seed=777 13 | train.steps_per_epoch=1 # For speed purposes. Change to -1. 14 | train.data_seed=777 15 | LRSchedule.base_lr=0.03 16 | LRSchedule.schedule=[[150, 1.0],[225, 0.1],[30000, 0.01]] # A pretty standard schedule, for some reason 17 | 18 | # Training loop 19 | training_loop.n_epochs=2 20 | training_loop.evaluation_freq=1 21 | training_loop.save_freq=1 22 | training_loop.reload=False 23 | 24 | # Dataset 25 | train.datasets=['cifar'] 26 | cifar.stream_seed=1 27 | cifar.n_examples=-1 28 | cifar.one_hot=True 29 | cifar.use_valid=False 30 | train.data_seed=777 31 | -------------------------------------------------------------------------------- /pytorch_project_template/src/models/lenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | LeNet model 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import gin 10 | 11 | @gin.configurable 12 | class LeNet(nn.Module): 13 | def __init__(self): 14 | super(LeNet, self).__init__() 15 | self.conv1 = nn.Conv2d(3, 6, 5) 16 | self.conv2 = nn.Conv2d(6, 16, 5) 17 | self.fc1 = nn.Linear(16*5*5, 120) 18 | self.fc2 = nn.Linear(120, 84) 19 | self.fc3 = nn.Linear(84, 10) 20 | 21 | def forward(self, x): 22 | out = F.relu(self.conv1(x)) 23 | out = F.max_pool2d(out, 2) 24 | out = F.relu(self.conv2(out)) 25 | out = F.max_pool2d(out, 2) 26 | out = out.view(out.size(0), -1) 27 | out = F.relu(self.fc1(out)) 28 | out = F.relu(self.fc2(out)) 29 | out = self.fc3(out) 30 | return F.log_softmax(out, dim=1) -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/models/simple_cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | SimpleCNN model 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import gin 11 | 12 | 13 | @gin.configurable 14 | class SimpleCNN(nn.Module): 15 | def __init__(self, n_filters, n_dense): 16 | super(SimpleCNN, self).__init__() 17 | self.conv1 = nn.Conv2d(3, int(n_filters), kernel_size=5) 18 | self.conv2 = nn.Conv2d(int(n_filters), int(n_filters), kernel_size=5) 19 | self.fc1 = nn.Linear(int(n_filters) * 25, int(n_dense)) # Oh well.. 20 | self.fc2 = nn.Linear(int(n_dense), 10) 21 | 22 | def forward(self, x): 23 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 24 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 25 | x = x.view(x.size(0), -1) 26 | x = F.relu(self.fc1(x)) 27 | x = self.fc2(x) 28 | return x 29 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/models/lenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | LeNet model 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import gin 10 | 11 | @gin.configurable 12 | class LeNet(nn.Module): 13 | def __init__(self): 14 | super(LeNet, self).__init__() 15 | self.conv1 = nn.Conv2d(3, 6, 5) 16 | self.conv2 = nn.Conv2d(6, 16, 5) 17 | self.fc1 = nn.Linear(16*5*5, 120) 18 | self.fc2 = nn.Linear(120, 84) 19 | self.fc3 = nn.Linear(84, 10) 20 | 21 | def forward(self, x): 22 | out = F.relu(self.conv1(x)) 23 | out = F.max_pool2d(out, 2) 24 | out = F.relu(self.conv2(out)) 25 | out = F.max_pool2d(out, 2) 26 | out = out.view(out.size(0), -1) 27 | out = F.relu(self.fc1(out)) 28 | out = F.relu(self.fc2(out)) 29 | out = self.fc3(out) 30 | return F.log_softmax(out, dim=1) -------------------------------------------------------------------------------- /pytorch_project_template/src/models/simple_cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | SimpleCNN model 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import gin 11 | 12 | 13 | @gin.configurable 14 | class SimpleCNN(nn.Module): 15 | def __init__(self, n_filters, n_dense): 16 | super(SimpleCNN, self).__init__() 17 | self.conv1 = nn.Conv2d(3, int(n_filters), kernel_size=5) 18 | self.conv2 = nn.Conv2d(int(n_filters), int(n_filters), kernel_size=5) 19 | self.fc1 = nn.Linear(int(n_filters) * 25, int(n_dense)) # Oh well.. 20 | self.fc2 = nn.Linear(int(n_dense), 10) 21 | 22 | def forward(self, x): 23 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 24 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 25 | x = x.view(x.size(0), -1) 26 | x = F.relu(self.fc1(x)) 27 | x = self.fc2(x) 28 | return F.log_softmax(x, dim=1) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 gmum.net 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/e.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export PNAME="pytorch_project_template" 3 | export ROOT="$( cd "$(dirname "$0")" ; pwd -P )" 4 | echo "Welcome to $PNAME rooted at $ROOT" 5 | echo "-" 6 | 7 | # Activates conda environment 8 | source activate ${PNAME} 9 | 10 | # Configures paths. Adapt to your needs! 11 | export PYTHONPATH=$ROOT:$PYTHONPATH 12 | export DATA_DIR=$ROOT/data 13 | export RESULTS_DIR=$ROOT/results 14 | 15 | # Optional: enables plotting in iterm 16 | # export MPLBACKEND="module://itermplot" 17 | 18 | # Switches off importing out of environment packages 19 | export PYTHONNOUSERSITE=1 20 | 21 | # Optional: enables integration with Neptune 22 | export NEPTUNE_TOKEN= 23 | export NEPTUNE_PROJECT= 24 | export NEPTUNE_USER= 25 | # export NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE=1 # Uncomment if you have issues with SSL 26 | 27 | if [ ! -d "${DATA_DIR}" ]; then 28 | echo "Creating ${DATA_DIR}" 29 | mkdir -p ${DATA_DIR} 30 | fi 31 | 32 | if [ ! -d "${RESULTS_DIR}" ]; then 33 | echo "Creating ${RESULTS_DIR}" 34 | mkdir -p ${RESULTS_DIR} 35 | fi 36 | -------------------------------------------------------------------------------- /tf2_project_template/experiments/tune_lr/README.md: -------------------------------------------------------------------------------- 1 | # Tune learning rate experiment 2 | 3 | Experiment conceptually is a list of shell jobs. For convenience this can be wrapped using a python script that prepares jobs, analyses the runs, stores configs, etc. 4 | 5 | We ship an example experiment, where we tune LR for the small CNN on Cifar10. Here is the typical workflow: 6 | 7 | 1. Prepare experiments: `python experiments/tune_lr/main.py prepare` 8 | 9 | 2. See prepare configs: `ls experiments/tune_lr/large/configs` 10 | 11 | 3. Run experiments: `bash experiments/tune_lr/large/batch.sh` 12 | 13 | 4. See runs: `ls $RESULTS_DIR/tune_lr/large` 14 | 15 | 5. Process experiment results: `python experiments/tune_lr/main.py report`. Bonus for OSX users: To enable plotting in iterm install ``pip install itermplot``, and uncomment the appropriate line in ``e.sh```. 16 | 17 | 6. Take a look at the main.py source code to understand better the logic. 18 | 19 | Note that running a list of shell jobs can be done using a scheduler. This is best if you develop your own 20 | solution for runnning efficiently such a list. -------------------------------------------------------------------------------- /pytorch_project_template/experiments/tune_lr/README.md: -------------------------------------------------------------------------------- 1 | # Tune learning rate experiment 2 | 3 | Experiment conceptually is a list of shell jobs. For convenience this can be wrapped using a python script that prepares jobs, analyses the runs, stores configs, etc. 4 | 5 | We ship an example experiment, where we tune LR for the small CNN on Cifar10. Here is the typical workflow: 6 | 7 | 1. Prepare experiments: `python experiments/tune_lr/main.py prepare` 8 | 9 | 2. See prepare configs: `ls experiments/tune_lr/large/configs` 10 | 11 | 3. Run experiments: `bash experiments/tune_lr/large/batch.sh` 12 | 13 | 4. See runs: `ls $RESULTS_DIR/tune_lr/large` 14 | 15 | 5. Process experiment results: `python experiments/tune_lr/main.py report`. Bonus for OSX users: To enable plotting in iterm install ``pip install itermplot``, and uncomment the appropriate line in ``e.sh```. 16 | 17 | 6. Take a look at the main.py source code to understand better the logic. 18 | 19 | Note that running a list of shell jobs can be done using a scheduler. This is best if you develop your own 20 | solution for runnning efficiently such a list. -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/README.md: -------------------------------------------------------------------------------- 1 | # Tune learning rate experiment 2 | 3 | Experiment conceptually is a list of shell jobs. For convenience this can be wrapped using a python script that prepares jobs, analyses the runs, stores configs, etc. 4 | 5 | We ship an example experiment, where we tune LR for the small CNN on Cifar10. Here is the typical workflow: 6 | 7 | 1. Prepare experiments: `python experiments/tune_lr/main.py prepare` 8 | 9 | 2. See prepare configs: `ls experiments/tune_lr/large/configs` 10 | 11 | 3. Run experiments: `bash experiments/tune_lr/large/batch.sh` 12 | 13 | 4. See runs: `ls $RESULTS_DIR/tune_lr/large` 14 | 15 | 5. Process experiment results: `python experiments/tune_lr/main.py report`. Bonus for OSX users: To enable plotting in iterm install ``pip install itermplot``, and uncomment the appropriate line in ``e.sh```. 16 | 17 | 6. Take a look at the main.py source code to understand better the logic. 18 | 19 | Note that running a list of shell jobs can be done using a scheduler. This is best if you develop your own 20 | solution for runnning efficiently such a list. -------------------------------------------------------------------------------- /tf2_project_template/src/data/streams.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Streams used in the project (e.g. augmentation) 4 | """ 5 | import numpy as np 6 | 7 | class DatasetGenerator(object): 8 | def __init__(self, dataset, seed, batch_size, shuffle=True): 9 | self.dataset = dataset 10 | self.seed = seed 11 | self.batch_size = batch_size 12 | self.shuffle = shuffle 13 | self.rng = np.random.RandomState(seed) 14 | 15 | def __iter__(self): 16 | if self.shuffle: 17 | ids = self.rng.choice(len(self.dataset[0]), len(self.dataset[0]), replace=False) 18 | else: 19 | ids = np.arange(len(self.dataset[0])) 20 | self.dataset = [self.dataset[0][ids], self.dataset[1][ids]] 21 | def _iter(): 22 | for id in range((len(self.dataset[0]) + self.batch_size - 1) // self.batch_size): 23 | yield self.dataset[0][id * self.batch_size:(id + 1) * self.batch_size], \ 24 | self.dataset[1][id * self.batch_size:(id + 1) * self.batch_size] 25 | return _iter() 26 | 27 | def __len__(self): 28 | return len(self.dataset[0]) -------------------------------------------------------------------------------- /pytorch_project_template/src/data/streams.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Streams used in the project (e.g. augmentation) 4 | """ 5 | import numpy as np 6 | 7 | class DatasetGenerator(object): 8 | def __init__(self, dataset, seed, batch_size, shuffle=True): 9 | self.dataset = dataset 10 | self.seed = seed 11 | self.batch_size = batch_size 12 | self.shuffle = shuffle 13 | self.rng = np.random.RandomState(seed) 14 | 15 | def __iter__(self): 16 | if self.shuffle: 17 | ids = self.rng.choice(len(self.dataset[0]), len(self.dataset[0]), replace=False) 18 | else: 19 | ids = np.arange(len(self.dataset[0])) 20 | self.dataset = [self.dataset[0][ids], self.dataset[1][ids]] 21 | def _iter(): 22 | for id in range((len(self.dataset[0]) + self.batch_size - 1) // self.batch_size): 23 | yield self.dataset[0][id * self.batch_size:(id + 1) * self.batch_size], \ 24 | self.dataset[1][id * self.batch_size:(id + 1) * self.batch_size] 25 | return _iter() 26 | 27 | def __len__(self): 28 | return len(self.dataset[0]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Toolkit 2 | 3 | Even using frameworks such as Keras or Pytorch Lightning leaves you with making many choices on how to organize your machine learning project. 4 | The main ambition of this repository is to make easy following the state-of-the-art good practices for a machine learning project, 5 | for the most popular packages/frameworks. 6 | 7 | Maintained and distributed by GMUM (https://gmum.net/). 8 | 9 | The following templates are provided: 10 | 11 | * **pytorch_lightning_project_template** - template for a ML project based on PyTorch and Pytorch Lightning. Advocates using modular code, gin configuration, and neptune. 12 | Last updated 10.2020. 13 | 14 | * **tf2_project_template** - template for a ML project based on vanilla TF2. Last updated 6.2020. Includes generic improvements compared 15 | to the PyTorch template such as automatic syncing of files, using watchman or integration with Neptune, or standalone training loop. 16 | 17 | * **pytorch_project_template** - template for a ML project based on vanilla pytorch. Last updated 12.2019. 18 | 19 | * **keras_project_template** - template for a ML project based on keras. Last updated 04.2018. 20 | -------------------------------------------------------------------------------- /tf2_project_template/experiments/tune_lr/large/run.sh: -------------------------------------------------------------------------------- 1 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/0 experiments/tune_lr/large/configs/0.gin -b training_loop.reload=True 2 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/1 experiments/tune_lr/large/configs/1.gin -b training_loop.reload=True 3 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/2 experiments/tune_lr/large/configs/2.gin -b training_loop.reload=True 4 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/3 experiments/tune_lr/large/configs/3.gin -b training_loop.reload=True 5 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/4 experiments/tune_lr/large/configs/4.gin -b training_loop.reload=True 6 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/5 experiments/tune_lr/large/configs/5.gin -b training_loop.reload=True 7 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/large/run.sh: -------------------------------------------------------------------------------- 1 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/0 experiments/tune_lr/large/configs/0.gin -b training_loop.reload=True 2 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/1 experiments/tune_lr/large/configs/1.gin -b training_loop.reload=True 3 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/2 experiments/tune_lr/large/configs/2.gin -b training_loop.reload=True 4 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/3 experiments/tune_lr/large/configs/3.gin -b training_loop.reload=True 5 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/4 experiments/tune_lr/large/configs/4.gin -b training_loop.reload=True 6 | python3 bin/train.py /Users/jastrs01/Dropbox/Projekty/toolkit/vanilla_pytorch_project_template/results/tune_lr/large/5 experiments/tune_lr/large/configs/5.gin -b training_loop.reload=True 7 | -------------------------------------------------------------------------------- /tf2_project_template/env.sh: -------------------------------------------------------------------------------- 1 | export ENVCALLED=1 2 | export PNAME=tf2_project_template 3 | 4 | export ROOT="$( cd "$(dirname "$0")" ; pwd -P )" 5 | echo "Welcome to $PNAME rooted at $ROOT" 6 | echo "-" 7 | 8 | # Activates conda environment 9 | source activate $PNAME 10 | 11 | # Configures paths. Adapt to your needs! 12 | export PYTHONPATH=$ROOT:$PYTHONPATH 13 | export DATA_DIR=$ROOT/data 14 | export RESULTS_DIR=$ROOT/results 15 | 16 | # Optional: enables plotting in iterm 17 | export MPLBACKEND="module://itermplot" 18 | 19 | # Optional: enables copying figures to Dropbox (useful for managing figures for paper) 20 | export DROPBOXTOKEN=FILLME 21 | 22 | # Optional: enables integration with Neptune 23 | export NEPTUNE_TOKEN=FILLME 24 | export NEPTUNE_PROJECT=tf2projecttemplate 25 | export NEPTUNE_USER=FILLME 26 | # export NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE=1 # Uncomment if you have issues with SSL 27 | 28 | # Switches off importing out of environment packages 29 | export PYTHONNOUSERSITE=1 30 | 31 | if [ ! -d "${DATA_DIR}" ]; then 32 | echo "Creating ${DATA_DIR}" 33 | mkdir -p ${DATA_DIR} 34 | fi 35 | 36 | if [ ! -d "${RESULTS_DIR}" ]; then 37 | echo "Creating ${RESULTS_DIR}" 38 | mkdir -p ${RESULTS_DIR} 39 | fi -------------------------------------------------------------------------------- /pytorch_project_template/src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relevant constants and configurations 3 | """ 4 | import os 5 | import matplotlib.style 6 | import matplotlib as mpl 7 | from .utils import configure_logger 8 | os.environ['KERAS_BACKEND'] = 'tensorflow' 9 | 10 | # Configure paths 11 | DATA_DIR = os.environ.get("DATA_DIR", os.path.join(os.path.dirname(__file__), "data")) 12 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 13 | 14 | # Configure logger 15 | configure_logger('') 16 | 17 | # Some useful plotting styles 18 | mpl.style.use('seaborn-colorblind') 19 | mpl.rcParams.update({'font.size': 14, 'lines.linewidth': 2, 'figure.figsize': (6, 6/1.61)}) 20 | mpl.rcParams['grid.color'] = 'k' 21 | mpl.rcParams['grid.linestyle'] = ':' 22 | mpl.rcParams['grid.linewidth'] = 0.5 23 | mpl.rcParams['lines.markersize'] = 6 24 | mpl.rcParams['lines.marker'] = None 25 | mpl.rcParams['axes.grid'] = True 26 | DEFAULT_FONTSIZE = 13 27 | mpl.rcParams.update({'font.size': DEFAULT_FONTSIZE, 'lines.linewidth': 2, 28 | 'legend.fontsize': DEFAULT_FONTSIZE, 'axes.labelsize': DEFAULT_FONTSIZE, 29 | 'xtick.labelsize': DEFAULT_FONTSIZE, 'ytick.labelsize': DEFAULT_FONTSIZE, 30 | 'figure.figsize': (7, 7.0/1.4)}) 31 | -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/update_plots.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple self-contained script template (you need to update paths) to update and remove unused figures in a tex file 5 | """ 6 | import glob 7 | import os 8 | import tqdm 9 | 10 | # Configure this 11 | PATH_TO_TEX = "papers/entanglement/main.tex" 12 | SRC1 = "experiments/04_2020_sweeps" 13 | SRC2 = "experiments/05_2020_understand_grads" 14 | SRCS = [SRC1, SRC2] 15 | REMOVE_UNUSED_FIGURES = True 16 | PAPER_SRCS = ["papers/entanglement/figs/*pdf"] 17 | 18 | # Go through all sources and update if there is a newer pdf 19 | TEX = open(PATH_TO_TEX).read() 20 | for SRC in SRCS: 21 | for SOURCE in PAPER_SRCS: 22 | files = list(glob.glob(SOURCE)) 23 | for f in tqdm.tqdm(files, total=len(files)): 24 | if os.path.basename(f) not in TEX: 25 | print("WARNING! {} not found in main.tex".format(f)) 26 | if REMOVE_UNUSED_FIGURES: 27 | os.system("rm " + f) 28 | else: 29 | print("OK") 30 | a = os.path.join(SRC, os.path.basename(f)) 31 | 32 | if "@" in SRC1: 33 | os.system("scp {} {}".format(a, f)) 34 | else: 35 | os.system("cp {} {}".format(a, f)) 36 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/update_plots.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple self-contained script template (you need to update paths) to update and remove unused figures in a tex file 5 | """ 6 | import glob 7 | import os 8 | import tqdm 9 | 10 | # Configure this 11 | PATH_TO_TEX = "papers/entanglement/main.tex" 12 | SRC1 = "experiments/04_2020_sweeps" 13 | SRC2 = "experiments/05_2020_understand_grads" 14 | SRCS = [SRC1, SRC2] 15 | REMOVE_UNUSED_FIGURES = True 16 | PAPER_SRCS = ["papers/entanglement/figs/*pdf"] 17 | 18 | # Go through all sources and update if there is a newer pdf 19 | TEX = open(PATH_TO_TEX).read() 20 | for SRC in SRCS: 21 | for SOURCE in PAPER_SRCS: 22 | files = list(glob.glob(SOURCE)) 23 | for f in tqdm.tqdm(files, total=len(files)): 24 | if os.path.basename(f) not in TEX: 25 | print("WARNING! {} not found in main.tex".format(f)) 26 | if REMOVE_UNUSED_FIGURES: 27 | os.system("rm " + f) 28 | else: 29 | print("OK") 30 | a = os.path.join(SRC, os.path.basename(f)) 31 | 32 | if "@" in SRC1: 33 | os.system("scp {} {}".format(a, f)) 34 | else: 35 | os.system("cp {} {}".format(a, f)) 36 | -------------------------------------------------------------------------------- /keras_project_template/src/scripts/train_simple_CNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trains simple CNN on cifar10/cifar100 5 | 6 | Run like: python src/scripts/train_simple_CNN.py cifar10 results/test_run 7 | """ 8 | 9 | from keras.optimizers import SGD 10 | 11 | from src.configs.simple_CNN import simple_CNN_configs 12 | from src.data import datasets 13 | from src.models import build_simple_model 14 | from src.training_loop import cifar_training_loop 15 | from src.vegab import main, MetaSaver, AutomaticNamer 16 | 17 | def train(config, save_path): 18 | # Load data 19 | train, test, _ = datasets(dataset=config['dataset'], batch_size=config['batch_size'], 20 | augmented=config['augmented'], preprocessing='center') 21 | 22 | # Load model 23 | 24 | model = build_simple_model(config) 25 | optimizer = SGD(lr=config['lr_schedule'][0][0], momentum=0.9) 26 | model.compile(optimizer=optimizer, 27 | loss='categorical_crossentropy', 28 | metrics=['accuracy']) 29 | 30 | # Call training loop (warning: using test as valid. Please don't do this) 31 | cifar_training_loop(model=model, train=train, valid=test, learning_rate_schedule=config['lr_schedule'], 32 | save_path=save_path, n_epochs=config['n_epochs']) 33 | 34 | 35 | if __name__ == "__main__": 36 | main(simple_CNN_configs, train, 37 | plugins=[MetaSaver()]) 38 | -------------------------------------------------------------------------------- /keras_project_template/src/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Simple model definitions 5 | """ 6 | 7 | from keras.datasets import cifar10 8 | from keras.preprocessing.image import ImageDataGenerator 9 | from keras.models import Sequential 10 | from keras.layers import Dense, Dropout, Activation, Flatten 11 | from keras.layers import Conv2D, MaxPooling2D 12 | 13 | def build_simple_model(config): 14 | model = Sequential() 15 | 16 | model.add(Conv2D(config['n_filters'], (3, 3), padding='same', data_format='channels_first', 17 | input_shape=(3, 32, 32))) 18 | model.add(Activation('relu')) 19 | model.add(Conv2D(32, (3, 3), data_format='channels_first')) 20 | model.add(Activation('relu')) 21 | model.add(MaxPooling2D(pool_size=(2, 2), data_format='channels_first')) 22 | model.add(Dropout(0.25)) 23 | 24 | model.add(Conv2D(64, (3, 3), padding='same', data_format='channels_first')) 25 | model.add(Activation('relu')) 26 | model.add(Conv2D(64, (3, 3), data_format='channels_first')) 27 | model.add(Activation('relu')) 28 | model.add(MaxPooling2D(pool_size=(2, 2), data_format='channels_first')) 29 | model.add(Dropout(0.25)) 30 | 31 | model.add(Flatten()) 32 | model.add(Dense(config['dim_dense'])) 33 | model.add(Activation('relu')) 34 | model.add(Dropout(0.5)) 35 | model.add(Dense(10)) 36 | model.add(Activation('softmax')) 37 | 38 | return model -------------------------------------------------------------------------------- /pytorch_project_template/bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trainer script. Example run command: bin/train.py save_to_folder configs/cnn.gin. 5 | """ 6 | 7 | import gin 8 | from gin.config import _CONFIG 9 | import torch 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | from src.data import get_dataset 14 | from src import models 15 | from src.training_loop import training_loop 16 | from src.callbacks import get_callback 17 | from src.utils import summary, acc, gin_wrap 18 | 19 | @gin.configurable 20 | def train(save_path, model, lr=0.1, batch_size=128, callbacks=[]): 21 | # Create dynamically dataset generators 22 | train, valid, test, meta_data = get_dataset(batch_size=batch_size) 23 | 24 | # Create dynamically model 25 | model = models.__dict__[model]() 26 | summary(model) 27 | loss_function = torch.nn.CrossEntropyLoss() 28 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 29 | 30 | # Create dynamically callbacks 31 | callbacks_constructed = [] 32 | for name in callbacks: 33 | clbk = get_callback(name, verbose=0) 34 | if clbk is not None: 35 | callbacks_constructed.append(clbk) 36 | 37 | # Pass everything to the training loop 38 | training_loop(model=model, optimizer=optimizer, loss_function=loss_function, metrics=[acc], 39 | train=train, valid=test, meta_data=meta_data, save_path=save_path, config=_CONFIG, 40 | use_tb=True, custom_callbacks=callbacks_constructed) 41 | 42 | 43 | if __name__ == "__main__": 44 | gin_wrap(train) 45 | -------------------------------------------------------------------------------- /keras_project_template/.gitignore: -------------------------------------------------------------------------------- 1 | .env.sh 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relevant constants and configurations 3 | """ 4 | import os 5 | import matplotlib.style 6 | from .utils import configure_logger 7 | 8 | # Data 9 | DATA_FORMAT = os.environ.get("DATA_FORMAT", "channels_first") 10 | DATA_NUM_WORKERS = int(os.environ.get("DATA_NUM_WORKERS", "2")) 11 | 12 | # Configure paths 13 | DATA_DIR = os.environ.get("DATA_DIR", os.path.join(os.path.dirname(__file__), "data")) 14 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 15 | 16 | # Configure logger 17 | configure_logger('') 18 | 19 | # Opinionated default plotting styles 20 | import matplotlib as mpl 21 | DEFAULT_FIGSIZE = 8 22 | MARKERS="oxP.X" # TODO: Update 23 | DEFAULT_LINEWIDTH = 3 24 | DEFAULT_FONTSIZE = 22 25 | mpl.style.use('seaborn-colorblind') 26 | mpl.rcParams['figure.facecolor'] = 'w' 27 | mpl.rcParams.update({'font.size': 14, 'lines.linewidth': 4, 'figure.figsize': (DEFAULT_FIGSIZE, DEFAULT_FIGSIZE / 1.61)}) 28 | mpl.rcParams['grid.color'] = 'k' 29 | mpl.rcParams['grid.linestyle'] = ':' 30 | mpl.rcParams['errorbar.capsize'] = 2 31 | mpl.rcParams['image.cmap'] = 'cividis' 32 | mpl.rcParams['grid.linewidth'] = 0.5 33 | mpl.rcParams['lines.markersize'] = 6 34 | mpl.rcParams['lines.marker'] = None 35 | mpl.rcParams['axes.grid'] = True 36 | COLORS = mpl.rcParams["axes.prop_cycle"].by_key()["color"] 37 | mpl.rcParams.update({'font.size': DEFAULT_FONTSIZE, 'lines.linewidth': DEFAULT_LINEWIDTH, 38 | 'legend.fontsize': DEFAULT_FONTSIZE, 'axes.labelsize': DEFAULT_FONTSIZE, 39 | 'xtick.labelsize': DEFAULT_FONTSIZE, 'ytick.labelsize': DEFAULT_FONTSIZE, 40 | 'figure.figsize': (7, 7.0 / 1.4)}) 41 | 42 | # Configure Neptun (optinal) 43 | NEPTUNE_TOKEN = os.environ["NEPTUNE_TOKEN"] 44 | NEPTUNE_USER = os.environ["NEPTUNE_USER"] 45 | NEPTUNE_PROJECT = os.environ["NEPTUNE_PROJECT"] -------------------------------------------------------------------------------- /pytorch_project_template/results/example_run/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trains simple CNN on cifar10/cifar100 5 | 6 | Run like: 7 | * python bin/train.py cifar10 results/test_run 8 | * python bin/train.py cifar10 results/test_run --model.n_filters=20 9 | * python bin/train.py cifar10_lenet results/test_run 10 | """ 11 | 12 | import gin 13 | from gin.config import _OPERATIVE_CONFIG 14 | import torch 15 | import logging 16 | logger = logging.getLogger(__name__) 17 | 18 | from src.data import get_dataset 19 | from src import models 20 | from src.training_loop import training_loop 21 | from src.callbacks import get_callback 22 | from src.utils import summary, acc, gin_wrap 23 | 24 | @gin.configurable 25 | def train(save_path, model, lr=0.1, batch_size=128, callbacks=[]): 26 | # Create dynamically dataset generators 27 | train, valid, test, meta_data = get_dataset(batch_size=batch_size) 28 | 29 | # Create dynamically model 30 | model = models.__dict__[model]() 31 | summary(model) 32 | loss_function = torch.nn.MSELoss() 33 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 34 | 35 | # Create dynamically callbacks 36 | callbacks_constructed = [] 37 | for name in callbacks: 38 | clbk = get_callback(name, verbose=0) 39 | if clbk is not None: 40 | callbacks_constructed.append(clbk) 41 | 42 | # Pass everything to the training loop 43 | steps_per_epoch = (len(meta_data['x_train']) - 1) // batch_size + 1 44 | training_loop(model=model, optimizer=optimizer, loss_function=loss_function, metrics=[acc], 45 | train=train, valid=test, meta_data=meta_data, steps_per_epoch=steps_per_epoch, 46 | save_path=save_path, config=_OPERATIVE_CONFIG, 47 | use_tb=True, custom_callbacks=callbacks_constructed) 48 | 49 | 50 | if __name__ == "__main__": 51 | gin_wrap(train) 52 | -------------------------------------------------------------------------------- /pytorch_project_template/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Simple data getters. Each returns iterator for train and dataset for test/valid. 4 | """ 5 | import os 6 | import gin 7 | from functools import partial 8 | import logging 9 | import numpy as np 10 | 11 | from .datasets import cifar, mnist 12 | from .streams import DatasetGenerator 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | @gin.configurable 17 | def get_dataset(dataset, n_examples, data_seed, batch_size): 18 | train, valid, test, meta_data = globals()[dataset](seed=data_seed) 19 | 20 | if n_examples > 0: 21 | assert len(train[0]) >= n_examples 22 | train = [train[0][0:n_examples], train[1][0:n_examples]] 23 | 24 | # Configure stream + optionally augmentation 25 | meta_data['x_train'] = train[0] 26 | meta_data['y_train'] = train[1] 27 | 28 | train_stream = DatasetGenerator(train, seed=data_seed, batch_size=batch_size, shuffle=True) 29 | 30 | # Save some extra versions of the dataset. Just a pattern that is useful. 31 | train_stream_duplicated = DatasetGenerator(train, seed=data_seed, batch_size=batch_size, shuffle=True) 32 | x_train_aug, y_train_aug = [], [] 33 | n = 0 34 | for x, y in train_stream_duplicated: 35 | x_train_aug.append(x) 36 | y_train_aug.append(y) 37 | n += len(x) 38 | if n >= len(meta_data['x_train']): 39 | break 40 | meta_data['x_train_aug'] = np.concatenate(x_train_aug, axis=0)[0:len(meta_data['x_train'])] 41 | meta_data['y_train_aug'] = np.concatenate(y_train_aug, axis=0)[0:len(meta_data['x_train'])] 42 | meta_data['train_stream_duplicated'] = train_stream_duplicated 43 | 44 | # Return 45 | valid = DatasetGenerator(valid, seed=data_seed, batch_size=batch_size, shuffle=False) 46 | test = DatasetGenerator(test, seed=data_seed, batch_size=batch_size, shuffle=False) 47 | return train_stream, valid, test, meta_data 48 | -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/run_on_a_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple script to run on a list of jobs on a GPU 5 | 6 | See bin/utils/run_on_a_gpu.py -h 7 | 8 | There are the following requirements for the batch: 9 | * Each command has saving dir as 2nd argument 10 | * Each script saves to the save_dir HEARTBEAT 11 | * Each script saves to the save_dir FINISHED when done 12 | """ 13 | 14 | import os 15 | import time 16 | import numpy as np 17 | from os import path 18 | import pandas as pd 19 | import argh 20 | 21 | from src import configure_logger # Will actually configure already 22 | 23 | import logging 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def get_last_modification(save_path): 29 | f_path = os.path.join(save_path, "HEARTBEAT") 30 | if os.path.exists(f_path): 31 | return time.time() - os.path.getmtime(f_path) 32 | else: 33 | return np.inf 34 | 35 | 36 | def get_save_path(job): 37 | if job.startswith("python"): 38 | return job.split(" ")[3] 39 | else: 40 | return job.split(" ")[2] 41 | 42 | 43 | def has_finished(save_path): 44 | # Hacky but works, usually 45 | return path.exists(path.join(save_path, "FINISHED")) 46 | 47 | 48 | def get_jobs(batch): 49 | jobs = list(open(batch, "r").read().splitlines()) 50 | jobs = [j for j in jobs if not has_finished(get_save_path(j))] 51 | # take only at least 10min old jobs 52 | jobs = [j for j in jobs if get_last_modification(get_save_path(j)) > 600] 53 | np.random.shuffle(jobs) 54 | return jobs 55 | 56 | def shell_single(batch, gpu=-1): 57 | # Assumes gpu is configured 58 | while True: 59 | jobs = get_jobs(batch) 60 | logger.info("Found {}".format(len(jobs))) 61 | job = jobs[0] 62 | logger.info("Running " + job) 63 | if gpu==-1: 64 | os.system(job) 65 | else: 66 | os.system("CUDA_VISIBLE_DEVICES={} {}".format(gpu, job)) 67 | # Allow to kill easilyq 68 | time.sleep(5) 69 | 70 | 71 | if __name__ == "__main__": 72 | argh.dispatch_command(shell_single) 73 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/run_on_a_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple script to run on a list of jobs on a GPU 5 | 6 | See bin/utils/run_on_a_gpu.py -h 7 | 8 | There are the following requirements for the batch: 9 | * Each command has saving dir as 2nd argument 10 | * Each script saves to the save_dir HEARTBEAT 11 | * Each script saves to the save_dir FINISHED when done 12 | """ 13 | 14 | import os 15 | import time 16 | import numpy as np 17 | from os import path 18 | import pandas as pd 19 | import argh 20 | 21 | from src import configure_logger # Will actually configure already 22 | 23 | import logging 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def get_last_modification(save_path): 29 | f_path = os.path.join(save_path, "HEARTBEAT") 30 | if os.path.exists(f_path): 31 | return time.time() - os.path.getmtime(f_path) 32 | else: 33 | return np.inf 34 | 35 | 36 | def get_save_path(job): 37 | if job.startswith("python"): 38 | return job.split(" ")[3] 39 | else: 40 | return job.split(" ")[2] 41 | 42 | 43 | def has_finished(save_path): 44 | # Hacky but works, usually 45 | return path.exists(path.join(save_path, "FINISHED")) 46 | 47 | 48 | def get_jobs(batch): 49 | jobs = list(open(batch, "r").read().splitlines()) 50 | jobs = [j for j in jobs if not has_finished(get_save_path(j))] 51 | # take only at least 10min old jobs 52 | jobs = [j for j in jobs if get_last_modification(get_save_path(j)) > 600] 53 | np.random.shuffle(jobs) 54 | return jobs 55 | 56 | def shell_single(batch, gpu=-1): 57 | # Assumes gpu is configured 58 | while True: 59 | jobs = get_jobs(batch) 60 | logger.info("Found {}".format(len(jobs))) 61 | job = jobs[0] 62 | logger.info("Running " + job) 63 | if gpu==-1: 64 | os.system(job) 65 | else: 66 | os.system("CUDA_VISIBLE_DEVICES={} {}".format(gpu, job)) 67 | # Allow to kill easilyq 68 | time.sleep(5) 69 | 70 | 71 | if __name__ == "__main__": 72 | argh.dispatch_command(shell_single) 73 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/evaluate_supervised.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trainer script for src.pl_modules.supervised_learning. Example run command: bin/train.py save_to_folder configs/cnn.gin. 5 | """ 6 | 7 | import gin 8 | from gin.config import _CONFIG 9 | import torch 10 | import logging 11 | import os 12 | import json 13 | 14 | import torch 15 | import pytorch_lightning as pl 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | from src.data import get_dataset 20 | from src.utils import summary, acc, gin_wrap, parse_gin_config 21 | from src.modules.supervised_training import SupervisedLearning 22 | # Ensure gin seens all classes 23 | from bin.train_supervised import * 24 | 25 | import argh 26 | 27 | def evaluate(save_path, checkpoint_name="weights.ckpt"): 28 | # Load config 29 | config = parse_gin_config(os.path.join(save_path, "config.gin")) 30 | gin.parse_config_files_and_bindings([os.path.join(os.path.join(save_path, "config.gin"))], bindings=[""]) 31 | 32 | # Create dynamically dataset generators 33 | train, valid, test, meta_data = get_dataset(batch_size=config['train.batch_size'], seed=config['train.seed']) 34 | 35 | # Load model (a bit hacky, but necessary because load_from_checkpoint seems to fail) 36 | ckpt_path = os.path.join(save_path, checkpoint_name) 37 | ckpt = torch.load(ckpt_path) 38 | model = models.__dict__[config['train.model']]() 39 | summary(model) 40 | pl_module = SupervisedLearning(model, lr=0.0) 41 | pl_module.load_state_dict(ckpt['state_dict']) 42 | 43 | # NOTE: This fails, probably due to a bug in Pytorch Lightning. The above is manually doing something similar 44 | # ckpt_path = os.path.join(save_path, checkpoint_name) 45 | # pl_module = SupervisedLearning.load_from_checkpoint(ckpt_path) 46 | 47 | trainer = pl.Trainer() 48 | results, = trainer.test(model=pl_module, test_dataloaders=test, ckpt_path=ckpt_path) 49 | logger.info(results) 50 | with open(os.path.join(save_path, "eval_results_{}.json".format(checkpoint_name)), "w") as f: 51 | json.dump(results, f) 52 | 53 | if __name__ == "__main__": 54 | argh.dispatch_command(evaluate) 55 | -------------------------------------------------------------------------------- /tf2_project_template/src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relevant constants and configurations 3 | """ 4 | # Env check 5 | import os 6 | if not int(os.environ.get("ENVCALLED", 0)): 7 | raise RuntimeError("Please source env before working on the project") 8 | 9 | # Tensorflow and Keras specific constants 10 | import tensorflow 11 | assert tensorflow.__version__[0] == '2' 12 | DATA_FORMAT = os.environ.get("DATA_FORMAT", "channels_first") 13 | tensorflow.keras.backend.set_image_data_format(DATA_FORMAT) 14 | os.environ['KERAS_BACKEND'] = 'tensorflow' 15 | 16 | # Configure paths 17 | PROJECT_NAME = os.environ['PNAME'] 18 | PROJECT_DIR = os.path.join(os.path.dirname(__file__), "..") 19 | DATA_DIR = os.environ.get("DATA_DIR", os.path.join(os.path.dirname(__file__), "data")) 20 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 21 | 22 | # Configure logger 23 | import logging 24 | from .utils import configure_logger 25 | configure_logger('') 26 | logger = logging.getLogger() 27 | 28 | # Opinionated default plotting styles 29 | import matplotlib.style 30 | import matplotlib as mpl 31 | DEFAULT_FIGSIZE = 8 32 | DEFAULT_LINEWIDTH = 3 33 | DEFAULT_FONTSIZE = 22 34 | mpl.style.use('seaborn-colorblind') 35 | mpl.rcParams['figure.facecolor'] = 'w' 36 | mpl.rcParams.update({'font.size': 14, 'lines.linewidth': 2, 'figure.figsize': (DEFAULT_FIGSIZE, DEFAULT_FIGSIZE / 1.61)}) 37 | mpl.rcParams['grid.color'] = 'k' 38 | mpl.rcParams['grid.linestyle'] = ':' 39 | mpl.rcParams['errorbar.capsize'] = 2 40 | mpl.rcParams['image.cmap'] = 'cividis' 41 | mpl.rcParams['grid.linewidth'] = 0.5 42 | mpl.rcParams['lines.markersize'] = 6 43 | mpl.rcParams['lines.marker'] = None 44 | mpl.rcParams['axes.grid'] = True 45 | COLORS = mpl.rcParams["axes.prop_cycle"].by_key()["color"] 46 | mpl.rcParams.update({'font.size': DEFAULT_FONTSIZE, 'lines.linewidth': DEFAULT_LINEWIDTH, 47 | 'legend.fontsize': DEFAULT_FONTSIZE, 'axes.labelsize': DEFAULT_FONTSIZE, 48 | 'xtick.labelsize': DEFAULT_FONTSIZE, 'ytick.labelsize': DEFAULT_FONTSIZE, 49 | 'figure.figsize': (7, 7.0 / 1.4)}) 50 | 51 | # Misc 52 | logger.info("GPU Available to Tensorflow:") 53 | logger.info(tensorflow.test.is_gpu_available()) 54 | logger.info("TF can use Eager") 55 | logger.info(tensorflow.executing_eagerly()) 56 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/train_supervised.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trainer script for src.pl_modules.supervised_learning. Example run command: bin/train.py save_to_folder configs/cnn.gin. 5 | """ 6 | 7 | import gin 8 | import logging 9 | import os 10 | import json 11 | 12 | from src.data import get_dataset 13 | from src.callbacks import get_callback 14 | from src.utils import summary, acc, gin_wrap, parse_gin_config 15 | from src.modules import supervised_training 16 | from src import models 17 | from src.training_loop import training_loop 18 | 19 | from pytorch_lightning.callbacks import ModelCheckpoint 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | @gin.configurable 24 | def train(save_path, model, batch_size=128, seed=777, callbacks=[], resume=True, evaluate=True): 25 | # Create dynamically dataset generators 26 | train, valid, test, meta_data = get_dataset(batch_size=batch_size, seed=seed) 27 | 28 | # Create dynamically model 29 | model = models.__dict__[model]() 30 | summary(model) 31 | 32 | # Create dynamically callbacks 33 | callbacks_constructed = [] 34 | for name in callbacks: 35 | clbk = get_callback(name, verbose=0) 36 | if clbk is not None: 37 | callbacks_constructed.append(clbk) 38 | 39 | if not resume and os.path.exists(os.path.join(save_path, "last.ckpt")): 40 | raise IOError("Please clear folder before running or pass train.resume=True") 41 | 42 | # Create module and pass to trianing 43 | checkpoint_callback = ModelCheckpoint( 44 | filepath=os.path.join(save_path, "weights"), 45 | verbose=True, 46 | save_last=True, # For resumability 47 | monitor='valid_acc', 48 | mode='max' 49 | ) 50 | pl_module = supervised_training.SupervisedLearning(model, meta_data=meta_data) 51 | trainer = training_loop(train, valid, pl_module=pl_module, checkpoint_callback=checkpoint_callback, 52 | callbacks=callbacks_constructed, save_path=save_path) 53 | 54 | # Evaluate 55 | if evaluate: 56 | results, = trainer.test(test_dataloaders=test) 57 | logger.info(results) 58 | with open(os.path.join(save_path, "eval_results.json"), "w") as f: 59 | json.dump(results, f) 60 | 61 | 62 | if __name__ == "__main__": 63 | gin_wrap(train) 64 | -------------------------------------------------------------------------------- /tf2_project_template/bin/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Evaluation script for the project. Example run command: bin/evaluate_tf.py path_to_experiment 5 | """ 6 | 7 | import gin 8 | import logging 9 | import sys 10 | import json 11 | import os 12 | 13 | import tensorflow as tf 14 | from tensorflow.keras.metrics import categorical_accuracy 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | from bin.train import * 19 | from src.utils import restore_model 20 | from src.plotting import load_HC 21 | from src.training_loop import evaluate 22 | 23 | 24 | if __name__ == "__main__": 25 | E = sys.argv[1] 26 | if len(sys.argv) == 3: 27 | checkpoint = sys.argv[2] 28 | suffix = checkpoint 29 | else: 30 | checkpoint = 'model_best_val.h5' 31 | suffix = "" 32 | 33 | H, C = load_HC(E) 34 | # Note: Load_HC doesn't load gin properly 35 | gin.parse_config_files_and_bindings([os.path.join(os.path.join(E, "config.gin"))], bindings=[""]) 36 | logger.info(C['train']['datasets']) 37 | datasets = [get_dataset(d, seed=C['train']['data_seed'], batch_size=C['train']['batch_size']) for d in C['train']['datasets']] 38 | 39 | model = models.__dict__[C['train']['model']](input_shape=datasets[0][-1]['input_shape'], 40 | n_classes=datasets[0][-1]['num_classes']) 41 | logger.info("# of parameters " + str(sum([np.prod(p.shape) for p in model.trainable_weights]))) 42 | model.summary() 43 | 44 | if C['train']['loss'] == 'ce': 45 | loss_function = tf.keras.losses.categorical_crossentropy 46 | else: 47 | raise NotImplementedError() 48 | metrics = [categorical_accuracy] 49 | if C['train'].get("f1", False): 50 | metrics.append("f1") 51 | 52 | model = restore_model(model, os.path.join(E, checkpoint)) 53 | 54 | m_test = evaluate(model, [datasets[0][2]], loss_function, metrics) 55 | m_val = evaluate(model, [datasets[0][1]], loss_function, metrics) 56 | 57 | logger.info("Saving") 58 | eval_results = {} 59 | for k in m_test: 60 | eval_results['test_' + k] = float(m_test[k]) 61 | for k in m_val: 62 | eval_results['val_' + k] = float(m_val[k]) 63 | logger.info(eval_results) 64 | 65 | json.dump(eval_results, open(os.path.join(E, f"eval_results{suffix}.json"), "w")) -------------------------------------------------------------------------------- /pytorch_lightning_project_template/e.yml: -------------------------------------------------------------------------------- 1 | name: pytorch_project_template 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - torchvision 7 | - _tflow_select 8 | - absl-py 9 | - astor 10 | - attrs 11 | - backcall 12 | - blas 13 | - bleach 14 | - c-ares 15 | - ca-certificates 16 | - certifi 17 | - cffi 18 | - cycler 19 | - dbus 20 | - decorator 21 | - defusedxml 22 | - entrypoints 23 | - expat 24 | - freetype 25 | - gast 26 | - gettext 27 | - glib 28 | - grpcio 29 | - h5py 30 | - hdf5 31 | - icu 32 | - intel-openmp 33 | - ipykernel 34 | - ipython 35 | - ipython_genutils 36 | - ipywidgets 37 | - jedi 38 | - jinja2 39 | - joblib 40 | - jpeg 41 | - jsonschema 42 | - jupyter 43 | - jupyter_client 44 | - jupyter_console 45 | - jupyter_core 46 | - keras 47 | - keras-applications 48 | - keras-base 49 | - keras-preprocessing 50 | - kiwisolver 51 | - libedit 52 | - libffi 53 | - libiconv 54 | - libpng 55 | - libprotobuf 56 | - libsodium 57 | - markdown 58 | - markupsafe 59 | - matplotlib 60 | - mistune 61 | - mkl 62 | - mkl_fft 63 | - mkl_random 64 | - mock 65 | - nbconvert 66 | - nbformat 67 | - ncurses 68 | - ninja 69 | - notebook 70 | - numpy 71 | - numpy-base 72 | - openssl 73 | - pandas 74 | - pandoc 75 | - pandocfilters 76 | - parso 77 | - patsy 78 | - pcre 79 | - pexpect 80 | - pickleshare 81 | - pip 82 | - prometheus_client 83 | - prompt_toolkit 84 | - protobuf 85 | - ptyprocess 86 | - pycparser 87 | - pygments 88 | - pyparsing 89 | - pyqt 90 | - pyrsistent 91 | - python 92 | - python-dateutil 93 | - pytorch 94 | - pytz 95 | - pyyaml 96 | - pyzmq 97 | - qt 98 | - qtconsole 99 | - readline 100 | - scikit-learn 101 | - scipy 102 | - seaborn 103 | - send2trash 104 | - setuptools 105 | - sip 106 | - six 107 | - sqlite 108 | - statsmodels 109 | - tensorboard 110 | - tensorflow 111 | - tensorflow-base 112 | - tensorflow-estimator 113 | - termcolor 114 | - terminado 115 | - testpath 116 | - tk 117 | - tornado 118 | - tqdm 119 | - traitlets 120 | - wcwidth 121 | - webencodings 122 | - werkzeug 123 | - wheel 124 | - widgetsnbextension 125 | - xz 126 | - yaml 127 | - zeromq 128 | - zlib 129 | - pip: 130 | - argh 131 | - enum34 132 | - pytorch-lightning 133 | - gin-config 134 | - itermplot 135 | 136 | -------------------------------------------------------------------------------- /tf2_project_template/src/models/simple_cnn.py: -------------------------------------------------------------------------------- 1 | # Simple CNN model in TensorFlow roughly based on https://keras.io/examples/cifar10_cnn/ 2 | import gin 3 | 4 | from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, MaxPooling2D, \ 5 | Input, Activation, Flatten, Dropout 6 | from tensorflow.keras.models import Model 7 | 8 | from src import DATA_FORMAT 9 | 10 | 11 | @gin.configurable 12 | def SimpleCNN(input_shape=(3, 32, 32), dropout=0.0, n_filters=32, activation="relu", 13 | n_dense=128, kernel_size=3, n1=1, n2=1, n_classes=10, bn=False): 14 | inputs = Input(shape=input_shape) 15 | x = inputs 16 | 17 | for id in range(n1): 18 | prefix_column = str(id) if id > 0 else "" 19 | x = Conv2D(n_filters, (kernel_size, kernel_size), padding='same', data_format=DATA_FORMAT, 20 | name=prefix_column + "conv1")(x) 21 | if bn: 22 | x = BatchNormalization(axis=1, name=prefix_column + "bn1")(x) 23 | x = Activation(activation, name=prefix_column + "act_1")(x) 24 | x = Conv2D(n_filters, (kernel_size, kernel_size), data_format=DATA_FORMAT, padding='same', 25 | name=prefix_column + "conv2")(x) 26 | if bn: 27 | x = BatchNormalization(axis=1, name=prefix_column + "bn2")(x) 28 | x = Activation(activation, name=prefix_column + "act_2")(x) 29 | x = MaxPooling2D(pool_size=(2, 2), data_format=DATA_FORMAT)(x) 30 | 31 | for id in range(n2): 32 | prefix_column = str(id) if id > 0 else "" 33 | x = Conv2D(n_filters * 2, (kernel_size, kernel_size), data_format=DATA_FORMAT, 34 | padding='same', name=prefix_column + "conv3")(x) 35 | if bn: 36 | x = BatchNormalization(axis=1, name=prefix_column + "bn3")(x) 37 | x = Activation(activation, name=prefix_column + "act_3")(x) 38 | x = Conv2D(n_filters * 2, (kernel_size, kernel_size), data_format=DATA_FORMAT, padding='same', 39 | name=prefix_column + "conv4")(x) 40 | if bn: 41 | x = BatchNormalization(axis=1, name=prefix_column + "bn4")(x) 42 | x = Activation(activation, name=prefix_column + "act_4")(x) 43 | 44 | x = MaxPooling2D(pool_size=(2, 2), data_format=DATA_FORMAT)(x) 45 | x = Flatten()(x) 46 | 47 | x = Dense(n_dense, name="dense2")(x) 48 | if bn: 49 | x = BatchNormalization(name="bn5")(x) 50 | x = Activation(activation, name="act_5")(x) # Post act 51 | if dropout > 0: 52 | x = Dropout(dropout)(x) 53 | x = Dense(n_classes, activation="linear", name="pre_softmax")(x) 54 | x = Activation(activation="softmax", name="post_softmax")(x) 55 | 56 | model = Model(inputs=[inputs], outputs=[x]) 57 | 58 | return model 59 | 60 | 61 | if __name__ == "__main__": 62 | model = SimpleCNN_tf() 63 | -------------------------------------------------------------------------------- /keras_project_template/src/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Simple data getters. Each returns iterator for train and dataset for test/valid 4 | """ 5 | from keras.datasets import cifar10, cifar100 6 | from keras.utils import np_utils 7 | from keras.preprocessing.image import ImageDataGenerator 8 | 9 | import numpy as np 10 | 11 | import logging 12 | logging.getLogger(__name__) 13 | 14 | def get_cifar(dataset="cifar10", data_format="channels_first", augmented=False, batch_size=128, preprocessing="center"): 15 | """ 16 | Returns train iterator and test X, y. 17 | """ 18 | 19 | # the data, shuffled and split between train and test sets 20 | if dataset == 'cifar10': 21 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 22 | elif dataset == 'cifar100': 23 | (x_train, y_train), (x_test, y_test) = cifar100.load_data() 24 | else: 25 | raise NotImplementedError() 26 | 27 | if x_train.shape[3] == 3: 28 | logging.info("Transposing") 29 | x_train = x_train.transpose((0, 3, 1, 2))[0:100] # For speed 30 | x_test = x_test.transpose((0, 3, 1, 2))[0:100] # For speed 31 | assert x_train.shape[1] == 3 32 | 33 | if preprocessing == "center": 34 | mean = np.mean(x_train, axis=0, keepdims=True) 35 | std = np.std(x_train) 36 | x_train = (x_train - mean) / std 37 | x_test = (x_test - mean) / std 38 | elif preprocessing == "01": # Required by scatnet 39 | x_train = x_train / 255.0 40 | x_test = x_test / 255.0 41 | else: 42 | raise NotImplementedError("Not implemented preprocessing " + preprocessing) 43 | 44 | logging.info('x_train shape:' + str(x_train.shape)) 45 | logging.info(str(x_train.shape[0]) + 'train samples') 46 | logging.info(str(x_test.shape[0]) + 'test samples') 47 | 48 | # convert class vectors to binary class matrices 49 | y_train = np_utils.to_categorical(y_train)[0:100] 50 | y_test = np_utils.to_categorical(y_test)[0:100] 51 | 52 | train, test = None, [x_test, y_test] 53 | if augmented: 54 | datagen_train = ImageDataGenerator( 55 | featurewise_center=False, 56 | samplewise_center=False, 57 | featurewise_std_normalization=False, 58 | samplewise_std_normalization=False, 59 | zca_whitening=False, 60 | rotation_range=0, 61 | data_format=data_format, 62 | width_shift_range=0.125, 63 | height_shift_range=0.125, 64 | horizontal_flip=True, 65 | vertical_flip=False) 66 | datagen_train.fit(x_train) 67 | train = datagen_train.flow(x_train, y_train, batch_size=batch_size, shuffle=True) 68 | else: 69 | raise NotImplementedError() 70 | 71 | return train, test, {"x_train": x_train, "y_train": y_train, "x_test": x_test, "y_test": y_test} 72 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/callbacks/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Basic callbacks available in the project 4 | """ 5 | 6 | import datetime 7 | import json 8 | import logging 9 | import os 10 | import sys 11 | import time 12 | 13 | import gin 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # from src.training_loop import training_loop 18 | from src.utils import parse_gin_config 19 | 20 | from pytorch_lightning.callbacks import Callback 21 | 22 | 23 | @gin.configurable 24 | class MetaSaver(Callback): 25 | def __init__(self): 26 | super(MetaSaver, self).__init__() 27 | 28 | def on_train_start(self, trainer, pl_module): 29 | logger.info("Saving meta data information from the beginning of training") 30 | 31 | assert os.system( 32 | "cp {} {}".format(sys.argv[0], trainer.default_root_dir)) == 0, "Failed to execute cp of source script" 33 | 34 | utc_date = datetime.datetime.utcnow().strftime("%Y_%m_%d") 35 | 36 | time_start = time.time() 37 | cmd = "python " + " ".join(sys.argv) 38 | self.meta = {"cmd": cmd, 39 | "save_path": trainer.default_root_dir, 40 | "most_recent_train_start_date": utc_date, 41 | "execution_time": -time_start} 42 | 43 | json.dump(self.meta, open(os.path.join(trainer.default_root_dir, "meta.json"), "w"), indent=4) 44 | 45 | def on_train_end(self, trainer, pl_module): 46 | self.meta['execution_time'] += time.time() 47 | json.dump(self.meta, open(os.path.join(trainer.default_root_dir, "meta.json"), "w"), indent=4) 48 | os.system("touch " + os.path.join(trainer.default_root_dir, "FINISHED")) 49 | 50 | 51 | class Heartbeat(Callback): 52 | def __init__(self, interval=10): 53 | self.last_time = time.time() 54 | self.interval = interval 55 | 56 | def on_train_start(self, trainer, pl_module): 57 | logger.info("HEARTBEAT - train begin") 58 | os.system("touch " + os.path.join(trainer.default_root_dir, "HEARTBEAT")) 59 | 60 | def on_batch_start(self, trainer, pl_module): 61 | if time.time() - self.last_time > self.interval: 62 | logger.info("HEARTBEAT") 63 | os.system("touch " + os.path.join(trainer.default_root_dir, "HEARTBEAT")) 64 | self.last_time = time.time() 65 | 66 | 67 | 68 | @gin.configurable 69 | class LRSchedule(Callback): 70 | def __init__(self, base_lr, schedule): 71 | self.schedule = schedule 72 | self.base_lr = base_lr 73 | super(LRSchedule, self).__init__() 74 | 75 | def on_epoch_start(self, trainer, pl_module): 76 | # Epochs starts from 0 77 | for e, v in self.schedule: 78 | if trainer.current_epoch < e: 79 | break 80 | for group in trainer.optimizers[0].param_groups: 81 | group['lr'] = v * self.base_lr 82 | logger.info("Set learning rate to {}".format(v * self.base_lr)) 83 | 84 | -------------------------------------------------------------------------------- /tf2_project_template/experiments/tune_lr/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | See main.py -h for help. 5 | """ 6 | import numpy as np 7 | import argh 8 | import glob 9 | import matplotlib.pylab as plt 10 | from os.path import dirname, basename, join 11 | 12 | from bin.train import * 13 | from src import * 14 | from src.plotting import load_HC 15 | 16 | EXPERIMENT_DIR = dirname(__file__) 17 | logger = logging.getLogger(__name__) 18 | 19 | def prepare(experiment="large"): 20 | os.system("mkdir -p {}/configs".format(experiment)) 21 | if experiment == "large": 22 | lrs = [0.001, 0.01, 0.1] 23 | elif experiment == "small": 24 | lrs = [0.001, 0.1] 25 | else: 26 | raise NotImplementedError() 27 | 28 | # Write configs by modifying template 29 | with open(join(EXPERIMENT_DIR, "template_config.gin"), "r") as f: 30 | TEMPLATE = f.read() 31 | for id, bs in enumerate(lrs): 32 | exp_config_path = join(EXPERIMENT_DIR, experiment, "configs", "{}.gin".format(id)) 33 | os.system("mkdir -p " + dirname(exp_config_path)) 34 | with open(exp_config_path, "w") as c: 35 | c.write(TEMPLATE.replace("$learning_rate$", str(bs))) 36 | 37 | # Prepare batch of runs 38 | with open(join(EXPERIMENT_DIR, experiment, "batch.sh"), "w") as f: 39 | for id in range(len(lrs)): 40 | exp_save_path = join(RESULTS_DIR, "tune_lr", experiment, str(id)) 41 | exp_config_path = join(EXPERIMENT_DIR, experiment, "configs", "{}.gin".format(id)) 42 | if not os.path.exists(join(exp_save_path, "FINISHED")): 43 | os.system("mkdir -p " + exp_save_path) 44 | f.write("python3 bin/train.py {save_path} {config_path}\n".format( 45 | save_path=exp_save_path, 46 | config_path=exp_config_path 47 | )) 48 | else: 49 | logger.info("Finished experiment #{}, checking if configs match.".format(id)) 50 | # Ensures that run experiment matches expectations 51 | with open(exp_config_path, "r") as f_c: 52 | with open(join(exp_save_path, "{}.gin".format(id))) as f_e: 53 | c = f_c.read() 54 | c_run = f_e.read() 55 | assert c == c_run, "Finished experiment with a mismatching config. Aborting." 56 | 57 | 58 | def report(experiment="large"): 59 | Es = glob.glob(join(RESULTS_DIR, "tune_lr", experiment, "*")) 60 | Es = sorted(Es, key=lambda e: int(basename(e))) 61 | 62 | x, y = [], [] 63 | for E in Es: 64 | H, C = load_HC(E) 65 | lr = C['LRSchedule']['base_lr'] 66 | x.append(lr) 67 | y.append(max(H['categorical_accuracy:0'])) 68 | 69 | logger.info("Maximum accuracy reached for learning_rate={}.".format(x[np.argmax(y)])) 70 | 71 | plt.plot(x, y) 72 | plt.xlabel("Learning rate") 73 | plt.ylabel("Maximum accuracy") 74 | plt.show() 75 | 76 | if __name__ == "__main__": 77 | argh.dispatch_commands([prepare, report]) 78 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/training_loop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Training loop based on PytorchLightning 4 | """ 5 | 6 | import logging 7 | import os 8 | import sys 9 | import tqdm 10 | import copy 11 | import pickle 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | import gin 16 | import pytorch_lightning as pl 17 | 18 | from functools import partial 19 | from collections import defaultdict 20 | from contextlib import contextmanager 21 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 22 | from pytorch_lightning import loggers as pl_loggers 23 | 24 | from src.utils import save_weights, parse_gin_config 25 | from src.callbacks.base import MetaSaver, Heartbeat 26 | from src import models, NEPTUNE_TOKEN, NEPTUNE_USER, NEPTUNE_PROJECT 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | @gin.configurable 32 | def training_loop(train, valid, save_path, pl_module, callbacks, 33 | n_epochs, checkpoint_callback, use_cpu=False, use_neptune=False, resume=True, limit_train_batches=None, neptune_tags="", neptune_name=""): 34 | """ 35 | Largely model/application agnostic training code. 36 | """ 37 | # Train with proper resuming 38 | # Copy gin configs used, for reference, to the save folder 39 | 40 | if use_cpu: 41 | gpus = 0 42 | else: 43 | gpus = 1 44 | 45 | if not limit_train_batches: 46 | limit_train_batches = len(train) 47 | 48 | os.system("rm " + os.path.join(save_path, "*gin")) 49 | for gin_config in sys.argv[2].split(";"): 50 | os.system("cp {} {}/base_config.gin".format(gin_config, save_path)) 51 | with open(os.path.join(save_path, "config.gin"), "w") as f: 52 | f.write(gin.operative_config_str()) 53 | hparams = parse_gin_config(os.path.join(save_path, 'config.gin')) 54 | if 'train.callbacks' in hparams: 55 | del hparams['train.callbacks'] 56 | # TODO: What is a less messy way to pass hparams? This is only that logging is aware of hyperparameters 57 | pl_module._set_hparams(hparams) 58 | pl_module._hparams_initial = copy.deepcopy(hparams) 59 | loggers = [] 60 | loggers.append(pl_loggers.CSVLogger(save_path)) 61 | if use_neptune: 62 | from pytorch_lightning.loggers import NeptuneLogger 63 | loggers.append(NeptuneLogger( 64 | api_key=NEPTUNE_TOKEN, 65 | project_name=NEPTUNE_USER + "/" + NEPTUNE_PROJECT, 66 | experiment_name=neptune_name if len(neptune_name) else os.path.basename(save_path), 67 | tags=neptune_tags.split(',') if len(neptune_tags) else None, 68 | )) 69 | callbacks += [MetaSaver(), Heartbeat(), LearningRateMonitor()] 70 | trainer = pl.Trainer( 71 | default_root_dir=save_path, 72 | limit_train_batches=limit_train_batches, 73 | max_epochs=n_epochs, 74 | logger=loggers, 75 | callbacks=callbacks, 76 | log_every_n_steps=1, 77 | gpus=gpus, 78 | checkpoint_callback=checkpoint_callback, 79 | resume_from_checkpoint=os.path.join(save_path, 'last.ckpt') 80 | if resume and os.path.exists(os.path.join(save_path, 'last.ckpt')) else None) 81 | trainer.fit(pl_module, train, valid) 82 | return trainer 83 | -------------------------------------------------------------------------------- /pytorch_project_template/e.yml: -------------------------------------------------------------------------------- 1 | name: pytorch_project_template 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _tflow_select=2.3.0 7 | - absl-py=0.7.1 8 | - astor=0.7.1 9 | - attrs=19.1.0 10 | - backcall=0.1.0 11 | - blas=1.0 12 | - bleach=3.1.0 13 | - c-ares=1.15.0 14 | - ca-certificates=2019.1.23 15 | - certifi=2019.3.9 16 | - cffi=1.12.3 17 | - cycler=0.10.0 18 | - dbus=1.13.6 19 | - decorator=4.4.0 20 | - defusedxml=0.6.0 21 | - entrypoints=0.3 22 | - expat=2.2.6 23 | - freetype=2.9.1 24 | - gast=0.2.2 25 | - gettext=0.19.8.1 26 | - glib=2.56.2 27 | - grpcio=1.16.1 28 | - h5py=2.9.0 29 | - hdf5=1.10.4 30 | - icu=58.2 31 | - intel-openmp=2019.3 32 | - ipykernel=5.1.0 33 | - ipython=7.5.0 34 | - ipython_genutils=0.2.0 35 | - ipywidgets=7.4.2 36 | - jedi=0.13.3 37 | - jinja2=2.10.1 38 | - joblib=0.13.2 39 | - jpeg=9b 40 | - jsonschema=3.0.1 41 | - jupyter=1.0.0 42 | - jupyter_client=5.2.4 43 | - jupyter_console=6.0.0 44 | - jupyter_core=4.4.0 45 | - keras=2.2.4 46 | - keras-applications=1.0.7 47 | - keras-base=2.2.4 48 | - keras-preprocessing=1.0.9 49 | - kiwisolver=1.1.0 50 | - libedit=3.1.20181209 51 | - libffi=3.2.1 52 | - libiconv=1.15 53 | - libpng=1.6.37 54 | - libprotobuf=3.7.1 55 | - libsodium=1.0.16 56 | - markdown=3.1 57 | - markupsafe=1.1.1 58 | - matplotlib=3.1.0 59 | - mistune=0.8.4 60 | - mkl=2019.3 61 | - mkl_fft=1.0.12 62 | - mkl_random=1.0.2 63 | - mock=3.0.5 64 | - nbconvert=5.5.0 65 | - nbformat=4.4.0 66 | - ncurses=6.1 67 | - ninja=1.9.0 68 | - notebook=5.7.8 69 | - numpy=1.16.4 70 | - numpy-base=1.16.4 71 | - openssl=1.1.1c 72 | - pandas=0.24.2 73 | - pandoc=2.2.3.2 74 | - pandocfilters=1.4.2 75 | - parso=0.4.0 76 | - patsy=0.5.1 77 | - pcre=8.43 78 | - pexpect=4.7.0 79 | - pickleshare=0.7.5 80 | - pip=19.1.1 81 | - prometheus_client=0.6.0 82 | - prompt_toolkit=2.0.9 83 | - protobuf=3.7.1 84 | - ptyprocess=0.6.0 85 | - pycparser=2.19 86 | - pygments=2.4.0 87 | - pyparsing=2.4.0 88 | - pyqt=5.9.2 89 | - pyrsistent=0.14.11 90 | - python=3.7.3 91 | - python-dateutil=2.8.0 92 | - pytorch=1.1.0 93 | - pytz=2019.1 94 | - pyyaml=5.1 95 | - pyzmq=18.0.0 96 | - qt=5.9.7 97 | - qtconsole=4.5.1 98 | - readline=7.0 99 | - scikit-learn=0.21.1 100 | - scipy=1.2.1 101 | - seaborn=0.9.0 102 | - send2trash=1.5.0 103 | - setuptools=41.0.1 104 | - sip=4.19.8 105 | - six=1.12.0 106 | - sqlite=3.28.0 107 | - statsmodels=0.9.0 108 | - tensorboard=1.13.1 109 | - tensorflow=1.13.1 110 | - tensorflow-base=1.13.1 111 | - tensorflow-estimator=1.13.0 112 | - termcolor=1.1.0 113 | - terminado=0.8.2 114 | - testpath=0.4.2 115 | - tk=8.6.8 116 | - tornado=6.0.2 117 | - tqdm=4.31.1 118 | - traitlets=4.3.2 119 | - wcwidth=0.1.7 120 | - webencodings=0.5.1 121 | - werkzeug=0.15.2 122 | - wheel=0.33.4 123 | - widgetsnbextension=3.4.2 124 | - xz=5.2.4 125 | - yaml=0.1.7 126 | - zeromq=4.3.1 127 | - zlib=1.2.11 128 | - pip: 129 | - argh==0.26.2 130 | - enum34==1.1.6 131 | - gin-config==0.1.4 132 | - itermplot==0.331 133 | prefix: /Users/jastrs01/anaconda3/envs/vanilla_pytorch_project_template 134 | 135 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/modules/supervised_training.py: -------------------------------------------------------------------------------- 1 | """ 2 | A PyTorch Lighting module defines a *system* that is constructed from modules (data loaders, models, optimizer) 3 | that trains and evaluates *on some task*. 4 | 5 | The benefit is that such module can be then easily trained using various PL modules. 6 | 7 | The most standard task is supervised training. 8 | """ 9 | import os 10 | from argparse import ArgumentParser, Namespace 11 | from collections import OrderedDict 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.nn.parallel 16 | import torch.optim as optim 17 | import torch.optim.lr_scheduler as lr_scheduler 18 | import torch.utils.data 19 | import torch.utils.data.distributed 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | import torchvision.transforms as transforms 23 | 24 | import pytorch_lightning as pl 25 | from pytorch_lightning.core import LightningModule 26 | 27 | import gin 28 | 29 | @gin.configurable 30 | class SupervisedLearning(LightningModule): 31 | """ 32 | Module defining a supervised learning system. 33 | """ 34 | 35 | def __init__(self, model, meta_data, lr=0.1): 36 | super().__init__() 37 | self.model = model 38 | self.meta_data = meta_data 39 | self.lr = lr 40 | 41 | def forward(self, x): 42 | return self.model(x) 43 | 44 | def training_step(self, batch, batch_idx): 45 | images, target = batch 46 | output = self(images) 47 | loss_val = F.cross_entropy(output, target) 48 | acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) 49 | self.log('train_loss', loss_val, on_epoch=True, on_step=True) 50 | self.log('train_acc', acc1, on_epoch=True, on_step=True) 51 | self.log('train_loss_step', loss_val, on_epoch=False, on_step=True) 52 | self.log('train_acc_step', acc1, on_epoch=False, on_step=True) 53 | return loss_val 54 | 55 | def validation_step(self, batch, batch_idx): 56 | images, target = batch 57 | output = self(images) 58 | loss_val = F.cross_entropy(output, target) 59 | acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) 60 | self.log('valid_loss', loss_val, on_epoch=True) 61 | self.log('valid_acc', acc1, on_epoch=True) 62 | 63 | def test_step(self, *args, **kwargs): 64 | return self.validation_step(*args, **kwargs) 65 | 66 | @staticmethod 67 | def __accuracy(output, target, topk=(1,)): 68 | """Computes the accuracy over the k top predictions for the specified values of k""" 69 | with torch.no_grad(): 70 | maxk = max(topk) 71 | batch_size = target.size(0) 72 | 73 | _, pred = output.topk(maxk, 1, True, True) 74 | pred = pred.t() 75 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 76 | res = [] 77 | for k in topk: 78 | correct_k = correct[:k].flatten().float().sum(0, keepdim=True) 79 | res.append(correct_k.mul_(100.0 / batch_size)) 80 | return res 81 | 82 | def configure_optimizers(self): 83 | # TODO: I am not sure where and how pytorch lightning uses it. Maybe not necessary to havet his logic here 84 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) 85 | return optimizer 86 | 87 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/experiments/tune_lr/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | See main.py -h for help. 5 | """ 6 | import numpy as np 7 | import argh 8 | import glob 9 | import os 10 | import matplotlib.pylab as plt 11 | from os.path import dirname, basename, join 12 | 13 | import pandas as pd 14 | 15 | from gin.config import _CONFIG 16 | from src import RESULTS_DIR 17 | from src.utils import load_C, load_H, load_HC 18 | # This will be important for passing gin configs 19 | from bin.train_supervised import * 20 | from src import * 21 | 22 | EXPERIMENT_DIR = dirname(__file__) 23 | logger = logging.getLogger(__name__) 24 | 25 | def prepare(experiment="large"): 26 | os.system("mkdir -p {}/configs".format(experiment)) 27 | if experiment == "large": 28 | lrs = [0.001, 0.01, 0.1] 29 | elif experiment == "small": 30 | lrs = [0.001, 0.1] 31 | else: 32 | raise NotImplementedError() 33 | 34 | # Write configs by modifying template 35 | with open(join(EXPERIMENT_DIR, "template_config.gin"), "r") as f: 36 | TEMPLATE = f.read() 37 | for id, bs in enumerate(lrs): 38 | exp_config_path = join(EXPERIMENT_DIR, experiment, "configs", "{}.gin".format(id)) 39 | os.system("mkdir -p " + dirname(exp_config_path)) 40 | with open(exp_config_path, "w") as c: 41 | c.write(TEMPLATE.replace("$learning_rate$", str(bs))) 42 | 43 | # Prepare batch of runs 44 | with open(join(EXPERIMENT_DIR, experiment, "batch.sh"), "w") as f: 45 | for id in range(len(lrs)): 46 | exp_save_path = join(RESULTS_DIR, "tune_lr", experiment, str(id)) 47 | exp_config_path = join(EXPERIMENT_DIR, experiment, "configs", "{}.gin".format(id)) 48 | if not os.path.exists(join(exp_save_path, "FINISHED")): 49 | os.system("mkdir -p " + exp_save_path) 50 | f.write("python3 bin/train_supervised.py {save_path} {config_path} \n".format( 51 | save_path=exp_save_path, 52 | config_path=exp_config_path 53 | )) 54 | else: 55 | logger.info("Finished experiment #{}, checking if configs match.".format(id)) 56 | # Ensures that run experiment matches expectations 57 | with open(exp_config_path, "r") as f_c: 58 | with open(join(exp_save_path, "{}.gin".format(id))) as f_e: 59 | c = f_c.read() 60 | c_run = f_e.read() 61 | assert c == c_run, "Finished experiment with a mismatching config. Aborting." 62 | 63 | 64 | def report(experiment="large"): 65 | SAVE_TO = join(EXPERIMENT_DIR, experiment) 66 | 67 | Es = glob.glob(join(RESULTS_DIR, "tune_lr", experiment, "*")) 68 | Es = sorted(Es, key=lambda e: int(basename(e))) 69 | Es = [E for E in Es if os.path.exists(join(E, 'FINISHED'))] # Use only finished 70 | 71 | x, y = [], [] 72 | for E in Es: 73 | H, C = load_HC(E) 74 | lr = C["LRSchedule.base_lr"] 75 | x.append(lr) 76 | y.append(max(H['valid_acc'])) 77 | 78 | logger.info("Maximum accuracy reached for learning_rate={}.".format(x[np.argmax(y)])) 79 | 80 | plt.plot(x, y) 81 | plt.xlabel("Learning rate") 82 | plt.ylabel("Maximum accuracy") 83 | plt.savefig(join(SAVE_TO, 'plot.pdf')) 84 | plt.show() 85 | 86 | if __name__ == "__main__": 87 | argh.dispatch_commands([prepare, report]) 88 | -------------------------------------------------------------------------------- /keras_project_template/README.md: -------------------------------------------------------------------------------- 1 | # Example project 2 | 3 | Warning: slightly outdate, but 4 | 5 | Simple exemplary code training small CNN on CIFAR10/CIFAR100. 6 | 7 | The main idea is that there is a set of "base config" and you can run scripts as: 8 | 9 | ``` 10 | python src/scripts/train_simple_CNN.py cifar10 results/test_run --n_filters=10 11 | ``` 12 | 13 | , which slighty modifies base config of simple_CNN by changing n_filters. After running you can find following goodies inside 14 | your results directory: 15 | 16 | ``` 17 | -rw-r--r--@ 1 kudkudak staff 144 Jun 10 18:04 history.csv 18 | -rw-r--r--@ 1 kudkudak staff 35 Jun 10 18:04 loop_state.pkl 19 | -rw-r--r--@ 1 kudkudak staff 200 Jun 10 18:04 meta.json 20 | -rw-r--r--@ 1 kudkudak staff 2593544 Jun 10 18:04 model.h5 21 | -rw-r--r--@ 1 kudkudak staff 1578 Jun 10 18:04 stderr.txt 22 | -rw-r--r--@ 1 kudkudak staff 2507 Jun 10 18:04 stdout.txt 23 | -rw-r--r--@ 1 kudkudak staff 300 Jun 10 18:04 config.json 24 | -rw-r--r--@ 1 kudkudak staff 1211 Jun 10 18:04 train_simple_CNN.py 25 | ``` 26 | 27 | 28 | ## Project structure 29 | 30 | * ./env.sh 31 | 32 | Any relevant environment variables (including `THEANO_FLAGS`). Shouldn't be commited. Often machine specific! 33 | 34 | * src/data.py 35 | 36 | Gives training dataset as iterator and test dataset as simple arrays. Usually coded with help of Fuel. 37 | 38 | ```{python} 39 | 40 | def get_cifar(which, augment): 41 | return train, test 42 | 43 | ``` 44 | 45 | * src/models.py 46 | 47 | Model definitions. 48 | 49 | ```{python} 50 | 51 | def build_simple_cnn(config): 52 | pass 53 | 54 | ``` 55 | 56 | Note: if your model is complicated (for instance has custom inference procedure), it is a good 57 | idea to wrap model in a class supporting these methods. For instance you can construct new block in Blocks 58 | or new Model in keras. 59 | 60 | * src/training_loop.py 61 | 62 | Resumable training loops (sometimes shipped with framework, e.g. Blocks). For instance in keras follow convention: 63 | 64 | ```{python} 65 | 66 | def cifar_training_loop(model, train, valid, [other params]): 67 | pass 68 | 69 | ``` 70 | 71 | Note that training loop does not accept test set. This should be explicitely never look at during training, 72 | very easy to use it (even implicitely), and thus overfit. 73 | 74 | * src/scripts 75 | 76 | Runners (usually use vegab/argh or other command line wrapper), following convention: 77 | 78 | ```{bash} 79 | 80 | ./src/scripts/train_simple_CNN.py root results/simple_cnn/my_save_location --n_layers=10 81 | 82 | ``` 83 | 84 | Any DL code should be resumable by default. 85 | 86 | * configs 87 | 88 | Stores configs use in the project as jsons or config_registry. This projects uses config_registry.- 89 | 90 | * etc 91 | 92 | All misc. files relevant to the project (includes meeting notes, paper sources, etc.). 93 | 94 | * requirements.txt 95 | 96 | Usually there is no need to specify versions of all packages, but it is useful to fix the ones that are less mature, 97 | like keras or Theano/tensorflow. 98 | 99 | ## FAQ 100 | 101 | ### Why no environment.yml 102 | 103 | Not everyone uses environment.yml, this seems more generic 104 | 105 | ### Why keras specific? 106 | 107 | It is not really keras specific. Same project structure works for other frameworks. For instance in Blocks one 108 | wouldn't need `src/training_loop.py`. -------------------------------------------------------------------------------- /tf2_project_template/tf2_project_template.yml: -------------------------------------------------------------------------------- 1 | name: tf2_project_template 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _tflow_select=2.3.0 6 | - absl-py=0.9.0 7 | - appnope=0.1.0 8 | - astor=0.8.0 9 | - attrs=19.3.0 10 | - backcall=0.1.0 11 | - blas=1.0 12 | - bleach=3.1.4 13 | - c-ares=1.15.0 14 | - ca-certificates=2020.1.1 15 | - certifi=2020.4.5.1 16 | - cycler=0.10.0 17 | - dbus=1.13.14 18 | - decorator=4.4.2 19 | - defusedxml=0.6.0 20 | - entrypoints=0.3 21 | - expat=2.2.6 22 | - freetype=2.9.1 23 | - gast=0.2.2 24 | - gettext=0.19.8.1 25 | - glib=2.63.1 26 | - google-pasta=0.2.0 27 | - grpcio=1.16.1 28 | - h5py=2.10.0 29 | - hdf5=1.10.4 30 | - icu=58.2 31 | - importlib-metadata=1.6.0 32 | - importlib_metadata=1.6.0 33 | - intel-openmp=2019.4 34 | - ipykernel=5.1.4 35 | - ipython=7.13.0 36 | - ipython_genutils=0.2.0 37 | - ipywidgets=7.5.1 38 | - jedi=0.17.0 39 | - jinja2=2.11.2 40 | - joblib=0.15.1 41 | - jpeg=9b 42 | - jsonschema=3.2.0 43 | - jupyter=1.0.0 44 | - jupyter_client=6.1.3 45 | - jupyter_console=6.1.0 46 | - jupyter_core=4.6.3 47 | - keras=2.3.1 48 | - keras-applications=1.0.8 49 | - keras-base=2.3.1 50 | - keras-preprocessing=1.1.0 51 | - kiwisolver=1.2.0 52 | - libcxx=10.0.0 53 | - libedit=3.1.20181209 54 | - libffi=3.3 55 | - libgfortran=3.0.1 56 | - libiconv=1.16 57 | - libpng=1.6.37 58 | - libprotobuf=3.12.3 59 | - libsodium=1.0.16 60 | - llvm-openmp=10.0.0 61 | - markdown=3.1.1 62 | - markupsafe=1.1.1 63 | - matplotlib=3.1.3 64 | - matplotlib-base=3.1.3 65 | - mistune=0.8.4 66 | - mkl=2019.4 67 | - mkl-service=2.3.0 68 | - mkl_fft=1.0.15 69 | - mkl_random=1.1.1 70 | - nbconvert=5.6.1 71 | - nbformat=5.0.6 72 | - ncurses=6.2 73 | - notebook=6.0.3 74 | - numpy=1.18.1 75 | - numpy-base=1.18.1 76 | - openssl=1.1.1g 77 | - opt_einsum=3.1.0 78 | - pandas=1.0.3 79 | - pandoc=2.2.3.2 80 | - pandocfilters=1.4.2 81 | - parso=0.7.0 82 | - pcre=8.43 83 | - pexpect=4.8.0 84 | - pickleshare=0.7.5 85 | - pip=20.0.2 86 | - prometheus_client=0.7.1 87 | - prompt-toolkit=3.0.5 88 | - prompt_toolkit=3.0.5 89 | - protobuf=3.12.3 90 | - ptyprocess=0.6.0 91 | - pygments=2.6.1 92 | - pyparsing=2.4.7 93 | - pyqt=5.9.2 94 | - pyrsistent=0.16.0 95 | - python=3.7.7 96 | - python-dateutil=2.8.1 97 | - pytz=2020.1 98 | - pyyaml=5.3.1 99 | - pyzmq=18.1.1 100 | - qt=5.9.7 101 | - qtconsole=4.7.4 102 | - qtpy=1.9.0 103 | - readline=8.0 104 | - scikit-learn=0.22.1 105 | - scipy=1.4.1 106 | - seaborn=0.10.1 107 | - send2trash=1.5.0 108 | - setuptools=47.1.1 109 | - sip=4.19.8 110 | - six=1.15.0 111 | - sqlite=3.31.1 112 | - tensorboard=2.0.0 113 | - tensorflow=2.0.0 114 | - tensorflow-base=2.0.0 115 | - tensorflow-estimator=2.0.0 116 | - termcolor=1.1.0 117 | - terminado=0.8.3 118 | - testpath=0.4.4 119 | - tk=8.6.8 120 | - tornado=6.0.4 121 | - tqdm=4.46.0 122 | - traitlets=4.3.3 123 | - wcwidth=0.1.9 124 | - webencodings=0.5.1 125 | - werkzeug=1.0.1 126 | - wheel=0.34.2 127 | - widgetsnbextension=3.5.1 128 | - wrapt=1.12.1 129 | - xz=5.2.5 130 | - yaml=0.1.7 131 | - zeromq=4.3.1 132 | - zipp=3.1.0 133 | - zlib=1.2.11 134 | - pip: 135 | - chardet==3.0.4 136 | - dropbox==10.2.0 137 | - neptune-client== 138 | - gin-config==0.3.0 139 | - idna==2.9 140 | - itermplot==0.331 141 | - requests==2.23.0 142 | - urllib3==1.25.9 143 | prefix: /Users/jastrs01/anaconda3/envs/tf2_project_template 144 | 145 | -------------------------------------------------------------------------------- /pytorch_project_template/experiments/tune_lr/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | See main.py -h for help. 5 | """ 6 | import numpy as np 7 | import argh 8 | import glob 9 | import os 10 | import matplotlib.pylab as plt 11 | from os.path import dirname, basename, join 12 | 13 | import pandas as pd 14 | 15 | from gin.config import _CONFIG 16 | from src import RESULTS_DIR 17 | # This will be important for passing gin configs 18 | from bin.train import * 19 | from src import * 20 | 21 | EXPERIMENT_DIR = dirname(__file__) 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def get_gin_value(C, key): 26 | for keys in C: 27 | if keys[1] == key: 28 | return C[keys] 29 | raise IndexError() 30 | 31 | 32 | def load_HC(e): 33 | H = pd.read_csv(join(e, "history.csv")) 34 | # WARNING: Uses naming convention. Dangerous otherwise. 35 | gin.parse_config_file(join(e, join(e, basename(e) + ".gin"))) 36 | C = dict(_CONFIG) 37 | return H, C 38 | 39 | 40 | def prepare(experiment="large"): 41 | os.system("mkdir -p {}/configs".format(experiment)) 42 | if experiment == "large": 43 | lrs = [0.001, 0.01, 0.1] 44 | elif experiment == "small": 45 | lrs = [0.001, 0.1] 46 | else: 47 | raise NotImplementedError() 48 | 49 | # Write configs by modifying template 50 | with open(join(EXPERIMENT_DIR, "template_config.gin"), "r") as f: 51 | TEMPLATE = f.read() 52 | for id, bs in enumerate(lrs): 53 | exp_config_path = join(EXPERIMENT_DIR, experiment, "configs", "{}.gin".format(id)) 54 | os.system("mkdir -p " + dirname(exp_config_path)) 55 | with open(exp_config_path, "w") as c: 56 | c.write(TEMPLATE.replace("$learning_rate$", str(bs))) 57 | 58 | # Prepare batch of runs 59 | with open(join(EXPERIMENT_DIR, experiment, "batch.sh"), "w") as f: 60 | for id in range(len(lrs)): 61 | exp_save_path = join(RESULTS_DIR, "tune_lr", experiment, str(id)) 62 | exp_config_path = join(EXPERIMENT_DIR, experiment, "configs", "{}.gin".format(id)) 63 | if not os.path.exists(join(exp_save_path, "FINISHED")): 64 | os.system("mkdir -p " + exp_save_path) 65 | f.write("python3 bin/train.py {save_path} {config_path} -b training_loop.reload=True \n".format( 66 | save_path=exp_save_path, 67 | config_path=exp_config_path 68 | )) 69 | else: 70 | logger.info("Finished experiment #{}, checking if configs match.".format(id)) 71 | # Ensures that run experiment matches expectations 72 | with open(exp_config_path, "r") as f_c: 73 | with open(join(exp_save_path, "{}.gin".format(id))) as f_e: 74 | c = f_c.read() 75 | c_run = f_e.read() 76 | assert c == c_run, "Finished experiment with a mismatching config. Aborting." 77 | 78 | 79 | def report(experiment="large"): 80 | Es = glob.glob(join(RESULTS_DIR, "tune_lr", experiment, "*")) 81 | Es = sorted(Es, key=lambda e: int(basename(e))) 82 | 83 | x, y = [], [] 84 | for E in Es: 85 | H, C = load_HC(E) 86 | lr = get_gin_value(C, "src.callbacks.callbacks.LRSchedule")['base_lr'] 87 | x.append(lr) 88 | y.append(max(H['acc'])) 89 | 90 | logger.info("Maximum accuracy reached for learning_rate={}.".format(x[np.argmax(y)])) 91 | 92 | plt.plot(x, y) 93 | plt.xlabel("Learning rate") 94 | plt.ylabel("Maximum accuracy") 95 | plt.show() 96 | 97 | if __name__ == "__main__": 98 | argh.dispatch_commands([prepare, report]) 99 | -------------------------------------------------------------------------------- /keras_project_template/src/training_loop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Simple data getters. Each returns iterator for train and dataset for test/valid 4 | """ 5 | 6 | from keras.callbacks import ModelCheckpoint, LambdaCallback, Callback 7 | 8 | # Might misbehave with tensorflow-gpu, make sure u use tensorflow-cpu if using Theano for keras 9 | try: 10 | import tensorflow 11 | except: 12 | pass 13 | 14 | import pandas as pd 15 | import os 16 | import cPickle as pickle 17 | 18 | import logging 19 | logger = logging.getLogger(__name__) 20 | 21 | class DumpTensorflowSummaries(Callback): 22 | def __init__(self, save_path): 23 | self._save_path = save_path 24 | super(DumpTensorflowSummaries, self).__init__() 25 | 26 | @property 27 | def file_writer(self): 28 | if not hasattr(self, '_file_writer'): 29 | self._file_writer = tensorflow.summary.FileWriter( 30 | self._save_path, flush_secs=10.) 31 | return self._file_writer 32 | 33 | def on_epoch_end(self, epoch, logs=None): 34 | summary = tensorflow.Summary() 35 | for key, value in logs.items(): 36 | try: 37 | float_value = float(value) 38 | value = summary.value.add() 39 | value.tag = key 40 | value.simple_value = float_value 41 | except: 42 | pass 43 | self.file_writer.add_summary( 44 | summary, epoch) 45 | 46 | def cifar_training_loop(model, train, valid, 47 | n_epochs, learning_rate_schedule, save_path): 48 | 49 | if os.path.exists(os.path.join(save_path, "loop_state.pkl")): 50 | logger.info("Reloading loop state") 51 | loop_state = pickle.load(open(os.path.join(save_path, "loop_state.pkl"))) 52 | else: 53 | loop_state = {'last_epoch_done_id': -1} 54 | 55 | if os.path.exists(os.path.join(save_path, "model.h5")): 56 | model.load_weights(os.path.join(save_path, "model.h5")) 57 | 58 | samples_per_epoch = 1000 59 | 60 | callbacks = [] 61 | 62 | def lr_schedule(epoch, logs): 63 | for e, v in learning_rate_schedule: 64 | if epoch >= e: 65 | model.optimizer.lr.set_value(v) 66 | break 67 | logger.info("Fix learning rate to {}".format(v)) 68 | 69 | callbacks.append(LambdaCallback(on_epoch_end=lr_schedule)) 70 | 71 | def save_history(epoch, logs): 72 | history_path = os.path.join(save_path, "history.csv") 73 | if os.path.exists(history_path): 74 | H = pd.read_csv(history_path) 75 | H = {col: list(H[col].values) for col in H.columns} 76 | else: 77 | H = {} 78 | 79 | for key, value in logs.items(): 80 | if key not in H: 81 | H[key] = [value] 82 | else: 83 | H[key].append(value) 84 | 85 | pd.DataFrame(H).to_csv(os.path.join(save_path, "history.csv"), index=False) 86 | 87 | callbacks.append(LambdaCallback(on_epoch_end=save_history)) 88 | # Uncomment if you have tensorflow installed correctly 89 | # callbacks.append(DumpTensorflowSummaries(save_path=save_path)) 90 | callbacks.append(ModelCheckpoint(monitor='val_acc', 91 | save_weights_only=False, filepath=os.path.join(save_path, "model.h5"))) 92 | 93 | def save_loop_state(epoch, logs): 94 | loop_state = {"last_epoch_done_id": epoch} 95 | pickle.dump(loop_state, open(os.path.join(save_path, "loop_state.pkl"), "w")) 96 | callbacks.append(LambdaCallback(on_epoch_end=save_loop_state)) 97 | 98 | _ = model.fit_generator(train, 99 | initial_epoch=loop_state['last_epoch_done_id'] + 1, 100 | samples_per_epoch=samples_per_epoch, 101 | nb_epoch=n_epochs, verbose=1, 102 | validation_data=valid, 103 | callbacks=callbacks) 104 | -------------------------------------------------------------------------------- /tf2_project_template/bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trainer script for the project. 5 | 6 | Example run commands: 7 | 8 | * python bin/train.py tst configs/scnn.gin - Trains SimpleCNN on the CIFAR-10 dataset 9 | """ 10 | 11 | import gin 12 | import os 13 | import logging 14 | import json 15 | import tensorflow as tf 16 | import numpy as np 17 | 18 | from gin.config import _CONFIG 19 | from tensorflow.keras.optimizers import SGD, Adam 20 | from tensorflow.keras.metrics import categorical_accuracy 21 | 22 | from src.data import get_dataset 23 | from src import models 24 | from src.training_loop import training_loop, evaluate, restore_model 25 | from src.callbacks import get_callback 26 | from src.utils import gin_wrap 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | @gin.configurable 31 | def train(save_path, 32 | model, 33 | datasets=['cifar10'], 34 | optimizer="SGD", 35 | data_seed=777, 36 | seed=777, 37 | batch_size=128, 38 | lr=0.0, 39 | wd=0.0, 40 | nesterov=False, 41 | checkpoint_monitor='val_categorical_accuracy:0', 42 | loss='ce', 43 | steps_per_epoch=-1, 44 | momentum=0.9, 45 | testing=False, 46 | testing_reload_best_val=True, 47 | callbacks=[]): 48 | np.random.seed(seed) 49 | 50 | # Create dataset generators (seeded) 51 | datasets = [get_dataset(d, seed=data_seed, batch_size=batch_size) for d in datasets] 52 | 53 | # Create model 54 | model = models.__dict__[model](input_shape=datasets[0][-1]['input_shape'], n_classes=datasets[0][-1]['num_classes']) 55 | logger.info("# of parameters " + str(sum([np.prod(p.shape) for p in model.trainable_weights]))) 56 | model.summary() 57 | if loss == 'ce': 58 | loss_function = tf.keras.losses.categorical_crossentropy 59 | else: 60 | raise NotImplementedError() 61 | 62 | if optimizer == "SGD": 63 | optimizer = SGD(learning_rate=lr, momentum=momentum, nesterov=nesterov) 64 | elif optimizer == "Adam": 65 | optimizer = Adam(learning_rate=lr) 66 | else: 67 | raise NotImplementedError() 68 | 69 | # Create callbacks 70 | callbacks_constructed = [] 71 | for name in callbacks: 72 | clbk = get_callback(name, verbose=0) 73 | if clbk is not None: 74 | callbacks_constructed.append(clbk) 75 | else: 76 | raise NotImplementedError(f"Did not find callback {name}") 77 | 78 | # Pass everything to the training loop 79 | metrics = [categorical_accuracy] 80 | 81 | if steps_per_epoch == -1: 82 | steps_per_epoch = (datasets[0][-1]['n_examples_train'] + batch_size - 1) // batch_size 83 | 84 | training_loop(model=model, optimizer=optimizer, loss_function=loss_function, metrics=metrics, datasets=datasets, 85 | weight_decay=wd, save_path=save_path, config=_CONFIG, steps_per_epoch=steps_per_epoch, 86 | use_tb=True, checkpoint_monitor=checkpoint_monitor, custom_callbacks=callbacks_constructed, seed=seed) 87 | 88 | if testing: 89 | if testing_reload_best_val: 90 | model = restore_model(model, os.path.join(save_path, "model_best_val.h5")) 91 | 92 | m_val = evaluate(model, [datasets[0][1]], loss_function, metrics) 93 | m_test = evaluate(model, [datasets[0][2]], loss_function, metrics) 94 | 95 | logger.info("Saving") 96 | eval_results = {} 97 | for k in m_test: 98 | eval_results['test_' + k] = float(m_test[k]) 99 | for k in m_val: 100 | eval_results['val_' + k] = float(m_val[k]) 101 | logger.info(eval_results) 102 | json.dump(eval_results, open(os.path.join(save_path, "eval_results.json"), "w")) 103 | 104 | 105 | if __name__ == "__main__": 106 | gin_wrap(train) 107 | -------------------------------------------------------------------------------- /pytorch_project_template/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch project template 2 | 3 | Simple machine learning project template based on PyTorch. 4 | 5 | If you are impatient just jump to the tutorial at the end of this README. 6 | 7 | ## Introduction 8 | 9 | The main goal of this template is to make easy following the state-of-the-art good practices for a machine learning project. This includes reducing boilerplate, or keeping config handling simple and consistent. 10 | 11 | First, this template includes a minimal trainer ``bin/train.py`` that has: 12 | 13 | * A gorgeous training loop (generic training script, checkpointing, callbacks, etc.) 14 | * Beautiful config handling 15 | - We use gin for this 16 | * Amazing automatic saving of logs and other auxiliary files 17 | 18 | This repo also ships with: 19 | 20 | * An example experiment in `experiments/tune_lr` 21 | * Environment configuration 22 | 23 | In the rest of this document we walk through all the key design principles and how we implement them. Finally, there is a quick tutorial. 24 | 25 | ## Good practices 26 | 27 | Here is a non-exhaustive list of good practices that this project tries to implement. They are based on 28 | state-of-the-art ideas in the community about how to organise a machine learning project. 29 | 30 | Do you have other ideas? Please open an issue and let's discuss. Here are ours: 31 | 32 | * Use a consistent environment 33 | 34 | - We use conda. See `e.sh`. You should source it each time you start working on the repo. You shouldn't commi it. 35 | 36 | * Use a common training loops for all trainings. Use callbacks. Everything should be resumable. 37 | 38 | - See `src/training_loop.py` 39 | 40 | * Structure code so that you can plug-and-play different models, datasets, callbacks, etc. 41 | 42 | - See `data`, `models` and `callbacks` modules. See how `bin/train.py` allows for easy gin-based configuration. 43 | 44 | * Separate functionalities as binaries. Motivation is similar to tooling like `ls` in unix systems. 45 | 46 | - See `bin` folder 47 | 48 | * Store common configs. 49 | 50 | - See `configs` folder. 51 | 52 | * Each run should have a dedicated folder with everything in one place, always in the same format (including logs, used command). 53 | 54 | - See `results/example_run` folder for example files that are produced. 55 | 56 | * Each experiment should be as self-contained as possible, e.g. include runner, plotting utilities, a README file, etc. 57 | 58 | - See `experiments/tune_lr` for an example. 59 | 60 | * Test everything easily testable 61 | 62 | - We have asserts sprinkled across the code, but probably not as many as we should. 63 | 64 | 65 | ## Tutorial: single training 66 | 67 | Take the following steps: 68 | 69 | 1. Install a minimal conda environment: ``conda env create --file e.yml``. 70 | 71 | 2. Activate the environment: ``source e.sh``. 72 | 73 | 3. Train on few batches a CNN on Cifar10: ``bin/train.py save_to_folder configs/cnn.gin``. 74 | 75 | 4. Run ``tensorboard --logdir=save_to_folder`` to visualize the learning curves. 76 | 77 | Configuration is done using gin. This allows for a flexible configuration of training. For instance, to continue training for more epochs you can run: ``bin/train.py save_to_folder configs/cnn.gin -b="training_loop.n_epochs=5#training_loop.reload=True"``. 78 | 79 | Note: training won't reach sensible accuracies. This is on purpose so that the demonstration works on small machines. For a bit more realistic training configuration see `configs/cnn_full.gin`. 80 | 81 | ## Tutorial: experiment example 82 | 83 | Experiment conceptually is a list of shell jobs. For convenience this can be wrapped using a python script that prepares jobs, analyses the runs, stores configs, etc. 84 | 85 | We ship an example experiment, where we tune LR for the small CNN on Cifar10. Here is the typical workflow: 86 | 87 | 1. Prepare experiments: `python experiments/tune_lr/main.py prepare` 88 | 89 | 2. See prepare configs: `ls experiments/tune_lr/large/configs` 90 | 91 | 3. Run experiments: `bash experiments/tune_lr/large/batch.sh` 92 | 93 | 4. See runs: `ls $RESULTS_DIR/tune_lr/large` 94 | 95 | 5. Process experiment results: `python experiments/tune_lr/main.py report`. Bonus for OSX users: To enable plotting in iterm install ``pip install itermplot``, and uncomment the appropriate line in ``e.sh```. 96 | 97 | 6. Take a look at the main.py source code to understand better the logic. 98 | 99 | Note that running a list of shell jobs can be done using a scheduler. This is best if you develop your own 100 | solution for runnning efficiently such a list. 101 | 102 | -------------------------------------------------------------------------------- /tf2_project_template/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow 2.0 project template 2 | 3 | Simple machine learning project template based on TensorFlow 2.0 (which is extremely similar to PyTorch or JAX). 4 | 5 | ## Installation and Setup 6 | 7 | This template uses conda. To setup the enviornment, run: 8 | 9 | ``` 10 | conda env create --file tf2_project_template.yml 11 | ``` 12 | 13 | Then, to start working you can run: 14 | 15 | ``` 16 | source env.sh 17 | ``` 18 | 19 | ## Introduction 20 | 21 | The main goal of this template is to make easy following the state-of-the-art good practices for a machine learning project. This includes reducing boilerplate, or keeping config handling simple and consistent. 22 | 23 | First, this template includes a minimal trainer ``bin/train.py`` that has: 24 | 25 | * A gorgeous self-contained training loop (generic training script, checkpointing, callbacks, etc.) 26 | * Beautiful config handling 27 | - We use gin for this 28 | * Amazing automatic saving of logs and other auxiliary files 29 | 30 | This repo also ships with: 31 | 32 | * An example experiment in `experiments/tune_lr` 33 | * Environment configuration 34 | * Integration with Neptune 35 | * Various utilities such as a tool for running and managing a list of jobs on free GPUs or SLURM 36 | 37 | In the rest of this document we walk through all the key design principles and how we implement them. Finally, there is a quick tutorial. 38 | 39 | ## Tutorial: single training 40 | 41 | Take the following steps: 42 | 43 | 1. Install a minimal conda environment: ``conda env create --file e.yml``. 44 | 45 | 2. Activate the environment: ``source e.sh``. 46 | 47 | 3. Train on few batches a CNN on Cifar10: ``python bin/train.py save_to_folder configs/scnn.gin``. 48 | 49 | 4. Run ``tensorboard --logdir=save_to_folder`` to visualize the learning curves. 50 | 51 | Configuration is done using gin. This allows for a flexible configuration of training. For instance, to continue training for more epochs you can run: ``bin/train.py save_to_folder configs/cnn.gin -b="training_loop.n_epochs=5#training_loop.reload=True"``. 52 | 53 | Note: training won't reach sensible accuracies. This is on purpose so that the demonstration works on small machines. For a bit more realistic training configuration see `configs/cnn_full.gin`. 54 | 55 | ## Tutorial: experiment example 56 | 57 | Experiment conceptually is a list of shell jobs. For convenience this can be wrapped using a python script that prepares jobs, analyses the runs, stores configs, etc. 58 | 59 | We ship an example experiment, where we tune LR for the small CNN on Cifar10. Here is the typical workflow: 60 | 61 | 1. Prepare experiments: `python experiments/tune_lr/main.py prepare` 62 | 63 | 2. See prepare configs: `ls experiments/tune_lr/large/configs` 64 | 65 | 3. Run experiments: `bash experiments/tune_lr/large/batch.sh` 66 | 67 | 4. See runs: `ls $RESULTS_DIR/tune_lr/large` 68 | 69 | 5. Process the results: `python experiments/tune_lr/main.py report`. Bonus for OSX users: To enable plotting in iterm install ``pip install itermplot``, and uncomment the appropriate line in ``e.sh```. 70 | 71 | Note that running a list of shell jobs can be done using a scheduler. We provide a reference implementation for two such schedulers in `bin/utils`. 72 | 73 | ## Extras 74 | 75 | * For neptune integration example run `python bin/train.py tst configs/scnn_neptune.gin`. 76 | 77 | * See `bin/utils/run_slurm.py` for a simple SLURM job manager. 78 | 79 | * See `bin/utils/run_on_a_gpu.py`/`bin/utils/run_on_free_gpus.py` for a simple GPU job manager on a single machine. 80 | 81 | * See `bin/utils/watch_changes.sh` for a simple script that automatically syncs changes to a machine. 82 | 83 | * See `bin/utils/update_plots.py` for a script updating all figures in .tex with new versions. 84 | 85 | * See `bin/evaluate.py` for an evaluation script. 86 | 87 | ## Good practices 88 | 89 | Here is a non-exhaustive list of good practices that this project tries to implement. They are based on 90 | state-of-the-art ideas in the community about how to organise a machine learning project. 91 | 92 | Do you have other ideas? Please open an issue and let's discuss. Here are ours: 93 | 94 | * Use a consistent environment 95 | 96 | - We use conda. See `e.sh`. You should source it each time you start working on the repo. You shouldn't commi it. 97 | 98 | * Use a common training loops for all trainings. Use callbacks. Everything should be resumable. 99 | 100 | - See `src/training_loop.py` 101 | 102 | * Structure code so that you can plug-and-play different models, datasets, callbacks, etc. 103 | 104 | - See `data`, `models` and `callbacks` modules. See how `bin/train.py` allows for easy gin-based configuration. 105 | 106 | * Separate functionalities as binaries. Motivation is similar to tooling like `ls` in unix systems. 107 | 108 | - See `bin` folder 109 | 110 | * Store common configs. 111 | 112 | - See `configs` folder. 113 | 114 | * Each run should have a dedicated folder with everything in one place, always in the same format (including logs, used command). 115 | 116 | - See `results/example_run` folder for example files that are produced. 117 | 118 | * Each experiment should be as self-contained as possible, e.g. include runner, plotting utilities, a README file, etc. 119 | 120 | - See `experiments/tune_lr` for an example. 121 | 122 | * Test everything easily testable 123 | 124 | - We have asserts sprinkled across the code, but probably not as many as we should. 125 | 126 | -------------------------------------------------------------------------------- /tf2_project_template/src/plotting.py: -------------------------------------------------------------------------------- 1 | # Utils for plotting. 2 | import json 3 | import logging 4 | import os 5 | import pickle 6 | 7 | import matplotlib as mpl 8 | import matplotlib.style 9 | import numpy as np 10 | import pandas as pd 11 | from matplotlib import pyplot as plt 12 | 13 | import copy 14 | import gin 15 | import time 16 | 17 | from gin.config import _CONFIG 18 | 19 | from src.utils import configure_neptune_exp, get_neptune_exp, configure_logger 20 | from src import PROJECT_NAME 21 | from os.path import join 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def load_C(E): 27 | return load_HC(E)[1] 28 | 29 | 30 | def load_HbC(e): 31 | # A utility function to load H (per batch) and C 32 | if len(_CONFIG): 33 | logger.warning("Erasing global gin config") 34 | gin.clear_config() 35 | H = None 36 | for i in range(3): 37 | try: 38 | H = pickle.load(open(join(e, "history_batch.pkl"), "rb")) 39 | H_epoch = pickle.load(open(join(e, "history.pkl"), "rb")) 40 | except: 41 | print("Warning. Faied to read :" + e) 42 | time.sleep(1) 43 | continue 44 | assert H is not None, "Failed to read " + e 45 | for k in H_epoch: 46 | H['epoch_' + k] = H_epoch[k] 47 | for k in H: 48 | H[k] = np.array(H[k]) 49 | gin.parse_config(open(join(e, "config.gin"))) 50 | C = copy.deepcopy(_CONFIG) 51 | C = {k[1].split(".")[-1]: v for k, v in C.items()} # Hacky way to simplify config 52 | return H, C 53 | 54 | 55 | def load_HC(e, force_pkl=False): 56 | # A utility function to load H and C 57 | if len(_CONFIG): 58 | logger.warning("Erasing global gin config") 59 | gin.clear_config() 60 | H = None 61 | for i in range(3): 62 | try: 63 | if os.path.exists(join(e, 'history.csv')) and not force_pkl: 64 | H = pd.read_csv(join(e, "history.csv")) 65 | else: 66 | H = pickle.load(open(join(e, "history.pkl"), "rb")) 67 | except: 68 | print("Warning. Faied to read :" + e) 69 | time.sleep(1) 70 | continue 71 | assert H is not None, "Failed to read" + e 72 | gin.parse_config(open(join(e, "config.gin"))) 73 | C = copy.deepcopy(_CONFIG) 74 | C = {k[1].split(".")[-1]: v for k, v in C.items()} # Hacky way to simplify config 75 | for k in list(_CONFIG): 76 | del _CONFIG[k] 77 | return H, C 78 | 79 | 80 | def construct_colorer(sorted_vals, cmap="coolwarm"): 81 | cm = plt.get_cmap(cmap, len(sorted_vals)) 82 | N = float(len(sorted_vals)) 83 | 84 | def _get_cm(val): 85 | return cm(sorted_vals.index(val) / N) 86 | 87 | return _get_cm 88 | 89 | 90 | def construct_marker(sorted_vals): 91 | cm = ['o', 'x', 'd'] 92 | N = float(len(sorted_vals)) 93 | 94 | def _get_cm(val): 95 | return cm[sorted_vals.index(val)] 96 | 97 | return _get_cm 98 | 99 | 100 | def construct_colorer_lin_scale(vmin, vmax, ticks=20, cmap="coolwarm"): 101 | assert vmax > vmin 102 | 103 | cm = plt.get_cmap(cmap, ticks) 104 | 105 | def _get_cm(val): 106 | alpha = (val - vmin) / float((vmax - vmin)) 107 | tick = int(alpha * ticks) 108 | tick = min(tick, ticks - 1) 109 | return cm(tick) 110 | 111 | return _get_cm 112 | 113 | 114 | try: 115 | import dropbox 116 | access_token = os.environ['DROPBOXTOKEN'] 117 | 118 | class TransferData: 119 | def __init__(self, access_token): 120 | self.access_token = access_token 121 | 122 | def upload_file(self, file_from, file_to): 123 | """upload a file to Dropbox using API v2 124 | """ 125 | dbx = dropbox.Dropbox(self.access_token) 126 | with open(file_from, 'rb') as f: 127 | try: 128 | dbx.files_upload(f.read(), file_to, mode=dropbox.files.WriteMode.overwrite) 129 | except: 130 | logger.error("Failed uploading " + file_from + " " + file_to) 131 | 132 | 133 | transferData = TransferData(access_token) 134 | except: 135 | transferData = None 136 | 137 | 138 | def save_fig(dir_name, figure_name, copy_to_dropbox=False, copy_to_neptune=False): 139 | if figure_name.endswith("pdf"): 140 | figure_name = figure_name[0:-4] 141 | figure_name = figure_name.replace("=", "").replace(":", "").replace("_", "").replace(".", "") \ 142 | .replace("/", "_").replace("$", "").replace("\\", "") 143 | path = os.path.join(dir_name, figure_name + ".pdf") 144 | 145 | if not os.path.exists(os.path.dirname(path)): 146 | os.system("mkdir -p " + os.path.dirname(path)) 147 | 148 | logger.info('Figure saved to: ' + path) 149 | 150 | fig = plt.gcf() 151 | fig.savefig(path, bbox_inches='tight', 152 | transparent=True, 153 | pad_inches=0) 154 | 155 | if copy_to_dropbox: 156 | # / because it is a dropbox app with own folder 157 | transferData.upload_file(path, os.path.join("/", PROJECT_NAME, path)) 158 | 159 | if copy_to_neptune: 160 | neptune_exp = get_neptune_exp() 161 | neptune_exp.send_image(figure_name, fig) 162 | 163 | 164 | if __name__ == "__main__": 165 | configure_logger('') 166 | configure_neptune_exp('tst') 167 | plt.plot([1,2,3], [1,2,4]) 168 | save_fig("examples", "qudratic.pdf", copy_to_dropbox=True, copy_to_neptune=True) 169 | plt.show() 170 | plt.close() -------------------------------------------------------------------------------- /pytorch_lightning_project_template/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Lightning project template 2 | 3 | Simple machine learning project template based on PyTorch and Pytorch Lightning and other key tools (gin, neptune, and more). The main ambition of this template is to make easy following the state-of-the-art good practices for a machine learning project. 4 | 5 | If you are impatient just jump to the tutorial at the end of this README. 6 | 7 | ## Why not just use Pytorch Lightning? 8 | 9 | While a huge step forward, Pytorch Lightning (PL) still leaves you with many choices to be made. This repository makes opinionated choices that integrates best ideas from PL, Keras, and other frameworks: 10 | 11 | * Modular code structure that separates out models/data/callbacks (like in Keras) and avoids monolithic and hard to read code 12 | 13 | - We advocate against putting too much logic in a PL module as it becomes monolithic and hard to read and work with 14 | 15 | * Configuration using gin config from Google 16 | 17 | - PL uses argparse, which requires a lot of boilerplate code 18 | 19 | * Dynamic loading of callbacks, models, etc, by name (like in Keras) 20 | 21 | 22 | * Providing template for running a grid search (see ``experiments``) 23 | 24 | 25 | * Template for environment configuration (as sourced e.sh) 26 | 27 | 28 | * Unified loading and processing experimental results (``load_C``, ``load_H`` functions, evaluation scripts) 29 | 30 | - his is a good design principle that makes it easier to repurpose plotting code 31 | 32 | * Utility scripts templates (automatic file syncing using watchman, running on SLURM, etc) 33 | 34 | * We advocate against using ad-hoc script for processing experiments but rather for using structured code (see `experiments/tune_lr/main.py`) 35 | 36 | - In this way you will reuse the code more often and write less bugs. No more IPython notebooks that no one understands and have ton of bugs. 37 | 38 | ## Tutorial: single training 39 | 40 | Take the following steps: 41 | 42 | 1. Install a minimal conda environment: ``conda env create --file e.yml``. 43 | 44 | 2. Activate the environment: ``source e.sh``. 45 | 46 | 3. Train on few batches a CNN on Cifar10: ``bin/train_supervised.py save_to_folder configs/cnn.gin``. 47 | 48 | 4. Run ``tensorboard --logdir=save_to_folder`` to visualize the learning curves. 49 | 50 | Configuration is done using gin. This allows for a flexible configuration of training. For instance, to continue training for more epochs you can run: ``bin/train_supervised.py save_to_folder configs/cnn.gin -b="training_loop.n_epochs=5#training_loop.resume=True"``. 51 | 52 | Note: training won't reach sensible accuracies. This is on purpose so that the demonstration works on small machines. For a bit more realistic training configuration see `configs/cnn_full.gin`. 53 | 54 | ## Tutorial: experiment example 55 | 56 | Experiment conceptually is a list of shell jobs. For convenience this can be wrapped using a python script that prepares jobs, analyses the runs, stores configs, etc. 57 | 58 | We ship an example experiment, where we tune LR for the small CNN on Cifar10. Here is the typical workflow: 59 | 60 | 1. Prepare experiments: `python experiments/tune_lr/main.py prepare` 61 | 62 | 2. See prepare configs: `ls experiments/tune_lr/large/configs` 63 | 64 | 3. Run experiments: `bash experiments/tune_lr/large/batch.sh` 65 | 66 | 4. See runs: `ls $RESULTS_DIR/tune_lr/large` 67 | 68 | 5. Process experiment results: `python experiments/tune_lr/main.py report`. Bonus for OSX users: To enable plotting in iterm install ``pip install itermplot``, and uncomment the appropriate line in ``e.sh```. 69 | 70 | 6. Take a look at the main.py source code to understand better the logic. 71 | 72 | Note that running a list of shell jobs can be done using a scheduler. This is best if you develop your own 73 | solution for runnning efficiently such a list. 74 | 75 | ## Appendix 76 | 77 | ### Good practices 78 | 79 | Here is a non-exhaustive list of good practices that this project tries to implement. They are based on 80 | state-of-the-art ideas in the community about how to organise a machine learning project. 81 | 82 | Do you have other ideas? Please open an issue and let's discuss. Here are ours: 83 | 84 | * Use a consistent environment 85 | 86 | - We use conda. See `e.sh`. You should source it each time you start working on the repo. You shouldn't commi it. 87 | 88 | * Use a common training loops for all trainings. Use callbacks. Everything should be resumable. 89 | 90 | - See `src/training_loop.py` 91 | 92 | * Structure code so that you can plug-and-play different models, datasets, callbacks, etc. 93 | 94 | - See `data`, `models` and `callbacks` modules. See how `bin/train.py` allows for easy gin-based configuration. 95 | 96 | * Separate functionalities as binaries. Motivation is similar to tooling like `ls` in unix systems. 97 | 98 | - See `bin` folder 99 | 100 | * Store common configs. 101 | 102 | - See `configs` folder. 103 | 104 | * Each run should have a dedicated folder with everything in one place, always in the same format (including logs, used command). 105 | 106 | - See `results/example_run` folder for example files that are produced. 107 | 108 | * Each experiment should be as self-contained as possible, e.g. include runner, plotting utilities, a README file, etc. 109 | 110 | - See `experiments/tune_lr` for an example. 111 | 112 | * Test everything easily testable 113 | 114 | - We have asserts sprinkled across the code, but probably not as many as we should. 115 | -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/run_on_free_gpus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple self-contained script to run on a list of jobs on all free GPUs on a machine 5 | 6 | See bin/utils/run_on_free_gpus.py -h 7 | 8 | There are the following requirements for the batch: 9 | * Each command has saving dir as 2nd argument 10 | * Each script saves to the save_dir HEARTBEAT 11 | * Each script saves to the save_dir FINISHED when done 12 | """ 13 | 14 | print("Remember to multiple your number of jobs by factor of threads per cpu") 15 | 16 | import os 17 | import time 18 | import numpy as np 19 | from os import path 20 | import pandas as pd 21 | import argh 22 | import subprocess 23 | import logging 24 | 25 | from src.utils import configure_logger 26 | configure_logger('', log_file=None) 27 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def get_next_free_gpu(): 33 | for i in range(10): 34 | try: 35 | output = subprocess.check_output(['nvidia-smi', '-i', str(i)]).decode("utf-8") 36 | except: 37 | logger.warning("Failed nvidia-smi {}".format(i)) 38 | output = "" 39 | 40 | if output == "No devices were found": 41 | return None 42 | elif "No running processes found" in output: 43 | return i 44 | else: 45 | continue 46 | 47 | import stat 48 | 49 | def get_last_modification(save_path): 50 | f_path = os.path.join(save_path, "HEARTBEAT") 51 | stderr_path = os.path.join(save_path, "stderr.txt") 52 | if os.path.exists(f_path): 53 | return time.time() - os.stat(f_path)[stat.ST_MTIME] 54 | else: 55 | if os.path.exists(stderr_path): 56 | return time.time() - os.stat(stderr_path)[stat.ST_MTIME] 57 | else: 58 | return 10000000000000 59 | 60 | 61 | def get_n_jobs(script_name="train"): 62 | try: 63 | output = subprocess.check_output('ps -f | grep {}'.format(script_name), shell=True).decode("utf-8").strip() 64 | except: 65 | print("No jobs") 66 | output = "" 67 | # print("===") 68 | # print(output) 69 | # print("===") 70 | return len(output.split("\n")) - 2 71 | 72 | def get_save_path(job): 73 | assert "$" not in job, "Seems like there in env variable in the command" 74 | if job.startswith("python"): 75 | return job.split(" ")[2] 76 | else: 77 | return job.split(" ")[1] 78 | 79 | 80 | def get_script_name(job): 81 | if job.startswith("python"): 82 | return job.split(" ")[1] 83 | else: 84 | return job.split(" ")[0] 85 | 86 | 87 | def has_finished(save_path): 88 | # Hacky but works, usually 89 | return path.exists(path.join(save_path, "FINISHED")) 90 | 91 | 92 | def get_jobs(batch): 93 | jobs = list(open(batch, "r").read().splitlines()) 94 | jobs = [j for j in jobs if not has_finished(get_save_path(j))] 95 | # take only at least 10min old jobs 96 | jobs = [j for j in jobs if get_last_modification(get_save_path(j)) > 600] 97 | jobs = [("python " + j) if "python" not in j else j for j in jobs] 98 | np.random.shuffle(jobs) 99 | return jobs 100 | 101 | def tensorboard_running(): 102 | output = subprocess.check_output('ps | grep tensorboard', shell=True).decode("utf-8").strip() 103 | return len(output.split("\n")) > 1 104 | 105 | def run(batch, max_jobs=1): 106 | try: 107 | if len(get_jobs(batch)) == 0: 108 | logger.error("No untouched (>10min old) jobs found. Exiting.") 109 | exit(1) 110 | 111 | tb_dir = os.path.join(RESULTS_DIR, "running_experiments") 112 | os.system("mkdir -p " +tb_dir) 113 | 114 | save_path = get_save_path(get_jobs(batch)[0]) 115 | script_name = get_script_name(get_jobs(batch)[0]) 116 | 117 | root_save_path = path.dirname(save_path) 118 | os.system("rm " + root_save_path + " " + tb_dir) 119 | os.system("ln -s " + root_save_path + " " + tb_dir) 120 | logger.info("Running tensorboard in " + tb_dir) 121 | os.system("tensorboard --port=7777 --logdir=" + tb_dir+ " &") 122 | 123 | while True: 124 | print("next_free_gpu={},n_jobs running={},batch={}".format(get_next_free_gpu(), get_n_jobs(script_name), batch)) 125 | jobs = get_jobs(batch) 126 | logger.info("Found {}".format(len(jobs))) 127 | if len(jobs): 128 | job = jobs[0] 129 | gpu = get_next_free_gpu() 130 | n_jobs = get_n_jobs(script_name) 131 | if gpu is not None and max_jobs > n_jobs: 132 | os.system("mkdir -p " + get_save_path(job)) 133 | # Run and redirect all output to a file in the save folder of the job 134 | logger.info("Running " + job) 135 | os.system("CUDA_VISIBLE_DEVICES={} {}".format(gpu, job) + "> {} 2>&1".format(os.path.join(get_save_path(job), "last_run.out")) + " &") 136 | while get_last_modification(get_save_path(job)) > 600 or get_next_free_gpu() == gpu: 137 | print("Waiting for bootup (no HEARTBEAT or not occupied GPU)... last_mod={},next_free_gpu={},n_jobs={}".format( 138 | get_last_modification(get_save_path(job)), get_next_free_gpu(), get_n_jobs(script_name))) 139 | time.sleep(1) 140 | elif gpu is None: 141 | logger.warning("No free GPUs") 142 | elif max_jobs <= n_jobs: 143 | logger.warning("Have {} jobs running but can run max max_jobs={}".format(get_n_jobs(script_name), max_jobs)) 144 | else: 145 | raise NotImplementedError() 146 | else: 147 | logger.info("No jobs found") 148 | if get_n_jobs(script_name)==0: 149 | exit(0) 150 | # Allow to kill easilyq 151 | time.sleep(5) 152 | except KeyboardInterrupt: 153 | logger.warning("Interrupt. Killing all python & tensorboard jobs.") 154 | os.system("ps | grep python |awk '{print $1}' | xargs kill -9") 155 | os.system("ps | grep tensorboard |awk '{print $1}' | xargs kill -9") 156 | os.system("rm " + os.path.join(tb_dir, path.basename(root_save_path))) 157 | 158 | if __name__ == "__main__": 159 | argh.dispatch_command(run) 160 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/run_on_free_gpus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple self-contained script to run on a list of jobs on all free GPUs on a machine 5 | 6 | See bin/utils/run_on_free_gpus.py -h 7 | 8 | There are the following requirements for the batch: 9 | * Each command has saving dir as 2nd argument 10 | * Each script saves to the save_dir HEARTBEAT 11 | * Each script saves to the save_dir FINISHED when done 12 | """ 13 | 14 | print("Remember to multiple your number of jobs by factor of threads per cpu") 15 | 16 | import os 17 | import time 18 | import numpy as np 19 | from os import path 20 | import pandas as pd 21 | import argh 22 | import subprocess 23 | import logging 24 | 25 | from src.utils import configure_logger 26 | configure_logger('', log_file=None) 27 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def get_next_free_gpu(): 33 | for i in range(10): 34 | try: 35 | output = subprocess.check_output(['nvidia-smi', '-i', str(i)]).decode("utf-8") 36 | except: 37 | logger.warning("Failed nvidia-smi {}".format(i)) 38 | output = "" 39 | 40 | if output == "No devices were found": 41 | return None 42 | elif "No running processes found" in output: 43 | return i 44 | else: 45 | continue 46 | 47 | import stat 48 | 49 | def get_last_modification(save_path): 50 | f_path = os.path.join(save_path, "HEARTBEAT") 51 | stderr_path = os.path.join(save_path, "stderr.txt") 52 | if os.path.exists(f_path): 53 | return time.time() - os.stat(f_path)[stat.ST_MTIME] 54 | else: 55 | if os.path.exists(stderr_path): 56 | return time.time() - os.stat(stderr_path)[stat.ST_MTIME] 57 | else: 58 | return 10000000000000 59 | 60 | 61 | def get_n_jobs(script_name="train"): 62 | try: 63 | output = subprocess.check_output('ps -f | grep {}'.format(script_name), shell=True).decode("utf-8").strip() 64 | except: 65 | print("No jobs") 66 | output = "" 67 | # print("===") 68 | # print(output) 69 | # print("===") 70 | return len(output.split("\n")) - 2 71 | 72 | def get_save_path(job): 73 | assert "$" not in job, "Seems like there in env variable in the command" 74 | if job.startswith("python"): 75 | return job.split(" ")[2] 76 | else: 77 | return job.split(" ")[1] 78 | 79 | 80 | def get_script_name(job): 81 | if job.startswith("python"): 82 | return job.split(" ")[1] 83 | else: 84 | return job.split(" ")[0] 85 | 86 | 87 | def has_finished(save_path): 88 | # Hacky but works, usually 89 | return path.exists(path.join(save_path, "FINISHED")) 90 | 91 | 92 | def get_jobs(batch): 93 | jobs = list(open(batch, "r").read().splitlines()) 94 | jobs = [j for j in jobs if not has_finished(get_save_path(j))] 95 | # take only at least 10min old jobs 96 | jobs = [j for j in jobs if get_last_modification(get_save_path(j)) > 600] 97 | jobs = [("python " + j) if "python" not in j else j for j in jobs] 98 | np.random.shuffle(jobs) 99 | return jobs 100 | 101 | def tensorboard_running(): 102 | output = subprocess.check_output('ps | grep tensorboard', shell=True).decode("utf-8").strip() 103 | return len(output.split("\n")) > 1 104 | 105 | def run(batch, max_jobs=1): 106 | try: 107 | if len(get_jobs(batch)) == 0: 108 | logger.error("No untouched (>10min old) jobs found. Exiting.") 109 | exit(1) 110 | 111 | tb_dir = os.path.join(RESULTS_DIR, "running_experiments") 112 | os.system("mkdir -p " +tb_dir) 113 | 114 | save_path = get_save_path(get_jobs(batch)[0]) 115 | script_name = get_script_name(get_jobs(batch)[0]) 116 | 117 | root_save_path = path.dirname(save_path) 118 | os.system("rm " + root_save_path + " " + tb_dir) 119 | os.system("ln -s " + root_save_path + " " + tb_dir) 120 | logger.info("Running tensorboard in " + tb_dir) 121 | os.system("tensorboard --port=7777 --logdir=" + tb_dir+ " &") 122 | 123 | while True: 124 | print("next_free_gpu={},n_jobs running={},batch={}".format(get_next_free_gpu(), get_n_jobs(script_name), batch)) 125 | jobs = get_jobs(batch) 126 | logger.info("Found {}".format(len(jobs))) 127 | if len(jobs): 128 | job = jobs[0] 129 | gpu = get_next_free_gpu() 130 | n_jobs = get_n_jobs(script_name) 131 | if gpu is not None and max_jobs > n_jobs: 132 | os.system("mkdir -p " + get_save_path(job)) 133 | # Run and redirect all output to a file in the save folder of the job 134 | logger.info("Running " + job) 135 | os.system("CUDA_VISIBLE_DEVICES={} {}".format(gpu, job) + "> {} 2>&1".format(os.path.join(get_save_path(job), "last_run.out")) + " &") 136 | while get_last_modification(get_save_path(job)) > 600 or get_next_free_gpu() == gpu: 137 | print("Waiting for bootup (no HEARTBEAT or not occupied GPU)... last_mod={},next_free_gpu={},n_jobs={}".format( 138 | get_last_modification(get_save_path(job)), get_next_free_gpu(), get_n_jobs(script_name))) 139 | time.sleep(1) 140 | elif gpu is None: 141 | logger.warning("No free GPUs") 142 | elif max_jobs <= n_jobs: 143 | logger.warning("Have {} jobs running but can run max max_jobs={}".format(get_n_jobs(script_name), max_jobs)) 144 | else: 145 | raise NotImplementedError() 146 | else: 147 | logger.info("No jobs found") 148 | if get_n_jobs(script_name)==0: 149 | exit(0) 150 | # Allow to kill easilyq 151 | time.sleep(5) 152 | except KeyboardInterrupt: 153 | logger.warning("Interrupt. Killing all python & tensorboard jobs.") 154 | os.system("ps | grep python |awk '{print $1}' | xargs kill -9") 155 | os.system("ps | grep tensorboard |awk '{print $1}' | xargs kill -9") 156 | os.system("rm " + os.path.join(tb_dir, path.basename(root_save_path))) 157 | 158 | if __name__ == "__main__": 159 | argh.dispatch_command(run) 160 | -------------------------------------------------------------------------------- /pytorch_project_template/src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minor utilities 3 | """ 4 | 5 | import sys 6 | from functools import reduce 7 | 8 | import traceback 9 | import logging 10 | import argparse 11 | import optparse 12 | import datetime 13 | import sys 14 | import pprint 15 | import types 16 | import time 17 | import copy 18 | import subprocess 19 | import glob 20 | from collections import OrderedDict 21 | import os 22 | import signal 23 | import atexit 24 | import json 25 | import inspect 26 | 27 | from logging import handlers 28 | 29 | import argh 30 | import gin 31 | from gin.config import _OPERATIVE_CONFIG 32 | 33 | import torch 34 | from torch.nn.modules.module import _addindent 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | def acc(y_pred, y_true): 39 | _, y_pred = y_pred.max(1) 40 | # _, y_true = y_true.max(1) 41 | acc_pred = (y_pred == y_true).float().mean() 42 | return acc_pred * 100 43 | 44 | def save_weights(model, optimizer, filename): 45 | """ 46 | Save all weights necessary to resume training 47 | """ 48 | state = { 49 | 'model': model.state_dict(), 50 | 'optimizer': optimizer.state_dict(), 51 | } 52 | torch.save(state, filename) 53 | 54 | from contextlib import contextmanager 55 | 56 | 57 | class Fork(object): 58 | def __init__(self, file1, file2): 59 | self.file1 = file1 60 | self.file2 = file2 61 | 62 | def write(self, data): 63 | self.file1.write(data) 64 | self.file2.write(data) 65 | 66 | def flush(self): 67 | self.file1.flush() 68 | self.file2.flush() 69 | 70 | 71 | @contextmanager 72 | def replace_logging_stream(file_): 73 | root = logging.getLogger() 74 | if len(root.handlers) != 1: 75 | print(root.handlers) 76 | raise ValueError("Don't know what to do with many handlers") 77 | if not isinstance(root.handlers[0], logging.StreamHandler): 78 | raise ValueError 79 | stream = root.handlers[0].stream 80 | root.handlers[0].stream = file_ 81 | try: 82 | yield 83 | finally: 84 | root.handlers[0].stream = stream 85 | 86 | 87 | @contextmanager 88 | def replace_standard_stream(stream_name, file_): 89 | stream = getattr(sys, stream_name) 90 | setattr(sys, stream_name, file_) 91 | try: 92 | yield 93 | finally: 94 | setattr(sys, stream_name, stream) 95 | 96 | def gin_wrap(fnc): 97 | def main(save_path, config, bindings=""): 98 | # You can pass many configs (think of them as mixins), and many bindings. Both ";" separated. 99 | gin.parse_config_files_and_bindings(config.split("#"), bindings.replace("#", "\n")) 100 | if not os.path.exists(save_path): 101 | logger.info("Creating folder " + save_path) 102 | os.system("mkdir -p " + save_path) 103 | 104 | run_with_redirection(os.path.join(save_path, "stdout.txt"), 105 | os.path.join(save_path, "stderr.txt"), 106 | fnc)(save_path) 107 | 108 | argh.dispatch_command(main) 109 | 110 | 111 | def run_with_redirection(stdout_path, stderr_path, func): 112 | def func_wrapper(*args, **kwargs): 113 | with open(stdout_path, 'a', 1) as out_dst: 114 | with open(stderr_path, 'a', 1) as err_dst: 115 | out_fork = Fork(sys.stdout, out_dst) 116 | err_fork = Fork(sys.stderr, err_dst) 117 | with replace_standard_stream('stderr', err_fork): 118 | with replace_standard_stream('stdout', out_fork): 119 | with replace_logging_stream(err_fork): 120 | func(*args, **kwargs) 121 | 122 | return func_wrapper 123 | 124 | def configure_logger(name='', 125 | console_logging_level=logging.INFO, 126 | file_logging_level=None, 127 | log_file=None): 128 | """ 129 | Configures logger 130 | :param name: logger name (default=module name, __name__) 131 | :param console_logging_level: level of logging to console (stdout), None = no logging 132 | :param file_logging_level: level of logging to log file, None = no logging 133 | :param log_file: path to log file (required if file_logging_level not None) 134 | :return instance of Logger class 135 | """ 136 | 137 | if file_logging_level is None and log_file is not None: 138 | print("Didnt you want to pass file_logging_level?") 139 | 140 | if len(logging.getLogger(name).handlers) != 0: 141 | print("Already configured logger '{}'".format(name)) 142 | return 143 | 144 | if console_logging_level is None and file_logging_level is None: 145 | return # no logging 146 | 147 | logger = logging.getLogger(name) 148 | logger.handlers = [] 149 | logger.setLevel(logging.DEBUG) 150 | format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 151 | 152 | if console_logging_level is not None: 153 | ch = logging.StreamHandler(sys.stdout) 154 | ch.setFormatter(format) 155 | ch.setLevel(console_logging_level) 156 | logger.addHandler(ch) 157 | 158 | if file_logging_level is not None: 159 | if log_file is None: 160 | raise ValueError("If file logging enabled, log_file path is required") 161 | fh = handlers.RotatingFileHandler(log_file, maxBytes=(1048576 * 5), backupCount=7) 162 | fh.setFormatter(format) 163 | logger.addHandler(fh) 164 | 165 | logger.info("Logging configured!") 166 | 167 | return logger 168 | 169 | 170 | def summary(model, file=sys.stderr): 171 | def repr(model): 172 | # We treat the extra repr like the sub-module, one item per line 173 | extra_lines = [] 174 | extra_repr = model.extra_repr() 175 | # empty string will be split into list [''] 176 | if extra_repr: 177 | extra_lines = extra_repr.split('\n') 178 | child_lines = [] 179 | total_params = 0 180 | for key, module in model._modules.items(): 181 | mod_str, num_params = repr(module) 182 | mod_str = _addindent(mod_str, 2) 183 | child_lines.append('(' + key + '): ' + mod_str) 184 | total_params += num_params 185 | lines = extra_lines + child_lines 186 | 187 | for name, p in model._parameters.items(): 188 | total_params += reduce(lambda x, y: x * y, p.shape) 189 | 190 | main_str = model._get_name() + '(' 191 | if lines: 192 | # simple one-liner info, which most builtin Modules will use 193 | if len(extra_lines) == 1 and not child_lines: 194 | main_str += extra_lines[0] 195 | else: 196 | main_str += '\n ' + '\n '.join(lines) + '\n' 197 | 198 | main_str += ')' 199 | if file is sys.stderr: 200 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params) 201 | else: 202 | main_str += ', {:,} params'.format(total_params) 203 | return main_str, total_params 204 | 205 | string, count = repr(model) 206 | if file is not None: 207 | print(string, file=file) 208 | return count -------------------------------------------------------------------------------- /pytorch_project_template/src/data/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Example datasets: cifar and mnist 4 | """ 5 | import gin 6 | import logging 7 | import os 8 | import numpy as np 9 | import h5py 10 | from keras import datasets 11 | from keras.datasets import mnist, fashion_mnist 12 | from keras.utils import np_utils 13 | 14 | from keras.preprocessing import sequence 15 | from keras.datasets import imdb as load_imdb 16 | 17 | from src import DATA_DIR 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | ROOT_DIR = os.path.join(os.path.dirname(__file__), "..") 22 | FMNIST_DIR = os.path.join(ROOT_DIR, "data/fmnist") 23 | 24 | @gin.configurable 25 | def cifar(which=10, preprocessing="center", seed=777, use_valid=True): 26 | rng = np.random.RandomState(seed) 27 | meta_data = {} 28 | 29 | if which == 10: 30 | (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data() 31 | elif which == 100: 32 | (x_train, y_train), (x_test, y_test) = datasets.cifar100.load_data() 33 | else: 34 | raise NotImplementedError(which) 35 | 36 | # Minor conversions 37 | x_train = x_train.astype("float32") 38 | x_test = x_test.astype("float32") 39 | y_train = y_train.astype("long").reshape(-1,) 40 | y_test = y_test.astype("long").reshape(-1,) 41 | 42 | # Always outputs channels first 43 | if x_train.shape[-1] == 3: 44 | logging.info("Transposing") 45 | x_train = x_train.transpose((0, 3, 1, 2)) 46 | x_test = x_test.transpose((0, 3, 1, 2)) 47 | 48 | if use_valid: 49 | # Some randomization to make sure 50 | ids = rng.choice(len(x_train), len(x_train), replace=False) 51 | assert len(set(ids)) == len(ids) == len(x_train) 52 | x_train = x_train[ids] 53 | y_train = y_train[ids] 54 | 55 | N_valid = int(len(x_train) * 0.1) 56 | 57 | assert len(x_train) == 50000, len(x_train) 58 | assert N_valid == 5000 59 | 60 | (x_train, y_train), (x_valid, y_valid) = (x_train[0:-N_valid], y_train[0:-N_valid]), \ 61 | (x_train[-N_valid:], y_train[-N_valid:]) 62 | 63 | meta_preprocessing = {"type": preprocessing} 64 | if preprocessing == "center": 65 | # This (I think) follows the original resnet paper. Per-pixel mean 66 | # and the global std, computed using the train set 67 | mean = np.mean(x_train, axis=0, keepdims=True) # Pixel mean 68 | std = np.std(x_train) 69 | meta_preprocessing['mean'] = mean 70 | meta_preprocessing['std'] = std 71 | x_train = (x_train - mean) / std 72 | x_test = (x_test - mean) / std 73 | if use_valid: 74 | x_valid = (x_valid - mean) / std 75 | elif preprocessing == "01": # Required by scatnet 76 | x_train = x_train / 255.0 77 | x_test = x_test / 255.0 78 | if use_valid: 79 | x_valid = x_valid / 255.0 80 | else: 81 | raise NotImplementedError("Not implemented preprocessing " + preprocessing) 82 | 83 | logging.info('x_train shape:' + str(x_train.shape)) 84 | logging.info(str(x_train.shape[0]) + 'train samples') 85 | logging.info(str(x_test.shape[0]) + 'test samples') 86 | if use_valid: 87 | logging.info(str(x_valid.shape[0]) + 'valid samples') 88 | logging.info('y_train shape:' + str(y_train.shape)) 89 | 90 | # Prepare test 91 | train = [x_train, y_train] 92 | test = [x_test, y_test] 93 | if use_valid: 94 | valid = [x_valid, y_valid] 95 | 96 | w, h, c = train[0].shape[1:4] 97 | meta_data['input_dim'] = (w, h, c) 98 | meta_data['preprocessing'] = meta_preprocessing 99 | 100 | if use_valid: 101 | return train, valid, test, meta_data 102 | else: 103 | return train, test, test, meta_data 104 | 105 | 106 | @gin.configurable 107 | def mnist(which="fmnist", preprocessing="01", seed=777, use_valid=True): 108 | """ 109 | Returns 110 | ------- 111 | (x_train, y_train), (x_valid, y_valid), (x_test, y_test) 112 | """ 113 | rng = np.random.RandomState(seed) 114 | meta_data = {} 115 | 116 | if use_valid: 117 | logger.info("Using valid") 118 | else: 119 | logger.info("Using as valid test") 120 | 121 | if which == "mnist": 122 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 123 | x_train = x_train.reshape(-1, 28, 28, 1) 124 | x_test = x_test.reshape(-1, 28, 28, 1) 125 | elif which == "fmnist": 126 | (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data() 127 | X_train, y_train = np.array(X_train).astype("float32"), np.array(y_train) 128 | X_test, y_test = np.array(X_test).astype("float32"), np.array(y_test) 129 | x_train = X_train.reshape(-1, 1, 28, 28) 130 | x_test = X_test.reshape(-1, 1, 28, 28) 131 | else: 132 | raise NotImplementedError() 133 | 134 | y_train = y_train.astype("long").reshape(-1,) 135 | y_test = y_test.astype("long").reshape(-1,) 136 | 137 | # Permute 138 | ids_train = rng.choice(len(x_train), len(x_train), replace=False) 139 | ids_test = rng.choice(len(x_test), len(x_test), replace=False) 140 | x_train, y_train = x_train[ids_train], y_train[ids_train] 141 | x_test, y_test = x_test[ids_test], y_test[ids_test] 142 | 143 | logger.info("Loaded dataset using eval") 144 | 145 | if use_valid: 146 | assert len(x_train) == 60000, len(x_train) 147 | (x_train, y_train), (x_valid, y_valid) = (x_train[0:50000], y_train[0:50000]), \ 148 | (x_train[-10000:], y_train[-10000:]) 149 | 150 | if preprocessing == "center": 151 | mean = np.mean(x_train, axis=0, keepdims=True) # Pixel mean 152 | # Complete std as in https://github.com/gcr/torch-residual-networks/blob/master/data/cifar-dataset.lua#L3 153 | std = np.std(x_train) 154 | x_train = (x_train - mean) / std 155 | x_test = (x_test - mean) / std 156 | if use_valid: 157 | x_valid = (x_valid - mean) / std 158 | elif preprocessing == "01": # Required by scatnet 159 | x_train = x_train / 255.0 160 | x_test = x_test / 255.0 161 | if use_valid: 162 | x_valid = x_valid / 255.0 163 | else: 164 | raise NotImplementedError("Not implemented preprocessing " + preprocessing) 165 | 166 | logger.info('x_train shape:' + str(x_train.shape)) 167 | logger.info(str(x_train.shape[0]) + 'train samples') 168 | logger.info(str(x_test.shape[0]) + 'test samples') 169 | if use_valid: 170 | logger.info(str(x_valid.shape[0]) + 'valid samples') 171 | logger.info('y_train shape:' + str(y_train.shape)) 172 | 173 | # Prepare test 174 | train = [x_train, y_train] 175 | test = [x_test, y_test] 176 | if use_valid: 177 | valid = [x_valid, y_valid] 178 | 179 | w, h = meta_data['x_train'].shape[1:3] 180 | n_channels = meta_data['x_train'].shape[3] 181 | logger.info((w, h, n_channels)) 182 | 183 | if use_valid: 184 | return train, valid, test, meta_data 185 | else: 186 | # Using as valid test 187 | return train, test, test, meta_data 188 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Utils for datasets. 4 | """ 5 | import gin 6 | import logging 7 | import numpy as np 8 | 9 | from src import DATA_FORMAT 10 | # from src.data.streams import DatasetGenerator 11 | from src import DATA_DIR, DATA_FORMAT, DATA_NUM_WORKERS 12 | 13 | from os.path import join 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | import torchvision.transforms as T 18 | import torch 19 | from torch.utils.data import Sampler 20 | from torch.utils.data import DataLoader, Dataset 21 | from torchvision.datasets import ImageFolder 22 | 23 | import numpy as np 24 | 25 | 26 | class DatasetFromNumpy(Dataset): 27 | """TensorDataset with support of transforms. 28 | """ 29 | 30 | def __init__(self, ds): 31 | self.ds = ds 32 | 33 | def __getitem__(self, index): 34 | x, y = self.ds[0][index], self.ds[1][index] 35 | 36 | return x, y 37 | 38 | def __len__(self): 39 | assert len(self.ds[0]) == len(self.ds[1]) 40 | return len(self.ds[0]) 41 | 42 | 43 | class TransformedDataset(Dataset): 44 | """TensorDataset with support of transforms. 45 | """ 46 | 47 | def __init__(self, ds, transform=None): 48 | self.ds = ds 49 | self.transform = transform 50 | 51 | def __getitem__(self, index): 52 | x, y = self.ds[index] 53 | 54 | if self.transform: 55 | x = self.transform(x) 56 | 57 | return x, y 58 | 59 | def __len__(self): 60 | return len(self.ds) 61 | 62 | 63 | class ShuffledDataset(TransformedDataset): 64 | """TensorDataset with support of transforms. 65 | """ 66 | 67 | def __init__(self, ds, rand_frac, n_classes, transform=None, seed=0): 68 | super().__init__(ds, transform=transform) 69 | self.create_shuffled(rand_frac, n_classes, seed) 70 | 71 | def create_shuffled(self, rand_frac, n_classes, seed): 72 | # TOD: use the same random state 73 | rng = np.random.RandomState(seed) 74 | 75 | inds = self.ds.indices 76 | targets = np.array(self.ds.dataset.targets) 77 | n = len(inds) 78 | n_rand = int(n * rand_frac) 79 | 80 | # TODO: make it work for arbitrarily many classes 81 | rand_inds = rng.choice(inds, n_rand, replace=False) 82 | rand_labels = rng.randint(0, n_classes, (n_rand,)) 83 | targets[rand_inds] = rand_labels 84 | self.ds.dataset.targets = targets.tolist() 85 | 86 | real_inds = np.array(list(set(inds) - set(rand_inds))) 87 | 88 | self.real_ds = torch.utils.data.Subset(self.ds.dataset, real_inds) 89 | self.real_data = TransformedDataset(self.real_ds, self.transform) 90 | self.noisy_ds = torch.utils.data.Subset(self.ds.dataset, rand_inds) 91 | self.noisy_data = TransformedDataset(self.noisy_ds, self.transform) 92 | 93 | 94 | def construct_generators_and_meta(train, valid, test, seed, batch_size, stream_seed, workers=DATA_NUM_WORKERS, 95 | rand_frac=0, pin_memory=True): 96 | """ 97 | A helper function that converts data sources into generators and meta data usable with the rest of the code 98 | 99 | Returns: a tuple of (train_generator, valid_generator, test_generator, meta_data) 100 | 101 | Note 102 | ---- 103 | Assumes labels are one hot encoded. 104 | """ 105 | rng_stream = np.random.RandomState(stream_seed) 106 | 107 | train_generator = construct_generator(train, workers=workers, shuffle=True, batch_size=batch_size, rng=rng_stream, 108 | pin_memory=pin_memory) 109 | train_duplicated_generator = construct_generator(train, workers=1, shuffle=True, batch_size=batch_size, 110 | rng=rng_stream, pin_memory=pin_memory) 111 | valid_generator = construct_generator(valid, workers=workers, shuffle=False, batch_size=batch_size, 112 | pin_memory=pin_memory) 113 | valid_duplicated_generator = construct_generator(valid, workers=workers, shuffle=False, batch_size=batch_size, 114 | pin_memory=pin_memory) 115 | test_generator = construct_generator(test, 1, shuffle=False, batch_size=batch_size, pin_memory=pin_memory) 116 | 117 | meta_data = construct_meta_data(train=train, valid=valid, test=test, 118 | train_duplicated_generator=train_duplicated_generator, 119 | valid_duplicated_generator=valid_duplicated_generator) 120 | 121 | return train_generator, valid_generator, test_generator, meta_data 122 | 123 | 124 | def construct_meta_data(train, valid, test, 125 | train_duplicated_generator, 126 | valid_duplicated_generator): 127 | meta_data = {} 128 | meta_data['train_ds'] = train 129 | meta_data['valid_ds'] = valid 130 | meta_data['test_ds'] = test 131 | meta_data['n_examples_train'] = len(train) 132 | meta_data['n_examples_valid'] = len(valid) 133 | meta_data['n_examples_test'] = len(test) 134 | meta_data['train_stream_duplicated'] = train_duplicated_generator 135 | meta_data['valid_stream_duplicated'] = valid_duplicated_generator 136 | meta_data['input_dim'] = meta_data['input_shape'] = train[0][0].shape # Take first element and the image of it 137 | if isinstance(train[0][1], int) or isinstance(train[0][1], float): 138 | max_label = 0 139 | # TODO: This might be slow for large datasets 140 | for x, y in train: 141 | max_label = max(max_label, y) 142 | meta_data['num_classes'] = max_label 143 | meta_data['one_hot'] = False 144 | else: 145 | meta_data['num_classes'] = len(train[0][1]) # We assume labels are one hot encoded 146 | meta_data['one_hot'] = True 147 | return meta_data 148 | 149 | 150 | def construct_generator(ds, workers, shuffle, batch_size, rng=None, pin_memory=True): 151 | # TODO: Is this collate function the optimal way to do it? 152 | # TODO: Generalize this. It works only for specific shapes of x and y 153 | def collate_fn(xy): 154 | if hasattr(xy[0][0], 'numpy'): 155 | return (torch.stack([x for x, y in xy]), torch.tensor([y for x, y in xy])) 156 | else: 157 | return (np.stack([x for x, y in xy]), np.array([y for x, y in xy])) 158 | 159 | 160 | class SeededRandomSampler(Sampler): 161 | r"""Default PyTorch sampler is not seeded. 162 | """ 163 | 164 | def __init__(self, data_source, rng, num_samples=None): 165 | self.data_source = data_source 166 | self.rng = rng 167 | self._num_samples = num_samples 168 | 169 | @property 170 | def num_samples(self): 171 | if self._num_samples is None: 172 | return len(self.data_source) 173 | return self._num_samples 174 | 175 | def __iter__(self): 176 | n = len(self.data_source) 177 | return iter(self.rng.choice(n, n, replace=False)) 178 | 179 | def __len__(self): 180 | return self.num_samples 181 | 182 | if shuffle: 183 | sampler = SeededRandomSampler(data_source=ds, rng=rng) 184 | else: 185 | sampler = None 186 | # pin_memory should just speedup 187 | return DataLoader(ds, batch_size=batch_size, collate_fn=collate_fn, num_workers=workers, sampler=sampler, 188 | pin_memory=pin_memory) 189 | -------------------------------------------------------------------------------- /tf2_project_template/src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minor utilities 3 | """ 4 | 5 | import csv 6 | import h5py 7 | import logging 8 | import os 9 | import sys 10 | import argh 11 | import gin 12 | import copy 13 | import tensorflow 14 | import os 15 | 16 | from contextlib import contextmanager 17 | from functools import reduce 18 | from logging import handlers 19 | from gin.config import _CONFIG 20 | from tensorflow import keras 21 | 22 | 23 | custom_tf_objects = {} 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | _NEPTUNE = {} 28 | 29 | try: 30 | import torch 31 | from torch import nn 32 | from torch.nn import Module 33 | except ImportError: 34 | class Module(): 35 | pass 36 | 37 | try: 38 | import neptune 39 | except ImportError: 40 | pass 41 | 42 | def save_model(model, optimizer, filename): 43 | """ 44 | Save all weights and state necessary to resume training 45 | """ 46 | if isinstance(model, keras.Model): 47 | if not filename.endswith(".h5"): 48 | filename = filename + ".h5" 49 | tensorflow.keras.models.save_model(model, h5py.File(filename, 'w'), 50 | overwrite=True, include_optimizer=True, save_format="h5") 51 | else: 52 | raise NotImplementedError() 53 | 54 | 55 | def restore_model(model, filename): 56 | if isinstance(model, keras.Model): 57 | if not filename.endswith(".h5"): 58 | filename += ".h5" 59 | model.load_weights(filename, by_name=True) 60 | return model 61 | else: 62 | raise NotImplementedError() 63 | 64 | 65 | def restore_model_and_optimizer(model, optimizer, filename): 66 | if isinstance(model, keras.Model): 67 | if not filename.endswith(".h5"): 68 | filename += ".h5" 69 | model = keras.models.load_model(h5py.File(filename), custom_objects=custom_tf_objects) 70 | optimizer = model.optimizer 71 | return model, optimizer 72 | else: 73 | raise NotImplementedError() 74 | 75 | 76 | class Fork(object): 77 | def __init__(self, file1, file2): 78 | self.file1 = file1 79 | self.file2 = file2 80 | 81 | def write(self, data): 82 | self.file1.write(data) 83 | self.file2.write(data) 84 | 85 | def flush(self): 86 | self.file1.flush() 87 | self.file2.flush() 88 | 89 | 90 | @contextmanager 91 | def replace_logging_stream(file_): 92 | root = logging.getLogger() 93 | if len(root.handlers) != 1: 94 | print(root.handlers) 95 | raise ValueError("Don't know what to do with many handlers") 96 | if not isinstance(root.handlers[0], logging.StreamHandler): 97 | raise ValueError 98 | stream = root.handlers[0].stream 99 | root.handlers[0].stream = file_ 100 | try: 101 | yield 102 | finally: 103 | root.handlers[0].stream = stream 104 | 105 | 106 | @contextmanager 107 | def replace_standard_stream(stream_name, file_): 108 | stream = getattr(sys, stream_name) 109 | setattr(sys, stream_name, file_) 110 | try: 111 | yield 112 | finally: 113 | setattr(sys, stream_name, stream) 114 | 115 | 116 | def gin_wrap(fnc): 117 | def main(save_path, config, bindings=""): 118 | # You can pass many configs (think of them as mixins), and many bindings. Both ";" separated. 119 | gin.parse_config_files_and_bindings(config.split("#"), bindings.replace("#", "\n")) 120 | if not os.path.exists(save_path): 121 | logger.info("Creating folder " + save_path) 122 | os.system("mkdir -p " + save_path) 123 | 124 | run_with_redirection(os.path.join(save_path, "stdout.txt"), 125 | os.path.join(save_path, "stderr.txt"), 126 | fnc)(save_path) 127 | 128 | argh.dispatch_command(main) 129 | 130 | 131 | def run_with_redirection(stdout_path, stderr_path, func): 132 | def func_wrapper(*args, **kwargs): 133 | with open(stdout_path, 'a', 1) as out_dst: 134 | with open(stderr_path, 'a', 1) as err_dst: 135 | out_fork = Fork(sys.stdout, out_dst) 136 | err_fork = Fork(sys.stderr, err_dst) 137 | with replace_standard_stream('stderr', err_fork): 138 | with replace_standard_stream('stdout', out_fork): 139 | with replace_logging_stream(err_fork): 140 | func(*args, **kwargs) 141 | 142 | return func_wrapper 143 | 144 | 145 | def configure_neptune_exp(name): 146 | global _NEPTUNE 147 | if 'NEPTUNE_TOKEN' not in os.environ: 148 | logger.warning("Neptune couldn't be configured. Couldn't find NEPTUNE_TOKEN. ") 149 | return 150 | 151 | NEPTUNE_TOKEN = os.environ['NEPTUNE_TOKEN'] 152 | NEPTUNE_USER = os.environ['NEPTUNE_USER'] 153 | NEPTUNE_PROJECT = os.environ['NEPTUNE_PROJECT'] 154 | C = copy.deepcopy(_CONFIG) 155 | C = {k[1].split(".")[-1]: v for k, v in C.items()} # Hacky way to simplify config 156 | logger.info("Initializing neptune to name " + name) 157 | project = neptune.Session(NEPTUNE_TOKEN).get_project(f'{NEPTUNE_USER}/{NEPTUNE_PROJECT}') 158 | exp = project.create_experiment(name=name, params=C) 159 | _NEPTUNE['default'] = exp 160 | logger.info("Initialized neptune") 161 | 162 | def get_neptune_exp(name=None): 163 | global _NEPTUNE 164 | if name is not None: 165 | return _NEPTUNE[name] 166 | else: 167 | return _NEPTUNE['default'] 168 | 169 | 170 | def configure_logger(name='', 171 | console_logging_level=logging.INFO, 172 | file_logging_level=None, 173 | log_file=None): 174 | """ 175 | Configures logger 176 | :param name: logger name (default=module name, __name__) 177 | :param console_logging_level: level of logging to console (stdout), None = no logging 178 | :param file_logging_level: level of logging to log file, None = no logging 179 | :param log_file: path to log file (required if file_logging_level not None) 180 | :return instance of Logger class 181 | """ 182 | 183 | if file_logging_level is None and log_file is not None: 184 | print("Didnt you want to pass file_logging_level?") 185 | 186 | if len(logging.getLogger(name).handlers) != 0: 187 | print("Already configured logger '{}'".format(name)) 188 | return 189 | 190 | if console_logging_level is None and file_logging_level is None: 191 | return # no logging 192 | 193 | logger = logging.getLogger(name) 194 | logger.handlers = [] 195 | logger.setLevel(logging.DEBUG) 196 | format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 197 | 198 | if console_logging_level is not None: 199 | ch = logging.StreamHandler(sys.stdout) 200 | ch.setFormatter(format) 201 | ch.setLevel(console_logging_level) 202 | logger.addHandler(ch) 203 | 204 | if file_logging_level is not None: 205 | if log_file is None: 206 | raise ValueError("If file logging enabled, log_file path is required") 207 | fh = handlers.RotatingFileHandler(log_file, maxBytes=(1048576 * 5), backupCount=7) 208 | fh.setFormatter(format) 209 | logger.addHandler(fh) 210 | 211 | logger.info("Logging configured!") 212 | 213 | return logger 214 | 215 | 216 | def dict_to_csv(path, d): 217 | with open(path, 'w') as f: 218 | for k, v in d.items(): 219 | f.write(f'{k}, {v}\n') 220 | 221 | 222 | def csv_to_dict(path): 223 | reader = csv.reader(open(path, "r")) 224 | d = {} 225 | for k, v in reader: 226 | d[k] = v 227 | return d 228 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/data/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Datasets used in the project. 4 | 5 | Each dataset is returned as a tuple (train_loader, valid_loader, test_loader, meta_data) 6 | 7 | Uses PyTorch due to the widespread code for dataset handling in PyTorch and standardized augmentations 8 | for image datasets in particular. 9 | """ 10 | import logging 11 | import gin 12 | import numpy as np 13 | import torch 14 | import torchvision.transforms as T 15 | import torchvision 16 | from PIL import Image 17 | import PIL 18 | 19 | from os.path import join 20 | from torch.utils.data import Subset 21 | 22 | from src import DATA_DIR, DATA_NUM_WORKERS 23 | from src.data.utils import construct_generators_and_meta, TransformedDataset, ShuffledDataset, DatasetFromNumpy 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | TINY_IMAGENET_PATH = join(DATA_DIR, 'tiny-imagenet-200') 29 | 30 | @gin.configurable 31 | def cifar(use_valid, seed=777, stream_seed=777, variant="10", augment=True, batch_size=128, rand_frac=0): 32 | rng = np.random.RandomState(seed) 33 | cifar_mean = (0.4914, 0.4822, 0.4465) 34 | cifar_std = (0.2023, 0.1994, 0.2010) 35 | 36 | if augment is True: 37 | transform_train = T.Compose([ 38 | T.RandomCrop(32, padding=4), 39 | T.RandomHorizontalFlip(), 40 | T.ToTensor(), 41 | T.Normalize(cifar_mean, cifar_std), 42 | ]) 43 | else: 44 | transform_train = T.Compose([ 45 | T.ToTensor(), 46 | T.Normalize(cifar_mean, cifar_std), 47 | ]) 48 | transform_test = T.Compose([ 49 | T.ToTensor(), 50 | T.Normalize(cifar_mean, cifar_std), 51 | ]) 52 | 53 | if variant == '10': 54 | n_classes = 10 55 | def to_one_hot(y): 56 | # CHANGE N CLASSES 57 | one_hot = np.zeros(shape=(10,)) 58 | one_hot[y] = 1 59 | one_hot = torch.tensor(one_hot) 60 | return one_hot 61 | trainval = torchvision.datasets.CIFAR10( 62 | root=join(DATA_DIR, 'cifar100'), train=True, download=True, transform=None, target_transform=None) 63 | test = torchvision.datasets.CIFAR10( 64 | root=join(DATA_DIR, 'cifar100'), train=False, download=True, transform=None, target_transform=None) 65 | assert len(trainval) == 50000 66 | elif variant in {'100', '100a', '100b'}: 67 | n_classes = 100 68 | def to_one_hot(y): 69 | # CHANGE N CLASSES 70 | one_hot = np.zeros(shape=(100,)) 71 | one_hot[y] = 1 72 | one_hot = torch.tensor(one_hot) 73 | return one_hot 74 | trainval = torchvision.datasets.CIFAR100( 75 | root=join(DATA_DIR, 'cifar100'), train=True, download=True, transform=None, target_transform=None) 76 | test = torchvision.datasets.CIFAR100( 77 | root=join(DATA_DIR, 'cifar100'), train=False, download=True, transform=None, target_transform=None) 78 | assert len(trainval) == 50000 79 | elif variant in {'100c', '100d'}: 80 | n_classes = 100 81 | def to_one_hot(y): 82 | # CHANGE N CLASSES 83 | one_hot = np.zeros(shape=(100,)) 84 | one_hot[y] = 1 85 | one_hot = torch.tensor(one_hot) 86 | return one_hot 87 | trainval = torchvision.datasets.CIFAR100( 88 | root=join(DATA_DIR, 'cifar100'), train=True, download=True, transform=None, target_transform=None) 89 | test = torchvision.datasets.CIFAR100( 90 | root=join(DATA_DIR, 'cifar100'), train=False, download=True, transform=None, target_transform=None) 91 | assert len(trainval) == 50000 92 | else: 93 | raise NotImplementedError() 94 | 95 | # Split here to ensure different validation set 96 | ids = np.arange(len(trainval)) 97 | rng.shuffle(ids) 98 | if variant.endswith("b"): 99 | trainval = Subset(trainval, ids[len(trainval)//2:]) 100 | elif variant.endswith("a"): 101 | trainval = Subset(trainval, ids[0:len(trainval)//2]) 102 | 103 | if use_valid: 104 | ids = rng.choice(len(trainval), len(trainval), replace=False) 105 | N_valid = int(len(trainval) * 0.1) 106 | ids_train, ids_val = ids[0:-N_valid], ids[-N_valid:] 107 | train, valid = Subset(trainval, ids_train), Subset(trainval, ids_val) 108 | assert len(valid) == int(0.1 * len(trainval)) and len(train) == int(0.9 * len(trainval)) 109 | else: 110 | train, valid = trainval, test 111 | 112 | # Same valid, weirdly but OK. 113 | ids = np.arange(len(train)) 114 | rng.shuffle(ids) 115 | if variant.endswith("c"): 116 | train = Subset(train, ids[len(train)//2:]) 117 | elif variant.endswith("d"): 118 | train = Subset(train, ids[0:len(train)//2]) 119 | 120 | if rand_frac > 0: 121 | train = ShuffledDataset(train, rand_frac=rand_frac, n_classes=n_classes, transform=transform_train, seed=seed) 122 | else: 123 | train = TransformedDataset(train, transform=transform_train) 124 | test = TransformedDataset(test, transform=transform_test) 125 | valid = TransformedDataset(valid, transform=transform_test) 126 | 127 | return construct_generators_and_meta(train, valid, test, batch_size=batch_size, seed=seed, stream_seed=stream_seed, 128 | workers=DATA_NUM_WORKERS) 129 | 130 | @gin.configurable 131 | def stl10(use_valid, seed=777, stream_seed=777, augment=True, batch_size=128): 132 | rng = np.random.RandomState(seed) 133 | resize = T.Lambda(lambda x: x.resize((32, 32), resample=PIL.Image.BOX)) 134 | 135 | def to_one_hot(y): 136 | one_hot = np.zeros(shape=(10,)) 137 | one_hot[y] = 1 138 | one_hot = torch.tensor(one_hot) 139 | return one_hot 140 | if augment is True: 141 | transform_train = [ 142 | T.RandomCrop(32, padding=4), 143 | T.RandomHorizontalFlip(), 144 | T.ToTensor() 145 | ] 146 | else: 147 | transform_train = [ 148 | T.ToTensor() 149 | ] 150 | transform_test = [ 151 | T.ToTensor(), 152 | ] 153 | 154 | trainval = torchvision.datasets.STL10( 155 | root=join(DATA_DIR, 'stl10'), split="train", download=True, transform=resize, target_transform=None) 156 | test = torchvision.datasets.STL10( 157 | root=join(DATA_DIR, 'stl10'), split="test", download=True, transform=resize, target_transform=None) 158 | 159 | assert len(trainval) == 5000 160 | 161 | if use_valid: 162 | ids = rng.choice(len(trainval), len(trainval), replace=False) 163 | N_valid = int(len(trainval) * 0.1) 164 | ids_train, ids_val = ids[0:-N_valid], ids[-N_valid:] 165 | train, valid = Subset(trainval, ids_train), Subset(trainval, ids_val) 166 | else: 167 | train, valid = trainval, test 168 | 169 | # Compute the standard channel-wise normalization by quickly loading dataset to memory 170 | X = [T.ToTensor()(train[i][0]).numpy() for i in range(len(train))] 171 | X = np.array(X) 172 | assert X.shape[1] == 3, X.shape 173 | stl10_mean = np.mean(X, axis=(0, 2, 3)) 174 | stl10_std = np.std(X, axis=(0, 2, 3)) 175 | transform_test.append(T.Normalize(stl10_mean, stl10_std)) 176 | transform_train.append(T.Normalize(stl10_mean, stl10_std)) 177 | 178 | train = TransformedDataset(train, transform=T.Compose(transform_train)) 179 | test = TransformedDataset(test, transform=T.Compose(transform_test)) 180 | valid = TransformedDataset(valid, transform=T.Compose(transform_test)) 181 | 182 | return construct_generators_and_meta(train, valid, test, batch_size=batch_size, seed=seed, stream_seed=stream_seed, 183 | workers=DATA_NUM_WORKERS) 184 | 185 | 186 | 187 | if __name__ == "__main__": 188 | train, valid, test, meta_data = cifar(seed=1, stream_seed=1, batch_size=128, augment=False, use_valid=True) 189 | for x, y in train: 190 | break 191 | 192 | x = x.numpy() 193 | 194 | import matplotlib.pylab as plt 195 | for xx in x[0:4]: 196 | plt.imshow(xx.transpose(1, 2, 0)) 197 | plt.show() 198 | plt.close() 199 | xx = [] 200 | for x, y in train: 201 | xx.append(x) 202 | if len(xx) > 40: 203 | break 204 | print("BATCH SHAPE") 205 | print(x.shape) 206 | print("MEAN") 207 | assert np.concatenate(xx, axis=0).shape[1] == 3 208 | print(np.mean(np.concatenate(xx, axis=0), axis=(0, 2, 3))) 209 | print("STD") 210 | print(np.std(np.concatenate(xx, axis=0), axis=(0, 2, 3))) 211 | print("MAX") 212 | print(np.max(np.max(np.concatenate(xx, axis=0), axis=0))) 213 | print("MIN") 214 | print(np.min(np.min(np.concatenate(xx, axis=0), axis=0))) 215 | -------------------------------------------------------------------------------- /tf2_project_template/bin/utils/run_slurm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple self-contained script to run on a list of jobs on a SLURM system 5 | 6 | You will need to update bin/utils/slurm_template.sh 7 | 8 | See bin/utils/run_slurm.py -h 9 | 10 | There are the following requirements for the batch: 11 | * Each command has saving dir as 2nd argument 12 | * Each script saves to the save_dir HEARTBEAT 13 | * Each script saves to the save_dir FINISHED when done 14 | 15 | When using: figure out which queue is better. Bsub doesn't like multiple queues. 16 | 17 | When adapting to a new project make sure that get_save_path is OK, and that you replace the project name. 18 | """ 19 | 20 | import os 21 | import time 22 | import stat 23 | from os import path 24 | import subprocess 25 | import logging 26 | import random 27 | import string 28 | from logging import handlers 29 | import sys 30 | 31 | def configure_logger(name='', 32 | console_logging_level=logging.INFO, 33 | file_logging_level=None, 34 | log_file=None): 35 | """ 36 | Configures logger 37 | :param name: logger name (default=module name, __name__) 38 | :param console_logging_level: level of logging to console (stdout), None = no logging 39 | :param file_logging_level: level of logging to log file, None = no logging 40 | :param log_file: path to log file (required if file_logging_level not None) 41 | :return instance of Logger class 42 | """ 43 | 44 | if file_logging_level is None and log_file is not None: 45 | print("Didnt you want to pass file_logging_level?") 46 | 47 | if len(logging.getLogger(name).handlers) != 0: 48 | print("Already configured logger '{}'".format(name)) 49 | return 50 | 51 | if console_logging_level is None and file_logging_level is None: 52 | return # no logging 53 | 54 | logger = logging.getLogger(name) 55 | logger.handlers = [] 56 | logger.setLevel(logging.DEBUG) 57 | format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 58 | 59 | if console_logging_level is not None: 60 | ch = logging.StreamHandler(sys.stdout) 61 | ch.setFormatter(format) 62 | ch.setLevel(console_logging_level) 63 | logger.addHandler(ch) 64 | 65 | if file_logging_level is not None: 66 | if log_file is None: 67 | raise ValueError("If file logging enabled, log_file path is required") 68 | fh = handlers.RotatingFileHandler(log_file, maxBytes=(1048576 * 5), backupCount=7) 69 | fh.setFormatter(format) 70 | logger.addHandler(fh) 71 | 72 | logger.info("Logging configured!") 73 | 74 | return logger 75 | 76 | 77 | 78 | configure_logger('', log_file=None) 79 | 80 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 81 | TIMEOUT = 180 # 3 minutes. Might be short for some 82 | 83 | logger = logging.getLogger(__name__) 84 | 85 | 86 | def random_string(strlen=10): 87 | """Generate a random string of fixed length """ 88 | letters = string.ascii_lowercase 89 | return ''.join(random.choice(letters) for i in range(strlen)) 90 | 91 | 92 | def get_last_modification(save_path): 93 | f_path = os.path.join(save_path, "HEARTBEAT") 94 | stderr_path = os.path.join(save_path, "stderr.txt") 95 | if os.path.exists(f_path): 96 | return time.time() - os.stat(f_path)[stat.ST_MTIME] 97 | else: 98 | if os.path.exists(stderr_path): 99 | return time.time() - os.stat(stderr_path)[stat.ST_MTIME] 100 | else: 101 | return 10000000000000 102 | 103 | def is_running(job_id): 104 | # A hacky way to chekc if a job is running 105 | output = subprocess.check_output("squeue -u jastrs01", shell=True).decode("utf-8").strip() 106 | if output.find(str(job_id)) == -1: 107 | return False 108 | else: 109 | return True 110 | 111 | def get_n_jobs(batch_name): 112 | try: 113 | output = subprocess.check_output('squeue -u jastrs01 | grep {}'.format(batch_name), shell=True).decode("utf-8").strip() 114 | except: 115 | print("No jobs") 116 | output = "" 117 | if len(output) == 0: 118 | return 0 119 | else: 120 | return len(output.split("\n")) - 1 # -1 Because header 121 | 122 | def get_save_path(job): 123 | if job.startswith("python"): 124 | return job.split(" ")[2] 125 | else: 126 | return job.split(" ")[1] 127 | 128 | 129 | def get_script_name(job): 130 | if job.startswith("python"): 131 | return job.split(" ")[1] 132 | else: 133 | return job.split(" ")[0] 134 | 135 | 136 | def has_finished(save_path): 137 | # Hacky but works, usually 138 | return path.exists(path.join(save_path, "FINISHED")) 139 | 140 | 141 | def get_jobs(batch): 142 | if ";" in batch: 143 | batches = batch.split(";") 144 | else: 145 | batches = [batch] 146 | 147 | all_jobs = [] 148 | for batch in batches: 149 | jobs = list(open(batch, "r").read().splitlines()) 150 | jobs = [j for j in jobs if j[0] != "#"] 151 | jobs = [j for j in jobs if not has_finished(get_save_path(j))] 152 | # take only at least 10min old jobs 153 | jobs = [j for j in jobs if get_last_modification(get_save_path(j)) > 600] 154 | jobs = [("python " + j) if "python" not in j else j for j in jobs] 155 | all_jobs += jobs 156 | 157 | random.shuffle(all_jobs) 158 | return all_jobs 159 | 160 | def tensorboard_running(): 161 | output = subprocess.check_output('ps | grep tensorboard', shell=True).decode("utf-8").strip() 162 | return len(output.split("\n")) > 1 163 | 164 | 165 | def run_job(job, batch_name, exclude_hosts=[], wait=1): 166 | # Shorter jobs ! 167 | slurm_cmd=open("bin/utils/slurm_template.sh", "r").read() 168 | slurm_cmd=slurm_cmd.format(job=job, batch_name=batch_name, save_path=get_save_path(job)) 169 | with open(os.path.join(get_save_path(job), "run.sh"), "w") as f: 170 | logger.info("Writing runner to " + os.path.join(get_save_path(job), "run.sh")) 171 | f.write(slurm_cmd) 172 | logger.info("Submitting job to bsub") 173 | 174 | # A heuristic way to submit the job 175 | cmd = "sbatch {}".format(os.path.join(get_save_path(job), "run.sh")) 176 | output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() 177 | assert output.startswith("Submitted") 178 | assert len(output.split(" ")) == 4 179 | job_id = int(output.split(" ")[-1]) 180 | logger.info("Job id is " + str(job_id)) 181 | 182 | # A hacky way to get hostname 183 | hostname = "" 184 | while wait: 185 | logger.info("Waiting to find a machine for {}.. last hostname is {}".format(cmd, hostname)) 186 | 187 | output = subprocess.check_output("squeue -u jastrs01 | grep " + str(job_id), shell=True).decode("utf-8").strip() 188 | parsed_output = output.strip().split() 189 | assert parsed_output[0] == str(job_id) 190 | if parsed_output[-1][0] != "(": 191 | hostname = parsed_output[-1] 192 | break 193 | 194 | time.sleep(1) 195 | 196 | if hostname != "": 197 | logger.info("Hostname is " + str(hostname)) 198 | 199 | return job_id, hostname 200 | 201 | 202 | def run(batch, max_jobs=1, wait=1): 203 | exclude_hosts = [] 204 | 205 | batch_name = random_string(5) 206 | 207 | n_jobs_start = len(get_jobs(batch)) 208 | 209 | print("Starting") 210 | print("==") 211 | 212 | try: 213 | if len(get_jobs(batch)) == 0: 214 | logger.error("No untouched (>10min old) jobs found. Exiting.") 215 | exit(1) 216 | 217 | tb_dir = os.path.join(RESULTS_DIR, "running_experiments") 218 | os.system("mkdir -p " +tb_dir) 219 | 220 | save_path = get_save_path(get_jobs(batch)[0]) 221 | root_save_path = path.dirname(save_path) 222 | os.system("rm " + root_save_path + " " + tb_dir) 223 | os.system("ln -s " + root_save_path + " " + tb_dir) 224 | 225 | while True: 226 | print("n_jobs={}/{},batch={},\nexclude_hosts={},name={}".format(get_n_jobs(batch_name), n_jobs_start, batch, exclude_hosts,batch_name)) 227 | jobs = get_jobs(batch) 228 | logger.info("Found {} jobs to run in the batch script.".format(len(jobs))) 229 | if len(jobs): 230 | job = jobs[0] 231 | n_jobs = get_n_jobs(batch_name) 232 | if max_jobs > n_jobs: 233 | os.system("mkdir -p " + get_save_path(job)) 234 | # Run and redirect all output to a file in the save folder of the job 235 | logger.info("Running " + job) 236 | job_id, hostname = run_job(job, batch_name, exclude_hosts, wait) 237 | start_wait = time.time() 238 | while get_last_modification(get_save_path(job)) > 600 and wait: 239 | print("Waiting for bootup (no HEARTBEAT)... last_mod={},n_jobs={}".format( 240 | get_last_modification(get_save_path(job)), get_n_jobs(batch_name))) 241 | 242 | if not is_running(job_id): 243 | print("Job died. Probably a faulty machine or a bug in code. Rerunning") 244 | print("Job cmd was: " + job) 245 | break 246 | 247 | time.sleep(1) 248 | 249 | if time.time() - start_wait > TIMEOUT: 250 | logger.info("Couldn't start the job in {}s. Killing and maybe excluding host (WARNING: excluding doesnt work for slurm).".format(TIMEOUT)) 251 | os.system("scancel " + str(job_id)) 252 | break 253 | elif max_jobs <= n_jobs: 254 | logger.warning("Have {} jobs running but can run max max_jobs={}".format(get_n_jobs(batch_name), max_jobs)) 255 | else: 256 | raise NotImplementedError() 257 | else: 258 | logger.info("No jobs found") 259 | # Allow to kill easilyq 260 | time.sleep(5) 261 | except KeyboardInterrupt: 262 | os.system("scancel -n " + batch_name) 263 | 264 | if __name__ == "__main__": 265 | _, batch, n_jobs = sys.argv 266 | run(batch, int(n_jobs)) 267 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/bin/utils/run_slurm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A simple self-contained script to run on a list of jobs on a SLURM system 5 | 6 | You will need to update bin/utils/slurm_template.sh 7 | 8 | See bin/utils/run_slurm.py -h 9 | 10 | There are the following requirements for the batch: 11 | * Each command has saving dir as 2nd argument 12 | * Each script saves to the save_dir HEARTBEAT 13 | * Each script saves to the save_dir FINISHED when done 14 | 15 | When using: figure out which queue is better. Bsub doesn't like multiple queues. 16 | 17 | When adapting to a new project make sure that get_save_path is OK, and that you replace the project name. 18 | """ 19 | 20 | import os 21 | import time 22 | import stat 23 | from os import path 24 | import subprocess 25 | import logging 26 | import random 27 | import string 28 | from logging import handlers 29 | import sys 30 | 31 | def configure_logger(name='', 32 | console_logging_level=logging.INFO, 33 | file_logging_level=None, 34 | log_file=None): 35 | """ 36 | Configures logger 37 | :param name: logger name (default=module name, __name__) 38 | :param console_logging_level: level of logging to console (stdout), None = no logging 39 | :param file_logging_level: level of logging to log file, None = no logging 40 | :param log_file: path to log file (required if file_logging_level not None) 41 | :return instance of Logger class 42 | """ 43 | 44 | if file_logging_level is None and log_file is not None: 45 | print("Didnt you want to pass file_logging_level?") 46 | 47 | if len(logging.getLogger(name).handlers) != 0: 48 | print("Already configured logger '{}'".format(name)) 49 | return 50 | 51 | if console_logging_level is None and file_logging_level is None: 52 | return # no logging 53 | 54 | logger = logging.getLogger(name) 55 | logger.handlers = [] 56 | logger.setLevel(logging.DEBUG) 57 | format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 58 | 59 | if console_logging_level is not None: 60 | ch = logging.StreamHandler(sys.stdout) 61 | ch.setFormatter(format) 62 | ch.setLevel(console_logging_level) 63 | logger.addHandler(ch) 64 | 65 | if file_logging_level is not None: 66 | if log_file is None: 67 | raise ValueError("If file logging enabled, log_file path is required") 68 | fh = handlers.RotatingFileHandler(log_file, maxBytes=(1048576 * 5), backupCount=7) 69 | fh.setFormatter(format) 70 | logger.addHandler(fh) 71 | 72 | logger.info("Logging configured!") 73 | 74 | return logger 75 | 76 | 77 | 78 | configure_logger('', log_file=None) 79 | 80 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 81 | TIMEOUT = 180 # 3 minutes. Might be short for some 82 | 83 | logger = logging.getLogger(__name__) 84 | 85 | 86 | def random_string(strlen=10): 87 | """Generate a random string of fixed length """ 88 | letters = string.ascii_lowercase 89 | return ''.join(random.choice(letters) for i in range(strlen)) 90 | 91 | 92 | def get_last_modification(save_path): 93 | f_path = os.path.join(save_path, "HEARTBEAT") 94 | stderr_path = os.path.join(save_path, "stderr.txt") 95 | if os.path.exists(f_path): 96 | return time.time() - os.stat(f_path)[stat.ST_MTIME] 97 | else: 98 | if os.path.exists(stderr_path): 99 | return time.time() - os.stat(stderr_path)[stat.ST_MTIME] 100 | else: 101 | return 10000000000000 102 | 103 | def is_running(job_id): 104 | # A hacky way to chekc if a job is running 105 | output = subprocess.check_output("squeue -u jastrs01", shell=True).decode("utf-8").strip() 106 | if output.find(str(job_id)) == -1: 107 | return False 108 | else: 109 | return True 110 | 111 | def get_n_jobs(batch_name): 112 | try: 113 | output = subprocess.check_output('squeue -u jastrs01 | grep {}'.format(batch_name), shell=True).decode("utf-8").strip() 114 | except: 115 | print("No jobs") 116 | output = "" 117 | if len(output) == 0: 118 | return 0 119 | else: 120 | return len(output.split("\n")) - 1 # -1 Because header 121 | 122 | def get_save_path(job): 123 | if job.startswith("python"): 124 | return job.split(" ")[2] 125 | else: 126 | return job.split(" ")[1] 127 | 128 | 129 | def get_script_name(job): 130 | if job.startswith("python"): 131 | return job.split(" ")[1] 132 | else: 133 | return job.split(" ")[0] 134 | 135 | 136 | def has_finished(save_path): 137 | # Hacky but works, usually 138 | return path.exists(path.join(save_path, "FINISHED")) 139 | 140 | 141 | def get_jobs(batch): 142 | if ";" in batch: 143 | batches = batch.split(";") 144 | else: 145 | batches = [batch] 146 | 147 | all_jobs = [] 148 | for batch in batches: 149 | jobs = list(open(batch, "r").read().splitlines()) 150 | jobs = [j for j in jobs if j[0] != "#"] 151 | jobs = [j for j in jobs if not has_finished(get_save_path(j))] 152 | # take only at least 10min old jobs 153 | jobs = [j for j in jobs if get_last_modification(get_save_path(j)) > 600] 154 | jobs = [("python " + j) if "python" not in j else j for j in jobs] 155 | all_jobs += jobs 156 | 157 | random.shuffle(all_jobs) 158 | return all_jobs 159 | 160 | def tensorboard_running(): 161 | output = subprocess.check_output('ps | grep tensorboard', shell=True).decode("utf-8").strip() 162 | return len(output.split("\n")) > 1 163 | 164 | 165 | def run_job(job, batch_name, exclude_hosts=[], wait=1): 166 | # Shorter jobs ! 167 | slurm_cmd=open("bin/utils/slurm_template.sh", "r").read() 168 | slurm_cmd=slurm_cmd.format(job=job, batch_name=batch_name, save_path=get_save_path(job)) 169 | with open(os.path.join(get_save_path(job), "run.sh"), "w") as f: 170 | logger.info("Writing runner to " + os.path.join(get_save_path(job), "run.sh")) 171 | f.write(slurm_cmd) 172 | logger.info("Submitting job to bsub") 173 | 174 | # A heuristic way to submit the job 175 | cmd = "sbatch {}".format(os.path.join(get_save_path(job), "run.sh")) 176 | output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() 177 | assert output.startswith("Submitted") 178 | assert len(output.split(" ")) == 4 179 | job_id = int(output.split(" ")[-1]) 180 | logger.info("Job id is " + str(job_id)) 181 | 182 | # A hacky way to get hostname 183 | hostname = "" 184 | while wait: 185 | logger.info("Waiting to find a machine for {}.. last hostname is {}".format(cmd, hostname)) 186 | 187 | output = subprocess.check_output("squeue -u jastrs01 | grep " + str(job_id), shell=True).decode("utf-8").strip() 188 | parsed_output = output.strip().split() 189 | assert parsed_output[0] == str(job_id) 190 | if parsed_output[-1][0] != "(": 191 | hostname = parsed_output[-1] 192 | break 193 | 194 | time.sleep(1) 195 | 196 | if hostname != "": 197 | logger.info("Hostname is " + str(hostname)) 198 | 199 | return job_id, hostname 200 | 201 | 202 | def run(batch, max_jobs=1, wait=1): 203 | exclude_hosts = [] 204 | 205 | batch_name = random_string(5) 206 | 207 | n_jobs_start = len(get_jobs(batch)) 208 | 209 | print("Starting") 210 | print("==") 211 | 212 | try: 213 | if len(get_jobs(batch)) == 0: 214 | logger.error("No untouched (>10min old) jobs found. Exiting.") 215 | exit(1) 216 | 217 | tb_dir = os.path.join(RESULTS_DIR, "running_experiments") 218 | os.system("mkdir -p " +tb_dir) 219 | 220 | save_path = get_save_path(get_jobs(batch)[0]) 221 | root_save_path = path.dirname(save_path) 222 | os.system("rm " + root_save_path + " " + tb_dir) 223 | os.system("ln -s " + root_save_path + " " + tb_dir) 224 | 225 | while True: 226 | print("n_jobs={}/{},batch={},\nexclude_hosts={},name={}".format(get_n_jobs(batch_name), n_jobs_start, batch, exclude_hosts,batch_name)) 227 | jobs = get_jobs(batch) 228 | logger.info("Found {} jobs to run in the batch script.".format(len(jobs))) 229 | if len(jobs): 230 | job = jobs[0] 231 | n_jobs = get_n_jobs(batch_name) 232 | if max_jobs > n_jobs: 233 | os.system("mkdir -p " + get_save_path(job)) 234 | # Run and redirect all output to a file in the save folder of the job 235 | logger.info("Running " + job) 236 | job_id, hostname = run_job(job, batch_name, exclude_hosts, wait) 237 | start_wait = time.time() 238 | while get_last_modification(get_save_path(job)) > 600 and wait: 239 | print("Waiting for bootup (no HEARTBEAT)... last_mod={},n_jobs={}".format( 240 | get_last_modification(get_save_path(job)), get_n_jobs(batch_name))) 241 | 242 | if not is_running(job_id): 243 | print("Job died. Probably a faulty machine or a bug in code. Rerunning") 244 | print("Job cmd was: " + job) 245 | break 246 | 247 | time.sleep(1) 248 | 249 | if time.time() - start_wait > TIMEOUT: 250 | logger.info("Couldn't start the job in {}s. Killing and maybe excluding host (WARNING: excluding doesnt work for slurm).".format(TIMEOUT)) 251 | os.system("scancel " + str(job_id)) 252 | break 253 | elif max_jobs <= n_jobs: 254 | logger.warning("Have {} jobs running but can run max max_jobs={}".format(get_n_jobs(batch_name), max_jobs)) 255 | else: 256 | raise NotImplementedError() 257 | else: 258 | logger.info("No jobs found") 259 | # Allow to kill easilyq 260 | time.sleep(5) 261 | except KeyboardInterrupt: 262 | os.system("scancel -n " + batch_name) 263 | 264 | if __name__ == "__main__": 265 | _, batch, n_jobs = sys.argv 266 | run(batch, int(n_jobs)) 267 | -------------------------------------------------------------------------------- /pytorch_lightning_project_template/src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minor utilities 3 | """ 4 | 5 | import sys 6 | from functools import reduce 7 | 8 | import traceback 9 | import logging 10 | import argparse 11 | import optparse 12 | import datetime 13 | import sys 14 | import pprint 15 | import types 16 | import time 17 | import copy 18 | import subprocess 19 | import glob 20 | from collections import OrderedDict 21 | import os 22 | import signal 23 | import atexit 24 | import json 25 | import inspect 26 | import pandas as pd 27 | import pickle 28 | 29 | from logging import handlers 30 | 31 | import argh 32 | import gin 33 | 34 | from os.path import join, exists 35 | from gin.config import _OPERATIVE_CONFIG 36 | 37 | import torch 38 | from torch.nn.modules.module import _addindent 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | def parse_gin_config(path): 44 | # A hacky parser for gin config without loading gin config into _CONFIG. Useful for parsin gin config with $ in it. 45 | C = {} 46 | for line in open(path, "r").readlines(): 47 | 48 | if len(line.strip()) == 0 or line[0] == "#": 49 | continue 50 | 51 | k, v = line.split("=") 52 | k, v = k.strip(), v.strip() 53 | k1, k2 = k.split(".") 54 | 55 | v = eval(v) 56 | 57 | C[k1 + "." + k2] = v 58 | return C 59 | 60 | def load_C(e): 61 | return parse_gin_config(join(e, 'config.gin')) 62 | 63 | 64 | def load_H(e): 65 | """ 66 | Loads a unified view of the experiment as a dict. 67 | 68 | Notes 69 | ----- 70 | Assumes logs are generated using CSVLogger and that name is train_[k]_step/train_[k]_epoch/valid_[k] for a metric 71 | of key logged in a given step, train epoch, valid epoch, respectively. 72 | """ 73 | Hs = [] 74 | for version in glob.glob(join(e, 'default', '*')): 75 | if os.path.exists(join(version, "metrics.csv")): 76 | H = pd.read_csv(join(version, "metrics.csv")) 77 | Hs.append(H) 78 | 79 | if len(Hs) == 0: 80 | logger.warning("No found metrics in " + e) 81 | return {} 82 | 83 | H = pd.concat(Hs) #.to_dict('list') 84 | 85 | # Boilerplate to split Pytorch Lightning's metrics into 3 parts 86 | # TODO: Refactor this. I think the best way would be to have a custom CSVLogger that doesn't log everything together 87 | valid_keys = [k for k in H.columns if k.startswith("valid")] 88 | train_epoch_keys = [k for k in H.columns if k.startswith("train") and k.endswith("epoch")] 89 | train_step_keys = [k for k in H.columns if k.startswith("train") and k.endswith("step")] 90 | assert len(valid_keys) > 0, "Make sure to prefix your validation metrics with 'valid'" 91 | H_valid = H[~H[valid_keys[0]].isna()] 92 | H_train_epoch = H[~H[train_epoch_keys[0]].isna()] 93 | H_train_step = H[~H[train_step_keys[0]].isna()] 94 | assert len(H_valid) + len(H_train_epoch) + len(H_train_step) == len(H), "Added or removed logs" 95 | H_valid['epoch'].values[:] = H_train_epoch['epoch'].values[0:len(H_valid['epoch'])] # Populate missing value 96 | del H_valid['step'] 97 | H_train_epoch['epoch_at_step'] = H_train_epoch['step'] 98 | del H_train_epoch['step'] 99 | H_valid = H_valid.dropna(axis='columns') 100 | H_train_epoch = H_train_epoch.dropna(axis='columns') 101 | H_train_step = H_train_step.dropna(axis='columns') 102 | H_processed = H_train_step.to_dict('list') 103 | H_processed.update(H_valid.to_dict('list')) 104 | H_processed.update(H_train_epoch.to_dict('list')) 105 | 106 | # Add evaluation results 107 | eval_results = {} 108 | for f_name in glob.glob(os.path.join(e, 'eval_results*json')): 109 | ev = json.load(open(f_name)) 110 | for k in ev: 111 | eval_results[os.path.basename(f_name) + "_" + k] = [ev[k]] 112 | for k in eval_results: 113 | H_processed['eval_' + k] = eval_results[k] 114 | 115 | return H_processed 116 | 117 | 118 | def load_HC(e): 119 | return load_H(e), load_C(e) 120 | 121 | 122 | def acc(y_pred, y_true): 123 | _, y_pred = y_pred.max(1) 124 | # _, y_true = y_true.max(1) 125 | acc_pred = (y_pred == y_true).float().mean() 126 | return acc_pred * 100 127 | 128 | def save_weights(model, optimizer, filename): 129 | """ 130 | Save all weights necessary to resume training 131 | """ 132 | state = { 133 | 'model': model.state_dict(), 134 | 'optimizer': optimizer.state_dict(), 135 | } 136 | torch.save(state, filename) 137 | 138 | from contextlib import contextmanager 139 | 140 | 141 | class Fork(object): 142 | def __init__(self, file1, file2): 143 | self.file1 = file1 144 | self.file2 = file2 145 | 146 | def write(self, data): 147 | self.file1.write(data) 148 | self.file2.write(data) 149 | 150 | def flush(self): 151 | self.file1.flush() 152 | self.file2.flush() 153 | 154 | 155 | @contextmanager 156 | def replace_logging_stream(file_): 157 | root = logging.getLogger() 158 | if len(root.handlers) != 1: 159 | print(root.handlers) 160 | raise ValueError("Don't know what to do with many handlers") 161 | if not isinstance(root.handlers[0], logging.StreamHandler): 162 | raise ValueError 163 | stream = root.handlers[0].stream 164 | root.handlers[0].stream = file_ 165 | try: 166 | yield 167 | finally: 168 | root.handlers[0].stream = stream 169 | 170 | 171 | @contextmanager 172 | def replace_standard_stream(stream_name, file_): 173 | stream = getattr(sys, stream_name) 174 | setattr(sys, stream_name, file_) 175 | try: 176 | yield 177 | finally: 178 | setattr(sys, stream_name, stream) 179 | 180 | def gin_wrap(fnc): 181 | def main(save_path, config, bindings=""): 182 | # You can pass many configs (think of them as mixins), and many bindings. Both ";" separated. 183 | gin.parse_config_files_and_bindings(config.split("#"), bindings.replace("#", "\n")) 184 | if not os.path.exists(save_path): 185 | logger.info("Creating folder " + save_path) 186 | os.system("mkdir -p " + save_path) 187 | 188 | run_with_redirection(os.path.join(save_path, "stdout.txt"), 189 | os.path.join(save_path, "stderr.txt"), 190 | fnc)(save_path) 191 | 192 | argh.dispatch_command(main) 193 | 194 | 195 | def run_with_redirection(stdout_path, stderr_path, func): 196 | def func_wrapper(*args, **kwargs): 197 | with open(stdout_path, 'a', 1) as out_dst: 198 | with open(stderr_path, 'a', 1) as err_dst: 199 | out_fork = Fork(sys.stdout, out_dst) 200 | err_fork = Fork(sys.stderr, err_dst) 201 | with replace_standard_stream('stderr', err_fork): 202 | with replace_standard_stream('stdout', out_fork): 203 | with replace_logging_stream(err_fork): 204 | func(*args, **kwargs) 205 | 206 | return func_wrapper 207 | 208 | def configure_logger(name='', 209 | console_logging_level=logging.INFO, 210 | file_logging_level=None, 211 | log_file=None): 212 | """ 213 | Configures logger 214 | :param name: logger name (default=module name, __name__) 215 | :param console_logging_level: level of logging to console (stdout), None = no logging 216 | :param file_logging_level: level of logging to log file, None = no logging 217 | :param log_file: path to log file (required if file_logging_level not None) 218 | :return instance of Logger class 219 | """ 220 | 221 | if file_logging_level is None and log_file is not None: 222 | print("Didnt you want to pass file_logging_level?") 223 | 224 | if len(logging.getLogger(name).handlers) != 0: 225 | print("Already configured logger '{}'".format(name)) 226 | return 227 | 228 | if console_logging_level is None and file_logging_level is None: 229 | return # no logging 230 | 231 | logger = logging.getLogger(name) 232 | logger.handlers = [] 233 | logger.setLevel(logging.DEBUG) 234 | format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 235 | 236 | if console_logging_level is not None: 237 | ch = logging.StreamHandler(sys.stdout) 238 | ch.setFormatter(format) 239 | ch.setLevel(console_logging_level) 240 | logger.addHandler(ch) 241 | 242 | if file_logging_level is not None: 243 | if log_file is None: 244 | raise ValueError("If file logging enabled, log_file path is required") 245 | fh = handlers.RotatingFileHandler(log_file, maxBytes=(1048576 * 5), backupCount=7) 246 | fh.setFormatter(format) 247 | logger.addHandler(fh) 248 | 249 | logger.info("Logging configured!") 250 | 251 | return logger 252 | 253 | 254 | def summary(model, file=sys.stderr): 255 | def repr(model): 256 | # We treat the extra repr like the sub-module, one item per line 257 | extra_lines = [] 258 | extra_repr = model.extra_repr() 259 | # empty string will be split into list [''] 260 | if extra_repr: 261 | extra_lines = extra_repr.split('\n') 262 | child_lines = [] 263 | total_params = 0 264 | for key, module in model._modules.items(): 265 | mod_str, num_params = repr(module) 266 | mod_str = _addindent(mod_str, 2) 267 | child_lines.append('(' + key + '): ' + mod_str) 268 | total_params += num_params 269 | lines = extra_lines + child_lines 270 | 271 | for name, p in model._parameters.items(): 272 | total_params += reduce(lambda x, y: x * y, p.shape) 273 | 274 | main_str = model._get_name() + '(' 275 | if lines: 276 | # simple one-liner info, which most builtin Modules will use 277 | if len(extra_lines) == 1 and not child_lines: 278 | main_str += extra_lines[0] 279 | else: 280 | main_str += '\n ' + '\n '.join(lines) + '\n' 281 | 282 | main_str += ')' 283 | if file is sys.stderr: 284 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params) 285 | else: 286 | main_str += ', {:,} params'.format(total_params) 287 | return main_str, total_params 288 | 289 | string, count = repr(model) 290 | if file is not None: 291 | print(string, file=file) 292 | return count 293 | 294 | 295 | if __name__ == "__main__": 296 | H,C = load_HC("save_to_folder8") 297 | print(H) 298 | print(C) --------------------------------------------------------------------------------