├── LocallyConnected2d.py ├── README.md ├── accumulate_gradients.py ├── adaptive_batchnorm.py ├── adaptive_pooling_torchvision.py ├── batch_norm_manual.py ├── change_crop_in_dataset.py ├── channel_to_patches.py ├── conv_rnn.py ├── csv_chunk_read.py ├── densenet_forwardhook.py ├── edge_weighting_segmentation.py ├── image_rotation_with_matrix.py ├── mnist_autoencoder.py ├── mnist_permuted.py ├── model_sharding_data_parallel.py ├── momentum_update_nograd.py ├── pytorch_redis.py ├── shared_array.py ├── shared_dict.py ├── unet_demo.py └── weighted_sampling.py /LocallyConnected2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test implementation of locally connected 2d layer 3 | The first part of the script was used for debugging 4 | 5 | @author: ptrblck 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn.modules.utils import _pair 12 | 13 | 14 | ## DEBUG 15 | batch_size = 5 16 | in_channels = 3 17 | h, w = 24, 24 18 | x = torch.ones(batch_size, in_channels, h, w) 19 | kh, kw = 3, 3 # kernel_size 20 | dh, dw = 1, 1 # stride 21 | x_windows = x.unfold(2, kh, dh).unfold(3, kw, dw) 22 | x_windows = x_windows.contiguous().view(*x_windows.size()[:-2], -1) 23 | 24 | out_channels = 2 25 | weights = torch.randn(1, out_channels, in_channels, *x_windows.size()[2:]) 26 | output = (x_windows.unsqueeze(1) * weights).sum([2, -1]) 27 | ## DEBUG 28 | 29 | 30 | class LocallyConnected2d(nn.Module): 31 | def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False): 32 | super(LocallyConnected2d, self).__init__() 33 | output_size = _pair(output_size) 34 | self.weight = nn.Parameter( 35 | torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2) 36 | ) 37 | if bias: 38 | self.bias = nn.Parameter( 39 | torch.randn(1, out_channels, output_size[0], output_size[1]) 40 | ) 41 | else: 42 | self.register_parameter('bias', None) 43 | self.kernel_size = _pair(kernel_size) 44 | self.stride = _pair(stride) 45 | 46 | def forward(self, x): 47 | _, c, h, w = x.size() 48 | kh, kw = self.kernel_size 49 | dh, dw = self.stride 50 | x = x.unfold(2, kh, dh).unfold(3, kw, dw) 51 | x = x.contiguous().view(*x.size()[:-2], -1) 52 | # Sum in in_channel and kernel_size dims 53 | out = (x.unsqueeze(1) * self.weight).sum([2, -1]) 54 | if self.bias is not None: 55 | out += self.bias 56 | return out 57 | 58 | 59 | # Create input 60 | batch_size = 5 61 | in_channels = 3 62 | h, w = 24, 24 63 | x = torch.randn(batch_size, in_channels, h, w) 64 | 65 | # Create layer and test if backpropagation works 66 | out_channels = 2 67 | output_size = 22 68 | kernel_size = 3 69 | stride = 1 70 | conv = LocallyConnected2d( 71 | in_channels, out_channels, output_size, kernel_size, stride, bias=True) 72 | 73 | out = conv(x) 74 | out.mean().backward() 75 | print(conv.weight.grad) 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch misc 2 | Collection of code snippets I've written for the [PyTorch discussion board](https://discuss.pytorch.org/). 3 | 4 | All scripts were testes using the PyTorch 1.0 preview and torchvision `0.2.1`. 5 | 6 | Additional libraries, e.g. `numpy` or `pandas`, are used in a few scripts. 7 | 8 | Some scripts might be a good starter to create a tutorial. 9 | 10 | ## Overview 11 | 12 | * [accumulate_gradients](https://github.com/ptrblck/pytorch_misc/blob/master/accumulate_gradients.py) - Comparison of accumulated gradients/losses to vanilla batch update. 13 | * [adaptive_batchnorm](https://github.com/ptrblck/pytorch_misc/blob/master/adaptive_batchnorm.py)- Adaptive BN implementation using two additional parameters: `out = a * x + b * bn(x)`. 14 | * [adaptive_pooling_torchvision](https://github.com/ptrblck/pytorch_misc/blob/master/adaptive_pooling_torchvision.py) - Example of using adaptive pooling layers in pretrained models to use different spatial input shapes. 15 | * [batch_norm_manual](https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py) - Comparison of PyTorch BatchNorm layers and a manual calculation. 16 | * [change_crop_in_dataset](https://github.com/ptrblck/pytorch_misc/blob/master/change_crop_in_dataset.py) - Change the image crop size on the fly using a Dataset. 17 | * [channel_to_patches](https://github.com/ptrblck/pytorch_misc/blob/master/channel_to_patches.py) - Permute image data so that channel values of each pixel are flattened to an image patch around the pixel. 18 | * [conv_rnn](https://github.com/ptrblck/pytorch_misc/blob/master/conv_rnn.py) - Combines a 3DCNN with an RNN; uses windowed frames as inputs. 19 | * [csv_chunk_read](https://github.com/ptrblck/pytorch_misc/blob/master/csv_chunk_read.py) - Provide data chunks from continuous .csv file. 20 | * [densenet_forwardhook](https://github.com/ptrblck/pytorch_misc/blob/master/densenet_forwardhook.py) - Use forward hooks to get intermediate activations from `densenet121`. Uses separate modules to process these activations further. 21 | * [edge_weighting_segmentation](https://github.com/ptrblck/pytorch_misc/blob/master/edge_weighting_segmentation.py) - Apply weighting to edges for a segmentation task. 22 | * [image_rotation_with_matrix](https://github.com/ptrblck/pytorch_misc/blob/master/image_rotation_with_matrix.py) - Rotate an image given an angle using 1.) a nested loop and 2.) a rotation matrix and mesh grid. 23 | * [LocallyConnected2d](https://github.com/ptrblck/pytorch_misc/blob/master/LocallyConnected2d.py) - Implementation of a locally connected 2d layer. 24 | * [mnist_autoencoder](https://github.com/ptrblck/pytorch_misc/blob/master/mnist_autoencoder.py) - Simple autoencoder for MNIST data. Includes visualizations of output images, intermediate activations and conv kernels. 25 | * [mnist_permuted](https://github.com/ptrblck/pytorch_misc/blob/master/mnist_permuted.py) - MNIST training using permuted pixel locations. 26 | * [model_sharding_data_parallel](https://github.com/ptrblck/pytorch_misc/blob/master/model_sharding_data_parallel.py) - Model sharding with `DataParallel` using 2 pairs of 2 GPUs. 27 | * [momentum_update_nograd](https://github.com/ptrblck/pytorch_misc/blob/master/momentum_update_nograd.py) - Script to see how parameters are updated when an optimizer is used with momentum/running estimates, even if gradients are zero. 28 | * [pytorch_redis](https://github.com/ptrblck/pytorch_misc/blob/master/pytorch_redis.py) - Script to demonstrate the loading data from redis using a PyTorch Dataset and DataLoader. 29 | * [shared_array](https://github.com/ptrblck/pytorch_misc/blob/master/shared_array.py) - Script to demonstrate the usage of shared arrays using multiple workers. 30 | * [shared_dict](https://github.com/ptrblck/pytorch_misc/blob/master/shared_dict.py) - Script to demonstrate the usage of shared dicts using multiple workers. 31 | * [unet_demo](https://github.com/ptrblck/pytorch_misc/blob/master/unet_demo.py) - Simple UNet demo. 32 | * [weighted_sampling](https://github.com/ptrblck/pytorch_misc/blob/master/weighted_sampling.py) - Usage of WeightedRandomSampler using an imbalanced dataset with class imbalance 99 to 1. 33 | 34 | 35 | Feedback is very welcome! 36 | -------------------------------------------------------------------------------- /accumulate_gradients.py: -------------------------------------------------------------------------------- 1 | """ 2 | Comparison of accumulated gradients/losses to vanilla batch update. 3 | Comments from @albanD: 4 | https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20 5 | 6 | @author: ptrblck 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | # Accumulate loss for each samples 14 | # more runtime, more memory 15 | x1 = torch.ones(2, 1) 16 | w1 = torch.ones(1, 1, requires_grad=True) 17 | y1 = torch.ones(2, 1) * 2 18 | 19 | criterion = nn.MSELoss() 20 | 21 | loss1 = 0 22 | for i in range(10): 23 | output1 = torch.matmul(x1, w1) 24 | loss1 += criterion(output1, y1) 25 | loss1 /= 10 # scale loss to match batch gradient 26 | loss1.backward() 27 | 28 | print('Accumulated losses: {}'.format(w1.grad)) 29 | 30 | # Use whole batch to calculate gradient 31 | # least runtime, more memory 32 | x2 = torch.ones(20, 1) 33 | w2 = torch.ones(1, 1, requires_grad=True) 34 | y2 = torch.ones(20, 1) * 2 35 | 36 | output2 = torch.matmul(x2, w2) 37 | loss2 = criterion(output2, y2) 38 | loss2.backward() 39 | print('Batch gradient: {}'.format(w2.grad)) 40 | 41 | # Accumulate scaled gradient 42 | # more runtime, least memory 43 | x3 = torch.ones(2, 1) 44 | w3 = torch.ones(1, 1, requires_grad=True) 45 | y3 = torch.ones(2, 1) * 2 46 | 47 | for i in range(10): 48 | output3 = torch.matmul(x3, w3) 49 | loss3 = criterion(output3, y3) 50 | loss3 /= 10 51 | loss3.backward() 52 | 53 | print('Accumulated gradient: {}'.format(w3.grad)) 54 | -------------------------------------------------------------------------------- /adaptive_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Adaptive BatchNorm 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torchvision import datasets, transforms 12 | 13 | 14 | # Globals 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | seed = 2809 17 | batch_size = 10 18 | lr = 0.01 19 | log_interval = 10 20 | epochs = 10 21 | torch.manual_seed(seed) 22 | 23 | 24 | class AdaptiveBatchNorm2d(nn.Module): 25 | ''' 26 | Adaptive BN implementation using two additional parameters: 27 | out = a * x + b * bn(x) 28 | ''' 29 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 30 | super(AdaptiveBatchNorm2d, self).__init__() 31 | self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine) 32 | self.a = nn.Parameter(torch.FloatTensor(1, 1, 1, 1)) 33 | self.b = nn.Parameter(torch.FloatTensor(1, 1, 1, 1)) 34 | 35 | def forward(self, x): 36 | return self.a * x + self.b * self.bn(x) 37 | 38 | 39 | class MyNet(nn.Module): 40 | def __init__(self): 41 | super(MyNet, self).__init__() 42 | self.conv1 = nn.Conv2d(in_channels=1, 43 | out_channels=10, 44 | kernel_size=5) 45 | self.conv1_bn = AdaptiveBatchNorm2d(10) 46 | self.conv2 = nn.Conv2d(in_channels=10, 47 | out_channels=20, 48 | kernel_size=5) 49 | self.conv2_bn = AdaptiveBatchNorm2d(20) 50 | self.fc1 = nn.Linear(320, 50) 51 | self.fc2 = nn.Linear(50, 10) 52 | 53 | def forward(self, x): 54 | x = F.relu(F.max_pool2d(self.conv1_bn(self.conv1(x)), 2)) 55 | x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2)) 56 | x = x.view(-1, 320) 57 | x = F.relu(self.fc1(x)) 58 | x = self.fc2(x) 59 | return F.log_softmax(x, dim=1) 60 | 61 | 62 | def train(epoch): 63 | model.train() 64 | for batch_idx, (data, target) in enumerate(train_loader): 65 | data, target = data.to(device), target.to(device) 66 | optimizer.zero_grad() 67 | output = model(data) 68 | loss = F.nll_loss(output, target) 69 | loss.backward() 70 | optimizer.step() 71 | if batch_idx % log_interval == 0: 72 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 73 | epoch, batch_idx * len(data), len(train_loader.dataset), 74 | 100. * batch_idx / len(train_loader), loss.item())) 75 | 76 | 77 | def test(): 78 | model.eval() 79 | test_loss = 0 80 | correct = 0 81 | with torch.no_grad(): 82 | for data, target in test_loader: 83 | data, target = data.to(device), target.to(device) 84 | output = model(data) 85 | # sum up batch loss 86 | test_loss += F.nll_loss(output, target, size_average=False).item() 87 | # get the index of the max log-probability 88 | pred = output.data.max(1, keepdim=True)[1] 89 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 90 | 91 | test_loss /= len(test_loader.dataset) 92 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 93 | test_loss, correct, len(test_loader.dataset), 94 | 100. * correct / len(test_loader.dataset))) 95 | 96 | 97 | train_loader = torch.utils.data.DataLoader( 98 | datasets.MNIST('./data', train=True, download=True, 99 | transform=transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.1307,), (0.3081,)) 102 | ])), 103 | batch_size=batch_size, 104 | shuffle=True) 105 | 106 | test_loader = torch.utils.data.DataLoader( 107 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 108 | transforms.ToTensor(), 109 | transforms.Normalize((0.1307,), (0.3081,)) 110 | ])), 111 | batch_size=batch_size, 112 | shuffle=True) 113 | 114 | model = MyNet().to(device) 115 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.5) 116 | 117 | for epoch in range(1, epochs + 1): 118 | train(epoch) 119 | test() 120 | -------------------------------------------------------------------------------- /adaptive_pooling_torchvision.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adaptive pooling layer examples 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import torchvision.models as models 11 | 12 | 13 | # Use standard model with [batch_size, 3, 224, 224] input 14 | model = models.vgg16(pretrained=False) 15 | batch_size = 1 16 | x = torch.randn(batch_size, 3, 224, 224) 17 | output = model(x) 18 | 19 | # Try bigger input 20 | x_big = torch.randn(batch_size, 3, 299, 299) 21 | try: 22 | output = model(x_big) 23 | except RuntimeError as e: 24 | print(e) 25 | 26 | # Try smaller input 27 | x_small = torch.randn(batch_size, 3, 128, 128) 28 | try: 29 | output = model(x_small) 30 | except RuntimeError as e: 31 | print(e) 32 | # Both don't work, since we get a size mismatch for these sizes 33 | 34 | # Get the size of the last activation map before the classifier 35 | def size_hook(module, input, output): 36 | print(output.shape) 37 | 38 | model.features[-1].register_forward_hook(size_hook) 39 | output = model(x) 40 | 41 | # We see that the last pooling layer returns an activation of 42 | # [batch_size, 512, 7, 7]. So let's replace it with an adaptive layer with an 43 | # output shape of 7x7. 44 | model.features[-1] = nn.AdaptiveMaxPool2d(output_size=7) 45 | 46 | # Now let's try the other shapes again 47 | output = model(x_big) 48 | output = model(x_small) 49 | 50 | x_tiny = torch.randn(batch_size, 3, 16, 16) 51 | output = model(x_tiny) 52 | 53 | # Now these inputs are working! 54 | # There is however a minimal size as we need a spatial size of at least 1x1 55 | # to pass into the adaptive pooling layer 56 | x_too_small = torch.randn(batch_size, 3, 15, 15) 57 | try: 58 | output = model(x_too_small) 59 | except RuntimeError as e: 60 | print(e) 61 | -------------------------------------------------------------------------------- /batch_norm_manual.py: -------------------------------------------------------------------------------- 1 | """ 2 | Comparison of manual BatchNorm2d layer implementation in Python and 3 | nn.BatchNorm2d 4 | 5 | @author: ptrblck 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def compare_bn(bn1, bn2): 13 | err = False 14 | if not torch.allclose(bn1.running_mean, bn2.running_mean): 15 | print('Diff in running_mean: {} vs {}'.format( 16 | bn1.running_mean, bn2.running_mean)) 17 | err = True 18 | 19 | if not torch.allclose(bn1.running_var, bn2.running_var): 20 | print('Diff in running_var: {} vs {}'.format( 21 | bn1.running_var, bn2.running_var)) 22 | err = True 23 | 24 | if bn1.affine and bn2.affine: 25 | if not torch.allclose(bn1.weight, bn2.weight): 26 | print('Diff in weight: {} vs {}'.format( 27 | bn1.weight, bn2.weight)) 28 | err = True 29 | 30 | if not torch.allclose(bn1.bias, bn2.bias): 31 | print('Diff in bias: {} vs {}'.format( 32 | bn1.bias, bn2.bias)) 33 | err = True 34 | 35 | if not err: 36 | print('All parameters are equal!') 37 | 38 | 39 | class MyBatchNorm2d(nn.BatchNorm2d): 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, 41 | affine=True, track_running_stats=True): 42 | super(MyBatchNorm2d, self).__init__( 43 | num_features, eps, momentum, affine, track_running_stats) 44 | 45 | def forward(self, input): 46 | self._check_input_dim(input) 47 | 48 | exponential_average_factor = 0.0 49 | 50 | if self.training and self.track_running_stats: 51 | if self.num_batches_tracked is not None: 52 | self.num_batches_tracked += 1 53 | if self.momentum is None: # use cumulative moving average 54 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 55 | else: # use exponential moving average 56 | exponential_average_factor = self.momentum 57 | 58 | # calculate running estimates 59 | if self.training: 60 | mean = input.mean([0, 2, 3]) 61 | # use biased var in train 62 | var = input.var([0, 2, 3], unbiased=False) 63 | n = input.numel() / input.size(1) 64 | with torch.no_grad(): 65 | self.running_mean = exponential_average_factor * mean\ 66 | + (1 - exponential_average_factor) * self.running_mean 67 | # update running_var with unbiased var 68 | self.running_var = exponential_average_factor * var * n / (n - 1)\ 69 | + (1 - exponential_average_factor) * self.running_var 70 | else: 71 | mean = self.running_mean 72 | var = self.running_var 73 | 74 | input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps)) 75 | if self.affine: 76 | input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None] 77 | 78 | return input 79 | 80 | 81 | # Init BatchNorm layers 82 | my_bn = MyBatchNorm2d(3, affine=True) 83 | bn = nn.BatchNorm2d(3, affine=True) 84 | 85 | compare_bn(my_bn, bn) # weight and bias should be different 86 | # Load weight and bias 87 | my_bn.load_state_dict(bn.state_dict()) 88 | compare_bn(my_bn, bn) 89 | 90 | # Run train 91 | for _ in range(10): 92 | scale = torch.randint(1, 10, (1,)).float() 93 | bias = torch.randint(-10, 10, (1,)).float() 94 | x = torch.randn(10, 3, 100, 100) * scale + bias 95 | out1 = my_bn(x) 96 | out2 = bn(x) 97 | compare_bn(my_bn, bn) 98 | 99 | torch.allclose(out1, out2) 100 | print('Max diff: ', (out1 - out2).abs().max()) 101 | 102 | # Run eval 103 | my_bn.eval() 104 | bn.eval() 105 | for _ in range(10): 106 | scale = torch.randint(1, 10, (1,)).float() 107 | bias = torch.randint(-10, 10, (1,)).float() 108 | x = torch.randn(10, 3, 100, 100) * scale + bias 109 | out1 = my_bn(x) 110 | out2 = bn(x) 111 | compare_bn(my_bn, bn) 112 | 113 | torch.allclose(out1, out2) 114 | print('Max diff: ', (out1 - out2).abs().max()) 115 | -------------------------------------------------------------------------------- /change_crop_in_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Change the crop size on the fly using a Dataset. 3 | MyDataset.set_state(stage) switches between crop sizes. 4 | Alternatively, the crop size could be specified. 5 | 6 | @author: ptrblck 7 | """ 8 | 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms 12 | import torchvision.transforms.functional as TF 13 | 14 | 15 | class MyDataset(Dataset): 16 | def __init__(self): 17 | self.images = [TF.to_pil_image(x) for x in torch.ByteTensor(10, 3, 48, 48)] 18 | self.set_stage(0) 19 | 20 | def __getitem__(self, index): 21 | image = self.images[index] 22 | 23 | # Switch your behavior depending on stage 24 | image = self.crop(image) 25 | x = TF.to_tensor(image) 26 | return x 27 | 28 | def set_stage(self, stage): 29 | if stage == 0: 30 | print('Using (32, 32) crops') 31 | self.crop = transforms.RandomCrop((32, 32)) 32 | elif stage == 1: 33 | print('Using (28, 28) crops') 34 | self.crop = transforms.RandomCrop((28, 28)) 35 | 36 | def __len__(self): 37 | return len(self.images) 38 | 39 | 40 | dataset = MyDataset() 41 | loader = DataLoader(dataset, 42 | batch_size=2, 43 | num_workers=2, 44 | shuffle=True) 45 | 46 | # Use standard crop size 47 | for batch_idx, data in enumerate(loader): 48 | print('Batch idx {}, data shape {}'.format( 49 | batch_idx, data.shape)) 50 | 51 | # Switch to stage1 crop size 52 | loader.dataset.set_stage(1) 53 | 54 | # Check the shape again 55 | for batch_idx, data in enumerate(loader): 56 | print('Batch idx {}, data shape {}'.format( 57 | batch_idx, data.shape)) 58 | -------------------------------------------------------------------------------- /channel_to_patches.py: -------------------------------------------------------------------------------- 1 | """ 2 | Permute image data so that channel values of each pixel are flattened to an image patch around the pixel. 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | 9 | B, C, H, W = 2, 16, 4, 4 10 | # Create dummy input with same values in each channel 11 | x = torch.arange(C)[None, :, None, None].repeat(B, 1, H, W) 12 | print(x) 13 | # Permute channel dimension to last position and view as 4x4 windows 14 | x = x.permute(0, 2, 3, 1).view(B, H, W, 4, 4) 15 | print(x) 16 | # Permute "window dims" with spatial dims, view as desired output 17 | x = x.permute(0, 1, 3, 2, 4).contiguous().view(B, 1, 4*H, 4*W) 18 | print(x) 19 | -------------------------------------------------------------------------------- /conv_rnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Combine Conv3d with an RNN Module. 3 | Use windowed frames as inputs. 4 | 5 | @author: ptrblck 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class MyModel(nn.Module): 15 | def __init__(self, window=16): 16 | super(MyModel, self).__init__() 17 | self.conv_model = nn.Sequential( 18 | nn.Conv3d( 19 | in_channels=3, 20 | out_channels=6, 21 | kernel_size=3, 22 | stride=1, 23 | padding=1), 24 | nn.MaxPool3d((1, 2, 2)), 25 | nn.ReLU() 26 | ) 27 | 28 | self.rnn = nn.RNN( 29 | input_size=6*16*12*12, 30 | hidden_size=1, 31 | num_layers=1, 32 | batch_first=True 33 | ) 34 | self.hidden = torch.zeros(1, 1, 1) 35 | self.window = window 36 | 37 | def forward(self, x): 38 | self.hidden = torch.zeros(1, 1, 1) # reset hidden 39 | 40 | activations = [] 41 | for idx in range(0, x.size(2), self.window): 42 | x_ = x[:, :, idx:idx+self.window] 43 | x_ = self.conv_model(x_) 44 | x_ = x_.view(x_.size(0), 1, -1) 45 | activations.append(x_) 46 | x = torch.cat(activations, 1) 47 | out, hidden = self.rnn(x, self.hidden) 48 | 49 | return out, hidden 50 | 51 | 52 | class MyDataset(Dataset): 53 | ''' 54 | Returns windowed frames from sequential data. 55 | ''' 56 | def __init__(self, frames=512): 57 | self.data = torch.randn(3, 2048, 24, 24) 58 | self.frames = frames 59 | 60 | def __getitem__(self, index): 61 | index = index * self.frames 62 | x = self.data[:, index:index+self.frames] 63 | return x 64 | 65 | def __len__(self): 66 | return self.data.size(1) / self.frames 67 | 68 | 69 | model = MyModel() 70 | dataset = MyDataset() 71 | x = dataset[0] 72 | output, hidden = model(x.unsqueeze(0)) 73 | -------------------------------------------------------------------------------- /csv_chunk_read.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provide data chunks from continuous .csv file. 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | import numpy as np 11 | import pandas as pd 12 | 13 | # Create dummy csv data 14 | nb_samples = 110 15 | a = np.arange(nb_samples) 16 | df = pd.DataFrame(a, columns=['data']) 17 | df.to_csv('data.csv', index=False) 18 | 19 | 20 | # Create Dataset 21 | class CSVDataset(Dataset): 22 | def __init__(self, path, chunksize, nb_samples): 23 | self.path = path 24 | self.chunksize = chunksize 25 | self.len = nb_samples // self.chunksize 26 | 27 | def __getitem__(self, index): 28 | ''' 29 | Get next chunk of data 30 | ''' 31 | x = next( 32 | pd.read_csv( 33 | self.path, 34 | skiprows=index * self.chunksize + 1, # +1, since we skip the header 35 | chunksize=self.chunksize, 36 | names=['data'])) 37 | x = torch.from_numpy(x.data.values) 38 | return x 39 | 40 | def __len__(self): 41 | return self.len 42 | 43 | 44 | dataset = CSVDataset('data.csv', chunksize=10, nb_samples=nb_samples) 45 | loader = DataLoader(dataset, batch_size=10, num_workers=1, shuffle=False) 46 | 47 | for batch_idx, data in enumerate(loader): 48 | print('batch: {}\tdata: {}'.format(batch_idx, data)) 49 | -------------------------------------------------------------------------------- /densenet_forwardhook.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use forward hooks to get intermediate activations from densenet121. 3 | Create additional conv layers to process these activations to get a 4 | desired number of output channels 5 | 6 | @author: ptrblck 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from torchvision import models 13 | 14 | 15 | activations = {} 16 | def get_activation(name): 17 | def hook(model, input, output): 18 | activations[name] = output 19 | return hook 20 | 21 | # Create Model 22 | model = models.densenet121(pretrained=False) 23 | 24 | # Register forward hooks with name 25 | for name, child in model.features.named_children(): 26 | if 'denseblock' in name: 27 | print(name) 28 | child.register_forward_hook(get_activation(name)) 29 | 30 | # Forward pass 31 | x = torch.randn(1, 3, 224, 224) 32 | output = model(x) 33 | 34 | # Create convs to get desired out_channels 35 | out_channels = 1 36 | convs = {'denseblock1': nn.Conv2d(256, out_channels, 1,), 37 | 'denseblock2': nn.Conv2d(512, out_channels, 1), 38 | 'denseblock3': nn.Conv2d(1024, out_channels, 1), 39 | 'denseblock4': nn.Conv2d(1024, out_channels, 1)} 40 | 41 | # Apply conv on each activation 42 | for key in activations: 43 | act = activations[key] 44 | act = convs[key](act) 45 | print(key, act.shape) 46 | -------------------------------------------------------------------------------- /edge_weighting_segmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply weighting to edges for a segmentation task 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | # Create dummy input and target with two squares 15 | output = F.log_softmax(torch.randn(1, 3, 24, 24), 1) 16 | target = torch.zeros(1, 24, 24, dtype=torch.long) 17 | target[0, 4:12, 4:12] = 1 18 | target[0, 14:20, 14:20] = 2 19 | plt.imshow(target[0]) 20 | 21 | # Edge calculation 22 | # Get binary target 23 | bin_target = torch.where(target > 0, torch.tensor(1), torch.tensor(0)) 24 | plt.imshow(bin_target[0]) 25 | 26 | # Use average pooling to get edge 27 | o = F.avg_pool2d(bin_target.float(), kernel_size=3, padding=1, stride=1) 28 | plt.imshow(o[0]) 29 | 30 | edge_idx = (o.ge(0.01) * o.le(0.99)).float() 31 | plt.imshow(edge_idx[0]) 32 | 33 | # Create weight mask 34 | weights = torch.ones_like(edge_idx, dtype=torch.float32) 35 | weights_sum0 = weights.sum() # Save initial sum for later rescaling 36 | weights = weights + edge_idx * 2. # Weight edged with 2x loss 37 | weights_sum1 = weights.sum() 38 | weights = weights / weights_sum1 * weights_sum0 # Rescale weigths 39 | plt.imshow(weights[0]) 40 | 41 | # Calculate loss 42 | criterion = nn.NLLLoss(reduction='none') 43 | loss = criterion(output, target) 44 | loss = loss * weights # Apply weighting 45 | loss = loss.sum() / weights.sum() # Scale loss 46 | -------------------------------------------------------------------------------- /image_rotation_with_matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rotate image given an angle. 3 | 1. Calculate rotated position for each input pixel 4 | 2. Use meshgrid and rotation matrix to achieve the same 5 | 6 | @author: ptrblck 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | 12 | 13 | # Create dummy image 14 | batch_size = 1 15 | im = torch.zeros(batch_size, 1, 10, 10) 16 | im[:, :, :, 2] = 1. 17 | 18 | # Set angle 19 | angle = torch.tensor([72 * np.pi / 180.]) 20 | 21 | # Calculate rotation for each target pixel 22 | x_mid = (im.size(2) + 1) / 2. 23 | y_mid = (im.size(3) + 1) / 2. 24 | im_rot = torch.zeros_like(im) 25 | for r in range(im.size(2)): 26 | for c in range(im.size(3)): 27 | x = (r - x_mid) * torch.cos(angle) + (c - y_mid) * torch.sin(angle) 28 | y = -1.0 * (r - x_mid) * torch.sin(angle) + (c - y_mid) * torch.cos(angle) 29 | x = torch.round(x) + x_mid 30 | y = torch.round(y) + y_mid 31 | 32 | if (x >= 0 and y >= 0 and x < im.size(2) and y < im.size(3)): 33 | im_rot[:, :, r, c] = im[:, :, x.long().item(), y.long().item()] 34 | 35 | 36 | # Calculate rotation with inverse rotation matrix 37 | rot_matrix = torch.tensor([[torch.cos(angle), torch.sin(angle)], 38 | [-1.0*torch.sin(angle), torch.cos(angle)]]) 39 | 40 | # Use meshgrid for pixel coords 41 | xv, yv = torch.meshgrid(torch.arange(im.size(2)), torch.arange(im.size(3))) 42 | xv = xv.contiguous() 43 | yv = yv.contiguous() 44 | src_ind = torch.cat(( 45 | (xv.float() - x_mid).view(-1, 1), 46 | (yv.float() - y_mid).view(-1, 1)), 47 | dim=1 48 | ) 49 | 50 | # Calculate indices using rotation matrix 51 | src_ind = torch.matmul(src_ind, rot_matrix.t()) 52 | src_ind = torch.round(src_ind) 53 | src_ind += torch.tensor([[x_mid, y_mid]]) 54 | 55 | # Set out of bounds indices to limits 56 | src_ind[src_ind < 0] = 0. 57 | src_ind[:, 0][src_ind[:, 0] >= im.size(2)] = float(im.size(2)) - 1 58 | src_ind[:, 1][src_ind[:, 1] >= im.size(3)] = float(im.size(3)) - 1 59 | 60 | # Create new rotated image 61 | im_rot2 = torch.zeros_like(im) 62 | src_ind = src_ind.long() 63 | im_rot2[:, :, xv.view(-1), yv.view(-1)] = im[:, :, src_ind[:, 0], src_ind[:, 1]] 64 | im_rot2 = im_rot2.view(batch_size, 1, 10, 10) 65 | 66 | print('Using method 1: {}'.format(im_rot)) 67 | print('Using method 2: {}'.format(im_rot2)) 68 | -------------------------------------------------------------------------------- /mnist_autoencoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple autoencoder for MNIST data. 3 | Visualizes some output images, intermediate activations as well as some conv 4 | kernels. 5 | 6 | @author: ptrblck 7 | """ 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch.utils.data import DataLoader 15 | 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | class MyModel(nn.Module): 23 | def __init__(self): 24 | super(MyModel, self).__init__() 25 | self.conv1 = nn.Conv2d(1, 3, 3, 1, 1) 26 | self.pool1 = nn.MaxPool2d(2) 27 | self.conv2 = nn.Conv2d(3, 6, 3, 1, 1) 28 | self.pool2 = nn.MaxPool2d(2) 29 | 30 | self.conv_trans1 = nn.ConvTranspose2d(6, 3, 4, 2, 1) 31 | self.conv_trans2 = nn.ConvTranspose2d(3, 1, 4, 2, 1) 32 | 33 | def forward(self, x): 34 | x = F.relu(self.pool1(self.conv1(x))) 35 | x = F.relu(self.pool2(self.conv2(x))) 36 | x = F.relu(self.conv_trans1(x)) 37 | x = self.conv_trans2(x) 38 | return x 39 | 40 | dataset = datasets.MNIST( 41 | root='./data', 42 | transform=transforms.ToTensor() 43 | ) 44 | loader = DataLoader( 45 | dataset, 46 | num_workers=2, 47 | batch_size=8, 48 | shuffle=True 49 | ) 50 | 51 | model = MyModel() 52 | criterion = nn.BCEWithLogitsLoss() 53 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 54 | 55 | epochs = 1 56 | for epoch in range(epochs): 57 | for batch_idx, (data, target) in enumerate(loader): 58 | optimizer.zero_grad() 59 | output = model(data) 60 | loss = criterion(output, data) 61 | loss.backward() 62 | optimizer.step() 63 | 64 | print('Epoch {}, Batch idx {}, loss {}'.format( 65 | epoch, batch_idx, loss.item())) 66 | 67 | 68 | def normalize_output(img): 69 | img = img - img.min() 70 | img = img / img.max() 71 | return img 72 | 73 | # Plot some images 74 | idx = torch.randint(0, output.size(0), ()) 75 | pred = normalize_output(output[idx, 0]) 76 | img = data[idx, 0] 77 | 78 | fig, axarr = plt.subplots(1, 2) 79 | axarr[0].imshow(img.detach().numpy()) 80 | axarr[1].imshow(pred.detach().numpy()) 81 | 82 | # Visualize feature maps 83 | activation = {} 84 | def get_activation(name): 85 | def hook(model, input, output): 86 | activation[name] = output.detach() 87 | return hook 88 | 89 | model.conv1.register_forward_hook(get_activation('conv1')) 90 | data, _ = dataset[0] 91 | data.unsqueeze_(0) 92 | output = model(data) 93 | 94 | act = activation['conv1'].squeeze() 95 | fig, axarr = plt.subplots(act.size(0)) 96 | for idx in range(act.size(0)): 97 | axarr[idx].imshow(act[idx]) 98 | 99 | # Visualize conv filter 100 | kernels = model.conv1.weight.detach() 101 | fig, axarr = plt.subplots(kernels.size(0)) 102 | for idx in range(kernels.size(0)): 103 | axarr[idx].imshow(kernels[idx].squeeze()) 104 | -------------------------------------------------------------------------------- /mnist_permuted.py: -------------------------------------------------------------------------------- 1 | """ 2 | Permute all pixels of MNIST data and try to learn it using simple model. 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | 13 | from torchvision import datasets 14 | from torchvision import transforms 15 | 16 | import numpy as np 17 | 18 | # Create random indices to permute images 19 | indices = np.arange(28*28) 20 | np.random.shuffle(indices) 21 | 22 | 23 | def shuffle_image(tensor): 24 | tensor = tensor.view(-1)[indices].view(1, 28, 28) 25 | return tensor 26 | 27 | 28 | # Apply permuatation using transforms.Lambda 29 | train_dataset = datasets.MNIST(root='./data', 30 | download=False, 31 | train=True, 32 | transform=transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.1307,), (0.3081,)), 35 | transforms.Lambda(shuffle_image) 36 | ])) 37 | 38 | test_dataset = datasets.MNIST(root='./data', 39 | download=False, 40 | train=False, 41 | transform=transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.1307,), (0.3081,)), 44 | transforms.Lambda(shuffle_image) 45 | ])) 46 | 47 | train_loader = DataLoader(train_dataset, 48 | batch_size=1, 49 | shuffle=True) 50 | 51 | test_loader = DataLoader(test_dataset, 52 | batch_size=1, 53 | shuffle=True) 54 | 55 | 56 | class MyModel(nn.Module): 57 | def __init__(self): 58 | super(MyModel, self).__init__() 59 | self.act = nn.ReLU() 60 | self.conv1 = nn.Conv2d(1, 4, 3, 1, 1) 61 | self.pool1 = nn.MaxPool2d(2) 62 | self.conv2 = nn.Conv2d(4, 8, 3, 1, 1) 63 | self.pool2 = nn.MaxPool2d(2) 64 | self.fc1 = nn.Linear(7*7*8, 10) 65 | 66 | def forward(self, x): 67 | x = self.act(self.conv1(x)) 68 | x = self.pool1(x) 69 | x = self.act(self.conv2(x)) 70 | x = self.pool2(x) 71 | x = x.view(x.size(0), -1) 72 | x = F.log_softmax(self.fc1(x), dim=1) 73 | return x 74 | 75 | 76 | def train(): 77 | acc = 0.0 78 | for batch_idx, (data, target) in enumerate(train_loader): 79 | optimizer.zero_grad() 80 | output = model(data) 81 | loss = criterion(output, target) 82 | loss.backward() 83 | optimizer.step() 84 | 85 | _, pred = torch.max(output, dim=1) 86 | accuracy = (pred == target).sum() / float(pred.size(0)) 87 | acc += accuracy.data.float() 88 | 89 | if (batch_idx + 1) % 10 == 0: 90 | print('batch idx {}, loss {}'.format( 91 | batch_idx, loss.item())) 92 | 93 | acc /= len(train_loader) 94 | print('Train accuracy {}'.format(acc)) 95 | 96 | 97 | def test(): 98 | acc = 0.0 99 | losses = 0.0 100 | for batch_idx, (data, target) in enumerate(test_loader): 101 | with torch.no_grad(): 102 | output = model(data) 103 | loss = criterion(output, target) 104 | _, pred = torch.max(output, dim=1) 105 | 106 | accuracy = (pred == target).sum() / float(pred.size(0)) 107 | acc += accuracy.data.float() 108 | losses += loss.item() 109 | acc /= len(test_loader) 110 | losses /= len(test_loader) 111 | print('Acc {}, loss {}'.format( 112 | acc, losses)) 113 | 114 | 115 | model = MyModel() 116 | criterion = nn.NLLLoss() 117 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 118 | 119 | train() 120 | test() 121 | 122 | # Visualize filters 123 | import matplotlib.pyplot as plt 124 | from torchvision.utils import make_grid 125 | filts1 = model.conv1.weight.data 126 | grid = make_grid(filts1) 127 | grid = grid.permute(1, 2, 0) 128 | plt.imshow(grid) 129 | -------------------------------------------------------------------------------- /model_sharding_data_parallel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model sharding with DataParallel using 2 pairs of 2 GPUs. 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SubModule(nn.Module): 12 | def __init__(self, in_channels, out_channels): 13 | super(SubModule, self).__init__() 14 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) 15 | 16 | def forward(self, x): 17 | print('SubModule, device: {}, shape: {}\n'.format(x.device, x.shape)) 18 | x = self.conv1(x) 19 | return x 20 | 21 | 22 | class MyModel(nn.Module): 23 | def __init__(self, split_gpus, parallel): 24 | super(MyModel, self).__init__() 25 | self.module1 = SubModule(3, 6) 26 | self.module2 = SubModule(6, 1) 27 | 28 | self.split_gpus = split_gpus 29 | self.parallel = parallel 30 | if self.split_gpus and self.parallel: 31 | self.module1 = nn.DataParallel(self.module1, device_ids=[0, 1]).to('cuda:0') 32 | self.module2 = nn.DataParallel(self.module2, device_ids=[2, 3]).to('cuda:2') 33 | 34 | def forward(self, x): 35 | print('Input: device {}, shape {}\n'.format(x.device, x.shape)) 36 | x = self.module1(x) 37 | print('After module1: device {}, shape {}\n'.format(x.device, x.shape)) 38 | x = self.module2(x) 39 | print('After module2: device {}, shape {}\n'.format(x.device, x.shape)) 40 | return x 41 | 42 | 43 | model = MyModel(split_gpus=True, parallel=True) 44 | x = torch.randn(16, 3, 24, 24).to('cuda:0') 45 | output = model(x) 46 | -------------------------------------------------------------------------------- /momentum_update_nograd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to see how parameters are updated when an optimizer is used with 3 | momentum/running estimates, even if gradients are zero. 4 | 5 | Set use_adam=True to see the effect. Otherwise plain SGD will be used. 6 | 7 | The model consists of two "decoder" parts, dec1 and dec2. 8 | In the first part of the script, you'll see that dec1 will be updated twice, 9 | even though this module is not used in the second forward pass. 10 | This effect is observed, if one optimizer is used for all parameters. 11 | 12 | In the second part of the script, two separate optimizers are used and 13 | we cannot observe this effect anymore. 14 | 15 | @author: ptrblck 16 | """ 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | 23 | 24 | use_adam = True 25 | 26 | 27 | class MyModel(nn.Module): 28 | def __init__(self): 29 | super(MyModel, self).__init__() 30 | self.enc = nn.Linear(64, 10) 31 | self.dec1 = nn.Linear(10, 64) 32 | self.dec2 = nn.Linear(10, 64) 33 | 34 | def forward(self, x, decoder_idx): 35 | x = F.relu(self.enc(x)) 36 | if decoder_idx == 1: 37 | print('Using dec1') 38 | x = self.dec1(x) 39 | elif decoder_idx == 2: 40 | print('Using dec2') 41 | x = self.dec2(x) 42 | else: 43 | print('Unknown decoder_idx') 44 | 45 | return x 46 | 47 | 48 | # Create input and model 49 | x = torch.randn(1, 64) 50 | y = x.clone() 51 | model = MyModel() 52 | criterion = nn.MSELoss() 53 | # Create optimizer using all model parameters 54 | if use_adam: 55 | optimizer = optim.Adam(model.parameters(), lr=1.) 56 | else: 57 | optimizer = optim.SGD(model.parameters(), lr=1.) 58 | 59 | # Save init values 60 | old_state_dict = {} 61 | for key in model.state_dict(): 62 | old_state_dict[key] = model.state_dict()[key].clone() 63 | 64 | # Training procedure 65 | optimizer.zero_grad() 66 | output = model(x, 1) 67 | loss = criterion(output, y) 68 | loss.backward() 69 | 70 | # Check for gradients in dec1, dec2 71 | print('Dec1 grad: {}\nDec2 grad: {}'.format( 72 | model.dec1.weight.grad, model.dec2.weight.grad)) 73 | 74 | optimizer.step() 75 | 76 | # Save new params 77 | new_state_dict = {} 78 | for key in model.state_dict(): 79 | new_state_dict[key] = model.state_dict()[key].clone() 80 | 81 | # Compare params 82 | for key in old_state_dict: 83 | if not (old_state_dict[key] == new_state_dict[key]).all(): 84 | print('Diff in {}'.format(key)) 85 | 86 | # Update 87 | old_state_dict = {} 88 | for key in model.state_dict(): 89 | old_state_dict[key] = model.state_dict()[key].clone() 90 | 91 | # Pass through dec2 92 | optimizer.zero_grad() 93 | output = model(x, 2) 94 | loss = criterion(output, y) 95 | loss.backward() 96 | 97 | print('Dec1 grad: {}\nDec2 grad: {}'.format( 98 | model.dec1.weight.grad, model.dec2.weight.grad)) 99 | 100 | optimizer.step() 101 | 102 | # Save new params 103 | new_state_dict = {} 104 | for key in model.state_dict(): 105 | new_state_dict[key] = model.state_dict()[key].clone() 106 | 107 | # Compare params 108 | for key in old_state_dict: 109 | if not (old_state_dict[key] == new_state_dict[key]).all(): 110 | print('Diff in {}'.format(key)) 111 | 112 | ## Create separate optimizers 113 | model = MyModel() 114 | dec1_params = list(model.enc.parameters()) + list(model.dec1.parameters()) 115 | optimizer1 = optim.Adam(dec1_params, lr=1.) 116 | dec2_params = list(model.enc.parameters()) + list(model.dec2.parameters()) 117 | optimizer2 = optim.Adam(dec2_params, lr=1.) 118 | 119 | # Save init values 120 | old_state_dict = {} 121 | for key in model.state_dict(): 122 | old_state_dict[key] = model.state_dict()[key].clone() 123 | 124 | # Training procedure 125 | optimizer1.zero_grad() 126 | output = model(x, 1) 127 | loss = criterion(output, y) 128 | loss.backward() 129 | 130 | # Check for gradients in dec1, dec2 131 | print('Dec1 grad: {}\nDec2 grad: {}'.format( 132 | model.dec1.weight.grad, model.dec2.weight.grad)) 133 | 134 | optimizer1.step() 135 | 136 | # Save new params 137 | new_state_dict = {} 138 | for key in model.state_dict(): 139 | new_state_dict[key] = model.state_dict()[key].clone() 140 | 141 | # Compare params 142 | for key in old_state_dict: 143 | if not (old_state_dict[key] == new_state_dict[key]).all(): 144 | print('Diff in {}'.format(key)) 145 | 146 | # Update 147 | old_state_dict = {} 148 | for key in model.state_dict(): 149 | old_state_dict[key] = model.state_dict()[key].clone() 150 | 151 | # Pass through dec2 152 | optimizer1.zero_grad() 153 | output = model(x, 2) 154 | loss = criterion(output, y) 155 | loss.backward() 156 | 157 | print('Dec1 grad: {}\nDec2 grad: {}'.format( 158 | model.dec1.weight.grad, model.dec2.weight.grad)) 159 | 160 | optimizer2.step() 161 | 162 | # Save new params 163 | new_state_dict = {} 164 | for key in model.state_dict(): 165 | new_state_dict[key] = model.state_dict()[key].clone() 166 | 167 | # Compare params 168 | for key in old_state_dict: 169 | if not (old_state_dict[key] == new_state_dict[key]).all(): 170 | print('Diff in {}'.format(key)) 171 | -------------------------------------------------------------------------------- /pytorch_redis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shows how to store and load data from redis using a PyTorch 3 | Dataset and DataLoader (with multiple workers). 4 | 5 | @author: ptrblck 6 | """ 7 | 8 | import redis 9 | 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader 12 | import torchvision.transforms as transforms 13 | 14 | import numpy as np 15 | 16 | 17 | # Create random data and push to redis 18 | r = redis.Redis(host='localhost', port=6379, db=0) 19 | 20 | nb_images = 100 21 | for idx in range(nb_images): 22 | # Use long for the fake images, as it's easier to store the target with it 23 | data = np.random.randint(0, 256, (3, 24, 24), dtype=np.long).tobytes() 24 | target = bytes(np.random.randint(0, 10, (1,)).astype(np.long)) 25 | r.set(idx, data + target) 26 | 27 | 28 | # Create RedisDataset 29 | class RedisDataset(Dataset): 30 | def __init__(self, 31 | redis_host='localhost', 32 | redis_port=6379, 33 | redis_db=0, 34 | length=0, 35 | transform=None): 36 | 37 | self.db = redis.Redis(host=redis_host, port=redis_port, db=redis_db) 38 | self.length = length 39 | self.transform = transform 40 | 41 | def __getitem__(self, index): 42 | data = self.db.get(index) 43 | data = np.frombuffer(data, dtype=np.long) 44 | x = data[:-1].reshape(3, 24, 24).astype(np.uint8) 45 | y = torch.tensor(data[-1]).long() 46 | if self.transform: 47 | x = self.transform(x) 48 | 49 | return x, y 50 | 51 | def __len__(self): 52 | return self.length 53 | 54 | 55 | # Load samples from redis using multiprocessing 56 | dataset = RedisDataset(length=100, transform=transforms.ToTensor()) 57 | loader = DataLoader( 58 | dataset, 59 | batch_size=10, 60 | num_workers=2, 61 | shuffle=True 62 | ) 63 | 64 | for data, target in loader: 65 | print(data.shape) 66 | print(target.shape) 67 | -------------------------------------------------------------------------------- /shared_array.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to demonstrate the usage of shared arrays using multiple workers. 3 | 4 | In the first epoch the shared arrays in the dataset will be filled with 5 | random values. After setting set_use_cache(True), the shared values will be 6 | loaded from multiple processes. 7 | 8 | @author: ptrblck 9 | """ 10 | 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | import ctypes 15 | import multiprocessing as mp 16 | 17 | import numpy as np 18 | 19 | 20 | class MyDataset(Dataset): 21 | def __init__(self): 22 | shared_array_base = mp.Array(ctypes.c_float, nb_samples*c*h*w) 23 | shared_array = np.ctypeslib.as_array(shared_array_base.get_obj()) 24 | shared_array = shared_array.reshape(nb_samples, c, h, w) 25 | self.shared_array = torch.from_numpy(shared_array) 26 | self.use_cache = False 27 | 28 | def set_use_cache(self, use_cache): 29 | self.use_cache = use_cache 30 | 31 | def __getitem__(self, index): 32 | if not self.use_cache: 33 | print('Filling cache for index {}'.format(index)) 34 | # Add your loading logic here 35 | self.shared_array[index] = torch.randn(c, h, w) 36 | x = self.shared_array[index] 37 | return x 38 | 39 | def __len__(self): 40 | return nb_samples 41 | 42 | 43 | nb_samples, c, h, w = 10, 3, 24, 24 44 | 45 | dataset = MyDataset() 46 | loader = DataLoader( 47 | dataset, 48 | num_workers=2, 49 | shuffle=False 50 | ) 51 | 52 | for epoch in range(2): 53 | for idx, data in enumerate(loader): 54 | print('Epoch {}, idx {}, data.shape {}'.format(epoch, idx, data.shape)) 55 | 56 | if epoch == 0: 57 | loader.dataset.set_use_cache(True) 58 | -------------------------------------------------------------------------------- /shared_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to demonstrate the usage of shared dicts using multiple workers. 3 | 4 | In the first epoch the shared dict in the dataset will be filled with 5 | random values. The next epochs will just use the dict without "loading" the 6 | data again. 7 | 8 | @author: ptrblck 9 | """ 10 | 11 | from multiprocessing import Manager 12 | 13 | import torch 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | 17 | class MyDataset(Dataset): 18 | def __init__(self, shared_dict, length): 19 | self.shared_dict = shared_dict 20 | self.length = length 21 | 22 | def __getitem__(self, index): 23 | if index not in self.shared_dict: 24 | print('Adding {} to shared_dict'.format(index)) 25 | self.shared_dict[index] = torch.tensor(index) 26 | return self.shared_dict[index] 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | 32 | # Init 33 | manager = Manager() 34 | shared_dict = manager.dict() 35 | dataset = MyDataset(shared_dict, length=100) 36 | 37 | loader = DataLoader( 38 | dataset, 39 | batch_size=10, 40 | num_workers=6, 41 | shuffle=True, 42 | pin_memory=True 43 | ) 44 | 45 | # First loop will add data to the shared_dict 46 | for x in loader: 47 | print(x) 48 | 49 | # The second loop will just get the data 50 | for x in loader: 51 | print(x) 52 | -------------------------------------------------------------------------------- /unet_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple UNet demo 3 | 4 | @author: ptrblck 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | 12 | 13 | class BaseConv(nn.Module): 14 | def __init__(self, in_channels, out_channels, kernel_size, padding, 15 | stride): 16 | super(BaseConv, self).__init__() 17 | 18 | self.act = nn.ReLU() 19 | 20 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding, 21 | stride) 22 | 23 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 24 | padding, stride) 25 | 26 | def forward(self, x): 27 | x = self.act(self.conv1(x)) 28 | x = self.act(self.conv2(x)) 29 | return x 30 | 31 | 32 | class DownConv(nn.Module): 33 | def __init__(self, in_channels, out_channels, kernel_size, padding, 34 | stride): 35 | super(DownConv, self).__init__() 36 | 37 | self.pool1 = nn.MaxPool2d(kernel_size=2) 38 | self.conv_block = BaseConv(in_channels, out_channels, kernel_size, 39 | padding, stride) 40 | 41 | def forward(self, x): 42 | x = self.pool1(x) 43 | x = self.conv_block(x) 44 | return x 45 | 46 | 47 | class UpConv(nn.Module): 48 | def __init__(self, in_channels, in_channels_skip, out_channels, 49 | kernel_size, padding, stride): 50 | super(UpConv, self).__init__() 51 | 52 | self.conv_trans1 = nn.ConvTranspose2d( 53 | in_channels, in_channels, kernel_size=2, padding=0, stride=2) 54 | self.conv_block = BaseConv( 55 | in_channels=in_channels + in_channels_skip, 56 | out_channels=out_channels, 57 | kernel_size=kernel_size, 58 | padding=padding, 59 | stride=stride) 60 | 61 | def forward(self, x, x_skip): 62 | x = self.conv_trans1(x) 63 | x = torch.cat((x, x_skip), dim=1) 64 | x = self.conv_block(x) 65 | return x 66 | 67 | 68 | class UNet(nn.Module): 69 | def __init__(self, in_channels, out_channels, n_class, kernel_size, 70 | padding, stride): 71 | super(UNet, self).__init__() 72 | 73 | self.init_conv = BaseConv(in_channels, out_channels, kernel_size, 74 | padding, stride) 75 | 76 | self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size, 77 | padding, stride) 78 | 79 | self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size, 80 | padding, stride) 81 | 82 | self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size, 83 | padding, stride) 84 | 85 | self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels, 86 | kernel_size, padding, stride) 87 | 88 | self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels, 89 | kernel_size, padding, stride) 90 | 91 | self.up1 = UpConv(2 * out_channels, out_channels, out_channels, 92 | kernel_size, padding, stride) 93 | 94 | self.out = nn.Conv2d(out_channels, n_class, kernel_size, padding, stride) 95 | 96 | def forward(self, x): 97 | # Encoder 98 | x = self.init_conv(x) 99 | x1 = self.down1(x) 100 | x2 = self.down2(x1) 101 | x3 = self.down3(x2) 102 | # Decoder 103 | x_up = self.up3(x3, x2) 104 | x_up = self.up2(x_up, x1) 105 | x_up = self.up1(x_up, x) 106 | x_out = F.log_softmax(self.out(x_up), 1) 107 | return x_out 108 | 109 | 110 | # Create 10-class segmentation dummy image and target 111 | nb_classes = 10 112 | x = torch.randn(1, 3, 96, 96) 113 | y = torch.randint(0, nb_classes, (1, 96, 96)) 114 | 115 | model = UNet(in_channels=3, 116 | out_channels=64, 117 | n_class=10, 118 | kernel_size=3, 119 | padding=1, 120 | stride=1) 121 | 122 | if torch.cuda.is_available(): 123 | model = model.to('cuda') 124 | x = x.to('cuda') 125 | y = y.to('cuda') 126 | 127 | criterion = nn.NLLLoss() 128 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 129 | 130 | # Training loop 131 | for epoch in range(1): 132 | optimizer.zero_grad() 133 | 134 | output = model(x) 135 | loss = criterion(output, y) 136 | loss.backward() 137 | optimizer.step() 138 | 139 | print('Epoch {}, Loss {}'.format(epoch, loss.item())) 140 | -------------------------------------------------------------------------------- /weighted_sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage of WeightedRandomSampler using an imbalanced dataset with 3 | class imbalance 99 to 1. 4 | 5 | @author: ptrblck 6 | """ 7 | 8 | import torch 9 | from torch.utils.data.sampler import WeightedRandomSampler 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | 13 | # Create dummy data with class imbalance 99 to 1 14 | numDataPoints = 1000 15 | data_dim = 5 16 | bs = 100 17 | data = torch.randn(numDataPoints, data_dim) 18 | target = torch.cat((torch.zeros(int(numDataPoints * 0.99), dtype=torch.long), 19 | torch.ones(int(numDataPoints * 0.01), dtype=torch.long))) 20 | 21 | print('target train 0/1: {}/{}'.format( 22 | (target == 0).sum(), (target == 1).sum())) 23 | 24 | # Compute samples weight (each sample should get its own weight) 25 | class_sample_count = torch.tensor( 26 | [(target == t).sum() for t in torch.unique(target, sorted=True)]) 27 | weight = 1. / class_sample_count.float() 28 | samples_weight = torch.tensor([weight[t] for t in target]) 29 | 30 | # Create sampler, dataset, loader 31 | sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) 32 | train_dataset = torch.utils.data.TensorDataset(data, target) 33 | train_loader = DataLoader( 34 | train_dataset, batch_size=bs, num_workers=1, sampler=sampler) 35 | 36 | # Iterate DataLoader and check class balance for each batch 37 | for i, (x, y) in enumerate(train_loader): 38 | print("batch index {}, 0/1: {}/{}".format( 39 | i, (y == 0).sum(), (y == 1).sum())) 40 | --------------------------------------------------------------------------------