├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── clients ├── __init__.py ├── client_constants.py ├── es_client.py ├── s3_client.py └── ses_client.py ├── common.py ├── competitions ├── __init__.py ├── carvana.py ├── dogscats.py ├── planet.py └── team │ ├── __init__.py │ └── brendan.py ├── config.py ├── constants.py ├── datasets ├── __init__.py ├── data_aug.py ├── data_folds.py ├── data_loaders.py ├── data_utils.py ├── datasets.py └── metadata.py ├── docs ├── email.png ├── kibana1.png ├── kibana2.png ├── visdom.png └── visdom2.png ├── ensembles ├── __init__.py ├── ens_utils.py └── ensemble.py ├── experiments ├── __init__.py ├── exp_builder.py ├── exp_config.py ├── exp_history.py ├── exp_utils.py └── experiment.py ├── explore-carvana.ipynb ├── explore-dogscats.ipynb ├── explore-planet.ipynb ├── init_project.py ├── metrics ├── __init__.py ├── evaluate.py ├── loss_functions.py ├── metric.py ├── metric_builder.py └── metric_utils.py ├── models ├── __init__.py ├── builder.py ├── layers.py ├── resnet.py ├── simplenet.py ├── unet.py └── utils.py ├── notifications ├── __init__.py ├── email_constants.py └── emailer.py ├── predictions ├── __init__.py ├── pred_builder.py ├── pred_constants.py ├── pred_utils.py └── prediction.py ├── requirements.txt ├── submissions ├── __init__.py └── utils.py ├── tests └── unit_tests │ └── training │ └── test_learning_rates.py ├── torchsample ├── __init__.py ├── callbacks.py ├── constraints.py ├── datasets.py ├── functions │ ├── __init__.py │ └── affine.py ├── initializers.py ├── metrics.py ├── modules │ ├── __init__.py │ ├── _utils.py │ └── module_trainer.py ├── regularizers.py ├── samplers.py ├── transforms │ ├── __init__.py │ ├── affine_transforms.py │ ├── distortion_transforms.py │ ├── image_transforms.py │ └── tensor_transforms.py ├── utils.py └── version.py ├── training ├── __init__.py ├── learning_rates.py ├── pseudolabels.py ├── trainers.py └── utils.py ├── utils ├── __init__.py ├── files.py ├── general.py ├── imgs.py ├── logger.py ├── multitasking.py └── widgets.py └── visualizers ├── __init__.py ├── kibana.py ├── vis_utils.py └── viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | *gz 2 | *.swp 3 | .ipynb_checkpoints/ 4 | __pycache__ 5 | *~ 6 | *pyc 7 | .cache 8 | .vscode 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Brendan Fortuner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Summary 2 | Pytorch Kaggle starter is a framework for managing experiments in Kaggle competitions. It reduces time to first submission by providing a suite of helper functions for model training, data loading, adjusting learning rates, making predictions, ensembling models, and formatting submissions. 3 | 4 | Inside are example Jupyter notebooks walking through how to get strong scores on popular competitions: 5 | 6 | * [Dogs vs Cats Redux](https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition) - Top 8% 7 | * [Planet Amazon Rainforest](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space) - Top 15% 8 | * [Carvana Image Segmentation](https://www.kaggle.com/c/carvana-image-masking-challenge) - WIP 9 | 10 | These notebooks outline basic, single-model submissions. Scores can be improved significantly by ensembling models and using test-time augmentation. 11 | 12 | ## Features 13 | 14 | 1. **Experiments** - Launch experiments from python dictionaries inside jupyter notebooks or python scripts. Attach Visualizers (Visdom, Kibana), Metrics (Accuracy, F2, Loss), or external datastores (S3, Elasticsearch) 15 | 2. **Monitoring** - Track experiments from your phone or web-browser in real-time with Visdom, a lightweight visualization framework from Facebook 16 | 3. **Notifications** - Receive email notifications when experiments complete or fail 17 | 4. **Sharing** - Upload experiments, predictions and ensembles to S3 for other users to download 18 | 5. **Analysis** - Compare experiments across users with Kibana. Design custom dashboards for specific competitions 19 | 6. **Helpers** - Reduce time to submission with helper code for common tasks--custom datasets, metrics, storing predictions, ensembling models, making submissions, and more. 20 | 7. **Torchsample** - Includes the latest release of ncullen93's [torchsample](https://github.com/ncullen93/torchsample) project for additional trainer helpers and data augmentations. 21 | 22 | ## Requirements 23 | 24 | 1. [Anaconda](https://www.continuum.io/downloads) with Python3 25 | 2. [Pytorch](http://pytorch.org/) 26 | 3. Other requirements: ```pip install -r requirements.txt``` 27 | 4. conda install -c menpo opencv 28 | 5. Server with GPU and Cuda installed 29 | 30 | ## Datasets 31 | To get started you'll need to move all training and test images to the `project_root/datasets/inputs` directory (then either trn_jpg tst_jpg subdirectories). Running the first cell of each notebook creates the directory structure outlined in the `config.py` file. 32 | 33 | There is no need to create separate directories for classes or validation sets. This is handled by the data_fold.py module and the FileDataset, which expects a list of filepaths and targets. After trying out a lot of approaches, I found this to be the easiest and most extensible. You'll sometimes need to generate a `metadata.csv` file separately if Kaggle didn't provide one. This sort of competition-specific code can live in the `competitions/` directory. 34 | 35 | ## Visdom 36 | Visualize experiment progress on your phone with Facebook's new [Visdom](https://github.com/facebookresearch/visdom) framework. 37 | 38 | ![Visdom](docs/visdom.png) 39 | 40 | ## Kibana 41 | Spin up an [Elasticsearch](https://www.elastic.co/) cluster locally or on AWS to start visualizing or tracking experiments. Create custom dashboards with [Kibana's](https://www.elastic.co/products/kibana) easy-to-use drag and drop chart creation tools. 42 | 43 | ![Kibana1](docs/kibana1.png) 44 | 45 | Filter and sort experiments, zoom to a specific time period, or aggregate metrics across experiments and see updates in real time. 46 | 47 | ![Kibana2](docs/kibana2.png) 48 | 49 | ## Emails 50 | Receive emails when experiments compete or fail using AWS SES service. 51 | 52 | ![Emails](docs/email.png) 53 | 54 | ## Kaggle CLI 55 | Quickly download and submit with the kaggle cli tool. 56 | 57 | ``` 58 | kg download -c dogs-vs-cats-redux-kernels-edition -v -u USERNAME -p PASSWORD 59 | kg submit -m 'my sub' -c dogs-vs-cats-redux-kernels-edition -v -u USERNAME -p PASSWORD my_exp_tst.csv 60 | ``` 61 | 62 | ## Best practices 63 | 64 | * Use systemd for always running Visdom and Jupyter servers 65 | 66 | 67 | ## Unit Tests 68 | 69 | Run tests with: 70 | ``` 71 | python -m pytest tests/ 72 | ``` 73 | 74 | Other run commands: 75 | ``` 76 | python -m pytest tests/ (all tests) 77 | python -m pytest -k filenamekeyword (tests matching keyword) 78 | python -m pytest tests/utils/test_sample.py (single test file) 79 | python -m pytest tests/utils/test_sample.py::test_answer_correct (single test method) 80 | python -m pytest --resultlog=testlog.log tests/ (log output to file) 81 | python -m pytest -s tests/ (print output to console) 82 | ``` 83 | 84 | ## TODO 85 | 86 | * Add TTA (test time augmentation) example 87 | * Add Pseudolabeling example 88 | * Add Knowledge Distillation example 89 | * Add Multi-input/Multi-target examples 90 | * Add stacking helper functions 91 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /clients/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/clients/__init__.py -------------------------------------------------------------------------------- /clients/client_constants.py: -------------------------------------------------------------------------------- 1 | import config as cfg 2 | 3 | # AWS Config 4 | AWS_REGION = cfg.AWS_REGION 5 | AWS_ACCESS_KEY = cfg.AWS_ACCESS_KEY 6 | AWS_SECRET_KEY = cfg.AWS_SECRET_KEY 7 | TIMEZONE = cfg.TIMEZONE 8 | 9 | # S3 Config 10 | S3_BUCKET = 'kaggle{:s}'.format(cfg.PROJECT_NAME) 11 | EXPERIMENT_CONFIG_PREFIX = 'experiment_configs/' 12 | EXPERIMENT_HISTORY_PREFIX = 'experiment_histories/' 13 | EXPERIMENT_PREFIX = 'experiments/' 14 | PREDICTION_PREFIX = 'predictions/' 15 | ENSEMBLE_PREFIX = 'ensembles/' 16 | 17 | # Elasticsearch Config 18 | ES_EXPERIMENT_HISTORY_INDEX = 'kaggle-{:s}-history'.format(cfg.PROJECT_NAME) 19 | ES_EXPERIMENT_CONFIG_INDEX = 'kaggle-{:s}-config'.format(cfg.PROJECT_NAME) 20 | ES_PREDICTIONS_INDEX = 'kaggle-{:s}-predictions'.format(cfg.PROJECT_NAME) 21 | ES_EXPERIMENT_HISTORY_DOC_TYPE = 'history' 22 | ES_EXPERIMENT_CONFIG_DOC_TYPE = 'config' 23 | ES_PREDICTIONS_DOC_TYPE = 'prediction' 24 | ES_ENDPOINT = cfg.ES_ENDPOINT 25 | ES_PORT = cfg.ES_PORT 26 | 27 | # SES Config 28 | AWS_SES_REGION = cfg.AWS_SES_REGION 29 | ADMIN_EMAIL = cfg.ADMIN_EMAIL 30 | USER_EMAIL = cfg.USER_EMAIL 31 | EMAIL_CHARSET = 'UTF-8' 32 | -------------------------------------------------------------------------------- /clients/es_client.py: -------------------------------------------------------------------------------- 1 | from elasticsearch import Elasticsearch 2 | from datetime import datetime 3 | import pytz 4 | 5 | from .client_constants import * 6 | 7 | 8 | def upload_experiment_history(config, history): 9 | index_docs(history.to_doc(config), ES_EXPERIMENT_HISTORY_INDEX, 10 | ES_EXPERIMENT_HISTORY_DOC_TYPE) 11 | 12 | 13 | def upload_experiment_config(config): 14 | index_doc(config.to_doc(), ES_EXPERIMENT_CONFIG_INDEX, 15 | ES_EXPERIMENT_CONFIG_DOC_TYPE) 16 | 17 | 18 | def upload_prediction(pred): 19 | index_doc(pred.to_doc(), ES_PREDICTIONS_INDEX, 20 | ES_PREDICTIONS_DOC_TYPE) 21 | 22 | 23 | def delete_experiment(config): 24 | delete_experiment_by_id(config.get_id()) 25 | 26 | 27 | def delete_experiment_by_id(exp_id): 28 | r1 = delete_by_field(ES_EXPERIMENT_HISTORY_INDEX, 29 | ES_EXPERIMENT_HISTORY_DOC_TYPE, 30 | ES_EXP_KEY_FIELD, exp_id) 31 | r2 = delete_by_field(ES_EXPERIMENT_CONFIG_INDEX, 32 | ES_EXPERIMENT_CONFIG_DOC_TYPE, 33 | ES_EXP_KEY_FIELD, exp_id) 34 | return r1,r2 35 | 36 | 37 | def delete_experiment_by_field(field, value): 38 | r1 = delete_by_field(ES_EXPERIMENT_HISTORY_INDEX, 39 | ES_EXPERIMENT_HISTORY_DOC_TYPE, 40 | field, value) 41 | r2 = delete_by_field(ES_EXPERIMENT_CONFIG_INDEX, 42 | ES_EXPERIMENT_CONFIG_DOC_TYPE, 43 | field, value) 44 | return r1,r2 45 | 46 | 47 | # API 48 | # http://elasticsearch-py.readthedocs.io/en/master/api.html 49 | 50 | def get_client(): 51 | return Elasticsearch([ 52 | {'host': ES_ENDPOINT, 'port': ES_PORT}, 53 | ]) 54 | 55 | def create_index(name, shards=2, replicas=1): 56 | es = get_client() 57 | ok = es.indices.create(name, body={ 58 | "settings" : { 59 | "index" : { 60 | "number_of_shards" : shards, 61 | "number_of_replicas" : replicas 62 | } 63 | } 64 | })['acknowledged'] 65 | assert ok is True 66 | return ok 67 | 68 | def delete_index(name): 69 | ok = get_client().indices.delete(name)['acknowledged'] 70 | assert ok is True 71 | return ok 72 | 73 | def delete_docs_by_ids(index_name, doc_ids): 74 | pass 75 | 76 | def delete_by_field(index_name, doc_type, field, value): 77 | query = { 78 | "query": { 79 | "match" : { 80 | field : value 81 | } 82 | } 83 | } 84 | es = get_client() 85 | r = es.delete_by_query(index=index_name, doc_type=doc_type, body=query) 86 | return r 87 | 88 | def search_by_field(index_name, doc_type, field, value): 89 | query = { 90 | "term" : { 91 | field : value 92 | } 93 | } 94 | print(query) 95 | es = get_client() 96 | resp = es.search(index=index_name, doc_type=doc_type, body=query) 97 | return resp 98 | 99 | def get_doc(index_name, doc_key): 100 | return get_client().get(index_name, id=doc_key) 101 | 102 | def search(index_name, query, metadata_only=False, n_docs=10): 103 | es = get_client() 104 | filters = [] 105 | if metadata_only: 106 | filters = ['hits.hits._id', 'hits.hits._type'] 107 | return es.search(index=index_name, filter_path=filters, 108 | body=query, size=n_docs) 109 | 110 | def index_doc(doc, index_name, doc_type): 111 | assert 'key' in doc 112 | es = get_client() 113 | doc['uploaded'] = datetime.now(pytz.timezone(TIMEZONE)) 114 | es.index(index=index_name, doc_type=doc_type, body=doc, id=doc['key']) 115 | 116 | def index_docs(docs, index_name, doc_type): 117 | # There exists a bulk API, but this is fine for now 118 | for doc in docs: 119 | index_doc(doc, index_name, doc_type) 120 | 121 | def get_mappings(index_name): 122 | # Shows the keys and data types in an index 123 | return get_client().indices.get_mapping(index_name) 124 | 125 | def doc_exists(index_name, doc_type, doc_id): 126 | return get_client().exists(index_name, doc_type, doc_id) 127 | 128 | def health(): 129 | return get_client().cluster.health(wait_for_status='yellow', 130 | request_timeout=1) 131 | 132 | def ping(): 133 | return get_client().ping() 134 | 135 | -------------------------------------------------------------------------------- /clients/s3_client.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import constants as c 3 | from .client_constants import * 4 | 5 | 6 | # List Files 7 | 8 | def list_experiment_configs(): 9 | return list_fnames(EXPERIMENT_CONFIG_PREFIX, 10 | c.EXPERIMENT_CONFIG_FILE_EXT) 11 | 12 | def list_experiments(): 13 | return list_fnames(EXPERIMENT_PREFIX, c.EXP_FILE_EXT) 14 | 15 | def list_predictions(): 16 | return list_fnames(PREDICTION_PREFIX, c.PRED_FILE_EXT) 17 | 18 | def list_fnames(prefix, postfix): 19 | keys = get_keys(prefix=prefix) 20 | names = [] 21 | for k in keys: 22 | names.append(k.replace(prefix,'').replace(postfix,'')) 23 | return names 24 | 25 | 26 | # Download 27 | 28 | def download_experiment(dest_fpath, exp_name, bucket=S3_BUCKET): 29 | key = EXPERIMENT_PREFIX+exp_name + c.EXP_FILE_EXT 30 | download_file(dest_fpath, key, bucket=bucket) 31 | 32 | def download_experiment_config(dest_fpath, exp_name, bucket=S3_BUCKET): 33 | key = EXPERIMENT_CONFIG_PREFIX+exp_name+c.EXPERIMENT_CONFIG_FILE_EXT 34 | download_file(dest_fpath, key, bucket=bucket) 35 | 36 | def download_prediction(dest_fpath, pred_name, bucket=S3_BUCKET): 37 | key = PREDICTION_PREFIX+pred_name+c.PRED_FILE_EXT 38 | download_file(dest_fpath, key, bucket=bucket) 39 | 40 | 41 | # Read Object directly from S3 42 | 43 | def fetch_experiment_history(exp_name, bucket=S3_BUCKET): 44 | key = EXPERIMENT_HISTORY_PREFIX+exp_name+c.EXPERIMENT_HISTORY_FILE_EXT 45 | return get_object_str(key, bucket) 46 | 47 | def fetch_experiment_config(exp_name, bucket=S3_BUCKET): 48 | key = EXPERIMENT_CONFIG_PREFIX+exp_name+c.EXPERIMENT_CONFIG_FILE_EXT 49 | return get_object_str(key, bucket) 50 | 51 | 52 | # Upload 53 | 54 | def upload_experiment(src_fpath, exp_name, bucket=S3_BUCKET): 55 | key = EXPERIMENT_PREFIX+exp_name+c.EXP_FILE_EXT 56 | upload_file(src_fpath, key, bucket=bucket) 57 | 58 | def upload_experiment_config(src_fpath, exp_name, bucket=S3_BUCKET): 59 | key = EXPERIMENT_CONFIG_PREFIX+exp_name+c.EXPERIMENT_CONFIG_FILE_EXT 60 | upload_file(src_fpath, key, bucket=bucket) 61 | 62 | def upload_experiment_history(src_fpath, exp_name, bucket=S3_BUCKET): 63 | key = EXPERIMENT_HISTORY_PREFIX+exp_name+c.EXPERIMENT_HISTORY_FILE_EXT 64 | upload_file(src_fpath, key, bucket=bucket) 65 | 66 | def upload_prediction(src_fpath, pred_name, bucket=S3_BUCKET): 67 | key = PREDICTION_PREFIX+pred_name+c.PRED_FILE_EXT 68 | upload_file(src_fpath, key, bucket=bucket) 69 | 70 | 71 | # Cleanup 72 | 73 | def delete_experiment(exp_name): 74 | exp_config_key = (EXPERIMENT_CONFIG_PREFIX + exp_name 75 | + c.EXPERIMENT_CONFIG_FILE_EXT) 76 | exp_history_key = (EXPERIMENT_HISTORY_PREFIX + exp_name 77 | + c.EXPERIMENT_HISTORY_FILE_EXT) 78 | delete_object(S3_BUCKET, key=exp_config_key) 79 | delete_object(S3_BUCKET, key=exp_history_key) 80 | 81 | 82 | # Base Helpers 83 | 84 | def get_client(): 85 | return boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY, 86 | aws_secret_access_key=AWS_SECRET_KEY) 87 | 88 | def get_resource(): 89 | return boto3.resource('s3', aws_access_key_id=AWS_ACCESS_KEY, 90 | aws_secret_access_key=AWS_SECRET_KEY) 91 | 92 | def get_buckets(): 93 | return get_client().list_buckets() 94 | 95 | def get_object_str(key, bucket=S3_BUCKET): 96 | s3 = get_resource() 97 | obj = s3.Object(bucket, key) 98 | return obj.get()['Body'].read().decode('utf-8') 99 | 100 | def get_keys(prefix, bucket=S3_BUCKET): 101 | objs = get_objects(prefix, bucket) 102 | keys = [] 103 | if 'Contents' not in objs: 104 | return keys 105 | for obj in objs['Contents']: 106 | keys.append(obj['Key']) 107 | return keys 108 | 109 | def download_file(dest_fpath, s3_fpath, bucket=S3_BUCKET): 110 | get_client().download_file(Filename=dest_fpath, 111 | Bucket=bucket, 112 | Key=s3_fpath) 113 | 114 | def upload_file(src_fpath, s3_fpath, bucket=S3_BUCKET): 115 | get_client().upload_file(Filename=src_fpath, 116 | Bucket=bucket, 117 | Key=s3_fpath) 118 | 119 | def get_download_url(s3_path, bucket=S3_BUCKET, expiry=86400): 120 | return get_client().generate_presigned_url( 121 | ClientMethod='get_object', 122 | Params={'Bucket': bucket, 123 | 'Key': s3_path}, 124 | ExpiresIn=expiry 125 | ) 126 | 127 | #key = 'experiment_configs/JeremyCNN-SGD-lr1-wd0001-bs32-id6E878.json' 128 | def delete_object(bucket, key): 129 | return get_client().delete_object( 130 | Bucket=bucket, 131 | Key=key 132 | ) 133 | -------------------------------------------------------------------------------- /clients/ses_client.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from .client_constants import * 3 | 4 | 5 | def get_client(): 6 | return boto3.client('ses', aws_access_key_id=AWS_ACCESS_KEY, 7 | aws_secret_access_key=AWS_SECRET_KEY, 8 | region_name=AWS_SES_REGION) 9 | 10 | 11 | def send_email(subject, body, to_email, from_email=ADMIN_EMAIL): 12 | response = get_client().send_email( 13 | Source=from_email, 14 | Destination={ 15 | 'ToAddresses': [ 16 | to_email, 17 | ], 18 | 'CcAddresses': [], 19 | 'BccAddresses': [] 20 | }, 21 | Message={ 22 | 'Subject': { 23 | 'Data': subject, 24 | 'Charset': EMAIL_CHARSET 25 | }, 26 | 'Body': { 27 | 'Text': { 28 | 'Data': body, 29 | 'Charset': EMAIL_CHARSET 30 | }, 31 | 'Html': { 32 | 'Data': body, 33 | 'Charset': EMAIL_CHARSET 34 | } 35 | } 36 | } 37 | ) 38 | return response 39 | 40 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import time 5 | import importlib 6 | import torch 7 | import random 8 | import pickle 9 | import math 10 | import matplotlib.pyplot as plt 11 | import socket 12 | import datetime 13 | from PIL import Image 14 | from collections import Counter 15 | from glob import glob 16 | from IPython.display import FileLink 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | from torchvision import transforms 21 | from torch.backends import cudnn 22 | from torch.autograd import Variable 23 | from torch import optim 24 | from torch import nn 25 | import torchvision 26 | import torchsample 27 | 28 | import config as cfg 29 | import constants as c 30 | 31 | import competitions 32 | 33 | import ensembles 34 | from ensembles import ens_utils 35 | 36 | import datasets 37 | from datasets import data_aug 38 | from datasets import data_folds 39 | from datasets import data_loaders 40 | from datasets import metadata 41 | 42 | from experiments.experiment import Experiment 43 | from experiments import exp_utils, exp_builder 44 | 45 | import models.builder 46 | import models.resnet 47 | import models.unet 48 | import models.utils 49 | 50 | from metrics import evaluate 51 | from metrics import metric_utils 52 | from metrics import metric 53 | from metrics import loss_functions 54 | 55 | import predictions 56 | from predictions import pred_utils 57 | 58 | import submissions 59 | 60 | import training 61 | from training import learning_rates 62 | from training import trainers 63 | 64 | import visualizers 65 | from visualizers.viz import Viz 66 | from visualizers.kibana import Kibana 67 | from visualizers import vis_utils 68 | 69 | import utils 70 | -------------------------------------------------------------------------------- /competitions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/competitions/__init__.py -------------------------------------------------------------------------------- /competitions/carvana.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import gzip 5 | 6 | import utils.files 7 | import config as cfg 8 | import constants as c 9 | from datasets import datasets 10 | from datasets import data_loaders 11 | import predictions 12 | import training 13 | import submissions 14 | 15 | 16 | 17 | def get_submission_lines(pred_arr, fnames): 18 | lines = [] 19 | for i in range(len(pred_arr)): 20 | rle = submissions.run_length_encode(pred_arr[i]) 21 | lines.append(fnames[i]+','+rle) 22 | return lines 23 | 24 | 25 | def make_submission(pred, block_size=10000, header=None, compress=True): 26 | meta = pred.attrs['meta'] 27 | print("Preds", pred.shape, meta['name'], meta['dset']) 28 | input_fnames = meta['input_fnames'] 29 | sub_fpath = os.path.join(cfg.PATHS['submissions'], meta['name']+c.SUBMISSION_FILE_EXT) 30 | lines = [] if header is None else [header] 31 | 32 | i = 0 33 | while i < len(pred): 34 | start = time.time() 35 | pred_block = pred[i:i+block_size].squeeze().astype('uint8') 36 | newlines = get_submission_lines(pred_block, input_fnames[i:i+block_size]) 37 | lines.extend(newlines) 38 | i += block_size 39 | print(training.get_time_msg(start)) 40 | 41 | sub_fpath = utils.files.write_lines(sub_fpath, lines, compress) 42 | return sub_fpath 43 | 44 | 45 | def get_block_predict_dataloaders(dataset, block_size, batch_size): 46 | loaders = [] 47 | i = 0 48 | while i < len(dataset): 49 | inp_fpaths = dataset.input_fpaths[i:i+block_size] 50 | tar_fpaths = (None if dataset.target_fpaths is None 51 | else dataset.target_fpaths[i:i+block_size]) 52 | block_dset = datasets.ImageTargetDataset(inp_fpaths, tar_fpaths, 53 | 'pil', 'pil', input_transform=dataset.input_transform, 54 | target_transform=dataset.target_transform, 55 | joint_transform=dataset.joint_transform) 56 | block_loader = data_loaders.get_data_loader(block_dset, batch_size, 57 | shuffle=False, n_workers=2, pin_memory=False) 58 | loaders.append(block_loader) 59 | i += block_size 60 | return loaders 61 | 62 | 63 | def predict_binary_mask_blocks(name, dset, model, dataset, block_size, 64 | batch_size, threshold, W=None, H=None): 65 | pred_fpath = os.path.join(cfg.PATHS['predictions'], name + '_' 66 | + dset + c.PRED_FILE_EXT) 67 | if os.path.exists(pred_fpath): 68 | print('Pred file exists. Overwriting') 69 | time.sleep(2) 70 | shutil.rmtree(pred_fpath) 71 | 72 | loaders = get_block_predict_dataloaders(dataset, block_size, batch_size) 73 | input_fnames = utils.files.get_fnames_from_fpaths(dataset.input_fpaths) 74 | target_fnames = (None if dataset.target_fpaths is None else 75 | utils.files.get_fnames_from_fpaths(dataset.target_fpaths)) 76 | meta = { 77 | 'name': name, 78 | 'dset': dset, 79 | 'input_fnames': input_fnames, 80 | 'target_fnames': target_fnames 81 | } 82 | 83 | i = 0 84 | for loader in loaders: 85 | print('Predicting block_{:d}, inputs: {:d}'.format(i, len(loader.dataset))) 86 | start = time.time() 87 | pred_block = predictions.get_mask_predictions( 88 | model, loader, threshold, W, H).astype('uint8') 89 | if i == 0: 90 | preds = predictions.save_pred(pred_fpath, pred_block, meta) 91 | else: 92 | preds = predictions.append_to_pred(preds, pred_block) 93 | i += 1 94 | print(training.get_time_msg(start)) 95 | return pred_fpath 96 | 97 | 98 | def upsample_preds(preds, block_size, W, H): 99 | meta = preds[0].attrs['meta'].copy() 100 | n_inputs = preds[0].shape[0] 101 | up_fpath = os.path.join(cfg.PATHS['predictions'], 102 | meta['name']+'_up'+c.PRED_FILE_EXT) 103 | print("inputs", n_inputs, "preds",len(preds), up_fpath) 104 | 105 | if os.path.exists(up_fpath): 106 | print('Ens file exists. Overwriting') 107 | time.sleep(2) 108 | shutil.rmtree(up_fpath) 109 | 110 | i = 0 111 | start = time.time() 112 | while i < n_inputs: 113 | up_block = predictions.resize_batch( 114 | preds[i:i+block_size], W, H).astype('uint8') 115 | if i == 0: 116 | up_pred = predictions.save_pred(up_fpath, up_block, meta) 117 | else: 118 | up_pred = predictions.append_to_pred(up_pred, up_block) 119 | i += block_size 120 | 121 | print(utils.logger.get_time_msg(start)) 122 | return up_pred 123 | 124 | 125 | 126 | def get_and_write_probabilities_to_bcolz(): 127 | """If I need extra speed""" 128 | pass -------------------------------------------------------------------------------- /competitions/dogscats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import config as cfg 6 | import constants as c 7 | 8 | import datasets.metadata as meta 9 | import utils 10 | 11 | 12 | LABEL_NAMES = ['cat', 'dog'] 13 | LABEL_TO_IDX = meta.get_labels_to_idxs(LABEL_NAMES) 14 | IDX_TO_LABEL = meta.get_idxs_to_labels(LABEL_NAMES) 15 | SUB_HEADER = 'id,label' 16 | 17 | 18 | def make_metadata_file(): 19 | ''' 20 | First move the cats/dogs data in train.zip 21 | to `catsdogs/datasets/inputs/trn_jpg` 22 | ''' 23 | train_path = cfg.PATHS['datasets']['inputs']['trn_jpg'] 24 | _, fnames = utils.files.get_paths_to_files( 25 | train_path, strip_ext=True) 26 | lines = [] 27 | for name in fnames: 28 | label = name.split('.')[0] 29 | lines.append('{:s},{:s}\n'.format(name, label)) 30 | with open(cfg.METADATA_PATH, 'w') as f: 31 | f.writelines(lines) -------------------------------------------------------------------------------- /competitions/planet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import config as cfg 6 | import constants as c 7 | 8 | import datasets.metadata as meta 9 | import utils 10 | 11 | 12 | LABEL_NAMES = [ 13 | 'clear','partly_cloudy','haze','cloudy','primary','agriculture','road','water', 14 | 'cultivation','habitation','bare_ground','selective_logging','artisinal_mine','blooming', 15 | 'slash_burn','blow_down','conventional_mine'] 16 | LABEL_TO_IDX = meta.get_labels_to_idxs(LABEL_NAMES) 17 | IDX_TO_LABEL = meta.get_idxs_to_labels(LABEL_NAMES) 18 | SUB_HEADER = 'image_name,tags' 19 | -------------------------------------------------------------------------------- /competitions/team/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/competitions/team/__init__.py -------------------------------------------------------------------------------- /competitions/team/brendan.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/competitions/team/brendan.py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import init_project 4 | import constants as c 5 | 6 | # Main config 7 | HOSTNAME = socket.gethostname() 8 | PROJECT_NAME = 'dogscats' 9 | PROJECT_PATH = '/bigguy/data/' + PROJECT_NAME 10 | PROJECT_TYPE = c.SEGMENTATION 11 | IMG_INPUT_FORMATS = [c.JPG] 12 | IMG_TARGET_FORMATS = [c.BCOLZ] #segmentation or generative 13 | IMG_DATASET_TYPES = [c.TRAIN, c.TEST] 14 | METADATA_PATH = os.path.join(PROJECT_PATH, 'metadata.csv') 15 | PATHS = init_project.init_paths(PROJECT_PATH, IMG_DATASET_TYPES, 16 | IMG_INPUT_FORMATS, IMG_TARGET_FORMATS) 17 | 18 | # AWS Config 19 | AWS_ACCESS_KEY = os.getenv('KAGGLE_AWS_ACCESS_KEY', 'dummy') 20 | AWS_SECRET_KEY = os.getenv('KAGGLE_AWS_SECRET_ACCESS_KEY', 'dummy') 21 | AWS_REGION = 'us-west-1' 22 | AWS_SES_REGION = 'us-west-2' 23 | ES_ENDPOINT = 'search-kagglecarvana-s7dnklyyz6sm2zald6umybeuau.us-west-1.es.amazonaws.com' 24 | ES_PORT = 80 25 | KIBANA_URL = 'https://search-kagglecarvana-s7dnklyyz6sm2zald6umybeuau.us-west-1.es.amazonaws.com/_plugin/kibana' 26 | TIMEZONE = 'US/Pacific' 27 | 28 | # External Resources 29 | # If True, you must setup an S3 bucket, ES Instance, and SES address 30 | S3_ENABLED = bool(os.getenv('KAGGLE_S3_ENABLED', False)) 31 | ES_ENABLED = bool(os.getenv('KAGGLE_ES_ENABLED', False)) 32 | EMAIL_ENABLED = bool(os.getenv('KAGGLE_SES_ENABLED', False)) 33 | 34 | 35 | # Email Notifications 36 | ADMIN_EMAIL = 'bfortuner@gmail.com' 37 | USER_EMAIL = 'bfortuner@gmail.com' 38 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # Project Types 2 | CLASSIFICATION = 'classification' 3 | SEGMENTATION = 'segmentation' 4 | PROJECT_TYPES = [CLASSIFICATION, SEGMENTATION] 5 | 6 | # Datasets 7 | TRAIN = 'trn' 8 | VAL = 'val' 9 | TEST = 'tst' 10 | FULL = 'full' 11 | DSETS = [TRAIN, VAL, TEST, FULL] 12 | 13 | # Transforms 14 | JOINT = 'joint' 15 | UPSAMPLE = 'upsample' 16 | TARGET = 'target' 17 | TENSOR = 'tensor' 18 | MASK = 'mask' 19 | 20 | # File 21 | JPG = 'jpg' 22 | TIF = 'tif' 23 | PNG = 'png' 24 | GIF = 'gif' 25 | BCOLZ = 'bc' 26 | JPG_EXT = '.'+JPG 27 | TIF_EXT = '.'+TIF 28 | PNG_EXT = '.'+PNG 29 | GIF_EXT = '.'+GIF 30 | BCOLZ_EXT = '.'+BCOLZ 31 | IMG_EXTS = [JPG_EXT, TIF_EXT, PNG_EXT, GIF_EXT, BCOLZ_EXT] 32 | CHECKPOINT_EXT = '.th' 33 | EXPERIMENT_CONFIG_FILE_EXT = '.json' 34 | EXPERIMENT_CONFIG_FNAME = 'config.json' 35 | EXPERIMENT_HISTORY_FILE_EXT = '.csv' 36 | EXP_FILE_EXT = '.zip' 37 | PRED_FILE_EXT = '.bc' 38 | SUBMISSION_FILE_EXT = '.csv' 39 | ENSEMBLE_FILE_EXT = '.bc' 40 | DSET_FOLD_FILE_EXT = '.json' 41 | MODEL_EXT = '.mdl' 42 | WEIGHTS_EXT = '.th' 43 | OPTIM_EXT = '.th' 44 | 45 | # Postfix 46 | INPUT_POSTFIX = JPG_EXT 47 | TARGET_POSTFIX = '_mask'+GIF_EXT 48 | 49 | # Metrics 50 | LOSS = 'Loss' 51 | SCORE = 'Score' 52 | ACCURACY = 'Accuracy' 53 | F2_SCORE = 'F2' 54 | ENSEMBLE_F2 = 'EnsembleF2' 55 | DICE_SCORE = 'Dice' 56 | MEAN = 'mean' 57 | GMEAN = 'gmean' 58 | VOTE = 'vote' 59 | STD_DEV = 'std' 60 | ENSEMBLE_METHODS = [MEAN, GMEAN] 61 | 62 | # File Regex 63 | WEIGHTS_FNAME_REGEX = r'weights-(\d+)\.pth$' 64 | OPTIM_FNAME_REGEX = r'optim-(\d+)\.pth$' 65 | WEIGHTS_OPTIM_FNAME_REGEX = r'(weights|optim)-(\d+)\.th$' 66 | LATEST_WEIGHTS_FNAME = 'latest_weights.th' 67 | LATEST_OPTIM_FNAME = 'latest_optim.th' 68 | LATEST = 'latest' 69 | 70 | 71 | # Predictions 72 | SINGLE_MODEL_PRED = 'single-basic' 73 | SINGLE_MODEL_TTA_PRED = 'single-tta' 74 | PREDICTION_TYPES = [SINGLE_MODEL_PRED, SINGLE_MODEL_TTA_PRED] 75 | SINGLE_MODEL_ENSEMBLE = 'single-ens' 76 | SINGLE_MODEL_TTA_ENSEMBLE = 'single-ens-tta' 77 | ENSEMBLE_TYPES = [SINGLE_MODEL_ENSEMBLE, SINGLE_MODEL_TTA_ENSEMBLE] 78 | DEFAULT_BLOCK_NAME = 'preds' 79 | 80 | # Ensembles 81 | MEGA_ENSEMBLE = 'mega-ens' 82 | MEGA_ENSEMBLE_TYPES = [MEGA_ENSEMBLE] 83 | 84 | 85 | # Experiments 86 | INITIALIZED = 'INITIALIZED' 87 | RESUMED = 'RESUMED' 88 | COMPLETED = 'COMPLETED' 89 | IN_PROGRESS = 'IN_PROGRESS' 90 | FAILED = 'FAILED' 91 | MAX_PATIENCE_EXCEEDED = 'MAX_PATIENCE_EXCEEDED' 92 | EXPERIMENT_STATUSES = [INITIALIZED, RESUMED, COMPLETED, 93 | IN_PROGRESS, FAILED, MAX_PATIENCE_EXCEEDED] 94 | EXP_ID_FIELD = 'exp_id' 95 | ES_EXP_KEY_FIELD = 'key' 96 | LATEST_WEIGHTS_FNAME = 'latest_weights{:s}'.format(WEIGHTS_EXT) 97 | LATEST_OPTIM_FNAME = 'latest_optim{:s}'.format(OPTIM_EXT) 98 | MODEL_FNAME = 'model{:s}'.format(MODEL_EXT) 99 | OPTIM_FNAME = 'optim{:s}'.format(OPTIM_EXT) 100 | 101 | 102 | # Data Aug 103 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 104 | IMAGENET_STD = [0.229, 0.224, 0.225] 105 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/data_aug.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from PIL import Image, ImageFilter 4 | import cv2 5 | import numpy as np 6 | 7 | import torch 8 | import torchsample 9 | import torchvision 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data import TensorDataset 12 | 13 | 14 | #http://docs.opencv.org/3.0-beta/doc/py_tutorials/py_imgproc/py_geometric_transformations/py_geometric_transformations.html 15 | 16 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 17 | IMAGENET_STD = [0.229, 0.224, 0.225] 18 | 19 | IMAGENET_NORMALIZE = torchvision.transforms.Normalize( 20 | mean=IMAGENET_MEAN, 21 | std=IMAGENET_STD 22 | ) 23 | 24 | def get_data_aug_summary(transforms): 25 | data_aug = [] 26 | for r in transforms.transforms: 27 | data_aug.append((str(r.__class__.__name__), r.__dict__)) 28 | return data_aug 29 | 30 | 31 | def get_basic_transform(scale, normalize=None): 32 | data_aug = [ 33 | torchvision.transforms.Scale(scale), 34 | torchvision.transforms.ToTensor() 35 | ] 36 | if normalize is not None: 37 | data_aug.append(normalize) 38 | return torchvision.transforms.Compose(data_aug) 39 | 40 | 41 | def get_single_pil_transform(scale, augmentation, normalize=None): 42 | data_aug = [ 43 | torchvision.transforms.Scale(scale), 44 | augmentation, 45 | torchvision.transforms.ToTensor() 46 | ] 47 | if normalize is not None: 48 | data_aug.append(normalize) 49 | return torchvision.transforms.Compose(data_aug) 50 | 51 | 52 | def get_single_tensor_transform(scale, augmentation, normalize=None): 53 | data_aug = [ 54 | torchvision.transforms.Scale(scale), 55 | torchvision.transforms.ToTensor(), 56 | augmentation 57 | ] 58 | if normalize is not None: 59 | data_aug.append(normalize) 60 | return torchvision.transforms.Compose(data_aug) 61 | 62 | 63 | class RandomRotate90(object): 64 | def __init__(self, p=0.75): 65 | self.p = p 66 | 67 | def __call__(self, *inputs): 68 | outputs = [] 69 | for idx, input_ in enumerate(inputs): 70 | input_ = random_rotate_90(input_, self.p) 71 | outputs.append(input_) 72 | return outputs if idx > 1 else outputs[0] 73 | 74 | 75 | class BinaryMask(object): 76 | def __init__(self, thresholds): 77 | self.thresholds = thresholds 78 | 79 | def __call__(self, *inputs): 80 | outputs = [] 81 | for idx, input_ in enumerate(inputs): 82 | input_[input_ >= self.thresholds] = 1.0 83 | input_[input_ < self.thresholds] = 0.0 84 | outputs.append(input_) 85 | return outputs if idx > 1 else outputs[0] 86 | 87 | 88 | class Slice1D(object): 89 | def __init__(self, dim=0, slice_idx=0): 90 | self.dim = dim 91 | self.slice_idx = slice_idx 92 | 93 | def __call__(self, *inputs): 94 | outputs = [] 95 | for idx, input_ in enumerate(inputs): 96 | input_ = torch.unsqueeze(input_[self.slice_idx,:,:], dim=self.dim) 97 | outputs.append(input_) 98 | return outputs if idx > 1 else outputs[0] 99 | 100 | 101 | class RandomHueSaturation(object): 102 | def __init__(self, hue_shift=(-180, 180), sat_shift=(-255, 255), 103 | val_shift=(-255, 255), u=0.5): 104 | self.hue_shift = hue_shift 105 | self.sat_shift = sat_shift 106 | self.val_shift = val_shift 107 | self.u = u 108 | 109 | def __call__(self, *inputs): 110 | outputs = [] 111 | for idx, input_ in enumerate(inputs): 112 | input_ = random_hue_saturation(input_, self.hue_shift, 113 | self.sat_shift, self.val_shift, self.u) 114 | outputs.append(input_) 115 | return outputs if idx > 1 else outputs[0] 116 | 117 | 118 | class RandomShiftScaleRotate(object): 119 | def __init__(self, shift=(-0.0625,0.0625), scale=(-0.1,0.1), 120 | rotate=(-45,45), aspect=(0,0), u=0.5): 121 | self.shift = shift 122 | self.scale = scale 123 | self.rotate = rotate 124 | self.aspect = aspect 125 | self.border_mode = cv2.BORDER_CONSTANT 126 | self.u = u 127 | 128 | def __call__(self, input_, target): 129 | input_, target = random_shift_scale_rot(input_, target, self.shift, 130 | self.scale, self.rotate, self.aspect, self.border_mode, self.u) 131 | return [input_, target] 132 | 133 | 134 | def random_rotate_90(pil_img, p=1.0): 135 | if random.random() < p: 136 | angle=random.randint(1,3)*90 137 | if angle == 90: 138 | pil_img = pil_img.rotate(90) 139 | elif angle == 180: 140 | pil_img = pil_img.rotate(180) 141 | elif angle == 270: 142 | pil_img = pil_img.rotate(270) 143 | return pil_img 144 | 145 | 146 | def random_hue_saturation(image, hue_shift=(-180, 180), sat_shift=(-255, 255), 147 | val_shift=(-255, 255), u=0.5): 148 | image = np.array(image) 149 | if np.random.random() < u: 150 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 151 | h, s, v = cv2.split(image) 152 | hue_shift = np.random.uniform(hue_shift[0], hue_shift[1]) 153 | h = cv2.add(h, hue_shift) 154 | sat_shift = np.random.uniform(sat_shift[0], sat_shift[1]) 155 | s = cv2.add(s, sat_shift) 156 | val_shift = np.random.uniform(val_shift[0], val_shift[1]) 157 | v = cv2.add(v, val_shift) 158 | image = cv2.merge((h, s, v)) 159 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 160 | 161 | return Image.fromarray(image) 162 | 163 | 164 | def random_shift_scale_rot(image, label, shift_limit=(-0.0625,0.0625), 165 | scale_limit=(-0.1,0.1), rotate_limit=(-45,45), aspect_limit = (0,0), 166 | borderMode=cv2.BORDER_CONSTANT, u=0.5): 167 | image = image.numpy().transpose(1,2,0) 168 | label = label.numpy().squeeze() 169 | if random.random() < u: 170 | height,width,channel = image.shape 171 | 172 | angle = random.uniform(rotate_limit[0],rotate_limit[1]) #degree 173 | scale = random.uniform(1+scale_limit[0],1+scale_limit[1]) 174 | aspect = random.uniform(1+aspect_limit[0],1+aspect_limit[1]) 175 | sx = scale*aspect/(aspect**0.5) 176 | sy = scale /(aspect**0.5) 177 | dx = round(random.uniform(shift_limit[0],shift_limit[1])*width ) 178 | dy = round(random.uniform(shift_limit[0],shift_limit[1])*height) 179 | 180 | cc = math.cos(angle/180*math.pi)*(sx) 181 | ss = math.sin(angle/180*math.pi)*(sy) 182 | rotate_matrix = np.array([ [cc,-ss], [ss,cc] ]) 183 | 184 | box0 = np.array([ [0,0], [width,0], [width,height], [0,height], ]) 185 | box1 = box0 - np.array([width/2,height/2]) 186 | box1 = np.dot(box1,rotate_matrix.T) + np.array([width/2+dx,height/2+dy]) 187 | box0 = box0.astype(np.float32) 188 | box1 = box1.astype(np.float32) 189 | mat = cv2.getPerspectiveTransform(box0,box1) 190 | image = cv2.warpPerspective(image, mat, (width,height), 191 | flags=cv2.INTER_LINEAR,borderMode=borderMode,borderValue=(0,0,0,)) 192 | #cv2.BORDER_CONSTANT, borderValue = (0, 0, 0)) #cv2.BORDER_REFLECT_101 193 | 194 | box0 = np.array([ [0,0], [width,0], [width,height], [0,height], ]) 195 | box1 = box0 - np.array([width/2,height/2]) 196 | box1 = np.dot(box1,rotate_matrix.T) + np.array([width/2+dx,height/2+dy]) 197 | box0 = box0.astype(np.float32) 198 | box1 = box1.astype(np.float32) 199 | mat = cv2.getPerspectiveTransform(box0,box1) 200 | label = cv2.warpPerspective(label, mat, (width,height), 201 | flags=cv2.INTER_LINEAR,borderMode=borderMode,borderValue=(0,0,0,)) 202 | #cv2.BORDER_CONSTANT, borderValue = (0, 0, 0)) #cv2.BORDER_REFLECT_101 203 | image = torch.from_numpy(image.transpose(2,0,1)) 204 | label = np.expand_dims(label, 0) 205 | label = torch.from_numpy(label)#.transpose(2,0,1)) 206 | return image,label 207 | 208 | 209 | blurTransform = torchvision.transforms.Lambda( 210 | lambda img: img.filter(ImageFilter.GaussianBlur(1.5))) -------------------------------------------------------------------------------- /datasets/data_folds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | 5 | import utils.files 6 | import constants as c 7 | 8 | 9 | def make_bag(fpaths, targets): 10 | bag_fpaths = [] 11 | bag_targets = [] 12 | for i in range(len(fpaths)): 13 | idx = random.randint(1,len(fpaths)-1) 14 | bag_fpaths.append(fpaths[idx]) 15 | bag_targets.append(targets[idx]) 16 | return bag_fpaths, np.array(bag_targets) 17 | 18 | 19 | def verify_bag(trn_fpaths_bag): 20 | trn_dict = {} 21 | for f in trn_fpaths_bag: 22 | if f in trn_dict: 23 | trn_dict[f] += 1 24 | else: 25 | trn_dict[f] = 1 26 | return max(trn_dict.values()) 27 | 28 | 29 | def make_fold(name, trn_path, tst_path, folds_dir, 30 | val_size, shuffle=True): 31 | _, trn_fnames = utils.files.get_paths_to_files( 32 | trn_path, file_ext=c.JPG, sort=True, strip_ext=True) 33 | _, tst_fnames = utils.files.get_paths_to_files( 34 | tst_path, file_ext=c.JPG, sort=True, strip_ext=True) 35 | 36 | if shuffle: 37 | random.shuffle(trn_fnames) 38 | 39 | fold = { 40 | c.TRAIN: trn_fnames[:-val_size], 41 | c.VAL: trn_fnames[-val_size:], 42 | c.TEST: tst_fnames 43 | } 44 | 45 | fpath = os.path.join(folds_dir, name + c.DSET_FOLD_FILE_EXT) 46 | utils.files.save_json(fpath, fold) 47 | return fold 48 | 49 | 50 | def load_data_fold(folds_dir, name): 51 | fpath = os.path.join(folds_dir, name + c.DSET_FOLD_FILE_EXT) 52 | return utils.files.load_json(fpath) 53 | 54 | 55 | def get_fpaths_from_fold(fold, dset, dset_path, postfix=''): 56 | fnames = fold[dset] 57 | fpaths = [os.path.join(dset_path, f+postfix) for f in fnames] 58 | return fpaths 59 | 60 | 61 | def get_targets_from_fold(fold, dset, lookup): 62 | img_names = [f.split('.')[0] for f in fold[dset]] 63 | targets = [] 64 | for img in img_names: 65 | targets.append(lookup[img]) 66 | return np.array(targets) 67 | 68 | 69 | def get_fpaths_targets_from_fold(fold, dset, dset_path, lookup): 70 | fpaths = get_fpaths_from_fold(fold, dset, dset_path) 71 | targets = get_fpaths_from_fold(fold, dset, lookup) 72 | return fpaths, targets 73 | -------------------------------------------------------------------------------- /datasets/data_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data 3 | from torch.utils.data import DataLoader 4 | from config import * 5 | import utils.imgs as img_utils 6 | 7 | 8 | 9 | class MixDataLoader(): 10 | """ 11 | Combines batches from two data loaders. 12 | Useful for pseudolabeling. 13 | """ 14 | def __init__(self, dl1, dl2): 15 | self.dl1 = dl1 16 | self.dl2 = dl2 17 | self.dl1_iter = iter(dl1) 18 | self.dl2_iter = iter(dl2) 19 | self.n = len(dl1) 20 | self.cur = 0 21 | 22 | def _reset(self): 23 | self.cur = 0 24 | 25 | def _cat_lst(self, fn1, fn2): 26 | return fn1 + fn2 27 | 28 | def _cat_tns(self, t1, t2): 29 | return torch.cat([t1, t2]) 30 | 31 | def __next__(self): 32 | x1,y1,f1 = next(self.dl1_iter) 33 | x2,y2,f2 = next(self.dl2_iter) 34 | while self.cur < self.n: 35 | self.cur += 1 36 | return (self._cat_tns(x1,x2), self._cat_tns(y1,y2), 37 | self._cat_lst(f1,f2)) 38 | 39 | def __iter__(self): 40 | self.cur = 0 41 | self.dl1_iter = iter(self.dl1) 42 | self.dl2_iter = iter(self.dl2) 43 | return self 44 | 45 | def __len__(self): 46 | return self.n 47 | 48 | 49 | def get_batch(dataset, batch_size, shuffle=False): 50 | dataloader = DataLoader( 51 | dataset, batch_size=batch_size, shuffle=shuffle) 52 | inputs, targets, img_paths = next(iter(dataloader)) 53 | return inputs, targets, img_paths 54 | 55 | 56 | def get_data_loader(dset, batch_size, shuffle=False, 57 | n_workers=1, pin_memory=False): 58 | return DataLoader(dset, batch_size, shuffle=shuffle, 59 | pin_memory=pin_memory, num_workers=n_workers) 60 | -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import utils 5 | from glob import glob 6 | from PIL import Image 7 | from skimage import io 8 | import torch 9 | 10 | import config as cfg 11 | import constants as c 12 | from datasets import metadata 13 | 14 | 15 | 16 | def pil_loader(path): 17 | return Image.open(path).convert('RGB') 18 | 19 | 20 | def tensor_loader(path): 21 | return torch.load(path) 22 | 23 | 24 | def numpy_loader(path): 25 | return np.load(path) 26 | 27 | 28 | def io_loader(path): 29 | return io.imread(path) 30 | 31 | 32 | def tif_loader(path): 33 | return calibrate_image(io.imread(path)[:,:,(2,1,0,3)]) 34 | 35 | 36 | def calibrate_image(rgb_image, ref_stds, ref_means): 37 | res = rgb_image.astype('float32') 38 | return np.clip((res - np.mean(res,axis=(0,1))) / np.std(res,axis=(0,1)) 39 | * ref_stds + ref_means,0,255).astype('uint8') 40 | 41 | 42 | def get_inputs_targets(fpaths, dframe): 43 | ## REFACTOR 44 | inputs = [] 45 | targets = [] 46 | for fpath in fpaths: 47 | # Refactor 48 | name, tags = metadata.get_img_name_and_tags(METADATA_DF, fpath) 49 | inputs.append(img_utils.load_img_as_arr(fpath)) 50 | targets.append(meta.get_one_hots_by_name(name, dframe)) 51 | return np.array(inputs), np.array(targets) -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import data_utils 3 | 4 | 5 | loaders = { 6 | 'pil': data_utils.pil_loader, 7 | 'tns': data_utils.tensor_loader, 8 | 'npy': data_utils.numpy_loader, 9 | 'tif': data_utils.tif_loader, 10 | 'io': data_utils.io_loader 11 | } 12 | 13 | 14 | class FileDataset(torch.utils.data.Dataset): 15 | def __init__(self, fpaths, 16 | img_loader='pil', 17 | targets=None, 18 | transform=None, 19 | target_transform=None): 20 | self.fpaths = fpaths 21 | self.loader = self._get_loader(img_loader) 22 | self.targets = targets 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | 26 | def _get_loader(self, loader_type): 27 | return loaders[loader_type] 28 | 29 | def _get_target(self, index): 30 | if self.targets is None: 31 | return torch.FloatTensor(1) 32 | target = self.targets[index] 33 | if self.target_transform is not None: 34 | return self.target_transform(target) 35 | return torch.FloatTensor(target) 36 | 37 | def _get_input(self, index): 38 | img_path = self.fpaths[index] 39 | img = self.loader(img_path) 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | return img 43 | 44 | def __getitem__(self, index): 45 | input_ = self._get_input(index) 46 | target = self._get_target(index) 47 | img_path = self.fpaths[index] 48 | return input_, target, img_path 49 | 50 | def __len__(self): 51 | return len(self.fpaths) 52 | 53 | 54 | class MultiInputDataset(FileDataset): 55 | def __init__(self, fpaths, 56 | img_loader='pil', #'tns', 'npy' 57 | targets=None, 58 | other_inputs=None, 59 | transform=None, 60 | target_transform=None): 61 | super().__init__(fpaths, img_loader, targets, 62 | transform, target_transform) 63 | self.other_inputs = other_inputs 64 | 65 | def _get_other_input(self, index): 66 | other_input = self.other_inputs[index] 67 | return other_input 68 | 69 | def __getitem__(self, index): 70 | input_ = self._get_input(index) 71 | target = self._get_target(index) 72 | other_input = self._get_other_input(index) 73 | img_path = self.fpaths[index] 74 | return input_, target, other_input, img_path 75 | 76 | 77 | class MultiTargetDataset(FileDataset): 78 | def __init__(self, fpaths, 79 | img_loader='pil', 80 | targets=None, 81 | other_targets=None, 82 | transform=None, 83 | target_transform=None): 84 | super().__init__(fpaths, img_loader, targets, 85 | transform, target_transform) 86 | self.other_targets = other_targets 87 | 88 | def _get_other_target(self, index): 89 | if self.other_targets is None: 90 | return torch.FloatTensor(1) 91 | other_target = self.other_targets[index] 92 | return torch.FloatTensor(other_target) 93 | 94 | def __getitem__(self, index): 95 | input_ = self._get_input(index) 96 | target = self._get_target(index) 97 | other_target = self._get_other_target(index) 98 | img_path = self.fpaths[index] 99 | return input_, target, other_target, img_path 100 | 101 | 102 | class ImageTargetDataset(torch.utils.data.Dataset): 103 | def __init__(self, input_fpaths, 104 | target_fpaths, 105 | input_loader='pil', 106 | target_loader='pil', 107 | input_transform=None, 108 | target_transform=None, 109 | joint_transform=None): 110 | self.input_fpaths = input_fpaths 111 | self.target_fpaths = target_fpaths 112 | self.input_loader = loaders[input_loader] 113 | self.target_loader = loaders[target_loader] 114 | self.input_transform = input_transform 115 | self.target_transform = target_transform 116 | self.joint_transform = joint_transform 117 | 118 | def _get_target(self, index): 119 | if self.target_fpaths is None: 120 | return torch.FloatTensor(1), "" 121 | img_path = self.target_fpaths[index] 122 | img = self.target_loader(img_path) 123 | if self.target_transform is not None: 124 | img = self.target_transform(img) 125 | return img, img_path 126 | 127 | def _get_input(self, index): 128 | img_path = self.input_fpaths[index] 129 | img = self.input_loader(img_path) 130 | if self.input_transform is not None: 131 | img = self.input_transform(img) 132 | return img, img_path 133 | 134 | def __getitem__(self, index): 135 | input_, inp_path = self._get_input(index) 136 | target, tar_path = self._get_target(index) 137 | if self.joint_transform is not None: 138 | input_, target = self.joint_transform(input_, target) 139 | return input_, target, inp_path, tar_path 140 | 141 | def __len__(self): 142 | return len(self.input_fpaths) 143 | -------------------------------------------------------------------------------- /datasets/metadata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | import constants as c 5 | 6 | 7 | def get_metadata_df(fpath): 8 | return pd.read_csv(fpath, header=0, names=['id','labels']) 9 | 10 | 11 | def get_labels_to_idxs(label_names): 12 | return {v:k for k,v in enumerate(label_names)} 13 | 14 | 15 | def get_idxs_to_labels(label_names): 16 | return {k:v for k,v in enumerate(label_names)} 17 | 18 | 19 | def convert_tags_to_one_hots(tags, label_names, delim=' '): 20 | label_to_idx = get_labels_to_idxs(label_names) 21 | idxs = [label_to_idx[o] for o in tags.split(delim)] 22 | onehot = np.zeros((len(label_names),), dtype=np.float32) 23 | onehot[idxs] = 1 24 | return onehot 25 | 26 | 27 | def get_one_hots_array(meta_fpath, label_names): 28 | meta_df = get_metadata_df(meta_fpath) 29 | onehots = np.zeros( (0, len(label_names)) ) 30 | for _, row in meta_df.iterrows(): 31 | onehot = convert_tags_to_one_hots(row[1], label_names) 32 | onehots = np.append(onehots, np.array([onehot]), axis=0) 33 | return onehots 34 | 35 | 36 | def get_one_hots_from_fold(meta_fpath, fold, dset, label_names): 37 | meta_df = get_metadata_df(meta_fpath) 38 | onehots = np.zeros( (0, len(label_names)) ) 39 | for _, name in enumerate(fold[dset]): 40 | tags = meta_df[meta_df['id'] == name]['labels'].values[0] 41 | onehot = convert_tags_to_one_hots(tags, label_names) 42 | onehots = np.append(onehots, np.array([onehot]), axis=0) 43 | return onehots 44 | 45 | 46 | def get_label_idx_by_name(label, label_names): 47 | return label_names.index(label) 48 | 49 | 50 | def get_tags_from_preds(preds, label_names): 51 | tags = [] 52 | for _, pred in enumerate(preds): 53 | tag_str = ' '.join(convert_one_hot_to_tags(pred, label_names)) 54 | tags.append(tag_str) 55 | return tags 56 | 57 | 58 | def convert_one_hot_to_tags(onehot, label_names): 59 | tags = [] 60 | for idx, val in enumerate(onehot): 61 | if val == 1: 62 | tags.append(label_names[idx]) 63 | return tags -------------------------------------------------------------------------------- /docs/email.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/docs/email.png -------------------------------------------------------------------------------- /docs/kibana1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/docs/kibana1.png -------------------------------------------------------------------------------- /docs/kibana2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/docs/kibana2.png -------------------------------------------------------------------------------- /docs/visdom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/docs/visdom.png -------------------------------------------------------------------------------- /docs/visdom2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/docs/visdom2.png -------------------------------------------------------------------------------- /ensembles/__init__.py: -------------------------------------------------------------------------------- 1 | from .ens_utils import * 2 | import os -------------------------------------------------------------------------------- /ensembles/ens_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import shutil 7 | 8 | import config as cfg 9 | import constants as c 10 | import utils 11 | import predictions 12 | 13 | 14 | 15 | def get_ensemble_fpath(basename, dset): 16 | fname = '{:s}_{:s}_{:s}'.format(basename, 'ens', dset + c.PRED_FILE_EXT) 17 | return os.path.join(cfg.PATHS['predictions'], fname) 18 | 19 | 20 | def get_ensemble_meta(name, fpaths): 21 | preds = [predictions.load_pred(f) for f in fpaths] 22 | meta = preds[0].attrs['meta'].copy() 23 | meta['name'] = name 24 | meta['members'] = {p.attrs['meta']['name']:p.attrs['meta'] for p in preds} 25 | print("members", list(meta['members'].keys())) 26 | return meta 27 | 28 | 29 | def ens_prediction_files(ens_fpath, pred_fpaths, block_size=1, 30 | method=c.MEAN, meta=None): 31 | preds = [predictions.load_pred(f) for f in pred_fpaths] 32 | n_inputs = preds[0].shape[0] 33 | if os.path.exists(ens_fpath): 34 | print('Ens file exists. Overwriting') 35 | time.sleep(2) 36 | shutil.rmtree(ens_fpath) 37 | 38 | i = 0 39 | start = time.time() 40 | while i < n_inputs: 41 | pred_block = np.array([p[i:i+block_size] for p in preds]) 42 | ens_block = predictions.ensemble_with_method(pred_block, method) 43 | if i == 0: 44 | ens_pred = predictions.save_pred(ens_fpath, ens_block, meta) 45 | else: 46 | ens_pred = predictions.append_to_pred(ens_pred, ens_block) 47 | i += block_size 48 | print(utils.logger.get_time_msg(start)) 49 | return ens_fpath 50 | 51 | 52 | def build_scores(loss, score): 53 | return { 54 | c.LOSS: loss, 55 | c.SCORE: score 56 | } 57 | 58 | 59 | def build_metadata(labels, scores, thresholds, pred_type, dset): 60 | return { 61 | 'label_names': labels, 62 | 'scores': scores, 63 | 'thresholds': thresholds, 64 | 'pred_type': pred_type, 65 | 'dset': dset, 66 | 'created': time.strftime("%m/%d/%Y %H:%M:%S", time.localtime()) 67 | } -------------------------------------------------------------------------------- /ensembles/ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import utils.files 4 | import copy 5 | import constants as c 6 | from predictions.prediction import Prediction 7 | 8 | 9 | 10 | class MegaEnsemblePrediction(Prediction): 11 | """ 12 | Prediction combining multiple experiments, models and epochs 13 | """ 14 | def __init__(self, name, pred_type, fpath, thresholds, 15 | label_names, val_score, val_probs, val_preds, 16 | test_probs, test_preds, created, sub_preds, 17 | ens_method, all_val_probs, all_test_probs): 18 | super().__init__(name, pred_type, fpath, thresholds, 19 | label_names, val_score, val_probs, val_preds, 20 | test_probs, test_preds, tta=None, created=created, 21 | other=None) 22 | 23 | self.sub_preds = self.get_sub_pred_docs(sub_preds) 24 | self.ens_method = ens_method 25 | self.all_val_probs = all_val_probs 26 | self.all_test_probs = all_test_probs 27 | 28 | def get_sub_pred_docs(self, sub_preds): 29 | docs = [] 30 | for pred in sub_preds: 31 | docs.append(pred.to_doc(include_exp=False)) 32 | return docs 33 | 34 | def to_doc(self): 35 | d = copy.deepcopy(self.__dict__) 36 | d['key'] = self.get_id() 37 | d['pred_id'] = self.get_id() 38 | d['display_name'] = self.get_display_name() 39 | d['preds'] = self.sub_preds 40 | del d['val_probs'] 41 | del d['val_preds'] 42 | del d['test_probs'] 43 | del d['test_preds'] 44 | del d['all_val_probs'] 45 | del d['all_test_probs'] 46 | return d 47 | 48 | 49 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/exp_builder.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import pandas as pd 3 | from .experiment import Experiment 4 | 5 | 6 | 7 | def prune_experiments(exp_dir, exp_names): 8 | # Delete weights except for best weight in `n_bins` 9 | for name in exp_names: 10 | exp = Experiment(name, exp_dir) 11 | exp.review(verbose=False) 12 | exp.auto_prune(n_bins=5, metric_name='Loss', func=max) 13 | 14 | 15 | def build_exp_summary_dict(exp): 16 | config = exp.config 17 | history = exp.history 18 | dict_ = { 19 | 'name': config.name, 20 | 'exp_id': config.get_id(), 21 | 'created': config.created, 22 | 'fold': config.data['dset_fold'], 23 | 'model_name' : config.model_name, 24 | 'threshold' : config.training['threshold'], 25 | 'model_name' : config.model['name'], 26 | 'optim' : config.optimizer['name'], 27 | 'optim_params' : str(config.optimizer['params']), 28 | 'lr_adjuster' : config.lr_adjuster['name'], 29 | 'lr_adjuster_params' : str(config.lr_adjuster['params']), 30 | 'criterion': config.criterion['name'], 31 | 'transforms' : ', '.join([t[0] for t in config.transforms]), 32 | ### initial lr, img_scale, rescale, total_epochs 33 | 'transforms' : ', '.join([t[0] for t in config.transforms]), 34 | 'init_lr':config.training['initial_lr'], 35 | 'wdecay':config.training['weight_decay'], 36 | 'batch': config.training['batch_size'], 37 | 'img_scl':config.data['img_scale'], 38 | 'img_rescl': config.data['img_rescale'], 39 | 'nepochs':config.training['n_epochs'], 40 | } 41 | for name in config.metrics: 42 | dict_[name+'Epoch'] = history.best_metrics[name]['epoch'] 43 | dict_[name+'Val'] = history.best_metrics[name]['value'] 44 | return dict_ 45 | 46 | 47 | def build_exps_df_from_dir(exps_dir): 48 | exp_names = glob(exps_dir+'/*/') 49 | summaries = [] 50 | for name in exp_names: 51 | exp = Experiment(name, exps_dir) 52 | exp.review(verbose=False) 53 | exp_summary = build_exp_summary_dict(exp) 54 | summaries.append(exp_summary) 55 | return pd.DataFrame(summaries) 56 | 57 | 58 | def upload_experiments(exp_dir): 59 | exp_paths = glob(exp_dir+'/*/') 60 | for path in exp_paths: 61 | name = path.strip('/').split('/')[-1] 62 | exp = Experiment(name, exp_dir) 63 | exp.review(verbose=False) 64 | exp.save() 65 | 66 | 67 | def upload_experiments(exp_dir): 68 | exp_paths = glob(exp_dir+'/*/') 69 | for path in exp_paths: 70 | name = path.strip('/').split('/')[-1] 71 | exp = Experiment(name, exp_dir) 72 | exp.review(verbose=False) 73 | exp.save() 74 | -------------------------------------------------------------------------------- /experiments/exp_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pprint 4 | import logging 5 | import time 6 | import copy 7 | from datetime import datetime 8 | 9 | import pandas as pd 10 | 11 | import config 12 | import constants as c 13 | import utils.files 14 | import utils.general 15 | from clients import s3_client, es_client 16 | 17 | 18 | 19 | 20 | class ExperimentConfig(): 21 | def __init__(self, name, parent_dir, created, metrics, aux_metrics, 22 | model, optimizer, lr_adjuster, criterion, transforms, 23 | visualizers, training, data, hardware, other, progress=None): 24 | self.name = name 25 | self.parent_dir = parent_dir 26 | self.fpath = os.path.join(parent_dir, name, 27 | c.EXPERIMENT_CONFIG_FNAME) 28 | self.created = created 29 | self.metrics = metrics 30 | self.aux_metrics = aux_metrics 31 | self.visualizers = visualizers 32 | self.model = model 33 | self.optimizer = optimizer 34 | self.lr_adjuster = lr_adjuster 35 | self.criterion = criterion 36 | self.transforms = transforms 37 | self.data = data 38 | self.training = training 39 | self.hardware = hardware 40 | self.other = other 41 | self.progress = {} if progress is None else progress 42 | self.model_name = self.model['name'] 43 | 44 | def get_id(self): 45 | return self.name.split('-id')[-1] 46 | 47 | def get_display_name(self): 48 | return self.name.split('-id')[0] 49 | 50 | def summary(self, include_model=True): 51 | d = dict(self.__dict__) 52 | del d['model'] 53 | print(json.dumps(d, indent=4, ensure_ascii=False)) 54 | if include_model: 55 | print(self.model['layers']) 56 | 57 | def save(self, s3=config.S3_ENABLED, es=config.ES_ENABLED): 58 | dict_ = self.__dict__ 59 | utils.files.save_json(self.fpath, dict_) 60 | if s3: 61 | s3_client.upload_experiment_config(self.fpath, self.name) 62 | if es: 63 | es_client.upload_experiment_config(self) 64 | 65 | def to_dict(self): 66 | return self.__dict__ 67 | 68 | def to_json(self): 69 | return json.dumps(self.to_dict(), indent=4, ensure_ascii=False) 70 | 71 | def to_html(self): 72 | dict_ = self.to_dict() 73 | html = utils.general.dict_to_html_ul(dict_) 74 | return html 75 | 76 | def to_doc(self): 77 | # Changes to self.__dict__ also change instance variables?? 78 | doc = copy.deepcopy(self.to_dict()) 79 | doc[c.EXP_ID_FIELD] = self.get_id() 80 | doc[c.ES_EXP_KEY_FIELD] = self.get_id() 81 | doc['display_name'] = self.get_display_name() 82 | doc['transforms'] = str(doc['transforms']) 83 | del doc['model'] 84 | return doc 85 | 86 | 87 | ## Helpers 88 | 89 | def fetch_external_config(exp_name): 90 | str_ = s3_client.fetch_experiment_config(exp_name) 91 | dict_ = json.loads(str_) 92 | return load_config_from_json(dict_) 93 | 94 | 95 | def load_config_from_file(fpath): 96 | dict_ = utils.files.load_json(fpath) 97 | return load_config_from_json(dict_) 98 | 99 | 100 | def load_config_from_json(dict_): 101 | return ExperimentConfig( 102 | name=dict_['name'], 103 | parent_dir=dict_['parent_dir'], 104 | created=dict_['created'], 105 | metrics=dict_['metrics'], 106 | aux_metrics=dict_['aux_metrics'], 107 | visualizers=dict_['visualizers'], 108 | model=dict_['model'], 109 | optimizer=dict_['optimizer'], 110 | lr_adjuster=dict_['lr_adjuster'], 111 | criterion=dict_['criterion'], 112 | transforms=dict_['transforms'], 113 | data=dict_['data'], 114 | training=dict_['training'], 115 | hardware=dict_['hardware'], 116 | other=dict_['other'], 117 | progress=dict_['progress']) 118 | 119 | 120 | def create_config_from_dict(config): 121 | metrics_config = get_metrics_config(config['metrics']) 122 | aux_metrics_config = get_aux_metrics_config(config['aux_metrics']) 123 | visualizers_config = get_visualizers_config(config['visualizers']) 124 | transforms_config = get_transforms_config(config['transforms']) 125 | model_config = get_model_config(config['model']) 126 | optim_config = get_optim_config(config['optimizer']) 127 | lr_adjuster_config = get_lr_config(config['lr_adjuster']) 128 | criterion_config = get_criterion_config(config['criterion']) 129 | return ExperimentConfig( 130 | name=config['name'], 131 | parent_dir=config['parent_dir'], 132 | created=time.strftime("%m/%d/%Y %H:%M:%S", time.localtime()), 133 | metrics=metrics_config, 134 | aux_metrics=aux_metrics_config, 135 | visualizers=visualizers_config, 136 | model=model_config, 137 | optimizer=optim_config, 138 | lr_adjuster=lr_adjuster_config, 139 | criterion=criterion_config, 140 | transforms=transforms_config, 141 | data=config['data'], 142 | training=get_training_config(config['training']), 143 | hardware=config['hardware'], 144 | other=config['other']) 145 | 146 | 147 | def remove_large_items(dict_): 148 | max_len = 100 149 | new_dict = {} 150 | for k,v in dict_.items(): 151 | if isinstance(v, list) and len(v) > max_len: 152 | pass 153 | elif isinstance(v, dict): 154 | if len(v.items()) < max_len: 155 | new_dict[k] = str(v.items()) 156 | else: 157 | assert not isinstance(v, dict) 158 | new_dict[k] = v 159 | return new_dict 160 | 161 | 162 | def get_training_config(train_config): 163 | return remove_large_items(train_config) 164 | 165 | 166 | def get_model_config(model): 167 | name = utils.general.get_class_name(model) 168 | layers = str(model) 169 | return { 170 | 'name': name, 171 | 'layers': layers 172 | } 173 | 174 | 175 | def get_optim_config(optim): 176 | name = utils.general.get_class_name(optim) 177 | params = optim.param_groups[0] 178 | params = remove_large_items(dict(params)) 179 | if 'params' in params: 180 | del params['params'] 181 | return { 182 | 'name': name, 183 | 'params': params 184 | } 185 | 186 | 187 | def get_lr_config(lr_adjuster): 188 | name = utils.general.get_class_name(lr_adjuster) 189 | params = dict(vars(lr_adjuster)) 190 | params = remove_large_items(params) 191 | return { 192 | 'name': name, 193 | 'params': params 194 | } 195 | 196 | 197 | def get_criterion_config(criterion): 198 | name = utils.general.get_class_name(criterion) 199 | return { 200 | 'name': name 201 | } 202 | 203 | 204 | def get_transforms_config(transforms): 205 | data_aug = [] 206 | for r in transforms.transforms: 207 | data_aug.append((str(r.__class__.__name__), 208 | str(r.__dict__))) 209 | return data_aug 210 | 211 | 212 | def get_visualizers_config(visualizers): 213 | return [v.classname for v in visualizers] 214 | 215 | 216 | def get_metrics_config(metrics): 217 | return [m.name for m in metrics] 218 | 219 | 220 | def get_aux_metrics_config(aux_metrics): 221 | return [m.__dict__ for m in aux_metrics] 222 | 223 | 224 | -------------------------------------------------------------------------------- /experiments/exp_history.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from os.path import join 4 | from pathlib import Path 5 | import matplotlib as mpl 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import utils.files 9 | import pandas as pd 10 | from io import StringIO 11 | mpl.use('Agg') 12 | plt.style.use('bmh') 13 | 14 | import config as cfg 15 | import constants as c 16 | from clients import s3_client, es_client 17 | 18 | 19 | 20 | class ExperimentHistory(): 21 | 22 | def __init__(self, exp_name, history_dir, metrics=None, aux_metrics=None): 23 | self.exp_name = exp_name 24 | self.history_dir = history_dir 25 | self.train_history_fpath = join(self.history_dir, c.TRAIN+'.csv') 26 | self.val_history_fpath = join(self.history_dir, c.VAL+'.csv') 27 | self.aux_metrics_fpath = join(self.history_dir, 'aux_metrics.csv') 28 | self.summary_fpath = join(self.history_dir, exp_name+'.csv') 29 | self.metrics = metrics 30 | self.aux_metrics = aux_metrics 31 | self.metrics_history = None 32 | self.best_metrics = {} 33 | 34 | def init(self): 35 | self.init_metrics() 36 | self.init_history_files() 37 | 38 | def resume(self, fetch=False): 39 | self.init_metrics() 40 | if fetch: 41 | self.load_from_external() 42 | else: 43 | self.load_from_files() 44 | self.update_best_metrics() 45 | 46 | def init_history_files(self): 47 | Path(self.train_history_fpath).touch() 48 | Path(self.val_history_fpath).touch() 49 | Path(self.aux_metrics_fpath).touch() 50 | 51 | def init_metrics(self): 52 | histories = {} 53 | for metric in self.metrics: 54 | histories[metric.name] = { 55 | c.TRAIN: [], 56 | c.VAL: [] 57 | } 58 | for aux_metric in self.aux_metrics: 59 | histories[aux_metric.name] = [] 60 | self.metrics_history = histories 61 | 62 | def save(self, config, s3=cfg.S3_ENABLED, es=cfg.S3_ENABLED): 63 | df = pd.DataFrame() 64 | for metric in self.metrics: 65 | trn_data = self.metrics_history[metric.name][c.TRAIN] 66 | val_data = self.metrics_history[metric.name][c.VAL] 67 | df[c.TRAIN+'_'+metric.name] = trn_data 68 | df[c.VAL+'_'+metric.name] = val_data 69 | 70 | for aux_metric in self.aux_metrics: 71 | df[aux_metric.name] = self.metrics_history[aux_metric.name] 72 | 73 | epochs = pd.Series([i for i in range(1,len(trn_data)+1)]) 74 | df.insert(0, 'Epoch', epochs) 75 | df.to_csv(self.summary_fpath, index=False) 76 | 77 | if s3: 78 | s3_client.upload_experiment_history(self.summary_fpath, 79 | self.exp_name) 80 | if es: 81 | es_client.upload_experiment_history(config, self) 82 | 83 | def load_from_files(self): 84 | self.load_history_from_file(c.TRAIN) 85 | self.load_history_from_file(c.VAL) 86 | self.load_aux_metrics_from_file() 87 | 88 | def load_from_external(self): 89 | df = self.fetch_dataframe() 90 | for metric in self.metrics: 91 | for dset in [c.TRAIN, c.VAL]: 92 | data = df[dset+'_'+metric.name].tolist() 93 | self.metrics_history[metric.name][dset] = data 94 | for aux_metric in self.aux_metrics: 95 | data = df[aux_metric.name].tolist() 96 | self.metrics_history[aux_metric.name] = data 97 | 98 | def get_dataframe(self): 99 | if os.path.isfile(self.summary_fpath): 100 | return self.load_dataframe_from_file() 101 | return self.fetch_dataframe() 102 | 103 | def fetch_dataframe(self): 104 | csv_str = s3_client.fetch_experiment_history(self.exp_name) 105 | df = pd.DataFrame 106 | data = StringIO(csv_str) 107 | return pd.read_csv(data, sep=",") 108 | 109 | def load_dataframe_from_file(self): 110 | df = pd.read_csv(self.summary_fpath, sep=',') 111 | return df 112 | 113 | def save_metric(self, dset_type, values_dict, epoch): 114 | values_arr = [] 115 | for metric in self.metrics: 116 | value = values_dict[metric.name] 117 | self.metrics_history[metric.name][dset_type].append(value) 118 | values_arr.append(value) 119 | fpath = join(self.history_dir, dset_type+'.csv') 120 | self.append_history_to_file(fpath, values_arr, epoch) 121 | 122 | def load_history_from_file(self, dset_type): 123 | fpath = join(self.history_dir, dset_type+'.csv') 124 | data = np.loadtxt(fpath, delimiter=',').reshape( 125 | -1, len(self.metrics)+1) 126 | for i in range(len(self.metrics)): 127 | self.metrics_history[self.metrics[i].name][dset_type] = data[:,i+1].tolist() 128 | 129 | def append_history_to_file(self, fpath, values, epoch): 130 | # Restricts decimals to 6 places!!! 131 | formatted_vals = ["{:.6f}".format(v) for v in values] 132 | line = ','.join(formatted_vals) 133 | with open(fpath, 'a') as f: 134 | f.write('{},{}\n'.format(epoch, line)) 135 | 136 | def update_best_metrics(self): 137 | best_metrics = {} 138 | for metric in self.metrics: 139 | metric_history = self.metrics_history[metric.name][c.VAL] 140 | best_epoch, best_value = metric.get_best_epoch( 141 | metric_history) 142 | best_metrics[metric.name] = { 143 | 'epoch':best_epoch, 144 | 'value':best_value 145 | } 146 | self.best_metrics = best_metrics 147 | 148 | def load_aux_metrics_from_file(self): 149 | data = np.loadtxt(self.aux_metrics_fpath, delimiter=',').reshape( 150 | -1, len(self.aux_metrics)+1) 151 | for i in range(len(self.aux_metrics)): 152 | self.metrics_history[self.aux_metrics[i].name] = data[:,i+1].tolist() 153 | 154 | def save_aux_metrics(self, values, epoch): 155 | for i in range(len(self.aux_metrics)): 156 | self.metrics_history[self.aux_metrics[i].name].append(values[i]) 157 | self.append_history_to_file(self.aux_metrics_fpath, values, epoch) 158 | 159 | def get_dset_arr(self, dset): 160 | data = [] 161 | for metric in self.metrics: 162 | data.append(self.metrics_history[metric.name][dset]) 163 | epochs = [i+1 for i in range(len(data[0]))] 164 | data.insert(0,epochs) 165 | arr = np.array(data) 166 | return arr.T 167 | 168 | def plot(self, save=False): 169 | trn_data = self.get_dset_arr(c.TRAIN) 170 | val_data = self.get_dset_arr(c.VAL) 171 | metrics_idx = [i+1 for i in range(len(self.metrics))] 172 | trn_args = np.split(trn_data, metrics_idx, axis=1) 173 | val_args = np.split(val_data, metrics_idx, axis=1) 174 | 175 | metric_fpaths = [] 176 | for i in range(len(self.metrics)): 177 | metric_trn_data = trn_data[:,i+1] #skip epoch 178 | metric_val_data = val_data[:,i+1] 179 | 180 | fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 181 | plt.plot(trn_args[0], metric_trn_data, label='Train') 182 | plt.plot(val_args[0], metric_val_data, label='Validation') 183 | plt.title(self.metrics[i].name) 184 | plt.xlabel('Epoch') 185 | plt.ylabel(self.metrics[i].name) 186 | plt.legend() 187 | ax.set_yscale('log') 188 | 189 | if save: 190 | metric_fpath = join(self.history_dir, 191 | self.metrics[i].name+'.png') 192 | metric_fpaths.append(metric_fpath) 193 | plt.savefig(metric_fpath) 194 | 195 | # Combined View 196 | if save: 197 | all_metrics_fpath = join(self.history_dir, 'all_metrics.png') 198 | metric_fpaths.append(all_metrics_fpath) 199 | os.system('convert +append ' + ' '.join(metric_fpaths)) 200 | 201 | plt.show() 202 | 203 | def to_doc(self, config): 204 | df = self.get_dataframe() 205 | df[c.EXP_ID_FIELD] = config.get_id() 206 | df[c.ES_EXP_KEY_FIELD] = df['Epoch'].map(str) + '_' + config.get_id() 207 | df['name'] = config.get_display_name() 208 | df['user'] = config.hardware['hostname'] 209 | df['criterion'] = config.criterion['name'] 210 | df['optim'] = config.optimizer['name'] 211 | df['init_lr'] = config.training['initial_lr'] 212 | df['wd'] = config.training['weight_decay'] 213 | df['bs'] = config.training['batch_size'] 214 | df['imsz'] = config.data['img_rescale'] 215 | df['model_name'] = config.model_name 216 | df['lr_adjuster'] = config.lr_adjuster['name'] 217 | df['threshold'] = config.training['threshold'] 218 | return json.loads(df.to_json(orient='records')) 219 | 220 | 221 | ### TODO 222 | def get_history_summary(self, epoch, early_stop_metric): 223 | msg = ['Epoch: %d' % epoch] 224 | for dset in [c.TRAIN, c.VAL]: 225 | dset_msg = dset.capitalize() + ' - ' 226 | for metric in self.metrics: 227 | value = self.metrics_history[metric.name][dset][-1] 228 | dset_msg += '{:s}: {:.3f} '.format(metric.name, value) 229 | msg.append(dset_msg) 230 | 231 | best_epoch = self.best_metrics[early_stop_metric]['epoch'] 232 | best_epoch_value = self.best_metrics[early_stop_metric]['value'] 233 | best_metric_msg = 'Best val {:s}: Epoch {:d} - {:.3f}'.format( 234 | early_stop_metric, best_epoch, best_epoch_value) 235 | msg.append(best_metric_msg) 236 | 237 | return '\n'.join(msg) 238 | 239 | -------------------------------------------------------------------------------- /experiments/exp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import shutil 4 | from glob import glob 5 | 6 | import config as cfg 7 | import constants as c 8 | import numpy as np 9 | import utils.general as gen_utils 10 | import utils.files 11 | import models.utils 12 | from clients import s3_client, es_client 13 | 14 | 15 | 16 | def cleanup_experiments(exp_dir): 17 | exp_paths = glob(exp_dir+'/*/') 18 | for path in exp_paths: 19 | config_path = os.path.join(path, c.EXPERIMENT_CONFIG_FNAME) 20 | if not os.path.isfile(config_path): 21 | shutil.rmtree(path) 22 | 23 | 24 | def delete_experiment(exp_name, exp_dir, local=True, s3=False, es=False): 25 | if local: 26 | pattern = os.path.join(exp_dir, exp_name) 27 | exp_path_list = glob(pattern) 28 | if len(exp_path_list) > 0: 29 | for p in exp_path_list: 30 | print("Deleting local exp") 31 | shutil.rmtree(p) 32 | else: 33 | print("Local copy of exp not found!") 34 | if s3: 35 | print("Deleting S3 document") 36 | s3_client.delete_experiment(exp_name) 37 | if es: 38 | print("ES delete not implemented") 39 | es_client.delete_experiment_by_field(field='exp_name', value=exp_name) 40 | 41 | 42 | def prune(weights_dir, keep_epochs): 43 | prune_weights_and_optims(weights_dir, keep_epochs) 44 | 45 | 46 | def auto_prune(exp, n_bins=5, metric_name=c.LOSS, func=min): 47 | best_epochs = get_best_epochs(exp, metric_name, n_bins, func) 48 | print(best_epochs) 49 | prune(exp.weights_dir, best_epochs) 50 | 51 | 52 | def get_best_epochs(exp, metric_name, n_bins=5, func=max, end_epoch=10000): 53 | metric_arr = exp.history.metrics_history[metric_name][c.VAL][:end_epoch+1] 54 | idx, _ = get_best_values_in_bins(metric_arr, n_bins, func) 55 | return [i+1 for i in idx] #epoch starts at 1 56 | 57 | 58 | def get_best_values_in_bins(arr, n_bins, func): 59 | bucket_size = math.ceil(len(arr)/n_bins) 60 | if isinstance(arr, list): 61 | arr = np.array(arr) 62 | best_idxfunc = np.argmax if func is max else np.argmin 63 | best_valfunc = np.amax if func is max else np.amin 64 | best_idx, best_vals = [], [] 65 | for i in range(0, len(arr), bucket_size): 66 | best_idx.append(i+best_idxfunc(arr[i:i+bucket_size])) 67 | best_vals.append(best_valfunc(arr[i:i+bucket_size])) 68 | return best_idx, best_vals 69 | 70 | 71 | def prune_weights_and_optims(weights_dir, keep_epochs): 72 | matches, fpaths = utils.files.get_matching_files_in_dir( 73 | weights_dir, c.WEIGHTS_OPTIM_FNAME_REGEX) 74 | print(matches) 75 | for i in range(len(matches)): 76 | epoch = int(matches[i].group(2)) 77 | if epoch not in keep_epochs: 78 | os.remove(fpaths[i]) 79 | 80 | 81 | def get_weights_fpaths(weights_dir): 82 | return utils.files.get_matching_files_in_dir( 83 | weights_dir, c.WEIGHTS_FNAME_REGEX)[1] 84 | 85 | 86 | def get_weight_epochs_from_fpaths(fpaths): 87 | epochs = [] 88 | found_latest = False 89 | for path_ in fpaths: 90 | ## FIX THIS override 91 | if 'latest' not in path_: 92 | epochs.append(int(path_.strip(c.WEIGHTS_EXT).split('-')[-1])) 93 | else: 94 | found_latest = True 95 | epochs.sort() 96 | if found_latest: 97 | epochs.insert(0,'latest') 98 | return epochs 99 | 100 | 101 | def get_weight_fpaths_by_epoch(weights_dir, epochs): 102 | matches, fpaths = utils.files.get_matching_files_in_dir( 103 | weights_dir, c.WEIGHTS_FNAME_REGEX) 104 | weight_fpaths = [] 105 | for i in range(len(matches)): 106 | epoch = int(matches[i].group(1)) 107 | if epoch in epochs: 108 | weight_fpaths.append(fpaths[i]) 109 | return weight_fpaths 110 | 111 | 112 | def get_optim_fpaths_by_epoch(optims_dir, keep_epochs): 113 | matches, fpaths = utils.files.get_matching_files_in_dir( 114 | optims_dir, c.OPTIM_FNAME_REGEX) 115 | weight_fpaths = [] 116 | for i in range(len(matches)): 117 | epoch = int(matches[i].group(1)) 118 | if epoch in keep_epochs: 119 | weight_fpaths.append(fpaths[i]) 120 | return weight_fpaths 121 | 122 | 123 | def get_weights_fname(epoch): 124 | if epoch is None: 125 | return c.LATEST_WEIGHTS_FNAME 126 | return 'weights-%d%s' % (epoch, c.WEIGHTS_EXT) 127 | 128 | 129 | def get_optim_fname(epoch): 130 | if epoch is None: 131 | return c.LATEST_OPTIM_FNAME 132 | return 'optim-%d%s' % (epoch, c.OPTIM_EXT) 133 | 134 | 135 | def load_weights_by_exp_and_epoch(model, exp_name, epoch='latest'): 136 | if epoch is None or epoch == 'latest': 137 | weights_fname = c.LATEST_WEIGHTS_FNAME 138 | else: 139 | weights_fname = 'weights-{:d}.th'.format(epoch) 140 | fpath = os.path.join(cfg.PATHS['experiments'], exp_name, 'weights', weights_fname) 141 | models.utils.load_weights(model, fpath) 142 | 143 | 144 | def download_experiment(dest_dir, exp_name): 145 | fpath = os.path.join(dest_dir, exp_name + c.EXP_FILE_EXT) 146 | s3_client.download_experiment(fpath, exp_name) 147 | utils.files.unzipdir(fpath, dest_dir) 148 | 149 | 150 | def upload_experiment(parent_dir, exp_name): 151 | print(('Uploading experiment {:s}. ' 152 | 'This may take a while..').format(exp_name)) 153 | exp_path = os.path.join(parent_dir, exp_name) 154 | exp_copy_path = exp_path+'_copy' 155 | exp_copy_archive_path = os.path.join(exp_copy_path, exp_name) 156 | archive_path = exp_path + c.EXP_FILE_EXT 157 | shutil.copytree(exp_path, exp_copy_archive_path) 158 | print('Archiving..') 159 | utils.files.zipdir(exp_copy_path, archive_path) 160 | shutil.rmtree(exp_copy_path) 161 | print('Uploading..') 162 | s3_client.upload_experiment(archive_path, exp_name) 163 | os.remove(archive_path) 164 | print('Upload complete!') 165 | 166 | 167 | def generate_display_name(base_name, *args): 168 | unique_id = gen_utils.gen_unique_id() 169 | return base_name+'-'.join(args[0])+'-id'+unique_id 170 | 171 | 172 | def get_id_from_name(exp_name): 173 | return exp_name.split('-id')[-1] 174 | 175 | 176 | def get_transforms_config(transforms): 177 | data_aug = [] 178 | for r in transforms.transforms: 179 | data_aug.append((str(r.__class__.__name__), 180 | r.__dict__)) 181 | return data_aug -------------------------------------------------------------------------------- /experiments/experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import logging 5 | from os.path import join 6 | 7 | import config as cfg 8 | import constants as c 9 | from metrics import metric_builder 10 | from visualizers import vis_utils 11 | from notifications import emailer 12 | import utils.logger 13 | import training 14 | import models.utils 15 | 16 | from .exp_history import ExperimentHistory 17 | from . import exp_utils 18 | from . import exp_config 19 | 20 | 21 | 22 | 23 | class Experiment(): 24 | def __init__(self, name, parent_dir): 25 | self.name = name 26 | self.parent_dir = parent_dir 27 | self.root = join(parent_dir, name) 28 | self.weights_dir = join(self.root, 'weights') 29 | self.results_dir = join(self.root, 'results') 30 | self.history_dir = join(self.root, 'history') 31 | self.config_fpath = join(self.root, c.EXPERIMENT_CONFIG_FNAME) 32 | self.model_fpath = join(self.root, c.MODEL_FNAME) 33 | self.optim_fpath = join(self.root, c.OPTIM_FNAME) 34 | 35 | # Initialized/loaded later 36 | self.config = None 37 | self.logger = None 38 | self.history = None 39 | self.model = None 40 | self.optim = None 41 | self.max_patience = None 42 | self.early_stop_metric = None 43 | self.epoch = 0 44 | self.best_epoch = 1 45 | self.best_epoch_value = None 46 | self.visualizers = [] 47 | self.metrics = [] 48 | self.aux_metrics = [] 49 | self.best_metrics = None 50 | 51 | def init(self, config_dict): 52 | self.config = exp_config.create_config_from_dict(config_dict) 53 | self.config.progress['status'] = c.INITIALIZED 54 | self.metrics = config_dict['metrics'] 55 | self.aux_metrics = config_dict['aux_metrics'] 56 | self.model = config_dict['model'] 57 | self.optim = config_dict['optimizer'] 58 | self.visualizers = config_dict['visualizers'] 59 | self.max_patience = self.config.training['max_patience'] 60 | self.early_stop_metric = self.config.training['early_stop_metric'] 61 | self.history = ExperimentHistory(self.name, self.history_dir, 62 | self.metrics, self.aux_metrics) 63 | self.init_dirs() 64 | self.history.init() 65 | self.init_logger() 66 | self.init_visualizers() 67 | self.save_components() 68 | self.model.logger = self.logger 69 | 70 | def resume(self, epoch=None, verbose=False): 71 | self.init_logger() 72 | self.log("Resuming existing experiment") 73 | self.config = exp_config.load_config_from_file(self.config_fpath) 74 | self.config.progress['status'] = c.RESUMED 75 | self.load(verbose) 76 | self.init_visualizers() 77 | self.load_components(epoch) 78 | self.model.logger = self.logger 79 | 80 | def review(self, download=False, verbose=True): 81 | self.init_logger() 82 | if download: 83 | self.config = exp_config.fetch_external_config(self.name) 84 | else: 85 | self.config = exp_config.load_config_from_file(self.config_fpath) 86 | self.load(verbose=verbose) 87 | 88 | def init_visualizers(self): 89 | for v in self.visualizers: 90 | v.init(self.config) 91 | 92 | def init_dirs(self): 93 | os.makedirs(self.weights_dir) 94 | os.makedirs(self.history_dir) 95 | os.makedirs(self.results_dir) 96 | 97 | def init_logger(self, log_level=logging.INFO): 98 | self.logger = utils.logger.get_logger( 99 | self.root, 'logger', ch_log_level=log_level, 100 | fh_log_level=log_level) 101 | 102 | def load(self, verbose=False): 103 | self.metrics = metric_builder.get_metrics_from_config(self.config) 104 | self.aux_metrics = metric_builder.get_aux_metrics_from_config(self.config) 105 | self.visualizers = vis_utils.get_visualizers_from_config(self.config) 106 | self.history = ExperimentHistory(self.name, self.history_dir, 107 | self.metrics, self.aux_metrics) 108 | self.history.resume() 109 | self.max_patience = self.config.training['max_patience'] 110 | self.early_stop_metric = self.config.training['early_stop_metric'] 111 | self.epoch = self.config.progress['epoch'] 112 | self.best_metrics = self.config.progress['best_metrics'] 113 | self.best_epoch = self.best_metrics[self.early_stop_metric]['epoch'] 114 | self.best_epoch_value = self.best_metrics[self.early_stop_metric]['value'] 115 | if verbose: self.config.summary(self.logger) 116 | 117 | def save(self, s3=cfg.S3_ENABLED, es=cfg.ES_ENABLED): 118 | self.config.save(s3, es) 119 | self.history.save(self.config, s3, es) 120 | 121 | def upload(self): 122 | exp_utils.upload_experiment(self.parent_dir, self.name) 123 | 124 | def load_components(self, epoch): 125 | self.model = models.utils.load_model(self.model_fpath) 126 | self.optim = training.load_optim(self.optim_fpath) 127 | self.load_model_state(epoch) 128 | self.load_optim_state(epoch) 129 | 130 | def save_components(self): 131 | models.utils.save_model(self.model.cpu(), self.model_fpath) 132 | training.save_optim(self.optim, self.optim_fpath) 133 | self.model = self.model.cuda() 134 | 135 | def log(self, msg): 136 | self.logger.info(msg) 137 | 138 | def update_visualizers(self, msg=None): 139 | for v in self.visualizers: 140 | v.update(self.config, self.history, msg) 141 | 142 | def update_progress(self): 143 | best = self.history.best_metrics 144 | self.best_epoch = best[self.early_stop_metric]['epoch'] 145 | self.best_epoch_value = best[self.early_stop_metric]['value'] 146 | self.config.progress['epoch'] = self.epoch 147 | self.config.progress['best_metrics'] = best 148 | 149 | def get_weights_fpath(self, epoch=None): 150 | fname = exp_utils.get_weights_fname(epoch) 151 | return join(self.weights_dir, fname) 152 | 153 | def get_optim_fpath(self, epoch=None): 154 | fname = exp_utils.get_optim_fname(epoch) 155 | return join(self.weights_dir, fname) 156 | 157 | def save_model_state(self, save_now=False): 158 | models.utils.save_weights(self.model, self.get_weights_fpath(), 159 | epoch=self.epoch, name=self.name) 160 | if (save_now or self.epoch 161 | % self.config.training['save_weights_cadence'] == 0): 162 | fpath = self.get_weights_fpath(self.epoch) 163 | shutil.copyfile(self.get_weights_fpath(), fpath) 164 | 165 | def load_model_state(self, epoch=None): 166 | fpath = self.get_weights_fpath(epoch) 167 | models.utils.load_weights(self.model, fpath) 168 | 169 | def save_optim_state(self, save_now=False): 170 | training.save_optim_params(self.optim, self.get_optim_fpath(), 171 | epoch=self.epoch, name=self.name) 172 | if (save_now or self.epoch 173 | % self.config.training['save_weights_cadence'] == 0): 174 | fpath = self.get_optim_fpath(self.epoch) 175 | shutil.copyfile(self.get_optim_fpath(), fpath) 176 | 177 | def load_optim_state(self, epoch=None): 178 | fpath = self.get_optim_fpath(epoch) 179 | training.load_optim_params(self.optim, fpath) 180 | 181 | def train(self, trainer, trn_loader, val_loader, n_epochs=None): 182 | start_epoch = self.epoch + 1 # Epochs start at 1 183 | self.config.progress['status'] = c.IN_PROGRESS 184 | self.config.progress['status_msg'] = 'Experiment in progress' 185 | 186 | if n_epochs is None: 187 | end_epoch = self.config.training['n_epochs'] + 1 188 | else: 189 | end_epoch = start_epoch + n_epochs 190 | try: 191 | for epoch in range(start_epoch, end_epoch): 192 | 193 | ### Adjust Lr ### 194 | lr_params = {'best_iter' : self.best_epoch} 195 | if trainer.lr_adjuster.iteration_type == 'epoch': 196 | trainer.lr_adjuster.adjust(self.optim, epoch, lr_params) 197 | current_lr = trainer.lr_adjuster.get_learning_rate(self.optim) 198 | 199 | ### Train ### 200 | trn_start_time = time.time() 201 | trn_metrics = trainer.train(self.model, trn_loader, 202 | self.config.training['threshold'], epoch, self.metrics) 203 | trn_msg = training.log_trn_msg(self.logger, trn_start_time, 204 | trn_metrics, current_lr, epoch) 205 | 206 | ### Test ### 207 | val_start_time = time.time() 208 | val_metrics = trainer.test(self.model, val_loader, 209 | self.config.training['threshold'], self.metrics) 210 | val_msg = training.log_val_msg(self.logger, val_start_time, 211 | val_metrics, current_lr) 212 | 213 | sys_mem = training.log_memory('') 214 | 215 | ### Save Metrics ### 216 | aux_metrics = [current_lr, sys_mem] 217 | self.history.save_metric(c.TRAIN, trn_metrics, epoch) 218 | self.history.save_metric(c.VAL, val_metrics, epoch) 219 | self.history.save_aux_metrics(aux_metrics, epoch) 220 | self.history.update_best_metrics() 221 | 222 | ### Checkpoint ### 223 | self.epoch = epoch 224 | self.update_progress() 225 | self.save() 226 | self.save_model_state() 227 | self.save_optim_state() 228 | self.update_visualizers('\n'.join([trn_msg, val_msg])) 229 | 230 | ### Early Stopping ### 231 | if training.early_stop(epoch, self.best_epoch, self.max_patience): 232 | msg = "Early stopping at epoch %d since no better %s found since epoch %d at %.3f" % ( 233 | epoch, self.early_stop_metric, self.best_epoch, self.best_epoch_value) 234 | self.config.progress['status'] = c.MAX_PATIENCE_EXCEEDED 235 | self.config.progress['status_msg'] = msg 236 | break 237 | 238 | except Exception as e: 239 | self.config.progress['status'] = c.FAILED 240 | self.config.progress['status_msg'] = e 241 | raise Exception(e) 242 | finally: 243 | if self.config.progress['status'] == c.IN_PROGRESS: 244 | self.config.progress['status'] = c.COMPLETED 245 | self.config.progress['status_msg'] = 'Experiment Complete!' 246 | if cfg.EMAIL_ENABLED: 247 | emailer.send_experiment_status_email(self, cfg.USER_EMAIL) 248 | self.log(self.config.progress['status_msg']) 249 | 250 | 251 | -------------------------------------------------------------------------------- /init_project.py: -------------------------------------------------------------------------------- 1 | import os 2 | import config as cfg 3 | 4 | 5 | def init_paths(root, dset_types, input_img_exts, target_img_exts): 6 | paths = { 7 | 'project': root, 8 | 'experiments': os.path.join(root, 'experiments'), 9 | 'predictions': os.path.join(root, 'predictions'), 10 | 'submissions': os.path.join(root, 'submissions'), 11 | 'folds': os.path.join(root, 'folds') 12 | } 13 | for key in paths: 14 | os.makedirs(paths[key], exist_ok=True) 15 | 16 | paths['datasets'] = {} 17 | datasets_root = os.path.join(root, 'datasets') 18 | os.makedirs(datasets_root, exist_ok=True) 19 | make_dataset(paths, datasets_root, 'inputs', dset_types, input_img_exts) 20 | make_dataset(paths, datasets_root, 'targets', dset_types, target_img_exts) 21 | 22 | return paths 23 | 24 | 25 | def make_dataset(paths, datasets_root, name, dset_types, img_exts): 26 | root = os.path.join(datasets_root, name) 27 | os.makedirs(root, exist_ok=True) 28 | paths['datasets'][name] = {} 29 | 30 | for dset in dset_types: 31 | for img in img_exts: 32 | dir_name = dset+'_'+img 33 | dir_path = os.path.join(root, dir_name) 34 | os.makedirs(dir_path, exist_ok=True) 35 | paths['datasets'][name][dir_name] = dir_path 36 | 37 | 38 | if __name__ == '__main__': 39 | init_paths(cfg.PROJECT_PATH, cfg.IMG_DATASET_TYPES, 40 | cfg.IMG_INPUT_FORMATS, cfg.IMG_TARGET_FORMATS) -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric_utils import * -------------------------------------------------------------------------------- /metrics/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from . import metric_utils 4 | 5 | 6 | class DiceLoss(): 7 | ''' 8 | http://campar.in.tum.de/pub/milletari2016Vnet/milletari2016Vnet.pdf 9 | https://github.com/faustomilletari/VNet/blob/master/pyLayer.py 10 | https://github.com/pytorch/pytorch/issues/1249 11 | ''' 12 | def __init__(self): 13 | self.__class__.__name__ = 'Dice' 14 | 15 | def __call__(self, output, target): 16 | return 1.0 - get_torch_dice_score(output, target) 17 | 18 | 19 | class DiceBCELoss(): 20 | def __init__(self, dice_weight=1.0): 21 | self.__class__.__name__ = 'DiceBCE' 22 | self.dice_weight = dice_weight 23 | self.bce_weight = 1.0 - dice_weight 24 | 25 | def __call__(self, output, target): 26 | bce = F.binary_cross_entropy(output, target) 27 | dice = 1 - get_torch_dice_score(output, target) 28 | return (dice * self.dice_weight) + (bce * self.bce_weight) 29 | 30 | 31 | class WeightedBCELoss(): 32 | def __init__(self, weights): 33 | self.weights = weights 34 | self.__class__.__name__ = 'WeightedBCE' 35 | 36 | def __call__(self, output, target): 37 | return F.binary_cross_entropy(output, target, self.weights) 38 | 39 | 40 | class KnowledgeDistillLoss(): 41 | def __init__(self, target_weight=0.25): 42 | self.__class__.__name__ = 'KnowledgeDistill' 43 | self.target_weight = target_weight 44 | 45 | def __call__(self, output, target, soft_target): 46 | target_loss = F.binary_cross_entropy(output, target) * self.target_weight 47 | soft_target_loss = F.binary_cross_entropy(output, soft_target) 48 | return target_loss + soft_target_loss 49 | 50 | 51 | class HuberLoss(): 52 | def __init__(self, c=0.5): 53 | self.c = c 54 | self.__class__.__name__ = 'Huber' 55 | 56 | def __call__(self, output, target): 57 | bce = F.binary_cross_entropy(output, target) 58 | return self.c**2 * (torch.sqrt(1 + (bce/self.c)**2) - 1) 59 | 60 | 61 | class SmoothF2Loss(): 62 | def __init__(self, c=10.0, f2_weight=0.2, bce_weight=1.0): 63 | self.__class__.__name__ = 'SmoothF2' 64 | self.c = c 65 | self.f2_weight = f2_weight 66 | self.bce_weight = bce_weight 67 | 68 | def __call__(self, output, target, thresholds): 69 | f2 = get_smooth_f2_score(output, target, thresholds, self.c) * self.f2_weight 70 | bce = F.binary_cross_entropy(output, target) * self.bce_weight 71 | return f2 + bce 72 | 73 | 74 | 75 | # Helpers / Shared Methods 76 | 77 | def get_torch_dice_score(outputs, targets): 78 | eps = 1e-7 79 | batch_size = outputs.size()[0] 80 | outputs = outputs.view(batch_size, -1) 81 | targets = targets.view(batch_size, -1) 82 | 83 | total = torch.sum(outputs, dim=1) + torch.sum(targets, dim=1) 84 | intersection = torch.sum(outputs * targets, dim=1).float() 85 | 86 | dice_score = (2.0 * intersection) / (total + eps) 87 | return torch.mean(dice_score) 88 | 89 | 90 | def sigmoid(z, c=1.0): 91 | return 1.0 / (1.0 + torch.exp(-c*z)) 92 | 93 | 94 | def get_smooth_f2_score(outputs, targets, thresholds, c=10.0): 95 | eps = 1e-9 96 | outputs = sigmoid(thresholds - outputs, c).float() 97 | tot_out_pos = torch.sum(outputs, dim=1) 98 | tot_tar_pos = torch.sum(targets, dim=1) 99 | TP = torch.sum(outputs * targets, dim=1) 100 | 101 | P = TP / (tot_out_pos + eps) 102 | R = TP / tot_tar_pos + eps 103 | F2 = 5.0 * (P*R / (4*P + R)) 104 | return torch.mean(F2) -------------------------------------------------------------------------------- /metrics/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import operator 3 | import constants as c 4 | from . import metric_utils 5 | 6 | 7 | class Metric(): 8 | def __init__(self, name, minimize=True): 9 | self.name = name 10 | self.minimize = minimize 11 | 12 | def get_best_epoch(self, values): 13 | if self.minimize: 14 | idx, value = min(enumerate(values), 15 | key=operator.itemgetter(1)) 16 | else: 17 | idx, value = max(enumerate(values), 18 | key=operator.itemgetter(1)) 19 | epoch = idx + 1 # epochs start at 1 20 | return epoch, value 21 | 22 | def evaluate(self, loss, preds, probs, targets): 23 | pass 24 | 25 | def format(self, value): 26 | pass 27 | 28 | 29 | class AuxiliaryMetric(): 30 | def __init__(self, name, units): 31 | self.name = name 32 | self.units = units 33 | 34 | 35 | class Accuracy(Metric): 36 | def __init__(self): 37 | super().__init__(c.ACCURACY, minimize=False) 38 | 39 | def evaluate(self, loss, preds, probs, targets): 40 | return metric_utils.get_accuracy(preds, targets) 41 | 42 | def format(self, value): 43 | return value 44 | 45 | 46 | class Loss(Metric): 47 | def __init__(self): 48 | super().__init__(c.LOSS, minimize=True) 49 | 50 | def evaluate(self, loss, preds, probs, targets): 51 | return loss 52 | 53 | def format(self, value): 54 | return value 55 | 56 | 57 | class F2Score(Metric): 58 | def __init__(self, target_threshold=None): 59 | super().__init__(c.F2_SCORE, minimize=False) 60 | self.target_threshold = target_threshold # pseudo soft targets 61 | 62 | def evaluate(self, loss, preds, probs, targets): 63 | average = 'samples' if targets.shape[1] > 1 else 'binary' 64 | if self.target_threshold is not None: 65 | targets = targets > self.target_threshold 66 | 67 | return metric_utils.get_f2_score(preds, targets, average) 68 | 69 | def format(self, value): 70 | return value 71 | 72 | 73 | class DiceScore(Metric): 74 | def __init__(self): 75 | super().__init__(c.DICE_SCORE, minimize=False) 76 | 77 | def evaluate(self, loss, preds, probs, targets): 78 | return metric_utils.get_dice_score(preds, targets) 79 | 80 | def format(self, value): 81 | return value 82 | 83 | 84 | class EnsembleF2(Metric): 85 | def __init__(self, ens_probs, threshold): 86 | super().__init__('EnsembleF2', minimize=False) 87 | self.ens_probs = ens_probs 88 | self.threshold = threshold 89 | 90 | def evaluate(self, loss, preds, probs, targets): 91 | if probs.shape[0] != self.ens_probs.shape[1]: 92 | return .950 93 | average = 'samples' if targets.shape[1] > 1 else 'binary' 94 | probs = np.expand_dims(probs, 0) 95 | joined_probs = np.concatenate([self.ens_probs, probs]) 96 | joined_probs = np.mean(joined_probs, axis=0) 97 | preds = joined_probs > self.threshold 98 | return metric_utils.get_f2_score(preds, targets, average) 99 | 100 | def format(self, value): 101 | return value -------------------------------------------------------------------------------- /metrics/metric_builder.py: -------------------------------------------------------------------------------- 1 | import constants as c 2 | from . import metric 3 | 4 | 5 | SUPPORTED_METRICS = { 6 | c.ACCURACY: metric.Accuracy(), 7 | c.LOSS: metric.Loss(), 8 | c.F2_SCORE: metric.F2Score(), 9 | c.ENSEMBLE_F2: metric.EnsembleF2(None,None), 10 | c.DICE_SCORE: metric.DiceScore(), 11 | } 12 | SUPPORTED_AUX_METRICS = {} 13 | 14 | 15 | def get_metric_by_name(name): 16 | return SUPPORTED_METRICS[name] 17 | 18 | 19 | def get_metrics_from_config(config): 20 | primary_metrics = [] 21 | for m in config.metrics: 22 | new_metric = get_metric_by_name(m) 23 | primary_metrics.append(new_metric) 24 | return primary_metrics 25 | 26 | 27 | def get_aux_metrics_from_config(config): 28 | aux_metrics = [] 29 | for m in config.aux_metrics: 30 | new_metric = metric.AuxiliaryMetric(m['name'], m['units']) 31 | aux_metrics.append(new_metric) 32 | return aux_metrics 33 | -------------------------------------------------------------------------------- /metrics/metric_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | import matplotlib.pyplot as plt 4 | from sklearn.metrics import fbeta_score 5 | from sklearn import metrics as scipy_metrics 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | import warnings 10 | import constants as c 11 | import predictions 12 | 13 | 14 | def get_default_loss(probs, targets, **kwargs): 15 | return get_cross_entropy_loss(probs, targets) 16 | 17 | 18 | def get_default_score(preds, targets, avg='samples', **kwargs): 19 | return get_f2_score(preds, targets, avg) 20 | 21 | 22 | def get_metric_in_blocks(outputs, targets, block_size, metric): 23 | sum_ = 0 24 | n = 0 25 | i = 0 26 | while i < len(outputs): 27 | out_block = outputs[i:i+block_size] 28 | tar_block = targets[i:i+block_size] 29 | score = metric(out_block, tar_block) 30 | sum_ += len(out_block) * score 31 | n += len(out_block) 32 | i += block_size 33 | return sum_ / n 34 | 35 | 36 | def get_metrics_in_batches(model, loader, thresholds, metrics): 37 | model.eval() 38 | n_batches = len(loader) 39 | metric_totals = [0 for m in metrics] 40 | 41 | for data in loader: 42 | if len(data[1].size()) == 1: 43 | targets = data[1].float().view(-1, 1) 44 | inputs = Variable(data[0].cuda(async=True)) 45 | targets = Variable(data[1].cuda(async=True)) 46 | 47 | output = model(inputs) 48 | 49 | labels = targets.data.cpu().numpy() 50 | probs = output.data.cpu().numpy() 51 | preds = predictions.get_predictions(probs, thresholds) 52 | 53 | for i,m in enumerate(metrics): 54 | score = m(preds, labels) 55 | metric_totals[i] += score 56 | 57 | metric_totals = [m / n_batches for m in metric_totals] 58 | return metric_totals 59 | 60 | 61 | def get_accuracy(preds, targets): 62 | preds = preds.flatten() 63 | targets = targets.flatten() 64 | correct = np.sum(preds==targets) 65 | return correct / len(targets) 66 | 67 | 68 | def get_cross_entropy_loss(probs, targets): 69 | return F.binary_cross_entropy( 70 | Variable(torch.from_numpy(probs)), 71 | Variable(torch.from_numpy(targets).float())).data[0] 72 | 73 | 74 | def get_recall(preds, targets): 75 | return scipy_metrics.recall_score(targets.flatten(), preds.flatten()) 76 | 77 | 78 | def get_precision(preds, targets): 79 | return scipy_metrics.precision_score(targets.flatten(), preds.flatten()) 80 | 81 | 82 | def get_roc_score(probs, targets): 83 | return scipy_metrics.roc_auc_score(targets.flatten(), probs.flatten()) 84 | 85 | 86 | def get_dice_score(preds, targets): 87 | eps = 1e-7 88 | batch_size = preds.shape[0] 89 | preds = preds.reshape(batch_size, -1) 90 | targets = targets.reshape(batch_size, -1) 91 | 92 | total = preds.sum(1) + targets.sum(1) + eps 93 | intersection = (preds * targets).astype(float) 94 | score = 2. * intersection.sum(1) / total 95 | return np.mean(score) 96 | 97 | 98 | def get_f2_score(y_pred, y_true, average='samples'): 99 | y_pred, y_true, = np.array(y_pred), np.array(y_true) 100 | return fbeta_score(y_true, y_pred, beta=2, average=average) 101 | 102 | 103 | def find_f2score_threshold(probs, targets, average='samples', 104 | try_all=True, verbose=False, step=.01): 105 | best = 0 106 | best_score = -1 107 | totry = np.arange(0.1, 0.9, step) 108 | for t in totry: 109 | score = get_f2_score(probs, targets, t) 110 | if score > best_score: 111 | best_score = score 112 | best = t 113 | if verbose is True: 114 | print('Best score: ', round(best_score, 5), 115 | ' @ threshold =', round(best,4)) 116 | return round(best,6) 117 | 118 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/models/__init__.py -------------------------------------------------------------------------------- /models/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models 4 | import models.utils 5 | 6 | 7 | def get_fc(in_feat, n_classes, activation=None): 8 | layers = [ 9 | nn.Linear(in_features=in_feat, out_features=n_classes) 10 | ] 11 | if activation is not None: 12 | layers.append(activation) 13 | return nn.Sequential(*layers) 14 | 15 | 16 | def get_classifier(in_feat, n_classes, activation, p=0.5): 17 | layers = [ 18 | nn.BatchNorm1d(num_features=in_feat), 19 | nn.Dropout(p), 20 | nn.Linear(in_features=in_feat, out_features=n_classes), 21 | activation 22 | ] 23 | return nn.Sequential(*layers) 24 | 25 | 26 | def get_mlp_classifier(in_feat, out_feat, n_classes, activation, p=0.01, p2=0.5): 27 | layers = [ 28 | nn.BatchNorm1d(num_features=in_feat), 29 | nn.Dropout(p), 30 | nn.Linear(in_features=in_feat, out_features=out_feat), 31 | nn.ReLU(), 32 | nn.BatchNorm1d(num_features=out_feat), 33 | nn.Dropout(p2), 34 | nn.Linear(in_features=out_feat, out_features=n_classes), 35 | activation 36 | ] 37 | return nn.Sequential(*layers) 38 | 39 | 40 | def cut_model(model, cut): 41 | return nn.Sequential(*list(model.children())[:cut]) -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv_relu(in_channels, out_channels, kernel_size=3, stride=1, 5 | padding=1, bias=True): 6 | return [ 7 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 8 | stride=stride, padding=padding, bias=bias), 9 | nn.ReLU(inplace=True), 10 | ] 11 | 12 | def conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, 13 | padding=1, bias=False): 14 | return [ 15 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 16 | stride=stride, padding=padding, bias=bias), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True), 19 | ] 20 | 21 | def linear_bn_relu_drop(in_channels, out_channels, dropout=0.5, bias=False): 22 | layers = [ 23 | nn.Linear(in_channels, out_channels, bias=bias), 24 | nn.BatchNorm1d(out_channels), 25 | nn.ReLU(inplace=True) 26 | ] 27 | if dropout > 0: 28 | layers.append(nn.Dropout(dropout)) 29 | return layers 30 | 31 | 32 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models 4 | import models.utils 5 | 6 | 7 | class SimpleResnet(nn.Module): 8 | def __init__(self, resnet, classifier): 9 | super().__init__() 10 | self.__class__.__name__ = "SimpleResnet" 11 | self.resnet = resnet 12 | self.classifier = classifier 13 | 14 | def forward(self, x): 15 | x = self.resnet(x) 16 | x = x.view(x.size(0), -1) 17 | x = self.classifier(x) 18 | return x 19 | 20 | 21 | class ConcatResnet(nn.Module): 22 | def __init__(self, resnet, classifier): 23 | super().__init__() 24 | self.__class__.__name__ = 'ConcatResnet' 25 | self.resnet = resnet 26 | self.ap = nn.AdaptiveAvgPool2d((1,1)) 27 | self.mp = nn.AdaptiveMaxPool2d((1,1)) 28 | self.classifier = classifier 29 | 30 | def forward(self, x): 31 | x = self.resnet(x) 32 | x = torch.cat([self.mp(x), self.ap(x)], 1) 33 | x = x.view(x.size(0), -1) 34 | x = self.classifier(x) 35 | return x 36 | 37 | 38 | def get_resnet18(pretrained, n_freeze): 39 | resnet = torchvision.models.resnet18(pretrained) 40 | if n_freeze > 0: 41 | models.utils.freeze_layers(resnet, n_freeze) 42 | return resnet 43 | 44 | 45 | def get_resnet34(pretrained, n_freeze): 46 | resnet = torchvision.models.resnet34(pretrained) 47 | if n_freeze > 0: 48 | models.utils.freeze_layers(resnet, n_freeze) 49 | return resnet 50 | 51 | 52 | def get_resnet50(pretrained, n_freeze): 53 | resnet = torchvision.models.resnet50(pretrained) 54 | if n_freeze > 0: 55 | models.utils.freeze_layers(resnet, n_freeze) 56 | return resnet 57 | -------------------------------------------------------------------------------- /models/simplenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import models.layers as layers 3 | 4 | 5 | class SimpleNet(nn.Module): 6 | def __init__(self, in_feat, n_classes): 7 | super().__init__() 8 | self.features = nn.Sequential( 9 | *layers.conv_bn_relu(in_feat, 8, kernel_size=1, stride=1, padding=0, bias=False), 10 | *layers.conv_bn_relu(8, 32, kernel_size=3, stride=1, padding=1, bias=False), 11 | nn.MaxPool2d(kernel_size=2, stride=2), #size/2 12 | *layers.conv_bn_relu(32, 32, kernel_size=3, stride=1, padding=1, bias=False), 13 | nn.MaxPool2d(kernel_size=2, stride=2), #size/2 14 | *layers.conv_bn_relu(32, 64, kernel_size=3, stride=1, padding=1, bias=False), 15 | nn.MaxPool2d(kernel_size=2, stride=2), #size/2 16 | ) 17 | self.classifier = nn.Sequential( 18 | *layers.linear_bn_relu_drop(64, 512, dropout=0.0, bias=False), 19 | nn.Linear(512, n_classes, bias=False), 20 | nn.Sigmoid() 21 | ) 22 | 23 | def forward(self, x): 24 | x = self.features(x) 25 | x = x.view(x.size(0), -1) 26 | x = self.classifier(x) 27 | return x -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def load_model(fpath, cuda=True): 5 | if cuda: 6 | return torch.load(fpath).cuda() 7 | return torch.load(fpath) 8 | 9 | 10 | def save_model(model, fpath): 11 | torch.save(model.cpu(), fpath) 12 | 13 | 14 | def load_weights(model, fpath): 15 | state = torch.load(fpath) 16 | model.load_state_dict(state['state_dict']) 17 | 18 | 19 | def save_weights(model, fpath, epoch=None, name=None): 20 | torch.save({ 21 | 'name': name, 22 | 'epoch': epoch, 23 | 'state_dict': model.state_dict() 24 | }, fpath) 25 | 26 | 27 | def freeze_layers(model, n_layers): 28 | i = 0 29 | for child in model.children(): 30 | if i >= n_layers: 31 | break 32 | print(i, "freezing", child) 33 | for param in child.parameters(): 34 | param.requires_grad = False 35 | i += 1 36 | 37 | 38 | def freeze_nested_layers(model, n_layers): 39 | i = 0 40 | for child in model.children(): 41 | for grandchild in child.children(): 42 | if isinstance(grandchild, torch.nn.modules.container.Sequential): 43 | for greatgrand in grandchild.children(): 44 | if i >= n_layers: 45 | break 46 | for param in greatgrand.parameters(): 47 | param.requires_grad = False 48 | print(i, "freezing", greatgrand) 49 | i += 1 50 | else: 51 | if i >= n_layers: 52 | break 53 | for param in grandchild.parameters(): 54 | param.requires_grad = False 55 | print(i, "freezing", grandchild) 56 | i += 1 57 | 58 | 59 | def init_nested_layers(module, conv_init, fc_init): 60 | for child in module.children(): 61 | if len(list(child.children())) > 0: 62 | init_nested_layers(child, conv_init, fc_init) 63 | else: 64 | init_weights(child, conv_init, fc_init) 65 | 66 | 67 | def init_weights(layer, conv_init, fc_init): 68 | if isinstance(layer, torch.nn.Conv2d): 69 | print("init", layer, "with", conv_init) 70 | conv_init(layer.weight) 71 | elif isinstance(layer, torch.nn.Linear): 72 | print("init", layer, "with", fc_init) 73 | fc_init(layer.weight) 74 | -------------------------------------------------------------------------------- /notifications/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/notifications/__init__.py -------------------------------------------------------------------------------- /notifications/email_constants.py: -------------------------------------------------------------------------------- 1 | import config 2 | import constants as c 3 | 4 | 5 | WEBSITE_URL = config.KIBANA_URL 6 | ADMIN_EMAIL = config.ADMIN_EMAIL 7 | USER_EMAIL = config.USER_EMAIL 8 | EMAIL_CHARSET = 'UTF-8' 9 | 10 | HEADER="" 11 | FOOTER="" 12 | 13 | EXPERIMENT_STATUS_EMAIL_TEMPLATE=""" 14 |

Hello,

15 |

Your experiment has ended.

16 |

Name: %s

17 |

Status: %s

18 |

Status Msg: %s

19 |

View Dashboard

20 |

Experiment Results:

21 |

%s

22 |

Experiment Config:

23 |

%s

24 |

Thanks,
25 | Team

26 | """ 27 | 28 | EXPERIMENT_STATUS_EMAIL_BODY = ( 29 | HEADER + EXPERIMENT_STATUS_EMAIL_TEMPLATE + FOOTER 30 | ) 31 | 32 | EXPERIMENT_STATUS_EMAIL ={ 33 | 'subject' : 'New Experiment Results', 34 | 'body' : EXPERIMENT_STATUS_EMAIL_BODY 35 | } 36 | -------------------------------------------------------------------------------- /notifications/emailer.py: -------------------------------------------------------------------------------- 1 | import config 2 | from .email_constants import * 3 | import clients.ses_client as ses 4 | import utils.general 5 | 6 | 7 | def send_experiment_status_email(exp, to_email): 8 | body = get_experiment_status_template(exp) 9 | ses.send_email(EXPERIMENT_STATUS_EMAIL['subject'], body, to_email) 10 | 11 | 12 | def get_experiment_status_template(exp): 13 | status = exp.config.progress['status'] 14 | msg = exp.config.progress['status_msg'] 15 | progress = utils.general.dict_to_html_ul(exp.config.progress) 16 | config = exp.config.to_html() 17 | return EXPERIMENT_STATUS_EMAIL['body'] % (exp.name, status, msg, 18 | WEBSITE_URL, progress, config) 19 | -------------------------------------------------------------------------------- /predictions/__init__.py: -------------------------------------------------------------------------------- 1 | from .pred_utils import * 2 | from .pred_builder import * 3 | from .pred_constants import * 4 | from .prediction import * -------------------------------------------------------------------------------- /predictions/pred_builder.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import utils.files 4 | import constants as c 5 | from .pred_constants import * 6 | from .prediction import Prediction 7 | 8 | 9 | def build_scores(loss, score): 10 | return { 11 | c.LOSS: loss, 12 | c.SCORE: score 13 | } 14 | 15 | 16 | def build_metadata(labels, scores, thresholds, pred_type, dset): 17 | return { 18 | 'label_names': labels, 19 | 'scores': scores, 20 | 'thresholds': thresholds, 21 | 'pred_type': pred_type, 22 | 'dset': dset, 23 | 'created': time.strftime("%m/%d/%Y %H:%M:%S", time.localtime()) 24 | } 25 | 26 | 27 | def build_pred(name, preds, probs, val_preds, val_probs, labels, loss, 28 | score, thresholds, w_fpath, exp_name, tta, dset): 29 | name = PRED_TYPE + '-' + name 30 | scores = build_scores(loss, score) 31 | metadata = build_metadata(labels, scores, thresholds, PRED_TYPE, dset) 32 | metadata['w_fpath'] = w_fpath 33 | metadata['exp_name'] = exp_name 34 | metadata['tta'] = get_tta_doc(tta) 35 | return Prediction(name, metadata, preds=preds, probs=probs, 36 | val_preds=val_preds, val_probs=val_probs) 37 | 38 | 39 | def get_tta_doc(transforms): 40 | data_aug = [] 41 | for r in transforms.transforms: 42 | data_aug.append((str(r.__class__.__name__), 43 | r.__dict__)) 44 | return str(data_aug) 45 | -------------------------------------------------------------------------------- /predictions/pred_constants.py: -------------------------------------------------------------------------------- 1 | PRED_TYPE = 'Basic' 2 | TTA_PRED_TYPE = 'TTA' 3 | ENS_TYPE = 'Ens' 4 | MEGA_ENS_TYPE = 'MegaEns' 5 | -------------------------------------------------------------------------------- /predictions/pred_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import scipy 5 | import numpy as np 6 | import pandas as pd 7 | import bcolz 8 | import random 9 | from io import StringIO 10 | from torch.autograd import Variable 11 | import cv2 12 | import h5py 13 | 14 | import config as cfg 15 | import constants as c 16 | from .pred_constants import * 17 | from . import pred_builder 18 | import utils.general 19 | import utils.files 20 | import models.utils 21 | import clients.s3_client as s3 22 | from datasets import data_loaders 23 | from datasets import metadata 24 | from metrics import metric_utils 25 | from experiments import exp_utils 26 | from datasets.datasets import FileDataset 27 | 28 | 29 | 30 | def predict_batch(net, inputs): 31 | v = Variable(inputs.cuda(), volatile=True) 32 | return net(v).data.cpu().numpy() 33 | 34 | 35 | def get_probabilities(model, loader): 36 | model.eval() 37 | return np.vstack(predict_batch(model, data[0]) for data in loader) 38 | 39 | 40 | def get_predictions(probs, thresholds): 41 | preds = np.copy(probs) 42 | preds[preds >= thresholds] = 1 43 | preds[preds < thresholds] = 0 44 | return preds.astype('uint8') 45 | 46 | 47 | def get_mask_predictions(model, loader, thresholds, W=None, H=None): 48 | probs = get_probabilities(model, loader) 49 | preds = get_predictions(probs, thresholds) 50 | 51 | if W is not None and H is not None: 52 | preds = resize_batch(preds, W, H) 53 | return preds 54 | 55 | 56 | def get_mask_probabilities(model, loader, W=None, H=None): 57 | model.eval() 58 | probs = get_probabilities(model, loader) 59 | if W is not None and H is not None: 60 | probs = resize_batch(probs, W, H) 61 | return probs 62 | 63 | 64 | def resize_batch(pred_batch, W=None, H=None): 65 | preds = [] 66 | for i in range(len(pred_batch)): 67 | arr = resize_arr(pred_batch[i], W, H) 68 | preds.append(arr) 69 | return np.stack(preds) 70 | 71 | 72 | def resize_arr(arr, W, H, mode=cv2.INTER_LINEAR): 73 | """ 74 | We assume shape is (C, H, W) like tensor 75 | # arr = scipy.misc.imresize(arr.squeeze(), shape, interp='bilinear', mode=None) 76 | To shrink: 77 | - INTER_AREA 78 | To enlarge: 79 | - INTER_CUBIC (slow, best quality) 80 | - INTER_LINEAR (faster, good quality). 81 | """ 82 | arr = arr.transpose(1,2,0) 83 | arr = cv2.resize(arr, (W, H), mode) 84 | if len(arr.shape) < 3: 85 | arr = np.expand_dims(arr, 2) 86 | arr = arr.transpose(2,0,1) 87 | return arr 88 | 89 | 90 | def get_targets(loader): 91 | targets = None 92 | for data in loader: 93 | if targets is None: 94 | shape = list(data[1].size()) 95 | shape[0] = 0 96 | targets = np.empty(shape) 97 | target = data[1] 98 | if len(target.size()) == 1: 99 | target = target.view(-1,1) 100 | target = target.numpy() 101 | targets = np.vstack([targets, target]) 102 | return targets 103 | 104 | 105 | def save_pred(fpath, pred_arr, meta_dict=None): 106 | bc = bcolz.carray(pred_arr, mode='w', rootdir=fpath, 107 | cparams=bcolz.cparams(clevel=9, cname='lz4')) 108 | if meta_dict is not None: 109 | bc.attrs['meta'] = meta_dict 110 | bc.flush() 111 | return bc 112 | 113 | 114 | def append_to_pred(bc_arr, pred_arr, meta_dict=None): 115 | bc_arr.append(pred_arr) 116 | if meta_dict is not None: 117 | bc_arr.attrs['meta'] = meta_dict 118 | bc_arr.flush() 119 | return bc_arr 120 | 121 | 122 | def append_pred_to_file(fpath, pred_arr, meta_dict=None): 123 | bc_arr = bcolz.open(rootdir=fpath) 124 | bc_arr.append(pred_arr) 125 | if meta_dict is not None: 126 | bc_arr.attrs['meta'] = meta_dict 127 | bc_arr.flush() 128 | return bc_arr 129 | 130 | 131 | def save_or_append_pred_to_file(fpath, pred_arr, meta_dict=None): 132 | if os.path.exists(fpath): 133 | return append_pred_to_file(fpath, pred_arr, meta_dict) 134 | else: 135 | return save_pred(fpath, pred_arr, meta_dict) 136 | 137 | 138 | def load_pred(fpath, numpy=False): 139 | bc = bcolz.open(rootdir=fpath) 140 | if numpy: 141 | return np.array(bc) 142 | return bc 143 | 144 | 145 | def get_local_pred_fpath(name): 146 | return os.path.join(cfg.PATHS['predictions'], name+c.PRED_FILE_EXT) 147 | 148 | 149 | def list_local_preds(dset=c.TEST, fnames_only=False): 150 | pattern = '_' + dset + c.PRED_FILE_EXT 151 | _, fpaths = utils.files.get_matching_files_in_dir( 152 | cfg.PATHS['predictions'], pattern) 153 | if fnames_only: 154 | return [utils.files.get_fname_from_fpath(f) for f in fpaths] 155 | return fpaths 156 | 157 | 158 | def ensemble_with_method(arr, method): 159 | if method == c.MEAN: 160 | return np.mean(arr, axis=0) 161 | elif method == c.GMEAN: 162 | return scipy.stats.mstats.gmean(arr, axis=0) 163 | elif method == c.VOTE: 164 | return scipy.stats.mode(arr, axis=0)[0][0] 165 | raise Exception("Operation not found") 166 | 167 | 168 | def get_prediction_fpath(basename, dset): 169 | fname = '{:s}_{:s}'.format(basename, dset + c.PRED_FILE_EXT) 170 | return os.path.join(cfg.PATHS['predictions'], fname) 171 | 172 | 173 | 174 | # refactor notebook helpers 175 | 176 | def build_pred_df_from_dir(dir_path): 177 | fpaths, _ = utils.files.get_paths_to_files(dir_path) 178 | summary = [] 179 | for f in fpaths: 180 | if c.PRED_FILE_EXT in f: 181 | pred = load_pred(f) 182 | summary_dict = build_pred_summary_dict(pred) 183 | summary.append(summary_dict) 184 | return pd.DataFrame(summary) 185 | 186 | 187 | def get_pred_summary_from_dicts(dicts): 188 | summary = [] 189 | for d in dicts: 190 | summary.append(build_pred_summary_dict(d)) 191 | return pd.DataFrame(summary) 192 | 193 | 194 | def build_pred_summary_dict(pred): 195 | meta = pred['meta'] 196 | return { 197 | 'id': pred.get_id(), 198 | 'name': pred.name, 199 | 'pred_type': pred.pred_type, 200 | 'dset': meta['dset'], 201 | c.LOSS : meta['scores'][c.LOSS], 202 | c.SCORE : meta['scores'][c.SCORE], 203 | 'threshold' : meta['thresholds'], 204 | 'created': meta['created'], 205 | 'fpath': get_local_pred_fpath(pred.name) 206 | } 207 | 208 | 209 | def get_clean_tta_str(tta): 210 | STRIP = [ 211 | 'torchvision.transforms.', 212 | 'torchsample.tensor_transforms.', 213 | 'torchsample.affine_transforms.', 214 | 'torchsample.transforms.tensor_transforms.', 215 | 'torchsample.transforms.affine_transforms.', 216 | 'object at ', 217 | '<', '>', 218 | ] 219 | str_ = str(tta.transforms) 220 | for s in STRIP: 221 | str_ = str_.replace(s,'') 222 | return str_ -------------------------------------------------------------------------------- /predictions/prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | import config as cfg 5 | import constants as c 6 | from clients import s3_client 7 | from clients import es_client 8 | from .pred_constants import * 9 | 10 | 11 | 12 | class Prediction: 13 | def __init__(self, fpath, metadata): 14 | self.fpath = fpath 15 | self.meta = metadata 16 | 17 | @property 18 | def name(self): 19 | return os.path.basename(self.fpath).rstrip( 20 | c.PRED_FILE_EXT) 21 | 22 | @property 23 | def id(self): 24 | return self.name.split('-id')[-1] 25 | 26 | @property 27 | def display_name(self): 28 | return self.name.split('-id')[0] 29 | 30 | def to_dict(self): 31 | return copy.deepcopy(self.__dict__) 32 | 33 | def to_doc(self): 34 | dict_ = self.to_dict() 35 | dict_['key'] = self.id 36 | dict_['display_name'] = self.display_name() 37 | return dict_ 38 | 39 | def save(self, s3=cfg.S3_ENABLED, es=cfg.ES_ENABLED): 40 | if s3: 41 | s3_client.upload_prediction(self.fpath, self.name()) 42 | if es: 43 | es_client.upload_prediction(self) 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire==0.1.1 2 | boto3==1.4.4 3 | mock==1.3.0 4 | pytest==3.0.7 5 | pytest-mock==1.6.0 6 | bcolz==1.0.0 7 | configobj==5.0.6 8 | ipywidgets==5.2.2 9 | nibabel==2.1.0 10 | Pillow-SIMD==4.1.1.post0 11 | qgrid==0.3.2 12 | tqdm==4.11.2 13 | traitlets==4.3.1 14 | visdom==0.1.4 15 | filelock==2.0.7 16 | spectral==0.18 17 | elasticsearch>=5.0.0,<6.0.0 18 | -------------------------------------------------------------------------------- /submissions/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /submissions/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import predictions 5 | import datasets.metadata as meta 6 | import training 7 | import constants as c 8 | import config as cfg 9 | import utils.files 10 | 11 | 12 | 13 | def write_preds_to_file(fpath, ids, preds, header): 14 | ids = np.array(ids).T 15 | preds = np.array(preds).T 16 | submission = np.stack([ids, preds], axis=1) 17 | np.savetxt(fpath, submission, fmt='%s', delimiter=',', 18 | header=header, comments='') 19 | 20 | 21 | def make_tags_submission(sub_fpath, ids, preds, label_names, header): 22 | tags = meta.get_tags_from_preds(preds, label_names) 23 | write_preds_to_file(sub_fpath, ids, tags, header) 24 | 25 | 26 | def make_preds_submission(sub_fpath, ids, preds, header): 27 | preds = [' '.join(map(str, p.tolist())) for p in preds] 28 | write_preds_to_file(sub_fpath, ids, preds, header) 29 | 30 | 31 | def get_sub_path_from_pred_path(pred_fpath): 32 | sub_fname = os.path.basename(pred_fpath).rstrip( 33 | c.PRED_FILE_EXT) + c.SUBMISSION_FILE_EXT 34 | sub_fpath = os.path.join(cfg.PATHS['submissions'], sub_fname) 35 | return sub_fpath 36 | 37 | 38 | def run_length_encode(mask_image): 39 | pixels = mask_image.flatten() 40 | # We avoid issues with '1' at the start or end (at the corners of 41 | # the original image) by setting those pixels to '0' explicitly. 42 | # We do not expect these to be non-zero for an accurate mask, 43 | # so this should not harm the score. 44 | pixels[0] = 0 45 | pixels[-1] = 0 46 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 47 | runs[1::2] = runs[1::2] - runs[:-1:2] 48 | return rle_to_string(runs) 49 | 50 | 51 | def rle_to_string(runs): 52 | return ' '.join(str(x) for x in runs) 53 | 54 | 55 | def run_length_decode(rel, H, W, fill_value=1): 56 | mask = np.zeros((H*W),np.uint8) 57 | rel = np.array([int(s) for s in rel.split(' ')]).reshape(-1,2) 58 | for r in rel: 59 | start = r[0] 60 | end = start +r[1] 61 | mask[start:end]=fill_value 62 | mask = mask.reshape(H,W) 63 | return mask 64 | 65 | 66 | 67 | 68 | 69 | def submit_to_kaggle(fpath, competition, username, password): 70 | pass 71 | 72 | 73 | 74 | # Refactor classification stuff from amazon.. 75 | 76 | def make_multi_label_submission(preds, img_paths, label_names, out_path, 77 | name, file_ext='.csv.gz'): 78 | pred_tags = convert_preds_to_tags(preds, label_names) 79 | fnames = utils.files.get_fnames_from_fpaths(img_paths) 80 | fnames = np.array(fnames) 81 | fnames = np.expand_dims(fnames, 1) 82 | submission_fpath = os.path.join(out_path, name+'-submission'+file_ext) 83 | write_preds_to_file(fnames, pred_tags, submission_fpath) 84 | 85 | 86 | def convert_preds_to_tags(preds, tags_list): 87 | tag_list = [] 88 | for pred in preds: 89 | tags = ' '.join(meta.convert_one_hot_to_tags(pred, tags_list)) 90 | tag_list.append(tags) 91 | tag_arr = np.array(tag_list) 92 | return np.expand_dims(tag_arr,1) -------------------------------------------------------------------------------- /tests/unit_tests/training/test_learning_rates.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mock 3 | import pytest 4 | from pytest_mock import mocker 5 | 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from training import learning_rates 9 | 10 | # Instructions 11 | # https://medium.com/@bfortuner/python-unit-testing-with-pytest-and-mock-197499c4623c 12 | 13 | 14 | ## Shared Objects 15 | 16 | INITIAL_LR = 1e-3 17 | 18 | @pytest.fixture(scope="module") 19 | def example_fixture(): 20 | return 1e-3 21 | 22 | @pytest.fixture(scope="module") 23 | def lr_schedule(): 24 | return { 25 | 1: 1e-3, 26 | 5: 1e-4, 27 | 10: 1e-5 28 | } 29 | 30 | def sgd(): 31 | model = nn.Sequential(nn.Linear(3, 3)) 32 | return optim.SGD(model.parameters(), lr=INITIAL_LR) 33 | 34 | def adam(): 35 | model = nn.Sequential(nn.Linear(3, 3)) 36 | return optim.Adam(model.parameters(), lr=INITIAL_LR) 37 | 38 | 39 | ## Tests 40 | 41 | def test_get_learning_rate(): 42 | LR = learning_rates.LearningRate(INITIAL_LR, 'epoch') 43 | optim = sgd() 44 | assert LR.initial_lr == INITIAL_LR 45 | assert LR.get_learning_rate(optim) == INITIAL_LR 46 | 47 | def test_set_learning_rate(): 48 | LR = learning_rates.LearningRate(INITIAL_LR, 'epoch') 49 | optim = sgd() 50 | new_lr = INITIAL_LR + 1e-1 51 | LR.set_learning_rate(optim, new_lr) 52 | assert LR.get_learning_rate(optim) == new_lr 53 | 54 | def test_LearningRate_adjust(): 55 | LR = learning_rates.LearningRate(INITIAL_LR, 'epoch') 56 | optim = sgd() 57 | iteration = 5 58 | new_lr_expected = INITIAL_LR + 1e-1 59 | new_lr_output = LR.adjust(optim, new_lr_expected, iteration) 60 | assert new_lr_output == new_lr_expected 61 | assert LR.get_learning_rate(optim) == new_lr_expected 62 | assert LR.lr_history[0] == [iteration, new_lr_expected] 63 | 64 | def test_FixedLR_adjust(): 65 | LR = learning_rates.FixedLR(INITIAL_LR, 'epoch') 66 | optim = sgd() 67 | iteration = 5 68 | new_lr_output = LR.adjust(optim, iteration) 69 | assert new_lr_output == INITIAL_LR 70 | assert LR.get_learning_rate(optim) == INITIAL_LR 71 | assert LR.lr_history[0] == [iteration, INITIAL_LR] 72 | 73 | def test_LinearLR_adjust(): 74 | fixed_delta = 1e-1 75 | LR = learning_rates.LinearLR(INITIAL_LR, 'epoch', fixed_delta) 76 | optim = sgd() 77 | iteration = 5 78 | new_lr_expected = INITIAL_LR + fixed_delta 79 | new_lr_output = LR.adjust(optim, iteration) 80 | assert new_lr_output == new_lr_expected 81 | assert LR.get_learning_rate(optim) == new_lr_expected 82 | assert LR.lr_history[0] == [iteration, new_lr_expected] 83 | 84 | def test_ScheduledLR_adjust(lr_schedule): 85 | LR = learning_rates.ScheduledLR(INITIAL_LR, 'epoch', lr_schedule) 86 | optim = sgd() 87 | 88 | assert LR.adjust(optim, 1) == lr_schedule[1] 89 | assert LR.get_learning_rate(optim) == lr_schedule[1] 90 | assert LR.lr_history[0] == [1, lr_schedule[1]] 91 | 92 | assert LR.adjust(optim, 2) == lr_schedule[1] 93 | assert LR.get_learning_rate(optim) == lr_schedule[1] 94 | assert LR.lr_history[1] == [2, lr_schedule[1]] 95 | 96 | assert LR.adjust(optim, 5) == lr_schedule[5] 97 | assert LR.get_learning_rate(optim) == lr_schedule[5] 98 | assert LR.lr_history[2] == [5, lr_schedule[5]] 99 | 100 | def test_SnapshotLR_adjust(): 101 | max_lr = 0.1 102 | n_iters = 100 103 | n_cycles = 5 104 | LR = learning_rates.SnapshotLR(INITIAL_LR, 'mini_batch', 105 | max_lr, n_iters, n_cycles) 106 | optim = sgd() 107 | 108 | assert LR.adjust(optim, 0) == 0.1 109 | assert LR.get_learning_rate(optim) == 0.1 110 | assert LR.lr_history[0] == [0, max_lr] 111 | 112 | assert math.isclose(LR.adjust(optim, 19), 0.0006, abs_tol=0.000016) 113 | assert math.isclose(LR.get_learning_rate(optim), 0.0006, abs_tol=0.000016) 114 | 115 | assert LR.adjust(optim, 20) == 0.1 116 | assert LR.get_learning_rate(optim) == 0.1 117 | 118 | -------------------------------------------------------------------------------- /torchsample/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from .version import __version__ 5 | 6 | from .datasets import * 7 | from .samplers import * 8 | 9 | #from .callbacks import * 10 | #from .constraints import * 11 | #from .regularizers import * 12 | 13 | #from . import functions 14 | #from . import transforms 15 | from . import modules 16 | -------------------------------------------------------------------------------- /torchsample/constraints.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | from fnmatch import fnmatch 6 | 7 | import torch as th 8 | from .callbacks import Callback 9 | 10 | 11 | class ConstraintContainer(object): 12 | 13 | def __init__(self, constraints): 14 | self.constraints = constraints 15 | self.batch_constraints = [c for c in self.constraints if c.unit.upper() == 'BATCH'] 16 | self.epoch_constraints = [c for c in self.constraints if c.unit.upper() == 'EPOCH'] 17 | 18 | def register_constraints(self, model): 19 | """ 20 | Grab pointers to the weights which will be modified by constraints so 21 | that we dont have to search through the entire network using `apply` 22 | each time 23 | """ 24 | # get batch constraint pointers 25 | self._batch_c_ptrs = {} 26 | for c_idx, constraint in enumerate(self.batch_constraints): 27 | self._batch_c_ptrs[c_idx] = [] 28 | for name, module in model.named_modules(): 29 | if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'): 30 | self._batch_c_ptrs[c_idx].append(module) 31 | 32 | # get epoch constraint pointers 33 | self._epoch_c_ptrs = {} 34 | for c_idx, constraint in enumerate(self.epoch_constraints): 35 | self._epoch_c_ptrs[c_idx] = [] 36 | for name, module in model.named_modules(): 37 | if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'): 38 | self._epoch_c_ptrs[c_idx].append(module) 39 | 40 | def apply_batch_constraints(self, batch_idx): 41 | for c_idx, modules in self._batch_c_ptrs.items(): 42 | if (batch_idx+1) % self.constraints[c_idx].frequency == 0: 43 | for module in modules: 44 | self.constraints[c_idx](module) 45 | 46 | def apply_epoch_constraints(self, epoch_idx): 47 | for c_idx, modules in self._epoch_c_ptrs.items(): 48 | if (epoch_idx+1) % self.constraints[c_idx].frequency == 0: 49 | for module in modules: 50 | self.constraints[c_idx](module) 51 | 52 | 53 | class ConstraintCallback(Callback): 54 | 55 | def __init__(self, container): 56 | self.container = container 57 | 58 | def on_batch_end(self, batch_idx, logs): 59 | self.container.apply_batch_constraints(batch_idx) 60 | 61 | def on_epoch_end(self, epoch_idx, logs): 62 | self.container.apply_epoch_constraints(epoch_idx) 63 | 64 | 65 | class Constraint(object): 66 | 67 | def __call__(self): 68 | raise NotImplementedError('Subclass much implement this method') 69 | 70 | 71 | class UnitNorm(Constraint): 72 | """ 73 | UnitNorm constraint. 74 | 75 | Constraints the weights to have column-wise unit norm 76 | """ 77 | def __init__(self, 78 | frequency=1, 79 | unit='batch', 80 | module_filter='*'): 81 | 82 | self.frequency = frequency 83 | self.unit = unit 84 | self.module_filter = module_filter 85 | 86 | def __call__(self, module): 87 | w = module.weight.data 88 | module.weight.data = w.div(th.norm(w,2,0)) 89 | 90 | 91 | class MaxNorm(Constraint): 92 | """ 93 | MaxNorm weight constraint. 94 | 95 | Constrains the weights incident to each hidden unit 96 | to have a norm less than or equal to a desired value. 97 | 98 | Any hidden unit vector with a norm less than the max norm 99 | constaint will not be altered. 100 | """ 101 | 102 | def __init__(self, 103 | value, 104 | axis=0, 105 | frequency=1, 106 | unit='batch', 107 | module_filter='*'): 108 | self.value = float(value) 109 | self.axis = axis 110 | 111 | self.frequency = frequency 112 | self.unit = unit 113 | self.module_filter = module_filter 114 | 115 | def __call__(self, module): 116 | w = module.weight.data 117 | module.weight.data = th.renorm(w, 2, self.axis, self.value) 118 | 119 | 120 | class NonNeg(Constraint): 121 | """ 122 | Constrains the weights to be non-negative. 123 | """ 124 | def __init__(self, 125 | frequency=1, 126 | unit='batch', 127 | module_filter='*'): 128 | self.frequency = frequency 129 | self.unit = unit 130 | self.module_filter = module_filter 131 | 132 | def __call__(self, module): 133 | w = module.weight.data 134 | module.weight.data = w.gt(0).float().mul(w) 135 | 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /torchsample/functions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .affine import * -------------------------------------------------------------------------------- /torchsample/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes to initialize module weights 3 | """ 4 | 5 | from fnmatch import fnmatch 6 | 7 | import torch.nn.init 8 | 9 | 10 | def _validate_initializer_string(init): 11 | dir_f = dir(torch.nn.init) 12 | loss_fns = [d.lower() for d in dir_f] 13 | if isinstance(init, str): 14 | try: 15 | str_idx = loss_fns.index(init.lower()) 16 | except: 17 | raise ValueError('Invalid loss string input - must match pytorch function.') 18 | return getattr(torch.nn.init, dir(torch.nn.init)[str_idx]) 19 | elif callable(init): 20 | return init 21 | else: 22 | raise ValueError('Invalid loss input') 23 | 24 | 25 | class InitializerContainer(object): 26 | 27 | def __init__(self, initializers): 28 | self._initializers = initializers 29 | 30 | def apply(self, model): 31 | for initializer in self._initializers: 32 | model.apply(initializer) 33 | 34 | 35 | class Initializer(object): 36 | 37 | def __call__(self, module): 38 | raise NotImplementedError('Initializer must implement this method') 39 | 40 | 41 | class GeneralInitializer(Initializer): 42 | 43 | def __init__(self, initializer, bias=False, bias_only=False, **kwargs): 44 | self._initializer = _validate_initializer_string(initializer) 45 | self.kwargs = kwargs 46 | 47 | def __call__(self, module): 48 | classname = module.__class__.__name__ 49 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 50 | if self.bias_only: 51 | self._initializer(module.bias.data, **self.kwargs) 52 | else: 53 | self._initializer(module.weight.data, **self.kwargs) 54 | if self.bias: 55 | self._initializer(module.bias.data, **self.kwargs) 56 | 57 | 58 | class Normal(Initializer): 59 | 60 | def __init__(self, mean=0.0, std=0.02, bias=False, 61 | bias_only=False, module_filter='*'): 62 | self.mean = mean 63 | self.std = std 64 | 65 | self.bias = bias 66 | self.bias_only = bias_only 67 | self.module_filter = module_filter 68 | 69 | super(Normal, self).__init__() 70 | 71 | def __call__(self, module): 72 | classname = module.__class__.__name__ 73 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 74 | if self.bias_only: 75 | torch.nn.init.normal(module.bias.data, mean=self.mean, std=self.std) 76 | else: 77 | torch.nn.init.normal(module.weight.data, mean=self.mean, std=self.std) 78 | if self.bias: 79 | torch.nn.init.normal(module.bias.data, mean=self.mean, std=self.std) 80 | 81 | 82 | class Uniform(Initializer): 83 | 84 | def __init__(self, a=0, b=1, bias=False, bias_only=False, module_filter='*'): 85 | self.a = a 86 | self.b = b 87 | 88 | self.bias = bias 89 | self.bias_only = bias_only 90 | self.module_filter = module_filter 91 | 92 | super(Uniform, self).__init__() 93 | 94 | def __call__(self, module): 95 | classname = module.__class__.__name__ 96 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 97 | if self.bias_only: 98 | torch.nn.init.uniform(module.bias.data, a=self.a, b=self.b) 99 | else: 100 | torch.nn.init.uniform(module.weight.data, a=self.a, b=self.b) 101 | if self.bias: 102 | torch.nn.init.uniform(module.bias.data, a=self.a, b=self.b) 103 | 104 | 105 | class ConstantInitializer(Initializer): 106 | 107 | def __init__(self, value, bias=False, bias_only=False, module_filter='*'): 108 | self.value = value 109 | 110 | self.bias = bias 111 | self.bias_only = bias_only 112 | self.module_filter = module_filter 113 | 114 | super(ConstantInitializer, self).__init__() 115 | 116 | def __call__(self, module, bias=False, bias_only=False, module_filter='*'): 117 | classname = module.__class__.__name__ 118 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 119 | if self.bias_only: 120 | torch.nn.init.constant(module.bias.data, val=self.value) 121 | else: 122 | torch.nn.init.constant(module.weight.data, val=self.value) 123 | if self.bias: 124 | torch.nn.init.constant(module.bias.data, val=self.value) 125 | 126 | 127 | class XavierUniform(Initializer): 128 | 129 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'): 130 | self.gain = gain 131 | 132 | self.bias = bias 133 | self.bias_only = bias_only 134 | self.module_filter = module_filter 135 | 136 | super(XavierUniform, self).__init__() 137 | 138 | def __call__(self, module): 139 | classname = module.__class__.__name__ 140 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 141 | if self.bias_only: 142 | torch.nn.init.xavier_uniform(module.bias.data, gain=self.gain) 143 | else: 144 | torch.nn.init.xavier_uniform(module.weight.data, gain=self.gain) 145 | if self.bias: 146 | torch.nn.init.xavier_uniform(module.bias.data, gain=self.gain) 147 | 148 | 149 | class XavierNormal(Initializer): 150 | 151 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'): 152 | self.gain = gain 153 | 154 | self.bias = bias 155 | self.bias_only = bias_only 156 | self.module_filter = module_filter 157 | 158 | super(XavierNormal, self).__init__() 159 | 160 | def __call__(self, module): 161 | classname = module.__class__.__name__ 162 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 163 | if self.bias_only: 164 | torch.nn.init.xavier_normal(module.bias.data, gain=self.gain) 165 | else: 166 | torch.nn.init.xavier_normal(module.weight.data, gain=self.gain) 167 | if self.bias: 168 | torch.nn.init.xavier_normal(module.bias.data, gain=self.gain) 169 | 170 | 171 | class KaimingUniform(Initializer): 172 | 173 | def __init__(self, a=0, mode='fan_in', bias=False, bias_only=False, module_filter='*'): 174 | self.a = a 175 | self.mode = mode 176 | 177 | self.bias = bias 178 | self.bias_only = bias_only 179 | self.module_filter = module_filter 180 | 181 | super(KaimingUniform, self).__init__() 182 | 183 | def __call__(self, module): 184 | classname = module.__class__.__name__ 185 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 186 | if self.bias_only: 187 | torch.nn.init.kaiming_uniform(module.bias.data, a=self.a, mode=self.mode) 188 | else: 189 | torch.nn.init.kaiming_uniform(module.weight.data, a=self.a, mode=self.mode) 190 | if self.bias: 191 | torch.nn.init.kaiming_uniform(module.bias.data, a=self.a, mode=self.mode) 192 | 193 | 194 | class KaimingNormal(Initializer): 195 | 196 | def __init__(self, a=0, mode='fan_in', bias=False, bias_only=False, module_filter='*'): 197 | self.a = a 198 | self.mode = mode 199 | 200 | self.bias = bias 201 | self.bias_only = bias_only 202 | self.module_filter = module_filter 203 | 204 | super(KaimingNormal, self).__init__() 205 | 206 | def __call__(self, module): 207 | classname = module.__class__.__name__ 208 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 209 | if self.bias_only: 210 | torch.nn.init.kaiming_normal(module.bias.data, a=self.a, mode=self.mode) 211 | else: 212 | torch.nn.init.kaiming_normal(module.weight.data, a=self.a, mode=self.mode) 213 | if self.bias: 214 | torch.nn.init.kaiming_normal(module.bias.data, a=self.a, mode=self.mode) 215 | 216 | 217 | class Orthogonal(Initializer): 218 | 219 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'): 220 | self.gain = gain 221 | 222 | self.bias = bias 223 | self.bias_only = bias_only 224 | self.module_filter = module_filter 225 | 226 | super(Orthogonal, self).__init__() 227 | 228 | def __call__(self, module): 229 | classname = module.__class__.__name__ 230 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 231 | if self.bias_only: 232 | torch.nn.init.orthogonal(module.bias.data, gain=self.gain) 233 | else: 234 | torch.nn.init.orthogonal(module.weight.data, gain=self.gain) 235 | if self.bias: 236 | torch.nn.init.orthogonal(module.bias.data, gain=self.gain) 237 | 238 | 239 | class Sparse(Initializer): 240 | 241 | def __init__(self, sparsity, std=0.01, bias=False, bias_only=False, module_filter='*'): 242 | self.sparsity = sparsity 243 | self.std = std 244 | 245 | self.bias = bias 246 | self.bias_only = bias_only 247 | self.module_filter = module_filter 248 | 249 | super(Sparse, self).__init__() 250 | 251 | def __call__(self, module): 252 | classname = module.__class__.__name__ 253 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 254 | if self.bias_only: 255 | torch.nn.init.sparse(module.bias.data, sparsity=self.sparsity, std=self.std) 256 | else: 257 | torch.nn.init.sparse(module.weight.data, sparsity=self.sparsity, std=self.std) 258 | if self.bias: 259 | torch.nn.init.sparse(module.bias.data, sparsity=self.sparsity, std=self.std) 260 | 261 | 262 | 263 | -------------------------------------------------------------------------------- /torchsample/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | 5 | import torch as th 6 | 7 | from .utils import th_matrixcorr 8 | 9 | from .callbacks import Callback 10 | 11 | class MetricContainer(object): 12 | 13 | 14 | def __init__(self, metrics, prefix=''): 15 | self.metrics = metrics 16 | self.helper = None 17 | self.prefix = prefix 18 | 19 | def set_helper(self, helper): 20 | self.helper = helper 21 | 22 | def reset(self): 23 | for metric in self.metrics: 24 | metric.reset() 25 | 26 | def __call__(self, output_batch, target_batch): 27 | logs = {} 28 | for metric in self.metrics: 29 | logs[self.prefix+metric._name] = self.helper.calculate_loss(output_batch, 30 | target_batch, 31 | metric) 32 | return logs 33 | 34 | class Metric(object): 35 | 36 | def __call__(self, y_pred, y_true): 37 | raise NotImplementedError('Custom Metrics must implement this function') 38 | 39 | def reset(self): 40 | raise NotImplementedError('Custom Metrics must implement this function') 41 | 42 | 43 | class MetricCallback(Callback): 44 | 45 | def __init__(self, container): 46 | self.container = container 47 | def on_epoch_begin(self, epoch_idx, logs): 48 | self.container.reset() 49 | 50 | class CategoricalAccuracy(Metric): 51 | 52 | def __init__(self, top_k=1): 53 | self.top_k = top_k 54 | self.correct_count = 0 55 | self.total_count = 0 56 | self.accuracy = 0 57 | 58 | self._name = 'acc_metric' 59 | 60 | def reset(self): 61 | self.correct_count = 0 62 | self.total_count = 0 63 | self.accuracy = 0 64 | 65 | def __call__(self, y_pred, y_true): 66 | top_k = y_pred.topk(self.top_k,1)[1] 67 | true_k = y_true.view(len(y_true),1).expand_as(top_k) 68 | self.correct_count += top_k.eq(true_k).float().sum().data[0] 69 | self.total_count += len(y_pred) 70 | accuracy = 100. * float(self.correct_count) / float(self.total_count) 71 | return accuracy 72 | 73 | 74 | class BinaryAccuracy(Metric): 75 | 76 | def __init__(self): 77 | self.correct_count = 0 78 | self.total_count = 0 79 | self.accuracy = 0 80 | 81 | self._name = 'acc_metric' 82 | 83 | def reset(self): 84 | self.correct_count = 0 85 | self.total_count = 0 86 | self.accuracy = 0 87 | 88 | def __call__(self, y_pred, y_true): 89 | y_pred_round = y_pred.round().long() 90 | self.correct_count += y_pred_round.eq(y_true).float().sum().data[0] 91 | self.total_count += len(y_pred) 92 | accuracy = 100. * float(self.correct_count) / float(self.total_count) 93 | return accuracy 94 | 95 | 96 | class ProjectionCorrelation(Metric): 97 | 98 | def __init__(self): 99 | self.corr_sum = 0. 100 | self.total_count = 0. 101 | 102 | self._name = 'corr_metric' 103 | 104 | def reset(self): 105 | self.corr_sum = 0. 106 | self.total_count = 0. 107 | self.average = 0. 108 | 109 | def __call__(self, y_pred, y_true=None): 110 | """ 111 | y_pred should be two projections 112 | """ 113 | covar_mat = th.abs(th_matrixcorr(y_pred[0].data, y_pred[1].data)) 114 | self.corr_sum += th.trace(covar_mat) 115 | self.total_count += covar_mat.size(0) 116 | return self.corr_sum / self.total_count 117 | 118 | 119 | class ProjectionAntiCorrelation(Metric): 120 | 121 | def __init__(self): 122 | self.anticorr_sum = 0. 123 | self.total_count = 0. 124 | self.average = 0. 125 | 126 | self._name = 'anticorr_metric' 127 | 128 | def reset(self): 129 | self.anticorr_sum = 0. 130 | self.total_count = 0. 131 | self.average = 0. 132 | 133 | def __call__(self, y_pred, y_true=None): 134 | """ 135 | y_pred should be two projections 136 | """ 137 | covar_mat = th.abs(th_matrixcorr(y_pred[0].data, y_pred[1].data)) 138 | upper_sum = th.sum(th.triu(covar_mat,1)) 139 | lower_sum = th.sum(th.tril(covar_mat,-1)) 140 | self.anticorr_sum += upper_sum 141 | self.anticorr_sum += lower_sum 142 | self.total_count += covar_mat.size(0)*(covar_mat.size(1) - 1) 143 | return self.anticorr_sum / self.total_count 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /torchsample/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .module_trainer import ModuleTrainer 4 | -------------------------------------------------------------------------------- /torchsample/modules/_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import warnings 4 | 5 | try: 6 | from inspect import signature 7 | except: 8 | warnings.warn('inspect.signature not available... ' 9 | 'you should upgrade to Python 3.x') 10 | 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | 14 | from ..metrics import Metric, CategoricalAccuracy, BinaryAccuracy 15 | from ..initializers import GeneralInitializer 16 | 17 | def _add_regularizer_to_loss_fn(loss_fn, 18 | regularizer_container): 19 | def new_loss_fn(output_batch, target_batch): 20 | return loss_fn(output_batch, target_batch) + regularizer_container.get_value() 21 | return new_loss_fn 22 | 23 | def _is_iterable(x): 24 | return isinstance(x, (tuple, list)) 25 | def _is_tuple_or_list(x): 26 | return isinstance(x, (tuple, list)) 27 | 28 | def _parse_num_inputs_and_targets_from_loader(loader): 29 | """ NOT IMPLEMENTED """ 30 | #batch = next(iter(loader)) 31 | num_inputs = loader.dataset.num_inputs 32 | num_targets = loader.dataset.num_targets 33 | return num_inputs, num_targets 34 | 35 | def _parse_num_inputs_and_targets(inputs, targets=None): 36 | if isinstance(inputs, (list, tuple)): 37 | num_inputs = len(inputs) 38 | else: 39 | num_inputs = 1 40 | if targets is not None: 41 | if isinstance(targets, (list, tuple)): 42 | num_targets = len(targets) 43 | else: 44 | num_targets = 1 45 | else: 46 | num_targets = 0 47 | return num_inputs, num_targets 48 | 49 | def _standardize_user_data(inputs, targets=None): 50 | if not isinstance(inputs, (list,tuple)): 51 | inputs = [inputs] 52 | if targets is not None: 53 | if not isinstance(targets, (list,tuple)): 54 | targets = [targets] 55 | return inputs, targets 56 | else: 57 | return inputs 58 | 59 | def _validate_metric_input(metric): 60 | if isinstance(metric, str): 61 | if metric.upper() == 'CATEGORICAL_ACCURACY' or metric.upper() == 'ACCURACY': 62 | return CategoricalAccuracy() 63 | elif metric.upper() == 'BINARY_ACCURACY': 64 | return BinaryAccuracy() 65 | else: 66 | raise ValueError('Invalid metric string input - must match pytorch function.') 67 | elif isinstance(metric, Metric): 68 | return metric 69 | else: 70 | raise ValueError('Invalid metric input') 71 | 72 | def _validate_loss_input(loss): 73 | dir_f = dir(F) 74 | loss_fns = [d.lower() for d in dir_f] 75 | if isinstance(loss, str): 76 | if loss.lower() == 'unconstrained': 77 | return lambda x: x 78 | elif loss.lower() == 'unconstrained_sum': 79 | return lambda x: x.sum() 80 | elif loss.lower() == 'unconstrained_mean': 81 | return lambda x: x.mean() 82 | else: 83 | try: 84 | str_idx = loss_fns.index(loss.lower()) 85 | except: 86 | raise ValueError('Invalid loss string input - must match pytorch function.') 87 | return getattr(F, dir(F)[str_idx]) 88 | elif callable(loss): 89 | return loss 90 | else: 91 | raise ValueError('Invalid loss input') 92 | 93 | def _validate_optimizer_input(optimizer): 94 | dir_optim = dir(optim) 95 | opts = [o.lower() for o in dir_optim] 96 | if isinstance(optimizer, str): 97 | try: 98 | str_idx = opts.index(optimizer.lower()) 99 | except: 100 | raise ValueError('Invalid optimizer string input - must match pytorch function.') 101 | return getattr(optim, dir_optim[str_idx]) 102 | elif hasattr(optimizer, 'step') and hasattr(optimizer, 'zero_grad'): 103 | return optimizer 104 | else: 105 | raise ValueError('Invalid optimizer input') 106 | 107 | def _validate_initializer_input(initializer): 108 | if isinstance(initializer, str): 109 | try: 110 | initializer = GeneralInitializer(initializer) 111 | except: 112 | raise ValueError('Invalid initializer string input - must match pytorch function.') 113 | return initializer 114 | elif callable(initializer): 115 | return initializer 116 | else: 117 | raise ValueError('Invalid optimizer input') 118 | 119 | def _get_current_time(): 120 | return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p") 121 | 122 | def _nb_function_args(fn): 123 | return len(signature(fn).parameters) -------------------------------------------------------------------------------- /torchsample/regularizers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | from fnmatch import fnmatch 4 | 5 | from .callbacks import Callback 6 | 7 | class RegularizerContainer(object): 8 | 9 | def __init__(self, regularizers): 10 | self.regularizers = regularizers 11 | self._forward_hooks = [] 12 | 13 | def register_forward_hooks(self, model): 14 | for regularizer in self.regularizers: 15 | for module_name, module in model.named_modules(): 16 | if fnmatch(module_name, regularizer.module_filter) and hasattr(module, 'weight'): 17 | hook = module.register_forward_hook(regularizer) 18 | self._forward_hooks.append(hook) 19 | 20 | if len(self._forward_hooks) == 0: 21 | raise Exception('Tried to register regularizers but no modules ' 22 | 'were found that matched any module_filter argument.') 23 | 24 | def unregister_forward_hooks(self): 25 | for hook in self._forward_hooks: 26 | hook.remove() 27 | 28 | def reset(self): 29 | for r in self.regularizers: 30 | r.reset() 31 | 32 | def get_value(self): 33 | value = sum([r.value for r in self.regularizers]) 34 | self.current_value = value.data[0] 35 | return value 36 | 37 | def __len__(self): 38 | return len(self.regularizers) 39 | 40 | 41 | class RegularizerCallback(Callback): 42 | 43 | def __init__(self, container): 44 | self.container = container 45 | 46 | def on_batch_end(self, batch, logs=None): 47 | self.container.reset() 48 | 49 | 50 | class Regularizer(object): 51 | 52 | def reset(self): 53 | raise NotImplementedError('subclass must implement this method') 54 | 55 | def __call__(self, module, input=None, output=None): 56 | raise NotImplementedError('subclass must implement this method') 57 | 58 | 59 | class L1Regularizer(Regularizer): 60 | 61 | def __init__(self, scale=1e-3, module_filter='*'): 62 | self.scale = float(scale) 63 | self.module_filter = module_filter 64 | self.value = 0. 65 | 66 | def reset(self): 67 | self.value = 0. 68 | 69 | def __call__(self, module, input=None, output=None): 70 | value = th.sum(th.abs(module.weight)) * self.scale 71 | self.value += value 72 | 73 | 74 | class L2Regularizer(Regularizer): 75 | 76 | def __init__(self, scale=1e-3, module_filter='*'): 77 | self.scale = float(scale) 78 | self.module_filter = module_filter 79 | self.value = 0. 80 | 81 | def reset(self): 82 | self.value = 0. 83 | 84 | def __call__(self, module, input=None, output=None): 85 | value = th.sum(th.pow(module.weight,2)) * self.scale 86 | self.value += value 87 | 88 | 89 | class L1L2Regularizer(Regularizer): 90 | 91 | def __init__(self, l1_scale=1e-3, l2_scale=1e-3, module_filter='*'): 92 | self.l1 = L1Regularizer(l1_scale) 93 | self.l2 = L2Regularizer(l2_scale) 94 | self.module_filter = module_filter 95 | self.value = 0. 96 | 97 | def reset(self): 98 | self.value = 0. 99 | 100 | def __call__(self, module, input=None, output=None): 101 | self.l1(module, input, output) 102 | self.l2(module, input, output) 103 | self.value += (self.l1.value + self.l2.value) 104 | 105 | 106 | # ------------------------------------------------------------------ 107 | # ------------------------------------------------------------------ 108 | # ------------------------------------------------------------------ 109 | 110 | class UnitNormRegularizer(Regularizer): 111 | """ 112 | UnitNorm constraint on Weights 113 | 114 | Constraints the weights to have column-wise unit norm 115 | """ 116 | def __init__(self, 117 | scale=1e-3, 118 | module_filter='*'): 119 | 120 | self.scale = scale 121 | self.module_filter = module_filter 122 | self.value = 0. 123 | 124 | def reset(self): 125 | self.value = 0. 126 | 127 | def __call__(self, module, input=None, output=None): 128 | w = module.weight 129 | norm_diff = th.norm(w, 2, 1).sub(1.) 130 | value = self.scale * th.sum(norm_diff.gt(0).float().mul(norm_diff)) 131 | self.value += value 132 | 133 | 134 | class MaxNormRegularizer(Regularizer): 135 | """ 136 | MaxNorm regularizer on Weights 137 | 138 | Constraints the weights to have column-wise unit norm 139 | """ 140 | def __init__(self, 141 | scale=1e-3, 142 | module_filter='*'): 143 | 144 | self.scale = scale 145 | self.module_filter = module_filter 146 | self.value = 0. 147 | 148 | def reset(self): 149 | self.value = 0. 150 | 151 | def __call__(self, module, input=None, output=None): 152 | w = module.weight 153 | norm_diff = th.norm(w,2,self.axis).sub(self.value) 154 | value = self.scale * th.sum(norm_diff.gt(0).float().mul(norm_diff)) 155 | self.value += value 156 | 157 | 158 | class NonNegRegularizer(Regularizer): 159 | """ 160 | Non-Negativity regularizer on Weights 161 | 162 | Constraints the weights to have column-wise unit norm 163 | """ 164 | def __init__(self, 165 | scale=1e-3, 166 | module_filter='*'): 167 | 168 | self.scale = scale 169 | self.module_filter = module_filter 170 | self.value = 0. 171 | 172 | def reset(self): 173 | self.value = 0. 174 | 175 | def __call__(self, module, input=None, output=None): 176 | w = module.weight 177 | value = -1 * self.scale * th.sum(w.gt(0).float().mul(w)) 178 | self.value += value 179 | 180 | -------------------------------------------------------------------------------- /torchsample/samplers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import math 4 | 5 | class Sampler(object): 6 | """Base class for all Samplers. 7 | 8 | Every Sampler subclass has to provide an __iter__ method, providing a way 9 | to iterate over indices of dataset elements, and a __len__ method that 10 | returns the length of the returned iterators. 11 | """ 12 | 13 | def __init__(self, data_source): 14 | pass 15 | 16 | def __iter__(self): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | class StratifiedSampler(Sampler): 23 | """Stratified Sampling 24 | 25 | Provides equal representation of target classes in each batch 26 | """ 27 | def __init__(self, class_vector, batch_size): 28 | """ 29 | Arguments 30 | --------- 31 | class_vector : torch tensor 32 | a vector of class labels 33 | batch_size : integer 34 | batch_size 35 | """ 36 | self.n_splits = int(class_vector.size(0) / batch_size) 37 | self.class_vector = class_vector 38 | 39 | def gen_sample_array(self): 40 | try: 41 | from sklearn.model_selection import StratifiedShuffleSplit 42 | except: 43 | print('Need scikit-learn for this functionality') 44 | import numpy as np 45 | 46 | s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.5) 47 | X = th.randn(self.class_vector.size(0),2).numpy() 48 | y = self.class_vector.numpy() 49 | s.get_n_splits(X, y) 50 | 51 | train_index, test_index = next(s.split(X, y)) 52 | return np.hstack([train_index, test_index]) 53 | 54 | def __iter__(self): 55 | return iter(self.gen_sample_array()) 56 | 57 | def __len__(self): 58 | return len(self.class_vector) 59 | 60 | class MultiSampler(Sampler): 61 | """Samples elements more than once in a single pass through the data. 62 | 63 | This allows the number of samples per epoch to be larger than the number 64 | of samples itself, which can be useful when training on 2D slices taken 65 | from 3D images, for instance. 66 | """ 67 | def __init__(self, nb_samples, desired_samples, shuffle=False): 68 | """Initialize MultiSampler 69 | 70 | Arguments 71 | --------- 72 | data_source : the dataset to sample from 73 | 74 | desired_samples : number of samples per batch you want 75 | whatever the difference is between an even division will 76 | be randomly selected from the samples. 77 | e.g. if len(data_source) = 3 and desired_samples = 4, then 78 | all 3 samples will be included and the last sample will be 79 | randomly chosen from the 3 original samples. 80 | 81 | shuffle : boolean 82 | whether to shuffle the indices or not 83 | 84 | Example: 85 | >>> m = MultiSampler(2, 6) 86 | >>> x = m.gen_sample_array() 87 | >>> print(x) # [0,1,0,1,0,1] 88 | """ 89 | self.data_samples = nb_samples 90 | self.desired_samples = desired_samples 91 | self.shuffle = shuffle 92 | 93 | def gen_sample_array(self): 94 | from torchsample.utils import th_random_choice 95 | n_repeats = self.desired_samples / self.data_samples 96 | cat_list = [] 97 | for i in range(math.floor(n_repeats)): 98 | cat_list.append(th.arange(0,self.data_samples)) 99 | # add the left over samples 100 | left_over = self.desired_samples % self.data_samples 101 | if left_over > 0: 102 | cat_list.append(th_random_choice(self.data_samples, left_over)) 103 | self.sample_idx_array = th.cat(cat_list).long() 104 | return self.sample_idx_array 105 | 106 | def __iter__(self): 107 | return iter(self.gen_sample_array()) 108 | 109 | def __len__(self): 110 | return self.desired_samples 111 | 112 | 113 | class SequentialSampler(Sampler): 114 | """Samples elements sequentially, always in the same order. 115 | 116 | Arguments: 117 | data_source (Dataset): dataset to sample from 118 | """ 119 | 120 | def __init__(self, nb_samples): 121 | self.num_samples = nb_samples 122 | 123 | def __iter__(self): 124 | return iter(range(self.num_samples)) 125 | 126 | def __len__(self): 127 | return self.num_samples 128 | 129 | 130 | class RandomSampler(Sampler): 131 | """Samples elements randomly, without replacement. 132 | 133 | Arguments: 134 | data_source (Dataset): dataset to sample from 135 | """ 136 | 137 | def __init__(self, nb_samples): 138 | self.num_samples = nb_samples 139 | 140 | def __iter__(self): 141 | return iter(th.randperm(self.num_samples).long()) 142 | 143 | def __len__(self): 144 | return self.num_samples 145 | 146 | 147 | -------------------------------------------------------------------------------- /torchsample/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from .affine_transforms import * 5 | from .image_transforms import * 6 | from .tensor_transforms import * -------------------------------------------------------------------------------- /torchsample/transforms/distortion_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms to distort local or global information of an image 3 | """ 4 | 5 | 6 | import torch as th 7 | import numpy as np 8 | import random 9 | 10 | 11 | class Scramble(object): 12 | """ 13 | Create blocks of an image and scramble them 14 | """ 15 | def __init__(self, blocksize): 16 | self.blocksize = blocksize 17 | 18 | def __call__(self, *inputs): 19 | outputs = [] 20 | for idx, _input in enumerate(inputs): 21 | size = _input.size() 22 | img_height = size[1] 23 | img_width = size[2] 24 | 25 | x_blocks = int(img_height/self.blocksize) # number of x blocks 26 | y_blocks = int(img_width/self.blocksize) 27 | ind = th.randperm(x_blocks*y_blocks) 28 | 29 | new = th.zeros(_input.size()) 30 | count = 0 31 | for i in range(x_blocks): 32 | for j in range (y_blocks): 33 | row = int(ind[count] / x_blocks) 34 | column = ind[count] % x_blocks 35 | new[:, i*self.blocksize:(i+1)*self.blocksize, j*self.blocksize:(j+1)*self.blocksize] = \ 36 | _input[:, row*self.blocksize:(row+1)*self.blocksize, column*self.blocksize:(column+1)*self.blocksize] 37 | count += 1 38 | outputs.append(new) 39 | return outputs if idx > 1 else outputs[0] 40 | 41 | 42 | class RandomChoiceScramble(object): 43 | 44 | def __init__(self, blocksizes): 45 | self.blocksizes = blocksizes 46 | 47 | def __call__(self, *inputs): 48 | blocksize = random.choice(self.blocksizes) 49 | outputs = Scramble(blocksize=blocksize)(*inputs) 50 | return outputs 51 | 52 | 53 | def _blur_image(image, H): 54 | # break image up into its color components 55 | size = image.shape 56 | imr = image[0,:,:] 57 | img = image[1,:,:] 58 | imb = image[2,:,:] 59 | 60 | # compute Fourier transform and frequqnecy spectrum 61 | Fim1r = np.fft.fftshift(np.fft.fft2(imr)) 62 | Fim1g = np.fft.fftshift(np.fft.fft2(img)) 63 | Fim1b = np.fft.fftshift(np.fft.fft2(imb)) 64 | 65 | # Apply the lowpass filter to the Fourier spectrum of the image 66 | filtered_imager = np.multiply(H, Fim1r) 67 | filtered_imageg = np.multiply(H, Fim1g) 68 | filtered_imageb = np.multiply(H, Fim1b) 69 | 70 | newim = np.zeros(size) 71 | 72 | # convert the result to the spatial domain. 73 | newim[0,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imager))) 74 | newim[1,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imageg))) 75 | newim[2,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imageb))) 76 | 77 | return newim.astype('uint8') 78 | 79 | def _butterworth_filter(rows, cols, thresh, order): 80 | # X and Y matrices with ranges normalised to +/- 0.5 81 | array1 = np.ones(rows) 82 | array2 = np.ones(cols) 83 | array3 = np.arange(1,rows+1) 84 | array4 = np.arange(1,cols+1) 85 | 86 | x = np.outer(array1, array4) 87 | y = np.outer(array3, array2) 88 | 89 | x = x - float(cols/2) - 1 90 | y = y - float(rows/2) - 1 91 | 92 | x = x / cols 93 | y = y / rows 94 | 95 | radius = np.sqrt(np.square(x) + np.square(y)) 96 | 97 | matrix1 = radius/thresh 98 | matrix2 = np.power(matrix1, 2*order) 99 | f = np.reciprocal(1 + matrix2) 100 | 101 | return f 102 | 103 | 104 | class Blur(object): 105 | """ 106 | Blur an image with a Butterworth filter with a frequency 107 | cutoff matching local block size 108 | """ 109 | def __init__(self, threshold, order=5): 110 | """ 111 | scramble blocksize of 128 => filter threshold of 64 112 | scramble blocksize of 64 => filter threshold of 32 113 | scramble blocksize of 32 => filter threshold of 16 114 | scramble blocksize of 16 => filter threshold of 8 115 | scramble blocksize of 8 => filter threshold of 4 116 | """ 117 | self.threshold = threshold 118 | self.order = order 119 | 120 | def __call__(self, *inputs): 121 | """ 122 | inputs should have values between 0 and 255 123 | """ 124 | outputs = [] 125 | for idx, _input in enumerate(inputs): 126 | rows = _input.size(1) 127 | cols = _input.size(2) 128 | fc = self.threshold # threshold 129 | fs = 128.0 # max frequency 130 | n = self.order # filter order 131 | fc_rad = (fc/fs)*0.5 132 | H = _butterworth_filter(rows, cols, fc_rad, n) 133 | _input_blurred = _blur_image(_input.numpy().astype('uint8'), H) 134 | _input_blurred = th.from_numpy(_input_blurred).float() 135 | outputs.append(_input_blurred) 136 | 137 | return outputs if idx > 1 else outputs[0] 138 | 139 | 140 | class RandomChoiceBlur(object): 141 | 142 | def __init__(self, thresholds, order=5): 143 | """ 144 | thresholds = [64.0, 32.0, 16.0, 8.0, 4.0] 145 | """ 146 | self.thresholds = thresholds 147 | self.order = order 148 | 149 | def __call__(self, *inputs): 150 | threshold = random.choice(self.thresholds) 151 | outputs = Blur(threshold=threshold, order=self.order)(*inputs) 152 | return outputs 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /torchsample/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.3' 2 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .trainers import * 3 | from .pseudolabels import * 4 | -------------------------------------------------------------------------------- /training/learning_rates.py: -------------------------------------------------------------------------------- 1 | import math 2 | import operator 3 | import copy 4 | 5 | 6 | def set_learning_rate(optimizer, lr): 7 | for param_group in optimizer.param_groups: 8 | param_group['lr'] = lr 9 | 10 | 11 | def get_learning_rate(optimizer): 12 | return optimizer.param_groups[0]['lr'] 13 | 14 | 15 | 16 | class LearningRate(): 17 | def __init__(self, initial_lr, iteration_type): 18 | self.initial_lr = initial_lr 19 | self.iteration_type = iteration_type #epoch or mini_batch 20 | 21 | def get_learning_rate(self, optimizer): 22 | return optimizer.param_groups[0]['lr'] 23 | 24 | def set_learning_rate(self, optimizer, new_lr): 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = new_lr 27 | 28 | def adjust(self, optimizer, lr, iteration, params=None): 29 | self.set_learning_rate(optimizer, lr) 30 | return lr 31 | 32 | 33 | class FixedLR(LearningRate): 34 | def __init__(self, initial_lr, iteration_type): 35 | super().__init__(initial_lr, iteration_type) 36 | 37 | def adjust(self, optimizer, iteration, params=None): 38 | new_lr = super().get_learning_rate(optimizer) 39 | return new_lr 40 | 41 | 42 | class LinearLR(LearningRate): 43 | def __init__(self, initial_lr, iteration_type, fixed_delta): 44 | super().__init__(initial_lr, iteration_type) 45 | self.fixed_delta = fixed_delta 46 | 47 | def adjust(self, optimizer, iteration, params=None): 48 | lr = super().get_learning_rate(optimizer) 49 | new_lr = lr + self.fixed_delta 50 | super().set_learning_rate(optimizer, new_lr) 51 | return new_lr 52 | 53 | 54 | class SnapshotLR(LearningRate): 55 | '''https://arxiv.org/abs/1704.00109''' 56 | def __init__(self, initial_lr, iteration_type, 57 | max_lr, total_iters, n_cycles): 58 | ''' 59 | n_iters = total number of mini-batch iterations during training 60 | n_cycles = total num snapshots during training 61 | max_lr = starting learning rate each cycle''' 62 | super().__init__(initial_lr, iteration_type) 63 | self.max_lr = max_lr 64 | self.total_iters = total_iters 65 | self.cycles = n_cycles 66 | 67 | def cosine_annealing(self, t): 68 | '''t = current mini-batch iteration''' 69 | return self.max_lr/2 * (math.cos( 70 | (math.pi * (t % (self.total_iters//self.cycles))) / 71 | (self.total_iters//self.cycles)) + 1) 72 | 73 | def adjust(self, optimizer, iteration, params=None): 74 | new_lr = self.cosine_annealing(iteration) 75 | self.set_learning_rate(optimizer, new_lr) 76 | return new_lr 77 | 78 | 79 | class SnapshotParamsLR(LearningRate): 80 | '''Snapshot Learning with per-parameter LRs''' 81 | def __init__(self, initial_lr, iteration_type, 82 | total_iters, n_cycles): 83 | ''' 84 | n_iters = total number of mini-batch iterations during training 85 | n_cycles = total num snapshots during training 86 | max_lr = starting learning rate each cycle''' 87 | super().__init__(initial_lr, iteration_type) 88 | self.total_iters = total_iters 89 | self.cycles = n_cycles 90 | 91 | def cosine_annealing(self, t, max_lr): 92 | return max_lr/2 * (math.cos( 93 | (math.pi * (t % (self.total_iters//self.cycles)))/( 94 | self.total_iters//self.cycles)) + 1) 95 | 96 | def adjust(self, optimizer, iteration, params=None): 97 | lrs = [] 98 | for param_group in optimizer.param_groups: 99 | new_lr = self.cosine_annealing(iteration, param_group['max_lr']) 100 | param_group['lr'] = new_lr 101 | lrs.append(new_lr) 102 | return new_lr 103 | 104 | 105 | class DevDecayLR(LearningRate): 106 | '''https://arxiv.org/abs/1705.08292''' 107 | def __init__(self, initial_lr, iteration_type, 108 | decay_factor=0.9, decay_patience=1): 109 | super().__init__(initial_lr, iteration_type) 110 | self.decay_factor = decay_factor 111 | self.decay_patience = decay_patience 112 | 113 | def adjust(self, optimizer, iteration, params): 114 | lr = super().get_learning_rate(optimizer) 115 | best_iter = params['best_iter'] 116 | 117 | if (iteration - best_iter) > self.decay_patience: 118 | print('Decaying learning rate by factor: {:.5f}'.format( 119 | self.decay_factor).rstrip('0')) 120 | lr *= self.decay_factor 121 | super().set_learning_rate(optimizer, lr) 122 | return lr 123 | 124 | 125 | class ScheduledLR(LearningRate): 126 | def __init__(self, initial_lr, iteration_type, lr_schedule): 127 | super().__init__(initial_lr, iteration_type) 128 | self.lr_schedule = lr_schedule 129 | 130 | def adjust(self, optimizer, iteration, params=None): 131 | if iteration in self.lr_schedule: 132 | new_lr = self.lr_schedule[iteration] 133 | else: 134 | new_lr = self.get_learning_rate(optimizer) 135 | super().set_learning_rate(optimizer, new_lr) 136 | return new_lr 137 | 138 | 139 | class DecayingLR(LearningRate): 140 | def __init__(self, initial_lr, iteration_type, decay, n_epochs): 141 | super().__init__(initial_lr, iteration_type) 142 | self.decay = decay 143 | self.n_epochs = n_epochs 144 | 145 | def exponential_decay(self, iteration, params=None): 146 | '''Update learning rate to `initial_lr` decayed 147 | by `decay` every `n_epochs`''' 148 | return self.initial_lr * (self.decay ** (iteration // self.n_epochs)) 149 | 150 | def adjust(self, optimizer, iteration): 151 | new_lr = self.exponential_decay(iteration) 152 | super().set_learning_rate(optimizer, new_lr) 153 | return new_lr 154 | 155 | 156 | class CyclicalLR(LearningRate): 157 | '''https://arxiv.org/abs/1506.01186''' 158 | def __init__(self, initial_lr, iteration_type, n_iters, cycle_length, 159 | min_lr, max_lr): 160 | assert initial_lr == min_lr 161 | super(CyclicalLR, self).__init__(initial_lr, iteration_type) 162 | self.n_iters = n_iters 163 | self.cycle_length = cycle_length 164 | self.min_lr = min_lr 165 | self.max_lr = max_lr 166 | 167 | def triangular(self, iteration): 168 | iteration -= 1 # if iteration count starts at 1 169 | cycle = math.floor(1 + iteration/self.cycle_length) 170 | x = abs(iteration/(self.cycle_length/2) - 2*cycle + 1) 171 | new_lr = self.min_lr + (self.max_lr - self.min_lr) * max(0, (1-x)) 172 | return new_lr 173 | 174 | def adjust(self, optimizer, iteration, best_iter=1): 175 | new_lr = self.triangular(iteration) 176 | super().set_learning_rate(optimizer, new_lr) 177 | return new_lr 178 | 179 | 180 | 181 | 182 | ## Helpers 183 | 184 | def cosine_annealing(lr_max, T, M, t): 185 | ''' 186 | t = current mini-batch iteration 187 | # lr(t) = f(t-1 % T//M) 188 | # lr(t) = lr_max/2 * (math.cos( (math.pi * (t % T/M))/(T/M) ) + 1) 189 | ''' 190 | return lr_max/2 * (math.cos( (math.pi * (t % (T//M)))/(T//M)) + 1) 191 | -------------------------------------------------------------------------------- /training/pseudolabels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from predictions import pred_utils 5 | import utils.files as file_utils 6 | import datasets.data_aug as data_aug 7 | import pickle 8 | from datasets.datasets import FileDataset 9 | 10 | 11 | def get_pseudo_label_targets(fpaths, model, img_scale, n_labels, thresholds): 12 | dataset = FileDataset(fpaths, targets=None, 13 | transform=data_aug.get_basic_transform(img_scale)) 14 | dataloader = torch.utils.data.DataLoader(dataset, 64, shuffle=False, 15 | pin_memory=False, num_workers=1) 16 | probs = pred_utils.get_probabilities(model, dataloader) 17 | preds = pred_utils.get_predictions(probs, thresholds) 18 | return preds, probs 19 | 20 | 21 | def get_pseudo_labeled_fpaths_targets(dir_path, model, n_samples, 22 | img_scale, n_labels, thresholds): 23 | fpaths, _ = file_utils.get_paths_to_files(dir_path) 24 | random.shuffle(fpaths) 25 | fpaths = fpaths[:n_samples] 26 | targets, _ = get_pseudo_label_targets(fpaths, model, img_scale, 27 | n_labels, thresholds) 28 | return fpaths, targets 29 | 30 | 31 | def combined_train_and_pseudo_fpaths_targets(trn_fpaths, trn_targets, 32 | pseudo_fpaths, pseudo_targets): 33 | combined_fpaths = trn_fpaths + pseudo_fpaths 34 | combined_targets = np.vstack([trn_targets, pseudo_targets]) 35 | return combined_fpaths, combined_targets 36 | 37 | 38 | def save_pseudo_labels(pseudo_preds, img_paths, out_fpath): 39 | obj = {'preds':pseudo_preds, 'img_paths':img_paths} 40 | with open(out_fpath, 'wb') as f: 41 | pickle.dump(obj, f) 42 | 43 | 44 | def load_pseudo_labels(fpath): 45 | obj = pickle.load(open(fpath, 'rb')) 46 | return obj['img_paths'], obj['preds'] 47 | -------------------------------------------------------------------------------- /training/trainers.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import logging 5 | from torch.autograd import Variable 6 | 7 | import constants as c 8 | from predictions import pred_utils 9 | from metrics import metric 10 | from metrics import metric_utils 11 | from . import utils as trn_utils 12 | 13 | 14 | 15 | class QuickTrainer(): 16 | def __init__(self, metrics): 17 | self.metrics = metrics 18 | self.logger = None 19 | 20 | def train(self, model, optim, lr_adjuster, criterion, trn_loader, 21 | val_loader, n_classes, threshold, n_epochs): 22 | start_epoch = 1 23 | end_epoch = start_epoch + n_epochs 24 | 25 | for epoch in range(start_epoch, end_epoch): 26 | current_lr = lr_adjuster.get_learning_rate(optim) 27 | 28 | ### Train ### 29 | trn_start_time = time.time() 30 | trn_metrics = trn_utils.train_model(model, trn_loader, threshold, 31 | optim, criterion, lr_adjuster, epoch, n_epochs, 32 | self.metrics) 33 | trn_msg = trn_utils.log_trn_msg(self.logger, trn_start_time, 34 | trn_metrics, current_lr, epoch) 35 | print(trn_msg) 36 | 37 | ### Test ### 38 | val_start_time = time.time() 39 | val_metrics = trn_utils.test_model(model, val_loader, threshold, 40 | n_classes, criterion, self.metrics) 41 | val_msg = trn_utils.log_val_msg(self.logger, val_start_time, 42 | val_metrics, current_lr) 43 | print(val_msg) 44 | 45 | ### Adjust Lr ### 46 | if lr_adjuster.iteration_type == 'epoch': 47 | lr_adjuster.adjust(optim, epoch+1) 48 | 49 | 50 | class Trainer(): 51 | def __init__(self, trn_criterion, tst_criterion, optimizer, lr_adjuster): 52 | self.trn_criterion = trn_criterion 53 | self.tst_criterion = tst_criterion 54 | self.optimizer = optimizer 55 | self.lr_adjuster = lr_adjuster 56 | 57 | def train(self, model, loader, thresholds, epoch, metrics): 58 | model.train() 59 | 60 | loss_data = 0 61 | n_classes = loader.dataset.targets.shape[1] 62 | probs = np.empty((0, n_classes)) 63 | labels = np.empty((0, n_classes)) 64 | metric_totals = {m.name:0 for m in metrics} 65 | cur_iter = int((epoch-1) * len(loader))+1 66 | 67 | for inputs, targets, _ in loader: 68 | if len(targets.size()) == 1: 69 | targets = targets.float().view(-1, 1) 70 | inputs = Variable(inputs.cuda(async=True)) 71 | targets = Variable(targets.cuda(async=True)) 72 | 73 | ## Forward Pass 74 | output = model(inputs) 75 | 76 | ## Clear Gradients 77 | model.zero_grad() 78 | 79 | # Loss 80 | loss = self.trn_criterion(output, targets) 81 | 82 | ## Backprop 83 | loss.backward() 84 | self.optimizer.step() 85 | 86 | ### Adjust Lr ### 87 | if self.lr_adjuster.iteration_type == 'mini_batch': 88 | self.lr_adjuster.adjust(self.optimizer, cur_iter) 89 | cur_iter += 1 90 | 91 | loss_data += loss.data[0] 92 | probs = np.vstack([probs, output.data.cpu().numpy()]) 93 | labels = np.vstack([labels, targets.data.cpu().numpy()]) 94 | 95 | 96 | loss_data /= len(loader) 97 | preds = pred_utils.get_predictions(probs, thresholds) 98 | 99 | for m in metrics: 100 | score = m.evaluate(loss_data, preds, probs, labels) 101 | metric_totals[m.name] = score 102 | 103 | return metric_totals 104 | 105 | def test(self, model, loader, thresholds, metrics): 106 | model.eval() 107 | 108 | loss = 0 109 | n_classes = loader.dataset.targets.shape[1] 110 | probs = np.empty((0, n_classes)) 111 | labels = np.empty((0, n_classes)) 112 | metric_totals = {m.name:0 for m in metrics} 113 | 114 | for inputs, targets, _ in loader: 115 | if len(targets.size()) == 1: 116 | targets = targets.float().view(-1,1) 117 | inputs = Variable(inputs.cuda(async=True), volatile=True) 118 | targets = Variable(targets.cuda(async=True), volatile=True) 119 | 120 | output = model(inputs) 121 | 122 | loss += self.tst_criterion(output, targets).data[0] 123 | probs = np.vstack([probs, output.data.cpu().numpy()]) 124 | labels = np.vstack([labels, targets.data.cpu().numpy()]) 125 | 126 | loss /= len(loader) 127 | preds = pred_utils.get_predictions(probs, thresholds) 128 | 129 | for m in metrics: 130 | score = m.evaluate(loss, preds, probs, labels) 131 | metric_totals[m.name] = score 132 | 133 | return metric_totals 134 | 135 | 136 | class MultiTargetTrainer(Trainer): 137 | def __init__(self, trn_criterion, tst_criterion, optimizer, lr_adjuster): 138 | super().__init__(trn_criterion, tst_criterion, optimizer, lr_adjuster) 139 | 140 | def train(self, model, loader, thresholds, epoch, metrics): 141 | model.train() 142 | n_batches = len(loader) 143 | cur_iter = int((epoch-1) * n_batches)+1 144 | metric_totals = {m.name:0 for m in metrics} 145 | 146 | for inputs, targets, aux_targets, _ in loader: 147 | if len(targets.size()) == 1: 148 | targets = targets.float().view(-1, 1) 149 | inputs = Variable(inputs.cuda(async=True)) 150 | targets = Variable(targets.cuda(async=True)) 151 | aux_targets = Variable(aux_targets.cuda(async=True)) 152 | 153 | output = model(inputs) 154 | 155 | model.zero_grad() 156 | 157 | loss = self.trn_criterion(output, targets, aux_targets) 158 | loss_data = loss.data[0] 159 | labels = targets.data.cpu().numpy() 160 | probs = output.data.cpu().numpy() 161 | preds = pred_utils.get_predictions(probs, thresholds) 162 | 163 | for m in metrics: 164 | score = m.evaluate(loss_data, preds, probs, labels) 165 | metric_totals[m.name] += score 166 | 167 | loss.backward() 168 | self.optimizer.step() 169 | 170 | if self.lr_adjuster.iteration_type == 'mini_batch': 171 | self.lr_adjuster.adjust(self.optimizer, cur_iter) 172 | cur_iter += 1 173 | 174 | for m in metrics: 175 | metric_totals[m.name] /= n_batches 176 | 177 | return metric_totals 178 | 179 | 180 | class MultiInputTrainer(Trainer): 181 | def __init__(self, trn_criterion, tst_criterion, optimizer, lr_adjuster): 182 | super().__init__(trn_criterion, tst_criterion, optimizer, lr_adjuster) 183 | 184 | def train(self, model, loader, thresholds, epoch, metrics): 185 | model.train() 186 | n_batches = len(loader) 187 | cur_iter = int((epoch-1) * n_batches)+1 188 | metric_totals = {m.name:0 for m in metrics} 189 | 190 | for inputs, targets, aux_inputs, _ in loader: 191 | if len(targets.size()) == 1: 192 | targets = targets.float().view(-1, 1) 193 | inputs = Variable(inputs.cuda(async=True)) 194 | aux_inputs = Variable(aux_inputs.cuda(async=True)) 195 | targets = Variable(targets.cuda(async=True)) 196 | 197 | output = model(inputs, aux_inputs) 198 | 199 | model.zero_grad() 200 | 201 | loss = self.trn_criterion(output, targets) 202 | loss_data = loss.data[0] 203 | labels = targets.data.cpu().numpy() 204 | probs = output.data.cpu().numpy() 205 | preds = pred_utils.get_predictions(probs, thresholds) 206 | 207 | for m in metrics: 208 | score = m.evaluate(loss_data, preds, probs, labels) 209 | metric_totals[m.name] += score 210 | 211 | loss.backward() 212 | self.optimizer.step() 213 | 214 | if self.lr_adjuster.iteration_type == 'mini_batch': 215 | self.lr_adjuster.adjust(self.optimizer, cur_iter) 216 | cur_iter += 1 217 | 218 | for m in metrics: 219 | metric_totals[m.name] /= n_batches 220 | 221 | return metric_totals 222 | 223 | def test(self, model, loader, thresholds, metrics): 224 | model.eval() 225 | 226 | loss = 0 227 | probs = [] 228 | labels = [] 229 | metric_totals = {m.name:0 for m in metrics} 230 | 231 | for inputs, targets, aux_inputs, _ in loader: 232 | if len(targets.size()) == 1: 233 | targets = targets.float().view(-1,1) 234 | inputs = Variable(inputs.cuda(async=True), volatile=True) 235 | aux_inputs = Variable(aux_inputs.cuda(async=True), volatile=True) 236 | targets = Variable(targets.cuda(async=True), volatile=True) 237 | 238 | output = model(inputs, aux_inputs) 239 | 240 | loss += self.tst_criterion(output, targets).data[0] 241 | probs = np.vstack([probs, output.data.cpu().numpy()]) 242 | labels = np.vstack([labels, targets.data.cpu().numpy()]) 243 | 244 | loss /= len(loader) 245 | preds = pred_utils.get_predictions(probs, thresholds) 246 | for m in metrics: 247 | score = m.evaluate(loss, preds, probs, labels) 248 | metric_totals[m.name] = score 249 | 250 | return metric_totals 251 | 252 | 253 | class ImageTargetTrainer(Trainer): 254 | def __init__(self, trn_criterion, tst_criterion, optimizer, 255 | lr_adjuster, n_classes, n_batches_per_step=1): 256 | super().__init__(trn_criterion, tst_criterion, optimizer, lr_adjuster) 257 | self.n_batches_per_step = n_batches_per_step 258 | 259 | def train(self, model, loader, thresholds, epoch, n_epochs, 260 | metrics): 261 | model.train() 262 | n_batches = len(loader) 263 | cur_iter = int((epoch-1) * n_batches)+1 264 | metric_totals = {m.name:0 for m in metrics} 265 | 266 | for inputs, targets, _, _ in loader: 267 | inputs = Variable(inputs.cuda(async=True)) 268 | targets = Variable(targets.cuda(async=True)) 269 | 270 | output = model(inputs) 271 | 272 | loss = self.trn_criterion(output, targets) 273 | loss_data = loss.data[0] 274 | labels = targets.data.cpu().numpy() 275 | probs = output.data.cpu().numpy() 276 | preds = pred_utils.get_predictions(probs, thresholds) 277 | 278 | for m in metrics: 279 | score = m.evaluate(loss_data, preds, probs, labels) 280 | metric_totals[m.name] += score 281 | 282 | ## Backprop (Calculate gradient) 283 | loss.backward() 284 | 285 | ## Update gradient 286 | if cur_iter % self.n_batches_per_step == 0: 287 | self.optimizer.step() 288 | model.zero_grad() 289 | 290 | if self.lr_adjuster.iteration_type == 'mini_batch': 291 | self.lr_adjuster.adjust(self.optimizer, cur_iter) 292 | cur_iter += 1 293 | 294 | for m in metrics: 295 | metric_totals[m.name] /= n_batches 296 | 297 | return metric_totals 298 | 299 | def test(self, model, loader, thresholds, metrics): 300 | model.eval() 301 | n_batches = len(loader) 302 | metric_totals = {m.name:0 for m in metrics} 303 | 304 | for inputs, targets, _, _ in loader: 305 | inputs = Variable(inputs.cuda(async=True), volatile=True) 306 | targets = Variable(targets.cuda(async=True), volatile=True) 307 | 308 | output = model(inputs) 309 | 310 | loss = self.tst_criterion(output, targets) 311 | loss_data = loss.data[0] 312 | labels = targets.data.cpu().numpy() 313 | probs = output.data.cpu().numpy() 314 | preds = pred_utils.get_predictions(probs, thresholds) 315 | 316 | for m in metrics: 317 | score = m.evaluate(loss_data, preds, probs, labels) 318 | metric_totals[m.name] += score 319 | 320 | for m in metrics: 321 | metric_totals[m.name] /= n_batches 322 | 323 | return metric_totals -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import objgraph 3 | import resource 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import time 10 | import math 11 | 12 | from predictions import pred_utils 13 | import constants as c 14 | 15 | 16 | 17 | def train_model(model, dataloader, thresholds, optimizer, criterion, 18 | lr_adjuster, epoch, n_epochs, metrics=[]): 19 | model.train() 20 | n_batches = len(dataloader) 21 | cur_iter = int((epoch-1) * n_batches)+1 22 | total_iter = int(n_batches * n_epochs) 23 | metric_totals = {m.name:0 for m in metrics} 24 | 25 | for inputs, targets, img_paths in dataloader: 26 | if len(targets.size()) == 1: 27 | targets = targets.float().view(-1, 1) 28 | inputs = Variable(inputs.cuda(async=True)) 29 | targets = Variable(targets.cuda(async=True)) 30 | 31 | ## Forward Pass 32 | output = model(inputs) 33 | 34 | ## Clear Gradients 35 | model.zero_grad() 36 | 37 | # Metrics 38 | loss = criterion(output, targets) 39 | loss_data = loss.data[0] 40 | labels = targets.data.cpu().numpy() 41 | probs = output.data.cpu().numpy() 42 | preds = pred_utils.get_predictions(probs, thresholds) 43 | 44 | for metric in metrics: 45 | score = metric.evaluate(loss_data, preds, probs, labels) 46 | metric_totals[metric.name] += score 47 | 48 | ## Backprop 49 | loss.backward() 50 | optimizer.step() 51 | 52 | ### Adjust Lr ### 53 | if lr_adjuster.iteration_type == 'mini_batch': 54 | lr_adjuster.adjust(optimizer, cur_iter) 55 | cur_iter += 1 56 | 57 | for metric in metrics: 58 | metric_totals[metric.name] /= n_batches 59 | 60 | return metric_totals 61 | 62 | 63 | def test_model(model, loader, thresholds, n_classes, criterion, metrics): 64 | model.eval() 65 | 66 | loss = 0 67 | probs = np.empty((0, n_classes)) 68 | labels = np.empty((0, n_classes)) 69 | metric_totals = {m.name:0 for m in metrics} 70 | 71 | for inputs, targets, img_paths in loader: 72 | if len(targets.size()) == 1: 73 | targets = targets.float().view(-1,1) 74 | inputs = Variable(inputs.cuda(async=True), volatile=True) 75 | targets = Variable(targets.cuda(async=True), volatile=True) 76 | 77 | output = model(inputs) 78 | 79 | loss += criterion(output, targets).data[0] 80 | probs = np.vstack([probs, output.data.cpu().numpy()]) 81 | labels = np.vstack([labels, targets.data.cpu().numpy()]) 82 | 83 | loss /= len(loader) 84 | preds = pred_utils.get_predictions(probs, thresholds) 85 | for metric in metrics: 86 | score = metric.evaluate(loss, preds, probs, labels) 87 | metric_totals[metric.name] = score 88 | 89 | return metric_totals 90 | 91 | 92 | def early_stop(epoch, best_epoch, patience): 93 | return (epoch - best_epoch) > patience 94 | 95 | 96 | def log_memory(step): 97 | gc.collect() 98 | max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // 1024 99 | print("Memory usage ({:s}): {:.2f} MB\n".format(step, max_mem_used)) 100 | return max_mem_used 101 | 102 | 103 | def log_trn_msg(logger, start_time, trn_metrics, lr, epoch): 104 | epoch_msg = 'Epoch {:d}'.format(epoch) 105 | metric_msg = get_metric_msg(logger, c.TRAIN, trn_metrics, lr) 106 | time_msg = get_time_msg(start_time) 107 | combined = epoch_msg + '\n' + metric_msg + time_msg 108 | logger.info(combined) 109 | return combined 110 | 111 | 112 | def log_val_msg(logger, start_time, trn_metrics, lr): 113 | metric_msg = get_metric_msg(logger, c.VAL, trn_metrics, lr) 114 | time_msg = get_time_msg(start_time) 115 | combined = metric_msg + time_msg 116 | logger.info(combined) 117 | return combined 118 | 119 | 120 | def get_metric_msg(logger, dset, metrics_dict, lr=0): 121 | msg = dset.capitalize() + ' - ' 122 | for name in metrics_dict.keys(): 123 | metric_str = ('{:.4f}').format(metrics_dict[name]).lstrip('0') 124 | msg += ('{:s} {:s} | ').format(name, metric_str) 125 | msg += 'LR ' + '{:.6f}'.format(lr).rstrip('0').lstrip('0') + ' | ' 126 | return msg 127 | 128 | 129 | def get_time_msg(start_time): 130 | time_elapsed = time.time() - start_time 131 | msg = 'Time {:.1f}m {:.2f}s'.format( 132 | time_elapsed // 60, time_elapsed % 60) 133 | return msg 134 | 135 | 136 | def load_optim_params(optim, fpath): 137 | state = torch.load(fpath) 138 | optim.load_state_dict(state['state_dict']) 139 | 140 | 141 | def save_optim_params(optim, fpath, epoch=None, name=None): 142 | torch.save({ 143 | 'name': name, 144 | 'epoch': epoch, 145 | 'state_dict': optim.state_dict() 146 | }, fpath) 147 | 148 | 149 | def load_optim(fpath): 150 | return torch.load(fpath) 151 | 152 | 153 | def save_optim(optim, fpath): 154 | torch.save(optim, fpath) 155 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bfortuner/pytorch-federated-learning/99cc406a361f23a34aa4576c8781274fd0312158/utils/__init__.py -------------------------------------------------------------------------------- /utils/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from glob import glob 4 | import shutil 5 | import gzip 6 | import pickle 7 | import json 8 | from contextlib import closing 9 | from zipfile import ZipFile, ZIP_DEFLATED 10 | import re 11 | import bcolz 12 | 13 | 14 | 15 | def get_fnames_from_fpaths(fpaths): 16 | fnames = [] 17 | for f in fpaths: 18 | if isinstance(f, tuple): 19 | f = f[0] 20 | fnames.append(os.path.basename(f)) 21 | return fnames 22 | 23 | 24 | def get_matching_files_in_dir(dirpath, regex): 25 | fpaths = glob(os.path.join(dirpath,'*.*')) 26 | match_objs, match_fpaths = [], [] 27 | for i in range(len(fpaths)): 28 | match = re.search(regex, fpaths[i]) 29 | if match is not None: 30 | match_objs.append(match) 31 | match_fpaths.append(fpaths[i]) 32 | return match_objs, match_fpaths 33 | 34 | 35 | def zipdir(basedir, archivename): 36 | assert os.path.isdir(basedir) 37 | with closing(ZipFile(archivename, "w", ZIP_DEFLATED)) as z: 38 | for root, dirs, files in os.walk(basedir): 39 | #NOTE: ignore empty directories 40 | for fn in files: 41 | absfn = os.path.join(root, fn) 42 | zfn = absfn[len(basedir)+len(os.sep):] #XXX: relative path 43 | z.write(absfn, zfn) 44 | 45 | 46 | def unzipdir(archive_path, dest_path, remove=True): 47 | ZipFile(archive_path).extractall(dest_path) 48 | if remove: 49 | os.remove(archive_path) 50 | 51 | 52 | def save_json(fpath, dict_): 53 | with open(fpath, 'w') as f: 54 | json.dump(dict_, f, indent=4, ensure_ascii=False) 55 | 56 | 57 | def load_json(fpath): 58 | with open(fpath, 'r') as f: 59 | json_ = json.load(f) 60 | return json_ 61 | 62 | 63 | def pickle_obj(obj, fpath): 64 | with open(fpath, 'wb') as f: 65 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 66 | 67 | 68 | def unpickle_obj(fpath): 69 | with open(fpath, 'rb') as f: 70 | return pickle.load(f) 71 | 72 | 73 | def get_fname_from_fpath(fpath): 74 | return os.path.basename(fpath) 75 | 76 | 77 | def get_paths_to_files(root, file_ext=None, sort=True, strip_ext=False): 78 | filepaths = [] 79 | fnames = [] 80 | for (dirpath, dirnames, filenames) in os.walk(root): 81 | filepaths.extend(os.path.join(dirpath, f) 82 | for f in filenames if file_ext is None or f.endswith(file_ext)) 83 | fnames.extend([f for f in filenames if file_ext is None or f.endswith(file_ext)]) 84 | if strip_ext: 85 | fnames = [os.path.splitext(f)[0] for f in fnames] 86 | if sort: 87 | return sorted(filepaths), sorted(fnames) 88 | return filepaths, fnames 89 | 90 | 91 | def get_random_image_path(dir_path): 92 | filepaths = get_paths_to_files(dir_path)[0] 93 | return filepaths[random.randrange(len(filepaths))] 94 | 95 | 96 | def save_obj(obj, out_fpath): 97 | with open(out_fpath, 'wb') as f: 98 | pickle.dump(obj, f) 99 | 100 | 101 | def load_obj(fpath): 102 | return pickle.load(open(fpath, 'rb')) 103 | 104 | 105 | def save_bcolz_array(fpath, arr): 106 | c=bcolz.carray(arr, rootdir=fpath, mode='w') 107 | c.flush() 108 | 109 | 110 | def load_bcolz_array(fpath): 111 | return bcolz.open(fpath)[:] 112 | 113 | 114 | def compress_file(fpath): 115 | gzip_fpath = fpath+'.gz' 116 | with open(fpath, 'rb') as f_in: 117 | with gzip.open(gzip_fpath, 'wb') as f_out: 118 | shutil.copyfileobj(f_in, f_out) 119 | return gzip_fpath 120 | 121 | 122 | def write_lines(fpath, lines, compress=False): 123 | lines_str = '\n'.join(lines) 124 | if compress: 125 | fpath += '.gz' 126 | lines_str = str.encode(lines_str) 127 | f = gzip.open(fpath, 'wb') 128 | else: 129 | f = open(fpath, 'w') 130 | f.write(lines_str) 131 | f.close() 132 | return fpath -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | 4 | def gen_unique_id(prefix='', length=5): 5 | return prefix + str(uuid.uuid4()).upper().replace('-','')[:length] 6 | 7 | def get_class_name(obj): 8 | invalid_class_names = ['function'] 9 | classname = obj.__class__.__name__ 10 | if classname is None or classname in invalid_class_names: 11 | classname = obj.__name__ 12 | return classname 13 | 14 | def dict_to_html(dd, level=0): 15 | """ 16 | Convert dict to html using basic html tags 17 | """ 18 | import simplejson 19 | text = '' 20 | for k, v in dd.items(): 21 | text += '
' + ' '*(4*level) + '%s: %s' % (k, dict_to_html(v, level+1) if isinstance(v, dict) else (simplejson.dumps(v) if isinstance(v, list) else v)) 22 | return text 23 | 24 | def dict_to_html_ul(dd, level=0): 25 | """ 26 | Convert dict to html using ul/li tags 27 | """ 28 | import simplejson 29 | text = '' 33 | return text 34 | 35 | 36 | -------------------------------------------------------------------------------- /utils/imgs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from skimage import io 5 | from PIL import Image, ImageFilter 6 | from scipy import ndimage 7 | import cv2 8 | import scipy.misc 9 | import matplotlib.image as mpimg 10 | import matplotlib as mpl 11 | import matplotlib.pyplot as plt 12 | import torchvision.transforms as transforms 13 | import torchsample 14 | from torch.utils.data import DataLoader, TensorDataset 15 | 16 | import config as cfg 17 | import constants as c 18 | from datasets import metadata 19 | from . import files 20 | 21 | 22 | CLASS_COLORS = { 23 | 'green': (0, 128, 0), 24 | 'red': (128, 0, 0), 25 | 'blue': (0, 0, 128), 26 | 'black': (0, 0, 0), 27 | 'white': (255, 255, 255), 28 | 'grey':(128, 128, 128), 29 | } 30 | 31 | 32 | def is_image_file(filename): 33 | return any(filename.endswith(extension) for extension in c.IMG_EXTS) 34 | 35 | 36 | def load_rgb_pil(img_path): 37 | return Image.open(img_path).convert('RGB') 38 | 39 | 40 | def load_tif_as_arr(img_path): 41 | return io.imread(img_path) 42 | 43 | 44 | def load_img_as_arr(img_path): 45 | return plt.imread(img_path) 46 | 47 | 48 | def load_img_as_tensor(img_path): 49 | img_arr = load_img_as_arr(img_path) 50 | return transforms.ToTensor()(img_arr) 51 | 52 | 53 | def load_img_as_pil(img_path): 54 | return Image.open(img_path).convert('RGB') 55 | 56 | 57 | def save_pil_img(pil_img, fpath): 58 | pil_img.save(fpath) 59 | 60 | 61 | def save_arr(arr, fpath): 62 | scipy.misc.imsave(fpath, arr) 63 | 64 | 65 | def norm_meanstd(arr, mean, std): 66 | return (arr - mean) / std 67 | 68 | 69 | def denorm_meanstd(arr, mean, std): 70 | return (arr * std) + mean 71 | 72 | 73 | def norm255_tensor(arr): 74 | """Given a color image/where max pixel value in each channel is 255 75 | returns normalized tensor or array with all values between 0 and 1""" 76 | return arr / 255. 77 | 78 | 79 | def denorm255_tensor(arr): 80 | return arr * 255. 81 | 82 | 83 | def plot_img_arr(arr, fs=(6,6), title=None): 84 | plt.figure(figsize=fs) 85 | plt.imshow(arr.astype('uint8')) 86 | plt.title(title) 87 | plt.show() 88 | 89 | 90 | def plot_img_tensor(tns, fs=(6,6), title=None): 91 | tns = denorm255_tensor(tns) 92 | arr = tns.numpy().transpose((1,2,0)) 93 | plot_img_arr(arr, fs, title) 94 | 95 | 96 | def tensor_to_arr(tns): 97 | tns = denorm255_tensor(tns) 98 | return tns.numpy().transpose((1,2,0)) 99 | 100 | 101 | def plot_img_from_fpath(img_path, fs=(8,8), title=None): 102 | plt.figure(figsize=fs) 103 | plt.imshow(plt.imread(img_path)) 104 | plt.title(title) 105 | plt.show() 106 | 107 | 108 | def plot_meanstd_normed_tensor(tns, mean, std, fs=(6,6), title=None): 109 | """If normalized with mean/std""" 110 | tns = denorm255_tensor(tns) 111 | arr = tns.numpy().transpose((1, 2, 0)) 112 | arr = denorm_meanstd(arr, mean, std) 113 | plt.figure(figsize=fs) 114 | plt.imshow(arr) 115 | if title: 116 | plt.title(title) 117 | plt.show() 118 | 119 | 120 | def get_mean_std_of_dataset(dir_path, sample_size=5): 121 | fpaths, fnames = files.get_paths_to_files(dir_path) 122 | random.shuffle(fpaths) 123 | total_mean = np.array([0.,0.,0.]) 124 | total_std = np.array([0.,0.,0.]) 125 | for f in fpaths[:sample_size]: 126 | if 'tif' in f: 127 | img_arr = io.imread(f) 128 | else: 129 | img_arr = load_img_as_arr(f) 130 | mean = np.mean(img_arr, axis=(0,1)) 131 | std = np.std(img_arr, axis=(0,1)) 132 | total_mean += mean 133 | total_std += std 134 | avg_mean = total_mean / sample_size 135 | avg_std = total_std / sample_size 136 | print("mean: {}".format(avg_mean), "stdev: {}".format(avg_std)) 137 | return avg_mean, avg_std 138 | 139 | 140 | def plot_binary_mask(arr, threshold=0.5, title=None, color=(255,255,255)): 141 | arr = format_1D_binary_mask(arr.copy()) 142 | print(arr.shape) 143 | for i in range(3): 144 | arr[:,:,i][arr[:,:,i] >= threshold] = color[i] 145 | arr[arr < threshold] = 0 146 | plot_img_arr(arr, title=title) 147 | 148 | 149 | def format_1D_binary_mask(mask): 150 | if len(mask.shape) == 2: 151 | mask = np.expand_dims(mask, 0) 152 | mask = np.stack([mask,mask,mask],axis=1).squeeze().transpose(1,2,0) 153 | return mask.astype('float32') 154 | 155 | 156 | def plot_binary_mask_overlay(mask, img_arr, fs=(18,18), title=None): 157 | mask = format_1D_binary_mask(mask.copy()) 158 | fig = plt.figure(figsize=fs) 159 | a = fig.add_subplot(1,2,1) 160 | a.set_title(title) 161 | plt.imshow(img_arr.astype('uint8')) 162 | plt.imshow(mask, cmap='jet', alpha=0.5) # interpolation='none' 163 | plt.show() 164 | 165 | 166 | def plot_binary_mask_overlay(mask, img_arr, fs=(18,18), title=None): 167 | mask = format_1D_binary_mask(mask.copy()) 168 | fig = plt.figure(figsize=fs) 169 | a = fig.add_subplot(1,2,1) 170 | a.set_title(title) 171 | plt.imshow(img_arr.astype('uint8')) 172 | plt.imshow(mask, cmap='jet', alpha=0.5) # interpolation='none' 173 | plt.show() 174 | 175 | 176 | def plot_samples_from_dir(dir_path, shuffle=False): 177 | fpaths, fnames = files.get_paths_to_files(dir_path) 178 | plt.figure(figsize=(16,12)) 179 | start = random.randint(0,len(fpaths)-1) if shuffle else 0 180 | j = 1 181 | for idx in range(start, start+6): 182 | plt.subplot(2,3,j) 183 | plt.imshow(plt.imread(fpaths[idx])) 184 | plt.title(fnames[idx]) 185 | plt.axis('off') 186 | j += 1 187 | 188 | 189 | def plot_sample_preds(fpaths, preds, targs, label_names, shuffle=False): 190 | fnames = files.get_fnames_from_fpaths(fpaths) 191 | plt.figure(figsize=(16,12)) 192 | start = random.randint(0,len(preds)-1) if shuffle else 0 193 | j = 1 194 | for idx in range(start, start+6): 195 | plt.subplot(2,3,j) 196 | pred_tags = 'P: ' + ','.join(metadata.convert_one_hot_to_tags(preds[idx], label_names)) 197 | if targs is not None: 198 | targ_tags = 'T: ' + ','.join(metadata.convert_one_hot_to_tags( 199 | targs[idx], label_names)) 200 | else: 201 | targ_tags = '' 202 | title = '\n'.join([fnames[idx], pred_tags, targ_tags]) 203 | plt.imshow(plt.imread(fpaths[idx])) 204 | plt.title(title) 205 | j += 1 206 | 207 | 208 | def plot_sample_preds_masks(fnames, inputs, preds, fs=(9,9), 209 | n_samples=8, shuffle=False): 210 | start = random.randint(0,len(inputs)-1) if shuffle else 0 211 | for idx in range(start, start+n_samples): 212 | print(fnames[idx]) 213 | img = tensor_to_arr(inputs[idx]) 214 | plot_binary_mask_overlay(preds[idx], img, fs, fnames[idx]) 215 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import imp 4 | import time 5 | 6 | 7 | def get_logger(log_path='', 8 | logger_name='logger', 9 | ch_log_level=logging.ERROR, 10 | fh_log_level=logging.INFO): 11 | logging.shutdown() 12 | imp.reload(logging) 13 | logger = logging.getLogger(logger_name) 14 | logger.setLevel(logging.DEBUG) 15 | 16 | # Console Handler 17 | if ch_log_level: 18 | ch = logging.StreamHandler() 19 | ch.setLevel(ch_log_level) 20 | ch.setFormatter(logging.Formatter('%(message)s')) 21 | logger.addHandler(ch) 22 | 23 | # File Handler 24 | if fh_log_level: 25 | fh = logging.FileHandler(os.path.join(log_path,logger_name+'.log')) 26 | fh.setLevel(fh_log_level) 27 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 28 | fh.setFormatter(formatter) 29 | logger.addHandler(fh) 30 | 31 | return logger 32 | 33 | 34 | def get_time_msg(start_time): 35 | time_elapsed = time.time() - start_time 36 | msg = 'Time {:.1f}m {:.2f}s'.format( 37 | time_elapsed // 60, time_elapsed % 60) 38 | return msg -------------------------------------------------------------------------------- /utils/multitasking.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 2 | from itertools import repeat 3 | 4 | def multithreading(func, args, workers): 5 | begin_time = time.time() 6 | with ThreadPoolExecutor(max_workers=workers) as executor: 7 | res = executor.map(func, args, [begin_time for i in range(len(args))]) 8 | return list(res) 9 | 10 | def multiprocessing(func, args, workers): 11 | begin_time = time.time() 12 | with ProcessPoolExecutor(max_workers=workers) as executor: 13 | res = executor.map(func, args, [begin_time for i in range(len(args))]) 14 | return list(res) -------------------------------------------------------------------------------- /utils/widgets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from glob import glob 4 | from config import PATHS 5 | from experiments import experiment 6 | from experiments import exp_utils 7 | import training.utils as train_utils 8 | import constants as c 9 | 10 | 11 | LATEST = 'latest' 12 | 13 | def load_single_weights(exp_name, epoch): 14 | return epoch 15 | 16 | 17 | def load_multiple_weights(exp_name, epoch): 18 | return list(epoch) 19 | 20 | 21 | def get_weights_path(exp_name): 22 | return os.path.join(PATHS['experiments']['root'], exp_name, 'weights') 23 | 24 | 25 | def load_experiment(name): 26 | exp = experiment.Experiment(name, PATHS['experiments']['root']) 27 | exp.review() 28 | exp.history.load_history_from_file(c.VAL) 29 | return exp 30 | 31 | 32 | def get_f2_scores_by_epoch(exp_name, sort_by_score): 33 | exp = load_experiment(exp_name) 34 | weight_fpaths = exp_utils.get_weights_fpaths(exp.weights_dir) 35 | epochs = exp_utils.get_weight_epochs_from_fpaths(weight_fpaths) 36 | epochs.insert(0,'latest') 37 | f2_scores = exp.history.metrics_history[c.F2_SCORE][c.VAL] 38 | score_by_epoch = {} 39 | for epoch in epochs[1:]: 40 | score_by_epoch[epoch] = float('{:4g}'.format(f2_scores[epoch-1])) 41 | score_by_epoch[LATEST] = float('{:4g}'.format(f2_scores[-1])) 42 | if sort_by_score: 43 | sorted_epochs_by_score = {} 44 | sorted_epochs = sorted(score_by_epoch, key=score_by_epoch.get, 45 | reverse=sort_by_score) 46 | for epoch in sorted_epochs: 47 | if epoch == LATEST: 48 | sorted_epochs_by_score[epoch] = '{:4g}'.format(f2_scores[-1]) 49 | else: 50 | score = '{:4g}'.format(f2_scores[epoch-1]) 51 | sorted_epochs_by_score[epoch] = score 52 | return append_score_wpaths(exp_name, sorted_epochs_by_score) 53 | return append_score_wpaths(exp_name, score_by_epoch) 54 | 55 | def append_score_wpaths(exp_name, epoch_dict): 56 | new_dict = {} 57 | for epoch in epoch_dict.keys(): 58 | new_key = '{:} ({:4g})'.format( 59 | epoch, float(epoch_dict[epoch])) 60 | wpath = get_weights_fpath(epoch, exp_name) 61 | new_dict[new_key] = wpath 62 | return new_dict 63 | 64 | 65 | def get_weights_fpath(epoch, exp_name): 66 | weights_path = get_weights_path(exp_name) 67 | if epoch == LATEST: 68 | return os.path.join(weights_path, c.LATEST_WEIGHTS_FNAME) 69 | return weights_path+'/weights-'+str(epoch)+'.pth' 70 | 71 | 72 | def get_weights_epoch_path_dict(exp_name): 73 | wtdict = {} 74 | epochs = get_weights_epochs(exp_name) 75 | for epoch in epochs: 76 | wtdict[epoch] = get_weights_fpath(epoch, exp_name) 77 | return wtdict 78 | -------------------------------------------------------------------------------- /visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /visualizers/kibana.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import config 4 | import constants as c 5 | import clients.client_constants as cc 6 | import clients.es_client as es 7 | import copy 8 | 9 | 10 | class Kibana(): 11 | 12 | def __init__(self, exp_name): 13 | self.name = exp_name 14 | self.classname = 'Kibana' 15 | 16 | def init(self, exp_config): 17 | assert config.ES_ENABLED is True 18 | assert es.ping() is True 19 | 20 | def update(self, exp_config, exp_history, msg=None): 21 | es.upload_experiment_history(exp_config, exp_history) 22 | es.upload_experiment_config(exp_config) 23 | 24 | 25 | def load(config): 26 | return Kibana(config.name) 27 | 28 | 29 | -------------------------------------------------------------------------------- /visualizers/vis_utils.py: -------------------------------------------------------------------------------- 1 | from . import kibana 2 | from . import viz 3 | 4 | 5 | VISUALIZERS = { 6 | 'visdom': viz.load, 7 | 'kibana': kibana.load 8 | } 9 | 10 | def get_visualizer(config, name): 11 | return VISUALIZERS[name.lower()](config) 12 | 13 | 14 | def get_visualizers_from_config(config): 15 | visualizers = [] 16 | for v in config.visualizers: 17 | visualizer = get_visualizer(config, v) 18 | visualizers.append(visualizer) 19 | return visualizers 20 | -------------------------------------------------------------------------------- /visualizers/viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from visdom import Visdom 3 | import constants as c 4 | 5 | class Viz(): 6 | 7 | def __init__(self, exp_name): 8 | self.name = exp_name 9 | self.classname = 'Visdom' 10 | self.viz = None 11 | self.plots = None 12 | 13 | def init(self, exp_config): 14 | self.viz = Visdom() 15 | self.plots = self.init_visdom_plots(exp_config) 16 | 17 | def update(self, exp_config, exp_history, msg=None): 18 | epoch = exp_config.progress['epoch'] 19 | metrics_history = exp_history.metrics_history 20 | 21 | for name in exp_config.metrics: 22 | trn_arr = np.array(metrics_history[name][c.TRAIN]) 23 | val_arr = np.array(metrics_history[name][c.VAL]) 24 | self.update_metric_plot(name, trn_arr, val_arr, epoch, 25 | ylabel=name) 26 | 27 | for metric in exp_config.aux_metrics: 28 | name = metric['name'] 29 | data_arr = np.array(metrics_history[name]) 30 | self.update_aux_metric_plot(name, data_arr, epoch, 31 | ylabel=metric['units']) 32 | self.update_summary_plot(msg) 33 | 34 | def init_visdom_plots(self, exp_config): 35 | plots = {} 36 | for name in exp_config.metrics: 37 | plot = self.init_train_val_metric_plot(name, name) 38 | plots[name] = plot 39 | for aux_metric in exp_config.aux_metrics: 40 | name = aux_metric['name'] 41 | plot = self.init_auxiliary_metric_plot(name, aux_metric['units']) 42 | plots[name] = plot 43 | plots['summary'] = self.init_txt_plot('summary') 44 | return plots 45 | 46 | def init_train_val_metric_plot(self, title, ylabel, xlabel='epoch'): 47 | return self.viz.line( 48 | X=np.array([1]), 49 | Y=np.array([[1, 1]]), 50 | opts=dict( 51 | xlabel=xlabel, 52 | ylabel=ylabel, 53 | title=title, 54 | legend=['Train', 'Valid'] 55 | ), 56 | env=self.name 57 | ) 58 | 59 | def init_auxiliary_metric_plot(self, title, ylabel, xlabel='epoch'): 60 | return self.viz.line( 61 | X=np.array([1]), 62 | Y=np.array([1]), 63 | opts=dict( 64 | xlabel=xlabel, 65 | ylabel=ylabel, 66 | title=title, 67 | legend=[] 68 | ), 69 | env=self.name 70 | ) 71 | 72 | def init_txt_plot(self, title): 73 | return self.viz.text( 74 | "Initializing.. " + title, 75 | env=self.name 76 | ) 77 | 78 | def viz_epochs(self, cur_epoch): 79 | # Epochs start at 1 80 | epochs = np.arange(1, cur_epoch+1) 81 | return np.stack([epochs, epochs],1) 82 | 83 | def update_metric_plot(self, metric, train_arr, val_arr, 84 | epoch, ylabel, xlabel='epoch'): 85 | data = np.stack([train_arr, val_arr], 1) 86 | window = self.plots[metric] 87 | return self.viz.line( 88 | X=self.viz_epochs(epoch), 89 | Y=data, 90 | win=window, 91 | env=self.name, 92 | opts=dict( 93 | xlabel=xlabel, 94 | ylabel=ylabel, 95 | title=metric, 96 | legend=['Train', 'Valid'] 97 | ), 98 | ) 99 | 100 | def update_aux_metric_plot(self, metric, data_arr, epoch, ylabel, 101 | xlabel='epoch', legend=[]): 102 | window = self.plots[metric] 103 | return self.viz.line( 104 | X=self.viz_epochs(epoch)[:,0], 105 | Y=data_arr, 106 | win=window, 107 | env=self.name, 108 | opts=dict( 109 | xlabel=xlabel, 110 | ylabel=ylabel, #metric.units, 111 | title=metric, 112 | legend=legend 113 | ), 114 | ) 115 | 116 | def update_summary_plot(self, msg): 117 | window = self.plots['summary'] 118 | return self.viz.text( 119 | msg, 120 | win=window, 121 | env=self.name 122 | ) 123 | 124 | 125 | def load(config): 126 | return Viz(config.name) 127 | --------------------------------------------------------------------------------