├── .gitignore ├── README.md ├── setup.py ├── test └── test_extras.py └── torch_extras ├── __init__.py └── extras.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-extras 2 | 3 | `pip install pytorch-extras` 4 | 5 | ## Usage 6 | 7 | ### [expand_along](#expand_along) 8 | 9 | `expand_along(var, mask)` - Useful for selecting a dynamic amount of items from different indexes using a byte mask. This is a bit like [numpy.repeat](https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html). 10 | 11 | import torch 12 | import torch_extras 13 | setattr(torch, 'expand_along', torch_extras.expand_along) 14 | 15 | var = torch.Tensor([1, 0, 2]) 16 | mask = torch.ByteTensor([[True, True], [False, True], [False, False]]) 17 | torch.expand_along(var, mask) 18 | # (1, 1, 0) 19 | 20 | 21 | ### [expand_dims](#expand_dims) 22 | 23 | `expand_dims(var, dim)` - Is similar to [numpy.expand_dims](https://docs.scipy.org/doc/numpy/reference/generated/numpy.expand_dims.html). 24 | 25 | import torch 26 | import torch_extras 27 | setattr(torch, 'expand_dims', torch_extras.expand_dims) 28 | 29 | var = torch.range(0, 9).view(-1, 2) 30 | torch.expand_dims(var, 0).size() 31 | # (1, 5, 2) 32 | 33 | Note: Have recently found out about [torch.unsqeeze](http://pytorch.org/docs/tensors.html?highlight=unsqueeze#torch.Tensor.unsqueeze), which has the same API and is probably a more effective method for expanding dimensions. 34 | 35 | 36 | ### [select_item](#select_item) 37 | 38 | `select_item(var, index)` - Is similar to `[var[row,col] for row, col in enumerate(index)]`. 39 | 40 | import torch 41 | import torch_extras 42 | setattr(torch, 'select_item', torch_extras.select_item) 43 | 44 | var = torch.range(0, 9).view(-1, 2) 45 | index = torch.LongTensor([0, 0, 0, 1, 1]) 46 | torch.select_item(var, index) 47 | # [0, 2, 4, 7, 9] 48 | 49 | 50 | ### [cast](#cast) 51 | 52 | `cast(var, type)` - Cast a Tensor to the given type. 53 | 54 | import torch 55 | import torch_extras 56 | setattr(torch, 'cast', torch_extras.cast) 57 | 58 | input = torch.FloatTensor(1) 59 | target_type = type(torch.LongTensor(1)) 60 | type(torch.cast(input, target_type)) 61 | # 62 | 63 | 64 | ### [one_hot](#one_hot) 65 | 66 | 67 | `one_hot(size, index)` - Creates a matrix of one hot vectors. 68 | 69 | import torch 70 | import torch_extras 71 | setattr(torch, 'one_hot', torch_extras.one_hot) 72 | 73 | size = (3, 3) 74 | index = torch.LongTensor([2, 0, 1]).view(-1, 1) 75 | torch.one_hot(size, index) 76 | # [[0, 0, 1], [1, 0, 0], [0, 1, 0]] 77 | 78 | 79 | ### [nll](#nll) 80 | 81 | `nll(log_prob, label)` - Is similar to [`nll_loss`](http://pytorch.org/docs/nn.html?highlight=nll#torch.nn.functional.nll_loss) except does not return an aggregate. 82 | 83 | import torch 84 | from torch.autograd import Variable 85 | import torch.nn.functional as F 86 | import torch_extras 87 | setattr(torch, 'nll', torch_extras.nll) 88 | 89 | input = Variable(torch.FloatTensor([[0.5, 0.2, 0.3], [0.1, 0.8, 0.1]])) 90 | target = Variable(torch.LongTensor([1, 2]).view(-1, 1)) 91 | output = torch.nll(torch.log(input), target) 92 | output.size() 93 | # (2, 1) 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='pytorch-extras', 4 | version='0.1.3', 5 | description='A handful of extensions to the pytorch library.', 6 | url='http://github.com/mrdrozdov/pytorch-extras', 7 | author='Andrew Drozdov', 8 | author_email='andrew@mrdrozdov.com', 9 | license='MIT', 10 | packages=['torch_extras'], 11 | ) -------------------------------------------------------------------------------- /test/test_extras.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import torch_extras 8 | 9 | setattr(torch, 'expand_along', torch_extras.expand_along) 10 | setattr(torch, 'expand_dims', torch_extras.expand_dims) 11 | setattr(torch, 'select_item', torch_extras.select_item) 12 | setattr(torch, 'cast', torch_extras.cast) 13 | setattr(torch, 'one_hot', torch_extras.one_hot) 14 | setattr(torch, 'nll', torch_extras.nll) 15 | 16 | 17 | class PytorchExtrasTestCase(unittest.TestCase): 18 | 19 | def test_expand_along_variable(self): 20 | var = Variable(torch.Tensor([1, 0, 2])) 21 | mask = torch.ByteTensor([[True, True], [False, True], [False, False]]) 22 | ret = torch.expand_along(var, mask) 23 | expected = torch.Tensor([1, 1, 0]) 24 | assert len(ret) == len(expected) 25 | assert all(torch.eq(ret.data, expected)) 26 | 27 | def test_expand_along_tensor(self): 28 | var = torch.Tensor([1, 0, 2]) 29 | mask = torch.ByteTensor([[True, True], [False, True], [False, False]]) 30 | ret = torch.expand_along(var, mask) 31 | expected = torch.Tensor([1, 1, 0]) 32 | assert len(ret) == len(expected) 33 | assert all(torch.eq(ret, expected)) 34 | 35 | def test_expand_dims_variable(self): 36 | var = Variable(torch.range(0, 9).view(-1, 2)) 37 | ret = torch.expand_dims(var, 0) 38 | expected = var.view(1, -1, 2) 39 | assert ret.size() == expected.size() 40 | 41 | def test_expand_dims_tensor(self): 42 | var = torch.range(0, 9).view(-1, 2) 43 | ret = torch.expand_dims(var, 0) 44 | expected = var.view(1, -1, 2) 45 | assert ret.size() == expected.size() 46 | 47 | def test_select_item_variable(self): 48 | var = Variable(torch.range(0, 9).view(-1, 2)) 49 | index = torch.LongTensor([0, 0, 0, 1, 1]) 50 | ret = torch.select_item(var, index) 51 | expected = torch.Tensor([0, 2, 4, 7, 9]) 52 | assert all(torch.eq(ret.data, expected)) 53 | 54 | def test_select_item_tensor(self): 55 | var = torch.range(0, 9).view(-1, 2) 56 | index = torch.LongTensor([0, 0, 0, 1, 1]) 57 | ret = torch.select_item(var, index) 58 | expected = torch.Tensor([0, 2, 4, 7, 9]) 59 | assert all(torch.eq(ret, expected)) 60 | 61 | def test_cast(self): 62 | inputs = [ 63 | torch.ByteTensor(1), 64 | torch.CharTensor(1), 65 | torch.DoubleTensor(1), 66 | torch.FloatTensor(1), 67 | torch.IntTensor(1), 68 | torch.LongTensor(1), 69 | torch.ShortTensor(1), 70 | ] 71 | 72 | for inp in inputs: 73 | assert type(inp) == type(torch.cast(inp, type(inp))) 74 | 75 | def test_cast_variable(self): 76 | inputs = [ 77 | torch.ByteTensor(1), 78 | torch.CharTensor(1), 79 | torch.DoubleTensor(1), 80 | torch.FloatTensor(1), 81 | torch.IntTensor(1), 82 | torch.LongTensor(1), 83 | torch.ShortTensor(1), 84 | ] 85 | 86 | for inp in inputs: 87 | assert type(inp) == type(torch.cast(Variable(inp), type(inp)).data) 88 | 89 | def test_one_hot(self): 90 | size = (3, 3) 91 | index = torch.LongTensor([2, 0, 1]).view(-1, 1) 92 | ret = torch.one_hot(size, index) 93 | expected = torch.LongTensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 94 | assert ret.size() == expected.size() 95 | assert all(torch.eq(ret.view(-1), expected.view(-1))) 96 | 97 | def test_one_hot_variable(self): 98 | size = (3, 3) 99 | index = torch.LongTensor([2, 0, 1]).view(-1, 1) 100 | ret = torch.one_hot(size, Variable(index)) 101 | expected = torch.LongTensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 102 | assert ret.size() == expected.size() 103 | assert all(torch.eq(ret.view(-1).data, expected.view(-1))) 104 | 105 | def test_nll(self): 106 | input = torch.FloatTensor([[0.5, 0.2, 0.3], [0.1, 0.8, 0.1]]) 107 | target = torch.LongTensor([1, 2]).view(-1, 1) 108 | output = torch.nll(torch.log(input), target) 109 | assert output.size() == (target.size(0), 1) 110 | assert all(o == -1 * row[t] for row, o, t in zip( 111 | torch.log(input), output.view(-1), target.view(-1))) 112 | 113 | def test_nll_variable(self): 114 | input = Variable(torch.FloatTensor([[0.5, 0.2, 0.3], [0.1, 0.8, 0.1]])) 115 | target = Variable(torch.LongTensor([1, 2]).view(-1, 1)) 116 | output = torch.nll(torch.log(input), target) 117 | assert output.size() == (target.size(0), 1) 118 | assert all(o == -1 * row[t] for row, o, t in zip( 119 | torch.log(input).data, output.view(-1).data, target.view(-1).data)) 120 | 121 | 122 | if __name__ == '__main__': 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /torch_extras/__init__.py: -------------------------------------------------------------------------------- 1 | from .extras import expand_along, expand_dims, select_item, cast, one_hot, nll 2 | 3 | __all__ = ['expand_along', 'expand_dims', 'select_item', 4 | 'cast', 'one_hot', 'nll', 5 | ] 6 | -------------------------------------------------------------------------------- /torch_extras/extras.py: -------------------------------------------------------------------------------- 1 | import six 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | 6 | def expand_along(var, mask): 7 | """ Useful for selecting a dynamic amount of items from different 8 | indexes using a byte mask. 9 | 10 | ``` 11 | import torch 12 | import torch_extras 13 | setattr(torch, 'expand_along', torch_extras.expand_along) 14 | 15 | var = torch.Tensor([1, 0, 2]) 16 | mask = torch.ByteTensor([[True, True], [False, True], [False, False]]) 17 | torch.expand_along(var, mask) 18 | # (1, 1, 0) 19 | ``` 20 | """ 21 | indexes = torch.range(0, var.size(0) - 1).view(-1, 1).repeat(1, mask.size(1)) 22 | _mask = indexes[mask].long() 23 | if isinstance(var, Variable): 24 | _mask = Variable(_mask, volatile=var.volatile) 25 | return torch.index_select(var, 0, _mask) 26 | 27 | 28 | def expand_dims(var, dim=0): 29 | """ Is similar to [numpy.expand_dims](https://docs.scipy.org/doc/numpy/reference/generated/numpy.expand_dims.html). 30 | 31 | import torch 32 | import torch_extras 33 | setattr(torch, 'expand_dims', torch_extras.expand_dims) 34 | 35 | var = torch.range(0, 9).view(-1, 2) 36 | torch.expand_dims(var, 0).size() 37 | # (1, 5, 2) 38 | """ 39 | sizes = list(var.size()) 40 | sizes.insert(dim, 1) 41 | return var.view(*sizes) 42 | 43 | 44 | def select_item(var, index): 45 | """ Is similar to `[var[row,col] for row, col in enumerate(index)]`. 46 | 47 | ``` 48 | import torch 49 | import torch_extras 50 | setattr(torch, 'select_item', torch_extras.select_item) 51 | 52 | var = torch.range(0, 9).view(-1, 2) 53 | index = torch.LongTensor([0, 0, 0, 1, 1]) 54 | torch.select_item(var, index) 55 | # [0, 2, 4, 7, 9] 56 | ``` 57 | """ 58 | index_mask = index.view(-1, 1).repeat(1, var.size(1)) 59 | mask = torch.range(0, var.size(1) - 1).long() 60 | mask = mask.repeat(var.size(0), 1) 61 | mask = mask.eq(index_mask) 62 | if isinstance(var, Variable): 63 | mask = Variable(mask, volatile=var.volatile) 64 | return torch.masked_select(var, mask) 65 | 66 | 67 | def cast(var, type): 68 | """ Cast a Tensor to the given type. 69 | 70 | ``` 71 | import torch 72 | import torch_extras 73 | setattr(torch, 'cast', torch_extras.cast) 74 | 75 | input = torch.FloatTensor(1) 76 | target_type = type(torch.LongTensor(1)) 77 | type(torch.cast(input, target_type)) 78 | # 79 | ``` 80 | """ 81 | if type == torch.ByteTensor: 82 | return var.byte() 83 | elif type == torch.CharTensor: 84 | return var.char() 85 | elif type == torch.DoubleTensor: 86 | return var.double() 87 | elif type == torch.FloatTensor: 88 | return var.float() 89 | elif type == torch.IntTensor: 90 | return var.int() 91 | elif type == torch.LongTensor: 92 | return var.long() 93 | elif type == torch.ShortTensor: 94 | return var.short() 95 | else: 96 | raise ValueError("Not a Tensor type.") 97 | 98 | 99 | def one_hot(size, index): 100 | """ Creates a matrix of one hot vectors. 101 | 102 | ``` 103 | import torch 104 | import torch_extras 105 | setattr(torch, 'one_hot', torch_extras.one_hot) 106 | 107 | size = (3, 3) 108 | index = torch.LongTensor([2, 0, 1]).view(-1, 1) 109 | torch.one_hot(size, index) 110 | # [[0, 0, 1], [1, 0, 0], [0, 1, 0]] 111 | ``` 112 | """ 113 | mask = torch.LongTensor(*size).fill_(0) 114 | ones = 1 115 | if isinstance(index, Variable): 116 | ones = Variable(torch.LongTensor(index.size()).fill_(1)) 117 | mask = Variable(mask, volatile=index.volatile) 118 | ret = mask.scatter_(1, index, ones) 119 | return ret 120 | 121 | 122 | def nll(log_prob, label): 123 | """ Is similar to [`nll_loss`](http://pytorch.org/docs/nn.html?highlight=nll#torch.nn.functional.nll_loss) except does not return an aggregate. 124 | 125 | ``` 126 | import torch 127 | from torch.autograd import Variable 128 | import torch.nn.functional as F 129 | import torch_extras 130 | setattr(torch, 'nll', torch_extras.nll) 131 | 132 | input = Variable(torch.FloatTensor([[0.5, 0.2, 0.3], [0.1, 0.8, 0.1]])) 133 | target = Variable(torch.LongTensor([1, 2]).view(-1, 1)) 134 | output = torch.nll(torch.log(input), target) 135 | output.size() 136 | # (2, 1) 137 | ``` 138 | """ 139 | if isinstance(log_prob, Variable): 140 | _type = type(log_prob.data) 141 | else: 142 | _type = type(log_prob) 143 | 144 | mask = one_hot(log_prob.size(), label) 145 | mask = cast(mask, _type) 146 | return -1 * (log_prob * mask).sum(1) 147 | --------------------------------------------------------------------------------