├── .gitignore
├── LICENSE.txt
├── README.md
├── examples
├── Transforms with Pytorch and Torchsample.ipynb
├── imgs
│ ├── orig1.png
│ ├── orig2.png
│ ├── orig3.png
│ ├── tform1.png
│ ├── tform2.png
│ └── tform3.png
├── mnist_example.py
└── mnist_loader_example.py
├── setup.py
├── tests
├── integration
│ ├── fit_complex
│ │ └── multi_input_multi_target.py
│ ├── fit_loader_simple
│ │ ├── single_input_multi_target.py
│ │ └── single_input_single_target.py
│ └── fit_simple
│ │ ├── simple_multi_input_multi_target.py
│ │ ├── simple_multi_input_no_target.py
│ │ ├── simple_multi_input_single_target.py
│ │ ├── single_input_multi_target.py
│ │ ├── single_input_no_target.py
│ │ └── single_input_single_target.py
├── test_metrics.py
├── transforms
│ ├── test_affine_transforms.py
│ ├── test_image_transforms.py
│ └── test_tensor_transforms.py
└── utils.py
└── torchsample
├── __init__.py
├── callbacks.py
├── constraints.py
├── datasets.py
├── functions
├── __init__.py
└── affine.py
├── initializers.py
├── metrics.py
├── modules
├── __init__.py
├── _utils.py
└── module_trainer.py
├── regularizers.py
├── samplers.py
├── transforms
├── __init__.py
├── affine_transforms.py
├── distortion_transforms.py
├── image_transforms.py
└── tensor_transforms.py
├── utils.py
└── version.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .git/
2 | sandbox/
3 |
4 | *.DS_Store
5 | *__pycache__*
6 | __pycache__
7 | *.pyc
8 | .ipynb_checkpoints/
9 | *.ipynb_checkpoints/
10 | *.bkbn
11 | .spyderworkspace
12 | .spyderproject
13 |
14 | # setup.py working directory
15 | build
16 | # sphinx build directory
17 | doc/_build
18 | # setup.py dist directory
19 | dist
20 | # Egg metadata
21 | *.egg-info
22 | .eggs
23 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | COPYRIGHT
2 |
3 | Some contributions by Nicholas Cullen:
4 | Copyright (c) 2017, Nicholas Cullen:
5 | All rights reserved.
6 |
7 | Some contributions by François Chollet:
8 | Copyright (c) 2015, François Chollet.
9 | All rights reserved.
10 |
11 | Some contributions by Google:
12 | Copyright (c) 2015, Google, Inc.
13 | All rights reserved.
14 |
15 | All other contributions:
16 | Copyright (c) 2015, the respective contributors.
17 | All rights reserved.
18 |
19 |
20 | LICENSE
21 |
22 | The MIT License (MIT)
23 |
24 | Permission is hereby granted, free of charge, to any person obtaining a copy
25 | of this software and associated documentation files (the "Software"), to deal
26 | in the Software without restriction, including without limitation the rights
27 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28 | copies of the Software, and to permit persons to whom the Software is
29 | furnished to do so, subject to the following conditions:
30 |
31 | The above copyright notice and this permission notice shall be included in all
32 | copies or substantial portions of the Software.
33 |
34 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 | SOFTWARE.
41 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # High-Level Training, Data Augmentation, and Utilities for Pytorch
2 |
3 | [v0.1.3](https://github.com/ncullen93/torchsample/releases) JUST RELEASED - contains significant improvements, bug fixes, and additional
4 | support. Get it from the releases, or pull the master branch.
5 |
6 | This package provides a few things:
7 | - A high-level module for Keras-like training with callbacks, constraints, and regularizers.
8 | - Comprehensive data augmentation, transforms, sampling, and loading
9 | - Utility tensor and variable functions so you don't need numpy as often
10 |
11 | Have any feature requests? Submit an issue! I'll make it happen. Specifically,
12 | any data augmentation, data loading, or sampling functions.
13 |
14 | Want to contribute? Check the [issues page](https://github.com/ncullen93/torchsample/issues)
15 | for those tagged with [contributions welcome].
16 |
17 | ## ModuleTrainer
18 | The `ModuleTrainer` class provides a high-level training interface which abstracts
19 | away the training loop while providing callbacks, constraints, initializers, regularizers,
20 | and more.
21 |
22 | Example:
23 | ```python
24 | from torchsample.modules import ModuleTrainer
25 |
26 | # Define your model EXACTLY as normal
27 | class Network(nn.Module):
28 | def __init__(self):
29 | super(Network, self).__init__()
30 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
31 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
32 | self.fc1 = nn.Linear(1600, 128)
33 | self.fc2 = nn.Linear(128, 10)
34 |
35 | def forward(self, x):
36 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
37 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
38 | x = x.view(-1, 1600)
39 | x = F.relu(self.fc1(x))
40 | x = F.dropout(x, training=self.training)
41 | x = self.fc2(x)
42 | return F.log_softmax(x)
43 |
44 | model = Network()
45 | trainer = ModuleTrainer(model)
46 |
47 | trainer.compile(loss='nll_loss',
48 | optimizer='adadelta')
49 |
50 | trainer.fit(x_train, y_train,
51 | val_data=(x_test, y_test),
52 | num_epoch=20,
53 | batch_size=128,
54 | verbose=1)
55 | ```
56 | You also have access to the standard evaluation and prediction functions:
57 |
58 | ```python
59 | loss = model.evaluate(x_train, y_train)
60 | y_pred = model.predict(x_train)
61 | ```
62 | Torchsample provides a wide range of callbacks, generally mimicking the interface
63 | found in `Keras`:
64 |
65 | - `EarlyStopping`
66 | - `ModelCheckpoint`
67 | - `LearningRateScheduler`
68 | - `ReduceLROnPlateau`
69 | - `CSVLogger`
70 |
71 | ```python
72 | from torchsample.callbacks import EarlyStopping
73 |
74 | callbacks = [EarlyStopping(monitor='val_loss', patience=5)]
75 | model.set_callbacks(callbacks)
76 | ```
77 |
78 | Torchsample also provides regularizers:
79 |
80 | - `L1Regularizer`
81 | - `L2Regularizer`
82 | - `L1L2Regularizer`
83 |
84 |
85 | and constraints:
86 | - `UnitNorm`
87 | - `MaxNorm`
88 | - `NonNeg`
89 |
90 | Both regularizers and constraints can be selectively applied on layers using regular expressions and the `module_filter`
91 | argument. Constraints can be explicit (hard) constraints applied at an arbitrary batch or
92 | epoch frequency, or they can be implicit (soft) constraints similar to regularizers
93 | where the the constraint deviation is added as a penalty to the total model loss.
94 |
95 | ```python
96 | from torchsample.constraints import MaxNorm, NonNeg
97 | from torchsample.regularizers import L1Regularizer
98 |
99 | # hard constraint applied every 5 batches
100 | hard_constraint = MaxNorm(value=2., frequency=5, unit='batch', module_filter='*fc*')
101 | # implicit constraint added as a penalty term to model loss
102 | soft_constraint = NonNeg(lagrangian=True, scale=1e-3, module_filter='*fc*')
103 | constraints = [hard_constraint, soft_constraint]
104 | model.set_constraints(constraints)
105 |
106 | regularizers = [L1Regularizer(scale=1e-4, module_filter='*conv*')]
107 | model.set_regularizers(regularizers)
108 | ```
109 |
110 | You can also fit directly on a `torch.utils.data.DataLoader` and can have
111 | a validation set as well :
112 |
113 | ```python
114 | from torchsample import TensorDataset
115 | from torch.utils.data import DataLoader
116 |
117 | train_dataset = TensorDataset(x_train, y_train)
118 | train_loader = DataLoader(train_dataset, batch_size=32)
119 |
120 | val_dataset = TensorDataset(x_val, y_val)
121 | val_loader = DataLoader(val_dataset, batch_size=32)
122 |
123 | trainer.fit_loader(loader, val_loader=val_loader, num_epoch=100)
124 | ```
125 |
126 | ## Utility Functions
127 | Finally, torchsample provides a few utility functions not commonly found:
128 |
129 | ### Tensor Functions
130 | - `th_iterproduct` (mimics itertools.product)
131 | - `th_gather_nd` (N-dimensional version of torch.gather)
132 | - `th_random_choice` (mimics np.random.choice)
133 | - `th_pearsonr` (mimics scipy.stats.pearsonr)
134 | - `th_corrcoef` (mimics np.corrcoef)
135 | - `th_affine2d` and `th_affine3d` (affine transforms on torch.Tensors)
136 |
137 | ### Variable Functions
138 | - `F_affine2d` and `F_affine3d`
139 | - `F_map_coordinates2d` and `F_map_coordinates3d`
140 |
141 | ## Data Augmentation and Datasets
142 | The torchsample package provides a ton of good data augmentation and transformation
143 | tools which can be applied during data loading. The package also provides the flexible
144 | `TensorDataset` and `FolderDataset` classes to handle most dataset needs.
145 |
146 | ### Torch Transforms
147 | These transforms work directly on torch tensors
148 |
149 | - `Compose()`
150 | - `AddChannel()`
151 | - `SwapDims()`
152 | - `RangeNormalize()`
153 | - `StdNormalize()`
154 | - `Slice2D()`
155 | - `RandomCrop()`
156 | - `SpecialCrop()`
157 | - `Pad()`
158 | - `RandomFlip()`
159 | - `ToTensor()`
160 |
161 | ### Affine Transforms
162 |  
163 |
164 | The following transforms perform affine (or affine-like) transforms on torch tensors.
165 |
166 | - `Rotate()`
167 | - `Translate()`
168 | - `Shear()`
169 | - `Zoom()`
170 |
171 | We also provide a class for stringing multiple affine transformations together so that only one interpolation takes place:
172 |
173 | - `Affine()`
174 | - `AffineCompose()`
175 |
176 | ### Datasets and Sampling
177 | We provide the following datasets which provide general structure and iterators for sampling from and using transforms on in-memory or out-of-memory data:
178 |
179 | - `TensorDataset()`
180 |
181 | - `FolderDataset()`
182 |
183 |
184 | ## Acknowledgements
185 | Thank you to the following people and contributors:
186 | - All Keras contributors
187 | - @deallynomore
188 | - @recastrodiaz
189 |
190 |
--------------------------------------------------------------------------------
/examples/imgs/orig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/orig1.png
--------------------------------------------------------------------------------
/examples/imgs/orig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/orig2.png
--------------------------------------------------------------------------------
/examples/imgs/orig3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/orig3.png
--------------------------------------------------------------------------------
/examples/imgs/tform1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/tform1.png
--------------------------------------------------------------------------------
/examples/imgs/tform2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/tform2.png
--------------------------------------------------------------------------------
/examples/imgs/tform3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/achaiah/torchsample/21507feb258a25bf6924e4844e578624cda72140/examples/imgs/tform3.png
--------------------------------------------------------------------------------
/examples/mnist_example.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 | from torchsample.callbacks import EarlyStopping, ReduceLROnPlateau
8 | from torchsample.regularizers import L1Regularizer, L2Regularizer
9 | from torchsample.constraints import UnitNorm
10 | from torchsample.initializers import XavierUniform
11 | from torchsample.metrics import CategoricalAccuracy
12 |
13 | import os
14 | from torchvision import datasets
15 | ROOT = '/users/ncullen/desktop/data/mnist'
16 | dataset = datasets.MNIST(ROOT, train=True, download=True)
17 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
18 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
19 |
20 | x_train = x_train.float()
21 | y_train = y_train.long()
22 | x_test = x_test.float()
23 | y_test = y_test.long()
24 |
25 | x_train = x_train / 255.
26 | x_test = x_test / 255.
27 | x_train = x_train.unsqueeze(1)
28 | x_test = x_test.unsqueeze(1)
29 |
30 | # only train on a subset
31 | x_train = x_train[:10000]
32 | y_train = y_train[:10000]
33 | x_test = x_test[:1000]
34 | y_test = y_test[:1000]
35 |
36 |
37 | # Define your model EXACTLY as if you were using nn.Module
38 | class Network(nn.Module):
39 | def __init__(self):
40 | super(Network, self).__init__()
41 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
42 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
43 | self.fc1 = nn.Linear(1600, 128)
44 | self.fc2 = nn.Linear(128, 10)
45 |
46 | def forward(self, x):
47 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
48 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
49 | x = x.view(-1, 1600)
50 | x = F.relu(self.fc1(x))
51 | x = F.dropout(x, training=self.training)
52 | x = self.fc2(x)
53 | return F.log_softmax(x)
54 |
55 |
56 | model = Network()
57 | trainer = ModuleTrainer(model)
58 |
59 |
60 | callbacks = [EarlyStopping(patience=10),
61 | ReduceLROnPlateau(factor=0.5, patience=5)]
62 | regularizers = [L1Regularizer(scale=1e-3, module_filter='conv*'),
63 | L2Regularizer(scale=1e-5, module_filter='fc*')]
64 | constraints = [UnitNorm(frequency=3, unit='batch', module_filter='fc*')]
65 | initializers = [XavierUniform(bias=False, module_filter='fc*')]
66 | metrics = [CategoricalAccuracy(top_k=3)]
67 |
68 | trainer.compile(loss='nll_loss',
69 | optimizer='adadelta',
70 | regularizers=regularizers,
71 | constraints=constraints,
72 | initializers=initializers,
73 | metrics=metrics)
74 |
75 | #summary = trainer.summary([1,28,28])
76 | #print(summary)
77 |
78 | trainer.fit(x_train, y_train,
79 | val_data=(x_test, y_test),
80 | num_epoch=20,
81 | batch_size=128,
82 | verbose=1)
83 |
84 |
85 |
--------------------------------------------------------------------------------
/examples/mnist_loader_example.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.utils.data import DataLoader
7 |
8 | from torchsample.modules import ModuleTrainer
9 | from torchsample.callbacks import EarlyStopping, ReduceLROnPlateau
10 | from torchsample.regularizers import L1Regularizer, L2Regularizer
11 | from torchsample.constraints import UnitNorm
12 | from torchsample.initializers import XavierUniform
13 | from torchsample.metrics import CategoricalAccuracy
14 | from torchsample import TensorDataset
15 |
16 | import os
17 | from torchvision import datasets
18 | ROOT = '/users/ncullen/desktop/data/mnist'
19 | dataset = datasets.MNIST(ROOT, train=True, download=True)
20 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
21 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
22 |
23 | x_train = x_train.float()
24 | y_train = y_train.long()
25 | x_test = x_test.float()
26 | y_test = y_test.long()
27 |
28 | x_train = x_train / 255.
29 | x_test = x_test / 255.
30 | x_train = x_train.unsqueeze(1)
31 | x_test = x_test.unsqueeze(1)
32 |
33 | # only train on a subset
34 | x_train = x_train[:10000]
35 | y_train = y_train[:10000]
36 | x_test = x_test[:1000]
37 | y_test = y_test[:1000]
38 |
39 | train_dataset = TensorDataset(x_train, y_train)
40 | train_loader = DataLoader(train_dataset, batch_size=32)
41 | val_dataset = TensorDataset(x_test, y_test)
42 | val_loader = DataLoader(val_dataset, batch_size=32)
43 |
44 | # Define your model EXACTLY as if you were using nn.Module
45 | class Network(nn.Module):
46 | def __init__(self):
47 | super(Network, self).__init__()
48 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
49 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
50 | self.fc1 = nn.Linear(1600, 128)
51 | self.fc2 = nn.Linear(128, 10)
52 |
53 | def forward(self, x):
54 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
55 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
56 | x = x.view(-1, 1600)
57 | x = F.relu(self.fc1(x))
58 | x = F.dropout(x, training=self.training)
59 | x = self.fc2(x)
60 | return F.log_softmax(x)
61 |
62 |
63 | model = Network()
64 | trainer = ModuleTrainer(model)
65 |
66 | callbacks = [EarlyStopping(patience=10),
67 | ReduceLROnPlateau(factor=0.5, patience=5)]
68 | regularizers = [L1Regularizer(scale=1e-3, module_filter='conv*'),
69 | L2Regularizer(scale=1e-5, module_filter='fc*')]
70 | constraints = [UnitNorm(frequency=3, unit='batch', module_filter='fc*')]
71 | initializers = [XavierUniform(bias=False, module_filter='fc*')]
72 | metrics = [CategoricalAccuracy(top_k=3)]
73 |
74 | trainer.compile(loss='nll_loss',
75 | optimizer='adadelta',
76 | regularizers=regularizers,
77 | constraints=constraints,
78 | initializers=initializers,
79 | metrics=metrics,
80 | callbacks=callbacks)
81 |
82 | trainer.fit_loader(train_loader, val_loader, num_epoch=20, verbose=1)
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 | setup(name='torchsample',
6 | version='0.1.3',
7 | description='High-Level Training, Augmentation, and Sampling for Pytorch',
8 | author='NC Cullen',
9 | author_email='nickmarch31@yahoo.com',
10 | packages=find_packages()
11 | )
--------------------------------------------------------------------------------
/tests/integration/fit_complex/multi_input_multi_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 | from torchsample import regularizers as regs
8 | from torchsample import constraints as cons
9 | from torchsample import initializers as inits
10 | from torchsample import callbacks as cbks
11 | from torchsample import metrics
12 | from torchsample import transforms as tforms
13 |
14 | import os
15 | from torchvision import datasets
16 |
17 | ROOT = '/users/ncullen/desktop/data/mnist'
18 | dataset = datasets.MNIST(ROOT, train=True, download=True)
19 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
20 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
21 |
22 | x_train = x_train.float()
23 | y_train = y_train.long()
24 | x_test = x_test.float()
25 | y_test = y_test.long()
26 |
27 | x_train = x_train / 255.
28 | x_test = x_test / 255.
29 | x_train = x_train.unsqueeze(1)
30 | x_test = x_test.unsqueeze(1)
31 |
32 | # only train on a subset
33 | x_train = x_train[:1000]
34 | y_train = y_train[:1000]
35 | x_test = x_test[:100]
36 | y_test = y_test[:100]
37 |
38 |
39 | # Define your model EXACTLY as if you were using nn.Module
40 | class Network(nn.Module):
41 | def __init__(self):
42 | super(Network, self).__init__()
43 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
44 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
45 | self.fc1 = nn.Linear(1600, 128)
46 | self.fc2 = nn.Linear(128, 10)
47 |
48 | def forward(self, x, y, z):
49 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
50 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
51 | x = x.view(-1, 1600)
52 | x = F.relu(self.fc1(x))
53 | x = F.dropout(x, training=self.training)
54 | x = self.fc2(x)
55 | return F.log_softmax(x), F.log_softmax(x), F.log_softmax(x)
56 |
57 | # with one loss function given
58 | model = Network()
59 | trainer = ModuleTrainer(model)
60 |
61 | regularizers = [regs.L1Regularizer(1e-4, 'fc*'), regs.L2Regularizer(1e-5, 'conv*')]
62 | constraints = [cons.UnitNorm(5, 'batch', 'fc*'),
63 | cons.MaxNorm(5, 0, 'batch', 'conv*')]
64 | callbacks = [cbks.ReduceLROnPlateau(monitor='loss', verbose=1)]
65 |
66 | trainer.compile(loss='nll_loss',
67 | optimizer='adadelta',
68 | regularizers=regularizers,
69 | constraints=constraints,
70 | callbacks=callbacks)
71 |
72 | trainer.fit([x_train, x_train, x_train],
73 | [y_train, y_train, y_train],
74 | num_epoch=3,
75 | batch_size=128,
76 | verbose=1)
77 |
78 | yp1, yp2, yp3 = trainer.predict([x_train, x_train, x_train])
79 | print(yp1.size(), yp2.size(), yp3.size())
80 |
81 | eval_loss = trainer.evaluate([x_train, x_train, x_train],
82 | [y_train, y_train, y_train])
83 | print(eval_loss)
84 |
85 | # With multiple loss functions given
86 | model = Network()
87 | trainer = ModuleTrainer(model)
88 |
89 | trainer.compile(loss=['nll_loss', 'nll_loss', 'nll_loss'],
90 | optimizer='adadelta',
91 | regularizers=regularizers,
92 | constraints=constraints,
93 | callbacks=callbacks)
94 |
95 | trainer.fit([x_train, x_train, x_train],
96 | [y_train, y_train, y_train],
97 | num_epoch=3,
98 | batch_size=128,
99 | verbose=1)
100 |
101 | # should raise exception for giving multiple loss functions
102 | # but not giving a loss function for every input
103 | try:
104 | model = Network()
105 | trainer = ModuleTrainer(model)
106 |
107 | trainer.compile(loss=['nll_loss', 'nll_loss'],
108 | optimizer='adadelta',
109 | regularizers=regularizers,
110 | constraints=constraints,
111 | callbacks=callbacks)
112 |
113 | trainer.fit([x_train, x_train, x_train],
114 | [y_train, y_train, y_train],
115 | num_epoch=3,
116 | batch_size=128,
117 | verbose=1)
118 | except:
119 | print('Exception correctly caught')
120 |
121 |
--------------------------------------------------------------------------------
/tests/integration/fit_loader_simple/single_input_multi_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.utils.data import DataLoader
6 |
7 | from torchsample.modules import ModuleTrainer
8 | from torchsample import TensorDataset
9 |
10 | import os
11 | from torchvision import datasets
12 | ROOT = '/users/ncullen/desktop/data/mnist'
13 | dataset = datasets.MNIST(ROOT, train=True, download=True)
14 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
15 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
16 |
17 | x_train = x_train.float()
18 | y_train = y_train.long()
19 | x_test = x_test.float()
20 | y_test = y_test.long()
21 |
22 | x_train = x_train / 255.
23 | x_test = x_test / 255.
24 | x_train = x_train.unsqueeze(1)
25 | x_test = x_test.unsqueeze(1)
26 |
27 | # only train on a subset
28 | x_train = x_train[:1000]
29 | y_train = y_train[:1000]
30 | x_test = x_test[:1000]
31 | y_test = y_test[:1000]
32 |
33 | train_data = TensorDataset(x_train, [y_train, y_train])
34 | train_loader = DataLoader(train_data, batch_size=128)
35 |
36 | # Define your model EXACTLY as if you were using nn.Module
37 | class Network(nn.Module):
38 | def __init__(self):
39 | super(Network, self).__init__()
40 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
41 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
42 | self.fc1 = nn.Linear(1600, 128)
43 | self.fc2 = nn.Linear(128, 10)
44 |
45 | def forward(self, x):
46 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
47 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
48 | x = x.view(-1, 1600)
49 | x = F.relu(self.fc1(x))
50 | x = F.dropout(x, training=self.training)
51 | x = self.fc2(x)
52 | return F.log_softmax(x), F.log_softmax(x)
53 |
54 |
55 | # one loss function for multiple targets
56 | model = Network()
57 | trainer = ModuleTrainer(model)
58 | trainer.compile(loss='nll_loss',
59 | optimizer='adadelta')
60 |
61 | trainer.fit_loader(train_loader,
62 | num_epoch=3,
63 | verbose=1)
64 | ypred1, ypred2 = trainer.predict(x_train)
65 | print(ypred1.size(), ypred2.size())
66 |
67 | eval_loss = trainer.evaluate(x_train, [y_train, y_train])
68 | print(eval_loss)
69 | # multiple loss functions
70 | model = Network()
71 | trainer = ModuleTrainer(model)
72 | trainer.compile(loss=['nll_loss', 'nll_loss'],
73 | optimizer='adadelta')
74 | trainer.fit_loader(train_loader,
75 | num_epoch=3,
76 | verbose=1)
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/tests/integration/fit_loader_simple/single_input_single_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.utils.data import DataLoader
6 |
7 | from torchsample.modules import ModuleTrainer
8 | from torchsample import TensorDataset
9 |
10 | import os
11 | from torchvision import datasets
12 | ROOT = '/users/ncullen/desktop/data/mnist'
13 | dataset = datasets.MNIST(ROOT, train=True, download=True)
14 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
15 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
16 |
17 | x_train = x_train.float()
18 | y_train = y_train.long()
19 | x_test = x_test.float()
20 | y_test = y_test.long()
21 |
22 | x_train = x_train / 255.
23 | x_test = x_test / 255.
24 | x_train = x_train.unsqueeze(1)
25 | x_test = x_test.unsqueeze(1)
26 |
27 | # only train on a subset
28 | x_train = x_train[:1000]
29 | y_train = y_train[:1000]
30 | x_test = x_test[:1000]
31 | y_test = y_test[:1000]
32 |
33 | train_data = TensorDataset(x_train, y_train)
34 | train_loader = DataLoader(train_data, batch_size=128)
35 |
36 | # Define your model EXACTLY as if you were using nn.Module
37 | class Network(nn.Module):
38 | def __init__(self):
39 | super(Network, self).__init__()
40 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
41 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
42 | self.fc1 = nn.Linear(1600, 128)
43 | self.fc2 = nn.Linear(128, 10)
44 |
45 | def forward(self, x):
46 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
47 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
48 | x = x.view(-1, 1600)
49 | x = F.relu(self.fc1(x))
50 | #x = F.dropout(x, training=self.training)
51 | x = self.fc2(x)
52 | return F.log_softmax(x)
53 |
54 |
55 | model = Network()
56 | trainer = ModuleTrainer(model)
57 |
58 | trainer.compile(loss='nll_loss',
59 | optimizer='adadelta')
60 |
61 | trainer.fit_loader(train_loader,
62 | num_epoch=3,
63 | verbose=1)
64 |
65 | ypred = trainer.predict(x_train)
66 | print(ypred.size())
67 |
68 | eval_loss = trainer.evaluate(x_train, y_train)
69 | print(eval_loss)
70 |
71 | print(trainer.history)
72 | #print(trainer.history['loss'])
73 |
74 |
--------------------------------------------------------------------------------
/tests/integration/fit_simple/simple_multi_input_multi_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 |
8 | import os
9 | from torchvision import datasets
10 | ROOT = '/users/ncullen/desktop/data/mnist'
11 | dataset = datasets.MNIST(ROOT, train=True, download=True)
12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
14 |
15 | x_train = x_train.float()
16 | y_train = y_train.long()
17 | x_test = x_test.float()
18 | y_test = y_test.long()
19 |
20 | x_train = x_train / 255.
21 | x_test = x_test / 255.
22 | x_train = x_train.unsqueeze(1)
23 | x_test = x_test.unsqueeze(1)
24 |
25 | # only train on a subset
26 | x_train = x_train[:1000]
27 | y_train = y_train[:1000]
28 | x_test = x_test[:100]
29 | y_test = y_test[:100]
30 |
31 |
32 | # Define your model EXACTLY as if you were using nn.Module
33 | class Network(nn.Module):
34 | def __init__(self):
35 | super(Network, self).__init__()
36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
38 | self.fc1 = nn.Linear(1600, 128)
39 | self.fc2 = nn.Linear(128, 10)
40 |
41 | def forward(self, x, y, z):
42 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
43 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
44 | x = x.view(-1, 1600)
45 | x = F.relu(self.fc1(x))
46 | x = F.dropout(x, training=self.training)
47 | x = self.fc2(x)
48 | return F.log_softmax(x), F.log_softmax(x), F.log_softmax(x)
49 |
50 | # with one loss function given
51 | model = Network()
52 | trainer = ModuleTrainer(model)
53 |
54 | trainer.compile(loss='nll_loss',
55 | optimizer='adadelta')
56 |
57 | trainer.fit([x_train, x_train, x_train],
58 | [y_train, y_train, y_train],
59 | num_epoch=3,
60 | batch_size=128,
61 | verbose=1)
62 |
63 | yp1, yp2, yp3 = trainer.predict([x_train, x_train, x_train])
64 | print(yp1.size(), yp2.size(), yp3.size())
65 |
66 | eval_loss = trainer.evaluate([x_train, x_train, x_train],
67 | [y_train, y_train, y_train])
68 | print(eval_loss)
69 |
70 | # With multiple loss functions given
71 | model = Network()
72 | trainer = ModuleTrainer(model)
73 |
74 | trainer.compile(loss=['nll_loss', 'nll_loss', 'nll_loss'],
75 | optimizer='adadelta')
76 |
77 | trainer.fit([x_train, x_train, x_train],
78 | [y_train, y_train, y_train],
79 | num_epoch=3,
80 | batch_size=128,
81 | verbose=1)
82 |
83 | # should raise exception for giving multiple loss functions
84 | # but not giving a loss function for every input
85 | try:
86 | model = Network()
87 | trainer = ModuleTrainer(model)
88 |
89 | trainer.compile(loss=['nll_loss', 'nll_loss'],
90 | optimizer='adadelta')
91 |
92 | trainer.fit([x_train, x_train, x_train],
93 | [y_train, y_train, y_train],
94 | num_epoch=3,
95 | batch_size=128,
96 | verbose=1)
97 | except:
98 | print('Exception correctly caught')
99 |
100 |
--------------------------------------------------------------------------------
/tests/integration/fit_simple/simple_multi_input_no_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 |
8 | import os
9 | from torchvision import datasets
10 | ROOT = '/users/ncullen/desktop/data/mnist'
11 | dataset = datasets.MNIST(ROOT, train=True, download=True)
12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
14 |
15 | x_train = x_train.float()
16 | y_train = y_train.long()
17 | x_test = x_test.float()
18 | y_test = y_test.long()
19 |
20 | x_train = x_train / 255.
21 | x_test = x_test / 255.
22 | x_train = x_train.unsqueeze(1)
23 | x_test = x_test.unsqueeze(1)
24 |
25 | # only train on a subset
26 | x_train = x_train[:1000]
27 | y_train = y_train[:1000]
28 | x_test = x_test[:1000]
29 | y_test = y_test[:1000]
30 |
31 |
32 | # Define your model EXACTLY as if you were using nn.Module
33 | class Network(nn.Module):
34 | def __init__(self):
35 | super(Network, self).__init__()
36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
38 | self.fc1 = nn.Linear(1600, 128)
39 | self.fc2 = nn.Linear(128, 1)
40 |
41 | def forward(self, x, y, z):
42 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
43 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
44 | x = x.view(-1, 1600)
45 | x = F.relu(self.fc1(x))
46 | x = F.dropout(x, training=self.training)
47 | x = self.fc2(x)
48 | return th.abs(10 - x)
49 |
50 |
51 | model = Network()
52 | trainer = ModuleTrainer(model)
53 |
54 | trainer.compile(loss='unconstrained_sum',
55 | optimizer='adadelta')
56 |
57 | trainer.fit([x_train, x_train, x_train],
58 | num_epoch=3,
59 | batch_size=128,
60 | verbose=1)
61 |
62 | ypred = trainer.predict([x_train, x_train, x_train])
63 | print(ypred.size())
64 |
65 | eval_loss = trainer.evaluate([x_train, x_train, x_train])
66 | print(eval_loss)
67 |
68 |
--------------------------------------------------------------------------------
/tests/integration/fit_simple/simple_multi_input_single_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 |
8 | import os
9 | from torchvision import datasets
10 | ROOT = '/users/ncullen/desktop/data/mnist'
11 | dataset = datasets.MNIST(ROOT, train=True, download=True)
12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
14 |
15 | x_train = x_train.float()
16 | y_train = y_train.long()
17 | x_test = x_test.float()
18 | y_test = y_test.long()
19 |
20 | x_train = x_train / 255.
21 | x_test = x_test / 255.
22 | x_train = x_train.unsqueeze(1)
23 | x_test = x_test.unsqueeze(1)
24 |
25 | # only train on a subset
26 | x_train = x_train[:1000]
27 | y_train = y_train[:1000]
28 | x_test = x_test[:100]
29 | y_test = y_test[:100]
30 |
31 |
32 | # Define your model EXACTLY as if you were using nn.Module
33 | class Network(nn.Module):
34 | def __init__(self):
35 | super(Network, self).__init__()
36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
38 | self.fc1 = nn.Linear(1600, 128)
39 | self.fc2 = nn.Linear(128, 10)
40 |
41 | def forward(self, x, y, z):
42 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
43 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
44 | x = x.view(-1, 1600)
45 | x = F.relu(self.fc1(x))
46 | x = F.dropout(x, training=self.training)
47 | x = self.fc2(x)
48 | return F.log_softmax(x)
49 |
50 |
51 | model = Network()
52 | trainer = ModuleTrainer(model)
53 |
54 | trainer.compile(loss='nll_loss',
55 | optimizer='adadelta')
56 |
57 | trainer.fit([x_train, x_train, x_train], y_train,
58 | val_data=([x_test, x_test, x_test], y_test),
59 | num_epoch=3,
60 | batch_size=128,
61 | verbose=1)
62 |
63 | ypred = trainer.predict([x_train, x_train, x_train])
64 | print(ypred.size())
65 |
66 | eval_loss = trainer.evaluate([x_train, x_train, x_train], y_train)
67 | print(eval_loss)
--------------------------------------------------------------------------------
/tests/integration/fit_simple/single_input_multi_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 |
8 | import os
9 | from torchvision import datasets
10 | ROOT = '/users/ncullen/desktop/data/mnist'
11 | dataset = datasets.MNIST(ROOT, train=True, download=True)
12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
14 |
15 | x_train = x_train.float()
16 | y_train = y_train.long()
17 | x_test = x_test.float()
18 | y_test = y_test.long()
19 |
20 | x_train = x_train / 255.
21 | x_test = x_test / 255.
22 | x_train = x_train.unsqueeze(1)
23 | x_test = x_test.unsqueeze(1)
24 |
25 | # only train on a subset
26 | x_train = x_train[:1000]
27 | y_train = y_train[:1000]
28 | x_test = x_test[:1000]
29 | y_test = y_test[:1000]
30 |
31 |
32 | # Define your model EXACTLY as if you were using nn.Module
33 | class Network(nn.Module):
34 | def __init__(self):
35 | super(Network, self).__init__()
36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
38 | self.fc1 = nn.Linear(1600, 128)
39 | self.fc2 = nn.Linear(128, 10)
40 |
41 | def forward(self, x):
42 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
43 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
44 | x = x.view(-1, 1600)
45 | x = F.relu(self.fc1(x))
46 | x = F.dropout(x, training=self.training)
47 | x = self.fc2(x)
48 | return F.log_softmax(x), F.log_softmax(x)
49 |
50 |
51 | # one loss function for multiple targets
52 | model = Network()
53 | trainer = ModuleTrainer(model)
54 | trainer.compile(loss='nll_loss',
55 | optimizer='adadelta')
56 |
57 | trainer.fit(x_train,
58 | [y_train, y_train],
59 | num_epoch=3,
60 | batch_size=128,
61 | verbose=1)
62 | ypred1, ypred2 = trainer.predict(x_train)
63 | print(ypred1.size(), ypred2.size())
64 |
65 | eval_loss = trainer.evaluate(x_train, [y_train, y_train])
66 | print(eval_loss)
67 | # multiple loss functions
68 | model = Network()
69 | trainer = ModuleTrainer(model)
70 | trainer.compile(loss=['nll_loss', 'nll_loss'],
71 | optimizer='adadelta')
72 | trainer.fit(x_train,
73 | [y_train, y_train],
74 | num_epoch=3,
75 | batch_size=128,
76 | verbose=1)
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/tests/integration/fit_simple/single_input_no_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 |
8 | import os
9 | from torchvision import datasets
10 | ROOT = '/users/ncullen/desktop/data/mnist'
11 | dataset = datasets.MNIST(ROOT, train=True, download=True)
12 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
13 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
14 |
15 | x_train = x_train.float()
16 | y_train = y_train.long()
17 | x_test = x_test.float()
18 | y_test = y_test.long()
19 |
20 | x_train = x_train / 255.
21 | x_test = x_test / 255.
22 | x_train = x_train.unsqueeze(1)
23 | x_test = x_test.unsqueeze(1)
24 |
25 | # only train on a subset
26 | x_train = x_train[:1000]
27 | y_train = y_train[:1000]
28 | x_test = x_test[:1000]
29 | y_test = y_test[:1000]
30 |
31 |
32 | # Define your model EXACTLY as if you were using nn.Module
33 | class Network(nn.Module):
34 | def __init__(self):
35 | super(Network, self).__init__()
36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
38 | self.fc1 = nn.Linear(1600, 128)
39 | self.fc2 = nn.Linear(128, 1)
40 |
41 | def forward(self, x):
42 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
43 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
44 | x = x.view(-1, 1600)
45 | x = F.relu(self.fc1(x))
46 | x = F.dropout(x, training=self.training)
47 | x = self.fc2(x)
48 | return th.abs(10 - x)
49 |
50 |
51 | model = Network()
52 | trainer = ModuleTrainer(model)
53 |
54 | trainer.compile(loss='unconstrained_sum',
55 | optimizer='adadelta')
56 |
57 | trainer.fit(x_train,
58 | num_epoch=3,
59 | batch_size=128,
60 | verbose=1)
61 |
62 | ypred = trainer.predict(x_train)
63 | print(ypred.size())
64 |
65 | eval_loss = trainer.evaluate(x_train, None)
66 | print(eval_loss)
--------------------------------------------------------------------------------
/tests/integration/fit_simple/single_input_single_target.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torchsample.modules import ModuleTrainer
7 | from torchsample import regularizers as reg
8 | from torchsample import constraints as con
9 |
10 | import os
11 | from torchvision import datasets
12 | ROOT = '/users/ncullen/desktop/data/mnist'
13 | dataset = datasets.MNIST(ROOT, train=True, download=True)
14 | x_train, y_train = th.load(os.path.join(dataset.root, 'processed/training.pt'))
15 | x_test, y_test = th.load(os.path.join(dataset.root, 'processed/test.pt'))
16 |
17 | x_train = x_train.float()
18 | y_train = y_train.long()
19 | x_test = x_test.float()
20 | y_test = y_test.long()
21 |
22 | x_train = x_train / 255.
23 | x_test = x_test / 255.
24 | x_train = x_train.unsqueeze(1)
25 | x_test = x_test.unsqueeze(1)
26 |
27 | # only train on a subset
28 | x_train = x_train[:1000]
29 | y_train = y_train[:1000]
30 | x_test = x_test[:1000]
31 | y_test = y_test[:1000]
32 |
33 |
34 | # Define your model EXACTLY as if you were using nn.Module
35 | class Network(nn.Module):
36 | def __init__(self):
37 | super(Network, self).__init__()
38 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
39 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
40 | self.fc1 = nn.Linear(1600, 128)
41 | self.fc2 = nn.Linear(128, 10)
42 |
43 | def forward(self, x):
44 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
45 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
46 | x = x.view(-1, 1600)
47 | x = F.relu(self.fc1(x))
48 | #x = F.dropout(x, training=self.training)
49 | x = self.fc2(x)
50 | return F.log_softmax(x)
51 |
52 |
53 | model = Network()
54 | trainer = ModuleTrainer(model)
55 |
56 | trainer.compile(loss='nll_loss',
57 | optimizer='adadelta',
58 | regularizers=[reg.L1Regularizer(1e-4)])
59 |
60 | trainer.fit(x_train, y_train,
61 | val_data=(x_test, y_test),
62 | num_epoch=3,
63 | batch_size=128,
64 | verbose=1)
65 |
66 | ypred = trainer.predict(x_train)
67 | print(ypred.size())
68 |
69 | eval_loss = trainer.evaluate(x_train, y_train)
70 | print(eval_loss)
71 |
72 | print(trainer.history)
73 | #print(trainer.history['loss'])
74 |
75 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | from torch.autograd import Variable
4 |
5 | from torchsample.metrics import CategoricalAccuracy
6 |
7 | class TestMetrics(unittest.TestCase):
8 |
9 | def test_categorical_accuracy(self):
10 | metric = CategoricalAccuracy()
11 | predicted = Variable(torch.eye(10))
12 | expected = Variable(torch.LongTensor(list(range(10))))
13 | self.assertEqual(metric(predicted, expected), 100.0)
14 |
15 | # Set 1st column to ones
16 | predicted = Variable(torch.zeros(10, 10))
17 | predicted.data[:, 0] = torch.ones(10)
18 | self.assertEqual(metric(predicted, expected), 55.0)
19 |
20 | if __name__ == '__main__':
21 | unittest.main()
--------------------------------------------------------------------------------
/tests/transforms/test_affine_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Test affine transforms
3 |
4 | Transforms:
5 | - Affine + RandomAffine
6 | - AffineCompose
7 | - Rotate + RandomRotate
8 | - Translate + RandomTranslate
9 | - Shear + RandomShear
10 | - Zoom + RandomZoom
11 | """
12 |
13 | #import pytest
14 |
15 | import torch as th
16 |
17 | from torchsample.transforms import (RandomAffine, Affine,
18 | RandomRotate, RandomChoiceRotate, Rotate,
19 | RandomTranslate, RandomChoiceTranslate, Translate,
20 | RandomShear, RandomChoiceShear, Shear,
21 | RandomZoom, RandomChoiceZoom, Zoom)
22 |
23 | # ----------------------------------------------------
24 | # ----------------------------------------------------
25 |
26 | ## DATA SET ##
27 | def gray2d_setup():
28 | images = {}
29 |
30 | x = th.zeros(1,30,30)
31 | x[:,10:21,10:21] = 1
32 | images['gray_01'] = x
33 |
34 | x = th.zeros(1,30,40)
35 | x[:,10:21,10:21] = 1
36 | images['gray_02'] = x
37 |
38 | return images
39 |
40 | def multi_gray2d_setup():
41 | old_imgs = gray2d_setup()
42 | images = {}
43 | for k,v in old_imgs.items():
44 | images[k+'_2imgs'] = [v,v]
45 | images[k+'_3imgs'] = [v,v,v]
46 | images[k+'_4imgs'] = [v,v,v,v]
47 | return images
48 |
49 | def color2d_setup():
50 | images = {}
51 |
52 | x = th.zeros(3,30,30)
53 | x[:,10:21,10:21] = 1
54 | images['color_01'] = x
55 |
56 | x = th.zeros(3,30,40)
57 | x[:,10:21,10:21] = 1
58 | images['color_02'] = x
59 |
60 | return images
61 |
62 | def multi_color2d_setup():
63 | old_imgs = color2d_setup()
64 | images = {}
65 | for k,v in old_imgs.items():
66 | images[k+'_2imgs'] = [v,v]
67 | images[k+'_3imgs'] = [v,v,v]
68 | images[k+'_4imgs'] = [v,v,v,v]
69 | return images
70 |
71 |
72 | # ----------------------------------------------------
73 | # ----------------------------------------------------
74 |
75 |
76 | def Affine_setup():
77 | tforms = {}
78 | tforms['random_affine'] = RandomAffine(rotation_range=30,
79 | translation_range=0.1)
80 | tforms['affine'] = Affine(th.FloatTensor([[0.9,0,0],[0,0.9,0]]))
81 | return tforms
82 |
83 | def Rotate_setup():
84 | tforms = {}
85 | tforms['random_rotate'] = RandomRotate(30)
86 | tforms['random_choice_rotate'] = RandomChoiceRotate([30,40,50])
87 | tforms['rotate'] = Rotate(30)
88 | return tforms
89 |
90 | def Translate_setup():
91 | tforms = {}
92 | tforms['random_translate'] = RandomTranslate(0.1)
93 | tforms['random_choice_translate'] = RandomChoiceTranslate([0.1,0.2])
94 | tforms['translate'] = Translate(0.3)
95 | return tforms
96 |
97 | def Shear_setup():
98 | tforms = {}
99 | tforms['random_shear'] = RandomShear(30)
100 | tforms['random_choice_shear'] = RandomChoiceShear([20,30,40])
101 | tforms['shear'] = Shear(25)
102 | return tforms
103 |
104 | def Zoom_setup():
105 | tforms = {}
106 | tforms['random_zoom'] = RandomZoom((0.8,1.2))
107 | tforms['random_choice_zoom'] = RandomChoiceZoom([0.8,0.9,1.1,1.2])
108 | tforms['zoom'] = Zoom(0.9)
109 | return tforms
110 |
111 | # ----------------------------------------------------
112 | # ----------------------------------------------------
113 |
114 | def test_affine_transforms_runtime(verbose=1):
115 | """
116 | Test that there are no runtime errors
117 | """
118 | ### MAKE TRANSFORMS ###
119 | tforms = {}
120 | tforms.update(Affine_setup())
121 | tforms.update(Rotate_setup())
122 | tforms.update(Translate_setup())
123 | tforms.update(Shear_setup())
124 | tforms.update(Zoom_setup())
125 |
126 | ### MAKE DATA
127 | images = {}
128 | images.update(gray2d_setup())
129 | images.update(multi_gray2d_setup())
130 | images.update(color2d_setup())
131 | images.update(multi_color2d_setup())
132 |
133 | successes = []
134 | failures = []
135 | for im_key, im_val in images.items():
136 | for tf_key, tf_val in tforms.items():
137 | try:
138 | if isinstance(im_val, (tuple,list)):
139 | tf_val(*im_val)
140 | else:
141 | tf_val(im_val)
142 | successes.append((im_key, tf_key))
143 | except:
144 | failures.append((im_key, tf_key))
145 |
146 | if verbose > 0:
147 | for k, v in failures:
148 | print('%s - %s' % (k, v))
149 |
150 | print('# SUCCESSES: ', len(successes))
151 | print('# FAILURES: ' , len(failures))
152 |
153 |
154 | if __name__=='__main__':
155 | test_affine_transforms_runtime()
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/tests/transforms/test_image_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for torchsample/transforms/image_transforms.py
3 | """
4 |
5 |
6 | import torch as th
7 |
8 | from torchsample.transforms import (Grayscale, RandomGrayscale,
9 | Gamma, RandomGamma, RandomChoiceGamma,
10 | Brightness, RandomBrightness, RandomChoiceBrightness,
11 | Saturation, RandomSaturation, RandomChoiceSaturation,
12 | Contrast, RandomContrast, RandomChoiceContrast)
13 |
14 | # ----------------------------------------------------
15 | # ----------------------------------------------------
16 |
17 | ## DATA SET ##
18 | def gray2d_setup():
19 | images = {}
20 |
21 | x = th.zeros(1,30,30)
22 | x[:,10:21,10:21] = 1
23 | images['gray_01'] = x
24 |
25 | x = th.zeros(1,30,40)
26 | x[:,10:21,10:21] = 1
27 | images['gray_02'] = x
28 |
29 | return images
30 |
31 | def multi_gray2d_setup():
32 | old_imgs = gray2d_setup()
33 | images = {}
34 | for k,v in old_imgs.items():
35 | images[k+'_2imgs'] = [v,v]
36 | images[k+'_3imgs'] = [v,v,v]
37 | images[k+'_4imgs'] = [v,v,v,v]
38 | return images
39 |
40 | def color2d_setup():
41 | images = {}
42 |
43 | x = th.zeros(3,30,30)
44 | x[:,10:21,10:21] = 1
45 | images['color_01'] = x
46 |
47 | x = th.zeros(3,30,40)
48 | x[:,10:21,10:21] = 1
49 | images['color_02'] = x
50 |
51 | return images
52 |
53 | def multi_color2d_setup():
54 | old_imgs = color2d_setup()
55 | images = {}
56 | for k,v in old_imgs.items():
57 | images[k+'_2imgs'] = [v,v]
58 | images[k+'_3imgs'] = [v,v,v]
59 | images[k+'_4imgs'] = [v,v,v,v]
60 | return images
61 |
62 | # ----------------------------------------------------
63 | # ----------------------------------------------------
64 |
65 | ## TFORMS SETUP ###
66 | def Grayscale_setup():
67 | tforms = {}
68 | tforms['grayscale_keepchannels'] = Grayscale(keep_channels=True)
69 | tforms['grayscale_dontkeepchannels'] = Grayscale(keep_channels=False)
70 |
71 | tforms['random_grayscale_nop'] = RandomGrayscale()
72 | tforms['random_grayscale_p_01'] = RandomGrayscale(0)
73 | tforms['random_grayscale_p_02'] = RandomGrayscale(0.5)
74 | tforms['random_grayscale_p_03'] = RandomGrayscale(1)
75 |
76 | return tforms
77 |
78 | def Gamma_setup():
79 | tforms = {}
80 | tforms['gamma_<1'] = Gamma(value=0.5)
81 | tforms['gamma_=1'] = Gamma(value=1.0)
82 | tforms['gamma_>1'] = Gamma(value=1.5)
83 | tforms['random_gamma_01'] = RandomGamma(0.5,1.5)
84 | tforms['random_gamma_02'] = RandomGamma(0.5,1.0)
85 | tforms['random_gamma_03'] = RandomGamma(1.0,1.5)
86 | tforms['random_choice_gamma_01'] = RandomChoiceGamma([0.5,1.0])
87 | tforms['random_choice_gamma_02'] = RandomChoiceGamma([0.5,1.0],p=[0.5,0.5])
88 | tforms['random_choice_gamma_03'] = RandomChoiceGamma([0.5,1.0],p=[0.2,0.8])
89 |
90 | return tforms
91 |
92 | def Brightness_setup():
93 | tforms = {}
94 | tforms['brightness_=-1'] = Brightness(value=-1)
95 | tforms['brightness_<0'] = Brightness(value=-0.5)
96 | tforms['brightness_=0'] = Brightness(value=0)
97 | tforms['brightness_>0'] = Brightness(value=0.5)
98 | tforms['brightness_=1'] = Brightness(value=1)
99 |
100 | tforms['random_brightness_01'] = RandomBrightness(-1,-0.5)
101 | tforms['random_brightness_02'] = RandomBrightness(-0.5,0)
102 | tforms['random_brightness_03'] = RandomBrightness(0,0.5)
103 | tforms['random_brightness_04'] = RandomBrightness(0.5,1)
104 |
105 | tforms['random_choice_brightness_01'] = RandomChoiceBrightness([-1,0,1])
106 | tforms['random_choice_brightness_02'] = RandomChoiceBrightness([-1,0,1],p=[0.2,0.5,0.3])
107 | tforms['random_choice_brightness_03'] = RandomChoiceBrightness([0,0,0,0],p=[0.25,0.5,0.25,0.25])
108 |
109 | return tforms
110 |
111 | def Saturation_setup():
112 | tforms = {}
113 | tforms['saturation_=-1'] = Saturation(-1)
114 | tforms['saturation_<0'] = Saturation(-0.5)
115 | tforms['saturation_=0'] = Saturation(0)
116 | tforms['saturation_>0'] = Saturation(0.5)
117 | tforms['saturation_=1'] = Saturation(1)
118 |
119 | tforms['random_saturation_01'] = RandomSaturation(-1,-0.5)
120 | tforms['random_saturation_02'] = RandomSaturation(-0.5,0)
121 | tforms['random_saturation_03'] = RandomSaturation(0,0.5)
122 | tforms['random_saturation_04'] = RandomSaturation(0.5,1)
123 |
124 | tforms['random_choice_saturation_01'] = RandomChoiceSaturation([-1,0,1])
125 | tforms['random_choice_saturation_02'] = RandomChoiceSaturation([-1,0,1],p=[0.2,0.5,0.3])
126 | tforms['random_choice_saturation_03'] = RandomChoiceSaturation([0,0,0,0],p=[0.25,0.5,0.25,0.25])
127 |
128 | return tforms
129 |
130 | def Contrast_setup():
131 | tforms = {}
132 | tforms['contrast_<<0'] = Contrast(-10)
133 | tforms['contrast_<0'] = Contrast(-1)
134 | tforms['contrast_=0'] = Contrast(0)
135 | tforms['contrast_>0'] = Contrast(1)
136 | tforms['contrast_>>0'] = Contrast(10)
137 |
138 | tforms['random_contrast_01'] = RandomContrast(-10,-1)
139 | tforms['random_contrast_02'] = RandomContrast(-1,0)
140 | tforms['random_contrast_03'] = RandomContrast(0,1)
141 | tforms['random_contrast_04'] = RandomContrast(1,10)
142 |
143 | tforms['random_choice_saturation_01'] = RandomChoiceContrast([-1,0,1])
144 | tforms['random_choice_saturation_02'] = RandomChoiceContrast([-10,0,10],p=[0.2,0.5,0.3])
145 | tforms['random_choice_saturation_03'] = RandomChoiceContrast([1,1],p=[0.5,0.5])
146 |
147 | return tforms
148 |
149 | # ----------------------------------------------------
150 | # ----------------------------------------------------
151 |
152 | def test_image_transforms_runtime(verbose=1):
153 | """
154 | Test that there are no runtime errors
155 | """
156 | ### MAKE TRANSFORMS ###
157 | tforms = {}
158 | tforms.update(Gamma_setup())
159 | tforms.update(Brightness_setup())
160 | tforms.update(Saturation_setup())
161 | tforms.update(Contrast_setup())
162 |
163 | ### MAKE DATA ###
164 | images = {}
165 | images.update(gray2d_setup())
166 | images.update(multi_gray2d_setup())
167 | images.update(color2d_setup())
168 | images.update(multi_color2d_setup())
169 |
170 | successes = []
171 | failures = []
172 | for im_key, im_val in images.items():
173 | for tf_key, tf_val in tforms.items():
174 | try:
175 | if isinstance(im_val, (tuple,list)):
176 | tf_val(*im_val)
177 | else:
178 | tf_val(im_val)
179 | successes.append((im_key, tf_key))
180 | except:
181 | failures.append((im_key, tf_key))
182 |
183 | if verbose > 0:
184 | for k, v in failures:
185 | print('%s - %s' % (k, v))
186 |
187 | print('# SUCCESSES: ', len(successes))
188 | print('# FAILURES: ' , len(failures))
189 |
190 |
191 | if __name__=='__main__':
192 | test_image_transforms_runtime()
193 |
--------------------------------------------------------------------------------
/tests/transforms/test_tensor_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for torchsample/transforms/image_transforms.py
3 | """
4 |
5 |
6 | import torch as th
7 |
8 | from torchsample.transforms import (ToTensor,
9 | ToVariable,
10 | ToCuda,
11 | ToFile,
12 | ChannelsLast, HWC,
13 | ChannelsFirst, CHW,
14 | TypeCast,
15 | AddChannel,
16 | Transpose,
17 | RangeNormalize,
18 | StdNormalize,
19 | RandomCrop,
20 | SpecialCrop,
21 | Pad,
22 | RandomFlip,
23 | RandomOrder)
24 |
25 | # ----------------------------------------------------
26 |
27 | ## DATA SET ##
28 | def gray2d_setup():
29 | images = {}
30 |
31 | x = th.zeros(1,30,30)
32 | x[:,10:21,10:21] = 1
33 | images['gray_01'] = x
34 |
35 | x = th.zeros(1,30,40)
36 | x[:,10:21,10:21] = 1
37 | images['gray_02'] = x
38 | return images
39 |
40 | def multi_gray2d_setup():
41 | old_imgs = gray2d_setup()
42 | images = {}
43 | for k,v in old_imgs.items():
44 | images[k+'_2imgs'] = [v,v]
45 | images[k+'_3imgs'] = [v,v,v]
46 | images[k+'_4imgs'] = [v,v,v,v]
47 | return images
48 |
49 | def color2d_setup():
50 | images = {}
51 |
52 | x = th.zeros(3,30,30)
53 | x[:,10:21,10:21] = 1
54 | images['color_01'] = x
55 |
56 | x = th.zeros(3,30,40)
57 | x[:,10:21,10:21] = 1
58 | images['color_02'] = x
59 |
60 | return images
61 |
62 | def multi_color2d_setup():
63 | old_imgs = color2d_setup()
64 | images = {}
65 | for k,v in old_imgs.items():
66 | images[k+'_2imgs'] = [v,v]
67 | images[k+'_3imgs'] = [v,v,v]
68 | images[k+'_4imgs'] = [v,v,v,v]
69 | return images
70 | # ----------------------------------------------------
71 | # ----------------------------------------------------
72 |
73 | ## TFORMS SETUP ###
74 | def ToTensor_setup():
75 | tforms = {}
76 |
77 | tforms['totensor'] = ToTensor()
78 |
79 | return tforms
80 |
81 | def ToVariable_setup():
82 | tforms = {}
83 |
84 | tforms['tovariable'] = ToVariable()
85 |
86 | return tforms
87 |
88 | def ToCuda_setup():
89 | tforms = {}
90 |
91 | tforms['tocuda'] = ToCuda()
92 |
93 | return tforms
94 |
95 | def ToFile_setup():
96 | tforms = {}
97 |
98 | ROOT = '~/desktop/data/'
99 | tforms['tofile_npy'] = ToFile(root=ROOT, fmt='npy')
100 | tforms['tofile_pth'] = ToFile(root=ROOT, fmt='pth')
101 | tforms['tofile_jpg'] = ToFile(root=ROOT, fmt='jpg')
102 | tforms['tofile_png'] = ToFile(root=ROOT, fmt='png')
103 |
104 | return tforms
105 |
106 | def ChannelsLast_setup():
107 | tforms = {}
108 |
109 | tforms['channels_last'] = ChannelsLast()
110 | tforms['hwc'] = HWC()
111 |
112 | return tforms
113 |
114 | def ChannelsFirst_setup():
115 | tforms = {}
116 |
117 | tforms['channels_first'] = ChannelsFirst()
118 | tforms['chw'] = CHW()
119 |
120 | return tforms
121 |
122 | def TypeCast_setup():
123 | tforms = {}
124 |
125 | tforms['byte'] = TypeCast('byte')
126 | tforms['double'] = TypeCast('double')
127 | tforms['float'] = TypeCast('float')
128 | tforms['int'] = TypeCast('int')
129 | tforms['long'] = TypeCast('long')
130 | tforms['short'] = TypeCast('short')
131 |
132 | return tforms
133 |
134 | def AddChannel_setup():
135 | tforms = {}
136 |
137 | tforms['addchannel_axis0'] = AddChannel(axis=0)
138 | tforms['addchannel_axis1'] = AddChannel(axis=1)
139 | tforms['addchannel_axis2'] = AddChannel(axis=2)
140 |
141 | return tforms
142 |
143 | def Transpose_setup():
144 | tforms = {}
145 |
146 | tforms['transpose_01'] = Transpose(0, 1)
147 | tforms['transpose_02'] = Transpose(0, 2)
148 | tforms['transpose_10'] = Transpose(1, 0)
149 | tforms['transpose_12'] = Transpose(1, 2)
150 | tforms['transpose_20'] = Transpose(2, 0)
151 | tforms['transpose_21'] = Transpose(2, 1)
152 |
153 | return tforms
154 |
155 | def RangeNormalize_setup():
156 | tforms = {}
157 |
158 | tforms['rangenorm_01'] = RangeNormalize(0, 1)
159 | tforms['rangenorm_-11'] = RangeNormalize(-1, 1)
160 | tforms['rangenorm_-33'] = RangeNormalize(-3, 3)
161 | tforms['rangenorm_02'] = RangeNormalize(0, 2)
162 |
163 | return tforms
164 |
165 | def StdNormalize_setup():
166 | tforms = {}
167 |
168 | tforms['stdnorm'] = StdNormalize()
169 |
170 | return tforms
171 |
172 | def RandomCrop_setup():
173 | tforms = {}
174 |
175 | tforms['randomcrop_1010'] = RandomCrop((10,10))
176 | tforms['randomcrop_510'] = RandomCrop((5,10))
177 | tforms['randomcrop_105'] = RandomCrop((10,5))
178 | tforms['randomcrop_99'] = RandomCrop((9,9))
179 | tforms['randomcrop_79'] = RandomCrop((7,9))
180 | tforms['randomcrop_97'] = RandomCrop((9,7))
181 |
182 | return tforms
183 |
184 | def SpecialCrop_setup():
185 | tforms = {}
186 |
187 | tforms['specialcrop_0_1010'] = SpecialCrop((10,10),0)
188 | tforms['specialcrop_0_510'] = SpecialCrop((5,10),0)
189 | tforms['specialcrop_0_105'] = SpecialCrop((10,5),0)
190 | tforms['specialcrop_0_99'] = SpecialCrop((9,9),0)
191 | tforms['specialcrop_0_79'] = SpecialCrop((7,9),0)
192 | tforms['specialcrop_0_97'] = SpecialCrop((9,7),0)
193 |
194 | tforms['specialcrop_1_1010'] = SpecialCrop((10,10),1)
195 | tforms['specialcrop_1_510'] = SpecialCrop((5,10),1)
196 | tforms['specialcrop_1_105'] = SpecialCrop((10,5),1)
197 | tforms['specialcrop_1_99'] = SpecialCrop((9,9),1)
198 | tforms['specialcrop_1_79'] = SpecialCrop((7,9),1)
199 | tforms['specialcrop_1_97'] = SpecialCrop((9,7),1)
200 |
201 | tforms['specialcrop_2_1010'] = SpecialCrop((10,10),2)
202 | tforms['specialcrop_2_510'] = SpecialCrop((5,10),2)
203 | tforms['specialcrop_2_105'] = SpecialCrop((10,5),2)
204 | tforms['specialcrop_2_99'] = SpecialCrop((9,9),2)
205 | tforms['specialcrop_2_79'] = SpecialCrop((7,9),2)
206 | tforms['specialcrop_2_97'] = SpecialCrop((9,7),2)
207 |
208 | tforms['specialcrop_3_1010'] = SpecialCrop((10,10),3)
209 | tforms['specialcrop_3_510'] = SpecialCrop((5,10),3)
210 | tforms['specialcrop_3_105'] = SpecialCrop((10,5),3)
211 | tforms['specialcrop_3_99'] = SpecialCrop((9,9),3)
212 | tforms['specialcrop_3_79'] = SpecialCrop((7,9),3)
213 | tforms['specialcrop_3_97'] = SpecialCrop((9,7),3)
214 |
215 | tforms['specialcrop_4_1010'] = SpecialCrop((10,10),4)
216 | tforms['specialcrop_4_510'] = SpecialCrop((5,10),4)
217 | tforms['specialcrop_4_105'] = SpecialCrop((10,5),4)
218 | tforms['specialcrop_4_99'] = SpecialCrop((9,9),4)
219 | tforms['specialcrop_4_79'] = SpecialCrop((7,9),4)
220 | tforms['specialcrop_4_97'] = SpecialCrop((9,7),4)
221 | return tforms
222 |
223 | def Pad_setup():
224 | tforms = {}
225 |
226 | tforms['pad_4040'] = Pad((40,40))
227 | tforms['pad_3040'] = Pad((30,40))
228 | tforms['pad_4030'] = Pad((40,30))
229 | tforms['pad_3939'] = Pad((39,39))
230 | tforms['pad_3941'] = Pad((39,41))
231 | tforms['pad_4139'] = Pad((41,39))
232 | tforms['pad_4138'] = Pad((41,38))
233 | tforms['pad_3841'] = Pad((38,41))
234 |
235 | return tforms
236 |
237 | def RandomFlip_setup():
238 | tforms = {}
239 |
240 | tforms['randomflip_h_01'] = RandomFlip(h=True, v=False)
241 | tforms['randomflip_h_02'] = RandomFlip(h=True, v=False, p=0)
242 | tforms['randomflip_h_03'] = RandomFlip(h=True, v=False, p=1)
243 | tforms['randomflip_h_04'] = RandomFlip(h=True, v=False, p=0.3)
244 | tforms['randomflip_v_01'] = RandomFlip(h=False, v=True)
245 | tforms['randomflip_v_02'] = RandomFlip(h=False, v=True, p=0)
246 | tforms['randomflip_v_03'] = RandomFlip(h=False, v=True, p=1)
247 | tforms['randomflip_v_04'] = RandomFlip(h=False, v=True, p=0.3)
248 | tforms['randomflip_hv_01'] = RandomFlip(h=True, v=True)
249 | tforms['randomflip_hv_02'] = RandomFlip(h=True, v=True, p=0)
250 | tforms['randomflip_hv_03'] = RandomFlip(h=True, v=True, p=1)
251 | tforms['randomflip_hv_04'] = RandomFlip(h=True, v=True, p=0.3)
252 | return tforms
253 |
254 | def RandomOrder_setup():
255 | tforms = {}
256 |
257 | tforms['randomorder'] = RandomOrder()
258 |
259 | return tforms
260 |
261 | # ----------------------------------------------------
262 | # ----------------------------------------------------
263 |
264 | def test_image_transforms_runtime(verbose=1):
265 | ### MAKE TRANSFORMS ###
266 | tforms = {}
267 | tforms.update(ToTensor_setup())
268 | tforms.update(ToVariable_setup())
269 | tforms.update(ToCuda_setup())
270 | #tforms.update(ToFile_setup())
271 | tforms.update(ChannelsLast_setup())
272 | tforms.update(ChannelsFirst_setup())
273 | tforms.update(TypeCast_setup())
274 | tforms.update(AddChannel_setup())
275 | tforms.update(Transpose_setup())
276 | tforms.update(RangeNormalize_setup())
277 | tforms.update(StdNormalize_setup())
278 | tforms.update(RandomCrop_setup())
279 | tforms.update(SpecialCrop_setup())
280 | tforms.update(Pad_setup())
281 | tforms.update(RandomFlip_setup())
282 | tforms.update(RandomOrder_setup())
283 |
284 |
285 | ### MAKE DATA
286 | images = {}
287 | images.update(gray2d_setup())
288 | images.update(multi_gray2d_setup())
289 | images.update(color2d_setup())
290 | images.update(multi_color2d_setup())
291 |
292 | successes =[]
293 | failures = []
294 | for im_key, im_val in images.items():
295 | for tf_key, tf_val in tforms.items():
296 | try:
297 | if isinstance(im_val, (tuple,list)):
298 | tf_val(*im_val)
299 | else:
300 | tf_val(im_val)
301 | successes.append((im_key, tf_key))
302 | except:
303 | failures.append((im_key, tf_key))
304 |
305 | if verbose > 0:
306 | for k, v in failures:
307 | print('%s - %s' % (k, v))
308 |
309 | print('# SUCCESSES: ', len(successes))
310 | print('# FAILURES: ' , len(failures))
311 |
312 |
313 | if __name__=='__main__':
314 | test_image_transforms_runtime()
315 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch as th
4 |
5 | def get_test_data(num_train=1000, num_test=500,
6 | input_shape=(10,), output_shape=(2,),
7 | classification=True, num_classes=2):
8 | """Generates test data to train a model on.
9 |
10 | classification=True overrides output_shape
11 | (i.e. output_shape is set to (1,)) and the output
12 | consists in integers in [0, num_class-1].
13 |
14 | Otherwise: float output with shape output_shape.
15 | """
16 | samples = num_train + num_test
17 | if classification:
18 | y = np.random.randint(0, num_classes, size=(samples,))
19 | X = np.zeros((samples,) + input_shape)
20 | for i in range(samples):
21 | X[i] = np.random.normal(loc=y[i], scale=0.7, size=input_shape)
22 | else:
23 | y_loc = np.random.random((samples,))
24 | X = np.zeros((samples,) + input_shape)
25 | y = np.zeros((samples,) + output_shape)
26 | for i in range(samples):
27 | X[i] = np.random.normal(loc=y_loc[i], scale=0.7, size=input_shape)
28 | y[i] = np.random.normal(loc=y_loc[i], scale=0.7, size=output_shape)
29 |
30 | return (th.from_numpy(X[:num_train]), th.from_numpy(y[:num_train])), \
31 | (th.from_numpy(X[num_train:]), th.from_numpy(y[num_train:]))
--------------------------------------------------------------------------------
/torchsample/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from .version import __version__
5 |
6 | from .datasets import *
7 | from .samplers import *
8 |
9 | #from .callbacks import *
10 | #from .constraints import *
11 | #from .regularizers import *
12 |
13 | #from . import functions
14 | #from . import transforms
15 | from . import modules
16 |
--------------------------------------------------------------------------------
/torchsample/callbacks.py:
--------------------------------------------------------------------------------
1 | """
2 | SuperModule Callbacks
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import print_function
7 |
8 | from collections import OrderedDict
9 | from collections import Iterable
10 | import warnings
11 |
12 | import os
13 | import csv
14 | import time
15 | from tempfile import NamedTemporaryFile
16 | import shutil
17 | import datetime
18 | import numpy as np
19 |
20 | from tqdm import tqdm
21 |
22 | import torch as th
23 |
24 |
25 | def _get_current_time():
26 | return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p")
27 |
28 | class CallbackContainer(object):
29 | """
30 | Container holding a list of callbacks.
31 | """
32 | def __init__(self, callbacks=None, queue_length=10):
33 | callbacks = callbacks or []
34 | self.callbacks = [c for c in callbacks]
35 | self.queue_length = queue_length
36 |
37 | def append(self, callback):
38 | self.callbacks.append(callback)
39 |
40 | def set_params(self, params):
41 | for callback in self.callbacks:
42 | callback.set_params(params)
43 |
44 | def set_trainer(self, trainer):
45 | self.trainer = trainer
46 | for callback in self.callbacks:
47 | callback.set_trainer(trainer)
48 |
49 | def on_epoch_begin(self, epoch, logs=None):
50 | logs = logs or {}
51 | for callback in self.callbacks:
52 | callback.on_epoch_begin(epoch, logs)
53 |
54 | def on_epoch_end(self, epoch, logs=None):
55 | logs = logs or {}
56 | for callback in self.callbacks:
57 | callback.on_epoch_end(epoch, logs)
58 |
59 | def on_batch_begin(self, batch, logs=None):
60 | logs = logs or {}
61 | for callback in self.callbacks:
62 | callback.on_batch_begin(batch, logs)
63 |
64 | def on_batch_end(self, batch, logs=None):
65 | logs = logs or {}
66 | for callback in self.callbacks:
67 | callback.on_batch_end(batch, logs)
68 |
69 | def on_train_begin(self, logs=None):
70 | logs = logs or {}
71 | logs['start_time'] = _get_current_time()
72 | for callback in self.callbacks:
73 | callback.on_train_begin(logs)
74 |
75 | def on_train_end(self, logs=None):
76 | logs = logs or {}
77 | logs['final_loss'] = self.trainer.history.epoch_losses[-1],
78 | logs['best_loss'] = min(self.trainer.history.epoch_losses),
79 | logs['stop_time'] = _get_current_time()
80 | for callback in self.callbacks:
81 | callback.on_train_end(logs)
82 |
83 |
84 | class Callback(object):
85 | """
86 | Abstract base class used to build new callbacks.
87 | """
88 |
89 | def __init__(self):
90 | pass
91 |
92 | def set_params(self, params):
93 | self.params = params
94 |
95 | def set_trainer(self, model):
96 | self.trainer = model
97 |
98 | def on_epoch_begin(self, epoch, logs=None):
99 | pass
100 |
101 | def on_epoch_end(self, epoch, logs=None):
102 | pass
103 |
104 | def on_batch_begin(self, batch, logs=None):
105 | pass
106 |
107 | def on_batch_end(self, batch, logs=None):
108 | pass
109 |
110 | def on_train_begin(self, logs=None):
111 | pass
112 |
113 | def on_train_end(self, logs=None):
114 | pass
115 |
116 |
117 | class TQDM(Callback):
118 |
119 | def __init__(self):
120 | """
121 | TQDM Progress Bar callback
122 |
123 | This callback is automatically applied to
124 | every SuperModule if verbose > 0
125 | """
126 | self.progbar = None
127 | super(TQDM, self).__init__()
128 |
129 | def __enter__(self):
130 | return self
131 |
132 | def __exit__(self, exc_type, exc_val, exc_tb):
133 | # make sure the dbconnection gets closed
134 | if self.progbar is not None:
135 | self.progbar.close()
136 |
137 | def on_train_begin(self, logs):
138 | self.train_logs = logs
139 |
140 | def on_epoch_begin(self, epoch, logs=None):
141 | try:
142 | self.progbar = tqdm(total=self.train_logs['num_batches'],
143 | unit=' batches')
144 | self.progbar.set_description('Epoch %i/%i' %
145 | (epoch+1, self.train_logs['num_epoch']))
146 | except:
147 | pass
148 |
149 | def on_epoch_end(self, epoch, logs=None):
150 | log_data = {key: '%.04f' % value for key, value in self.trainer.history.batch_metrics.items()}
151 | for k, v in logs.items():
152 | if k.endswith('metric'):
153 | log_data[k.split('_metric')[0]] = '%.02f' % v
154 | else:
155 | log_data[k] = v
156 | self.progbar.set_postfix(log_data)
157 | self.progbar.update()
158 | self.progbar.close()
159 |
160 | def on_batch_begin(self, batch, logs=None):
161 | self.progbar.update(1)
162 |
163 | def on_batch_end(self, batch, logs=None):
164 | log_data = {key: '%.04f' % value for key, value in self.trainer.history.batch_metrics.items()}
165 | for k, v in logs.items():
166 | if k.endswith('metric'):
167 | log_data[k.split('_metric')[0]] = '%.02f' % v
168 | self.progbar.set_postfix(log_data)
169 |
170 |
171 | class History(Callback):
172 | """
173 | Callback that records events into a `History` object.
174 |
175 | This callback is automatically applied to
176 | every SuperModule.
177 | """
178 | def __init__(self, model):
179 | super(History, self).__init__()
180 | self.samples_seen = 0.
181 | self.trainer = model
182 |
183 | def on_train_begin(self, logs=None):
184 | self.epoch_metrics = {
185 | 'loss': []
186 | }
187 | self.batch_size = logs['batch_size']
188 | self.has_val_data = logs['has_val_data']
189 | self.has_regularizers = logs['has_regularizers']
190 | if self.has_val_data:
191 | self.epoch_metrics['val_loss'] = []
192 | if self.has_regularizers:
193 | self.epoch_metrics['reg_loss'] = []
194 |
195 | def on_epoch_begin(self, epoch, logs=None):
196 | self.batch_metrics = {
197 | 'loss': 0.
198 | }
199 | if self.has_regularizers:
200 | self.batch_metrics['reg_loss'] = 0.
201 | self.samples_seen = 0.
202 |
203 | def on_epoch_end(self, epoch, logs=None):
204 | #for k in self.batch_metrics:
205 | # k_log = k.split('_metric')[0]
206 | # self.epoch_metrics.update(self.batch_metrics)
207 | # TODO
208 | pass
209 |
210 | def on_batch_end(self, batch, logs=None):
211 | for k in self.batch_metrics:
212 | self.batch_metrics[k] = (self.samples_seen*self.batch_metrics[k] + logs[k]*self.batch_size) / (self.samples_seen+self.batch_size)
213 | self.samples_seen += self.batch_size
214 |
215 | def __getitem__(self, name):
216 | return self.epoch_metrics[name]
217 |
218 | def __repr__(self):
219 | return str(self.epoch_metrics)
220 |
221 | def __str__(self):
222 | return str(self.epoch_metrics)
223 |
224 |
225 | class ModelCheckpoint(Callback):
226 | """
227 | Model Checkpoint to save model weights during training
228 |
229 | save_checkpoint({
230 | 'epoch': epoch + 1,
231 | 'arch': args.arch,
232 | 'state_dict': model.state_dict(),
233 | 'best_prec1': best_prec1,
234 | 'optimizer' : optimizer.state_dict(),
235 | }
236 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
237 | th.save(state, filename)
238 | if is_best:
239 | shutil.copyfile(filename, 'model_best.pth.tar')
240 |
241 | """
242 |
243 | def __init__(self,
244 | directory,
245 | filename='ckpt.pth.tar',
246 | monitor='val_loss',
247 | save_best_only=False,
248 | save_weights_only=True,
249 | max_save=-1,
250 | verbose=0):
251 | """
252 | Model Checkpoint to save model weights during training
253 |
254 | Arguments
255 | ---------
256 | file : string
257 | file to which model will be saved.
258 | It can be written 'filename_{epoch}_{loss}' and those
259 | values will be filled in before saving.
260 | monitor : string in {'val_loss', 'loss'}
261 | whether to monitor train or val loss
262 | save_best_only : boolean
263 | whether to only save if monitored value has improved
264 | save_weight_only : boolean
265 | whether to save entire model or just weights
266 | NOTE: only `True` is supported at the moment
267 | max_save : integer > 0 or -1
268 | the max number of models to save. Older model checkpoints
269 | will be overwritten if necessary. Set equal to -1 to have
270 | no limit
271 | verbose : integer in {0, 1}
272 | verbosity
273 | """
274 | if directory.startswith('~'):
275 | directory = os.path.expanduser(directory)
276 | self.directory = directory
277 | self.filename = filename
278 | self.file = os.path.join(self.directory, self.filename)
279 | self.monitor = monitor
280 | self.save_best_only = save_best_only
281 | self.save_weights_only = save_weights_only
282 | self.max_save = max_save
283 | self.verbose = verbose
284 |
285 | if self.max_save > 0:
286 | self.old_files = []
287 |
288 | # mode = 'min' only supported
289 | self.best_loss = float('inf')
290 | super(ModelCheckpoint, self).__init__()
291 |
292 | def save_checkpoint(self, epoch, file, is_best=False):
293 | th.save({
294 | 'epoch': epoch + 1,
295 | #'arch': args.arch,
296 | 'state_dict': self.trainer.model.state_dict(),
297 | #'best_prec1': best_prec1,
298 | 'optimizer' : self.trainer._optimizer.state_dict(),
299 | #'loss':{},
300 | # #'regularizers':{},
301 | # #'constraints':{},
302 | # #'initializers':{},
303 | # #'metrics':{},
304 | # #'val_loss':{}
305 | }, file)
306 | if is_best:
307 | shutil.copyfile(file, 'model_best.pth.tar')
308 |
309 | def on_epoch_end(self, epoch, logs=None):
310 |
311 | file = self.file.format(epoch='%03i'%(epoch+1),
312 | loss='%0.4f'%logs[self.monitor])
313 | if self.save_best_only:
314 | current_loss = logs.get(self.monitor)
315 | if current_loss is None:
316 | pass
317 | else:
318 | if current_loss < self.best_loss:
319 | if self.verbose > 0:
320 | print('\nEpoch %i: improved from %0.4f to %0.4f saving model to %s' %
321 | (epoch+1, self.best_loss, current_loss, file))
322 | self.best_loss = current_loss
323 | #if self.save_weights_only:
324 | #else:
325 | self.save_checkpoint(epoch, file)
326 | if self.max_save > 0:
327 | if len(self.old_files) == self.max_save:
328 | try:
329 | os.remove(self.old_files[0])
330 | except:
331 | pass
332 | self.old_files = self.old_files[1:]
333 | self.old_files.append(file)
334 | else:
335 | if self.verbose > 0:
336 | print('\nEpoch %i: saving model to %s' % (epoch+1, file))
337 | self.save_checkpoint(epoch, file)
338 | if self.max_save > 0:
339 | if len(self.old_files) == self.max_save:
340 | try:
341 | os.remove(self.old_files[0])
342 | except:
343 | pass
344 | self.old_files = self.old_files[1:]
345 | self.old_files.append(file)
346 |
347 |
348 | class EarlyStopping(Callback):
349 | """
350 | Early Stopping to terminate training early under certain conditions
351 | """
352 |
353 | def __init__(self,
354 | monitor='val_loss',
355 | min_delta=0,
356 | patience=5):
357 | """
358 | EarlyStopping callback to exit the training loop if training or
359 | validation loss does not improve by a certain amount for a certain
360 | number of epochs
361 |
362 | Arguments
363 | ---------
364 | monitor : string in {'val_loss', 'loss'}
365 | whether to monitor train or val loss
366 | min_delta : float
367 | minimum change in monitored value to qualify as improvement.
368 | This number should be positive.
369 | patience : integer
370 | number of epochs to wait for improvment before terminating.
371 | the counter be reset after each improvment
372 | """
373 | self.monitor = monitor
374 | self.min_delta = min_delta
375 | self.patience = patience
376 | self.wait = 0
377 | self.best_loss = 1e-15
378 | self.stopped_epoch = 0
379 | super(EarlyStopping, self).__init__()
380 |
381 | def on_train_begin(self, logs=None):
382 | self.wait = 0
383 | self.best_loss = 1e15
384 |
385 | def on_epoch_end(self, epoch, logs=None):
386 | current_loss = logs.get(self.monitor)
387 | if current_loss is None:
388 | pass
389 | else:
390 | if (current_loss - self.best_loss) < -self.min_delta:
391 | self.best_loss = current_loss
392 | self.wait = 1
393 | else:
394 | if self.wait >= self.patience:
395 | self.stopped_epoch = epoch + 1
396 | self.trainer._stop_training = True
397 | self.wait += 1
398 |
399 | def on_train_end(self, logs):
400 | if self.stopped_epoch > 0:
401 | print('\nTerminated Training for Early Stopping at Epoch %04i' %
402 | (self.stopped_epoch))
403 |
404 |
405 | class LRScheduler(Callback):
406 | """
407 | Schedule the learning rate according to some function of the
408 | current epoch index, current learning rate, and current train/val loss.
409 | """
410 |
411 | def __init__(self, schedule):
412 | """
413 | LearningRateScheduler callback to adapt the learning rate
414 | according to some function
415 |
416 | Arguments
417 | ---------
418 | schedule : callable
419 | should return a number of learning rates equal to the number
420 | of optimizer.param_groups. It should take the epoch index and
421 | **kwargs (or logs) as argument. **kwargs (or logs) will return
422 | the epoch logs such as mean training and validation loss from
423 | the epoch
424 | """
425 | if isinstance(schedule, dict):
426 | schedule = self.schedule_from_dict
427 | self.schedule_dict = schedule
428 | if any([k < 1.0 for k in schedule.keys()]):
429 | self.fractional_bounds = False
430 | else:
431 | self.fractional_bounds = True
432 | self.schedule = schedule
433 | super(LRScheduler, self).__init__()
434 |
435 | def schedule_from_dict(self, epoch, logs=None):
436 | for epoch_bound, learn_rate in self.schedule_dict.items():
437 | # epoch_bound is in units of "epochs"
438 | if not self.fractional_bounds:
439 | if epoch_bound < epoch:
440 | return learn_rate
441 | # epoch_bound is in units of "cumulative percent of epochs"
442 | else:
443 | if epoch <= epoch_bound*logs['num_epoch']:
444 | return learn_rate
445 | warnings.warn('Check the keys in the schedule dict.. Returning last value')
446 | return learn_rate
447 |
448 | def on_epoch_begin(self, epoch, logs=None):
449 | current_lrs = [p['lr'] for p in self.trainer._optimizer.param_groups]
450 | lr_list = self.schedule(epoch, current_lrs, **logs)
451 | if not isinstance(lr_list, list):
452 | lr_list = [lr_list]
453 |
454 | for param_group, lr_change in zip(self.trainer._optimizer.param_groups, lr_list):
455 | param_group['lr'] = lr_change
456 |
457 |
458 | class ReduceLROnPlateau(Callback):
459 | """
460 | Reduce the learning rate if the train or validation loss plateaus
461 | """
462 |
463 | def __init__(self,
464 | monitor='val_loss',
465 | factor=0.1,
466 | patience=10,
467 | epsilon=0,
468 | cooldown=0,
469 | min_lr=0,
470 | verbose=0):
471 | """
472 | Reduce the learning rate if the train or validation loss plateaus
473 |
474 | Arguments
475 | ---------
476 | monitor : string in {'loss', 'val_loss'}
477 | which metric to monitor
478 | factor : floar
479 | factor to decrease learning rate by
480 | patience : integer
481 | number of epochs to wait for loss improvement before reducing lr
482 | epsilon : float
483 | how much improvement must be made to reset patience
484 | cooldown : integer
485 | number of epochs to cooldown after a lr reduction
486 | min_lr : float
487 | minimum value to ever let the learning rate decrease to
488 | verbose : integer
489 | whether to print reduction to console
490 | """
491 | self.monitor = monitor
492 | if factor >= 1.0:
493 | raise ValueError('ReduceLROnPlateau does not support a factor >= 1.0.')
494 | self.factor = factor
495 | self.min_lr = min_lr
496 | self.epsilon = epsilon
497 | self.patience = patience
498 | self.verbose = verbose
499 | self.cooldown = cooldown
500 | self.cooldown_counter = 0
501 | self.wait = 0
502 | self.best_loss = 1e15
503 | self._reset()
504 | super(ReduceLROnPlateau, self).__init__()
505 |
506 | def _reset(self):
507 | """
508 | Reset the wait and cooldown counters
509 | """
510 | self.monitor_op = lambda a, b: (a - b) < -self.epsilon
511 | self.best_loss = 1e15
512 | self.cooldown_counter = 0
513 | self.wait = 0
514 |
515 | def on_train_begin(self, logs=None):
516 | self._reset()
517 |
518 | def on_epoch_end(self, epoch, logs=None):
519 | logs = logs or {}
520 | logs['lr'] = [p['lr'] for p in self.trainer._optimizer.param_groups]
521 | current_loss = logs.get(self.monitor)
522 | if current_loss is None:
523 | pass
524 | else:
525 | # if in cooldown phase
526 | if self.cooldown_counter > 0:
527 | self.cooldown_counter -= 1
528 | self.wait = 0
529 | # if loss improved, grab new loss and reset wait counter
530 | if self.monitor_op(current_loss, self.best_loss):
531 | self.best_loss = current_loss
532 | self.wait = 0
533 | # loss didnt improve, and not in cooldown phase
534 | elif not (self.cooldown_counter > 0):
535 | if self.wait >= self.patience:
536 | for p in self.trainer._optimizer.param_groups:
537 | old_lr = p['lr']
538 | if old_lr > self.min_lr + 1e-4:
539 | new_lr = old_lr * self.factor
540 | new_lr = max(new_lr, self.min_lr)
541 | if self.verbose > 0:
542 | print('\nEpoch %05d: reducing lr from %0.3f to %0.3f' %
543 | (epoch, old_lr, new_lr))
544 | p['lr'] = new_lr
545 | self.cooldown_counter = self.cooldown
546 | self.wait = 0
547 | self.wait += 1
548 |
549 |
550 | class CSVLogger(Callback):
551 | """
552 | Logs epoch-level metrics to a CSV file
553 | """
554 |
555 | def __init__(self,
556 | file,
557 | separator=',',
558 | append=False):
559 | """
560 | Logs epoch-level metrics to a CSV file
561 |
562 | Arguments
563 | ---------
564 | file : string
565 | path to csv file
566 | separator : string
567 | delimiter for file
568 | apped : boolean
569 | whether to append result to existing file or make new file
570 | """
571 | self.file = file
572 | self.sep = separator
573 | self.append = append
574 | self.writer = None
575 | self.keys = None
576 | self.append_header = True
577 | super(CSVLogger, self).__init__()
578 |
579 | def on_train_begin(self, logs=None):
580 | if self.append:
581 | if os.path.exists(self.file):
582 | with open(self.file) as f:
583 | self.append_header = not bool(len(f.readline()))
584 | self.csv_file = open(self.file, 'a')
585 | else:
586 | self.csv_file = open(self.file, 'w')
587 |
588 | def on_epoch_end(self, epoch, logs=None):
589 | logs = logs or {}
590 | RK = {'num_batches', 'num_epoch'}
591 |
592 | def handle_value(k):
593 | is_zero_dim_tensor = isinstance(k, th.Tensor) and k.dim() == 0
594 | if isinstance(k, Iterable) and not is_zero_dim_tensor:
595 | return '"[%s]"' % (', '.join(map(str, k)))
596 | else:
597 | return k
598 |
599 | if not self.writer:
600 | self.keys = sorted(logs.keys())
601 |
602 | class CustomDialect(csv.excel):
603 | delimiter = self.sep
604 |
605 | self.writer = csv.DictWriter(self.csv_file,
606 | fieldnames=['epoch'] + [k for k in self.keys if k not in RK],
607 | dialect=CustomDialect)
608 | if self.append_header:
609 | self.writer.writeheader()
610 |
611 | row_dict = OrderedDict({'epoch': epoch})
612 | row_dict.update((key, handle_value(logs[key])) for key in self.keys if key not in RK)
613 | self.writer.writerow(row_dict)
614 | self.csv_file.flush()
615 |
616 | def on_train_end(self, logs=None):
617 | self.csv_file.close()
618 | self.writer = None
619 |
620 |
621 | class ExperimentLogger(Callback):
622 |
623 | def __init__(self,
624 | directory,
625 | filename='Experiment_Logger.csv',
626 | save_prefix='Model_',
627 | separator=',',
628 | append=True):
629 |
630 | self.directory = directory
631 | self.filename = filename
632 | self.file = os.path.join(self.directory, self.filename)
633 | self.save_prefix = save_prefix
634 | self.sep = separator
635 | self.append = append
636 | self.writer = None
637 | self.keys = None
638 | self.append_header = True
639 | super(ExperimentLogger, self).__init__()
640 |
641 | def on_train_begin(self, logs=None):
642 | if self.append:
643 | open_type = 'a'
644 | else:
645 | open_type = 'w'
646 |
647 | # if append is True, find whether the file already has header
648 | num_lines = 0
649 | if self.append:
650 | if os.path.exists(self.file):
651 | with open(self.file) as f:
652 | for num_lines, l in enumerate(f):
653 | pass
654 | # if header exists, DONT append header again
655 | with open(self.file) as f:
656 | self.append_header = not bool(len(f.readline()))
657 |
658 | model_idx = num_lines
659 | REJECT_KEYS={'has_validation_data'}
660 | MODEL_NAME = self.save_prefix + str(model_idx) # figure out how to get model name
661 | self.row_dict = OrderedDict({'model': MODEL_NAME})
662 | self.keys = sorted(logs.keys())
663 | for k in self.keys:
664 | if k not in REJECT_KEYS:
665 | self.row_dict[k] = logs[k]
666 |
667 | class CustomDialect(csv.excel):
668 | delimiter = self.sep
669 |
670 | with open(self.file, open_type) as csv_file:
671 | writer = csv.DictWriter(csv_file,
672 | fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS],
673 | dialect=CustomDialect)
674 | if self.append_header:
675 | writer.writeheader()
676 |
677 | writer.writerow(self.row_dict)
678 | csv_file.flush()
679 |
680 | def on_train_end(self, logs=None):
681 | REJECT_KEYS={'has_validation_data'}
682 | row_dict = self.row_dict
683 |
684 | class CustomDialect(csv.excel):
685 | delimiter = self.sep
686 | self.keys = self.keys
687 | temp_file = NamedTemporaryFile(delete=False, mode='w')
688 | with open(self.file, 'r') as csv_file, temp_file:
689 | reader = csv.DictReader(csv_file,
690 | fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS],
691 | dialect=CustomDialect)
692 | writer = csv.DictWriter(temp_file,
693 | fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS],
694 | dialect=CustomDialect)
695 | for row_idx, row in enumerate(reader):
696 | if row_idx == 0:
697 | # re-write header with on_train_end's metrics
698 | pass
699 | if row['model'] == self.row_dict['model']:
700 | writer.writerow(row_dict)
701 | else:
702 | writer.writerow(row)
703 | shutil.move(temp_file.name, self.file)
704 |
705 |
706 | class LambdaCallback(Callback):
707 | """
708 | Callback for creating simple, custom callbacks on-the-fly.
709 | """
710 | def __init__(self,
711 | on_epoch_begin=None,
712 | on_epoch_end=None,
713 | on_batch_begin=None,
714 | on_batch_end=None,
715 | on_train_begin=None,
716 | on_train_end=None,
717 | **kwargs):
718 | super(LambdaCallback, self).__init__()
719 | self.__dict__.update(kwargs)
720 | if on_epoch_begin is not None:
721 | self.on_epoch_begin = on_epoch_begin
722 | else:
723 | self.on_epoch_begin = lambda epoch, logs: None
724 | if on_epoch_end is not None:
725 | self.on_epoch_end = on_epoch_end
726 | else:
727 | self.on_epoch_end = lambda epoch, logs: None
728 | if on_batch_begin is not None:
729 | self.on_batch_begin = on_batch_begin
730 | else:
731 | self.on_batch_begin = lambda batch, logs: None
732 | if on_batch_end is not None:
733 | self.on_batch_end = on_batch_end
734 | else:
735 | self.on_batch_end = lambda batch, logs: None
736 | if on_train_begin is not None:
737 | self.on_train_begin = on_train_begin
738 | else:
739 | self.on_train_begin = lambda logs: None
740 | if on_train_end is not None:
741 | self.on_train_end = on_train_end
742 | else:
743 | self.on_train_end = lambda logs: None
744 |
745 |
--------------------------------------------------------------------------------
/torchsample/constraints.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | from fnmatch import fnmatch
6 |
7 | import torch as th
8 | from .callbacks import Callback
9 |
10 |
11 | class ConstraintContainer(object):
12 |
13 | def __init__(self, constraints):
14 | self.constraints = constraints
15 | self.batch_constraints = [c for c in self.constraints if c.unit.upper() == 'BATCH']
16 | self.epoch_constraints = [c for c in self.constraints if c.unit.upper() == 'EPOCH']
17 |
18 | def register_constraints(self, model):
19 | """
20 | Grab pointers to the weights which will be modified by constraints so
21 | that we dont have to search through the entire network using `apply`
22 | each time
23 | """
24 | # get batch constraint pointers
25 | self._batch_c_ptrs = {}
26 | for c_idx, constraint in enumerate(self.batch_constraints):
27 | self._batch_c_ptrs[c_idx] = []
28 | for name, module in model.named_modules():
29 | if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'):
30 | self._batch_c_ptrs[c_idx].append(module)
31 |
32 | # get epoch constraint pointers
33 | self._epoch_c_ptrs = {}
34 | for c_idx, constraint in enumerate(self.epoch_constraints):
35 | self._epoch_c_ptrs[c_idx] = []
36 | for name, module in model.named_modules():
37 | if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'):
38 | self._epoch_c_ptrs[c_idx].append(module)
39 |
40 | def apply_batch_constraints(self, batch_idx):
41 | for c_idx, modules in self._batch_c_ptrs.items():
42 | if (batch_idx+1) % self.constraints[c_idx].frequency == 0:
43 | for module in modules:
44 | self.constraints[c_idx](module)
45 |
46 | def apply_epoch_constraints(self, epoch_idx):
47 | for c_idx, modules in self._epoch_c_ptrs.items():
48 | if (epoch_idx+1) % self.constraints[c_idx].frequency == 0:
49 | for module in modules:
50 | self.constraints[c_idx](module)
51 |
52 |
53 | class ConstraintCallback(Callback):
54 |
55 | def __init__(self, container):
56 | self.container = container
57 |
58 | def on_batch_end(self, batch_idx, logs):
59 | self.container.apply_batch_constraints(batch_idx)
60 |
61 | def on_epoch_end(self, epoch_idx, logs):
62 | self.container.apply_epoch_constraints(epoch_idx)
63 |
64 |
65 | class Constraint(object):
66 |
67 | def __call__(self):
68 | raise NotImplementedError('Subclass much implement this method')
69 |
70 |
71 | class UnitNorm(Constraint):
72 | """
73 | UnitNorm constraint.
74 |
75 | Constraints the weights to have column-wise unit norm
76 | """
77 | def __init__(self,
78 | frequency=1,
79 | unit='batch',
80 | module_filter='*'):
81 |
82 | self.frequency = frequency
83 | self.unit = unit
84 | self.module_filter = module_filter
85 |
86 | def __call__(self, module):
87 | w = module.weight.data
88 | module.weight.data = w.div(th.norm(w,2,0))
89 |
90 |
91 | class MaxNorm(Constraint):
92 | """
93 | MaxNorm weight constraint.
94 |
95 | Constrains the weights incident to each hidden unit
96 | to have a norm less than or equal to a desired value.
97 |
98 | Any hidden unit vector with a norm less than the max norm
99 | constaint will not be altered.
100 | """
101 |
102 | def __init__(self,
103 | value,
104 | axis=0,
105 | frequency=1,
106 | unit='batch',
107 | module_filter='*'):
108 | self.value = float(value)
109 | self.axis = axis
110 |
111 | self.frequency = frequency
112 | self.unit = unit
113 | self.module_filter = module_filter
114 |
115 | def __call__(self, module):
116 | w = module.weight.data
117 | module.weight.data = th.renorm(w, 2, self.axis, self.value)
118 |
119 |
120 | class NonNeg(Constraint):
121 | """
122 | Constrains the weights to be non-negative.
123 | """
124 | def __init__(self,
125 | frequency=1,
126 | unit='batch',
127 | module_filter='*'):
128 | self.frequency = frequency
129 | self.unit = unit
130 | self.module_filter = module_filter
131 |
132 | def __call__(self, module):
133 | w = module.weight.data
134 | module.weight.data = w.gt(0).float().mul(w)
135 |
136 |
137 |
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/torchsample/datasets.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import os
7 | import fnmatch
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import PIL.Image as Image
12 | import nibabel
13 |
14 | import torch as th
15 |
16 | from . import transforms
17 |
18 |
19 | class BaseDataset(object):
20 | """An abstract class representing a Dataset.
21 |
22 | All other datasets should subclass it. All subclasses should override
23 | ``__len__``, that provides the size of the dataset, and ``__getitem__``,
24 | supporting integer indexing in range from 0 to len(self) exclusive.
25 | """
26 |
27 | def __len__(self):
28 | return len(self.inputs) if not isinstance(self.inputs, (tuple,list)) else len(self.inputs[0])
29 |
30 | def add_input_transform(self, transform, add_to_front=True, idx=None):
31 | if idx is None:
32 | idx = np.arange(len(self.num_inputs))
33 | elif not is_tuple_or_list(idx):
34 | idx = [idx]
35 |
36 | if add_to_front:
37 | for i in idx:
38 | self.input_transform[i] = transforms.Compose([transform, self.input_transform[i]])
39 | else:
40 | for i in idx:
41 | self.input_transform[i] = transforms.Compose([self.input_transform[i], transform])
42 |
43 | def add_target_transform(self, transform, add_to_front=True, idx=None):
44 | if idx is None:
45 | idx = np.arange(len(self.num_targets))
46 | elif not is_tuple_or_list(idx):
47 | idx = [idx]
48 |
49 | if add_to_front:
50 | for i in idx:
51 | self.target_transform[i] = transforms.Compose([transform, self.target_transform[i]])
52 | else:
53 | for i in idx:
54 | self.target_transform[i] = transforms.Compose([self.target_transform[i], transform])
55 |
56 | def add_co_transform(self, transform, add_to_front=True, idx=None):
57 | if idx is None:
58 | idx = np.arange(len(self.min_inputs_or_targets))
59 | elif not is_tuple_or_list(idx):
60 | idx = [idx]
61 |
62 | if add_to_front:
63 | for i in idx:
64 | self.co_transform[i] = transforms.Compose([transform, self.co_transform[i]])
65 | else:
66 | for i in idx:
67 | self.co_transform[i] = transforms.Compose([self.co_transform[i], transform])
68 |
69 | def load(self, num_samples=None, load_range=None):
70 | """
71 | Load all data or a subset of the data into actual memory.
72 | For instance, if the inputs are paths to image files, then this
73 | function will actually load those images.
74 |
75 | Arguments
76 | ---------
77 | num_samples : integer (optional)
78 | number of samples to load. if None, will load all
79 | load_range : numpy array of integers (optional)
80 | the index range of images to load
81 | e.g. np.arange(4) loads the first 4 inputs+targets
82 | """
83 | def _parse_shape(x):
84 | if isinstance(x, (list,tuple)):
85 | return (len(x),)
86 | elif isinstance(x, th.Tensor):
87 | return x.size()
88 | else:
89 | return (1,)
90 |
91 | if num_samples is None and load_range is None:
92 | num_samples = len(self)
93 | load_range = np.arange(num_samples)
94 | elif num_samples is None and load_range is not None:
95 | num_samples = len(load_range)
96 | elif num_samples is not None and load_range is None:
97 | load_range = np.arange(num_samples)
98 |
99 |
100 | if self.has_target:
101 | for enum_idx, sample_idx in enumerate(load_range):
102 | input_sample, target_sample = self.__getitem__(sample_idx)
103 |
104 | if enum_idx == 0:
105 | if self.num_inputs == 1:
106 | _shape = [len(load_range)] + list(_parse_shape(input_sample))
107 | inputs = np.empty(_shape)
108 | else:
109 | inputs = []
110 | for i in range(self.num_inputs):
111 | _shape = [len(load_range)] + list(_parse_shape(input_sample[i]))
112 | inputs.append(np.empty(_shape))
113 | #inputs = [np.empty((len(load_range), *_parse_shape(input_sample[i]))) for i in range(self.num_inputs)]
114 |
115 | if self.num_targets == 1:
116 | _shape = [len(load_range)] + list(_parse_shape(target_sample))
117 | targets = np.empty(_shape)
118 | #targets = np.empty((len(load_range), *_parse_shape(target_sample)))
119 | else:
120 | targets = []
121 | for i in range(self.num_targets):
122 | _shape = [len(load_range)] + list(_parse_shape(target_sample[i]))
123 | targets.append(np.empty(_shape))
124 | #targets = [np.empty((len(load_range), *_parse_shape(target_sample[i]))) for i in range(self.num_targets)]
125 |
126 | if self.num_inputs == 1:
127 | inputs[enum_idx] = input_sample
128 | else:
129 | for i in range(self.num_inputs):
130 | inputs[i][enum_idx] = input_sample[i]
131 |
132 | if self.num_targets == 1:
133 | targets[enum_idx] = target_sample
134 | else:
135 | for i in range(self.num_targets):
136 | targets[i][enum_idx] = target_sample[i]
137 |
138 | return inputs, targets
139 | else:
140 | for enum_idx, sample_idx in enumerate(load_range):
141 | input_sample = self.__getitem__(sample_idx)
142 |
143 | if enum_idx == 0:
144 | if self.num_inputs == 1:
145 | _shape = [len(load_range)] + list(_parse_shape(input_sample))
146 | inputs = np.empty(_shape)
147 | #inputs = np.empty((len(load_range), *_parse_shape(input_sample)))
148 | else:
149 | inputs = []
150 | for i in range(self.num_inputs):
151 | _shape = [len(load_range)] + list(_parse_shape(input_sample[i]))
152 | inputs.append(np.empty(_shape))
153 | #inputs = [np.empty((len(load_range), *_parse_shape(input_sample[i]))) for i in range(self.num_inputs)]
154 |
155 | if self.num_inputs == 1:
156 | inputs[enum_idx] = input_sample
157 | else:
158 | for i in range(self.num_inputs):
159 | inputs[i][enum_idx] = input_sample[i]
160 |
161 | return inputs
162 |
163 | def fit_transforms(self):
164 | """
165 | Make a single pass through the entire dataset in order to fit
166 | any parameters of the transforms which require the entire dataset.
167 | e.g. StandardScaler() requires mean and std for the entire dataset.
168 |
169 | If you dont call this fit function, then transforms which require properties
170 | of the entire dataset will just work at the batch level.
171 | e.g. StandardScaler() will normalize each batch by the specific batch mean/std
172 | """
173 | it_fit = hasattr(self.input_transform, 'update_fit')
174 | tt_fit = hasattr(self.target_transform, 'update_fit')
175 | ct_fit = hasattr(self.co_transform, 'update_fit')
176 | if it_fit or tt_fit or ct_fit:
177 | for sample_idx in range(len(self)):
178 | if hasattr(self, 'input_loader'):
179 | x = self.input_loader(self.inputs[sample_idx])
180 | else:
181 | x = self.inputs[sample_idx]
182 | if it_fit:
183 | self.input_transform.update_fit(x)
184 | if self.has_target:
185 | if hasattr(self, 'target_loader'):
186 | y = self.target_loader(self.targets[sample_idx])
187 | else:
188 | y = self.targets[sample_idx]
189 | if tt_fit:
190 | self.target_transform.update_fit(y)
191 | if ct_fit:
192 | self.co_transform.update_fit(x,y)
193 |
194 |
195 | def _process_array_argument(x):
196 | if not is_tuple_or_list(x):
197 | x = [x]
198 | return x
199 |
200 |
201 | class TensorDataset(BaseDataset):
202 |
203 | def __init__(self,
204 | inputs,
205 | targets=None,
206 | input_transform=None,
207 | target_transform=None,
208 | co_transform=None):
209 | """
210 | Dataset class for loading in-memory data.
211 |
212 | Arguments
213 | ---------
214 | inputs: numpy array
215 |
216 | targets : numpy array
217 |
218 | input_transform : class with __call__ function implemented
219 | transform to apply to input sample individually
220 |
221 | target_transform : class with __call__ function implemented
222 | transform to apply to target sample individually
223 |
224 | co_transform : class with __call__ function implemented
225 | transform to apply to both input and target sample simultaneously
226 |
227 | """
228 | self.inputs = _process_array_argument(inputs)
229 | self.num_inputs = len(self.inputs)
230 | self.input_return_processor = _return_first_element_of_list if self.num_inputs==1 else _pass_through
231 |
232 | if targets is None:
233 | self.has_target = False
234 | else:
235 | self.targets = _process_array_argument(targets)
236 | self.num_targets = len(self.targets)
237 | self.target_return_processor = _return_first_element_of_list if self.num_targets==1 else _pass_through
238 | self.min_inputs_or_targets = min(self.num_inputs, self.num_targets)
239 | self.has_target = True
240 |
241 | self.input_transform = _process_transform_argument(input_transform, self.num_inputs)
242 | if self.has_target:
243 | self.target_transform = _process_transform_argument(target_transform, self.num_targets)
244 | self.co_transform = _process_co_transform_argument(co_transform, self.num_inputs, self.num_targets)
245 |
246 | def __getitem__(self, index):
247 | """
248 | Index the dataset and return the input + target
249 | """
250 | input_sample = [self.input_transform[i](self.inputs[i][index]) for i in range(self.num_inputs)]
251 |
252 | if self.has_target:
253 | target_sample = [self.target_transform[i](self.targets[i][index]) for i in range(self.num_targets)]
254 | #for i in range(self.min_inputs_or_targets):
255 | # input_sample[i], target_sample[i] = self.co_transform[i](input_sample[i], target_sample[i])
256 |
257 | return self.input_return_processor(input_sample), self.target_return_processor(target_sample)
258 | else:
259 | return self.input_return_processor(input_sample)
260 |
261 |
262 | def default_file_reader(x):
263 | def pil_loader(path):
264 | return Image.open(path).convert('RGB')
265 | def npy_loader(path):
266 | return np.load(path)
267 | def nifti_loader(path):
268 | return nibabel.load(path).get_data()
269 | if isinstance(x, str):
270 | if x.endswith('.npy'):
271 | x = npy_loader(x)
272 | elif x.endsiwth('.nii.gz'):
273 | x = nifti_loader(x)
274 | else:
275 | try:
276 | x = pil_loader(x)
277 | except:
278 | raise ValueError('File Format is not supported')
279 | #else:
280 | #raise ValueError('x should be string, but got %s' % type(x))
281 | return x
282 |
283 | def is_tuple_or_list(x):
284 | return isinstance(x, (tuple,list))
285 |
286 | def _process_transform_argument(tform, num_inputs):
287 | tform = tform if tform is not None else _pass_through
288 | if is_tuple_or_list(tform):
289 | if len(tform) != num_inputs:
290 | raise Exception('If transform is list, must provide one transform for each input')
291 | tform = [t if t is not None else _pass_through for t in tform]
292 | else:
293 | tform = [tform] * num_inputs
294 | return tform
295 |
296 | def _process_co_transform_argument(tform, num_inputs, num_targets):
297 | tform = tform if tform is not None else _multi_arg_pass_through
298 | if is_tuple_or_list(tform):
299 | if len(tform) != num_inputs:
300 | raise Exception('If transform is list, must provide one transform for each input')
301 | tform = [t if t is not None else _multi_arg_pass_through for t in tform]
302 | else:
303 | tform = [tform] * min(num_inputs, num_targets)
304 | return tform
305 |
306 | def _process_csv_argument(csv):
307 | if isinstance(csv, str):
308 | df = pd.read_csv(csv)
309 | elif isinstance(csv, pd.DataFrame):
310 | df = csv
311 | else:
312 | raise ValueError('csv argument must be string or dataframe')
313 | return df
314 |
315 | def _select_dataframe_columns(df, cols):
316 | if isinstance(cols[0], str):
317 | inputs = df.loc[:,cols].values
318 | elif isinstance(cols[0], int):
319 | inputs = df.iloc[:,cols].values
320 | else:
321 | raise ValueError('Provided columns should be string column names or integer column indices')
322 | return inputs
323 |
324 | def _process_cols_argument(cols):
325 | if isinstance(cols, tuple):
326 | cols = list(cols)
327 | return cols
328 |
329 | def _return_first_element_of_list(x):
330 | return x[0]
331 |
332 | def _pass_through(x):
333 | return x
334 |
335 | def _multi_arg_pass_through(*x):
336 | return x
337 |
338 |
339 | class CSVDataset(BaseDataset):
340 |
341 | def __init__(self,
342 | csv,
343 | input_cols=[0],
344 | target_cols=[1],
345 | input_transform=None,
346 | target_transform=None,
347 | co_transform=None):
348 | """
349 | Initialize a Dataset from a CSV file/dataframe. This does NOT
350 | actually load the data into memory if the CSV contains filepaths.
351 |
352 | Arguments
353 | ---------
354 | csv : string or pandas.DataFrame
355 | if string, should be a path to a .csv file which
356 | can be loaded as a pandas dataframe
357 |
358 | input_cols : int/list of ints, or string/list of strings
359 | which columns to use as input arrays.
360 | If int(s), should be column indicies
361 | If str(s), should be column names
362 |
363 | target_cols : int/list of ints, or string/list of strings
364 | which columns to use as input arrays.
365 | If int(s), should be column indicies
366 | If str(s), should be column names
367 |
368 | input_transform : class which implements a __call__ method
369 | tranform(s) to apply to inputs during runtime loading
370 |
371 | target_tranform : class which implements a __call__ method
372 | transform(s) to apply to targets during runtime loading
373 |
374 | co_transform : class which implements a __call__ method
375 | transform(s) to apply to both inputs and targets simultaneously
376 | during runtime loading
377 | """
378 | self.input_cols = _process_cols_argument(input_cols)
379 | self.target_cols = _process_cols_argument(target_cols)
380 |
381 | self.df = _process_csv_argument(csv)
382 |
383 | self.inputs = _select_dataframe_columns(self.df, input_cols)
384 | self.num_inputs = self.inputs.shape[1]
385 | self.input_return_processor = _return_first_element_of_list if self.num_inputs==1 else _pass_through
386 |
387 | if target_cols is None:
388 | self.num_targets = 0
389 | self.has_target = False
390 | else:
391 | self.targets = _select_dataframe_columns(self.df, target_cols)
392 | self.num_targets = self.targets.shape[1]
393 | self.target_return_processor = _return_first_element_of_list if self.num_targets==1 else _pass_through
394 | self.has_target = True
395 | self.min_inputs_or_targets = min(self.num_inputs, self.num_targets)
396 |
397 | self.input_loader = default_file_reader
398 | self.target_loader = default_file_reader
399 |
400 | self.input_transform = _process_transform_argument(input_transform, self.num_inputs)
401 | if self.has_target:
402 | self.target_transform = _process_transform_argument(target_transform, self.num_targets)
403 | self.co_transform = _process_co_transform_argument(co_transform, self.num_inputs, self.num_targets)
404 |
405 | def __getitem__(self, index):
406 | """
407 | Index the dataset and return the input + target
408 | """
409 | input_sample = [self.input_transform[i](self.input_loader(self.inputs[index, i])) for i in range(self.num_inputs)]
410 |
411 | if self.has_target:
412 | target_sample = [self.target_transform[i](self.target_loader(self.targets[index, i])) for i in range(self.num_targets)]
413 | for i in range(self.min_inputs_or_targets):
414 | input_sample[i], input_sample[i] = self.co_transform[i](input_sample[i], target_sample[i])
415 |
416 | return self.input_return_processor(input_sample), self.target_return_processor(target_sample)
417 | else:
418 | return self.input_return_processor(input_sample)
419 |
420 | def split_by_column(self, col):
421 | """
422 | Split this dataset object into multiple dataset objects based on
423 | the unique factors of the given column. The number of returned
424 | datasets will be equal to the number of unique values in the given
425 | column. The transforms and original dataframe will all be transferred
426 | to the new datasets
427 |
428 | Useful for splitting a dataset into train/val/test datasets.
429 |
430 | Arguments
431 | ---------
432 | col : integer or string
433 | which column to split the data on.
434 | if int, should be column index
435 | if str, should be column name
436 |
437 | Returns
438 | -------
439 | - list of new datasets with transforms copied
440 | """
441 | if isinstance(col, int):
442 | split_vals = self.df.iloc[:,col].values.flatten()
443 |
444 | new_df_list = []
445 | for unique_split_val in np.unique(split_vals):
446 | new_df = self.df[:][self.df.iloc[:,col]==unique_split_val]
447 | new_df_list.append(new_df)
448 | elif isinstance(col, str):
449 | split_vals = self.df.loc[:,col].values.flatten()
450 |
451 | new_df_list = []
452 | for unique_split_val in np.unique(split_vals):
453 | new_df = self.df[:][self.df.loc[:,col]==unique_split_val]
454 | new_df_list.append(new_df)
455 | else:
456 | raise ValueError('col argument not valid - must be column name or index')
457 |
458 | new_datasets = []
459 | for new_df in new_df_list:
460 | new_dataset = self.copy(new_df)
461 | new_datasets.append(new_dataset)
462 |
463 | return new_datasets
464 |
465 | def train_test_split(self, train_size):
466 | if train_size < 1:
467 | train_size = int(train_size * len(self))
468 |
469 | train_indices = np.random.choice(len(self), train_size, replace=False)
470 | test_indices = np.array([i for i in range(len(self)) if i not in train_indices])
471 |
472 | train_df = self.df.iloc[train_indices,:]
473 | test_df = self.df.iloc[test_indices,:]
474 |
475 | train_dataset = self.copy(train_df)
476 | test_dataset = self.copy(test_df)
477 |
478 | return train_dataset, test_dataset
479 |
480 | def copy(self, df=None):
481 | if df is None:
482 | df = self.df
483 |
484 | return CSVDataset(df,
485 | input_cols=self.input_cols,
486 | target_cols=self.target_cols,
487 | input_transform=self.input_transform,
488 | target_transform=self.target_transform,
489 | co_transform=self.co_transform)
490 |
491 |
492 | class FolderDataset(BaseDataset):
493 |
494 | def __init__(self,
495 | root,
496 | class_mode='label',
497 | input_regex='*',
498 | target_regex=None,
499 | input_transform=None,
500 | target_transform=None,
501 | co_transform=None,
502 | input_loader='npy'):
503 | """
504 | Dataset class for loading out-of-memory data.
505 |
506 | Arguments
507 | ---------
508 | root : string
509 | path to main directory
510 |
511 | class_mode : string in `{'label', 'image'}`
512 | type of target sample to look for and return
513 | `label` = return class folder as target
514 | `image` = return another image as target as found by 'target_regex'
515 | NOTE: if class_mode == 'image', you must give an
516 | input and target regex and the input/target images should
517 | be in a folder together with no other images in that folder
518 |
519 | input_regex : string (default is any valid image file)
520 | regular expression to find input images
521 | e.g. if all your inputs have the word 'input',
522 | you'd enter something like input_regex='*input*'
523 |
524 | target_regex : string (default is Nothing)
525 | regular expression to find target images if class_mode == 'image'
526 | e.g. if all your targets have the word 'segment',
527 | you'd enter somthing like target_regex='*segment*'
528 |
529 | transform : transform class
530 | transform to apply to input sample individually
531 |
532 | target_transform : transform class
533 | transform to apply to target sample individually
534 |
535 | input_loader : string in `{'npy', 'pil', 'nifti'} or callable
536 | defines how to load samples from file
537 | if a function is provided, it should take in a file path
538 | as input and return the loaded sample.
539 |
540 | """
541 | self.input_loader = default_file_reader
542 | self.target_loader = default_file_reader if class_mode == 'image' else lambda x: x
543 |
544 | root = os.path.expanduser(root)
545 |
546 | classes, class_to_idx = _find_classes(root)
547 | inputs, targets = _finds_inputs_and_targets(root, class_mode,
548 | class_to_idx, input_regex, target_regex)
549 |
550 | if len(inputs) == 0:
551 | raise(RuntimeError('Found 0 images in subfolders of: %s' % root))
552 | else:
553 | print('Found %i images' % len(inputs))
554 |
555 | self.root = os.path.expanduser(root)
556 | self.inputs = inputs
557 | self.targets = targets
558 | self.classes = classes
559 | self.class_to_idx = class_to_idx
560 |
561 | self.input_transform = input_transform if input_transform is not None else lambda x: x
562 | if isinstance(input_transform, (tuple,list)):
563 | self.input_transform = transforms.Compose(self.input_transform)
564 | self.target_transform = target_transform if target_transform is not None else lambda x: x
565 | if isinstance(target_transform, (tuple,list)):
566 | self.target_transform = transforms.Compose(self.target_transform)
567 | self.co_transform = co_transform if co_transform is not None else lambda x,y: (x,y)
568 | if isinstance(co_transform, (tuple,list)):
569 | self.co_transform = transforms.Compose(self.co_transform)
570 |
571 | self.class_mode = class_mode
572 |
573 | def get_full_paths(self):
574 | return [os.path.join(self.root, i) for i in self.inputs]
575 |
576 | def __getitem__(self, index):
577 | input_sample = self.inputs[index]
578 | input_sample = self.input_loader(input_sample)
579 | input_sample = self.input_transform(input_sample)
580 |
581 | target_sample = self.targets[index]
582 | target_sample = self.target_loader(target_sample)
583 | target_sample = self.target_transform(target_sample)
584 |
585 | input_sample, target_sample = self.co_transform(input_sample, target_sample)
586 |
587 | return input_sample, target_sample
588 |
589 | def __len__(self):
590 | return len(self.inputs)
591 |
592 |
593 |
594 | def _find_classes(dir):
595 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
596 | classes.sort()
597 | class_to_idx = {classes[i]: i for i in range(len(classes))}
598 | return classes, class_to_idx
599 |
600 | def _is_image_file(filename):
601 | IMG_EXTENSIONS = [
602 | '.jpg', '.JPG', '.jpeg', '.JPEG',
603 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
604 | '.nii.gz', '.npy'
605 | ]
606 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
607 |
608 | def _finds_inputs_and_targets(directory, class_mode, class_to_idx=None,
609 | input_regex=None, target_regex=None, ):
610 | """
611 | Map a dataset from a root folder
612 | """
613 | if class_mode == 'image':
614 | if not input_regex and not target_regex:
615 | raise ValueError('must give input_regex and target_regex if'+
616 | ' class_mode==image')
617 | inputs = []
618 | targets = []
619 | for subdir in sorted(os.listdir(directory)):
620 | d = os.path.join(directory, subdir)
621 | if not os.path.isdir(d):
622 | continue
623 |
624 | for root, _, fnames in sorted(os.walk(d)):
625 | for fname in fnames:
626 | if _is_image_file(fname):
627 | if fnmatch.fnmatch(fname, input_regex):
628 | path = os.path.join(root, fname)
629 | inputs.append(path)
630 | if class_mode == 'label':
631 | targets.append(class_to_idx[subdir])
632 | if class_mode == 'image' and \
633 | fnmatch.fnmatch(fname, target_regex):
634 | path = os.path.join(root, fname)
635 | targets.append(path)
636 | if class_mode is None:
637 | return inputs
638 | else:
639 | return inputs, targets
640 |
--------------------------------------------------------------------------------
/torchsample/functions/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .affine import *
--------------------------------------------------------------------------------
/torchsample/functions/affine.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | from ..utils import th_iterproduct, th_flatten
7 |
8 |
9 | def F_affine2d(x, matrix, center=True):
10 | """
11 | 2D Affine image transform on torch.autograd.Variable
12 | """
13 | if matrix.dim() == 2:
14 | matrix = matrix.view(-1,2,3)
15 |
16 | A_batch = matrix[:,:,:2]
17 | if A_batch.size(0) != x.size(0):
18 | A_batch = A_batch.repeat(x.size(0),1,1)
19 | b_batch = matrix[:,:,2].unsqueeze(1)
20 |
21 | # make a meshgrid of normal coordinates
22 | _coords = th_iterproduct(x.size(1),x.size(2))
23 | coords = Variable(_coords.unsqueeze(0).repeat(x.size(0),1,1).float(),
24 | requires_grad=False)
25 | if center:
26 | # shift the coordinates so center is the origin
27 | coords[:,:,0] = coords[:,:,0] - (x.size(1) / 2. + 0.5)
28 | coords[:,:,1] = coords[:,:,1] - (x.size(2) / 2. + 0.5)
29 |
30 | # apply the coordinate transformation
31 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords)
32 |
33 | if center:
34 | # shift the coordinates back so origin is origin
35 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(1) / 2. + 0.5)
36 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(2) / 2. + 0.5)
37 |
38 | # map new coordinates using bilinear interpolation
39 | x_transformed = F_bilinear_interp2d(x, new_coords)
40 |
41 | return x_transformed
42 |
43 |
44 | def F_bilinear_interp2d(input, coords):
45 | """
46 | bilinear interpolation of 2d torch.autograd.Variable
47 | """
48 | x = torch.clamp(coords[:,:,0], 0, input.size(1)-2)
49 | x0 = x.floor()
50 | x1 = x0 + 1
51 | y = torch.clamp(coords[:,:,1], 0, input.size(2)-2)
52 | y0 = y.floor()
53 | y1 = y0 + 1
54 |
55 | stride = torch.LongTensor(input.stride())
56 | x0_ix = x0.mul(stride[1]).long()
57 | x1_ix = x1.mul(stride[1]).long()
58 | y0_ix = y0.mul(stride[2]).long()
59 | y1_ix = y1.mul(stride[2]).long()
60 |
61 | input_flat = input.view(input.size(0),-1).contiguous()
62 |
63 | vals_00 = input_flat.gather(1, x0_ix.add(y0_ix).detach())
64 | vals_10 = input_flat.gather(1, x1_ix.add(y0_ix).detach())
65 | vals_01 = input_flat.gather(1, x0_ix.add(y1_ix).detach())
66 | vals_11 = input_flat.gather(1, x1_ix.add(y1_ix).detach())
67 |
68 | xd = x - x0
69 | yd = y - y0
70 | xm = 1 - xd
71 | ym = 1 - yd
72 |
73 | x_mapped = (vals_00.mul(xm).mul(ym) +
74 | vals_10.mul(xd).mul(ym) +
75 | vals_01.mul(xm).mul(yd) +
76 | vals_11.mul(xd).mul(yd))
77 |
78 | return x_mapped.view_as(input)
79 |
80 |
81 | def F_batch_affine2d(x, matrix, center=True):
82 | """
83 |
84 | x : torch.Tensor
85 | shape = (Samples, C, H, W)
86 | NOTE: Assume C is always equal to 1!
87 | matrix : torch.Tensor
88 | shape = (Samples, 6) or (Samples, 2, 3)
89 |
90 | Example
91 | -------
92 | >>> x = Variable(torch.zeros(3,1,10,10))
93 | >>> x[:,:,3:7,3:7] = 1
94 | >>> m1 = torch.FloatTensor([[1.2,0,0],[0,1.2,0]])
95 | >>> m2 = torch.FloatTensor([[0.8,0,0],[0,0.8,0]])
96 | >>> m3 = torch.FloatTensor([[1.0,0,3],[0,1.0,3]])
97 | >>> matrix = Variable(torch.stack([m1,m2,m3]))
98 | >>> xx = F_batch_affine2d(x,matrix)
99 | """
100 | if matrix.dim() == 2:
101 | matrix = matrix.view(-1,2,3)
102 |
103 | A_batch = matrix[:,:,:2]
104 | b_batch = matrix[:,:,2].unsqueeze(1)
105 |
106 | # make a meshgrid of normal coordinates
107 | _coords = th_iterproduct(x.size(2),x.size(3))
108 | coords = Variable(_coords.unsqueeze(0).repeat(x.size(0),1,1).float(),
109 | requires_grad=False)
110 |
111 | if center:
112 | # shift the coordinates so center is the origin
113 | coords[:,:,0] = coords[:,:,0] - (x.size(2) / 2. + 0.5)
114 | coords[:,:,1] = coords[:,:,1] - (x.size(3) / 2. + 0.5)
115 |
116 | # apply the coordinate transformation
117 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords)
118 |
119 | if center:
120 | # shift the coordinates back so origin is origin
121 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(2) / 2. + 0.5)
122 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(3) / 2. + 0.5)
123 |
124 | # map new coordinates using bilinear interpolation
125 | x_transformed = F_batch_bilinear_interp2d(x, new_coords)
126 |
127 | return x_transformed
128 |
129 |
130 | def F_batch_bilinear_interp2d(input, coords):
131 | """
132 | input : torch.Tensor
133 | size = (N,H,W,C)
134 | coords : torch.Tensor
135 | size = (N,H*W*C,2)
136 | """
137 | x = torch.clamp(coords[:,:,0], 0, input.size(2)-2)
138 | x0 = x.floor()
139 | x1 = x0 + 1
140 | y = torch.clamp(coords[:,:,1], 0, input.size(3)-2)
141 | y0 = y.floor()
142 | y1 = y0 + 1
143 |
144 | stride = torch.LongTensor(input.stride())
145 | x0_ix = x0.mul(stride[2]).long()
146 | x1_ix = x1.mul(stride[2]).long()
147 | y0_ix = y0.mul(stride[3]).long()
148 | y1_ix = y1.mul(stride[3]).long()
149 |
150 | input_flat = input.view(input.size(0),-1).contiguous()
151 |
152 | vals_00 = input_flat.gather(1, x0_ix.add(y0_ix).detach())
153 | vals_10 = input_flat.gather(1, x1_ix.add(y0_ix).detach())
154 | vals_01 = input_flat.gather(1, x0_ix.add(y1_ix).detach())
155 | vals_11 = input_flat.gather(1, x1_ix.add(y1_ix).detach())
156 |
157 | xd = x - x0
158 | yd = y - y0
159 | xm = 1 - xd
160 | ym = 1 - yd
161 |
162 | x_mapped = (vals_00.mul(xm).mul(ym) +
163 | vals_10.mul(xd).mul(ym) +
164 | vals_01.mul(xm).mul(yd) +
165 | vals_11.mul(xd).mul(yd))
166 |
167 | return x_mapped.view_as(input)
168 |
169 |
170 | def F_affine3d(x, matrix, center=True):
171 | A = matrix[:3,:3]
172 | b = matrix[:3,3]
173 |
174 | # make a meshgrid of normal coordinates
175 | coords = Variable(th_iterproduct(x.size(1),x.size(2),x.size(3)).float(),
176 | requires_grad=False)
177 |
178 | if center:
179 | # shift the coordinates so center is the origin
180 | coords[:,0] = coords[:,0] - (x.size(1) / 2. + 0.5)
181 | coords[:,1] = coords[:,1] - (x.size(2) / 2. + 0.5)
182 | coords[:,2] = coords[:,2] - (x.size(3) / 2. + 0.5)
183 |
184 |
185 | # apply the coordinate transformation
186 | new_coords = F.linear(coords, A, b)
187 |
188 | if center:
189 | # shift the coordinates back so origin is origin
190 | new_coords[:,0] = new_coords[:,0] + (x.size(1) / 2. + 0.5)
191 | new_coords[:,1] = new_coords[:,1] + (x.size(2) / 2. + 0.5)
192 | new_coords[:,2] = new_coords[:,2] + (x.size(3) / 2. + 0.5)
193 |
194 | # map new coordinates using bilinear interpolation
195 | x_transformed = F_trilinear_interp3d(x, new_coords)
196 |
197 | return x_transformed
198 |
199 |
200 | def F_trilinear_interp3d(input, coords):
201 | """
202 | trilinear interpolation of 3D image
203 | """
204 | # take clamp then floor/ceil of x coords
205 | x = torch.clamp(coords[:,0], 0, input.size(1)-2)
206 | x0 = x.floor()
207 | x1 = x0 + 1
208 | # take clamp then floor/ceil of y coords
209 | y = torch.clamp(coords[:,1], 0, input.size(2)-2)
210 | y0 = y.floor()
211 | y1 = y0 + 1
212 | # take clamp then floor/ceil of z coords
213 | z = torch.clamp(coords[:,2], 0, input.size(3)-2)
214 | z0 = z.floor()
215 | z1 = z0 + 1
216 |
217 | stride = torch.LongTensor(input.stride())[1:]
218 | x0_ix = x0.mul(stride[0]).long()
219 | x1_ix = x1.mul(stride[0]).long()
220 | y0_ix = y0.mul(stride[1]).long()
221 | y1_ix = y1.mul(stride[1]).long()
222 | z0_ix = z0.mul(stride[2]).long()
223 | z1_ix = z1.mul(stride[2]).long()
224 |
225 | input_flat = th_flatten(input)
226 |
227 | vals_000 = input_flat[x0_ix.add(y0_ix).add(z0_ix).detach()]
228 | vals_100 = input_flat[x1_ix.add(y0_ix).add(z0_ix).detach()]
229 | vals_010 = input_flat[x0_ix.add(y1_ix).add(z0_ix).detach()]
230 | vals_001 = input_flat[x0_ix.add(y0_ix).add(z1_ix).detach()]
231 | vals_101 = input_flat[x1_ix.add(y0_ix).add(z1_ix).detach()]
232 | vals_011 = input_flat[x0_ix.add(y1_ix).add(z1_ix).detach()]
233 | vals_110 = input_flat[x1_ix.add(y1_ix).add(z0_ix).detach()]
234 | vals_111 = input_flat[x1_ix.add(y1_ix).add(z1_ix).detach()]
235 |
236 | xd = x - x0
237 | yd = y - y0
238 | zd = z - z0
239 | xm = 1 - xd
240 | ym = 1 - yd
241 | zm = 1 - zd
242 |
243 | x_mapped = (vals_000.mul(xm).mul(ym).mul(zm) +
244 | vals_100.mul(xd).mul(ym).mul(zm) +
245 | vals_010.mul(xm).mul(yd).mul(zm) +
246 | vals_001.mul(xm).mul(ym).mul(zd) +
247 | vals_101.mul(xd).mul(ym).mul(zd) +
248 | vals_011.mul(xm).mul(yd).mul(zd) +
249 | vals_110.mul(xd).mul(yd).mul(zm) +
250 | vals_111.mul(xd).mul(yd).mul(zd))
251 |
252 | return x_mapped.view_as(input)
253 |
254 |
255 | def F_batch_affine3d(x, matrix, center=True):
256 | """
257 |
258 | x : torch.Tensor
259 | shape = (Samples, C, H, W)
260 | NOTE: Assume C is always equal to 1!
261 | matrix : torch.Tensor
262 | shape = (Samples, 6) or (Samples, 2, 3)
263 |
264 | Example
265 | -------
266 | >>> x = Variable(torch.zeros(3,1,10,10,10))
267 | >>> x[:,:,3:7,3:7,3:7] = 1
268 | >>> m1 = torch.FloatTensor([[1.2,0,0,0],[0,1.2,0,0],[0,0,1.2,0]])
269 | >>> m2 = torch.FloatTensor([[0.8,0,0,0],[0,0.8,0,0],[0,0,0.8,0]])
270 | >>> m3 = torch.FloatTensor([[1.0,0,0,3],[0,1.0,0,3],[0,0,1.0,3]])
271 | >>> matrix = Variable(torch.stack([m1,m2,m3]))
272 | >>> xx = F_batch_affine3d(x,matrix)
273 | """
274 | if matrix.dim() == 2:
275 | matrix = matrix.view(-1,3,4)
276 |
277 | A_batch = matrix[:,:3,:3]
278 | b_batch = matrix[:,:3,3].unsqueeze(1)
279 |
280 | # make a meshgrid of normal coordinates
281 | _coords = th_iterproduct(x.size(2),x.size(3),x.size(4))
282 | coords = Variable(_coords.unsqueeze(0).repeat(x.size(0),1,1).float(),
283 | requires_grad=False)
284 |
285 | if center:
286 | # shift the coordinates so center is the origin
287 | coords[:,:,0] = coords[:,:,0] - (x.size(2) / 2. + 0.5)
288 | coords[:,:,1] = coords[:,:,1] - (x.size(3) / 2. + 0.5)
289 | coords[:,:,2] = coords[:,:,2] - (x.size(4) / 2. + 0.5)
290 |
291 | # apply the coordinate transformation
292 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords)
293 |
294 | if center:
295 | # shift the coordinates back so origin is origin
296 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(2) / 2. + 0.5)
297 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(3) / 2. + 0.5)
298 | new_coords[:,:,2] = new_coords[:,:,2] + (x.size(4) / 2. + 0.5)
299 |
300 | # map new coordinates using bilinear interpolation
301 | x_transformed = F_batch_trilinear_interp3d(x, new_coords)
302 |
303 | return x_transformed
304 |
305 |
306 | def F_batch_trilinear_interp3d(input, coords):
307 | """
308 | input : torch.Tensor
309 | size = (N,H,W,C)
310 | coords : torch.Tensor
311 | size = (N,H*W*C,2)
312 | """
313 | x = torch.clamp(coords[:,:,0], 0, input.size(2)-2)
314 | x0 = x.floor()
315 | x1 = x0 + 1
316 | y = torch.clamp(coords[:,:,1], 0, input.size(3)-2)
317 | y0 = y.floor()
318 | y1 = y0 + 1
319 | z = torch.clamp(coords[:,:,2], 0, input.size(4)-2)
320 | z0 = z.floor()
321 | z1 = z0 + 1
322 |
323 | stride = torch.LongTensor(input.stride())
324 | x0_ix = x0.mul(stride[2]).long()
325 | x1_ix = x1.mul(stride[2]).long()
326 | y0_ix = y0.mul(stride[3]).long()
327 | y1_ix = y1.mul(stride[3]).long()
328 | z0_ix = z0.mul(stride[4]).long()
329 | z1_ix = z1.mul(stride[4]).long()
330 |
331 | input_flat = input.contiguous().view(input.size(0),-1)
332 |
333 | vals_000 = input_flat.gather(1,x0_ix.add(y0_ix).add(z0_ix).detach())
334 | vals_100 = input_flat.gather(1,x1_ix.add(y0_ix).add(z0_ix).detach())
335 | vals_010 = input_flat.gather(1,x0_ix.add(y1_ix).add(z0_ix).detach())
336 | vals_001 = input_flat.gather(1,x0_ix.add(y0_ix).add(z1_ix).detach())
337 | vals_101 = input_flat.gather(1,x1_ix.add(y0_ix).add(z1_ix).detach())
338 | vals_011 = input_flat.gather(1,x0_ix.add(y1_ix).add(z1_ix).detach())
339 | vals_110 = input_flat.gather(1,x1_ix.add(y1_ix).add(z0_ix).detach())
340 | vals_111 = input_flat.gather(1,x1_ix.add(y1_ix).add(z1_ix).detach())
341 |
342 | xd = x - x0
343 | yd = y - y0
344 | zd = z - z0
345 | xm = 1 - xd
346 | ym = 1 - yd
347 | zm = 1 - zd
348 |
349 | x_mapped = (vals_000.mul(xm).mul(ym).mul(zm) +
350 | vals_100.mul(xd).mul(ym).mul(zm) +
351 | vals_010.mul(xm).mul(yd).mul(zm) +
352 | vals_001.mul(xm).mul(ym).mul(zd) +
353 | vals_101.mul(xd).mul(ym).mul(zd) +
354 | vals_011.mul(xm).mul(yd).mul(zd) +
355 | vals_110.mul(xd).mul(yd).mul(zm) +
356 | vals_111.mul(xd).mul(yd).mul(zd))
357 |
358 | return x_mapped.view_as(input)
359 |
360 |
361 |
--------------------------------------------------------------------------------
/torchsample/initializers.py:
--------------------------------------------------------------------------------
1 | """
2 | Classes to initialize module weights
3 | """
4 |
5 | from fnmatch import fnmatch
6 |
7 | import torch.nn.init
8 |
9 |
10 | def _validate_initializer_string(init):
11 | dir_f = dir(torch.nn.init)
12 | loss_fns = [d.lower() for d in dir_f]
13 | if isinstance(init, str):
14 | try:
15 | str_idx = loss_fns.index(init.lower())
16 | except:
17 | raise ValueError('Invalid loss string input - must match pytorch function.')
18 | return getattr(torch.nn.init, dir(torch.nn.init)[str_idx])
19 | elif callable(init):
20 | return init
21 | else:
22 | raise ValueError('Invalid loss input')
23 |
24 |
25 | class InitializerContainer(object):
26 |
27 | def __init__(self, initializers):
28 | self._initializers = initializers
29 |
30 | def apply(self, model):
31 | for initializer in self._initializers:
32 | model.apply(initializer)
33 |
34 |
35 | class Initializer(object):
36 |
37 | def __call__(self, module):
38 | raise NotImplementedError('Initializer must implement this method')
39 |
40 |
41 | class GeneralInitializer(Initializer):
42 |
43 | def __init__(self, initializer, bias=False, bias_only=False, **kwargs):
44 | self._initializer = _validate_initializer_string(initializer)
45 | self.kwargs = kwargs
46 |
47 | def __call__(self, module):
48 | classname = module.__class__.__name__
49 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
50 | if self.bias_only:
51 | self._initializer(module.bias.data, **self.kwargs)
52 | else:
53 | self._initializer(module.weight.data, **self.kwargs)
54 | if self.bias:
55 | self._initializer(module.bias.data, **self.kwargs)
56 |
57 |
58 | class Normal(Initializer):
59 |
60 | def __init__(self, mean=0.0, std=0.02, bias=False,
61 | bias_only=False, module_filter='*'):
62 | self.mean = mean
63 | self.std = std
64 |
65 | self.bias = bias
66 | self.bias_only = bias_only
67 | self.module_filter = module_filter
68 |
69 | super(Normal, self).__init__()
70 |
71 | def __call__(self, module):
72 | classname = module.__class__.__name__
73 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
74 | if self.bias_only:
75 | torch.nn.init.normal(module.bias.data, mean=self.mean, std=self.std)
76 | else:
77 | torch.nn.init.normal(module.weight.data, mean=self.mean, std=self.std)
78 | if self.bias:
79 | torch.nn.init.normal(module.bias.data, mean=self.mean, std=self.std)
80 |
81 |
82 | class Uniform(Initializer):
83 |
84 | def __init__(self, a=0, b=1, bias=False, bias_only=False, module_filter='*'):
85 | self.a = a
86 | self.b = b
87 |
88 | self.bias = bias
89 | self.bias_only = bias_only
90 | self.module_filter = module_filter
91 |
92 | super(Uniform, self).__init__()
93 |
94 | def __call__(self, module):
95 | classname = module.__class__.__name__
96 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
97 | if self.bias_only:
98 | torch.nn.init.uniform(module.bias.data, a=self.a, b=self.b)
99 | else:
100 | torch.nn.init.uniform(module.weight.data, a=self.a, b=self.b)
101 | if self.bias:
102 | torch.nn.init.uniform(module.bias.data, a=self.a, b=self.b)
103 |
104 |
105 | class ConstantInitializer(Initializer):
106 |
107 | def __init__(self, value, bias=False, bias_only=False, module_filter='*'):
108 | self.value = value
109 |
110 | self.bias = bias
111 | self.bias_only = bias_only
112 | self.module_filter = module_filter
113 |
114 | super(ConstantInitializer, self).__init__()
115 |
116 | def __call__(self, module, bias=False, bias_only=False, module_filter='*'):
117 | classname = module.__class__.__name__
118 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
119 | if self.bias_only:
120 | torch.nn.init.constant(module.bias.data, val=self.value)
121 | else:
122 | torch.nn.init.constant(module.weight.data, val=self.value)
123 | if self.bias:
124 | torch.nn.init.constant(module.bias.data, val=self.value)
125 |
126 |
127 | class XavierUniform(Initializer):
128 |
129 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'):
130 | self.gain = gain
131 |
132 | self.bias = bias
133 | self.bias_only = bias_only
134 | self.module_filter = module_filter
135 |
136 | super(XavierUniform, self).__init__()
137 |
138 | def __call__(self, module):
139 | classname = module.__class__.__name__
140 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
141 | if self.bias_only:
142 | torch.nn.init.xavier_uniform(module.bias.data, gain=self.gain)
143 | else:
144 | torch.nn.init.xavier_uniform(module.weight.data, gain=self.gain)
145 | if self.bias:
146 | torch.nn.init.xavier_uniform(module.bias.data, gain=self.gain)
147 |
148 |
149 | class XavierNormal(Initializer):
150 |
151 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'):
152 | self.gain = gain
153 |
154 | self.bias = bias
155 | self.bias_only = bias_only
156 | self.module_filter = module_filter
157 |
158 | super(XavierNormal, self).__init__()
159 |
160 | def __call__(self, module):
161 | classname = module.__class__.__name__
162 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
163 | if self.bias_only:
164 | torch.nn.init.xavier_normal(module.bias.data, gain=self.gain)
165 | else:
166 | torch.nn.init.xavier_normal(module.weight.data, gain=self.gain)
167 | if self.bias:
168 | torch.nn.init.xavier_normal(module.bias.data, gain=self.gain)
169 |
170 |
171 | class KaimingUniform(Initializer):
172 |
173 | def __init__(self, a=0, mode='fan_in', bias=False, bias_only=False, module_filter='*'):
174 | self.a = a
175 | self.mode = mode
176 |
177 | self.bias = bias
178 | self.bias_only = bias_only
179 | self.module_filter = module_filter
180 |
181 | super(KaimingUniform, self).__init__()
182 |
183 | def __call__(self, module):
184 | classname = module.__class__.__name__
185 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
186 | if self.bias_only:
187 | torch.nn.init.kaiming_uniform(module.bias.data, a=self.a, mode=self.mode)
188 | else:
189 | torch.nn.init.kaiming_uniform(module.weight.data, a=self.a, mode=self.mode)
190 | if self.bias:
191 | torch.nn.init.kaiming_uniform(module.bias.data, a=self.a, mode=self.mode)
192 |
193 |
194 | class KaimingNormal(Initializer):
195 |
196 | def __init__(self, a=0, mode='fan_in', bias=False, bias_only=False, module_filter='*'):
197 | self.a = a
198 | self.mode = mode
199 |
200 | self.bias = bias
201 | self.bias_only = bias_only
202 | self.module_filter = module_filter
203 |
204 | super(KaimingNormal, self).__init__()
205 |
206 | def __call__(self, module):
207 | classname = module.__class__.__name__
208 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
209 | if self.bias_only:
210 | torch.nn.init.kaiming_normal(module.bias.data, a=self.a, mode=self.mode)
211 | else:
212 | torch.nn.init.kaiming_normal(module.weight.data, a=self.a, mode=self.mode)
213 | if self.bias:
214 | torch.nn.init.kaiming_normal(module.bias.data, a=self.a, mode=self.mode)
215 |
216 |
217 | class Orthogonal(Initializer):
218 |
219 | def __init__(self, gain=1, bias=False, bias_only=False, module_filter='*'):
220 | self.gain = gain
221 |
222 | self.bias = bias
223 | self.bias_only = bias_only
224 | self.module_filter = module_filter
225 |
226 | super(Orthogonal, self).__init__()
227 |
228 | def __call__(self, module):
229 | classname = module.__class__.__name__
230 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
231 | if self.bias_only:
232 | torch.nn.init.orthogonal(module.bias.data, gain=self.gain)
233 | else:
234 | torch.nn.init.orthogonal(module.weight.data, gain=self.gain)
235 | if self.bias:
236 | torch.nn.init.orthogonal(module.bias.data, gain=self.gain)
237 |
238 |
239 | class Sparse(Initializer):
240 |
241 | def __init__(self, sparsity, std=0.01, bias=False, bias_only=False, module_filter='*'):
242 | self.sparsity = sparsity
243 | self.std = std
244 |
245 | self.bias = bias
246 | self.bias_only = bias_only
247 | self.module_filter = module_filter
248 |
249 | super(Sparse, self).__init__()
250 |
251 | def __call__(self, module):
252 | classname = module.__class__.__name__
253 | if fnmatch(classname, self.module_filter) and hasattr(module, 'weight'):
254 | if self.bias_only:
255 | torch.nn.init.sparse(module.bias.data, sparsity=self.sparsity, std=self.std)
256 | else:
257 | torch.nn.init.sparse(module.weight.data, sparsity=self.sparsity, std=self.std)
258 | if self.bias:
259 | torch.nn.init.sparse(module.bias.data, sparsity=self.sparsity, std=self.std)
260 |
261 |
262 |
263 |
--------------------------------------------------------------------------------
/torchsample/metrics.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import print_function
4 |
5 | import torch as th
6 |
7 | from .utils import th_matrixcorr
8 |
9 | from .callbacks import Callback
10 |
11 | class MetricContainer(object):
12 |
13 |
14 | def __init__(self, metrics, prefix=''):
15 | self.metrics = metrics
16 | self.helper = None
17 | self.prefix = prefix
18 |
19 | def set_helper(self, helper):
20 | self.helper = helper
21 |
22 | def reset(self):
23 | for metric in self.metrics:
24 | metric.reset()
25 |
26 | def __call__(self, output_batch, target_batch):
27 | logs = {}
28 | for metric in self.metrics:
29 | logs[self.prefix+metric._name] = self.helper.calculate_loss(output_batch,
30 | target_batch,
31 | metric)
32 | return logs
33 |
34 | class Metric(object):
35 |
36 | def __call__(self, y_pred, y_true):
37 | raise NotImplementedError('Custom Metrics must implement this function')
38 |
39 | def reset(self):
40 | raise NotImplementedError('Custom Metrics must implement this function')
41 |
42 |
43 | class MetricCallback(Callback):
44 |
45 | def __init__(self, container):
46 | self.container = container
47 | def on_epoch_begin(self, epoch_idx, logs):
48 | self.container.reset()
49 |
50 | class CategoricalAccuracy(Metric):
51 |
52 | def __init__(self, top_k=1):
53 | self.top_k = top_k
54 | self.correct_count = 0
55 | self.total_count = 0
56 |
57 | self._name = 'acc_metric'
58 |
59 | def reset(self):
60 | self.correct_count = 0
61 | self.total_count = 0
62 |
63 | def __call__(self, y_pred, y_true):
64 | top_k = y_pred.topk(self.top_k,1)[1]
65 | true_k = y_true.view(len(y_true),1).expand_as(top_k)
66 | self.correct_count += top_k.eq(true_k).float().sum().data[0]
67 | self.total_count += len(y_pred)
68 | accuracy = 100. * float(self.correct_count) / float(self.total_count)
69 | return accuracy
70 |
71 |
72 | class BinaryAccuracy(Metric):
73 |
74 | def __init__(self):
75 | self.correct_count = 0
76 | self.total_count = 0
77 |
78 | self._name = 'acc_metric'
79 |
80 | def reset(self):
81 | self.correct_count = 0
82 | self.total_count = 0
83 |
84 | def __call__(self, y_pred, y_true):
85 | y_pred_round = y_pred.round().long()
86 | self.correct_count += y_pred_round.eq(y_true).float().sum().data[0]
87 | self.total_count += len(y_pred)
88 | accuracy = 100. * float(self.correct_count) / float(self.total_count)
89 | return accuracy
90 |
91 |
92 | class ProjectionCorrelation(Metric):
93 |
94 | def __init__(self):
95 | self.corr_sum = 0.
96 | self.total_count = 0.
97 |
98 | self._name = 'corr_metric'
99 |
100 | def reset(self):
101 | self.corr_sum = 0.
102 | self.total_count = 0.
103 |
104 | def __call__(self, y_pred, y_true=None):
105 | """
106 | y_pred should be two projections
107 | """
108 | covar_mat = th.abs(th_matrixcorr(y_pred[0].data, y_pred[1].data))
109 | self.corr_sum += th.trace(covar_mat)
110 | self.total_count += covar_mat.size(0)
111 | return self.corr_sum / self.total_count
112 |
113 |
114 | class ProjectionAntiCorrelation(Metric):
115 |
116 | def __init__(self):
117 | self.anticorr_sum = 0.
118 | self.total_count = 0.
119 |
120 | self._name = 'anticorr_metric'
121 |
122 | def reset(self):
123 | self.anticorr_sum = 0.
124 | self.total_count = 0.
125 |
126 | def __call__(self, y_pred, y_true=None):
127 | """
128 | y_pred should be two projections
129 | """
130 | covar_mat = th.abs(th_matrixcorr(y_pred[0].data, y_pred[1].data))
131 | upper_sum = th.sum(th.triu(covar_mat,1))
132 | lower_sum = th.sum(th.tril(covar_mat,-1))
133 | self.anticorr_sum += upper_sum
134 | self.anticorr_sum += lower_sum
135 | self.total_count += covar_mat.size(0)*(covar_mat.size(1) - 1)
136 | return self.anticorr_sum / self.total_count
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/torchsample/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .module_trainer import ModuleTrainer
4 |
--------------------------------------------------------------------------------
/torchsample/modules/_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import datetime
3 | import warnings
4 |
5 | try:
6 | from inspect import signature
7 | except:
8 | warnings.warn('inspect.signature not available... '
9 | 'you should upgrade to Python 3.x')
10 |
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 |
14 | from ..metrics import Metric, CategoricalAccuracy, BinaryAccuracy
15 | from ..initializers import GeneralInitializer
16 |
17 | def _add_regularizer_to_loss_fn(loss_fn,
18 | regularizer_container):
19 | def new_loss_fn(output_batch, target_batch):
20 | return loss_fn(output_batch, target_batch) + regularizer_container.get_value()
21 | return new_loss_fn
22 |
23 | def _is_iterable(x):
24 | return isinstance(x, (tuple, list))
25 | def _is_tuple_or_list(x):
26 | return isinstance(x, (tuple, list))
27 |
28 | def _parse_num_inputs_and_targets_from_loader(loader):
29 | """ NOT IMPLEMENTED """
30 | #batch = next(iter(loader))
31 | num_inputs = loader.dataset.num_inputs
32 | num_targets = loader.dataset.num_targets
33 | return num_inputs, num_targets
34 |
35 | def _parse_num_inputs_and_targets(inputs, targets=None):
36 | if isinstance(inputs, (list, tuple)):
37 | num_inputs = len(inputs)
38 | else:
39 | num_inputs = 1
40 | if targets is not None:
41 | if isinstance(targets, (list, tuple)):
42 | num_targets = len(targets)
43 | else:
44 | num_targets = 1
45 | else:
46 | num_targets = 0
47 | return num_inputs, num_targets
48 |
49 | def _standardize_user_data(inputs, targets=None):
50 | if not isinstance(inputs, (list,tuple)):
51 | inputs = [inputs]
52 | if targets is not None:
53 | if not isinstance(targets, (list,tuple)):
54 | targets = [targets]
55 | return inputs, targets
56 | else:
57 | return inputs
58 |
59 | def _validate_metric_input(metric):
60 | if isinstance(metric, str):
61 | if metric.upper() == 'CATEGORICAL_ACCURACY' or metric.upper() == 'ACCURACY':
62 | return CategoricalAccuracy()
63 | elif metric.upper() == 'BINARY_ACCURACY':
64 | return BinaryAccuracy()
65 | else:
66 | raise ValueError('Invalid metric string input - must match pytorch function.')
67 | elif isinstance(metric, Metric):
68 | return metric
69 | else:
70 | raise ValueError('Invalid metric input')
71 |
72 | def _validate_loss_input(loss):
73 | dir_f = dir(F)
74 | loss_fns = [d.lower() for d in dir_f]
75 | if isinstance(loss, str):
76 | if loss.lower() == 'unconstrained':
77 | return lambda x: x
78 | elif loss.lower() == 'unconstrained_sum':
79 | return lambda x: x.sum()
80 | elif loss.lower() == 'unconstrained_mean':
81 | return lambda x: x.mean()
82 | else:
83 | try:
84 | str_idx = loss_fns.index(loss.lower())
85 | except:
86 | raise ValueError('Invalid loss string input - must match pytorch function.')
87 | return getattr(F, dir(F)[str_idx])
88 | elif callable(loss):
89 | return loss
90 | else:
91 | raise ValueError('Invalid loss input')
92 |
93 | def _validate_optimizer_input(optimizer):
94 | dir_optim = dir(optim)
95 | opts = [o.lower() for o in dir_optim]
96 | if isinstance(optimizer, str):
97 | try:
98 | str_idx = opts.index(optimizer.lower())
99 | except:
100 | raise ValueError('Invalid optimizer string input - must match pytorch function.')
101 | return getattr(optim, dir_optim[str_idx])
102 | elif hasattr(optimizer, 'step') and hasattr(optimizer, 'zero_grad'):
103 | return optimizer
104 | else:
105 | raise ValueError('Invalid optimizer input')
106 |
107 | def _validate_initializer_input(initializer):
108 | if isinstance(initializer, str):
109 | try:
110 | initializer = GeneralInitializer(initializer)
111 | except:
112 | raise ValueError('Invalid initializer string input - must match pytorch function.')
113 | return initializer
114 | elif callable(initializer):
115 | return initializer
116 | else:
117 | raise ValueError('Invalid optimizer input')
118 |
119 | def _get_current_time():
120 | return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p")
121 |
122 | def _nb_function_args(fn):
123 | return len(signature(fn).parameters)
--------------------------------------------------------------------------------
/torchsample/regularizers.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | from fnmatch import fnmatch
4 |
5 | from .callbacks import Callback
6 |
7 | class RegularizerContainer(object):
8 |
9 | def __init__(self, regularizers):
10 | self.regularizers = regularizers
11 | self._forward_hooks = []
12 |
13 | def register_forward_hooks(self, model):
14 | for regularizer in self.regularizers:
15 | for module_name, module in model.named_modules():
16 | if fnmatch(module_name, regularizer.module_filter) and hasattr(module, 'weight'):
17 | hook = module.register_forward_hook(regularizer)
18 | self._forward_hooks.append(hook)
19 |
20 | if len(self._forward_hooks) == 0:
21 | raise Exception('Tried to register regularizers but no modules '
22 | 'were found that matched any module_filter argument.')
23 |
24 | def unregister_forward_hooks(self):
25 | for hook in self._forward_hooks:
26 | hook.remove()
27 |
28 | def reset(self):
29 | for r in self.regularizers:
30 | r.reset()
31 |
32 | def get_value(self):
33 | value = sum([r.value for r in self.regularizers])
34 | self.current_value = value.data[0]
35 | return value
36 |
37 | def __len__(self):
38 | return len(self.regularizers)
39 |
40 |
41 | class RegularizerCallback(Callback):
42 |
43 | def __init__(self, container):
44 | self.container = container
45 |
46 | def on_batch_end(self, batch, logs=None):
47 | self.container.reset()
48 |
49 |
50 | class Regularizer(object):
51 |
52 | def reset(self):
53 | raise NotImplementedError('subclass must implement this method')
54 |
55 | def __call__(self, module, input=None, output=None):
56 | raise NotImplementedError('subclass must implement this method')
57 |
58 |
59 | class L1Regularizer(Regularizer):
60 |
61 | def __init__(self, scale=1e-3, module_filter='*'):
62 | self.scale = float(scale)
63 | self.module_filter = module_filter
64 | self.value = 0.
65 |
66 | def reset(self):
67 | self.value = 0.
68 |
69 | def __call__(self, module, input=None, output=None):
70 | value = th.sum(th.abs(module.weight)) * self.scale
71 | self.value += value
72 |
73 |
74 | class L2Regularizer(Regularizer):
75 |
76 | def __init__(self, scale=1e-3, module_filter='*'):
77 | self.scale = float(scale)
78 | self.module_filter = module_filter
79 | self.value = 0.
80 |
81 | def reset(self):
82 | self.value = 0.
83 |
84 | def __call__(self, module, input=None, output=None):
85 | value = th.sum(th.pow(module.weight,2)) * self.scale
86 | self.value += value
87 |
88 |
89 | class L1L2Regularizer(Regularizer):
90 |
91 | def __init__(self, l1_scale=1e-3, l2_scale=1e-3, module_filter='*'):
92 | self.l1 = L1Regularizer(l1_scale)
93 | self.l2 = L2Regularizer(l2_scale)
94 | self.module_filter = module_filter
95 | self.value = 0.
96 |
97 | def reset(self):
98 | self.value = 0.
99 |
100 | def __call__(self, module, input=None, output=None):
101 | self.l1(module, input, output)
102 | self.l2(module, input, output)
103 | self.value += (self.l1.value + self.l2.value)
104 |
105 |
106 | # ------------------------------------------------------------------
107 | # ------------------------------------------------------------------
108 | # ------------------------------------------------------------------
109 |
110 | class UnitNormRegularizer(Regularizer):
111 | """
112 | UnitNorm constraint on Weights
113 |
114 | Constraints the weights to have column-wise unit norm
115 | """
116 | def __init__(self,
117 | scale=1e-3,
118 | module_filter='*'):
119 |
120 | self.scale = scale
121 | self.module_filter = module_filter
122 | self.value = 0.
123 |
124 | def reset(self):
125 | self.value = 0.
126 |
127 | def __call__(self, module, input=None, output=None):
128 | w = module.weight
129 | norm_diff = th.norm(w, 2, 1).sub(1.)
130 | value = self.scale * th.sum(norm_diff.gt(0).float().mul(norm_diff))
131 | self.value += value
132 |
133 |
134 | class MaxNormRegularizer(Regularizer):
135 | """
136 | MaxNorm regularizer on Weights
137 |
138 | Constraints the weights to have column-wise unit norm
139 | """
140 | def __init__(self,
141 | scale=1e-3,
142 | module_filter='*'):
143 |
144 | self.scale = scale
145 | self.module_filter = module_filter
146 | self.value = 0.
147 |
148 | def reset(self):
149 | self.value = 0.
150 |
151 | def __call__(self, module, input=None, output=None):
152 | w = module.weight
153 | norm_diff = th.norm(w,2,self.axis).sub(self.value)
154 | value = self.scale * th.sum(norm_diff.gt(0).float().mul(norm_diff))
155 | self.value += value
156 |
157 |
158 | class NonNegRegularizer(Regularizer):
159 | """
160 | Non-Negativity regularizer on Weights
161 |
162 | Constraints the weights to have column-wise unit norm
163 | """
164 | def __init__(self,
165 | scale=1e-3,
166 | module_filter='*'):
167 |
168 | self.scale = scale
169 | self.module_filter = module_filter
170 | self.value = 0.
171 |
172 | def reset(self):
173 | self.value = 0.
174 |
175 | def __call__(self, module, input=None, output=None):
176 | w = module.weight
177 | value = -1 * self.scale * th.sum(w.gt(0).float().mul(w))
178 | self.value += value
179 |
180 |
--------------------------------------------------------------------------------
/torchsample/samplers.py:
--------------------------------------------------------------------------------
1 |
2 | import torch as th
3 | import math
4 |
5 | class Sampler(object):
6 | """Base class for all Samplers.
7 |
8 | Every Sampler subclass has to provide an __iter__ method, providing a way
9 | to iterate over indices of dataset elements, and a __len__ method that
10 | returns the length of the returned iterators.
11 | """
12 |
13 | def __init__(self, data_source):
14 | pass
15 |
16 | def __iter__(self):
17 | raise NotImplementedError
18 |
19 | def __len__(self):
20 | raise NotImplementedError
21 |
22 | class StratifiedSampler(Sampler):
23 | """Stratified Sampling
24 |
25 | Provides equal representation of target classes in each batch
26 | """
27 | def __init__(self, class_vector, batch_size):
28 | """
29 | Arguments
30 | ---------
31 | class_vector : torch tensor
32 | a vector of class labels
33 | batch_size : integer
34 | batch_size
35 | """
36 | self.n_splits = int(class_vector.size(0) / batch_size)
37 | self.class_vector = class_vector
38 |
39 | def gen_sample_array(self):
40 | try:
41 | from sklearn.model_selection import StratifiedShuffleSplit
42 | except:
43 | print('Need scikit-learn for this functionality')
44 | import numpy as np
45 |
46 | s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.5)
47 | X = th.randn(self.class_vector.size(0),2).numpy()
48 | y = self.class_vector.numpy()
49 | s.get_n_splits(X, y)
50 |
51 | train_index, test_index = next(s.split(X, y))
52 | return np.hstack([train_index, test_index])
53 |
54 | def __iter__(self):
55 | return iter(self.gen_sample_array())
56 |
57 | def __len__(self):
58 | return len(self.class_vector)
59 |
60 | class MultiSampler(Sampler):
61 | """Samples elements more than once in a single pass through the data.
62 |
63 | This allows the number of samples per epoch to be larger than the number
64 | of samples itself, which can be useful when training on 2D slices taken
65 | from 3D images, for instance.
66 | """
67 | def __init__(self, nb_samples, desired_samples, shuffle=False):
68 | """Initialize MultiSampler
69 |
70 | Arguments
71 | ---------
72 | data_source : the dataset to sample from
73 |
74 | desired_samples : number of samples per batch you want
75 | whatever the difference is between an even division will
76 | be randomly selected from the samples.
77 | e.g. if len(data_source) = 3 and desired_samples = 4, then
78 | all 3 samples will be included and the last sample will be
79 | randomly chosen from the 3 original samples.
80 |
81 | shuffle : boolean
82 | whether to shuffle the indices or not
83 |
84 | Example:
85 | >>> m = MultiSampler(2, 6)
86 | >>> x = m.gen_sample_array()
87 | >>> print(x) # [0,1,0,1,0,1]
88 | """
89 | self.data_samples = nb_samples
90 | self.desired_samples = desired_samples
91 | self.shuffle = shuffle
92 |
93 | def gen_sample_array(self):
94 | from torchsample.utils import th_random_choice
95 | n_repeats = self.desired_samples / self.data_samples
96 | cat_list = []
97 | for i in range(math.floor(n_repeats)):
98 | cat_list.append(th.arange(0,self.data_samples))
99 | # add the left over samples
100 | left_over = self.desired_samples % self.data_samples
101 | if left_over > 0:
102 | cat_list.append(th_random_choice(self.data_samples, left_over))
103 | self.sample_idx_array = th.cat(cat_list).long()
104 | return self.sample_idx_array
105 |
106 | def __iter__(self):
107 | return iter(self.gen_sample_array())
108 |
109 | def __len__(self):
110 | return self.desired_samples
111 |
112 |
113 | class SequentialSampler(Sampler):
114 | """Samples elements sequentially, always in the same order.
115 |
116 | Arguments:
117 | data_source (Dataset): dataset to sample from
118 | """
119 |
120 | def __init__(self, nb_samples):
121 | self.num_samples = nb_samples
122 |
123 | def __iter__(self):
124 | return iter(range(self.num_samples))
125 |
126 | def __len__(self):
127 | return self.num_samples
128 |
129 |
130 | class RandomSampler(Sampler):
131 | """Samples elements randomly, without replacement.
132 |
133 | Arguments:
134 | data_source (Dataset): dataset to sample from
135 | """
136 |
137 | def __init__(self, nb_samples):
138 | self.num_samples = nb_samples
139 |
140 | def __iter__(self):
141 | return iter(th.randperm(self.num_samples).long())
142 |
143 | def __len__(self):
144 | return self.num_samples
145 |
146 |
147 |
--------------------------------------------------------------------------------
/torchsample/transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from .affine_transforms import *
5 | from .image_transforms import *
6 | from .tensor_transforms import *
--------------------------------------------------------------------------------
/torchsample/transforms/distortion_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Transforms to distort local or global information of an image
3 | """
4 |
5 |
6 | import torch as th
7 | import numpy as np
8 | import random
9 |
10 |
11 | class Scramble(object):
12 | """
13 | Create blocks of an image and scramble them
14 | """
15 | def __init__(self, blocksize):
16 | self.blocksize = blocksize
17 |
18 | def __call__(self, *inputs):
19 | outputs = []
20 | for idx, _input in enumerate(inputs):
21 | size = _input.size()
22 | img_height = size[1]
23 | img_width = size[2]
24 |
25 | x_blocks = int(img_height/self.blocksize) # number of x blocks
26 | y_blocks = int(img_width/self.blocksize)
27 | ind = th.randperm(x_blocks*y_blocks)
28 |
29 | new = th.zeros(_input.size())
30 | count = 0
31 | for i in range(x_blocks):
32 | for j in range (y_blocks):
33 | row = int(ind[count] / x_blocks)
34 | column = ind[count] % x_blocks
35 | new[:, i*self.blocksize:(i+1)*self.blocksize, j*self.blocksize:(j+1)*self.blocksize] = \
36 | _input[:, row*self.blocksize:(row+1)*self.blocksize, column*self.blocksize:(column+1)*self.blocksize]
37 | count += 1
38 | outputs.append(new)
39 | return outputs if idx > 1 else outputs[0]
40 |
41 |
42 | class RandomChoiceScramble(object):
43 |
44 | def __init__(self, blocksizes):
45 | self.blocksizes = blocksizes
46 |
47 | def __call__(self, *inputs):
48 | blocksize = random.choice(self.blocksizes)
49 | outputs = Scramble(blocksize=blocksize)(*inputs)
50 | return outputs
51 |
52 |
53 | def _blur_image(image, H):
54 | # break image up into its color components
55 | size = image.shape
56 | imr = image[0,:,:]
57 | img = image[1,:,:]
58 | imb = image[2,:,:]
59 |
60 | # compute Fourier transform and frequqnecy spectrum
61 | Fim1r = np.fft.fftshift(np.fft.fft2(imr))
62 | Fim1g = np.fft.fftshift(np.fft.fft2(img))
63 | Fim1b = np.fft.fftshift(np.fft.fft2(imb))
64 |
65 | # Apply the lowpass filter to the Fourier spectrum of the image
66 | filtered_imager = np.multiply(H, Fim1r)
67 | filtered_imageg = np.multiply(H, Fim1g)
68 | filtered_imageb = np.multiply(H, Fim1b)
69 |
70 | newim = np.zeros(size)
71 |
72 | # convert the result to the spatial domain.
73 | newim[0,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imager)))
74 | newim[1,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imageg)))
75 | newim[2,:,:] = np.absolute(np.real(np.fft.ifft2(filtered_imageb)))
76 |
77 | return newim.astype('uint8')
78 |
79 | def _butterworth_filter(rows, cols, thresh, order):
80 | # X and Y matrices with ranges normalised to +/- 0.5
81 | array1 = np.ones(rows)
82 | array2 = np.ones(cols)
83 | array3 = np.arange(1,rows+1)
84 | array4 = np.arange(1,cols+1)
85 |
86 | x = np.outer(array1, array4)
87 | y = np.outer(array3, array2)
88 |
89 | x = x - float(cols/2) - 1
90 | y = y - float(rows/2) - 1
91 |
92 | x = x / cols
93 | y = y / rows
94 |
95 | radius = np.sqrt(np.square(x) + np.square(y))
96 |
97 | matrix1 = radius/thresh
98 | matrix2 = np.power(matrix1, 2*order)
99 | f = np.reciprocal(1 + matrix2)
100 |
101 | return f
102 |
103 |
104 | class Blur(object):
105 | """
106 | Blur an image with a Butterworth filter with a frequency
107 | cutoff matching local block size
108 | """
109 | def __init__(self, threshold, order=5):
110 | """
111 | scramble blocksize of 128 => filter threshold of 64
112 | scramble blocksize of 64 => filter threshold of 32
113 | scramble blocksize of 32 => filter threshold of 16
114 | scramble blocksize of 16 => filter threshold of 8
115 | scramble blocksize of 8 => filter threshold of 4
116 | """
117 | self.threshold = threshold
118 | self.order = order
119 |
120 | def __call__(self, *inputs):
121 | """
122 | inputs should have values between 0 and 255
123 | """
124 | outputs = []
125 | for idx, _input in enumerate(inputs):
126 | rows = _input.size(1)
127 | cols = _input.size(2)
128 | fc = self.threshold # threshold
129 | fs = 128.0 # max frequency
130 | n = self.order # filter order
131 | fc_rad = (fc/fs)*0.5
132 | H = _butterworth_filter(rows, cols, fc_rad, n)
133 | _input_blurred = _blur_image(_input.numpy().astype('uint8'), H)
134 | _input_blurred = th.from_numpy(_input_blurred).float()
135 | outputs.append(_input_blurred)
136 |
137 | return outputs if idx > 1 else outputs[0]
138 |
139 |
140 | class RandomChoiceBlur(object):
141 |
142 | def __init__(self, thresholds, order=5):
143 | """
144 | thresholds = [64.0, 32.0, 16.0, 8.0, 4.0]
145 | """
146 | self.thresholds = thresholds
147 | self.order = order
148 |
149 | def __call__(self, *inputs):
150 | threshold = random.choice(self.thresholds)
151 | outputs = Blur(threshold=threshold, order=self.order)(*inputs)
152 | return outputs
153 |
154 |
155 |
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/torchsample/transforms/image_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Transforms very specific to images such as
3 | color, lighting, contrast, brightness, etc transforms
4 |
5 | NOTE: Most of these transforms assume your image intensity
6 | is between 0 and 1, and are torch tensors (NOT numpy or PIL)
7 | """
8 |
9 | import random
10 |
11 | import torch as th
12 |
13 | from ..utils import th_random_choice
14 |
15 |
16 | def _blend(img1, img2, alpha):
17 | """
18 | Weighted sum of two images
19 |
20 | Arguments
21 | ---------
22 | img1 : torch tensor
23 | img2 : torch tensor
24 | alpha : float between 0 and 1
25 | how much weight to put on img1 and 1-alpha weight
26 | to put on img2
27 | """
28 | return img1.mul(alpha).add(1 - alpha, img2)
29 |
30 |
31 | class Grayscale(object):
32 |
33 | def __init__(self, keep_channels=False):
34 | """
35 | Convert RGB image to grayscale
36 |
37 | Arguments
38 | ---------
39 | keep_channels : boolean
40 | If true, will keep all 3 channels and they will be the same
41 | If false, will just return 1 grayscale channel
42 | """
43 | self.keep_channels = keep_channels
44 | if keep_channels:
45 | self.channels = 3
46 | else:
47 | self.channels = 1
48 |
49 | def __call__(self, *inputs):
50 | outputs = []
51 | for idx, _input in enumerate(inputs):
52 | _input_dst = _input[0]*0.299 + _input[1]*0.587 + _input[2]*0.114
53 | _input_gs = _input_dst.repeat(self.channels,1,1)
54 | outputs.append(_input_gs)
55 | return outputs if idx > 1 else outputs[0]
56 |
57 | class RandomGrayscale(object):
58 |
59 | def __init__(self, p=0.5):
60 | """
61 | Randomly convert RGB image(s) to Grayscale w/ some probability,
62 | NOTE: Always retains the 3 channels if image is grayscaled
63 |
64 | p : a float
65 | probability that image will be grayscaled
66 | """
67 | self.p = p
68 |
69 | def __call__(self, *inputs):
70 | pval = random.random()
71 | if pval < self.p:
72 | outputs = Grayscale(keep_channels=True)(*inputs)
73 | else:
74 | outputs = inputs
75 | return outputs
76 |
77 | # ----------------------------------------------------
78 | # ----------------------------------------------------
79 |
80 | class Gamma(object):
81 |
82 | def __init__(self, value):
83 | """
84 | Performs Gamma Correction on the input image. Also known as
85 | Power Law Transform. This function transforms the input image
86 | pixelwise according
87 | to the equation Out = In**gamma after scaling each
88 | pixel to the range 0 to 1.
89 |
90 | Arguments
91 | ---------
92 | value : float
93 | <1 : image will tend to be lighter
94 | =1 : image will stay the same
95 | >1 : image will tend to be darker
96 | """
97 | self.value = value
98 |
99 | def __call__(self, *inputs):
100 | outputs = []
101 | for idx, _input in enumerate(inputs):
102 | _input = th.pow(_input, self.value)
103 | outputs.append(_input)
104 | return outputs if idx > 1 else outputs[0]
105 |
106 | class RandomGamma(object):
107 |
108 | def __init__(self, min_val, max_val):
109 | """
110 | Performs Gamma Correction on the input image with some
111 | randomly selected gamma value between min_val and max_val.
112 | Also known as Power Law Transform. This function transforms
113 | the input image pixelwise according to the equation
114 | Out = In**gamma after scaling each pixel to the range 0 to 1.
115 |
116 | Arguments
117 | ---------
118 | min_val : float
119 | min range
120 | max_val : float
121 | max range
122 |
123 | NOTE:
124 | for values:
125 | <1 : image will tend to be lighter
126 | =1 : image will stay the same
127 | >1 : image will tend to be darker
128 | """
129 | self.values = (min_val, max_val)
130 |
131 | def __call__(self, *inputs):
132 | value = random.uniform(self.values[0], self.values[1])
133 | outputs = Gamma(value)(*inputs)
134 | return outputs
135 |
136 | class RandomChoiceGamma(object):
137 |
138 | def __init__(self, values, p=None):
139 | """
140 | Performs Gamma Correction on the input image with some
141 | gamma value selected in the list of given values.
142 | Also known as Power Law Transform. This function transforms
143 | the input image pixelwise according to the equation
144 | Out = In**gamma after scaling each pixel to the range 0 to 1.
145 |
146 | Arguments
147 | ---------
148 | values : list of floats
149 | gamma values to sampled from
150 | p : list of floats - same length as `values`
151 | if None, values will be sampled uniformly.
152 | Must sum to 1.
153 |
154 | NOTE:
155 | for values:
156 | <1 : image will tend to be lighter
157 | =1 : image will stay the same
158 | >1 : image will tend to be darker
159 | """
160 | self.values = values
161 | self.p = p
162 |
163 | def __call__(self, *inputs):
164 | value = th_random_choice(self.values, p=self.p)
165 | outputs = Gamma(value)(*inputs)
166 | return outputs
167 |
168 | # ----------------------------------------------------
169 | # ----------------------------------------------------
170 |
171 | class Brightness(object):
172 | def __init__(self, value):
173 | """
174 | Alter the Brightness of an image
175 |
176 | Arguments
177 | ---------
178 | value : brightness factor
179 | =-1 = completely black
180 | <0 = darker
181 | 0 = no change
182 | >0 = brighter
183 | =1 = completely white
184 | """
185 | self.value = max(min(value,1.0),-1.0)
186 |
187 | def __call__(self, *inputs):
188 | outputs = []
189 | for idx, _input in enumerate(inputs):
190 | _input = th.clamp(_input.float().add(self.value).type(_input.type()), 0, 1)
191 | outputs.append(_input)
192 | return outputs if idx > 1 else outputs[0]
193 |
194 | class RandomBrightness(object):
195 |
196 | def __init__(self, min_val, max_val):
197 | """
198 | Alter the Brightness of an image with a value randomly selected
199 | between `min_val` and `max_val`
200 |
201 | Arguments
202 | ---------
203 | min_val : float
204 | min range
205 | max_val : float
206 | max range
207 | """
208 | self.values = (min_val, max_val)
209 |
210 | def __call__(self, *inputs):
211 | value = random.uniform(self.values[0], self.values[1])
212 | outputs = Brightness(value)(*inputs)
213 | return outputs
214 |
215 | class RandomChoiceBrightness(object):
216 |
217 | def __init__(self, values, p=None):
218 | """
219 | Alter the Brightness of an image with a value randomly selected
220 | from the list of given values with given probabilities
221 |
222 | Arguments
223 | ---------
224 | values : list of floats
225 | brightness values to sampled from
226 | p : list of floats - same length as `values`
227 | if None, values will be sampled uniformly.
228 | Must sum to 1.
229 | """
230 | self.values = values
231 | self.p = p
232 |
233 | def __call__(self, *inputs):
234 | value = th_random_choice(self.values, p=self.p)
235 | outputs = Brightness(value)(*inputs)
236 | return outputs
237 |
238 | # ----------------------------------------------------
239 | # ----------------------------------------------------
240 |
241 | class Saturation(object):
242 |
243 | def __init__(self, value):
244 | """
245 | Alter the Saturation of image
246 |
247 | Arguments
248 | ---------
249 | value : float
250 | =-1 : gray
251 | <0 : colors are more muted
252 | =0 : image stays the same
253 | >0 : colors are more pure
254 | =1 : most saturated
255 | """
256 | self.value = max(min(value,1.0),-1.0)
257 |
258 | def __call__(self, *inputs):
259 | outputs = []
260 | for idx, _input in enumerate(inputs):
261 | _in_gs = Grayscale(keep_channels=True)(_input)
262 | alpha = 1.0 + self.value
263 | _in = th.clamp(_blend(_input, _in_gs, alpha), 0, 1)
264 | outputs.append(_in)
265 | return outputs if idx > 1 else outputs[0]
266 |
267 | class RandomSaturation(object):
268 |
269 | def __init__(self, min_val, max_val):
270 | """
271 | Alter the Saturation of an image with a value randomly selected
272 | between `min_val` and `max_val`
273 |
274 | Arguments
275 | ---------
276 | min_val : float
277 | min range
278 | max_val : float
279 | max range
280 | """
281 | self.values = (min_val, max_val)
282 |
283 | def __call__(self, *inputs):
284 | value = random.uniform(self.values[0], self.values[1])
285 | outputs = Saturation(value)(*inputs)
286 | return outputs
287 |
288 | class RandomChoiceSaturation(object):
289 |
290 | def __init__(self, values, p=None):
291 | """
292 | Alter the Saturation of an image with a value randomly selected
293 | from the list of given values with given probabilities
294 |
295 | Arguments
296 | ---------
297 | values : list of floats
298 | saturation values to sampled from
299 | p : list of floats - same length as `values`
300 | if None, values will be sampled uniformly.
301 | Must sum to 1.
302 |
303 | """
304 | self.values = values
305 | self.p = p
306 |
307 | def __call__(self, *inputs):
308 | value = th_random_choice(self.values, p=self.p)
309 | outputs = Saturation(value)(*inputs)
310 | return outputs
311 |
312 | # ----------------------------------------------------
313 | # ----------------------------------------------------
314 |
315 | class Contrast(object):
316 | """
317 |
318 | """
319 | def __init__(self, value):
320 | """
321 | Adjust Contrast of image.
322 |
323 | Contrast is adjusted independently for each channel of each image.
324 |
325 | For each channel, this Op computes the mean of the image pixels
326 | in the channel and then adjusts each component x of each pixel to
327 | (x - mean) * contrast_factor + mean.
328 |
329 | Arguments
330 | ---------
331 | value : float
332 | smaller value: less contrast
333 | ZERO: channel means
334 | larger positive value: greater contrast
335 | larger negative value: greater inverse contrast
336 | """
337 | self.value = value
338 |
339 | def __call__(self, *inputs):
340 | outputs = []
341 | for idx, _input in enumerate(inputs):
342 | channel_means = _input.mean(1).mean(2)
343 | channel_means = channel_means.expand_as(_input)
344 | _input = th.clamp((_input - channel_means) * self.value + channel_means,0,1)
345 | outputs.append(_input)
346 | return outputs if idx > 1 else outputs[0]
347 |
348 | class RandomContrast(object):
349 |
350 | def __init__(self, min_val, max_val):
351 | """
352 | Alter the Contrast of an image with a value randomly selected
353 | between `min_val` and `max_val`
354 |
355 | Arguments
356 | ---------
357 | min_val : float
358 | min range
359 | max_val : float
360 | max range
361 | """
362 | self.values = (min_val, max_val)
363 |
364 | def __call__(self, *inputs):
365 | value = random.uniform(self.values[0], self.values[1])
366 | outputs = Contrast(value)(*inputs)
367 | return outputs
368 |
369 | class RandomChoiceContrast(object):
370 |
371 | def __init__(self, values, p=None):
372 | """
373 | Alter the Contrast of an image with a value randomly selected
374 | from the list of given values with given probabilities
375 |
376 | Arguments
377 | ---------
378 | values : list of floats
379 | contrast values to sampled from
380 | p : list of floats - same length as `values`
381 | if None, values will be sampled uniformly.
382 | Must sum to 1.
383 |
384 | """
385 | self.values = values
386 | self.p = p
387 |
388 | def __call__(self, *inputs):
389 | value = th_random_choice(self.values, p=None)
390 | outputs = Contrast(value)(*inputs)
391 | return outputs
392 |
393 | # ----------------------------------------------------
394 | # ----------------------------------------------------
395 |
396 | def rgb_to_hsv(x):
397 | """
398 | Convert from RGB to HSV
399 | """
400 | hsv = th.zeros(*x.size())
401 | c_min = x.min(0)
402 | c_max = x.max(0)
403 |
404 | delta = c_max[0] - c_min[0]
405 |
406 | # set H
407 | r_idx = c_max[1].eq(0)
408 | hsv[0][r_idx] = ((x[1][r_idx] - x[2][r_idx]) / delta[r_idx]) % 6
409 | g_idx = c_max[1].eq(1)
410 | hsv[0][g_idx] = 2 + ((x[2][g_idx] - x[0][g_idx]) / delta[g_idx])
411 | b_idx = c_max[1].eq(2)
412 | hsv[0][b_idx] = 4 + ((x[0][b_idx] - x[1][b_idx]) / delta[b_idx])
413 | hsv[0] = hsv[0].mul(60)
414 |
415 | # set S
416 | hsv[1] = delta / c_max[0]
417 |
418 | # set V - good
419 | hsv[2] = c_max[0]
420 |
421 | return hsv
422 |
--------------------------------------------------------------------------------
/torchsample/transforms/tensor_transforms.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import math
5 | import numpy as np
6 |
7 | import torch as th
8 | from torch.autograd import Variable
9 |
10 | from ..utils import th_random_choice
11 |
12 | class Compose(object):
13 | """
14 | Composes several transforms together.
15 | """
16 | def __init__(self, transforms):
17 | """
18 | Composes (chains) several transforms together into
19 | a single transform
20 |
21 | Arguments
22 | ---------
23 | transforms : a list of transforms
24 | transforms will be applied sequentially
25 | """
26 | self.transforms = transforms
27 |
28 | def __call__(self, *inputs):
29 | for transform in self.transforms:
30 | if not isinstance(inputs, (list,tuple)):
31 | inputs = [inputs]
32 | inputs = transform(*inputs)
33 | return inputs
34 |
35 |
36 | class RandomChoiceCompose(object):
37 | """
38 | Randomly choose to apply one transform from a collection of transforms
39 |
40 | e.g. to randomly apply EITHER 0-1 or -1-1 normalization to an input:
41 | >>> transform = RandomChoiceCompose([RangeNormalize(0,1),
42 | RangeNormalize(-1,1)])
43 | >>> x_norm = transform(x) # only one of the two normalizations is applied
44 | """
45 | def __init__(self, transforms):
46 | self.transforms = transforms
47 |
48 | def __call__(self, *inputs):
49 | tform = random.choice(self.transforms)
50 | outputs = tform(*inputs)
51 | return outputs
52 |
53 |
54 | class ToTensor(object):
55 | """
56 | Converts a numpy array to torch.Tensor
57 | """
58 | def __call__(self, *inputs):
59 | outputs = []
60 | for idx, _input in enumerate(inputs):
61 | _input = th.from_numpy(_input)
62 | outputs.append(_input)
63 | return outputs if idx > 1 else outputs[0]
64 |
65 |
66 | class ToVariable(object):
67 | """
68 | Converts a torch.Tensor to autograd.Variable
69 | """
70 | def __call__(self, *inputs):
71 | outputs = []
72 | for idx, _input in enumerate(inputs):
73 | _input = Variable(_input)
74 | outputs.append(_input)
75 | return outputs if idx > 1 else outputs[0]
76 |
77 |
78 | class ToCuda(object):
79 | """
80 | Moves an autograd.Variable to the GPU
81 | """
82 | def __init__(self, device=0):
83 | """
84 | Moves an autograd.Variable to the GPU
85 |
86 | Arguments
87 | ---------
88 | device : integer
89 | which GPU device to put the input(s) on
90 | """
91 | self.device = device
92 |
93 | def __call__(self, *inputs):
94 | outputs = []
95 | for idx, _input in enumerate(inputs):
96 | _input = _input.cuda(self.device)
97 | outputs.append(_input)
98 | return outputs if idx > 1 else outputs[0]
99 |
100 |
101 | class ToFile(object):
102 | """
103 | Saves an image to file. Useful as a pass-through ransform
104 | when wanting to observe how augmentation affects the data
105 |
106 | NOTE: Only supports saving to Numpy currently
107 | """
108 | def __init__(self, root):
109 | """
110 | Saves an image to file. Useful as a pass-through ransform
111 | when wanting to observe how augmentation affects the data
112 |
113 | NOTE: Only supports saving to Numpy currently
114 |
115 | Arguments
116 | ---------
117 | root : string
118 | path to main directory in which images will be saved
119 | """
120 | if root.startswith('~'):
121 | root = os.path.expanduser(root)
122 | self.root = root
123 | self.counter = 0
124 |
125 | def __call__(self, *inputs):
126 | for idx, _input in inputs:
127 | fpath = os.path.join(self.root, 'img_%i_%i.npy'%(self.counter, idx))
128 | np.save(fpath, _input.numpy())
129 | self.counter += 1
130 | return inputs
131 |
132 |
133 | class ChannelsLast(object):
134 | """
135 | Transposes a tensor so that the channel dim is last
136 | `HWC` and `DHWC` are aliases for this transform.
137 | """
138 | def __init__(self, safe_check=False):
139 | """
140 | Transposes a tensor so that the channel dim is last
141 | `HWC` and `DHWC` are aliases for this transform.
142 |
143 | Arguments
144 | ---------
145 | safe_check : boolean
146 | if true, will check if channels are already last and, if so,
147 | will just return the inputs
148 | """
149 | self.safe_check = safe_check
150 |
151 | def __call__(self, *inputs):
152 | ndim = inputs[0].dim()
153 | if self.safe_check:
154 | # check if channels are already last
155 | if inputs[0].size(-1) < inputs[0].size(0):
156 | return inputs
157 | plist = list(range(1,ndim))+[0]
158 |
159 | outputs = []
160 | for idx, _input in enumerate(inputs):
161 | _input = _input.permute(*plist)
162 | outputs.append(_input)
163 | return outputs if idx > 1 else outputs[0]
164 |
165 | HWC = ChannelsLast
166 | DHWC = ChannelsLast
167 |
168 | class ChannelsFirst(object):
169 | """
170 | Transposes a tensor so that the channel dim is first.
171 | `CHW` and `CDHW` are aliases for this transform.
172 | """
173 | def __init__(self, safe_check=False):
174 | """
175 | Transposes a tensor so that the channel dim is first.
176 | `CHW` and `CDHW` are aliases for this transform.
177 |
178 | Arguments
179 | ---------
180 | safe_check : boolean
181 | if true, will check if channels are already last and, if so,
182 | will just return the inputs
183 | """
184 | self.safe_check = safe_check
185 |
186 | def __call__(self, *inputs):
187 | ndim = inputs[0].dim()
188 | if self.safe_check:
189 | # check if channels are already first
190 | if inputs[0].size(0) < inputs[0].size(-1):
191 | return inputs
192 | plist = [ndim-1] + list(range(0,ndim-1))
193 |
194 | outputs = []
195 | for idx, _input in enumerate(inputs):
196 | _input = _input.permute(*plist)
197 | outputs.append(_input)
198 | return outputs if idx > 1 else outputs[0]
199 |
200 | CHW = ChannelsFirst
201 | CDHW = ChannelsFirst
202 |
203 | class TypeCast(object):
204 | """
205 | Cast a torch.Tensor to a different type
206 | """
207 | def __init__(self, dtype='float'):
208 | """
209 | Cast a torch.Tensor to a different type
210 |
211 | Arguments
212 | ---------
213 | dtype : string or torch.*Tensor literal or list of such
214 | data type to which input(s) will be cast.
215 | If list, it should be the same length as inputs.
216 | """
217 | if isinstance(dtype, (list,tuple)):
218 | dtypes = []
219 | for dt in dtype:
220 | if isinstance(dt, str):
221 | if dt == 'byte':
222 | dt = th.ByteTensor
223 | elif dt == 'double':
224 | dt = th.DoubleTensor
225 | elif dt == 'float':
226 | dt = th.FloatTensor
227 | elif dt == 'int':
228 | dt = th.IntTensor
229 | elif dt == 'long':
230 | dt = th.LongTensor
231 | elif dt == 'short':
232 | dt = th.ShortTensor
233 | dtypes.append(dt)
234 | self.dtype = dtypes
235 | else:
236 | if isinstance(dtype, str):
237 | if dtype == 'byte':
238 | dtype = th.ByteTensor
239 | elif dtype == 'double':
240 | dtype = th.DoubleTensor
241 | elif dtype == 'float':
242 | dtype = th.FloatTensor
243 | elif dtype == 'int':
244 | dtype = th.IntTensor
245 | elif dtype == 'long':
246 | dtype = th.LongTensor
247 | elif dtype == 'short':
248 | dtype = th.ShortTensor
249 | self.dtype = dtype
250 |
251 | def __call__(self, *inputs):
252 | if not isinstance(self.dtype, (tuple,list)):
253 | dtypes = [self.dtype]*len(inputs)
254 | else:
255 | dtypes = self.dtype
256 |
257 | outputs = []
258 | for idx, _input in enumerate(inputs):
259 | _input = _input.type(dtypes[idx])
260 | outputs.append(_input)
261 | return outputs if idx > 1 else outputs[0]
262 |
263 |
264 | class AddChannel(object):
265 | """
266 | Adds a dummy channel to an image.
267 | This will make an image of size (28, 28) to now be
268 | of size (1, 28, 28), for example.
269 | """
270 | def __init__(self, axis=0):
271 | """
272 | Adds a dummy channel to an image, also known as
273 | expanding an axis or unsqueezing a dim
274 |
275 | Arguments
276 | ---------
277 | axis : integer
278 | dimension to be expanded to singleton
279 | """
280 | self.axis = axis
281 |
282 | def __call__(self, *inputs):
283 | outputs = []
284 | for idx, _input in enumerate(inputs):
285 | _input = _input.unsqueeze(self.axis)
286 | outputs.append(_input)
287 | return outputs if idx > 1 else outputs[0]
288 |
289 | ExpandAxis = AddChannel
290 | Unsqueeze = AddChannel
291 |
292 | class Transpose(object):
293 |
294 | def __init__(self, dim1, dim2):
295 | """
296 | Swaps two dimensions of a tensor
297 |
298 | Arguments
299 | ---------
300 | dim1 : integer
301 | first dim to switch
302 | dim2 : integer
303 | second dim to switch
304 | """
305 | self.dim1 = dim1
306 | self.dim2 = dim2
307 |
308 | def __call__(self, *inputs):
309 | outputs = []
310 | for idx, _input in enumerate(inputs):
311 | _input = th.transpose(_input, self.dim1, self.dim2)
312 | outputs.append(_input)
313 | return outputs if idx > 1 else outputs[0]
314 |
315 |
316 | class RangeNormalize(object):
317 | """
318 | Given min_val: (R, G, B) and max_val: (R,G,B),
319 | will normalize each channel of the th.*Tensor to
320 | the provided min and max values.
321 |
322 | Works by calculating :
323 | a = (max'-min')/(max-min)
324 | b = max' - a * max
325 | new_value = a * value + b
326 | where min' & max' are given values,
327 | and min & max are observed min/max for each channel
328 |
329 | Arguments
330 | ---------
331 | min_range : float or integer
332 | Min value to which tensors will be normalized
333 | max_range : float or integer
334 | Max value to which tensors will be normalized
335 | fixed_min : float or integer
336 | Give this value if every sample has the same min (max) and
337 | you know for sure what it is. For instance, if you
338 | have an image then you know the min value will be 0 and the
339 | max value will be 255. Otherwise, the min/max value will be
340 | calculated for each individual sample and this will decrease
341 | speed. Dont use this if each sample has a different min/max.
342 | fixed_max :float or integer
343 | See above
344 |
345 | Example:
346 | >>> x = th.rand(3,5,5)
347 | >>> rn = RangeNormalize((0,0,10),(1,1,11))
348 | >>> x_norm = rn(x)
349 |
350 | Also works with just one value for min/max:
351 | >>> x = th.rand(3,5,5)
352 | >>> rn = RangeNormalize(0,1)
353 | >>> x_norm = rn(x)
354 | """
355 | def __init__(self,
356 | min_val,
357 | max_val):
358 | """
359 | Normalize a tensor between a min and max value
360 |
361 | Arguments
362 | ---------
363 | min_val : float
364 | lower bound of normalized tensor
365 | max_val : float
366 | upper bound of normalized tensor
367 | """
368 | self.min_val = min_val
369 | self.max_val = max_val
370 |
371 | def __call__(self, *inputs):
372 | outputs = []
373 | for idx, _input in enumerate(inputs):
374 | _min_val = _input.min()
375 | _max_val = _input.max()
376 | a = (self.max_val - self.min_val) / (_max_val - _min_val)
377 | b = self.max_val- a * _max_val
378 | _input = _input.mul(a).add(b)
379 | outputs.append(_input)
380 | return outputs if idx > 1 else outputs[0]
381 |
382 |
383 | class StdNormalize(object):
384 | """
385 | Normalize torch tensor to have zero mean and unit std deviation
386 | """
387 | def __call__(self, *inputs):
388 | outputs = []
389 | for idx, _input in enumerate(inputs):
390 | _input = _input.sub(_input.mean()).div(_input.std())
391 | outputs.append(_input)
392 | return outputs if idx > 1 else outputs[0]
393 |
394 |
395 | class Slice2D(object):
396 |
397 | def __init__(self, axis=0, reject_zeros=False):
398 | """
399 | Take a random 2D slice from a 3D image along
400 | a given axis. This image should not have a 4th channel dim.
401 |
402 | Arguments
403 | ---------
404 | axis : integer in {0, 1, 2}
405 | the axis on which to take slices
406 |
407 | reject_zeros : boolean
408 | whether to reject slices that are all zeros
409 | """
410 | self.axis = axis
411 | self.reject_zeros = reject_zeros
412 |
413 | def __call__(self, x, y=None):
414 | while True:
415 | keep_slice = random.randint(0,x.size(self.axis)-1)
416 | if self.axis == 0:
417 | slice_x = x[keep_slice,:,:]
418 | if y is not None:
419 | slice_y = y[keep_slice,:,:]
420 | elif self.axis == 1:
421 | slice_x = x[:,keep_slice,:]
422 | if y is not None:
423 | slice_y = y[:,keep_slice,:]
424 | elif self.axis == 2:
425 | slice_x = x[:,:,keep_slice]
426 | if y is not None:
427 | slice_y = y[:,:,keep_slice]
428 |
429 | if not self.reject_zeros:
430 | break
431 | else:
432 | if y is not None and th.sum(slice_y) > 0:
433 | break
434 | elif th.sum(slice_x) > 0:
435 | break
436 | if y is not None:
437 | return slice_x, slice_y
438 | else:
439 | return slice_x
440 |
441 |
442 | class RandomCrop(object):
443 |
444 | def __init__(self, size):
445 | """
446 | Randomly crop a torch tensor
447 |
448 | Arguments
449 | --------
450 | size : tuple or list
451 | dimensions of the crop
452 | """
453 | self.size = size
454 |
455 | def __call__(self, *inputs):
456 | h_idx = random.randint(0,inputs[0].size(1)-self.size[0])
457 | w_idx = random.randint(0,inputs[1].size(2)-self.size[1])
458 | outputs = []
459 | for idx, _input in enumerate(inputs):
460 | _input = _input[:, h_idx:(h_idx+self.size[0]),w_idx:(w_idx+self.size[1])]
461 | outputs.append(_input)
462 | return outputs if idx > 1 else outputs[0]
463 |
464 |
465 | class SpecialCrop(object):
466 |
467 | def __init__(self, size, crop_type=0):
468 | """
469 | Perform a special crop - one of the four corners or center crop
470 |
471 | Arguments
472 | ---------
473 | size : tuple or list
474 | dimensions of the crop
475 |
476 | crop_type : integer in {0,1,2,3,4}
477 | 0 = center crop
478 | 1 = top left crop
479 | 2 = top right crop
480 | 3 = bottom right crop
481 | 4 = bottom left crop
482 | """
483 | if crop_type not in {0, 1, 2, 3, 4}:
484 | raise ValueError('crop_type must be in {0, 1, 2, 3, 4}')
485 | self.size = size
486 | self.crop_type = crop_type
487 |
488 | def __call__(self, x, y=None):
489 | if self.crop_type == 0:
490 | # center crop
491 | x_diff = (x.size(1)-self.size[0])/2.
492 | y_diff = (x.size(2)-self.size[1])/2.
493 | ct_x = [int(math.ceil(x_diff)),x.size(1)-int(math.floor(x_diff))]
494 | ct_y = [int(math.ceil(y_diff)),x.size(2)-int(math.floor(y_diff))]
495 | indices = [ct_x,ct_y]
496 | elif self.crop_type == 1:
497 | # top left crop
498 | tl_x = [0, self.size[0]]
499 | tl_y = [0, self.size[1]]
500 | indices = [tl_x,tl_y]
501 | elif self.crop_type == 2:
502 | # top right crop
503 | tr_x = [0, self.size[0]]
504 | tr_y = [x.size(2)-self.size[1], x.size(2)]
505 | indices = [tr_x,tr_y]
506 | elif self.crop_type == 3:
507 | # bottom right crop
508 | br_x = [x.size(1)-self.size[0],x.size(1)]
509 | br_y = [x.size(2)-self.size[1],x.size(2)]
510 | indices = [br_x,br_y]
511 | elif self.crop_type == 4:
512 | # bottom left crop
513 | bl_x = [x.size(1)-self.size[0], x.size(1)]
514 | bl_y = [0, self.size[1]]
515 | indices = [bl_x,bl_y]
516 |
517 | x = x[:,indices[0][0]:indices[0][1],indices[1][0]:indices[1][1]]
518 |
519 | if y is not None:
520 | y = y[:,indices[0][0]:indices[0][1],indices[1][0]:indices[1][1]]
521 | return x, y
522 | else:
523 | return x
524 |
525 |
526 | class Pad(object):
527 |
528 | def __init__(self, size):
529 | """
530 | Pads an image to the given size
531 |
532 | Arguments
533 | ---------
534 | size : tuple or list
535 | size of crop
536 | """
537 | self.size = size
538 |
539 | def __call__(self, x, y=None):
540 | x = x.numpy()
541 | shape_diffs = [int(np.ceil((i_s - d_s))) for d_s,i_s in zip(x.shape,self.size)]
542 | shape_diffs = np.maximum(shape_diffs,0)
543 | pad_sizes = [(int(np.ceil(s/2.)),int(np.floor(s/2.))) for s in shape_diffs]
544 | x = np.pad(x, pad_sizes, mode='constant')
545 | if y is not None:
546 | y = y.numpy()
547 | y = np.pad(y, pad_sizes, mode='constant')
548 | return th.from_numpy(x), th.from_numpy(y)
549 | else:
550 | return th.from_numpy(x)
551 |
552 |
553 | class RandomFlip(object):
554 |
555 | def __init__(self, h=True, v=False, p=0.5):
556 | """
557 | Randomly flip an image horizontally and/or vertically with
558 | some probability.
559 |
560 | Arguments
561 | ---------
562 | h : boolean
563 | whether to horizontally flip w/ probability p
564 |
565 | v : boolean
566 | whether to vertically flip w/ probability p
567 |
568 | p : float between [0,1]
569 | probability with which to apply allowed flipping operations
570 | """
571 | self.horizontal = h
572 | self.vertical = v
573 | self.p = p
574 |
575 | def __call__(self, x, y=None):
576 | x = x.numpy()
577 | if y is not None:
578 | y = y.numpy()
579 | # horizontal flip with p = self.p
580 | if self.horizontal:
581 | if random.random() < self.p:
582 | x = x.swapaxes(2, 0)
583 | x = x[::-1, ...]
584 | x = x.swapaxes(0, 2)
585 | if y is not None:
586 | y = y.swapaxes(2, 0)
587 | y = y[::-1, ...]
588 | y = y.swapaxes(0, 2)
589 | # vertical flip with p = self.p
590 | if self.vertical:
591 | if random.random() < self.p:
592 | x = x.swapaxes(1, 0)
593 | x = x[::-1, ...]
594 | x = x.swapaxes(0, 1)
595 | if y is not None:
596 | y = y.swapaxes(1, 0)
597 | y = y[::-1, ...]
598 | y = y.swapaxes(0, 1)
599 | if y is None:
600 | # must copy because torch doesnt current support neg strides
601 | return th.from_numpy(x.copy())
602 | else:
603 | return th.from_numpy(x.copy()),th.from_numpy(y.copy())
604 |
605 |
606 | class RandomOrder(object):
607 | """
608 | Randomly permute the channels of an image
609 | """
610 | def __call__(self, *inputs):
611 | order = th.randperm(inputs[0].dim())
612 | outputs = []
613 | for idx, _input in enumerate(inputs):
614 | _input = _input.index_select(0, order)
615 | outputs.append(_input)
616 | return outputs if idx > 1 else outputs[0]
617 |
618 |
--------------------------------------------------------------------------------
/torchsample/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for th.Tensors
3 | """
4 |
5 | import pickle
6 | import random
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def th_allclose(x, y):
13 | """
14 | Determine whether two torch tensors have same values
15 | Mimics np.allclose
16 | """
17 | return th.sum(th.abs(x-y)) < 1e-5
18 |
19 |
20 | def th_flatten(x):
21 | """Flatten tensor"""
22 | return x.contiguous().view(-1)
23 |
24 | def th_c_flatten(x):
25 | """
26 | Flatten tensor, leaving channel intact.
27 | Assumes CHW format.
28 | """
29 | return x.contiguous().view(x.size(0), -1)
30 |
31 | def th_bc_flatten(x):
32 | """
33 | Flatten tensor, leaving batch and channel dims intact.
34 | Assumes BCHW format
35 | """
36 | return x.contiguous().view(x.size(0), x.size(1), -1)
37 |
38 |
39 | def th_zeros_like(x):
40 | return x.new().resize_as_(x).zero_()
41 |
42 | def th_ones_like(x):
43 | return x.new().resize_as_(x).fill_(1)
44 |
45 | def th_constant_like(x, val):
46 | return x.new().resize_as_(x).fill_(val)
47 |
48 |
49 | def th_iterproduct(*args):
50 | return th.from_numpy(np.indices(args).reshape((len(args),-1)).T)
51 |
52 | def th_iterproduct_like(x):
53 | return th_iterproduct(*x.size())
54 |
55 |
56 | def th_uniform(lower, upper):
57 | return random.uniform(lower, upper)
58 |
59 |
60 | def th_gather_nd(x, coords):
61 | x = x.contiguous()
62 | inds = coords.mv(th.LongTensor(x.stride()))
63 | x_gather = th.index_select(th_flatten(x), 0, inds)
64 | return x_gather
65 |
66 |
67 | def th_affine2d(x, matrix, mode='bilinear', center=True):
68 | """
69 | 2D Affine image transform on th.Tensor
70 |
71 | Arguments
72 | ---------
73 | x : th.Tensor of size (C, H, W)
74 | image tensor to be transformed
75 |
76 | matrix : th.Tensor of size (3, 3) or (2, 3)
77 | transformation matrix
78 |
79 | mode : string in {'nearest', 'bilinear'}
80 | interpolation scheme to use
81 |
82 | center : boolean
83 | whether to alter the bias of the transform
84 | so the transform is applied about the center
85 | of the image rather than the origin
86 |
87 | Example
88 | -------
89 | >>> import torch
90 | >>> from torchsample.utils import *
91 | >>> x = th.zeros(2,1000,1000)
92 | >>> x[:,100:1500,100:500] = 10
93 | >>> matrix = th.FloatTensor([[1.,0,-50],
94 | ... [0,1.,-50]])
95 | >>> xn = th_affine2d(x, matrix, mode='nearest')
96 | >>> xb = th_affine2d(x, matrix, mode='bilinear')
97 | """
98 |
99 | if matrix.dim() == 2:
100 | matrix = matrix[:2,:]
101 | matrix = matrix.unsqueeze(0)
102 | elif matrix.dim() == 3:
103 | if matrix.size()[1:] == (3,3):
104 | matrix = matrix[:,:2,:]
105 |
106 | A_batch = matrix[:,:,:2]
107 | if A_batch.size(0) != x.size(0):
108 | A_batch = A_batch.repeat(x.size(0),1,1)
109 | b_batch = matrix[:,:,2].unsqueeze(1)
110 |
111 | # make a meshgrid of normal coordinates
112 | _coords = th_iterproduct(x.size(1),x.size(2))
113 | coords = _coords.unsqueeze(0).repeat(x.size(0),1,1).float()
114 |
115 | if center:
116 | # shift the coordinates so center is the origin
117 | coords[:,:,0] = coords[:,:,0] - (x.size(1) / 2. - 0.5)
118 | coords[:,:,1] = coords[:,:,1] - (x.size(2) / 2. - 0.5)
119 | # apply the coordinate transformation
120 | new_coords = coords.bmm(A_batch.transpose(1,2)) + b_batch.expand_as(coords)
121 |
122 | if center:
123 | # shift the coordinates back so origin is origin
124 | new_coords[:,:,0] = new_coords[:,:,0] + (x.size(1) / 2. - 0.5)
125 | new_coords[:,:,1] = new_coords[:,:,1] + (x.size(2) / 2. - 0.5)
126 |
127 | # map new coordinates using bilinear interpolation
128 | if mode == 'nearest':
129 | x_transformed = th_nearest_interp2d(x.contiguous(), new_coords)
130 | elif mode == 'bilinear':
131 | x_transformed = th_bilinear_interp2d(x.contiguous(), new_coords)
132 |
133 | return x_transformed
134 |
135 |
136 | def th_nearest_interp2d(input, coords):
137 | """
138 | 2d nearest neighbor interpolation th.Tensor
139 | """
140 | # take clamp of coords so they're in the image bounds
141 | x = th.clamp(coords[:,:,0], 0, input.size(1)-1).round()
142 | y = th.clamp(coords[:,:,1], 0, input.size(2)-1).round()
143 |
144 | stride = th.LongTensor(input.stride())
145 | x_ix = x.mul(stride[1]).long()
146 | y_ix = y.mul(stride[2]).long()
147 |
148 | input_flat = input.view(input.size(0),-1)
149 |
150 | mapped_vals = input_flat.gather(1, x_ix.add(y_ix))
151 |
152 | return mapped_vals.view_as(input)
153 |
154 |
155 | def th_bilinear_interp2d(input, coords):
156 | """
157 | bilinear interpolation in 2d
158 | """
159 | x = th.clamp(coords[:,:,0], 0, input.size(1)-2)
160 | x0 = x.floor()
161 | x1 = x0 + 1
162 | y = th.clamp(coords[:,:,1], 0, input.size(2)-2)
163 | y0 = y.floor()
164 | y1 = y0 + 1
165 |
166 | stride = th.LongTensor(input.stride())
167 | x0_ix = x0.mul(stride[1]).long()
168 | x1_ix = x1.mul(stride[1]).long()
169 | y0_ix = y0.mul(stride[2]).long()
170 | y1_ix = y1.mul(stride[2]).long()
171 |
172 | input_flat = input.view(input.size(0),-1)
173 |
174 | vals_00 = input_flat.gather(1, x0_ix.add(y0_ix))
175 | vals_10 = input_flat.gather(1, x1_ix.add(y0_ix))
176 | vals_01 = input_flat.gather(1, x0_ix.add(y1_ix))
177 | vals_11 = input_flat.gather(1, x1_ix.add(y1_ix))
178 |
179 | xd = x - x0
180 | yd = y - y0
181 | xm = 1 - xd
182 | ym = 1 - yd
183 |
184 | x_mapped = (vals_00.mul(xm).mul(ym) +
185 | vals_10.mul(xd).mul(ym) +
186 | vals_01.mul(xm).mul(yd) +
187 | vals_11.mul(xd).mul(yd))
188 |
189 | return x_mapped.view_as(input)
190 |
191 |
192 | def th_affine3d(x, matrix, mode='trilinear', center=True):
193 | """
194 | 3D Affine image transform on th.Tensor
195 | """
196 | A = matrix[:3,:3]
197 | b = matrix[:3,3]
198 |
199 | # make a meshgrid of normal coordinates
200 | coords = th_iterproduct(x.size(1),x.size(2),x.size(3)).float()
201 |
202 |
203 | if center:
204 | # shift the coordinates so center is the origin
205 | coords[:,0] = coords[:,0] - (x.size(1) / 2. - 0.5)
206 | coords[:,1] = coords[:,1] - (x.size(2) / 2. - 0.5)
207 | coords[:,2] = coords[:,2] - (x.size(3) / 2. - 0.5)
208 |
209 |
210 | # apply the coordinate transformation
211 | new_coords = coords.mm(A.t().contiguous()) + b.expand_as(coords)
212 |
213 | if center:
214 | # shift the coordinates back so origin is origin
215 | new_coords[:,0] = new_coords[:,0] + (x.size(1) / 2. - 0.5)
216 | new_coords[:,1] = new_coords[:,1] + (x.size(2) / 2. - 0.5)
217 | new_coords[:,2] = new_coords[:,2] + (x.size(3) / 2. - 0.5)
218 |
219 | # map new coordinates using bilinear interpolation
220 | if mode == 'nearest':
221 | x_transformed = th_nearest_interp3d(x, new_coords)
222 | elif mode == 'trilinear':
223 | x_transformed = th_trilinear_interp3d(x, new_coords)
224 | else:
225 | x_transformed = th_trilinear_interp3d(x, new_coords)
226 |
227 | return x_transformed
228 |
229 |
230 | def th_nearest_interp3d(input, coords):
231 | """
232 | 2d nearest neighbor interpolation th.Tensor
233 | """
234 | # take clamp of coords so they're in the image bounds
235 | coords[:,0] = th.clamp(coords[:,0], 0, input.size(1)-1).round()
236 | coords[:,1] = th.clamp(coords[:,1], 0, input.size(2)-1).round()
237 | coords[:,2] = th.clamp(coords[:,2], 0, input.size(3)-1).round()
238 |
239 | stride = th.LongTensor(input.stride())[1:].float()
240 | idx = coords.mv(stride).long()
241 |
242 | input_flat = th_flatten(input)
243 |
244 | mapped_vals = input_flat[idx]
245 |
246 | return mapped_vals.view_as(input)
247 |
248 |
249 | def th_trilinear_interp3d(input, coords):
250 | """
251 | trilinear interpolation of 3D th.Tensor image
252 | """
253 | # take clamp then floor/ceil of x coords
254 | x = th.clamp(coords[:,0], 0, input.size(1)-2)
255 | x0 = x.floor()
256 | x1 = x0 + 1
257 | # take clamp then floor/ceil of y coords
258 | y = th.clamp(coords[:,1], 0, input.size(2)-2)
259 | y0 = y.floor()
260 | y1 = y0 + 1
261 | # take clamp then floor/ceil of z coords
262 | z = th.clamp(coords[:,2], 0, input.size(3)-2)
263 | z0 = z.floor()
264 | z1 = z0 + 1
265 |
266 | stride = th.LongTensor(input.stride())[1:]
267 | x0_ix = x0.mul(stride[0]).long()
268 | x1_ix = x1.mul(stride[0]).long()
269 | y0_ix = y0.mul(stride[1]).long()
270 | y1_ix = y1.mul(stride[1]).long()
271 | z0_ix = z0.mul(stride[2]).long()
272 | z1_ix = z1.mul(stride[2]).long()
273 |
274 | input_flat = th_flatten(input)
275 |
276 | vals_000 = input_flat[x0_ix+y0_ix+z0_ix]
277 | vals_100 = input_flat[x1_ix+y0_ix+z0_ix]
278 | vals_010 = input_flat[x0_ix+y1_ix+z0_ix]
279 | vals_001 = input_flat[x0_ix+y0_ix+z1_ix]
280 | vals_101 = input_flat[x1_ix+y0_ix+z1_ix]
281 | vals_011 = input_flat[x0_ix+y1_ix+z1_ix]
282 | vals_110 = input_flat[x1_ix+y1_ix+z0_ix]
283 | vals_111 = input_flat[x1_ix+y1_ix+z1_ix]
284 |
285 | xd = x - x0
286 | yd = y - y0
287 | zd = z - z0
288 | xm1 = 1 - xd
289 | ym1 = 1 - yd
290 | zm1 = 1 - zd
291 |
292 | x_mapped = (vals_000.mul(xm1).mul(ym1).mul(zm1) +
293 | vals_100.mul(xd).mul(ym1).mul(zm1) +
294 | vals_010.mul(xm1).mul(yd).mul(zm1) +
295 | vals_001.mul(xm1).mul(ym1).mul(zd) +
296 | vals_101.mul(xd).mul(ym1).mul(zd) +
297 | vals_011.mul(xm1).mul(yd).mul(zd) +
298 | vals_110.mul(xd).mul(yd).mul(zm1) +
299 | vals_111.mul(xd).mul(yd).mul(zd))
300 |
301 | return x_mapped.view_as(input)
302 |
303 |
304 | def th_pearsonr(x, y):
305 | """
306 | mimics scipy.stats.pearsonr
307 | """
308 | mean_x = th.mean(x)
309 | mean_y = th.mean(y)
310 | xm = x.sub(mean_x)
311 | ym = y.sub(mean_y)
312 | r_num = xm.dot(ym)
313 | r_den = th.norm(xm, 2) * th.norm(ym, 2)
314 | r_val = r_num / r_den
315 | return r_val
316 |
317 |
318 | def th_corrcoef(x):
319 | """
320 | mimics np.corrcoef
321 | """
322 | # calculate covariance matrix of rows
323 | mean_x = th.mean(x, 1)
324 | xm = x.sub(mean_x.expand_as(x))
325 | c = xm.mm(xm.t())
326 | c = c / (x.size(1) - 1)
327 |
328 | # normalize covariance matrix
329 | d = th.diag(c)
330 | stddev = th.pow(d, 0.5)
331 | c = c.div(stddev.expand_as(c))
332 | c = c.div(stddev.expand_as(c).t())
333 |
334 | # clamp between -1 and 1
335 | c = th.clamp(c, -1.0, 1.0)
336 |
337 | return c
338 |
339 |
340 | def th_matrixcorr(x, y):
341 | """
342 | return a correlation matrix between
343 | columns of x and columns of y.
344 |
345 | So, if X.size() == (1000,4) and Y.size() == (1000,5),
346 | then the result will be of size (4,5) with the
347 | (i,j) value equal to the pearsonr correlation coeff
348 | between column i in X and column j in Y
349 | """
350 | mean_x = th.mean(x, 0)
351 | mean_y = th.mean(y, 0)
352 | xm = x.sub(mean_x.expand_as(x))
353 | ym = y.sub(mean_y.expand_as(y))
354 | r_num = xm.t().mm(ym)
355 | r_den1 = th.norm(xm,2,0)
356 | r_den2 = th.norm(ym,2,0)
357 | r_den = r_den1.t().mm(r_den2)
358 | r_mat = r_num.div(r_den)
359 | return r_mat
360 |
361 |
362 | def th_random_choice(a, n_samples=1, replace=True, p=None):
363 | """
364 | Parameters
365 | -----------
366 | a : 1-D array-like
367 | If a th.Tensor, a random sample is generated from its elements.
368 | If an int, the random sample is generated as if a was th.range(n)
369 | n_samples : int, optional
370 | Number of samples to draw. Default is None, in which case a
371 | single value is returned.
372 | replace : boolean, optional
373 | Whether the sample is with or without replacement
374 | p : 1-D array-like, optional
375 | The probabilities associated with each entry in a.
376 | If not given the sample assumes a uniform distribution over all
377 | entries in a.
378 |
379 | Returns
380 | --------
381 | samples : 1-D ndarray, shape (size,)
382 | The generated random samples
383 | """
384 | if isinstance(a, int):
385 | a = th.arange(0, a)
386 |
387 | if p is None:
388 | if replace:
389 | idx = th.floor(th.rand(n_samples)*a.size(0)).long()
390 | else:
391 | idx = th.randperm(len(a))[:n_samples]
392 | else:
393 | if abs(1.0-sum(p)) > 1e-3:
394 | raise ValueError('p must sum to 1.0')
395 | if not replace:
396 | raise ValueError('replace must equal true if probabilities given')
397 | idx_vec = th.cat([th.zeros(round(p[i]*1000))+i for i in range(len(p))])
398 | idx = (th.floor(th.rand(n_samples)*999)).long()
399 | idx = idx_vec[idx].long()
400 | selection = a[idx]
401 | if n_samples == 1:
402 | selection = selection[0]
403 | return selection
404 |
405 |
406 | def save_transform(file, transform):
407 | """
408 | Save a transform object
409 | """
410 | with open(file, 'wb') as output_file:
411 | pickler = pickle.Pickler(output_file, -1)
412 | pickler.dump(transform)
413 |
414 |
415 | def load_transform(file):
416 | """
417 | Load a transform object
418 | """
419 | with open(file, 'rb') as input_file:
420 | transform = pickle.load(input_file)
421 | return transform
422 |
423 |
424 |
425 |
426 |
--------------------------------------------------------------------------------
/torchsample/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.1.3'
2 |
--------------------------------------------------------------------------------