24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
Most of the codebase comes from Fiery
35 |
36 |
37 |
38 |
39 |
40 |
41 |

42 |
43 |
44 |
45 |
46 |
52 |
53 |
54 |
pip install eclipse_pytorch
55 |
56 |
57 |
58 |
59 |
60 |
66 | {% raw %}
67 |
68 |
83 | {% endraw %}
84 |
85 | {% raw %}
86 |
87 |
100 | {% endraw %}
101 |
102 |
103 |
104 |
let's simulte some input images:
105 |
106 |
107 |
108 |
109 | {% raw %}
110 |
111 |
124 | {% endraw %}
125 |
126 | {% raw %}
127 |
128 |
141 | {% endraw %}
142 |
143 |
144 |
145 |
you get a dict with forecasted masks and irradiances:
146 |
147 |
148 |
149 |
150 | {% raw %}
151 |
152 |
153 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
(6, torch.Size([2, 4, 128, 128]), torch.Size([2, 6]))
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | {% endraw %}
182 |
183 |
189 |
190 |
191 |
@article{paletta2021eclipse,
192 | title = {{ECLIPSE} : Envisioning Cloud Induced Perturbations in Solar Energy},
193 | author = {Quentin Paletta and Anthony Hu and Guillaume Arbod and Joan Lasenby},
194 | year = {2021},
195 | eprinttype = {arXiv},
196 | eprint = {2104.12419}
197 | }
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
Contribute
This repo is made with nbdev, please read the documentation to contribute
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/nbs/00_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp model"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# The model\n",
17 | "\n",
18 | "> from the paper from [Paletta et al](https://arxiv.org/pdf/2104.12419v1.pdf)"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "We will try to implement as close as possible the architecture from the paper `ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy`"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | ""
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "#export\n",
42 | "from eclipse_pytorch.imports import *\n",
43 | "from eclipse_pytorch.layers import *"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "## 1. Spatial Downsampler\n",
51 | "> A resnet encoder to get image features"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {},
57 | "source": [
58 | "You could use any spatial downsampler as you want, but the paper states a simple resnet arch..."
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "#export\n",
68 | "class SpatialDownsampler(nn.Module):\n",
69 | " \n",
70 | " def __init__(self, in_channels=3):\n",
71 | " super().__init__()\n",
72 | " self.conv1 = ConvBlock(in_channels, 64, kernel_size=7, stride=1)\n",
73 | " self.blocks = nn.Sequential(ResBlock(64, 64, kernel_size=3, stride=2), \n",
74 | " ResBlock(64, 128, kernel_size=3, stride=2), \n",
75 | " ResBlock(128,256, kernel_size=3, stride=2))\n",
76 | " \n",
77 | " def forward(self, x):\n",
78 | " return self.blocks(self.conv1(x))"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "sd = SpatialDownsampler()"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "data": {
97 | "text/plain": [
98 | "torch.Size([1, 4, 256, 8, 8])"
99 | ]
100 | },
101 | "execution_count": null,
102 | "metadata": {},
103 | "output_type": "execute_result"
104 | }
105 | ],
106 | "source": [
107 | "images = [torch.rand(1, 3, 64, 64) for _ in range(4)]\n",
108 | "features = torch.stack([sd(image) for image in images], dim=1)\n",
109 | "features.shape"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "## 2. Temporal Encoder"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "#export\n",
126 | "class TemporalModel(nn.Module):\n",
127 | " def __init__(\n",
128 | " self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0,\n",
129 | " n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True):\n",
130 | " super().__init__()\n",
131 | " self.receptive_field = receptive_field\n",
132 | " n_temporal_layers = receptive_field - 1\n",
133 | "\n",
134 | " h, w = input_shape\n",
135 | " modules = []\n",
136 | "\n",
137 | " block_in_channels = in_channels\n",
138 | " block_out_channels = start_out_channels\n",
139 | "\n",
140 | " for _ in range(n_temporal_layers):\n",
141 | " if use_pyramid_pooling:\n",
142 | " use_pyramid_pooling = True\n",
143 | " pool_sizes = [(2, h, w)]\n",
144 | " else:\n",
145 | " use_pyramid_pooling = False\n",
146 | " pool_sizes = None\n",
147 | " temporal = TemporalBlock(\n",
148 | " block_in_channels,\n",
149 | " block_out_channels,\n",
150 | " use_pyramid_pooling=use_pyramid_pooling,\n",
151 | " pool_sizes=pool_sizes,\n",
152 | " )\n",
153 | " spatial = [\n",
154 | " Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3))\n",
155 | " for _ in range(n_spatial_layers_between_temporal_layers)\n",
156 | " ]\n",
157 | " temporal_spatial_layers = nn.Sequential(temporal, *spatial)\n",
158 | " modules.extend(temporal_spatial_layers)\n",
159 | "\n",
160 | " block_in_channels = block_out_channels\n",
161 | " block_out_channels += extra_in_channels\n",
162 | "\n",
163 | " self.out_channels = block_in_channels\n",
164 | "\n",
165 | " self.model = nn.Sequential(*modules)\n",
166 | "\n",
167 | " def forward(self, x):\n",
168 | " # Reshape input tensor to (batch, C, time, H, W)\n",
169 | " x = x.permute(0, 2, 1, 3, 4)\n",
170 | " x = self.model(x)\n",
171 | " x = x.permute(0, 2, 1, 3, 4).contiguous()\n",
172 | " return x[:, (self.receptive_field - 1):]"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "tm = TemporalModel(256, 3, (8,8), 128)"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [
189 | {
190 | "data": {
191 | "text/plain": [
192 | "torch.Size([1, 2, 128, 8, 8])"
193 | ]
194 | },
195 | "execution_count": null,
196 | "metadata": {},
197 | "output_type": "execute_result"
198 | }
199 | ],
200 | "source": [
201 | "temp_encoded = tm(features)\n",
202 | "temp_encoded.shape"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "## 3. Future State Predictions"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "fp = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {},
225 | "outputs": [
226 | {
227 | "data": {
228 | "text/plain": [
229 | "torch.Size([1, 4, 128, 8, 8])"
230 | ]
231 | },
232 | "execution_count": null,
233 | "metadata": {},
234 | "output_type": "execute_result"
235 | }
236 | ],
237 | "source": [
238 | "hidden = torch.rand(1, 128, 8, 8)\n",
239 | "x = torch.rand(1, 4, 128, 8, 8)\n",
240 | "fp(x, hidden).shape"
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {},
246 | "source": [
247 | "## 4A. Segmentation Decoder"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {},
254 | "outputs": [
255 | {
256 | "data": {
257 | "text/plain": [
258 | "Bottleneck(\n",
259 | " (layers): Sequential(\n",
260 | " (conv_down_project): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
261 | " (abn_down_project): Sequential(\n",
262 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
263 | " (1): ReLU(inplace=True)\n",
264 | " )\n",
265 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)\n",
266 | " (abn): Sequential(\n",
267 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
268 | " (1): ReLU(inplace=True)\n",
269 | " )\n",
270 | " (conv_up_project): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
271 | " (abn_up_project): Sequential(\n",
272 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
273 | " (1): ReLU(inplace=True)\n",
274 | " )\n",
275 | " (dropout): Dropout2d(p=0.0, inplace=False)\n",
276 | " )\n",
277 | " (projection): Sequential(\n",
278 | " (upsample_skip_proj): Interpolate()\n",
279 | " (conv_skip_proj): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
280 | " (bn_skip_proj): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
281 | " )\n",
282 | ")"
283 | ]
284 | },
285 | "execution_count": null,
286 | "metadata": {},
287 | "output_type": "execute_result"
288 | }
289 | ],
290 | "source": [
291 | "bn = Bottleneck(256, 128, upsample=True)\n",
292 | "bn"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "metadata": {},
299 | "outputs": [
300 | {
301 | "data": {
302 | "text/plain": [
303 | "torch.Size([4, 256, 8, 8])"
304 | ]
305 | },
306 | "execution_count": null,
307 | "metadata": {},
308 | "output_type": "execute_result"
309 | }
310 | ],
311 | "source": [
312 | "features[0].shape"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {},
319 | "outputs": [
320 | {
321 | "data": {
322 | "text/plain": [
323 | "torch.Size([1, 128, 64, 64])"
324 | ]
325 | },
326 | "execution_count": null,
327 | "metadata": {},
328 | "output_type": "execute_result"
329 | }
330 | ],
331 | "source": [
332 | "x = torch.rand(1,256,32,32)\n",
333 | "bn(x).shape"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "metadata": {},
340 | "outputs": [],
341 | "source": [
342 | "#export\n",
343 | "class Upsampler(nn.Module):\n",
344 | " def __init__(self, sizes=[128,128,64], n_out=3):\n",
345 | " super().__init__()\n",
346 | " zsizes = zip(sizes[:-1], sizes[1:])\n",
347 | " self.convs = nn.Sequential(*[Bottleneck(si, sf, upsample=True) for si,sf in zsizes], \n",
348 | " Bottleneck(sizes[-1], sizes[-1], upsample=True), \n",
349 | " ConvBlock(sizes[-1], n_out, kernel_size=1, activation=None))\n",
350 | " \n",
351 | " def forward(self, x):\n",
352 | " return self.convs(x)"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": null,
358 | "metadata": {},
359 | "outputs": [
360 | {
361 | "data": {
362 | "text/plain": [
363 | "torch.Size([1, 3, 256, 256])"
364 | ]
365 | },
366 | "execution_count": null,
367 | "metadata": {},
368 | "output_type": "execute_result"
369 | }
370 | ],
371 | "source": [
372 | "us = Upsampler()\n",
373 | "\n",
374 | "x = torch.rand(1,128,32,32)\n",
375 | "us(x).shape"
376 | ]
377 | },
378 | {
379 | "cell_type": "markdown",
380 | "metadata": {},
381 | "source": [
382 | "## 4B. Irradiance Module"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "#export\n",
392 | "class IrradianceModule(nn.Module):\n",
393 | " def __init__(self):\n",
394 | " super().__init__()\n",
395 | " self.convs = nn.Sequential(ConvBlock(128, 64, stride=2), \n",
396 | " ConvBlock(64, 64),\n",
397 | " nn.AdaptiveMaxPool2d(1)\n",
398 | " )\n",
399 | " self.linear = nn.Sequential(nn.Flatten(), \n",
400 | " nn.Linear(64, 1)\n",
401 | " )\n",
402 | " def forward(self, x):\n",
403 | " return self.linear(self.convs(x))"
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "execution_count": null,
409 | "metadata": {},
410 | "outputs": [],
411 | "source": [
412 | "im = IrradianceModule()"
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "metadata": {},
419 | "outputs": [
420 | {
421 | "data": {
422 | "text/plain": [
423 | "IrradianceModule(\n",
424 | " (convs): Sequential(\n",
425 | " (0): ConvBlock(\n",
426 | " (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
427 | " (1): ReLU(inplace=True)\n",
428 | " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
429 | " )\n",
430 | " (1): ConvBlock(\n",
431 | " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
432 | " (1): ReLU(inplace=True)\n",
433 | " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
434 | " )\n",
435 | " (2): AdaptiveMaxPool2d(output_size=1)\n",
436 | " )\n",
437 | " (linear): Sequential(\n",
438 | " (0): Flatten(start_dim=1, end_dim=-1)\n",
439 | " (1): Linear(in_features=64, out_features=1, bias=True)\n",
440 | " )\n",
441 | ")"
442 | ]
443 | },
444 | "execution_count": null,
445 | "metadata": {},
446 | "output_type": "execute_result"
447 | }
448 | ],
449 | "source": [
450 | "im"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "metadata": {},
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/plain": [
461 | "torch.Size([2, 1])"
462 | ]
463 | },
464 | "execution_count": null,
465 | "metadata": {},
466 | "output_type": "execute_result"
467 | }
468 | ],
469 | "source": [
470 | "x = torch.rand(2, 128, 32, 32)\n",
471 | "im(x).shape"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {},
477 | "source": [
478 | "## Everything Together..."
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": null,
484 | "metadata": {},
485 | "outputs": [],
486 | "source": [
487 | "#export\n",
488 | "class Eclipse(nn.Module):\n",
489 | " \"\"\"Not very parametric\"\"\"\n",
490 | " def __init__(self, n_in=3, n_out=4, horizon=5, img_size=(128, 128), n_gru_layers=4, n_res_layers=4, debug=False):\n",
491 | " super().__init__()\n",
492 | " store_attr()\n",
493 | " self.spatial_downsampler = SpatialDownsampler(n_in)\n",
494 | " self.temporal_model = TemporalModel(256, 3, input_shape=(img_size[0]//8, img_size[1]//8), start_out_channels=128)\n",
495 | " self.future_prediction = FuturePrediction(128, 128, n_gru_blocks=n_gru_layers, n_res_layers=n_res_layers)\n",
496 | " self.upsampler = Upsampler(n_out=n_out)\n",
497 | " self.irradiance = IrradianceModule()\n",
498 | " \n",
499 | " def zero_hidden(self, x, horizon):\n",
500 | " bs, ch, h, w = x.shape\n",
501 | " return x.new_zeros(bs, horizon, ch, h, w)\n",
502 | " \n",
503 | " def forward(self, imgs):\n",
504 | " x = torch.stack([self.spatial_downsampler(img) for img in imgs], dim=1)\n",
505 | " \n",
506 | " #encode temporal model\n",
507 | " states = self.temporal_model(x)\n",
508 | " if self.debug: print(f'{states.shape=}')\n",
509 | " \n",
510 | " #get hidden state\n",
511 | " present_state = states[:, -1:]\n",
512 | " if self.debug: print(f'{present_state.shape=}')\n",
513 | " \n",
514 | " \n",
515 | " # Prepare future prediction input\n",
516 | " hidden_state = present_state.squeeze()\n",
517 | " if self.debug: print(f'{hidden_state.shape=}')\n",
518 | " \n",
519 | " future_prediction_input = self.zero_hidden(hidden_state, self.horizon)\n",
520 | " \n",
521 | " # Recursively predict future states\n",
522 | " future_states = self.future_prediction(future_prediction_input, hidden_state)\n",
523 | "\n",
524 | " # Concatenate present state\n",
525 | " future_states = torch.cat([present_state, future_states], dim=1)\n",
526 | " if self.debug: print(f'{future_states.shape=}')\n",
527 | " \n",
528 | " #decode outputs\n",
529 | " masks, irradiances = [], []\n",
530 | "\n",
531 | " for state in future_states.unbind(dim=1):\n",
532 | " masks.append(self.upsampler(state))\n",
533 | " irradiances.append(self.irradiance(state))\n",
534 | " return {'masks': masks, 'irradiances': torch.cat(irradiances, dim=-1)}\n",
535 | " "
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": null,
541 | "metadata": {},
542 | "outputs": [],
543 | "source": [
544 | "eclipse = Eclipse(img_size=(256, 192), debug=True)"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [
552 | {
553 | "name": "stdout",
554 | "output_type": "stream",
555 | "text": [
556 | "states.shape=torch.Size([2, 2, 128, 32, 24])\n",
557 | "present_state.shape=torch.Size([2, 1, 128, 32, 24])\n",
558 | "hidden_state.shape=torch.Size([2, 128, 32, 24])\n",
559 | "future_states.shape=torch.Size([2, 6, 128, 32, 24])\n"
560 | ]
561 | },
562 | {
563 | "data": {
564 | "text/plain": [
565 | "(torch.Size([2, 4, 256, 192]), torch.Size([2, 6]))"
566 | ]
567 | },
568 | "execution_count": null,
569 | "metadata": {},
570 | "output_type": "execute_result"
571 | }
572 | ],
573 | "source": [
574 | "preds = eclipse([torch.rand(2, 3, 256, 192) for _ in range(4)])\n",
575 | "\n",
576 | "preds['masks'][0].shape, preds['irradiances'].shape"
577 | ]
578 | },
579 | {
580 | "cell_type": "markdown",
581 | "metadata": {},
582 | "source": [
583 | "## Export"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": null,
589 | "metadata": {},
590 | "outputs": [
591 | {
592 | "name": "stdout",
593 | "output_type": "stream",
594 | "text": [
595 | "Converted 00_model.ipynb.\n",
596 | "Converted 01_layers.ipynb.\n",
597 | "Converted index.ipynb.\n"
598 | ]
599 | }
600 | ],
601 | "source": [
602 | "# hide\n",
603 | "from nbdev.export import *\n",
604 | "notebook2script()"
605 | ]
606 | }
607 | ],
608 | "metadata": {
609 | "kernelspec": {
610 | "display_name": "Python 3",
611 | "language": "python",
612 | "name": "python3"
613 | }
614 | },
615 | "nbformat": 4,
616 | "nbformat_minor": 4
617 | }
618 |
--------------------------------------------------------------------------------
/eclipse_pytorch/layers.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_layers.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['get_activation', 'get_norm', 'init_linear', 'ConvBlock', 'Bottleneck', 'Interpolate', 'UpsamplingConcat',
4 | 'UpsamplingAdd', 'ResBlock', 'SpatialGRU', 'CausalConv3d', 'conv_1x1x1_norm_activated', 'Bottleneck3D',
5 | 'PyramidSpatioTemporalPooling', 'TemporalBlock', 'FuturePrediction']
6 |
7 | # Cell
8 | from .imports import *
9 |
10 | # Cell
11 | def get_activation(activation):
12 | if activation == 'relu':
13 | return nn.ReLU(inplace=True)
14 | elif activation == 'lrelu':
15 | return nn.LeakyReLU(0.1, inplace=True)
16 | elif activation == 'elu':
17 | return nn.ELU(inplace=True)
18 | elif activation == 'tanh':
19 | return nn.Tanh(inplace=True)
20 | else:
21 | raise ValueError('Invalid activation {}'.format(activation))
22 |
23 | def get_norm(norm, out_channels):
24 | if norm == 'bn':
25 | return nn.BatchNorm2d(out_channels)
26 | elif norm == 'in':
27 | return nn.InstanceNorm2d(out_channels)
28 | else:
29 | raise ValueError('Invalid norm {}'.format(norm))
30 |
31 | # Cell
32 |
33 | def init_linear(m, act_func=None, init='auto', bias_std=0.01):
34 | if getattr(m,'bias',None) is not None and bias_std is not None:
35 | if bias_std != 0: normal_(m.bias, 0, bias_std)
36 | else: m.bias.data.zero_()
37 | if init=='auto':
38 | if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_
39 | else: init = getattr(act_func.__class__, '__default_init__', None)
40 | if init is None: init = getattr(act_func, '__default_init__', None)
41 | if init is not None: init(m.weight)
42 |
43 |
44 | class ConvBlock(nn.Sequential):
45 | """2D convolution followed by
46 | - an optional normalisation (batch norm or instance norm)
47 | - an optional activation (ReLU, LeakyReLU, or tanh)
48 | """
49 |
50 | def __init__(
51 | self,
52 | in_channels,
53 | out_channels=None,
54 | kernel_size=3,
55 | stride=1,
56 | norm='bn',
57 | activation='relu',
58 | bias=False,
59 | transpose=False,
60 | init='auto'
61 | ):
62 |
63 | out_channels = out_channels or in_channels
64 | padding = (kernel_size-1)//2
65 | conv_cls = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)
66 | conv = conv_cls(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)
67 | if activation is not None: activation = get_activation(activation)
68 | init_linear(conv, activation, init=init)
69 | layers = [conv]
70 | if activation is not None: layers.append(activation)
71 | if norm is not None: layers.append(get_norm(norm, out_channels))
72 | super().__init__(*layers)
73 |
74 | # Cell
75 | class Bottleneck(nn.Module):
76 | """
77 | Defines a bottleneck module with a residual connection
78 | """
79 |
80 | def __init__(
81 | self,
82 | in_channels,
83 | out_channels=None,
84 | kernel_size=3,
85 | dilation=1,
86 | groups=1,
87 | upsample=False,
88 | downsample=False,
89 | dropout=0.0,
90 | ):
91 | super().__init__()
92 | self._downsample = downsample
93 | bottleneck_channels = int(in_channels / 2)
94 | out_channels = out_channels or in_channels
95 | padding_size = ((kernel_size - 1) * dilation + 1) // 2
96 |
97 | # Define the main conv operation
98 | assert dilation == 1
99 | if upsample:
100 | assert not downsample, 'downsample and upsample not possible simultaneously.'
101 | bottleneck_conv = nn.ConvTranspose2d(
102 | bottleneck_channels,
103 | bottleneck_channels,
104 | kernel_size=kernel_size,
105 | bias=False,
106 | dilation=1,
107 | stride=2,
108 | output_padding=padding_size,
109 | padding=padding_size,
110 | groups=groups,
111 | )
112 | elif downsample:
113 | bottleneck_conv = nn.Conv2d(
114 | bottleneck_channels,
115 | bottleneck_channels,
116 | kernel_size=kernel_size,
117 | bias=False,
118 | dilation=dilation,
119 | stride=2,
120 | padding=padding_size,
121 | groups=groups,
122 | )
123 | else:
124 | bottleneck_conv = nn.Conv2d(
125 | bottleneck_channels,
126 | bottleneck_channels,
127 | kernel_size=kernel_size,
128 | bias=False,
129 | dilation=dilation,
130 | padding=padding_size,
131 | groups=groups,
132 | )
133 |
134 | self.layers = nn.Sequential(
135 | OrderedDict(
136 | [
137 | # First projection with 1x1 kernel
138 | ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),
139 | ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),
140 | nn.ReLU(inplace=True))),
141 | # Second conv block
142 | ('conv', bottleneck_conv),
143 | ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),
144 | # Final projection with 1x1 kernel
145 | ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),
146 | ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),
147 | nn.ReLU(inplace=True))),
148 | # Regulariser
149 | ('dropout', nn.Dropout2d(p=dropout)),
150 | ]
151 | )
152 | )
153 |
154 | if out_channels == in_channels and not downsample and not upsample:
155 | self.projection = None
156 | else:
157 | projection = OrderedDict()
158 | if upsample:
159 | projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})
160 | elif downsample:
161 | projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})
162 | projection.update(
163 | {
164 | 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
165 | 'bn_skip_proj': nn.BatchNorm2d(out_channels),
166 | }
167 | )
168 | self.projection = nn.Sequential(projection)
169 |
170 | # pylint: disable=arguments-differ
171 | def forward(self, *args):
172 | (x,) = args
173 | x_residual = self.layers(x)
174 | if self.projection is not None:
175 | if self._downsample:
176 | # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer
177 | x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)
178 | return x_residual + self.projection(x)
179 | return x_residual + x
180 |
181 | # Cell
182 | class Interpolate(nn.Module):
183 | def __init__(self, scale_factor: int = 2):
184 | super().__init__()
185 | self._interpolate = nn.functional.interpolate
186 | self._scale_factor = scale_factor
187 |
188 | # pylint: disable=arguments-differ
189 | def forward(self, x):
190 | return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)
191 |
192 |
193 | class UpsamplingConcat(nn.Module):
194 | def __init__(self, in_channels, out_channels, scale_factor=2):
195 | super().__init__()
196 |
197 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
198 |
199 | self.conv = nn.Sequential(
200 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
201 | nn.BatchNorm2d(out_channels),
202 | nn.ReLU(inplace=True),
203 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
204 | nn.BatchNorm2d(out_channels),
205 | nn.ReLU(inplace=True),
206 | )
207 |
208 | def forward(self, x_to_upsample, x):
209 | x_to_upsample = self.upsample(x_to_upsample)
210 | x_to_upsample = torch.cat([x, x_to_upsample], dim=1)
211 | return self.conv(x_to_upsample)
212 |
213 |
214 | class UpsamplingAdd(nn.Module):
215 | def __init__(self, in_channels, out_channels, scale_factor=2):
216 | super().__init__()
217 | self.upsample_layer = nn.Sequential(
218 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
219 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),
220 | nn.BatchNorm2d(out_channels),
221 | )
222 |
223 | def forward(self, x, x_skip):
224 | x = self.upsample_layer(x)
225 | return x + x_skip
226 |
227 | # Cell
228 | class ResBlock(nn.Module):
229 | " A simple resnet Block"
230 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu'):
231 | super().__init__()
232 | self.convs = nn.Sequential(ConvBlock(in_channels, out_channels, kernel_size, stride, norm=norm, activation=activation),
233 | ConvBlock(out_channels, out_channels, norm=norm, activation=activation)
234 | )
235 | id_path = [ConvBlock(in_channels, out_channels, kernel_size=1, activation=None, norm=None)]
236 | self.activation = get_activation(activation)
237 | if stride!=1: id_path.insert(1, nn.AvgPool2d(2, stride, ceil_mode=True))
238 | self.id_path = nn.Sequential(*id_path)
239 |
240 | def forward(self, x):
241 | return self.activation(self.convs(x) + self.id_path(x))
242 |
243 | # Cell
244 | class SpatialGRU(nn.Module):
245 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a
246 | convolutional gated recurrent unit over the data"""
247 |
248 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):
249 | super().__init__()
250 | self.input_size = input_size
251 | self.hidden_size = hidden_size
252 | self.gru_bias_init = gru_bias_init
253 |
254 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
255 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
256 |
257 | self.conv_state_tilde = ConvBlock(
258 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation
259 | )
260 |
261 | def forward(self, x, state=None, flow=None, mode='bilinear'):
262 | # pylint: disable=unused-argument, arguments-differ
263 | # Check size
264 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'
265 | b, timesteps, c, h, w = x.size()
266 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'
267 |
268 | # recurrent layers
269 | rnn_output = []
270 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state
271 | for t in range(timesteps):
272 | x_t = x[:, t]
273 | if flow is not None:
274 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)
275 |
276 | # propagate rnn state
277 | rnn_state = self.gru_cell(x_t, rnn_state)
278 | rnn_output.append(rnn_state)
279 |
280 | # reshape rnn output to batch tensor
281 | return torch.stack(rnn_output, dim=1)
282 |
283 | def gru_cell(self, x, state):
284 | # Compute gates
285 | x_and_state = torch.cat([x, state], dim=1)
286 | update_gate = self.conv_update(x_and_state)
287 | reset_gate = self.conv_reset(x_and_state)
288 | # Add bias to initialise gate as close to identity function
289 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init)
290 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)
291 |
292 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)
293 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))
294 |
295 | output = (1.0 - update_gate) * state + update_gate * state_tilde
296 | return output
297 |
298 | # Cell
299 | class CausalConv3d(nn.Module):
300 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):
301 | super().__init__()
302 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
303 | time_pad = (kernel_size[0] - 1) * dilation[0]
304 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2
305 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2
306 |
307 | # Pad temporally on the left
308 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
309 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)
310 | self.norm = nn.BatchNorm3d(out_channels)
311 | self.activation = nn.ReLU(inplace=True)
312 |
313 | def forward(self, *inputs):
314 | (x,) = inputs
315 | x = self.pad(x)
316 | x = self.conv(x)
317 | x = self.norm(x)
318 | x = self.activation(x)
319 | return x
320 |
321 | # Cell
322 | def conv_1x1x1_norm_activated(in_channels, out_channels):
323 | """1x1x1 3D convolution, normalization and activation layer."""
324 | return nn.Sequential(
325 | OrderedDict(
326 | [
327 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),
328 | ('norm', nn.BatchNorm3d(out_channels)),
329 | ('activation', nn.ReLU(inplace=True)),
330 | ]
331 | )
332 | )
333 |
334 | # Cell
335 | class Bottleneck3D(nn.Module):
336 | """
337 | Defines a bottleneck module with a residual connection
338 | """
339 |
340 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):
341 | super().__init__()
342 | bottleneck_channels = in_channels // 2
343 | out_channels = out_channels or in_channels
344 |
345 | self.layers = nn.Sequential(
346 | OrderedDict(
347 | [
348 | # First projection with 1x1 kernel
349 | ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)),
350 | # Second conv block
351 | (
352 | 'conv',
353 | CausalConv3d(
354 | bottleneck_channels,
355 | bottleneck_channels,
356 | kernel_size=kernel_size,
357 | dilation=dilation,
358 | bias=False,
359 | ),
360 | ),
361 | # Final projection with 1x1 kernel
362 | ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)),
363 | ]
364 | )
365 | )
366 |
367 | if out_channels != in_channels:
368 | self.projection = nn.Sequential(
369 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),
370 | nn.BatchNorm3d(out_channels),
371 | )
372 | else:
373 | self.projection = None
374 |
375 | def forward(self, *args):
376 | (x,) = args
377 | x_residual = self.layers(x)
378 | x_features = self.projection(x) if self.projection is not None else x
379 | return x_residual + x_features
380 |
381 | # Cell
382 | class PyramidSpatioTemporalPooling(nn.Module):
383 | """ Spatio-temporal pyramid pooling.
384 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.
385 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]
386 | """
387 |
388 | def __init__(self, in_channels, reduction_channels, pool_sizes):
389 | super().__init__()
390 | self.features = []
391 | for pool_size in pool_sizes:
392 | assert pool_size[0] == 2, (
393 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size"
394 | )
395 | stride = (1, *pool_size[1:])
396 | padding = (pool_size[0] - 1, 0, 0)
397 | self.features.append(
398 | nn.Sequential(
399 | OrderedDict(
400 | [
401 | # Pad the input tensor but do not take into account zero padding into the average.
402 | (
403 | 'avgpool',
404 | torch.nn.AvgPool3d(
405 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False
406 | ),
407 | ),
408 | ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)),
409 | ]
410 | )
411 | )
412 | )
413 | self.features = nn.ModuleList(self.features)
414 |
415 | def forward(self, *inputs):
416 | (x,) = inputs
417 | b, _, t, h, w = x.shape
418 | # Do not include current tensor when concatenating
419 | out = []
420 | for f in self.features:
421 | # Remove unnecessary padded values (time dimension) on the right
422 | x_pool = f(x)[:, :, :-1].contiguous()
423 | c = x_pool.shape[1]
424 | x_pool = nn.functional.interpolate(
425 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False
426 | )
427 | x_pool = x_pool.view(b, c, t, h, w)
428 | out.append(x_pool)
429 | out = torch.cat(out, 1)
430 | return out
431 |
432 | # Cell
433 | class TemporalBlock(nn.Module):
434 | """ Temporal block with the following layers:
435 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling
436 | - dropout
437 | - skip connection.
438 | """
439 |
440 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):
441 | super().__init__()
442 | self.in_channels = in_channels
443 | self.half_channels = in_channels // 2
444 | self.out_channels = out_channels or self.in_channels
445 | self.kernels = [(2, 3, 3), (1, 3, 3)]
446 |
447 | # Flag for spatio-temporal pyramid pooling
448 | self.use_pyramid_pooling = use_pyramid_pooling
449 |
450 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1
451 | self.convolution_paths = []
452 | for kernel_size in self.kernels:
453 | self.convolution_paths.append(
454 | nn.Sequential(
455 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels),
456 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),
457 | )
458 | )
459 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))
460 | self.convolution_paths = nn.ModuleList(self.convolution_paths)
461 |
462 | agg_in_channels = len(self.convolution_paths) * self.half_channels
463 |
464 | if self.use_pyramid_pooling:
465 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None."
466 | reduction_channels = self.in_channels // 3
467 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)
468 | agg_in_channels += len(pool_sizes) * reduction_channels
469 |
470 | # Feature aggregation
471 | self.aggregation = nn.Sequential(
472 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),)
473 |
474 | if self.out_channels != self.in_channels:
475 | self.projection = nn.Sequential(
476 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),
477 | nn.BatchNorm3d(self.out_channels),
478 | )
479 | else:
480 | self.projection = None
481 |
482 | def forward(self, *inputs):
483 | (x,) = inputs
484 | x_paths = []
485 | for conv in self.convolution_paths:
486 | x_paths.append(conv(x))
487 | x_residual = torch.cat(x_paths, dim=1)
488 | if self.use_pyramid_pooling:
489 | x_pool = self.pyramid_pooling(x)
490 | x_residual = torch.cat([x_residual, x_pool], dim=1)
491 | x_residual = self.aggregation(x_residual)
492 |
493 | if self.out_channels != self.in_channels:
494 | x = self.projection(x)
495 | x = x + x_residual
496 | return x
497 |
498 | # Cell
499 | class FuturePrediction(torch.nn.Module):
500 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):
501 | super().__init__()
502 | self.n_gru_blocks = n_gru_blocks
503 |
504 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample
505 | # from the probabilistic model. The architecture of the model is:
506 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks
507 | self.spatial_grus = []
508 | self.res_blocks = []
509 |
510 | for i in range(self.n_gru_blocks):
511 | gru_in_channels = latent_dim if i == 0 else in_channels
512 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))
513 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)
514 | for _ in range(n_res_layers)]))
515 |
516 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)
517 | self.res_blocks = torch.nn.ModuleList(self.res_blocks)
518 |
519 | def forward(self, x, hidden_state):
520 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)
521 | for i in range(self.n_gru_blocks):
522 | x = self.spatial_grus[i](x, hidden_state, flow=None)
523 | b, n_future, c, h, w = x.shape
524 |
525 | x = self.res_blocks[i](x.view(b * n_future, c, h, w))
526 | x = x.view(b, n_future, c, h, w)
527 |
528 | return x
--------------------------------------------------------------------------------
/docs/model.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: The model
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "from the paper from
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
We will try to implement as close as possible the architecture from the paper ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy
35 |
36 |
37 |
38 |
39 |
40 |
41 |

42 |
43 |
44 |
45 |
46 | {% raw %}
47 |
48 |
49 |
50 |
51 | {% endraw %}
52 |
53 |
54 |
55 |
1. Spatial Downsampler
A resnet encoder to get image features
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
You could use any spatial downsampler as you want, but the paper states a simple resnet arch...
64 |
65 |
66 |
67 |
68 | {% raw %}
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
SpatialDownsampler(in_channels=3) :: Module
80 |
81 |
Base class for all neural network modules.
82 |
Your models should also subclass this class.
83 |
Modules can also contain other Modules, allowing to nest them in
84 | a tree structure. You can assign the submodules as regular attributes::
85 |
86 |
import torch.nn as nn
87 | import torch.nn.functional as F
88 |
89 | class Model(nn.Module):
90 | def __init__(self):
91 | super(Model, self).__init__()
92 | self.conv1 = nn.Conv2d(1, 20, 5)
93 | self.conv2 = nn.Conv2d(20, 20, 5)
94 |
95 | def forward(self, x):
96 | x = F.relu(self.conv1(x))
97 | return F.relu(self.conv2(x))
98 |
99 |
100 |
Submodules assigned in this way will be registered, and will have their
101 | parameters converted too when you call :meth:to, etc.
102 |
:ivar training: Boolean represents whether this module is in training or
103 | evaluation mode.
104 | :vartype training: bool
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 | {% endraw %}
115 |
116 | {% raw %}
117 |
118 |
119 |
120 |
121 | {% endraw %}
122 |
123 | {% raw %}
124 |
125 |
138 | {% endraw %}
139 |
140 | {% raw %}
141 |
142 |
143 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
torch.Size([1, 256, 4, 8, 8])
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 | {% endraw %}
174 |
175 |
176 |
177 |
2. Temporal Encoder
178 |
179 |
180 |
181 | {% raw %}
182 |
183 |
196 | {% endraw %}
197 |
198 | {% raw %}
199 |
200 |
201 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
torch.Size([1, 128, 4, 8, 8])
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 | {% endraw %}
231 |
232 |
233 |
234 |
3. Future State Predictions
235 |
236 |
237 |
238 | {% raw %}
239 |
240 |
253 | {% endraw %}
254 |
255 | {% raw %}
256 |
257 |
258 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
torch.Size([1, 4, 128, 8, 8])
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 | {% endraw %}
289 |
290 |
291 |
292 |
4A. Segmentation Decoder
293 |
294 |
295 |
296 | {% raw %}
297 |
298 |
299 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
Bottleneck(
320 | (layers): Sequential(
321 | (conv_down_project): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
322 | (abn_down_project): Sequential(
323 | (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
324 | (1): ReLU(inplace=True)
325 | )
326 | (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
327 | (abn): Sequential(
328 | (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
329 | (1): ReLU(inplace=True)
330 | )
331 | (conv_up_project): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
332 | (abn_up_project): Sequential(
333 | (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
334 | (1): ReLU(inplace=True)
335 | )
336 | (dropout): Dropout2d(p=0.0, inplace=False)
337 | )
338 | (projection): Sequential(
339 | (upsample_skip_proj): Interpolate()
340 | (conv_skip_proj): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
341 | (bn_skip_proj): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
342 | )
343 | )
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 | {% endraw %}
353 |
354 | {% raw %}
355 |
356 |
357 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
torch.Size([256, 4, 8, 8])
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 | {% endraw %}
386 |
387 | {% raw %}
388 |
389 |
390 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
torch.Size([1, 128, 64, 64])
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 | {% endraw %}
420 |
421 | {% raw %}
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
Upsampler(sizes=[128, 128, 64], n_out=3) :: Module
433 |
434 |
Base class for all neural network modules.
435 |
Your models should also subclass this class.
436 |
Modules can also contain other Modules, allowing to nest them in
437 | a tree structure. You can assign the submodules as regular attributes::
438 |
439 |
import torch.nn as nn
440 | import torch.nn.functional as F
441 |
442 | class Model(nn.Module):
443 | def __init__(self):
444 | super(Model, self).__init__()
445 | self.conv1 = nn.Conv2d(1, 20, 5)
446 | self.conv2 = nn.Conv2d(20, 20, 5)
447 |
448 | def forward(self, x):
449 | x = F.relu(self.conv1(x))
450 | return F.relu(self.conv2(x))
451 |
452 |
453 |
Submodules assigned in this way will be registered, and will have their
454 | parameters converted too when you call :meth:to, etc.
455 |
:ivar training: Boolean represents whether this module is in training or
456 | evaluation mode.
457 | :vartype training: bool
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 | {% endraw %}
468 |
469 | {% raw %}
470 |
471 |
472 |
473 |
474 | {% endraw %}
475 |
476 | {% raw %}
477 |
478 |
479 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
torch.Size([1, 3, 256, 256])
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 | {% endraw %}
511 |
512 |
513 |
514 |
4B. Irradiance Module
515 |
516 |
517 |
518 | {% raw %}
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
IrradianceModule() :: Module
530 |
531 |
Base class for all neural network modules.
532 |
Your models should also subclass this class.
533 |
Modules can also contain other Modules, allowing to nest them in
534 | a tree structure. You can assign the submodules as regular attributes::
535 |
536 |
import torch.nn as nn
537 | import torch.nn.functional as F
538 |
539 | class Model(nn.Module):
540 | def __init__(self):
541 | super(Model, self).__init__()
542 | self.conv1 = nn.Conv2d(1, 20, 5)
543 | self.conv2 = nn.Conv2d(20, 20, 5)
544 |
545 | def forward(self, x):
546 | x = F.relu(self.conv1(x))
547 | return F.relu(self.conv2(x))
548 |
549 |
550 |
Submodules assigned in this way will be registered, and will have their
551 | parameters converted too when you call :meth:to, etc.
552 |
:ivar training: Boolean represents whether this module is in training or
553 | evaluation mode.
554 | :vartype training: bool
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 | {% endraw %}
565 |
566 | {% raw %}
567 |
568 |
569 |
570 |
571 | {% endraw %}
572 |
573 | {% raw %}
574 |
575 |
588 | {% endraw %}
589 |
590 | {% raw %}
591 |
592 |
593 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
torch.Size([2, 1])
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 | {% endraw %}
623 |
624 |
625 |
626 |
Everything Together...
627 |
628 |
629 |
630 | {% raw %}
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
Eclipse(n_in=3, n_out=4, horizon=5, debug=False) :: Module
642 |
643 |
Not very parametric
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 | {% endraw %}
654 |
655 | {% raw %}
656 |
657 |
658 |
659 |
660 | {% endraw %}
661 |
662 | {% raw %}
663 |
664 |
677 | {% endraw %}
678 |
679 | {% raw %}
680 |
681 |
682 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
states.shape=torch.Size([2, 4, 128, 16, 16])
702 | present_state.shape=torch.Size([2, 1, 128, 16, 16])
703 | hidden_state.shape=torch.Size([2, 128, 16, 16])
704 | future_states.shape=torch.Size([2, 6, 128, 16, 16])
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
(torch.Size([2, 4, 128, 128]), torch.Size([2, 6]))
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 | {% endraw %}
724 |
725 |
731 |
732 |
733 |
734 |
--------------------------------------------------------------------------------
/docs/layers.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Layers
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "most of them come from
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
35 |
36 |
37 | {% endraw %}
38 |
39 | {% raw %}
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
get_activation(activation)
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 | {% endraw %}
62 |
63 | {% raw %}
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
get_norm(norm, out_channels)
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 | {% endraw %}
86 |
87 | {% raw %}
88 |
89 |
90 |
91 |
92 | {% endraw %}
93 |
94 | {% raw %}
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
init_linear(m, act_func=None, init='auto', bias_std=0.01)
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | {% endraw %}
117 |
118 | {% raw %}
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
ConvBlock(in_channels, out_channels=None, kernel_size=3, stride=1, norm='bn', activation='relu', bias=False, transpose=False, init='auto') :: Sequential
130 |
131 |
2D convolution followed by
132 |
133 | - an optional normalisation (batch norm or instance norm)
134 | - an optional activation (ReLU, LeakyReLU, or tanh)
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | {% endraw %}
146 |
147 | {% raw %}
148 |
149 |
150 |
151 |
152 | {% endraw %}
153 |
154 | {% raw %}
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
Bottleneck(in_channels, out_channels=None, kernel_size=3, dilation=1, groups=1, upsample=False, downsample=False, dropout=0.0) :: Module
166 |
167 |
Defines a bottleneck module with a residual connection
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 | {% endraw %}
178 |
179 | {% raw %}
180 |
181 |
182 |
183 |
184 | {% endraw %}
185 |
186 | {% raw %}
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
Interpolate(scale_factor:int=2) :: Module
198 |
199 |
Base class for all neural network modules.
200 |
Your models should also subclass this class.
201 |
Modules can also contain other Modules, allowing to nest them in
202 | a tree structure. You can assign the submodules as regular attributes::
203 |
204 |
import torch.nn as nn
205 | import torch.nn.functional as F
206 |
207 | class Model(nn.Module):
208 | def __init__(self):
209 | super(Model, self).__init__()
210 | self.conv1 = nn.Conv2d(1, 20, 5)
211 | self.conv2 = nn.Conv2d(20, 20, 5)
212 |
213 | def forward(self, x):
214 | x = F.relu(self.conv1(x))
215 | return F.relu(self.conv2(x))
216 |
217 |
218 |
Submodules assigned in this way will be registered, and will have their
219 | parameters converted too when you call :meth:to, etc.
220 |
:ivar training: Boolean represents whether this module is in training or
221 | evaluation mode.
222 | :vartype training: bool
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 | {% endraw %}
233 |
234 | {% raw %}
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
UpsamplingConcat(in_channels, out_channels, scale_factor=2) :: Module
246 |
247 |
Base class for all neural network modules.
248 |
Your models should also subclass this class.
249 |
Modules can also contain other Modules, allowing to nest them in
250 | a tree structure. You can assign the submodules as regular attributes::
251 |
252 |
import torch.nn as nn
253 | import torch.nn.functional as F
254 |
255 | class Model(nn.Module):
256 | def __init__(self):
257 | super(Model, self).__init__()
258 | self.conv1 = nn.Conv2d(1, 20, 5)
259 | self.conv2 = nn.Conv2d(20, 20, 5)
260 |
261 | def forward(self, x):
262 | x = F.relu(self.conv1(x))
263 | return F.relu(self.conv2(x))
264 |
265 |
266 |
Submodules assigned in this way will be registered, and will have their
267 | parameters converted too when you call :meth:to, etc.
268 |
:ivar training: Boolean represents whether this module is in training or
269 | evaluation mode.
270 | :vartype training: bool
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 | {% endraw %}
281 |
282 | {% raw %}
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
UpsamplingAdd(in_channels, out_channels, scale_factor=2) :: Module
294 |
295 |
Base class for all neural network modules.
296 |
Your models should also subclass this class.
297 |
Modules can also contain other Modules, allowing to nest them in
298 | a tree structure. You can assign the submodules as regular attributes::
299 |
300 |
import torch.nn as nn
301 | import torch.nn.functional as F
302 |
303 | class Model(nn.Module):
304 | def __init__(self):
305 | super(Model, self).__init__()
306 | self.conv1 = nn.Conv2d(1, 20, 5)
307 | self.conv2 = nn.Conv2d(20, 20, 5)
308 |
309 | def forward(self, x):
310 | x = F.relu(self.conv1(x))
311 | return F.relu(self.conv2(x))
312 |
313 |
314 |
Submodules assigned in this way will be registered, and will have their
315 | parameters converted too when you call :meth:to, etc.
316 |
:ivar training: Boolean represents whether this module is in training or
317 | evaluation mode.
318 | :vartype training: bool
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 | {% endraw %}
329 |
330 | {% raw %}
331 |
332 |
333 |
334 |
335 | {% endraw %}
336 |
337 | {% raw %}
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
ResBlock(in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu') :: Module
349 |
350 |
A simple resnet Block
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 | {% endraw %}
361 |
362 | {% raw %}
363 |
364 |
365 |
366 |
367 | {% endraw %}
368 |
369 | {% raw %}
370 |
371 |
384 | {% endraw %}
385 |
386 | {% raw %}
387 |
388 |
389 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
torch.Size([1, 128, 10, 10])
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 | {% endraw %}
419 |
420 | {% raw %}
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
SpatialGRU(input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu') :: Module
432 |
433 |
A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a
434 | convolutional gated recurrent unit over the data
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 | {% endraw %}
445 |
446 | {% raw %}
447 |
448 |
449 |
450 |
451 | {% endraw %}
452 |
453 | {% raw %}
454 |
455 |
468 | {% endraw %}
469 |
470 |
471 |
472 |
without hidden state
473 |
474 |
475 |
476 |
477 | {% raw %}
478 |
479 |
493 | {% endraw %}
494 |
495 |
496 |
497 |
with hidden
498 |
hidden.shape = (bs, hidden_size, h, w)
499 |
500 |
501 |
502 |
503 |
504 | {% raw %}
505 |
506 |
507 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
torch.Size([1, 3, 64, 8, 8])
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 | {% endraw %}
540 |
541 | {% raw %}
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
CausalConv3d(in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False) :: Module
553 |
554 |
Base class for all neural network modules.
555 |
Your models should also subclass this class.
556 |
Modules can also contain other Modules, allowing to nest them in
557 | a tree structure. You can assign the submodules as regular attributes::
558 |
559 |
import torch.nn as nn
560 | import torch.nn.functional as F
561 |
562 | class Model(nn.Module):
563 | def __init__(self):
564 | super(Model, self).__init__()
565 | self.conv1 = nn.Conv2d(1, 20, 5)
566 | self.conv2 = nn.Conv2d(20, 20, 5)
567 |
568 | def forward(self, x):
569 | x = F.relu(self.conv1(x))
570 | return F.relu(self.conv2(x))
571 |
572 |
573 |
Submodules assigned in this way will be registered, and will have their
574 | parameters converted too when you call :meth:to, etc.
575 |
:ivar training: Boolean represents whether this module is in training or
576 | evaluation mode.
577 | :vartype training: bool
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 | {% endraw %}
588 |
589 | {% raw %}
590 |
591 |
592 |
593 |
594 | {% endraw %}
595 |
596 | {% raw %}
597 |
598 |
611 | {% endraw %}
612 |
613 | {% raw %}
614 |
615 |
616 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
torch.Size([1, 128, 4, 8, 8])
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 | {% endraw %}
646 |
647 | {% raw %}
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
conv_1x1x1_norm_activated(in_channels, out_channels)
659 |
660 |
1x1x1 3D convolution, normalization and activation layer.
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 | {% endraw %}
671 |
672 | {% raw %}
673 |
674 |
675 |
676 |
677 | {% endraw %}
678 |
679 | {% raw %}
680 |
681 |
682 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
Sequential(
702 | (conv): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
703 | (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
704 | (activation): ReLU(inplace=True)
705 | )
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 | {% endraw %}
715 |
716 | {% raw %}
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
Bottleneck3D(in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)) :: Module
728 |
729 |
Defines a bottleneck module with a residual connection
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 | {% endraw %}
740 |
741 | {% raw %}
742 |
743 |
744 |
745 |
746 | {% endraw %}
747 |
748 | {% raw %}
749 |
750 |
763 | {% endraw %}
764 |
765 | {% raw %}
766 |
767 |
768 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
torch.Size([1, 12, 4, 8, 8])
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 | {% endraw %}
798 |
799 | {% raw %}
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
PyramidSpatioTemporalPooling(in_channels, reduction_channels, pool_sizes) :: Module
811 |
812 |
Spatio-temporal pyramid pooling.
813 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.
814 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 | {% endraw %}
825 |
826 | {% raw %}
827 |
828 |
829 |
830 |
831 | {% endraw %}
832 |
833 | {% raw %}
834 |
835 |
849 | {% endraw %}
850 |
851 | {% raw %}
852 |
853 |
854 |
867 |
868 |
869 |
870 |
871 |
872 |
873 |
874 |
875 |
876 |
torch.Size([1, 36, 4, 64, 64])
877 |
878 |
879 |
880 |
881 |
882 |
883 |
884 |
885 | {% endraw %}
886 |
887 | {% raw %}
888 |
889 |
890 |
891 |
892 |
893 |
894 |
895 |
896 |
897 |
898 |
TemporalBlock(in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None) :: Module
899 |
900 |
Temporal block with the following layers:
901 |
902 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling
903 | - dropout
904 | - skip connection.
905 |
906 |
907 |
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 | {% endraw %}
916 |
917 | {% raw %}
918 |
919 |
920 |
921 |
922 | {% endraw %}
923 |
924 | {% raw %}
925 |
926 |
927 |
941 |
942 |
943 |
944 |
945 |
946 |
947 |
948 |
949 |
950 |
torch.Size([1, 8, 4, 64, 64])
951 |
952 |
953 |
954 |
955 |
956 |
957 |
958 |
959 | {% endraw %}
960 |
961 | {% raw %}
962 |
963 |
964 |
965 |
966 |
967 |
968 |
969 |
970 |
971 |
972 |
FuturePrediction(in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3) :: Module
973 |
974 |
Base class for all neural network modules.
975 |
Your models should also subclass this class.
976 |
Modules can also contain other Modules, allowing to nest them in
977 | a tree structure. You can assign the submodules as regular attributes::
978 |
979 |
import torch.nn as nn
980 | import torch.nn.functional as F
981 |
982 | class Model(nn.Module):
983 | def __init__(self):
984 | super(Model, self).__init__()
985 | self.conv1 = nn.Conv2d(1, 20, 5)
986 | self.conv2 = nn.Conv2d(20, 20, 5)
987 |
988 | def forward(self, x):
989 | x = F.relu(self.conv1(x))
990 | return F.relu(self.conv2(x))
991 |
992 |
993 |
Submodules assigned in this way will be registered, and will have their
994 | parameters converted too when you call :meth:to, etc.
995 |
:ivar training: Boolean represents whether this module is in training or
996 | evaluation mode.
997 | :vartype training: bool
998 |
999 |
1000 |
1001 |
1002 |
1003 |
1004 |
1005 |
1006 |
1007 | {% endraw %}
1008 |
1009 | {% raw %}
1010 |
1011 |
1012 |
1013 |
1014 | {% endraw %}
1015 |
1016 | {% raw %}
1017 |
1018 |
1035 | {% endraw %}
1036 |
1037 |
1043 |
1044 |
1045 |
1046 |
--------------------------------------------------------------------------------
/nbs/01_layers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "94f61f73-8683-4638-b730-6d9174eb212d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# default_exp layers"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "0ea22c98-3207-4673-8ee8-7396637f45cd",
16 | "metadata": {},
17 | "source": [
18 | "# Layers\n",
19 | "> most of them come from [fiery](https://github.com/wayveai/fiery)"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "id": "113f4b7d-ace4-4416-a067-9d7025ef7ae6",
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "#export\n",
30 | "from eclipse_pytorch.imports import *"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "id": "e8f80af7-bcc4-43a8-a6bd-283991737d8f",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "#export\n",
41 | "def get_activation(activation):\n",
42 | " if activation == 'relu':\n",
43 | " return nn.ReLU(inplace=True)\n",
44 | " elif activation == 'lrelu':\n",
45 | " return nn.LeakyReLU(0.1, inplace=True)\n",
46 | " elif activation == 'elu':\n",
47 | " return nn.ELU(inplace=True)\n",
48 | " elif activation == 'tanh':\n",
49 | " return nn.Tanh(inplace=True)\n",
50 | " else:\n",
51 | " raise ValueError('Invalid activation {}'.format(activation))\n",
52 | " \n",
53 | "def get_norm(norm, out_channels):\n",
54 | " if norm == 'bn':\n",
55 | " return nn.BatchNorm2d(out_channels)\n",
56 | " elif norm == 'in':\n",
57 | " return nn.InstanceNorm2d(out_channels)\n",
58 | " else:\n",
59 | " raise ValueError('Invalid norm {}'.format(norm))"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "f158184b-c6d0-4a22-8625-f8514b1b524f",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "#export\n",
70 | "\n",
71 | "def init_linear(m, act_func=None, init='auto', bias_std=0.01):\n",
72 | " if getattr(m,'bias',None) is not None and bias_std is not None:\n",
73 | " if bias_std != 0: normal_(m.bias, 0, bias_std)\n",
74 | " else: m.bias.data.zero_()\n",
75 | " if init=='auto':\n",
76 | " if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_\n",
77 | " else: init = getattr(act_func.__class__, '__default_init__', None)\n",
78 | " if init is None: init = getattr(act_func, '__default_init__', None)\n",
79 | " if init is not None: init(m.weight)\n",
80 | "\n",
81 | "\n",
82 | "class ConvBlock(nn.Sequential):\n",
83 | " \"\"\"2D convolution followed by\n",
84 | " - an optional normalisation (batch norm or instance norm)\n",
85 | " - an optional activation (ReLU, LeakyReLU, or tanh)\n",
86 | " \"\"\"\n",
87 | "\n",
88 | " def __init__(\n",
89 | " self,\n",
90 | " in_channels,\n",
91 | " out_channels=None,\n",
92 | " kernel_size=3,\n",
93 | " stride=1,\n",
94 | " norm='bn',\n",
95 | " activation='relu',\n",
96 | " bias=False,\n",
97 | " transpose=False,\n",
98 | " init='auto'\n",
99 | " ):\n",
100 | " \n",
101 | " out_channels = out_channels or in_channels\n",
102 | " padding = (kernel_size-1)//2\n",
103 | " conv_cls = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)\n",
104 | " conv = conv_cls(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)\n",
105 | " if activation is not None: activation = get_activation(activation)\n",
106 | " init_linear(conv, activation, init=init)\n",
107 | " layers = [conv]\n",
108 | " if activation is not None: layers.append(activation)\n",
109 | " if norm is not None: layers.append(get_norm(norm, out_channels))\n",
110 | " super().__init__(*layers)"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "id": "cd4f1349-6d9d-4a3c-8be3-92f2a9d5c173",
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "#export\n",
121 | "class Bottleneck(nn.Module):\n",
122 | " \"\"\"\n",
123 | " Defines a bottleneck module with a residual connection\n",
124 | " \"\"\"\n",
125 | "\n",
126 | " def __init__(\n",
127 | " self,\n",
128 | " in_channels,\n",
129 | " out_channels=None,\n",
130 | " kernel_size=3,\n",
131 | " dilation=1,\n",
132 | " groups=1,\n",
133 | " upsample=False,\n",
134 | " downsample=False,\n",
135 | " dropout=0.0,\n",
136 | " ):\n",
137 | " super().__init__()\n",
138 | " self._downsample = downsample\n",
139 | " bottleneck_channels = int(in_channels / 2)\n",
140 | " out_channels = out_channels or in_channels\n",
141 | " padding_size = ((kernel_size - 1) * dilation + 1) // 2\n",
142 | "\n",
143 | " # Define the main conv operation\n",
144 | " assert dilation == 1\n",
145 | " if upsample:\n",
146 | " assert not downsample, 'downsample and upsample not possible simultaneously.'\n",
147 | " bottleneck_conv = nn.ConvTranspose2d(\n",
148 | " bottleneck_channels,\n",
149 | " bottleneck_channels,\n",
150 | " kernel_size=kernel_size,\n",
151 | " bias=False,\n",
152 | " dilation=1,\n",
153 | " stride=2,\n",
154 | " output_padding=padding_size,\n",
155 | " padding=padding_size,\n",
156 | " groups=groups,\n",
157 | " )\n",
158 | " elif downsample:\n",
159 | " bottleneck_conv = nn.Conv2d(\n",
160 | " bottleneck_channels,\n",
161 | " bottleneck_channels,\n",
162 | " kernel_size=kernel_size,\n",
163 | " bias=False,\n",
164 | " dilation=dilation,\n",
165 | " stride=2,\n",
166 | " padding=padding_size,\n",
167 | " groups=groups,\n",
168 | " )\n",
169 | " else:\n",
170 | " bottleneck_conv = nn.Conv2d(\n",
171 | " bottleneck_channels,\n",
172 | " bottleneck_channels,\n",
173 | " kernel_size=kernel_size,\n",
174 | " bias=False,\n",
175 | " dilation=dilation,\n",
176 | " padding=padding_size,\n",
177 | " groups=groups,\n",
178 | " )\n",
179 | "\n",
180 | " self.layers = nn.Sequential(\n",
181 | " OrderedDict(\n",
182 | " [\n",
183 | " # First projection with 1x1 kernel\n",
184 | " ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),\n",
185 | " ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),\n",
186 | " nn.ReLU(inplace=True))),\n",
187 | " # Second conv block\n",
188 | " ('conv', bottleneck_conv),\n",
189 | " ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),\n",
190 | " # Final projection with 1x1 kernel\n",
191 | " ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),\n",
192 | " ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),\n",
193 | " nn.ReLU(inplace=True))),\n",
194 | " # Regulariser\n",
195 | " ('dropout', nn.Dropout2d(p=dropout)),\n",
196 | " ]\n",
197 | " )\n",
198 | " )\n",
199 | "\n",
200 | " if out_channels == in_channels and not downsample and not upsample:\n",
201 | " self.projection = None\n",
202 | " else:\n",
203 | " projection = OrderedDict()\n",
204 | " if upsample:\n",
205 | " projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})\n",
206 | " elif downsample:\n",
207 | " projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})\n",
208 | " projection.update(\n",
209 | " {\n",
210 | " 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),\n",
211 | " 'bn_skip_proj': nn.BatchNorm2d(out_channels),\n",
212 | " }\n",
213 | " )\n",
214 | " self.projection = nn.Sequential(projection)\n",
215 | "\n",
216 | " # pylint: disable=arguments-differ\n",
217 | " def forward(self, *args):\n",
218 | " (x,) = args\n",
219 | " x_residual = self.layers(x)\n",
220 | " if self.projection is not None:\n",
221 | " if self._downsample:\n",
222 | " # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer\n",
223 | " x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)\n",
224 | " return x_residual + self.projection(x)\n",
225 | " return x_residual + x"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": null,
231 | "id": "1bd71c90-9228-4a9f-9821-20328c57ba4d",
232 | "metadata": {},
233 | "outputs": [],
234 | "source": [
235 | "#export\n",
236 | "class Interpolate(nn.Module):\n",
237 | " def __init__(self, scale_factor: int = 2):\n",
238 | " super().__init__()\n",
239 | " self._interpolate = nn.functional.interpolate\n",
240 | " self._scale_factor = scale_factor\n",
241 | "\n",
242 | " # pylint: disable=arguments-differ\n",
243 | " def forward(self, x):\n",
244 | " return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)\n",
245 | "\n",
246 | "\n",
247 | "class UpsamplingConcat(nn.Module):\n",
248 | " def __init__(self, in_channels, out_channels, scale_factor=2):\n",
249 | " super().__init__()\n",
250 | "\n",
251 | " self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)\n",
252 | "\n",
253 | " self.conv = nn.Sequential(\n",
254 | " nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
255 | " nn.BatchNorm2d(out_channels),\n",
256 | " nn.ReLU(inplace=True),\n",
257 | " nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
258 | " nn.BatchNorm2d(out_channels),\n",
259 | " nn.ReLU(inplace=True),\n",
260 | " )\n",
261 | "\n",
262 | " def forward(self, x_to_upsample, x):\n",
263 | " x_to_upsample = self.upsample(x_to_upsample)\n",
264 | " x_to_upsample = torch.cat([x, x_to_upsample], dim=1)\n",
265 | " return self.conv(x_to_upsample)\n",
266 | "\n",
267 | "\n",
268 | "class UpsamplingAdd(nn.Module):\n",
269 | " def __init__(self, in_channels, out_channels, scale_factor=2):\n",
270 | " super().__init__()\n",
271 | " self.upsample_layer = nn.Sequential(\n",
272 | " nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),\n",
273 | " nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),\n",
274 | " nn.BatchNorm2d(out_channels),\n",
275 | " )\n",
276 | "\n",
277 | " def forward(self, x, x_skip):\n",
278 | " x = self.upsample_layer(x)\n",
279 | " return x + x_skip"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": null,
285 | "id": "b46ad863-f4b8-400b-a62e-b1d593ff28ca",
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "#export\n",
290 | "class ResBlock(nn.Module):\n",
291 | " \" A simple resnet Block\"\n",
292 | " def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu'):\n",
293 | " super().__init__()\n",
294 | " self.convs = nn.Sequential(ConvBlock(in_channels, out_channels, kernel_size, stride, norm=norm, activation=activation),\n",
295 | " ConvBlock(out_channels, out_channels, norm=norm, activation=activation)\n",
296 | " )\n",
297 | " id_path = [ConvBlock(in_channels, out_channels, kernel_size=1, activation=None, norm=None)]\n",
298 | " self.activation = get_activation(activation)\n",
299 | " if stride!=1: id_path.insert(1, nn.AvgPool2d(2, stride, ceil_mode=True))\n",
300 | " self.id_path = nn.Sequential(*id_path)\n",
301 | " \n",
302 | " def forward(self, x):\n",
303 | " return self.activation(self.convs(x) + self.id_path(x))"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "id": "87cd945e-757e-4d81-abee-8ba5fce5e02b",
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "res_block = ResBlock(64, 128, stride=2)"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "id": "cc730d95-89a2-454b-8949-fa45879ccf41",
320 | "metadata": {},
321 | "outputs": [
322 | {
323 | "data": {
324 | "text/plain": [
325 | "torch.Size([1, 128, 10, 10])"
326 | ]
327 | },
328 | "execution_count": null,
329 | "metadata": {},
330 | "output_type": "execute_result"
331 | }
332 | ],
333 | "source": [
334 | "x = torch.rand(1,64, 20,20)\n",
335 | "res_block(x).shape"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "id": "4568af38-4494-4fbe-aacb-6d3c5fa79e51",
342 | "metadata": {},
343 | "outputs": [],
344 | "source": [
345 | "#export\n",
346 | "class SpatialGRU(nn.Module):\n",
347 | " \"\"\"A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a\n",
348 | " convolutional gated recurrent unit over the data\"\"\"\n",
349 | "\n",
350 | " def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):\n",
351 | " super().__init__()\n",
352 | " self.input_size = input_size\n",
353 | " self.hidden_size = hidden_size\n",
354 | " self.gru_bias_init = gru_bias_init\n",
355 | "\n",
356 | " self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)\n",
357 | " self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)\n",
358 | "\n",
359 | " self.conv_state_tilde = ConvBlock(\n",
360 | " input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation\n",
361 | " )\n",
362 | "\n",
363 | " def forward(self, x, state=None, flow=None, mode='bilinear'):\n",
364 | " # pylint: disable=unused-argument, arguments-differ\n",
365 | " # Check size\n",
366 | " assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'\n",
367 | " b, timesteps, c, h, w = x.size()\n",
368 | " assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'\n",
369 | "\n",
370 | " # recurrent layers\n",
371 | " rnn_output = []\n",
372 | " rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state\n",
373 | " for t in range(timesteps):\n",
374 | " x_t = x[:, t]\n",
375 | " if flow is not None:\n",
376 | " rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)\n",
377 | "\n",
378 | " # propagate rnn state\n",
379 | " rnn_state = self.gru_cell(x_t, rnn_state)\n",
380 | " rnn_output.append(rnn_state)\n",
381 | "\n",
382 | " # reshape rnn output to batch tensor\n",
383 | " return torch.stack(rnn_output, dim=1)\n",
384 | "\n",
385 | " def gru_cell(self, x, state):\n",
386 | " # Compute gates\n",
387 | " x_and_state = torch.cat([x, state], dim=1)\n",
388 | " update_gate = self.conv_update(x_and_state)\n",
389 | " reset_gate = self.conv_reset(x_and_state)\n",
390 | " # Add bias to initialise gate as close to identity function\n",
391 | " update_gate = torch.sigmoid(update_gate + self.gru_bias_init)\n",
392 | " reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)\n",
393 | "\n",
394 | " # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)\n",
395 | " state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))\n",
396 | "\n",
397 | " output = (1.0 - update_gate) * state + update_gate * state_tilde\n",
398 | " return output"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": null,
404 | "id": "25fdb367-7794-48bc-94aa-1ba421a7c8b0",
405 | "metadata": {},
406 | "outputs": [],
407 | "source": [
408 | "sgru = SpatialGRU(input_size=32, hidden_size=64)"
409 | ]
410 | },
411 | {
412 | "cell_type": "markdown",
413 | "id": "a55cd120-755b-4783-b356-03044cb5abe3",
414 | "metadata": {},
415 | "source": [
416 | "without hidden state"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": null,
422 | "id": "8207186f-761b-4744-8e37-48f47101525d",
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "x = torch.rand(1,3,32,8, 8)\n",
427 | "test_eq(sgru(x).shape, (1,3,64,8,8))"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "id": "bdcaa79f-6d57-4b74-b956-f9dcc3df0589",
433 | "metadata": {},
434 | "source": [
435 | "with hidden\n",
436 | "> hidden.shape = `(bs, hidden_size, h, w)`"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": null,
442 | "id": "53e47adc-2ce8-4cd9-86b4-433cbd47dbd1",
443 | "metadata": {},
444 | "outputs": [
445 | {
446 | "data": {
447 | "text/plain": [
448 | "torch.Size([1, 3, 64, 8, 8])"
449 | ]
450 | },
451 | "execution_count": null,
452 | "metadata": {},
453 | "output_type": "execute_result"
454 | }
455 | ],
456 | "source": [
457 | "x = torch.rand(1,3,32,8, 8)\n",
458 | "hidden = torch.rand(1,64,8,8)\n",
459 | "# test_eq(sgru(x).shape, (1,3,64,8,8))\n",
460 | "\n",
461 | "sgru(x, hidden).shape"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": null,
467 | "id": "4f4a2fe0-535a-4d51-ab22-81bcefac5c7c",
468 | "metadata": {},
469 | "outputs": [],
470 | "source": [
471 | "#export\n",
472 | "class CausalConv3d(nn.Module):\n",
473 | " def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):\n",
474 | " super().__init__()\n",
475 | " assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'\n",
476 | " time_pad = (kernel_size[0] - 1) * dilation[0]\n",
477 | " height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2\n",
478 | " width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2\n",
479 | "\n",
480 | " # Pad temporally on the left\n",
481 | " self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)\n",
482 | " self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)\n",
483 | " self.norm = nn.BatchNorm3d(out_channels)\n",
484 | " self.activation = nn.ReLU(inplace=True)\n",
485 | "\n",
486 | " def forward(self, *inputs):\n",
487 | " (x,) = inputs\n",
488 | " x = self.pad(x)\n",
489 | " x = self.conv(x)\n",
490 | " x = self.norm(x)\n",
491 | " x = self.activation(x)\n",
492 | " return x"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": null,
498 | "id": "f97c1b07-dcb1-444a-8372-bdd3ff9dd95a",
499 | "metadata": {},
500 | "outputs": [],
501 | "source": [
502 | "cc3d = CausalConv3d(64, 128)"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "execution_count": null,
508 | "id": "c00cb8c1-7513-40e2-85ef-646ddd20d585",
509 | "metadata": {},
510 | "outputs": [
511 | {
512 | "data": {
513 | "text/plain": [
514 | "torch.Size([1, 128, 4, 8, 8])"
515 | ]
516 | },
517 | "execution_count": null,
518 | "metadata": {},
519 | "output_type": "execute_result"
520 | }
521 | ],
522 | "source": [
523 | "x = torch.rand(1,64,4,8,8)\n",
524 | "cc3d(x).shape"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": null,
530 | "id": "6ac267c7-a4f3-47dd-a15a-04c61c82d434",
531 | "metadata": {},
532 | "outputs": [],
533 | "source": [
534 | "#export\n",
535 | "def conv_1x1x1_norm_activated(in_channels, out_channels):\n",
536 | " \"\"\"1x1x1 3D convolution, normalization and activation layer.\"\"\"\n",
537 | " return nn.Sequential(\n",
538 | " OrderedDict(\n",
539 | " [\n",
540 | " ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),\n",
541 | " ('norm', nn.BatchNorm3d(out_channels)),\n",
542 | " ('activation', nn.ReLU(inplace=True)),\n",
543 | " ]\n",
544 | " )\n",
545 | " )"
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "execution_count": null,
551 | "id": "e318b3fd-d944-4d98-8965-88a1dd2d8d8e",
552 | "metadata": {},
553 | "outputs": [
554 | {
555 | "data": {
556 | "text/plain": [
557 | "Sequential(\n",
558 | " (conv): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
559 | " (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
560 | " (activation): ReLU(inplace=True)\n",
561 | ")"
562 | ]
563 | },
564 | "execution_count": null,
565 | "metadata": {},
566 | "output_type": "execute_result"
567 | }
568 | ],
569 | "source": [
570 | "conv_1x1x1_norm_activated(2, 4)"
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": null,
576 | "id": "5574dfd4-3891-464d-b0d6-a67466fed549",
577 | "metadata": {},
578 | "outputs": [],
579 | "source": [
580 | "#export\n",
581 | "class Bottleneck3D(nn.Module):\n",
582 | " \"\"\"\n",
583 | " Defines a bottleneck module with a residual connection\n",
584 | " \"\"\"\n",
585 | "\n",
586 | " def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):\n",
587 | " super().__init__()\n",
588 | " bottleneck_channels = in_channels // 2\n",
589 | " out_channels = out_channels or in_channels\n",
590 | "\n",
591 | " self.layers = nn.Sequential(\n",
592 | " OrderedDict(\n",
593 | " [\n",
594 | " # First projection with 1x1 kernel\n",
595 | " ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)),\n",
596 | " # Second conv block\n",
597 | " (\n",
598 | " 'conv',\n",
599 | " CausalConv3d(\n",
600 | " bottleneck_channels,\n",
601 | " bottleneck_channels,\n",
602 | " kernel_size=kernel_size,\n",
603 | " dilation=dilation,\n",
604 | " bias=False,\n",
605 | " ),\n",
606 | " ),\n",
607 | " # Final projection with 1x1 kernel\n",
608 | " ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)),\n",
609 | " ]\n",
610 | " )\n",
611 | " )\n",
612 | "\n",
613 | " if out_channels != in_channels:\n",
614 | " self.projection = nn.Sequential(\n",
615 | " nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),\n",
616 | " nn.BatchNorm3d(out_channels),\n",
617 | " )\n",
618 | " else:\n",
619 | " self.projection = None\n",
620 | "\n",
621 | " def forward(self, *args):\n",
622 | " (x,) = args\n",
623 | " x_residual = self.layers(x)\n",
624 | " x_features = self.projection(x) if self.projection is not None else x\n",
625 | " return x_residual + x_features"
626 | ]
627 | },
628 | {
629 | "cell_type": "code",
630 | "execution_count": null,
631 | "id": "3dbcf0d0-1cc2-4914-9ba0-30b99f03a4c3",
632 | "metadata": {},
633 | "outputs": [],
634 | "source": [
635 | "bn3d = Bottleneck3D(8, 12)"
636 | ]
637 | },
638 | {
639 | "cell_type": "code",
640 | "execution_count": null,
641 | "id": "12917923-c26f-4c40-b2ee-2836d2ffb4f6",
642 | "metadata": {},
643 | "outputs": [
644 | {
645 | "data": {
646 | "text/plain": [
647 | "torch.Size([1, 12, 4, 8, 8])"
648 | ]
649 | },
650 | "execution_count": null,
651 | "metadata": {},
652 | "output_type": "execute_result"
653 | }
654 | ],
655 | "source": [
656 | "x = torch.rand(1,8,4,8,8)\n",
657 | "bn3d(x).shape"
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": null,
663 | "id": "1bd3b9ac-7faa-4da6-b59b-86792e649590",
664 | "metadata": {},
665 | "outputs": [],
666 | "source": [
667 | "#export\n",
668 | "class PyramidSpatioTemporalPooling(nn.Module):\n",
669 | " \"\"\" Spatio-temporal pyramid pooling.\n",
670 | " Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.\n",
671 | " Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]\n",
672 | " \"\"\"\n",
673 | "\n",
674 | " def __init__(self, in_channels, reduction_channels, pool_sizes):\n",
675 | " super().__init__()\n",
676 | " self.features = []\n",
677 | " for pool_size in pool_sizes:\n",
678 | " assert pool_size[0] == 2, (\n",
679 | " \"Time kernel should be 2 as PyTorch raises an error when\" \"padding with more than half the kernel size\"\n",
680 | " )\n",
681 | " stride = (1, *pool_size[1:])\n",
682 | " padding = (pool_size[0] - 1, 0, 0)\n",
683 | " self.features.append(\n",
684 | " nn.Sequential(\n",
685 | " OrderedDict(\n",
686 | " [\n",
687 | " # Pad the input tensor but do not take into account zero padding into the average.\n",
688 | " (\n",
689 | " 'avgpool',\n",
690 | " torch.nn.AvgPool3d(\n",
691 | " kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False\n",
692 | " ),\n",
693 | " ),\n",
694 | " ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)),\n",
695 | " ]\n",
696 | " )\n",
697 | " )\n",
698 | " )\n",
699 | " self.features = nn.ModuleList(self.features)\n",
700 | "\n",
701 | " def forward(self, *inputs):\n",
702 | " (x,) = inputs\n",
703 | " b, _, t, h, w = x.shape\n",
704 | " # Do not include current tensor when concatenating\n",
705 | " out = []\n",
706 | " for f in self.features:\n",
707 | " # Remove unnecessary padded values (time dimension) on the right\n",
708 | " x_pool = f(x)[:, :, :-1].contiguous()\n",
709 | " c = x_pool.shape[1]\n",
710 | " x_pool = nn.functional.interpolate(\n",
711 | " x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False\n",
712 | " )\n",
713 | " x_pool = x_pool.view(b, c, t, h, w)\n",
714 | " out.append(x_pool)\n",
715 | " out = torch.cat(out, 1)\n",
716 | " return out"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": null,
722 | "id": "45902f68-7d7f-4935-ac97-00e332ef9d48",
723 | "metadata": {},
724 | "outputs": [],
725 | "source": [
726 | "h,w = (64,64)\n",
727 | "ptp = PyramidSpatioTemporalPooling(12, 12, [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)])"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": null,
733 | "id": "0e0db71a-dfef-4389-bc9c-483adc060008",
734 | "metadata": {},
735 | "outputs": [
736 | {
737 | "data": {
738 | "text/plain": [
739 | "torch.Size([1, 36, 4, 64, 64])"
740 | ]
741 | },
742 | "execution_count": null,
743 | "metadata": {},
744 | "output_type": "execute_result"
745 | }
746 | ],
747 | "source": [
748 | "x = torch.rand(1,12,4,64,64)\n",
749 | "\n",
750 | "#the output is concatenated...\n",
751 | "ptp(x).shape"
752 | ]
753 | },
754 | {
755 | "cell_type": "code",
756 | "execution_count": null,
757 | "id": "7575bf79-6718-4a70-90ed-985b79fea483",
758 | "metadata": {},
759 | "outputs": [],
760 | "source": [
761 | "#export\n",
762 | "class TemporalBlock(nn.Module):\n",
763 | " \"\"\" Temporal block with the following layers:\n",
764 | " - 2x3x3, 1x3x3, spatio-temporal pyramid pooling\n",
765 | " - dropout\n",
766 | " - skip connection.\n",
767 | " \"\"\"\n",
768 | "\n",
769 | " def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):\n",
770 | " super().__init__()\n",
771 | " self.in_channels = in_channels\n",
772 | " self.half_channels = in_channels // 2\n",
773 | " self.out_channels = out_channels or self.in_channels\n",
774 | " self.kernels = [(2, 3, 3), (1, 3, 3)]\n",
775 | "\n",
776 | " # Flag for spatio-temporal pyramid pooling\n",
777 | " self.use_pyramid_pooling = use_pyramid_pooling\n",
778 | "\n",
779 | " # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1\n",
780 | " self.convolution_paths = []\n",
781 | " for kernel_size in self.kernels:\n",
782 | " self.convolution_paths.append(\n",
783 | " nn.Sequential(\n",
784 | " conv_1x1x1_norm_activated(self.in_channels, self.half_channels),\n",
785 | " CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),\n",
786 | " )\n",
787 | " )\n",
788 | " self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))\n",
789 | " self.convolution_paths = nn.ModuleList(self.convolution_paths)\n",
790 | "\n",
791 | " agg_in_channels = len(self.convolution_paths) * self.half_channels\n",
792 | "\n",
793 | " if self.use_pyramid_pooling:\n",
794 | " assert pool_sizes is not None, \"setting must contain the list of kernel_size, but is None.\"\n",
795 | " reduction_channels = self.in_channels // 3\n",
796 | " self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)\n",
797 | " agg_in_channels += len(pool_sizes) * reduction_channels\n",
798 | "\n",
799 | " # Feature aggregation\n",
800 | " self.aggregation = nn.Sequential(\n",
801 | " conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),)\n",
802 | "\n",
803 | " if self.out_channels != self.in_channels:\n",
804 | " self.projection = nn.Sequential(\n",
805 | " nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),\n",
806 | " nn.BatchNorm3d(self.out_channels),\n",
807 | " )\n",
808 | " else:\n",
809 | " self.projection = None\n",
810 | "\n",
811 | " def forward(self, *inputs):\n",
812 | " (x,) = inputs\n",
813 | " x_paths = []\n",
814 | " for conv in self.convolution_paths:\n",
815 | " x_paths.append(conv(x))\n",
816 | " x_residual = torch.cat(x_paths, dim=1)\n",
817 | " if self.use_pyramid_pooling:\n",
818 | " x_pool = self.pyramid_pooling(x)\n",
819 | " x_residual = torch.cat([x_residual, x_pool], dim=1)\n",
820 | " x_residual = self.aggregation(x_residual)\n",
821 | "\n",
822 | " if self.out_channels != self.in_channels:\n",
823 | " x = self.projection(x)\n",
824 | " x = x + x_residual\n",
825 | " return x"
826 | ]
827 | },
828 | {
829 | "cell_type": "code",
830 | "execution_count": null,
831 | "id": "385f0475-0c44-4b90-83bf-6869c24e105e",
832 | "metadata": {},
833 | "outputs": [
834 | {
835 | "data": {
836 | "text/plain": [
837 | "torch.Size([1, 8, 4, 64, 64])"
838 | ]
839 | },
840 | "execution_count": null,
841 | "metadata": {},
842 | "output_type": "execute_result"
843 | }
844 | ],
845 | "source": [
846 | "tb = TemporalBlock(4,8)\n",
847 | "\n",
848 | "x = torch.rand(1,4,4,64,64)\n",
849 | "\n",
850 | "tb(x).shape"
851 | ]
852 | },
853 | {
854 | "cell_type": "code",
855 | "execution_count": null,
856 | "id": "65e50f4d-40f1-48fc-b390-66d0fe7a4f16",
857 | "metadata": {},
858 | "outputs": [
859 | {
860 | "data": {
861 | "text/plain": [
862 | "TemporalBlock(\n",
863 | " (convolution_paths): ModuleList(\n",
864 | " (0): Sequential(\n",
865 | " (0): Sequential(\n",
866 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
867 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
868 | " (activation): ReLU(inplace=True)\n",
869 | " )\n",
870 | " (1): CausalConv3d(\n",
871 | " (pad): ConstantPad3d(padding=(1, 1, 1, 1, 1, 0), value=0)\n",
872 | " (conv): Conv3d(2, 2, kernel_size=(2, 3, 3), stride=(1, 1, 1), bias=False)\n",
873 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
874 | " (activation): ReLU(inplace=True)\n",
875 | " )\n",
876 | " )\n",
877 | " (1): Sequential(\n",
878 | " (0): Sequential(\n",
879 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
880 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
881 | " (activation): ReLU(inplace=True)\n",
882 | " )\n",
883 | " (1): CausalConv3d(\n",
884 | " (pad): ConstantPad3d(padding=(1, 1, 1, 1, 0, 0), value=0)\n",
885 | " (conv): Conv3d(2, 2, kernel_size=(1, 3, 3), stride=(1, 1, 1), bias=False)\n",
886 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
887 | " (activation): ReLU(inplace=True)\n",
888 | " )\n",
889 | " )\n",
890 | " (2): Sequential(\n",
891 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
892 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
893 | " (activation): ReLU(inplace=True)\n",
894 | " )\n",
895 | " )\n",
896 | " (aggregation): Sequential(\n",
897 | " (0): Sequential(\n",
898 | " (conv): Conv3d(6, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
899 | " (norm): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
900 | " (activation): ReLU(inplace=True)\n",
901 | " )\n",
902 | " )\n",
903 | " (projection): Sequential(\n",
904 | " (0): Conv3d(4, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
905 | " (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
906 | " )\n",
907 | ")"
908 | ]
909 | },
910 | "execution_count": null,
911 | "metadata": {},
912 | "output_type": "execute_result"
913 | }
914 | ],
915 | "source": [
916 | "tb"
917 | ]
918 | },
919 | {
920 | "cell_type": "code",
921 | "execution_count": null,
922 | "id": "5f671904-4256-48d5-8787-72556d225456",
923 | "metadata": {},
924 | "outputs": [],
925 | "source": [
926 | "#export\n",
927 | "class FuturePrediction(torch.nn.Module):\n",
928 | " def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):\n",
929 | " super().__init__()\n",
930 | " self.n_gru_blocks = n_gru_blocks\n",
931 | "\n",
932 | " # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample\n",
933 | " # from the probabilistic model. The architecture of the model is:\n",
934 | " # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks\n",
935 | " self.spatial_grus = []\n",
936 | " self.res_blocks = []\n",
937 | "\n",
938 | " for i in range(self.n_gru_blocks):\n",
939 | " gru_in_channels = latent_dim if i == 0 else in_channels\n",
940 | " self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))\n",
941 | " self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)\n",
942 | " for _ in range(n_res_layers)]))\n",
943 | "\n",
944 | " self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)\n",
945 | " self.res_blocks = torch.nn.ModuleList(self.res_blocks)\n",
946 | "\n",
947 | " def forward(self, x, hidden_state):\n",
948 | " # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)\n",
949 | " for i in range(self.n_gru_blocks):\n",
950 | " x = self.spatial_grus[i](x, hidden_state, flow=None)\n",
951 | " b, n_future, c, h, w = x.shape\n",
952 | "\n",
953 | " x = self.res_blocks[i](x.view(b * n_future, c, h, w))\n",
954 | " x = x.view(b, n_future, c, h, w)\n",
955 | "\n",
956 | " return x"
957 | ]
958 | },
959 | {
960 | "cell_type": "code",
961 | "execution_count": null,
962 | "id": "7ce0e2ce-a862-4631-b5f7-a6b2c9928415",
963 | "metadata": {},
964 | "outputs": [],
965 | "source": [
966 | "fp = FuturePrediction(32, 32, 4, 4)\n",
967 | "\n",
968 | "x = torch.rand(1,4, 32, 64, 64)\n",
969 | "hidden = torch.rand(1,32,64,64)\n",
970 | "test_eq(fp(x, hidden).shape, x.shape)"
971 | ]
972 | },
973 | {
974 | "cell_type": "markdown",
975 | "id": "542aa109-21e9-48b8-a87a-e4bbd5f59b8d",
976 | "metadata": {},
977 | "source": [
978 | "## Export"
979 | ]
980 | },
981 | {
982 | "cell_type": "code",
983 | "execution_count": null,
984 | "id": "d4c520b5-d666-4ebb-a9c0-02992b822b8d",
985 | "metadata": {},
986 | "outputs": [
987 | {
988 | "name": "stdout",
989 | "output_type": "stream",
990 | "text": [
991 | "Converted 00_model.ipynb.\n",
992 | "Converted 01_layers.ipynb.\n",
993 | "Converted index.ipynb.\n"
994 | ]
995 | }
996 | ],
997 | "source": [
998 | "# hide\n",
999 | "from nbdev.export import *\n",
1000 | "notebook2script()"
1001 | ]
1002 | }
1003 | ],
1004 | "metadata": {
1005 | "kernelspec": {
1006 | "display_name": "Python 3",
1007 | "language": "python",
1008 | "name": "python3"
1009 | }
1010 | },
1011 | "nbformat": 4,
1012 | "nbformat_minor": 5
1013 | }
1014 |
--------------------------------------------------------------------------------