├── .gitmodules ├── oil ├── __init__.py ├── logging │ ├── __init__.py │ └── lazyLogger.py ├── tuning │ ├── __init__.py │ ├── tests │ │ └── study_test.py │ ├── slurm_example.py │ ├── args.py │ ├── slurmExecutor.py │ ├── configGenerator.py │ └── study.py ├── utils │ ├── __init__.py │ ├── mytqdm.py │ ├── optim.py │ ├── losses.py │ ├── metrics.py │ └── parallel.py ├── model_trainers │ ├── __init__.py │ ├── piModel.py │ ├── classifier.py │ ├── graphssl.py │ ├── cycleGan.py │ ├── cGan.py │ ├── vat.py │ ├── segmenter.py │ ├── gan.py │ └── trainer.py ├── architectures │ ├── img2img │ │ └── __init__.py │ ├── __init__.py │ ├── img_gen │ │ ├── __init__.py │ │ ├── ganBase.py │ │ ├── resnetgan.py │ │ └── conditionalgan.py │ ├── pointcloud │ │ └── __init__.py │ ├── img_classifiers │ │ ├── __init__.py │ │ ├── tests │ │ │ └── test_networks.py │ │ ├── vgg.py │ │ ├── smallconv.py │ │ ├── wide_resnet.py │ │ ├── preresnet.py │ │ ├── densenet.py │ │ ├── networkparts.py │ │ └── shake_shake.py │ └── parts │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── blocks.py │ │ ├── denseblocks.py │ │ ├── CoordConv.py │ │ ├── antialiasing.py │ │ └── deconv.py ├── recipes │ ├── __init__.py │ ├── exampleHyperSearch.py │ ├── simpleTrial.py │ ├── trainGan.py │ ├── trainPi.py │ └── trainCGan.py └── datasetup │ ├── __init__.py │ ├── camvid.py │ ├── celeba.py │ ├── joint_transforms.py │ ├── dataloaders.py │ ├── datasets.py │ └── augLayers.py ├── .gitignore ├── setup.py ├── lincense.md └── README.md /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /oil/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /oil/logging/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /oil/tuning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /oil/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | *.pyc 4 | __pycache__ 5 | *.np 6 | *.t 7 | .ipynb_checkpoints/ 8 | *.pt 9 | *.ckpt 10 | *.pdf 11 | *.lexicog 12 | *.npz 13 | *.npy 14 | *.s 15 | *.S 16 | *mint 17 | .trainer 18 | experiments/runs/ 19 | -------------------------------------------------------------------------------- /oil/model_trainers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from .trainer import Trainer 5 | from .classifier import Classifier,Regressor 6 | from .piModel import PiModel 7 | from .gan import Gan 8 | from .cGan import cGan, Pix2Pix 9 | 10 | __all__ = ['Trainer','Classifier','Regressor','PiModel','Gan','cGan','Pix2Pix'] -------------------------------------------------------------------------------- /oil/architectures/img2img/__init__.py: -------------------------------------------------------------------------------- 1 | #from .densenetFC import FCDenseNet57, FCDenseNet67, FCDenseNet103 2 | # from .fcn32s import FCN32s 3 | # from .fcn16s import FCN16s 4 | # from .fcn8s import FCN8s 5 | # from .fcn8s import FCN8sAtOnce 6 | # from .vgg import VGG16 7 | #from .deeplab import DeepLab 8 | #'FCDenseNet57', 'FCDenseNet67', 'FCDenseNet103' 9 | #__all__ = ['DeepLab'] -------------------------------------------------------------------------------- /oil/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | __all__ = [] 4 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 5 | module = importlib.import_module('.'+module_name,package=__name__) 6 | try: 7 | globals().update({k: getattr(module, k) for k in module.__all__}) 8 | __all__ += module.__all__ 9 | except AttributeError: continue -------------------------------------------------------------------------------- /oil/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | __all__ = [] 4 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 5 | module = importlib.import_module('.'+module_name,package=__name__) 6 | try: 7 | globals().update({k: getattr(module, k) for k in module.__all__}) 8 | __all__ += module.__all__ 9 | except AttributeError: continue -------------------------------------------------------------------------------- /oil/architectures/img_gen/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | __all__ = [] 4 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 5 | module = importlib.import_module('.'+module_name,package=__name__) 6 | try: 7 | globals().update({k: getattr(module, k) for k in module.__all__}) 8 | __all__ += module.__all__ 9 | except AttributeError: continue -------------------------------------------------------------------------------- /oil/architectures/pointcloud/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | __all__ = [] 4 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 5 | module = importlib.import_module('.'+module_name,package=__name__) 6 | try: 7 | globals().update({k: getattr(module, k) for k in module.__all__}) 8 | __all__ += module.__all__ 9 | except AttributeError: continue -------------------------------------------------------------------------------- /oil/datasetup/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | __all__ = [] 4 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 5 | module = importlib.import_module('.'+module_name,package=__name__) 6 | try: 7 | globals().update({k: getattr(module, k) for k in module.__all__}) 8 | __all__ += module.__all__ 9 | except AttributeError: continue 10 | -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | __all__ = [] 4 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 5 | module = importlib.import_module('.'+module_name,package=__name__) 6 | try: 7 | globals().update({k: getattr(module, k) for k in module.__all__}) 8 | __all__ += module.__all__ 9 | except AttributeError: continue -------------------------------------------------------------------------------- /oil/utils/mytqdm.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | try: #tqdm.autonotebook 3 | from tqdm.auto import tqdm#_notebook as tqdm 4 | tqdm.get_lock().locks = [] 5 | # old_print = print 6 | # if tqdm.tqdm.write raises error, use builtin print 7 | # def new_print(*args, **kwargs): 8 | # try: tqdm.write(*map(lambda x: str(x),args), **kwargs) 9 | # except: old_print(*args, ** kwargs) 10 | # inspect.builtins.print = new_print 11 | except ImportError: tqdm = lambda it,*args,**kwargs:it 12 | -------------------------------------------------------------------------------- /oil/utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchcontrib.optim import SWA 3 | 4 | ## Inverse LR wieghting automatic SWA 5 | class AutoSWA(SWA): 6 | def __init__(self,*args,swa_start=0,swa_freq=1000,**kwargs): 7 | super().__init__(*args,swa_start=swa_start,swa_freq=swa_freq,**kwargs) 8 | 9 | def update_swa_group(self,group): 10 | coeff_new = 1/group["lr"] 11 | group["n_avg"] += coeff_new 12 | for p in group['params']: 13 | param_state = self.state[p] 14 | if 'swa_buffer' not in param_state: 15 | param_state['swa_buffer'] = torch.zeros_like(p.data) 16 | buf = param_state['swa_buffer'] 17 | diff = (p.data - buf)*coeff_new/group["n_avg"] 18 | buf.add_(diff) 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup,find_packages 2 | import sys, os 3 | 4 | setup(name="olive-oil-ml", 5 | description="For slowing down deep learning research", 6 | version='0.1.1', 7 | author='Marc Finzi', 8 | author_email='maf388@cornell.edu', 9 | license='MIT', 10 | python_requires='>=3.6', 11 | install_requires=['torch>=1.2','torchvision','pandas','numpy','matplotlib','dill','tqdm>=4.38','natsort','scikit-learn','torchcontrib'], 12 | extras_require = { 13 | 'TBX':['tensorboardX'] 14 | }, 15 | packages=find_packages(),#["oil",],#find_packages() 16 | long_description=open('README.md').read(), 17 | ) 18 | #pathToThisFile = os.path.dirname(os.path.realpath(__file__)) 19 | # add to .bashrc 20 | #sys.path.append(pathToThisFile) 21 | -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/tests/test_networks.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | from oil.datasetup.datasets import CIFAR10 4 | from oil.model_trainers.classifier import Classifier, simpleClassifierTrial 5 | import oil.architectures.img_classifiers as models 6 | from oil.tuning.configGenerator import sample_config 7 | 8 | 9 | class ClassifierNetworkArchitecturesTests(unittest.TestCase): 10 | def test_networks_load_and_train(self): 11 | """ Takes 30 minutes on 1 gpu """ 12 | configs = [{'dataset': CIFAR10, 'num_epochs':1, 13 | 'network': getattr(models,modelname)} for modelname in models.__all__] 14 | Trial = simpleClassifierTrial(strict=True) 15 | for cfg in configs: 16 | outcome = Trial(cfg) 17 | 18 | if __name__=="__main__": 19 | unittest.main() -------------------------------------------------------------------------------- /oil/recipes/exampleHyperSearch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from oil.recipes.simpleTrial import simpleTrial, makeTrainer 3 | from oil.tuning.study import Study 4 | from oil.architectures import layer13s 5 | from oil.tuning.args import argupdated_config 6 | 7 | if __name__=="__main__": 8 | config_spec = copy.deepcopy(makeTrainer.__kwdefaults__) 9 | config_spec.update( 10 | {'network':layer13s,'bs':[50,32,64],'lr':(lambda cfg: .002*cfg['bs']), 11 | 'num_epochs':2,'net_config':{'k':[64,96]},'study_name':'example'} 12 | ) 13 | config_spec = argupdated_config(config_spec) 14 | name = config_spec.pop('study_name') 15 | thestudy = Study(simpleTrial,config_spec,study_name=name, 16 | base_log_dir=config_spec['trainer_config'].get('log_dir',None)) 17 | thestudy.run(3,ordered=False) 18 | print(thestudy.covariates()) 19 | print(thestudy.outcomes) -------------------------------------------------------------------------------- /oil/architectures/parts/__init__.py: -------------------------------------------------------------------------------- 1 | # from .CoordConv import CoordConv 2 | # from .blocks import conv2d,ConvBNrelu,FcBNrelu,ResBlock, DenseBlock,ConcatResBlock,ODEBlock,RNNBlock 3 | # from .blocks import FiLMResBlock,ConcatBottleBlock 4 | # #from .denseblocks import DenseLayer, DenseBlock, TransitionUp,TransitionDown,Bottleneck 5 | # __all__ = ['CoordConv','conv2d','ConvBNrelu','FcBNrelu','ResBlock','DenseBlock','ConcatResBlock', 6 | # 'ODEBlock','RNNBlock','FiLMResBlock','ConcatBottleBlock'] 7 | # #'DenseLayer', 'DenseBlock', 'TransitionUp','TransitionDown','Bottleneck'] 8 | 9 | import importlib 10 | import pkgutil 11 | __all__ = [] 12 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 13 | module = importlib.import_module('.'+module_name,package=__name__) 14 | try: 15 | globals().update({k: getattr(module, k) for k in module.__all__}) 16 | __all__ += module.__all__ 17 | except AttributeError: continue -------------------------------------------------------------------------------- /lincense.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Marc Finzi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /oil/model_trainers/piModel.py: -------------------------------------------------------------------------------- 1 | from .classifier import Classifier 2 | from ..utils.losses import softmax_mse_loss, softmax_mse_loss_both 3 | from ..utils.utils import Eval, izip, icycle 4 | import torch.nn as nn 5 | 6 | class PiModel(Classifier): 7 | def __init__(self, *args, cons_weight=15, 8 | **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.hypers.update({'cons_weight':cons_weight}) 11 | self.dataloaders['train'] = izip(icycle(self.dataloaders['train']),self.dataloaders['_unlab']) 12 | 13 | def unlabLoss(self, x_unlab): 14 | logits1 = self.model(x_unlab) 15 | logits2 = self.model(x_unlab) 16 | cons_loss = softmax_mse_loss(logits1, logits2.detach()) 17 | return cons_loss 18 | 19 | def loss(self, minibatch): 20 | (x_lab, y_lab), x_unlab = minibatch 21 | unlab_loss = self.unlabLoss(x_unlab)*float(self.hypers['cons_weight']) 22 | lab_loss = nn.CrossEntropyLoss()(self.model(x_lab),y_lab) 23 | return lab_loss + unlab_loss 24 | 25 | def logStuff(self, step, minibatch=None): 26 | if minibatch: 27 | extra_metrics = {'Unlab_loss(batch)':self.unlabLoss(minibatch[1]).cpu().item()} 28 | self.logger.add_scalars('metrics',extra_metrics,step) 29 | super().logStuff(step, minibatch) -------------------------------------------------------------------------------- /oil/architectures/parts/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from ...utils.utils import Expression,export,Named 5 | 6 | def Attention(Q,K,V): 7 | """ Assumes Q,K,V have shape (bs,N,d)""" 8 | bs,N,d = K.shape 9 | Kt = K.transpose(-1,-2) 10 | similarity = Q@Kt/np.sqrt(d) 11 | return F.softmax(similarity,dim=-1)@V 12 | 13 | class SelfAttentionHead(nn.Module): 14 | 15 | def __init__(self,inp_channels, outp_channels): 16 | super().__init__() 17 | self.WQ = nn.Linear(inp_channels,outp_channels) 18 | self.WK = nn.Linear(inp_channels,outp_channels) 19 | self.WV = nn.Linear(inp_channels,outp_channels) 20 | def forward(self,X): 21 | """ Assumes X has shape (bs,N,d)""" 22 | return Attention(self.WQ(X),self.WK(X),self.WV(X)) 23 | 24 | class MultiHeadAtt(nn.Module): 25 | def __init__(self,inp_channels,num_heads): 26 | super().__init__() 27 | self.heads = nn.ModuleList([SelfAttentionHead(inp_channels,inp_channels/num_heads) 28 | for _ in range(num_heads)]) 29 | self.WO = nn.Linear(inp_channels,inp_channels) 30 | def forward(self,X): 31 | return self.WO(torch.cat([head(X) for head in self.heads],dim=-1)) 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /oil/tuning/tests/study_test.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import unittest 5 | from oil.datasetup.datasets import SVHN,CIFAR10,CIFAR100 6 | from oil.model_trainers.classifier import Classifier, simpleClassifierTrial 7 | import oil.architectures.img_classifiers as models 8 | from oil.tuning.configGenerator import sample_config, logUniform 9 | from oil.tuning.study import Study 10 | import torch 11 | import dill 12 | 13 | class StudyTests(unittest.TestCase): 14 | def test_study_generates_configs(self): 15 | """ Takes 30 minutes on 1 gpu """ 16 | config_spec = {'dataset': [CIFAR10,CIFAR100], 'num_epochs':1, 17 | 'network': models.layer13s, 18 | 'loader_config': {'amnt_dev':5000,'lab_BS':[16,32,64]}, 19 | 'opt_config':{'lr':logUniform(.03,.3), 20 | 'momentum':lambda cfg: 1/(1-cfg['opt_config']['lr']), 21 | 'weight_decay':lambda cfg: cfg['opt_config']['momentum']*1e-4}, 22 | 'trainer_config':{'log_args':{'no_print':True}}, 23 | } 24 | 25 | Trial = simpleClassifierTrial(strict=True) 26 | cutout_study = Study(Trial,config_spec, slurm_cfg={'time':'00:10:00'}) 27 | save_loc = cutout_study.run(num_trials=2,max_workers=1) 28 | study = torch.load(save_loc,pickle_module=dill) 29 | assert isinstance(study,Study) 30 | 31 | if __name__=="__main__": 32 | unittest.main() -------------------------------------------------------------------------------- /oil/tuning/slurm_example.py: -------------------------------------------------------------------------------- 1 | from oil.tuning.slurmExecutor import SlurmExecutor 2 | import subprocess 3 | import concurrent 4 | import time 5 | import multiprocessing 6 | 7 | # "Worker" functions. 8 | def square(n): 9 | return n * n 10 | def hostinfo(a): 11 | return subprocess.check_output('uname -a', shell=True).decode()#.split() 12 | def gpustat(a): 13 | return subprocess.check_output('gpustat', shell=True).decode()#.split() 14 | def cpu_count(): 15 | return multiprocessing.cpu_count() 16 | def example_1(): 17 | """Square some numbers on remote hosts! 18 | """ 19 | with SlurmExecutor(max_workers=5) as executor: 20 | futures = [executor.submit(square, n) for n in range(15)] 21 | for future in concurrent.futures.as_completed(futures): 22 | print((future.result())) 23 | 24 | def example_2(): 25 | """Get host identifying information about the servers running 26 | our jobs. 27 | """ 28 | with SlurmExecutor(max_workers=5) as executor: 29 | futures = [executor.submit(cpu_count) for n in range(5)] 30 | print('Some cluster nodes:') 31 | for future in concurrent.futures.as_completed(futures): 32 | print(future.result()) 33 | 34 | def example_3(): 35 | """Demonstrates the use of the map() convenience function. 36 | """ 37 | start = time.time() 38 | with SlurmExecutor(max_workers=5,clone_session=False) as exc: 39 | print(''.join(list(exc.map(hostinfo,range(10),chunksize=1)))) 40 | print("Taking a total time of:",time.time()-start) 41 | 42 | if __name__ == '__main__': 43 | #example_1() 44 | #example_2() 45 | example_3() 46 | -------------------------------------------------------------------------------- /oil/model_trainers/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from oil.utils.utils import Eval, cosLr, export 4 | from oil.model_trainers.trainer import Trainer 5 | 6 | @export 7 | class Classifier(Trainer): 8 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy 9 | and getAccuracy (full dataset) """ 10 | 11 | def loss(self, minibatch, model = None): 12 | """ Standard cross-entropy loss """ 13 | x,y = minibatch 14 | if model is None: model = self.model 15 | try: class_weights = self.dataloaders['train'].dataset.class_weights 16 | except AttributeError: class_weights=None 17 | try: ignored_index = self.dataloaders['train'].dataset.ignored_index 18 | except AttributeError: ignored_index=-100 19 | criterion = nn.CrossEntropyLoss(weight=class_weights,ignore_index=ignored_index) 20 | return criterion(model(x),y) 21 | 22 | def metrics(self,loader): 23 | acc = lambda mb: self.model(mb[0]).max(1)[1].type_as(mb[1]).eq(mb[1]).cpu().data.numpy().mean() 24 | return {'Acc':self.evalAverageMetrics(loader,acc)} 25 | 26 | @export 27 | class Regressor(Trainer): 28 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy 29 | and getAccuracy (full dataset) """ 30 | 31 | def loss(self, minibatch, model = None): 32 | """ Standard cross-entropy loss """ 33 | x,y = minibatch 34 | if model is None: model = self.model 35 | return nn.MSELoss()(model(x),y) 36 | 37 | def metrics(self,loader): 38 | mse = lambda mb: nn.MSELoss()(self.model(mb[0]),mb[1]).cpu().data.numpy() 39 | return {'MSE':self.evalAverageMetrics(loader,mse)} 40 | -------------------------------------------------------------------------------- /oil/model_trainers/graphssl.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABCMeta, abstractmethod 3 | import numpy as np 4 | from sklearn.base import BaseEstimator, ClassifierMixin 5 | from functools import partial 6 | from sklearn.metrics.pairwise import rbf_kernel 7 | 8 | def sine_kernel(X1,X2,gamma): 9 | X1n = X1/np.linalg.norm(X1,axis=1)[:,None] 10 | X2n = X2/np.linalg.norm(X2,axis=1)[:,None] 11 | return np.exp(-gamma*(1-X1n@X2n.T)) 12 | 13 | def oh(a, num_classes): 14 | return np.squeeze(np.eye(num_classes)[a.reshape(-1)]) 15 | 16 | class GraphSSL(BaseEstimator,ClassifierMixin,metaclass=ABCMeta): 17 | 18 | def __init__(self,gamma=2,reg=1,kernel='sin'): 19 | super().__init__() 20 | if kernel=='sin': self.kernel = partial(sine_kernel,gamma=gamma) 21 | elif kernel=='rbf': self.kernel = partial(rbf_kernel,gamma=gamma) 22 | elif callable(kernel): self.kernel = kernel 23 | else: raise NotImplementedError(f"Unknown kernel {kernel}") 24 | self.reg=reg 25 | 26 | def fit(self,X,y): 27 | """ Assumes y is -1 for unlabeled """ 28 | n,d = X.shape 29 | c = max(y)+1 30 | Wxx = self.kernel(X,X) 31 | Wxx -= np.diag(np.diag(Wxx)) 32 | D = np.diag(np.sum(Wxx,axis=-1)) 33 | self.dm2 = dm2 = np.sum(Wxx,axis=-1)[:,None]**-.5 34 | L = np.eye(n) - dm2*Wxx*dm2.T 35 | Y = np.zeros((n,c)) 36 | Y[y!=-1] = oh(y[y!=-1],c) 37 | self.X = X 38 | self.Ys = np.linalg.solve(L+self.reg*np.eye(n),Y) 39 | 40 | def predict(self,X_test): 41 | Wtx = self.kernel(X_test,self.X) 42 | dm2t = np.sum(Wtx,axis=1)**-.5 43 | Stx = dm2t[:,None]*Wtx*self.dm2.T 44 | return (Stx@self.Ys).argmax(-1) 45 | -------------------------------------------------------------------------------- /oil/recipes/simpleTrial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.optim import SGD 4 | from oil.utils.utils import LoaderTo, cosLr, islice, export 5 | from oil.tuning.study import train_trial 6 | from oil.datasetup.datasets import CIFAR10, split_dataset 7 | from oil.architectures.img_classifiers import layer13 8 | from oil.utils.parallel import try_multigpu_parallelize 9 | from oil.tuning.args import argupdated_config 10 | from oil.model_trainers.classifier import Classifier 11 | from functools import partial 12 | 13 | def makeTrainer(*,dataset=CIFAR10,network=layer13,num_epochs=100, 14 | bs=50,lr=.1,aug=True,optim=SGD,device='cuda',trainer=Classifier, 15 | split={'train':-1,'val':.1},net_config={},opt_config={}, 16 | trainer_config={'log_dir':None},save=False): 17 | 18 | # Prep the datasets splits, model, and dataloaders 19 | datasets = split_dataset(dataset(f'~/datasets/{dataset}/'),splits=split) 20 | datasets['test'] = dataset(f'~/datasets/{dataset}/', train=False) 21 | 22 | device = torch.device(device) 23 | model = network(num_targets=datasets['train'].num_targets,**net_config).to(device) 24 | if aug: model = torch.nn.Sequential(datasets['train'].default_aug_layers(),model) 25 | model,bs = try_multigpu_parallelize(model,bs) 26 | 27 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'), 28 | num_workers=0,pin_memory=False),device) for k,v in datasets.items()} 29 | dataloaders['Train'] = islice(dataloaders['train'],1+len(dataloaders['train'])//10) 30 | # Add some extra defaults if SGD is chosen 31 | if optim==SGD: opt_config={**{'momentum':.9,'weight_decay':1e-4,'nesterov':True},**opt_config} 32 | opt_constr = partial(optim, lr=lr, **opt_config) 33 | lr_sched = cosLr(num_epochs) 34 | return trainer(model,dataloaders,opt_constr,lr_sched,**trainer_config) 35 | 36 | simpleTrial = train_trial(makeTrainer) 37 | if __name__=='__main__': 38 | simpleTrial(argupdated_config(makeTrainer.__kwdefaults__)) 39 | -------------------------------------------------------------------------------- /oil/recipes/trainGan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.optim import Adam 4 | from oil.model_trainers.gan import Gan 5 | from oil.architectures.img_gen import resnetgan 6 | from oil.utils.utils import LoaderTo, cosLr, islice, dmap,export 7 | from oil.tuning.study import train_trial 8 | from oil.datasetup.datasets import CIFAR10 9 | from oil.utils.parallel import try_multigpu_parallelize 10 | from oil.tuning.args import argupdated_config 11 | from functools import partial 12 | from torchvision import transforms 13 | 14 | def makeTrainer(*,gen=resnetgan.Generator,disc=resnetgan.Discriminator, 15 | num_epochs=500,dataset=CIFAR10,bs=64,lr=2e-4, 16 | device='cuda',net_config={},opt_config={'betas':(.5,.999)}, 17 | trainer_config={'n_disc':5,'log_dir':None},save=False): 18 | 19 | transform = transforms.Compose([transforms.RandomHorizontalFlip(), 20 | transforms.ToTensor(), 21 | transforms.Normalize((.5,.5,.5),(.5,.5,.5))]) 22 | # Prep the datasets splits, model, and dataloaders 23 | datasets = {} 24 | datasets['train'] = dmap(lambda mb: mb[0],dataset(f'~/datasets/{dataset}/',transform=transform)) 25 | datasets['test'] = dmap(lambda mb: mb[0],dataset(f'~/datasets/{dataset}/',train=False,transform=transform)) 26 | 27 | device = torch.device(device) 28 | G = gen(**net_config).to(device) 29 | D = disc(**net_config).to(device) 30 | G,_ = try_multigpu_parallelize(G,bs) 31 | D,bs = try_multigpu_parallelize(D,bs) 32 | 33 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'), 34 | num_workers=0,pin_memory=False),device) for k,v in datasets.items()} 35 | opt_constr = partial(Adam, lr=lr, **opt_config) 36 | lr_sched = lambda e:1 37 | return Gan(G,dataloaders,opt_constr,lr_sched,D=D,**trainer_config) 38 | 39 | GanTrial = train_trial(makeTrainer) 40 | if __name__=='__main__': 41 | GanTrial(argupdated_config(makeTrainer.__kwdefaults__)) -------------------------------------------------------------------------------- /oil/model_trainers/cycleGan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..utils.utils import Eval, join_opts, stateful_zip 4 | from .gan import Gan, hinge_loss_D, hinge_loss_G, cGan 5 | from torch.nn.functional import l1_loss as L1 6 | 7 | class CycleGan(Gan): 8 | def __init__(self,*args,cycle_strength=10,**kwargs): 9 | super().__init__(*args,**kwargs) 10 | self.gan1 = cGan(*args,**kwargs) 11 | self.gan2 = cGan(*args,**kwargs) 12 | # The join needs to be stateful so that gopt and dopt loads/saves work 13 | # or we just add the state in the state dict explicitly 14 | self.g_optimizer = join_opts(self.gan1.g_optimizer,self.gan2.g_optimizer) 15 | self.d_optimizer = join_opts(self.gan1.d_optimizer,self.gan2.d_optimizer) 16 | self.dataloaders['train'] = stateful_zip(self.dataloaders['A'],self.dataloaders['B']) 17 | 18 | 19 | def discLoss(self, data): 20 | return self.gan1.disLoss(data) + self.gan2.disloss(data[::-1]) 21 | 22 | def genLoss(self,data): 23 | """ Adversarial and cycle loss""" 24 | xa,xb = data 25 | adversarial_loss = self.gan1.genLoss(data[::-1]) + self.gan2.genLoss(data) 26 | G1,G2 = self.gan1.G, self.gan2.G 27 | cycle_loss = L1(G2(G1(xa)),xa) + L1(G1(G2(xb)),xb) 28 | return adversarial_loss + self.hypers['cycle_strength']*cycle_loss 29 | 30 | def logStuff(self, i, minibatch=None): 31 | raise NotImplementedError 32 | 33 | def state_dict(self): 34 | extra_state = { 35 | 'gan1_state':self.gan1.state_dict(), 36 | 'gan2_state':self.gan2.state_dict(), 37 | 'AB_loader_state':self.dataloaders['train'].state_dict(), 38 | } 39 | return {**super(cGan,self).state_dict(),**extra_state} 40 | 41 | def load_state_dict(self,state): 42 | super(cGan,self).load_state_dict(state) 43 | self.gan1.load_state_dict(state['gan1_state']) 44 | self.gan2.load_state_dict(state['gan2_state']) 45 | self.dataloaders['train'].load_state_dict() -------------------------------------------------------------------------------- /oil/recipes/trainPi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.optim import SGD 4 | from oil.utils.utils import LoaderTo, cosLr, islice, dmap 5 | from oil.tuning.study import train_trial 6 | from oil.datasetup.datasets import CIFAR10, split_dataset 7 | from oil.architectures.img_classifiers import layer13 8 | from oil.utils.parallel import try_multigpu_parallelize 9 | from oil.tuning.args import argupdated_config 10 | from oil.model_trainers.piModel import PiModel 11 | from functools import partial 12 | 13 | def makeTrainer(*,dataset=CIFAR10,network=layer13,num_epochs=200, 14 | bs=50,lr=.1,optim=SGD,device='cuda', 15 | split={'train':4000,'val':.1},net_config={},opt_config={}, 16 | trainer_config={'cons_weight':.3,'log_dir':None},save=False): 17 | 18 | # Prep the datasets splits, model, and dataloaders 19 | datasets = split_dataset(dataset(f'~/datasets/{dataset}/'),splits=split) 20 | datasets['_unlab'] = dmap(lambda mb: mb[0],dataset(f'~/datasets/{dataset}/')) 21 | datasets['test'] = dataset(f'~/datasets/{dataset}/', train=False) 22 | 23 | device = torch.device(device) 24 | net = network(num_targets=datasets['train'].num_targets,**net_config) 25 | model = torch.nn.Sequential(datasets['train'].default_aug_layers(),net).to(device) 26 | model,bs = try_multigpu_parallelize(model,bs) 27 | 28 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'), 29 | num_workers=0,pin_memory=False),device) for k,v in datasets.items()} 30 | dataloaders['Train'] = islice(dataloaders['train'],1+len(dataloaders['train'])//10) 31 | # Add some extra defaults if SGD is chosen 32 | if optim==SGD: opt_config={**{'momentum':.9,'weight_decay':1e-4,'nesterov':True},**opt_config} 33 | opt_constr = partial(optim, lr=lr, **opt_config) 34 | lr_sched = cosLr(num_epochs) 35 | return PiModel(model,dataloaders,opt_constr,lr_sched,**trainer_config) 36 | 37 | piTrial = train_trial(makeTrainer) 38 | 39 | if __name__=='__main__': 40 | piTrial(argupdated_config(makeTrainer.__kwdefaults__)) 41 | -------------------------------------------------------------------------------- /oil/architectures/img_gen/ganBase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import spectral_norm 4 | import numpy as np 5 | from ...utils.utils import Expression,export,Named 6 | 7 | @export 8 | class GanBase(nn.Module,metaclass=Named): 9 | 10 | def __init__(self,z_dim,img_channels,num_classes=None): 11 | self.z_dim = z_dim 12 | self.img_channels = img_channels 13 | super().__init__() 14 | 15 | @property 16 | def device(self): 17 | try: return self._device 18 | except AttributeError: 19 | self._device = next(self.parameters()).device 20 | return self._device 21 | 22 | def sample_z(self, n=1): 23 | return torch.randn(n, self.z_dim).to(self.device) 24 | 25 | def sample(self, n=1): 26 | return self(self.sample_z(n)) 27 | 28 | 29 | def add_spectral_norm(module): 30 | if isinstance(module, (nn.ConvTranspose1d, 31 | nn.ConvTranspose2d, 32 | nn.ConvTranspose3d, 33 | )): 34 | spectral_norm(module,dim = 1) 35 | #print("SN on conv layer: ",module) 36 | elif isinstance(module, (nn.Linear, 37 | nn.Conv1d, 38 | nn.Conv2d, 39 | nn.Conv3d)): 40 | spectral_norm(module,dim = 0) 41 | #print("SN on linear layer: ",module) 42 | 43 | def xavier_uniform_init(module): 44 | if isinstance(module, (nn.ConvTranspose1d, 45 | nn.ConvTranspose2d, 46 | nn.ConvTranspose3d, 47 | nn.Conv1d, 48 | nn.Conv2d, 49 | nn.Conv3d)): 50 | if module.kernel_size==(1,1): 51 | nn.init.xavier_uniform_(module.weight.data,np.sqrt(2)) 52 | else: 53 | nn.init.xavier_uniform_(module.weight.data,1) 54 | #print("Xavier init on conv layer: ",module) 55 | elif isinstance(module, nn.Linear): 56 | nn.init.xavier_uniform_(module.weight.data,1) 57 | #print("Xavier init on linear layer: ",module) -------------------------------------------------------------------------------- /oil/recipes/trainCGan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.optim import Adam 4 | from oil.model_trainers.cGan import cGan 5 | from oil.architectures.img_gen import conditionalgan as cgan 6 | from oil.utils.utils import LoaderTo, cosLr, islice, dmap,export 7 | from oil.tuning.study import train_trial 8 | from oil.datasetup.datasets import CIFAR10 9 | from oil.datasetup.augLayers import RandomHorizontalFlip 10 | from oil.utils.parallel import try_multigpu_parallelize 11 | from oil.tuning.args import argupdated_config 12 | from functools import partial 13 | from torchvision import transforms 14 | 15 | 16 | def makeTrainer(*,gen=cgan.Generator,disc=cgan.Discriminator, 17 | num_epochs=500,dataset=CIFAR10,bs=64,lr=2e-4, 18 | device='cuda',net_config={},opt_config={'betas':(.5,.999)}, 19 | trainer_config={'n_disc':5,'log_dir':None},save=False): 20 | 21 | transform = transforms.Compose([transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize((.5,.5,.5),(.5,.5,.5))]) 24 | # Prep the datasets splits, model, and dataloaders 25 | datasets = {} 26 | datasets['train'] = dataset(f'~/datasets/{dataset}/',transform =transform) 27 | datasets['test'] = dmap(lambda mb: mb[0],dataset(f'~/datasets/{dataset}/',train=False,transform=transform)) 28 | 29 | device = torch.device(device) 30 | G = gen(num_classes=datasets['train'].num_targets).to(device) 31 | D = disc(num_classes=datasets['train'].num_targets).to(device) 32 | G,_ = try_multigpu_parallelize(G,bs) 33 | D,bs = try_multigpu_parallelize(D,bs) 34 | 35 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'), 36 | num_workers=0,pin_memory=False),device) for k,v in datasets.items()} 37 | opt_constr = partial(Adam, lr=lr, **opt_config) 38 | lr_sched = lambda e:1 39 | return cGan(G,dataloaders,opt_constr,lr_sched,D=D,**trainer_config) 40 | 41 | cGanTrial = train_trial(makeTrainer) 42 | if __name__=='__main__': 43 | cGanTrial(argupdated_config(makeTrainer.__kwdefaults__)) -------------------------------------------------------------------------------- /oil/utils/losses.py: -------------------------------------------------------------------------------- 1 | """Custom loss functions""" 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | def softmax_mse_loss(input_logits, target_logits): 9 | """Takes softmax on both sides and returns MSE loss 10 | Note: 11 | - Returns the sum over all examples. Divide by the batch size afterwards 12 | if you want the mean. (not true anymore) 13 | - Sends gradients to inputs but not the targets. 14 | """ 15 | assert input_logits.size() == target_logits.size() 16 | input_softmax = F.softmax(input_logits, dim=1) 17 | target_softmax = F.softmax(target_logits, dim=1) 18 | num_classes = input_logits.size()[1] 19 | return F.mse_loss(input_softmax, target_softmax) / num_classes 20 | 21 | 22 | def softmax_kl_loss(input_logits, target_logits): 23 | """Takes softmax on both sides and returns KL divergence 24 | Note: 25 | - Returns the sum over all examples. Divide by the batch size afterwards 26 | if you want the mean. 27 | - Sends gradients to inputs but not the targets. 28 | """ 29 | assert input_logits.size() == target_logits.size() 30 | input_log_softmax = F.log_softmax(input_logits, dim=1) 31 | target_softmax = F.softmax(target_logits, dim=1) 32 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 33 | 34 | 35 | def symmetric_mse_loss(input1, input2): 36 | """Like F.mse_loss but sends gradients to both directions 37 | Note: 38 | - Returns the sum over all examples. Divide by the batch size afterwards 39 | if you want the mean. 40 | - Sends gradients to both input1 and input2. 41 | """ 42 | assert input1.size() == input2.size() 43 | num_classes = input1.size()[1] 44 | return torch.sum((input1 - input2)**2) / num_classes 45 | 46 | def softmax_mse_loss_both(input1, input2): 47 | """Takes softmax on both sides and returns MSE loss 48 | Note: 49 | - Returns the sum over all examples. Divide by the batch size afterwards 50 | if you want the mean. 51 | - Sends gradients to both input1 and input2. 52 | """ 53 | assert input1.size() == input2.size() 54 | input1_sm = F.softmax(input1, dim=1) 55 | input2_sm = F.softmax(input2, dim=1) 56 | num_classes = input1.size()[1] 57 | return torch.sum((input1_sm - input2_sm)**2) / num_classes -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG model definition 3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 4 | """ 5 | 6 | import math 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | from ...utils.utils import Named 10 | 11 | __all__ = ['VGG16', 'VGG16BN'] 12 | 13 | 14 | def make_layers(cfg, batch_norm=False): 15 | layers = list() 16 | in_channels = 3 17 | for v in cfg: 18 | if v == 'M': 19 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 20 | else: 21 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 22 | if batch_norm: 23 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 24 | else: 25 | layers += [conv2d, nn.ReLU(inplace=True)] 26 | in_channels = v 27 | return nn.Sequential(*layers) 28 | 29 | 30 | cfg = { 31 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 32 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 33 | 512, 512, 512, 512, 'M'], 34 | } 35 | 36 | 37 | class VGG(nn.Module,metaclass=Named): 38 | def __init__(self, num_targets=10, depth=16, batch_norm=False): 39 | super(VGG, self).__init__() 40 | self.features = make_layers(cfg[depth], batch_norm) 41 | self.classifier = nn.Sequential( 42 | nn.Dropout(), 43 | nn.Linear(512, 512), 44 | nn.ReLU(True), 45 | nn.Dropout(), 46 | nn.Linear(512, 512), 47 | nn.ReLU(True), 48 | nn.Linear(512, num_targets), 49 | ) 50 | 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 54 | m.weight.data.normal_(0, math.sqrt(2. / n)) 55 | m.bias.data.zero_() 56 | 57 | def forward(self, x): 58 | x = self.features(x) 59 | x = x.view(x.size(0), -1) 60 | x = self.classifier(x) 61 | return x 62 | 63 | class VGG16(VGG): 64 | def __init__(self,num_targets=10): 65 | super().__init__(num_targets=num_targets,depth=16,batch_norm=False) 66 | class VGG16BN(VGG): 67 | def __init__(self,num_targets=10): 68 | super().__init__(num_targets=num_targets,depth=16,batch_norm=True) -------------------------------------------------------------------------------- /oil/tuning/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import numpy as np 4 | from .configGenerator import flatten, unflatten 5 | #from oil.datasetup import * 6 | #from oil.architectures import * 7 | #from torch.optim import * 8 | 9 | def argupdated_config(cfg,parser=None, namespace=None): 10 | """ Uses the cfg to generate a parser spec which parses the command line arguments 11 | and outputs the updated config. An existing argparser can be specified.""" 12 | # TODO: throw error for clobbered names 13 | flat_cfg = flatten(cfg) 14 | if parser is None: 15 | fmt = lambda prog: argparse.HelpFormatter(prog, max_help_position=80) 16 | parser = argparse.ArgumentParser(formatter_class=fmt) 17 | #formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | clobbered_name_mapping = {} 19 | for field, value in flat_cfg.items(): 20 | fields = field.split('/') 21 | short_field_name = fields[-1] 22 | parser.add_argument('--'+short_field_name,default=value,help="(default: %(default)s)") 23 | clobbered_name_mapping[short_field_name] = field 24 | if len(fields)>1 and fields[0] not in clobbered_name_mapping: 25 | parser.add_argument('--'+fields[0],default={},help="Additional Kwargs") 26 | clobbered_name_mapping[fields[0]] = fields[0] 27 | 28 | #parser.add_argument("--local_rank",type=int) # so that distributed will work #TODO: sort this out properly 29 | args = parser.parse_args() 30 | add_to_namespace(namespace) 31 | for short_argname, argvalue in vars(args).items(): 32 | argvalue = tryeval(argvalue) 33 | if short_argname in clobbered_name_mapping: # There may be additional args from argparse 34 | flat_cfg[clobbered_name_mapping[short_argname]] = argvalue 35 | else: 36 | flat_cfg[short_argname] = argvalue # Never actually called? 37 | 38 | extra_flat_cfg = flatten(flat_cfg) 39 | updated_full_cfg = unflatten(extra_flat_cfg) # Flatten again 40 | return updated_full_cfg 41 | 42 | def add_to_namespace(namespace): 43 | if namespace is not None: 44 | if not isinstance(namespace,tuple): 45 | namespace = (namespace,) 46 | for ns in namespace: 47 | globals().update({k: getattr(ns, k) for k in ns.__all__}) 48 | 49 | def tryeval(value): 50 | if isinstance(value,dict): 51 | return {k:tryeval(v) for k,v in value.items()} 52 | elif isinstance(value,str): 53 | try: 54 | return eval(value) # Try to evaluate the strings 55 | except (NameError, SyntaxError): 56 | return value # Interpret just as string 57 | else: 58 | return value 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Code Climate maintainability](https://img.shields.io/codeclimate/maintainability-percentage/mfinzi/pristine-ml)](https://codeclimate.com/github/mfinzi/pristine-ml) 2 | [![CodeClimate](http://img.shields.io/codeclimate/mfinzi/pristine-ml.svg?style=flat)](https://codeclimate.com/github/mfinzi/pristine-ml 3 | "CodeClimate") 4 | 5 | # Olive-Oil-ML 6 | 7 | Nuts and bolts deep learning library to make training neural networks easier. 8 | Features: 9 | * Logging functionality not at set time intervals but as a percentage of the total training time 10 | * Convenient specification for Random and Grid hyperparameter search: queued with a single GPU, split over multiple local GPUs, or over a Slurm Cluster 11 | * Clean implementation of popular methods/problems in CV such as Vanilla Image Classification, Regression, the PiModel for Semi-Supervised learning, and SN-GAN; all through the common Trainer abstraction 12 | 13 | # Installation 14 | To install, run `pip install git+https://github.com/mfinzi/olive-oil-ml`. Dependencies will be checked and installed from the setup.py file. 15 | 16 | # Dependencies 17 | * Python 3.7+ 18 | * [PyTorch](http://pytorch.org/) 1.3.0+ 19 | * [torchvision](https://github.com/pytorch/vision/) 20 | * [tqdm](https://tqdm.github.io/) 4.40+ 21 | * [natsort]() 22 | * (optional) [tensorboardX](https://github.com/lanpa/tensorboardX) 23 | 24 | # Jump into training a single model 25 | 26 | To get a feel for the library, try training a (Classifier, Regressor, PiModel, GAN) model from our recipes. 27 | For classification try running 28 | * `python oil/recipes/simpleTrial.py --dataset CIFAR100 --num_epochs 10` 29 | 30 | Or, to train a conditional GAN model: 31 | * `python oil/recipes/simpleCGan.py --dataset SVHN --lr 2e-4` 32 | 33 | Or train a PiModel semisupervised on CIFAR10 using only 1k labels: 34 | * `python oil/recipes/simplePi.py --dataset CIFAR10 --train 1000` 35 | 36 | 37 | You can use `-h` to see the full range of arguments available. Command line arguments and defaults are automatically inferred 38 | from the code used to construct the trial, so you can make a new trial (that uses some exotic data augmentation strategy for example) and the command line parser will generated for you, see the example recipes for how this works. 39 | 40 | # Perform a hyperparameter search 41 | Example: Search over hyperparameters for CNN classifier on Cifar100 42 | * `python oil/recipes/exampleHyperSearch.py --dataset CIFAR100 --bs [50,32,64] --k [64,96] --num_epochs 100` 43 | See example code for programmatic way of specifying the hyperparameter search. 44 | Automatically parallelizes the search over multiple GPUs if available. 45 | # Logging Support 46 | 47 | # Interfacing with external libraries 48 | -------------------------------------------------------------------------------- /oil/architectures/parts/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch.nn.utils import weight_norm 8 | from ...utils.utils import Expression,export,Named 9 | from .CoordConv import CoordConv 10 | 11 | @export 12 | def conv2d(in_channels,out_channels,kernel_size=3,coords=False,dilation=1,**kwargs): 13 | """ Wraps nn.Conv2d and CoordConv, padding is set to same 14 | and coords=True can be specified to get additional coordinate in_channels""" 15 | assert 'padding' not in kwargs, "assumed to be padding = same " 16 | same = (kernel_size//2)*dilation 17 | if coords: 18 | return CoordConv(in_channels,out_channels,kernel_size,padding=same,dilation=dilation,**kwargs) 19 | else: 20 | return nn.Conv2d(in_channels,out_channels,kernel_size,padding=same,dilation=dilation,**kwargs) 21 | @export 22 | class ResBlock(nn.Module): 23 | def __init__(self,in_channels,out_channels,ksize=3,drop_rate=0,stride=1,gn=False,**kwargs): 24 | super().__init__() 25 | norm_layer = (lambda c: nn.GroupNorm(c//16,c)) if gn else nn.BatchNorm2d 26 | self.net = nn.Sequential( 27 | norm_layer(in_channels), 28 | nn.ReLU(), 29 | conv2d(in_channels,out_channels,ksize,**kwargs), 30 | norm_layer(out_channels), 31 | nn.ReLU(), 32 | conv2d(out_channels,out_channels,ksize,stride=stride,**kwargs), 33 | nn.Dropout(p=drop_rate) 34 | ) 35 | if in_channels != out_channels: 36 | self.shortcut = conv2d(in_channels,out_channels,1,stride=stride,**kwargs) 37 | elif stride!=1: 38 | self.shortcut = Expression(lambda x: F.interpolate(x,scale_factor=1/stride)) 39 | else: 40 | self.shortcut = nn.Sequential() 41 | 42 | def forward(self,x): 43 | return self.shortcut(x) + self.net(x) 44 | 45 | @export 46 | def ConvBNrelu(in_channels,out_channels,**kwargs): 47 | return nn.Sequential( 48 | conv2d(in_channels,out_channels,**kwargs), 49 | nn.BatchNorm2d(out_channels), 50 | nn.ReLU() 51 | ) 52 | @export 53 | def FcBNrelu(in_channels,out_channels): 54 | return nn.Sequential( 55 | nn.Linear(in_channels,out_channels), 56 | nn.BatchNorm1d(out_channels), 57 | nn.ReLU() 58 | ) 59 | 60 | @export 61 | class DenseLayer(nn.Module): 62 | def __init__(self, inplanes, k=12, drop_rate=0,coords=True): 63 | super().__init__() 64 | self.net = nn.Sequential( 65 | ConvBNrelu(inplanes,4*k,kernel_size=1,coords=coords), 66 | ConvBNrelu(4*k,k,kernel_size=3,coords=coords), 67 | nn.Dropout(p=drop_rate), 68 | ) 69 | def forward(self, x): 70 | return torch.cat((x, self.net(x)), 1) 71 | @export 72 | class DenseBlock(nn.Module): 73 | def __init__(self, inplanes,k=16,N=20,drop_rate=0,coords=True): 74 | super().__init__() 75 | layers = [] 76 | for i in range(N): 77 | layers.append(DenseLayer(inplanes,k,drop_rate,coords)) 78 | inplanes += k 79 | layers.append(ConvBNrelu(inplanes,inplanes//2)) 80 | self.net = nn.Sequential(*layers) 81 | 82 | def forward(self,x): 83 | return self.net(x) -------------------------------------------------------------------------------- /oil/model_trainers/cGan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import l1_loss as L1 4 | from ..utils.utils import Eval 5 | from .gan import Gan, hinge_loss_D, hinge_loss_G 6 | 7 | class cGan(Gan): 8 | def __init__(self,*args,**kwargs): 9 | super().__init__(*args,**kwargs) 10 | self.fixed_input = (self.G.sample_y(32),self.G.sample_z(32)) 11 | 12 | def discLoss(self, data): 13 | """ Hinge loss for discriminator""" 14 | x,y = data 15 | fake_logits = self.D(self.G(y),y) 16 | real_logits = self.D(x,y) 17 | return hinge_loss_D(real_logits,fake_logits) 18 | 19 | def genLoss(self,data): 20 | """ Hinge based generator loss -E[D(G(z))] """ 21 | x,y = data 22 | fake_logits = self.D(self.G(y),y) 23 | return hinge_loss_G(fake_logits) 24 | 25 | 26 | class Pix2Pix(cGan): 27 | def __init__(self,*args,l1=10,**kwargs): 28 | super().__init__(*args,**kwargs) 29 | self.hypers['l1'] = l1 30 | 31 | def genLoss(self,data): 32 | y,x = data # Here y is the output image and x is input 33 | fake_y = self.G(y) 34 | adversarial = hinge_loss_G(self.D(fake_y,x)) 35 | return adversarial + self.hypers['l1']*L1(fake_y,y) 36 | 37 | def logStuff(self, i, minibatch=None): 38 | raise NotImplementedError 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | from torch.utils.data import DataLoader 52 | from oil.utils.utils import LoaderTo, cosLr, recursively_update,imap,islice 53 | from oil.tuning.study import train_trial 54 | from oil.datasetup.dataloaders import getLabLoader 55 | from oil.datasetup.datasets import CIFAR10 56 | from oil.architectures.img_classifiers import layer13s 57 | from oil.architectures.img_gen import conditionalgan 58 | 59 | def simpleGanTrial(strict=False): 60 | def makeTrainer(config): 61 | cfg = { 62 | 'loader_config':{'amnt_dev':0,'lab_BS':64,'dataseed':0,'num_workers':1}, 63 | 'gen':conditionalgan.Generator,'disc':conditionalgan.Discriminator, 64 | 'trainer':cGan, 65 | 'trainer_config':{'n_disc':2}, 66 | 'opt_config':{'lr':2e-4,'betas':(.5,.999)}, 67 | 'num_epochs':400, 68 | } 69 | recursively_update(cfg,config) 70 | trainset = cfg['dataset']('~/datasets/{}/'.format(cfg['dataset']),gan_normalize=True) 71 | device = torch.device('cuda') 72 | G = cfg['gen'](num_classes=10).to(device) 73 | D = cfg['disc'](num_classes=10).to(device) 74 | 75 | dataloaders = {} 76 | dataloaders['train'], _ = getLabLoader(trainset,**cfg['loader_config']) 77 | imgs_only = imap(lambda z: z[0], dataloaders['train']) 78 | dataloaders['dev'] = islice(imgs_only,5000//cfg['loader_config']['lab_BS']) 79 | dataloaders = {k: LoaderTo(v,device) for k,v in dataloaders.items()} 80 | opt_constr = lambda params: torch.optim.Adam(params,**cfg['opt_config']) 81 | lr_sched = cosLr(cfg['num_epochs']) 82 | trainer = cfg['trainer'](G,dataloaders,opt_constr=opt_constr, 83 | lr_sched=lr_sched,D=D,**cfg['trainer_config']) 84 | return trainer 85 | return train_trial(makeTrainer,strict) 86 | 87 | if __name__=='__main__': 88 | Trial = simpleGanTrial(strict=True) 89 | Trial({'num_epochs':4*(100,)}) -------------------------------------------------------------------------------- /oil/architectures/parts/denseblocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ...utils.utils import Expression,export,Named 4 | # Densenet blocks ported from 5 | # https://github.com/bfortuner/pytorch_tiramisu/blob/master/models/layers.py 6 | 7 | class DenseLayer(nn.Sequential): 8 | def __init__(self, in_channels, growth_rate): 9 | super().__init__() 10 | self.add_module('norm', nn.BatchNorm2d(in_channels)) 11 | self.add_module('relu', nn.ReLU(True)) 12 | self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3, 13 | stride=1, padding=1, bias=True)) 14 | self.add_module('drop', nn.Dropout2d(0.2)) 15 | 16 | def forward(self, x): 17 | return super().forward(x) 18 | 19 | @export 20 | class DenseBlock(nn.Module): 21 | def __init__(self, in_channels, growth_rate, n_layers, upsample=False): 22 | super().__init__() 23 | self.upsample = upsample 24 | self.layers = nn.ModuleList([DenseLayer( 25 | in_channels + i*growth_rate, growth_rate) 26 | for i in range(n_layers)]) 27 | 28 | def forward(self, x): 29 | if self.upsample: 30 | new_features = [] 31 | #we pass all previous activations into each dense layer normally 32 | #But we only store each dense layer's output in the new_features array 33 | for layer in self.layers: 34 | out = layer(x) 35 | x = torch.cat([x, out], 1) 36 | new_features.append(out) 37 | return torch.cat(new_features,1) 38 | else: 39 | for layer in self.layers: 40 | out = layer(x) 41 | x = torch.cat([x, out], 1) # 1 = channel axis 42 | return x 43 | 44 | 45 | 46 | class TransitionDown(nn.Sequential): 47 | def __init__(self, in_channels): 48 | super().__init__() 49 | self.add_module('norm', nn.BatchNorm2d(num_features=in_channels)) 50 | self.add_module('relu', nn.ReLU(inplace=True)) 51 | self.add_module('conv', nn.Conv2d(in_channels, in_channels, 52 | kernel_size=1, stride=1, 53 | padding=0, bias=True)) 54 | self.add_module('drop', nn.Dropout2d(0.2)) 55 | self.add_module('maxpool', nn.MaxPool2d(2)) 56 | 57 | def forward(self, x): 58 | return super().forward(x) 59 | 60 | 61 | class TransitionUp(nn.Module): 62 | def __init__(self, in_channels, out_channels): 63 | super().__init__() 64 | self.convTrans = nn.ConvTranspose2d( 65 | in_channels=in_channels, out_channels=out_channels, 66 | kernel_size=3, stride=2, padding=0, bias=True) 67 | 68 | def forward(self, x, skip): 69 | out = self.convTrans(x) 70 | out = center_crop(out, skip.size(2), skip.size(3)) 71 | out = torch.cat([out, skip], 1) 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Sequential): 76 | def __init__(self, in_channels, growth_rate, n_layers): 77 | super().__init__() 78 | self.add_module('bottleneck', DenseBlock( 79 | in_channels, growth_rate, n_layers, upsample=True)) 80 | 81 | def forward(self, x): 82 | return super().forward(x) 83 | 84 | 85 | def center_crop(layer, max_height, max_width): 86 | _, _, h, w = layer.size() 87 | xy1 = (w - max_width) // 2 88 | xy2 = (h - max_height) // 2 89 | return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] -------------------------------------------------------------------------------- /oil/model_trainers/vat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import Variable, grad 6 | import torch.nn.functional as F 7 | import torchvision.utils as vutils 8 | import itertools 9 | import torch 10 | import numpy as np 11 | import torch.nn as nn 12 | 13 | from .piModel import PiModel 14 | from ..utils.losses import softmax_mse_loss, softmax_mse_loss_both 15 | from ..utils.utils import Eval, izip, icycle 16 | 17 | def _l2_normalize(d): 18 | d /= (1e-12 + torch.max(torch.abs(d))) 19 | d /= norm(d) 20 | return d 21 | 22 | def norm(d,keepdim=True): 23 | if len(d.shape)==4: 24 | norm = torch.sqrt(1e-6+(d**2).sum(3).sum(2).sum(1)) 25 | return norm[:,None,None,None] if keepdim else norm 26 | elif len(d.shape)==2: 27 | norm = torch.sqrt(1e-6+(d**2).sum(1)) 28 | return norm[:,None] if keepdim else norm 29 | else: 30 | assert False, "only supports 0d and 2d now" 31 | 32 | def kl_div_withlogits(p_logits, q_logits): 33 | kl_div = nn.KLDivLoss(size_average=True).cuda() 34 | LSM = nn.LogSoftmax(dim=1) 35 | SM = nn.Softmax(dim=1) 36 | return kl_div(LSM(q_logits), SM(p_logits)) 37 | 38 | def cross_ent_withlogits(p_logits,q_logits): 39 | LSM = nn.LogSoftmax(dim=1).cuda() 40 | SM = nn.Softmax(dim=1) 41 | return -1*(SM(p_logits)*LSM(q_logits)).sum(dim=1).mean(dim=0) 42 | 43 | class Vat(PiModel): 44 | def __init__(self, *args, cons_weight=.3, advEps=32, entMin=True, **kwargs): 45 | super().__init__(*args,**kwargs) 46 | self.hypers.update({'cons_weight':cons_weight,'advEps':advEps, 'entMin':entMin}) 47 | 48 | def unlabLoss(self, x_unlab): 49 | """ Calculates LDS loss according to https://arxiv.org/abs/1704.03976 """ 50 | wasTraining = self.model.training; self.model.train(False) 51 | 52 | r_adv = self.hypers['advEps'] * self.getAdvPert(self.model, x_unlab) 53 | perturbed_logits = self.model(x_unlab + r_adv) 54 | logits = self.model(x_unlab).detach() 55 | unlabLoss = kl_div_withlogits(logits, perturbed_logits)/(self.hypers['advEps'])**2 56 | self.model.train(wasTraining) 57 | return unlabLoss 58 | 59 | @staticmethod 60 | def getAdvPert(model, X, powerIts=1, xi=1e-6): 61 | wasTraining = model.training; model.train(False) 62 | 63 | ddata = torch.randn(X.size()).to(X.device) 64 | # calc adversarial direction 65 | d = Variable(xi*_l2_normalize(ddata), requires_grad=True) 66 | logit_p = model(X).detach() 67 | perturbed_logits = model(X + d) 68 | adv_distance = kl_div_withlogits(logit_p, perturbed_logits) 69 | d_grad = torch.autograd.grad(adv_distance,d)[0] 70 | #model.zero_grad() 71 | #ddata = d.grad.data 72 | #print("2 max = %.4E, min = %.4E"%(torch.max(norm(ddata)),torch.min(norm(ddata)))) 73 | #model.zero_grad() 74 | #print("3 max = %.4E, min = %.4E"%(torch.max(norm(ddata)),torch.min(norm(ddata)))) 75 | model.train(wasTraining) 76 | return _l2_normalize(d_grad)#.detach() 77 | 78 | # def logStuff(self, step, minibatch=None): 79 | # if minibatch: 80 | # x_unlab = trainData[1]; someX = x_unlab[:16] 81 | # r_adv = self.hypers['advEps'] * self.getAdvPert(self.model, someX) 82 | # adversarialImages = (someX + r_adv).cpu().data 83 | # imgComparison = torch.cat((adversarialImages, someX.cpu().data)) 84 | # self.writer.add_image('adversarialInputs', 85 | # vutils.make_grid(imgComparison,normalize=True,range=(-2.5,2.5)), step) 86 | # self.metricLog.update({'Unlab_loss(batch)': 87 | # self.unlabLoss(x_unlab).cpu().data[0]}) 88 | # super().logStuff(step, minibatch) -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/smallconv.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn import Parameter 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import numpy as np 8 | from torch.nn.utils import weight_norm 9 | from ...utils.utils import Expression,export,Named 10 | from ..parts import ResBlock, DeConv2d, FastDeconv, MaxBlurPool,BlurPool 11 | 12 | 13 | def ConvBNrelu(in_channels,out_channels,stride=1): 14 | return nn.Sequential( 15 | nn.Conv2d(in_channels,out_channels,3,padding=1,stride=stride), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU() 18 | ) 19 | 20 | def ConvDrelu(in_channels,out_channels,stride=1): 21 | return nn.Sequential( 22 | FastDeconv(in_channels,out_channels,3,padding=1,stride=stride), 23 | nn.ReLU() 24 | ) 25 | 26 | def ConvGNrelu(in_channels,out_channels,stride=1): 27 | return nn.Sequential( 28 | nn.Conv2d(in_channels,out_channels,3,padding=1,stride=stride), 29 | nn.GroupNorm(out_channels//16,out_channels), 30 | nn.ReLU() 31 | ) 32 | @export 33 | class smallCNN(nn.Module,metaclass=Named): 34 | """ 35 | Very small CNN 36 | """ 37 | def __init__(self, num_targets=10,in_channels=3,k=16): 38 | super().__init__() 39 | self.num_targets = num_targets 40 | self.net = nn.Sequential( 41 | ConvBNrelu(in_channels,k), 42 | ConvBNrelu(k,k), 43 | ConvBNrelu(k,2*k), 44 | nn.MaxPool2d(2), 45 | ConvBNrelu(2*k,2*k), 46 | ConvBNrelu(2*k,2*k), 47 | ConvBNrelu(2*k,2*k), 48 | nn.MaxPool2d(2), 49 | ConvBNrelu(2*k,2*k), 50 | ConvBNrelu(2*k,2*k), 51 | ConvBNrelu(2*k,2*k), 52 | Expression(lambda u:u.mean(-1).mean(-1)), 53 | nn.Linear(2*k,num_targets) 54 | ) 55 | def forward(self,x): 56 | return self.net(x) 57 | 58 | @export 59 | class layer13s(nn.Module,metaclass=Named): 60 | """ 61 | Very small CNN 62 | """ 63 | def __init__(self, num_targets=10,in_channels=3,k=128): 64 | super().__init__() 65 | self.num_targets = num_targets 66 | self.net = nn.Sequential( 67 | ConvBNrelu(in_channels,k), 68 | ConvBNrelu(k,k), 69 | ConvBNrelu(k,2*k), 70 | nn.MaxPool2d(2),#MaxBlurPool(2*k), 71 | #nn.Dropout2d(), 72 | ConvBNrelu(2*k,2*k), 73 | ConvBNrelu(2*k,2*k), 74 | ConvBNrelu(2*k,2*k), 75 | nn.MaxPool2d(2),#MaxBlurPool(2*k), 76 | #nn.Dropout2d(), 77 | ConvBNrelu(2*k,2*k), 78 | ConvBNrelu(2*k,2*k), 79 | ConvBNrelu(2*k,2*k), 80 | Expression(lambda u:u.mean(-1).mean(-1)), 81 | nn.Linear(2*k,num_targets) 82 | ) 83 | def forward(self,x): 84 | return self.net(x) 85 | 86 | @export 87 | class layer13d(nn.Module,metaclass=Named): 88 | """ 89 | Very small CNN 90 | """ 91 | def __init__(self, num_targets=10,in_channels=3,k=128): 92 | super().__init__() 93 | self.num_targets = num_targets 94 | self.net = nn.Sequential( 95 | ConvDrelu(in_channels,k), 96 | ConvDrelu(k,k), 97 | ConvDrelu(k,2*k), 98 | nn.MaxPool2d(2), 99 | nn.Dropout2d(), 100 | ConvDrelu(2*k,2*k), 101 | ConvDrelu(2*k,2*k), 102 | ConvDrelu(2*k,2*k), 103 | nn.MaxPool2d(2), 104 | nn.Dropout2d(), 105 | ConvDrelu(2*k,2*k), 106 | ConvDrelu(2*k,2*k), 107 | ConvDrelu(2*k,2*k), 108 | Expression(lambda u:u.mean(-1).mean(-1)), 109 | nn.Linear(2*k,num_targets) 110 | ) 111 | def forward(self,x): 112 | return self.net(x) -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/wide_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | WideResNet model definition 3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py 4 | """ 5 | 6 | import torchvision.transforms as transforms 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | import math 11 | from ...utils.utils import export, Named 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 16 | 17 | 18 | def conv_init(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('Conv') != -1: 21 | init.xavier_uniform(m.weight, gain=math.sqrt(2)) 22 | init.constant(m.bias, 0) 23 | elif classname.find('BatchNorm') != -1: 24 | init.constant(m.weight, 1) 25 | init.constant(m.bias, 0) 26 | 27 | 28 | class WideBasic(nn.Module): 29 | def __init__(self, in_planes, planes, drop_rate, stride=1): 30 | super(WideBasic, self).__init__() 31 | self.bn1 = nn.BatchNorm2d(in_planes) 32 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 33 | self.dropout = nn.Dropout(p=drop_rate) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride != 1 or in_planes != planes: 39 | self.shortcut = nn.Sequential( 40 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 41 | ) 42 | 43 | def forward(self, x): 44 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 45 | out = self.conv2(F.relu(self.bn2(out))) 46 | out += self.shortcut(x) 47 | 48 | return out 49 | 50 | @export 51 | class WideResNet(nn.Module,metaclass=Named): 52 | def __init__(self, num_targets=10, depth=28, widen_factor=10, drop_rate=0.3,in_channels=3,initial_stride=1): 53 | super(WideResNet, self).__init__() 54 | self.in_planes = 16 55 | 56 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 57 | n = (depth - 4) / 6 58 | k = widen_factor 59 | 60 | nstages = [16, 16 * k, 32 * k, 64 * k] 61 | 62 | self.conv1 = conv3x3(in_channels, nstages[0]) 63 | self.layer1 = self._wide_layer(WideBasic, nstages[1], n, drop_rate, stride=initial_stride) 64 | self.layer2 = self._wide_layer(WideBasic, nstages[2], n, drop_rate, stride=2) 65 | self.layer3 = self._wide_layer(WideBasic, nstages[3], n, drop_rate, stride=2) 66 | self.bn1 = nn.BatchNorm2d(nstages[3])#, momentum=0.9) 67 | self.linear = nn.Linear(nstages[3], num_targets) 68 | 69 | def _wide_layer(self, block, planes, num_blocks, drop_rate, stride): 70 | strides = [stride] + [1] * int(num_blocks - 1) 71 | layers = [] 72 | 73 | for stride in strides: 74 | layers.append(block(self.in_planes, planes, drop_rate, stride)) 75 | self.in_planes = planes 76 | 77 | return nn.Sequential(*layers) 78 | 79 | def forward(self, x): 80 | out = self.conv1(x) 81 | out = self.layer1(out) 82 | out = self.layer2(out) 83 | out = self.layer3(out) 84 | out = F.relu(self.bn1(out)) 85 | out = self.linear(out.mean(-1).mean(-1)) 86 | return out 87 | 88 | @export 89 | class WideResNet28x10(WideResNet): 90 | def __init__(self,num_targets=10,drop_rate=.3,in_channels=3): 91 | super().__init__(num_targets,depth=28, widen_factor=10,drop_rate=drop_rate,in_channels=in_channels) 92 | 93 | @export 94 | class WideResNet28x10stl(WideResNet): 95 | def __init__(self,num_targets=10,drop_rate=.3,in_channels=3): 96 | super().__init__(num_targets,depth=28, widen_factor=10,drop_rate=drop_rate,in_channels=in_channels,initial_stride=2) -------------------------------------------------------------------------------- /oil/datasetup/camvid.py: -------------------------------------------------------------------------------- 1 | 2 | # Adapted from https://github.com/felixgwu/vision/blob/cf491d301f62ae9c77ff7250fb7def5cd55ec963/torchvision/datasets/camvid.py 3 | import os 4 | import torch 5 | import torch.utils.data as data 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision.datasets.folder import default_loader 9 | import torchvision.transforms as transforms 10 | from .joint_transforms import JointRandomCrop, JointRandomHorizontalFlip 11 | 12 | def make_dataset(dir): 13 | images = [] 14 | for root, _, fnames in sorted(os.walk(dir)): 15 | for fname in fnames: 16 | #if is_image_file(fname): 17 | path = os.path.join(root, fname) 18 | item = path 19 | images.append(item) 20 | return images 21 | 22 | def LabelToLongTensor(pic): 23 | label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 24 | label = label.view(pic.size[1], pic.size[0], 1) 25 | label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long() 26 | return label 27 | 28 | class CamVid(data.Dataset): 29 | # weights when using median frequency balancing used in SegNet paper 30 | # https://arxiv.org/pdf/1511.00561.pdf 31 | # The numbers were generated by https://github.com/yandex/segnet-torch/blob/master/datasets/camvid-gen.lua 32 | classes = ['Sky', 'Building', 'Column-Pole', 'Road', 33 | 'Sidewalk', 'Tree', 'Sign-Symbol', 'Fence', 'Car', 'Pedestrain', 34 | 'Bicyclist', 'Void'] 35 | class_weights = [0.58872014284134, 0.51052379608154, 2.6966278553009, 0.45021694898605, 1.1785038709641, 36 | 0.77028578519821, 2.4782588481903, 2.5273461341858, 1.0122526884079, 3.2375309467316, 37 | 4.1312313079834, 0] 38 | class_color = [(128, 128, 128),(128, 0, 0),(192, 192, 128),(128, 64, 128),(0, 0, 192),(128, 128, 0), 39 | (192, 128, 128),(64, 64, 128),(64, 0, 128),(64, 64, 0),(0, 128, 192),(0, 0, 0),] 40 | num_classes=11 41 | means = [0.41189489566336, 0.4251328133025, 0.4326707089857] 42 | stds = [0.27413549931506, 0.28506257482912, 0.28284674400252] 43 | 44 | 45 | def __init__(self, root, split='train', joint_transform=None, 46 | transform=None, download=False, 47 | loader=default_loader): 48 | self.root = root 49 | assert split in ('train', 'val', 'test') 50 | self.split = split 51 | self.transform = transform 52 | self.joint_transform = joint_transform 53 | 54 | if download: 55 | self.download() 56 | 57 | self.imgs = make_dataset(os.path.join(self.root, self.split)) 58 | 59 | 60 | def __getitem__(self, index): 61 | path = self.imgs[index] 62 | img = self.loader(path) 63 | target = Image.open(path.replace(self.split, self.split + 'annot')) 64 | 65 | if self.joint_transform is not None: 66 | img, target = self.joint_transform([img, target]) 67 | 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | 71 | target = LabelToLongTensor(target) 72 | return img, target 73 | 74 | def __len__(self): 75 | return len(self.imgs) 76 | 77 | def download(self): 78 | # TODO: please download the dataset from 79 | # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid 80 | raise NotImplementedError 81 | 82 | @classmethod 83 | def LabelToPILImage(self, label): 84 | label = label.unsqueeze(0) 85 | colored_label = torch.zeros(3, label.size(1), label.size(2)).byte() 86 | for i, color in enumerate(self.class_color): 87 | mask = label.eq(i) 88 | for j in range(3): 89 | colored_label[j].masked_fill_(mask, color[j]) 90 | npimg = colored_label.numpy() 91 | npimg = np.transpose(npimg, (1, 2, 0)) 92 | mode = None 93 | if npimg.shape[2] == 1: 94 | npimg = npimg[:, :, 0] 95 | mode = "L" 96 | return Image.fromarray(npimg, mode=mode) -------------------------------------------------------------------------------- /oil/datasetup/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from oil.utils.utils import Named, export 8 | 9 | IMAGE_EXTENSTOINS = [".png", ".jpg", ".jpeg", ".bmp"] 10 | ATTR_ANNO = "list_attr_celeba.csv" 11 | 12 | def _is_image(fname): 13 | _, ext = os.path.splitext(fname) 14 | return ext.lower() in IMAGE_EXTENSTOINS 15 | 16 | 17 | def _find_images_and_annotation(root_dir): 18 | images = {} 19 | attr = None 20 | assert os.path.exists(root_dir), "{} not exists".format(root_dir) 21 | for root, _, fnames in sorted(os.walk(root_dir)): 22 | for fname in sorted(fnames): 23 | if _is_image(fname): 24 | path = os.path.join(root, fname) 25 | images[os.path.splitext(fname)[0]] = path 26 | elif fname.lower() == ATTR_ANNO: 27 | attr = os.path.join(root, fname) 28 | 29 | assert attr is not None, "Failed to find `list_attr_celeba.csv`" 30 | 31 | # begin to parse all image 32 | print("Begin to parse all image attrs") 33 | final = [] 34 | with open(attr, "r") as fin: 35 | image_total = 0 36 | attrs = [] 37 | for i_line, line in enumerate(fin): 38 | line = line.strip() 39 | if i_line == 0: 40 | image_total = int(line) 41 | elif i_line == 1: 42 | attrs = line.split(" ") 43 | else: 44 | line = re.sub("[ ]+", " ", line) 45 | line = line.split(" ") 46 | fname = os.path.splitext(line[0])[0] 47 | onehot = [int(int(d) > 0) for d in line[1:]] 48 | assert len(onehot) == len(attrs), "{} only has {} attrs < {}".format( 49 | fname, len(onehot), len(attrs)) 50 | final.append({ 51 | "path": images[fname], 52 | "attr": onehot 53 | }) 54 | print("Find {} images, with {} attrs".format(len(final), len(attrs))) 55 | return final, attrs 56 | 57 | 58 | def find_imgs_only(root_dir): 59 | images = [] 60 | attr = None 61 | assert os.path.exists(root_dir), "{} not exists".format(root_dir) 62 | for root, _, fnames in sorted(os.walk(root_dir)): 63 | for fname in sorted(fnames): 64 | if _is_image(fname): 65 | path = os.path.join(root, fname) 66 | images.append({'path':path,'attr':1}) 67 | return images,None 68 | 69 | @export 70 | class CelebA(Dataset): 71 | def __init__(self, root_dir, transform=None,size=64,flow=False): 72 | super().__init__() 73 | if transform is None: transform = transforms.Compose([ 74 | transforms.CenterCrop(160), 75 | transforms.Resize(size), 76 | transforms.ToTensor()]) 77 | full_dir = os.path.join(os.path.expanduser(root_dir),'celeba-dataset/img_align_celeba/img_align_celeba') 78 | #print(full_dir) 79 | dicts, attrs = find_imgs_only(full_dir) 80 | self.data = dicts 81 | self.attrs = attrs 82 | self.transform = transform 83 | 84 | def __getitem__(self, index): 85 | data = self.data[index] 86 | path = data["path"] 87 | attr = data["attr"] 88 | image= Image.open(path).convert("RGB") 89 | if self.transform is not None: 90 | image = self.transform(image) 91 | return image,attr 92 | 93 | def __len__(self): 94 | return len(self.data) 95 | 96 | 97 | if __name__ == "__main__": 98 | import cv2 99 | celeba = CelebA(os.path.expanduser("~/datasets/CelebA/")) 100 | d = celeba[0] 101 | print(d[0].size()) 102 | img = d[0].permute(1, 2, 0).contiguous().numpy() 103 | print(np.min(img), np.max(img)) 104 | cv2.imshow("img", img) 105 | cv2.waitKey() 106 | -------------------------------------------------------------------------------- /oil/tuning/slurmExecutor.py: -------------------------------------------------------------------------------- 1 | 2 | import dill 3 | from ast import literal_eval 4 | import sys,os,stat 5 | import time 6 | import tempfile 7 | import atexit 8 | import subprocess 9 | from concurrent import futures 10 | from functools import partial 11 | import itertools 12 | import torch 13 | from oil.tuning.localGpuExecutor import LocalGpuExecutor 14 | 15 | def kwargs_to_list(kwargs): 16 | return ["%s%s"%(('--'+k+'=',v) if len(k)>1 17 | else ('-'+k+' ',v)) for k,v in kwargs.items()] 18 | 19 | def tmp_file_name(suffix=".sh"): 20 | t = tempfile.mktemp(dir='.',suffix=suffix) 21 | atexit.register(os.unlink, t) 22 | return t 23 | 24 | class SlurmExecutor(futures.ThreadPoolExecutor): 25 | def __init__(self,*args,slurm_cfg={},clone_session=True,**kwargs): 26 | self.slurm_cfg = slurm_cfg 27 | # Dump the python session 28 | if clone_session: 29 | self.session_file_name = tmp_file_name(".pkl") 30 | dill.dump_session(self.session_file_name) 31 | else: 32 | self.session_file_name = 'no_session' 33 | super().__init__(*args,**kwargs) 34 | 35 | def submit(self,fn,*args,**kwargs): 36 | def todo(): 37 | with open(tmp_file_name(), 'wb+') as funcfile: 38 | dill.dump((fn,args,kwargs),funcfile) 39 | with open(tmp_file_name(), "wb+") as sh_script: 40 | sh_script.write(os.fsencode('#!/bin/sh\n{} {} {} {}'\ 41 | .format(sys.executable,os.path.realpath(__file__), 42 | funcfile.name,self.session_file_name))) 43 | os.fchmod(sh_script.fileno(),stat.S_IRWXU|stat.S_IRWXG|stat.S_IROTH|stat.S_IXOTH) 44 | cfg_args = kwargs_to_list(self.slurm_cfg) 45 | subprocess.call(['srun',*cfg_args,sh_script.name]) 46 | with open(funcfile.name, 'rb') as funcfile: 47 | function_output = dill.load(funcfile) 48 | return function_output 49 | return super().submit(todo) 50 | 51 | def map(self, fn, *iterables, timeout=None, chunksize=1): 52 | """ Identical to the chunky ProcessPoolExecutor implementation, 53 | but underlying parts aren't exposed """ 54 | if chunksize < 1: 55 | raise ValueError("chunksize must be >= 1.") 56 | results = super().map(partial(_process_chunk, fn), 57 | _get_chunks(*iterables, chunksize=chunksize), 58 | timeout=timeout) 59 | return _chain_from_iterable_of_lists(results) 60 | 61 | def _process_chunk(fn, chunk): 62 | return [fn(*args) for args in chunk] 63 | def _get_chunks(*iterables, chunksize): 64 | it = zip(*iterables) 65 | while True: 66 | chunk = tuple(itertools.islice(it, chunksize)) 67 | if not chunk: 68 | return 69 | yield chunk 70 | def _chain_from_iterable_of_lists(iterable): 71 | for element in iterable: 72 | element.reverse() 73 | while element: 74 | yield element.pop() 75 | 76 | def LocalExecutor(max_workers=None): 77 | if max_workers==1 or torch.cuda.device_count()<=1 or os.environ.copy().get("WORLD_SIZE",0)!=0: 78 | print("local") 79 | return futures.ThreadPoolExecutor(max_workers=1) 80 | else: 81 | return LocalGpuExecutor(max_workers) 82 | 83 | # #LocalExecutor = LocalGpuExecutor 84 | # class LocalExecutor(futures.ThreadPoolExecutor): 85 | # """Wraps ProcessPoolExecutor but distributes local gpus to the 86 | # processes #TODO: restrict gpu allocation. At the moment restricts 87 | # to sequential (single core and gpu) execution.""" 88 | # def __init__(self,max_workers,*args,**kwargs): 89 | # super().__init__(max_workers=1,*args,**kwargs) 90 | # #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 91 | # #os.environ["CUDA_VISIBLE_DEVICES"]="0" 92 | 93 | 94 | if __name__=='__main__': 95 | if sys.argv[2]!='no_session': 96 | dill.load_session(sys.argv[2]) 97 | with open(sys.argv[1], 'rb') as funcfile: 98 | (fn,args,kwargs) = dill.load(funcfile) 99 | out = fn(*args,**kwargs) 100 | with open(sys.argv[1], 'wb+') as funcfile: 101 | dill.dump(out,funcfile) 102 | -------------------------------------------------------------------------------- /oil/architectures/img_gen/resnetgan.py: -------------------------------------------------------------------------------- 1 | # ResNet generator and discriminator 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | #from .spectral_normalization import SpectralNorm 6 | #from torch.nn.utils import spectral_norm 7 | from ...utils.utils import Expression, Named 8 | from .ganBase import GanBase, add_spectral_norm, xavier_uniform_init 9 | 10 | # Resnet GAN and Discriminator with Spectral normalization 11 | # Implementation of architectures used in SNGAN (https://arxiv.org/abs/1802.05957) 12 | 13 | class Generator(GanBase): 14 | def __init__(self, z_dim=128,img_channels=3,k=256,**kwargs): 15 | super().__init__(z_dim,img_channels,**kwargs) 16 | self.k = k 17 | self.model = nn.Sequential( 18 | nn.Linear(z_dim, 4 * 4 * k), 19 | Expression(lambda z: z.view(-1,k,4,4)), 20 | ResBlockGenerator(k, k, stride=2), 21 | ResBlockGenerator(k, k, stride=2), 22 | ResBlockGenerator(k, k, stride=2), 23 | nn.BatchNorm2d(k), 24 | nn.ReLU(), 25 | nn.Conv2d(k, img_channels, 3, stride=1, padding=1), 26 | nn.Tanh()) 27 | 28 | self.apply(xavier_uniform_init) 29 | def forward(self, z): 30 | return self.model(z) 31 | 32 | 33 | class Discriminator(nn.Module,metaclass=Named): 34 | def __init__(self,img_channels=3,k=128,out_size=1): 35 | super().__init__() 36 | self.img_channels = img_channels 37 | self.k = k 38 | self.model = nn.Sequential( 39 | FirstResBlockDiscriminator(img_channels, k, stride=2), 40 | ResBlockDiscriminator(k, k, stride=2), 41 | ResBlockDiscriminator(k, k), 42 | ResBlockDiscriminator(k, k), 43 | nn.ReLU(), 44 | nn.AvgPool2d(8), 45 | Expression(lambda u: u.view(-1,k)), 46 | nn.Linear(k, out_size) 47 | ) 48 | self.apply(xavier_uniform_init) 49 | self.apply(add_spectral_norm) 50 | # Spectral norm on discriminator but not generator 51 | def forward(self, x): 52 | return self.model(x) 53 | 54 | class ResBlockGenerator(nn.Module): 55 | 56 | def __init__(self, in_ch, out_ch, stride=1): 57 | super().__init__() 58 | self.upsample = nn.Upsample(scale_factor=stride,mode='bilinear') if stride!=1 else nn.Sequential() 59 | self.model = nn.Sequential( 60 | nn.BatchNorm2d(in_ch), 61 | nn.ReLU(), 62 | self.upsample, 63 | nn.Conv2d(in_ch, out_ch, 3, 1, padding=1), 64 | nn.BatchNorm2d(out_ch), 65 | nn.ReLU(), 66 | nn.Conv2d(out_ch, out_ch, 3, 1, padding=1) 67 | ) 68 | self.bypass = nn.Conv2d(in_ch,out_ch,1,1,padding=0) if in_ch!=out_ch else nn.Sequential() 69 | 70 | def forward(self, x): 71 | return self.model(x) + self.bypass(self.upsample(x)) 72 | 73 | 74 | class ResBlockDiscriminator(nn.Module): 75 | 76 | def __init__(self, in_ch, out_ch, stride=1): 77 | super().__init__() 78 | self.model = nn.Sequential( 79 | nn.ReLU(), 80 | nn.Conv2d(in_ch, out_ch, 3, 1, padding=1), 81 | nn.ReLU(), 82 | nn.Conv2d(out_ch, out_ch, 3, 1, padding=1) 83 | ) 84 | self.downsample = nn.AvgPool2d(2, stride=stride, padding=0) if stride!=1 else nn.Sequential() 85 | self.bypass = nn.Conv2d(in_ch,out_ch,1,1,padding=0) if in_ch!=out_ch else nn.Sequential() 86 | 87 | def forward(self, x): 88 | return self.downsample(self.model(x)) + self.downsample(self.bypass(x)) 89 | 90 | # special ResBlock just for the first layer of the discriminator 91 | class FirstResBlockDiscriminator(nn.Module): 92 | def __init__(self, in_ch, out_ch, stride=1): 93 | super().__init__() 94 | # we don't want to apply ReLU activation to raw image before convolution transformation. 95 | self.model = nn.Sequential( 96 | #nn.ReLU(), 97 | nn.Conv2d(in_ch, out_ch, 3, 1, padding=1), 98 | nn.ReLU(), 99 | nn.Conv2d(out_ch, out_ch, 3, 1, padding=1) 100 | ) 101 | self.downsample = nn.AvgPool2d(2, stride=stride, padding=0) if stride!=1 else nn.Sequential() 102 | self.bypass = nn.Conv2d(in_ch,out_ch,1,1,padding=0) if in_ch!=out_ch else nn.Sequential() 103 | 104 | def forward(self, x): 105 | return self.downsample(self.model(x)) + self.downsample(self.bypass(x)) 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /oil/architectures/parts/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ...utils.utils import Expression,export,Named 4 | # Copied from https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py repo 5 | # original paper https://arxiv.org/pdf/1807.03247.pdf 6 | 7 | class AddCoordsTh(nn.Module): 8 | def __init__(self, x_dim=64, y_dim=64, with_r=False): 9 | super(AddCoordsTh, self).__init__() 10 | self.x_dim = x_dim 11 | self.y_dim = y_dim 12 | self.with_r = with_r 13 | 14 | def forward(self, input_tensor): 15 | """ 16 | input_tensor: (batch, c, x_dim, y_dim) 17 | """ 18 | batch_size_tensor = input_tensor.shape[0] 19 | 20 | xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32) 21 | xx_ones = xx_ones.unsqueeze(-1) 22 | 23 | xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0) 24 | xx_range = xx_range.unsqueeze(1) 25 | 26 | xx_channel = torch.matmul(xx_ones, xx_range) 27 | xx_channel = xx_channel.unsqueeze(-1) 28 | 29 | yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32) 30 | yy_ones = yy_ones.unsqueeze(1) 31 | 32 | yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0) 33 | yy_range = yy_range.unsqueeze(-1) 34 | 35 | yy_channel = torch.matmul(yy_range, yy_ones) 36 | yy_channel = yy_channel.unsqueeze(-1) 37 | 38 | xx_channel = xx_channel.permute(0, 3, 2, 1) 39 | yy_channel = yy_channel.permute(0, 3, 2, 1) 40 | 41 | xx_channel = xx_channel.float() / (self.x_dim - 1) 42 | yy_channel = yy_channel.float() / (self.y_dim - 1) 43 | 44 | xx_channel = xx_channel * 2 - 1 45 | yy_channel = yy_channel * 2 - 1 46 | 47 | xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) 48 | yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) 49 | 50 | ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 51 | 52 | if self.with_r: 53 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 54 | ret = torch.cat([ret, rr], dim=1) 55 | 56 | return ret 57 | 58 | 59 | class CoordConvTh(nn.Module): 60 | """CoordConv layer as in the paper.""" 61 | def __init__(self, x_dim, y_dim, with_r, *args, **kwargs): 62 | super(CoordConvTh, self).__init__() 63 | self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r) 64 | self.conv = nn.Conv2d(*args, **kwargs) 65 | 66 | def forward(self, input_tensor): 67 | ret = self.addcoords(input_tensor) 68 | ret = self.conv(ret) 69 | return ret 70 | 71 | 72 | ''' 73 | An alternative implementation for PyTorch with auto-infering the x-y dimensions. 74 | ''' 75 | class AddCoords(nn.Module): 76 | 77 | def __init__(self, with_r=False): 78 | super().__init__() 79 | self.with_r = with_r 80 | 81 | def forward(self, input_tensor): 82 | """ 83 | Args: 84 | input_tensor: shape(batch, channel, x_dim, y_dim) 85 | """ 86 | batch_size, _, x_dim, y_dim = input_tensor.size() 87 | 88 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 89 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 90 | 91 | xx_channel = xx_channel.float() / (x_dim - 1) 92 | yy_channel = yy_channel.float() / (y_dim - 1) 93 | 94 | xx_channel = xx_channel * 2 - 1 95 | yy_channel = yy_channel * 2 - 1 96 | 97 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 98 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 99 | 100 | ret = torch.cat([ 101 | input_tensor, 102 | xx_channel.type_as(input_tensor), 103 | yy_channel.type_as(input_tensor)], dim=1) 104 | 105 | if self.with_r: 106 | rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)) 107 | ret = torch.cat([ret, rr], dim=1) 108 | 109 | return ret 110 | 111 | @export 112 | class CoordConv(nn.Module): 113 | 114 | def __init__(self, in_channels, out_channels, kernel_size, with_r=False, **kwargs): 115 | super().__init__() 116 | self.addcoords = AddCoords(with_r=with_r) 117 | in_size = in_channels+2 118 | if with_r: 119 | in_size += 1 120 | self.conv = nn.Conv2d(in_size, out_channels, kernel_size, **kwargs) 121 | 122 | def forward(self, x): 123 | ret = self.addcoords(x) 124 | ret = self.conv(ret) 125 | return ret -------------------------------------------------------------------------------- /oil/model_trainers/segmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from oil.utils.utils import Eval, cosLr, islice 4 | from oil.utils.metrics import confusion_from_logits, meanIoU, freqIoU, pixelAcc,meanAcc 5 | from oil.model_trainers.trainer import Trainer 6 | from oil.model_trainers.classifier import Classifier, Regressor 7 | import torchvision.utils as vutils 8 | import itertools 9 | import math 10 | 11 | class Segmenter(Classifier): 12 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy 13 | and getAccuracy (full dataset) """ 14 | def __init__(self,*args,**kwargs): 15 | super().__init__(*args,**kwargs) 16 | viz_loader = (self.dataloaders['val'] or self.dataloaders['train']) 17 | self.viz_loader = islice(viz_loader,math.ceil(4/viz_loader.batch_size)) 18 | self.overlay = False 19 | 20 | def metrics(self,loader): 21 | mb_confusion = lambda mb: confusion_from_logits(self.model(mb[0]),mb[1]) 22 | full_confusion = self.evalAverageMetrics(loader,mb_confusion) 23 | metrics = { 24 | 'pixelAcc':pixelAcc(full_confusion), 25 | 'meanAcc':meanAcc(full_confusion), 26 | 'mIoU':meanIoU(full_confusion), 27 | 'fwIoU':freqIoU(full_confusion), 28 | } 29 | return metrics 30 | def logStuff(self, step, minibatch=None): 31 | seg2img = self.viz_loader.dataset.decode_segmap 32 | means = torch.tensor(self.viz_loader.dataset.means)[None,:,None,None] 33 | stds = torch.tensor(self.viz_loader.dataset.stds)[None,:,None,None] 34 | with torch.no_grad(): 35 | xs,ys,gts = zip(*[(mb[0].cpu().data, 36 | torch.argmax(self.model(mb[0]),1).cpu().data, 37 | mb[1].cpu().data.squeeze(1)) for mb in self.viz_loader]) 38 | imgs = [torch.cat(xs)*stds+means,seg2img(torch.cat(ys)),seg2img(torch.cat(gts))] 39 | if self.overlay: imgs = [overlay(imgs[0],imgs[1]),overlay(imgs[0],imgs[2])] 40 | img_grid = vutils.make_grid(torch.cat(imgs),nrow=len(imgs[0]),range=(0,1)) 41 | self.logger.add_image('Segmentations', img_grid, step) 42 | super().logStuff(step,minibatch) 43 | 44 | def overlay(rgb,segmap,alpha=.5): 45 | return rgb + alpha*(segmap-rgb)*(segmap.sum(1)!=0)[:,None].float() 46 | 47 | class ImgRegressor(Regressor): 48 | pass 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | # Convenience function for that covers a common use case of training the model using 57 | # the cosLr schedule, and logging the outcome and returning the results 58 | from torch.utils.data import DataLoader 59 | from oil.utils.utils import LoaderTo, cosLr, recursively_update,islice 60 | from oil.tuning.study import train_trial 61 | from oil.datasetup.segmentation.voc import VOCSegmentation 62 | from oil.architectures.img2img import DeepLab 63 | import collections 64 | 65 | def simpleSegmenterTrial(strict=False): 66 | def makeTrainer(config): 67 | cfg = { 68 | 'dataset': VOCSegmentation,'network':DeepLab,'net_config': {}, 69 | 'loader_config': {'batch_size':2,'pin_memory':True,'num_workers':2,'shuffle':True}, 70 | 'opt_config':{'lr':.01, 'momentum':.9, 'weight_decay':1e-4,'nesterov':True}, 71 | 'num_epochs':100,'trainer_config':{'log_args':{'timeFrac':1/2}}, 72 | } 73 | recursively_update(cfg,config) 74 | 75 | 76 | trainset = cfg['dataset']('~/datasets/{}/'.format(cfg['dataset']),image_set='train') 77 | valset = cfg['dataset']('~/datasets/{}/'.format(cfg['dataset']),image_set='val') 78 | device = torch.device('cuda') 79 | fullCNN = cfg['network'](num_classes=trainset.num_classes,**cfg['net_config']) 80 | # vgg16_backbone = VGG16(pretrained=True) 81 | # fullCNN.copy_params_from_vgg16(vgg16_backbone) 82 | fullCNN.to(device).train() 83 | #fullCNN.load_state_dict(torch.load(cfg['network'].download())) 84 | dataloaders = {} 85 | dataloaders['train'] = DataLoader(trainset,**cfg['loader_config']) 86 | dataloaders['val'] = islice(DataLoader(valset,**cfg['loader_config']),30) 87 | dataloaders = {k:LoaderTo(v,device) for k,v in dataloaders.items()} 88 | 89 | #opt_constr = lambda params: torch.optim.SGD(params, **cfg['opt_config']) 90 | opt_constr = lambda params: torch.optim.Adam(params, lr=3e-3) 91 | lr_sched = cosLr() 92 | return Segmenter(fullCNN,dataloaders,opt_constr,lr_sched,**cfg['trainer_config']) 93 | return train_trial(makeTrainer,strict) 94 | 95 | 96 | if __name__=='__main__': 97 | Trial = simpleSegmenterTrial(strict=True) 98 | Trial({'num_epochs':5}) -------------------------------------------------------------------------------- /oil/model_trainers/gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision.utils as vutils 6 | import itertools 7 | import os 8 | import torch.nn.functional as F 9 | from .trainer import Trainer 10 | import scipy.misc 11 | from ..utils.metrics import FID_and_IS 12 | from ..utils.utils import Eval 13 | from itertools import islice 14 | 15 | def hinge_loss_G(fake_logits): 16 | return -torch.mean(fake_logits) 17 | 18 | def hinge_loss_D(real_logits,fake_logits): 19 | return F.relu(1-real_logits).mean() + F.relu(1 + fake_logits).mean() 20 | 21 | class Gan(Trainer): 22 | 23 | def __init__(self, *args,D=None,opt_constr=None,lr_sched=lambda e:1, 24 | n_disc = 2, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | self.hypers.update({'n_disc':n_disc}) 27 | 28 | self.G = self.model 29 | self.D = D 30 | if not opt_constr: 31 | g_opt=d_opt = lambda params: optim.Adam(params,2e-4,betas=(.5,.999)) 32 | self.hypers.update({'lr':2e-4,'betas':(.5,.999)}) 33 | elif isinstance(opt_constr,(list,tuple)): 34 | g_opt,d_opt = opt_constr 35 | else: 36 | g_opt=d_opt = opt_constr 37 | self.g_optimizer = g_opt(self.G.parameters()) 38 | self.d_optimizer = d_opt(self.D.parameters()) 39 | g_sched = optim.lr_scheduler.LambdaLR(self.g_optimizer,lr_sched) 40 | d_sched = optim.lr_scheduler.LambdaLR(self.d_optimizer,lr_sched) 41 | self.lr_schedulers = [g_sched,d_sched] 42 | self.fixed_input = (self.G.sample_z(32),) 43 | 44 | def step(self, data): 45 | # Step the Generator 46 | self.g_optimizer.zero_grad() 47 | G_loss = self.genLoss(data) 48 | G_loss.backward() 49 | self.g_optimizer.step() 50 | # Step the Discriminator 51 | for _ in range(self.hypers['n_disc']): 52 | self.d_optimizer.zero_grad() 53 | D_loss = self.discLoss(data) 54 | D_loss.backward() 55 | self.d_optimizer.step() 56 | 57 | def genLoss(self, x): 58 | """ hinge based generator loss -E[D(G(z))] """ 59 | z = self.G.sample_z(x.shape[0]) 60 | fake_logits = self.D(self.G(z)) 61 | return hinge_loss_G(fake_logits) 62 | 63 | def discLoss(self, x): 64 | z = self.G.sample_z(x.shape[0]) 65 | fake_logits = self.D(self.G(z)) 66 | real_logits = self.D(x) 67 | return hinge_loss_D(real_logits,fake_logits) 68 | 69 | def logStuff(self, step, minibatch=None): 70 | """ Handles Logging and any additional needs for subclasses, 71 | should have no impact on the training """ 72 | 73 | metrics = {} 74 | if minibatch is not None: 75 | metrics['G_loss'] = self.genLoss(minibatch).cpu().data.numpy() 76 | metrics['D_loss'] = self.discLoss(minibatch).cpu().data.numpy() 77 | try: metrics['FID'],metrics['IS'] = FID_and_IS(self.as_dataloader(),self.dataloaders['test']) 78 | except KeyError: pass 79 | self.logger.add_scalars('metrics', metrics, step) 80 | # what if (in case of cycleGAN, there is no G?) 81 | fake_images = self.G(*self.fixed_input).cpu().data 82 | img_grid = vutils.make_grid(fake_images[:,:3], normalize=True) 83 | self.logger.add_image('fake_samples', img_grid, step) 84 | super().logStuff(step,minibatch) 85 | 86 | def as_dataloader(self,N=5000,bs=64): 87 | return GanLoader(self.G,N,bs) 88 | 89 | def state_dict(self): 90 | extra_state = { 91 | 'G_state':self.G.state_dict(), 92 | 'G_optim_state':self.g_optimizer.state_dict(), 93 | 'D_state':self.D.state_dict(), 94 | 'D_optim_state':self.d_optimizer.state_dict(), 95 | } 96 | return {**super().state_dict(),**extra_state} 97 | 98 | def load_state_dict(self,state): 99 | super().load_state_dict(state) 100 | self.G.load_state_dict(state['G_state']) 101 | self.g_optimizer.load_state_dict(state['G_optim_state']) 102 | self.D.load_state_dict(state['D_state']) 103 | self.d_optimizer.load_state_dict(state['D_optim_state']) 104 | 105 | # TODO: ???? 106 | class GanLoader(object): 107 | """ Dataloader class for the generator""" 108 | def __init__(self,G,N=10**10,bs=64): 109 | self.G, self.N, self.bs = G,N,bs 110 | def __len__(self): 111 | return self.N 112 | def __iter__(self): 113 | with torch.no_grad(),Eval(self.G): 114 | for i in range(self.N//self.bs): 115 | yield self.G.sample(self.bs) 116 | if self.N%self.bs!=0: 117 | yield self.G.sample(self.N%self.bs) 118 | 119 | def write_imgs(self,path): 120 | np_images = np.concatenate([img.cpu().numpy() for img in self],axis=0) 121 | for i,img in enumerate(np_images): 122 | scipy.misc.imsave(path+'img{}.jpg'.format(i), img) -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/preresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreResNet model definition 3 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py 4 | """ 5 | 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import math 9 | from ...utils.utils import Named 10 | 11 | __all__ = ['PreResNet110', 'PreResNet56','PreResNet'] 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(inplanes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.bn1(x) 36 | out = self.relu(out) 37 | out = self.conv1(out) 38 | 39 | out = self.bn2(out) 40 | out = self.relu(out) 41 | out = self.conv2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.bn1 = nn.BatchNorm2d(inplanes) 57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 60 | padding=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes) 62 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.bn1(x) 71 | out = self.relu(out) 72 | out = self.conv1(out) 73 | 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | out = self.conv2(out) 77 | 78 | out = self.bn3(out) 79 | out = self.relu(out) 80 | out = self.conv3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | 87 | return out 88 | 89 | 90 | class PreResNet(nn.Module,metaclass=Named): 91 | 92 | def __init__(self, num_targets=10, depth=110): 93 | super(PreResNet, self).__init__() 94 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 95 | n = (depth - 2) // 6 96 | 97 | block = Bottleneck if depth >= 44 else BasicBlock 98 | 99 | self.inplanes = 16 100 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 101 | bias=False) 102 | self.layer1 = self._make_layer(block, 16, n) 103 | self.layer2 = self._make_layer(block, 32, n, stride=2) 104 | self.layer3 = self._make_layer(block, 64, n, stride=2) 105 | self.bn = nn.BatchNorm2d(64 * block.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.avgpool = nn.AvgPool2d(8) 108 | self.fc = nn.Linear(64 * block.expansion, num_targets) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | ) 125 | 126 | layers = list() 127 | layers.append(block(self.inplanes, planes, stride, downsample)) 128 | self.inplanes = planes * block.expansion 129 | for i in range(1, blocks): 130 | layers.append(block(self.inplanes, planes)) 131 | 132 | return nn.Sequential(*layers) 133 | 134 | def forward(self, x): 135 | x = self.conv1(x) 136 | 137 | x = self.layer1(x) # 32x32 138 | x = self.layer2(x) # 16x16 139 | x = self.layer3(x) # 8x8 140 | x = self.bn(x) 141 | x = self.relu(x) 142 | 143 | x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | 147 | return x 148 | 149 | 150 | class PreResNet56(PreResNet): 151 | def __init__(self,num_targets=10): 152 | super().__init__(num_targets=num_targets,depth=56) 153 | 154 | class PreResNet110(PreResNet): 155 | def __init__(self,num_targets=10): 156 | super().__init__(num_targets=num_targets,depth=110) -------------------------------------------------------------------------------- /oil/architectures/parts/antialiasing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, Adobe Inc. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4 | # 4.0 International Public License. To view a copy of this license, visit 5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 6 | 7 | import torch 8 | import torch.nn.parallel 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from ...utils.utils import Expression,export,Named 13 | 14 | @export 15 | class Downsample(nn.Module): 16 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 17 | super(Downsample, self).__init__() 18 | self.filt_size = filt_size 19 | self.pad_off = pad_off 20 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] 21 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] 22 | self.stride = stride 23 | self.off = int((self.stride-1)/2.) 24 | self.channels = channels 25 | 26 | # print('Filter size [%i]'%filt_size) 27 | if(self.filt_size==1): 28 | a = np.array([1.,]) 29 | elif(self.filt_size==2): 30 | a = np.array([1., 1.]) 31 | elif(self.filt_size==3): 32 | a = np.array([1., 2., 1.]) 33 | elif(self.filt_size==4): 34 | a = np.array([1., 3., 3., 1.]) 35 | elif(self.filt_size==5): 36 | a = np.array([1., 4., 6., 4., 1.]) 37 | elif(self.filt_size==6): 38 | a = np.array([1., 5., 10., 10., 5., 1.]) 39 | elif(self.filt_size==7): 40 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 41 | 42 | filt = torch.Tensor(a[:,None]*a[None,:]) 43 | filt = filt/torch.sum(filt) 44 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) 45 | 46 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 47 | 48 | def forward(self, inp): 49 | if(self.filt_size==1): 50 | if(self.pad_off==0): 51 | return inp[:,:,::self.stride,::self.stride] 52 | else: 53 | return self.pad(inp)[:,:,::self.stride,::self.stride] 54 | else: 55 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 56 | 57 | @export 58 | def MaxBlurPool(channels,M=3): 59 | return nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=1), 60 | Downsample(channels=channels, filt_size=M, stride=2)) 61 | @export 62 | def BlurPool(channels,M=3): 63 | return Downsample(channels=channels, filt_size=M, stride=2) 64 | 65 | def get_pad_layer(pad_type): 66 | if(pad_type in ['refl','reflect']): 67 | PadLayer = nn.ReflectionPad2d 68 | elif(pad_type in ['repl','replicate']): 69 | PadLayer = nn.ReplicationPad2d 70 | elif(pad_type=='zero'): 71 | PadLayer = nn.ZeroPad2d 72 | else: 73 | print('Pad type [%s] not recognized'%pad_type) 74 | return PadLayer 75 | 76 | 77 | class Downsample1D(nn.Module): 78 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 79 | super(Downsample1D, self).__init__() 80 | self.filt_size = filt_size 81 | self.pad_off = pad_off 82 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] 83 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 84 | self.stride = stride 85 | self.off = int((self.stride - 1) / 2.) 86 | self.channels = channels 87 | 88 | # print('Filter size [%i]' % filt_size) 89 | if(self.filt_size == 1): 90 | a = np.array([1., ]) 91 | elif(self.filt_size == 2): 92 | a = np.array([1., 1.]) 93 | elif(self.filt_size == 3): 94 | a = np.array([1., 2., 1.]) 95 | elif(self.filt_size == 4): 96 | a = np.array([1., 3., 3., 1.]) 97 | elif(self.filt_size == 5): 98 | a = np.array([1., 4., 6., 4., 1.]) 99 | elif(self.filt_size == 6): 100 | a = np.array([1., 5., 10., 10., 5., 1.]) 101 | elif(self.filt_size == 7): 102 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 103 | 104 | filt = torch.Tensor(a) 105 | filt = filt / torch.sum(filt) 106 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) 107 | 108 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) 109 | 110 | def forward(self, inp): 111 | if(self.filt_size == 1): 112 | if(self.pad_off == 0): 113 | return inp[:, :, ::self.stride] 114 | else: 115 | return self.pad(inp)[:, :, ::self.stride] 116 | else: 117 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 118 | 119 | 120 | def get_pad_layer_1d(pad_type): 121 | if(pad_type in ['refl', 'reflect']): 122 | PadLayer = nn.ReflectionPad1d 123 | elif(pad_type in ['repl', 'replicate']): 124 | PadLayer = nn.ReplicationPad1d 125 | elif(pad_type == 'zero'): 126 | PadLayer = nn.ZeroPad1d 127 | else: 128 | print('Pad type [%s] not recognized' % pad_type) 129 | return PadLayer -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from ...utils.utils import Named 6 | """ 7 | Densenet model definition 8 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/densenet.py 9 | """ 10 | 11 | __all__ = ['DenseNet','DenseNetBC12','DenseNetBC40'] 12 | 13 | 14 | from torch.autograd import Variable 15 | 16 | class Bottleneck(nn.Module): 17 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 18 | super(Bottleneck, self).__init__() 19 | planes = expansion * growthRate 20 | self.bn1 = nn.BatchNorm2d(inplanes) 21 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 24 | padding=1, bias=False) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.dropRate = dropRate 27 | 28 | def forward(self, x): 29 | out = self.bn1(x) 30 | out = self.relu(out) 31 | out = self.conv1(out) 32 | out = self.bn2(out) 33 | out = self.relu(out) 34 | out = self.conv2(out) 35 | if self.dropRate > 0: 36 | out = F.dropout(out, p=self.dropRate, training=self.training) 37 | 38 | out = torch.cat((x, out), 1) 39 | 40 | return out 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 45 | super(BasicBlock, self).__init__() 46 | planes = expansion * growthRate 47 | self.bn1 = nn.BatchNorm2d(inplanes) 48 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 49 | padding=1, bias=False) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.dropRate = dropRate 52 | 53 | def forward(self, x): 54 | out = self.bn1(x) 55 | out = self.relu(out) 56 | out = self.conv1(out) 57 | if self.dropRate > 0: 58 | out = F.dropout(out, p=self.dropRate, training=self.training) 59 | 60 | out = torch.cat((x, out), 1) 61 | 62 | return out 63 | 64 | 65 | class Transition(nn.Module): 66 | def __init__(self, inplanes, outplanes): 67 | super(Transition, self).__init__() 68 | self.bn1 = nn.BatchNorm2d(inplanes) 69 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 70 | bias=False) 71 | self.relu = nn.ReLU(inplace=True) 72 | 73 | def forward(self, x): 74 | out = self.bn1(x) 75 | out = self.relu(out) 76 | out = self.conv1(out) 77 | out = F.avg_pool2d(out, 2) 78 | return out 79 | 80 | 81 | class DenseNet(nn.Module,metaclass=Named): 82 | 83 | def __init__(self, depth=22, block=Bottleneck, 84 | drop_rate=0, num_targets=10, k=12, compressionRate=2): 85 | super(DenseNet, self).__init__() 86 | 87 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 88 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 89 | 90 | self.growthRate = k 91 | self.dropRate = drop_rate 92 | 93 | # self.inplanes is a global variable used across multiple 94 | # helper functions 95 | self.inplanes = k * 2 96 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 97 | bias=False) 98 | self.dense1 = self._make_denseblock(block, n) 99 | self.trans1 = self._make_transition(compressionRate) 100 | self.dense2 = self._make_denseblock(block, n) 101 | self.trans2 = self._make_transition(compressionRate) 102 | self.dense3 = self._make_denseblock(block, n) 103 | self.bn = nn.BatchNorm2d(self.inplanes) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.avgpool = nn.AvgPool2d(8) 106 | self.fc = nn.Linear(self.inplanes, num_targets) 107 | 108 | # Weight initialization 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | 117 | def _make_denseblock(self, block, blocks): 118 | layers = [] 119 | for i in range(blocks): 120 | # Currently we fix the expansion ratio as the default value 121 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 122 | self.inplanes += self.growthRate 123 | 124 | return nn.Sequential(*layers) 125 | 126 | def _make_transition(self, compressionRate): 127 | inplanes = self.inplanes 128 | outplanes = int(math.floor(self.inplanes // compressionRate)) 129 | self.inplanes = outplanes 130 | return Transition(inplanes, outplanes) 131 | 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | 136 | x = self.trans1(self.dense1(x)) 137 | x = self.trans2(self.dense2(x)) 138 | x = self.dense3(x) 139 | x = self.bn(x) 140 | x = self.relu(x) 141 | 142 | x = self.avgpool(x) 143 | x = x.view(x.size(0), -1) 144 | x = self.fc(x) 145 | 146 | return x 147 | 148 | class DenseNetBC12(DenseNet): 149 | def __init__(self,num_targets=10,drop_rate=0): 150 | super().__init__(depth=100,k=12,drop_rate=drop_rate,num_targets=num_targets) 151 | 152 | class DenseNetBC40(DenseNet): 153 | def __init__(self,num_targets=10,drop_rate=0): 154 | super().__init__(depth=190,k=40,drop_rate=drop_rate,num_targets=num_targets) 155 | -------------------------------------------------------------------------------- /oil/model_trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import torch, dill 2 | from torch import optim 3 | from ..logging.lazyLogger import LazyLogger 4 | from ..utils.utils import Eval, Named 5 | from ..utils.mytqdm import tqdm 6 | from ..tuning.study import guess_metric_sign 7 | import copy, os, random 8 | import glob 9 | import numpy as np 10 | from natsort import natsorted 11 | 12 | class Trainer(object,metaclass=Named): 13 | """ Base trainer 14 | """ 15 | def __init__(self, model, dataloaders, opt_constr=optim.Adam, lr_sched = lambda e: 1, 16 | log_dir=None, log_suffix='',log_args={},early_stop_metric=None): 17 | 18 | # Setup model, optimizer, and dataloaders 19 | self.model = model 20 | 21 | self.optimizer = opt_constr(self.model.parameters()) 22 | try: self.lr_schedulers = [lr_sched(optimizer=self.optimizer)] 23 | except TypeError: self.lr_schedulers = [optim.lr_scheduler.LambdaLR(self.optimizer,lr_sched)] 24 | self.dataloaders = dataloaders # A dictionary of dataloaders 25 | self.epoch = 0 26 | 27 | self.logger = LazyLogger(log_dir, log_suffix, **log_args) 28 | #self.logger.add_text('ModelSpec','model: {}'.format(model)) 29 | self.hypers = {} 30 | self.ckpt = None#copy.deepcopy(self.state_dict()) 31 | self.early_stop_metric = early_stop_metric 32 | 33 | def metrics(self,loader): 34 | return {} 35 | 36 | def loss(self,minibatch): 37 | raise NotImplementedError 38 | 39 | def train_to(self, final_epoch=100): 40 | assert final_epoch>=self.epoch, "trying to train less than already trained" 41 | self.train(final_epoch-self.epoch) 42 | 43 | def train(self, num_epochs=100): 44 | """ The main training loop""" 45 | start_epoch = self.epoch 46 | steps_per_epoch = len(self.dataloaders['train']); step=0 47 | for self.epoch in tqdm(range(start_epoch+1, start_epoch + num_epochs+1),desc='train'): 48 | for i, minibatch in enumerate(self.dataloaders['train']): 49 | step = i + (self.epoch-1)*steps_per_epoch 50 | with self.logger as do_log: 51 | if do_log: self.logStuff(step, minibatch) 52 | self.step(minibatch) 53 | [sched.step(step/steps_per_epoch) for sched in self.lr_schedulers] 54 | self.logStuff(step) 55 | 56 | def step(self, minibatch): 57 | self.optimizer.zero_grad() 58 | loss = self.loss(minibatch) 59 | loss.backward() 60 | self.optimizer.step() 61 | return loss 62 | 63 | def logStuff(self, step, minibatch=None): 64 | metrics = {} 65 | if minibatch is not None and hasattr(self,'loss'): 66 | try: metrics['Minibatch_Loss'] = self.loss(minibatch).cpu().data.numpy() 67 | except (NotImplementedError, TypeError): pass 68 | for loader_name,dloader in self.dataloaders.items(): # Ignore metrics on train 69 | if loader_name=='train' or len(dloader)==0 or loader_name[0]=='_': continue 70 | for metric_name, metric_value in self.metrics(dloader).items(): 71 | metrics[loader_name+'_'+metric_name] = metric_value 72 | self.logger.add_scalars('metrics', metrics, step) 73 | schedules = {} 74 | for i, sched in enumerate(self.lr_schedulers): 75 | schedules['lr{}'.format(i)] = sched.get_lr()[0] 76 | self.logger.add_scalars('schedules', schedules, step) 77 | 78 | for name,m in self.model.named_modules(): 79 | if hasattr(m, 'log_data'): 80 | m.log_data(self.logger,step,name) 81 | self.logger.report() 82 | # update the best checkpoint 83 | if self.early_stop_metric is not None: 84 | maximize = guess_metric_sign(self.early_stop_metric) 85 | sign = 2*maximize-1 86 | best = (sign*self.logger.scalar_frame[self.early_stop_metric].values).max() 87 | current = sign*self.logger.scalar_frame[self.early_stop_metric].iloc[-1] 88 | if current >= best: self.ckpt = copy.deepcopy(self.state_dict()) 89 | else: self.ckpt = copy.deepcopy(self.state_dict()) 90 | 91 | def evalAverageMetrics(self,loader,metrics): 92 | num_total, loss_totals = 0, 0 93 | with Eval(self.model), torch.no_grad(): 94 | for minibatch in loader: 95 | try: mb_size = loader.batch_size 96 | except AttributeError: mb_size=1 97 | loss_totals += mb_size*metrics(minibatch) 98 | num_total += mb_size 99 | if num_total==0: raise KeyError("dataloader is empty") 100 | return loss_totals/num_total 101 | 102 | def state_dict(self): 103 | state = { 104 | 'outcome':self.logger.scalar_frame[-1:], 105 | 'epoch':self.epoch, 106 | 'model_state':self.model.state_dict(), 107 | 'optim_state':self.optimizer.state_dict(), 108 | 'logger_state':self.logger.state_dict(), 109 | } 110 | return state 111 | 112 | def load_state_dict(self,state): 113 | self.epoch = state['epoch'] 114 | self.model.load_state_dict(state['model_state']) 115 | self.optimizer.load_state_dict(state['optim_state']) 116 | self.logger.load_state_dict(state['logger_state']) 117 | 118 | def load_checkpoint(self,path=None): 119 | """ Loads the checkpoint from path, if None gets the highest epoch checkpoint""" 120 | if not path: 121 | chkpts = glob.glob(os.path.join(self.logger.log_dirr,'checkpoints/c*.state')) 122 | path = natsorted(chkpts)[-1] # get most recent checkpoint 123 | print(f"loading checkpoint {path}") 124 | with open(path,'rb') as f: 125 | self.load_state_dict(dill.load(f)) 126 | 127 | def save_checkpoint(self): 128 | return self.logger.save_object(self.ckpt,suffix=f'checkpoints/c{self.epoch}.state') 129 | 130 | -------------------------------------------------------------------------------- /oil/datasetup/joint_transforms.py: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/bfortuner/pytorch_tiramisu/blob/master/datasets/joint_transforms.py 2 | 3 | from __future__ import division 4 | import torch 5 | import math 6 | import random 7 | from PIL import Image, ImageOps 8 | import numpy as np 9 | import numbers 10 | import types 11 | 12 | 13 | class JointScale(object): 14 | """Rescales the input PIL.Image to the given 'size'. 15 | 'size' will be the size of the smaller edge. 16 | For example, if height > width, then image will be 17 | rescaled to (size * height / width, size) 18 | size: size of the smaller edge 19 | interpolation: Default: PIL.Image.BILINEAR 20 | """ 21 | 22 | def __init__(self, size, interpolation=Image.BILINEAR): 23 | self.size = size 24 | self.interpolation = interpolation 25 | 26 | def __call__(self, imgs): 27 | w, h = imgs[0].size 28 | if (w <= h and w == self.size) or (h <= w and h == self.size): 29 | return imgs 30 | if w < h: 31 | ow = self.size 32 | oh = int(self.size * h / w) 33 | return [img.resize((ow, oh), self.interpolation) for img in imgs] 34 | else: 35 | oh = self.size 36 | ow = int(self.size * w / h) 37 | return [img.resize((ow, oh), self.interpolation) for img in imgs] 38 | 39 | 40 | class JointCenterCrop(object): 41 | """Crops the given PIL.Image at the center to have a region of 42 | the given size. size can be a tuple (target_height, target_width) 43 | or an integer, in which case the target will be of a square shape (size, size) 44 | """ 45 | 46 | def __init__(self, size): 47 | if isinstance(size, numbers.Number): 48 | self.size = (int(size), int(size)) 49 | else: 50 | self.size = size 51 | 52 | def __call__(self, imgs): 53 | w, h = imgs[0].size 54 | th, tw = self.size 55 | x1 = int(round((w - tw) / 2.)) 56 | y1 = int(round((h - th) / 2.)) 57 | return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs] 58 | 59 | 60 | class JointPad(object): 61 | """Pads the given PIL.Image on all sides with the given "pad" value""" 62 | 63 | def __init__(self, padding, fill=0): 64 | assert isinstance(padding, numbers.Number) 65 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 66 | self.padding = padding 67 | self.fill = fill 68 | 69 | def __call__(self, imgs): 70 | return [ImageOps.expand(img, border=self.padding, fill=self.fill) for img in imgs] 71 | 72 | 73 | class JointLambda(object): 74 | """Applies a lambda as a transform.""" 75 | 76 | def __init__(self, lambd): 77 | assert isinstance(lambd, types.LambdaType) 78 | self.lambd = lambd 79 | 80 | def __call__(self, imgs): 81 | return [self.lambd(img) for img in imgs] 82 | 83 | 84 | class JointRandomCrop(object): 85 | """Crops the given list of PIL.Image at a random location to have a region of 86 | the given size. size can be a tuple (target_height, target_width) 87 | or an integer, in which case the target will be of a square shape (size, size) 88 | """ 89 | 90 | def __init__(self, size, padding=0): 91 | if isinstance(size, numbers.Number): 92 | self.size = (int(size), int(size)) 93 | else: 94 | self.size = size 95 | self.padding = padding 96 | 97 | def __call__(self, imgs): 98 | if self.padding > 0: 99 | imgs = [ImageOps.expand(img, border=self.padding, fill=0) for img in imgs] 100 | 101 | w, h = imgs[0].size 102 | th, tw = self.size 103 | if w == tw and h == th: 104 | return imgs 105 | 106 | x1 = random.randint(0, w - tw) 107 | y1 = random.randint(0, h - th) 108 | return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs] 109 | 110 | 111 | class JointRandomHorizontalFlip(object): 112 | """Randomly horizontally flips the given list of PIL.Image with a probability of 0.5 113 | """ 114 | 115 | def __call__(self, imgs): 116 | if random.random() < 0.5: 117 | return [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs] 118 | return imgs 119 | 120 | 121 | class JointRandomSizedCrop(object): 122 | """Random crop the given list of PIL.Image to a random size of (0.08 to 1.0) of the original size 123 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 124 | This is popularly used to train the Inception networks 125 | size: size of the smaller edge 126 | interpolation: Default: PIL.Image.BILINEAR 127 | """ 128 | 129 | def __init__(self, size, interpolation=Image.BILINEAR): 130 | self.size = size 131 | self.interpolation = interpolation 132 | 133 | def __call__(self, imgs): 134 | for attempt in range(10): 135 | area = imgs[0].size[0] * imgs[0].size[1] 136 | target_area = random.uniform(0.08, 1.0) * area 137 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 138 | 139 | w = int(round(math.sqrt(target_area * aspect_ratio))) 140 | h = int(round(math.sqrt(target_area / aspect_ratio))) 141 | 142 | if random.random() < 0.5: 143 | w, h = h, w 144 | 145 | if w <= imgs[0].size[0] and h <= imgs[0].size[1]: 146 | x1 = random.randint(0, imgs[0].size[0] - w) 147 | y1 = random.randint(0, imgs[0].size[1] - h) 148 | 149 | imgs = [img.crop((x1, y1, x1 + w, y1 + h)) for img in imgs] 150 | assert(imgs[0].size == (w, h)) 151 | 152 | return [img.resize((self.size, self.size), self.interpolation) for img in imgs] 153 | 154 | # Fallback 155 | scale = JointScale(self.size, interpolation=self.interpolation) 156 | crop = JointCenterCrop(self.size) 157 | return crop(scale(imgs)) -------------------------------------------------------------------------------- /oil/architectures/img_gen/conditionalgan.py: -------------------------------------------------------------------------------- 1 | # ResNet generator and discriminator 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | import torchcontrib 6 | import torchcontrib.nn.functional as contrib 7 | import numpy as np 8 | from ...utils.utils import Expression,export,Named 9 | from .ganBase import GanBase, add_spectral_norm, xavier_uniform_init 10 | 11 | # Conditional Resnet GAN and Discriminator with Spectral normalization and Projection Discriminator 12 | # Implementation of architectures used in SNGAN (https://arxiv.org/abs/1802.05957) 13 | # With class conditional enhancement (FiLM and Projection Discriminator) from (https://arxiv.org/abs/1802.05637) 14 | 15 | class CategoricalFiLM(nn.Module): 16 | def __init__(self,num_classes,channels): 17 | super().__init__() 18 | self.gammas = nn.Embedding(num_classes,channels) 19 | self.betas = nn.Embedding(num_classes,channels) 20 | def forward(self,x,y): 21 | return contrib.film(x,self.gammas(y),self.betas(y)) 22 | 23 | class Generator(GanBase): 24 | def __init__(self, num_classes,z_dim=128,img_channels=3,k=256): 25 | super().__init__(z_dim,img_channels) 26 | self.num_classes = num_classes 27 | self.k = k 28 | self.linear1 = nn.Linear(z_dim, 4 * 4 * k) 29 | self.res1 = cResBlockGenerator(k,k,num_classes,stride=2) 30 | self.res2 = cResBlockGenerator(k,k,num_classes,stride=2) 31 | self.res3 = cResBlockGenerator(k,k,num_classes,stride=2) 32 | self.final = nn.Sequential( 33 | nn.BatchNorm2d(k), 34 | nn.ReLU(), 35 | nn.Conv2d(k, img_channels, 3, stride=1, padding=1), 36 | nn.Tanh()) 37 | self.apply(xavier_uniform_init) 38 | 39 | def forward(self,y,z=None): 40 | if z is None: z = self.sample_z(y.shape[0]) 41 | z = self.linear1(z).view(-1,self.k,4,4) 42 | z = self.res1(z,y) 43 | z = self.res2(z,y) 44 | z = self.res3(z,y) 45 | return self.final(z) 46 | 47 | def sample_y(self,n=1): 48 | return (torch.LongTensor(n).random_()%self.num_classes).to(self.device) 49 | 50 | def sample(self, n=1): 51 | return self(self.sample_y(n),self.sample_z(n))[:,:3] 52 | 53 | 54 | class Discriminator(nn.Module,metaclass=Named): 55 | def __init__(self,num_classes,img_channels=3,k=128,out_size=1): 56 | super().__init__() 57 | self.num_classes = num_classes 58 | self.img_channels = img_channels 59 | self.k = k 60 | self.phi = nn.Sequential( 61 | FirstResBlockDiscriminator(img_channels, k, stride=2), 62 | ResBlockDiscriminator(k, k, stride=2), 63 | ResBlockDiscriminator(k, k), 64 | ResBlockDiscriminator(k, k), 65 | nn.ReLU(), 66 | Expression(lambda u: u.mean(-1).mean(-1)), 67 | ) 68 | self.psi = nn.Linear(k, out_size) 69 | self.apply(add_spectral_norm) 70 | self.label_embedding = nn.Embedding(num_classes,k) 71 | self.apply(xavier_uniform_init) 72 | 73 | def forward(self, x, y): 74 | embedded_labels = self.label_embedding(y) 75 | phi = self.phi(x) 76 | return self.psi(phi) + (embedded_labels*phi).sum(-1) 77 | 78 | class cResBlockGenerator(nn.Module): 79 | 80 | def __init__(self, in_channels, out_channels, num_classes, stride=1): 81 | super().__init__() 82 | self.stride = stride 83 | self.bn1 = nn.BatchNorm2d(in_channels) 84 | self.film1 = CategoricalFiLM(num_classes,in_channels) # should it be shared? 85 | self.relu1 = nn.ReLU() 86 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 87 | self.bn2 = nn.BatchNorm2d(out_channels) 88 | self.film2 = CategoricalFiLM(num_classes,out_channels) 89 | self.relu2 = nn.ReLU() 90 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 91 | if stride!=1 or in_channels!=out_channels: 92 | self.shortcut = nn.Sequential( 93 | nn.Upsample(scale_factor=stride,mode='bilinear'), 94 | nn.Conv2d(in_channels,out_channels,3, 1, padding=1)) 95 | else: 96 | self.shortcut = nn.Sequential() 97 | 98 | def forward(self, x,y): 99 | z = x 100 | z = self.relu1(self.film1(self.bn1(z),y)) 101 | z = self.conv1(F.interpolate(z,scale_factor=self.stride)) 102 | z = self.conv2(self.relu2(self.film2(self.bn2(z),y))) 103 | return z + self.shortcut(x) 104 | 105 | 106 | class ResBlockDiscriminator(nn.Module): 107 | 108 | def __init__(self, in_channels, out_channels, stride=1): 109 | super(ResBlockDiscriminator, self).__init__() 110 | modules = [nn.ReLU(), 111 | nn.Conv2d(in_channels, out_channels, 3, 1, padding=1), 112 | nn.ReLU(), 113 | nn.Conv2d(out_channels, out_channels, 3, 1, padding=1)] 114 | bypass = [] 115 | if stride!=1: 116 | modules += [nn.AvgPool2d(2, stride=stride, padding=0)] 117 | bypass += [nn.Conv2d(in_channels,out_channels, 1, 1, padding=0), 118 | nn.AvgPool2d(2, stride=stride, padding=0)] 119 | self.model = nn.Sequential(*modules) 120 | self.bypass = nn.Sequential(*bypass) 121 | def forward(self, x): 122 | return self.model(x) + self.bypass(x) 123 | 124 | # special ResBlock just for the first layer of the discriminator 125 | class FirstResBlockDiscriminator(nn.Module): 126 | 127 | def __init__(self, in_channels, out_channels, stride=1): 128 | super(FirstResBlockDiscriminator, self).__init__() 129 | 130 | # we don't want to apply ReLU activation to raw image before convolution transformation. 131 | self.model = nn.Sequential( 132 | nn.Conv2d(in_channels, out_channels, 3, 1, padding=1), 133 | nn.ReLU(), 134 | nn.Conv2d(out_channels, out_channels, 3, 1, padding=1), 135 | nn.AvgPool2d(2)) 136 | self.bypass = nn.Sequential( 137 | nn.AvgPool2d(2), 138 | nn.Conv2d(in_channels, out_channels, 1, 1, padding=0)) 139 | 140 | def forward(self, x): 141 | return self.model(x) + self.bypass(x) 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /oil/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import subprocess 6 | from torchvision.models.inception import inception_v3 7 | from scipy.stats import entropy 8 | from scipy.linalg import norm,sqrtm 9 | from ..utils.utils import Eval, Expression 10 | #from torch.nn.functional import adaptive_avg_pool2d 11 | #from .pytorch-fid.fid_score import calculate_frechet_distance 12 | #from .pytorch-fid.inception import InceptionV3 13 | 14 | # GAN Metrics 15 | 16 | # TODO: cache logits for existing datasets 17 | # should be possible if we can serialize dataloaders 18 | def get_inception(): 19 | """ grabs the pytorch pretrained inception_v3 with resized inputs """ 20 | inception = inception_v3(pretrained=True,transform_input=False) 21 | upsample = Expression(lambda x: nn.functional.interpolate(x,size=(299,299),mode='bilinear')) 22 | model = nn.Sequential(upsample,inception).cuda().eval() 23 | return model 24 | 25 | def get_logits(model,loader): 26 | """ Extracts logits from a model, dataloader returns a numpy array of size (N, K) 27 | where K is the number of classes """ 28 | with torch.no_grad(), Eval(model): 29 | model_logits = lambda mb: model(mb).cpu().data.numpy() 30 | logits = np.concatenate([model_logits(minibatch) for minibatch in loader],axis=0) 31 | return logits 32 | 33 | def FID_from_logits(logits1,logits2): 34 | """Computes the FID between logits1 and logits2 35 | Inputs: [logits1 (N,C)] [logits2 (N,C)] """ 36 | mu1 = np.mean(logits1,axis=0) 37 | mu2 = np.mean(logits2,axis=0) 38 | sigma1 = np.cov(logits1, rowvar=False) 39 | sigma2 = np.cov(logits2, rowvar=False) 40 | 41 | tr = np.trace(sigma1 + sigma2 - 2*sqrtm(sigma1@sigma2)) 42 | distance = norm(mu1-mu2)**2 + tr 43 | return distance 44 | 45 | def IS_from_logits(logits): 46 | """ Computes the Inception score (IS) from logits of the dataset of size N with C classes. 47 | Inputs: [logits (N,C)], Outputs: [IS (scalar)]""" 48 | # E_z[KL(Pyz||Py)] = \mean_z [\sum_y (Pyz log(Pyz) - Pyz log(Py))] 49 | Pyz = np.exp(logits).transpose() # Take softmax (up to a normalization constant) 50 | Pyz /= Pyz.sum(0)[None,:] # divide by normalization constant 51 | Py = np.broadcast_to(Pyz.mean(-1)[:,None],Pyz.shape) # Average over z 52 | logIS = entropy(Pyz,Py).mean() # Average over z 53 | return np.exp(logIS) 54 | 55 | cachedLogits = {} 56 | def FID(loader1,loader2): 57 | """ Computes the Frechet Inception Distance (FID) between the two image dataloaders 58 | using pytorch pretrained inception_v3. Requires >2048 imgs for comparison 59 | Dataloader should be an iterable of minibatched images, assumed to already 60 | be normalized with mean 0, std 1 (per color) 61 | """ 62 | model = get_inception() 63 | logits1 = get_logits(model,loader1) 64 | if loader2 not in cachedLogits: 65 | cachedLogits[loader2] = get_logits(model,loader2) 66 | logits2 = cachedLogits[loader2] 67 | return FID_from_logits(logits1,logits2) 68 | 69 | def IS(loader): 70 | """Computes the Inception score of a dataloader using pytorch pretrained inception_v3""" 71 | model = get_inception() 72 | logits = get_logits(model,loader) 73 | return IS_from_logits(logits) 74 | 75 | 76 | def FID_and_IS(loader1,loader2): 77 | """Computes FID and IS score for loader1 against target loader2 """ 78 | model = get_inception() 79 | logits1 = get_logits(model,loader1) 80 | if loader2 not in cachedLogits: 81 | cachedLogits[loader2] = get_logits(model,loader2) 82 | logits2 = cachedLogits[loader2] 83 | return FID_from_logits(logits1,logits2),IS_from_logits(logits1) 84 | 85 | #TODO: Implement Kernel Inception Distance (KID) from (https://openreview.net/pdf?id=r1lUOzWCW) 86 | 87 | 88 | def get_official_FID(loader,dataset='cifar10'): 89 | #TODO: make function not ass and check that it still works 90 | dir = os.path.expanduser("~/olive-oil-ml/oil/utils/") 91 | path = dir+"temp" 92 | loader.write_imgs(path) 93 | if dataset not in ('cifar10',): 94 | raise NotImplementedError 95 | score = subprocess.check_output(dir+"TTUR/fid.py "+dir+"fid_stats_{}_train.npz {}.npy" 96 | .format(dataset,path),shell=True) 97 | return score 98 | 99 | 100 | # Semantic Segmentation Metrics 101 | # Adapted from https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py 102 | def confusion_from_logits(logits,y_gt): 103 | bs, num_classes, _, _ = logits.shape 104 | pred_image = logits.max(1)[1].type_as(y_gt).cpu().data.numpy() 105 | #print(pred_image) 106 | gt_image = y_gt.cpu().data.numpy() 107 | return confusion_matrix(pred_image,gt_image,num_classes) 108 | 109 | def confusion_matrix(pred_image,gt_image,num_classes): 110 | """Computes the confusion matrix from two numpy class images (integer values) 111 | ignoring classes that are negative""" 112 | mask = (gt_image >= 0) & (gt_image < num_classes) 113 | #print(gt_image[mask]) 114 | label = num_classes * gt_image[mask].astype(int) + pred_image[mask] 115 | count = np.bincount(label, minlength=num_classes**2) 116 | confusion_matrix = count.reshape(num_classes, num_classes) 117 | return confusion_matrix # confusing shape, maybe transpose 118 | 119 | def meanIoU(confusion_matrix): 120 | MIoU = np.diag(confusion_matrix) / ( 121 | np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) - 122 | np.diag(confusion_matrix)) 123 | MIoU = np.nanmean(MIoU) 124 | return MIoU 125 | 126 | def freqIoU(confusion_matrix): 127 | freq = np.sum(confusion_matrix, axis=1) / np.sum(confusion_matrix) 128 | iu = np.diag(confusion_matrix) / ( 129 | np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) - 130 | np.diag(confusion_matrix)) 131 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 132 | return FWIoU 133 | 134 | def pixelAcc(confusion_matrix): 135 | return np.diag(confusion_matrix).sum() / confusion_matrix.sum() 136 | 137 | def meanAcc(confusion_matrix): 138 | return np.nanmean(np.diag(confusion_matrix) / np.sum(confusion_matrix, axis=1)) 139 | 140 | 141 | # def boundary_mIoU(confusion_matrix,epsilon) -------------------------------------------------------------------------------- /oil/datasetup/dataloaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..utils.utils import FixedNumpySeed 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.sampler import Sampler, SubsetRandomSampler 5 | 6 | 7 | # def getUnlabLoader(trainset, ul_BS, **kwargs): 8 | # """ Returns a dataloader for the full dataset, with cyclic reshuffling """ 9 | # indices = np.arange(len(trainset)) 10 | # unlabSampler = ShuffleCycleSubsetSampler(indices) 11 | # unlabLoader = DataLoader(trainset,sampler=unlabSampler,batch_size=ul_BS,**kwargs) 12 | # return unlabLoader 13 | 14 | def getLabLoader(trainset, lab_BS, amnt_lab=1, amnt_dev=0, dataseed=0, balanced=True,**kwargs): 15 | """ returns a dataloader of class balanced subset of the full dataset, 16 | and a (possibly empty) dataloader reserved for devset 17 | amntLabeled and amntDev can be a fraction or an integer. 18 | If fraction amntLabeled specifies fraction of entire dataset to 19 | use as labeled, whereas fraction amntDev is fraction of labeled 20 | dataset to reserve as a devset """ 21 | numLabeled = amnt_lab 22 | if amnt_lab <= 1: 23 | numLabeled *= len(trainset) 24 | numDev = amnt_dev 25 | if amnt_dev <= 1: 26 | numDev *= numLabeled 27 | with FixedNumpySeed(dataseed): 28 | get_indices = classBalancedSampleIndices if (balanced and trainset.balanced) else randomSampleIndices 29 | labIndices, devIndices = get_indices(trainset, int(numLabeled), int(numDev)) 30 | 31 | labSampler = SubsetRandomSampler(labIndices) 32 | labLoader = DataLoader(trainset,sampler=labSampler,batch_size=lab_BS,**kwargs) 33 | if numLabeled == 0: labLoader = EmptyLoader() 34 | 35 | devSampler = SequentialSubsetSampler(devIndices) # No shuffling on dev 36 | devLoader = DataLoader(trainset,sampler=devSampler,batch_size=lab_BS) 37 | return labLoader, devLoader 38 | 39 | def classBalancedSampleIndices(trainset, numLabeled, numDev): 40 | """ Generates a subset of indices of y (of size numLabeled) so that 41 | each class is equally represented """ 42 | y = np.array([target for img,target in trainset]) 43 | uniqueVals = np.unique(y) 44 | numDev = (numDev // len(uniqueVals))*len(uniqueVals) 45 | numLabeled = ((numLabeled-numDev)// len(uniqueVals))*len(uniqueVals) 46 | 47 | classIndices = [np.where(y==val) for val in uniqueVals] 48 | devIndices = np.empty(numDev, dtype=np.int64) 49 | labIndices = np.empty(numLabeled, dtype=np.int64) 50 | 51 | dev_m = numDev // len(uniqueVals) 52 | lab_m = numLabeled // len(uniqueVals); assert lab_m>0, "Note: dev is subtracted from train" 53 | total_m = lab_m + dev_m 54 | for i in range(len(uniqueVals)): 55 | sampledclassIndices = np.random.choice(classIndices[i][0],total_m,replace=False) 56 | labIndices[i*lab_m:i*lab_m+lab_m] = sampledclassIndices[:lab_m] 57 | devIndices[i*dev_m:i*dev_m+dev_m] = sampledclassIndices[lab_m:] 58 | 59 | print("Creating Train, Dev split \ 60 | with {} Train and {} Dev".format(numLabeled, numDev)) 61 | return labIndices, devIndices 62 | 63 | def randomSampleIndices(trainset,numLabeled,numDev): 64 | numLabeled = (numLabeled-numDev) 65 | indices = np.random.permutation(len(trainset)) 66 | return indices[:numLabeled],indices[numLabeled:numLabeled+numDev] 67 | 68 | #TODO: Needs some rework to function properly in the semisupervised case 69 | 70 | # class ShuffleCycleSubsetSampler(Sampler): 71 | # """A cycle version of SubsetRandomSampler with 72 | # reordering on calls to __iter__ 73 | # contains current permutation & index as a state """ 74 | # def __init__(self, indices): 75 | # self.indices = indices 76 | 77 | # def __iter__(self): 78 | # return self._gen() 79 | 80 | # def _gen(self): 81 | # i = len(self.indices) 82 | # while True: 83 | # if i >= len(self.indices): 84 | # perm = np.random.permutation(self.indices) 85 | # i=0 86 | # yield perm[i] 87 | # i+=1 88 | 89 | # def __len__(self): 90 | # return len(self.indices) 91 | 92 | class SequentialSubsetSampler(Sampler): 93 | """Samples sequentially from specified indices, does not cycle """ 94 | def __init__(self, indices): 95 | self.indices = indices 96 | def __iter__(self): 97 | return iter(self.indices) 98 | def __len__(self): 99 | return len(self.indices) 100 | 101 | class EmptyLoader(object): 102 | """A dataloader that loads None tuples, with zero length for convenience""" 103 | def __next__(self): 104 | return (None,None) 105 | def __len__(self): 106 | return 0 107 | def __iter__(self): 108 | return self 109 | 110 | 111 | # def getUandLloaders(trainset, amntLabeled, lab_BS, ul_BS, **kwargs): 112 | # """ Returns two cycling dataloaders where the first one only operates on a subset 113 | # of the dataset. AmntLabeled can either be a fraction or an integer """ 114 | # numLabeled = amntLabeled 115 | # if amntLabeled <= 1: 116 | # numLabeled *= len(trainset) 117 | 118 | # indices = np.random.permutation(len(trainset)) 119 | # labIndices = indices[:numLabeled] 120 | 121 | # labSampler = ShuffleCycleSubsetSampler(labIndices) 122 | # labLoader = DataLoader(trainset,sampler=labSampler,batch_size=lab_BS,**kwargs) 123 | # if amntLabeled == 0: labLoader = EmptyLoader() 124 | 125 | # # Includes the labeled samples in the unlabeled data 126 | # unlabSampler = ShuffleCycleSubsetSampler(indices) 127 | # unlabLoader = DataLoader(trainset,sampler=unlabSampler,batch_size=ul_BS,**kwargs) 128 | 129 | # return unlabLoader, labLoader 130 | 131 | # def getLoadersBalanced(trainset, amntLabeled, lab_BS, ul_BS, **kwargs): 132 | # """ Variant of getUandLloaders""" 133 | # numLabeled = amntLabeled 134 | # if amntLabeled <= 1: 135 | # numLabeled *= len(trainset) 136 | 137 | # indices = np.random.permutation(len(trainset)) 138 | # labIndices = classBalancedSampleIndices(trainset, numLabeled) 139 | 140 | # labSampler = ShuffleCycleSubsetSampler(labIndices) 141 | # labLoader = DataLoader(trainset,sampler=labSampler,batch_size=lab_BS,**kwargs) 142 | # if amntLabeled == 0: labLoader = EmptyLoader() 143 | 144 | # # Includes the labeled samples in the unlabeled data 145 | # unlabSampler = ShuffleCycleSubsetSampler(indices) 146 | # unlabLoader = DataLoader(trainset,sampler=unlabSampler,batch_size=ul_BS,**kwargs) 147 | 148 | # return unlabLoader, labLoader 149 | -------------------------------------------------------------------------------- /oil/tuning/configGenerator.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import numbers 4 | import random 5 | from ..utils.utils import log_uniform,ReadOnlyDict,FixedNumpySeed 6 | from collections import defaultdict 7 | from collections.abc import Iterable 8 | 9 | import itertools,operator,functools 10 | # class SearchVariation(object): 11 | # def __init__(self,sample_func): 12 | # self.sample_func = sample_func 13 | # def sample(self,config): 14 | # out = self.sample_func(config) 15 | # if isinstance(out,SearchVariation): raise KeyError 16 | # return out 17 | 18 | 19 | # def sampleFrom(func): 20 | # return SearchVariation(func) 21 | def logUniform(low,high): 22 | return lambda _: log_uniform(low,high) 23 | def uniform(low,high): 24 | return lambda _:np.random.uniform(low,high) 25 | 26 | class NoGetItLambdaDict(dict): 27 | """ Regular dict, but refuses to __getitem__ pretending 28 | the element is not there and throws a KeyError 29 | if the value is a non string iterable or a lambda """ 30 | def __init__(self,d={}): 31 | super().__init__() 32 | for k,v in d.items(): 33 | if isinstance(v,dict): 34 | self[k] = NoGetItLambdaDict(v) 35 | else: 36 | self[k] = v 37 | def __getitem__(self, key): 38 | value = super().__getitem__(key) 39 | if callable(value) and value.__name__ == "": 40 | raise LookupError("You shouldn't try to retrieve lambda {} from this dict".format(value)) 41 | if isinstance(value,Iterable) and not isinstance(value,(str,bytes,dict,tuple)): 42 | raise LookupError("You shouldn't try to retrieve iterable {} from this dict".format(value)) 43 | return value 44 | 45 | # pop = __readonly__ 46 | # popitem = __readonly__ 47 | 48 | def sample_config(config_spec): 49 | """ Generates configs from the config spec. 50 | It will apply lambdas that depend on the config and sample from any 51 | iterables, make sure that no elements in the generated config are meant to 52 | be iterable or lambdas, strings are allowed.""" 53 | cfg_all = config_spec 54 | more_work=True 55 | i=0 56 | while more_work: 57 | cfg_all, more_work = _sample_config(cfg_all,NoGetItLambdaDict(cfg_all)) 58 | i+=1 59 | if i>10: 60 | raise RecursionError("config dependency unresolvable with {}".format(cfg_all)) 61 | out = defaultdict(dict) 62 | out.update(cfg_all) 63 | return out 64 | 65 | def _sample_config(config_spec,cfg_all): 66 | cfg = {} 67 | more_work = False 68 | for k,v in config_spec.items(): 69 | if isinstance(v,dict): 70 | new_dict,extra_work = _sample_config(v,cfg_all) 71 | cfg[k] = new_dict 72 | more_work |= extra_work 73 | elif isinstance(v,Iterable) and not isinstance(v,(str,bytes,dict,tuple)): 74 | cfg[k] = random.choice(v) 75 | elif callable(v) and v.__name__ == "": 76 | try:cfg[k] = v(cfg_all) 77 | except (KeyError, LookupError,Exception): 78 | cfg[k] = v # is used isntead of the variable it returns 79 | more_work = True 80 | else: cfg[k] = v 81 | return cfg, more_work 82 | 83 | def flatten(d, parent_key='', sep='/'): 84 | """An invertible dictionary flattening operation that does not clobber objs""" 85 | items = [] 86 | for k, v in d.items(): 87 | new_key = parent_key + sep + k if parent_key else k 88 | if isinstance(v, dict) and v: # non-empty dict 89 | items.extend(flatten(v, new_key, sep=sep).items()) 90 | else: 91 | items.append((new_key, v)) 92 | return dict(items) 93 | 94 | def unflatten(d,sep='/'): 95 | """Take a dictionary with keys {'k1/k2/k3':v} to {'k1':{'k2':{'k3':v}}} 96 | as outputted by flatten """ 97 | out_dict={} 98 | for k,v in d.items(): 99 | if isinstance(k,str): 100 | keys = k.split(sep) 101 | dict_to_modify = out_dict 102 | for partial_key in keys[:-1]: 103 | try: dict_to_modify = dict_to_modify[partial_key] 104 | except KeyError: 105 | dict_to_modify[partial_key] = {} 106 | dict_to_modify = dict_to_modify[partial_key] 107 | # Base level reached 108 | if keys[-1] in dict_to_modify: 109 | dict_to_modify[keys[-1]].update(v) 110 | else: 111 | dict_to_modify[keys[-1]] = v 112 | else: out_dict[k]=v 113 | return out_dict 114 | 115 | class grid_iter(object): 116 | """ Defines a length which corresponds to one full pass through the grid 117 | defined by grid variables in config_spec, but the iterator will continue iterating 118 | past that by repeating over the grid variables""" 119 | def __init__(self,config_spec,num_elements=-1,shuffle=True): 120 | self.cfg_flat = flatten(config_spec) 121 | is_grid_iterable = lambda v: (isinstance(v,Iterable) and not isinstance(v,(str,bytes,dict,tuple))) 122 | iterables = sorted({k:v for k,v in self.cfg_flat.items() if is_grid_iterable(v)}.items()) 123 | if iterables: self.iter_keys,self.iter_vals = zip(*iterables) 124 | else: self.iter_keys,self.iter_vals = [],[[]] 125 | self.vals = list(itertools.product(*self.iter_vals)) 126 | if shuffle: 127 | with FixedNumpySeed(0): random.shuffle(self.vals) 128 | self.num_elements = num_elements if num_elements>=0 else (-1*num_elements)*len(self) 129 | 130 | def __iter__(self): 131 | self.i=0 132 | self.vals_iter = iter(self.vals) 133 | return self 134 | def __next__(self): 135 | self.i+=1 136 | if self.i > self.num_elements: raise StopIteration 137 | if not self.vals: v = [] 138 | else: 139 | try: v = next(self.vals_iter) 140 | except StopIteration: 141 | self.vals_iter = iter(self.vals) 142 | v = next(self.vals_iter) 143 | chosen_iter_params = dict(zip(self.iter_keys,v)) 144 | self.cfg_flat.update(chosen_iter_params) 145 | return sample_config(unflatten(self.cfg_flat)) 146 | def __len__(self): 147 | product = functools.partial(functools.reduce, operator.mul) 148 | return product(len(v) for v in self.iter_vals) if self.vals else 1 149 | 150 | def flatten_dict(d): 151 | """ Flattens a dictionary, ignoring outer keys. Only 152 | numbers and strings allowed, others will be converted 153 | to a string. """ 154 | out = {} 155 | for k,v in d.items(): 156 | if isinstance(v,dict): 157 | out.update(flatten_dict(v)) 158 | elif isinstance(v,(numbers.Number,str,bytes)): 159 | out[k] = v 160 | else: 161 | out[k] = str(v) 162 | return out 163 | -------------------------------------------------------------------------------- /oil/utils/parallel.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import torch 3 | import warnings 4 | import os 5 | from itertools import chain 6 | import torch.nn as nn 7 | from torch.nn.modules import Module 8 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 9 | from torch.nn.parallel.replicate import replicate 10 | from torch.nn.parallel.parallel_apply import parallel_apply 11 | from torch.cuda._utils import _get_device_index 12 | 13 | 14 | def try_multigpu_parallelize(model,bs,lr=None): 15 | scalelr = (lr is not None) 16 | if os.environ.copy().get("WORLD_SIZE",0)!=0: 17 | assert torch.cuda.is_available(), "No GPUs found" 18 | ngpus = torch.cuda.device_count() # For Adam, only the bs is scaled up 19 | print(f"Discovered and training with {ngpus} GPUs, bs ->\ 20 | {ngpus}*bs{f', lr -> {ngpus}*lr' if scalelr else ''}.") 21 | torch.distributed.init_process_group(backend="nccl") 22 | DDP_model = nn.parallel.DistributedDataParallel(model)#,find_unused_parameters=True) #for 1.0.0 23 | return (DDP_model, bs*ngpus, lr*ngpus) if scalelr else (DDP_model, bs*ngpus) 24 | else: 25 | return (model, bs, lr) if scalelr else (model,bs) 26 | 27 | def _check_balance(device_ids): 28 | imbalance_warn = """ 29 | There is an imbalance between your GPUs. You may want to exclude GPU {} which 30 | has less than 75% of the memory or cores of GPU {}. You can do so by setting 31 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 32 | environment variable.""" 33 | device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) 34 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] 35 | 36 | def warn_imbalance(get_prop): 37 | values = [get_prop(props) for props in dev_props] 38 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) 39 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) 40 | if min_val / max_val < 0.75: 41 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) 42 | return True 43 | return False 44 | 45 | if warn_imbalance(lambda props: props.total_memory): 46 | return 47 | if warn_imbalance(lambda props: props.multi_processor_count): 48 | return 49 | 50 | class MyDataParallel(nn.DataParallel): 51 | 52 | def __getattr__(self, name): 53 | try: 54 | return super().__getattr__(name) 55 | except AttributeError: 56 | attr = getattr(self.module, name) 57 | if callable(attr): 58 | funcname = name 59 | def parallel_closure(*inputs,**kwargs): 60 | if not self.device_ids: 61 | return self.module(*inputs, **kwargs) 62 | 63 | for t in chain(self.module.parameters(), self.module.buffers()): 64 | if t.device != self.src_device_obj: 65 | raise RuntimeError("module must have its parameters and buffers " 66 | "on device {} (device_ids[0]) but found one of " 67 | "them on device: {}".format(self.src_device_obj, t.device)) 68 | 69 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 70 | if len(self.device_ids) == 1: 71 | return getattr(self.module,funcname)(*inputs[0], **kwargs[0]) 72 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 73 | outputs = self.parallel_apply([getattr(module,funcname) for module in replicas], inputs, kwargs) 74 | return self.gather(outputs, self.output_device) 75 | return parallel_closure 76 | else: 77 | return attr 78 | 79 | import torch.cuda.comm 80 | import torch.distributed as dist 81 | 82 | if dist.is_available(): 83 | from torch.distributed.distributed_c10d import _get_default_group 84 | 85 | from torch.cuda._utils import _get_device_index 86 | 87 | 88 | def _find_tensors(obj): 89 | r""" 90 | Recursively find all tensors contained in the specified object. 91 | """ 92 | if isinstance(obj, torch.Tensor): 93 | return [obj] 94 | if isinstance(obj, (list, tuple)): 95 | return itertools.chain(*map(_find_tensors, obj)) 96 | if isinstance(obj, dict): 97 | return itertools.chain(*map(_find_tensors, obj.values())) 98 | return [] 99 | class MyDistributedDataParallel(nn.parallel.DistributedDataParallel): 100 | 101 | def __getattr__(self, name): 102 | try: 103 | return super().__getattr__(name) 104 | except AttributeError: 105 | attr = getattr(self.module, name) 106 | if callable(attr): 107 | #print("got to callable function") 108 | funcname = name 109 | def parallel_closure(*inputs,**kwargs): 110 | if self.require_forward_param_sync: 111 | self._sync_params() 112 | 113 | if self.device_ids: 114 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 115 | if len(self.device_ids) == 1: 116 | output = getattr(self.module,funcname)(*inputs[0], **kwargs[0]) 117 | else: 118 | outputs = self.parallel_apply([getattr(module,funcname) for module in self._module_copies[:len(inputs)]], inputs, kwargs) 119 | output = self.gather(outputs, self.output_device) 120 | else: 121 | output = getattr(self.module,funcname)(*inputs, **kwargs) 122 | 123 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 124 | #print("grad enabled, got here") 125 | self.require_forward_param_sync = True 126 | # We'll return the output object verbatim since it is a freeform 127 | # object. We need to find any tensors in this object, though, 128 | # because we need to figure out which parameters were used during 129 | # this forward pass, to ensure we short circuit reduction for any 130 | # unused parameters. Only if `find_unused_parameters` is set. 131 | if self.find_unused_parameters: 132 | self.reducer.prepare_for_backward(list(_find_tensors(output))) 133 | else: 134 | self.reducer.prepare_for_backward([]) 135 | else: 136 | self.require_forward_param_sync = False 137 | 138 | return output 139 | return parallel_closure 140 | else: 141 | return attr 142 | -------------------------------------------------------------------------------- /oil/datasetup/datasets.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision 2 | import torchvision.transforms as transforms 3 | from torch.utils.data import DataLoader, Dataset 4 | import torchvision.datasets as ds 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | from . import augLayers 10 | from ..utils.utils import Named, export, Wrapper 11 | 12 | class EasyIMGDataset(Dataset): 13 | ignored_index = -100 14 | class_weights = None 15 | balanced = True 16 | stratify = True 17 | def __init__(self,*args,gan_normalize=False,download=True,**kwargs): 18 | transform = kwargs.pop('transform',None) 19 | if not transform: transform = self.default_transform(gan_normalize) 20 | super().__init__(*args,transform=transform,download=download,**kwargs) 21 | 22 | def default_transform(self,gan_normalize=False): 23 | if gan_normalize: 24 | normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 25 | else: 26 | normalize = transforms.Normalize(self.means, self.stds) 27 | transform = transforms.Compose([transforms.ToTensor(),normalize]) 28 | return transform 29 | # def compute_default_transform(self): 30 | # raise NotImplementedError 31 | def default_aug_layers(self): 32 | return nn.Sequential() 33 | 34 | # class InMemoryDataset(EasyIMGDataset): 35 | # def __init__(self,*args,**kwargs): 36 | # super().__init__(*args,**kwargs) 37 | # self.data = F.to_tensor(self.data) 38 | # def to(self,device): 39 | # self.data.to(device) 40 | # self.targets.to(device) 41 | # return self 42 | 43 | @export 44 | class CIFAR10(EasyIMGDataset,ds.CIFAR10): 45 | means = (0.4914, 0.4822, 0.4465) 46 | stds = (.247,.243,.261) 47 | num_targets=10 48 | def default_aug_layers(self): 49 | return nn.Sequential( 50 | augLayers.RandomTranslate(4), 51 | augLayers.RandomHorizontalFlip(), 52 | ) 53 | @export 54 | class CIFAR100(EasyIMGDataset,ds.CIFAR100): 55 | means = (0.5071, 0.4867, 0.4408) 56 | stds = (0.2675, 0.2565, 0.2761) 57 | num_targets=100 58 | def default_aug_layers(self): 59 | return nn.Sequential( 60 | augLayers.RandomTranslate(4), 61 | augLayers.RandomHorizontalFlip(), 62 | ) 63 | @export 64 | class SVHN(EasyIMGDataset,ds.SVHN): 65 | #TODO: Find real mean and std 66 | means = (0.5, 0.5, 0.5) 67 | stds = (0.25, 0.25, 0.25) 68 | num_targets=10 69 | def default_aug_layers(self): 70 | return nn.Sequential( 71 | augLayers.RandomTranslate(4), 72 | augLayers.RandomHorizontalFlip(), 73 | ) 74 | 75 | class IndexedDataset(Wrapper): 76 | def __init__(self,dataset,ids): 77 | super().__init__(dataset) 78 | self._ids = ids 79 | def __len__(self): 80 | return len(self._ids) 81 | def __getitem__(self,i): 82 | return super().__getitem__(self._ids[i]) 83 | 84 | @export 85 | def split_dataset(dataset,splits): 86 | """ Inputs: A torchvision.dataset DATASET and a dictionary SPLITS 87 | containing fractions or number of elements for each of the new datasets. 88 | Allows values (0,1] or (1,N] or -1 to fill with remaining. 89 | Example {'train':-1,'val':.1} will create a (.9, .1) split of the dataset. 90 | {'train':10000,'val':.2,'test':-1} will create a (10000, .2N, .8N-10000) split 91 | {'train':.5} will simply subsample the dataset by half.""" 92 | # Check that split values are valid 93 | N = len(dataset) 94 | int_splits = {k:(int(np.round(v*N)) if ((v<=1) and (v>0)) else v) for k,v in splits.items()} 95 | assert sum(int_splits.values())<=N, "sum of split values exceed training set size, \ 96 | make sure that they sum to <=1 or the dataset size." 97 | if hasattr(dataset,'stratify') and dataset.stratify!=False: 98 | if dataset.stratify==True: 99 | y = np.array([mb[-1] for mb in dataset]) 100 | else: 101 | y = np.array([dataset.stratify(mb) for mb in dataset]) 102 | else: 103 | y = None 104 | indices = np.arange(len(dataset)) 105 | split_datasets = {} 106 | for split_name, split_count in sorted(int_splits.items(),reverse=True, key=lambda kv: kv[1]): 107 | if split_count == len(indices) or split_count==-1: 108 | new_split_ids = indices 109 | indices = indices[:0] 110 | else: 111 | strat = None if y is None else y[indices] 112 | indices, new_split_ids = train_test_split(indices,test_size=split_count,stratify=strat) 113 | split_datasets[split_name] = IndexedDataset(dataset,new_split_ids) 114 | return split_datasets 115 | 116 | 117 | 118 | 119 | 120 | # class SegmentationDataset(EasyIMGDataset): 121 | # def __init__(self,*args,joint_transform=True,split='train',**kwargs): 122 | # if joint_transform is True: 123 | # joint_transform = self.default_joint_transform() if \ 124 | # split=='train' else None 125 | # super().__init__(*args,joint_transform=joint_transform, 126 | # split=split,**kwargs) 127 | 128 | # def default_joint_transform(self): 129 | # """ Currently translating x and y is more easily 130 | # expressed as a joint transformation rather than layer """ 131 | # raise NotImplementedError 132 | 133 | # class CamVid(camvid.CamVid): 134 | # @classmethod 135 | # def default_joint_transform(self): 136 | # return transforms.Compose([ 137 | # JointRandomCrop(224), 138 | # JointRandomHorizontalFlip() 139 | # ]) 140 | 141 | 142 | 143 | 144 | # def CIFAR10ZCA(): 145 | # """ Note, currently broken and doesn't support data aug """ 146 | # transform_dev = transforms.Compose( 147 | # [transforms.ToTensor(), 148 | # transforms.Normalize((.0904,.0868,.0468), (1,1,1))]) 149 | # transform_train = transform_dev 150 | # pathToDataset = '/scratch/datasets/cifar10/' 151 | # trainset = ds.CIFAR10(pathToDataset, download=True, transform=transform_train) 152 | # testset = ds.CIFAR10(pathToDataset, train=False, download=True, transform=transform_dev) 153 | # try: ZCAt_mat = torch.load("ZCAtranspose.np") 154 | # except: ZCAt_mat = constructCifar10ZCA(trainset) 155 | # trainset.train_data = np.dot(trainset.train_data.reshape(-1,32*32*3), ZCAt_mat).reshape(-1,32,32,3) 156 | # testset.test_data = np.dot(testset.test_data.reshape(-1,32*32*3), ZCAt_mat).reshape(-1,32,32,3) 157 | 158 | # def constructCifar10ZCA(trainset): 159 | # print("Constructing ZCA matrix for Cifar10") 160 | # X = trainset.train_data.reshape(-1,32*32*3) 161 | # cov = np.cov(X, rowvar=False) 162 | # # Singular Value Decomposition. X = U * np.diag(S) * V 163 | # U,S,V = np.linalg.svd(cov) 164 | # # U: [M x M] eigenvectors of sigma. 165 | # # S: [M x 1] eigenvalues of sigma. 166 | # # V: [M x M] transpose of U 167 | # # Whitening constant: prevents division by zero 168 | # epsilon = 1e-6 169 | # # ZCA Whitening matrix: U * Lambda * U' 170 | # ZCAMatrix = np.dot(U, np.dot(np.diag(1.0/np.sqrt(S + epsilon)), U.T)) # [M x M] 171 | # torch.save(ZCAMatrix.T, "ZCAtranspose.np") 172 | # return ZCAMatrix.T 173 | 174 | -------------------------------------------------------------------------------- /oil/logging/lazyLogger.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import time 4 | import os 5 | import dill 6 | 7 | class LogTimer(object): 8 | """ Timer to automatically control time spent on expensive logs 9 | by not logging when computations performed in __enter__ 10 | exceed the specified fraction of total time outside. 11 | """ 12 | def __init__(self, minPeriod = 1, timeFrac = 1/10, **kwargs): 13 | """ minPeriod: minimum time between logs. 14 | timeFrac: max fraction of total time spent inside __enter__ block.""" 15 | self.avgLogTime = 0 16 | self.numLogs = 0 17 | self.lastLogTime = 0 18 | self.minPeriod = minPeriod #(measured in minutes) 19 | self.timeFrac = timeFrac 20 | self.performedLog = False 21 | super().__init__(**kwargs) 22 | 23 | def __enter__(self): 24 | """ returns yes iff the number of minutes have elapsed > minPeriod 25 | and > (1/timeFrac) * average time it takes to log """ 26 | timeSinceLog = time.time() - self.lastLogTime 27 | self.performedLog = (timeSinceLog > 60*self.minPeriod) \ 28 | and (timeSinceLog > self.avgLogTime/self.timeFrac) 29 | if self.performedLog: self.lastLogTime = time.time() 30 | return self.performedLog 31 | 32 | def __exit__(self, *args): 33 | if self.performedLog: 34 | timeSpentLogging = time.time()-self.lastLogTime 35 | n = self.numLogs 36 | self.avgLogTime = timeSpentLogging/(n+1) + self.avgLogTime*n/(n+1) 37 | self.numLogs += 1 38 | self.lastLogTime = time.time() 39 | 40 | 41 | # If tensorboardX fails to load, we replace it with a writer 42 | # that does nothing 43 | class NothingWriter(object): 44 | add_scalar = add_scalars = add_scalars_to_json = add_image \ 45 | = add_image_with_boxes = add_figure = add_video = add_audio \ 46 | = add_text = add_onnx_graph = add_graph = add_embedding \ 47 | = add_pr_curve_raw = close = lambda *args,**kwargs:None 48 | def __init__(self,*args,**kwargs): 49 | return super().__init__() 50 | try: 51 | import tensorboardX 52 | MaybeTbWriter = tensorboardX.SummaryWriter 53 | except ModuleNotFoundError: 54 | MaybeTbWriter = NothingWriter 55 | 56 | class MaybeTbWriterWSerial(MaybeTbWriter): 57 | """ Wraps summary writer but allows pickling with set and getstate """ 58 | def __getstate__(self): 59 | return dict((k, v) for k, v in self.__dict__.items() 60 | if not k in ['file_writer','all_writers']) 61 | def __setstate__(self,state): 62 | self.__init__(log_dir = state['_log_dir']) 63 | self.__dict__.update(state) 64 | def add_scalars(self,main_tag,tag_scalar_dict,global_step=None,walltime=None): 65 | for tag,scalar in tag_scalar_dict.items(): 66 | full_tag = f"{main_tag}/{tag}" 67 | self.add_scalar(full_tag,scalar,global_step,walltime) 68 | def tb_default_logdir(comment=''): 69 | import socket 70 | from datetime import datetime 71 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 72 | extra = current_time + '_' + socket.gethostname() 73 | log_dir = os.path.join('runs', comment or extra) 74 | return log_dir 75 | 76 | class LazyLogger(LogTimer, MaybeTbWriterWSerial): 77 | """ Tensorboard logging to log_dir, logged scalars are also stored to 78 | a pandas dataframe called constants. Logged text is additionally 79 | store in a dictionary called text. 80 | Lazy context manager functionality allows controlling time spent on 81 | expensive logging operations to a fixed fraction. See LogTimer for 82 | more details. 83 | """ 84 | def __init__(self, log_dir = None,comment='', no_print=False, ema_com=0, **kwargs): 85 | """ log_dir: Where tensorboardX logs are saved, tb default 86 | no_print: if True printing is disabled 87 | ema_com: if nonzero, emas and report show the exponential moving 88 | average of tracked scalars 89 | """ 90 | self.text = {} 91 | self.constants = {} 92 | self.scalar_frame = pd.DataFrame() 93 | self.no_print = no_print 94 | self._com = ema_com 95 | self._unreported = {} 96 | self._log_dirr = tb_default_logdir(comment) if not log_dir else os.path.join(log_dir,comment) 97 | super().__init__(log_dir=self._log_dirr, **kwargs) 98 | 99 | def report(self): 100 | """ prints all unreported text and constants, prints scalar emas""" 101 | if self.no_print: return 102 | for unreported_info in self._unreported.values(): 103 | print(unreported_info)#+'\n') 104 | self._unreported = {} 105 | emas = self.emas() 106 | print(emas) 107 | return emas 108 | 109 | @property # Needs to be read only 110 | def log_dirr(self): # Whatever was assigned by the tbwriter 111 | return self._log_dirr 112 | 113 | def emas(self): 114 | """ Returns the exponential moving average of the logged 115 | scalars (not consts) """ 116 | return self.scalar_frame.iloc[-1:]#.ewm(com=self._com).mean() 117 | 118 | def add_text(self, tag, text_string): 119 | """ text_string is logged (into text and tensorboard) 120 | tag can be specified to allow overwrites so that 121 | a frequently logged text under a tag will only show 122 | the most recent after a report """ 123 | try: self.text[tag].add(text_string) 124 | except KeyError: self.text[tag] = {text_string} 125 | self._unreported[tag] = text_string 126 | super().add_text(tag, text_string) 127 | 128 | def _add_constants(self, tag, dic): 129 | try: self.constants[tag].update(dic) 130 | except KeyError: self.constants[tag] = dic 131 | with pd.option_context('display.expand_frame_repr',False): 132 | self.add_text('Constants/{}'.format(tag),str(pd.Series(dic).to_frame(tag).T)) 133 | 134 | def add_scalars(self, tag, dic, step=None, walltime=None): 135 | """ Like tensorboard add_scalars, but if step and walltime 136 | are not specified, the dic is assumed to hold constants 137 | which are logged as text using add_text""" 138 | if step is None and walltime is None: 139 | self._add_constants(tag,dic) 140 | else: 141 | i = step if step is not None else walltime 142 | newRow = pd.DataFrame(dic, index = [i]) 143 | self.scalar_frame = self.scalar_frame.combine_first(newRow) 144 | super().add_scalars(tag, dic, step)#, walltime=walltime) #TODO: update tensorboardX? 145 | 146 | def save_object(self,obj,suffix): 147 | final_path = os.path.join(self.log_dirr,suffix) 148 | os.makedirs(os.path.dirname(final_path),exist_ok=True) 149 | with open(final_path,'wb') as file: 150 | dill.dump(obj,file) 151 | #torch.save(obj,final_path,pickle_module=dill) 152 | return os.path.abspath(final_path) 153 | 154 | def state_dict(self): 155 | # Will there be a problem with pickling the log_timer here? 156 | return {'text':self.text,'constants':self.constants, 157 | 'scalar_frame':self.scalar_frame} 158 | 159 | def load_state_dict(self, state): 160 | self.text = state['text'] 161 | self.constants = state['constants'] 162 | self.scalar_frame = state['scalar_frame'] 163 | #self.log_timer = state['log_timer'] 164 | 165 | def __str__(self): 166 | return "{} object with text: {}, constants: {}, scalar_frame: {}.\n\ 167 | logging in directory: {}".format( 168 | self.__class__,self.text,self.constants,self.scalar_frame,self.log_dirr) 169 | -------------------------------------------------------------------------------- /oil/datasetup/augLayers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torchvision.transforms as transforms 4 | import torch.nn as nn 5 | import numbers 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from math import ceil, floor 9 | from ..utils.utils import log_uniform 10 | 11 | 12 | class RandomErasing(nn.Module): 13 | ''' 14 | Augmentation module that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 15 | ------------------------------------------------------------------------------------- 16 | probability: The probability that the operation will be performed. 17 | ave_area_frac: average fraction of img area that is erased 18 | ''' 19 | def __init__(self, p = 1, af=1/4, ar=3,max_scale=3): 20 | self.p = p 21 | self.area_frac = af 22 | self.max_ratio = ar 23 | self.max_scale=max_scale 24 | super().__init__() 25 | 26 | def forward(self, x): 27 | if self.training: 28 | return self.random_erase(x) 29 | else: 30 | return x 31 | 32 | def random_erase(self, img): 33 | bs,c,h,w = img.shape 34 | area = h*w 35 | target_areas = log_uniform(1/self.max_scale, self.max_scale,size=bs)*self.area_frac*area 36 | aspect_ratios = log_uniform(1/self.max_ratio, self.max_ratio,size=bs) 37 | 38 | do_erase = np.random.random(bs)trainer.epoch] 199 | for epoch in epochs: 200 | trainer.train_to(epoch) 201 | if save: cfg['saved_at']=trainer.save_checkpoint() 202 | outcome = trainer.ckpt['outcome'] 203 | except Exception as e: 204 | if self.strict: raise 205 | outcome = e 206 | cleanup_cuda() 207 | del trainer 208 | return cfg, outcome 209 | -------------------------------------------------------------------------------- /oil/architectures/img_classifiers/shake_shake.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 4 | # 5 | # This work is licensed under the Creative Commons Attribution-NonCommercial 6 | # 4.0 International License. To view a copy of this license, visit 7 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | ShakeShake model definition 11 | ported from https://github.com/CuriousAI/mean-teacher/blob/master/pytorch/mean_teacher/architectures.py 12 | """ 13 | import math 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.autograd import Function 18 | from ...utils.utils import export,Named 19 | import itertools 20 | 21 | __all__ = ['ShakeShake26','ResNext152'] 22 | 23 | 24 | 25 | class ResNet224x224(nn.Module,metaclass=Named): 26 | def __init__(self, block, layers, channels, groups=1, num_targets=1000, downsample='basic'): 27 | super().__init__() 28 | assert len(layers) == 4 29 | self.downsample_mode = downsample 30 | self.inplanes = 64 31 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 32 | bias=False) 33 | self.bn1 = nn.BatchNorm2d(self.inplanes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 36 | self.layer1 = self._make_layer(block, channels, groups, layers[0]) 37 | self.layer2 = self._make_layer( 38 | block, channels * 2, groups, layers[1], stride=2) 39 | self.layer3 = self._make_layer( 40 | block, channels * 4, groups, layers[2], stride=2) 41 | self.layer4 = self._make_layer( 42 | block, channels * 8, groups, layers[3], stride=2) 43 | self.avgpool = nn.AvgPool2d(7) 44 | self.fc1 = nn.Linear(block.out_channels( 45 | channels * 8, groups), num_targets) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 50 | m.weight.data.normal_(0, math.sqrt(2. / n)) 51 | elif isinstance(m, nn.BatchNorm2d): 52 | m.weight.data.fill_(1) 53 | m.bias.data.zero_() 54 | 55 | def _make_layer(self, block, planes, groups, blocks, stride=1): 56 | downsample = None 57 | if stride != 1 or self.inplanes != block.out_channels(planes, groups): 58 | if self.downsample_mode == 'basic' or stride == 1: 59 | downsample = nn.Sequential( 60 | nn.Conv2d(self.inplanes, block.out_channels(planes, groups), 61 | kernel_size=1, stride=stride, bias=False), 62 | nn.BatchNorm2d(block.out_channels(planes, groups)), 63 | ) 64 | elif self.downsample_mode == 'shift_conv': 65 | downsample = ShiftConvDownsample(in_channels=self.inplanes, 66 | out_channels=block.out_channels(planes, groups)) 67 | else: 68 | assert False 69 | 70 | layers = [] 71 | layers.append(block(self.inplanes, planes, groups, stride, downsample)) 72 | self.inplanes = block.out_channels(planes, groups) 73 | for i in range(1, blocks): 74 | layers.append(block(self.inplanes, planes, groups)) 75 | 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | x = self.conv1(x) 80 | x = self.bn1(x) 81 | x = self.relu(x) 82 | x = self.maxpool(x) 83 | x = self.layer1(x) 84 | x = self.layer2(x) 85 | x = self.layer3(x) 86 | x = self.layer4(x) 87 | x = self.avgpool(x) 88 | x = x.view(x.size(0), -1) 89 | return self.fc1(x) 90 | 91 | 92 | class ResNet32x32(nn.Module,metaclass=Named): 93 | def __init__(self, block, layers, channels, groups=1, num_targets=1000, downsample='basic'): 94 | super().__init__() 95 | assert len(layers) == 3 96 | self.downsample_mode = downsample 97 | self.inplanes = 16 98 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, 99 | padding=1, bias=False) 100 | self.layer1 = self._make_layer(block, channels, groups, layers[0]) 101 | self.layer2 = self._make_layer( 102 | block, channels * 2, groups, layers[1], stride=2) 103 | self.layer3 = self._make_layer( 104 | block, channels * 4, groups, layers[2], stride=2) 105 | self.avgpool = nn.AvgPool2d(8) 106 | self.fc1 = nn.Linear(block.out_channels( 107 | channels * 4, groups), num_targets) 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | 117 | def _make_layer(self, block, planes, groups, blocks, stride=1): 118 | downsample = None 119 | if stride != 1 or self.inplanes != block.out_channels(planes, groups): 120 | if self.downsample_mode == 'basic' or stride == 1: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, block.out_channels(planes, groups), 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(block.out_channels(planes, groups)), 125 | ) 126 | elif self.downsample_mode == 'shift_conv': 127 | downsample = ShiftConvDownsample(in_channels=self.inplanes, 128 | out_channels=block.out_channels(planes, groups)) 129 | else: 130 | assert False 131 | 132 | layers = [] 133 | layers.append(block(self.inplanes, planes, groups, stride, downsample)) 134 | self.inplanes = block.out_channels(planes, groups) 135 | for i in range(1, blocks): 136 | layers.append(block(self.inplanes, planes, groups)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | x = self.conv1(x) 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | return self.fc1(x) 148 | 149 | 150 | def conv3x3(in_planes, out_planes, stride=1): 151 | "3x3 convolution with padding" 152 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 153 | padding=1, bias=False) 154 | 155 | 156 | class BottleneckBlock(nn.Module): 157 | @classmethod 158 | def out_channels(cls, planes, groups): 159 | if groups > 1: 160 | return 2 * planes 161 | else: 162 | return 4 * planes 163 | 164 | def __init__(self, inplanes, planes, groups, stride=1, downsample=None): 165 | super().__init__() 166 | self.relu = nn.ReLU(inplace=True) 167 | 168 | self.conv_a1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 169 | self.bn_a1 = nn.BatchNorm2d(planes) 170 | self.conv_a2 = nn.Conv2d( 171 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups) 172 | self.bn_a2 = nn.BatchNorm2d(planes) 173 | self.conv_a3 = nn.Conv2d(planes, self.out_channels( 174 | planes, groups), kernel_size=1, bias=False) 175 | self.bn_a3 = nn.BatchNorm2d(self.out_channels(planes, groups)) 176 | 177 | self.downsample = downsample 178 | self.stride = stride 179 | 180 | def forward(self, x): 181 | a, residual = x, x 182 | 183 | a = self.conv_a1(a) 184 | a = self.bn_a1(a) 185 | a = self.relu(a) 186 | a = self.conv_a2(a) 187 | a = self.bn_a2(a) 188 | a = self.relu(a) 189 | a = self.conv_a3(a) 190 | a = self.bn_a3(a) 191 | 192 | if self.downsample is not None: 193 | residual = self.downsample(residual) 194 | 195 | return self.relu(residual + a) 196 | 197 | 198 | class ShakeShakeBlock(nn.Module): 199 | @classmethod 200 | def out_channels(cls, planes, groups): 201 | assert groups == 1 202 | return planes 203 | 204 | def __init__(self, inplanes, planes, groups, stride=1, downsample=None): 205 | super().__init__() 206 | assert groups == 1 207 | self.conv_a1 = conv3x3(inplanes, planes, stride) 208 | self.bn_a1 = nn.BatchNorm2d(planes) 209 | self.conv_a2 = conv3x3(planes, planes) 210 | self.bn_a2 = nn.BatchNorm2d(planes) 211 | 212 | self.conv_b1 = conv3x3(inplanes, planes, stride) 213 | self.bn_b1 = nn.BatchNorm2d(planes) 214 | self.conv_b2 = conv3x3(planes, planes) 215 | self.bn_b2 = nn.BatchNorm2d(planes) 216 | 217 | self.downsample = downsample 218 | self.stride = stride 219 | 220 | def forward(self, x): 221 | a, b, residual = x, x, x 222 | 223 | a = F.relu(a, inplace=False) 224 | a = self.conv_a1(a) 225 | a = self.bn_a1(a) 226 | a = F.relu(a, inplace=True) 227 | a = self.conv_a2(a) 228 | a = self.bn_a2(a) 229 | 230 | b = F.relu(b, inplace=False) 231 | b = self.conv_b1(b) 232 | b = self.bn_b1(b) 233 | b = F.relu(b, inplace=True) 234 | b = self.conv_b2(b) 235 | b = self.bn_b2(b) 236 | 237 | ab = shake(a, b, training=self.training) 238 | 239 | if self.downsample is not None: 240 | residual = self.downsample(x) 241 | 242 | return residual + ab 243 | 244 | 245 | class Shake(Function): 246 | @classmethod 247 | def forward(cls, ctx, inp1, inp2, training): 248 | assert inp1.size() == inp2.size() 249 | gate_size = [inp1.size()[0], *itertools.repeat(1, inp1.dim() - 1)] 250 | gate = inp1.new(*gate_size) 251 | if training: 252 | gate.uniform_(0, 1) 253 | else: 254 | gate.fill_(0.5) 255 | return inp1 * gate + inp2 * (1. - gate) 256 | 257 | @classmethod 258 | def backward(cls, ctx, grad_output): 259 | grad_inp1 = grad_inp2 = grad_training = None 260 | gate_size = [grad_output.size()[0], *itertools.repeat(1, 261 | grad_output.dim() - 1)] 262 | gate = Variable(grad_output.data.new(*gate_size).uniform_(0, 1)) 263 | if ctx.needs_input_grad[0]: 264 | grad_inp1 = grad_output * gate 265 | if ctx.needs_input_grad[1]: 266 | grad_inp2 = grad_output * (1 - gate) 267 | assert not ctx.needs_input_grad[2] 268 | return grad_inp1, grad_inp2, grad_training 269 | 270 | 271 | def shake(inp1, inp2, training=False): 272 | return Shake.apply(inp1, inp2, training) 273 | 274 | 275 | class ShiftConvDownsample(nn.Module): 276 | def __init__(self, in_channels, out_channels): 277 | super().__init__() 278 | self.relu = nn.ReLU(inplace=True) 279 | self.conv = nn.Conv2d(in_channels=2 * in_channels, 280 | out_channels=out_channels, 281 | kernel_size=1, 282 | groups=2) 283 | self.bn = nn.BatchNorm2d(out_channels) 284 | 285 | def forward(self, x): 286 | x = torch.cat((x[:, :, 0::2, 0::2], 287 | x[:, :, 1::2, 1::2]), dim=1) 288 | x = self.relu(x) 289 | x = self.conv(x) 290 | x = self.bn(x) 291 | return x 292 | 293 | class ShakeShake26(ResNet32x32): 294 | def __init__(self,num_targets=10): 295 | super().__init__(ShakeShakeBlock, 296 | layers=[4, 4, 4], 297 | channels=96, 298 | downsample='shift_conv', num_targets=num_targets) 299 | 300 | class ResNext152(ResNet224x224): 301 | def __init__(self,num_targets=10): 302 | super().__init__(BottleneckBlock, 303 | layers=[3, 8, 36, 3], 304 | channels=32 * 4, 305 | groups=32, 306 | downsample='basic', num_targets=num_targets) -------------------------------------------------------------------------------- /oil/architectures/parts/deconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from torch.nn.modules import conv 7 | from torch.nn.modules.utils import _pair 8 | from ...utils.utils import Expression,export,Named 9 | #import cv2 10 | 11 | #This is a reference implementation using im2col, and is not used anywhere else 12 | class Conv2d(conv._ConvNd): 13 | 14 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=False): 15 | 16 | 17 | kernel_size = _pair(kernel_size) 18 | stride = _pair(stride) 19 | padding = _pair(padding) 20 | dilation = _pair(dilation) 21 | super(Conv2d, self).__init__( 22 | in_channels, out_channels, kernel_size, stride, padding, dilation, 23 | False, _pair(0), 1, False) 24 | 25 | self.kernel_size=kernel_size 26 | self.dilation=dilation 27 | self.padding=padding 28 | self.stride=stride 29 | 30 | 31 | def forward(self, x): 32 | N,C,H,W=x.shape 33 | out_h=(H+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[0] 34 | out_w=(W+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[1] 35 | w=self.weight 36 | #im2col 37 | inp_unf = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.stride) 38 | #matrix multiplication, reshape 39 | out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2).view(N,-1,out_h,out_w) 40 | 41 | return out_unf 42 | 43 | 44 | #iteratively solve for inverse sqrt of a matrix 45 | def isqrt_newton_schulz_autograd(A, numIters): 46 | dim = A.shape[0] 47 | normA=A.norm() 48 | Y = A.div(normA) 49 | I = torch.eye(dim,dtype=A.dtype,device=A.device) 50 | Z = torch.eye(dim,dtype=A.dtype,device=A.device) 51 | 52 | for i in range(numIters): 53 | T = 0.5*(3.0*I - Z@Y) 54 | Y = Y@T 55 | Z = T@Z 56 | #A_sqrt = Y*torch.sqrt(normA) 57 | A_isqrt = Z / torch.sqrt(normA) 58 | return A_isqrt 59 | 60 | 61 | #deconvolve channels 62 | class ChannelDeconv(nn.Module): 63 | def __init__(self, num_groups, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3,debug=False): 64 | super(ChannelDeconv, self).__init__() 65 | 66 | self.eps = eps 67 | self.n_iter=n_iter 68 | self.momentum=momentum 69 | self.num_groups = num_groups 70 | self.debug=debug 71 | 72 | self.register_buffer('running_mean1', torch.zeros(num_groups, 1)) 73 | #self.register_buffer('running_cov', torch.eye(num_groups)) 74 | self.register_buffer('running_deconv', torch.eye(num_groups)) 75 | self.register_buffer('running_mean2', torch.zeros(1, 1)) 76 | self.register_buffer('running_var', torch.ones(1, 1)) 77 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 78 | self.sampling_stride=sampling_stride 79 | def forward(self, x): 80 | x_shape = x.shape 81 | if len(x.shape)==2: 82 | x=x.view(x.shape[0],x.shape[1],1,1) 83 | if len(x.shape)==3: 84 | print('Error! Unsupprted tensor shape.') 85 | 86 | N, C, H, W = x.size() 87 | G = self.num_groups 88 | 89 | #take the first c channels out for deconv 90 | c=int(C/G)*G 91 | if c==0: 92 | print('Error! num_groups should be set smaller.') 93 | 94 | #step 1. remove mean 95 | if c!=C: 96 | x1=x[:,:c].permute(1,0,2,3).contiguous().view(G,-1) 97 | else: 98 | x1=x.permute(1,0,2,3).contiguous().view(G,-1) 99 | 100 | if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: 101 | x1_s = x1[:,::self.sampling_stride**2] 102 | else: 103 | x1_s=x1 104 | 105 | mean1 = x1_s.mean(-1, keepdim=True) 106 | 107 | if self.num_batches_tracked==0: 108 | self.running_mean1.copy_(mean1.detach()) 109 | if self.training: 110 | self.running_mean1.mul_(1-self.momentum) 111 | self.running_mean1.add_(mean1.detach()*self.momentum) 112 | else: 113 | mean1 = self.running_mean1 114 | 115 | x1=x1-mean1 116 | 117 | #step 2. calculate deconv@x1 = cov^(-0.5)@x1 118 | if self.training: 119 | cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(G, dtype=x.dtype, device=x.device) 120 | deconv = isqrt_newton_schulz_autograd(cov, self.n_iter) 121 | 122 | if self.num_batches_tracked==0: 123 | #self.running_cov.copy_(cov.detach()) 124 | self.running_deconv.copy_(deconv.detach()) 125 | 126 | if self.training: 127 | #self.running_cov.mul_(1-self.momentum) 128 | #self.running_cov.add_(cov.detach()*self.momentum) 129 | self.running_deconv.mul_(1 - self.momentum) 130 | self.running_deconv.add_(deconv.detach() * self.momentum) 131 | else: 132 | # cov = self.running_cov 133 | deconv = self.running_deconv 134 | 135 | x1 =deconv@x1 136 | 137 | #reshape to N,c,J,W 138 | x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3) 139 | 140 | # normalize the remaining channels 141 | if c!=C: 142 | x_tmp=x[:, c:].view(N,-1) 143 | if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride: 144 | x_s = x_tmp[:, ::self.sampling_stride ** 2] 145 | else: 146 | x_s = x_tmp 147 | 148 | mean2=x_s.mean() 149 | var=x_s.var() 150 | 151 | if self.num_batches_tracked == 0: 152 | self.running_mean2.copy_(mean2.detach()) 153 | self.running_var.copy_(var.detach()) 154 | 155 | if self.training: 156 | self.running_mean2.mul_(1 - self.momentum) 157 | self.running_mean2.add_(mean2.detach() * self.momentum) 158 | self.running_var.mul_(1 - self.momentum) 159 | self.running_var.add_(var.detach() * self.momentum) 160 | else: 161 | mean2 = self.running_mean2 162 | var = self.running_var 163 | 164 | x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt() 165 | x1 = torch.cat([x1, x_tmp], dim=1) 166 | 167 | 168 | if self.training: 169 | self.num_batches_tracked.add_(1) 170 | 171 | if len(x_shape)==2: 172 | x1=x1.view(x_shape) 173 | return x1 174 | 175 | 176 | @export 177 | class DeConv2d(conv._ConvNd): 178 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=True, eps=1e-2, n_iter=5, momentum=0.1, mode=4, num_groups=16,debug=False): 179 | # mode 1: remove channel correlation then pixel correlation 180 | # mode 2: only remove pixel correlation 181 | # mode 3: only channel correlation 182 | # mode 4: remove channel correlation and pixel correlation together 183 | kernel_size = _pair(kernel_size) 184 | stride = _pair(stride) 185 | padding = _pair(padding) 186 | dilation = _pair(dilation) 187 | self.kernel_size=kernel_size 188 | self.dilation=dilation 189 | self.padding=padding 190 | self.stride=stride 191 | super(DeConv2d, self).__init__( 192 | in_channels, out_channels, kernel_size, stride, padding, dilation, 193 | False, _pair(0), 1, bias, padding_mode='zeros') 194 | #add padding_mode='zeros' for pytorch 1.1 195 | 196 | self.momentum = momentum 197 | self.mode=mode 198 | self.n_iter = n_iter 199 | self.eps = eps 200 | 201 | num_features = self.weight.shape[2] * self.weight.shape[3]#k*k 202 | if self.mode!=2: 203 | if num_groups>self.weight.shape[1]: 204 | num_groups=self.weight.shape[1] 205 | self.num_groups=num_groups 206 | if self.mode!=4: 207 | self.channel_deconv=ChannelDeconv(num_groups,eps=eps,n_iter=n_iter,momentum=momentum,debug=False) 208 | else: 209 | num_features*=num_groups 210 | 211 | self.num_features = num_features 212 | 213 | if self.mode!=3: 214 | self.register_buffer('running_mean', torch.zeros(num_features,1)) 215 | #self.register_buffer('running_cov', torch.eye(num_features)) 216 | self.register_buffer('running_deconv', torch.eye(num_features)) 217 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 218 | 219 | def forward(self, x): 220 | 221 | N,C,H,W=x.shape 222 | out_h=(H+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[0] 223 | out_w=(W+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[1] 224 | 225 | 226 | if self.mode == 1: 227 | x = self.channel_deconv(x) 228 | 229 | if self.mode==3: 230 | x=self.channel_deconv(x) 231 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, 1) 232 | 233 | 234 | if self.mode!=3: 235 | #1. im2col, reshape 236 | 237 | # N * cols * pixels 238 | inp_unf = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.stride) 239 | 240 | #(k*k, C*N*H*W) for pixel deconv 241 | #(k*k*G, C//G*N*H*W) for grouped pixel deconv 242 | X=inp_unf.permute(1,0,2).contiguous().view(self.num_features,-1) 243 | 244 | #2.subtract mean 245 | X_mean = X.mean(-1, keepdim=True) 246 | 247 | #track stats for evaluation 248 | if self.num_batches_tracked==0: 249 | self.running_mean.copy_(X_mean.detach()) 250 | if self.training: 251 | self.running_mean.mul_(1-self.momentum) 252 | self.running_mean.add_(X_mean.detach()*self.momentum) 253 | else: 254 | X_mean = self.running_mean 255 | 256 | X = X - X_mean 257 | 258 | #3. calculate COV, COV^(-0.5), then deconv 259 | if self.training: 260 | Cov = X / X.shape[1] @ X.t() + self.eps * torch.eye(X.shape[0], dtype=X.dtype, device=X.device) 261 | deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter) 262 | 263 | #track stats for evaluation 264 | if self.num_batches_tracked==0: 265 | #self.running_cov.copy_(Cov.detach()) 266 | self.running_deconv.copy_(deconv.detach()) 267 | if self.training: 268 | #self.running_cov.mul_(1-self.momentum) 269 | #self.running_cov.add_(Cov.detach()*self.momentum) 270 | self.running_deconv.mul_(1 - self.momentum) 271 | self.running_deconv.add_(deconv.detach() * self.momentum) 272 | else: 273 | #Cov = self.running_cov 274 | deconv = self.running_deconv 275 | 276 | #deconv 277 | X_deconv =deconv@X 278 | 279 | #reshape 280 | X_deconv=X_deconv.view(-1,N,out_h*out_w).contiguous().permute(1,2,0) 281 | 282 | #4. convolve 283 | 284 | w = self.weight 285 | out_unf = X_deconv.matmul(w.view(w.size(0), -1).t()).transpose(1, 2).view(N,-1,out_h,out_w) 286 | if self.bias is not None: 287 | out_unf=out_unf+self.bias.view(1,-1,1,1) 288 | 289 | if self.training: 290 | self.num_batches_tracked.add_(1) 291 | 292 | return out_unf#.contiguous() 293 | 294 | 295 | #this version is faster but slightly weaker. We approximately remove the mean. 296 | @export 297 | class FastDeconv(conv._ConvNd): 298 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=True, eps=1e-2, n_iter=5, momentum=0.1, num_groups=16,sampling_stride=3): 299 | 300 | kernel_size = _pair(kernel_size) 301 | stride = _pair(stride) 302 | padding = _pair(padding) 303 | dilation = _pair(dilation) 304 | self.kernel_size=kernel_size 305 | self.dilation=dilation 306 | self.padding=padding 307 | self.stride=stride 308 | self.momentum = momentum 309 | self.n_iter = n_iter 310 | self.eps = eps 311 | super(FastDeconv, self).__init__( 312 | in_channels, out_channels, kernel_size, stride, padding, dilation, 313 | False, _pair(0), 1, bias, padding_mode='zeros') 314 | 315 | if num_groups>in_channels: 316 | num_groups=in_channels 317 | self.num_groups=num_groups 318 | 319 | self.num_features = self.kernel_size[0] * self.kernel_size[1]*num_groups 320 | 321 | self.register_buffer('running_mean', torch.zeros(1,self.num_groups, 1, 1)) 322 | self.register_buffer('running_deconv', torch.eye(self.num_features)) 323 | self.sampling_stride=[sampling_stride*s for s in stride] 324 | 325 | def forward(self, x): 326 | 327 | N,C,H,W=x.shape 328 | N1,C1=N*C//self.num_groups,self.num_groups 329 | 330 | # 1.subtract mean (this is a fast approximation) 331 | x=x.view(N1,C1,H,W) 332 | 333 | # track stats for evaluation 334 | if self.training: 335 | x_mean = x.mean((0,2,3), keepdim=True) 336 | self.running_mean.mul_(1 - self.momentum) 337 | self.running_mean.add_(x_mean.detach() * self.momentum) 338 | else: 339 | x_mean = self.running_mean 340 | 341 | x = x - x_mean 342 | 343 | x=x.view(N,C,H,W) 344 | 345 | #2. im2col: N x cols x pixels 346 | inp_unf = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride) 347 | 348 | #(k*k*G, C//G*N*H*W) for grouped pixel deconv 349 | X = inp_unf.transpose(0,1).contiguous().view(self.num_features, -1) 350 | 351 | #3. calculate COV, COV^(-0.5), then deconv 352 | if self.training: 353 | Cov = X / X.shape[1] @ X.t() + self.eps * torch.eye(X.shape[0], dtype=X.dtype, device=X.device) 354 | deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter) 355 | 356 | #track stats for evaluation 357 | if self.training: 358 | self.running_deconv.mul_(1 - self.momentum) 359 | self.running_deconv.add_(deconv.detach() * self.momentum) 360 | else: 361 | deconv = self.running_deconv 362 | 363 | #deconv + conv 364 | w=self.weight.view(self.weight.shape[0],-1).t().contiguous().view(self.num_features,-1) 365 | w=deconv@w 366 | w=w.view(-1,self.weight.shape[0]).t().view(self.weight.shape) 367 | return F.conv2d(x, w.view(self.weight.shape), self.bias, self.stride, self.padding, self.dilation, 1) --------------------------------------------------------------------------------