├── .gitignore ├── LICENSE.txt ├── README.md ├── examples ├── Transforms with Pytorch and Torchsample.ipynb ├── imgs │ ├── orig1.png │ ├── orig2.png │ ├── orig3.png │ ├── tform1.png │ ├── tform2.png │ └── tform3.png ├── mnist_example.py └── mnist_loader_example.py ├── setup.py ├── tests ├── integration │ ├── fit_complex │ │ └── multi_input_multi_target.py │ ├── fit_loader_simple │ │ ├── single_input_multi_target.py │ │ └── single_input_single_target.py │ └── fit_simple │ │ ├── simple_multi_input_multi_target.py │ │ ├── simple_multi_input_no_target.py │ │ ├── simple_multi_input_single_target.py │ │ ├── single_input_multi_target.py │ │ ├── single_input_no_target.py │ │ └── single_input_single_target.py ├── test_metrics.py ├── transforms │ ├── test_affine_transforms.py │ ├── test_image_transforms.py │ └── test_tensor_transforms.py └── utils.py └── torchsample ├── __init__.py ├── callbacks.py ├── constraints.py ├── datasets.py ├── functions ├── __init__.py └── affine.py ├── initializers.py ├── metrics.py ├── modules ├── __init__.py ├── _utils.py └── module_trainer.py ├── regularizers.py ├── samplers.py ├── transforms ├── __init__.py ├── affine_transforms.py ├── distortion_transforms.py ├── image_transforms.py └── tensor_transforms.py ├── utils.py └── version.py /.gitignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | sandbox/ 3 | 4 | *.DS_Store 5 | *__pycache__* 6 | __pycache__ 7 | *.pyc 8 | .ipynb_checkpoints/ 9 | *.ipynb_checkpoints/ 10 | *.bkbn 11 | .spyderworkspace 12 | .spyderproject 13 | 14 | # setup.py working directory 15 | build 16 | # sphinx build directory 17 | doc/_build 18 | # setup.py dist directory 19 | dist 20 | # Egg metadata 21 | *.egg-info 22 | .eggs 23 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | Some contributions by Nicholas Cullen: 4 | Copyright (c) 2017, Nicholas Cullen: 5 | All rights reserved. 6 | 7 | Some contributions by François Chollet: 8 | Copyright (c) 2015, François Chollet. 9 | All rights reserved. 10 | 11 | Some contributions by Google: 12 | Copyright (c) 2015, Google, Inc. 13 | All rights reserved. 14 | 15 | All other contributions: 16 | Copyright (c) 2015, the respective contributors. 17 | All rights reserved. 18 | 19 | 20 | LICENSE 21 | 22 | The MIT License (MIT) 23 | 24 | Permission is hereby granted, free of charge, to any person obtaining a copy 25 | of this software and associated documentation files (the "Software"), to deal 26 | in the Software without restriction, including without limitation the rights 27 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 28 | copies of the Software, and to permit persons to whom the Software is 29 | furnished to do so, subject to the following conditions: 30 | 31 | The above copyright notice and this permission notice shall be included in all 32 | copies or substantial portions of the Software. 33 | 34 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 35 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 36 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 37 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 38 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 39 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 40 | SOFTWARE. 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-Level Training, Data Augmentation, and Utilities for Pytorch 2 | 3 | [v0.1.3](https://github.com/ncullen93/torchsample/releases) JUST RELEASED - contains significant improvements, bug fixes, and additional 4 | support. Get it from the releases, or pull the master branch. 5 | 6 | This package provides a few things: 7 | - A high-level module for Keras-like training with callbacks, constraints, and regularizers. 8 | - Comprehensive data augmentation, transforms, sampling, and loading 9 | - Utility tensor and variable functions so you don't need numpy as often 10 | 11 | Have any feature requests? Submit an issue! I'll make it happen. Specifically, 12 | any data augmentation, data loading, or sampling functions. 13 | 14 | Want to contribute? Check the [issues page](https://github.com/ncullen93/torchsample/issues) 15 | for those tagged with [contributions welcome]. 16 | 17 | ## ModuleTrainer 18 | The `ModuleTrainer` class provides a high-level training interface which abstracts 19 | away the training loop while providing callbacks, constraints, initializers, regularizers, 20 | and more. 21 | 22 | Example: 23 | ```python 24 | from torchsample.modules import ModuleTrainer 25 | 26 | # Define your model EXACTLY as normal 27 | class Network(nn.Module): 28 | def __init__(self): 29 | super(Network, self).__init__() 30 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 31 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 32 | self.fc1 = nn.Linear(1600, 128) 33 | self.fc2 = nn.Linear(128, 10) 34 | 35 | def forward(self, x): 36 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 37 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 38 | x = x.view(-1, 1600) 39 | x = F.relu(self.fc1(x)) 40 | x = F.dropout(x, training=self.training) 41 | x = self.fc2(x) 42 | return F.log_softmax(x) 43 | 44 | model = Network() 45 | trainer = ModuleTrainer(model) 46 | 47 | trainer.compile(loss='nll_loss', 48 | optimizer='adadelta') 49 | 50 | trainer.fit(x_train, y_train, 51 | val_data=(x_test, y_test), 52 | num_epoch=20, 53 | batch_size=128, 54 | verbose=1) 55 | ``` 56 | You also have access to the standard evaluation and prediction functions: 57 | 58 | ```python 59 | loss = model.evaluate(x_train, y_train) 60 | y_pred = model.predict(x_train) 61 | ``` 62 | Torchsample provides a wide range of callbacks, generally mimicking the interface 63 | found in `Keras`: 64 | 65 | - `EarlyStopping` 66 | - `ModelCheckpoint` 67 | - `LearningRateScheduler` 68 | - `ReduceLROnPlateau` 69 | - `CSVLogger` 70 | 71 | ```python 72 | from torchsample.callbacks import EarlyStopping 73 | 74 | callbacks = [EarlyStopping(monitor='val_loss', patience=5)] 75 | model.set_callbacks(callbacks) 76 | ``` 77 | 78 | Torchsample also provides regularizers: 79 | 80 | - `L1Regularizer` 81 | - `L2Regularizer` 82 | - `L1L2Regularizer` 83 | 84 | 85 | and constraints: 86 | - `UnitNorm` 87 | - `MaxNorm` 88 | - `NonNeg` 89 | 90 | Both regularizers and constraints can be selectively applied on layers using regular expressions and the `module_filter` 91 | argument. Constraints can be explicit (hard) constraints applied at an arbitrary batch or 92 | epoch frequency, or they can be implicit (soft) constraints similar to regularizers 93 | where the the constraint deviation is added as a penalty to the total model loss. 94 | 95 | ```python 96 | from torchsample.constraints import MaxNorm, NonNeg 97 | from torchsample.regularizers import L1Regularizer 98 | 99 | # hard constraint applied every 5 batches 100 | hard_constraint = MaxNorm(value=2., frequency=5, unit='batch', module_filter='*fc*') 101 | # implicit constraint added as a penalty term to model loss 102 | soft_constraint = NonNeg(lagrangian=True, scale=1e-3, module_filter='*fc*') 103 | constraints = [hard_constraint, soft_constraint] 104 | model.set_constraints(constraints) 105 | 106 | regularizers = [L1Regularizer(scale=1e-4, module_filter='*conv*')] 107 | model.set_regularizers(regularizers) 108 | ``` 109 | 110 | You can also fit directly on a `torch.utils.data.DataLoader` and can have 111 | a validation set as well : 112 | 113 | ```python 114 | from torchsample import TensorDataset 115 | from torch.utils.data import DataLoader 116 | 117 | train_dataset = TensorDataset(x_train, y_train) 118 | train_loader = DataLoader(train_dataset, batch_size=32) 119 | 120 | val_dataset = TensorDataset(x_val, y_val) 121 | val_loader = DataLoader(val_dataset, batch_size=32) 122 | 123 | trainer.fit_loader(loader, val_loader=val_loader, num_epoch=100) 124 | ``` 125 | 126 | ## Utility Functions 127 | Finally, torchsample provides a few utility functions not commonly found: 128 | 129 | ### Tensor Functions 130 | - `th_iterproduct` (mimics itertools.product) 131 | - `th_gather_nd` (N-dimensional version of torch.gather) 132 | - `th_random_choice` (mimics np.random.choice) 133 | - `th_pearsonr` (mimics scipy.stats.pearsonr) 134 | - `th_corrcoef` (mimics np.corrcoef) 135 | - `th_affine2d` and `th_affine3d` (affine transforms on torch.Tensors) 136 | 137 | ### Variable Functions 138 | - `F_affine2d` and `F_affine3d` 139 | - `F_map_coordinates2d` and `F_map_coordinates3d` 140 | 141 | ## Data Augmentation and Datasets 142 | The torchsample package provides a ton of good data augmentation and transformation 143 | tools which can be applied during data loading. The package also provides the flexible 144 | `TensorDataset` and `FolderDataset` classes to handle most dataset needs. 145 | 146 | ### Torch Transforms 147 | These transforms work directly on torch tensors 148 | 149 | - `Compose()` 150 | - `AddChannel()` 151 | - `SwapDims()` 152 | - `RangeNormalize()` 153 | - `StdNormalize()` 154 | - `Slice2D()` 155 | - `RandomCrop()` 156 | - `SpecialCrop()` 157 | - `Pad()` 158 | - `RandomFlip()` 159 | - `ToTensor()` 160 | 161 | ### Affine Transforms 162 | ![Original](https://github.com/ncullen93/torchsample/blob/master/examples/imgs/orig1.png "Original") ![Transformed](https://github.com/ncullen93/torchsample/blob/master/examples/imgs/tform1.png "Transformed") 163 | 164 | The following transforms perform affine (or affine-like) transforms on torch tensors. 165 | 166 | - `Rotate()` 167 | - `Translate()` 168 | - `Shear()` 169 | - `Zoom()` 170 | 171 | We also provide a class for stringing multiple affine transformations together so that only one interpolation takes place: 172 | 173 | - `Affine()` 174 | - `AffineCompose()` 175 | 176 | ### Datasets and Sampling 177 | We provide the following datasets which provide general structure and iterators for sampling from and using transforms on in-memory or out-of-memory data: 178 | 179 | - `TensorDataset()` 180 | 181 | - `FolderDataset()` 182 | 183 | 184 | ## Acknowledgements 185 | Thank you to the following people and contributors: 186 | - All Keras contributors 187 | - @deallynomore 188 | - @recastrodiaz 189 | 190 | -------------------------------------------------------------------------------- /examples/imgs/orig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/orig1.png -------------------------------------------------------------------------------- /examples/imgs/orig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/orig2.png -------------------------------------------------------------------------------- /examples/imgs/orig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/orig3.png -------------------------------------------------------------------------------- /examples/imgs/tform1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/tform1.png -------------------------------------------------------------------------------- /examples/imgs/tform2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/tform2.png -------------------------------------------------------------------------------- /examples/imgs/tform3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/tform3.png -------------------------------------------------------------------------------- /examples/mnist_example.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | from torchsample.callbacks import EarlyStopping, ReduceLROnPlateau 8 | from torchsample.regularizers import L1Regularizer, L2Regularizer 9 | from torchsample.constraints import UnitNorm 10 | from torchsample.initializers import XavierUniform 11 | from torchsample.metrics import CategoricalAccuracy 12 | 13 | import os 14 | from torchvision import datasets 15 | ROOT = '/users/ncullen/desktop/data/mnist' 16 | dataset = datasets.MNIST(ROOT, train=True, download=True) 17 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 18 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 19 | 20 | x_train = x_train.float() 21 | y_train = y_train.long() 22 | x_test = x_test.float() 23 | y_test = y_test.long() 24 | 25 | x_train = x_train / 255. 26 | x_test = x_test / 255. 27 | x_train = x_train.unsqueeze(1) 28 | x_test = x_test.unsqueeze(1) 29 | 30 | # only train on a subset 31 | x_train = x_train[:10000] 32 | y_train = y_train[:10000] 33 | x_test = x_test[:1000] 34 | y_test = y_test[:1000] 35 | 36 | 37 | # Define your model EXACTLY as if you were using nn.Module 38 | class Network(nn.Module): 39 | def __init__(self): 40 | super(Network, self).__init__() 41 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 42 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 43 | self.fc1 = nn.Linear(1600, 128) 44 | self.fc2 = nn.Linear(128, 10) 45 | 46 | def forward(self, x): 47 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 48 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 49 | x = x.view(-1, 1600) 50 | x = F.relu(self.fc1(x)) 51 | x = F.dropout(x, training=self.training) 52 | x = self.fc2(x) 53 | return F.log_softmax(x) 54 | 55 | 56 | model = Network() 57 | trainer = ModuleTrainer(model) 58 | 59 | 60 | callbacks = [EarlyStopping(patience=10), 61 | ReduceLROnPlateau(factor=0.5, patience=5)] 62 | regularizers = [L1Regularizer(scale=1e-3, module_filter='conv*'), 63 | L2Regularizer(scale=1e-5, module_filter='fc*')] 64 | constraints = [UnitNorm(frequency=3, unit='batch', module_filter='fc*')] 65 | initializers = [XavierUniform(bias=False, module_filter='fc*')] 66 | metrics = [CategoricalAccuracy(top_k=3)] 67 | 68 | trainer.compile(loss='nll_loss', 69 | optimizer='adadelta', 70 | regularizers=regularizers, 71 | constraints=constraints, 72 | initializers=initializers, 73 | metrics=metrics) 74 | 75 | #summary = trainer.summary([1,28,28]) 76 | #print(summary) 77 | 78 | trainer.fit(x_train, y_train, 79 | val_data=(x_test, y_test), 80 | num_epoch=20, 81 | batch_size=128, 82 | verbose=1) 83 | 84 | 85 | -------------------------------------------------------------------------------- /examples/mnist_loader_example.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | from torchsample.modules import ModuleTrainer 9 | from torchsample.callbacks import EarlyStopping, ReduceLROnPlateau 10 | from torchsample.regularizers import L1Regularizer, L2Regularizer 11 | from torchsample.constraints import UnitNorm 12 | from torchsample.initializers import XavierUniform 13 | from torchsample.metrics import CategoricalAccuracy 14 | from torchsample import TensorDataset 15 | 16 | import os 17 | from torchvision import datasets 18 | ROOT = '/users/ncullen/desktop/data/mnist' 19 | dataset = datasets.MNIST(ROOT, train=True, download=True) 20 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 21 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 22 | 23 | x_train = x_train.float() 24 | y_train = y_train.long() 25 | x_test = x_test.float() 26 | y_test = y_test.long() 27 | 28 | x_train = x_train / 255. 29 | x_test = x_test / 255. 30 | x_train = x_train.unsqueeze(1) 31 | x_test = x_test.unsqueeze(1) 32 | 33 | # only train on a subset 34 | x_train = x_train[:10000] 35 | y_train = y_train[:10000] 36 | x_test = x_test[:1000] 37 | y_test = y_test[:1000] 38 | 39 | train_dataset = TensorDataset(x_train, y_train) 40 | train_loader = DataLoader(train_dataset, batch_size=32) 41 | val_dataset = TensorDataset(x_test, y_test) 42 | val_loader = DataLoader(val_dataset, batch_size=32) 43 | 44 | # Define your model EXACTLY as if you were using nn.Module 45 | class Network(nn.Module): 46 | def __init__(self): 47 | super(Network, self).__init__() 48 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 49 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 50 | self.fc1 = nn.Linear(1600, 128) 51 | self.fc2 = nn.Linear(128, 10) 52 | 53 | def forward(self, x): 54 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 55 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 56 | x = x.view(-1, 1600) 57 | x = F.relu(self.fc1(x)) 58 | x = F.dropout(x, training=self.training) 59 | x = self.fc2(x) 60 | return F.log_softmax(x) 61 | 62 | 63 | model = Network() 64 | trainer = ModuleTrainer(model) 65 | 66 | callbacks = [EarlyStopping(patience=10), 67 | ReduceLROnPlateau(factor=0.5, patience=5)] 68 | regularizers = [L1Regularizer(scale=1e-3, module_filter='conv*'), 69 | L2Regularizer(scale=1e-5, module_filter='fc*')] 70 | constraints = [UnitNorm(frequency=3, unit='batch', module_filter='fc*')] 71 | initializers = [XavierUniform(bias=False, module_filter='fc*')] 72 | metrics = [CategoricalAccuracy(top_k=3)] 73 | 74 | trainer.compile(loss='nll_loss', 75 | optimizer='adadelta', 76 | regularizers=regularizers, 77 | constraints=constraints, 78 | initializers=initializers, 79 | metrics=metrics, 80 | callbacks=callbacks) 81 | 82 | trainer.fit_loader(train_loader, val_loader, num_epoch=20, verbose=1) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='torchsample', 6 | version='0.1.3', 7 | description='High-Level Training, Augmentation, and Sampling for Pytorch', 8 | author='NC Cullen', 9 | author_email='nickmarch31@yahoo.com', 10 | packages=find_packages() 11 | ) -------------------------------------------------------------------------------- /tests/integration/fit_complex/multi_input_multi_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | from torchsample import regularizers as regs 8 | from torchsample import constraints as cons 9 | from torchsample import initializers as inits 10 | from torchsample import callbacks as cbks 11 | from torchsample import metrics 12 | from torchsample import transforms as tforms 13 | 14 | import os 15 | from torchvision import datasets 16 | 17 | ROOT = '/users/ncullen/desktop/data/mnist' 18 | dataset = datasets.MNIST(ROOT, train=True, download=True) 19 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 20 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 21 | 22 | x_train = x_train.float() 23 | y_train = y_train.long() 24 | x_test = x_test.float() 25 | y_test = y_test.long() 26 | 27 | x_train = x_train / 255. 28 | x_test = x_test / 255. 29 | x_train = x_train.unsqueeze(1) 30 | x_test = x_test.unsqueeze(1) 31 | 32 | # only train on a subset 33 | x_train = x_train[:1000] 34 | y_train = y_train[:1000] 35 | x_test = x_test[:100] 36 | y_test = y_test[:100] 37 | 38 | 39 | # Define your model EXACTLY as if you were using nn.Module 40 | class Network(nn.Module): 41 | def __init__(self): 42 | super(Network, self).__init__() 43 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 44 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 45 | self.fc1 = nn.Linear(1600, 128) 46 | self.fc2 = nn.Linear(128, 10) 47 | 48 | def forward(self, x, y, z): 49 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 50 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 51 | x = x.view(-1, 1600) 52 | x = F.relu(self.fc1(x)) 53 | x = F.dropout(x, training=self.training) 54 | x = self.fc2(x) 55 | return F.log_softmax(x), F.log_softmax(x), F.log_softmax(x) 56 | 57 | # with one loss function given 58 | model = Network() 59 | trainer = ModuleTrainer(model) 60 | 61 | regularizers = [regs.L1Regularizer(1e-4, 'fc*'), regs.L2Regularizer(1e-5, 'conv*')] 62 | constraints = [cons.UnitNorm(5, 'batch', 'fc*'), 63 | cons.MaxNorm(5, 0, 'batch', 'conv*')] 64 | callbacks = [cbks.ReduceLROnPlateau(monitor='loss', verbose=1)] 65 | 66 | trainer.compile(loss='nll_loss', 67 | optimizer='adadelta', 68 | regularizers=regularizers, 69 | constraints=constraints, 70 | callbacks=callbacks) 71 | 72 | trainer.fit([x_train, x_train, x_train], 73 | [y_train, y_train, y_train], 74 | num_epoch=3, 75 | batch_size=128, 76 | verbose=1) 77 | 78 | yp1, yp2, yp3 = trainer.predict([x_train, x_train, x_train]) 79 | print(yp1.size(), yp2.size(), yp3.size()) 80 | 81 | eval_loss = trainer.evaluate([x_train, x_train, x_train], 82 | [y_train, y_train, y_train]) 83 | print(eval_loss) 84 | 85 | # With multiple loss functions given 86 | model = Network() 87 | trainer = ModuleTrainer(model) 88 | 89 | trainer.compile(loss=['nll_loss', 'nll_loss', 'nll_loss'], 90 | optimizer='adadelta', 91 | regularizers=regularizers, 92 | constraints=constraints, 93 | callbacks=callbacks) 94 | 95 | trainer.fit([x_train, x_train, x_train], 96 | [y_train, y_train, y_train], 97 | num_epoch=3, 98 | batch_size=128, 99 | verbose=1) 100 | 101 | # should raise exception for giving multiple loss functions 102 | # but not giving a loss function for every input 103 | try: 104 | model = Network() 105 | trainer = ModuleTrainer(model) 106 | 107 | trainer.compile(loss=['nll_loss', 'nll_loss'], 108 | optimizer='adadelta', 109 | regularizers=regularizers, 110 | constraints=constraints, 111 | callbacks=callbacks) 112 | 113 | trainer.fit([x_train, x_train, x_train], 114 | [y_train, y_train, y_train], 115 | num_epoch=3, 116 | batch_size=128, 117 | verbose=1) 118 | except: 119 | print('Exception correctly caught') 120 | 121 | -------------------------------------------------------------------------------- /tests/integration/fit_loader_simple/single_input_multi_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from torchsample.modules import ModuleTrainer 8 | from torchsample import TensorDataset 9 | 10 | import os 11 | from torchvision import datasets 12 | ROOT = '/users/ncullen/desktop/data/mnist' 13 | dataset = datasets.MNIST(ROOT, train=True, download=True) 14 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 15 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 16 | 17 | x_train = x_train.float() 18 | y_train = y_train.long() 19 | x_test = x_test.float() 20 | y_test = y_test.long() 21 | 22 | x_train = x_train / 255. 23 | x_test = x_test / 255. 24 | x_train = x_train.unsqueeze(1) 25 | x_test = x_test.unsqueeze(1) 26 | 27 | # only train on a subset 28 | x_train = x_train[:1000] 29 | y_train = y_train[:1000] 30 | x_test = x_test[:1000] 31 | y_test = y_test[:1000] 32 | 33 | train_data = TensorDataset(x_train, [y_train, y_train]) 34 | train_loader = DataLoader(train_data, batch_size=128) 35 | 36 | # Define your model EXACTLY as if you were using nn.Module 37 | class Network(nn.Module): 38 | def __init__(self): 39 | super(Network, self).__init__() 40 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 41 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 42 | self.fc1 = nn.Linear(1600, 128) 43 | self.fc2 = nn.Linear(128, 10) 44 | 45 | def forward(self, x): 46 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 47 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 48 | x = x.view(-1, 1600) 49 | x = F.relu(self.fc1(x)) 50 | x = F.dropout(x, training=self.training) 51 | x = self.fc2(x) 52 | return F.log_softmax(x), F.log_softmax(x) 53 | 54 | 55 | # one loss function for multiple targets 56 | model = Network() 57 | trainer = ModuleTrainer(model) 58 | trainer.compile(loss='nll_loss', 59 | optimizer='adadelta') 60 | 61 | trainer.fit_loader(train_loader, 62 | num_epoch=3, 63 | verbose=1) 64 | ypred1, ypred2 = trainer.predict(x_train) 65 | print(ypred1.size(), ypred2.size()) 66 | 67 | eval_loss = trainer.evaluate(x_train, [y_train, y_train]) 68 | print(eval_loss) 69 | # multiple loss functions 70 | model = Network() 71 | trainer = ModuleTrainer(model) 72 | trainer.compile(loss=['nll_loss', 'nll_loss'], 73 | optimizer='adadelta') 74 | trainer.fit_loader(train_loader, 75 | num_epoch=3, 76 | verbose=1) 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /tests/integration/fit_loader_simple/single_input_single_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from torchsample.modules import ModuleTrainer 8 | from torchsample import TensorDataset 9 | 10 | import os 11 | from torchvision import datasets 12 | ROOT = '/users/ncullen/desktop/data/mnist' 13 | dataset = datasets.MNIST(ROOT, train=True, download=True) 14 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 15 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 16 | 17 | x_train = x_train.float() 18 | y_train = y_train.long() 19 | x_test = x_test.float() 20 | y_test = y_test.long() 21 | 22 | x_train = x_train / 255. 23 | x_test = x_test / 255. 24 | x_train = x_train.unsqueeze(1) 25 | x_test = x_test.unsqueeze(1) 26 | 27 | # only train on a subset 28 | x_train = x_train[:1000] 29 | y_train = y_train[:1000] 30 | x_test = x_test[:1000] 31 | y_test = y_test[:1000] 32 | 33 | train_data = TensorDataset(x_train, y_train) 34 | train_loader = DataLoader(train_data, batch_size=128) 35 | 36 | # Define your model EXACTLY as if you were using nn.Module 37 | class Network(nn.Module): 38 | def __init__(self): 39 | super(Network, self).__init__() 40 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 41 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 42 | self.fc1 = nn.Linear(1600, 128) 43 | self.fc2 = nn.Linear(128, 10) 44 | 45 | def forward(self, x): 46 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 47 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 48 | x = x.view(-1, 1600) 49 | x = F.relu(self.fc1(x)) 50 | #x = F.dropout(x, training=self.training) 51 | x = self.fc2(x) 52 | return F.log_softmax(x) 53 | 54 | 55 | model = Network() 56 | trainer = ModuleTrainer(model) 57 | 58 | trainer.compile(loss='nll_loss', 59 | optimizer='adadelta') 60 | 61 | trainer.fit_loader(train_loader, 62 | num_epoch=3, 63 | verbose=1) 64 | 65 | ypred = trainer.predict(x_train) 66 | print(ypred.size()) 67 | 68 | eval_loss = trainer.evaluate(x_train, y_train) 69 | print(eval_loss) 70 | 71 | print(trainer.history) 72 | #print(trainer.history['loss']) 73 | 74 | -------------------------------------------------------------------------------- /tests/integration/fit_simple/simple_multi_input_multi_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | 8 | import os 9 | from torchvision import datasets 10 | ROOT = '/users/ncullen/desktop/data/mnist' 11 | dataset = datasets.MNIST(ROOT, train=True, download=True) 12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 14 | 15 | x_train = x_train.float() 16 | y_train = y_train.long() 17 | x_test = x_test.float() 18 | y_test = y_test.long() 19 | 20 | x_train = x_train / 255. 21 | x_test = x_test / 255. 22 | x_train = x_train.unsqueeze(1) 23 | x_test = x_test.unsqueeze(1) 24 | 25 | # only train on a subset 26 | x_train = x_train[:1000] 27 | y_train = y_train[:1000] 28 | x_test = x_test[:100] 29 | y_test = y_test[:100] 30 | 31 | 32 | # Define your model EXACTLY as if you were using nn.Module 33 | class Network(nn.Module): 34 | def __init__(self): 35 | super(Network, self).__init__() 36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 38 | self.fc1 = nn.Linear(1600, 128) 39 | self.fc2 = nn.Linear(128, 10) 40 | 41 | def forward(self, x, y, z): 42 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 43 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 44 | x = x.view(-1, 1600) 45 | x = F.relu(self.fc1(x)) 46 | x = F.dropout(x, training=self.training) 47 | x = self.fc2(x) 48 | return F.log_softmax(x), F.log_softmax(x), F.log_softmax(x) 49 | 50 | # with one loss function given 51 | model = Network() 52 | trainer = ModuleTrainer(model) 53 | 54 | trainer.compile(loss='nll_loss', 55 | optimizer='adadelta') 56 | 57 | trainer.fit([x_train, x_train, x_train], 58 | [y_train, y_train, y_train], 59 | num_epoch=3, 60 | batch_size=128, 61 | verbose=1) 62 | 63 | yp1, yp2, yp3 = trainer.predict([x_train, x_train, x_train]) 64 | print(yp1.size(), yp2.size(), yp3.size()) 65 | 66 | eval_loss = trainer.evaluate([x_train, x_train, x_train], 67 | [y_train, y_train, y_train]) 68 | print(eval_loss) 69 | 70 | # With multiple loss functions given 71 | model = Network() 72 | trainer = ModuleTrainer(model) 73 | 74 | trainer.compile(loss=['nll_loss', 'nll_loss', 'nll_loss'], 75 | optimizer='adadelta') 76 | 77 | trainer.fit([x_train, x_train, x_train], 78 | [y_train, y_train, y_train], 79 | num_epoch=3, 80 | batch_size=128, 81 | verbose=1) 82 | 83 | # should raise exception for giving multiple loss functions 84 | # but not giving a loss function for every input 85 | try: 86 | model = Network() 87 | trainer = ModuleTrainer(model) 88 | 89 | trainer.compile(loss=['nll_loss', 'nll_loss'], 90 | optimizer='adadelta') 91 | 92 | trainer.fit([x_train, x_train, x_train], 93 | [y_train, y_train, y_train], 94 | num_epoch=3, 95 | batch_size=128, 96 | verbose=1) 97 | except: 98 | print('Exception correctly caught') 99 | 100 | -------------------------------------------------------------------------------- /tests/integration/fit_simple/simple_multi_input_no_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | 8 | import os 9 | from torchvision import datasets 10 | ROOT = '/users/ncullen/desktop/data/mnist' 11 | dataset = datasets.MNIST(ROOT, train=True, download=True) 12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 14 | 15 | x_train = x_train.float() 16 | y_train = y_train.long() 17 | x_test = x_test.float() 18 | y_test = y_test.long() 19 | 20 | x_train = x_train / 255. 21 | x_test = x_test / 255. 22 | x_train = x_train.unsqueeze(1) 23 | x_test = x_test.unsqueeze(1) 24 | 25 | # only train on a subset 26 | x_train = x_train[:1000] 27 | y_train = y_train[:1000] 28 | x_test = x_test[:1000] 29 | y_test = y_test[:1000] 30 | 31 | 32 | # Define your model EXACTLY as if you were using nn.Module 33 | class Network(nn.Module): 34 | def __init__(self): 35 | super(Network, self).__init__() 36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 38 | self.fc1 = nn.Linear(1600, 128) 39 | self.fc2 = nn.Linear(128, 1) 40 | 41 | def forward(self, x, y, z): 42 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 43 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 44 | x = x.view(-1, 1600) 45 | x = F.relu(self.fc1(x)) 46 | x = F.dropout(x, training=self.training) 47 | x = self.fc2(x) 48 | return th.abs(10 - x) 49 | 50 | 51 | model = Network() 52 | trainer = ModuleTrainer(model) 53 | 54 | trainer.compile(loss='unconstrained_sum', 55 | optimizer='adadelta') 56 | 57 | trainer.fit([x_train, x_train, x_train], 58 | num_epoch=3, 59 | batch_size=128, 60 | verbose=1) 61 | 62 | ypred = trainer.predict([x_train, x_train, x_train]) 63 | print(ypred.size()) 64 | 65 | eval_loss = trainer.evaluate([x_train, x_train, x_train]) 66 | print(eval_loss) 67 | 68 | -------------------------------------------------------------------------------- /tests/integration/fit_simple/simple_multi_input_single_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | 8 | import os 9 | from torchvision import datasets 10 | ROOT = '/users/ncullen/desktop/data/mnist' 11 | dataset = datasets.MNIST(ROOT, train=True, download=True) 12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 14 | 15 | x_train = x_train.float() 16 | y_train = y_train.long() 17 | x_test = x_test.float() 18 | y_test = y_test.long() 19 | 20 | x_train = x_train / 255. 21 | x_test = x_test / 255. 22 | x_train = x_train.unsqueeze(1) 23 | x_test = x_test.unsqueeze(1) 24 | 25 | # only train on a subset 26 | x_train = x_train[:1000] 27 | y_train = y_train[:1000] 28 | x_test = x_test[:100] 29 | y_test = y_test[:100] 30 | 31 | 32 | # Define your model EXACTLY as if you were using nn.Module 33 | class Network(nn.Module): 34 | def __init__(self): 35 | super(Network, self).__init__() 36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 38 | self.fc1 = nn.Linear(1600, 128) 39 | self.fc2 = nn.Linear(128, 10) 40 | 41 | def forward(self, x, y, z): 42 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 43 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 44 | x = x.view(-1, 1600) 45 | x = F.relu(self.fc1(x)) 46 | x = F.dropout(x, training=self.training) 47 | x = self.fc2(x) 48 | return F.log_softmax(x) 49 | 50 | 51 | model = Network() 52 | trainer = ModuleTrainer(model) 53 | 54 | trainer.compile(loss='nll_loss', 55 | optimizer='adadelta') 56 | 57 | trainer.fit([x_train, x_train, x_train], y_train, 58 | val_data=([x_test, x_test, x_test], y_test), 59 | num_epoch=3, 60 | batch_size=128, 61 | verbose=1) 62 | 63 | ypred = trainer.predict([x_train, x_train, x_train]) 64 | print(ypred.size()) 65 | 66 | eval_loss = trainer.evaluate([x_train, x_train, x_train], y_train) 67 | print(eval_loss) -------------------------------------------------------------------------------- /tests/integration/fit_simple/single_input_multi_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | 8 | import os 9 | from torchvision import datasets 10 | ROOT = '/users/ncullen/desktop/data/mnist' 11 | dataset = datasets.MNIST(ROOT, train=True, download=True) 12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 14 | 15 | x_train = x_train.float() 16 | y_train = y_train.long() 17 | x_test = x_test.float() 18 | y_test = y_test.long() 19 | 20 | x_train = x_train / 255. 21 | x_test = x_test / 255. 22 | x_train = x_train.unsqueeze(1) 23 | x_test = x_test.unsqueeze(1) 24 | 25 | # only train on a subset 26 | x_train = x_train[:1000] 27 | y_train = y_train[:1000] 28 | x_test = x_test[:1000] 29 | y_test = y_test[:1000] 30 | 31 | 32 | # Define your model EXACTLY as if you were using nn.Module 33 | class Network(nn.Module): 34 | def __init__(self): 35 | super(Network, self).__init__() 36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 38 | self.fc1 = nn.Linear(1600, 128) 39 | self.fc2 = nn.Linear(128, 10) 40 | 41 | def forward(self, x): 42 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 43 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 44 | x = x.view(-1, 1600) 45 | x = F.relu(self.fc1(x)) 46 | x = F.dropout(x, training=self.training) 47 | x = self.fc2(x) 48 | return F.log_softmax(x), F.log_softmax(x) 49 | 50 | 51 | # one loss function for multiple targets 52 | model = Network() 53 | trainer = ModuleTrainer(model) 54 | trainer.compile(loss='nll_loss', 55 | optimizer='adadelta') 56 | 57 | trainer.fit(x_train, 58 | [y_train, y_train], 59 | num_epoch=3, 60 | batch_size=128, 61 | verbose=1) 62 | ypred1, ypred2 = trainer.predict(x_train) 63 | print(ypred1.size(), ypred2.size()) 64 | 65 | eval_loss = trainer.evaluate(x_train, [y_train, y_train]) 66 | print(eval_loss) 67 | # multiple loss functions 68 | model = Network() 69 | trainer = ModuleTrainer(model) 70 | trainer.compile(loss=['nll_loss', 'nll_loss'], 71 | optimizer='adadelta') 72 | trainer.fit(x_train, 73 | [y_train, y_train], 74 | num_epoch=3, 75 | batch_size=128, 76 | verbose=1) 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /tests/integration/fit_simple/single_input_no_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | 8 | import os 9 | from torchvision import datasets 10 | ROOT = '/users/ncullen/desktop/data/mnist' 11 | dataset = datasets.MNIST(ROOT, train=True, download=True) 12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 14 | 15 | x_train = x_train.float() 16 | y_train = y_train.long() 17 | x_test = x_test.float() 18 | y_test = y_test.long() 19 | 20 | x_train = x_train / 255. 21 | x_test = x_test / 255. 22 | x_train = x_train.unsqueeze(1) 23 | x_test = x_test.unsqueeze(1) 24 | 25 | # only train on a subset 26 | x_train = x_train[:1000] 27 | y_train = y_train[:1000] 28 | x_test = x_test[:1000] 29 | y_test = y_test[:1000] 30 | 31 | 32 | # Define your model EXACTLY as if you were using nn.Module 33 | class Network(nn.Module): 34 | def __init__(self): 35 | super(Network, self).__init__() 36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 38 | self.fc1 = nn.Linear(1600, 128) 39 | self.fc2 = nn.Linear(128, 1) 40 | 41 | def forward(self, x): 42 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 43 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 44 | x = x.view(-1, 1600) 45 | x = F.relu(self.fc1(x)) 46 | x = F.dropout(x, training=self.training) 47 | x = self.fc2(x) 48 | return th.abs(10 - x) 49 | 50 | 51 | model = Network() 52 | trainer = ModuleTrainer(model) 53 | 54 | trainer.compile(loss='unconstrained_sum', 55 | optimizer='adadelta') 56 | 57 | trainer.fit(x_train, 58 | num_epoch=3, 59 | batch_size=128, 60 | verbose=1) 61 | 62 | ypred = trainer.predict(x_train) 63 | print(ypred.size()) 64 | 65 | eval_loss = trainer.evaluate(x_train, None) 66 | print(eval_loss) -------------------------------------------------------------------------------- /tests/integration/fit_simple/single_input_single_target.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torchsample.modules import ModuleTrainer 7 | from torchsample import regularizers as reg 8 | from torchsample import constraints as con 9 | 10 | import os 11 | from torchvision import datasets 12 | ROOT = '/users/ncullen/desktop/data/mnist' 13 | dataset = datasets.MNIST(ROOT, train=True, download=True) 14 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt')) 15 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt')) 16 | 17 | x_train = x_train.float() 18 | y_train = y_train.long() 19 | x_test = x_test.float() 20 | y_test = y_test.long() 21 | 22 | x_train = x_train / 255. 23 | x_test = x_test / 255. 24 | x_train = x_train.unsqueeze(1) 25 | x_test = x_test.unsqueeze(1) 26 | 27 | # only train on a subset 28 | x_train = x_train[:1000] 29 | y_train = y_train[:1000] 30 | x_test = x_test[:1000] 31 | y_test = y_test[:1000] 32 | 33 | 34 | # Define your model EXACTLY as if you were using nn.Module 35 | class Network(nn.Module): 36 | def __init__(self): 37 | super(Network, self).__init__() 38 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 39 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 40 | self.fc1 = nn.Linear(1600, 128) 41 | self.fc2 = nn.Linear(128, 10) 42 | 43 | def forward(self, x): 44 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 45 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 46 | x = x.view(-1, 1600) 47 | x = F.relu(self.fc1(x)) 48 | #x = F.dropout(x, training=self.training) 49 | x = self.fc2(x) 50 | return F.log_softmax(x) 51 | 52 | 53 | model = Network() 54 | trainer = ModuleTrainer(model) 55 | 56 | trainer.compile(loss='nll_loss', 57 | optimizer='adadelta', 58 | regularizers=[reg.L1Regularizer(1e-4)]) 59 | 60 | trainer.fit(x_train, y_train, 61 | val_data=(x_test, y_test), 62 | num_epoch=3, 63 | batch_size=128, 64 | verbose=1) 65 | 66 | ypred = trainer.predict(x_train) 67 | print(ypred.size()) 68 | 69 | eval_loss = trainer.evaluate(x_train, y_train) 70 | print(eval_loss) 71 | 72 | print(trainer.history) 73 | #print(trainer.history['loss']) 74 | 75 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | from torchsample.metrics import CategoricalAccuracy 6 | 7 | class TestMetrics(unittest.TestCase): 8 | 9 | def test_categorical_accuracy(self): 10 | metric = CategoricalAccuracy() 11 | predicted = Variable(torch.eye(10)) 12 | expected = Variable(torch.LongTensor(list(range(10)))) 13 | self.assertEqual(metric(predicted, expected), 100.0) 14 | 15 | # Set 1st column to ones 16 | predicted = Variable(torch.zeros(10, 10)) 17 | predicted.data[:, 0] = torch.ones(10) 18 | self.assertEqual(metric(predicted, expected), 55.0) 19 | 20 | if __name__ == '__main__': 21 | unittest.main() -------------------------------------------------------------------------------- /tests/transforms/test_affine_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test affine transforms 3 | 4 | Transforms: 5 | - Affine + RandomAffine 6 | - AffineCompose 7 | - Rotate + RandomRotate 8 | - Translate + RandomTranslate 9 | - Shear + RandomShear 10 | - Zoom + RandomZoom 11 | """ 12 | 13 | #import pytest 14 | 15 | import torch as th 16 | 17 | from torchsample.transforms import (RandomAffine, Affine, 18 | RandomRotate, RandomChoiceRotate, Rotate, 19 | RandomTranslate, RandomChoiceTranslate, Translate, 20 | RandomShear, RandomChoiceShear, Shear, 21 | RandomZoom, RandomChoiceZoom, Zoom) 22 | 23 | # ---------------------------------------------------- 24 | # ---------------------------------------------------- 25 | 26 | ## DATA SET ## 27 | def gray2d_setup(): 28 | images = {} 29 | 30 | x = th.zeros(1,30,30) 31 | x[:,10:21,10:21] = 1 32 | images['gray_01'] = x 33 | 34 | x = th.zeros(1,30,40) 35 | x[:,10:21,10:21] = 1 36 | images['gray_02'] = x 37 | 38 | return images 39 | 40 | def multi_gray2d_setup(): 41 | old_imgs = gray2d_setup() 42 | images = {} 43 | for k,v in old_imgs.items(): 44 | images[k+'_2imgs'] = [v,v] 45 | images[k+'_3imgs'] = [v,v,v] 46 | images[k+'_4imgs'] = [v,v,v,v] 47 | return images 48 | 49 | def color2d_setup(): 50 | images = {} 51 | 52 | x = th.zeros(3,30,30) 53 | x[:,10:21,10:21] = 1 54 | images['color_01'] = x 55 | 56 | x = th.zeros(3,30,40) 57 | x[:,10:21,10:21] = 1 58 | images['color_02'] = x 59 | 60 | return images 61 | 62 | def multi_color2d_setup(): 63 | old_imgs = color2d_setup() 64 | images = {} 65 | for k,v in old_imgs.items(): 66 | images[k+'_2imgs'] = [v,v] 67 | images[k+'_3imgs'] = [v,v,v] 68 | images[k+'_4imgs'] = [v,v,v,v] 69 | return images 70 | 71 | 72 | # ---------------------------------------------------- 73 | # ---------------------------------------------------- 74 | 75 | 76 | def Affine_setup(): 77 | tforms = {} 78 | tforms['random_affine'] = RandomAffine(rotation_range=30, 79 | translation_range=0.1) 80 | tforms['affine'] = Affine(th.FloatTensor([[0.9,0,0],[0,0.9,0]])) 81 | return tforms 82 | 83 | def Rotate_setup(): 84 | tforms = {} 85 | tforms['random_rotate'] = RandomRotate(30) 86 | tforms['random_choice_rotate'] = RandomChoiceRotate([30,40,50]) 87 | tforms['rotate'] = Rotate(30) 88 | return tforms 89 | 90 | def Translate_setup(): 91 | tforms = {} 92 | tforms['random_translate'] = RandomTranslate(0.1) 93 | tforms['random_choice_translate'] = RandomChoiceTranslate([0.1,0.2]) 94 | tforms['translate'] = Translate(0.3) 95 | return tforms 96 | 97 | def Shear_setup(): 98 | tforms = {} 99 | tforms['random_shear'] = RandomShear(30) 100 | tforms['random_choice_shear'] = RandomChoiceShear([20,30,40]) 101 | tforms['shear'] = Shear(25) 102 | return tforms 103 | 104 | def Zoom_setup(): 105 | tforms = {} 106 | tforms['random_zoom'] = RandomZoom((0.8,1.2)) 107 | tforms['random_choice_zoom'] = RandomChoiceZoom([0.8,0.9,1.1,1.2]) 108 | tforms['zoom'] = Zoom(0.9) 109 | return tforms 110 | 111 | # ---------------------------------------------------- 112 | # ---------------------------------------------------- 113 | 114 | def test_affine_transforms_runtime(verbose=1): 115 | """ 116 | Test that there are no runtime errors 117 | """ 118 | ### MAKE TRANSFORMS ### 119 | tforms = {} 120 | tforms.update(Affine_setup()) 121 | tforms.update(Rotate_setup()) 122 | tforms.update(Translate_setup()) 123 | tforms.update(Shear_setup()) 124 | tforms.update(Zoom_setup()) 125 | 126 | ### MAKE DATA 127 | images = {} 128 | images.update(gray2d_setup()) 129 | images.update(multi_gray2d_setup()) 130 | images.update(color2d_setup()) 131 | images.update(multi_color2d_setup()) 132 | 133 | successes = [] 134 | failures = [] 135 | for im_key, im_val in images.items(): 136 | for tf_key, tf_val in tforms.items(): 137 | try: 138 | if isinstance(im_val, (tuple,list)): 139 | tf_val(*im_val) 140 | else: 141 | tf_val(im_val) 142 | successes.append((im_key, tf_key)) 143 | except: 144 | failures.append((im_key, tf_key)) 145 | 146 | if verbose > 0: 147 | for k, v in failures: 148 | print('%s - %s' % (k, v)) 149 | 150 | print('# SUCCESSES: ', len(successes)) 151 | print('# FAILURES: ' , len(failures)) 152 | 153 | 154 | if __name__=='__main__': 155 | test_affine_transforms_runtime() 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /tests/transforms/test_image_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for torchsample/transforms/image_transforms.py 3 | """ 4 | 5 | 6 | import torch as th 7 | 8 | from torchsample.transforms import (Grayscale, RandomGrayscale, 9 | Gamma, RandomGamma, RandomChoiceGamma, 10 | Brightness, RandomBrightness, RandomChoiceBrightness, 11 | Saturation, RandomSaturation, RandomChoiceSaturation, 12 | Contrast, RandomContrast, RandomChoiceContrast) 13 | 14 | # ---------------------------------------------------- 15 | # ---------------------------------------------------- 16 | 17 | ## DATA SET ## 18 | def gray2d_setup(): 19 | images = {} 20 | 21 | x = th.zeros(1,30,30) 22 | x[:,10:21,10:21] = 1 23 | images['gray_01'] = x 24 | 25 | x = th.zeros(1,30,40) 26 | x[:,10:21,10:21] = 1 27 | images['gray_02'] = x 28 | 29 | return images 30 | 31 | def multi_gray2d_setup(): 32 | old_imgs = gray2d_setup() 33 | images = {} 34 | for k,v in old_imgs.items(): 35 | images[k+'_2imgs'] = [v,v] 36 | images[k+'_3imgs'] = [v,v,v] 37 | images[k+'_4imgs'] = [v,v,v,v] 38 | return images 39 | 40 | def color2d_setup(): 41 | images = {} 42 | 43 | x = th.zeros(3,30,30) 44 | x[:,10:21,10:21] = 1 45 | images['color_01'] = x 46 | 47 | x = th.zeros(3,30,40) 48 | x[:,10:21,10:21] = 1 49 | images['color_02'] = x 50 | 51 | return images 52 | 53 | def multi_color2d_setup(): 54 | old_imgs = color2d_setup() 55 | images = {} 56 | for k,v in old_imgs.items(): 57 | images[k+'_2imgs'] = [v,v] 58 | images[k+'_3imgs'] = [v,v,v] 59 | images[k+'_4imgs'] = [v,v,v,v] 60 | return images 61 | 62 | # ---------------------------------------------------- 63 | # ---------------------------------------------------- 64 | 65 | ## TFORMS SETUP ### 66 | def Grayscale_setup(): 67 | tforms = {} 68 | tforms['grayscale_keepchannels'] = Grayscale(keep_channels=True) 69 | tforms['grayscale_dontkeepchannels'] = Grayscale(keep_channels=False) 70 | 71 | tforms['random_grayscale_nop'] = RandomGrayscale() 72 | tforms['random_grayscale_p_01'] = RandomGrayscale(0) 73 | tforms['random_grayscale_p_02'] = RandomGrayscale(0.5) 74 | tforms['random_grayscale_p_03'] = RandomGrayscale(1) 75 | 76 | return tforms 77 | 78 | def Gamma_setup(): 79 | tforms = {} 80 | tforms['gamma_<1'] = Gamma(value=0.5) 81 | tforms['gamma_=1'] = Gamma(value=1.0) 82 | tforms['gamma_>1'] = Gamma(value=1.5) 83 | tforms['random_gamma_01'] = RandomGamma(0.5,1.5) 84 | tforms['random_gamma_02'] = RandomGamma(0.5,1.0) 85 | tforms['random_gamma_03'] = RandomGamma(1.0,1.5) 86 | tforms['random_choice_gamma_01'] = RandomChoiceGamma([0.5,1.0]) 87 | tforms['random_choice_gamma_02'] = RandomChoiceGamma([0.5,1.0],p=[0.5,0.5]) 88 | tforms['random_choice_gamma_03'] = RandomChoiceGamma([0.5,1.0],p=[0.2,0.8]) 89 | 90 | return tforms 91 | 92 | def Brightness_setup(): 93 | tforms = {} 94 | tforms['brightness_=-1'] = Brightness(value=-1) 95 | tforms['brightness_<0'] = Brightness(value=-0.5) 96 | tforms['brightness_=0'] = Brightness(value=0) 97 | tforms['brightness_>0'] = Brightness(value=0.5) 98 | tforms['brightness_=1'] = Brightness(value=1) 99 | 100 | tforms['random_brightness_01'] = RandomBrightness(-1,-0.5) 101 | tforms['random_brightness_02'] = RandomBrightness(-0.5,0) 102 | tforms['random_brightness_03'] = RandomBrightness(0,0.5) 103 | tforms['random_brightness_04'] = RandomBrightness(0.5,1) 104 | 105 | tforms['random_choice_brightness_01'] = RandomChoiceBrightness([-1,0,1]) 106 | tforms['random_choice_brightness_02'] = RandomChoiceBrightness([-1,0,1],p=[0.2,0.5,0.3]) 107 | tforms['random_choice_brightness_03'] = RandomChoiceBrightness([0,0,0,0],p=[0.25,0.5,0.25,0.25]) 108 | 109 | return tforms 110 | 111 | def Saturation_setup(): 112 | tforms = {} 113 | tforms['saturation_=-1'] = Saturation(-1) 114 | tforms['saturation_<0'] = Saturation(-0.5) 115 | tforms['saturation_=0'] = Saturation(0) 116 | tforms['saturation_>0'] = Saturation(0.5) 117 | tforms['saturation_=1'] = Saturation(1) 118 | 119 | tforms['random_saturation_01'] = RandomSaturation(-1,-0.5) 120 | tforms['random_saturation_02'] = RandomSaturation(-0.5,0) 121 | tforms['random_saturation_03'] = RandomSaturation(0,0.5) 122 | tforms['random_saturation_04'] = RandomSaturation(0.5,1) 123 | 124 | tforms['random_choice_saturation_01'] = RandomChoiceSaturation([-1,0,1]) 125 | tforms['random_choice_saturation_02'] = RandomChoiceSaturation([-1,0,1],p=[0.2,0.5,0.3]) 126 | tforms['random_choice_saturation_03'] = RandomChoiceSaturation([0,0,0,0],p=[0.25,0.5,0.25,0.25]) 127 | 128 | return tforms 129 | 130 | def Contrast_setup(): 131 | tforms = {} 132 | tforms['contrast_<<0'] = Contrast(-10) 133 | tforms['contrast_<0'] = Contrast(-1) 134 | tforms['contrast_=0'] = Contrast(0) 135 | tforms['contrast_>0'] = Contrast(1) 136 | tforms['contrast_>>0'] = Contrast(10) 137 | 138 | tforms['random_contrast_01'] = RandomContrast(-10,-1) 139 | tforms['random_contrast_02'] = RandomContrast(-1,0) 140 | tforms['random_contrast_03'] = RandomContrast(0,1) 141 | tforms['random_contrast_04'] = RandomContrast(1,10) 142 | 143 | tforms['random_choice_saturation_01'] = RandomChoiceContrast([-1,0,1]) 144 | tforms['random_choice_saturation_02'] = RandomChoiceContrast([-10,0,10],p=[0.2,0.5,0.3]) 145 | tforms['random_choice_saturation_03'] = RandomChoiceContrast([1,1],p=[0.5,0.5]) 146 | 147 | return tforms 148 | 149 | # ---------------------------------------------------- 150 | # ---------------------------------------------------- 151 | 152 | def test_image_transforms_runtime(verbose=1): 153 | """ 154 | Test that there are no runtime errors 155 | """ 156 | ### MAKE TRANSFORMS ### 157 | tforms = {} 158 | tforms.update(Gamma_setup()) 159 | tforms.update(Brightness_setup()) 160 | tforms.update(Saturation_setup()) 161 | tforms.update(Contrast_setup()) 162 | 163 | ### MAKE DATA ### 164 | images = {} 165 | images.update(gray2d_setup()) 166 | images.update(multi_gray2d_setup()) 167 | images.update(color2d_setup()) 168 | images.update(multi_color2d_setup()) 169 | 170 | successes = [] 171 | failures = [] 172 | for im_key, im_val in images.items(): 173 | for tf_key, tf_val in tforms.items(): 174 | try: 175 | if isinstance(im_val, (tuple,list)): 176 | tf_val(*im_val) 177 | else: 178 | tf_val(im_val) 179 | successes.append((im_key, tf_key)) 180 | except: 181 | failures.append((im_key, tf_key)) 182 | 183 | if verbose > 0: 184 | for k, v in failures: 185 | print('%s - %s' % (k, v)) 186 | 187 | print('# SUCCESSES: ', len(successes)) 188 | print('# FAILURES: ' , len(failures)) 189 | 190 | 191 | if __name__=='__main__': 192 | test_image_transforms_runtime() 193 | -------------------------------------------------------------------------------- /tests/transforms/test_tensor_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for torchsample/transforms/image_transforms.py 3 | """ 4 | 5 | 6 | import torch as th 7 | 8 | from torchsample.transforms import (ToTensor, 9 | ToVariable, 10 | ToCuda, 11 | ToFile, 12 | ChannelsLast, HWC, 13 | ChannelsFirst, CHW, 14 | TypeCast, 15 | AddChannel, 16 | Transpose, 17 | RangeNormalize, 18 | StdNormalize, 19 | RandomCrop, 20 | SpecialCrop, 21 | Pad, 22 | RandomFlip, 23 | RandomOrder) 24 | 25 | # ---------------------------------------------------- 26 | 27 | ## DATA SET ## 28 | def gray2d_setup(): 29 | images = {} 30 | 31 | x = th.zeros(1,30,30) 32 | x[:,10:21,10:21] = 1 33 | images['gray_01'] = x 34 | 35 | x = th.zeros(1,30,40) 36 | x[:,10:21,10:21] = 1 37 | images['gray_02'] = x 38 | return images 39 | 40 | def multi_gray2d_setup(): 41 | old_imgs = gray2d_setup() 42 | images = {} 43 | for k,v in old_imgs.items(): 44 | images[k+'_2imgs'] = [v,v] 45 | images[k+'_3imgs'] = [v,v,v] 46 | images[k+'_4imgs'] = [v,v,v,v] 47 | return images 48 | 49 | def color2d_setup(): 50 | images = {} 51 | 52 | x = th.zeros(3,30,30) 53 | x[:,10:21,10:21] = 1 54 | images['color_01'] = x 55 | 56 | x = th.zeros(3,30,40) 57 | x[:,10:21,10:21] = 1 58 | images['color_02'] = x 59 | 60 | return images 61 | 62 | def multi_color2d_setup(): 63 | old_imgs = color2d_setup() 64 | images = {} 65 | for k,v in old_imgs.items(): 66 | images[k+'_2imgs'] = [v,v] 67 | images[k+'_3imgs'] = [v,v,v] 68 | images[k+'_4imgs'] = [v,v,v,v] 69 | return images 70 | # ---------------------------------------------------- 71 | # ---------------------------------------------------- 72 | 73 | ## TFORMS SETUP ### 74 | def ToTensor_setup(): 75 | tforms = {} 76 | 77 | tforms['totensor'] = ToTensor() 78 | 79 | return tforms 80 | 81 | def ToVariable_setup(): 82 | tforms = {} 83 | 84 | tforms['tovariable'] = ToVariable() 85 | 86 | return tforms 87 | 88 | def ToCuda_setup(): 89 | tforms = {} 90 | 91 | tforms['tocuda'] = ToCuda() 92 | 93 | return tforms 94 | 95 | def ToFile_setup(): 96 | tforms = {} 97 | 98 | ROOT = '~/desktop/data/' 99 | tforms['tofile_npy'] = ToFile(root=ROOT, fmt='npy') 100 | tforms['tofile_pth'] = ToFile(root=ROOT, fmt='pth') 101 | tforms['tofile_jpg'] = ToFile(root=ROOT, fmt='jpg') 102 | tforms['tofile_png'] = ToFile(root=ROOT, fmt='png') 103 | 104 | return tforms 105 | 106 | def ChannelsLast_setup(): 107 | tforms = {} 108 | 109 | tforms['channels_last'] = ChannelsLast() 110 | tforms['hwc'] = HWC() 111 | 112 | return tforms 113 | 114 | def ChannelsFirst_setup(): 115 | tforms = {} 116 | 117 | tforms['channels_first'] = ChannelsFirst() 118 | tforms['chw'] = CHW() 119 | 120 | return tforms 121 | 122 | def TypeCast_setup(): 123 | tforms = {} 124 | 125 | tforms['byte'] = TypeCast('byte') 126 | tforms['double'] = TypeCast('double') 127 | tforms['float'] = TypeCast('float') 128 | tforms['int'] = TypeCast('int') 129 | tforms['long'] = TypeCast('long') 130 | tforms['short'] = TypeCast('short') 131 | 132 | return tforms 133 | 134 | def AddChannel_setup(): 135 | tforms = {} 136 | 137 | tforms['addchannel_axis0'] = AddChannel(axis=0) 138 | tforms['addchannel_axis1'] = AddChannel(axis=1) 139 | tforms['addchannel_axis2'] = AddChannel(axis=2) 140 | 141 | return tforms 142 | 143 | def Transpose_setup(): 144 | tforms = {} 145 | 146 | tforms['transpose_01'] = Transpose(0, 1) 147 | tforms['transpose_02'] = Transpose(0, 2) 148 | tforms['transpose_10'] = Transpose(1, 0) 149 | tforms['transpose_12'] = Transpose(1, 2) 150 | tforms['transpose_20'] = Transpose(2, 0) 151 | tforms['transpose_21'] = Transpose(2, 1) 152 | 153 | return tforms 154 | 155 | def RangeNormalize_setup(): 156 | tforms = {} 157 | 158 | tforms['rangenorm_01'] = RangeNormalize(0, 1) 159 | tforms['rangenorm_-11'] = RangeNormalize(-1, 1) 160 | tforms['rangenorm_-33'] = RangeNormalize(-3, 3) 161 | tforms['rangenorm_02'] = RangeNormalize(0, 2) 162 | 163 | return tforms 164 | 165 | def StdNormalize_setup(): 166 | tforms = {} 167 | 168 | tforms['stdnorm'] = StdNormalize() 169 | 170 | return tforms 171 | 172 | def RandomCrop_setup(): 173 | tforms = {} 174 | 175 | tforms['randomcrop_1010'] = RandomCrop((10,10)) 176 | tforms['randomcrop_510'] = RandomCrop((5,10)) 177 | tforms['randomcrop_105'] = RandomCrop((10,5)) 178 | tforms['randomcrop_99'] = RandomCrop((9,9)) 179 | tforms['randomcrop_79'] = RandomCrop((7,9)) 180 | tforms['randomcrop_97'] = RandomCrop((9,7)) 181 | 182 | return tforms 183 | 184 | def SpecialCrop_setup(): 185 | tforms = {} 186 | 187 | tforms['specialcrop_0_1010'] = SpecialCrop((10,10),0) 188 | tforms['specialcrop_0_510'] = SpecialCrop((5,10),0) 189 | tforms['specialcrop_0_105'] = SpecialCrop((10,5),0) 190 | tforms['specialcrop_0_99'] = SpecialCrop((9,9),0) 191 | tforms['specialcrop_0_79'] = SpecialCrop((7,9),0) 192 | tforms['specialcrop_0_97'] = SpecialCrop((9,7),0) 193 | 194 | tforms['specialcrop_1_1010'] = SpecialCrop((10,10),1) 195 | tforms['specialcrop_1_510'] = SpecialCrop((5,10),1) 196 | tforms['specialcrop_1_105'] = SpecialCrop((10,5),1) 197 | tforms['specialcrop_1_99'] = SpecialCrop((9,9),1) 198 | tforms['specialcrop_1_79'] = SpecialCrop((7,9),1) 199 | tforms['specialcrop_1_97'] = SpecialCrop((9,7),1) 200 | 201 | tforms['specialcrop_2_1010'] = SpecialCrop((10,10),2) 202 | tforms['specialcrop_2_510'] = SpecialCrop((5,10),2) 203 | tforms['specialcrop_2_105'] = SpecialCrop((10,5),2) 204 | tforms['specialcrop_2_99'] = SpecialCrop((9,9),2) 205 | tforms['specialcrop_2_79'] = SpecialCrop((7,9),2) 206 | tforms['specialcrop_2_97'] = SpecialCrop((9,7),2) 207 | 208 | tforms['specialcrop_3_1010'] = SpecialCrop((10,10),3) 209 | tforms['specialcrop_3_510'] = SpecialCrop((5,10),3) 210 | tforms['specialcrop_3_105'] = SpecialCrop((10,5),3) 211 | tforms['specialcrop_3_99'] = SpecialCrop((9,9),3) 212 | tforms['specialcrop_3_79'] = SpecialCrop((7,9),3) 213 | tforms['specialcrop_3_97'] = SpecialCrop((9,7),3) 214 | 215 | tforms['specialcrop_4_1010'] = SpecialCrop((10,10),4) 216 | tforms['specialcrop_4_510'] = SpecialCrop((5,10),4) 217 | tforms['specialcrop_4_105'] = SpecialCrop((10,5),4) 218 | tforms['specialcrop_4_99'] = SpecialCrop((9,9),4) 219 | tforms['specialcrop_4_79'] = SpecialCrop((7,9),4) 220 | tforms['specialcrop_4_97'] = SpecialCrop((9,7),4) 221 | return tforms 222 | 223 | def Pad_setup(): 224 | tforms = {} 225 | 226 | tforms['pad_4040'] = Pad((40,40)) 227 | tforms['pad_3040'] = Pad((30,40)) 228 | tforms['pad_4030'] = Pad((40,30)) 229 | tforms['pad_3939'] = Pad((39,39)) 230 | tforms['pad_3941'] = Pad((39,41)) 231 | tforms['pad_4139'] = Pad((41,39)) 232 | tforms['pad_4138'] = Pad((41,38)) 233 | tforms['pad_3841'] = Pad((38,41)) 234 | 235 | return tforms 236 | 237 | def RandomFlip_setup(): 238 | tforms = {} 239 | 240 | tforms['randomflip_h_01'] = RandomFlip(h=True, v=False) 241 | tforms['randomflip_h_02'] = RandomFlip(h=True, v=False, p=0) 242 | tforms['randomflip_h_03'] = RandomFlip(h=True, v=False, p=1) 243 | tforms['randomflip_h_04'] = RandomFlip(h=True, v=False, p=0.3) 244 | tforms['randomflip_v_01'] = RandomFlip(h=False, v=True) 245 | tforms['randomflip_v_02'] = RandomFlip(h=False, v=True, p=0) 246 | tforms['randomflip_v_03'] = RandomFlip(h=False, v=True, p=1) 247 | tforms['randomflip_v_04'] = RandomFlip(h=False, v=True, p=0.3) 248 | tforms['randomflip_hv_01'] = RandomFlip(h=True, v=True) 249 | tforms['randomflip_hv_02'] = RandomFlip(h=True, v=True, p=0) 250 | tforms['randomflip_hv_03'] = RandomFlip(h=True, v=True, p=1) 251 | tforms['randomflip_hv_04'] = RandomFlip(h=True, v=True, p=0.3) 252 | return tforms 253 | 254 | def RandomOrder_setup(): 255 | tforms = {} 256 | 257 | tforms['randomorder'] = RandomOrder() 258 | 259 | return tforms 260 | 261 | # ---------------------------------------------------- 262 | # ---------------------------------------------------- 263 | 264 | def test_image_transforms_runtime(verbose=1): 265 | ### MAKE TRANSFORMS ### 266 | tforms = {} 267 | tforms.update(ToTensor_setup()) 268 | tforms.update(ToVariable_setup()) 269 | tforms.update(ToCuda_setup()) 270 | #tforms.update(ToFile_setup()) 271 | tforms.update(ChannelsLast_setup()) 272 | tforms.update(ChannelsFirst_setup()) 273 | tforms.update(TypeCast_setup()) 274 | tforms.update(AddChannel_setup()) 275 | tforms.update(Transpose_setup()) 276 | tforms.update(RangeNormalize_setup()) 277 | tforms.update(StdNormalize_setup()) 278 | tforms.update(RandomCrop_setup()) 279 | tforms.update(SpecialCrop_setup()) 280 | tforms.update(Pad_setup()) 281 | tforms.update(RandomFlip_setup()) 282 | tforms.update(RandomOrder_setup()) 283 | 284 | 285 | ### MAKE DATA 286 | images = {} 287 | images.update(gray2d_setup()) 288 | images.update(multi_gray2d_setup()) 289 | images.update(color2d_setup()) 290 | images.update(multi_color2d_setup()) 291 | 292 | successes =[] 293 | failures = [] 294 | for im_key, im_val in images.items(): 295 | for tf_key, tf_val in tforms.items(): 296 | try: 297 | if isinstance(im_val, (tuple,list)): 298 | tf_val(*im_val) 299 | else: 300 | tf_val(im_val) 301 | successes.append((im_key, tf_key)) 302 | except: 303 | failures.append((im_key, tf_key)) 304 | 305 | if verbose > 0: 306 | for k, v in failures: 307 | print('%s - %s' % (k, v)) 308 | 309 | print('# SUCCESSES: ', len(successes)) 310 | print('# FAILURES: ' , len(failures)) 311 | 312 | 313 | if __name__=='__main__': 314 | test_image_transforms_runtime() 315 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch as th 4 | 5 | def get_test_data(num_train=1000, num_test=500, 6 | input_shape=(10,), output_shape=(2,), 7 | classification=True, num_classes=2): 8 | """Generates test data to train a model on. 9 | 10 | classification=True overrides output_shape 11 | (i.e. output_shape is set to (1,)) and the output 12 | consists in integers in [0, num_class-1]. 13 | 14 | Otherwise: float output with shape output_shape. 15 | """ 16 | samples = num_train + num_test 17 | if classification: 18 | y = np.random.randint(0, num_classes, size=(samples,)) 19 | X = np.zeros((samples,) + input_shape) 20 | for i in range(samples): 21 | X[i] = np.random.normal(loc=y[i], scale=0.7, size=input_shape) 22 | else: 23 | y_loc = np.random.random((samples,)) 24 | X = np.zeros((samples,) + input_shape) 25 | y = np.zeros((samples,) + output_shape) 26 | for i in range(samples): 27 | X[i] = np.random.normal(loc=y_loc[i], scale=0.7, size=input_shape) 28 | y[i] = np.random.normal(loc=y_loc[i], scale=0.7, size=output_shape) 29 | 30 | return (th.from_numpy(X[:num_train]), th.from_numpy(y[:num_train])), \ 31 | (th.from_numpy(X[num_train:]), th.from_numpy(y[num_train:])) -------------------------------------------------------------------------------- /torchsample/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from .version import __version__ 5 | 6 | from .datasets import * 7 | from .samplers import * 8 | 9 | #from .callbacks import * 10 | #from .constraints import * 11 | #from .regularizers import * 12 | 13 | #from . import functions 14 | #from . import transforms 15 | from . import modules 16 | -------------------------------------------------------------------------------- /torchsample/callbacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | SuperModule Callbacks 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import print_function 7 | 8 | from collections import OrderedDict 9 | from collections import Iterable 10 | import warnings 11 | 12 | import os 13 | import csv 14 | import time 15 | from tempfile import NamedTemporaryFile 16 | import shutil 17 | import datetime 18 | import numpy as np 19 | 20 | from tqdm import tqdm 21 | 22 | import torch as th 23 | 24 | 25 | def _get_current_time(): 26 | return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p") 27 | 28 | class CallbackContainer(object): 29 | """ 30 | Container holding a list of callbacks. 31 | """ 32 | def __init__(self, callbacks=None, queue_length=10): 33 | callbacks = callbacks or [] 34 | self.callbacks = [c for c in callbacks] 35 | self.queue_length = queue_length 36 | 37 | def append(self, callback): 38 | self.callbacks.append(callback) 39 | 40 | def set_params(self, params): 41 | for callback in self.callbacks: 42 | callback.set_params(params) 43 | 44 | def set_trainer(self, trainer): 45 | self.trainer = trainer 46 | for callback in self.callbacks: 47 | callback.set_trainer(trainer) 48 | 49 | def on_epoch_begin(self, epoch, logs=None): 50 | logs = logs or {} 51 | for callback in self.callbacks: 52 | callback.on_epoch_begin(epoch, logs) 53 | 54 | def on_epoch_end(self, epoch, logs=None): 55 | logs = logs or {} 56 | for callback in self.callbacks: 57 | callback.on_epoch_end(epoch, logs) 58 | 59 | def on_batch_begin(self, batch, logs=None): 60 | logs = logs or {} 61 | for callback in self.callbacks: 62 | callback.on_batch_begin(batch, logs) 63 | 64 | def on_batch_end(self, batch, logs=None): 65 | logs = logs or {} 66 | for callback in self.callbacks: 67 | callback.on_batch_end(batch, logs) 68 | 69 | def on_train_begin(self, logs=None): 70 | logs = logs or {} 71 | logs['start_time'] = _get_current_time() 72 | for callback in self.callbacks: 73 | callback.on_train_begin(logs) 74 | 75 | def on_train_end(self, logs=None): 76 | logs = logs or {} 77 | logs['final_loss'] = self.trainer.history.epoch_losses[-1], 78 | logs['best_loss'] = min(self.trainer.history.epoch_losses), 79 | logs['stop_time'] = _get_current_time() 80 | for callback in self.callbacks: 81 | callback.on_train_end(logs) 82 | 83 | 84 | class Callback(object): 85 | """ 86 | Abstract base class used to build new callbacks. 87 | """ 88 | 89 | def __init__(self): 90 | pass 91 | 92 | def set_params(self, params): 93 | self.params = params 94 | 95 | def set_trainer(self, model): 96 | self.trainer = model 97 | 98 | def on_epoch_begin(self, epoch, logs=None): 99 | pass 100 | 101 | def on_epoch_end(self, epoch, logs=None): 102 | pass 103 | 104 | def on_batch_begin(self, batch, logs=None): 105 | pass 106 | 107 | def on_batch_end(self, batch, logs=None): 108 | pass 109 | 110 | def on_train_begin(self, logs=None): 111 | pass 112 | 113 | def on_train_end(self, logs=None): 114 | pass 115 | 116 | 117 | class TQDM(Callback): 118 | 119 | def __init__(self): 120 | """ 121 | TQDM Progress Bar callback 122 | 123 | This callback is automatically applied to 124 | every SuperModule if verbose > 0 125 | """ 126 | self.progbar = None 127 | super(TQDM, self).__init__() 128 | 129 | def __enter__(self): 130 | return self 131 | 132 | def __exit__(self, exc_type, exc_val, exc_tb): 133 | # make sure the dbconnection gets closed 134 | if self.progbar is not None: 135 | self.progbar.close() 136 | 137 | def on_train_begin(self, logs): 138 | self.train_logs = logs 139 | 140 | def on_epoch_begin(self, epoch, logs=None): 141 | try: 142 | self.progbar = tqdm(total=self.train_logs['num_batches'], 143 | unit=' batches') 144 | self.progbar.set_description('Epoch %i/%i' % 145 | (epoch+1, self.train_logs['num_epoch'])) 146 | except: 147 | pass 148 | 149 | def on_epoch_end(self, epoch, logs=None): 150 | log_data = {key: '%.04f' % value for key, value in self.trainer.history.batch_metrics.items()} 151 | for k, v in logs.items(): 152 | if k.endswith('metric'): 153 | log_data[k.split('_metric')[0]] = '%.02f' % v 154 | else: 155 | log_data[k] = v 156 | self.progbar.set_postfix(log_data) 157 | self.progbar.update() 158 | self.progbar.close() 159 | 160 | def on_batch_begin(self, batch, logs=None): 161 | self.progbar.update(1) 162 | 163 | def on_batch_end(self, batch, logs=None): 164 | log_data = {key: '%.04f' % value for key, value in self.trainer.history.batch_metrics.items()} 165 | for k, v in logs.items(): 166 | if k.endswith('metric'): 167 | log_data[k.split('_metric')[0]] = '%.02f' % v 168 | self.progbar.set_postfix(log_data) 169 | 170 | 171 | class History(Callback): 172 | """ 173 | Callback that records events into a `History` object. 174 | 175 | This callback is automatically applied to 176 | every SuperModule. 177 | """ 178 | def __init__(self, model): 179 | super(History, self).__init__() 180 | self.samples_seen = 0. 181 | self.trainer = model 182 | 183 | def on_train_begin(self, logs=None): 184 | self.epoch_metrics = { 185 | 'loss': [] 186 | } 187 | self.batch_size = logs['batch_size'] 188 | self.has_val_data = logs['has_val_data'] 189 | self.has_regularizers = logs['has_regularizers'] 190 | if self.has_val_data: 191 | self.epoch_metrics['val_loss'] = [] 192 | if self.has_regularizers: 193 | self.epoch_metrics['reg_loss'] = [] 194 | 195 | def on_epoch_begin(self, epoch, logs=None): 196 | self.batch_metrics = { 197 | 'loss': 0. 198 | } 199 | if self.has_regularizers: 200 | self.batch_metrics['reg_loss'] = 0. 201 | self.samples_seen = 0. 202 | 203 | def on_epoch_end(self, epoch, logs=None): 204 | #for k in self.batch_metrics: 205 | # k_log = k.split('_metric')[0] 206 | # self.epoch_metrics.update(self.batch_metrics) 207 | # TODO 208 | pass 209 | 210 | def on_batch_end(self, batch, logs=None): 211 | for k in self.batch_metrics: 212 | self.batch_metrics[k] = (self.samples_seen*self.batch_metrics[k] + logs[k]*self.batch_size) / (self.samples_seen+self.batch_size) 213 | self.samples_seen += self.batch_size 214 | 215 | def __getitem__(self, name): 216 | return self.epoch_metrics[name] 217 | 218 | def __repr__(self): 219 | return str(self.epoch_metrics) 220 | 221 | def __str__(self): 222 | return str(self.epoch_metrics) 223 | 224 | 225 | class ModelCheckpoint(Callback): 226 | """ 227 | Model Checkpoint to save model weights during training 228 | 229 | save_checkpoint({ 230 | 'epoch': epoch + 1, 231 | 'arch': args.arch, 232 | 'state_dict': model.state_dict(), 233 | 'best_prec1': best_prec1, 234 | 'optimizer' : optimizer.state_dict(), 235 | } 236 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 237 | th.save(state, filename) 238 | if is_best: 239 | shutil.copyfile(filename, 'model_best.pth.tar') 240 | 241 | """ 242 | 243 | def __init__(self, 244 | directory, 245 | filename='ckpt.pth.tar', 246 | monitor='val_loss', 247 | save_best_only=False, 248 | save_weights_only=True, 249 | max_save=-1, 250 | verbose=0): 251 | """ 252 | Model Checkpoint to save model weights during training 253 | 254 | Arguments 255 | --------- 256 | file : string 257 | file to which model will be saved. 258 | It can be written 'filename_{epoch}_{loss}' and those 259 | values will be filled in before saving. 260 | monitor : string in {'val_loss', 'loss'} 261 | whether to monitor train or val loss 262 | save_best_only : boolean 263 | whether to only save if monitored value has improved 264 | save_weight_only : boolean 265 | whether to save entire model or just weights 266 | NOTE: only `True` is supported at the moment 267 | max_save : integer > 0 or -1 268 | the max number of models to save. Older model checkpoints 269 | will be overwritten if necessary. Set equal to -1 to have 270 | no limit 271 | verbose : integer in {0, 1} 272 | verbosity 273 | """ 274 | if directory.startswith('~'): 275 | directory = os.path.expanduser(directory) 276 | self.directory = directory 277 | self.filename = filename 278 | self.file = os.path.join(self.directory, self.filename) 279 | self.monitor = monitor 280 | self.save_best_only = save_best_only 281 | self.save_weights_only = save_weights_only 282 | self.max_save = max_save 283 | self.verbose = verbose 284 | 285 | if self.max_save > 0: 286 | self.old_files = [] 287 | 288 | # mode = 'min' only supported 289 | self.best_loss = float('inf') 290 | super(ModelCheckpoint, self).__init__() 291 | 292 | def save_checkpoint(self, epoch, file, is_best=False): 293 | th.save({ 294 | 'epoch': epoch + 1, 295 | #'arch': args.arch, 296 | 'state_dict': self.trainer.model.state_dict(), 297 | #'best_prec1': best_prec1, 298 | 'optimizer' : self.trainer._optimizer.state_dict(), 299 | #'loss':{}, 300 | # #'regularizers':{}, 301 | # #'constraints':{}, 302 | # #'initializers':{}, 303 | # #'metrics':{}, 304 | # #'val_loss':{} 305 | }, file) 306 | if is_best: 307 | shutil.copyfile(file, 'model_best.pth.tar') 308 | 309 | def on_epoch_end(self, epoch, logs=None): 310 | 311 | file = self.file.format(epoch='%03i'%(epoch+1), 312 | loss='%0.4f'%logs[self.monitor]) 313 | if self.save_best_only: 314 | current_loss = logs.get(self.monitor) 315 | if current_loss is None: 316 | pass 317 | else: 318 | if current_loss < self.best_loss: 319 | if self.verbose > 0: 320 | print('\nEpoch %i: improved from %0.4f to %0.4f saving model to %s' % 321 | (epoch+1, self.best_loss, current_loss, file)) 322 | self.best_loss = current_loss 323 | #if self.save_weights_only: 324 | #else: 325 | self.save_checkpoint(epoch, file) 326 | if self.max_save > 0: 327 | if len(self.old_files) == self.max_save: 328 | try: 329 | os.remove(self.old_files[0]) 330 | except: 331 | pass 332 | self.old_files = self.old_files[1:] 333 | self.old_files.append(file) 334 | else: 335 | if self.verbose > 0: 336 | print('\nEpoch %i: saving model to %s' % (epoch+1, file)) 337 | self.save_checkpoint(epoch, file) 338 | if self.max_save > 0: 339 | if len(self.old_files) == self.max_save: 340 | try: 341 | os.remove(self.old_files[0]) 342 | except: 343 | pass 344 | self.old_files = self.old_files[1:] 345 | self.old_files.append(file) 346 | 347 | 348 | class EarlyStopping(Callback): 349 | """ 350 | Early Stopping to terminate training early under certain conditions 351 | """ 352 | 353 | def __init__(self, 354 | monitor='val_loss', 355 | min_delta=0, 356 | patience=5): 357 | """ 358 | EarlyStopping callback to exit the training loop if training or 359 | validation loss does not improve by a certain amount for a certain 360 | number of epochs 361 | 362 | Arguments 363 | --------- 364 | monitor : string in {'val_loss', 'loss'} 365 | whether to monitor train or val loss 366 | min_delta : float 367 | minimum change in monitored value to qualify as improvement. 368 | This number should be positive. 369 | patience : integer 370 | number of epochs to wait for improvment before terminating. 371 | the counter be reset after each improvment 372 | """ 373 | self.monitor = monitor 374 | self.min_delta = min_delta 375 | self.patience = patience 376 | self.wait = 0 377 | self.best_loss = 1e-15 378 | self.stopped_epoch = 0 379 | super(EarlyStopping, self).__init__() 380 | 381 | def on_train_begin(self, logs=None): 382 | self.wait = 0 383 | self.best_loss = 1e15 384 | 385 | def on_epoch_end(self, epoch, logs=None): 386 | current_loss = logs.get(self.monitor) 387 | if current_loss is None: 388 | pass 389 | else: 390 | if (current_loss - self.best_loss) < -self.min_delta: 391 | self.best_loss = current_loss 392 | self.wait = 1 393 | else: 394 | if self.wait >= self.patience: 395 | self.stopped_epoch = epoch + 1 396 | self.trainer._stop_training = True 397 | self.wait += 1 398 | 399 | def on_train_end(self, logs): 400 | if self.stopped_epoch > 0: 401 | print('\nTerminated Training for Early Stopping at Epoch %04i' % 402 | (self.stopped_epoch)) 403 | 404 | 405 | class LRScheduler(Callback): 406 | """ 407 | Schedule the learning rate according to some function of the 408 | current epoch index, current learning rate, and current train/val loss. 409 | """ 410 | 411 | def __init__(self, schedule): 412 | """ 413 | LearningRateScheduler callback to adapt the learning rate 414 | according to some function 415 | 416 | Arguments 417 | --------- 418 | schedule : callable 419 | should return a number of learning rates equal to the number 420 | of optimizer.param_groups. It should take the epoch index and 421 | **kwargs (or logs) as argument. **kwargs (or logs) will return 422 | the epoch logs such as mean training and validation loss from 423 | the epoch 424 | """ 425 | if isinstance(schedule, dict): 426 | schedule = self.schedule_from_dict 427 | self.schedule_dict = schedule 428 | if any([k < 1.0 for k in schedule.keys()]): 429 | self.fractional_bounds = False 430 | else: 431 | self.fractional_bounds = True 432 | self.schedule = schedule 433 | super(LRScheduler, self).__init__() 434 | 435 | def schedule_from_dict(self, epoch, logs=None): 436 | for epoch_bound, learn_rate in self.schedule_dict.items(): 437 | # epoch_bound is in units of "epochs" 438 | if not self.fractional_bounds: 439 | if epoch_bound < epoch: 440 | return learn_rate 441 | # epoch_bound is in units of "cumulative percent of epochs" 442 | else: 443 | if epoch <= epoch_bound*logs['num_epoch']: 444 | return learn_rate 445 | warnings.warn('Check the keys in the schedule dict.. Returning last value') 446 | return learn_rate 447 | 448 | def on_epoch_begin(self, epoch, logs=None): 449 | current_lrs = [p['lr'] for p in self.trainer._optimizer.param_groups] 450 | lr_list = self.schedule(epoch, current_lrs, **logs) 451 | if not isinstance(lr_list, list): 452 | lr_list = [lr_list] 453 | 454 | for param_group, lr_change in zip(self.trainer._optimizer.param_groups, lr_list): 455 | param_group['lr'] = lr_change 456 | 457 | 458 | class ReduceLROnPlateau(Callback): 459 | """ 460 | Reduce the learning rate if the train or validation loss plateaus 461 | """ 462 | 463 | def __init__(self, 464 | monitor='val_loss', 465 | factor=0.1, 466 | patience=10, 467 | epsilon=0, 468 | cooldown=0, 469 | min_lr=0, 470 | verbose=0): 471 | """ 472 | Reduce the learning rate if the train or validation loss plateaus 473 | 474 | Arguments 475 | --------- 476 | monitor : string in {'loss', 'val_loss'} 477 | which metric to monitor 478 | factor : floar 479 | factor to decrease learning rate by 480 | patience : integer 481 | number of epochs to wait for loss improvement before reducing lr 482 | epsilon : float 483 | how much improvement must be made to reset patience 484 | cooldown : integer 485 | number of epochs to cooldown after a lr reduction 486 | min_lr : float 487 | minimum value to ever let the learning rate decrease to 488 | verbose : integer 489 | whether to print reduction to console 490 | """ 491 | self.monitor = monitor 492 | if factor >= 1.0: 493 | raise ValueError('ReduceLROnPlateau does not support a factor >= 1.0.') 494 | self.factor = factor 495 | self.min_lr = min_lr 496 | self.epsilon = epsilon 497 | self.patience = patience 498 | self.verbose = verbose 499 | self.cooldown = cooldown 500 | self.cooldown_counter = 0 501 | self.wait = 0 502 | self.best_loss = 1e15 503 | self._reset() 504 | super(ReduceLROnPlateau, self).__init__() 505 | 506 | def _reset(self): 507 | """ 508 | Reset the wait and cooldown counters 509 | """ 510 | self.monitor_op = lambda a, b: (a - b) < -self.epsilon 511 | self.best_loss = 1e15 512 | self.cooldown_counter = 0 513 | self.wait = 0 514 | 515 | def on_train_begin(self, logs=None): 516 | self._reset() 517 | 518 | def on_epoch_end(self, epoch, logs=None): 519 | logs = logs or {} 520 | logs['lr'] = [p['lr'] for p in self.trainer._optimizer.param_groups] 521 | current_loss = logs.get(self.monitor) 522 | if current_loss is None: 523 | pass 524 | else: 525 | # if in cooldown phase 526 | if self.cooldown_counter > 0: 527 | self.cooldown_counter -= 1 528 | self.wait = 0 529 | # if loss improved, grab new loss and reset wait counter 530 | if self.monitor_op(current_loss, self.best_loss): 531 | self.best_loss = current_loss 532 | self.wait = 0 533 | # loss didnt improve, and not in cooldown phase 534 | elif not (self.cooldown_counter > 0): 535 | if self.wait >= self.patience: 536 | for p in self.trainer._optimizer.param_groups: 537 | old_lr = p['lr'] 538 | if old_lr > self.min_lr + 1e-4: 539 | new_lr = old_lr * self.factor 540 | new_lr = max(new_lr, self.min_lr) 541 | if self.verbose > 0: 542 | print('\nEpoch %05d: reducing lr from %0.3f to %0.3f' % 543 | (epoch, old_lr, new_lr)) 544 | p['lr'] = new_lr 545 | self.cooldown_counter = self.cooldown 546 | self.wait = 0 547 | self.wait += 1 548 | 549 | 550 | class CSVLogger(Callback): 551 | """ 552 | Logs epoch-level metrics to a CSV file 553 | """ 554 | 555 | def __init__(self, 556 | file, 557 | separator=',', 558 | append=False): 559 | """ 560 | Logs epoch-level metrics to a CSV file 561 | 562 | Arguments 563 | --------- 564 | file : string 565 | path to csv file 566 | separator : string 567 | delimiter for file 568 | apped : boolean 569 | whether to append result to existing file or make new file 570 | """ 571 | self.file = file 572 | self.sep = separator 573 | self.append = append 574 | self.writer = None 575 | self.keys = None 576 | self.append_header = True 577 | super(CSVLogger, self).__init__() 578 | 579 | def on_train_begin(self, logs=None): 580 | if self.append: 581 | if os.path.exists(self.file): 582 | with open(self.file) as f: 583 | self.append_header = not bool(len(f.readline())) 584 | self.csv_file = open(self.file, 'a') 585 | else: 586 | self.csv_file = open(self.file, 'w') 587 | 588 | def on_epoch_end(self, epoch, logs=None): 589 | logs = logs or {} 590 | RK = {'num_batches', 'num_epoch'} 591 | 592 | def handle_value(k): 593 | is_zero_dim_tensor = isinstance(k, th.Tensor) and k.dim() == 0 594 | if isinstance(k, Iterable) and not is_zero_dim_tensor: 595 | return '"[%s]"' % (', '.join(map(str, k))) 596 | else: 597 | return k 598 | 599 | if not self.writer: 600 | self.keys = sorted(logs.keys()) 601 | 602 | class CustomDialect(csv.excel): 603 | delimiter = self.sep 604 | 605 | self.writer = csv.DictWriter(self.csv_file, 606 | fieldnames=['epoch'] + [k for k in self.keys if k not in RK], 607 | dialect=CustomDialect) 608 | if self.append_header: 609 | self.writer.writeheader() 610 | 611 | row_dict = OrderedDict({'epoch': epoch}) 612 | row_dict.update((key, handle_value(logs[key])) for key in self.keys if key not in RK) 613 | self.writer.writerow(row_dict) 614 | self.csv_file.flush() 615 | 616 | def on_train_end(self, logs=None): 617 | self.csv_file.close() 618 | self.writer = None 619 | 620 | 621 | class ExperimentLogger(Callback): 622 | 623 | def __init__(self, 624 | directory, 625 | filename='Experiment_Logger.csv', 626 | save_prefix='Model_', 627 | separator=',', 628 | append=True): 629 | 630 | self.directory = directory 631 | self.filename = filename 632 | self.file = os.path.join(self.directory, self.filename) 633 | self.save_prefix = save_prefix 634 | self.sep = separator 635 | self.append = append 636 | self.writer = None 637 | self.keys = None 638 | self.append_header = True 639 | super(ExperimentLogger, self).__init__() 640 | 641 | def on_train_begin(self, logs=None): 642 | if self.append: 643 | open_type = 'a' 644 | else: 645 | open_type = 'w' 646 | 647 | # if append is True, find whether the file already has header 648 | num_lines = 0 649 | if self.append: 650 | if os.path.exists(self.file): 651 | with open(self.file) as f: 652 | for num_lines, l in enumerate(f): 653 | pass 654 | # if header exists, DONT append header again 655 | with open(self.file) as f: 656 | self.append_header = not bool(len(f.readline())) 657 | 658 | model_idx = num_lines 659 | REJECT_KEYS={'has_validation_data'} 660 | MODEL_NAME = self.save_prefix + str(model_idx) # figure out how to get model name 661 | self.row_dict = OrderedDict({'model': MODEL_NAME}) 662 | self.keys = sorted(logs.keys()) 663 | for k in self.keys: 664 | if k not in REJECT_KEYS: 665 | self.row_dict[k] = logs[k] 666 | 667 | class CustomDialect(csv.excel): 668 | delimiter = self.sep 669 | 670 | with open(self.file, open_type) as csv_file: 671 | writer = csv.DictWriter(csv_file, 672 | fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], 673 | dialect=CustomDialect) 674 | if self.append_header: 675 | writer.writeheader() 676 | 677 | writer.writerow(self.row_dict) 678 | csv_file.flush() 679 | 680 | def on_train_end(self, logs=None): 681 | REJECT_KEYS={'has_validation_data'} 682 | row_dict = self.row_dict 683 | 684 | class CustomDialect(csv.excel): 685 | delimiter = self.sep 686 | self.keys = self.keys 687 | temp_file = NamedTemporaryFile(delete=False, mode='w') 688 | with open(self.file, 'r') as csv_file, temp_file: 689 | reader = csv.DictReader(csv_file, 690 | fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], 691 | dialect=CustomDialect) 692 | writer = csv.DictWriter(temp_file, 693 | fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], 694 | dialect=CustomDialect) 695 | for row_idx, row in enumerate(reader): 696 | if row_idx == 0: 697 | # re-write header with on_train_end's metrics 698 | pass 699 | if row['model'] == self.row_dict['model']: 700 | writer.writerow(row_dict) 701 | else: 702 | writer.writerow(row) 703 | shutil.move(temp_file.name, self.file) 704 | 705 | 706 | class LambdaCallback(Callback): 707 | """ 708 | Callback for creating simple, custom callbacks on-the-fly. 709 | """ 710 | def __init__(self, 711 | on_epoch_begin=None, 712 | on_epoch_end=None, 713 | on_batch_begin=None, 714 | on_batch_end=None, 715 | on_train_begin=None, 716 | on_train_end=None, 717 | **kwargs): 718 | super(LambdaCallback, self).__init__() 719 | self.__dict__.update(kwargs) 720 | if on_epoch_begin is not None: 721 | self.on_epoch_begin = on_epoch_begin 722 | else: 723 | self.on_epoch_begin = lambda epoch, logs: None 724 | if on_epoch_end is not None: 725 | self.on_epoch_end = on_epoch_end 726 | else: 727 | self.on_epoch_end = lambda epoch, logs: None 728 | if on_batch_begin is not None: 729 | self.on_batch_begin = on_batch_begin 730 | else: 731 | self.on_batch_begin = lambda batch, logs: None 732 | if on_batch_end is not None: 733 | self.on_batch_end = on_batch_end 734 | else: 735 | self.on_batch_end = lambda batch, logs: None 736 | if on_train_begin is not None: 737 | self.on_train_begin = on_train_begin 738 | else: 739 | self.on_train_begin = lambda logs: None 740 | if on_train_end is not None: 741 | self.on_train_end = on_train_end 742 | else: 743 | self.on_train_end = lambda logs: None 744 | 745 | -------------------------------------------------------------------------------- /torchsample/constraints.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | from fnmatch import fnmatch 6 | 7 | import torch as th 8 | from .callbacks import Callback 9 | 10 | 11 | class ConstraintContainer(object): 12 | 13 | def __init__(self, constraints): 14 | self.constraints = constraints 15 | self.batch_constraints = [c for c in self.constraints if c.unit.upper() == 'BATCH'] 16 | self.epoch_constraints = [c for c in self.constraints if c.unit.upper() == 'EPOCH'] 17 | 18 | def register_constraints(self, model): 19 | """ 20 | Grab pointers to the weights which will be modified by constraints so 21 | that we dont have to search through the entire network using `apply` 22 | each time 23 | """ 24 | # get batch constraint pointers 25 | self._batch_c_ptrs = {} 26 | for c_idx, constraint in enumerate(self.batch_constraints): 27 | self._batch_c_ptrs[c_idx] = [] 28 | for name, module in model.named_modules(): 29 | if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'): 30 | self._batch_c_ptrs[c_idx].append(module) 31 | 32 | # get epoch constraint pointers 33 | self._epoch_c_ptrs = {} 34 | for c_idx, constraint in enumerate(self.epoch_constraints): 35 | self._epoch_c_ptrs[c_idx] = [] 36 | for name, module in model.named_modules(): 37 | if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'): 38 | self._epoch_c_ptrs[c_idx].append(module) 39 | 40 | def apply_batch_constraints(self, batch_idx): 41 | for c_idx, modules in self._batch_c_ptrs.items(): 42 | if (batch_idx+1) % self.constraints[c_idx].frequency == 0: 43 | for module in modules: 44 | self.constraints[c_idx](module) 45 | 46 | def apply_epoch_constraints(self, epoch_idx): 47 | for c_idx, modules in self._epoch_c_ptrs.items(): 48 | if (epoch_idx+1) % self.constraints[c_idx].frequency == 0: 49 | for module in modules: 50 | self.constraints[c_idx](module) 51 | 52 | 53 | class ConstraintCallback(Callback): 54 | 55 | def __init__(self, container): 56 | self.container = container 57 | 58 | def on_batch_end(self, batch_idx, logs): 59 | self.container.apply_batch_constraints(batch_idx) 60 | 61 | def on_epoch_end(self, epoch_idx, logs): 62 | self.container.apply_epoch_constraints(epoch_idx) 63 | 64 | 65 | class Constraint(object): 66 | 67 | def __call__(self): 68 | raise NotImplementedError('Subclass much implement this method') 69 | 70 | 71 | class UnitNorm(Constraint): 72 | """ 73 | UnitNorm constraint. 74 | 75 | Constraints the weights to have column-wise unit norm 76 | """ 77 | def __init__(self, 78 | frequency=1, 79 | unit='batch', 80 | module_filter='*'): 81 | 82 | self.frequency = frequency 83 | self.unit = unit 84 | self.module_filter = module_filter 85 | 86 | def __call__(self, module): 87 | w = module.weight.data 88 | module.weight.data = w.div(th.norm(w,2,0)) 89 | 90 | 91 | class MaxNorm(Constraint): 92 | """ 93 | MaxNorm weight constraint. 94 | 95 | Constrains the weights incident to each hidden unit 96 | to have a norm less than or equal to a desired value. 97 | 98 | Any hidden unit vector with a norm less than the max norm 99 | constaint will not be altered. 100 | """ 101 | 102 | def __init__(self, 103 | value, 104 | axis=0, 105 | frequency=1, 106 | unit='batch', 107 | module_filter='*'): 108 | self.value = float(value) 109 | self.axis = axis 110 | 111 | self.frequency = frequency 112 | self.unit = unit 113 | self.module_filter = module_filter 114 | 115 | def __call__(self, module): 116 | w = module.weight.data 117 | module.weight.data = th.renorm(w, 2, self.axis, self.value) 118 | 119 | 120 | class NonNeg(Constraint): 121 | """ 122 | Constrains the weights to be non-negative. 123 | """ 124 | def __init__(self, 125 | frequency=1, 126 | unit='batch', 127 | module_filter='*'): 128 | self.frequency = frequency 129 | self.unit = unit 130 | self.module_filter = module_filter 131 | 132 | def __call__(self, module): 133 | w = module.weight.data 134 | module.weight.data = w.gt(0).float().mul(w) 135 | 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /torchsample/datasets.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import os 7 | import fnmatch 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import PIL.Image as Image 12 | import nibabel 13 | 14 | import torch as th 15 | 16 | from . import transforms 17 | 18 | 19 | class BaseDataset(object): 20 | """An abstract class representing a Dataset. 21 | 22 | All other datasets should subclass it. All subclasses should override 23 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 24 | supporting integer indexing in range from 0 to len(self) exclusive. 25 | """ 26 | 27 | def __len__(self): 28 | return len(self.inputs) if not isinstance(self.inputs, (tuple,list)) else len(self.inputs[0]) 29 | 30 | def add_input_transform(self, transform, add_to_front=True, idx=None): 31 | if idx is None: 32 | idx = np.arange(len(self.num_inputs)) 33 | elif not is_tuple_or_list(idx): 34 | idx = [idx] 35 | 36 | if add_to_front: 37 | for i in idx: 38 | self.input_transform[i] = transforms.Compose([transform, self.input_transform[i]]) 39 | else: 40 | for i in idx: 41 | self.input_transform[i] = transforms.Compose([self.input_transform[i], transform]) 42 | 43 | def add_target_transform(self, transform, add_to_front=True, idx=None): 44 | if idx is None: 45 | idx = np.arange(len(self.num_targets)) 46 | elif not is_tuple_or_list(idx): 47 | idx = [idx] 48 | 49 | if add_to_front: 50 | for i in idx: 51 | self.target_transform[i] = transforms.Compose([transform, self.target_transform[i]]) 52 | else: 53 | for i in idx: 54 | self.target_transform[i] = transforms.Compose([self.target_transform[i], transform]) 55 | 56 | def add_co_transform(self, transform, add_to_front=True, idx=None): 57 | if idx is None: 58 | idx = np.arange(len(self.min_inputs_or_targets)) 59 | elif not is_tuple_or_list(idx): 60 | idx = [idx] 61 | 62 | if add_to_front: 63 | for i in idx: 64 | self.co_transform[i] = transforms.Compose([transform, self.co_transform[i]]) 65 | else: 66 | for i in idx: 67 | self.co_transform[i] = transforms.Compose([self.co_transform[i], transform]) 68 | 69 | def load(self, num_samples=None, load_range=None): 70 | """ 71 | Load all data or a subset of the data into actual memory. 72 | For instance, if the inputs are paths to image files, then this 73 | function will actually load those images. 74 | 75 | Arguments 76 | --------- 77 | num_samples : integer (optional) 78 | number of samples to load. if None, will load all 79 | load_range : numpy array of integers (optional) 80 | the index range of images to load 81 | e.g. np.arange(4) loads the first 4 inputs+targets 82 | """ 83 | def _parse_shape(x): 84 | if isinstance(x, (list,tuple)): 85 | return (len(x),) 86 | elif isinstance(x, th.Tensor): 87 | return x.size() 88 | else: 89 | return (1,) 90 | 91 | if num_samples is None and load_range is None: 92 | num_samples = len(self) 93 | load_range = np.arange(num_samples) 94 | elif num_samples is None and load_range is not None: 95 | num_samples = len(load_range) 96 | elif num_samples is not None and load_range is None: 97 | load_range = np.arange(num_samples) 98 | 99 | 100 | if self.has_target: 101 | for enum_idx, sample_idx in enumerate(load_range): 102 | input_sample, target_sample = self.__getitem__(sample_idx) 103 | 104 | if enum_idx == 0: 105 | if self.num_inputs == 1: 106 | _shape = [len(load_range)] + list(_parse_shape(input_sample)) 107 | inputs = np.empty(_shape) 108 | else: 109 | inputs = [] 110 | for i in range(self.num_inputs): 111 | _shape = [len(load_range)] + list(_parse_shape(input_sample[i])) 112 | inputs.append(np.empty(_shape)) 113 | #inputs = [np.empty((len(load_range), *_parse_shape(input_sample[i]))) for i in range(self.num_inputs)] 114 | 115 | if self.num_targets == 1: 116 | _shape = [len(load_range)] + list(_parse_shape(target_sample)) 117 | targets = np.empty(_shape) 118 | #targets = np.empty((len(load_range), *_parse_shape(target_sample))) 119 | else: 120 | targets = [] 121 | for i in range(self.num_targets): 122 | _shape = [len(load_range)] + list(_parse_shape(target_sample[i])) 123 | targets.append(np.empty(_shape)) 124 | #targets = [np.empty((len(load_range), *_parse_shape(target_sample[i]))) for i in range(self.num_targets)] 125 | 126 | if self.num_inputs == 1: 127 | inputs[enum_idx] = input_sample 128 | else: 129 | for i in range(self.num_inputs): 130 | inputs[i][enum_idx] = input_sample[i] 131 | 132 | if self.num_targets == 1: 133 | targets[enum_idx] = target_sample 134 | else: 135 | for i in range(self.num_targets): 136 | targets[i][enum_idx] = target_sample[i] 137 | 138 | return inputs, targets 139 | else: 140 | for enum_idx, sample_idx in enumerate(load_range): 141 | input_sample = self.__getitem__(sample_idx) 142 | 143 | if enum_idx == 0: 144 | if self.num_inputs == 1: 145 | _shape = [len(load_range)] + list(_parse_shape(input_sample)) 146 | inputs = np.empty(_shape) 147 | #inputs = np.empty((len(load_range), *_parse_shape(input_sample))) 148 | else: 149 | inputs = [] 150 | for i in range(self.num_inputs): 151 | _shape = [len(load_range)] + list(_parse_shape(input_sample[i])) 152 | inputs.append(np.empty(_shape)) 153 | #inputs = [np.empty((len(load_range), *_parse_shape(input_sample[i]))) for i in range(self.num_inputs)] 154 | 155 | if self.num_inputs == 1: 156 | inputs[enum_idx] = input_sample 157 | else: 158 | for i in range(self.num_inputs): 159 | inputs[i][enum_idx] = input_sample[i] 160 | 161 | return inputs 162 | 163 | def fit_transforms(self): 164 | """ 165 | Make a single pass through the entire dataset in order to fit 166 | any parameters of the transforms which require the entire dataset. 167 | e.g. StandardScaler() requires mean and std for the entire dataset. 168 | 169 | If you dont call this fit function, then transforms which require properties 170 | of the entire dataset will just work at the batch level. 171 | e.g. StandardScaler() will normalize each batch by the specific batch mean/std 172 | """ 173 | it_fit = hasattr(self.input_transform, 'update_fit') 174 | tt_fit = hasattr(self.target_transform, 'update_fit') 175 | ct_fit = hasattr(self.co_transform, 'update_fit') 176 | if it_fit or tt_fit or ct_fit: 177 | for sample_idx in range(len(self)): 178 | if hasattr(self, 'input_loader'): 179 | x = self.input_loader(self.inputs[sample_idx]) 180 | else: 181 | x = self.inputs[sample_idx] 182 | if it_fit: 183 | self.input_transform.update_fit(x) 184 | if self.has_target: 185 | if hasattr(self, 'target_loader'): 186 | y = self.target_loader(self.targets[sample_idx]) 187 | else: 188 | y = self.targets[sample_idx] 189 | if tt_fit: 190 | self.target_transform.update_fit(y) 191 | if ct_fit: 192 | self.co_transform.update_fit(x,y) 193 | 194 | 195 | def _process_array_argument(x): 196 | if not is_tuple_or_list(x): 197 | x = [x] 198 | return x 199 | 200 | 201 | class TensorDataset(BaseDataset): 202 | 203 | def __init__(self, 204 | inputs, 205 | targets=None, 206 | input_transform=None, 207 | target_transform=None, 208 | co_transform=None): 209 | """ 210 | Dataset class for loading in-memory data. 211 | 212 | Arguments 213 | --------- 214 | inputs: numpy array 215 | 216 | targets : numpy array 217 | 218 | input_transform : class with __call__ function implemented 219 | transform to apply to input sample individually 220 | 221 | target_transform : class with __call__ function implemented 222 | transform to apply to target sample individually 223 | 224 | co_transform : class with __call__ function implemented 225 | transform to apply to both input and target sample simultaneously 226 | 227 | """ 228 | self.inputs = _process_array_argument(inputs) 229 | self.num_inputs = len(self.inputs) 230 | self.input_return_processor = _return_first_element_of_list if self.num_inputs==1 else _pass_through 231 | 232 | if targets is None: 233 | self.has_target = False 234 | else: 235 | self.targets = _process_array_argument(targets) 236 | self.num_targets = len(self.targets) 237 | self.target_return_processor = _return_first_element_of_list if self.num_targets==1 else _pass_through 238 | self.min_inputs_or_targets = min(self.num_inputs, self.num_targets) 239 | self.has_target = True 240 | 241 | self.input_transform = _process_transform_argument(input_transform, self.num_inputs) 242 | if self.has_target: 243 | self.target_transform = _process_transform_argument(target_transform, self.num_targets) 244 | self.co_transform = _process_co_transform_argument(co_transform, self.num_inputs, self.num_targets) 245 | 246 | def __getitem__(self, index): 247 | """ 248 | Index the dataset and return the input + target 249 | """ 250 | input_sample = [self.input_transform[i](self.inputs[i][index]) for i in range(self.num_inputs)] 251 | 252 | if self.has_target: 253 | target_sample = [self.target_transform[i](self.targets[i][index]) for i in range(self.num_targets)] 254 | #for i in range(self.min_inputs_or_targets): 255 | # input_sample[i], target_sample[i] = self.co_transform[i](input_sample[i], target_sample[i]) 256 | 257 | return self.input_return_processor(input_sample), self.target_return_processor(target_sample) 258 | else: 259 | return self.input_return_processor(input_sample) 260 | 261 | 262 | def default_file_reader(x): 263 | def pil_loader(path): 264 | return Image.open(path).convert('RGB') 265 | def npy_loader(path): 266 | return np.load(path) 267 | def nifti_loader(path): 268 | return nibabel.load(path).get_data() 269 | if isinstance(x, str): 270 | if x.endswith('.npy'): 271 | x = npy_loader(x) 272 | elif x.endsiwth('.nii.gz'): 273 | x = nifti_loader(x) 274 | else: 275 | try: 276 | x = pil_loader(x) 277 | except: 278 | raise ValueError('File Format is not supported') 279 | #else: 280 | #raise ValueError('x should be string, but got %s' % type(x)) 281 | return x 282 | 283 | def is_tuple_or_list(x): 284 | return isinstance(x, (tuple,list)) 285 | 286 | def _process_transform_argument(tform, num_inputs): 287 | tform = tform if tform is not None else _pass_through 288 | if is_tuple_or_list(tform): 289 | if len(tform) != num_inputs: 290 | raise Exception('If transform is list, must provide one transform for each input') 291 | tform = [t if t is not None else _pass_through for t in tform] 292 | else: 293 | tform = [tform] * num_inputs 294 | return tform 295 | 296 | def _process_co_transform_argument(tform, num_inputs, num_targets): 297 | tform = tform if tform is not None else _multi_arg_pass_through 298 | if is_tuple_or_list(tform): 299 | if len(tform) != num_inputs: 300 | raise Exception('If transform is list, must provide one transform for each input') 301 | tform = [t if t is not None else _multi_arg_pass_through for t in tform] 302 | else: 303 | tform = [tform] * min(num_inputs, num_targets) 304 | return tform 305 | 306 | def _process_csv_argument(csv): 307 | if isinstance(csv, str): 308 | df = pd.read_csv(csv) 309 | elif isinstance(csv, pd.DataFrame): 310 | df = csv 311 | else: 312 | raise ValueError('csv argument must be string or dataframe') 313 | return df 314 | 315 | def _select_dataframe_columns(df, cols): 316 | if isinstance(cols[0], str): 317 | inputs = df.loc[:,cols].values 318 | elif isinstance(cols[0], int): 319 | inputs = df.iloc[:,cols].values 320 | else: 321 | raise ValueError('Provided columns should be string column names or integer column indices') 322 | return inputs 323 | 324 | def _process_cols_argument(cols): 325 | if isinstance(cols, tuple): 326 | cols = list(cols) 327 | return cols 328 | 329 | def _return_first_element_of_list(x): 330 | return x[0] 331 | 332 | def _pass_through(x): 333 | return x 334 | 335 | def _multi_arg_pass_through(*x): 336 | return x 337 | 338 | 339 | class CSVDataset(BaseDataset): 340 | 341 | def __init__(self, 342 | csv, 343 | input_cols=[0], 344 | target_cols=[1], 345 | input_transform=None, 346 | target_transform=None, 347 | co_transform=None): 348 | """ 349 | Initialize a Dataset from a CSV file/dataframe. This does NOT 350 | actually load the data into memory if the CSV contains filepaths. 351 | 352 | Arguments 353 | --------- 354 | csv : string or pandas.DataFrame 355 | if string, should be a path to a .csv file which 356 | can be loaded as a pandas dataframe 357 | 358 | input_cols : int/list of ints, or string/list of strings 359 | which columns to use as input arrays. 360 | If int(s), should be column indicies 361 | If str(s), should be column names 362 | 363 | target_cols : int/list of ints, or string/list of strings 364 | which columns to use as input arrays. 365 | If int(s), should be column indicies 366 | If str(s), should be column names 367 | 368 | input_transform : class which implements a __call__ method 369 | tranform(s) to apply to inputs during runtime loading 370 | 371 | target_tranform : class which implements a __call__ method 372 | transform(s) to apply to targets during runtime loading 373 | 374 | co_transform : class which implements a __call__ method 375 | transform(s) to apply to both inputs and targets simultaneously 376 | during runtime loading 377 | """ 378 | self.input_cols = _process_cols_argument(input_cols) 379 | self.target_cols = _process_cols_argument(target_cols) 380 | 381 | self.df = _process_csv_argument(csv) 382 | 383 | self.inputs = _select_dataframe_columns(self.df, input_cols) 384 | self.num_inputs = self.inputs.shape[1] 385 | self.input_return_processor = _return_first_element_of_list if self.num_inputs==1 else _pass_through 386 | 387 | if target_cols is None: 388 | self.num_targets = 0 389 | self.has_target = False 390 | else: 391 | self.targets = _select_dataframe_columns(self.df, target_cols) 392 | self.num_targets = self.targets.shape[1] 393 | self.target_return_processor = _return_first_element_of_list if self.num_targets==1 else _pass_through 394 | self.has_target = True 395 | self.min_inputs_or_targets = min(self.num_inputs, self.num_targets) 396 | 397 | self.input_loader = default_file_reader 398 | self.target_loader = default_file_reader 399 | 400 | self.input_transform = _process_transform_argument(input_transform, self.num_inputs) 401 | if self.has_target: 402 | self.target_transform = _process_transform_argument(target_transform, self.num_targets) 403 | self.co_transform = _process_co_transform_argument(co_transform, self.num_inputs, self.num_targets) 404 | 405 | def __getitem__(self, index): 406 | """ 407 | Index the dataset and return the input + target 408 | """ 409 | input_sample = [self.input_transform[i](self.input_loader(self.inputs[index, i])) for i in range(self.num_inputs)] 410 | 411 | if self.has_target: 412 | target_sample = [self.target_transform[i](self.target_loader(self.targets[index, i])) for i in range(self.num_targets)] 413 | for i in range(self.min_inputs_or_targets): 414 | input_sample[i], input_sample[i] = self.co_transform[i](input_sample[i], target_sample[i]) 415 | 416 | return self.input_return_processor(input_sample), self.target_return_processor(target_sample) 417 | else: 418 | return self.input_return_processor(input_sample) 419 | 420 | def split_by_column(self, col): 421 | """ 422 | Split this dataset object into multiple dataset objects based on 423 | the unique factors of the given column. The number of returned 424 | datasets will be equal to the number of unique values in the given 425 | column. The transforms and original dataframe will all be transferred 426 | to the new datasets 427 | 428 | Useful for splitting a dataset into train/val/test datasets. 429 | 430 | Arguments 431 | --------- 432 | col : integer or string 433 | which column to split the data on. 434 | if int, should be column index 435 | if str, should be column name 436 | 437 | Returns 438 | ------- 439 | - list of new datasets with transforms copied 440 | """ 441 | if isinstance(col, int): 442 | split_vals = self.df.iloc[:,col].values.flatten() 443 | 444 | new_df_list = [] 445 | for unique_split_val in np.unique(split_vals): 446 | new_df = self.df[:][self.df.iloc[:,col]==unique_split_val] 447 | new_df_list.append(new_df) 448 | elif isinstance(col, str): 449 | split_vals = self.df.loc[:,col].values.flatten() 450 | 451 | new_df_list = [] 452 | for unique_split_val in np.unique(split_vals): 453 | new_df = self.df[:][self.df.loc[:,col]==unique_split_val] 454 | new_df_list.append(new_df) 455 | else: 456 | raise ValueError('col argument not valid - must be column name or index') 457 | 458 | new_datasets = [] 459 | for new_df in new_df_list: 460 | new_dataset = self.copy(new_df) 461 | new_datasets.append(new_dataset) 462 | 463 | return new_datasets 464 | 465 | def train_test_split(self, train_size): 466 | if train_size < 1: 467 | train_size = int(train_size * len(self)) 468 | 469 | train_indices = np.random.choice(len(self), train_size, replace=False) 470 | test_indices = np.array([i for i in range(len(self)) if i not in train_indices]) 471 | 472 | train_df = self.df.iloc[train_indices,:] 473 | test_df = self.df.iloc[test_indices,:] 474 | 475 | train_dataset = self.copy(train_df) 476 | test_dataset = self.copy(test_df) 477 | 478 | return train_dataset, test_dataset 479 | 480 | def copy(self, df=None): 481 | if df is None: 482 | df = self.df 483 | 484 | return CSVDataset(df, 485 | input_cols=self.input_cols, 486 | target_cols=self.target_cols, 487 | input_transform=self.input_transform, 488 | target_transform=self.target_transform, 489 | co_transform=self.co_transform) 490 | 491 | 492 | class FolderDataset(BaseDataset): 493 | 494 | def __init__(self, 495 | root, 496 | class_mode='label', 497 | input_regex='*', 498 | target_regex=None, 499 | input_transform=None, 500 | target_transform=None, 501 | co_transform=None, 502 | input_loader='npy'): 503 | """ 504 | Dataset class for loading out-of-memory data. 505 | 506 | Arguments 507 | --------- 508 | root : string 509 | path to main directory 510 | 511 | class_mode : string in `{'label', 'image'}` 512 | type of target sample to look for and return 513 | `label` = return class folder as target 514 | `image` = return another image as target as found by 'target_regex' 515 | NOTE: if class_mode == 'image', you must give an 516 | input and target regex and the input/target images should 517 | be in a folder together with no other images in that folder 518 | 519 | input_regex : string (default is any valid image file) 520 | regular expression to find input images 521 | e.g. if all your inputs have the word 'input', 522 | you'd enter something like input_regex='*input*' 523 | 524 | target_regex : string (default is Nothing) 525 | regular expression to find target images if class_mode == 'image' 526 | e.g. if all your targets have the word 'segment', 527 | you'd enter somthing like target_regex='*segment*' 528 | 529 | transform : transform class 530 | transform to apply to input sample individually 531 | 532 | target_transform : transform class 533 | transform to apply to target sample individually 534 | 535 | input_loader : string in `{'npy', 'pil', 'nifti'} or callable 536 | defines how to load samples from file 537 | if a function is provided, it should take in a file path 538 | as input and return the loaded sample. 539 | 540 | """ 541 | self.input_loader = default_file_reader 542 | self.target_loader = default_file_reader if class_mode == 'image' else lambda x: x 543 | 544 | root = os.path.expanduser(root) 545 | 546 | classes, class_to_idx = _find_classes(root) 547 | inputs, targets = _finds_inputs_and_targets(root, class_mode, 548 | class_to_idx, input_regex, target_regex) 549 | 550 | if len(inputs) == 0: 551 | raise(RuntimeError('Found 0 images in subfolders of: %s' % root)) 552 | else: 553 | print('Found %i images' % len(inputs)) 554 | 555 | self.root = os.path.expanduser(root) 556 | self.inputs = inputs 557 | self.targets = targets 558 | self.classes = classes 559 | self.class_to_idx = class_to_idx 560 | 561 | self.input_transform = input_transform if input_transform is not None else lambda x: x 562 | if isinstance(input_transform, (tuple,list)): 563 | self.input_transform = transforms.Compose(self.input_transform) 564 | self.target_transform = target_transform if target_transform is not None else lambda x: x 565 | if isinstance(target_transform, (tuple,list)): 566 | self.target_transform = transforms.Compose(self.target_transform) 567 | self.co_transform = co_transform if co_transform is not None else lambda x,y: (x,y) 568 | if isinstance(co_transform, (tuple,list)): 569 | self.co_transform = transforms.Compose(self.co_transform) 570 | 571 | self.class_mode = class_mode 572 | 573 | def get_full_paths(self): 574 | return [os.path.join(self.root, i) for i in self.inputs] 575 | 576 | def __getitem__(self, index): 577 | input_sample = self.inputs[index] 578 | input_sample = self.input_loader(input_sample) 579 | input_sample = self.input_transform(input_sample) 580 | 581 | target_sample = self.targets[index] 582 | target_sample = self.target_loader(target_sample) 583 | target_sample = self.target_transform(target_sample) 584 | 585 | input_sample, target_sample = self.co_transform(input_sample, target_sample) 586 | 587 | return input_sample, target_sample 588 | 589 | def __len__(self): 590 | return len(self.inputs) 591 | 592 | 593 | 594 | def _find_classes(dir): 595 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 596 | classes.sort() 597 | class_to_idx = {classes[i]: i for i in range(len(classes))} 598 | return classes, class_to_idx 599 | 600 | def _is_image_file(filename): 601 | IMG_EXTENSIONS = [ 602 | '.jpg', '.JPG', '.jpeg', '.JPEG', 603 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 604 | '.nii.gz', '.npy' 605 | ] 606 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 607 | 608 | def _finds_inputs_and_targets(directory, class_mode, class_to_idx=None, 609 | input_regex=None, target_regex=None, ): 610 | """ 611 | Map a dataset from a root folder 612 | """ 613 | if class_mode == 'image': 614 | if not input_regex and not target_regex: 615 | raise ValueError('must give input_regex and target_regex if'+ 616 | ' class_mode==image') 617 | inputs = [] 618 | targets = [] 619 | for subdir in sorted(os.listdir(directory)): 620 | d = os.path.join(directory, subdir) 621 | if not os.path.isdir(d): 622 | continue 623 | 624 | for root, _, fnames in sorted(os.walk(d)): 625 | for fname in fnames: 626 | if _is_image_file(fname): 627 | if fnmatch.fnmatch(fname, input_regex): 628 | path = os.path.join(root, fname) 629 | inputs.append(path) 630 | if class_mode == 'label': 631 | targets.append(class_to_idx[subdir]) 632 | if class_mode == 'image' and \ 633 | fnmatch.fnmatch(fname, target_regex): 634 | path = os.path.join(root, fname) 635 | targets.append(path) 636 | if class_mode is None: 637 | return inputs 638 | else: 639 | return inputs, targets 640 | -------------------------------------------------------------------------------- /torchsample/functions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .affine import * -------------------------------------------------------------------------------- /torchsample/functions/affine.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from ..utils import th_iterproduct, th_flatten 7 | 8 | 9 | def F_affine2d(x, matrix, center=True): 10 | """ 11 | 2D Affine image transform on torch.autograd.Variable 12 | """ 13 | if matrix.dim() == 2: 14 | matrix = matrix.view(-1,2,3) 15 | 16 | A_batch = matrix[:,:,:2] 17 | if A_batch.size(0) != x.size(0): 18 | A_batch = A_batch.repeat(x.size(0),1,1) 19 | b_batch = matrix[:,:,2].unsqueeze(1) 20 | 21 | # make a meshgrid of normal coordinates 22 | _coords = th_iterproduct(x.size(1),x.size(2)) 23 | coords = Variable(_coords.unsqueeze(0).repeat(x.size(0),1,1).float(), 24 | requires_grad=False) 25 | if center: 26 | # shift the coordinates so center is the origin 27 | coords[:,:,0] = coords[:,:,0] - (x.size(1) / 2. + 0.5) 28 | coords[:,:,1] = coords[:,:,1] - (x.size(2) / 2. + 0.5) 29 | 30 | # apply the coordinate transformation 31 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords) 32 | 33 | if center: 34 | # shift the coordinates back so origin is origin 35 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(1) / 2. + 0.5) 36 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(2) / 2. + 0.5) 37 | 38 | # map new coordinates using bilinear interpolation 39 | x_transformed = F_bilinear_interp2d(x, new_coords) 40 | 41 | return x_transformed 42 | 43 | 44 | def F_bilinear_interp2d(input, coords): 45 | """ 46 | bilinear interpolation of 2d torch.autograd.Variable 47 | """ 48 | x = torch.clamp(coords[:,:,0], 0, input.size(1)-2) 49 | x0 = x.floor() 50 | x1 = x0 + 1 51 | y = torch.clamp(coords[:,:,1], 0, input.size(2)-2) 52 | y0 = y.floor() 53 | y1 = y0 + 1 54 | 55 | stride = torch.LongTensor(input.stride()) 56 | x0_ix = x0.mul(stride[1]).long() 57 | x1_ix = x1.mul(stride[1]).long() 58 | y0_ix = y0.mul(stride[2]).long() 59 | y1_ix = y1.mul(stride[2]).long() 60 | 61 | input_flat = input.view(input.size(0),-1).contiguous() 62 | 63 | vals_00 = input_flat.gather(1, x0_ix.add(y0_ix).detach()) 64 | vals_10 = input_flat.gather(1, x1_ix.add(y0_ix).detach()) 65 | vals_01 = input_flat.gather(1, x0_ix.add(y1_ix).detach()) 66 | vals_11 = input_flat.gather(1, x1_ix.add(y1_ix).detach()) 67 | 68 | xd = x - x0 69 | yd = y - y0 70 | xm = 1 - xd 71 | ym = 1 - yd 72 | 73 | x_mapped = (vals_00.mul(xm).mul(ym) + 74 | vals_10.mul(xd).mul(ym) + 75 | vals_01.mul(xm).mul(yd) + 76 | vals_11.mul(xd).mul(yd)) 77 | 78 | return x_mapped.view_as(input) 79 | 80 | 81 | def F_batch_affine2d(x, matrix, center=True): 82 | """ 83 | 84 | x : torch.Tensor 85 | shape = (Samples, C, H, W) 86 | NOTE: Assume C is always equal to 1! 87 | matrix : torch.Tensor 88 | shape = (Samples, 6) or (Samples, 2, 3) 89 | 90 | Example 91 | ------- 92 | >>> x = Variable(torch.zeros(3,1,10,10)) 93 | >>> x[:,:,3:7,3:7] = 1 94 | >>> m1 = torch.FloatTensor([[1.2,0,0],[0,1.2,0]]) 95 | >>> m2 = torch.FloatTensor([[0.8,0,0],[0,0.8,0]]) 96 | >>> m3 = torch.FloatTensor([[1.0,0,3],[0,1.0,3]]) 97 | >>> matrix = Variable(torch.stack([m1,m2,m3])) 98 | >>> xx = F_batch_affine2d(x,matrix) 99 | """ 100 | if matrix.dim() == 2: 101 | matrix = matrix.view(-1,2,3) 102 | 103 | A_batch = matrix[:,:,:2] 104 | b_batch = matrix[:,:,2].unsqueeze(1) 105 | 106 | # make a meshgrid of normal coordinates 107 | _coords = th_iterproduct(x.size(2),x.size(3)) 108 | coords = Variable(_coords.unsqueeze(0).repeat(x.size(0),1,1).float(), 109 | requires_grad=False) 110 | 111 | if center: 112 | # shift the coordinates so center is the origin 113 | coords[:,:,0] = coords[:,:,0] - (x.size(2) / 2. + 0.5) 114 | coords[:,:,1] = coords[:,:,1] - (x.size(3) / 2. + 0.5) 115 | 116 | # apply the coordinate transformation 117 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords) 118 | 119 | if center: 120 | # shift the coordinates back so origin is origin 121 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(2) / 2. + 0.5) 122 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(3) / 2. + 0.5) 123 | 124 | # map new coordinates using bilinear interpolation 125 | x_transformed = F_batch_bilinear_interp2d(x, new_coords) 126 | 127 | return x_transformed 128 | 129 | 130 | def F_batch_bilinear_interp2d(input, coords): 131 | """ 132 | input : torch.Tensor 133 | size = (N,H,W,C) 134 | coords : torch.Tensor 135 | size = (N,H*W*C,2) 136 | """ 137 | x = torch.clamp(coords[:,:,0], 0, input.size(2)-2) 138 | x0 = x.floor() 139 | x1 = x0 + 1 140 | y = torch.clamp(coords[:,:,1], 0, input.size(3)-2) 141 | y0 = y.floor() 142 | y1 = y0 + 1 143 | 144 | stride = torch.LongTensor(input.stride()) 145 | x0_ix = x0.mul(stride[2]).long() 146 | x1_ix = x1.mul(stride[2]).long() 147 | y0_ix = y0.mul(stride[3]).long() 148 | y1_ix = y1.mul(stride[3]).long() 149 | 150 | input_flat = input.view(input.size(0),-1).contiguous() 151 | 152 | vals_00 = input_flat.gather(1, x0_ix.add(y0_ix).detach()) 153 | vals_10 = input_flat.gather(1, x1_ix.add(y0_ix).detach()) 154 | vals_01 = input_flat.gather(1, x0_ix.add(y1_ix).detach()) 155 | vals_11 = input_flat.gather(1, x1_ix.add(y1_ix).detach()) 156 | 157 | xd = x - x0 158 | yd = y - y0 159 | xm = 1 - xd 160 | ym = 1 - yd 161 | 162 | x_mapped = (vals_00.mul(xm).mul(ym) + 163 | vals_10.mul(xd).mul(ym) + 164 | vals_01.mul(xm).mul(yd) + 165 | vals_11.mul(xd).mul(yd)) 166 | 167 | return x_mapped.view_as(input) 168 | 169 | 170 | def F_affine3d(x, matrix, center=True): 171 | A = matrix[:3,:3] 172 | b = matrix[:3,3] 173 | 174 | # make a meshgrid of normal coordinates 175 | coords = Variable(th_iterproduct(x.size(1),x.size(2),x.size(3)).float(), 176 | requires_grad=False) 177 | 178 | if center: 179 | # shift the coordinates so center is the origin 180 | coords[:,0] = coords[:,0] - (x.size(1) / 2. + 0.5) 181 | coords[:,1] = coords[:,1] - (x.size(2) / 2. + 0.5) 182 | coords[:,2] = coords[:,2] - (x.size(3) / 2. + 0.5) 183 | 184 | 185 | # apply the coordinate transformation 186 | new_coords = F.linear(coords, A, b) 187 | 188 | if center: 189 | # shift the coordinates back so origin is origin 190 | new_coords[:,0] = new_coords[:,0] + (x.size(1) / 2. + 0.5) 191 | new_coords[:,1] = new_coords[:,1] + (x.size(2) / 2. + 0.5) 192 | new_coords[:,2] = new_coords[:,2] + (x.size(3) / 2. + 0.5) 193 | 194 | # map new coordinates using bilinear interpolation 195 | x_transformed = F_trilinear_interp3d(x, new_coords) 196 | 197 | return x_transformed 198 | 199 | 200 | def F_trilinear_interp3d(input, coords): 201 | """ 202 | trilinear interpolation of 3D image 203 | """ 204 | # take clamp then floor/ceil of x coords 205 | x = torch.clamp(coords[:,0], 0, input.size(1)-2) 206 | x0 = x.floor() 207 | x1 = x0 + 1 208 | # take clamp then floor/ceil of y coords 209 | y = torch.clamp(coords[:,1], 0, input.size(2)-2) 210 | y0 = y.floor() 211 | y1 = y0 + 1 212 | # take clamp then floor/ceil of z coords 213 | z = torch.clamp(coords[:,2], 0, input.size(3)-2) 214 | z0 = z.floor() 215 | z1 = z0 + 1 216 | 217 | stride = torch.LongTensor(input.stride())[1:] 218 | x0_ix = x0.mul(stride[0]).long() 219 | x1_ix = x1.mul(stride[0]).long() 220 | y0_ix = y0.mul(stride[1]).long() 221 | y1_ix = y1.mul(stride[1]).long() 222 | z0_ix = z0.mul(stride[2]).long() 223 | z1_ix = z1.mul(stride[2]).long() 224 | 225 | input_flat = th_flatten(input) 226 | 227 | vals_000 = input_flat[x0_ix.add(y0_ix).add(z0_ix).detach()] 228 | vals_100 = input_flat[x1_ix.add(y0_ix).add(z0_ix).detach()] 229 | vals_010 = input_flat[x0_ix.add(y1_ix).add(z0_ix).detach()] 230 | vals_001 = input_flat[x0_ix.add(y0_ix).add(z1_ix).detach()] 231 | vals_101 = input_flat[x1_ix.add(y0_ix).add(z1_ix).detach()] 232 | vals_011 = input_flat[x0_ix.add(y1_ix).add(z1_ix).detach()] 233 | vals_110 = input_flat[x1_ix.add(y1_ix).add(z0_ix).detach()] 234 | vals_111 = input_flat[x1_ix.add(y1_ix).add(z1_ix).detach()] 235 | 236 | xd = x - x0 237 | yd = y - y0 238 | zd = z - z0 239 | xm = 1 - xd 240 | ym = 1 - yd 241 | zm = 1 - zd 242 | 243 | x_mapped = (vals_000.mul(xm).mul(ym).mul(zm) + 244 | vals_100.mul(xd).mul(ym).mul(zm) + 245 | vals_010.mul(xm).mul(yd).mul(zm) + 246 | vals_001.mul(xm).mul(ym).mul(zd) + 247 | vals_101.mul(xd).mul(ym).mul(zd) + 248 | vals_011.mul(xm).mul(yd).mul(zd) + 249 | vals_110.mul(xd).mul(yd).mul(zm) + 250 | vals_111.mul(xd).mul(yd).mul(zd)) 251 | 252 | return x_mapped.view_as(input) 253 | 254 | 255 | def F_batch_affine3d(x, matrix, center=True): 256 | """ 257 | 258 | x : torch.Tensor 259 | shape = (Samples, C, H, W) 260 | NOTE: Assume C is always equal to 1! 261 | matrix : torch.Tensor 262 | shape = (Samples, 6) or (Samples, 2, 3) 263 | 264 | Example 265 | ------- 266 | >>> x = Variable(torch.zeros(3,1,10,10,10)) 267 | >>> x[:,:,3:7,3:7,3:7] = 1 268 | >>> m1 = torch.FloatTensor([[1.2,0,0,0],[0,1.2,0,0],[0,0,1.2,0]]) 269 | >>> m2 = torch.FloatTensor([[0.8,0,0,0],[0,0.8,0,0],[0,0,0.8,0]]) 270 | >>> m3 = torch.FloatTensor([[1.0,0,0,3],[0,1.0,0,3],[0,0,1.0,3]]) 271 | >>> matrix = Variable(torch.stack([m1,m2,m3])) 272 | >>> xx = F_batch_affine3d(x,matrix) 273 | """ 274 | if matrix.dim() == 2: 275 | matrix = matrix.view(-1,3,4) 276 | 277 | A_batch = matrix[:,:3,:3] 278 | b_batch = matrix[:,:3,3].unsqueeze(1) 279 | 280 | # make a meshgrid of normal coordinates 281 | _coords = th_iterproduct(x.size(2),x.size(3),x.size(4)) 282 | coords = Variable(_coords.unsqueeze(0).repeat(x.size(0),1,1).float(), 283 | requires_grad=False) 284 | 285 | if center: 286 | # shift the coordinates so center is the origin 287 | coords[:,:,0] = coords[:,:,0] - (x.size(2) / 2. + 0.5) 288 | coords[:,:,1] = coords[:,:,1] - (x.size(3) / 2. + 0.5) 289 | coords[:,:,2] = coords[:,:,2] - (x.size(4) / 2. + 0.5) 290 | 291 | # apply the coordinate transformation 292 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords) 293 | 294 | if center: 295 | # shift the coordinates back so origin is origin 296 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(2) / 2. + 0.5) 297 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(3) / 2. + 0.5) 298 | new_coords[:,:,2] = new_coords[:,:,2] + (x.size(4) / 2. + 0.5) 299 | 300 | # map new coordinates using bilinear interpolation 301 | x_transformed = F_batch_trilinear_interp3d(x, new_coords) 302 | 303 | return x_transformed 304 | 305 | 306 | def F_batch_trilinear_interp3d(input, coords): 307 | """ 308 | input : torch.Tensor 309 | size = (N,H,W,C) 310 | coords : torch.Tensor 311 | size = (N,H*W*C,2) 312 | """ 313 | x = torch.clamp(coords[:,:,0], 0, input.size(2)-2) 314 | x0 = x.floor() 315 | x1 = x0 + 1 316 | y = torch.clamp(coords[:,:,1], 0, input.size(3)-2) 317 | y0 = y.floor() 318 | y1 = y0 + 1 319 | z = torch.clamp(coords[:,:,2], 0, input.size(4)-2) 320 | z0 = z.floor() 321 | z1 = z0 + 1 322 | 323 | stride = torch.LongTensor(input.stride()) 324 | x0_ix = x0.mul(stride[2]).long() 325 | x1_ix = x1.mul(stride[2]).long() 326 | y0_ix = y0.mul(stride[3]).long() 327 | y1_ix = y1.mul(stride[3]).long() 328 | z0_ix = z0.mul(stride[4]).long() 329 | z1_ix = z1.mul(stride[4]).long() 330 | 331 | input_flat = input.contiguous().view(input.size(0),-1) 332 | 333 | vals_000 = input_flat.gather(1,x0_ix.add(y0_ix).add(z0_ix).detach()) 334 | vals_100 = input_flat.gather(1,x1_ix.add(y0_ix).add(z0_ix).detach()) 335 | vals_010 = input_flat.gather(1,x0_ix.add(y1_ix).add(z0_ix).detach()) 336 | vals_001 = input_flat.gather(1,x0_ix.add(y0_ix).add(z1_ix).detach()) 337 | vals_101 = input_flat.gather(1,x1_ix.add(y0_ix).add(z1_ix).detach()) 338 | vals_011 = input_flat.gather(1,x0_ix.add(y1_ix).add(z1_ix).detach()) 339 | vals_110 = input_flat.gather(1,x1_ix.add(y1_ix).add(z0_ix).detach()) 340 | vals_111 = input_flat.gather(1,x1_ix.add(y1_ix).add(z1_ix).detach()) 341 | 342 | xd = x - x0 343 | yd = y - y0 344 | zd = z - z0 345 | xm = 1 - xd 346 | ym = 1 - yd 347 | zm = 1 - zd 348 | 349 | x_mapped = (vals_000.mul(xm).mul(ym).mul(zm) + 350 | vals_100.mul(xd).mul(ym).mul(zm) + 351 | vals_010.mul(xm).mul(yd).mul(zm) + 352 | vals_001.mul(xm).mul(ym).mul(zd) + 353 | vals_101.mul(xd).mul(ym).mul(zd) + 354 | vals_011.mul(xm).mul(yd).mul(zd) + 355 | vals_110.mul(xd).mul(yd).mul(zm) + 356 | vals_111.mul(xd).mul(yd).mul(zd)) 357 | 358 | return x_mapped.view_as(input) 359 | 360 | 361 | -------------------------------------------------------------------------------- /torchsample/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes to initialize module weights 3 | """ 4 | 5 | from fnmatch import fnmatch 6 | 7 | import torch.nn.init 8 | 9 | 10 | def _validate_initializer_string(init): 11 | dir_f = dir(torch.nn.init) 12 | loss_fns = [d.lower() for d in dir_f] 13 | if isinstance(init, str): 14 | try: 15 | str_idx = loss_fns.index(init.lower()) 16 | except: 17 | raise ValueError('Invalid loss string input - must match pytorch function.') 18 | return getattr(torch.nn.init, dir(torch.nn.init)[str_idx]) 19 | elif callable(init): 20 | return init 21 | else: 22 | raise ValueError('Invalid loss input') 23 | 24 | 25 | class InitializerContainer(object): 26 | 27 | def __init__(self, initializers): 28 | self._initializers = initializers 29 | 30 | def apply(self, model): 31 | for initializer in self._initializers: 32 | model.apply(initializer) 33 | 34 | 35 | class Initializer(object): 36 | 37 | def __call__(self, module): 38 | raise NotImplementedError('Initializer must implement this method') 39 | 40 | 41 | class GeneralInitializer(Initializer): 42 | 43 | def __init__(self, initializer, bias=False, bias_only=False, **kwargs): 44 | self._initializer = _validate_initializer_string(initializer) 45 | self.kwargs = kwargs 46 | 47 | def __call__(self, module): 48 | classname = module.__class__.__name__ 49 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 50 | if self.bias_only: 51 | self._initializer(module.bias.data, **self.kwargs) 52 | else: 53 | self._initializer(module.weight.data, **self.kwargs) 54 | if self.bias: 55 | self._initializer(module.bias.data, **self.kwargs) 56 | 57 | 58 | class Normal(Initializer): 59 | 60 | def __init__(self, mean=0.0, std=0.02, bias=False, 61 | bias_only=False, module_filter='*'): 62 | self.mean = mean 63 | self.std = std 64 | 65 | self.bias = bias 66 | self.bias_only = bias_only 67 | self.module_filter = module_filter 68 | 69 | super(Normal, self).__init__() 70 | 71 | def __call__(self, module): 72 | classname = module.__class__.__name__ 73 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 74 | if self.bias_only: 75 | torch.nn.init.normal(module.bias.data, mean=self.mean, std=self.std) 76 | else: 77 | torch.nn.init.normal(module.weight.data, mean=self.mean, std=self.std) 78 | if self.bias: 79 | torch.nn.init.normal(module.bias.data, mean=self.mean, std=self.std) 80 | 81 | 82 | class Uniform(Initializer): 83 | 84 | def __init__(self, a=0, b=1, bias=False, bias_only=False, module_filter='*'): 85 | self.a = a 86 | self.b = b 87 | 88 | self.bias = bias 89 | self.bias_only = bias_only 90 | self.module_filter = module_filter 91 | 92 | super(Uniform, self).__init__() 93 | 94 | def __call__(self, module): 95 | classname = module.__class__.__name__ 96 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 97 | if self.bias_only: 98 | torch.nn.init.uniform(module.bias.data, a=self.a, b=self.b) 99 | else: 100 | torch.nn.init.uniform(module.weight.data, a=self.a, b=self.b) 101 | if self.bias: 102 | torch.nn.init.uniform(module.bias.data, a=self.a, b=self.b) 103 | 104 | 105 | class ConstantInitializer(Initializer): 106 | 107 | def __init__(self, value, bias=False, bias_only=False, module_filter='*'): 108 | self.value = value 109 | 110 | self.bias = bias 111 | self.bias_only = bias_only 112 | self.module_filter = module_filter 113 | 114 | super(ConstantInitializer, self).__init__() 115 | 116 | def __call__(self, module, bias=False, bias_only=False, module_filter='*'): 117 | classname = module.__class__.__name__ 118 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 119 | if self.bias_only: 120 | torch.nn.init.constant(module.bias.data, val=self.value) 121 | else: 122 | torch.nn.init.constant(module.weight.data, val=self.value) 123 | if self.bias: 124 | torch.nn.init.constant(module.bias.data, val=self.value) 125 | 126 | 127 | class XavierUniform(Initializer): 128 | 129 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'): 130 | self.gain = gain 131 | 132 | self.bias = bias 133 | self.bias_only = bias_only 134 | self.module_filter = module_filter 135 | 136 | super(XavierUniform, self).__init__() 137 | 138 | def __call__(self, module): 139 | classname = module.__class__.__name__ 140 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 141 | if self.bias_only: 142 | torch.nn.init.xavier_uniform(module.bias.data, gain=self.gain) 143 | else: 144 | torch.nn.init.xavier_uniform(module.weight.data, gain=self.gain) 145 | if self.bias: 146 | torch.nn.init.xavier_uniform(module.bias.data, gain=self.gain) 147 | 148 | 149 | class XavierNormal(Initializer): 150 | 151 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'): 152 | self.gain = gain 153 | 154 | self.bias = bias 155 | self.bias_only = bias_only 156 | self.module_filter = module_filter 157 | 158 | super(XavierNormal, self).__init__() 159 | 160 | def __call__(self, module): 161 | classname = module.__class__.__name__ 162 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 163 | if self.bias_only: 164 | torch.nn.init.xavier_normal(module.bias.data, gain=self.gain) 165 | else: 166 | torch.nn.init.xavier_normal(module.weight.data, gain=self.gain) 167 | if self.bias: 168 | torch.nn.init.xavier_normal(module.bias.data, gain=self.gain) 169 | 170 | 171 | class KaimingUniform(Initializer): 172 | 173 | def __init__(self, a=0, mode='fan_in', bias=False, bias_only=False, module_filter='*'): 174 | self.a = a 175 | self.mode = mode 176 | 177 | self.bias = bias 178 | self.bias_only = bias_only 179 | self.module_filter = module_filter 180 | 181 | super(KaimingUniform, self).__init__() 182 | 183 | def __call__(self, module): 184 | classname = module.__class__.__name__ 185 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 186 | if self.bias_only: 187 | torch.nn.init.kaiming_uniform(module.bias.data, a=self.a, mode=self.mode) 188 | else: 189 | torch.nn.init.kaiming_uniform(module.weight.data, a=self.a, mode=self.mode) 190 | if self.bias: 191 | torch.nn.init.kaiming_uniform(module.bias.data, a=self.a, mode=self.mode) 192 | 193 | 194 | class KaimingNormal(Initializer): 195 | 196 | def __init__(self, a=0, mode='fan_in', bias=False, bias_only=False, module_filter='*'): 197 | self.a = a 198 | self.mode = mode 199 | 200 | self.bias = bias 201 | self.bias_only = bias_only 202 | self.module_filter = module_filter 203 | 204 | super(KaimingNormal, self).__init__() 205 | 206 | def __call__(self, module): 207 | classname = module.__class__.__name__ 208 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 209 | if self.bias_only: 210 | torch.nn.init.kaiming_normal(module.bias.data, a=self.a, mode=self.mode) 211 | else: 212 | torch.nn.init.kaiming_normal(module.weight.data, a=self.a, mode=self.mode) 213 | if self.bias: 214 | torch.nn.init.kaiming_normal(module.bias.data, a=self.a, mode=self.mode) 215 | 216 | 217 | class Orthogonal(Initializer): 218 | 219 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'): 220 | self.gain = gain 221 | 222 | self.bias = bias 223 | self.bias_only = bias_only 224 | self.module_filter = module_filter 225 | 226 | super(Orthogonal, self).__init__() 227 | 228 | def __call__(self, module): 229 | classname = module.__class__.__name__ 230 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 231 | if self.bias_only: 232 | torch.nn.init.orthogonal(module.bias.data, gain=self.gain) 233 | else: 234 | torch.nn.init.orthogonal(module.weight.data, gain=self.gain) 235 | if self.bias: 236 | torch.nn.init.orthogonal(module.bias.data, gain=self.gain) 237 | 238 | 239 | class Sparse(Initializer): 240 | 241 | def __init__(self, sparsity, std=0.01, bias=False, bias_only=False, module_filter='*'): 242 | self.sparsity = sparsity 243 | self.std = std 244 | 245 | self.bias = bias 246 | self.bias_only = bias_only 247 | self.module_filter = module_filter 248 | 249 | super(Sparse, self).__init__() 250 | 251 | def __call__(self, module): 252 | classname = module.__class__.__name__ 253 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'): 254 | if self.bias_only: 255 | torch.nn.init.sparse(module.bias.data, sparsity=self.sparsity, std=self.std) 256 | else: 257 | torch.nn.init.sparse(module.weight.data, sparsity=self.sparsity, std=self.std) 258 | if self.bias: 259 | torch.nn.init.sparse(module.bias.data, sparsity=self.sparsity, std=self.std) 260 | 261 | 262 | 263 | -------------------------------------------------------------------------------- /torchsample/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | 5 | import torch as th 6 | 7 | from .utils import th_matrixcorr 8 | 9 | from .callbacks import Callback 10 | 11 | class MetricContainer(object): 12 | 13 | 14 | def __init__(self, metrics, prefix=''): 15 | self.metrics = metrics 16 | self.helper = None 17 | self.prefix = prefix 18 | 19 | def set_helper(self, helper): 20 | self.helper = helper 21 | 22 | def reset(self): 23 | for metric in self.metrics: 24 | metric.reset() 25 | 26 | def __call__(self, output_batch, target_batch): 27 | logs = {} 28 | for metric in self.metrics: 29 | logs[self.prefix+metric._name] = self.helper.calculate_loss(output_batch, 30 | target_batch, 31 | metric) 32 | return logs 33 | 34 | class Metric(object): 35 | 36 | def __call__(self, y_pred, y_true): 37 | raise NotImplementedError('Custom Metrics must implement this function') 38 | 39 | def reset(self): 40 | raise NotImplementedError('Custom Metrics must implement this function') 41 | 42 | 43 | class MetricCallback(Callback): 44 | 45 | def __init__(self, container): 46 | self.container = container 47 | def on_epoch_begin(self, epoch_idx, logs): 48 | self.container.reset() 49 | 50 | class CategoricalAccuracy(Metric): 51 | 52 | def __init__(self, top_k=1): 53 | self.top_k = top_k 54 | self.correct_count = 0 55 | self.total_count = 0 56 | 57 | self._name = 'acc_metric' 58 | 59 | def reset(self): 60 | self.correct_count = 0 61 | self.total_count = 0 62 | 63 | def __call__(self, y_pred, y_true): 64 | top_k = y_pred.topk(self.top_k,1)[1] 65 | true_k = y_true.view(len(y_true),1).expand_as(top_k) 66 | self.correct_count += top_k.eq(true_k).float().sum().data[0] 67 | self.total_count += len(y_pred) 68 | accuracy = 100. * float(self.correct_count) / float(self.total_count) 69 | return accuracy 70 | 71 | 72 | class BinaryAccuracy(Metric): 73 | 74 | def __init__(self): 75 | self.correct_count = 0 76 | self.total_count = 0 77 | 78 | self._name = 'acc_metric' 79 | 80 | def reset(self): 81 | self.correct_count = 0 82 | self.total_count = 0 83 | 84 | def __call__(self, y_pred, y_true): 85 | y_pred_round = y_pred.round().long() 86 | self.correct_count += y_pred_round.eq(y_true).float().sum().data[0] 87 | self.total_count += len(y_pred) 88 | accuracy = 100. * float(self.correct_count) / float(self.total_count) 89 | return accuracy 90 | 91 | 92 | class ProjectionCorrelation(Metric): 93 | 94 | def __init__(self): 95 | self.corr_sum = 0. 96 | self.total_count = 0. 97 | 98 | self._name = 'corr_metric' 99 | 100 | def reset(self): 101 | self.corr_sum = 0. 102 | self.total_count = 0. 103 | 104 | def __call__(self, y_pred, y_true=None): 105 | """ 106 | y_pred should be two projections 107 | """ 108 | covar_mat = th.abs(th_matrixcorr(y_pred[0].data, y_pred[1].data)) 109 | self.corr_sum += th.trace(covar_mat) 110 | self.total_count += covar_mat.size(0) 111 | return self.corr_sum / self.total_count 112 | 113 | 114 | class ProjectionAntiCorrelation(Metric): 115 | 116 | def __init__(self): 117 | self.anticorr_sum = 0. 118 | self.total_count = 0. 119 | 120 | self._name = 'anticorr_metric' 121 | 122 | def reset(self): 123 | self.anticorr_sum = 0. 124 | self.total_count = 0. 125 | 126 | def __call__(self, y_pred, y_true=None): 127 | """ 128 | y_pred should be two projections 129 | """ 130 | covar_mat = th.abs(th_matrixcorr(y_pred[0].data, y_pred[1].data)) 131 | upper_sum = th.sum(th.triu(covar_mat,1)) 132 | lower_sum = th.sum(th.tril(covar_mat,-1)) 133 | self.anticorr_sum += upper_sum 134 | self.anticorr_sum += lower_sum 135 | self.total_count += covar_mat.size(0)*(covar_mat.size(1) - 1) 136 | return self.anticorr_sum / self.total_count 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /torchsample/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .module_trainer import ModuleTrainer 4 | -------------------------------------------------------------------------------- /torchsample/modules/_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import warnings 4 | 5 | try: 6 | from inspect import signature 7 | except: 8 | warnings.warn('inspect.signature not available... ' 9 | 'you should upgrade to Python 3.x') 10 | 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | 14 | from ..metrics import Metric, CategoricalAccuracy, BinaryAccuracy 15 | from ..initializers import GeneralInitializer 16 | 17 | def _add_regularizer_to_loss_fn(loss_fn, 18 | regularizer_container): 19 | def new_loss_fn(output_batch, target_batch): 20 | return loss_fn(output_batch, target_batch) + regularizer_container.get_value() 21 | return new_loss_fn 22 | 23 | def _is_iterable(x): 24 | return isinstance(x, (tuple, list)) 25 | def _is_tuple_or_list(x): 26 | return isinstance(x, (tuple, list)) 27 | 28 | def _parse_num_inputs_and_targets_from_loader(loader): 29 | """ NOT IMPLEMENTED """ 30 | #batch = next(iter(loader)) 31 | num_inputs = loader.dataset.num_inputs 32 | num_targets = loader.dataset.num_targets 33 | return num_inputs, num_targets 34 | 35 | def _parse_num_inputs_and_targets(inputs, targets=None): 36 | if isinstance(inputs, (list, tuple)): 37 | num_inputs = len(inputs) 38 | else: 39 | num_inputs = 1 40 | if targets is not None: 41 | if isinstance(targets, (list, tuple)): 42 | num_targets = len(targets) 43 | else: 44 | num_targets = 1 45 | else: 46 | num_targets = 0 47 | return num_inputs, num_targets 48 | 49 | def _standardize_user_data(inputs, targets=None): 50 | if not isinstance(inputs, (list,tuple)): 51 | inputs = [inputs] 52 | if targets is not None: 53 | if not isinstance(targets, (list,tuple)): 54 | targets = [targets] 55 | return inputs, targets 56 | else: 57 | return inputs 58 | 59 | def _validate_metric_input(metric): 60 | if isinstance(metric, str): 61 | if metric.upper() == 'CATEGORICAL_ACCURACY' or metric.upper() == 'ACCURACY': 62 | return CategoricalAccuracy() 63 | elif metric.upper() == 'BINARY_ACCURACY': 64 | return BinaryAccuracy() 65 | else: 66 | raise ValueError('Invalid metric string input - must match pytorch function.') 67 | elif isinstance(metric, Metric): 68 | return metric 69 | else: 70 | raise ValueError('Invalid metric input') 71 | 72 | def _validate_loss_input(loss): 73 | dir_f = dir(F) 74 | loss_fns = [d.lower() for d in dir_f] 75 | if isinstance(loss, str): 76 | if loss.lower() == 'unconstrained': 77 | return lambda x: x 78 | elif loss.lower() == 'unconstrained_sum': 79 | return lambda x: x.sum() 80 | elif loss.lower() == 'unconstrained_mean': 81 | return lambda x: x.mean() 82 | else: 83 | try: 84 | str_idx = loss_fns.index(loss.lower()) 85 | except: 86 | raise ValueError('Invalid loss string input - must match pytorch function.') 87 | return getattr(F, dir(F)[str_idx]) 88 | elif callable(loss): 89 | return loss 90 | else: 91 | raise ValueError('Invalid loss input') 92 | 93 | def _validate_optimizer_input(optimizer): 94 | dir_optim = dir(optim) 95 | opts = [o.lower() for o in dir_optim] 96 | if isinstance(optimizer, str): 97 | try: 98 | str_idx = opts.index(optimizer.lower()) 99 | except: 100 | raise ValueError('Invalid optimizer string input - must match pytorch function.') 101 | return getattr(optim, dir_optim[str_idx]) 102 | elif hasattr(optimizer, 'step') and hasattr(optimizer, 'zero_grad'): 103 | return optimizer 104 | else: 105 | raise ValueError('Invalid optimizer input') 106 | 107 | def _validate_initializer_input(initializer): 108 | if isinstance(initializer, str): 109 | try: 110 | initializer = GeneralInitializer(initializer) 111 | except: 112 | raise ValueError('Invalid initializer string input - must match pytorch function.') 113 | return initializer 114 | elif callable(initializer): 115 | return initializer 116 | else: 117 | raise ValueError('Invalid optimizer input') 118 | 119 | def _get_current_time(): 120 | return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p") 121 | 122 | def _nb_function_args(fn): 123 | return len(signature(fn).parameters) -------------------------------------------------------------------------------- /torchsample/regularizers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | from fnmatch import fnmatch 4 | 5 | from .callbacks import Callback 6 | 7 | class RegularizerContainer(object): 8 | 9 | def __init__(self, regularizers): 10 | self.regularizers = regularizers 11 | self._forward_hooks = [] 12 | 13 | def register_forward_hooks(self, model): 14 | for regularizer in self.regularizers: 15 | for module_name, module in model.named_modules(): 16 | if fnmatch(module_name, regularizer.module_filter) and hasattr(module, 'weight'): 17 | hook = module.register_forward_hook(regularizer) 18 | self._forward_hooks.append(hook) 19 | 20 | if len(self._forward_hooks) == 0: 21 | raise Exception('Tried to register regularizers but no modules ' 22 | 'were found that matched any module_filter argument.') 23 | 24 | def unregister_forward_hooks(self): 25 | for hook in self._forward_hooks: 26 | hook.remove() 27 | 28 | def reset(self): 29 | for r in self.regularizers: 30 | r.reset() 31 | 32 | def get_value(self): 33 | value = sum([r.value for r in self.regularizers]) 34 | self.current_value = value.data[0] 35 | return value 36 | 37 | def __len__(self): 38 | return len(self.regularizers) 39 | 40 | 41 | class RegularizerCallback(Callback): 42 | 43 | def __init__(self, container): 44 | self.container = container 45 | 46 | def on_batch_end(self, batch, logs=None): 47 | self.container.reset() 48 | 49 | 50 | class Regularizer(object): 51 | 52 | def reset(self): 53 | raise NotImplementedError('subclass must implement this method') 54 | 55 | def __call__(self, module, input=None, output=None): 56 | raise NotImplementedError('subclass must implement this method') 57 | 58 | 59 | class L1Regularizer(Regularizer): 60 | 61 | def __init__(self, scale=1e-3, module_filter='*'): 62 | self.scale = float(scale) 63 | self.module_filter = module_filter 64 | self.value = 0. 65 | 66 | def reset(self): 67 | self.value = 0. 68 | 69 | def __call__(self, module, input=None, output=None): 70 | value = th.sum(th.abs(module.weight)) * self.scale 71 | self.value += value 72 | 73 | 74 | class L2Regularizer(Regularizer): 75 | 76 | def __init__(self, scale=1e-3, module_filter='*'): 77 | self.scale = float(scale) 78 | self.module_filter = module_filter 79 | self.value = 0. 80 | 81 | def reset(self): 82 | self.value = 0. 83 | 84 | def __call__(self, module, input=None, output=None): 85 | value = th.sum(th.pow(module.weight,2)) * self.scale 86 | self.value += value 87 | 88 | 89 | class L1L2Regularizer(Regularizer): 90 | 91 | def __init__(self, l1_scale=1e-3, l2_scale=1e-3, module_filter='*'): 92 | self.l1 = L1Regularizer(l1_scale) 93 | self.l2 = L2Regularizer(l2_scale) 94 | self.module_filter = module_filter 95 | self.value = 0. 96 | 97 | def reset(self): 98 | self.value = 0. 99 | 100 | def __call__(self, module, input=None, output=None): 101 | self.l1(module, input, output) 102 | self.l2(module, input, output) 103 | self.value += (self.l1.value + self.l2.value) 104 | 105 | 106 | # ------------------------------------------------------------------ 107 | # ------------------------------------------------------------------ 108 | # ------------------------------------------------------------------ 109 | 110 | class UnitNormRegularizer(Regularizer): 111 | """ 112 | UnitNorm constraint on Weights 113 | 114 | Constraints the weights to have column-wise unit norm 115 | """ 116 | def __init__(self, 117 | scale=1e-3, 118 | module_filter='*'): 119 | 120 | self.scale = scale 121 | self.module_filter = module_filter 122 | self.value = 0. 123 | 124 | def reset(self): 125 | self.value = 0. 126 | 127 | def __call__(self, module, input=None, output=None): 128 | w = module.weight 129 | norm_diff = th.norm(w, 2, 1).sub(1.) 130 | value = self.scale * th.sum(norm_diff.gt(0).float().mul(norm_diff)) 131 | self.value += value 132 | 133 | 134 | class MaxNormRegularizer(Regularizer): 135 | """ 136 | MaxNorm regularizer on Weights 137 | 138 | Constraints the weights to have column-wise unit norm 139 | """ 140 | def __init__(self, 141 | scale=1e-3, 142 | module_filter='*'): 143 | 144 | self.scale = scale 145 | self.module_filter = module_filter 146 | self.value = 0. 147 | 148 | def reset(self): 149 | self.value = 0. 150 | 151 | def __call__(self, module, input=None, output=None): 152 | w = module.weight 153 | norm_diff = th.norm(w,2,self.axis).sub(self.value) 154 | value = self.scale * th.sum(norm_diff.gt(0).float().mul(norm_diff)) 155 | self.value += value 156 | 157 | 158 | class NonNegRegularizer(Regularizer): 159 | """ 160 | Non-Negativity regularizer on Weights 161 | 162 | Constraints the weights to have column-wise unit norm 163 | """ 164 | def __init__(self, 165 | scale=1e-3, 166 | module_filter='*'): 167 | 168 | self.scale = scale 169 | self.module_filter = module_filter 170 | self.value = 0. 171 | 172 | def reset(self): 173 | self.value = 0. 174 | 175 | def __call__(self, module, input=None, output=None): 176 | w = module.weight 177 | value = -1 * self.scale * th.sum(w.gt(0).float().mul(w)) 178 | self.value += value 179 | 180 | -------------------------------------------------------------------------------- /torchsample/samplers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as th 3 | import math 4 | 5 | class Sampler(object): 6 | """Base class for all Samplers. 7 | 8 | Every Sampler subclass has to provide an __iter__ method, providing a way 9 | to iterate over indices of dataset elements, and a __len__ method that 10 | returns the length of the returned iterators. 11 | """ 12 | 13 | def __init__(self, data_source): 14 | pass 15 | 16 | def __iter__(self): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | class StratifiedSampler(Sampler): 23 | """Stratified Sampling 24 | 25 | Provides equal representation of target classes in each batch 26 | """ 27 | def __init__(self, class_vector, batch_size): 28 | """ 29 | Arguments 30 | --------- 31 | class_vector : torch tensor 32 | a vector of class labels 33 | batch_size : integer 34 | batch_size 35 | """ 36 | self.n_splits = int(class_vector.size(0) / batch_size) 37 | self.class_vector = class_vector 38 | 39 | def gen_sample_array(self): 40 | try: 41 | from sklearn.model_selection import StratifiedShuffleSplit 42 | except: 43 | print('Need scikit-learn for this functionality') 44 | import numpy as np 45 | 46 | s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.5) 47 | X = th.randn(self.class_vector.size(0),2).numpy() 48 | y = self.class_vector.numpy() 49 | s.get_n_splits(X, y) 50 | 51 | train_index, test_index = next(s.split(X, y)) 52 | return np.hstack([train_index, test_index]) 53 | 54 | def __iter__(self): 55 | return iter(self.gen_sample_array()) 56 | 57 | def __len__(self): 58 | return len(self.class_vector) 59 | 60 | class MultiSampler(Sampler): 61 | """Samples elements more than once in a single pass through the data. 62 | 63 | This allows the number of samples per epoch to be larger than the number 64 | of samples itself, which can be useful when training on 2D slices taken 65 | from 3D images, for instance. 66 | """ 67 | def __init__(self, nb_samples, desired_samples, shuffle=False): 68 | """Initialize MultiSampler 69 | 70 | Arguments 71 | --------- 72 | data_source : the dataset to sample from 73 | 74 | desired_samples : number of samples per batch you want 75 | whatever the difference is between an even division will 76 | be randomly selected from the samples. 77 | e.g. if len(data_source) = 3 and desired_samples = 4, then 78 | all 3 samples will be included and the last sample will be 79 | randomly chosen from the 3 original samples. 80 | 81 | shuffle : boolean 82 | whether to shuffle the indices or not 83 | 84 | Example: 85 | >>> m = MultiSampler(2, 6) 86 | >>> x = m.gen_sample_array() 87 | >>> print(x) # [0,1,0,1,0,1] 88 | """ 89 | self.data_samples = nb_samples 90 | self.desired_samples = desired_samples 91 | self.shuffle = shuffle 92 | 93 | def gen_sample_array(self): 94 | from torchsample.utils import th_random_choice 95 | n_repeats = self.desired_samples / self.data_samples 96 | cat_list = [] 97 | for i in range(math.floor(n_repeats)): 98 | cat_list.append(th.arange(0,self.data_samples)) 99 | # add the left over samples 100 | left_over = self.desired_samples % self.data_samples 101 | if left_over > 0: 102 | cat_list.append(th_random_choice(self.data_samples, left_over)) 103 | self.sample_idx_array = th.cat(cat_list).long() 104 | return self.sample_idx_array 105 | 106 | def __iter__(self): 107 | return iter(self.gen_sample_array()) 108 | 109 | def __len__(self): 110 | return self.desired_samples 111 | 112 | 113 | class SequentialSampler(Sampler): 114 | """Samples elements sequentially, always in the same order. 115 | 116 | Arguments: 117 | data_source (Dataset): dataset to sample from 118 | """ 119 | 120 | def __init__(self, nb_samples): 121 | self.num_samples = nb_samples 122 | 123 | def __iter__(self): 124 | return iter(range(self.num_samples)) 125 | 126 | def __len__(self): 127 | return self.num_samples 128 | 129 | 130 | class RandomSampler(Sampler): 131 | """Samples elements randomly, without replacement. 132 | 133 | Arguments: 134 | data_source (Dataset): dataset to sample from 135 | """ 136 | 137 | def __init__(self, nb_samples): 138 | self.num_samples = nb_samples 139 | 140 | def __iter__(self): 141 | return iter(th.randperm(self.num_samples).long()) 142 | 143 | def __len__(self): 144 | return self.num_samples 145 | 146 | 147 | -------------------------------------------------------------------------------- /torchsample/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from .affine_transforms import * 5 | from .image_transforms import * 6 | from .tensor_transforms import * -------------------------------------------------------------------------------- /torchsample/transforms/distortion_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms to distort local or global information of an image 3 | """ 4 | 5 | 6 | import torch as th 7 | import numpy as np 8 | import random 9 | 10 | 11 | class Scramble(object): 12 | """ 13 | Create blocks of an image and scramble them 14 | """ 15 | def __init__(self, blocksize): 16 | self.blocksize = blocksize 17 | 18 | def __call__(self, *inputs): 19 | outputs = [] 20 | for idx, _input in enumerate(inputs): 21 | size = _input.size() 22 | img_height = size[1] 23 | img_width = size[2] 24 | 25 | x_blocks = int(img_height/self.blocksize) # number of x blocks 26 | y_blocks = int(img_width/self.blocksize) 27 | ind = th.randperm(x_blocks*y_blocks) 28 | 29 | new = th.zeros(_input.size()) 30 | count = 0 31 | for i in range(x_blocks): 32 | for j in range (y_blocks): 33 | row = int(ind[count] / x_blocks) 34 | column = ind[count] % x_blocks 35 | new[:, i*self.blocksize:(i+1)*self.blocksize, j*self.blocksize:(j+1)*self.blocksize] = \ 36 | _input[:, row*self.blocksize:(row+1)*self.blocksize, column*self.blocksize:(column+1)*self.blocksize] 37 | count += 1 38 | outputs.append(new) 39 | return outputs if idx > 1 else outputs[0] 40 | 41 | 42 | class RandomChoiceScramble(object): 43 | 44 | def __init__(self, blocksizes): 45 | self.blocksizes = blocksizes 46 | 47 | def __call__(self, *inputs): 48 | blocksize = random.choice(self.blocksizes) 49 | outputs = Scramble(blocksize=blocksize)(*inputs) 50 | return outputs 51 | 52 | 53 | def _blur_image(image, H): 54 | # break image up into its color components 55 | size = image.shape 56 | imr = image[0,:,:] 57 | img = image[1,:,:] 58 | imb = image[2,:,:] 59 | 60 | # compute Fourier transform and frequqnecy spectrum 61 | Fim1r = np.fft.fftshift(np.fft.fft2(imr)) 62 | Fim1g = np.fft.fftshift(np.fft.fft2(img)) 63 | Fim1b = np.fft.fftshift(np.fft.fft2(imb)) 64 | 65 | # Apply the lowpass filter to the Fourier spectrum of the image 66 | filtered_imager = np.multiply(H, Fim1r) 67 | filtered_imageg = np.multiply(H, Fim1g) 68 | filtered_imageb = np.multiply(H, Fim1b) 69 | 70 | newim = np.zeros(size) 71 | 72 | # convert the result to the spatial domain. 73 | newim[0,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imager))) 74 | newim[1,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imageg))) 75 | newim[2,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imageb))) 76 | 77 | return newim.astype('uint8') 78 | 79 | def _butterworth_filter(rows, cols, thresh, order): 80 | # X and Y matrices with ranges normalised to +/- 0.5 81 | array1 = np.ones(rows) 82 | array2 = np.ones(cols) 83 | array3 = np.arange(1,rows+1) 84 | array4 = np.arange(1,cols+1) 85 | 86 | x = np.outer(array1, array4) 87 | y = np.outer(array3, array2) 88 | 89 | x = x - float(cols/2) - 1 90 | y = y - float(rows/2) - 1 91 | 92 | x = x / cols 93 | y = y / rows 94 | 95 | radius = np.sqrt(np.square(x) + np.square(y)) 96 | 97 | matrix1 = radius/thresh 98 | matrix2 = np.power(matrix1, 2*order) 99 | f = np.reciprocal(1 + matrix2) 100 | 101 | return f 102 | 103 | 104 | class Blur(object): 105 | """ 106 | Blur an image with a Butterworth filter with a frequency 107 | cutoff matching local block size 108 | """ 109 | def __init__(self, threshold, order=5): 110 | """ 111 | scramble blocksize of 128 => filter threshold of 64 112 | scramble blocksize of 64 => filter threshold of 32 113 | scramble blocksize of 32 => filter threshold of 16 114 | scramble blocksize of 16 => filter threshold of 8 115 | scramble blocksize of 8 => filter threshold of 4 116 | """ 117 | self.threshold = threshold 118 | self.order = order 119 | 120 | def __call__(self, *inputs): 121 | """ 122 | inputs should have values between 0 and 255 123 | """ 124 | outputs = [] 125 | for idx, _input in enumerate(inputs): 126 | rows = _input.size(1) 127 | cols = _input.size(2) 128 | fc = self.threshold # threshold 129 | fs = 128.0 # max frequency 130 | n = self.order # filter order 131 | fc_rad = (fc/fs)*0.5 132 | H = _butterworth_filter(rows, cols, fc_rad, n) 133 | _input_blurred = _blur_image(_input.numpy().astype('uint8'), H) 134 | _input_blurred = th.from_numpy(_input_blurred).float() 135 | outputs.append(_input_blurred) 136 | 137 | return outputs if idx > 1 else outputs[0] 138 | 139 | 140 | class RandomChoiceBlur(object): 141 | 142 | def __init__(self, thresholds, order=5): 143 | """ 144 | thresholds = [64.0, 32.0, 16.0, 8.0, 4.0] 145 | """ 146 | self.thresholds = thresholds 147 | self.order = order 148 | 149 | def __call__(self, *inputs): 150 | threshold = random.choice(self.thresholds) 151 | outputs = Blur(threshold=threshold, order=self.order)(*inputs) 152 | return outputs 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /torchsample/transforms/image_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms very specific to images such as 3 | color, lighting, contrast, brightness, etc transforms 4 | 5 | NOTE: Most of these transforms assume your image intensity 6 | is between 0 and 1, and are torch tensors (NOT numpy or PIL) 7 | """ 8 | 9 | import random 10 | 11 | import torch as th 12 | 13 | from ..utils import th_random_choice 14 | 15 | 16 | def _blend(img1, img2, alpha): 17 | """ 18 | Weighted sum of two images 19 | 20 | Arguments 21 | --------- 22 | img1 : torch tensor 23 | img2 : torch tensor 24 | alpha : float between 0 and 1 25 | how much weight to put on img1 and 1-alpha weight 26 | to put on img2 27 | """ 28 | return img1.mul(alpha).add(1 - alpha, img2) 29 | 30 | 31 | class Grayscale(object): 32 | 33 | def __init__(self, keep_channels=False): 34 | """ 35 | Convert RGB image to grayscale 36 | 37 | Arguments 38 | --------- 39 | keep_channels : boolean 40 | If true, will keep all 3 channels and they will be the same 41 | If false, will just return 1 grayscale channel 42 | """ 43 | self.keep_channels = keep_channels 44 | if keep_channels: 45 | self.channels = 3 46 | else: 47 | self.channels = 1 48 | 49 | def __call__(self, *inputs): 50 | outputs = [] 51 | for idx, _input in enumerate(inputs): 52 | _input_dst = _input[0]*0.299 + _input[1]*0.587 + _input[2]*0.114 53 | _input_gs = _input_dst.repeat(self.channels,1,1) 54 | outputs.append(_input_gs) 55 | return outputs if idx > 1 else outputs[0] 56 | 57 | class RandomGrayscale(object): 58 | 59 | def __init__(self, p=0.5): 60 | """ 61 | Randomly convert RGB image(s) to Grayscale w/ some probability, 62 | NOTE: Always retains the 3 channels if image is grayscaled 63 | 64 | p : a float 65 | probability that image will be grayscaled 66 | """ 67 | self.p = p 68 | 69 | def __call__(self, *inputs): 70 | pval = random.random() 71 | if pval < self.p: 72 | outputs = Grayscale(keep_channels=True)(*inputs) 73 | else: 74 | outputs = inputs 75 | return outputs 76 | 77 | # ---------------------------------------------------- 78 | # ---------------------------------------------------- 79 | 80 | class Gamma(object): 81 | 82 | def __init__(self, value): 83 | """ 84 | Performs Gamma Correction on the input image. Also known as 85 | Power Law Transform. This function transforms the input image 86 | pixelwise according 87 | to the equation Out = In**gamma after scaling each 88 | pixel to the range 0 to 1. 89 | 90 | Arguments 91 | --------- 92 | value : float 93 | <1 : image will tend to be lighter 94 | =1 : image will stay the same 95 | >1 : image will tend to be darker 96 | """ 97 | self.value = value 98 | 99 | def __call__(self, *inputs): 100 | outputs = [] 101 | for idx, _input in enumerate(inputs): 102 | _input = th.pow(_input, self.value) 103 | outputs.append(_input) 104 | return outputs if idx > 1 else outputs[0] 105 | 106 | class RandomGamma(object): 107 | 108 | def __init__(self, min_val, max_val): 109 | """ 110 | Performs Gamma Correction on the input image with some 111 | randomly selected gamma value between min_val and max_val. 112 | Also known as Power Law Transform. This function transforms 113 | the input image pixelwise according to the equation 114 | Out = In**gamma after scaling each pixel to the range 0 to 1. 115 | 116 | Arguments 117 | --------- 118 | min_val : float 119 | min range 120 | max_val : float 121 | max range 122 | 123 | NOTE: 124 | for values: 125 | <1 : image will tend to be lighter 126 | =1 : image will stay the same 127 | >1 : image will tend to be darker 128 | """ 129 | self.values = (min_val, max_val) 130 | 131 | def __call__(self, *inputs): 132 | value = random.uniform(self.values[0], self.values[1]) 133 | outputs = Gamma(value)(*inputs) 134 | return outputs 135 | 136 | class RandomChoiceGamma(object): 137 | 138 | def __init__(self, values, p=None): 139 | """ 140 | Performs Gamma Correction on the input image with some 141 | gamma value selected in the list of given values. 142 | Also known as Power Law Transform. This function transforms 143 | the input image pixelwise according to the equation 144 | Out = In**gamma after scaling each pixel to the range 0 to 1. 145 | 146 | Arguments 147 | --------- 148 | values : list of floats 149 | gamma values to sampled from 150 | p : list of floats - same length as `values` 151 | if None, values will be sampled uniformly. 152 | Must sum to 1. 153 | 154 | NOTE: 155 | for values: 156 | <1 : image will tend to be lighter 157 | =1 : image will stay the same 158 | >1 : image will tend to be darker 159 | """ 160 | self.values = values 161 | self.p = p 162 | 163 | def __call__(self, *inputs): 164 | value = th_random_choice(self.values, p=self.p) 165 | outputs = Gamma(value)(*inputs) 166 | return outputs 167 | 168 | # ---------------------------------------------------- 169 | # ---------------------------------------------------- 170 | 171 | class Brightness(object): 172 | def __init__(self, value): 173 | """ 174 | Alter the Brightness of an image 175 | 176 | Arguments 177 | --------- 178 | value : brightness factor 179 | =-1 = completely black 180 | <0 = darker 181 | 0 = no change 182 | >0 = brighter 183 | =1 = completely white 184 | """ 185 | self.value = max(min(value,1.0),-1.0) 186 | 187 | def __call__(self, *inputs): 188 | outputs = [] 189 | for idx, _input in enumerate(inputs): 190 | _input = th.clamp(_input.float().add(self.value).type(_input.type()), 0, 1) 191 | outputs.append(_input) 192 | return outputs if idx > 1 else outputs[0] 193 | 194 | class RandomBrightness(object): 195 | 196 | def __init__(self, min_val, max_val): 197 | """ 198 | Alter the Brightness of an image with a value randomly selected 199 | between `min_val` and `max_val` 200 | 201 | Arguments 202 | --------- 203 | min_val : float 204 | min range 205 | max_val : float 206 | max range 207 | """ 208 | self.values = (min_val, max_val) 209 | 210 | def __call__(self, *inputs): 211 | value = random.uniform(self.values[0], self.values[1]) 212 | outputs = Brightness(value)(*inputs) 213 | return outputs 214 | 215 | class RandomChoiceBrightness(object): 216 | 217 | def __init__(self, values, p=None): 218 | """ 219 | Alter the Brightness of an image with a value randomly selected 220 | from the list of given values with given probabilities 221 | 222 | Arguments 223 | --------- 224 | values : list of floats 225 | brightness values to sampled from 226 | p : list of floats - same length as `values` 227 | if None, values will be sampled uniformly. 228 | Must sum to 1. 229 | """ 230 | self.values = values 231 | self.p = p 232 | 233 | def __call__(self, *inputs): 234 | value = th_random_choice(self.values, p=self.p) 235 | outputs = Brightness(value)(*inputs) 236 | return outputs 237 | 238 | # ---------------------------------------------------- 239 | # ---------------------------------------------------- 240 | 241 | class Saturation(object): 242 | 243 | def __init__(self, value): 244 | """ 245 | Alter the Saturation of image 246 | 247 | Arguments 248 | --------- 249 | value : float 250 | =-1 : gray 251 | <0 : colors are more muted 252 | =0 : image stays the same 253 | >0 : colors are more pure 254 | =1 : most saturated 255 | """ 256 | self.value = max(min(value,1.0),-1.0) 257 | 258 | def __call__(self, *inputs): 259 | outputs = [] 260 | for idx, _input in enumerate(inputs): 261 | _in_gs = Grayscale(keep_channels=True)(_input) 262 | alpha = 1.0 + self.value 263 | _in = th.clamp(_blend(_input, _in_gs, alpha), 0, 1) 264 | outputs.append(_in) 265 | return outputs if idx > 1 else outputs[0] 266 | 267 | class RandomSaturation(object): 268 | 269 | def __init__(self, min_val, max_val): 270 | """ 271 | Alter the Saturation of an image with a value randomly selected 272 | between `min_val` and `max_val` 273 | 274 | Arguments 275 | --------- 276 | min_val : float 277 | min range 278 | max_val : float 279 | max range 280 | """ 281 | self.values = (min_val, max_val) 282 | 283 | def __call__(self, *inputs): 284 | value = random.uniform(self.values[0], self.values[1]) 285 | outputs = Saturation(value)(*inputs) 286 | return outputs 287 | 288 | class RandomChoiceSaturation(object): 289 | 290 | def __init__(self, values, p=None): 291 | """ 292 | Alter the Saturation of an image with a value randomly selected 293 | from the list of given values with given probabilities 294 | 295 | Arguments 296 | --------- 297 | values : list of floats 298 | saturation values to sampled from 299 | p : list of floats - same length as `values` 300 | if None, values will be sampled uniformly. 301 | Must sum to 1. 302 | 303 | """ 304 | self.values = values 305 | self.p = p 306 | 307 | def __call__(self, *inputs): 308 | value = th_random_choice(self.values, p=self.p) 309 | outputs = Saturation(value)(*inputs) 310 | return outputs 311 | 312 | # ---------------------------------------------------- 313 | # ---------------------------------------------------- 314 | 315 | class Contrast(object): 316 | """ 317 | 318 | """ 319 | def __init__(self, value): 320 | """ 321 | Adjust Contrast of image. 322 | 323 | Contrast is adjusted independently for each channel of each image. 324 | 325 | For each channel, this Op computes the mean of the image pixels 326 | in the channel and then adjusts each component x of each pixel to 327 | (x - mean) * contrast_factor + mean. 328 | 329 | Arguments 330 | --------- 331 | value : float 332 | smaller value: less contrast 333 | ZERO: channel means 334 | larger positive value: greater contrast 335 | larger negative value: greater inverse contrast 336 | """ 337 | self.value = value 338 | 339 | def __call__(self, *inputs): 340 | outputs = [] 341 | for idx, _input in enumerate(inputs): 342 | channel_means = _input.mean(1).mean(2) 343 | channel_means = channel_means.expand_as(_input) 344 | _input = th.clamp((_input - channel_means) * self.value + channel_means,0,1) 345 | outputs.append(_input) 346 | return outputs if idx > 1 else outputs[0] 347 | 348 | class RandomContrast(object): 349 | 350 | def __init__(self, min_val, max_val): 351 | """ 352 | Alter the Contrast of an image with a value randomly selected 353 | between `min_val` and `max_val` 354 | 355 | Arguments 356 | --------- 357 | min_val : float 358 | min range 359 | max_val : float 360 | max range 361 | """ 362 | self.values = (min_val, max_val) 363 | 364 | def __call__(self, *inputs): 365 | value = random.uniform(self.values[0], self.values[1]) 366 | outputs = Contrast(value)(*inputs) 367 | return outputs 368 | 369 | class RandomChoiceContrast(object): 370 | 371 | def __init__(self, values, p=None): 372 | """ 373 | Alter the Contrast of an image with a value randomly selected 374 | from the list of given values with given probabilities 375 | 376 | Arguments 377 | --------- 378 | values : list of floats 379 | contrast values to sampled from 380 | p : list of floats - same length as `values` 381 | if None, values will be sampled uniformly. 382 | Must sum to 1. 383 | 384 | """ 385 | self.values = values 386 | self.p = p 387 | 388 | def __call__(self, *inputs): 389 | value = th_random_choice(self.values, p=None) 390 | outputs = Contrast(value)(*inputs) 391 | return outputs 392 | 393 | # ---------------------------------------------------- 394 | # ---------------------------------------------------- 395 | 396 | def rgb_to_hsv(x): 397 | """ 398 | Convert from RGB to HSV 399 | """ 400 | hsv = th.zeros(*x.size()) 401 | c_min = x.min(0) 402 | c_max = x.max(0) 403 | 404 | delta = c_max[0] - c_min[0] 405 | 406 | # set H 407 | r_idx = c_max[1].eq(0) 408 | hsv[0][r_idx] = ((x[1][r_idx] - x[2][r_idx]) / delta[r_idx]) % 6 409 | g_idx = c_max[1].eq(1) 410 | hsv[0][g_idx] = 2 + ((x[2][g_idx] - x[0][g_idx]) / delta[g_idx]) 411 | b_idx = c_max[1].eq(2) 412 | hsv[0][b_idx] = 4 + ((x[0][b_idx] - x[1][b_idx]) / delta[b_idx]) 413 | hsv[0] = hsv[0].mul(60) 414 | 415 | # set S 416 | hsv[1] = delta / c_max[0] 417 | 418 | # set V - good 419 | hsv[2] = c_max[0] 420 | 421 | return hsv 422 | -------------------------------------------------------------------------------- /torchsample/transforms/tensor_transforms.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import math 5 | import numpy as np 6 | 7 | import torch as th 8 | from torch.autograd import Variable 9 | 10 | from ..utils import th_random_choice 11 | 12 | class Compose(object): 13 | """ 14 | Composes several transforms together. 15 | """ 16 | def __init__(self, transforms): 17 | """ 18 | Composes (chains) several transforms together into 19 | a single transform 20 | 21 | Arguments 22 | --------- 23 | transforms : a list of transforms 24 | transforms will be applied sequentially 25 | """ 26 | self.transforms = transforms 27 | 28 | def __call__(self, *inputs): 29 | for transform in self.transforms: 30 | if not isinstance(inputs, (list,tuple)): 31 | inputs = [inputs] 32 | inputs = transform(*inputs) 33 | return inputs 34 | 35 | 36 | class RandomChoiceCompose(object): 37 | """ 38 | Randomly choose to apply one transform from a collection of transforms 39 | 40 | e.g. to randomly apply EITHER 0-1 or -1-1 normalization to an input: 41 | >>> transform = RandomChoiceCompose([RangeNormalize(0,1), 42 | RangeNormalize(-1,1)]) 43 | >>> x_norm = transform(x) # only one of the two normalizations is applied 44 | """ 45 | def __init__(self, transforms): 46 | self.transforms = transforms 47 | 48 | def __call__(self, *inputs): 49 | tform = random.choice(self.transforms) 50 | outputs = tform(*inputs) 51 | return outputs 52 | 53 | 54 | class ToTensor(object): 55 | """ 56 | Converts a numpy array to torch.Tensor 57 | """ 58 | def __call__(self, *inputs): 59 | outputs = [] 60 | for idx, _input in enumerate(inputs): 61 | _input = th.from_numpy(_input) 62 | outputs.append(_input) 63 | return outputs if idx > 1 else outputs[0] 64 | 65 | 66 | class ToVariable(object): 67 | """ 68 | Converts a torch.Tensor to autograd.Variable 69 | """ 70 | def __call__(self, *inputs): 71 | outputs = [] 72 | for idx, _input in enumerate(inputs): 73 | _input = Variable(_input) 74 | outputs.append(_input) 75 | return outputs if idx > 1 else outputs[0] 76 | 77 | 78 | class ToCuda(object): 79 | """ 80 | Moves an autograd.Variable to the GPU 81 | """ 82 | def __init__(self, device=0): 83 | """ 84 | Moves an autograd.Variable to the GPU 85 | 86 | Arguments 87 | --------- 88 | device : integer 89 | which GPU device to put the input(s) on 90 | """ 91 | self.device = device 92 | 93 | def __call__(self, *inputs): 94 | outputs = [] 95 | for idx, _input in enumerate(inputs): 96 | _input = _input.cuda(self.device) 97 | outputs.append(_input) 98 | return outputs if idx > 1 else outputs[0] 99 | 100 | 101 | class ToFile(object): 102 | """ 103 | Saves an image to file. Useful as a pass-through ransform 104 | when wanting to observe how augmentation affects the data 105 | 106 | NOTE: Only supports saving to Numpy currently 107 | """ 108 | def __init__(self, root): 109 | """ 110 | Saves an image to file. Useful as a pass-through ransform 111 | when wanting to observe how augmentation affects the data 112 | 113 | NOTE: Only supports saving to Numpy currently 114 | 115 | Arguments 116 | --------- 117 | root : string 118 | path to main directory in which images will be saved 119 | """ 120 | if root.startswith('~'): 121 | root = os.path.expanduser(root) 122 | self.root = root 123 | self.counter = 0 124 | 125 | def __call__(self, *inputs): 126 | for idx, _input in inputs: 127 | fpath = os.path.join(self.root, 'img_%i_%i.npy'%(self.counter, idx)) 128 | np.save(fpath, _input.numpy()) 129 | self.counter += 1 130 | return inputs 131 | 132 | 133 | class ChannelsLast(object): 134 | """ 135 | Transposes a tensor so that the channel dim is last 136 | `HWC` and `DHWC` are aliases for this transform. 137 | """ 138 | def __init__(self, safe_check=False): 139 | """ 140 | Transposes a tensor so that the channel dim is last 141 | `HWC` and `DHWC` are aliases for this transform. 142 | 143 | Arguments 144 | --------- 145 | safe_check : boolean 146 | if true, will check if channels are already last and, if so, 147 | will just return the inputs 148 | """ 149 | self.safe_check = safe_check 150 | 151 | def __call__(self, *inputs): 152 | ndim = inputs[0].dim() 153 | if self.safe_check: 154 | # check if channels are already last 155 | if inputs[0].size(-1) < inputs[0].size(0): 156 | return inputs 157 | plist = list(range(1,ndim))+[0] 158 | 159 | outputs = [] 160 | for idx, _input in enumerate(inputs): 161 | _input = _input.permute(*plist) 162 | outputs.append(_input) 163 | return outputs if idx > 1 else outputs[0] 164 | 165 | HWC = ChannelsLast 166 | DHWC = ChannelsLast 167 | 168 | class ChannelsFirst(object): 169 | """ 170 | Transposes a tensor so that the channel dim is first. 171 | `CHW` and `CDHW` are aliases for this transform. 172 | """ 173 | def __init__(self, safe_check=False): 174 | """ 175 | Transposes a tensor so that the channel dim is first. 176 | `CHW` and `CDHW` are aliases for this transform. 177 | 178 | Arguments 179 | --------- 180 | safe_check : boolean 181 | if true, will check if channels are already last and, if so, 182 | will just return the inputs 183 | """ 184 | self.safe_check = safe_check 185 | 186 | def __call__(self, *inputs): 187 | ndim = inputs[0].dim() 188 | if self.safe_check: 189 | # check if channels are already first 190 | if inputs[0].size(0) < inputs[0].size(-1): 191 | return inputs 192 | plist = [ndim-1] + list(range(0,ndim-1)) 193 | 194 | outputs = [] 195 | for idx, _input in enumerate(inputs): 196 | _input = _input.permute(*plist) 197 | outputs.append(_input) 198 | return outputs if idx > 1 else outputs[0] 199 | 200 | CHW = ChannelsFirst 201 | CDHW = ChannelsFirst 202 | 203 | class TypeCast(object): 204 | """ 205 | Cast a torch.Tensor to a different type 206 | """ 207 | def __init__(self, dtype='float'): 208 | """ 209 | Cast a torch.Tensor to a different type 210 | 211 | Arguments 212 | --------- 213 | dtype : string or torch.*Tensor literal or list of such 214 | data type to which input(s) will be cast. 215 | If list, it should be the same length as inputs. 216 | """ 217 | if isinstance(dtype, (list,tuple)): 218 | dtypes = [] 219 | for dt in dtype: 220 | if isinstance(dt, str): 221 | if dt == 'byte': 222 | dt = th.ByteTensor 223 | elif dt == 'double': 224 | dt = th.DoubleTensor 225 | elif dt == 'float': 226 | dt = th.FloatTensor 227 | elif dt == 'int': 228 | dt = th.IntTensor 229 | elif dt == 'long': 230 | dt = th.LongTensor 231 | elif dt == 'short': 232 | dt = th.ShortTensor 233 | dtypes.append(dt) 234 | self.dtype = dtypes 235 | else: 236 | if isinstance(dtype, str): 237 | if dtype == 'byte': 238 | dtype = th.ByteTensor 239 | elif dtype == 'double': 240 | dtype = th.DoubleTensor 241 | elif dtype == 'float': 242 | dtype = th.FloatTensor 243 | elif dtype == 'int': 244 | dtype = th.IntTensor 245 | elif dtype == 'long': 246 | dtype = th.LongTensor 247 | elif dtype == 'short': 248 | dtype = th.ShortTensor 249 | self.dtype = dtype 250 | 251 | def __call__(self, *inputs): 252 | if not isinstance(self.dtype, (tuple,list)): 253 | dtypes = [self.dtype]*len(inputs) 254 | else: 255 | dtypes = self.dtype 256 | 257 | outputs = [] 258 | for idx, _input in enumerate(inputs): 259 | _input = _input.type(dtypes[idx]) 260 | outputs.append(_input) 261 | return outputs if idx > 1 else outputs[0] 262 | 263 | 264 | class AddChannel(object): 265 | """ 266 | Adds a dummy channel to an image. 267 | This will make an image of size (28, 28) to now be 268 | of size (1, 28, 28), for example. 269 | """ 270 | def __init__(self, axis=0): 271 | """ 272 | Adds a dummy channel to an image, also known as 273 | expanding an axis or unsqueezing a dim 274 | 275 | Arguments 276 | --------- 277 | axis : integer 278 | dimension to be expanded to singleton 279 | """ 280 | self.axis = axis 281 | 282 | def __call__(self, *inputs): 283 | outputs = [] 284 | for idx, _input in enumerate(inputs): 285 | _input = _input.unsqueeze(self.axis) 286 | outputs.append(_input) 287 | return outputs if idx > 1 else outputs[0] 288 | 289 | ExpandAxis = AddChannel 290 | Unsqueeze = AddChannel 291 | 292 | class Transpose(object): 293 | 294 | def __init__(self, dim1, dim2): 295 | """ 296 | Swaps two dimensions of a tensor 297 | 298 | Arguments 299 | --------- 300 | dim1 : integer 301 | first dim to switch 302 | dim2 : integer 303 | second dim to switch 304 | """ 305 | self.dim1 = dim1 306 | self.dim2 = dim2 307 | 308 | def __call__(self, *inputs): 309 | outputs = [] 310 | for idx, _input in enumerate(inputs): 311 | _input = th.transpose(_input, self.dim1, self.dim2) 312 | outputs.append(_input) 313 | return outputs if idx > 1 else outputs[0] 314 | 315 | 316 | class RangeNormalize(object): 317 | """ 318 | Given min_val: (R, G, B) and max_val: (R,G,B), 319 | will normalize each channel of the th.*Tensor to 320 | the provided min and max values. 321 | 322 | Works by calculating : 323 | a = (max'-min')/(max-min) 324 | b = max' - a * max 325 | new_value = a * value + b 326 | where min' & max' are given values, 327 | and min & max are observed min/max for each channel 328 | 329 | Arguments 330 | --------- 331 | min_range : float or integer 332 | Min value to which tensors will be normalized 333 | max_range : float or integer 334 | Max value to which tensors will be normalized 335 | fixed_min : float or integer 336 | Give this value if every sample has the same min (max) and 337 | you know for sure what it is. For instance, if you 338 | have an image then you know the min value will be 0 and the 339 | max value will be 255. Otherwise, the min/max value will be 340 | calculated for each individual sample and this will decrease 341 | speed. Dont use this if each sample has a different min/max. 342 | fixed_max :float or integer 343 | See above 344 | 345 | Example: 346 | >>> x = th.rand(3,5,5) 347 | >>> rn = RangeNormalize((0,0,10),(1,1,11)) 348 | >>> x_norm = rn(x) 349 | 350 | Also works with just one value for min/max: 351 | >>> x = th.rand(3,5,5) 352 | >>> rn = RangeNormalize(0,1) 353 | >>> x_norm = rn(x) 354 | """ 355 | def __init__(self, 356 | min_val, 357 | max_val): 358 | """ 359 | Normalize a tensor between a min and max value 360 | 361 | Arguments 362 | --------- 363 | min_val : float 364 | lower bound of normalized tensor 365 | max_val : float 366 | upper bound of normalized tensor 367 | """ 368 | self.min_val = min_val 369 | self.max_val = max_val 370 | 371 | def __call__(self, *inputs): 372 | outputs = [] 373 | for idx, _input in enumerate(inputs): 374 | _min_val = _input.min() 375 | _max_val = _input.max() 376 | a = (self.max_val - self.min_val) / (_max_val - _min_val) 377 | b = self.max_val- a * _max_val 378 | _input = _input.mul(a).add(b) 379 | outputs.append(_input) 380 | return outputs if idx > 1 else outputs[0] 381 | 382 | 383 | class StdNormalize(object): 384 | """ 385 | Normalize torch tensor to have zero mean and unit std deviation 386 | """ 387 | def __call__(self, *inputs): 388 | outputs = [] 389 | for idx, _input in enumerate(inputs): 390 | _input = _input.sub(_input.mean()).div(_input.std()) 391 | outputs.append(_input) 392 | return outputs if idx > 1 else outputs[0] 393 | 394 | 395 | class Slice2D(object): 396 | 397 | def __init__(self, axis=0, reject_zeros=False): 398 | """ 399 | Take a random 2D slice from a 3D image along 400 | a given axis. This image should not have a 4th channel dim. 401 | 402 | Arguments 403 | --------- 404 | axis : integer in {0, 1, 2} 405 | the axis on which to take slices 406 | 407 | reject_zeros : boolean 408 | whether to reject slices that are all zeros 409 | """ 410 | self.axis = axis 411 | self.reject_zeros = reject_zeros 412 | 413 | def __call__(self, x, y=None): 414 | while True: 415 | keep_slice = random.randint(0,x.size(self.axis)-1) 416 | if self.axis == 0: 417 | slice_x = x[keep_slice,:,:] 418 | if y is not None: 419 | slice_y = y[keep_slice,:,:] 420 | elif self.axis == 1: 421 | slice_x = x[:,keep_slice,:] 422 | if y is not None: 423 | slice_y = y[:,keep_slice,:] 424 | elif self.axis == 2: 425 | slice_x = x[:,:,keep_slice] 426 | if y is not None: 427 | slice_y = y[:,:,keep_slice] 428 | 429 | if not self.reject_zeros: 430 | break 431 | else: 432 | if y is not None and th.sum(slice_y) > 0: 433 | break 434 | elif th.sum(slice_x) > 0: 435 | break 436 | if y is not None: 437 | return slice_x, slice_y 438 | else: 439 | return slice_x 440 | 441 | 442 | class RandomCrop(object): 443 | 444 | def __init__(self, size): 445 | """ 446 | Randomly crop a torch tensor 447 | 448 | Arguments 449 | -------- 450 | size : tuple or list 451 | dimensions of the crop 452 | """ 453 | self.size = size 454 | 455 | def __call__(self, *inputs): 456 | h_idx = random.randint(0,inputs[0].size(1)-self.size[0]) 457 | w_idx = random.randint(0,inputs[1].size(2)-self.size[1]) 458 | outputs = [] 459 | for idx, _input in enumerate(inputs): 460 | _input = _input[:, h_idx:(h_idx+self.size[0]),w_idx:(w_idx+self.size[1])] 461 | outputs.append(_input) 462 | return outputs if idx > 1 else outputs[0] 463 | 464 | 465 | class SpecialCrop(object): 466 | 467 | def __init__(self, size, crop_type=0): 468 | """ 469 | Perform a special crop - one of the four corners or center crop 470 | 471 | Arguments 472 | --------- 473 | size : tuple or list 474 | dimensions of the crop 475 | 476 | crop_type : integer in {0,1,2,3,4} 477 | 0 = center crop 478 | 1 = top left crop 479 | 2 = top right crop 480 | 3 = bottom right crop 481 | 4 = bottom left crop 482 | """ 483 | if crop_type not in {0, 1, 2, 3, 4}: 484 | raise ValueError('crop_type must be in {0, 1, 2, 3, 4}') 485 | self.size = size 486 | self.crop_type = crop_type 487 | 488 | def __call__(self, x, y=None): 489 | if self.crop_type == 0: 490 | # center crop 491 | x_diff = (x.size(1)-self.size[0])/2. 492 | y_diff = (x.size(2)-self.size[1])/2. 493 | ct_x = [int(math.ceil(x_diff)),x.size(1)-int(math.floor(x_diff))] 494 | ct_y = [int(math.ceil(y_diff)),x.size(2)-int(math.floor(y_diff))] 495 | indices = [ct_x,ct_y] 496 | elif self.crop_type == 1: 497 | # top left crop 498 | tl_x = [0, self.size[0]] 499 | tl_y = [0, self.size[1]] 500 | indices = [tl_x,tl_y] 501 | elif self.crop_type == 2: 502 | # top right crop 503 | tr_x = [0, self.size[0]] 504 | tr_y = [x.size(2)-self.size[1], x.size(2)] 505 | indices = [tr_x,tr_y] 506 | elif self.crop_type == 3: 507 | # bottom right crop 508 | br_x = [x.size(1)-self.size[0],x.size(1)] 509 | br_y = [x.size(2)-self.size[1],x.size(2)] 510 | indices = [br_x,br_y] 511 | elif self.crop_type == 4: 512 | # bottom left crop 513 | bl_x = [x.size(1)-self.size[0], x.size(1)] 514 | bl_y = [0, self.size[1]] 515 | indices = [bl_x,bl_y] 516 | 517 | x = x[:,indices[0][0]:indices[0][1],indices[1][0]:indices[1][1]] 518 | 519 | if y is not None: 520 | y = y[:,indices[0][0]:indices[0][1],indices[1][0]:indices[1][1]] 521 | return x, y 522 | else: 523 | return x 524 | 525 | 526 | class Pad(object): 527 | 528 | def __init__(self, size): 529 | """ 530 | Pads an image to the given size 531 | 532 | Arguments 533 | --------- 534 | size : tuple or list 535 | size of crop 536 | """ 537 | self.size = size 538 | 539 | def __call__(self, x, y=None): 540 | x = x.numpy() 541 | shape_diffs = [int(np.ceil((i_s - d_s))) for d_s,i_s in zip(x.shape,self.size)] 542 | shape_diffs = np.maximum(shape_diffs,0) 543 | pad_sizes = [(int(np.ceil(s/2.)),int(np.floor(s/2.))) for s in shape_diffs] 544 | x = np.pad(x, pad_sizes, mode='constant') 545 | if y is not None: 546 | y = y.numpy() 547 | y = np.pad(y, pad_sizes, mode='constant') 548 | return th.from_numpy(x), th.from_numpy(y) 549 | else: 550 | return th.from_numpy(x) 551 | 552 | 553 | class RandomFlip(object): 554 | 555 | def __init__(self, h=True, v=False, p=0.5): 556 | """ 557 | Randomly flip an image horizontally and/or vertically with 558 | some probability. 559 | 560 | Arguments 561 | --------- 562 | h : boolean 563 | whether to horizontally flip w/ probability p 564 | 565 | v : boolean 566 | whether to vertically flip w/ probability p 567 | 568 | p : float between [0,1] 569 | probability with which to apply allowed flipping operations 570 | """ 571 | self.horizontal = h 572 | self.vertical = v 573 | self.p = p 574 | 575 | def __call__(self, x, y=None): 576 | x = x.numpy() 577 | if y is not None: 578 | y = y.numpy() 579 | # horizontal flip with p = self.p 580 | if self.horizontal: 581 | if random.random() < self.p: 582 | x = x.swapaxes(2, 0) 583 | x = x[::-1, ...] 584 | x = x.swapaxes(0, 2) 585 | if y is not None: 586 | y = y.swapaxes(2, 0) 587 | y = y[::-1, ...] 588 | y = y.swapaxes(0, 2) 589 | # vertical flip with p = self.p 590 | if self.vertical: 591 | if random.random() < self.p: 592 | x = x.swapaxes(1, 0) 593 | x = x[::-1, ...] 594 | x = x.swapaxes(0, 1) 595 | if y is not None: 596 | y = y.swapaxes(1, 0) 597 | y = y[::-1, ...] 598 | y = y.swapaxes(0, 1) 599 | if y is None: 600 | # must copy because torch doesnt current support neg strides 601 | return th.from_numpy(x.copy()) 602 | else: 603 | return th.from_numpy(x.copy()),th.from_numpy(y.copy()) 604 | 605 | 606 | class RandomOrder(object): 607 | """ 608 | Randomly permute the channels of an image 609 | """ 610 | def __call__(self, *inputs): 611 | order = th.randperm(inputs[0].dim()) 612 | outputs = [] 613 | for idx, _input in enumerate(inputs): 614 | _input = _input.index_select(0, order) 615 | outputs.append(_input) 616 | return outputs if idx > 1 else outputs[0] 617 | 618 | -------------------------------------------------------------------------------- /torchsample/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for th.Tensors 3 | """ 4 | 5 | import pickle 6 | import random 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def th_allclose(x, y): 13 | """ 14 | Determine whether two torch tensors have same values 15 | Mimics np.allclose 16 | """ 17 | return th.sum(th.abs(x-y)) < 1e-5 18 | 19 | 20 | def th_flatten(x): 21 | """Flatten tensor""" 22 | return x.contiguous().view(-1) 23 | 24 | def th_c_flatten(x): 25 | """ 26 | Flatten tensor, leaving channel intact. 27 | Assumes CHW format. 28 | """ 29 | return x.contiguous().view(x.size(0), -1) 30 | 31 | def th_bc_flatten(x): 32 | """ 33 | Flatten tensor, leaving batch and channel dims intact. 34 | Assumes BCHW format 35 | """ 36 | return x.contiguous().view(x.size(0), x.size(1), -1) 37 | 38 | 39 | def th_zeros_like(x): 40 | return x.new().resize_as_(x).zero_() 41 | 42 | def th_ones_like(x): 43 | return x.new().resize_as_(x).fill_(1) 44 | 45 | def th_constant_like(x, val): 46 | return x.new().resize_as_(x).fill_(val) 47 | 48 | 49 | def th_iterproduct(*args): 50 | return th.from_numpy(np.indices(args).reshape((len(args),-1)).T) 51 | 52 | def th_iterproduct_like(x): 53 | return th_iterproduct(*x.size()) 54 | 55 | 56 | def th_uniform(lower, upper): 57 | return random.uniform(lower, upper) 58 | 59 | 60 | def th_gather_nd(x, coords): 61 | x = x.contiguous() 62 | inds = coords.mv(th.LongTensor(x.stride())) 63 | x_gather = th.index_select(th_flatten(x), 0, inds) 64 | return x_gather 65 | 66 | 67 | def th_affine2d(x, matrix, mode='bilinear', center=True): 68 | """ 69 | 2D Affine image transform on th.Tensor 70 | 71 | Arguments 72 | --------- 73 | x : th.Tensor of size (C, H, W) 74 | image tensor to be transformed 75 | 76 | matrix : th.Tensor of size (3, 3) or (2, 3) 77 | transformation matrix 78 | 79 | mode : string in {'nearest', 'bilinear'} 80 | interpolation scheme to use 81 | 82 | center : boolean 83 | whether to alter the bias of the transform 84 | so the transform is applied about the center 85 | of the image rather than the origin 86 | 87 | Example 88 | ------- 89 | >>> import torch 90 | >>> from torchsample.utils import * 91 | >>> x = th.zeros(2,1000,1000) 92 | >>> x[:,100:1500,100:500] = 10 93 | >>> matrix = th.FloatTensor([[1.,0,-50], 94 | ... [0,1.,-50]]) 95 | >>> xn = th_affine2d(x, matrix, mode='nearest') 96 | >>> xb = th_affine2d(x, matrix, mode='bilinear') 97 | """ 98 | 99 | if matrix.dim() == 2: 100 | matrix = matrix[:2,:] 101 | matrix = matrix.unsqueeze(0) 102 | elif matrix.dim() == 3: 103 | if matrix.size()[1:] == (3,3): 104 | matrix = matrix[:,:2,:] 105 | 106 | A_batch = matrix[:,:,:2] 107 | if A_batch.size(0) != x.size(0): 108 | A_batch = A_batch.repeat(x.size(0),1,1) 109 | b_batch = matrix[:,:,2].unsqueeze(1) 110 | 111 | # make a meshgrid of normal coordinates 112 | _coords = th_iterproduct(x.size(1),x.size(2)) 113 | coords = _coords.unsqueeze(0).repeat(x.size(0),1,1).float() 114 | 115 | if center: 116 | # shift the coordinates so center is the origin 117 | coords[:,:,0] = coords[:,:,0] - (x.size(1) / 2. - 0.5) 118 | coords[:,:,1] = coords[:,:,1] - (x.size(2) / 2. - 0.5) 119 | # apply the coordinate transformation 120 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords) 121 | 122 | if center: 123 | # shift the coordinates back so origin is origin 124 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(1) / 2. - 0.5) 125 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(2) / 2. - 0.5) 126 | 127 | # map new coordinates using bilinear interpolation 128 | if mode == 'nearest': 129 | x_transformed = th_nearest_interp2d(x.contiguous(), new_coords) 130 | elif mode == 'bilinear': 131 | x_transformed = th_bilinear_interp2d(x.contiguous(), new_coords) 132 | 133 | return x_transformed 134 | 135 | 136 | def th_nearest_interp2d(input, coords): 137 | """ 138 | 2d nearest neighbor interpolation th.Tensor 139 | """ 140 | # take clamp of coords so they're in the image bounds 141 | x = th.clamp(coords[:,:,0], 0, input.size(1)-1).round() 142 | y = th.clamp(coords[:,:,1], 0, input.size(2)-1).round() 143 | 144 | stride = th.LongTensor(input.stride()) 145 | x_ix = x.mul(stride[1]).long() 146 | y_ix = y.mul(stride[2]).long() 147 | 148 | input_flat = input.view(input.size(0),-1) 149 | 150 | mapped_vals = input_flat.gather(1, x_ix.add(y_ix)) 151 | 152 | return mapped_vals.view_as(input) 153 | 154 | 155 | def th_bilinear_interp2d(input, coords): 156 | """ 157 | bilinear interpolation in 2d 158 | """ 159 | x = th.clamp(coords[:,:,0], 0, input.size(1)-2) 160 | x0 = x.floor() 161 | x1 = x0 + 1 162 | y = th.clamp(coords[:,:,1], 0, input.size(2)-2) 163 | y0 = y.floor() 164 | y1 = y0 + 1 165 | 166 | stride = th.LongTensor(input.stride()) 167 | x0_ix = x0.mul(stride[1]).long() 168 | x1_ix = x1.mul(stride[1]).long() 169 | y0_ix = y0.mul(stride[2]).long() 170 | y1_ix = y1.mul(stride[2]).long() 171 | 172 | input_flat = input.view(input.size(0),-1) 173 | 174 | vals_00 = input_flat.gather(1, x0_ix.add(y0_ix)) 175 | vals_10 = input_flat.gather(1, x1_ix.add(y0_ix)) 176 | vals_01 = input_flat.gather(1, x0_ix.add(y1_ix)) 177 | vals_11 = input_flat.gather(1, x1_ix.add(y1_ix)) 178 | 179 | xd = x - x0 180 | yd = y - y0 181 | xm = 1 - xd 182 | ym = 1 - yd 183 | 184 | x_mapped = (vals_00.mul(xm).mul(ym) + 185 | vals_10.mul(xd).mul(ym) + 186 | vals_01.mul(xm).mul(yd) + 187 | vals_11.mul(xd).mul(yd)) 188 | 189 | return x_mapped.view_as(input) 190 | 191 | 192 | def th_affine3d(x, matrix, mode='trilinear', center=True): 193 | """ 194 | 3D Affine image transform on th.Tensor 195 | """ 196 | A = matrix[:3,:3] 197 | b = matrix[:3,3] 198 | 199 | # make a meshgrid of normal coordinates 200 | coords = th_iterproduct(x.size(1),x.size(2),x.size(3)).float() 201 | 202 | 203 | if center: 204 | # shift the coordinates so center is the origin 205 | coords[:,0] = coords[:,0] - (x.size(1) / 2. - 0.5) 206 | coords[:,1] = coords[:,1] - (x.size(2) / 2. - 0.5) 207 | coords[:,2] = coords[:,2] - (x.size(3) / 2. - 0.5) 208 | 209 | 210 | # apply the coordinate transformation 211 | new_coords = coords.mm(A.t().contiguous()) + b.expand_as(coords) 212 | 213 | if center: 214 | # shift the coordinates back so origin is origin 215 | new_coords[:,0] = new_coords[:,0] + (x.size(1) / 2. - 0.5) 216 | new_coords[:,1] = new_coords[:,1] + (x.size(2) / 2. - 0.5) 217 | new_coords[:,2] = new_coords[:,2] + (x.size(3) / 2. - 0.5) 218 | 219 | # map new coordinates using bilinear interpolation 220 | if mode == 'nearest': 221 | x_transformed = th_nearest_interp3d(x, new_coords) 222 | elif mode == 'trilinear': 223 | x_transformed = th_trilinear_interp3d(x, new_coords) 224 | else: 225 | x_transformed = th_trilinear_interp3d(x, new_coords) 226 | 227 | return x_transformed 228 | 229 | 230 | def th_nearest_interp3d(input, coords): 231 | """ 232 | 2d nearest neighbor interpolation th.Tensor 233 | """ 234 | # take clamp of coords so they're in the image bounds 235 | coords[:,0] = th.clamp(coords[:,0], 0, input.size(1)-1).round() 236 | coords[:,1] = th.clamp(coords[:,1], 0, input.size(2)-1).round() 237 | coords[:,2] = th.clamp(coords[:,2], 0, input.size(3)-1).round() 238 | 239 | stride = th.LongTensor(input.stride())[1:].float() 240 | idx = coords.mv(stride).long() 241 | 242 | input_flat = th_flatten(input) 243 | 244 | mapped_vals = input_flat[idx] 245 | 246 | return mapped_vals.view_as(input) 247 | 248 | 249 | def th_trilinear_interp3d(input, coords): 250 | """ 251 | trilinear interpolation of 3D th.Tensor image 252 | """ 253 | # take clamp then floor/ceil of x coords 254 | x = th.clamp(coords[:,0], 0, input.size(1)-2) 255 | x0 = x.floor() 256 | x1 = x0 + 1 257 | # take clamp then floor/ceil of y coords 258 | y = th.clamp(coords[:,1], 0, input.size(2)-2) 259 | y0 = y.floor() 260 | y1 = y0 + 1 261 | # take clamp then floor/ceil of z coords 262 | z = th.clamp(coords[:,2], 0, input.size(3)-2) 263 | z0 = z.floor() 264 | z1 = z0 + 1 265 | 266 | stride = th.LongTensor(input.stride())[1:] 267 | x0_ix = x0.mul(stride[0]).long() 268 | x1_ix = x1.mul(stride[0]).long() 269 | y0_ix = y0.mul(stride[1]).long() 270 | y1_ix = y1.mul(stride[1]).long() 271 | z0_ix = z0.mul(stride[2]).long() 272 | z1_ix = z1.mul(stride[2]).long() 273 | 274 | input_flat = th_flatten(input) 275 | 276 | vals_000 = input_flat[x0_ix+y0_ix+z0_ix] 277 | vals_100 = input_flat[x1_ix+y0_ix+z0_ix] 278 | vals_010 = input_flat[x0_ix+y1_ix+z0_ix] 279 | vals_001 = input_flat[x0_ix+y0_ix+z1_ix] 280 | vals_101 = input_flat[x1_ix+y0_ix+z1_ix] 281 | vals_011 = input_flat[x0_ix+y1_ix+z1_ix] 282 | vals_110 = input_flat[x1_ix+y1_ix+z0_ix] 283 | vals_111 = input_flat[x1_ix+y1_ix+z1_ix] 284 | 285 | xd = x - x0 286 | yd = y - y0 287 | zd = z - z0 288 | xm1 = 1 - xd 289 | ym1 = 1 - yd 290 | zm1 = 1 - zd 291 | 292 | x_mapped = (vals_000.mul(xm1).mul(ym1).mul(zm1) + 293 | vals_100.mul(xd).mul(ym1).mul(zm1) + 294 | vals_010.mul(xm1).mul(yd).mul(zm1) + 295 | vals_001.mul(xm1).mul(ym1).mul(zd) + 296 | vals_101.mul(xd).mul(ym1).mul(zd) + 297 | vals_011.mul(xm1).mul(yd).mul(zd) + 298 | vals_110.mul(xd).mul(yd).mul(zm1) + 299 | vals_111.mul(xd).mul(yd).mul(zd)) 300 | 301 | return x_mapped.view_as(input) 302 | 303 | 304 | def th_pearsonr(x, y): 305 | """ 306 | mimics scipy.stats.pearsonr 307 | """ 308 | mean_x = th.mean(x) 309 | mean_y = th.mean(y) 310 | xm = x.sub(mean_x) 311 | ym = y.sub(mean_y) 312 | r_num = xm.dot(ym) 313 | r_den = th.norm(xm, 2) * th.norm(ym, 2) 314 | r_val = r_num / r_den 315 | return r_val 316 | 317 | 318 | def th_corrcoef(x): 319 | """ 320 | mimics np.corrcoef 321 | """ 322 | # calculate covariance matrix of rows 323 | mean_x = th.mean(x, 1) 324 | xm = x.sub(mean_x.expand_as(x)) 325 | c = xm.mm(xm.t()) 326 | c = c / (x.size(1) - 1) 327 | 328 | # normalize covariance matrix 329 | d = th.diag(c) 330 | stddev = th.pow(d, 0.5) 331 | c = c.div(stddev.expand_as(c)) 332 | c = c.div(stddev.expand_as(c).t()) 333 | 334 | # clamp between -1 and 1 335 | c = th.clamp(c, -1.0, 1.0) 336 | 337 | return c 338 | 339 | 340 | def th_matrixcorr(x, y): 341 | """ 342 | return a correlation matrix between 343 | columns of x and columns of y. 344 | 345 | So, if X.size() == (1000,4) and Y.size() == (1000,5), 346 | then the result will be of size (4,5) with the 347 | (i,j) value equal to the pearsonr correlation coeff 348 | between column i in X and column j in Y 349 | """ 350 | mean_x = th.mean(x, 0) 351 | mean_y = th.mean(y, 0) 352 | xm = x.sub(mean_x.expand_as(x)) 353 | ym = y.sub(mean_y.expand_as(y)) 354 | r_num = xm.t().mm(ym) 355 | r_den1 = th.norm(xm,2,0) 356 | r_den2 = th.norm(ym,2,0) 357 | r_den = r_den1.t().mm(r_den2) 358 | r_mat = r_num.div(r_den) 359 | return r_mat 360 | 361 | 362 | def th_random_choice(a, n_samples=1, replace=True, p=None): 363 | """ 364 | Parameters 365 | ----------- 366 | a : 1-D array-like 367 | If a th.Tensor, a random sample is generated from its elements. 368 | If an int, the random sample is generated as if a was th.range(n) 369 | n_samples : int, optional 370 | Number of samples to draw. Default is None, in which case a 371 | single value is returned. 372 | replace : boolean, optional 373 | Whether the sample is with or without replacement 374 | p : 1-D array-like, optional 375 | The probabilities associated with each entry in a. 376 | If not given the sample assumes a uniform distribution over all 377 | entries in a. 378 | 379 | Returns 380 | -------- 381 | samples : 1-D ndarray, shape (size,) 382 | The generated random samples 383 | """ 384 | if isinstance(a, int): 385 | a = th.arange(0, a) 386 | 387 | if p is None: 388 | if replace: 389 | idx = th.floor(th.rand(n_samples)*a.size(0)).long() 390 | else: 391 | idx = th.randperm(len(a))[:n_samples] 392 | else: 393 | if abs(1.0-sum(p)) > 1e-3: 394 | raise ValueError('p must sum to 1.0') 395 | if not replace: 396 | raise ValueError('replace must equal true if probabilities given') 397 | idx_vec = th.cat([th.zeros(round(p[i]*1000))+i for i in range(len(p))]) 398 | idx = (th.floor(th.rand(n_samples)*999)).long() 399 | idx = idx_vec[idx].long() 400 | selection = a[idx] 401 | if n_samples == 1: 402 | selection = selection[0] 403 | return selection 404 | 405 | 406 | def save_transform(file, transform): 407 | """ 408 | Save a transform object 409 | """ 410 | with open(file, 'wb') as output_file: 411 | pickler = pickle.Pickler(output_file, -1) 412 | pickler.dump(transform) 413 | 414 | 415 | def load_transform(file): 416 | """ 417 | Load a transform object 418 | """ 419 | with open(file, 'rb') as input_file: 420 | transform = pickle.load(input_file) 421 | return transform 422 | 423 | 424 | 425 | 426 | -------------------------------------------------------------------------------- /torchsample/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.3' 2 | --------------------------------------------------------------------------------