├── .gitignore ├── .idea ├── .gitignore ├── .name ├── SwinLSTM-D.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── Pretrained └── trained_model_state_dict ├── README.md ├── SwinLSTM_B.py ├── SwinLSTM_D.py ├── __pycache__ ├── SwinLSTM_D.cpython-39.pyc ├── configs.cpython-39.pyc ├── dataset.cpython-39.pyc ├── functions.cpython-39.pyc └── utils.cpython-39.pyc ├── architecture.png ├── configs.py ├── data ├── test_data_download_link.md └── train-images-idx3-ubyte.gz ├── dataset.py ├── functions.py ├── requirements.txt ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/.name: -------------------------------------------------------------------------------- 1 | main.py -------------------------------------------------------------------------------- /.idea/SwinLSTM-D.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Song Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pretrained/trained_model_state_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/Pretrained/trained_model_state_dict -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SwinLSTM: A new recurrent cell for spatiotemporal modeling 2 | 3 | This repository contains the official PyTorch implementation of the following paper: 4 | 5 | SwinLSTM: Improving Spatiotemporal Prediction Accuracy using Swin Transformer and LSTM **(ICCV 2023)** 6 | 7 | Paper:http://arxiv.org/abs/2308.09891v2 8 | 9 | 10 | 11 | 12 | ## Introduction 13 | ![architecture](/architecture.png) 14 | Integrating CNNs and RNNs to capture spatiotemporal dependencies is a prevalent strategy for spatiotemporal prediction tasks. However, the property of CNNs to learn local spatial information decreases their efficiency in capturing spatiotemporal dependencies, thereby limiting their prediction accuracy. In this paper, we propose a new recurrent cell, SwinLSTM, which integrates Swin Transformer blocks and the simplified LSTM, an extension that replaces the convolutional structure in ConvLSTM with the self-attention mechanism. Furthermore, we construct a network with SwinLSTM cell as the core for spatiotemporal prediction. Without using unique tricks, SwinLSTM outperforms state-of-the-art methods on Moving MNIST, Human3.6m, TaxiBJ, and KTH datasets. In particular, it exhibits a significant improvement in prediction accuracy compared to ConvLSTM. Our competitive experimental results demonstrate that learning global spatial dependencies is more advantageous for models to capture spatiotemporal dependencies. We hope that SwinLSTM can serve as a solid baseline to promote the advancement of spatiotemporal prediction accuracy. 15 | 16 | ## Overview 17 | - `Pretrained/` contains pretrained weights on MovingMNIST. 18 | - `data/` contains the MNIST dataset and the MovingMNIST test set download link. 19 | - `SwinLSTM_B.py` contains the model with a single SwinLSTM cell. 20 | - `SwinLSTM_D.py` contains the model with a multiple SwinLSTM cell. 21 | - `dataset.py` contains training and validation dataloaders. 22 | - `functions.py` contains train and test functions. 23 | - `train.py` is the core file for training pipeline. 24 | - `test.py` is a file for a quick test. 25 | 26 | ## Requirements 27 | - python >= 3.8 28 | - torch == 1.11.0 29 | - torchvision == 0.12.0 30 | - numpy 31 | - matplotlib 32 | - skimage == 0.19.2 33 | - timm == 0.4.12 34 | - einops == 0.4.1 35 | 36 | ## Citation 37 | If you find this work useful in your research, please cite the paper: 38 | ``` 39 | @inproceedings{tang2023swinlstm, 40 | title={SwinLSTM: Improving Spatiotemporal Prediction Accuracy using Swin Transformer and LSTM}, 41 | author={Tang, Song and Li, Chuang and Zhang, Pu and Tang, RongNian}, 42 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 43 | pages={13470--13479}, 44 | year={2023} 45 | } 46 | 47 | ``` 48 | 49 | ## Acknowledgment 50 | These codes are based on [Swin Transformer](https://github.com/microsoft/Swin-Transformer). We extend our sincere appreciation for their valuable contributions. 51 | 52 | -------------------------------------------------------------------------------- /SwinLSTM_B.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 4 | 5 | 6 | class Mlp(nn.Module): 7 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 8 | super(Mlp, self).__init__() 9 | out_features = out_features or in_features 10 | hidden_features = hidden_features or in_features 11 | self.fc1 = nn.Linear(in_features, hidden_features) 12 | self.act = act_layer() 13 | self.fc2 = nn.Linear(hidden_features, out_features) 14 | self.drop = nn.Dropout(drop) 15 | 16 | def forward(self, x): 17 | x = self.fc1(x) 18 | x = self.act(x) 19 | x = self.drop(x) 20 | x = self.fc2(x) 21 | x = self.drop(x) 22 | return x 23 | 24 | 25 | def window_partition(x, window_size): 26 | """ 27 | Args: 28 | x: (B, H, W, C) 29 | window_size (int): window size 30 | 31 | Returns: 32 | windows: (num_windows*B, window_size, window_size, C) 33 | """ 34 | B, H, W, C = x.shape 35 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 36 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | return windows 38 | 39 | 40 | def window_reverse(windows, window_size, H, W): 41 | """ 42 | Args: 43 | windows: (num_windows*B, window_size, window_size, C) 44 | window_size (int): Window size 45 | H (int): Height of image 46 | W (int): Width of image 47 | 48 | Returns: 49 | x: (B, H, W, C) 50 | """ 51 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 52 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 53 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 54 | return x 55 | 56 | 57 | class WindowAttention(nn.Module): 58 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 59 | It supports both of shifted and non-shifted window. 60 | 61 | Args: 62 | dim (int): Number of input channels. 63 | window_size (tuple[int]): The height and width of the window. 64 | num_heads (int): Number of attention heads. 65 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 66 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 67 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 68 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 69 | """ 70 | 71 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 72 | 73 | super(WindowAttention, self).__init__() 74 | self.dim = dim 75 | self.window_size = window_size # Wh, Ww 76 | self.num_heads = num_heads 77 | head_dim = dim // num_heads 78 | self.scale = qk_scale or head_dim ** -0.5 79 | 80 | # define a parameter table of relative position bias 81 | self.relative_position_bias_table = nn.Parameter( 82 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 83 | 84 | # get pair-wise relative position index for each token inside the window 85 | coords_h = torch.arange(self.window_size[0]) 86 | coords_w = torch.arange(self.window_size[1]) 87 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 88 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 89 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 90 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 91 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 92 | relative_coords[:, :, 1] += self.window_size[1] - 1 93 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 94 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 95 | self.register_buffer("relative_position_index", relative_position_index) 96 | 97 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 98 | self.attn_drop = nn.Dropout(attn_drop) 99 | self.proj = nn.Linear(dim, dim) 100 | self.proj_drop = nn.Dropout(proj_drop) 101 | 102 | trunc_normal_(self.relative_position_bias_table, std=.02) 103 | self.softmax = nn.Softmax(dim=-1) 104 | 105 | def forward(self, x, mask=None): 106 | """ 107 | Args: 108 | x: input features with shape of (num_windows*B, N, C) 109 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 110 | """ 111 | 112 | B_, N, C = x.shape 113 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 114 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 115 | 116 | q = q * self.scale 117 | attn = (q @ k.transpose(-2, -1)) 118 | 119 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 120 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 121 | 122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 123 | attn = attn + relative_position_bias.unsqueeze(0) 124 | 125 | if mask is not None: 126 | nW = mask.shape[0] 127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 128 | attn = attn.view(-1, self.num_heads, N, N) 129 | attn = self.softmax(attn) 130 | else: 131 | attn = self.softmax(attn) 132 | 133 | attn = self.attn_drop(attn) 134 | 135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 136 | x = self.proj(x) 137 | x = self.proj_drop(x) 138 | return x 139 | 140 | 141 | class SwinTransformerBlock(nn.Module): 142 | r""" Swin Transformer Block. 143 | 144 | Args: 145 | dim (int): Number of input channels. 146 | input_resolution (tuple[int]): Input resulotion. 147 | num_heads (int): Number of attention heads. 148 | window_size (int): Window size. 149 | shift_size (int): Shift size for SW-MSA. 150 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 151 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 152 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 153 | drop (float, optional): Dropout rate. Default: 0.0 154 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 155 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 156 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 157 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 158 | """ 159 | 160 | def __init__(self, dim, input_resolution, num_heads, window_size=2, shift_size=0, 161 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 162 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 163 | super().__init__() 164 | self.dim = dim 165 | self.input_resolution = input_resolution 166 | self.num_heads = num_heads 167 | self.window_size = window_size 168 | self.shift_size = shift_size 169 | self.mlp_ratio = mlp_ratio 170 | if min(self.input_resolution) <= self.window_size: 171 | # if window size is larger than input resolution, we don't partition windows 172 | self.shift_size = 0 173 | self.window_size = min(self.input_resolution) 174 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 175 | 176 | self.norm1 = norm_layer(dim) 177 | self.attn = WindowAttention( 178 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 179 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 180 | 181 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 182 | self.norm2 = norm_layer(dim) 183 | mlp_hidden_dim = int(dim * mlp_ratio) 184 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 185 | self.red = nn.Linear(2 * dim, dim) 186 | if self.shift_size > 0: 187 | # calculate attention mask for SW-MSA 188 | H, W = self.input_resolution 189 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 190 | h_slices = (slice(0, -self.window_size), 191 | slice(-self.window_size, -self.shift_size), 192 | slice(-self.shift_size, None)) 193 | w_slices = (slice(0, -self.window_size), 194 | slice(-self.window_size, -self.shift_size), 195 | slice(-self.shift_size, None)) 196 | cnt = 0 197 | for h in h_slices: 198 | for w in w_slices: 199 | img_mask[:, h, w, :] = cnt 200 | cnt += 1 201 | 202 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 203 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 204 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 205 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 206 | else: 207 | attn_mask = None 208 | 209 | self.register_buffer("attn_mask", attn_mask) 210 | 211 | def forward(self, x, hx=None): 212 | H, W = self.input_resolution 213 | B, L, C = x.shape 214 | assert L == H * W, "input feature has wrong size" 215 | 216 | shortcut = x 217 | x = self.norm1(x) 218 | if hx is not None: 219 | hx = self.norm1(hx) 220 | x = torch.cat((x, hx), -1) 221 | x = self.red(x) 222 | x = x.view(B, H, W, C) 223 | 224 | # cyclic shift 225 | if self.shift_size > 0: 226 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 227 | else: 228 | shifted_x = x 229 | 230 | # partition windows 231 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 232 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 233 | 234 | # W-MSA/SW-MSA 235 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 236 | 237 | # merge windows 238 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 239 | 240 | # reverse cyclic shift 241 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 242 | 243 | if self.shift_size > 0: 244 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 245 | else: 246 | x = shifted_x 247 | 248 | # FFN 249 | x = x.view(B, H * W, C) 250 | x = shortcut + self.drop_path(x) 251 | x = x + self.drop_path(self.mlp(self.norm2(x))) 252 | 253 | return x 254 | 255 | 256 | class PatchEmbed(nn.Module): 257 | r""" Image to Patch Embedding 258 | 259 | Args: 260 | img_size (int): Image size. 261 | patch_size (int): Patch token size. 262 | in_chans (int): Number of input image channels. 263 | embed_dim (int): Number of linear projection output channels. 264 | """ 265 | 266 | def __init__(self, img_size, patch_size, in_chans, embed_dim): 267 | super(PatchEmbed, self).__init__() 268 | img_size = to_2tuple(img_size) 269 | patch_size = to_2tuple(patch_size) 270 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 271 | self.img_size = img_size 272 | self.patch_size = patch_size 273 | self.patches_resolution = patches_resolution 274 | self.num_patches = patches_resolution[0] * patches_resolution[1] 275 | self.in_chans = in_chans 276 | self.embed_dim = embed_dim 277 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 278 | self.norm = nn.LayerNorm(embed_dim) 279 | 280 | def forward(self, x): 281 | B, C, H, W = x.shape 282 | assert H == self.img_size[0] and W == self.img_size[1], \ 283 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 284 | x = self.proj(x).flatten(2).transpose(1, 2) 285 | x = self.norm(x) 286 | return x 287 | 288 | 289 | class PatchInflated(nn.Module): 290 | r""" Tensor to Patch Inflating 291 | 292 | Args: 293 | in_chans (int): Number of input image channels. 294 | embed_dim (int): Number of linear projection output channels. 295 | input_resolution (tuple[int]): Input resulotion. 296 | """ 297 | 298 | def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1): 299 | super(PatchInflated, self).__init__() 300 | 301 | stride = to_2tuple(stride) 302 | padding = to_2tuple(padding) 303 | output_padding = to_2tuple(output_padding) 304 | self.input_resolution = input_resolution 305 | 306 | self.ConvT = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3), 307 | stride=stride, padding=padding, output_padding=output_padding) 308 | 309 | def forward(self, x): 310 | H, W = self.input_resolution 311 | B, L, C = x.shape 312 | assert L == H * W, "input feature has wrong size" 313 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 314 | 315 | x = x.view(B, H, W, C) 316 | x = x.permute(0, 3, 1, 2) 317 | x = self.ConvT(x) 318 | 319 | return x 320 | 321 | 322 | class SwinTransformerBlocks(nn.Module): 323 | 324 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, 325 | drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm): 326 | super(SwinTransformerBlocks, self).__init__() 327 | self.layers = nn.ModuleList([ 328 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 329 | num_heads=num_heads, window_size=window_size, 330 | shift_size=0 if (i % 2 == 0) else window_size // 2, 331 | mlp_ratio=mlp_ratio, 332 | qkv_bias=qkv_bias, qk_scale=qk_scale, 333 | drop=drop, attn_drop=attn_drop, 334 | drop_path=drop_path, 335 | norm_layer=norm_layer) 336 | for i in range(depth)]) 337 | 338 | def forward(self, xt, hx): 339 | 340 | outputs = [] 341 | 342 | for index, layer in enumerate(self.layers): 343 | if index == 0: 344 | x = layer(xt, hx) 345 | outputs.append(x) 346 | 347 | else: 348 | if index % 2 == 0: 349 | x = layer(outputs[-1], xt) 350 | outputs.append(x) 351 | 352 | if index % 2 == 1: 353 | x = layer(outputs[-1], None) 354 | outputs.append(x) 355 | 356 | return outputs[-1] 357 | 358 | 359 | class SwinLSTMCell(nn.Module): 360 | 361 | def __init__(self, dim, input_resolution, num_heads, window_size, depth, 362 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 363 | drop_path=0., norm_layer=nn.LayerNorm): 364 | super(SwinLSTMCell, self).__init__() 365 | 366 | self.Swin = SwinTransformerBlocks(dim=dim, input_resolution=input_resolution, depth=depth, 367 | num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, 368 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, 369 | drop_path=drop_path, norm_layer=norm_layer) 370 | 371 | def forward(self, xt, hidden_states): 372 | if hidden_states is None: 373 | B, L, C = xt.shape 374 | hx = torch.zeros(B, L, C).to(xt.device) 375 | cx = torch.zeros(B, L, C).to(xt.device) 376 | 377 | else: 378 | hx, cx = hidden_states 379 | 380 | Ft = self.Swin(xt, hx) 381 | 382 | gate = torch.sigmoid(Ft) 383 | cell = torch.tanh(Ft) 384 | 385 | cy = gate * (cx + cell) 386 | hy = gate * torch.tanh(cy) 387 | hx = hy 388 | cx = cy 389 | 390 | return hx, (hx, cx) 391 | 392 | 393 | class STconvert(nn.Module): 394 | r""" STconvert 395 | 396 | Args: 397 | img_size (int | tuple(int)): Input image size. 398 | patch_size (int | tuple(int)): Patch size. 399 | in_chans (int): Number of input image channels. 400 | embed_dim (int): Patch embedding dimension. 401 | depths (tuple(int)): Depth of Swin Transformer layer. 402 | num_heads (tuple(int)): Number of attention heads in different layers. 403 | window_size (int): Window size. 404 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 405 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 406 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 407 | drop_rate (float): Dropout rate. Default: 0 408 | attn_drop_rate (float): Attention dropout rate. Default: 0 409 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 410 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 411 | """ 412 | 413 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, 414 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 415 | norm_layer=nn.LayerNorm): 416 | 417 | super(STconvert, self).__init__() 418 | 419 | self.num_layers = len(depths) 420 | self.embed_dim = embed_dim 421 | self.mlp_ratio = mlp_ratio 422 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, 423 | in_chans=in_chans, embed_dim=embed_dim) 424 | patches_resolution = self.patch_embed.patches_resolution 425 | 426 | self.PatchInflated = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution) 427 | self.layers = nn.ModuleList() 428 | 429 | for i_layer in range(self.num_layers): 430 | layer = SwinLSTMCell(dim=embed_dim, 431 | input_resolution=(patches_resolution[0], patches_resolution[1]), 432 | depth=depths[i_layer], 433 | num_heads=num_heads[i_layer], 434 | window_size=window_size, 435 | mlp_ratio=self.mlp_ratio, 436 | qkv_bias=qkv_bias, qk_scale=qk_scale, 437 | drop=drop_rate, attn_drop=attn_drop_rate, 438 | drop_path=drop_path_rate, 439 | norm_layer=norm_layer) 440 | 441 | self.layers.append(layer) 442 | 443 | def forward(self, x, h): 444 | 445 | x = self.patch_embed(x) 446 | 447 | hidden_states = [] 448 | 449 | for index, layer in enumerate(self.layers): 450 | x, hidden_state = layer(x, h[index]) 451 | hidden_states.append(hidden_state) 452 | 453 | x = torch.sigmoid(self.PatchInflated(x)) 454 | 455 | return hidden_states, x 456 | 457 | 458 | class SwinLSTM(nn.Module): 459 | r""" SwinLSTM 460 | 461 | Args: 462 | img_size (int | tuple(int)): Input image size. 463 | patch_size (int | tuple(int)): Patch size. 464 | in_chans (int): Number of input image channels. 465 | embed_dim (int): Patch embedding dimension. 466 | depths (tuple(int)): Depth of Swin Transformer layer. 467 | num_heads (tuple(int)): Number of attention heads in different layers. 468 | window_size (int): Window size. 469 | drop_rate (float): Dropout rate. 470 | attn_drop_rate (float): Attention dropout rate. 471 | drop_path_rate (float): Stochastic depth rate. 472 | """ 473 | 474 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths, 475 | num_heads, window_size, drop_rate, attn_drop_rate, drop_path_rate): 476 | super(SwinLSTM, self).__init__() 477 | 478 | self.ST = STconvert(img_size=img_size, patch_size=patch_size, in_chans=in_chans, 479 | embed_dim=embed_dim, depths=depths, 480 | num_heads=num_heads, window_size=window_size, drop_rate=drop_rate, 481 | attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate) 482 | 483 | def forward(self, input, states): 484 | states_next, output = self.ST(input, states) 485 | 486 | return output, states_next 487 | -------------------------------------------------------------------------------- /SwinLSTM_D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | 6 | 7 | class Mlp(nn.Module): 8 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 9 | super(Mlp, self).__init__() 10 | out_features = out_features or in_features 11 | hidden_features = hidden_features or in_features 12 | self.fc1 = nn.Linear(in_features, hidden_features) 13 | self.act = act_layer() 14 | self.fc2 = nn.Linear(hidden_features, out_features) 15 | self.drop = nn.Dropout(drop) 16 | 17 | def forward(self, x): 18 | x = self.fc1(x) 19 | x = self.act(x) 20 | x = self.drop(x) 21 | x = self.fc2(x) 22 | x = self.drop(x) 23 | return x 24 | 25 | 26 | def window_partition(x, window_size): 27 | """ 28 | Args: 29 | x: (B, H, W, C) 30 | window_size (int): window size 31 | 32 | Returns: 33 | windows: (num_windows*B, window_size, window_size, C) 34 | """ 35 | B, H, W, C = x.shape 36 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 37 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 38 | return windows 39 | 40 | 41 | def window_reverse(windows, window_size, H, W): 42 | """ 43 | Args: 44 | windows: (num_windows*B, window_size, window_size, C) 45 | window_size (int): Window size 46 | H (int): Height of image 47 | W (int): Width of image 48 | 49 | Returns: 50 | x: (B, H, W, C) 51 | """ 52 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 53 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 54 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 55 | return x 56 | 57 | 58 | class WindowAttention(nn.Module): 59 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 60 | It supports both of shifted and non-shifted window. 61 | 62 | Args: 63 | dim (int): Number of input channels. 64 | window_size (tuple[int]): The height and width of the window. 65 | num_heads (int): Number of attention heads. 66 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 67 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 68 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 69 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 70 | """ 71 | 72 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 73 | 74 | super(WindowAttention, self).__init__() 75 | self.dim = dim 76 | self.window_size = window_size # Wh, Ww 77 | self.num_heads = num_heads 78 | head_dim = dim // num_heads 79 | self.scale = qk_scale or head_dim ** -0.5 80 | 81 | # define a parameter table of relative position bias 82 | self.relative_position_bias_table = nn.Parameter( 83 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 84 | 85 | # get pair-wise relative position index for each token inside the window 86 | coords_h = torch.arange(self.window_size[0]) 87 | coords_w = torch.arange(self.window_size[1]) 88 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 89 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 90 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 91 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 92 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 93 | relative_coords[:, :, 1] += self.window_size[1] - 1 94 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 95 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 96 | self.register_buffer("relative_position_index", relative_position_index) 97 | 98 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 99 | self.attn_drop = nn.Dropout(attn_drop) 100 | self.proj = nn.Linear(dim, dim) 101 | self.proj_drop = nn.Dropout(proj_drop) 102 | 103 | trunc_normal_(self.relative_position_bias_table, std=.02) 104 | self.softmax = nn.Softmax(dim=-1) 105 | 106 | def forward(self, x, mask=None): 107 | """ 108 | Args: 109 | x: input features with shape of (num_windows*B, N, C) 110 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 111 | """ 112 | 113 | B_, N, C = x.shape 114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 116 | 117 | q = q * self.scale 118 | attn = (q @ k.transpose(-2, -1)) 119 | 120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 122 | 123 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 124 | attn = attn + relative_position_bias.unsqueeze(0) 125 | 126 | if mask is not None: 127 | nW = mask.shape[0] 128 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 129 | attn = attn.view(-1, self.num_heads, N, N) 130 | attn = self.softmax(attn) 131 | else: 132 | attn = self.softmax(attn) 133 | 134 | attn = self.attn_drop(attn) 135 | 136 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 137 | x = self.proj(x) 138 | x = self.proj_drop(x) 139 | return x 140 | 141 | 142 | class SwinTransformerBlock(nn.Module): 143 | r""" Swin Transformer Block. 144 | 145 | Args: 146 | dim (int): Number of input channels. 147 | input_resolution (tuple[int]): Input resulotion. 148 | num_heads (int): Number of attention heads. 149 | window_size (int): Window size. 150 | shift_size (int): Shift size for SW-MSA. 151 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 152 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 153 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 154 | drop (float, optional): Dropout rate. Default: 0.0 155 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 156 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 157 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 158 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 159 | """ 160 | 161 | def __init__(self, dim, input_resolution, num_heads, window_size=2, shift_size=0, 162 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 163 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 164 | super().__init__() 165 | self.dim = dim 166 | self.input_resolution = input_resolution 167 | self.num_heads = num_heads 168 | self.window_size = window_size 169 | self.shift_size = shift_size 170 | self.mlp_ratio = mlp_ratio 171 | if min(self.input_resolution) <= self.window_size: 172 | # if window size is larger than input resolution, we don't partition windows 173 | self.shift_size = 0 174 | self.window_size = min(self.input_resolution) 175 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 176 | 177 | self.norm1 = norm_layer(dim) 178 | self.attn = WindowAttention( 179 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 180 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 181 | 182 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 183 | self.norm2 = norm_layer(dim) 184 | mlp_hidden_dim = int(dim * mlp_ratio) 185 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 186 | self.red = nn.Linear(2 * dim, dim) 187 | if self.shift_size > 0: 188 | # calculate attention mask for SW-MSA 189 | H, W = self.input_resolution 190 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 191 | h_slices = (slice(0, -self.window_size), 192 | slice(-self.window_size, -self.shift_size), 193 | slice(-self.shift_size, None)) 194 | w_slices = (slice(0, -self.window_size), 195 | slice(-self.window_size, -self.shift_size), 196 | slice(-self.shift_size, None)) 197 | cnt = 0 198 | for h in h_slices: 199 | for w in w_slices: 200 | img_mask[:, h, w, :] = cnt 201 | cnt += 1 202 | 203 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 204 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 205 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 206 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 207 | else: 208 | attn_mask = None 209 | 210 | self.register_buffer("attn_mask", attn_mask) 211 | 212 | def forward(self, x, hx=None): 213 | H, W = self.input_resolution 214 | B, L, C = x.shape 215 | assert L == H * W, "input feature has wrong size" 216 | 217 | shortcut = x 218 | x = self.norm1(x) 219 | if hx is not None: 220 | hx = self.norm1(hx) 221 | x = torch.cat((x, hx), -1) 222 | x = self.red(x) 223 | x = x.view(B, H, W, C) 224 | 225 | # cyclic shift 226 | if self.shift_size > 0: 227 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 228 | else: 229 | shifted_x = x 230 | 231 | # partition windows 232 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 233 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 234 | 235 | # W-MSA/SW-MSA 236 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 237 | 238 | # merge windows 239 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 240 | 241 | # reverse cyclic shift 242 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 243 | 244 | if self.shift_size > 0: 245 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 246 | else: 247 | x = shifted_x 248 | 249 | # FFN 250 | x = x.view(B, H * W, C) 251 | x = shortcut + self.drop_path(x) 252 | x = x + self.drop_path(self.mlp(self.norm2(x))) 253 | 254 | return x 255 | 256 | 257 | class PatchMerging(nn.Module): 258 | r""" Patch Merging Layer. 259 | 260 | Args: 261 | input_resolution (tuple[int]): Resolution of input feature. 262 | dim (int): Number of input channels. 263 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 264 | """ 265 | 266 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 267 | super().__init__() 268 | self.input_resolution = input_resolution 269 | self.dim = dim 270 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 271 | self.norm = norm_layer(4 * dim) 272 | 273 | def forward(self, x): 274 | H, W = self.input_resolution 275 | B, L, C = x.shape 276 | assert L == H * W, "input feature has wrong size" 277 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 278 | 279 | x = x.view(B, H, W, C) 280 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 281 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 282 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 283 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 284 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 285 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 286 | x = self.norm(x) 287 | x = self.reduction(x) 288 | return x 289 | 290 | 291 | # cite the 'PatchExpand' code from Swin-Unet, Thanks! 292 | # https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/ 293 | # swin_transformer_unet_skip_expand_decoder_sys.py, line 333-355 294 | class PatchExpanding(nn.Module): 295 | r""" Patch Expanding Layer. 296 | 297 | Args: 298 | input_resolution (tuple[int]): Resolution of input feature. 299 | dim (int): Number of input channels. 300 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 301 | """ 302 | 303 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 304 | super(PatchExpanding, self).__init__() 305 | self.input_resolution = input_resolution 306 | self.dim = dim 307 | self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() 308 | self.norm = norm_layer(dim // dim_scale) 309 | 310 | def forward(self, x): 311 | H, W = self.input_resolution 312 | x = self.expand(x) 313 | B, L, C = x.shape 314 | assert L == H * W, "input feature has wrong size" 315 | 316 | x = x.view(B, H, W, C) 317 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) 318 | x = x.view(B, -1, C // 4) 319 | x = self.norm(x) 320 | 321 | return x 322 | 323 | 324 | class PatchEmbed(nn.Module): 325 | r""" Image to Patch Embedding 326 | 327 | Args: 328 | img_size (int): Image size. 329 | patch_size (int): Patch token size. 330 | in_chans (int): Number of input image channels. 331 | embed_dim (int): Number of linear projection output channels. 332 | """ 333 | 334 | def __init__(self, img_size, patch_size, in_chans, embed_dim): 335 | super(PatchEmbed, self).__init__() 336 | img_size = to_2tuple(img_size) 337 | patch_size = to_2tuple(patch_size) 338 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 339 | self.img_size = img_size 340 | self.patch_size = patch_size 341 | self.patches_resolution = patches_resolution 342 | self.num_patches = patches_resolution[0] * patches_resolution[1] 343 | self.in_chans = in_chans 344 | self.embed_dim = embed_dim 345 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 346 | self.norm = nn.LayerNorm(embed_dim) 347 | 348 | def forward(self, x): 349 | B, C, H, W = x.shape 350 | assert H == self.img_size[0] and W == self.img_size[1], \ 351 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 352 | x = self.proj(x).flatten(2).transpose(1, 2) 353 | x = self.norm(x) 354 | return x 355 | 356 | 357 | class PatchInflated(nn.Module): 358 | r""" Tensor to Patch Inflating 359 | 360 | Args: 361 | in_chans (int): Number of input image channels. 362 | embed_dim (int): Number of linear projection output channels. 363 | input_resolution (tuple[int]): Input resulotion. 364 | """ 365 | 366 | def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1): 367 | super(PatchInflated, self).__init__() 368 | 369 | stride = to_2tuple(stride) 370 | padding = to_2tuple(padding) 371 | output_padding = to_2tuple(output_padding) 372 | self.input_resolution = input_resolution 373 | 374 | self.Conv = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3), 375 | stride=stride, padding=padding, output_padding=output_padding) 376 | 377 | def forward(self, x): 378 | H, W = self.input_resolution 379 | B, L, C = x.shape 380 | assert L == H * W, "input feature has wrong size" 381 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 382 | 383 | x = x.view(B, H, W, C) 384 | x = x.permute(0, 3, 1, 2) 385 | x = self.Conv(x) 386 | 387 | return x 388 | 389 | 390 | class SwinLSTMCell(nn.Module): 391 | def __init__(self, dim, input_resolution, num_heads, window_size, depth, 392 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 393 | drop_path=0., norm_layer=nn.LayerNorm, flag=None): 394 | super(SwinLSTMCell, self).__init__() 395 | 396 | self.Swin = SwinTransformer(dim=dim, input_resolution=input_resolution, depth=depth, 397 | num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, 398 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, 399 | drop_path=drop_path, norm_layer=norm_layer, flag=flag) 400 | 401 | def forward(self, xt, hidden_states): 402 | if hidden_states is None: 403 | B, L, C = xt.shape 404 | hx = torch.zeros(B, L, C).to(xt.device) 405 | cx = torch.zeros(B, L, C).to(xt.device) 406 | 407 | else: 408 | hx, cx = hidden_states 409 | 410 | Ft = self.Swin(xt, hx) 411 | 412 | gate = torch.sigmoid(Ft) 413 | cell = torch.tanh(Ft) 414 | 415 | cy = gate * (cx + cell) 416 | hy = gate * torch.tanh(cy) 417 | hx = hy 418 | cx = cy 419 | 420 | return hx, (hx, cx) 421 | 422 | 423 | class SwinTransformer(nn.Module): 424 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, 425 | drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, flag=None): 426 | super(SwinTransformer, self).__init__() 427 | 428 | self.layers = nn.ModuleList([ 429 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 430 | num_heads=num_heads, window_size=window_size, 431 | shift_size=0 if (i % 2 == 0) else window_size // 2, 432 | mlp_ratio=mlp_ratio, 433 | qkv_bias=qkv_bias, qk_scale=qk_scale, 434 | drop=drop, attn_drop=attn_drop, 435 | drop_path=drop_path[depth - i - 1] if (flag == 0) else drop_path[i], 436 | norm_layer=norm_layer) 437 | for i in range(depth)]) 438 | 439 | def forward(self, xt, hx): 440 | 441 | outputs = [] 442 | 443 | for index, layer in enumerate(self.layers): 444 | if index == 0: 445 | x = layer(xt, hx) 446 | outputs.append(x) 447 | 448 | else: 449 | if index % 2 == 0: 450 | x = layer(outputs[-1], xt) 451 | outputs.append(x) 452 | 453 | if index % 2 == 1: 454 | x = layer(outputs[-1], None) 455 | outputs.append(x) 456 | 457 | return outputs[-1] 458 | 459 | 460 | class DownSample(nn.Module): 461 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, num_heads, window_size, 462 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 463 | norm_layer=nn.LayerNorm): 464 | super(DownSample, self).__init__() 465 | 466 | self.num_layers = len(depths_downsample) 467 | self.embed_dim = embed_dim 468 | self.mlp_ratio = mlp_ratio 469 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 470 | patches_resolution = self.patch_embed.patches_resolution 471 | 472 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_downsample))] 473 | 474 | self.layers = nn.ModuleList() 475 | self.downsample = nn.ModuleList() 476 | 477 | for i_layer in range(self.num_layers): 478 | downsample = PatchMerging(input_resolution=(patches_resolution[0] // (2 ** i_layer), 479 | patches_resolution[1] // (2 ** i_layer)), 480 | dim=int(embed_dim * 2 ** i_layer)) 481 | 482 | layer = SwinLSTMCell(dim=int(embed_dim * 2 ** i_layer), 483 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 484 | patches_resolution[1] // (2 ** i_layer)), 485 | depth=depths_downsample[i_layer], 486 | num_heads=num_heads[i_layer], 487 | window_size=window_size, 488 | mlp_ratio=self.mlp_ratio, 489 | qkv_bias=qkv_bias, qk_scale=qk_scale, 490 | drop=drop_rate, attn_drop=attn_drop_rate, 491 | drop_path=dpr[sum(depths_downsample[:i_layer]):sum(depths_downsample[:i_layer + 1])], 492 | norm_layer=norm_layer) 493 | 494 | self.layers.append(layer) 495 | self.downsample.append(downsample) 496 | 497 | def forward(self, x, y): 498 | 499 | x = self.patch_embed(x) 500 | 501 | hidden_states_down = [] 502 | 503 | for index, layer in enumerate(self.layers): 504 | x, hidden_state = layer(x, y[index]) 505 | x = self.downsample[index](x) 506 | hidden_states_down.append(hidden_state) 507 | 508 | return hidden_states_down, x 509 | 510 | 511 | class UpSample(nn.Module): 512 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_upsample, num_heads, window_size, mlp_ratio=4., 513 | qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 514 | norm_layer=nn.LayerNorm, flag=0): 515 | super(UpSample, self).__init__() 516 | 517 | self.img_size = img_size 518 | self.num_layers = len(depths_upsample) 519 | self.embed_dim = embed_dim 520 | self.mlp_ratio = mlp_ratio 521 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 522 | patches_resolution = self.patch_embed.patches_resolution 523 | self.Unembed = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution) 524 | 525 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_upsample))] 526 | 527 | self.layers = nn.ModuleList() 528 | self.upsample = nn.ModuleList() 529 | 530 | for i_layer in range(self.num_layers): 531 | resolution1 = (patches_resolution[0] // (2 ** (self.num_layers - i_layer))) 532 | resolution2 = (patches_resolution[1] // (2 ** (self.num_layers - i_layer))) 533 | 534 | dimension = int(embed_dim * 2 ** (self.num_layers - i_layer)) 535 | upsample = PatchExpanding(input_resolution=(resolution1, resolution2), dim=dimension) 536 | 537 | layer = SwinLSTMCell(dim=dimension, input_resolution=(resolution1, resolution2), 538 | depth=depths_upsample[(self.num_layers - 1 - i_layer)], 539 | num_heads=num_heads[(self.num_layers - 1 - i_layer)], 540 | window_size=window_size, 541 | mlp_ratio=self.mlp_ratio, 542 | qkv_bias=qkv_bias, qk_scale=qk_scale, 543 | drop=drop_rate, attn_drop=attn_drop_rate, 544 | drop_path=dpr[sum(depths_upsample[:(self.num_layers - 1 - i_layer)]): 545 | sum(depths_upsample[:(self.num_layers - 1 - i_layer) + 1])], 546 | norm_layer=norm_layer, flag=flag) 547 | 548 | self.layers.append(layer) 549 | self.upsample.append(upsample) 550 | 551 | def forward(self, x, y): 552 | hidden_states_up = [] 553 | 554 | for index, layer in enumerate(self.layers): 555 | x, hidden_state = layer(x, y[index]) 556 | x = self.upsample[index](x) 557 | hidden_states_up.append(hidden_state) 558 | 559 | x = torch.sigmoid(self.Unembed(x)) 560 | 561 | return hidden_states_up, x 562 | 563 | 564 | class SwinLSTM(nn.Module): 565 | def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, depths_upsample, num_heads, 566 | window_size): 567 | super(SwinLSTM, self).__init__() 568 | 569 | self.Downsample = DownSample(img_size=img_size, patch_size=patch_size, in_chans=in_chans, 570 | embed_dim=embed_dim, depths_downsample=depths_downsample, 571 | num_heads=num_heads, window_size=window_size) 572 | 573 | self.Upsample = UpSample(img_size=img_size, patch_size=patch_size, in_chans=in_chans, 574 | embed_dim=embed_dim, depths_upsample=depths_upsample, 575 | num_heads=num_heads, window_size=window_size) 576 | 577 | def forward(self, input, states_down, states_up): 578 | states_down, x = self.Downsample(input, states_down) 579 | states_up, output = self.Upsample(x, states_up) 580 | 581 | return output, states_down, states_up 582 | -------------------------------------------------------------------------------- /__pycache__/SwinLSTM_D.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/SwinLSTM_D.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/configs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/configs.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/functions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/functions.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/architecture.png -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser('SwinLSTM training and evaluation script', add_help=False) 6 | 7 | # Setup parameters 8 | parser.add_argument('--device', default='cuda:0', type=str) 9 | parser.add_argument('--res_dir', default='./results', type=str) 10 | parser.add_argument('--seed', default=1234, type=int) 11 | 12 | # Moving_MNIST dataset parameters 13 | parser.add_argument('--num_frames_input', default=10, type=int, help='Input sequence length') 14 | parser.add_argument('--num_frames_output', default=10, type=int, help='Output sequence length') 15 | parser.add_argument('--image_size', default=(28, 28), type=int, help='Original resolution') 16 | parser.add_argument('--input_size', default=(64, 64), help='Input resolution') 17 | parser.add_argument('--step_length', default=0.1, type=float) 18 | parser.add_argument('--num_objects', default=[2], type=int) 19 | parser.add_argument('--train_samples', default=[0, 10000], type=int, help='Number of samples in training set') 20 | parser.add_argument('--valid_samples', default=[10000, 13000], type=int, help='Number of samples in validation set') 21 | parser.add_argument('--train_data_dir', default='./data/train-images-idx3-ubyte.gz') 22 | parser.add_argument('--test_data_dir', default='./data/mnist_test_seq.npy') 23 | 24 | # model parameters 25 | parser.add_argument('--model', default='SwinLSTM-D', type=str, choices=['SwinLSTM-B', 'SwinLSTM-D'], 26 | help='Model type') 27 | parser.add_argument('--input_channels', default=1, type=int, help='Number of input image channels') 28 | parser.add_argument('--input_img_size', default=64, type=int, help='Input image size') 29 | parser.add_argument('--patch_size', default=2, type=int, help='Patch size of input images') 30 | parser.add_argument('--embed_dim', default=128, type=int, help='Patch embedding dimension') 31 | parser.add_argument('--depths', default=[12], type=int, help='Depth of Swin Transformer layer for SwinLSTM-B') 32 | parser.add_argument('--depths_down', default=[2, 6], type=int, help='Downsample of SwinLSTM-D') 33 | parser.add_argument('--depths_up', default=[6, 2], type=int, help='Upsample of SwinLSTM-D') 34 | parser.add_argument('--heads_number', default=[4, 8], type=int, 35 | help='Number of attention heads in different layers') 36 | parser.add_argument('--window_size', default=4, type=int, help='Window size of Swin Transformer layer') 37 | parser.add_argument('--drop_rate', default=0., type=float, help='Dropout rate') 38 | parser.add_argument('--attn_drop_rate', default=0., type=float, help='Attention dropout rate') 39 | parser.add_argument('--drop_path_rate', default=0.1, type=float, help='Stochastic depth rate') 40 | 41 | # Training parameters 42 | parser.add_argument('--train_batch_size', default=16, type=int, help='Batch size for training') 43 | parser.add_argument('--valid_batch_size', default=16, type=int, help='Batch size for validation') 44 | parser.add_argument('--test_batch_size', default=16, type=int, help='Batch size for testing') 45 | parser.add_argument('--num_workers', default=8, type=int) 46 | parser.add_argument('--epochs', default=2000, type=int) 47 | parser.add_argument('--epoch_valid', default=10, type=int) 48 | parser.add_argument('--log_train', default=100, type=int) 49 | parser.add_argument('--log_valid', default=60, type=int) 50 | parser.add_argument('--lr', default=0.0001, type=float, help='Learning rate') 51 | 52 | args = parser.parse_args() 53 | 54 | return args 55 | -------------------------------------------------------------------------------- /data/test_data_download_link.md: -------------------------------------------------------------------------------- 1 | http://www.cs.toronto.edu/~nitish/unsupervised_video/ 2 | -------------------------------------------------------------------------------- /data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongTang-x/SwinLSTM/4425bcfefbebfac85c9fc6c6659361b8154d5ca3/data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class Moving_MNIST(Dataset): 10 | 11 | def __init__(self, args, split): 12 | 13 | super(Moving_MNIST, self).__init__() 14 | 15 | with gzip.open(args.train_data_dir, 'rb') as f: 16 | self.datas = np.frombuffer(f.read(), np.uint8, offset=16) 17 | self.datas = self.datas.reshape(-1, *args.image_size) 18 | 19 | if split == 'train': 20 | self.datas = self.datas[args.train_samples[0]: args.train_samples[1]] 21 | else: 22 | self.datas = self.datas[args.valid_samples[0]: args.valid_samples[1]] 23 | 24 | self.image_size = args.image_size 25 | self.input_size = args.input_size 26 | self.step_length = args.step_length 27 | self.num_objects = args.num_objects 28 | 29 | self.num_frames_input = args.num_frames_input 30 | self.num_frames_output = args.num_frames_output 31 | self.num_frames_total = args.num_frames_input + args.num_frames_output 32 | 33 | print('Loaded {} {} samples'.format(self.__len__(), split)) 34 | 35 | def _get_random_trajectory(self, seq_length): 36 | 37 | assert self.input_size[0] == self.input_size[1] 38 | assert self.image_size[0] == self.image_size[1] 39 | 40 | canvas_size = self.input_size[0] - self.image_size[0] 41 | 42 | x = random.random() 43 | y = random.random() 44 | 45 | theta = random.random() * 2 * np.pi 46 | 47 | v_y = np.sin(theta) 48 | v_x = np.cos(theta) 49 | 50 | start_y = np.zeros(seq_length) 51 | start_x = np.zeros(seq_length) 52 | 53 | for i in range(seq_length): 54 | 55 | y += v_y * self.step_length 56 | x += v_x * self.step_length 57 | 58 | if x <= 0.: x = 0.; v_x = -v_x; 59 | if x >= 1.: x = 1.; v_x = -v_x 60 | if y <= 0.: y = 0.; v_y = -v_y; 61 | if y >= 1.: y = 1.; v_y = -v_y 62 | 63 | start_y[i] = y 64 | start_x[i] = x 65 | 66 | start_y = (canvas_size * start_y).astype(np.int32) 67 | start_x = (canvas_size * start_x).astype(np.int32) 68 | 69 | return start_y, start_x 70 | 71 | def _generate_moving_mnist(self, num_digits=2): 72 | 73 | data = np.zeros((self.num_frames_total, *self.input_size), dtype=np.float32) 74 | 75 | for n in range(num_digits): 76 | 77 | start_y, start_x = self._get_random_trajectory(self.num_frames_total) 78 | ind = np.random.randint(0, self.__len__()) 79 | digit_image = self.datas[ind] 80 | 81 | for i in range(self.num_frames_total): 82 | top = start_y[i] 83 | left = start_x[i] 84 | bottom = top + self.image_size[0] 85 | right = left + self.image_size[1] 86 | data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image) 87 | 88 | data = data[..., np.newaxis] 89 | 90 | return data 91 | 92 | def __getitem__(self, item): 93 | 94 | num_digits = random.choice(self.num_objects) 95 | images = self._generate_moving_mnist(num_digits) 96 | 97 | inputs = torch.from_numpy(images[:self.num_frames_input]).permute(0, 3, 1, 2).contiguous() 98 | targets = torch.from_numpy(images[self.num_frames_output:]).permute(0, 3, 1, 2).contiguous() 99 | 100 | return inputs / 255., targets / 255. 101 | 102 | def __len__(self): 103 | return self.datas.shape[0] 104 | 105 | 106 | class Moving_MNIST_Test(Dataset): 107 | def __init__(self, args): 108 | super(Moving_MNIST_Test, self).__init__() 109 | 110 | self.num_frames_input = args.num_frames_input 111 | self.num_frames_output = args.num_frames_output 112 | self.num_frames_total = args.num_frames_input + args.num_frames_output 113 | 114 | self.dataset = np.load(args.test_data_dir) 115 | self.dataset = self.dataset[..., np.newaxis] 116 | 117 | print('Loaded {} {} samples'.format(self.__len__(), 'test')) 118 | 119 | def __getitem__(self, index): 120 | images = self.dataset[:, index, ...] 121 | 122 | inputs = torch.from_numpy(images[:self.num_frames_input]).permute(0, 3, 1, 2).contiguous() 123 | targets = torch.from_numpy(images[self.num_frames_output:]).permute(0, 3, 1, 2).contiguous() 124 | 125 | return inputs / 255., targets / 255. 126 | 127 | def __len__(self): 128 | return len(self.dataset[1]) 129 | 130 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.cuda import amp 4 | from torch.cuda.amp import autocast as autocast 5 | from utils import compute_metrics, visualize 6 | 7 | scaler = amp.GradScaler() 8 | 9 | 10 | def model_forward_single_layer(model, inputs, targets_len, num_layers): 11 | outputs = [] 12 | states = [None] * len(num_layers) 13 | 14 | inputs_len = inputs.shape[1] 15 | 16 | last_input = inputs[:, -1] 17 | 18 | for i in range(inputs_len - 1): 19 | output, states = model(inputs[:, i], states) 20 | outputs.append(output) 21 | 22 | for i in range(targets_len): 23 | output, states = model(last_input, states) 24 | outputs.append(output) 25 | last_input = output 26 | 27 | return outputs 28 | 29 | 30 | def model_forward_multi_layer(model, inputs, targets_len, num_layers): 31 | states_down = [None] * len(num_layers) 32 | states_up = [None] * len(num_layers) 33 | 34 | outputs = [] 35 | 36 | inputs_len = inputs.shape[1] 37 | 38 | last_input = inputs[:, -1] 39 | 40 | for i in range(inputs_len - 1): 41 | output, states_down, states_up = model(inputs[:, i], states_down, states_up) 42 | outputs.append(output) 43 | 44 | for i in range(targets_len): 45 | output, states_down, states_up = model(last_input, states_down, states_up) 46 | outputs.append(output) 47 | last_input = output 48 | 49 | return outputs 50 | 51 | 52 | def train(args, logger, epoch, model, train_loader, criterion, optimizer): 53 | model.train() 54 | num_batches = len(train_loader) 55 | losses = [] 56 | 57 | for batch_idx, (inputs, targets) in enumerate(train_loader): 58 | 59 | optimizer.zero_grad() 60 | 61 | inputs, targets = map(lambda x: x.float().to(args.device), [inputs, targets]) 62 | targets_len = targets.shape[1] 63 | with autocast(): 64 | if args.model == 'SwinLSTM-B': 65 | outputs = model_forward_single_layer(model, inputs, targets_len, args.depths) 66 | 67 | if args.model == 'SwinLSTM-D': 68 | outputs = model_forward_multi_layer(model, inputs, targets_len, args.depths_down) 69 | 70 | outputs = torch.stack(outputs).permute(1, 0, 2, 3, 4).contiguous() 71 | targets_ = torch.cat((inputs[:, 1:], targets), dim=1) 72 | loss = criterion(outputs, targets_) 73 | 74 | scaler.scale(loss).backward() 75 | scaler.step(optimizer) 76 | scaler.update() 77 | 78 | losses.append(loss.item()) 79 | 80 | if batch_idx and batch_idx % args.log_train == 0: 81 | logger.info(f'EP:{epoch:04d} BI:{batch_idx:03d}/{num_batches:03d} Loss:{np.mean(losses):.6f}') 82 | 83 | return np.mean(losses) 84 | 85 | 86 | def test(args, logger, epoch, model, test_loader, criterion, cache_dir): 87 | model.eval() 88 | num_batches = len(test_loader) 89 | losses, mses, ssims = [], [], [] 90 | 91 | for batch_idx, (inputs, targets) in enumerate(test_loader): 92 | 93 | with torch.no_grad(): 94 | inputs, targets = map(lambda x: x.float().to(args.device), [inputs, targets]) 95 | targets_len = targets.shape[1] 96 | 97 | if args.model == 'SwinLSTM-B': 98 | outputs = model_forward_single_layer(model, inputs, targets_len, args.depths) 99 | 100 | if args.model == 'SwinLSTM-D': 101 | outputs = model_forward_multi_layer(model, inputs, targets_len, args.depths_down) 102 | 103 | outputs = torch.stack(outputs).permute(1, 0, 2, 3, 4).contiguous() 104 | targets_ = torch.cat((inputs[:, 1:], targets), dim=1) 105 | 106 | losses.append(criterion(outputs, targets_).item()) 107 | 108 | inputs_len = inputs.shape[1] 109 | outputs = outputs[:, inputs_len - 1:] 110 | 111 | mse, ssim = compute_metrics(outputs, targets) 112 | 113 | mses.append(mse) 114 | ssims.append(ssim) 115 | 116 | if batch_idx and batch_idx % args.log_valid == 0: 117 | logger.info( 118 | f'EP:{epoch:04d} BI:{batch_idx:03d}/{num_batches:03d} Loss:{np.mean(losses):.6f} MSE:{mse:.4f} SSIM:{ssim:.4f}') 119 | visualize(inputs, targets, outputs, epoch, batch_idx, cache_dir) 120 | 121 | return np.mean(losses), np.mean(mses), np.mean(ssims) 122 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchvision==0.12.0 3 | numpy 4 | matplotlib 5 | skimage==0.19.2 6 | timm==0.4.12 7 | einops==0.4.1 8 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from SwinLSTM_D import SwinLSTM 6 | from configs import get_args 7 | from dataset import Moving_MNIST_Test 8 | from functions import test 9 | from utils import set_seed, make_dir, init_logger 10 | 11 | if __name__ == '__main__': 12 | args = get_args() 13 | set_seed(args.seed) 14 | cache_dir, model_dir, log_dir = make_dir(args) 15 | logger = init_logger(log_dir) 16 | 17 | model = SwinLSTM(img_size=args.input_img_size, patch_size=args.patch_size, 18 | in_chans=args.input_channels, embed_dim=args.embed_dim, 19 | depths_downsample=args.depths_down, depths_upsample=args.depths_up, 20 | num_heads=args.heads_number, window_size=args.window_size).to(args.device) 21 | 22 | criterion = nn.MSELoss() 23 | 24 | test_dataset = Moving_MNIST_Test(args) 25 | test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, 26 | num_workers=args.num_workers, shuffle=False, pin_memory=True, drop_last=True) 27 | 28 | model.load_state_dict(torch.load('./Pretrained/trained_model_state_dict')) 29 | 30 | start_time = time.time() 31 | 32 | _, mse, ssim = test(args, logger, 0, model, test_loader, criterion, cache_dir) 33 | 34 | print(f'[Metrics] MSE:{mse:.4f} SSIM:{ssim:.4f}') 35 | print(f'Time usage per epoch: {time.time() - start_time:.0f}s') 36 | 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import torch.nn as nn 3 | from configs import get_args 4 | from functions import train, test 5 | from torch.utils.data import DataLoader 6 | from dataset import Moving_MNIST 7 | 8 | 9 | def setup(args): 10 | if args.model == 'SwinLSTM-B': 11 | from SwinLSTM_B import SwinLSTM 12 | model = SwinLSTM(img_size=args.input_img_size, patch_size=args.patch_size, 13 | in_chans=args.input_channels, embed_dim=args.embed_dim, 14 | depths=args.depths, num_heads=args.heads_number, 15 | window_size=args.window_size, drop_rate=args.drop_rate, 16 | attn_drop_rate=args.attn_drop_rate, drop_path_rate=args.drop_path_rate).to(args.device) 17 | 18 | if args.model == 'SwinLSTM-D': 19 | from SwinLSTM_D import SwinLSTM 20 | model = SwinLSTM(img_size=args.input_img_size, patch_size=args.patch_size, 21 | in_chans=args.input_channels, embed_dim=args.embed_dim, 22 | depths_downsample=args.depths_down, depths_upsample=args.depths_up, 23 | num_heads=args.heads_number, window_size=args.window_size).to(args.device) 24 | 25 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 26 | 27 | criterion = nn.MSELoss() 28 | 29 | train_dataset = Moving_MNIST(args, split='train') 30 | valid_dataset = Moving_MNIST(args, split='valid') 31 | 32 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, 33 | num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True) 34 | 35 | valid_loader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, 36 | num_workers=args.num_workers, shuffle=False, pin_memory=True, drop_last=True) 37 | 38 | return model, criterion, optimizer, train_loader, valid_loader 39 | 40 | def main(): 41 | args = get_args() 42 | set_seed(args.seed) 43 | cache_dir, model_dir, log_dir = make_dir(args) 44 | logger = init_logger(log_dir) 45 | 46 | model, criterion, optimizer, train_loader, valid_loader = setup(args) 47 | 48 | train_losses, valid_losses = [], [] 49 | 50 | best_metric = (0, float('inf'), float('inf')) 51 | 52 | for epoch in range(args.epochs): 53 | 54 | start_time = time.time() 55 | train_loss = train(args, logger, epoch, model, train_loader, criterion, optimizer) 56 | train_losses.append(train_loss) 57 | plot_loss(train_losses, 'train', epoch, args.res_dir, 1) 58 | 59 | if (epoch + 1) % args.epoch_valid == 0: 60 | 61 | valid_loss, mse, ssim = test(args, logger, epoch, model, valid_loader, criterion, cache_dir) 62 | 63 | valid_losses.append(valid_loss) 64 | 65 | plot_loss(valid_losses, 'valid', epoch, args.res_dir, args.epoch_valid) 66 | 67 | if mse < best_metric[1]: 68 | torch.save(model.state_dict(), f'{model_dir}/trained_model_state_dict') 69 | best_metric = (epoch, mse, ssim) 70 | 71 | logger.info(f'[Current Best] EP:{best_metric[0]:04d} MSE:{best_metric[1]:.4f} SSIM:{best_metric[2]:.4f}') 72 | 73 | print(f'Time usage per epoch: {time.time() - start_time:.0f}s') 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import random 5 | import logging 6 | import matplotlib 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | from skimage.metrics import structural_similarity 10 | 11 | matplotlib.use('agg') 12 | 13 | def set_seed(seed): 14 | torch.manual_seed(seed) 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.backends.cudnn.deterministic = True 18 | 19 | def visualize(inputs, targets, outputs, epoch, idx, cache_dir): 20 | _, axarray = plt.subplots(3, targets.shape[1], figsize=(targets.shape[1] * 5, 10)) 21 | 22 | for t in range(targets.shape[1]): 23 | axarray[0][t].imshow(inputs[0, t, 0].detach().cpu().numpy(), cmap='gray') 24 | axarray[1][t].imshow(targets[0, t, 0].detach().cpu().numpy(), cmap='gray') 25 | axarray[2][t].imshow(outputs[0, t, 0].detach().cpu().numpy(), cmap='gray') 26 | 27 | plt.savefig(os.path.join(cache_dir, '{:03d}-{:03d}.png'.format(epoch, idx))) 28 | plt.close() 29 | 30 | def plot_loss(loss_records, loss_type, epoch, plot_dir, step): 31 | plt.plot(range((epoch + 1) // step), loss_records, label=loss_type) 32 | plt.legend() 33 | plt.savefig(os.path.join(plot_dir, '{}_loss_records.png'.format(loss_type))) 34 | plt.close() 35 | 36 | def MAE(pred, true): 37 | return np.mean(np.abs(pred - true), axis=(0, 1)).sum() 38 | 39 | def MSE(pred, true): 40 | return np.mean((pred - true) ** 2, axis=(0, 1)).sum() 41 | 42 | # cite the 'PSNR' code from E3D-LSTM, Thanks! 43 | # https://github.com/google/e3d_lstm/blob/master/src/trainer.py line 39-40 44 | def PSNR(pred, true): 45 | mse = np.mean((np.uint8(pred * 255) - np.uint8(true * 255)) ** 2) 46 | return 20 * np.log10(255) - 10 * np.log10(mse) 47 | 48 | def compute_metrics(predictions, targets): 49 | targets = targets.permute(0, 1, 3, 4, 2).detach().cpu().numpy() 50 | predictions = predictions.permute(0, 1, 3, 4, 2).detach().cpu().numpy() 51 | 52 | batch_size = predictions.shape[0] 53 | Seq_len = predictions.shape[1] 54 | 55 | ssim = 0 56 | 57 | for batch in range(batch_size): 58 | for frame in range(Seq_len): 59 | ssim += structural_similarity(targets[batch, frame].squeeze(), 60 | predictions[batch, frame].squeeze()) 61 | 62 | ssim /= (batch_size * Seq_len) 63 | 64 | mse = MSE(predictions, targets) 65 | 66 | return mse, ssim 67 | 68 | def check_dir(path): 69 | if not os.path.exists(path): 70 | os.makedirs(path) 71 | 72 | def make_dir(args): 73 | 74 | cache_dir = os.path.join(args.res_dir, 'cache') 75 | check_dir(cache_dir) 76 | 77 | model_dir = os.path.join(args.res_dir, 'model') 78 | check_dir(model_dir) 79 | 80 | log_dir = os.path.join(args.res_dir, 'log') 81 | check_dir(log_dir) 82 | 83 | return cache_dir, model_dir, log_dir 84 | 85 | def init_logger(log_dir): 86 | logging.basicConfig(level=logging.INFO, 87 | format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', 88 | datefmt='%m-%d %H:%M', 89 | filename=os.path.join(log_dir, time.strftime("%Y_%m_%d") + '.log'), 90 | filemode='w') 91 | 92 | console = logging.StreamHandler() 93 | console.setLevel(logging.INFO) 94 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 95 | console.setFormatter(formatter) 96 | logging.getLogger('').addHandler(console) 97 | 98 | return logging 99 | --------------------------------------------------------------------------------