├── new_metadata.xlsx
├── docs
├── Gen vs Real.jpg
├── GAN App demo.png
└── Omni BigGAN - Overview.jpg
├── requirements.txt
├── app
├── requirements.txt
├── gan_app.py
└── biggan.py
├── LICENSE
├── README.md
└── omni-loss-biggan.ipynb
/new_metadata.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safi842/Microstructure-GAN/HEAD/new_metadata.xlsx
--------------------------------------------------------------------------------
/docs/Gen vs Real.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safi842/Microstructure-GAN/HEAD/docs/Gen vs Real.jpg
--------------------------------------------------------------------------------
/docs/GAN App demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safi842/Microstructure-GAN/HEAD/docs/GAN App demo.png
--------------------------------------------------------------------------------
/docs/Omni BigGAN - Overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safi842/Microstructure-GAN/HEAD/docs/Omni BigGAN - Overview.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.21.5
2 | pandas==1.4.4
3 | Pillow==9.4.0
4 | requests==2.28.1
5 | stqdm==0.0.5
6 | streamlit==1.17.0
7 | torch==1.13.1
8 | torchvision==0.14.1
9 |
--------------------------------------------------------------------------------
/app/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.21.5
2 | pandas==1.4.4
3 | Pillow==9.4.0
4 | requests==2.28.1
5 | stqdm==0.0.5
6 | streamlit==1.17.0
7 | torch==1.13.1
8 | torchvision==0.14.1
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Mohammad Safiuddin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Microstructure-GAN — Pytorch Implementation
2 | [](https://safi842-microstructure-gan-appgan-app-32c049.streamlit.app/)
3 |
4 |
5 | 
6 |
7 | ### Establishing process-structure linkages using Generative Adversarial Networks
8 | Mohammad Safiuddin, Ch Likith Reddy, Ganesh Vasantada, CHJNS Harsha, Dr. Srinu Gangolu
9 |
10 | Paper: http://dx.doi.org/10.1007/978-981-97-6367-2_39
11 |
12 | [comment]: <> (Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen, Timo Aila
)
13 | [comment]: <> (Paper: http://dx.doi.org/10.1007/978-981-97-6367-2_39
)
14 |
15 | Abstract: *The microstructure of a material strongly influences its mechanical properties
16 | and the microstructure itself is influenced by the processing conditions. Thus,
17 | establishing a Process-Structure-Property relationship is a crucial task in material design and is of interest in many engineering applications. In this work,
18 | the processing-structure relationship is modelled as deep learning based conditional image synthesis problem. This approach is devoid of feature engineering,
19 | needs little domain awareness, and can be applied to a wide variety of material
20 | systems. We develop a GAN (Generative Adversarial Network) to synthesize
21 | microstructures based on given processing conditions. Results show that our GAN model
22 | can produce high-fidelity multiphase microstructures which have a good correlation with the given processing conditions.*
23 |
24 | ### Results:
25 |
26 |
27 |
28 |
29 |
30 | #### File Overview
31 |
32 | The following files are included in this package:
33 |
34 | - `omni-loss-biggan.ipynb`: an Ipython notebook that contains the code used to train the model.
35 | - `new_metadata.xlsx`: an Excel workbook that holds the training image metadata.
36 | - `.\app`: a directory that contains the source code for the app. Further instructions on the app can be found below.
37 |
38 | ### Application
39 | [](https://safi842-microstructure-gan-appgan-app-32c049.streamlit.app/) \
40 | **If you want to run the app locally, follow the instructions below**
41 |
42 |
43 |
44 |
45 | To install the app, unzip the `.\Microstructure_GAN` folder. Next, navigate to the `.\Microstructure_GAN\app` directory in a terminal and run the following command to install the necessary packages:
46 |
47 | ```
48 | pip install requirements.txt
49 | ```
50 |
51 | Once the packages have been installed, run the following command to start the web app:
52 |
53 | ```
54 | streamlit run gan_app.py
55 | ```
56 |
57 | **Recreating Results:**
58 |
59 | Generated micrographs can be downloaded by clicking the "Download Micrograph" button. The file name of the saved image contains the processing conditions and seed value, for example: `800-85H-Quench-864.png`. To recreate the image, the latent vector can be generated using the `seed` as follows.
60 | ```
61 | seed = 864
62 | rng = np.random.RandomState(seed)
63 | latent_vector = rng.normal(0, 1, (1, 384))
64 | ```
65 |
--------------------------------------------------------------------------------
/app/gan_app.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from io import BytesIO
3 | import os
4 | import json
5 | import requests
6 | #import urllib.request
7 | from stqdm import stqdm
8 | from PIL import Image
9 | import pickle
10 | import streamlit as st
11 | import sys
12 | import urllib
13 | import torch
14 | import random
15 | import biggan
16 | from torchvision.utils import make_grid
17 | from io import BytesIO
18 | import base64
19 |
20 | class NumpyEncoder(json.JSONEncoder):
21 | def default(self, obj):
22 | if isinstance(obj, np.ndarray):
23 | return obj.tolist()
24 | return json.JSONEncoder.default(self, obj)
25 |
26 |
27 | def main():
28 | first_run = not os.path.exists('state.json')
29 | state = {}
30 | st.title("Microstructure GAN demo")
31 | """This is a demonstration of conditional image generation of micrographs using a modified architecture based on [BigGAN-deep](https://arxiv.org/abs/1809.11096)
32 | The images generated are using three conditional inputs Annealing Temperature, Annealing Time and the type of cooling used.
33 | GAN is trained using [Omni Loss](https://arxiv.org/abs/2011.13074) on [UHCSDB](http://uhcsdb.materials.cmu.edu/) images. Details on the methodology can be found in the [paper](https://arxiv.org/abs/2107.09402)"""
34 |
35 | st.sidebar.title('Processing Conditions',)
36 | state['anneal_temp'] = st.sidebar.selectbox('Annealing Temperature °C',[700,750,800,900,970,1000,1100])
37 | state['anneal_time'] = st.sidebar.selectbox('Annealing Time (M: Minutes, H: Hours)',['5M','90M','1H','3H','8H','24H','48H','85H'])
38 | state['cooling'] = st.sidebar.selectbox('Cooling Type',['Quench','Furnace Cool','Air Cool','650C-1H'])
39 | temp_dict = {970: 0, 800: 1, 900: 2, 1100: 3, 1000: 4, 700: 5, 750: 6}
40 | time_dict = {'90M': 0, '24H': 1, '3H': 2, '5M': 3, '8H': 4, '85H': 5, '1H': 6, '48H': 7}
41 | cool_dict = {'Quench': 0, 'Air Cool': 1, 'Furnace Cool': 2, '650C-1H': 3}
42 | model = load_gan()
43 | st.sidebar.subheader('Generate a new latent Vector')
44 | state['seed'] = 7
45 | if st.sidebar.button('New z'):
46 | state['seed'] = random.randint(0,1000)
47 | rng = np.random.RandomState(state['seed'])
48 | noise = torch.tensor(rng.normal(0, 1, (1, 384))).float()
49 | state['noise'] = noise.numpy()
50 | y_temp = temp_dict[state['anneal_temp']]
51 | y_time = time_dict[state['anneal_time']]
52 | y_cool = cool_dict[state['cooling']]
53 |
54 | state['image_out'] = generate_img(model, noise, y_temp, y_time, y_cool)
55 | st.subheader('Generated Microstructure for the given processing conditions')
56 | st.text("")
57 | st.text(f"Random seed: {state['seed']}")
58 | st.image(np.array(state['image_out']), use_column_width=False)
59 | im = Image.fromarray((np.array(state['image_out']).reshape(256,256) * 255).astype(np.uint8))
60 | buf = BytesIO()
61 | im.save(buf, format="PNG")
62 | byte_im = buf.getvalue()
63 | st.download_button(label = "Download Micrograph", data = byte_im , file_name = f"{state['anneal_temp']}-{state['anneal_time']}-{state['cooling']}-{state['seed']}.png", mime="image/png")
64 |
65 | with open('state.json', 'w') as fp:
66 | json.dump(state, fp, cls=NumpyEncoder)
67 |
68 | def download_model(url, filename):
69 | """
70 | Helper method handling downloading large files from `url` to `filename`. Returns a pointer to `filename`.
71 | """
72 | with st.spinner(text= f"Downloading {filename} ..."):
73 | chunkSize = 1024
74 | r = requests.get(url, stream=True)
75 | with open(filename, 'wb') as f:
76 | pbar = stqdm( unit="B", total=int( r.headers['Content-Length'] ) )
77 | for chunk in r.iter_content(chunk_size=chunkSize):
78 | if chunk: # filter out keep-alive new chunks
79 | pbar.update (len(chunk))
80 | f.write(chunk)
81 |
82 | @st.cache(suppress_st_warning=True)
83 | def load_gan():
84 | if not os.path.isfile("BigGAN-deep.pth"):
85 | url = "https://github.com/safi842/Microstructure-GAN/releases/download/v0/BigGAN-deep.pth"
86 | filename = "BigGAN-deep.pth"
87 | download_model(url, filename)
88 | st.write(f"Downloaded {filename}...")
89 | #filename, headers = urllib.request.urlretrieve(url,"BigGAN-deep.pth",MyProgressBar())
90 | model = biggan.Generator()
91 | model.load_state_dict(torch.load('BigGAN-deep.pth', map_location=torch.device('cpu')))
92 | return model
93 |
94 | @st.cache(suppress_st_warning=True)
95 | def generate_img(model,noise, y_temp, y_time, y_cool):
96 | y_temp = torch.tensor([y_temp])
97 | y_time = torch.tensor([y_time])
98 | y_cool = torch.tensor([y_cool])
99 | with torch.no_grad():
100 | synthetic = model(noise, y_temp, y_time, y_cool)[0]
101 | synthetic = 0.5 * synthetic + 0.5
102 | #synthetic = make_grid(synthetic, normalize=True)
103 | return np.transpose(synthetic.numpy() ,(1,2,0))
104 |
105 | main()
106 | st.markdown('Copyright (c) 2023 Mohammad Safiuddin
', unsafe_allow_html=True)
--------------------------------------------------------------------------------
/app/biggan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.backends.cudnn as cudnn
8 | import torch.optim as optim
9 | import torch.utils.data
10 | import torchvision.transforms as transforms
11 | import torchvision.utils as vutils
12 | from torch.utils.data import DataLoader, Dataset
13 | from torch.nn.utils import spectral_norm
14 | import torch.nn.functional as F
15 | import numpy as np
16 | from torchvision.utils import make_grid
17 |
18 | class ClassConditionalBN(nn.Module):
19 | def __init__(self, input_size, output_size, eps=1e-4, momentum=0.1):
20 | super(ClassConditionalBN, self).__init__()
21 | self.output_size, self.input_size = output_size, input_size
22 | # Prepare gain and bias layers
23 | self.gain = spectral_norm(nn.Linear(input_size, output_size, bias = False), eps = 1e-4)
24 | self.bias = spectral_norm(nn.Linear(input_size, output_size, bias = False), eps = 1e-4)
25 | # epsilon to avoid dividing by 0
26 | self.eps = eps
27 | # Momentum
28 | self.momentum = momentum
29 |
30 | self.register_buffer('stored_mean', torch.zeros(output_size))
31 | self.register_buffer('stored_var', torch.ones(output_size))
32 |
33 | def forward(self, x, y):
34 | # Calculate class-conditional gains and biases
35 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
36 | bias = self.bias(y).view(y.size(0), -1, 1, 1)
37 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
38 | self.training, 0.1, self.eps)
39 | return out * gain + bias
40 |
41 | def extra_repr(self):
42 | s = 'out: {output_size}, in: {input_size},'
43 | return s.format(**self.__dict__)
44 |
45 | class Self_Attn(nn.Module):
46 | """ Self attention Layer"""
47 | def __init__(self,in_dim,activation = nn.ReLU(inplace = False)):
48 | super(Self_Attn,self).__init__()
49 | self.chanel_in = in_dim
50 | self.activation = activation
51 |
52 | self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
53 | self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
54 | self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
55 | self.gamma = nn.Parameter(torch.zeros(1))
56 |
57 | self.softmax = nn.Softmax(dim=-1) #
58 | def forward(self,x):
59 | """
60 | inputs :
61 | x : input feature maps( B X C X W X H)
62 | returns :
63 | out : self attention value + input feature
64 | attention: B X N X N (N is Width*Height)
65 | """
66 | m_batchsize,C,width ,height = x.size()
67 | proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
68 | proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
69 | energy = torch.bmm(proj_query,proj_key) # transpose check
70 | attention = self.softmax(energy) # BX (N) X (N)
71 | proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
72 |
73 | out = torch.bmm(proj_value,attention.permute(0,2,1) )
74 | out = out.view(m_batchsize,C,width,height)
75 |
76 | out = self.gamma*out + x
77 | return out
78 |
79 | class GeneratorResBlock(nn.Module):
80 | def __init__(self, in_channels, out_channels, upsample = None, embed_dim = 128, dim_z = 384):
81 | super(GeneratorResBlock, self).__init__()
82 | self.in_channels = in_channels
83 | self.out_channels = out_channels
84 | self.hidden_channels = self.in_channels // 4
85 |
86 | self.conv1 = spectral_norm(nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size = 1, padding = 0), eps = 1e-4)
87 | self.conv2 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)
88 | self.conv3 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)
89 | self.conv4 = spectral_norm(nn.Conv2d(self.hidden_channels, self.out_channels, kernel_size = 1, padding = 0), eps = 1e-4)
90 |
91 | self.bn1 = ClassConditionalBN((3 * embed_dim) + dim_z, self.in_channels)
92 | self.bn2 = ClassConditionalBN((3 * embed_dim) + dim_z, self.hidden_channels)
93 | self.bn3 = ClassConditionalBN((3 * embed_dim) + dim_z, self.hidden_channels)
94 | self.bn4 = ClassConditionalBN((3 * embed_dim) + dim_z, self.hidden_channels)
95 |
96 | self.activation = nn.ReLU(inplace=False)
97 |
98 | self.upsample = upsample
99 |
100 | def forward(self,x,y):
101 | # Project down to channel ratio
102 | h = self.conv1(self.activation(self.bn1(x, y)))
103 | # Apply next BN-ReLU
104 | h = self.activation(self.bn2(h, y))
105 | # Drop channels in x if necessary
106 | if self.in_channels != self.out_channels:
107 | x = x[:, :self.out_channels]
108 | # Upsample both h and x at this point
109 | if self.upsample:
110 | h = self.upsample(h)
111 | x = self.upsample(x)
112 | # 3x3 convs
113 | h = self.conv2(h)
114 | h = self.conv3(self.activation(self.bn3(h, y)))
115 | # Final 1x1 conv
116 | h = self.conv4(self.activation(self.bn4(h, y)))
117 | return h + x
118 |
119 | class Generator(nn.Module):
120 | def __init__(self, G_ch = 64, dim_z = 384, bottom_width=4, img_channels = 1,
121 | init = 'N02',n_classes_temp = 7, n_classes_time = 8, n_classes_cool = 4, embed_dim = 128):
122 | super(Generator, self).__init__()
123 | self.ch = G_ch
124 | self.dim_z = dim_z
125 | self.bottom_width = bottom_width
126 | self.init = init
127 | self.img_channels = img_channels
128 |
129 | self.embed_temp = nn.Embedding(n_classes_temp, embed_dim)
130 | self.embed_time = nn.Embedding(n_classes_time, embed_dim)
131 | self.embed_cool = nn.Embedding(n_classes_cool, embed_dim)
132 |
133 | self.linear = spectral_norm(nn.Linear(dim_z + (3 * embed_dim), 16 * self.ch * (self.bottom_width **2)), eps = 1e-4)
134 |
135 | self.blocks = nn.ModuleList([
136 | GeneratorResBlock(16*self.ch, 16*self.ch),
137 | GeneratorResBlock(16*self.ch, 16*self.ch, upsample = nn.Upsample(scale_factor = 2)),
138 | GeneratorResBlock(16*self.ch, 16*self.ch),
139 | GeneratorResBlock(16*self.ch, 8*self.ch, upsample = nn.Upsample(scale_factor = 2)),
140 | GeneratorResBlock(8*self.ch, 8*self.ch),
141 | GeneratorResBlock(8*self.ch, 8*self.ch, upsample = nn.Upsample(scale_factor = 2)),
142 | GeneratorResBlock(8*self.ch, 8*self.ch),
143 | GeneratorResBlock(8*self.ch, 4*self.ch, upsample = nn.Upsample(scale_factor = 2)),
144 | Self_Attn(4*self.ch),
145 | GeneratorResBlock(4*self.ch, 4*self.ch),
146 | GeneratorResBlock(4*self.ch, 2*self.ch, upsample = nn.Upsample(scale_factor = 2)),
147 | GeneratorResBlock(2*self.ch, 2*self.ch),
148 | GeneratorResBlock(2*self.ch, self.ch, upsample = nn.Upsample(scale_factor = 2))
149 | ])
150 |
151 | self.final_layer = nn.Sequential(
152 | nn.BatchNorm2d(self.ch),
153 | nn.ReLU(inplace = False),
154 | spectral_norm(nn.Conv2d(self.ch, self.img_channels, kernel_size = 3, padding = 1)),
155 | nn.Tanh()
156 | )
157 |
158 | self.init_weights()
159 |
160 | def init_weights(self):
161 | print(f"Weight initialization : {self.init}")
162 | self.param_count = 0
163 | for module in self.modules():
164 | if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
165 | if self.init == 'ortho':
166 | torch.nn.init.orthogonal_(module.weight)
167 | elif self.init == 'N02':
168 | torch.nn.init.normal_(module.weight, 0, 0.02)
169 | elif self.init in ['glorot', 'xavier']:
170 | torch.nn.init.xavier_uniform_(module.weight)
171 | else:
172 | print('Init style not recognized...')
173 | self.param_count += sum([p.data.nelement() for p in module.parameters()])
174 | #print("Param count for G's initialized parameters: %d Million" % (self.param_count/1000000))
175 |
176 |
177 | def forward(self,z , y_temp, y_time, y_cool):
178 | y_temp = self.embed_temp(y_temp)
179 | y_time = self.embed_time(y_time)
180 | y_cool = self.embed_cool(y_cool)
181 | z = torch.cat([z, y_temp, y_time, y_cool], 1)
182 | # First linear layer
183 | h = self.linear(z)
184 | # Reshape
185 | h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
186 | # Loop over blocks
187 | for i, block in enumerate(self.blocks):
188 | if i != 8:
189 | h = block(h, z)
190 | else:
191 | h = block(h)
192 | # Apply batchnorm-relu-conv-tanh at output
193 | h = self.final_layer(h)
194 | return h
195 |
--------------------------------------------------------------------------------
/omni-loss-biggan.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "execution": {
8 | "iopub.execute_input": "2021-05-12T05:40:48.623112Z",
9 | "iopub.status.busy": "2021-05-12T05:40:48.621513Z",
10 | "iopub.status.idle": "2021-05-12T05:40:51.352905Z",
11 | "shell.execute_reply": "2021-05-12T05:40:51.352242Z"
12 | },
13 | "papermill": {
14 | "duration": 2.755864,
15 | "end_time": "2021-05-12T05:40:51.353059",
16 | "exception": false,
17 | "start_time": "2021-05-12T05:40:48.597195",
18 | "status": "completed"
19 | },
20 | "tags": []
21 | },
22 | "outputs": [],
23 | "source": [
24 | "import os\n",
25 | "import pandas as pd\n",
26 | "import random\n",
27 | "from collections import OrderedDict\n",
28 | "from PIL import Image\n",
29 | "import torch\n",
30 | "import torch.nn as nn\n",
31 | "import torch.nn.parallel\n",
32 | "import torch.backends.cudnn as cudnn\n",
33 | "import torch.optim as optim\n",
34 | "import torch.utils.data\n",
35 | "import torchvision.transforms as transforms\n",
36 | "import torchvision.utils as vutils\n",
37 | "from torch.utils.data import DataLoader, Dataset\n",
38 | "import pytorch_lightning as pl\n",
39 | "from torch.nn.utils import spectral_norm\n",
40 | "import torch.nn.functional as F\n",
41 | "import numpy as np\n",
42 | "import matplotlib.pyplot as plt\n",
43 | "import matplotlib.animation as animation\n",
44 | "from torchvision.utils import make_grid\n",
45 | "import warnings\n",
46 | "warnings.filterwarnings(\"ignore\")\n",
47 | "#from torchsummaryX import summary"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 2,
53 | "metadata": {
54 | "execution": {
55 | "iopub.execute_input": "2021-05-12T05:40:51.394016Z",
56 | "iopub.status.busy": "2021-05-12T05:40:51.393457Z",
57 | "iopub.status.idle": "2021-05-12T05:40:59.288763Z",
58 | "shell.execute_reply": "2021-05-12T05:40:59.288233Z"
59 | },
60 | "papermill": {
61 | "duration": 7.917559,
62 | "end_time": "2021-05-12T05:40:59.288929",
63 | "exception": false,
64 | "start_time": "2021-05-12T05:40:51.371370",
65 | "status": "completed"
66 | },
67 | "tags": []
68 | },
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "Collecting openpyxl\r\n",
75 | " Downloading openpyxl-3.0.7-py2.py3-none-any.whl (243 kB)\r\n",
76 | "\u001b[K |████████████████████████████████| 243 kB 1.2 MB/s \r\n",
77 | "\u001b[?25hCollecting et-xmlfile\r\n",
78 | " Downloading et_xmlfile-1.1.0-py3-none-any.whl (4.7 kB)\r\n",
79 | "Installing collected packages: et-xmlfile, openpyxl\r\n",
80 | "Successfully installed et-xmlfile-1.1.0 openpyxl-3.0.7\r\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "!pip install openpyxl"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {
91 | "papermill": {
92 | "duration": 0.020651,
93 | "end_time": "2021-05-12T05:40:59.330619",
94 | "exception": false,
95 | "start_time": "2021-05-12T05:40:59.309968",
96 | "status": "completed"
97 | },
98 | "tags": []
99 | },
100 | "source": [
101 | "## Class Conditional Batch Normalization"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 3,
107 | "metadata": {
108 | "execution": {
109 | "iopub.execute_input": "2021-05-12T05:40:59.384025Z",
110 | "iopub.status.busy": "2021-05-12T05:40:59.383273Z",
111 | "iopub.status.idle": "2021-05-12T05:40:59.386077Z",
112 | "shell.execute_reply": "2021-05-12T05:40:59.385647Z"
113 | },
114 | "papermill": {
115 | "duration": 0.033907,
116 | "end_time": "2021-05-12T05:40:59.386189",
117 | "exception": false,
118 | "start_time": "2021-05-12T05:40:59.352282",
119 | "status": "completed"
120 | },
121 | "tags": []
122 | },
123 | "outputs": [],
124 | "source": [
125 | "class ClassConditionalBN(nn.Module):\n",
126 | " def __init__(self, input_size, output_size, eps=1e-4, momentum=0.1):\n",
127 | " super(ClassConditionalBN, self).__init__()\n",
128 | " self.output_size, self.input_size = output_size, input_size\n",
129 | " # Prepare gain and bias layers\n",
130 | " self.gain = spectral_norm(nn.Linear(input_size, output_size, bias = False), eps = 1e-4)\n",
131 | " self.bias = spectral_norm(nn.Linear(input_size, output_size, bias = False), eps = 1e-4)\n",
132 | " # epsilon to avoid dividing by 0\n",
133 | " self.eps = eps\n",
134 | " # Momentum\n",
135 | " self.momentum = momentum\n",
136 | " \n",
137 | " self.register_buffer('stored_mean', torch.zeros(output_size))\n",
138 | " self.register_buffer('stored_var', torch.ones(output_size))\n",
139 | " \n",
140 | " def forward(self, x, y):\n",
141 | " # Calculate class-conditional gains and biases\n",
142 | " gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)\n",
143 | " bias = self.bias(y).view(y.size(0), -1, 1, 1)\n",
144 | " out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,\n",
145 | " self.training, 0.1, self.eps)\n",
146 | " return out * gain + bias\n",
147 | " \n",
148 | " def extra_repr(self):\n",
149 | " s = 'out: {output_size}, in: {input_size},'\n",
150 | " return s.format(**self.__dict__)"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {
156 | "papermill": {
157 | "duration": 0.020582,
158 | "end_time": "2021-05-12T05:40:59.427620",
159 | "exception": false,
160 | "start_time": "2021-05-12T05:40:59.407038",
161 | "status": "completed"
162 | },
163 | "tags": []
164 | },
165 | "source": [
166 | "## Self Attention Module"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 4,
172 | "metadata": {
173 | "execution": {
174 | "iopub.execute_input": "2021-05-12T05:40:59.480572Z",
175 | "iopub.status.busy": "2021-05-12T05:40:59.479671Z",
176 | "iopub.status.idle": "2021-05-12T05:40:59.481743Z",
177 | "shell.execute_reply": "2021-05-12T05:40:59.482227Z"
178 | },
179 | "papermill": {
180 | "duration": 0.033732,
181 | "end_time": "2021-05-12T05:40:59.482365",
182 | "exception": false,
183 | "start_time": "2021-05-12T05:40:59.448633",
184 | "status": "completed"
185 | },
186 | "tags": []
187 | },
188 | "outputs": [],
189 | "source": [
190 | "class Self_Attn(nn.Module):\n",
191 | " \"\"\" Self attention Layer\"\"\"\n",
192 | " def __init__(self,in_dim,activation = nn.ReLU(inplace = False)):\n",
193 | " super(Self_Attn,self).__init__()\n",
194 | " self.chanel_in = in_dim\n",
195 | " self.activation = activation\n",
196 | " \n",
197 | " self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)\n",
198 | " self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)\n",
199 | " self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)\n",
200 | " self.gamma = nn.Parameter(torch.zeros(1))\n",
201 | "\n",
202 | " self.softmax = nn.Softmax(dim=-1) #\n",
203 | " def forward(self,x):\n",
204 | " \"\"\"\n",
205 | " inputs :\n",
206 | " x : input feature maps( B X C X W X H)\n",
207 | " returns :\n",
208 | " out : self attention value + input feature \n",
209 | " attention: B X N X N (N is Width*Height)\n",
210 | " \"\"\"\n",
211 | " m_batchsize,C,width ,height = x.size()\n",
212 | " proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)\n",
213 | " proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)\n",
214 | " energy = torch.bmm(proj_query,proj_key) # transpose check\n",
215 | " attention = self.softmax(energy) # BX (N) X (N) \n",
216 | " proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N\n",
217 | "\n",
218 | " out = torch.bmm(proj_value,attention.permute(0,2,1) )\n",
219 | " out = out.view(m_batchsize,C,width,height)\n",
220 | " \n",
221 | " out = self.gamma*out + x\n",
222 | " return out"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "metadata": {
228 | "papermill": {
229 | "duration": 0.020504,
230 | "end_time": "2021-05-12T05:40:59.523763",
231 | "exception": false,
232 | "start_time": "2021-05-12T05:40:59.503259",
233 | "status": "completed"
234 | },
235 | "tags": []
236 | },
237 | "source": [
238 | "## Generator Resblock"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 5,
244 | "metadata": {
245 | "execution": {
246 | "iopub.execute_input": "2021-05-12T05:40:59.580342Z",
247 | "iopub.status.busy": "2021-05-12T05:40:59.579379Z",
248 | "iopub.status.idle": "2021-05-12T05:40:59.582130Z",
249 | "shell.execute_reply": "2021-05-12T05:40:59.581725Z"
250 | },
251 | "papermill": {
252 | "duration": 0.037638,
253 | "end_time": "2021-05-12T05:40:59.582245",
254 | "exception": false,
255 | "start_time": "2021-05-12T05:40:59.544607",
256 | "status": "completed"
257 | },
258 | "tags": []
259 | },
260 | "outputs": [],
261 | "source": [
262 | "class GeneratorResBlock(nn.Module):\n",
263 | " def __init__(self, in_channels, out_channels, upsample = None, embed_dim = 128, dim_z = 384):\n",
264 | " super(GeneratorResBlock, self).__init__()\n",
265 | " self.in_channels = in_channels\n",
266 | " self.out_channels = out_channels\n",
267 | " self.hidden_channels = self.in_channels // 4\n",
268 | " \n",
269 | " self.conv1 = spectral_norm(nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size = 1, padding = 0), eps = 1e-4)\n",
270 | " self.conv2 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)\n",
271 | " self.conv3 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)\n",
272 | " self.conv4 = spectral_norm(nn.Conv2d(self.hidden_channels, self.out_channels, kernel_size = 1, padding = 0), eps = 1e-4)\n",
273 | " \n",
274 | " self.bn1 = ClassConditionalBN((3 * embed_dim) + dim_z, self.in_channels)\n",
275 | " self.bn2 = ClassConditionalBN((3 * embed_dim) + dim_z, self.hidden_channels)\n",
276 | " self.bn3 = ClassConditionalBN((3 * embed_dim) + dim_z, self.hidden_channels)\n",
277 | " self.bn4 = ClassConditionalBN((3 * embed_dim) + dim_z, self.hidden_channels)\n",
278 | " \n",
279 | " self.activation = nn.ReLU(inplace=False)\n",
280 | " \n",
281 | " self.upsample = upsample\n",
282 | " \n",
283 | " def forward(self,x,y):\n",
284 | " # Project down to channel ratio\n",
285 | " h = self.conv1(self.activation(self.bn1(x, y)))\n",
286 | " # Apply next BN-ReLU\n",
287 | " h = self.activation(self.bn2(h, y))\n",
288 | " # Drop channels in x if necessary\n",
289 | " if self.in_channels != self.out_channels:\n",
290 | " x = x[:, :self.out_channels] \n",
291 | " # Upsample both h and x at this point \n",
292 | " if self.upsample:\n",
293 | " h = self.upsample(h)\n",
294 | " x = self.upsample(x)\n",
295 | " # 3x3 convs\n",
296 | " h = self.conv2(h)\n",
297 | " h = self.conv3(self.activation(self.bn3(h, y)))\n",
298 | " # Final 1x1 conv\n",
299 | " h = self.conv4(self.activation(self.bn4(h, y)))\n",
300 | " return h + x"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {
306 | "papermill": {
307 | "duration": 0.021846,
308 | "end_time": "2021-05-12T05:40:59.625610",
309 | "exception": false,
310 | "start_time": "2021-05-12T05:40:59.603764",
311 | "status": "completed"
312 | },
313 | "tags": []
314 | },
315 | "source": [
316 | "## BigGAN-deep Generator \n",
317 | "This version of the generator has a different input mechanism from that of the original version in the paper. This change was made to accomodate multiple conditional inputs."
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 6,
323 | "metadata": {
324 | "execution": {
325 | "iopub.execute_input": "2021-05-12T05:40:59.691305Z",
326 | "iopub.status.busy": "2021-05-12T05:40:59.690418Z",
327 | "iopub.status.idle": "2021-05-12T05:40:59.692668Z",
328 | "shell.execute_reply": "2021-05-12T05:40:59.693070Z"
329 | },
330 | "papermill": {
331 | "duration": 0.04512,
332 | "end_time": "2021-05-12T05:40:59.693211",
333 | "exception": false,
334 | "start_time": "2021-05-12T05:40:59.648091",
335 | "status": "completed"
336 | },
337 | "tags": []
338 | },
339 | "outputs": [],
340 | "source": [
341 | "class Generator(nn.Module):\n",
342 | " def __init__(self, G_ch = 64, dim_z = 384, bottom_width=4, img_channels = 1,\n",
343 | " init = 'N02',n_classes_temp = 7, n_classes_time = 8, n_classes_cool = 4, embed_dim = 128):\n",
344 | " super(Generator, self).__init__()\n",
345 | " self.ch = G_ch\n",
346 | " self.dim_z = dim_z\n",
347 | " self.bottom_width = bottom_width\n",
348 | " self.init = init\n",
349 | " self.img_channels = img_channels\n",
350 | "\n",
351 | " self.embed_temp = nn.Embedding(n_classes_temp, embed_dim)\n",
352 | " self.embed_time = nn.Embedding(n_classes_time, embed_dim)\n",
353 | " self.embed_cool = nn.Embedding(n_classes_cool, embed_dim)\n",
354 | " \n",
355 | " self.linear = spectral_norm(nn.Linear(dim_z + (3 * embed_dim), 16 * self.ch * (self.bottom_width **2)), eps = 1e-4)\n",
356 | " \n",
357 | " self.blocks = nn.ModuleList([\n",
358 | " GeneratorResBlock(16*self.ch, 16*self.ch),\n",
359 | " GeneratorResBlock(16*self.ch, 16*self.ch, upsample = nn.Upsample(scale_factor = 2)),\n",
360 | " GeneratorResBlock(16*self.ch, 16*self.ch),\n",
361 | " GeneratorResBlock(16*self.ch, 8*self.ch, upsample = nn.Upsample(scale_factor = 2)),\n",
362 | " GeneratorResBlock(8*self.ch, 8*self.ch),\n",
363 | " GeneratorResBlock(8*self.ch, 8*self.ch, upsample = nn.Upsample(scale_factor = 2)),\n",
364 | " GeneratorResBlock(8*self.ch, 8*self.ch),\n",
365 | " GeneratorResBlock(8*self.ch, 4*self.ch, upsample = nn.Upsample(scale_factor = 2)),\n",
366 | " Self_Attn(4*self.ch),\n",
367 | " GeneratorResBlock(4*self.ch, 4*self.ch),\n",
368 | " GeneratorResBlock(4*self.ch, 2*self.ch, upsample = nn.Upsample(scale_factor = 2)),\n",
369 | " GeneratorResBlock(2*self.ch, 2*self.ch),\n",
370 | " GeneratorResBlock(2*self.ch, self.ch, upsample = nn.Upsample(scale_factor = 2))\n",
371 | " ])\n",
372 | " \n",
373 | " self.final_layer = nn.Sequential(\n",
374 | " nn.BatchNorm2d(self.ch),\n",
375 | " nn.ReLU(inplace = False),\n",
376 | " spectral_norm(nn.Conv2d(self.ch, self.img_channels, kernel_size = 3, padding = 1)),\n",
377 | " nn.Tanh()\n",
378 | " )\n",
379 | " \n",
380 | " self.init_weights()\n",
381 | " \n",
382 | " def init_weights(self):\n",
383 | " print(f\"Weight initialization : {self.init}\")\n",
384 | " self.param_count = 0\n",
385 | " for module in self.modules():\n",
386 | " if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):\n",
387 | " if self.init == 'ortho':\n",
388 | " torch.nn.init.orthogonal_(module.weight)\n",
389 | " elif self.init == 'N02':\n",
390 | " torch.nn.init.normal_(module.weight, 0, 0.02)\n",
391 | " elif self.init in ['glorot', 'xavier']:\n",
392 | " torch.nn.init.xavier_uniform_(module.weight)\n",
393 | " else:\n",
394 | " print('Init style not recognized...')\n",
395 | " self.param_count += sum([p.data.nelement() for p in module.parameters()])\n",
396 | " print(\"Param count for G's initialized parameters: %d Million\" % (self.param_count/1000000))\n",
397 | " \n",
398 | " \n",
399 | " def forward(self,z , y_temp, y_time, y_cool):\n",
400 | " y_temp = self.embed_temp(y_temp)\n",
401 | " y_time = self.embed_time(y_time)\n",
402 | " y_cool = self.embed_cool(y_cool)\n",
403 | " z = torch.cat([z, y_temp, y_time, y_cool], 1) \n",
404 | " # First linear layer\n",
405 | " h = self.linear(z)\n",
406 | " # Reshape\n",
407 | " h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) \n",
408 | " # Loop over blocks\n",
409 | " for i, block in enumerate(self.blocks):\n",
410 | " if i != 8:\n",
411 | " h = block(h, z)\n",
412 | " else:\n",
413 | " h = block(h)\n",
414 | " # Apply batchnorm-relu-conv-tanh at output\n",
415 | " h = self.final_layer(h)\n",
416 | " return h"
417 | ]
418 | },
419 | {
420 | "cell_type": "markdown",
421 | "metadata": {
422 | "papermill": {
423 | "duration": 0.021752,
424 | "end_time": "2021-05-12T05:40:59.737083",
425 | "exception": false,
426 | "start_time": "2021-05-12T05:40:59.715331",
427 | "status": "completed"
428 | },
429 | "tags": []
430 | },
431 | "source": [
432 | "## Discriminator Resblock"
433 | ]
434 | },
435 | {
436 | "cell_type": "code",
437 | "execution_count": 7,
438 | "metadata": {
439 | "execution": {
440 | "iopub.execute_input": "2021-05-12T05:40:59.792117Z",
441 | "iopub.status.busy": "2021-05-12T05:40:59.791327Z",
442 | "iopub.status.idle": "2021-05-12T05:40:59.794226Z",
443 | "shell.execute_reply": "2021-05-12T05:40:59.793762Z"
444 | },
445 | "papermill": {
446 | "duration": 0.036436,
447 | "end_time": "2021-05-12T05:40:59.794341",
448 | "exception": false,
449 | "start_time": "2021-05-12T05:40:59.757905",
450 | "status": "completed"
451 | },
452 | "tags": []
453 | },
454 | "outputs": [],
455 | "source": [
456 | "class DiscriminatorResBlock(nn.Module):\n",
457 | " def __init__(self, in_channels, out_channels, preactivation=True, \n",
458 | " downsample=None,channel_ratio=4):\n",
459 | " super(DiscriminatorResBlock, self).__init__()\n",
460 | " self.in_channels, self.out_channels = in_channels, out_channels\n",
461 | " # If using wide D (as in SA-GAN and BigGAN), change the channel pattern\n",
462 | " self.hidden_channels = self.out_channels // channel_ratio\n",
463 | " self.preactivation = preactivation\n",
464 | " self.activation = nn.ReLU(inplace=False)\n",
465 | " self.downsample = downsample\n",
466 | " \n",
467 | " # Conv layers\n",
468 | " self.conv1 = spectral_norm(nn.Conv2d(self.in_channels, self.hidden_channels, \n",
469 | " kernel_size=1, padding=0), eps = 1e-4)\n",
470 | " self.conv2 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)\n",
471 | " self.conv3 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)\n",
472 | " self.conv4 = spectral_norm(nn.Conv2d(self.hidden_channels, self.out_channels, \n",
473 | " kernel_size=1, padding=0), eps = 1e-4)\n",
474 | " \n",
475 | " self.learnable_sc = True if (in_channels != out_channels) else False\n",
476 | " if self.learnable_sc:\n",
477 | " self.conv_sc = spectral_norm(nn.Conv2d(in_channels, out_channels - in_channels, \n",
478 | " kernel_size=1, padding=0), eps = 1e-4)\n",
479 | " \n",
480 | " def shortcut(self, x):\n",
481 | " if self.downsample:\n",
482 | " x = self.downsample(x)\n",
483 | " if self.learnable_sc:\n",
484 | " x = torch.cat([x, self.conv_sc(x)], 1) \n",
485 | " return x\n",
486 | " \n",
487 | " def forward(self, x):\n",
488 | " # 1x1 bottleneck conv\n",
489 | " h = self.conv1(F.relu(x))\n",
490 | " # 3x3 convs\n",
491 | " h = self.conv2(self.activation(h))\n",
492 | " h = self.conv3(self.activation(h))\n",
493 | " # relu before downsample\n",
494 | " h = self.activation(h)\n",
495 | " # downsample\n",
496 | " if self.downsample:\n",
497 | " h = self.downsample(h) \n",
498 | " # final 1x1 conv\n",
499 | " h = self.conv4(h)\n",
500 | " return h + self.shortcut(x)"
501 | ]
502 | },
503 | {
504 | "cell_type": "markdown",
505 | "metadata": {
506 | "papermill": {
507 | "duration": 0.020773,
508 | "end_time": "2021-05-12T05:40:59.836086",
509 | "exception": false,
510 | "start_time": "2021-05-12T05:40:59.815313",
511 | "status": "completed"
512 | },
513 | "tags": []
514 | },
515 | "source": [
516 | "## BigGAN-deep Discriminator\n",
517 | "The discriminator is modified to output a 21 dimensional vector and no class conditional information is provided"
518 | ]
519 | },
520 | {
521 | "cell_type": "code",
522 | "execution_count": 8,
523 | "metadata": {
524 | "execution": {
525 | "iopub.execute_input": "2021-05-12T05:40:59.892421Z",
526 | "iopub.status.busy": "2021-05-12T05:40:59.891385Z",
527 | "iopub.status.idle": "2021-05-12T05:40:59.896694Z",
528 | "shell.execute_reply": "2021-05-12T05:40:59.896271Z"
529 | },
530 | "papermill": {
531 | "duration": 0.039824,
532 | "end_time": "2021-05-12T05:40:59.896806",
533 | "exception": false,
534 | "start_time": "2021-05-12T05:40:59.856982",
535 | "status": "completed"
536 | },
537 | "tags": []
538 | },
539 | "outputs": [],
540 | "source": [
541 | "class Discriminator(nn.Module):\n",
542 | " def __init__(self, D_ch = 64, img_channels = 1, init = 'N02', n_classes_temp = 7, n_classes_time = 8, n_classes_cool = 4):\n",
543 | " super(Discriminator, self).__init__()\n",
544 | " self.ch = D_ch\n",
545 | " self.init = init\n",
546 | " self.img_channels = img_channels\n",
547 | " self.output_dim = n_classes_temp + n_classes_time + n_classes_cool + 2\n",
548 | " \n",
549 | " # Prepare model\n",
550 | " # Stem convolution\n",
551 | " self.input_conv = spectral_norm(nn.Conv2d(self.img_channels, self.ch, kernel_size = 3, padding = 1), eps = 1e-4)\n",
552 | " \n",
553 | " self.blocks = nn.Sequential(\n",
554 | " DiscriminatorResBlock(self.ch, 2*self.ch, downsample = nn.AvgPool2d(2)),\n",
555 | " DiscriminatorResBlock(2*self.ch, 2*self.ch),\n",
556 | " DiscriminatorResBlock(2*self.ch, 4*self.ch, downsample = nn.AvgPool2d(2)),\n",
557 | " DiscriminatorResBlock(4*self.ch, 4*self.ch),\n",
558 | " Self_Attn(4*self.ch),\n",
559 | " DiscriminatorResBlock(4*self.ch, 8*self.ch, downsample = nn.AvgPool2d(2)),\n",
560 | " DiscriminatorResBlock(8*self.ch, 8*self.ch),\n",
561 | " DiscriminatorResBlock(8*self.ch, 8*self.ch, downsample = nn.AvgPool2d(2)),\n",
562 | " DiscriminatorResBlock(8*self.ch, 8*self.ch),\n",
563 | " DiscriminatorResBlock(8*self.ch, 16*self.ch, downsample = nn.AvgPool2d(2)),\n",
564 | " DiscriminatorResBlock(16*self.ch, 16*self.ch),\n",
565 | " DiscriminatorResBlock(16*self.ch, 16*self.ch, downsample = nn.AvgPool2d(2)),\n",
566 | " DiscriminatorResBlock(16*self.ch, 16*self.ch),\n",
567 | " )\n",
568 | " # Linear output layer. The output dimension is typically 1, but may be\n",
569 | " # larger if we're e.g. turning this into a VAE with an inference output\n",
570 | " self.linear = spectral_norm(nn.Linear(16*self.ch, self.output_dim), eps = 1e-4)\n",
571 | " \n",
572 | " self.init_weights()\n",
573 | " \n",
574 | " def init_weights(self):\n",
575 | " print(f\"Weight initialization : {self.init}\")\n",
576 | " self.param_count = 0\n",
577 | " for module in self.modules():\n",
578 | " if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):\n",
579 | " if self.init == 'ortho':\n",
580 | " torch.nn.init.orthogonal_(module.weight)\n",
581 | " elif self.init == 'N02':\n",
582 | " torch.nn.init.normal_(module.weight, 0, 0.02)\n",
583 | " elif self.init in ['glorot', 'xavier']:\n",
584 | " torch.nn.init.xavier_uniform_(module.weight)\n",
585 | " else:\n",
586 | " print('Init style not recognized...')\n",
587 | " self.param_count += sum([p.data.nelement() for p in module.parameters()])\n",
588 | " print(\"Param count for D's initialized parameters: %d Million\" % (self.param_count/1000000))\n",
589 | " \n",
590 | " def forward(self, x):\n",
591 | " # Run input conv\n",
592 | " h = self.input_conv(x)\n",
593 | " # Blocks\n",
594 | " h = self.blocks(h)\n",
595 | " # Apply global sum pooling as in SN-GAN\n",
596 | " h = torch.sum(nn.ReLU(inplace = False)(h), [2, 3])\n",
597 | " # Get initial class-unconditional output\n",
598 | " out = self.linear(h)\n",
599 | " # Get projection of final featureset onto class vectors and add to evidence\n",
600 | " #out = out + torch.sum(self.embed_temp(y_temp) * h, 1, keepdim=True) + torch.sum(self.embed_time(y_time) * h, 1, keepdim=True) + torch.sum(self.embed_cool(y_cool) * h, 1, keepdim=True)\n",
601 | " return out"
602 | ]
603 | },
604 | {
605 | "cell_type": "markdown",
606 | "metadata": {
607 | "papermill": {
608 | "duration": 0.020852,
609 | "end_time": "2021-05-12T05:40:59.938450",
610 | "exception": false,
611 | "start_time": "2021-05-12T05:40:59.917598",
612 | "status": "completed"
613 | },
614 | "tags": []
615 | },
616 | "source": [
617 | "## Differentiable Augmentation for GAN\n",
618 | "The implementation follows the Color-Translation-Cutout policy from this [paper](https://arxiv.org/pdf/2006.10738.pdf)"
619 | ]
620 | },
621 | {
622 | "cell_type": "code",
623 | "execution_count": 9,
624 | "metadata": {
625 | "execution": {
626 | "iopub.execute_input": "2021-05-12T05:41:00.003553Z",
627 | "iopub.status.busy": "2021-05-12T05:41:00.002673Z",
628 | "iopub.status.idle": "2021-05-12T05:41:00.005189Z",
629 | "shell.execute_reply": "2021-05-12T05:41:00.004788Z"
630 | },
631 | "papermill": {
632 | "duration": 0.045901,
633 | "end_time": "2021-05-12T05:41:00.005300",
634 | "exception": false,
635 | "start_time": "2021-05-12T05:40:59.959399",
636 | "status": "completed"
637 | },
638 | "tags": []
639 | },
640 | "outputs": [],
641 | "source": [
642 | "class DiffAugment:\n",
643 | " def __init__(self, policy='color,translation,cutout', channels_first=True):\n",
644 | " self.policy = policy\n",
645 | " print(f'Diff. Augment Policy : {policy}')\n",
646 | " self.channels_first = channels_first\n",
647 | " self.AUGMENT_FNS = {'color': [self.rand_brightness, self.rand_saturation, self.rand_contrast],\n",
648 | " 'translation': [self.rand_translation],\n",
649 | " 'cutout': [self.rand_cutout]}\n",
650 | " \n",
651 | " def __call__(self, x):\n",
652 | " if self.policy:\n",
653 | " if not self.channels_first:\n",
654 | " x = x.permute(0, 3, 1, 2)\n",
655 | " for p in self.policy.split(','):\n",
656 | " for f in self.AUGMENT_FNS[p]:\n",
657 | " x = f(x)\n",
658 | " if not self.channels_first:\n",
659 | " x = x.permute(0, 2, 3, 1)\n",
660 | " x = x.contiguous()\n",
661 | " return x\n",
662 | " \n",
663 | " \n",
664 | " def rand_brightness(self, x):\n",
665 | " x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)\n",
666 | " return x\n",
667 | " \n",
668 | " def rand_saturation(self, x):\n",
669 | " x_mean = x.mean(dim=1, keepdim=True)\n",
670 | " x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean\n",
671 | " return x\n",
672 | " \n",
673 | " def rand_contrast(self,x):\n",
674 | " x_mean = x.mean(dim=[1, 2, 3], keepdim=True)\n",
675 | " x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean\n",
676 | " return x\n",
677 | " \n",
678 | " def rand_translation(self, x, ratio = 0.125):\n",
679 | " shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n",
680 | " translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)\n",
681 | " translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)\n",
682 | " grid_batch, grid_x, grid_y = torch.meshgrid(\n",
683 | " torch.arange(x.size(0), dtype=torch.long, device=x.device),\n",
684 | " torch.arange(x.size(2), dtype=torch.long, device=x.device),\n",
685 | " torch.arange(x.size(3), dtype=torch.long, device=x.device),\n",
686 | " )\n",
687 | " grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)\n",
688 | " grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)\n",
689 | " x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])\n",
690 | " x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)\n",
691 | " return x\n",
692 | " \n",
693 | " def rand_cutout(self, x, ratio=0.5):\n",
694 | " cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n",
695 | " offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)\n",
696 | " offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)\n",
697 | " grid_batch, grid_x, grid_y = torch.meshgrid(\n",
698 | " torch.arange(x.size(0), dtype=torch.long, device=x.device),\n",
699 | " torch.arange(cutout_size[0], dtype=torch.long, device=x.device),\n",
700 | " torch.arange(cutout_size[1], dtype=torch.long, device=x.device),\n",
701 | " )\n",
702 | " grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)\n",
703 | " grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)\n",
704 | " mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)\n",
705 | " mask[grid_batch, grid_x, grid_y] = 0\n",
706 | " x = x * mask.unsqueeze(1)\n",
707 | " return x"
708 | ]
709 | },
710 | {
711 | "cell_type": "markdown",
712 | "metadata": {
713 | "papermill": {
714 | "duration": 0.020992,
715 | "end_time": "2021-05-12T05:41:00.047398",
716 | "exception": false,
717 | "start_time": "2021-05-12T05:41:00.026406",
718 | "status": "completed"
719 | },
720 | "tags": []
721 | },
722 | "source": [
723 | "## Custom sampler to balance the dataset "
724 | ]
725 | },
726 | {
727 | "cell_type": "code",
728 | "execution_count": 10,
729 | "metadata": {
730 | "execution": {
731 | "iopub.execute_input": "2021-05-12T05:41:00.099675Z",
732 | "iopub.status.busy": "2021-05-12T05:41:00.098826Z",
733 | "iopub.status.idle": "2021-05-12T05:41:00.101540Z",
734 | "shell.execute_reply": "2021-05-12T05:41:00.101057Z"
735 | },
736 | "papermill": {
737 | "duration": 0.033026,
738 | "end_time": "2021-05-12T05:41:00.101649",
739 | "exception": false,
740 | "start_time": "2021-05-12T05:41:00.068623",
741 | "status": "completed"
742 | },
743 | "tags": []
744 | },
745 | "outputs": [],
746 | "source": [
747 | "class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):\n",
748 | " \"\"\"Samples elements randomly from a given list of indices for imbalanced dataset\n",
749 | " Arguments:\n",
750 | " indices (list, optional): a list of indices\n",
751 | " num_samples (int, optional): number of samples to draw\n",
752 | " callback_get_label func: a callback-like function which takes two arguments - dataset and index\n",
753 | " \"\"\"\n",
754 | " def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):\n",
755 | " \n",
756 | " # if indices is not provided, \n",
757 | " # all elements in the dataset will be considered\n",
758 | " self.indices = list(range(len(dataset))) \\\n",
759 | " if indices is None else indices\n",
760 | " \n",
761 | " # define custom callback\n",
762 | " self.callback_get_label = callback_get_label\n",
763 | " \n",
764 | " # if num_samples is not provided, \n",
765 | " # draw `len(indices)` samples in each iteration\n",
766 | " self.num_samples = len(self.indices) \\\n",
767 | " if num_samples is None else num_samples\n",
768 | " \n",
769 | " # distribution of classes in the dataset for temp , time and cool \n",
770 | " label_to_count = {}\n",
771 | " for idx in self.indices:\n",
772 | " label = self._get_label(dataset, idx)\n",
773 | " if label in label_to_count:\n",
774 | " label_to_count[label] += 1\n",
775 | " else:\n",
776 | " label_to_count[label] = 1\n",
777 | " \n",
778 | " # weight for each sample\n",
779 | " weights = [1.0 / label_to_count[self._get_label(dataset, idx)]\n",
780 | " for idx in self.indices]\n",
781 | " weights = torch.DoubleTensor(weights)\n",
782 | " self.weights= weights\n",
783 | " \n",
784 | " def _get_label(self, dataset, idx):\n",
785 | " return dataset[idx][4]\n",
786 | " \n",
787 | " def __iter__(self):\n",
788 | " return (self.indices[i] for i in torch.multinomial(\n",
789 | " self.weights, self.num_samples, replacement=True))\n",
790 | " \n",
791 | " def __len__(self):\n",
792 | " return self.num_samples"
793 | ]
794 | },
795 | {
796 | "cell_type": "markdown",
797 | "metadata": {
798 | "papermill": {
799 | "duration": 0.021281,
800 | "end_time": "2021-05-12T05:41:00.143963",
801 | "exception": false,
802 | "start_time": "2021-05-12T05:41:00.122682",
803 | "status": "completed"
804 | },
805 | "tags": []
806 | },
807 | "source": [
808 | "## Custom dataset for UHCSDB"
809 | ]
810 | },
811 | {
812 | "cell_type": "code",
813 | "execution_count": 11,
814 | "metadata": {
815 | "execution": {
816 | "iopub.execute_input": "2021-05-12T05:41:00.197546Z",
817 | "iopub.status.busy": "2021-05-12T05:41:00.196765Z",
818 | "iopub.status.idle": "2021-05-12T05:41:00.199550Z",
819 | "shell.execute_reply": "2021-05-12T05:41:00.199131Z"
820 | },
821 | "papermill": {
822 | "duration": 0.034447,
823 | "end_time": "2021-05-12T05:41:00.199666",
824 | "exception": false,
825 | "start_time": "2021-05-12T05:41:00.165219",
826 | "status": "completed"
827 | },
828 | "tags": []
829 | },
830 | "outputs": [],
831 | "source": [
832 | "class MicrographDataset(Dataset):\n",
833 | " \"\"\"\n",
834 | " A custom Dataset class for Micrograph data which returns the following\n",
835 | " # Micrograph image\n",
836 | " # Inputs : Anneal Temperature , Anneal Time and Type of cooling used\n",
837 | " ------------------------------------------------------------------------------------\n",
838 | " Attributes\n",
839 | " \n",
840 | " df : pandas.core.frame.DataFrame\n",
841 | " A Dataframe that contains the proper entries (i.e. dataframe corresponding to new_metadata.xlsx)\n",
842 | " root_dir : str\n",
843 | " The path of the folder where the images are located\n",
844 | " transform : torchvision.transforms.transforms.Compose\n",
845 | " The transforms that are to be applied to the loaded images\n",
846 | " \"\"\"\n",
847 | " def __init__(self, df, root_dir, transform=None):\n",
848 | " self.df = df\n",
849 | " self.transform = transform\n",
850 | " self.root_dir = root_dir\n",
851 | " \n",
852 | " def __len__(self):\n",
853 | " return len(self.df) \n",
854 | " \n",
855 | " def __getitem__(self, idx):\n",
856 | " temp_dict = {970: 0, 800: 1, 900: 2, 1100: 3, 1000: 4, 700: 5, 750: 6}\n",
857 | " time_dict = {90: 0, 1440: 1, 180: 2, 5: 3, 480: 4, 5100: 5, 60: 6, 2880: 7}\n",
858 | " microconst_dict = {'spheroidite': 0, 'network' : 1,'spheroidite+widmanstatten' : 2, 'pearlite+spheroidite' : 3,\n",
859 | " 'pearlite' : 4,'pearlite+widmanstatten' : 5}\n",
860 | " cooling_dict = {'Q': 0, 'FC': 1, 'AR': 2, '650-1H': 3}\n",
861 | " row = self.df.loc[idx]\n",
862 | " img_name = row['path']\n",
863 | " img_path = self.root_dir + '/' + 'Cropped' + img_name\n",
864 | " anneal_temp = temp_dict[row['anneal_temperature']]\n",
865 | " if row['anneal_time_unit'] == 'H':\n",
866 | " anneal_time = int(row['anneal_time']) * 60\n",
867 | " else:\n",
868 | " anneal_time = row['anneal_time']\n",
869 | " anneal_time = time_dict[anneal_time]\n",
870 | " cooling_type = cooling_dict[row['cool_method']]\n",
871 | " microconst = microconst_dict[row['primary_microconstituent']]\n",
872 | " img = Image.open(img_path)\n",
873 | " img = img.convert('L')\n",
874 | " if self.transform:\n",
875 | " img = self.transform(img)\n",
876 | " return img , anneal_temp, anneal_time, cooling_type, microconst"
877 | ]
878 | },
879 | {
880 | "cell_type": "markdown",
881 | "metadata": {
882 | "papermill": {
883 | "duration": 0.022176,
884 | "end_time": "2021-05-12T05:41:00.244556",
885 | "exception": false,
886 | "start_time": "2021-05-12T05:41:00.222380",
887 | "status": "completed"
888 | },
889 | "tags": []
890 | },
891 | "source": [
892 | "## Lightining Module for GAN "
893 | ]
894 | },
895 | {
896 | "cell_type": "code",
897 | "execution_count": 12,
898 | "metadata": {
899 | "execution": {
900 | "iopub.execute_input": "2021-05-12T05:41:00.291050Z",
901 | "iopub.status.busy": "2021-05-12T05:41:00.290171Z",
902 | "iopub.status.idle": "2021-05-12T05:41:00.320279Z",
903 | "shell.execute_reply": "2021-05-12T05:41:00.319873Z"
904 | },
905 | "papermill": {
906 | "duration": 0.054443,
907 | "end_time": "2021-05-12T05:41:00.320388",
908 | "exception": false,
909 | "start_time": "2021-05-12T05:41:00.265945",
910 | "status": "completed"
911 | },
912 | "tags": []
913 | },
914 | "outputs": [],
915 | "source": [
916 | "class MicrographBigGAN(pl.LightningModule):\n",
917 | " def __init__(self, root_dir, df_dir, batch_size, augment_bool = True, lr = 0.0002,\n",
918 | " n_classes_temp = 7, n_classes_time = 8, n_classes_cool = 4):\n",
919 | " super().__init__()\n",
920 | " self.save_hyperparameters()\n",
921 | " self.root_dir = root_dir\n",
922 | " self.df_dir = df_dir\n",
923 | " self.generator = Generator(G_ch = 64)\n",
924 | " self.discriminator = Discriminator(D_ch = 64)\n",
925 | " self.diffaugment = DiffAugment()\n",
926 | " self.augment_bool = augment_bool\n",
927 | " self.batch_size = batch_size \n",
928 | " self.lr = lr\n",
929 | " self.n_classes_temp = n_classes_temp\n",
930 | " self.n_classes_time = n_classes_time\n",
931 | " self.n_classes_cool = n_classes_cool\n",
932 | " \n",
933 | " def forward(self, z, y_temp, y_time, y_cool):\n",
934 | " return self.generator(z, y_temp, y_time, y_cool)\n",
935 | " \n",
936 | " def multilabel_categorical_crossentropy(self, y_true, y_pred, margin=0., gamma=1.):\n",
937 | " \"\"\" y_true: positive=1, negative=0, ignore=-1\n",
938 | " \"\"\"\n",
939 | " y_true = y_true.clamp(-1, 1)\n",
940 | " if len(y_pred.shape) > 2:\n",
941 | " y_true = y_true.view(y_true.shape[0], 1, 1, -1)\n",
942 | " _, _, h, w = y_pred.shape\n",
943 | " y_true = y_true.expand(-1, h, w, -1)\n",
944 | " y_pred = y_pred.permute(0, 2, 3, 1)\n",
945 | "\n",
946 | " y_pred = y_pred + margin\n",
947 | " y_pred = y_pred * gamma\n",
948 | "\n",
949 | " y_pred[y_true == 1] = -1 * y_pred[y_true == 1]\n",
950 | " y_pred[y_true == -1] = -1e12\n",
951 | "\n",
952 | " y_pred_neg = y_pred.clone()\n",
953 | " y_pred_neg[y_true == 1] = -1e12\n",
954 | "\n",
955 | " y_pred_pos = y_pred.clone()\n",
956 | " y_pred_pos[y_true == 0] = -1e12\n",
957 | "\n",
958 | " zeros = torch.zeros_like(y_pred[..., :1])\n",
959 | " y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)\n",
960 | " y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)\n",
961 | " neg_loss = torch.logsumexp(y_pred_neg, dim=-1)\n",
962 | " pos_loss = torch.logsumexp(y_pred_pos, dim=-1)\n",
963 | " return neg_loss + pos_loss\n",
964 | " \n",
965 | " def Omni_Dloss(self,disc_real, disc_fake, y_temp_real, y_time_real, y_cool_real):\n",
966 | " b = y_temp_real.shape[0]\n",
967 | " y_temp = F.one_hot(y_temp_real, num_classes = self.n_classes_temp).to(self.device)\n",
968 | " y_time = F.one_hot(y_time_real, num_classes = self.n_classes_time).to(self.device)\n",
969 | " y_cool = F.one_hot(y_cool_real, num_classes = self.n_classes_cool).to(self.device)\n",
970 | " y_real = torch.cat([y_temp,y_time,y_cool,torch.tensor([1,0]).repeat(b,1).to(self.device)],1).float()\n",
971 | " y_real.requires_grad = True\n",
972 | " y_fake = torch.cat([torch.zeros((b, self.n_classes_temp + self.n_classes_time + self.n_classes_cool)),torch.tensor([0,1]).repeat(b,1)],1).float().to(self.device)\n",
973 | " y_fake.requires_grad = True\n",
974 | " d_loss_real = self.multilabel_categorical_crossentropy(y_true = y_real, y_pred = disc_real)\n",
975 | " d_loss_fake = self.multilabel_categorical_crossentropy(y_true = y_fake, y_pred = disc_fake)\n",
976 | " d_loss = d_loss_real.mean() + d_loss_fake.mean()\n",
977 | " return d_loss\n",
978 | " \n",
979 | " def Omni_Gloss(self, disc_fake, y_temp_fake, y_time_fake, y_cool_fake):\n",
980 | " b = y_temp_fake.shape[0]\n",
981 | " y_temp = F.one_hot(y_temp_fake, num_classes = self.n_classes_temp).to(self.device)\n",
982 | " y_time = F.one_hot(y_time_fake, num_classes = self.n_classes_time).to(self.device)\n",
983 | " y_cool = F.one_hot(y_cool_fake, num_classes = self.n_classes_cool).to(self.device)\n",
984 | " y_fake_g = torch.cat([y_temp,y_time,y_cool,torch.tensor([1,0]).repeat(b,1).to(self.device)],1).float()\n",
985 | " y_fake_g.requires_grad = True\n",
986 | " g_loss = self.multilabel_categorical_crossentropy(y_true = y_fake_g, y_pred = disc_fake)\n",
987 | " return g_loss.mean()\n",
988 | " \n",
989 | " def training_step(self, batch, batch_idx, optimizer_idx):\n",
990 | " real, y_temp_real, y_time_real, y_cool_real , _ = batch\n",
991 | " z = torch.randn(real.shape[0], 384)\n",
992 | " z = z.type_as(real)\n",
993 | " y_temp_fake = torch.randint(self.n_classes_temp,(real.shape[0],)).to(self.device)\n",
994 | " y_time_fake = torch.randint(self.n_classes_time,(real.shape[0],)).to(self.device)\n",
995 | " y_cool_fake = torch.randint(self.n_classes_cool,(real.shape[0],)).to(self.device)\n",
996 | " \n",
997 | " if optimizer_idx == 0:\n",
998 | " fake = self(z, y_temp_fake, y_time_fake, y_cool_fake)\n",
999 | " \n",
1000 | " if self.augment_bool:\n",
1001 | " disc_real = self.discriminator(self.diffaugment(real))\n",
1002 | " disc_fake = self.discriminator(self.diffaugment(fake))\n",
1003 | " else:\n",
1004 | " disc_real = self.discriminator(real)\n",
1005 | " disc_fake = self.discriminator(fake)\n",
1006 | " d_loss = self.Omni_Dloss(disc_real, disc_fake, y_temp_real, y_time_real, y_cool_real)\n",
1007 | " #print(d_loss.requires_grad)\n",
1008 | " tqdm_dict = {'d_loss': d_loss}\n",
1009 | " output = OrderedDict({\n",
1010 | " 'loss': d_loss,\n",
1011 | " 'progress_bar': tqdm_dict,\n",
1012 | " 'log': tqdm_dict\n",
1013 | " })\n",
1014 | " return output\n",
1015 | " \n",
1016 | " if optimizer_idx == 1:\n",
1017 | " fake = self(z,y_temp_fake, y_time_fake, y_cool_fake)\n",
1018 | " if self.augment_bool:\n",
1019 | " disc_fake = self.discriminator(self.diffaugment(fake))\n",
1020 | " else:\n",
1021 | " disc_fake = self.discriminator(fake)\n",
1022 | " g_loss = self.Omni_Gloss(disc_fake, y_temp_fake, y_time_fake, y_cool_fake)\n",
1023 | " #print(g_loss.requires_grad)\n",
1024 | " tqdm_dict = {'g_loss': g_loss}\n",
1025 | " output = OrderedDict({\n",
1026 | " 'loss': g_loss,\n",
1027 | " 'progress_bar': tqdm_dict,\n",
1028 | " 'log': tqdm_dict\n",
1029 | " })\n",
1030 | " return output\n",
1031 | " \n",
1032 | " def configure_optimizers(self):\n",
1033 | " opt_g = optim.Adam(self.generator.parameters(), lr = self.lr, weight_decay = 0.001)\n",
1034 | " opt_d = optim.Adam(self.discriminator.parameters(), lr = self.lr, weight_decay = 0.0005)\n",
1035 | " return (\n",
1036 | " {'optimizer': opt_d, 'frequency': 4},\n",
1037 | " {'optimizer': opt_g, 'frequency': 1}\n",
1038 | " )\n",
1039 | " \n",
1040 | " def train_dataloader(self):\n",
1041 | " img_transforms = transforms.Compose([\n",
1042 | " transforms.RandomCrop(256),\n",
1043 | " transforms.RandomHorizontalFlip(p=0.5),\n",
1044 | " transforms.RandomVerticalFlip(p=0.5),\n",
1045 | " transforms.Resize([256, 256]),\n",
1046 | " transforms.ToTensor(),\n",
1047 | " transforms.Normalize([0.5 for _ in range(1)],[0.5 for _ in range(1)]),\n",
1048 | " ])\n",
1049 | " df = pd.read_excel(self.df_dir,engine = 'openpyxl')\n",
1050 | " dataset = MicrographDataset(df,self.root_dir,transform = img_transforms)\n",
1051 | " return DataLoader(dataset, sampler = ImbalancedDatasetSampler(dataset), \n",
1052 | " batch_size=self.batch_size,shuffle=False)"
1053 | ]
1054 | },
1055 | {
1056 | "cell_type": "code",
1057 | "execution_count": 13,
1058 | "metadata": {
1059 | "execution": {
1060 | "iopub.execute_input": "2021-05-12T05:41:00.367953Z",
1061 | "iopub.status.busy": "2021-05-12T05:41:00.367205Z",
1062 | "iopub.status.idle": "2021-05-12T05:41:00.369515Z",
1063 | "shell.execute_reply": "2021-05-12T05:41:00.369901Z"
1064 | },
1065 | "papermill": {
1066 | "duration": 0.027684,
1067 | "end_time": "2021-05-12T05:41:00.370032",
1068 | "exception": false,
1069 | "start_time": "2021-05-12T05:41:00.342348",
1070 | "status": "completed"
1071 | },
1072 | "tags": []
1073 | },
1074 | "outputs": [],
1075 | "source": [
1076 | "ROOT_DIR = '../input/highcarbon-micrographs/For Training/Cropped'\n",
1077 | "DF_DIR = '../input/highcarbon-micrographs/new_metadata.xlsx'"
1078 | ]
1079 | },
1080 | {
1081 | "cell_type": "code",
1082 | "execution_count": 14,
1083 | "metadata": {
1084 | "execution": {
1085 | "iopub.execute_input": "2021-05-12T05:41:00.417973Z",
1086 | "iopub.status.busy": "2021-05-12T05:41:00.417450Z",
1087 | "iopub.status.idle": "2021-05-12T05:41:01.386394Z",
1088 | "shell.execute_reply": "2021-05-12T05:41:01.386953Z"
1089 | },
1090 | "papermill": {
1091 | "duration": 0.995368,
1092 | "end_time": "2021-05-12T05:41:01.387148",
1093 | "exception": false,
1094 | "start_time": "2021-05-12T05:41:00.391780",
1095 | "status": "completed"
1096 | },
1097 | "tags": []
1098 | },
1099 | "outputs": [
1100 | {
1101 | "name": "stdout",
1102 | "output_type": "stream",
1103 | "text": [
1104 | "Weight initialization : N02\n",
1105 | "Param count for G's initialized parameters: 39 Million\n",
1106 | "Weight initialization : N02\n",
1107 | "Param count for D's initialized parameters: 9 Million\n",
1108 | "Diff. Augment Policy : color,translation,cutout\n"
1109 | ]
1110 | }
1111 | ],
1112 | "source": [
1113 | "gan = MicrographBigGAN(ROOT_DIR,DF_DIR,batch_size=12)"
1114 | ]
1115 | },
1116 | {
1117 | "cell_type": "code",
1118 | "execution_count": 15,
1119 | "metadata": {
1120 | "execution": {
1121 | "iopub.execute_input": "2021-05-12T05:41:01.507288Z",
1122 | "iopub.status.busy": "2021-05-12T05:41:01.506408Z",
1123 | "iopub.status.idle": "2021-05-12T05:41:01.510356Z",
1124 | "shell.execute_reply": "2021-05-12T05:41:01.510902Z"
1125 | },
1126 | "papermill": {
1127 | "duration": 0.100204,
1128 | "end_time": "2021-05-12T05:41:01.511052",
1129 | "exception": false,
1130 | "start_time": "2021-05-12T05:41:01.410848",
1131 | "status": "completed"
1132 | },
1133 | "tags": []
1134 | },
1135 | "outputs": [],
1136 | "source": [
1137 | "trainer = pl.Trainer(max_epochs=600, gpus=1 if torch.cuda.is_available() else 0, accumulate_grad_batches=8)"
1138 | ]
1139 | },
1140 | {
1141 | "cell_type": "code",
1142 | "execution_count": 16,
1143 | "metadata": {
1144 | "execution": {
1145 | "iopub.execute_input": "2021-05-12T05:41:01.569751Z",
1146 | "iopub.status.busy": "2021-05-12T05:41:01.569230Z",
1147 | "iopub.status.idle": "2021-05-12T13:02:06.697914Z",
1148 | "shell.execute_reply": "2021-05-12T13:02:06.698424Z"
1149 | },
1150 | "papermill": {
1151 | "duration": 26465.164828,
1152 | "end_time": "2021-05-12T13:02:06.698628",
1153 | "exception": false,
1154 | "start_time": "2021-05-12T05:41:01.533800",
1155 | "status": "completed"
1156 | },
1157 | "tags": []
1158 | },
1159 | "outputs": [
1160 | {
1161 | "data": {
1162 | "application/vnd.jupyter.widget-view+json": {
1163 | "model_id": "b42a2f5cf39d48c385105398a2096934",
1164 | "version_major": 2,
1165 | "version_minor": 0
1166 | },
1167 | "text/plain": [
1168 | "Training: 0it [00:00, ?it/s]"
1169 | ]
1170 | },
1171 | "metadata": {},
1172 | "output_type": "display_data"
1173 | },
1174 | {
1175 | "data": {
1176 | "text/plain": [
1177 | "1"
1178 | ]
1179 | },
1180 | "execution_count": 16,
1181 | "metadata": {},
1182 | "output_type": "execute_result"
1183 | }
1184 | ],
1185 | "source": [
1186 | "trainer.fit(gan)"
1187 | ]
1188 | },
1189 | {
1190 | "cell_type": "code",
1191 | "execution_count": 17,
1192 | "metadata": {
1193 | "execution": {
1194 | "iopub.execute_input": "2021-05-12T13:02:06.795159Z",
1195 | "iopub.status.busy": "2021-05-12T13:02:06.750611Z",
1196 | "iopub.status.idle": "2021-05-12T13:02:08.948957Z",
1197 | "shell.execute_reply": "2021-05-12T13:02:08.949777Z"
1198 | },
1199 | "papermill": {
1200 | "duration": 2.226897,
1201 | "end_time": "2021-05-12T13:02:08.950022",
1202 | "exception": false,
1203 | "start_time": "2021-05-12T13:02:06.723125",
1204 | "status": "completed"
1205 | },
1206 | "tags": []
1207 | },
1208 | "outputs": [],
1209 | "source": [
1210 | "trainer.save_checkpoint(\"MicroGAN_checkpoint.ckpt\")"
1211 | ]
1212 | },
1213 | {
1214 | "cell_type": "markdown",
1215 | "metadata": {
1216 | "papermill": {
1217 | "duration": 0.040859,
1218 | "end_time": "2021-05-12T13:02:09.052123",
1219 | "exception": false,
1220 | "start_time": "2021-05-12T13:02:09.011264",
1221 | "status": "completed"
1222 | },
1223 | "tags": []
1224 | },
1225 | "source": [
1226 | "## For resuming the training \n",
1227 | "Add the saved checkpoint to the highcarbon-micrographs directory and then resume training. "
1228 | ]
1229 | },
1230 | {
1231 | "cell_type": "markdown",
1232 | "metadata": {
1233 | "papermill": {
1234 | "duration": 0.031489,
1235 | "end_time": "2021-05-12T13:02:09.132971",
1236 | "exception": false,
1237 | "start_time": "2021-05-12T13:02:09.101482",
1238 | "status": "completed"
1239 | },
1240 | "tags": []
1241 | },
1242 | "source": [
1243 | "new_gan = MicrographBigGAN.load_from_checkpoint(checkpoint_path=\"../input/microstrcuture-biggan-gpu/MicroGAN_checkpoint.ckpt\")"
1244 | ]
1245 | },
1246 | {
1247 | "cell_type": "markdown",
1248 | "metadata": {
1249 | "papermill": {
1250 | "duration": 0.022949,
1251 | "end_time": "2021-05-12T13:02:09.180732",
1252 | "exception": false,
1253 | "start_time": "2021-05-12T13:02:09.157783",
1254 | "status": "completed"
1255 | },
1256 | "tags": []
1257 | },
1258 | "source": [
1259 | "torch.save(new_gan.generator.state_dict(), 'BigGAN-deep.pth')"
1260 | ]
1261 | },
1262 | {
1263 | "cell_type": "markdown",
1264 | "metadata": {
1265 | "papermill": {
1266 | "duration": 0.023118,
1267 | "end_time": "2021-05-12T13:02:09.227021",
1268 | "exception": false,
1269 | "start_time": "2021-05-12T13:02:09.203903",
1270 | "status": "completed"
1271 | },
1272 | "tags": []
1273 | },
1274 | "source": [
1275 | "trainer = Trainer(resume_from_checkpoint='../input/highcarbon-micrographs/MicroGAN_checkpoint.ckpt')"
1276 | ]
1277 | },
1278 | {
1279 | "cell_type": "markdown",
1280 | "metadata": {
1281 | "papermill": {
1282 | "duration": 0.023173,
1283 | "end_time": "2021-05-12T13:02:09.274088",
1284 | "exception": false,
1285 | "start_time": "2021-05-12T13:02:09.250915",
1286 | "status": "completed"
1287 | },
1288 | "tags": []
1289 | },
1290 | "source": [
1291 | "trainer.fit(gan)"
1292 | ]
1293 | },
1294 | {
1295 | "cell_type": "code",
1296 | "execution_count": 18,
1297 | "metadata": {
1298 | "execution": {
1299 | "iopub.execute_input": "2021-05-12T13:02:09.325209Z",
1300 | "iopub.status.busy": "2021-05-12T13:02:09.324445Z",
1301 | "iopub.status.idle": "2021-05-12T13:02:09.593620Z",
1302 | "shell.execute_reply": "2021-05-12T13:02:09.594400Z"
1303 | },
1304 | "papermill": {
1305 | "duration": 0.297594,
1306 | "end_time": "2021-05-12T13:02:09.594673",
1307 | "exception": false,
1308 | "start_time": "2021-05-12T13:02:09.297079",
1309 | "status": "completed"
1310 | },
1311 | "tags": []
1312 | },
1313 | "outputs": [],
1314 | "source": [
1315 | "torch.save(gan.generator.state_dict(), 'BigGAN-deep.pth')"
1316 | ]
1317 | },
1318 | {
1319 | "cell_type": "code",
1320 | "execution_count": null,
1321 | "metadata": {
1322 | "papermill": {
1323 | "duration": 0.040898,
1324 | "end_time": "2021-05-12T13:02:09.695758",
1325 | "exception": false,
1326 | "start_time": "2021-05-12T13:02:09.654860",
1327 | "status": "completed"
1328 | },
1329 | "tags": []
1330 | },
1331 | "outputs": [],
1332 | "source": []
1333 | }
1334 | ],
1335 | "metadata": {
1336 | "kernelspec": {
1337 | "display_name": "Python 3",
1338 | "language": "python",
1339 | "name": "python3"
1340 | },
1341 | "language_info": {
1342 | "codemirror_mode": {
1343 | "name": "ipython",
1344 | "version": 3
1345 | },
1346 | "file_extension": ".py",
1347 | "mimetype": "text/x-python",
1348 | "name": "python",
1349 | "nbconvert_exporter": "python",
1350 | "pygments_lexer": "ipython3",
1351 | "version": "3.8.3"
1352 | },
1353 | "papermill": {
1354 | "default_parameters": {},
1355 | "duration": 26490.991864,
1356 | "end_time": "2021-05-12T13:02:12.461219",
1357 | "environment_variables": {},
1358 | "exception": null,
1359 | "input_path": "__notebook__.ipynb",
1360 | "output_path": "__notebook__.ipynb",
1361 | "parameters": {},
1362 | "start_time": "2021-05-12T05:40:41.469355",
1363 | "version": "2.3.3"
1364 | },
1365 | "widgets": {
1366 | "application/vnd.jupyter.widget-state+json": {
1367 | "state": {
1368 | "03955165e0e645319087ded469df198a": {
1369 | "model_module": "@jupyter-widgets/controls",
1370 | "model_module_version": "1.5.0",
1371 | "model_name": "ProgressStyleModel",
1372 | "state": {
1373 | "_model_module": "@jupyter-widgets/controls",
1374 | "_model_module_version": "1.5.0",
1375 | "_model_name": "ProgressStyleModel",
1376 | "_view_count": null,
1377 | "_view_module": "@jupyter-widgets/base",
1378 | "_view_module_version": "1.2.0",
1379 | "_view_name": "StyleView",
1380 | "bar_color": null,
1381 | "description_width": ""
1382 | }
1383 | },
1384 | "0a89a9c11093488180b029036a173df3": {
1385 | "model_module": "@jupyter-widgets/base",
1386 | "model_module_version": "1.2.0",
1387 | "model_name": "LayoutModel",
1388 | "state": {
1389 | "_model_module": "@jupyter-widgets/base",
1390 | "_model_module_version": "1.2.0",
1391 | "_model_name": "LayoutModel",
1392 | "_view_count": null,
1393 | "_view_module": "@jupyter-widgets/base",
1394 | "_view_module_version": "1.2.0",
1395 | "_view_name": "LayoutView",
1396 | "align_content": null,
1397 | "align_items": null,
1398 | "align_self": null,
1399 | "border": null,
1400 | "bottom": null,
1401 | "display": null,
1402 | "flex": null,
1403 | "flex_flow": null,
1404 | "grid_area": null,
1405 | "grid_auto_columns": null,
1406 | "grid_auto_flow": null,
1407 | "grid_auto_rows": null,
1408 | "grid_column": null,
1409 | "grid_gap": null,
1410 | "grid_row": null,
1411 | "grid_template_areas": null,
1412 | "grid_template_columns": null,
1413 | "grid_template_rows": null,
1414 | "height": null,
1415 | "justify_content": null,
1416 | "justify_items": null,
1417 | "left": null,
1418 | "margin": null,
1419 | "max_height": null,
1420 | "max_width": null,
1421 | "min_height": null,
1422 | "min_width": null,
1423 | "object_fit": null,
1424 | "object_position": null,
1425 | "order": null,
1426 | "overflow": null,
1427 | "overflow_x": null,
1428 | "overflow_y": null,
1429 | "padding": null,
1430 | "right": null,
1431 | "top": null,
1432 | "visibility": null,
1433 | "width": null
1434 | }
1435 | },
1436 | "207fc2d55d314763a543f4084c038f13": {
1437 | "model_module": "@jupyter-widgets/controls",
1438 | "model_module_version": "1.5.0",
1439 | "model_name": "HTMLModel",
1440 | "state": {
1441 | "_dom_classes": [],
1442 | "_model_module": "@jupyter-widgets/controls",
1443 | "_model_module_version": "1.5.0",
1444 | "_model_name": "HTMLModel",
1445 | "_view_count": null,
1446 | "_view_module": "@jupyter-widgets/controls",
1447 | "_view_module_version": "1.5.0",
1448 | "_view_name": "HTMLView",
1449 | "description": "",
1450 | "description_tooltip": null,
1451 | "layout": "IPY_MODEL_0a89a9c11093488180b029036a173df3",
1452 | "placeholder": "",
1453 | "style": "IPY_MODEL_91beed9008f34128ad4e23f0842758c7",
1454 | "value": " 50/50 [00:43<00:00, 1.14it/s, loss=5.77, v_num=0, d_loss=5.690, g_loss=6.650]"
1455 | }
1456 | },
1457 | "33fc35e21b5c4bacb6fbc4ab0aff8ce1": {
1458 | "model_module": "@jupyter-widgets/base",
1459 | "model_module_version": "1.2.0",
1460 | "model_name": "LayoutModel",
1461 | "state": {
1462 | "_model_module": "@jupyter-widgets/base",
1463 | "_model_module_version": "1.2.0",
1464 | "_model_name": "LayoutModel",
1465 | "_view_count": null,
1466 | "_view_module": "@jupyter-widgets/base",
1467 | "_view_module_version": "1.2.0",
1468 | "_view_name": "LayoutView",
1469 | "align_content": null,
1470 | "align_items": null,
1471 | "align_self": null,
1472 | "border": null,
1473 | "bottom": null,
1474 | "display": "inline-flex",
1475 | "flex": null,
1476 | "flex_flow": "row wrap",
1477 | "grid_area": null,
1478 | "grid_auto_columns": null,
1479 | "grid_auto_flow": null,
1480 | "grid_auto_rows": null,
1481 | "grid_column": null,
1482 | "grid_gap": null,
1483 | "grid_row": null,
1484 | "grid_template_areas": null,
1485 | "grid_template_columns": null,
1486 | "grid_template_rows": null,
1487 | "height": null,
1488 | "justify_content": null,
1489 | "justify_items": null,
1490 | "left": null,
1491 | "margin": null,
1492 | "max_height": null,
1493 | "max_width": null,
1494 | "min_height": null,
1495 | "min_width": null,
1496 | "object_fit": null,
1497 | "object_position": null,
1498 | "order": null,
1499 | "overflow": null,
1500 | "overflow_x": null,
1501 | "overflow_y": null,
1502 | "padding": null,
1503 | "right": null,
1504 | "top": null,
1505 | "visibility": null,
1506 | "width": "100%"
1507 | }
1508 | },
1509 | "4cc96cb6859640239a3f434bf7cc5402": {
1510 | "model_module": "@jupyter-widgets/controls",
1511 | "model_module_version": "1.5.0",
1512 | "model_name": "DescriptionStyleModel",
1513 | "state": {
1514 | "_model_module": "@jupyter-widgets/controls",
1515 | "_model_module_version": "1.5.0",
1516 | "_model_name": "DescriptionStyleModel",
1517 | "_view_count": null,
1518 | "_view_module": "@jupyter-widgets/base",
1519 | "_view_module_version": "1.2.0",
1520 | "_view_name": "StyleView",
1521 | "description_width": ""
1522 | }
1523 | },
1524 | "885f59260c8d4342a6c103b02963fa79": {
1525 | "model_module": "@jupyter-widgets/controls",
1526 | "model_module_version": "1.5.0",
1527 | "model_name": "FloatProgressModel",
1528 | "state": {
1529 | "_dom_classes": [],
1530 | "_model_module": "@jupyter-widgets/controls",
1531 | "_model_module_version": "1.5.0",
1532 | "_model_name": "FloatProgressModel",
1533 | "_view_count": null,
1534 | "_view_module": "@jupyter-widgets/controls",
1535 | "_view_module_version": "1.5.0",
1536 | "_view_name": "ProgressView",
1537 | "bar_style": "success",
1538 | "description": "",
1539 | "description_tooltip": null,
1540 | "layout": "IPY_MODEL_8f7728ba242f48b2b3aadf3a833eb75a",
1541 | "max": 50,
1542 | "min": 0,
1543 | "orientation": "horizontal",
1544 | "style": "IPY_MODEL_03955165e0e645319087ded469df198a",
1545 | "value": 50
1546 | }
1547 | },
1548 | "8f7728ba242f48b2b3aadf3a833eb75a": {
1549 | "model_module": "@jupyter-widgets/base",
1550 | "model_module_version": "1.2.0",
1551 | "model_name": "LayoutModel",
1552 | "state": {
1553 | "_model_module": "@jupyter-widgets/base",
1554 | "_model_module_version": "1.2.0",
1555 | "_model_name": "LayoutModel",
1556 | "_view_count": null,
1557 | "_view_module": "@jupyter-widgets/base",
1558 | "_view_module_version": "1.2.0",
1559 | "_view_name": "LayoutView",
1560 | "align_content": null,
1561 | "align_items": null,
1562 | "align_self": null,
1563 | "border": null,
1564 | "bottom": null,
1565 | "display": null,
1566 | "flex": "2",
1567 | "flex_flow": null,
1568 | "grid_area": null,
1569 | "grid_auto_columns": null,
1570 | "grid_auto_flow": null,
1571 | "grid_auto_rows": null,
1572 | "grid_column": null,
1573 | "grid_gap": null,
1574 | "grid_row": null,
1575 | "grid_template_areas": null,
1576 | "grid_template_columns": null,
1577 | "grid_template_rows": null,
1578 | "height": null,
1579 | "justify_content": null,
1580 | "justify_items": null,
1581 | "left": null,
1582 | "margin": null,
1583 | "max_height": null,
1584 | "max_width": null,
1585 | "min_height": null,
1586 | "min_width": null,
1587 | "object_fit": null,
1588 | "object_position": null,
1589 | "order": null,
1590 | "overflow": null,
1591 | "overflow_x": null,
1592 | "overflow_y": null,
1593 | "padding": null,
1594 | "right": null,
1595 | "top": null,
1596 | "visibility": null,
1597 | "width": null
1598 | }
1599 | },
1600 | "91beed9008f34128ad4e23f0842758c7": {
1601 | "model_module": "@jupyter-widgets/controls",
1602 | "model_module_version": "1.5.0",
1603 | "model_name": "DescriptionStyleModel",
1604 | "state": {
1605 | "_model_module": "@jupyter-widgets/controls",
1606 | "_model_module_version": "1.5.0",
1607 | "_model_name": "DescriptionStyleModel",
1608 | "_view_count": null,
1609 | "_view_module": "@jupyter-widgets/base",
1610 | "_view_module_version": "1.2.0",
1611 | "_view_name": "StyleView",
1612 | "description_width": ""
1613 | }
1614 | },
1615 | "b42a2f5cf39d48c385105398a2096934": {
1616 | "model_module": "@jupyter-widgets/controls",
1617 | "model_module_version": "1.5.0",
1618 | "model_name": "HBoxModel",
1619 | "state": {
1620 | "_dom_classes": [],
1621 | "_model_module": "@jupyter-widgets/controls",
1622 | "_model_module_version": "1.5.0",
1623 | "_model_name": "HBoxModel",
1624 | "_view_count": null,
1625 | "_view_module": "@jupyter-widgets/controls",
1626 | "_view_module_version": "1.5.0",
1627 | "_view_name": "HBoxView",
1628 | "box_style": "",
1629 | "children": [
1630 | "IPY_MODEL_c296a4a6337e482591115a2cae9bcab9",
1631 | "IPY_MODEL_885f59260c8d4342a6c103b02963fa79",
1632 | "IPY_MODEL_207fc2d55d314763a543f4084c038f13"
1633 | ],
1634 | "layout": "IPY_MODEL_33fc35e21b5c4bacb6fbc4ab0aff8ce1"
1635 | }
1636 | },
1637 | "c296a4a6337e482591115a2cae9bcab9": {
1638 | "model_module": "@jupyter-widgets/controls",
1639 | "model_module_version": "1.5.0",
1640 | "model_name": "HTMLModel",
1641 | "state": {
1642 | "_dom_classes": [],
1643 | "_model_module": "@jupyter-widgets/controls",
1644 | "_model_module_version": "1.5.0",
1645 | "_model_name": "HTMLModel",
1646 | "_view_count": null,
1647 | "_view_module": "@jupyter-widgets/controls",
1648 | "_view_module_version": "1.5.0",
1649 | "_view_name": "HTMLView",
1650 | "description": "",
1651 | "description_tooltip": null,
1652 | "layout": "IPY_MODEL_c36574edba584e488690901249d15ddc",
1653 | "placeholder": "",
1654 | "style": "IPY_MODEL_4cc96cb6859640239a3f434bf7cc5402",
1655 | "value": "Epoch 599: 100%"
1656 | }
1657 | },
1658 | "c36574edba584e488690901249d15ddc": {
1659 | "model_module": "@jupyter-widgets/base",
1660 | "model_module_version": "1.2.0",
1661 | "model_name": "LayoutModel",
1662 | "state": {
1663 | "_model_module": "@jupyter-widgets/base",
1664 | "_model_module_version": "1.2.0",
1665 | "_model_name": "LayoutModel",
1666 | "_view_count": null,
1667 | "_view_module": "@jupyter-widgets/base",
1668 | "_view_module_version": "1.2.0",
1669 | "_view_name": "LayoutView",
1670 | "align_content": null,
1671 | "align_items": null,
1672 | "align_self": null,
1673 | "border": null,
1674 | "bottom": null,
1675 | "display": null,
1676 | "flex": null,
1677 | "flex_flow": null,
1678 | "grid_area": null,
1679 | "grid_auto_columns": null,
1680 | "grid_auto_flow": null,
1681 | "grid_auto_rows": null,
1682 | "grid_column": null,
1683 | "grid_gap": null,
1684 | "grid_row": null,
1685 | "grid_template_areas": null,
1686 | "grid_template_columns": null,
1687 | "grid_template_rows": null,
1688 | "height": null,
1689 | "justify_content": null,
1690 | "justify_items": null,
1691 | "left": null,
1692 | "margin": null,
1693 | "max_height": null,
1694 | "max_width": null,
1695 | "min_height": null,
1696 | "min_width": null,
1697 | "object_fit": null,
1698 | "object_position": null,
1699 | "order": null,
1700 | "overflow": null,
1701 | "overflow_x": null,
1702 | "overflow_y": null,
1703 | "padding": null,
1704 | "right": null,
1705 | "top": null,
1706 | "visibility": null,
1707 | "width": null
1708 | }
1709 | }
1710 | },
1711 | "version_major": 2,
1712 | "version_minor": 0
1713 | }
1714 | }
1715 | },
1716 | "nbformat": 4,
1717 | "nbformat_minor": 5
1718 | }
1719 |
--------------------------------------------------------------------------------