├── .gitignore
├── .idea
├── .gitignore
├── .name
├── SwinLSTM-D.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── Pretrained
└── trained_model_state_dict
├── README.md
├── SwinLSTM_B.py
├── SwinLSTM_D.py
├── __pycache__
├── SwinLSTM_D.cpython-39.pyc
├── configs.cpython-39.pyc
├── dataset.cpython-39.pyc
├── functions.cpython-39.pyc
└── utils.cpython-39.pyc
├── architecture.png
├── configs.py
├── data
├── test_data_download_link.md
└── train-images-idx3-ubyte.gz
├── dataset.py
├── functions.py
├── requirements.txt
├── test.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/.name:
--------------------------------------------------------------------------------
1 | main.py
--------------------------------------------------------------------------------
/.idea/SwinLSTM-D.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Song Tang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Pretrained/trained_model_state_dict:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/Pretrained/trained_model_state_dict
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SwinLSTM: A new recurrent cell for spatiotemporal modeling
2 |
3 | This repository contains the official PyTorch implementation of the following paper:
4 |
5 | SwinLSTM: Improving Spatiotemporal Prediction Accuracy using Swin Transformer and LSTM **(ICCV 2023)**
6 |
7 | Paper:http://arxiv.org/abs/2308.09891v2
8 |
9 |
10 |
11 |
12 | ## Introduction
13 | 
14 | Integrating CNNs and RNNs to capture spatiotemporal dependencies is a prevalent strategy for spatiotemporal prediction tasks. However, the property of CNNs to learn local spatial information decreases their efficiency in capturing spatiotemporal dependencies, thereby limiting their prediction accuracy. In this paper, we propose a new recurrent cell, SwinLSTM, which integrates Swin Transformer blocks and the simplified LSTM, an extension that replaces the convolutional structure in ConvLSTM with the self-attention mechanism. Furthermore, we construct a network with SwinLSTM cell as the core for spatiotemporal prediction. Without using unique tricks, SwinLSTM outperforms state-of-the-art methods on Moving MNIST, Human3.6m, TaxiBJ, and KTH datasets. In particular, it exhibits a significant improvement in prediction accuracy compared to ConvLSTM. Our competitive experimental results demonstrate that learning global spatial dependencies is more advantageous for models to capture spatiotemporal dependencies. We hope that SwinLSTM can serve as a solid baseline to promote the advancement of spatiotemporal prediction accuracy.
15 |
16 | ## Overview
17 | - `Pretrained/` contains pretrained weights on MovingMNIST.
18 | - `data/` contains the MNIST dataset and the MovingMNIST test set download link.
19 | - `SwinLSTM_B.py` contains the model with a single SwinLSTM cell.
20 | - `SwinLSTM_D.py` contains the model with a multiple SwinLSTM cell.
21 | - `dataset.py` contains training and validation dataloaders.
22 | - `functions.py` contains train and test functions.
23 | - `train.py` is the core file for training pipeline.
24 | - `test.py` is a file for a quick test.
25 |
26 | ## Requirements
27 | - python >= 3.8
28 | - torch == 1.11.0
29 | - torchvision == 0.12.0
30 | - numpy
31 | - matplotlib
32 | - skimage == 0.19.2
33 | - timm == 0.4.12
34 | - einops == 0.4.1
35 |
36 | ## Citation
37 | If you find this work useful in your research, please cite the paper:
38 | ```
39 | @inproceedings{tang2023swinlstm,
40 | title={SwinLSTM: Improving Spatiotemporal Prediction Accuracy using Swin Transformer and LSTM},
41 | author={Tang, Song and Li, Chuang and Zhang, Pu and Tang, RongNian},
42 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
43 | pages={13470--13479},
44 | year={2023}
45 | }
46 |
47 | ```
48 |
49 | ## Acknowledgment
50 | These codes are based on [Swin Transformer](https://github.com/microsoft/Swin-Transformer). We extend our sincere appreciation for their valuable contributions.
51 |
52 |
--------------------------------------------------------------------------------
/SwinLSTM_B.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
4 |
5 |
6 | class Mlp(nn.Module):
7 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
8 | super(Mlp, self).__init__()
9 | out_features = out_features or in_features
10 | hidden_features = hidden_features or in_features
11 | self.fc1 = nn.Linear(in_features, hidden_features)
12 | self.act = act_layer()
13 | self.fc2 = nn.Linear(hidden_features, out_features)
14 | self.drop = nn.Dropout(drop)
15 |
16 | def forward(self, x):
17 | x = self.fc1(x)
18 | x = self.act(x)
19 | x = self.drop(x)
20 | x = self.fc2(x)
21 | x = self.drop(x)
22 | return x
23 |
24 |
25 | def window_partition(x, window_size):
26 | """
27 | Args:
28 | x: (B, H, W, C)
29 | window_size (int): window size
30 |
31 | Returns:
32 | windows: (num_windows*B, window_size, window_size, C)
33 | """
34 | B, H, W, C = x.shape
35 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
36 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37 | return windows
38 |
39 |
40 | def window_reverse(windows, window_size, H, W):
41 | """
42 | Args:
43 | windows: (num_windows*B, window_size, window_size, C)
44 | window_size (int): Window size
45 | H (int): Height of image
46 | W (int): Width of image
47 |
48 | Returns:
49 | x: (B, H, W, C)
50 | """
51 | B = int(windows.shape[0] / (H * W / window_size / window_size))
52 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
53 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
54 | return x
55 |
56 |
57 | class WindowAttention(nn.Module):
58 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
59 | It supports both of shifted and non-shifted window.
60 |
61 | Args:
62 | dim (int): Number of input channels.
63 | window_size (tuple[int]): The height and width of the window.
64 | num_heads (int): Number of attention heads.
65 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
66 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
67 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
68 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
69 | """
70 |
71 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
72 |
73 | super(WindowAttention, self).__init__()
74 | self.dim = dim
75 | self.window_size = window_size # Wh, Ww
76 | self.num_heads = num_heads
77 | head_dim = dim // num_heads
78 | self.scale = qk_scale or head_dim ** -0.5
79 |
80 | # define a parameter table of relative position bias
81 | self.relative_position_bias_table = nn.Parameter(
82 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
83 |
84 | # get pair-wise relative position index for each token inside the window
85 | coords_h = torch.arange(self.window_size[0])
86 | coords_w = torch.arange(self.window_size[1])
87 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
88 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
89 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
90 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
91 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
92 | relative_coords[:, :, 1] += self.window_size[1] - 1
93 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
94 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
95 | self.register_buffer("relative_position_index", relative_position_index)
96 |
97 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
98 | self.attn_drop = nn.Dropout(attn_drop)
99 | self.proj = nn.Linear(dim, dim)
100 | self.proj_drop = nn.Dropout(proj_drop)
101 |
102 | trunc_normal_(self.relative_position_bias_table, std=.02)
103 | self.softmax = nn.Softmax(dim=-1)
104 |
105 | def forward(self, x, mask=None):
106 | """
107 | Args:
108 | x: input features with shape of (num_windows*B, N, C)
109 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
110 | """
111 |
112 | B_, N, C = x.shape
113 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
114 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
115 |
116 | q = q * self.scale
117 | attn = (q @ k.transpose(-2, -1))
118 |
119 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
120 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
121 |
122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
123 | attn = attn + relative_position_bias.unsqueeze(0)
124 |
125 | if mask is not None:
126 | nW = mask.shape[0]
127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
128 | attn = attn.view(-1, self.num_heads, N, N)
129 | attn = self.softmax(attn)
130 | else:
131 | attn = self.softmax(attn)
132 |
133 | attn = self.attn_drop(attn)
134 |
135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
136 | x = self.proj(x)
137 | x = self.proj_drop(x)
138 | return x
139 |
140 |
141 | class SwinTransformerBlock(nn.Module):
142 | r""" Swin Transformer Block.
143 |
144 | Args:
145 | dim (int): Number of input channels.
146 | input_resolution (tuple[int]): Input resulotion.
147 | num_heads (int): Number of attention heads.
148 | window_size (int): Window size.
149 | shift_size (int): Shift size for SW-MSA.
150 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
151 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
152 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
153 | drop (float, optional): Dropout rate. Default: 0.0
154 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
155 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
156 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
157 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
158 | """
159 |
160 | def __init__(self, dim, input_resolution, num_heads, window_size=2, shift_size=0,
161 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
162 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
163 | super().__init__()
164 | self.dim = dim
165 | self.input_resolution = input_resolution
166 | self.num_heads = num_heads
167 | self.window_size = window_size
168 | self.shift_size = shift_size
169 | self.mlp_ratio = mlp_ratio
170 | if min(self.input_resolution) <= self.window_size:
171 | # if window size is larger than input resolution, we don't partition windows
172 | self.shift_size = 0
173 | self.window_size = min(self.input_resolution)
174 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
175 |
176 | self.norm1 = norm_layer(dim)
177 | self.attn = WindowAttention(
178 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
179 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
180 |
181 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
182 | self.norm2 = norm_layer(dim)
183 | mlp_hidden_dim = int(dim * mlp_ratio)
184 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
185 | self.red = nn.Linear(2 * dim, dim)
186 | if self.shift_size > 0:
187 | # calculate attention mask for SW-MSA
188 | H, W = self.input_resolution
189 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
190 | h_slices = (slice(0, -self.window_size),
191 | slice(-self.window_size, -self.shift_size),
192 | slice(-self.shift_size, None))
193 | w_slices = (slice(0, -self.window_size),
194 | slice(-self.window_size, -self.shift_size),
195 | slice(-self.shift_size, None))
196 | cnt = 0
197 | for h in h_slices:
198 | for w in w_slices:
199 | img_mask[:, h, w, :] = cnt
200 | cnt += 1
201 |
202 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
203 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
204 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
205 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
206 | else:
207 | attn_mask = None
208 |
209 | self.register_buffer("attn_mask", attn_mask)
210 |
211 | def forward(self, x, hx=None):
212 | H, W = self.input_resolution
213 | B, L, C = x.shape
214 | assert L == H * W, "input feature has wrong size"
215 |
216 | shortcut = x
217 | x = self.norm1(x)
218 | if hx is not None:
219 | hx = self.norm1(hx)
220 | x = torch.cat((x, hx), -1)
221 | x = self.red(x)
222 | x = x.view(B, H, W, C)
223 |
224 | # cyclic shift
225 | if self.shift_size > 0:
226 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
227 | else:
228 | shifted_x = x
229 |
230 | # partition windows
231 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
232 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
233 |
234 | # W-MSA/SW-MSA
235 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
236 |
237 | # merge windows
238 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
239 |
240 | # reverse cyclic shift
241 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
242 |
243 | if self.shift_size > 0:
244 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
245 | else:
246 | x = shifted_x
247 |
248 | # FFN
249 | x = x.view(B, H * W, C)
250 | x = shortcut + self.drop_path(x)
251 | x = x + self.drop_path(self.mlp(self.norm2(x)))
252 |
253 | return x
254 |
255 |
256 | class PatchEmbed(nn.Module):
257 | r""" Image to Patch Embedding
258 |
259 | Args:
260 | img_size (int): Image size.
261 | patch_size (int): Patch token size.
262 | in_chans (int): Number of input image channels.
263 | embed_dim (int): Number of linear projection output channels.
264 | """
265 |
266 | def __init__(self, img_size, patch_size, in_chans, embed_dim):
267 | super(PatchEmbed, self).__init__()
268 | img_size = to_2tuple(img_size)
269 | patch_size = to_2tuple(patch_size)
270 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
271 | self.img_size = img_size
272 | self.patch_size = patch_size
273 | self.patches_resolution = patches_resolution
274 | self.num_patches = patches_resolution[0] * patches_resolution[1]
275 | self.in_chans = in_chans
276 | self.embed_dim = embed_dim
277 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
278 | self.norm = nn.LayerNorm(embed_dim)
279 |
280 | def forward(self, x):
281 | B, C, H, W = x.shape
282 | assert H == self.img_size[0] and W == self.img_size[1], \
283 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
284 | x = self.proj(x).flatten(2).transpose(1, 2)
285 | x = self.norm(x)
286 | return x
287 |
288 |
289 | class PatchInflated(nn.Module):
290 | r""" Tensor to Patch Inflating
291 |
292 | Args:
293 | in_chans (int): Number of input image channels.
294 | embed_dim (int): Number of linear projection output channels.
295 | input_resolution (tuple[int]): Input resulotion.
296 | """
297 |
298 | def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1):
299 | super(PatchInflated, self).__init__()
300 |
301 | stride = to_2tuple(stride)
302 | padding = to_2tuple(padding)
303 | output_padding = to_2tuple(output_padding)
304 | self.input_resolution = input_resolution
305 |
306 | self.ConvT = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3),
307 | stride=stride, padding=padding, output_padding=output_padding)
308 |
309 | def forward(self, x):
310 | H, W = self.input_resolution
311 | B, L, C = x.shape
312 | assert L == H * W, "input feature has wrong size"
313 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
314 |
315 | x = x.view(B, H, W, C)
316 | x = x.permute(0, 3, 1, 2)
317 | x = self.ConvT(x)
318 |
319 | return x
320 |
321 |
322 | class SwinTransformerBlocks(nn.Module):
323 |
324 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None,
325 | drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm):
326 | super(SwinTransformerBlocks, self).__init__()
327 | self.layers = nn.ModuleList([
328 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
329 | num_heads=num_heads, window_size=window_size,
330 | shift_size=0 if (i % 2 == 0) else window_size // 2,
331 | mlp_ratio=mlp_ratio,
332 | qkv_bias=qkv_bias, qk_scale=qk_scale,
333 | drop=drop, attn_drop=attn_drop,
334 | drop_path=drop_path,
335 | norm_layer=norm_layer)
336 | for i in range(depth)])
337 |
338 | def forward(self, xt, hx):
339 |
340 | outputs = []
341 |
342 | for index, layer in enumerate(self.layers):
343 | if index == 0:
344 | x = layer(xt, hx)
345 | outputs.append(x)
346 |
347 | else:
348 | if index % 2 == 0:
349 | x = layer(outputs[-1], xt)
350 | outputs.append(x)
351 |
352 | if index % 2 == 1:
353 | x = layer(outputs[-1], None)
354 | outputs.append(x)
355 |
356 | return outputs[-1]
357 |
358 |
359 | class SwinLSTMCell(nn.Module):
360 |
361 | def __init__(self, dim, input_resolution, num_heads, window_size, depth,
362 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
363 | drop_path=0., norm_layer=nn.LayerNorm):
364 | super(SwinLSTMCell, self).__init__()
365 |
366 | self.Swin = SwinTransformerBlocks(dim=dim, input_resolution=input_resolution, depth=depth,
367 | num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
368 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
369 | drop_path=drop_path, norm_layer=norm_layer)
370 |
371 | def forward(self, xt, hidden_states):
372 | if hidden_states is None:
373 | B, L, C = xt.shape
374 | hx = torch.zeros(B, L, C).to(xt.device)
375 | cx = torch.zeros(B, L, C).to(xt.device)
376 |
377 | else:
378 | hx, cx = hidden_states
379 |
380 | Ft = self.Swin(xt, hx)
381 |
382 | gate = torch.sigmoid(Ft)
383 | cell = torch.tanh(Ft)
384 |
385 | cy = gate * (cx + cell)
386 | hy = gate * torch.tanh(cy)
387 | hx = hy
388 | cx = cy
389 |
390 | return hx, (hx, cx)
391 |
392 |
393 | class STconvert(nn.Module):
394 | r""" STconvert
395 |
396 | Args:
397 | img_size (int | tuple(int)): Input image size.
398 | patch_size (int | tuple(int)): Patch size.
399 | in_chans (int): Number of input image channels.
400 | embed_dim (int): Patch embedding dimension.
401 | depths (tuple(int)): Depth of Swin Transformer layer.
402 | num_heads (tuple(int)): Number of attention heads in different layers.
403 | window_size (int): Window size.
404 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
405 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
406 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
407 | drop_rate (float): Dropout rate. Default: 0
408 | attn_drop_rate (float): Attention dropout rate. Default: 0
409 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
410 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
411 | """
412 |
413 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size,
414 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
415 | norm_layer=nn.LayerNorm):
416 |
417 | super(STconvert, self).__init__()
418 |
419 | self.num_layers = len(depths)
420 | self.embed_dim = embed_dim
421 | self.mlp_ratio = mlp_ratio
422 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
423 | in_chans=in_chans, embed_dim=embed_dim)
424 | patches_resolution = self.patch_embed.patches_resolution
425 |
426 | self.PatchInflated = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution)
427 | self.layers = nn.ModuleList()
428 |
429 | for i_layer in range(self.num_layers):
430 | layer = SwinLSTMCell(dim=embed_dim,
431 | input_resolution=(patches_resolution[0], patches_resolution[1]),
432 | depth=depths[i_layer],
433 | num_heads=num_heads[i_layer],
434 | window_size=window_size,
435 | mlp_ratio=self.mlp_ratio,
436 | qkv_bias=qkv_bias, qk_scale=qk_scale,
437 | drop=drop_rate, attn_drop=attn_drop_rate,
438 | drop_path=drop_path_rate,
439 | norm_layer=norm_layer)
440 |
441 | self.layers.append(layer)
442 |
443 | def forward(self, x, h):
444 |
445 | x = self.patch_embed(x)
446 |
447 | hidden_states = []
448 |
449 | for index, layer in enumerate(self.layers):
450 | x, hidden_state = layer(x, h[index])
451 | hidden_states.append(hidden_state)
452 |
453 | x = torch.sigmoid(self.PatchInflated(x))
454 |
455 | return hidden_states, x
456 |
457 |
458 | class SwinLSTM(nn.Module):
459 | r""" SwinLSTM
460 |
461 | Args:
462 | img_size (int | tuple(int)): Input image size.
463 | patch_size (int | tuple(int)): Patch size.
464 | in_chans (int): Number of input image channels.
465 | embed_dim (int): Patch embedding dimension.
466 | depths (tuple(int)): Depth of Swin Transformer layer.
467 | num_heads (tuple(int)): Number of attention heads in different layers.
468 | window_size (int): Window size.
469 | drop_rate (float): Dropout rate.
470 | attn_drop_rate (float): Attention dropout rate.
471 | drop_path_rate (float): Stochastic depth rate.
472 | """
473 |
474 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths,
475 | num_heads, window_size, drop_rate, attn_drop_rate, drop_path_rate):
476 | super(SwinLSTM, self).__init__()
477 |
478 | self.ST = STconvert(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
479 | embed_dim=embed_dim, depths=depths,
480 | num_heads=num_heads, window_size=window_size, drop_rate=drop_rate,
481 | attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate)
482 |
483 | def forward(self, input, states):
484 | states_next, output = self.ST(input, states)
485 |
486 | return output, states_next
487 |
--------------------------------------------------------------------------------
/SwinLSTM_D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5 |
6 |
7 | class Mlp(nn.Module):
8 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
9 | super(Mlp, self).__init__()
10 | out_features = out_features or in_features
11 | hidden_features = hidden_features or in_features
12 | self.fc1 = nn.Linear(in_features, hidden_features)
13 | self.act = act_layer()
14 | self.fc2 = nn.Linear(hidden_features, out_features)
15 | self.drop = nn.Dropout(drop)
16 |
17 | def forward(self, x):
18 | x = self.fc1(x)
19 | x = self.act(x)
20 | x = self.drop(x)
21 | x = self.fc2(x)
22 | x = self.drop(x)
23 | return x
24 |
25 |
26 | def window_partition(x, window_size):
27 | """
28 | Args:
29 | x: (B, H, W, C)
30 | window_size (int): window size
31 |
32 | Returns:
33 | windows: (num_windows*B, window_size, window_size, C)
34 | """
35 | B, H, W, C = x.shape
36 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
37 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
38 | return windows
39 |
40 |
41 | def window_reverse(windows, window_size, H, W):
42 | """
43 | Args:
44 | windows: (num_windows*B, window_size, window_size, C)
45 | window_size (int): Window size
46 | H (int): Height of image
47 | W (int): Width of image
48 |
49 | Returns:
50 | x: (B, H, W, C)
51 | """
52 | B = int(windows.shape[0] / (H * W / window_size / window_size))
53 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
54 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
55 | return x
56 |
57 |
58 | class WindowAttention(nn.Module):
59 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
60 | It supports both of shifted and non-shifted window.
61 |
62 | Args:
63 | dim (int): Number of input channels.
64 | window_size (tuple[int]): The height and width of the window.
65 | num_heads (int): Number of attention heads.
66 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
67 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
68 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
69 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
70 | """
71 |
72 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
73 |
74 | super(WindowAttention, self).__init__()
75 | self.dim = dim
76 | self.window_size = window_size # Wh, Ww
77 | self.num_heads = num_heads
78 | head_dim = dim // num_heads
79 | self.scale = qk_scale or head_dim ** -0.5
80 |
81 | # define a parameter table of relative position bias
82 | self.relative_position_bias_table = nn.Parameter(
83 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
84 |
85 | # get pair-wise relative position index for each token inside the window
86 | coords_h = torch.arange(self.window_size[0])
87 | coords_w = torch.arange(self.window_size[1])
88 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
89 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
90 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
91 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
92 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
93 | relative_coords[:, :, 1] += self.window_size[1] - 1
94 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
95 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
96 | self.register_buffer("relative_position_index", relative_position_index)
97 |
98 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
99 | self.attn_drop = nn.Dropout(attn_drop)
100 | self.proj = nn.Linear(dim, dim)
101 | self.proj_drop = nn.Dropout(proj_drop)
102 |
103 | trunc_normal_(self.relative_position_bias_table, std=.02)
104 | self.softmax = nn.Softmax(dim=-1)
105 |
106 | def forward(self, x, mask=None):
107 | """
108 | Args:
109 | x: input features with shape of (num_windows*B, N, C)
110 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
111 | """
112 |
113 | B_, N, C = x.shape
114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116 |
117 | q = q * self.scale
118 | attn = (q @ k.transpose(-2, -1))
119 |
120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
122 |
123 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
124 | attn = attn + relative_position_bias.unsqueeze(0)
125 |
126 | if mask is not None:
127 | nW = mask.shape[0]
128 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
129 | attn = attn.view(-1, self.num_heads, N, N)
130 | attn = self.softmax(attn)
131 | else:
132 | attn = self.softmax(attn)
133 |
134 | attn = self.attn_drop(attn)
135 |
136 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
137 | x = self.proj(x)
138 | x = self.proj_drop(x)
139 | return x
140 |
141 |
142 | class SwinTransformerBlock(nn.Module):
143 | r""" Swin Transformer Block.
144 |
145 | Args:
146 | dim (int): Number of input channels.
147 | input_resolution (tuple[int]): Input resulotion.
148 | num_heads (int): Number of attention heads.
149 | window_size (int): Window size.
150 | shift_size (int): Shift size for SW-MSA.
151 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
152 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
153 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
154 | drop (float, optional): Dropout rate. Default: 0.0
155 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
156 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
157 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
158 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
159 | """
160 |
161 | def __init__(self, dim, input_resolution, num_heads, window_size=2, shift_size=0,
162 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
163 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
164 | super().__init__()
165 | self.dim = dim
166 | self.input_resolution = input_resolution
167 | self.num_heads = num_heads
168 | self.window_size = window_size
169 | self.shift_size = shift_size
170 | self.mlp_ratio = mlp_ratio
171 | if min(self.input_resolution) <= self.window_size:
172 | # if window size is larger than input resolution, we don't partition windows
173 | self.shift_size = 0
174 | self.window_size = min(self.input_resolution)
175 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
176 |
177 | self.norm1 = norm_layer(dim)
178 | self.attn = WindowAttention(
179 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
180 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
181 |
182 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
183 | self.norm2 = norm_layer(dim)
184 | mlp_hidden_dim = int(dim * mlp_ratio)
185 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
186 | self.red = nn.Linear(2 * dim, dim)
187 | if self.shift_size > 0:
188 | # calculate attention mask for SW-MSA
189 | H, W = self.input_resolution
190 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
191 | h_slices = (slice(0, -self.window_size),
192 | slice(-self.window_size, -self.shift_size),
193 | slice(-self.shift_size, None))
194 | w_slices = (slice(0, -self.window_size),
195 | slice(-self.window_size, -self.shift_size),
196 | slice(-self.shift_size, None))
197 | cnt = 0
198 | for h in h_slices:
199 | for w in w_slices:
200 | img_mask[:, h, w, :] = cnt
201 | cnt += 1
202 |
203 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
204 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
205 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
206 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
207 | else:
208 | attn_mask = None
209 |
210 | self.register_buffer("attn_mask", attn_mask)
211 |
212 | def forward(self, x, hx=None):
213 | H, W = self.input_resolution
214 | B, L, C = x.shape
215 | assert L == H * W, "input feature has wrong size"
216 |
217 | shortcut = x
218 | x = self.norm1(x)
219 | if hx is not None:
220 | hx = self.norm1(hx)
221 | x = torch.cat((x, hx), -1)
222 | x = self.red(x)
223 | x = x.view(B, H, W, C)
224 |
225 | # cyclic shift
226 | if self.shift_size > 0:
227 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
228 | else:
229 | shifted_x = x
230 |
231 | # partition windows
232 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
233 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
234 |
235 | # W-MSA/SW-MSA
236 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
237 |
238 | # merge windows
239 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
240 |
241 | # reverse cyclic shift
242 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
243 |
244 | if self.shift_size > 0:
245 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
246 | else:
247 | x = shifted_x
248 |
249 | # FFN
250 | x = x.view(B, H * W, C)
251 | x = shortcut + self.drop_path(x)
252 | x = x + self.drop_path(self.mlp(self.norm2(x)))
253 |
254 | return x
255 |
256 |
257 | class PatchMerging(nn.Module):
258 | r""" Patch Merging Layer.
259 |
260 | Args:
261 | input_resolution (tuple[int]): Resolution of input feature.
262 | dim (int): Number of input channels.
263 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
264 | """
265 |
266 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
267 | super().__init__()
268 | self.input_resolution = input_resolution
269 | self.dim = dim
270 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
271 | self.norm = norm_layer(4 * dim)
272 |
273 | def forward(self, x):
274 | H, W = self.input_resolution
275 | B, L, C = x.shape
276 | assert L == H * W, "input feature has wrong size"
277 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
278 |
279 | x = x.view(B, H, W, C)
280 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
281 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
282 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
283 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
284 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
285 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
286 | x = self.norm(x)
287 | x = self.reduction(x)
288 | return x
289 |
290 |
291 | # cite the 'PatchExpand' code from Swin-Unet, Thanks!
292 | # https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/
293 | # swin_transformer_unet_skip_expand_decoder_sys.py, line 333-355
294 | class PatchExpanding(nn.Module):
295 | r""" Patch Expanding Layer.
296 |
297 | Args:
298 | input_resolution (tuple[int]): Resolution of input feature.
299 | dim (int): Number of input channels.
300 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
301 | """
302 |
303 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
304 | super(PatchExpanding, self).__init__()
305 | self.input_resolution = input_resolution
306 | self.dim = dim
307 | self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
308 | self.norm = norm_layer(dim // dim_scale)
309 |
310 | def forward(self, x):
311 | H, W = self.input_resolution
312 | x = self.expand(x)
313 | B, L, C = x.shape
314 | assert L == H * W, "input feature has wrong size"
315 |
316 | x = x.view(B, H, W, C)
317 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
318 | x = x.view(B, -1, C // 4)
319 | x = self.norm(x)
320 |
321 | return x
322 |
323 |
324 | class PatchEmbed(nn.Module):
325 | r""" Image to Patch Embedding
326 |
327 | Args:
328 | img_size (int): Image size.
329 | patch_size (int): Patch token size.
330 | in_chans (int): Number of input image channels.
331 | embed_dim (int): Number of linear projection output channels.
332 | """
333 |
334 | def __init__(self, img_size, patch_size, in_chans, embed_dim):
335 | super(PatchEmbed, self).__init__()
336 | img_size = to_2tuple(img_size)
337 | patch_size = to_2tuple(patch_size)
338 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
339 | self.img_size = img_size
340 | self.patch_size = patch_size
341 | self.patches_resolution = patches_resolution
342 | self.num_patches = patches_resolution[0] * patches_resolution[1]
343 | self.in_chans = in_chans
344 | self.embed_dim = embed_dim
345 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
346 | self.norm = nn.LayerNorm(embed_dim)
347 |
348 | def forward(self, x):
349 | B, C, H, W = x.shape
350 | assert H == self.img_size[0] and W == self.img_size[1], \
351 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
352 | x = self.proj(x).flatten(2).transpose(1, 2)
353 | x = self.norm(x)
354 | return x
355 |
356 |
357 | class PatchInflated(nn.Module):
358 | r""" Tensor to Patch Inflating
359 |
360 | Args:
361 | in_chans (int): Number of input image channels.
362 | embed_dim (int): Number of linear projection output channels.
363 | input_resolution (tuple[int]): Input resulotion.
364 | """
365 |
366 | def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1):
367 | super(PatchInflated, self).__init__()
368 |
369 | stride = to_2tuple(stride)
370 | padding = to_2tuple(padding)
371 | output_padding = to_2tuple(output_padding)
372 | self.input_resolution = input_resolution
373 |
374 | self.Conv = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3),
375 | stride=stride, padding=padding, output_padding=output_padding)
376 |
377 | def forward(self, x):
378 | H, W = self.input_resolution
379 | B, L, C = x.shape
380 | assert L == H * W, "input feature has wrong size"
381 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
382 |
383 | x = x.view(B, H, W, C)
384 | x = x.permute(0, 3, 1, 2)
385 | x = self.Conv(x)
386 |
387 | return x
388 |
389 |
390 | class SwinLSTMCell(nn.Module):
391 | def __init__(self, dim, input_resolution, num_heads, window_size, depth,
392 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
393 | drop_path=0., norm_layer=nn.LayerNorm, flag=None):
394 | super(SwinLSTMCell, self).__init__()
395 |
396 | self.Swin = SwinTransformer(dim=dim, input_resolution=input_resolution, depth=depth,
397 | num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
398 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
399 | drop_path=drop_path, norm_layer=norm_layer, flag=flag)
400 |
401 | def forward(self, xt, hidden_states):
402 | if hidden_states is None:
403 | B, L, C = xt.shape
404 | hx = torch.zeros(B, L, C).to(xt.device)
405 | cx = torch.zeros(B, L, C).to(xt.device)
406 |
407 | else:
408 | hx, cx = hidden_states
409 |
410 | Ft = self.Swin(xt, hx)
411 |
412 | gate = torch.sigmoid(Ft)
413 | cell = torch.tanh(Ft)
414 |
415 | cy = gate * (cx + cell)
416 | hy = gate * torch.tanh(cy)
417 | hx = hy
418 | cx = cy
419 |
420 | return hx, (hx, cx)
421 |
422 |
423 | class SwinTransformer(nn.Module):
424 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None,
425 | drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, flag=None):
426 | super(SwinTransformer, self).__init__()
427 |
428 | self.layers = nn.ModuleList([
429 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
430 | num_heads=num_heads, window_size=window_size,
431 | shift_size=0 if (i % 2 == 0) else window_size // 2,
432 | mlp_ratio=mlp_ratio,
433 | qkv_bias=qkv_bias, qk_scale=qk_scale,
434 | drop=drop, attn_drop=attn_drop,
435 | drop_path=drop_path[depth - i - 1] if (flag == 0) else drop_path[i],
436 | norm_layer=norm_layer)
437 | for i in range(depth)])
438 |
439 | def forward(self, xt, hx):
440 |
441 | outputs = []
442 |
443 | for index, layer in enumerate(self.layers):
444 | if index == 0:
445 | x = layer(xt, hx)
446 | outputs.append(x)
447 |
448 | else:
449 | if index % 2 == 0:
450 | x = layer(outputs[-1], xt)
451 | outputs.append(x)
452 |
453 | if index % 2 == 1:
454 | x = layer(outputs[-1], None)
455 | outputs.append(x)
456 |
457 | return outputs[-1]
458 |
459 |
460 | class DownSample(nn.Module):
461 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, num_heads, window_size,
462 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
463 | norm_layer=nn.LayerNorm):
464 | super(DownSample, self).__init__()
465 |
466 | self.num_layers = len(depths_downsample)
467 | self.embed_dim = embed_dim
468 | self.mlp_ratio = mlp_ratio
469 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
470 | patches_resolution = self.patch_embed.patches_resolution
471 |
472 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_downsample))]
473 |
474 | self.layers = nn.ModuleList()
475 | self.downsample = nn.ModuleList()
476 |
477 | for i_layer in range(self.num_layers):
478 | downsample = PatchMerging(input_resolution=(patches_resolution[0] // (2 ** i_layer),
479 | patches_resolution[1] // (2 ** i_layer)),
480 | dim=int(embed_dim * 2 ** i_layer))
481 |
482 | layer = SwinLSTMCell(dim=int(embed_dim * 2 ** i_layer),
483 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
484 | patches_resolution[1] // (2 ** i_layer)),
485 | depth=depths_downsample[i_layer],
486 | num_heads=num_heads[i_layer],
487 | window_size=window_size,
488 | mlp_ratio=self.mlp_ratio,
489 | qkv_bias=qkv_bias, qk_scale=qk_scale,
490 | drop=drop_rate, attn_drop=attn_drop_rate,
491 | drop_path=dpr[sum(depths_downsample[:i_layer]):sum(depths_downsample[:i_layer + 1])],
492 | norm_layer=norm_layer)
493 |
494 | self.layers.append(layer)
495 | self.downsample.append(downsample)
496 |
497 | def forward(self, x, y):
498 |
499 | x = self.patch_embed(x)
500 |
501 | hidden_states_down = []
502 |
503 | for index, layer in enumerate(self.layers):
504 | x, hidden_state = layer(x, y[index])
505 | x = self.downsample[index](x)
506 | hidden_states_down.append(hidden_state)
507 |
508 | return hidden_states_down, x
509 |
510 |
511 | class UpSample(nn.Module):
512 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_upsample, num_heads, window_size, mlp_ratio=4.,
513 | qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
514 | norm_layer=nn.LayerNorm, flag=0):
515 | super(UpSample, self).__init__()
516 |
517 | self.img_size = img_size
518 | self.num_layers = len(depths_upsample)
519 | self.embed_dim = embed_dim
520 | self.mlp_ratio = mlp_ratio
521 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
522 | patches_resolution = self.patch_embed.patches_resolution
523 | self.Unembed = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution)
524 |
525 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_upsample))]
526 |
527 | self.layers = nn.ModuleList()
528 | self.upsample = nn.ModuleList()
529 |
530 | for i_layer in range(self.num_layers):
531 | resolution1 = (patches_resolution[0] // (2 ** (self.num_layers - i_layer)))
532 | resolution2 = (patches_resolution[1] // (2 ** (self.num_layers - i_layer)))
533 |
534 | dimension = int(embed_dim * 2 ** (self.num_layers - i_layer))
535 | upsample = PatchExpanding(input_resolution=(resolution1, resolution2), dim=dimension)
536 |
537 | layer = SwinLSTMCell(dim=dimension, input_resolution=(resolution1, resolution2),
538 | depth=depths_upsample[(self.num_layers - 1 - i_layer)],
539 | num_heads=num_heads[(self.num_layers - 1 - i_layer)],
540 | window_size=window_size,
541 | mlp_ratio=self.mlp_ratio,
542 | qkv_bias=qkv_bias, qk_scale=qk_scale,
543 | drop=drop_rate, attn_drop=attn_drop_rate,
544 | drop_path=dpr[sum(depths_upsample[:(self.num_layers - 1 - i_layer)]):
545 | sum(depths_upsample[:(self.num_layers - 1 - i_layer) + 1])],
546 | norm_layer=norm_layer, flag=flag)
547 |
548 | self.layers.append(layer)
549 | self.upsample.append(upsample)
550 |
551 | def forward(self, x, y):
552 | hidden_states_up = []
553 |
554 | for index, layer in enumerate(self.layers):
555 | x, hidden_state = layer(x, y[index])
556 | x = self.upsample[index](x)
557 | hidden_states_up.append(hidden_state)
558 |
559 | x = torch.sigmoid(self.Unembed(x))
560 |
561 | return hidden_states_up, x
562 |
563 |
564 | class SwinLSTM(nn.Module):
565 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, depths_upsample, num_heads,
566 | window_size):
567 | super(SwinLSTM, self).__init__()
568 |
569 | self.Downsample = DownSample(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
570 | embed_dim=embed_dim, depths_downsample=depths_downsample,
571 | num_heads=num_heads, window_size=window_size)
572 |
573 | self.Upsample = UpSample(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
574 | embed_dim=embed_dim, depths_upsample=depths_upsample,
575 | num_heads=num_heads, window_size=window_size)
576 |
577 | def forward(self, input, states_down, states_up):
578 | states_down, x = self.Downsample(input, states_down)
579 | states_up, output = self.Upsample(x, states_up)
580 |
581 | return output, states_down, states_up
582 |
--------------------------------------------------------------------------------
/__pycache__/SwinLSTM_D.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/SwinLSTM_D.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/configs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/configs.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/functions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/functions.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/architecture.png
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser('SwinLSTM training and evaluation script', add_help=False)
6 |
7 | # Setup parameters
8 | parser.add_argument('--device', default='cuda:0', type=str)
9 | parser.add_argument('--res_dir', default='./results', type=str)
10 | parser.add_argument('--seed', default=1234, type=int)
11 |
12 | # Moving_MNIST dataset parameters
13 | parser.add_argument('--num_frames_input', default=10, type=int, help='Input sequence length')
14 | parser.add_argument('--num_frames_output', default=10, type=int, help='Output sequence length')
15 | parser.add_argument('--image_size', default=(28, 28), type=int, help='Original resolution')
16 | parser.add_argument('--input_size', default=(64, 64), help='Input resolution')
17 | parser.add_argument('--step_length', default=0.1, type=float)
18 | parser.add_argument('--num_objects', default=[2], type=int)
19 | parser.add_argument('--train_samples', default=[0, 10000], type=int, help='Number of samples in training set')
20 | parser.add_argument('--valid_samples', default=[10000, 13000], type=int, help='Number of samples in validation set')
21 | parser.add_argument('--train_data_dir', default='./data/train-images-idx3-ubyte.gz')
22 | parser.add_argument('--test_data_dir', default='./data/mnist_test_seq.npy')
23 |
24 | # model parameters
25 | parser.add_argument('--model', default='SwinLSTM-D', type=str, choices=['SwinLSTM-B', 'SwinLSTM-D'],
26 | help='Model type')
27 | parser.add_argument('--input_channels', default=1, type=int, help='Number of input image channels')
28 | parser.add_argument('--input_img_size', default=64, type=int, help='Input image size')
29 | parser.add_argument('--patch_size', default=2, type=int, help='Patch size of input images')
30 | parser.add_argument('--embed_dim', default=128, type=int, help='Patch embedding dimension')
31 | parser.add_argument('--depths', default=[12], type=int, help='Depth of Swin Transformer layer for SwinLSTM-B')
32 | parser.add_argument('--depths_down', default=[2, 6], type=int, help='Downsample of SwinLSTM-D')
33 | parser.add_argument('--depths_up', default=[6, 2], type=int, help='Upsample of SwinLSTM-D')
34 | parser.add_argument('--heads_number', default=[4, 8], type=int,
35 | help='Number of attention heads in different layers')
36 | parser.add_argument('--window_size', default=4, type=int, help='Window size of Swin Transformer layer')
37 | parser.add_argument('--drop_rate', default=0., type=float, help='Dropout rate')
38 | parser.add_argument('--attn_drop_rate', default=0., type=float, help='Attention dropout rate')
39 | parser.add_argument('--drop_path_rate', default=0.1, type=float, help='Stochastic depth rate')
40 |
41 | # Training parameters
42 | parser.add_argument('--train_batch_size', default=16, type=int, help='Batch size for training')
43 | parser.add_argument('--valid_batch_size', default=16, type=int, help='Batch size for validation')
44 | parser.add_argument('--test_batch_size', default=16, type=int, help='Batch size for testing')
45 | parser.add_argument('--num_workers', default=8, type=int)
46 | parser.add_argument('--epochs', default=2000, type=int)
47 | parser.add_argument('--epoch_valid', default=10, type=int)
48 | parser.add_argument('--log_train', default=100, type=int)
49 | parser.add_argument('--log_valid', default=60, type=int)
50 | parser.add_argument('--lr', default=0.0001, type=float, help='Learning rate')
51 |
52 | args = parser.parse_args()
53 |
54 | return args
55 |
--------------------------------------------------------------------------------
/data/test_data_download_link.md:
--------------------------------------------------------------------------------
1 | http://www.cs.toronto.edu/~nitish/unsupervised_video/
2 |
--------------------------------------------------------------------------------
/data/train-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/data/train-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class Moving_MNIST(Dataset):
10 |
11 | def __init__(self, args, split):
12 |
13 | super(Moving_MNIST, self).__init__()
14 |
15 | with gzip.open(args.train_data_dir, 'rb') as f:
16 | self.datas = np.frombuffer(f.read(), np.uint8, offset=16)
17 | self.datas = self.datas.reshape(-1, *args.image_size)
18 |
19 | if split == 'train':
20 | self.datas = self.datas[args.train_samples[0]: args.train_samples[1]]
21 | else:
22 | self.datas = self.datas[args.valid_samples[0]: args.valid_samples[1]]
23 |
24 | self.image_size = args.image_size
25 | self.input_size = args.input_size
26 | self.step_length = args.step_length
27 | self.num_objects = args.num_objects
28 |
29 | self.num_frames_input = args.num_frames_input
30 | self.num_frames_output = args.num_frames_output
31 | self.num_frames_total = args.num_frames_input + args.num_frames_output
32 |
33 | print('Loaded {} {} samples'.format(self.__len__(), split))
34 |
35 | def _get_random_trajectory(self, seq_length):
36 |
37 | assert self.input_size[0] == self.input_size[1]
38 | assert self.image_size[0] == self.image_size[1]
39 |
40 | canvas_size = self.input_size[0] - self.image_size[0]
41 |
42 | x = random.random()
43 | y = random.random()
44 |
45 | theta = random.random() * 2 * np.pi
46 |
47 | v_y = np.sin(theta)
48 | v_x = np.cos(theta)
49 |
50 | start_y = np.zeros(seq_length)
51 | start_x = np.zeros(seq_length)
52 |
53 | for i in range(seq_length):
54 |
55 | y += v_y * self.step_length
56 | x += v_x * self.step_length
57 |
58 | if x <= 0.: x = 0.; v_x = -v_x;
59 | if x >= 1.: x = 1.; v_x = -v_x
60 | if y <= 0.: y = 0.; v_y = -v_y;
61 | if y >= 1.: y = 1.; v_y = -v_y
62 |
63 | start_y[i] = y
64 | start_x[i] = x
65 |
66 | start_y = (canvas_size * start_y).astype(np.int32)
67 | start_x = (canvas_size * start_x).astype(np.int32)
68 |
69 | return start_y, start_x
70 |
71 | def _generate_moving_mnist(self, num_digits=2):
72 |
73 | data = np.zeros((self.num_frames_total, *self.input_size), dtype=np.float32)
74 |
75 | for n in range(num_digits):
76 |
77 | start_y, start_x = self._get_random_trajectory(self.num_frames_total)
78 | ind = np.random.randint(0, self.__len__())
79 | digit_image = self.datas[ind]
80 |
81 | for i in range(self.num_frames_total):
82 | top = start_y[i]
83 | left = start_x[i]
84 | bottom = top + self.image_size[0]
85 | right = left + self.image_size[1]
86 | data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image)
87 |
88 | data = data[..., np.newaxis]
89 |
90 | return data
91 |
92 | def __getitem__(self, item):
93 |
94 | num_digits = random.choice(self.num_objects)
95 | images = self._generate_moving_mnist(num_digits)
96 |
97 | inputs = torch.from_numpy(images[:self.num_frames_input]).permute(0, 3, 1, 2).contiguous()
98 | targets = torch.from_numpy(images[self.num_frames_output:]).permute(0, 3, 1, 2).contiguous()
99 |
100 | return inputs / 255., targets / 255.
101 |
102 | def __len__(self):
103 | return self.datas.shape[0]
104 |
105 |
106 | class Moving_MNIST_Test(Dataset):
107 | def __init__(self, args):
108 | super(Moving_MNIST_Test, self).__init__()
109 |
110 | self.num_frames_input = args.num_frames_input
111 | self.num_frames_output = args.num_frames_output
112 | self.num_frames_total = args.num_frames_input + args.num_frames_output
113 |
114 | self.dataset = np.load(args.test_data_dir)
115 | self.dataset = self.dataset[..., np.newaxis]
116 |
117 | print('Loaded {} {} samples'.format(self.__len__(), 'test'))
118 |
119 | def __getitem__(self, index):
120 | images = self.dataset[:, index, ...]
121 |
122 | inputs = torch.from_numpy(images[:self.num_frames_input]).permute(0, 3, 1, 2).contiguous()
123 | targets = torch.from_numpy(images[self.num_frames_output:]).permute(0, 3, 1, 2).contiguous()
124 |
125 | return inputs / 255., targets / 255.
126 |
127 | def __len__(self):
128 | return len(self.dataset[1])
129 |
130 |
--------------------------------------------------------------------------------
/functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.cuda import amp
4 | from torch.cuda.amp import autocast as autocast
5 | from utils import compute_metrics, visualize
6 |
7 | scaler = amp.GradScaler()
8 |
9 |
10 | def model_forward_single_layer(model, inputs, targets_len, num_layers):
11 | outputs = []
12 | states = [None] * len(num_layers)
13 |
14 | inputs_len = inputs.shape[1]
15 |
16 | last_input = inputs[:, -1]
17 |
18 | for i in range(inputs_len - 1):
19 | output, states = model(inputs[:, i], states)
20 | outputs.append(output)
21 |
22 | for i in range(targets_len):
23 | output, states = model(last_input, states)
24 | outputs.append(output)
25 | last_input = output
26 |
27 | return outputs
28 |
29 |
30 | def model_forward_multi_layer(model, inputs, targets_len, num_layers):
31 | states_down = [None] * len(num_layers)
32 | states_up = [None] * len(num_layers)
33 |
34 | outputs = []
35 |
36 | inputs_len = inputs.shape[1]
37 |
38 | last_input = inputs[:, -1]
39 |
40 | for i in range(inputs_len - 1):
41 | output, states_down, states_up = model(inputs[:, i], states_down, states_up)
42 | outputs.append(output)
43 |
44 | for i in range(targets_len):
45 | output, states_down, states_up = model(last_input, states_down, states_up)
46 | outputs.append(output)
47 | last_input = output
48 |
49 | return outputs
50 |
51 |
52 | def train(args, logger, epoch, model, train_loader, criterion, optimizer):
53 | model.train()
54 | num_batches = len(train_loader)
55 | losses = []
56 |
57 | for batch_idx, (inputs, targets) in enumerate(train_loader):
58 |
59 | optimizer.zero_grad()
60 |
61 | inputs, targets = map(lambda x: x.float().to(args.device), [inputs, targets])
62 | targets_len = targets.shape[1]
63 | with autocast():
64 | if args.model == 'SwinLSTM-B':
65 | outputs = model_forward_single_layer(model, inputs, targets_len, args.depths)
66 |
67 | if args.model == 'SwinLSTM-D':
68 | outputs = model_forward_multi_layer(model, inputs, targets_len, args.depths_down)
69 |
70 | outputs = torch.stack(outputs).permute(1, 0, 2, 3, 4).contiguous()
71 | targets_ = torch.cat((inputs[:, 1:], targets), dim=1)
72 | loss = criterion(outputs, targets_)
73 |
74 | scaler.scale(loss).backward()
75 | scaler.step(optimizer)
76 | scaler.update()
77 |
78 | losses.append(loss.item())
79 |
80 | if batch_idx and batch_idx % args.log_train == 0:
81 | logger.info(f'EP:{epoch:04d} BI:{batch_idx:03d}/{num_batches:03d} Loss:{np.mean(losses):.6f}')
82 |
83 | return np.mean(losses)
84 |
85 |
86 | def test(args, logger, epoch, model, test_loader, criterion, cache_dir):
87 | model.eval()
88 | num_batches = len(test_loader)
89 | losses, mses, ssims = [], [], []
90 |
91 | for batch_idx, (inputs, targets) in enumerate(test_loader):
92 |
93 | with torch.no_grad():
94 | inputs, targets = map(lambda x: x.float().to(args.device), [inputs, targets])
95 | targets_len = targets.shape[1]
96 |
97 | if args.model == 'SwinLSTM-B':
98 | outputs = model_forward_single_layer(model, inputs, targets_len, args.depths)
99 |
100 | if args.model == 'SwinLSTM-D':
101 | outputs = model_forward_multi_layer(model, inputs, targets_len, args.depths_down)
102 |
103 | outputs = torch.stack(outputs).permute(1, 0, 2, 3, 4).contiguous()
104 | targets_ = torch.cat((inputs[:, 1:], targets), dim=1)
105 |
106 | losses.append(criterion(outputs, targets_).item())
107 |
108 | inputs_len = inputs.shape[1]
109 | outputs = outputs[:, inputs_len - 1:]
110 |
111 | mse, ssim = compute_metrics(outputs, targets)
112 |
113 | mses.append(mse)
114 | ssims.append(ssim)
115 |
116 | if batch_idx and batch_idx % args.log_valid == 0:
117 | logger.info(
118 | f'EP:{epoch:04d} BI:{batch_idx:03d}/{num_batches:03d} Loss:{np.mean(losses):.6f} MSE:{mse:.4f} SSIM:{ssim:.4f}')
119 | visualize(inputs, targets, outputs, epoch, batch_idx, cache_dir)
120 |
121 | return np.mean(losses), np.mean(mses), np.mean(ssims)
122 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.11.0
2 | torchvision==0.12.0
3 | numpy
4 | matplotlib
5 | skimage==0.19.2
6 | timm==0.4.12
7 | einops==0.4.1
8 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from torch import nn
4 | from torch.utils.data import DataLoader
5 | from SwinLSTM_D import SwinLSTM
6 | from configs import get_args
7 | from dataset import Moving_MNIST_Test
8 | from functions import test
9 | from utils import set_seed, make_dir, init_logger
10 |
11 | if __name__ == '__main__':
12 | args = get_args()
13 | set_seed(args.seed)
14 | cache_dir, model_dir, log_dir = make_dir(args)
15 | logger = init_logger(log_dir)
16 |
17 | model = SwinLSTM(img_size=args.input_img_size, patch_size=args.patch_size,
18 | in_chans=args.input_channels, embed_dim=args.embed_dim,
19 | depths_downsample=args.depths_down, depths_upsample=args.depths_up,
20 | num_heads=args.heads_number, window_size=args.window_size).to(args.device)
21 |
22 | criterion = nn.MSELoss()
23 |
24 | test_dataset = Moving_MNIST_Test(args)
25 | test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size,
26 | num_workers=args.num_workers, shuffle=False, pin_memory=True, drop_last=True)
27 |
28 | model.load_state_dict(torch.load('./Pretrained/trained_model_state_dict'))
29 |
30 | start_time = time.time()
31 |
32 | _, mse, ssim = test(args, logger, 0, model, test_loader, criterion, cache_dir)
33 |
34 | print(f'[Metrics] MSE:{mse:.4f} SSIM:{ssim:.4f}')
35 | print(f'Time usage per epoch: {time.time() - start_time:.0f}s')
36 |
37 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | import torch.nn as nn
3 | from configs import get_args
4 | from functions import train, test
5 | from torch.utils.data import DataLoader
6 | from dataset import Moving_MNIST
7 |
8 |
9 | def setup(args):
10 | if args.model == 'SwinLSTM-B':
11 | from SwinLSTM_B import SwinLSTM
12 | model = SwinLSTM(img_size=args.input_img_size, patch_size=args.patch_size,
13 | in_chans=args.input_channels, embed_dim=args.embed_dim,
14 | depths=args.depths, num_heads=args.heads_number,
15 | window_size=args.window_size, drop_rate=args.drop_rate,
16 | attn_drop_rate=args.attn_drop_rate, drop_path_rate=args.drop_path_rate).to(args.device)
17 |
18 | if args.model == 'SwinLSTM-D':
19 | from SwinLSTM_D import SwinLSTM
20 | model = SwinLSTM(img_size=args.input_img_size, patch_size=args.patch_size,
21 | in_chans=args.input_channels, embed_dim=args.embed_dim,
22 | depths_downsample=args.depths_down, depths_upsample=args.depths_up,
23 | num_heads=args.heads_number, window_size=args.window_size).to(args.device)
24 |
25 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
26 |
27 | criterion = nn.MSELoss()
28 |
29 | train_dataset = Moving_MNIST(args, split='train')
30 | valid_dataset = Moving_MNIST(args, split='valid')
31 |
32 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,
33 | num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True)
34 |
35 | valid_loader = DataLoader(valid_dataset, batch_size=args.valid_batch_size,
36 | num_workers=args.num_workers, shuffle=False, pin_memory=True, drop_last=True)
37 |
38 | return model, criterion, optimizer, train_loader, valid_loader
39 |
40 | def main():
41 | args = get_args()
42 | set_seed(args.seed)
43 | cache_dir, model_dir, log_dir = make_dir(args)
44 | logger = init_logger(log_dir)
45 |
46 | model, criterion, optimizer, train_loader, valid_loader = setup(args)
47 |
48 | train_losses, valid_losses = [], []
49 |
50 | best_metric = (0, float('inf'), float('inf'))
51 |
52 | for epoch in range(args.epochs):
53 |
54 | start_time = time.time()
55 | train_loss = train(args, logger, epoch, model, train_loader, criterion, optimizer)
56 | train_losses.append(train_loss)
57 | plot_loss(train_losses, 'train', epoch, args.res_dir, 1)
58 |
59 | if (epoch + 1) % args.epoch_valid == 0:
60 |
61 | valid_loss, mse, ssim = test(args, logger, epoch, model, valid_loader, criterion, cache_dir)
62 |
63 | valid_losses.append(valid_loss)
64 |
65 | plot_loss(valid_losses, 'valid', epoch, args.res_dir, args.epoch_valid)
66 |
67 | if mse < best_metric[1]:
68 | torch.save(model.state_dict(), f'{model_dir}/trained_model_state_dict')
69 | best_metric = (epoch, mse, ssim)
70 |
71 | logger.info(f'[Current Best] EP:{best_metric[0]:04d} MSE:{best_metric[1]:.4f} SSIM:{best_metric[2]:.4f}')
72 |
73 | print(f'Time usage per epoch: {time.time() - start_time:.0f}s')
74 |
75 | if __name__ == '__main__':
76 | main()
77 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import random
5 | import logging
6 | import matplotlib
7 | import numpy as np
8 | from matplotlib import pyplot as plt
9 | from skimage.metrics import structural_similarity
10 |
11 | matplotlib.use('agg')
12 |
13 | def set_seed(seed):
14 | torch.manual_seed(seed)
15 | random.seed(seed)
16 | np.random.seed(seed)
17 | torch.backends.cudnn.deterministic = True
18 |
19 | def visualize(inputs, targets, outputs, epoch, idx, cache_dir):
20 | _, axarray = plt.subplots(3, targets.shape[1], figsize=(targets.shape[1] * 5, 10))
21 |
22 | for t in range(targets.shape[1]):
23 | axarray[0][t].imshow(inputs[0, t, 0].detach().cpu().numpy(), cmap='gray')
24 | axarray[1][t].imshow(targets[0, t, 0].detach().cpu().numpy(), cmap='gray')
25 | axarray[2][t].imshow(outputs[0, t, 0].detach().cpu().numpy(), cmap='gray')
26 |
27 | plt.savefig(os.path.join(cache_dir, '{:03d}-{:03d}.png'.format(epoch, idx)))
28 | plt.close()
29 |
30 | def plot_loss(loss_records, loss_type, epoch, plot_dir, step):
31 | plt.plot(range((epoch + 1) // step), loss_records, label=loss_type)
32 | plt.legend()
33 | plt.savefig(os.path.join(plot_dir, '{}_loss_records.png'.format(loss_type)))
34 | plt.close()
35 |
36 | def MAE(pred, true):
37 | return np.mean(np.abs(pred - true), axis=(0, 1)).sum()
38 |
39 | def MSE(pred, true):
40 | return np.mean((pred - true) ** 2, axis=(0, 1)).sum()
41 |
42 | # cite the 'PSNR' code from E3D-LSTM, Thanks!
43 | # https://github.com/google/e3d_lstm/blob/master/src/trainer.py line 39-40
44 | def PSNR(pred, true):
45 | mse = np.mean((np.uint8(pred * 255) - np.uint8(true * 255)) ** 2)
46 | return 20 * np.log10(255) - 10 * np.log10(mse)
47 |
48 | def compute_metrics(predictions, targets):
49 | targets = targets.permute(0, 1, 3, 4, 2).detach().cpu().numpy()
50 | predictions = predictions.permute(0, 1, 3, 4, 2).detach().cpu().numpy()
51 |
52 | batch_size = predictions.shape[0]
53 | Seq_len = predictions.shape[1]
54 |
55 | ssim = 0
56 |
57 | for batch in range(batch_size):
58 | for frame in range(Seq_len):
59 | ssim += structural_similarity(targets[batch, frame].squeeze(),
60 | predictions[batch, frame].squeeze())
61 |
62 | ssim /= (batch_size * Seq_len)
63 |
64 | mse = MSE(predictions, targets)
65 |
66 | return mse, ssim
67 |
68 | def check_dir(path):
69 | if not os.path.exists(path):
70 | os.makedirs(path)
71 |
72 | def make_dir(args):
73 |
74 | cache_dir = os.path.join(args.res_dir, 'cache')
75 | check_dir(cache_dir)
76 |
77 | model_dir = os.path.join(args.res_dir, 'model')
78 | check_dir(model_dir)
79 |
80 | log_dir = os.path.join(args.res_dir, 'log')
81 | check_dir(log_dir)
82 |
83 | return cache_dir, model_dir, log_dir
84 |
85 | def init_logger(log_dir):
86 | logging.basicConfig(level=logging.INFO,
87 | format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
88 | datefmt='%m-%d %H:%M',
89 | filename=os.path.join(log_dir, time.strftime("%Y_%m_%d") + '.log'),
90 | filemode='w')
91 |
92 | console = logging.StreamHandler()
93 | console.setLevel(logging.INFO)
94 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
95 | console.setFormatter(formatter)
96 | logging.getLogger('').addHandler(console)
97 |
98 | return logging
99 |
--------------------------------------------------------------------------------