├── LICENSE ├── LICENSE.md ├── train_gans.py ├── models ├── Vanilla_GAN.py ├── DCGAN.py └── GAN3D.py ├── README.md └── notebooks ├── DCGAN.ipynb └── 3D_GAN_pytorch.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Adaloglou Nikolaos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 3D-GAN-pytorch Nikolas Adaloglou (black0017) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /train_gans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.variable import Variable 3 | 4 | 5 | def ones_target(size): 6 | ''' 7 | Tensor containing ones, with shape = size 8 | ''' 9 | data = Variable(torch.ones(size, 1)) 10 | return data 11 | 12 | 13 | def zeros_target(size): 14 | ''' 15 | FAKE data 16 | Tensor containing zeros, with shape = size 17 | ''' 18 | data = Variable(torch.zeros(size, 1)) 19 | return data 20 | 21 | 22 | def train_discriminator(discriminator, optimizer, real_data, fake_data, loss): 23 | cuda = next(discriminator.parameters()).is_cuda 24 | N = real_data.size(0) 25 | # Reset gradients 26 | optimizer.zero_grad() 27 | # 1.1 Train on Real Data 28 | prediction_real = discriminator(real_data) 29 | # Calculate error and backpropagate 30 | target_real = ones_target(N) 31 | if cuda: 32 | target_real.cuda() 33 | 34 | error_real = loss(prediction_real, target_real) 35 | error_real.backward() 36 | 37 | # 1.2 Train on Fake Data 38 | prediction_fake = discriminator(fake_data) 39 | # Calculate error and backpropagate 40 | target_fake = zeros_target(N) 41 | if cuda: 42 | target_fake.cuda() 43 | error_fake = loss(prediction_fake, target_fake) 44 | error_fake.backward() 45 | 46 | # 1.3 Update weights with gradients 47 | optimizer.step() 48 | 49 | # Return error and predictions for real and fake inputs 50 | return error_real + error_fake, prediction_real, prediction_fake 51 | 52 | 53 | def train_generator(discriminator, optimizer, fake_data, loss): 54 | cuda = next(discriminator.parameters()).is_cuda 55 | N = fake_data.size(0) # Reset gradients 56 | optimizer.zero_grad() # Sample noise and generate fake data 57 | prediction = discriminator(fake_data) # Calculate error and backpropagate 58 | target = ones_target(N) 59 | if cuda: 60 | target.cuda() 61 | 62 | error = loss(prediction, target) 63 | error.backward() # Update weights with gradients 64 | optimizer.step() # Return error 65 | return error 66 | -------------------------------------------------------------------------------- /models/Vanilla_GAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DiscriminatorNet(torch.nn.Module): 6 | """ 7 | A three hidden-layer discriminative neural network 8 | """ 9 | 10 | def __init__(self): 11 | super(DiscriminatorNet, self).__init__() 12 | n_features = 784 13 | n_out = 1 14 | 15 | self.hidden0 = nn.Sequential( 16 | nn.Linear(n_features, 1024), 17 | nn.LeakyReLU(0.2), 18 | nn.Dropout(0.3) 19 | ) 20 | self.hidden1 = nn.Sequential( 21 | nn.Linear(1024, 512), 22 | nn.LeakyReLU(0.2), 23 | nn.Dropout(0.3) 24 | ) 25 | self.hidden2 = nn.Sequential( 26 | nn.Linear(512, 256), 27 | nn.LeakyReLU(0.2), 28 | nn.Dropout(0.3) 29 | ) 30 | self.out = nn.Sequential( 31 | torch.nn.Linear(256, n_out), 32 | torch.nn.Sigmoid() 33 | ) 34 | 35 | def forward(self, x): 36 | x = self.hidden0(x) 37 | x = self.hidden1(x) 38 | x = self.hidden2(x) 39 | x = self.out(x) 40 | return x 41 | 42 | 43 | class GeneratorNet(torch.nn.Module): 44 | """ 45 | A three hidden-layer generative neural network 46 | """ 47 | 48 | def __init__(self): 49 | super(GeneratorNet, self).__init__() 50 | n_features = 100 51 | n_out = 784 52 | 53 | self.hidden0 = nn.Sequential( 54 | nn.Linear(n_features, 256), 55 | nn.LeakyReLU(0.2) 56 | ) 57 | self.hidden1 = nn.Sequential( 58 | nn.Linear(256, 512), 59 | nn.LeakyReLU(0.2) 60 | ) 61 | self.hidden2 = nn.Sequential( 62 | nn.Linear(512, 1024), 63 | nn.LeakyReLU(0.2) 64 | ) 65 | 66 | self.out = nn.Sequential( 67 | nn.Linear(1024, n_out), 68 | nn.Tanh() 69 | ) 70 | 71 | def forward(self, x): 72 | x = self.hidden0(x) 73 | x = self.hidden1(x) 74 | x = self.hidden2(x) 75 | x = self.out(x) 76 | return x 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D GAN Pytorch (Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling) 2 | **Responsible** implementation of 3D-GAN NIPS 2016 paper that can be found [here](https://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf "Link to paper") 3 | 4 | We did our best to follow the original guidelines based on the papers. However, it is always good to try to reproduce the publication results from the original work. We also included our **DCGAN** implementation since **3D-GAN** is the natural extension of DCGAN in 3D space. For completeness, a Vanilla GAN is also included. All models all available in Google COLLAB. You can train them with the same training script that exists in train_gans.py 5 | 6 | Data loaders to be updated soon. 7 | 8 | ## Google collab instructions and Usage 9 | 1. Go to https://colab.research.google.com 10 | 2. **```File```** > **```Upload notebook...```** > **```GitHub```** > **```Paste this link:``` https://github.com/black0017/3D-GAN-pytorch/blob/master/notebooks/3D_GAN_pytorch.ipynb** 11 | 3. Ensure that **```Runtime```** > **```Change runtime type```** is ```Python 3``` with ```GPU``` 12 | 4. Run the code-blocks and enjoy :) 13 | 14 | 15 | 16 | ## Detailed Info 17 | 18 | #### Generator/Discriminator summary for batch size of 1 19 | Trainable params: 17,601,408/11,048,833 20 | 21 | Forward/backward pass size (MB): 67.25/63.75 22 | 23 | Params size (MB): 67.14/42.15 24 | 25 | Estimated Total Size (MB): 134.39/106.90 26 | 27 | ## References 28 | 29 | [1] Wu, J., Zhang, C., Xue, T., Freeman, B., & Tenenbaum, J. (2016). Learning a probabilistic latent space of object shapes via 3d generative-adversarial modeling. In Advances in neural information processing systems (pp. 82-90). 30 | 31 | [2] Radford, A., Metz, L., & Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434. 32 | 33 | [3] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2014). Generative adversarial nets. In Advances in neural information processing systems (pp. 2672-2680). 34 | 35 | 36 | ## Support 37 | If you **really** like this repository and find it useful, please consider (★) **starring** it, so that it can reach a broader audience of like-minded people. 38 | -------------------------------------------------------------------------------- /models/DCGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchsummary import summary 4 | 5 | """ 6 | DCGAN pytorch implementation based on https://arxiv.org/abs/1511.06434 7 | """ 8 | 9 | 10 | class Discriminator(torch.nn.Module): 11 | 12 | def __init__(self, in_channels=3, out_conv_channels=1024, dim=64): 13 | super(Discriminator, self).__init__() 14 | conv1_channels = int(out_conv_channels / 8) 15 | conv2_channels = int(out_conv_channels / 4) 16 | conv3_channels = int(out_conv_channels / 2) 17 | self.out_conv_channels = out_conv_channels 18 | self.out_dim = int(dim / 16) 19 | 20 | self.conv1 = nn.Sequential( 21 | nn.Conv2d( 22 | in_channels=in_channels, out_channels=conv1_channels, kernel_size=4, 23 | stride=2, padding=1, bias=False 24 | ), 25 | nn.BatchNorm2d(conv1_channels), 26 | nn.LeakyReLU(0.2, inplace=True) 27 | ) 28 | self.conv2 = nn.Sequential( 29 | nn.Conv2d( 30 | in_channels=conv1_channels, out_channels=conv2_channels, kernel_size=4, 31 | stride=2, padding=1, bias=False 32 | ), 33 | nn.BatchNorm2d(conv2_channels), 34 | nn.LeakyReLU(0.2, inplace=True) 35 | ) 36 | self.conv3 = nn.Sequential( 37 | nn.Conv2d( 38 | in_channels=conv2_channels, out_channels=conv3_channels, kernel_size=4, 39 | stride=2, padding=1, bias=False 40 | ), 41 | nn.BatchNorm2d(conv3_channels), 42 | nn.LeakyReLU(0.2, inplace=True) 43 | ) 44 | self.conv4 = nn.Sequential( 45 | nn.Conv2d( 46 | in_channels=conv3_channels, out_channels=out_conv_channels, kernel_size=4, 47 | stride=2, padding=1, bias=False 48 | ), 49 | nn.BatchNorm2d(out_conv_channels), 50 | nn.LeakyReLU(0.2, inplace=True) 51 | ) 52 | self.out = nn.Sequential( 53 | nn.Linear(out_conv_channels * self.out_dim * self.out_dim, 1), 54 | nn.Sigmoid(), 55 | ) 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | x = self.conv2(x) 60 | x = self.conv3(x) 61 | x = self.conv4(x) 62 | x = x.view(-1, self.out_conv_channels * self.out_dim * self.out_dim) 63 | x = self.out(x) 64 | return x 65 | 66 | 67 | class Generator(torch.nn.Module): 68 | 69 | def __init__(self, in_channels=1024, out_dim=64, out_channels=3, noise_dim=200): 70 | super(Generator, self).__init__() 71 | self.in_channels = in_channels 72 | self.out_dim = out_dim 73 | self.in_dim = int(out_dim / 16) 74 | conv1_out_channels = int(self.in_channels / 2.0) 75 | conv2_out_channels = int(conv1_out_channels / 2) 76 | conv3_out_channels = int(conv2_out_channels / 2) 77 | 78 | self.linear = torch.nn.Linear(noise_dim, in_channels * self.in_dim * self.in_dim) 79 | 80 | self.conv1 = nn.Sequential( 81 | nn.ConvTranspose2d( 82 | in_channels=self.in_channels, out_channels=conv1_out_channels, kernel_size=4, 83 | stride=2, padding=1, bias=False 84 | ), 85 | nn.BatchNorm2d(conv1_out_channels), 86 | nn.ReLU(inplace=True) 87 | ) 88 | self.conv2 = nn.Sequential( 89 | nn.ConvTranspose2d( 90 | in_channels=conv1_out_channels, out_channels=conv2_out_channels, kernel_size=4, 91 | stride=2, padding=1, bias=False 92 | ), 93 | nn.BatchNorm2d(conv2_out_channels), 94 | nn.ReLU(inplace=True) 95 | ) 96 | self.conv3 = nn.Sequential( 97 | nn.ConvTranspose2d( 98 | in_channels=conv2_out_channels, out_channels=conv3_out_channels, kernel_size=4, 99 | stride=2, padding=1, bias=False 100 | ), 101 | nn.BatchNorm2d(conv3_out_channels), 102 | nn.ReLU(inplace=True) 103 | ) 104 | self.conv4 = nn.Sequential( 105 | nn.ConvTranspose2d( 106 | in_channels=conv3_out_channels, out_channels=out_channels, kernel_size=4, 107 | stride=2, padding=1, bias=False 108 | ) 109 | ) 110 | self.out = torch.nn.Tanh() 111 | 112 | def forward(self, x): 113 | x = self.linear(x) 114 | x = x.view(-1, self.in_channels, self.in_dim, self.in_dim) 115 | x = self.conv1(x) 116 | x = self.conv2(x) 117 | x = self.conv3(x) 118 | x = self.conv4(x) 119 | return self.out(x) 120 | 121 | 122 | def test_dcgan(): 123 | noise_dim = 100 124 | in_conv_channels = 512 125 | dim = 64 # cube volume 126 | model_generator = Generator(in_channels=in_conv_channels, out_dim=dim, out_channels=3, noise_dim=noise_dim) 127 | noise = torch.rand(1, noise_dim) 128 | generated_volume = model_generator(noise) 129 | print("Generator output shape", generated_volume.shape) 130 | model_discriminator = Discriminator(in_channels=3, dim=dim, out_conv_channels=in_conv_channels) 131 | out = model_discriminator(generated_volume) 132 | print("Discriminator output", out) 133 | print("Generator summary") 134 | summary(model_generator, (1, noise_dim)) 135 | print("Discriminator summary") 136 | summary(model_discriminator, (3,64,64)) 137 | 138 | test_dcgan() -------------------------------------------------------------------------------- /models/GAN3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchsummary import summary 4 | 5 | """ 6 | Implementation based on original paper NeurIPS 2016 7 | https://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf 8 | """ 9 | 10 | 11 | class Discriminator(torch.nn.Module): 12 | def __init__(self, in_channels=1, dim=64, out_conv_channels=512): 13 | super(Discriminator, self).__init__() 14 | conv1_channels = int(out_conv_channels / 8) 15 | conv2_channels = int(out_conv_channels / 4) 16 | conv3_channels = int(out_conv_channels / 2) 17 | self.out_conv_channels = out_conv_channels 18 | self.out_dim = int(dim / 16) 19 | 20 | self.conv1 = nn.Sequential( 21 | nn.Conv3d( 22 | in_channels=in_channels, out_channels=conv1_channels, kernel_size=4, 23 | stride=2, padding=1, bias=False 24 | ), 25 | nn.BatchNorm3d(conv1_channels), 26 | nn.LeakyReLU(0.2, inplace=True) 27 | ) 28 | self.conv2 = nn.Sequential( 29 | nn.Conv3d( 30 | in_channels=conv1_channels, out_channels=conv2_channels, kernel_size=4, 31 | stride=2, padding=1, bias=False 32 | ), 33 | nn.BatchNorm3d(conv2_channels), 34 | nn.LeakyReLU(0.2, inplace=True) 35 | ) 36 | self.conv3 = nn.Sequential( 37 | nn.Conv3d( 38 | in_channels=conv2_channels, out_channels=conv3_channels, kernel_size=4, 39 | stride=2, padding=1, bias=False 40 | ), 41 | nn.BatchNorm3d(conv3_channels), 42 | nn.LeakyReLU(0.2, inplace=True) 43 | ) 44 | self.conv4 = nn.Sequential( 45 | nn.Conv3d( 46 | in_channels=conv3_channels, out_channels=out_conv_channels, kernel_size=4, 47 | stride=2, padding=1, bias=False 48 | ), 49 | nn.BatchNorm3d(out_conv_channels), 50 | nn.LeakyReLU(0.2, inplace=True) 51 | ) 52 | self.out = nn.Sequential( 53 | nn.Linear(out_conv_channels * self.out_dim * self.out_dim * self.out_dim, 1), 54 | nn.Sigmoid(), 55 | ) 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | x = self.conv2(x) 60 | x = self.conv3(x) 61 | x = self.conv4(x) 62 | # Flatten and apply linear + sigmoid 63 | x = x.view(-1, self.out_conv_channels * self.out_dim * self.out_dim * self.out_dim) 64 | x = self.out(x) 65 | return x 66 | 67 | 68 | class Generator(torch.nn.Module): 69 | def __init__(self, in_channels=512, out_dim=64, out_channels=1, noise_dim=200, activation="sigmoid"): 70 | super(Generator, self).__init__() 71 | self.in_channels = in_channels 72 | self.out_dim = out_dim 73 | self.in_dim = int(out_dim / 16) 74 | conv1_out_channels = int(self.in_channels / 2.0) 75 | conv2_out_channels = int(conv1_out_channels / 2) 76 | conv3_out_channels = int(conv2_out_channels / 2) 77 | 78 | self.linear = torch.nn.Linear(noise_dim, in_channels * self.in_dim * self.in_dim * self.in_dim) 79 | 80 | self.conv1 = nn.Sequential( 81 | nn.ConvTranspose3d( 82 | in_channels=in_channels, out_channels=conv1_out_channels, kernel_size=(4, 4, 4), 83 | stride=2, padding=1, bias=False 84 | ), 85 | nn.BatchNorm3d(conv1_out_channels), 86 | nn.ReLU(inplace=True) 87 | ) 88 | self.conv2 = nn.Sequential( 89 | nn.ConvTranspose3d( 90 | in_channels=conv1_out_channels, out_channels=conv2_out_channels, kernel_size=(4, 4, 4), 91 | stride=2, padding=1, bias=False 92 | ), 93 | nn.BatchNorm3d(conv2_out_channels), 94 | nn.ReLU(inplace=True) 95 | ) 96 | self.conv3 = nn.Sequential( 97 | nn.ConvTranspose3d( 98 | in_channels=conv2_out_channels, out_channels=conv3_out_channels, kernel_size=(4, 4, 4), 99 | stride=2, padding=1, bias=False 100 | ), 101 | nn.BatchNorm3d(conv3_out_channels), 102 | nn.ReLU(inplace=True) 103 | ) 104 | self.conv4 = nn.Sequential( 105 | nn.ConvTranspose3d( 106 | in_channels=conv3_out_channels, out_channels=out_channels, kernel_size=(4, 4, 4), 107 | stride=2, padding=1, bias=False 108 | ) 109 | ) 110 | if activation == "sigmoid": 111 | self.out = torch.nn.Sigmoid() 112 | else: 113 | self.out = torch.nn.Tanh() 114 | 115 | def project(self, x): 116 | """ 117 | projects and reshapes latent vector to starting volume 118 | :param x: latent vector 119 | :return: starting volume 120 | """ 121 | return x.view(-1, self.in_channels, self.in_dim, self.in_dim, self.in_dim) 122 | 123 | def forward(self, x): 124 | x = self.linear(x) 125 | x = self.project(x) 126 | x = self.conv1(x) 127 | x = self.conv2(x) 128 | x = self.conv3(x) 129 | x = self.conv4(x) 130 | return self.out(x) 131 | 132 | 133 | def test_gan3d(): 134 | noise_dim = 200 135 | in_channels = 512 136 | dim = 64 # cube volume 137 | model_generator = Generator(in_channels=512, out_dim=dim, out_channels=1, noise_dim=noise_dim) 138 | noise = torch.rand(1, noise_dim) 139 | generated_volume = model_generator(noise) 140 | print("Generator output shape", generated_volume.shape) 141 | model_discriminator = Discriminator(in_channels=1, dim=dim, out_conv_channels=in_channels) 142 | out = model_discriminator(generated_volume) 143 | print("Discriminator output", out) 144 | summary(model_generator, (1, noise_dim)) 145 | summary(model_discriminator, (1, 64, 64, 64)) 146 | 147 | 148 | test_gan3d() 149 | -------------------------------------------------------------------------------- /notebooks/DCGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "DCGAN.ipynb", 7 | "provenance": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | } 13 | }, 14 | "cells": [ 15 | { 16 | "cell_type": "markdown", 17 | "metadata": { 18 | "id": "iF22iJa0FgmD", 19 | "colab_type": "text" 20 | }, 21 | "source": [ 22 | "" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "id": "JX3QqxIxFh25", 29 | "colab_type": "text" 30 | }, 31 | "source": [ 32 | "# DCGAN implementation in pytorch" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "metadata": { 38 | "id": "am3Jnb-GF46B", 39 | "colab_type": "code", 40 | "colab": { 41 | "base_uri": "https://localhost:8080/", 42 | "height": 52 43 | }, 44 | "outputId": "8563f2e1-f119-45b9-e479-053d43dff46e" 45 | }, 46 | "source": [ 47 | "!pip install torch\n", 48 | "!pip install torchsummary" 49 | ], 50 | "execution_count": 2, 51 | "outputs": [ 52 | { 53 | "output_type": "stream", 54 | "text": [ 55 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.4.0)\n", 56 | "Requirement already satisfied: torchsummary in /usr/local/lib/python3.6/dist-packages (1.5.1)\n" 57 | ], 58 | "name": "stdout" 59 | } 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "metadata": { 65 | "id": "qAVbbhcdFYdP", 66 | "colab_type": "code", 67 | "colab": { 68 | "base_uri": "https://localhost:8080/", 69 | "height": 34 70 | }, 71 | "outputId": "ed992831-d8ca-493d-8162-fbe33f8a6172" 72 | }, 73 | "source": [ 74 | "import torch\n", 75 | "import torch.nn as nn\n", 76 | "from torchsummary import summary\n", 77 | "\n", 78 | "\"\"\"\n", 79 | "DCGAN pytorch implementation based on https://arxiv.org/abs/1511.06434\n", 80 | "\"\"\"\n" 81 | ], 82 | "execution_count": 3, 83 | "outputs": [ 84 | { 85 | "output_type": "execute_result", 86 | "data": { 87 | "text/plain": [ 88 | "'\\nDCGAN pytorch implementation based on https://arxiv.org/abs/1511.06434\\n'" 89 | ] 90 | }, 91 | "metadata": { 92 | "tags": [] 93 | }, 94 | "execution_count": 3 95 | } 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "a24T9e6oFrUE", 102 | "colab_type": "text" 103 | }, 104 | "source": [ 105 | "## Discriminator" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "metadata": { 111 | "id": "DMsIKh59Fpgi", 112 | "colab_type": "code", 113 | "colab": {} 114 | }, 115 | "source": [ 116 | "class Discriminator(torch.nn.Module):\n", 117 | "\n", 118 | " def __init__(self, in_channels=3, out_conv_channels=1024, dim=64):\n", 119 | " super(Discriminator, self).__init__()\n", 120 | " conv1_channels = int(out_conv_channels / 8)\n", 121 | " conv2_channels = int(out_conv_channels / 4)\n", 122 | " conv3_channels = int(out_conv_channels / 2)\n", 123 | " self.out_conv_channels = out_conv_channels\n", 124 | " self.out_dim = int(dim / 16)\n", 125 | "\n", 126 | " self.conv1 = nn.Sequential(\n", 127 | " nn.Conv2d(\n", 128 | " in_channels=in_channels, out_channels=conv1_channels, kernel_size=4,\n", 129 | " stride=2, padding=1, bias=False\n", 130 | " ),\n", 131 | " nn.BatchNorm2d(conv1_channels),\n", 132 | " nn.LeakyReLU(0.2, inplace=True)\n", 133 | " )\n", 134 | " self.conv2 = nn.Sequential(\n", 135 | " nn.Conv2d(\n", 136 | " in_channels=conv1_channels, out_channels=conv2_channels, kernel_size=4,\n", 137 | " stride=2, padding=1, bias=False\n", 138 | " ),\n", 139 | " nn.BatchNorm2d(conv2_channels),\n", 140 | " nn.LeakyReLU(0.2, inplace=True)\n", 141 | " )\n", 142 | " self.conv3 = nn.Sequential(\n", 143 | " nn.Conv2d(\n", 144 | " in_channels=conv2_channels, out_channels=conv3_channels, kernel_size=4,\n", 145 | " stride=2, padding=1, bias=False\n", 146 | " ),\n", 147 | " nn.BatchNorm2d(conv3_channels),\n", 148 | " nn.LeakyReLU(0.2, inplace=True)\n", 149 | " )\n", 150 | " self.conv4 = nn.Sequential(\n", 151 | " nn.Conv2d(\n", 152 | " in_channels=conv3_channels, out_channels=out_conv_channels, kernel_size=4,\n", 153 | " stride=2, padding=1, bias=False\n", 154 | " ),\n", 155 | " nn.BatchNorm2d(out_conv_channels),\n", 156 | " nn.LeakyReLU(0.2, inplace=True)\n", 157 | " )\n", 158 | " self.out = nn.Sequential(\n", 159 | " nn.Linear(out_conv_channels * self.out_dim * self.out_dim, 1),\n", 160 | " nn.Sigmoid(),\n", 161 | " )\n", 162 | "\n", 163 | " def forward(self, x):\n", 164 | " x = self.conv1(x)\n", 165 | " x = self.conv2(x)\n", 166 | " x = self.conv3(x)\n", 167 | " x = self.conv4(x)\n", 168 | " x = x.view(-1, self.out_conv_channels * self.out_dim * self.out_dim)\n", 169 | " x = self.out(x)\n", 170 | " return x\n" 171 | ], 172 | "execution_count": 0, 173 | "outputs": [] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": { 178 | "id": "oby_xHhqFmfw", 179 | "colab_type": "text" 180 | }, 181 | "source": [ 182 | "## Generator" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "metadata": { 188 | "id": "tIw8VMAUFvJ1", 189 | "colab_type": "code", 190 | "colab": {} 191 | }, 192 | "source": [ 193 | "\n", 194 | "class Generator(torch.nn.Module):\n", 195 | "\n", 196 | " def __init__(self, in_channels=1024, out_dim=64, out_channels=3, noise_dim=200):\n", 197 | " super(Generator, self).__init__()\n", 198 | " self.in_channels = in_channels\n", 199 | " self.out_dim = out_dim\n", 200 | " self.in_dim = int(out_dim / 16)\n", 201 | " conv1_out_channels = int(self.in_channels / 2.0)\n", 202 | " conv2_out_channels = int(conv1_out_channels / 2)\n", 203 | " conv3_out_channels = int(conv2_out_channels / 2)\n", 204 | "\n", 205 | " self.linear = torch.nn.Linear(noise_dim, in_channels * self.in_dim * self.in_dim)\n", 206 | "\n", 207 | " self.conv1 = nn.Sequential(\n", 208 | " nn.ConvTranspose2d(\n", 209 | " in_channels=self.in_channels, out_channels=conv1_out_channels, kernel_size=4,\n", 210 | " stride=2, padding=1, bias=False\n", 211 | " ),\n", 212 | " nn.BatchNorm2d(conv1_out_channels),\n", 213 | " nn.ReLU(inplace=True)\n", 214 | " )\n", 215 | " self.conv2 = nn.Sequential(\n", 216 | " nn.ConvTranspose2d(\n", 217 | " in_channels=conv1_out_channels, out_channels=conv2_out_channels, kernel_size=4,\n", 218 | " stride=2, padding=1, bias=False\n", 219 | " ),\n", 220 | " nn.BatchNorm2d(conv2_out_channels),\n", 221 | " nn.ReLU(inplace=True)\n", 222 | " )\n", 223 | " self.conv3 = nn.Sequential(\n", 224 | " nn.ConvTranspose2d(\n", 225 | " in_channels=conv2_out_channels, out_channels=conv3_out_channels, kernel_size=4,\n", 226 | " stride=2, padding=1, bias=False\n", 227 | " ),\n", 228 | " nn.BatchNorm2d(conv3_out_channels),\n", 229 | " nn.ReLU(inplace=True)\n", 230 | " )\n", 231 | " self.conv4 = nn.Sequential(\n", 232 | " nn.ConvTranspose2d(\n", 233 | " in_channels=conv3_out_channels, out_channels=out_channels, kernel_size=4,\n", 234 | " stride=2, padding=1, bias=False\n", 235 | " )\n", 236 | " )\n", 237 | " self.out = torch.nn.Tanh()\n", 238 | "\n", 239 | " def forward(self, x):\n", 240 | " x = self.linear(x)\n", 241 | " x = x.view(-1, self.in_channels, self.in_dim, self.in_dim)\n", 242 | " x = self.conv1(x)\n", 243 | " x = self.conv2(x)\n", 244 | " x = self.conv3(x)\n", 245 | " x = self.conv4(x)\n", 246 | " return self.out(x)" 247 | ], 248 | "execution_count": 0, 249 | "outputs": [] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": { 254 | "id": "UGmkMNYOF1EN", 255 | "colab_type": "text" 256 | }, 257 | "source": [ 258 | "## Test" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "metadata": { 264 | "id": "s5Xs5zyCF2Ms", 265 | "colab_type": "code", 266 | "colab": { 267 | "base_uri": "https://localhost:8080/", 268 | "height": 990 269 | }, 270 | "outputId": "160788d3-5aa9-4c44-9850-7455dd2cdad9" 271 | }, 272 | "source": [ 273 | "def test_dcgan():\n", 274 | " noise_dim = 100\n", 275 | " in_conv_channels = 512\n", 276 | " dim = 64 # cube volume\n", 277 | " model_generator = Generator(in_channels=in_conv_channels, out_dim=dim, out_channels=3, noise_dim=noise_dim)\n", 278 | " noise = torch.rand(1, noise_dim)\n", 279 | " generated_volume = model_generator(noise)\n", 280 | " print(\"Generator output shape\", generated_volume.shape)\n", 281 | " model_discriminator = Discriminator(in_channels=3, dim=dim, out_conv_channels=in_conv_channels)\n", 282 | " out = model_discriminator(generated_volume)\n", 283 | " print(\"Discriminator output\", out.item())\n", 284 | " print(\"Generator summary\")\n", 285 | " summary(model_generator, (1, noise_dim))\n", 286 | " print(\"Discriminator summary\")\n", 287 | " summary(model_discriminator, (3,64,64))\n", 288 | "\n", 289 | "test_dcgan()" 290 | ], 291 | "execution_count": 7, 292 | "outputs": [ 293 | { 294 | "output_type": "stream", 295 | "text": [ 296 | "Generator output shape torch.Size([1, 3, 64, 64])\n", 297 | "Discriminator output 0.5283796787261963\n", 298 | "Generator summary\n", 299 | "----------------------------------------------------------------\n", 300 | " Layer (type) Output Shape Param #\n", 301 | "================================================================\n", 302 | " Linear-1 [-1, 1, 8192] 827,392\n", 303 | " ConvTranspose2d-2 [-1, 256, 8, 8] 2,097,152\n", 304 | " BatchNorm2d-3 [-1, 256, 8, 8] 512\n", 305 | " ReLU-4 [-1, 256, 8, 8] 0\n", 306 | " ConvTranspose2d-5 [-1, 128, 16, 16] 524,288\n", 307 | " BatchNorm2d-6 [-1, 128, 16, 16] 256\n", 308 | " ReLU-7 [-1, 128, 16, 16] 0\n", 309 | " ConvTranspose2d-8 [-1, 64, 32, 32] 131,072\n", 310 | " BatchNorm2d-9 [-1, 64, 32, 32] 128\n", 311 | " ReLU-10 [-1, 64, 32, 32] 0\n", 312 | " ConvTranspose2d-11 [-1, 3, 64, 64] 3,072\n", 313 | " Tanh-12 [-1, 3, 64, 64] 0\n", 314 | "================================================================\n", 315 | "Total params: 3,583,872\n", 316 | "Trainable params: 3,583,872\n", 317 | "Non-trainable params: 0\n", 318 | "----------------------------------------------------------------\n", 319 | "Input size (MB): 0.00\n", 320 | "Forward/backward pass size (MB): 2.88\n", 321 | "Params size (MB): 13.67\n", 322 | "Estimated Total Size (MB): 16.55\n", 323 | "----------------------------------------------------------------\n", 324 | "Discriminator summary\n", 325 | "----------------------------------------------------------------\n", 326 | " Layer (type) Output Shape Param #\n", 327 | "================================================================\n", 328 | " Conv2d-1 [-1, 64, 32, 32] 3,072\n", 329 | " BatchNorm2d-2 [-1, 64, 32, 32] 128\n", 330 | " LeakyReLU-3 [-1, 64, 32, 32] 0\n", 331 | " Conv2d-4 [-1, 128, 16, 16] 131,072\n", 332 | " BatchNorm2d-5 [-1, 128, 16, 16] 256\n", 333 | " LeakyReLU-6 [-1, 128, 16, 16] 0\n", 334 | " Conv2d-7 [-1, 256, 8, 8] 524,288\n", 335 | " BatchNorm2d-8 [-1, 256, 8, 8] 512\n", 336 | " LeakyReLU-9 [-1, 256, 8, 8] 0\n", 337 | " Conv2d-10 [-1, 512, 4, 4] 2,097,152\n", 338 | " BatchNorm2d-11 [-1, 512, 4, 4] 1,024\n", 339 | " LeakyReLU-12 [-1, 512, 4, 4] 0\n", 340 | " Linear-13 [-1, 1] 8,193\n", 341 | " Sigmoid-14 [-1, 1] 0\n", 342 | "================================================================\n", 343 | "Total params: 2,765,697\n", 344 | "Trainable params: 2,765,697\n", 345 | "Non-trainable params: 0\n", 346 | "----------------------------------------------------------------\n", 347 | "Input size (MB): 0.05\n", 348 | "Forward/backward pass size (MB): 2.81\n", 349 | "Params size (MB): 10.55\n", 350 | "Estimated Total Size (MB): 13.41\n", 351 | "----------------------------------------------------------------\n" 352 | ], 353 | "name": "stdout" 354 | } 355 | ] 356 | } 357 | ] 358 | } -------------------------------------------------------------------------------- /notebooks/3D_GAN_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "3D-GAN-pytorch.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "code", 18 | "metadata": { 19 | "id": "RmxEe0AEtRiO", 20 | "colab_type": "code", 21 | "colab": { 22 | "base_uri": "https://localhost:8080/", 23 | "height": 52 24 | }, 25 | "outputId": "f8e8c2ad-11c6-4fc4-a5dc-33b8e5431808" 26 | }, 27 | "source": [ 28 | "!pip install torch\n", 29 | "!pip install torchsummary\n" 30 | ], 31 | "execution_count": 2, 32 | "outputs": [ 33 | { 34 | "output_type": "stream", 35 | "text": [ 36 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.4.0)\n", 37 | "Requirement already satisfied: torchsummary in /usr/local/lib/python3.6/dist-packages (1.5.1)\n" 38 | ], 39 | "name": "stdout" 40 | } 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "metadata": { 46 | "id": "_ue1BlbIxDsH", 47 | "colab_type": "code", 48 | "colab": { 49 | "base_uri": "https://localhost:8080/", 50 | "height": 54 51 | }, 52 | "outputId": "ac224250-82c9-4e4a-f702-a180dbc8d053" 53 | }, 54 | "source": [ 55 | "import torch\n", 56 | "import torch.nn as nn\n", 57 | "from torchsummary import summary\n", 58 | "\n", 59 | "\"\"\"\n", 60 | "Implementation based on original paper NeurIPS 2016 https://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf\n", 61 | "\"\"\"\n" 62 | ], 63 | "execution_count": 3, 64 | "outputs": [ 65 | { 66 | "output_type": "execute_result", 67 | "data": { 68 | "text/plain": [ 69 | "'\\nImplementation based on original paper NeurIPS 2016 https://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf\\n'" 70 | ] 71 | }, 72 | "metadata": { 73 | "tags": [] 74 | }, 75 | "execution_count": 3 76 | } 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": { 82 | "id": "FdWpJZhRxNVC", 83 | "colab_type": "text" 84 | }, 85 | "source": [ 86 | "## Discriminator" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "metadata": { 92 | "id": "0p3oPTztxLqj", 93 | "colab_type": "code", 94 | "colab": {} 95 | }, 96 | "source": [ 97 | "class Discriminator(torch.nn.Module):\n", 98 | " def __init__(self, in_channels=3, dim=64, out_conv_channels=512):\n", 99 | " super(Discriminator, self).__init__()\n", 100 | " conv1_channels = int(out_conv_channels / 8)\n", 101 | " conv2_channels = int(out_conv_channels / 4)\n", 102 | " conv3_channels = int(out_conv_channels / 2)\n", 103 | " self.out_conv_channels = out_conv_channels\n", 104 | " self.out_dim = int(dim / 16)\n", 105 | "\n", 106 | " self.conv1 = nn.Sequential(\n", 107 | " nn.Conv3d(\n", 108 | " in_channels=in_channels, out_channels=conv1_channels, kernel_size=4,\n", 109 | " stride=2, padding=1, bias=False\n", 110 | " ),\n", 111 | " nn.BatchNorm3d(conv1_channels),\n", 112 | " nn.LeakyReLU(0.2, inplace=True)\n", 113 | " )\n", 114 | " self.conv2 = nn.Sequential(\n", 115 | " nn.Conv3d(\n", 116 | " in_channels=conv1_channels, out_channels=conv2_channels, kernel_size=4,\n", 117 | " stride=2, padding=1, bias=False\n", 118 | " ),\n", 119 | " nn.BatchNorm3d(conv2_channels),\n", 120 | " nn.LeakyReLU(0.2, inplace=True)\n", 121 | " )\n", 122 | " self.conv3 = nn.Sequential(\n", 123 | " nn.Conv3d(\n", 124 | " in_channels=conv2_channels, out_channels=conv3_channels, kernel_size=4,\n", 125 | " stride=2, padding=1, bias=False\n", 126 | " ),\n", 127 | " nn.BatchNorm3d(conv3_channels),\n", 128 | " nn.LeakyReLU(0.2, inplace=True)\n", 129 | " )\n", 130 | " self.conv4 = nn.Sequential(\n", 131 | " nn.Conv3d(\n", 132 | " in_channels=conv3_channels, out_channels=out_conv_channels, kernel_size=4,\n", 133 | " stride=2, padding=1, bias=False\n", 134 | " ),\n", 135 | " nn.BatchNorm3d(out_conv_channels),\n", 136 | " nn.LeakyReLU(0.2, inplace=True)\n", 137 | " )\n", 138 | " self.out = nn.Sequential(\n", 139 | " nn.Linear(out_conv_channels * self.out_dim * self.out_dim * self.out_dim, 1),\n", 140 | " nn.Sigmoid(),\n", 141 | " )\n", 142 | "\n", 143 | " def forward(self, x):\n", 144 | " x = self.conv1(x)\n", 145 | " x = self.conv2(x)\n", 146 | " x = self.conv3(x)\n", 147 | " x = self.conv4(x)\n", 148 | " # Flatten and apply linear + sigmoid\n", 149 | " x = x.view(-1, self.out_conv_channels * self.out_dim * self.out_dim * self.out_dim)\n", 150 | " x = self.out(x)\n", 151 | " return x\n" 152 | ], 153 | "execution_count": 0, 154 | "outputs": [] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": { 159 | "id": "xDCCLjoaxW5B", 160 | "colab_type": "text" 161 | }, 162 | "source": [ 163 | "## Generator" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "metadata": { 169 | "id": "mwYhui5IxY6u", 170 | "colab_type": "code", 171 | "colab": {} 172 | }, 173 | "source": [ 174 | "class Generator(torch.nn.Module):\n", 175 | " def __init__(self, in_channels=512, out_dim=64, out_channels=1, noise_dim=200, activation=\"sigmoid\"):\n", 176 | " super(Generator, self).__init__()\n", 177 | " self.in_channels = in_channels\n", 178 | " self.out_dim = out_dim\n", 179 | " self.in_dim = int(out_dim / 16)\n", 180 | " conv1_out_channels = int(self.in_channels / 2.0)\n", 181 | " conv2_out_channels = int(conv1_out_channels / 2)\n", 182 | " conv3_out_channels = int(conv2_out_channels / 2)\n", 183 | "\n", 184 | " self.linear = torch.nn.Linear(noise_dim, in_channels * self.in_dim * self.in_dim * self.in_dim)\n", 185 | "\n", 186 | " self.conv1 = nn.Sequential(\n", 187 | " nn.ConvTranspose3d(\n", 188 | " in_channels=in_channels, out_channels=conv1_out_channels, kernel_size=(4, 4, 4),\n", 189 | " stride=2, padding=1, bias=False\n", 190 | " ),\n", 191 | " nn.BatchNorm3d(conv1_out_channels),\n", 192 | " nn.ReLU(inplace=True)\n", 193 | " )\n", 194 | " self.conv2 = nn.Sequential(\n", 195 | " nn.ConvTranspose3d(\n", 196 | " in_channels=conv1_out_channels, out_channels=conv2_out_channels, kernel_size=(4, 4, 4),\n", 197 | " stride=2, padding=1, bias=False\n", 198 | " ),\n", 199 | " nn.BatchNorm3d(conv2_out_channels),\n", 200 | " nn.ReLU(inplace=True)\n", 201 | " )\n", 202 | " self.conv3 = nn.Sequential(\n", 203 | " nn.ConvTranspose3d(\n", 204 | " in_channels=conv2_out_channels, out_channels=conv3_out_channels, kernel_size=(4, 4, 4),\n", 205 | " stride=2, padding=1, bias=False\n", 206 | " ),\n", 207 | " nn.BatchNorm3d(conv3_out_channels),\n", 208 | " nn.ReLU(inplace=True)\n", 209 | " )\n", 210 | " self.conv4 = nn.Sequential(\n", 211 | " nn.ConvTranspose3d(\n", 212 | " in_channels=conv3_out_channels, out_channels=out_channels, kernel_size=(4, 4, 4),\n", 213 | " stride=2, padding=1, bias=False\n", 214 | " )\n", 215 | " )\n", 216 | " if activation == \"sigmoid\":\n", 217 | " self.out = torch.nn.Sigmoid()\n", 218 | " else:\n", 219 | " self.out = torch.nn.Tanh()\n", 220 | "\n", 221 | " def project(self, x):\n", 222 | " \"\"\"\n", 223 | " projects and reshapes latent vector to starting volume\n", 224 | " :param x: latent vector\n", 225 | " :return: starting volume\n", 226 | " \"\"\"\n", 227 | " return x.view(-1, self.in_channels, self.in_dim, self.in_dim, self.in_dim)\n", 228 | "\n", 229 | " def forward(self, x):\n", 230 | " x = self.linear(x)\n", 231 | " x = self.project(x)\n", 232 | " x = self.conv1(x)\n", 233 | " x = self.conv2(x)\n", 234 | " x = self.conv3(x)\n", 235 | " x = self.conv4(x)\n", 236 | " return self.out(x)\n" 237 | ], 238 | "execution_count": 0, 239 | "outputs": [] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": { 244 | "id": "i2GyllAQxc-R", 245 | "colab_type": "text" 246 | }, 247 | "source": [ 248 | "## Test" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "metadata": { 254 | "id": "OmJJv8VwxfCC", 255 | "colab_type": "code", 256 | "colab": { 257 | "base_uri": "https://localhost:8080/", 258 | "height": 1000 259 | }, 260 | "outputId": "22d77bc2-ea30-4298-d2f2-b2ef203494ba" 261 | }, 262 | "source": [ 263 | "def test_gan3d(print_summary=True):\n", 264 | " noise_dim = 200 # latent space vector dim\n", 265 | " in_channels = 512 # convolutional channels\n", 266 | " dim = 64 # cube volume\n", 267 | " model_generator = Generator(in_channels=512, out_dim=dim, out_channels=1, noise_dim=noise_dim)\n", 268 | " noise = torch.rand(1, noise_dim)\n", 269 | " generated_volume = model_generator(noise)\n", 270 | " print(\"Generator output shape\", generated_volume.shape)\n", 271 | " model_discriminator = Discriminator(in_channels=1, dim=dim, out_conv_channels=in_channels)\n", 272 | " out = model_discriminator(generated_volume)\n", 273 | " print(\"Discriminator output\", out.item())\n", 274 | " if print_summary:\n", 275 | " print(\"\\n\\nGenerator summary\\n\\n\")\n", 276 | " summary(model_generator, (1, noise_dim))\n", 277 | " print(\"\\n\\nDiscriminator summary\\n\\n\")\n", 278 | " summary(model_discriminator, (1,dim,dim,dim))\n", 279 | "\n", 280 | "test_gan3d()" 281 | ], 282 | "execution_count": 12, 283 | "outputs": [ 284 | { 285 | "output_type": "stream", 286 | "text": [ 287 | "Generator output shape torch.Size([1, 1, 64, 64, 64])\n", 288 | "Discriminator output 0.47117894887924194\n", 289 | "\n", 290 | "\n", 291 | "Generator summary\n", 292 | "\n", 293 | "\n", 294 | "----------------------------------------------------------------\n", 295 | " Layer (type) Output Shape Param #\n", 296 | "================================================================\n", 297 | " Linear-1 [-1, 1, 32768] 6,586,368\n", 298 | " ConvTranspose3d-2 [-1, 256, 8, 8, 8] 8,388,608\n", 299 | " BatchNorm3d-3 [-1, 256, 8, 8, 8] 512\n", 300 | " ReLU-4 [-1, 256, 8, 8, 8] 0\n", 301 | " ConvTranspose3d-5 [-1, 128, 16, 16, 16] 2,097,152\n", 302 | " BatchNorm3d-6 [-1, 128, 16, 16, 16] 256\n", 303 | " ReLU-7 [-1, 128, 16, 16, 16] 0\n", 304 | " ConvTranspose3d-8 [-1, 64, 32, 32, 32] 524,288\n", 305 | " BatchNorm3d-9 [-1, 64, 32, 32, 32] 128\n", 306 | " ReLU-10 [-1, 64, 32, 32, 32] 0\n", 307 | " ConvTranspose3d-11 [-1, 1, 64, 64, 64] 4,096\n", 308 | " Sigmoid-12 [-1, 1, 64, 64, 64] 0\n", 309 | "================================================================\n", 310 | "Total params: 17,601,408\n", 311 | "Trainable params: 17,601,408\n", 312 | "Non-trainable params: 0\n", 313 | "----------------------------------------------------------------\n", 314 | "Input size (MB): 0.00\n", 315 | "Forward/backward pass size (MB): 67.25\n", 316 | "Params size (MB): 67.14\n", 317 | "Estimated Total Size (MB): 134.39\n", 318 | "----------------------------------------------------------------\n", 319 | "\n", 320 | "\n", 321 | "Discriminator summary\n", 322 | "\n", 323 | "\n", 324 | "----------------------------------------------------------------\n", 325 | " Layer (type) Output Shape Param #\n", 326 | "================================================================\n", 327 | " Conv3d-1 [-1, 64, 32, 32, 32] 4,096\n", 328 | " BatchNorm3d-2 [-1, 64, 32, 32, 32] 128\n", 329 | " LeakyReLU-3 [-1, 64, 32, 32, 32] 0\n", 330 | " Conv3d-4 [-1, 128, 16, 16, 16] 524,288\n", 331 | " BatchNorm3d-5 [-1, 128, 16, 16, 16] 256\n", 332 | " LeakyReLU-6 [-1, 128, 16, 16, 16] 0\n", 333 | " Conv3d-7 [-1, 256, 8, 8, 8] 2,097,152\n", 334 | " BatchNorm3d-8 [-1, 256, 8, 8, 8] 512\n", 335 | " LeakyReLU-9 [-1, 256, 8, 8, 8] 0\n", 336 | " Conv3d-10 [-1, 512, 4, 4, 4] 8,388,608\n", 337 | " BatchNorm3d-11 [-1, 512, 4, 4, 4] 1,024\n", 338 | " LeakyReLU-12 [-1, 512, 4, 4, 4] 0\n", 339 | " Linear-13 [-1, 1] 32,769\n", 340 | " Sigmoid-14 [-1, 1] 0\n", 341 | "================================================================\n", 342 | "Total params: 11,048,833\n", 343 | "Trainable params: 11,048,833\n", 344 | "Non-trainable params: 0\n", 345 | "----------------------------------------------------------------\n", 346 | "Input size (MB): 1.00\n", 347 | "Forward/backward pass size (MB): 63.75\n", 348 | "Params size (MB): 42.15\n", 349 | "Estimated Total Size (MB): 106.90\n", 350 | "----------------------------------------------------------------\n" 351 | ], 352 | "name": "stdout" 353 | } 354 | ] 355 | } 356 | ] 357 | } --------------------------------------------------------------------------------