├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md ├── example_data.npy └── source_data.xlsx ├── data_ranges.npy ├── dataloader.py ├── desired_results.xlsx ├── infer.py ├── model.py ├── pics └── framework.png ├── pytorch_wavelets ├── __init__.py ├── _version.py ├── dtcwt │ ├── __init__.py │ ├── coeffs.py │ ├── data │ │ ├── __init__.py │ │ ├── antonini.npz │ │ ├── farras.npz │ │ ├── legall.npz │ │ ├── near_sym_a.npz │ │ ├── near_sym_a2.npz │ │ ├── near_sym_b.npz │ │ ├── near_sym_b_bp.npz │ │ ├── qshift_06.npz │ │ ├── qshift_32.npz │ │ ├── qshift_a.npz │ │ ├── qshift_b.npz │ │ ├── qshift_b_bp.npz │ │ ├── qshift_c.npz │ │ └── qshift_d.npz │ ├── lowlevel.py │ ├── lowlevel2.py │ ├── transform2d.py │ └── transform_funcs.py ├── dwt │ ├── __init__.py │ ├── lowlevel.py │ ├── swt_inverse.py │ ├── transform1d.py │ └── transform2d.py ├── scatternet │ ├── __init__.py │ ├── layers.py │ └── lowlevel.py └── utils.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | converted.pt 2 | model_mod1121.py 3 | prepareData.py 4 | transfer.py 5 | WTFTP_his_tgt_pre_score.pt 6 | .idea 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flight trajectory prediction enabled by time-frequency wavelet transform 2 | 3 | # Introduction 4 | 5 | This repository provides source codes of the proposed flight trajectory prediction framework, called WTFTP, 6 | and example samples for the paper Flight trajectory 7 | prediction enabled by time-frequency wavelet transform. 8 | This is a novel wavelet-based flight trajectory prediction framework by modeling global flight trends and local motion 9 | details. 10 | 11 |

12 | 13 | ## Repository Structure 14 | ``` 15 | wtftp-model 16 | │ dataloader.py (Load trajectory data from ./data) 17 | │ data_ranges.npy (Provide data ranges for normalization and denormalization) 18 | │ desired_results.xlsx (Excepted prediction result of example samples) 19 | │ infer.py (Perform the prediction procedure) 20 | │ LICENSE (LICENSE file) 21 | │ model.py (The neural architecture corresponding to the WTFTP framework) 22 | │ README.md (The current README file) 23 | │ train.py (The training code) 24 | │ utils.py (Tools for the project) 25 | │ 26 | ├─data 27 | │ │ README.md (README file for the dataset.) 28 | │ │ example_data.npy (Example data file) 29 | │ │ source_data.xlsx (Source data of figures in the paper) 30 | │ ├─dev (Archive for the validation data) 31 | │ ├─test (Archive for the test data) 32 | │ └─train (Archive for the training data) 33 | └─pics 34 | ``` 35 | 36 | ## Package Requirements 37 | 38 | + Python == 3.7.1 39 | + torch == 1.4.0 + cu100 40 | + numpy == 1.18.5 41 | + tensorboard == 2.3.0 42 | + tensorboardX == 2.1 43 | + PyWavelets == 1.2.0 44 | + matplotlib == 3.2.1 45 | 46 | ## System Requirements 47 | + Ubuntu 16.04 operating system 48 | + Intel(R) Xeon(TM) E5-2690@2.90GHz 49 | + 128G of memory 50 | + 8TB of hard disks 51 | + 8 $\times$ NVIDIA(R) GeForce RTX(TM) 2080 Ti 11G GPUs. 52 | 53 | 54 | # Instructions 55 | ## Installation 56 | 57 | ### Clone this repository 58 | 59 | ``` 60 | git clone https://github.com/MusDev7/wtftp-model.git 61 | ``` 62 | 63 | ### Create proper software and hardware environment 64 | 65 | You are recommended to create a conda environment with the package requirements mentioned above, and conduct the 66 | training and test on the suggested system configurations. 67 | 68 | ## Training 69 | 70 | The training script is provided by `train.py` for the flight trajectory prediction. The arguments for the training 71 | process are defined bellow: 72 | 73 | `--minibatch_len`: Integer. The sliding-window length for constructing the samples. `default=10` 74 | 75 | `--interval`: Integer. The sampling period. `default=1` 76 | 77 | `--batch_size`: Integer. The number of samples in a single training batch. `default=2048` 78 | 79 | `--epoch`: Integer. The maximum epoch for training process. `default=150` 80 | 81 | `--lr`: Float. The learning rate of the Adam optimizer. `default=0.001` 82 | 83 | `--dpot`: Float. The dropout probability. `default=0.0` 84 | 85 | `--cpu`: Optional. Use the CPU for training process. 86 | 87 | `--nolongging`: Optional. The logs will not be recorded. 88 | 89 | `--logdir`: String. The path for logs. `default='./log'` 90 | 91 | `--datadir`: String. The path for dataset. `default='./data'` 92 | 93 | `--saving_model_num`: Integer. The number of models to be saved during the training process. `default=0` 94 | 95 | `--debug`: Optional. For debugging the scripts. 96 | 97 | `--bidirectional`: Optional. Use the bidirectional LSTM block. 98 | 99 | `--maxlevel`: Integer. The level of wavelet analysis. `default=1` 100 | 101 | `--wavelet`: String. The wavelet basis. `default=haar` 102 | 103 | `--wt_mode`: String. The signal extension mode for wavelet transform. `default=symmetric` 104 | 105 | `--w_lo`: Float. The weight for the low-frequency wavelet component in the loss function. `default=1.0` 106 | 107 | `--w_hi`: Float. The weight for the high-frequency wavelet components in the loss function. `default=1.0` 108 | 109 | `--enlayer`: Integer. The layer number of the LSTM block in the encoder. `default=4` 110 | 111 | `--delayer`: Integer. The layer number of the LSTM block in the decoder. `default=1` 112 | 113 | `--embding`: Integer. The dimension of trajectory embeddings, enhanced trajectory embeddings and contextual 114 | embeddings. `default=64` 115 | 116 | `--attn`: Optional. Use the wavelet attention module in the decoder. 117 | 118 | `--cuda`: Integer. The GPU index for training process. 119 | 120 | To train the WTFTP, use the following command. 121 | 122 | ``` 123 | python train.py --saving_model_num 10 --attn 124 | ``` 125 | 126 | To train the WTFTP without the wavelet attention module, use the following command. 127 | 128 | ``` 129 | python train.py --saving_model_num 10 130 | ``` 131 | 132 | To train the WTFTP of 2 or 3 level of wavelet analysis, use the following command. 133 | 134 | ``` 135 | python train.py --saving_model_num 10 --attn --maxlevel 2 136 | ``` 137 | ``` 138 | python train.py --saving_model_num 10 --attn --maxlevel 3 139 | ``` 140 | 141 | To train the WTFTP without the wavelet attention module of 2 or 3 level of wavelet analysis, use the following command. 142 | 143 | ``` 144 | python train.py --saving_model_num 10 --maxlevel 2 145 | ``` 146 | 147 | ``` 148 | python train.py --saving_model_num 10 --maxlevel 3 149 | ``` 150 | 151 | 152 | ## Test 153 | 154 | The test script is provided by `infer.py` for the evaluation. The arguments for the test process are defined bellow: 155 | 156 | `--minibatch_len`: Integer. The sliding-window length for constructing the samples. `default=10` 157 | 158 | `--interval`: Integer. The sampling period. `default=1` 159 | 160 | `--pre_len`: Integer. The prediction horizons for evaluation. `default=1` 161 | 162 | `--batch_size`: Integer. The number of samples in a single test batch. `default=2048` 163 | 164 | `--cpu`: Optional. Use the CPU for test process. 165 | 166 | `--logdir`: String. The path for logs. `default='./log'` 167 | 168 | `--datadir`: String. The path for dataset. `default='./data'` 169 | 170 | `--netdir`: String. The path for the model. 171 | 172 | To test the model, use the following command. 173 | 174 | ``` 175 | python infer.py --netdir ./xxx.pt 176 | ``` 177 | 178 | # Dataset 179 | 180 | In this repository, the example samples are provided for evaluation. They can be accessed in the `/data/example_data.npy`. 181 | The guidance about the example data can be found in `/data/README`. 182 | 183 | 184 | # Acknowledgement 185 | 186 | The PyTorch implementation of wavelet transform is utilized to support the procedure of the DWT and IDWT procedures 187 | in this work. Its repository can be accessed [here](https://github.com/fbcotter/pytorch_wavelets). 188 | Thank all contributors to this project. 189 | 190 | # Citation 191 | 192 | Zhang, Z., Guo, D., Zhou, S. et al. Flight trajectory prediction enabled by time-frequency wavelet transform. Nat Commun 14, 5258 (2023). https://doi.org/10.1038/s41467-023-40903-9 193 | 194 | # Contact 195 | 196 | Zheng Zhang (zhaeng@stu.scu.edu.cn, musevr.ae@gmail.com) -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Note 2 | 3 | We are not authorized to publicly release the whole dataset used during the current study concerning 4 | safety-critical issues. Nonetheless, the processed example samples are 5 | available in `example_data.npy`. Source data for figures are also provided in `source_data.xlsx`. 6 | 7 | ## Guidance of example data 8 | 9 | The example data has been stored as a binary file in NumPy format. It can be accessed by using: 10 | 11 | ```python 12 | import numpy as np 13 | example_data = np.load("example_data.npy") 14 | ``` 15 | 16 | The first dimension of `example_data` is the sample number of 500. The second is the sliding-window size of 10 (the first 9 17 | lines represent the input trajectory sequence, whereas the final line serves as the target trajectory point). And the 18 | last dimension of 6 indicates the six attributes: longitude (degree), latitude (degree), altitude (10 meters), and 19 | velocities (kilometers per hour) along previous three position components. 20 | 21 | The predicted trajectory of example data can be found in `desired_results.xlsx`. 22 | -------------------------------------------------------------------------------- /data/example_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/data/example_data.npy -------------------------------------------------------------------------------- /data/source_data.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/data/source_data.xlsx -------------------------------------------------------------------------------- /data_ranges.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/data_ranges.npy -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import random 5 | # import coordinate_conversion as cc 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as tu_data 9 | 10 | class DataGenerator: 11 | def __init__(self, data_path, minibatch_len, interval=1, use_preset_data_ranges=False, 12 | train=True, test=True, dev=True, train_shuffle=True, test_shuffle=False, dev_shuffle=True): 13 | assert os.path.exists(data_path) 14 | self.attr_names = ['lon', 'lat', 'alt', 'spdx', 'spdy', 'spdz'] 15 | self.data_path = data_path 16 | self.interval = interval 17 | self.minibatch_len = minibatch_len 18 | self.data_status = np.load('data_ranges.npy', allow_pickle=True).item() 19 | assert type(self.data_status) is dict 20 | self.preset_data_ranges = {"lon": {"max": 113.689, "min": 93.883}, "lat": {"max": 37.585, "min": 19.305}, 21 | "alt": {"max": 1500, "min": 0}, "spdx": {"max": 878, "min": -945}, 22 | "spdy": {"max": 925, "min": -963}, "spdz": {"max": 43, "min": -48}} 23 | self.use_preset_data_ranges = use_preset_data_ranges 24 | if train: 25 | self.train_set = mini_DataGenerator(self.readtxt(os.path.join(self.data_path, 'train'), shuffle=train_shuffle)) 26 | if dev: 27 | self.dev_set = mini_DataGenerator(self.readtxt(os.path.join(self.data_path, 'dev'), shuffle=dev_shuffle)) 28 | if test: 29 | self.test_set = mini_DataGenerator(self.readtxt(os.path.join(self.data_path, 'test'), shuffle=test_shuffle)) 30 | if use_preset_data_ranges: 31 | assert self.preset_data_ranges is not None 32 | print('data range:', self.data_status) 33 | 34 | def readtxt(self, data_path, shuffle=True): 35 | assert os.path.exists(data_path) 36 | data = [] 37 | for root, dirs, file_names in os.walk(data_path): 38 | for file_name in file_names: 39 | if not file_name.endswith('txt'): 40 | continue 41 | with open(os.path.join(root, file_name)) as file: 42 | lines = file.readlines() 43 | lines = lines[::self.interval] 44 | if len(lines) == self.minibatch_len: 45 | data.append(lines) 46 | elif len(lines) < self.minibatch_len: 47 | continue 48 | else: 49 | for i in range(len(lines)-self.minibatch_len+1): 50 | data.append(lines[i:i+self.minibatch_len]) 51 | print(f'{len(data)} items loaded from \'{data_path}\'') 52 | if shuffle: 53 | random.shuffle(data) 54 | return data 55 | 56 | def scale(self, inp, attr): 57 | assert type(attr) is str and attr in self.attr_names 58 | data_status = self.data_status if not self.use_preset_data_ranges else self.preset_data_ranges 59 | inp = (inp-data_status[attr]['min'])/(data_status[attr]['max']-data_status[attr]['min']) 60 | return inp 61 | 62 | def unscale(self, inp, attr): 63 | assert type(attr) is str and attr in self.attr_names 64 | data_status = self.data_status if not self.use_preset_data_ranges else self.preset_data_ranges 65 | inp = inp*(data_status[attr]['max']-data_status[attr]['min'])+data_status[attr]['min'] 66 | return inp 67 | 68 | def collate(self, inp): 69 | ''' 70 | :param inp: batch * n_sequence * n_attr 71 | :return: 72 | ''' 73 | oup = [] 74 | for minibatch in inp: 75 | tmp = [] 76 | for line in minibatch: 77 | items = line.strip().split("|") 78 | lon, lat, alt, spdx, spdy, spdz = float(items[4]), float(items[5]), int(float(items[6]) / 10), \ 79 | float(items[7]), float(items[8]), float(items[9]) 80 | tmp.append([lon, lat, alt, spdx, spdy, spdz]) 81 | minibatch = np.array(tmp) 82 | for i in range(minibatch.shape[-1]): 83 | minibatch[:, i] = self.scale(minibatch[:, i], self.attr_names[i]) 84 | oup.append(minibatch) 85 | return np.array(oup) 86 | 87 | 88 | class mini_DataGenerator(tu_data.Dataset): 89 | def __init__(self, data): 90 | self.data = data 91 | 92 | def __getitem__(self, item): 93 | return self.data[item] 94 | 95 | def __len__(self): 96 | return len(self.data) 97 | -------------------------------------------------------------------------------- /desired_results.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/desired_results.xlsx -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from dataloader import DataGenerator 10 | import logging 11 | import datetime 12 | import matplotlib.pyplot as plt 13 | from mpl_toolkits.mplot3d import Axes3D 14 | from pytorch_wavelets import DWT1DForward, DWT1DInverse 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--minibatch_len', default=10, type=int) 18 | parser.add_argument('--pre_len', default=1, type=int) 19 | parser.add_argument('--interval', default=1, type=int) 20 | parser.add_argument('--batch_size', default=2048, type=int) 21 | parser.add_argument('--cpu', action='store_true') 22 | parser.add_argument('--logdir', default='./log', type=str) 23 | parser.add_argument('--datadir', default='dataaaaaa', type=str) 24 | parser.add_argument('--netdir', default=None, type=str) 25 | 26 | 27 | class Test: 28 | def __init__(self, opt, net=None): 29 | self.opt = opt 30 | self.iscuda = torch.cuda.is_available() 31 | self.device = f'cuda:{torch.cuda.current_device()}' if self.iscuda and not opt.cpu else 'cpu' 32 | self.data_set = DataGenerator(data_path=self.opt.datadir, 33 | minibatch_len=opt.minibatch_len, interval=opt.interval, 34 | use_preset_data_ranges=False, train=False, dev=False, test_shuffle=True) 35 | self.net = net 36 | self.model_path = None 37 | self.MSE = torch.nn.MSELoss(reduction='mean') 38 | self.MAE = torch.nn.L1Loss(reduction='mean') 39 | if net is not None: 40 | assert next(self.net.parameters()).device == self.device 41 | 42 | def load_model(self, model_path): 43 | self.model_path = model_path 44 | self.net = torch.load(model_path, map_location=self.device) 45 | 46 | def test(self): 47 | print_str = f'model details:\n{self.net.args}' 48 | print(print_str) 49 | self.log_path = self.opt.logdir + f'/{datetime.datetime.now().strftime("%y-%m-%d")}' 50 | if not os.path.exists(self.log_path): 51 | os.makedirs(self.log_path) 52 | logging.basicConfig(filename=os.path.join(self.log_path, f'test_{self.net.args["train_opt"].comments}.log'), 53 | filemode='w', format='%(asctime)s %(message)s', level=logging.DEBUG) 54 | logging.debug(print_str) 55 | logging.debug(self.model_path) 56 | test_data = DataLoader(dataset=self.data_set.test_set, batch_size=self.opt.batch_size, shuffle=False, 57 | collate_fn=self.data_set.collate) 58 | idwt = DWT1DInverse(wave=self.net.args['train_opt'].wavelet, mode=self.net.args['train_opt'].wt_mode).to( 59 | self.device) 60 | self.net.eval() 61 | tgt_set = [] 62 | pre_set = [] 63 | 64 | with torch.no_grad(): 65 | his_batch_set = [] 66 | all_score_set = [] 67 | for i, batch in enumerate(test_data): 68 | batch = torch.FloatTensor(batch).to(self.device) 69 | n_batch, _, n_attr = batch.shape 70 | inp_batch = batch[:, :self.opt.minibatch_len-self.opt.pre_len, :] # shape: batch * his_len * n_attr 71 | his_batch_set.append(inp_batch) 72 | tgt_set.append(batch[:, -self.opt.pre_len:, :]) # shape: batch * pre_len * n_attr 73 | pre_batch_set = [] 74 | for j in range(self.opt.pre_len): 75 | if j > 0: 76 | new_batch = pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :].unsqueeze(1) 77 | inp_batch = torch.cat((inp_batch[:, 1:, :], new_batch), dim=1) # shape: batch * his_len * n_attr 78 | if self.net.__class__.__name__ == 'WTFTP': 79 | wt_pre_batch, score_set = self.net(inp_batch) 80 | else: 81 | wt_pre_batch = self.net(inp_batch) 82 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(), 83 | [comp.transpose(1, 2).contiguous() for comp in 84 | wt_pre_batch[:-1]])).contiguous() 85 | pre_batch = pre_batch.transpose(1, 2) # shape: batch * n_sequence * n_attr 86 | pre_batch_set.append(pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :]) 87 | if self.net.__class__.__name__ == 'WTFTP' and j == 0: 88 | all_score_set.append(score_set) 89 | pre_batch_set = torch.stack(pre_batch_set, dim=1) # shape: batch * pre_len * n_attr 90 | pre_set.append(pre_batch_set) 91 | 92 | tgt_set = torch.cat(tgt_set, dim=0) 93 | pre_set = torch.cat(pre_set, dim=0) 94 | # try: 95 | # all_score_set = torch.cat(all_score_set, dim=0) 96 | # except: 97 | # all_score_set = ['no scores'] 98 | # his_batch_set = torch.cat(his_batch_set, dim=0) 99 | # torch.save([his_batch_set, tgt_set, pre_set, all_score_set], 100 | # f'{self.net.args["train_opt"].comments}_his_tgt_pre_score.pt', _use_new_zipfile_serialization=False) 101 | for i in range(self.opt.pre_len): 102 | avemse = float(self.MSE(tgt_set[:, :i + 1, :], pre_set[:, :i + 1, :]).cpu()) 103 | avemae = float(self.MAE(tgt_set[:, :i + 1, :], pre_set[:, :i + 1, :]).cpu()) 104 | rmse = {} 105 | mae = {} 106 | mre = {} 107 | for j, name in enumerate(self.data_set.attr_names): 108 | rmse[name] = float(self.MSE(self.data_set.unscale(tgt_set[:, :i + 1, j], name), 109 | self.data_set.unscale(pre_set[:, :i + 1, j], name)).sqrt().cpu()) 110 | mae[name] = float(self.MAE(self.data_set.unscale(tgt_set[:, :i + 1, j], name), 111 | self.data_set.unscale(pre_set[:, :i + 1, j], name)).cpu()) 112 | logit = self.data_set.unscale(tgt_set[:, :i + 1, j], name) != 0 113 | mre[name] = float(torch.mean(torch.abs(self.data_set.unscale(tgt_set[:, :i + 1, j], name)- 114 | self.data_set.unscale(pre_set[:, :i + 1, j], name))[logit]/ 115 | self.data_set.unscale(tgt_set[:, :i + 1, j], name)[logit]).cpu()) * 100 if name in 'lonlatalt' \ 116 | else "N/A" 117 | lon = self.data_set.unscale(pre_set[:, :i + 1, 0], 'lon').cpu().numpy() 118 | lat = self.data_set.unscale(pre_set[:, :i + 1, 1], 'lat').cpu().numpy() 119 | alt = self.data_set.unscale(pre_set[:, :i + 1, 2], 'alt').cpu().numpy() / 100 # km 120 | X, Y, Z = self.gc2ecef(lon, lat, alt) 121 | lon_t = self.data_set.unscale(tgt_set[:, :i + 1, 0], 'lon').cpu().numpy() 122 | lat_t = self.data_set.unscale(tgt_set[:, :i + 1, 1], 'lat').cpu().numpy() 123 | alt_t = self.data_set.unscale(tgt_set[:, :i + 1, 2], 'alt').cpu().numpy() / 100 # km 124 | X_t, Y_t, Z_t = self.gc2ecef(lon_t, lat_t, alt_t) 125 | MDE = np.mean(np.sqrt((X - X_t) ** 2 + (Y - Y_t) ** 2 + (Z - Z_t) ** 2)) 126 | print_str = f'\nStep {i + 1}: \naveMSE(scaled): {avemse:.8f}, in each attr(RMSE, unscaled): {rmse}\n' \ 127 | f'aveMAE(scaled): {avemae:.8f}, in each attr(MAE, unscaled): {mae}\n' \ 128 | f'In each attr(MRE, %): {mre}\n' \ 129 | f'MDE(unscaled): {MDE:.8f}\n' 130 | print(print_str) 131 | logging.debug(print_str) 132 | 133 | def gc2ecef(self, lon, lat, alt): 134 | a = 6378.137 # km 135 | b = 6356.752 136 | lat = np.radians(lat) 137 | lon = np.radians(lon) 138 | e_square = 1 - (b ** 2) / (a ** 2) 139 | N = a / np.sqrt(1 - e_square * (np.sin(lat) ** 2)) 140 | X = (N + alt) * np.cos(lat) * np.cos(lon) 141 | Y = (N + alt) * np.cos(lat) * np.sin(lon) 142 | Z = ((b ** 2) / (a ** 2) * N + alt) * np.sin(lat) 143 | return X, Y, Z 144 | 145 | def draw_demo(self, items=None, realtime=False): 146 | plt.rcParams['font.family'] = 'arial' 147 | plt.rcParams['font.size'] = 16 148 | total_num = len(self.data_set.test_set) 149 | test_data = DataLoader(dataset=self.data_set.test_set, batch_size=self.opt.batch_size, shuffle=False, 150 | collate_fn=self.data_set.collate) 151 | idwt = DWT1DInverse(wave=self.net.args['train_opt'].wavelet, mode=self.net.args['train_opt'].wt_mode).to( 152 | self.device) 153 | if items is None: 154 | items = [random.randint(0, total_num)] 155 | elif type(items) is int: 156 | items = [items % total_num] 157 | elif type(items) is list and len(items) > 0 and type(items[0]) is int: 158 | pass 159 | else: 160 | TypeError(type(items)) 161 | if realtime: 162 | items = [int(input("item: "))] 163 | while len(items) > 0 and items[0] > 0: 164 | item = items[0] 165 | del items[0] 166 | n_batch = item // self.opt.batch_size 167 | n_minibatch = item % self.opt.batch_size 168 | sel_batch = None 169 | for i, batch in enumerate(test_data): 170 | if i == n_batch: 171 | sel_batch = batch 172 | break 173 | traj = sel_batch[n_minibatch:n_minibatch + 1, ...] 174 | with torch.no_grad(): 175 | self.net.eval() 176 | self.net.to(self.device) 177 | inp_batch = torch.FloatTensor(traj[:, :self.opt.minibatch_len-self.opt.pre_len, :]).to( 178 | self.device) # shape: 1 * his_len * n_attr 179 | pre_batch_set = [] 180 | full_pre_set = [] 181 | for j in range(self.opt.pre_len): 182 | if j > 0: 183 | new_batch = pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :].unsqueeze(1) 184 | inp_batch = torch.cat((inp_batch[:, 1:, :], new_batch), dim=1) # shape: batch * his_len * n_attr 185 | if self.net.__class__.__name__ == 'WTFTP': 186 | wt_pre_batch, score_set = self.net(inp_batch) 187 | else: 188 | wt_pre_batch = self.net(inp_batch) 189 | if j == 0: 190 | first_wt_pre = wt_pre_batch 191 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(), 192 | [comp.transpose(1, 2).contiguous() for comp in 193 | wt_pre_batch[:-1]])).contiguous() 194 | pre_batch = pre_batch.transpose(1, 2) # shape: 1 * n_sequence * n_attr 195 | pre_batch_set.append(pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :]) 196 | full_pre_set.append(pre_batch[:, :self.opt.minibatch_len-self.opt.pre_len + 1, :].clone()) 197 | pre_batch_set = torch.stack(pre_batch_set, dim=1) # shape: 1 * pre_len * n_attr 198 | 199 | lla_his = np.array(traj[0, :self.opt.minibatch_len-self.opt.pre_len, 0:3]) # shape: his_len * n_attr 200 | lla_trg = np.array(traj[0, -self.opt.pre_len:, 0:3]) # shape: pre_len * n_attr 201 | lla_pre = np.array(pre_batch_set[0, :, 0:3].cpu().numpy()) # shape: pre_len * n_attr 202 | for i, name in enumerate(self.data_set.attr_names): 203 | if i > 2: 204 | break 205 | lla_his[:, i] = self.data_set.unscale(lla_his[:, i], name) 206 | lla_trg[:, i] = self.data_set.unscale(lla_trg[:, i], name) 207 | lla_pre[:, i] = self.data_set.unscale(lla_pre[:, i], name) 208 | 209 | fig = plt.figure(figsize=(9,9)) 210 | elev_azim_set = [[90, 0], [0, 0], [0, 90], [None, None]] # represent top view, lateral view(lat), lateral view(lon) and default, respectively 211 | for i, elev_azim in enumerate(elev_azim_set): 212 | ax = fig.add_subplot(2, 2, i + 1, projection='3d') 213 | ax.view_init(elev=elev_azim[0], azim=elev_azim[1]) 214 | ax.plot3D(lla_his[:, 0], lla_his[:, 1], lla_his[:, 2], marker='o', markeredgecolor='dodgerblue', 215 | label='his') 216 | ax.plot3D(lla_trg[:, 0], lla_trg[:, 1], lla_trg[:, 2], marker='*', markeredgecolor='blueviolet', 217 | label='tgt') 218 | ax.plot3D(lla_pre[:, 0], lla_pre[:, 1], lla_pre[:, 2], marker='p', markeredgecolor='orangered', 219 | label='pre') 220 | ax.set_xlabel('lon') 221 | ax.set_ylabel('lat') 222 | ax.set_zlabel('alt') 223 | ax.set_zlim(min(lla_his[:, 2]) - 20, max(lla_his[:, 2]) + 20) 224 | plt.suptitle(f'item_{item}') 225 | ax.legend() 226 | plt.tight_layout() 227 | plt.show() 228 | if realtime: 229 | items.append(int(input("item: "))) 230 | 231 | if __name__ == '__main__': 232 | opt = parser.parse_args() 233 | test = Test(opt) 234 | test.load_model(opt.netdir) 235 | test.test() 236 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: zheng zhang 3 | Email: musevr.ae@gmail.com 4 | """ 5 | import torch 6 | from torch import nn 7 | 8 | class WTFTP_AttnRemoved(nn.Module): 9 | def __init__(self, n_inp, n_oup, n_embding=128, n_encoderLayers=4, n_decoderLayers=4, proj='linear', activation='relu', maxlevel=1, en_dropout=0, 10 | de_dropout=0, steps=None): 11 | super(WTFTP_AttnRemoved, self).__init__() 12 | self.args = locals() 13 | self.n_oup = n_oup 14 | self.maxlevel = maxlevel 15 | self.n_decoderLayers = n_decoderLayers 16 | self.steps = steps if steps is not None else [5,3,2] 17 | if activation == 'relu': 18 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding//2, bias=False, proj=proj), nn.ReLU(), Embed(n_embding//2, n_embding, bias=False, proj=proj), nn.ReLU()) 19 | elif activation == 'sigmoid': 20 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding//2, bias=False, proj=proj), nn.Sigmoid(), Embed(n_embding//2, n_embding, bias=False, proj=proj), nn.Sigmoid()) 21 | else: 22 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding//2, bias=False, proj=proj), nn.ReLU(), Embed(n_embding//2, n_embding, bias=False, proj=proj), nn.ReLU()) 23 | 24 | self.encoder = nn.LSTM(input_size=n_embding, 25 | hidden_size=n_embding, 26 | num_layers=n_encoderLayers, 27 | bidirectional=False, 28 | batch_first=True, 29 | dropout=en_dropout, 30 | bias=False) 31 | self.decoders = nn.ModuleList([nn.LSTM(input_size=n_embding, 32 | hidden_size=n_embding, 33 | num_layers=n_decoderLayers, 34 | bidirectional=False, 35 | batch_first=True, 36 | dropout=de_dropout, 37 | bias=False) for _ in range(1+maxlevel)]) 38 | self.LNs = nn.ModuleList([nn.LayerNorm(n_embding) for _ in range(1+maxlevel)]) 39 | self.oup_embdings = nn.ModuleList([Embed(n_embding, n_oup, bias=True, proj=proj) for _ in range(1+maxlevel)]) 40 | 41 | def forward(self, inp): 42 | """ 43 | :param inp: shape: batch * n_sequence * n_attr 44 | :return: shape: batch * n_desiredLength * level * n_attr 45 | """ 46 | inp = self.inp_embding(inp) 47 | _, (h_0, c_0) = self.encoder(inp) 48 | oup = [] 49 | n_batch = inp.shape[0] 50 | i = 0 51 | for decoder, LN, oup_embding in zip(self.decoders, self.LNs, self.oup_embdings): 52 | this_decoder_oup = [] 53 | decoder_first_inp = torch.zeros((n_batch, 1, inp.shape[-1]), device=inp.device) 54 | de_H = torch.zeros(self.n_decoderLayers, inp.shape[0], inp.shape[-1], dtype=torch.float, 55 | device=inp.device) 56 | de_C = torch.zeros(self.n_decoderLayers, inp.shape[0], inp.shape[-1], dtype=torch.float, 57 | device=inp.device) 58 | de_C[0, :, :] = c_0[-1, :, :].clone() 59 | decoder_hidden = (de_H, de_C) 60 | if i == self.maxlevel: 61 | this_step = self.steps[self.maxlevel-1] 62 | else: 63 | this_step = self.steps[i] 64 | for _ in range(this_step): 65 | decoder_first_inp, decoder_hidden = decoder(decoder_first_inp, decoder_hidden) 66 | de_oup = LN(decoder_first_inp) 67 | de_oup = oup_embding(de_oup) 68 | this_decoder_oup.append(de_oup) 69 | this_decoder_oup = torch.cat(this_decoder_oup, dim=1) # shape: batch * n_desiredLength * n_attr 70 | oup.append(this_decoder_oup) 71 | i += 1 72 | return oup 73 | 74 | 75 | class WTFTP(nn.Module): 76 | def __init__(self, n_inp, n_oup, his_step, n_embding=64, en_layers=4, de_layers=1, proj='linear', 77 | activation='relu', maxlevel=1, en_dropout=0., 78 | de_dropout=0., out_split=False, decoder_init_zero=False, bias=False, se_skip=True, 79 | attn_conv_params=None): 80 | """ 81 | :param n_inp: 82 | :param n_oup: 83 | :param n_embding: 84 | :param layers: 85 | :param maxlevel: 86 | :param en_dropout: 87 | :param de_dropout: 88 | :param out_split: 89 | :param bias: 90 | """ 91 | super(WTFTP, self).__init__() 92 | self.args = locals() 93 | self.n_embding = n_embding 94 | self.n_oup = n_oup 95 | self.maxlevel = maxlevel 96 | self.decoder_init_zero = decoder_init_zero 97 | self.de_layers = de_layers 98 | attn_conv_params = {0: {'stride': 2, 'kernel': 2, 'pad': 1}, 1: {'stride': 3, 'kernel': 3, 'pad': 0}, 99 | 2: {'stride': 5, 'kernel': 5, 'pad': 1}} if attn_conv_params == None else attn_conv_params 100 | if activation == 'relu': 101 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding // 2, bias=False, proj=proj), nn.ReLU(), 102 | Embed(n_embding // 2, n_embding, bias=False, proj=proj), nn.ReLU()) 103 | elif activation == 'sigmoid': 104 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding // 2, bias=False, proj=proj), nn.Sigmoid(), 105 | Embed(n_embding // 2, n_embding, bias=False, proj=proj), nn.Sigmoid()) 106 | else: 107 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding // 2, bias=False, proj=proj), nn.ReLU(), 108 | Embed(n_embding // 2, n_embding, bias=False, proj=proj), nn.ReLU()) 109 | self.encoder = nn.LSTM(input_size=n_embding, 110 | hidden_size=n_embding, 111 | num_layers=en_layers, 112 | bidirectional=False, 113 | batch_first=True, 114 | dropout=en_dropout, 115 | bias=False) 116 | self.Wt_attns = nn.ModuleList([WaveletAttention(hid_dim=n_embding, init_steps=his_step, skip=se_skip, 117 | kernel=attn_conv_params[i]['kernel'], 118 | stride=attn_conv_params[i]['stride'], 119 | padding=attn_conv_params[i]['pad']) 120 | for i in range(maxlevel)] + 121 | [WaveletAttention(hid_dim=n_embding, init_steps=his_step, skip=se_skip, 122 | kernel=attn_conv_params[maxlevel - 1]['kernel'], 123 | stride=attn_conv_params[maxlevel - 1]['stride'], 124 | padding=attn_conv_params[maxlevel - 1]['pad'])]) 125 | self.decoders = nn.ModuleList([nn.LSTM(input_size=n_embding, 126 | hidden_size=n_embding, 127 | num_layers=de_layers, 128 | bidirectional=False, 129 | batch_first=True, 130 | dropout=de_dropout, 131 | bias=False) for _ in range(1 + maxlevel)]) 132 | self.LNs = nn.ModuleList([nn.LayerNorm(n_embding) for _ in range(1 + maxlevel)]) 133 | if out_split: 134 | assert n_embding % n_oup == 0 135 | self.out_split = out_split 136 | self.oup_embdings = nn.ModuleList( 137 | [nn.Linear(n_embding // n_oup, 1, bias=bias) for _ in range(n_oup * 2 ** maxlevel)]) 138 | else: 139 | self.oup_embdings = nn.ModuleList( 140 | [Embed(n_embding, n_oup, bias=True, proj=proj) for _ in range(1 + maxlevel)]) 141 | 142 | def forward(self, inp): 143 | """ 144 | :param inp: shape: batch * n_sequence * n_attr 145 | :param steps: coeff length of wavelet 146 | :return: coef_set: shape: batch * steps * levels * n_oup, 147 | all_scores_set: shape: batch * steps * levels * n_sequence, n_sequence here is timeSteps of inp 148 | """ 149 | embdings = self.inp_embding(inp) # batch * n_sequence * n_embding 150 | en_oup, (h_en, c_en) = self.encoder(embdings) 151 | all_de_oup_set = [] 152 | all_scores_set = [] 153 | for i, (attn, decoder, LN) in enumerate(zip(self.Wt_attns, self.decoders, self.LNs)): 154 | if self.decoder_init_zero: 155 | de_HC = None 156 | else: 157 | de_H = torch.zeros(self.de_layers, embdings.shape[0], embdings.shape[-1], dtype=torch.float, 158 | device=embdings.device) 159 | de_C = torch.zeros(self.de_layers, embdings.shape[0], embdings.shape[-1], dtype=torch.float, 160 | device=embdings.device) 161 | de_C[0, :, :] = c_en[-1, :, :].clone() 162 | de_HC = (de_H, de_C) 163 | de_inp, weight = attn(en_oup) 164 | de_oup_set, _ = decoder(de_inp, de_HC) # shape: batch * steps * n_embding 165 | all_de_oup_set.append(LN(de_oup_set)) 166 | all_scores_set.append(weight) 167 | all_scores_set = torch.cat(all_scores_set, dim=1) # shape: batch * steps * levels * n_sequence 168 | if hasattr(self, 'out_split'): 169 | all_de_oup_set = torch.cat(all_de_oup_set, dim=2) # shape: batch * steps * (levels*n_embding) 170 | split_all_de_oup_set = torch.split(all_de_oup_set, split_size_or_sections=self.n_embding // self.n_oup, 171 | dim=-1) 172 | coef_set = [] 173 | for i, linear in enumerate(self.oup_embdings): 174 | coef_set.append( 175 | linear(split_all_de_oup_set[i]) 176 | ) # shape: batch * steps * 1 177 | coef_set = torch.cat(coef_set, dim=-1).reshape(inp.shape[0], -1, 2 ** self.maxlevel, 178 | self.n_oup) # shape: batch * steps * levels * n_oup 179 | else: 180 | coef_set = [] 181 | for i, linear in enumerate(self.oup_embdings): 182 | coef_set.append( 183 | linear(all_de_oup_set[i]) 184 | ) # shape: batch * steps * n_oup 185 | return coef_set, all_scores_set 186 | 187 | class Embed(nn.Module): 188 | def __init__(self, in_features, out_features, bias=True, proj='linear'): 189 | super(Embed, self).__init__() 190 | self.proj = proj 191 | if proj == 'linear': 192 | self.embed = nn.Linear(in_features, out_features, bias) 193 | else: 194 | self.embed = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=3, stride=1, 195 | padding=1, 196 | padding_mode='replicate', bias=bias) 197 | 198 | def forward(self, inp): 199 | """ 200 | inp: B * T * D 201 | """ 202 | if self.proj == 'linear': 203 | inp = self.embed(inp) 204 | else: 205 | inp = self.embed(inp.transpose(1, 2)).transpose(1, 2) 206 | return inp 207 | 208 | class WaveletAttention(nn.Module): 209 | def __init__(self, hid_dim, init_steps, skip=True, kernel=3, stride=3, padding=0): 210 | super(WaveletAttention, self).__init__() 211 | self.enhance = EnhancedBlock(hid_size=hid_dim, channel=init_steps, skip=skip) 212 | self.convs = nn.Sequential(nn.Conv1d(in_channels=hid_dim, 213 | out_channels=hid_dim, 214 | kernel_size=kernel, 215 | stride=stride, padding=padding, 216 | padding_mode='zeros', 217 | bias=False), 218 | nn.ReLU()) 219 | 220 | def forward(self, hid_set: torch.Tensor): 221 | """ 222 | q = W1 * hid_0, K = W2 * hid_set 223 | softmax(q * K.T) * hid_set 224 | :param hid_set: Key set, shape: batch * n_sequence * hid_dim 225 | :param hid_0: Query hidden state, shape: batch * 1 * hid_dim 226 | :return: out: shape: shape: batch * 1 * hid_dim, scores: shape: batch * 1 * sequence 227 | """ 228 | reweighted, weight = self.enhance(hid_set) 229 | align = self.convs(reweighted.transpose(1, 2)).transpose(1, 2) # B * DestStep * D 230 | return align, weight 231 | 232 | 233 | class EnhancedBlock(nn.Module): 234 | def __init__(self, hid_size, channel, skip=True): 235 | super(EnhancedBlock, self).__init__() 236 | self.skip = skip 237 | self.comp = nn.Sequential(nn.Linear(hid_size, hid_size // 2, bias=False), 238 | nn.ReLU(), 239 | nn.Linear(hid_size // 2, 1, bias=False)) 240 | self.activate = nn.Sequential(nn.Linear(channel, channel // 2, bias=False), 241 | nn.ReLU(), 242 | nn.Linear(channel // 2, channel, bias=False), 243 | nn.Sigmoid()) 244 | 245 | def forward(self, inp): 246 | S = self.comp(inp) # B * T * 1 247 | E = self.activate(S.transpose(1, 2)) # B * 1 * T 248 | out = inp * E.transpose(1, 2).expand_as(inp) 249 | if self.skip: 250 | out += inp 251 | return out, E # B * T * D 252 | 253 | if __name__ == '__main__': 254 | model = WTFTP(6, 6, 9) 255 | model = WTFTP_AttnRemoved(6, 6) 256 | inp = torch.rand((100, 9, 6)) 257 | out, _ = model(inp) 258 | print(model.__class__.__name__) 259 | print(out) 260 | print(model.args) 261 | -------------------------------------------------------------------------------- /pics/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pics/framework.png -------------------------------------------------------------------------------- /pytorch_wavelets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | '__version__', 3 | 'DTCWTForward', 4 | 'DTCWTInverse', 5 | 'DWTForward', 6 | 'DWTInverse', 7 | 'DWT1DForward', 8 | 'DWT1DInverse', 9 | 'DTCWT', 10 | 'IDTCWT', 11 | 'DWT', 12 | 'IDWT', 13 | 'DWT1D', 14 | 'DWT2D', 15 | 'IDWT1D', 16 | 'IDWT2D', 17 | 'ScatLayer', 18 | 'ScatLayerj2' 19 | ] 20 | 21 | from pytorch_wavelets._version import __version__ 22 | from pytorch_wavelets.dtcwt.transform2d import DTCWTForward, DTCWTInverse 23 | from pytorch_wavelets.dwt.transform2d import DWTForward, DWTInverse 24 | from pytorch_wavelets.dwt.transform1d import DWT1DForward, DWT1DInverse 25 | from pytorch_wavelets.scatternet import ScatLayer, ScatLayerj2 26 | 27 | # Some aliases 28 | DTCWT = DTCWTForward 29 | IDTCWT = DTCWTInverse 30 | DWT = DWTForward 31 | IDWT = DWTInverse 32 | DWT2D = DWT 33 | IDWT2D = IDWT 34 | 35 | DWT1D = DWT1DForward 36 | IDWT1D = DWT1DInverse 37 | -------------------------------------------------------------------------------- /pytorch_wavelets/_version.py: -------------------------------------------------------------------------------- 1 | # IMPORTANT: before release, remove the 'devN' tag from the release name 2 | __version__ = '1.3.0' 3 | -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provide low-level torch accelerated operations. This backend requires that 3 | torch be installed. Works best with a GPU but still offers good 4 | improvements with a CPU. 5 | 6 | """ 7 | -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/coeffs.py: -------------------------------------------------------------------------------- 1 | """Functions to load standard wavelet coefficients. 2 | 3 | """ 4 | from __future__ import absolute_import 5 | 6 | from numpy import load 7 | from pkg_resources import resource_stream 8 | try: 9 | import pywt 10 | _HAVE_PYWT = True 11 | except ImportError: 12 | _HAVE_PYWT = False 13 | 14 | COEFF_CACHE = {} 15 | 16 | 17 | def _load_from_file(basename, varnames): 18 | 19 | try: 20 | mat = COEFF_CACHE[basename] 21 | except KeyError: 22 | with resource_stream('pytorch_wavelets.dtcwt.data', basename + '.npz') as f: 23 | mat = dict(load(f)) 24 | COEFF_CACHE[basename] = mat 25 | 26 | try: 27 | return tuple(mat[k] for k in varnames) 28 | except KeyError: 29 | raise ValueError( 30 | 'Wavelet does not define ({0}) coefficients'.format( 31 | ', '.join(varnames))) 32 | 33 | 34 | def biort(name): 35 | """ Deprecated. Use :py::func:`pytorch_wavelets.dtcwt.coeffs.level1` 36 | Instead 37 | """ 38 | return level1(name, compact=True) 39 | 40 | 41 | def level1(name, compact=False): 42 | """Load level 1 wavelet by name. 43 | 44 | :param name: a string specifying the wavelet family name 45 | :returns: a tuple of vectors giving filter coefficients 46 | 47 | ============= ============================================ 48 | Name Wavelet 49 | ============= ============================================ 50 | antonini Antonini 9,7 tap filters. 51 | farras Farras 8,8 tap filters 52 | legall LeGall 5,3 tap filters. 53 | near_sym_a Near-Symmetric 5,7 tap filters. 54 | near_sym_b Near-Symmetric 13,19 tap filters. 55 | near_sym_b_bp Near-Symmetric 13,19 tap filters + BP filter 56 | ============= ============================================ 57 | 58 | Return a tuple whose elements are a vector specifying the h0o, g0o, h1o and 59 | g1o coefficients. 60 | 61 | See :ref:`rot-symm-wavelets` for an explanation of the ``near_sym_b_bp`` 62 | wavelet filters. 63 | 64 | :raises IOError: if name does not correspond to a set of wavelets known to 65 | the library. 66 | :raises ValueError: if name doesn't specify 67 | :py:func:`pytorch_wavelets.dtcwt.coeffs.qshift` wavelet. 68 | 69 | """ 70 | if compact: 71 | if name == 'near_sym_b_bp': 72 | return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o', 'h2o', 'g2o')) 73 | else: 74 | return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o')) 75 | else: 76 | return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 77 | 'g1a', 'g1b')) 78 | 79 | 80 | def qshift(name): 81 | """Load level >=2 wavelet by name, 82 | 83 | :param name: a string specifying the wavelet family name 84 | :returns: a tuple of vectors giving filter coefficients 85 | 86 | ============ ============================================ 87 | Name Wavelet 88 | ============ ============================================ 89 | qshift_06 Quarter Sample Shift Orthogonal (Q-Shift) 10,10 tap filters, 90 | (only 6,6 non-zero taps). 91 | qshift_a Q-shift 10,10 tap filters, 92 | (with 10,10 non-zero taps, unlike qshift_06). 93 | qshift_b Q-Shift 14,14 tap filters. 94 | qshift_c Q-Shift 16,16 tap filters. 95 | qshift_d Q-Shift 18,18 tap filters. 96 | qshift_b_bp Q-Shift 18,18 tap filters + BP 97 | ============ ============================================ 98 | 99 | Return a tuple whose elements are a vector specifying the h0a, h0b, g0a, 100 | g0b, h1a, h1b, g1a and g1b coefficients. 101 | 102 | See :ref:`rot-symm-wavelets` for an explanation of the ``qshift_b_bp`` 103 | wavelet filters. 104 | 105 | :raises IOError: if name does not correspond to a set of wavelets known to 106 | the library. 107 | :raises ValueError: if name doesn't specify a 108 | :py:func:`pytorch_wavelets.dtcwt.coeffs.biort` wavelet. 109 | 110 | """ 111 | if name == 'qshift_b_bp': 112 | return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 113 | 'g1a', 'g1b', 'h2a', 'h2b', 'g2a','g2b')) 114 | else: 115 | return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b', 116 | 'g1a', 'g1b')) 117 | 118 | 119 | def pywt_coeffs(name): 120 | """ Wraps pywt Wavelet function. """ 121 | if not _HAVE_PYWT: 122 | raise ImportError("Could not find PyWavelets module") 123 | return pywt.Wavelet(name) 124 | 125 | # vim:sw=4:sts=4:et 126 | -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/__init__.py -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/antonini.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/antonini.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/farras.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/farras.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/legall.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/legall.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/near_sym_a.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_a.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/near_sym_a2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_a2.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/near_sym_b.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_b.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/near_sym_b_bp.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_b_bp.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/qshift_06.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_06.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/qshift_32.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_32.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/qshift_a.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_a.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/qshift_b.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_b.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/qshift_b_bp.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_b_bp.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/qshift_c.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_c.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/data/qshift_d.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_d.npz -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/lowlevel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from pytorch_wavelets.utils import symm_pad_1d as symm_pad 7 | 8 | 9 | def as_column_vector(v): 10 | """Return *v* as a column vector with shape (N,1). 11 | 12 | """ 13 | v = np.atleast_2d(v) 14 | if v.shape[0] == 1: 15 | return v.T 16 | else: 17 | return v 18 | 19 | 20 | def _as_row_vector(v): 21 | """Return *v* as a row vector with shape (1, N). 22 | """ 23 | v = np.atleast_2d(v) 24 | if v.shape[0] == 1: 25 | return v 26 | else: 27 | return v.T 28 | 29 | 30 | def _as_row_tensor(h): 31 | if isinstance(h, torch.Tensor): 32 | h = torch.reshape(h, [1, -1]) 33 | else: 34 | h = as_column_vector(h).T 35 | h = torch.tensor(h, dtype=torch.get_default_dtype()) 36 | return h 37 | 38 | 39 | def _as_col_vector(v): 40 | """Return *v* as a column vector with shape (N,1). 41 | """ 42 | v = np.atleast_2d(v) 43 | if v.shape[0] == 1: 44 | return v.T 45 | else: 46 | return v 47 | 48 | 49 | def _as_col_tensor(h): 50 | if isinstance(h, torch.Tensor): 51 | h = torch.reshape(h, [-1, 1]) 52 | else: 53 | h = as_column_vector(h) 54 | h = torch.tensor(h, dtype=torch.get_default_dtype()) 55 | return h 56 | 57 | 58 | def prep_filt(h, c, transpose=False): 59 | """ Prepares an array to be of the correct format for pytorch. 60 | Can also specify whether to make it a row filter (set tranpose=True)""" 61 | h = _as_col_vector(h)[::-1] 62 | h = h[None, None, :] 63 | h = np.repeat(h, repeats=c, axis=0) 64 | if transpose: 65 | h = h.transpose((0,1,3,2)) 66 | h = np.copy(h) 67 | return torch.tensor(h, dtype=torch.get_default_dtype()) 68 | 69 | 70 | def colfilter(X, h, mode='symmetric'): 71 | if X is None or X.shape == torch.Size([]): 72 | return torch.zeros(1,1,1,1, device=X.device) 73 | b, ch, row, col = X.shape 74 | m = h.shape[2] // 2 75 | if mode == 'symmetric': 76 | xe = symm_pad(row, m) 77 | X = F.conv2d(X[:,:,xe], h.repeat(ch,1,1,1), groups=ch) 78 | else: 79 | X = F.conv2d(X, h.repeat(ch, 1, 1, 1), groups=ch, padding=(m, 0)) 80 | return X 81 | 82 | 83 | def rowfilter(X, h, mode='symmetric'): 84 | if X is None or X.shape == torch.Size([]): 85 | return torch.zeros(1,1,1,1, device=X.device) 86 | b, ch, row, col = X.shape 87 | m = h.shape[2] // 2 88 | h = h.transpose(2,3).contiguous() 89 | if mode == 'symmetric': 90 | xe = symm_pad(col, m) 91 | X = F.conv2d(X[:,:,:,xe], h.repeat(ch,1,1,1), groups=ch) 92 | else: 93 | X = F.conv2d(X, h.repeat(ch,1,1,1), groups=ch, padding=(0, m)) 94 | return X 95 | 96 | 97 | def coldfilt(X, ha, hb, highpass=False, mode='symmetric'): 98 | if X is None or X.shape == torch.Size([]): 99 | return torch.zeros(1,1,1,1, device=X.device) 100 | batch, ch, r, c = X.shape 101 | r2 = r // 2 102 | if r % 4 != 0: 103 | raise ValueError('No. of rows in X must be a multiple of 4\n' + 104 | 'X was {}'.format(X.shape)) 105 | 106 | if mode == 'symmetric': 107 | m = ha.shape[2] 108 | xe = symm_pad(r, m) 109 | X = torch.cat((X[:,:,xe[2::2]], X[:,:,xe[3::2]]), dim=1) 110 | h = torch.cat((ha.repeat(ch, 1, 1, 1), hb.repeat(ch, 1, 1, 1)), dim=0) 111 | X = F.conv2d(X, h, stride=(2,1), groups=ch*2) 112 | else: 113 | raise NotImplementedError() 114 | 115 | # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping 116 | # interleaves the columns 117 | if highpass: 118 | X = torch.stack((X[:, ch:], X[:, :ch]), dim=-2).view(batch, ch, r2, c) 119 | else: 120 | X = torch.stack((X[:, :ch], X[:, ch:]), dim=-2).view(batch, ch, r2, c) 121 | 122 | return X 123 | 124 | 125 | def rowdfilt(X, ha, hb, highpass=False, mode='symmetric'): 126 | if X is None or X.shape == torch.Size([]): 127 | return torch.zeros(1,1,1,1, device=X.device) 128 | batch, ch, r, c = X.shape 129 | c2 = c // 2 130 | if c % 4 != 0: 131 | raise ValueError('No. of cols in X must be a multiple of 4\n' + 132 | 'X was {}'.format(X.shape)) 133 | 134 | if mode == 'symmetric': 135 | m = ha.shape[2] 136 | xe = symm_pad(c, m) 137 | X = torch.cat((X[:,:,:,xe[2::2]], X[:,:,:,xe[3::2]]), dim=1) 138 | h = torch.cat((ha.reshape(1,1,1,m).repeat(ch, 1, 1, 1), 139 | hb.reshape(1,1,1,m).repeat(ch, 1, 1, 1)), dim=0) 140 | X = F.conv2d(X, h, stride=(1,2), groups=ch*2) 141 | else: 142 | raise NotImplementedError() 143 | 144 | # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping 145 | # interleaves the columns 146 | if highpass: 147 | Y = torch.stack((X[:, ch:], X[:, :ch]), dim=-1).view(batch, ch, r, c2) 148 | else: 149 | Y = torch.stack((X[:, :ch], X[:, ch:]), dim=-1).view(batch, ch, r, c2) 150 | 151 | return Y 152 | 153 | 154 | def colifilt(X, ha, hb, highpass=False, mode='symmetric'): 155 | if X is None or X.shape == torch.Size([]): 156 | return torch.zeros(1,1,1,1, device=X.device) 157 | m = ha.shape[2] 158 | m2 = m // 2 159 | hao = ha[:,:,1::2] 160 | hae = ha[:,:,::2] 161 | hbo = hb[:,:,1::2] 162 | hbe = hb[:,:,::2] 163 | batch, ch, r, c = X.shape 164 | if r % 2 != 0: 165 | raise ValueError('No. of rows in X must be a multiple of 2.\n' + 166 | 'X was {}'.format(X.shape)) 167 | xe = symm_pad(r, m2) 168 | 169 | if m2 % 2 == 0: 170 | h1 = hae 171 | h2 = hbe 172 | h3 = hao 173 | h4 = hbo 174 | if highpass: 175 | X = torch.cat((X[:,:,xe[1:-2:2]], X[:,:,xe[:-2:2]], X[:,:,xe[3::2]], X[:,:,xe[2::2]]), dim=1) 176 | else: 177 | X = torch.cat((X[:,:,xe[:-2:2]], X[:,:,xe[1:-2:2]], X[:,:,xe[2::2]], X[:,:,xe[3::2]]), dim=1) 178 | else: 179 | h1 = hao 180 | h2 = hbo 181 | h3 = hae 182 | h4 = hbe 183 | if highpass: 184 | X = torch.cat((X[:,:,xe[2:-1:2]], X[:,:,xe[1:-1:2]], X[:,:,xe[2:-1:2]], X[:,:,xe[1:-1:2]]), dim=1) 185 | else: 186 | X = torch.cat((X[:,:,xe[1:-1:2]], X[:,:,xe[2:-1:2]], X[:,:,xe[1:-1:2]], X[:,:,xe[2:-1:2]]), dim=1) 187 | h = torch.cat((h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1), 188 | h3.repeat(ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)), dim=0) 189 | 190 | X = F.conv2d(X, h, groups=4*ch) 191 | # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor 192 | # [batch, ch, r2, 4, c] 193 | X = torch.stack([X[:,:ch], X[:,ch:2*ch], X[:,2*ch:3*ch], X[:,3*ch:]], dim=3).view(batch, ch, r*2, c) 194 | 195 | return X 196 | 197 | 198 | def rowifilt(X, ha, hb, highpass=False, mode='symmetric'): 199 | if X is None or X.shape == torch.Size([]): 200 | return torch.zeros(1,1,1,1, device=X.device) 201 | m = ha.shape[2] 202 | m2 = m // 2 203 | hao = ha[:,:,1::2] 204 | hae = ha[:,:,::2] 205 | hbo = hb[:,:,1::2] 206 | hbe = hb[:,:,::2] 207 | batch, ch, r, c = X.shape 208 | if c % 2 != 0: 209 | raise ValueError('No. of cols in X must be a multiple of 2.\n' + 210 | 'X was {}'.format(X.shape)) 211 | xe = symm_pad(c, m2) 212 | 213 | if m2 % 2 == 0: 214 | h1 = hae 215 | h2 = hbe 216 | h3 = hao 217 | h4 = hbo 218 | if highpass: 219 | X = torch.cat((X[:,:,:,xe[1:-2:2]], X[:,:,:,xe[:-2:2]], X[:,:,:,xe[3::2]], X[:,:,:,xe[2::2]]), dim=1) 220 | else: 221 | X = torch.cat((X[:,:,:,xe[:-2:2]], X[:,:,:,xe[1:-2:2]], X[:,:,:,xe[2::2]], X[:,:,:,xe[3::2]]), dim=1) 222 | else: 223 | h1 = hao 224 | h2 = hbo 225 | h3 = hae 226 | h4 = hbe 227 | if highpass: 228 | X = torch.cat((X[:,:,:,xe[2:-1:2]], X[:,:,:,xe[1:-1:2]], X[:,:,:,xe[2:-1:2]], X[:,:,:,xe[1:-1:2]]), dim=1) 229 | else: 230 | X = torch.cat((X[:,:,:,xe[1:-1:2]], X[:,:,:,xe[2:-1:2]], X[:,:,:,xe[1:-1:2]], X[:,:,:,xe[2:-1:2]]), dim=1) 231 | h = torch.cat((h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1), 232 | h3.repeat(ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)), 233 | dim=0).reshape(4*ch, 1, 1, m2) 234 | 235 | X = F.conv2d(X, h, groups=4*ch) 236 | # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor 237 | # [batch, ch, r2, 4, c] 238 | X = torch.stack([X[:,:ch], X[:,ch:2*ch], X[:,2*ch:3*ch], X[:,3*ch:]], dim=4).view(batch, ch, r, c*2) 239 | return X 240 | 241 | 242 | # def q2c(y, dim=-1): 243 | def q2c(y, dim=-1): 244 | """ 245 | Convert from quads in y to complex numbers in z. 246 | """ 247 | 248 | # Arrange pixels from the corners of the quads into 249 | # 2 subimages of alternate real and imag pixels. 250 | # a----b 251 | # | | 252 | # | | 253 | # c----d 254 | # Combine (a,b) and (d,c) to form two complex subimages. 255 | y = y/np.sqrt(2) 256 | a, b = y[:,:, 0::2, 0::2], y[:,:, 0::2, 1::2] 257 | c, d = y[:,:, 1::2, 0::2], y[:,:, 1::2, 1::2] 258 | 259 | # return torch.stack((a-d, b+c), dim=dim), torch.stack((a+d, b-c), dim=dim) 260 | return ((a-d, b+c), (a+d, b-c)) 261 | 262 | 263 | def c2q(w1, w2): 264 | """ 265 | Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers 266 | in z. 267 | 268 | Arrange pixels from the real and imag parts of the 2 highpasses 269 | into 4 separate subimages . 270 | A----B Re Im of w(:,:,1) 271 | | | 272 | | | 273 | C----D Re Im of w(:,:,2) 274 | 275 | """ 276 | w1r, w1i = w1 277 | w2r, w2i = w2 278 | 279 | x1 = w1r + w2r 280 | x2 = w1i + w2i 281 | x3 = w1i - w2i 282 | x4 = -w1r + w2r 283 | 284 | # Get the shape of the tensor excluding the real/imagniary part 285 | b, ch, r, c = w1r.shape 286 | 287 | # Create new empty tensor and fill it 288 | y = w1r.new_zeros((b, ch, r*2, c*2), requires_grad=w1r.requires_grad) 289 | y[:, :, ::2,::2] = x1 290 | y[:, :, ::2, 1::2] = x2 291 | y[:, :, 1::2, ::2] = x3 292 | y[:, :, 1::2, 1::2] = x4 293 | y /= np.sqrt(2) 294 | 295 | return y 296 | -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/lowlevel2.py: -------------------------------------------------------------------------------- 1 | """ This module was part of an attempt to speed up the DTCWT. The code was 2 | ultimately slower than the original implementation, but it is a nice 3 | reference point for doing a DTCWT directly as 4 separate DWTs. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from pytorch_wavelets.dwt.lowlevel import roll, mypad 10 | import pywt 11 | from pytorch_wavelets.dwt.transform2d import DWTForward, DWTInverse 12 | from pytorch_wavelets.dwt.lowlevel import afb2d, sfb2d_nonsep as sfb2d 13 | from pytorch_wavelets.dwt.lowlevel import prep_filt_afb2d, prep_filt_sfb2d_nonsep as prep_filt_sfb2d 14 | from pytorch_wavelets.dtcwt.coeffs import level1 as _level1, qshift as _qshift, biort as _biort 15 | 16 | 17 | class DTCWTForward2(nn.Module): 18 | """ DTCWT based on 4 DWTs. Still works, but the above implementation is 19 | faster """ 20 | def __init__(self, biort='farras', qshift='qshift_a', J=3, 21 | mode='symmetric'): 22 | super().__init__() 23 | self.biort = biort 24 | self.qshift = qshift 25 | self.J = J 26 | 27 | if isinstance(biort, str): 28 | biort = _level1(biort) 29 | assert len(biort) == 8 30 | h0a1, h0b1, _, _, h1a1, h1b1, _, _ = biort 31 | DWTaa1 = DWTForward(J=1, wave=(h0a1, h1a1, h0a1, h1a1), mode=mode) 32 | DWTab1 = DWTForward(J=1, wave=(h0a1, h1a1, h0b1, h1b1), mode=mode) 33 | DWTba1 = DWTForward(J=1, wave=(h0b1, h1b1, h0a1, h1a1), mode=mode) 34 | DWTbb1 = DWTForward(J=1, wave=(h0b1, h1b1, h0b1, h1b1), mode=mode) 35 | self.level1 = nn.ModuleList([DWTaa1, DWTab1, DWTba1, DWTbb1]) 36 | 37 | if J > 1: 38 | if isinstance(qshift, str): 39 | qshift = _qshift(qshift) 40 | assert len(qshift) == 8 41 | h0a, h0b, _, _, h1a, h1b, _, _ = qshift 42 | DWTaa = DWTForward(J-1, (h0a, h1a, h0a, h1a), mode=mode) 43 | DWTab = DWTForward(J-1, (h0a, h1a, h0b, h1b), mode=mode) 44 | DWTba = DWTForward(J-1, (h0b, h1b, h0a, h1a), mode=mode) 45 | DWTbb = DWTForward(J-1, (h0b, h1b, h0b, h1b), mode=mode) 46 | self.level2 = nn.ModuleList([DWTaa, DWTab, DWTba, DWTbb]) 47 | 48 | def forward(self, x): 49 | x = x/2 50 | J = self.J 51 | w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)] 52 | lows = [[None for _ in range(2)] for _ in range(2)] 53 | for m in range(2): 54 | for n in range(2): 55 | # Do the first level transform 56 | ll, (w[0][m][n],) = self.level1[m*2+n](x) 57 | # w[0][m][n] = [bands[:,:,2], bands[:,:,1], bands[:,:,3]] 58 | 59 | # Do the second+ level transform with the second level filters 60 | if J > 1: 61 | ll, bands = self.level2[m*2+n](ll) 62 | for j in range(1,J): 63 | w[j][m][n] = bands[j-1] 64 | lows[m][n] = ll 65 | 66 | # Convert the quads into real and imaginary parts 67 | yh = [None,] * J 68 | for j in range(J): 69 | deg75r, deg105i = pm(w[j][0][0][:,:,1], w[j][1][1][:,:,1]) 70 | deg105r, deg75i = pm(w[j][0][1][:,:,1], w[j][1][0][:,:,1]) 71 | deg15r, deg165i = pm(w[j][0][0][:,:,0], w[j][1][1][:,:,0]) 72 | deg165r, deg15i = pm(w[j][0][1][:,:,0], w[j][1][0][:,:,0]) 73 | deg135r, deg45i = pm(w[j][0][0][:,:,2], w[j][1][1][:,:,2]) 74 | deg45r, deg135i = pm(w[j][0][1][:,:,2], w[j][1][0][:,:,2]) 75 | w[j] = None 76 | yhr = torch.stack((deg15r, deg45r, deg75r, 77 | deg105r, deg135r, deg165r), dim=1) 78 | yhi = torch.stack((deg15i, deg45i, deg75i, 79 | deg105i, deg135i, deg165i), dim=1) 80 | yh[j] = torch.stack((yhr, yhi), dim=-1) 81 | 82 | return lows, yh 83 | 84 | 85 | class DTCWTInverse2(nn.Module): 86 | def __init__(self, biort='farras', qshift='qshift_a', 87 | mode='symmetric'): 88 | super().__init__() 89 | self.biort = biort 90 | self.qshift = qshift 91 | 92 | if isinstance(biort, str): 93 | biort = _level1(biort) 94 | assert len(biort) == 8 95 | _, _, g0a1, g0b1, _, _, g1a1, g1b1 = biort 96 | IWTaa1 = DWTInverse(wave=(g0a1, g1a1, g0a1, g1a1), mode=mode) 97 | IWTab1 = DWTInverse(wave=(g0a1, g1a1, g0b1, g1b1), mode=mode) 98 | IWTba1 = DWTInverse(wave=(g0b1, g1b1, g0a1, g1a1), mode=mode) 99 | IWTbb1 = DWTInverse(wave=(g0b1, g1b1, g0b1, g1b1), mode=mode) 100 | self.level1 = nn.ModuleList([IWTaa1, IWTab1, IWTba1, IWTbb1]) 101 | 102 | if isinstance(qshift, str): 103 | qshift = _qshift(qshift) 104 | assert len(qshift) == 8 105 | _, _, g0a, g0b, _, _, g1a, g1b = qshift 106 | IWTaa = DWTInverse(wave=(g0a, g1a, g0a, g1a), mode=mode) 107 | IWTab = DWTInverse(wave=(g0a, g1a, g0b, g1b), mode=mode) 108 | IWTba = DWTInverse(wave=(g0b, g1b, g0a, g1a), mode=mode) 109 | IWTbb = DWTInverse(wave=(g0b, g1b, g0b, g1b), mode=mode) 110 | self.level2 = nn.ModuleList([IWTaa, IWTab, IWTba, IWTbb]) 111 | 112 | def forward(self, x): 113 | # Convert the highs back to subbands 114 | yl, yh = x 115 | J = len(yh) 116 | # w = [[[[None for i in range(3)] for j in range(2)] 117 | # for k in range(2)] for l in range(J)] 118 | w = [[[[None for band in range(3)] for j in range(J)] 119 | for m in range(2)] for n in range(2)] 120 | for j in range(J): 121 | w[0][0][j][0], w[1][1][j][0] = pm( 122 | yh[j][:,2,:,:,:,0], yh[j][:,3,:,:,:,1]) 123 | w[0][1][j][0], w[1][0][j][0] = pm( 124 | yh[j][:,3,:,:,:,0], yh[j][:,2,:,:,:,1]) 125 | w[0][0][j][1], w[1][1][j][1] = pm( 126 | yh[j][:,0,:,:,:,0], yh[j][:,5,:,:,:,1]) 127 | w[0][1][j][1], w[1][0][j][1] = pm( 128 | yh[j][:,5,:,:,:,0], yh[j][:,0,:,:,:,1]) 129 | w[0][0][j][2], w[1][1][j][2] = pm( 130 | yh[j][:,1,:,:,:,0], yh[j][:,4,:,:,:,1]) 131 | w[0][1][j][2], w[1][0][j][2] = pm( 132 | yh[j][:,4,:,:,:,0], yh[j][:,1,:,:,:,1]) 133 | w[0][0][j] = torch.stack(w[0][0][j], dim=2) 134 | w[0][1][j] = torch.stack(w[0][1][j], dim=2) 135 | w[1][0][j] = torch.stack(w[1][0][j], dim=2) 136 | w[1][1][j] = torch.stack(w[1][1][j], dim=2) 137 | 138 | y = None 139 | for m in range(2): 140 | for n in range(2): 141 | lo = yl[m][n] 142 | if J > 1: 143 | lo = self.level2[m*2+n]((lo, w[m][n][1:])) 144 | lo = self.level1[m*2+n]((lo, (w[m][n][0],))) 145 | 146 | # Add to the output 147 | if y is None: 148 | y = lo 149 | else: 150 | y = y + lo 151 | 152 | # Normalize 153 | y = y/2 154 | return y 155 | 156 | 157 | def prep_filt_quad_afb2d_nonsep( 158 | h0a_col, h1a_col, h0a_row, h1a_row, 159 | h0b_col, h1b_col, h0b_row, h1b_row, 160 | h0c_col, h1c_col, h0c_row, h1c_row, 161 | h0d_col, h1d_col, h0d_row, h1d_row, device=None): 162 | """ 163 | Prepares the filters to be of the right form for the afb2d_nonsep function. 164 | In particular, makes 2d point spread functions, and mirror images them in 165 | preparation to do torch.conv2d. 166 | 167 | Inputs: 168 | h0_col (array-like): low pass column filter bank 169 | h1_col (array-like): high pass column filter bank 170 | h0_row (array-like): low pass row filter bank. If none, will assume the 171 | same as column filter 172 | h1_row (array-like): high pass row filter bank. If none, will assume the 173 | same as column filter 174 | device: which device to put the tensors on to 175 | 176 | Returns: 177 | filts: (4, 1, h, w) tensor ready to get the four subbands 178 | """ 179 | lla = np.outer(h0a_col, h0a_row) 180 | lha = np.outer(h1a_col, h0a_row) 181 | hla = np.outer(h0a_col, h1a_row) 182 | hha = np.outer(h1a_col, h1a_row) 183 | llb = np.outer(h0b_col, h0b_row) 184 | lhb = np.outer(h1b_col, h0b_row) 185 | hlb = np.outer(h0b_col, h1b_row) 186 | hhb = np.outer(h1b_col, h1b_row) 187 | llc = np.outer(h0c_col, h0c_row) 188 | lhc = np.outer(h1c_col, h0c_row) 189 | hlc = np.outer(h0c_col, h1c_row) 190 | hhc = np.outer(h1c_col, h1c_row) 191 | lld = np.outer(h0d_col, h0d_row) 192 | lhd = np.outer(h1d_col, h0d_row) 193 | hld = np.outer(h0d_col, h1d_row) 194 | hhd = np.outer(h1d_col, h1d_row) 195 | filts = np.stack([lla[None,::-1,::-1], llb[None,::-1,::-1], 196 | llc[None,::-1,::-1], lld[None,::-1,::-1], 197 | lha[None,::-1,::-1], lhb[None,::-1,::-1], 198 | lhc[None,::-1,::-1], lhd[None,::-1,::-1], 199 | hla[None,::-1,::-1], hlb[None,::-1,::-1], 200 | hlc[None,::-1,::-1], hld[None,::-1,::-1], 201 | hha[None,::-1,::-1], hhb[None,::-1,::-1], 202 | hhc[None,::-1,::-1], hhd[None,::-1,::-1]], 203 | axis=0) 204 | filts = torch.tensor(filts, dtype=torch.get_default_dtype(), device=device) 205 | return filts 206 | 207 | 208 | def prep_filt_quad_afb2d(h0a, h1a, h0b, h1b, device=None): 209 | """ 210 | Prepares the filters to be of the right form for the quad_afb2d function. 211 | 212 | Inputs: 213 | h0_col (array-like): low pass column filter bank 214 | h1_col (array-like): high pass column filter bank 215 | h0_row (array-like): low pass row filter bank. If none, will assume the 216 | same as column filter 217 | h1_row (array-like): high pass row filter bank. If none, will assume the 218 | same as column filter 219 | device: which device to put the tensors on to 220 | 221 | Returns: 222 | filts: (4, 1, h, w) tensor ready to get the four subbands 223 | """ 224 | h0a_col = np.array(h0a).ravel()[::-1][None, :, None] 225 | h1a_col = np.array(h1a).ravel()[::-1][None, :, None] 226 | h0b_col = np.array(h0a).ravel()[::-1][None, :, None] 227 | h1b_col = np.array(h1a).ravel()[::-1][None, :, None] 228 | h0c_col = np.array(h0b).ravel()[::-1][None, :, None] 229 | h1c_col = np.array(h1b).ravel()[::-1][None, :, None] 230 | h0d_col = np.array(h0b).ravel()[::-1][None, :, None] 231 | h1d_col = np.array(h1b).ravel()[::-1][None, :, None] 232 | h0a_row = np.array(h0a).ravel()[::-1][None, None, :] 233 | h1a_row = np.array(h1a).ravel()[::-1][None, None, :] 234 | h0b_row = np.array(h0b).ravel()[::-1][None, None, :] 235 | h1b_row = np.array(h1b).ravel()[::-1][None, None, :] 236 | h0c_row = np.array(h0a).ravel()[::-1][None, None, :] 237 | h1c_row = np.array(h1a).ravel()[::-1][None, None, :] 238 | h0d_row = np.array(h0b).ravel()[::-1][None, None, :] 239 | h1d_row = np.array(h1b).ravel()[::-1][None, None, :] 240 | cols = np.stack((h0a_col, h1a_col, 241 | h0b_col, h1b_col, 242 | h0c_col, h1c_col, 243 | h0d_col, h1d_col), axis=0) 244 | rows = np.stack((h0a_row, h1a_row, 245 | h0a_row, h1a_row, 246 | h0b_row, h1b_row, 247 | h0b_row, h1b_row, 248 | h0c_row, h1c_row, 249 | h0c_row, h1c_row, 250 | h0d_row, h1d_row, 251 | h0d_row, h1d_row), axis=0) 252 | cols = torch.tensor(np.copy(cols), dtype=torch.get_default_dtype(), 253 | device=device) 254 | rows = torch.tensor(np.copy(rows), dtype=torch.get_default_dtype(), 255 | device=device) 256 | return cols, rows 257 | 258 | 259 | def quad_afb2d(x, cols, rows, mode='zero', split=True, stride=2): 260 | """ Does a single level 2d wavelet decomposition of an input. Does separate 261 | row and column filtering by two calls to 262 | :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` 263 | 264 | Inputs: 265 | x (torch.Tensor): Input to decompose 266 | filts (list of ndarray or torch.Tensor): If a list of tensors has been 267 | given, this function assumes they are in the right form (the form 268 | returned by 269 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`). 270 | Otherwise, this function will prepare the filters to be of the right 271 | form by calling 272 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`. 273 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which 274 | padding to use. If periodization, the output size will be half the 275 | input size. Otherwise, the output size will be slightly larger than 276 | half. 277 | """ 278 | x = x/2 279 | C = x.shape[1] 280 | cols = torch.cat([cols]*C, dim=0) 281 | rows = torch.cat([rows]*C, dim=0) 282 | 283 | if mode == 'per' or mode == 'periodization': 284 | # Do column filtering 285 | L = cols.shape[2] 286 | L2 = L // 2 287 | if x.shape[2] % 2 == 1: 288 | x = torch.cat((x, x[:,:,-1:]), dim=2) 289 | N2 = x.shape[2] // 2 290 | x = roll(x, -L2, dim=2) 291 | pad = (L-1, 0) 292 | lohi = F.conv2d(x, cols, padding=pad, stride=(stride,1), groups=C) 293 | lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2] 294 | lohi = lohi[:,:,:N2] 295 | 296 | # Do row filtering 297 | L = rows.shape[3] 298 | L2 = L // 2 299 | if lohi.shape[3] % 2 == 1: 300 | lohi = torch.cat((lohi, lohi[:,:,:,-1:]), dim=3) 301 | N2 = x.shape[3] // 2 302 | lohi = roll(lohi, -L2, dim=3) 303 | pad = (0, L-1) 304 | w = F.conv2d(lohi, rows, padding=pad, stride=(1,stride), groups=8*C) 305 | w[:,:,:,:L2] = w[:,:,:,:L2] + w[:,:,:,N2:N2+L2] 306 | w = w[:,:,:,:N2] 307 | elif mode == 'zero': 308 | # Do column filtering 309 | N = x.shape[2] 310 | L = cols.shape[2] 311 | outsize = pywt.dwt_coeff_len(N, L, mode='zero') 312 | p = 2 * (outsize - 1) - N + L 313 | 314 | # Sadly, pytorch only allows for same padding before and after, if 315 | # we need to do more padding after for odd length signals, have to 316 | # prepad 317 | if p % 2 == 1: 318 | x = F.pad(x, (0, 0, 0, 1)) 319 | pad = (p//2, 0) 320 | # Calculate the high and lowpass 321 | lohi = F.conv2d(x, cols, padding=pad, stride=(stride,1), groups=C) 322 | 323 | # Do row filtering 324 | N = lohi.shape[3] 325 | L = rows.shape[3] 326 | outsize = pywt.dwt_coeff_len(N, L, mode='zero') 327 | p = 2 * (outsize - 1) - N + L 328 | if p % 2 == 1: 329 | lohi = F.pad(lohi, (0, 1, 0, 0)) 330 | pad = (0, p//2) 331 | w = F.conv2d(lohi, rows, padding=pad, stride=(1,stride), groups=8*C) 332 | elif mode == 'symmetric' or mode == 'reflect': 333 | # Do column filtering 334 | N = x.shape[2] 335 | L = cols.shape[2] 336 | outsize = pywt.dwt_coeff_len(N, L, mode=mode) 337 | p = 2 * (outsize - 1) - N + L 338 | x = mypad(x, pad=(0, 0, p//2, (p+1)//2), mode=mode) 339 | lohi = F.conv2d(x, cols, stride=(stride,1), groups=C) 340 | 341 | # Do row filtering 342 | N = lohi.shape[3] 343 | L = rows.shape[3] 344 | outsize = pywt.dwt_coeff_len(N, L, mode=mode) 345 | p = 2 * (outsize - 1) - N + L 346 | lohi = mypad(lohi, pad=(p//2, (p+1)//2, 0, 0), mode=mode) 347 | w = F.conv2d(lohi, rows, stride=(1,stride), groups=8*C) 348 | else: 349 | raise ValueError("Unkown pad type: {}".format(mode)) 350 | 351 | y = w.view((w.shape[0], C, 4, 4, w.shape[-2], w.shape[-1])) 352 | yl = y[:,:,:,0] 353 | yh = y[:,:,:,1:] 354 | deg75r, deg105i = pm(yh[:,:,0,0], yh[:,:,3,0]) 355 | deg105r, deg75i = pm(yh[:,:,1,0], yh[:,:,2,0]) 356 | deg15r, deg165i = pm(yh[:,:,0,1], yh[:,:,3,1]) 357 | deg165r, deg15i = pm(yh[:,:,1,1], yh[:,:,2,1]) 358 | deg135r, deg45i = pm(yh[:,:,0,2], yh[:,:,3,2]) 359 | deg45r, deg135i = pm(yh[:,:,1,2], yh[:,:,2,2]) 360 | yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1) 361 | yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1) 362 | yh = torch.stack((yhr, yhi), dim=-1) 363 | 364 | yl_rowa = torch.stack((yl[:,:,1], yl[:,:,0]), dim=-1) 365 | yl_rowb = torch.stack((yl[:,:,3], yl[:,:,2]), dim=-1) 366 | yl_rowa = yl_rowa.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1]*2) 367 | yl_rowb = yl_rowb.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1]*2) 368 | z = torch.stack((yl_rowb, yl_rowa), dim=-2) 369 | yl = z.view(yl.shape[0], C, yl.shape[-2]*2, yl.shape[-1]*2) 370 | 371 | return yl.contiguous(), yh 372 | 373 | 374 | def quad_afb2d_nonsep(x, filts, mode='zero'): 375 | """ Does a 1 level 2d wavelet decomposition of an input. Doesn't do separate 376 | row and column filtering. 377 | 378 | Inputs: 379 | x (torch.Tensor): Input to decompose 380 | filts (list or torch.Tensor): If a list is given, should be the low and 381 | highpass filter banks. If a tensor is given, it should be of the 382 | form created by 383 | :py:func:`pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d_nonsep` 384 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which 385 | padding to use. If periodization, the output size will be half the 386 | input size. Otherwise, the output size will be slightly larger than 387 | half. 388 | """ 389 | C = x.shape[1] 390 | Ny = x.shape[2] 391 | Nx = x.shape[3] 392 | 393 | # Check the filter inputs 394 | f = torch.cat([filts]*C, dim=0) 395 | Ly = f.shape[2] 396 | Lx = f.shape[3] 397 | 398 | if mode == 'periodization' or mode == 'per': 399 | if x.shape[2] % 2 == 1: 400 | x = torch.cat((x, x[:,:,-1:]), dim=2) 401 | Ny += 1 402 | if x.shape[3] % 2 == 1: 403 | x = torch.cat((x, x[:,:,:,-1:]), dim=3) 404 | Nx += 1 405 | pad = (Ly-1, Lx-1) 406 | stride = (2, 2) 407 | x = roll(roll(x, -Ly//2, dim=2), -Lx//2, dim=3) 408 | y = F.conv2d(x, f, padding=pad, stride=stride, groups=C) 409 | y[:,:,:Ly//2] += y[:,:,Ny//2:Ny//2+Ly//2] 410 | y[:,:,:,:Lx//2] += y[:,:,:,Nx//2:Nx//2+Lx//2] 411 | y = y[:,:,:Ny//2, :Nx//2] 412 | elif mode == 'zero' or mode == 'symmetric' or mode == 'reflect': 413 | # Calculate the pad size 414 | out1 = pywt.dwt_coeff_len(Ny, Ly, mode=mode) 415 | out2 = pywt.dwt_coeff_len(Nx, Lx, mode=mode) 416 | p1 = 2 * (out1 - 1) - Ny + Ly 417 | p2 = 2 * (out2 - 1) - Nx + Lx 418 | if mode == 'zero': 419 | # Sadly, pytorch only allows for same padding before and after, if 420 | # we need to do more padding after for odd length signals, have to 421 | # prepad 422 | if p1 % 2 == 1 and p2 % 2 == 1: 423 | x = F.pad(x, (0, 1, 0, 1)) 424 | elif p1 % 2 == 1: 425 | x = F.pad(x, (0, 0, 0, 1)) 426 | elif p2 % 2 == 1: 427 | x = F.pad(x, (0, 1, 0, 0)) 428 | # Calculate the high and lowpass 429 | y = F.conv2d( 430 | x, f, padding=(p1//2, p2//2), stride=2, groups=C) 431 | elif mode == 'symmetric' or mode == 'reflect': 432 | pad = (p2//2, (p2+1)//2, p1//2, (p1+1)//2) 433 | x = mypad(x, pad=pad, mode=mode) 434 | y = F.conv2d(x, f, stride=2, groups=C) 435 | else: 436 | raise ValueError("Unkown pad type: {}".format(mode)) 437 | 438 | y = y.reshape((y.shape[0], C, 4, y.shape[-2], y.shape[-1])) 439 | yl = y[:,:,0].contiguous() 440 | yh = y[:,:,1:].contiguous() 441 | return yl, yh 442 | 443 | 444 | def cplxdual2D(x, J, level1='farras', qshift='qshift_a', mode='periodization', 445 | mag=False): 446 | """ Do a complex dtcwt 447 | 448 | Returns: 449 | lows: lowpass outputs from each of the 4 trees. Is a 2x2 list of lists 450 | w: bandpass outputs from each of the 4 trees. Is a list of lists, with 451 | shape [J][2][2][3]. Initially the 3 outputs are the lh, hl and hh from 452 | each of the 4 trees. After doing sums and differences though, they 453 | become the real and imaginary parts for the 6 orientations. In 454 | particular: 455 | first index - indexes over scales 456 | second index - 0 = real, 1 = imaginary 457 | third and fourth indices: 458 | 0,1 = 15 degrees 459 | 1,2 = 45 degrees 460 | 0,0 = 75 degrees 461 | 1,0 = 105 degrees 462 | 0,2 = 135 degrees 463 | 1,1 = 165 degrees 464 | """ 465 | x = x/2 466 | # Get the filters 467 | h0a1, h0b1, _, _, h1a1, h1b1, _, _ = _level1(level1) 468 | h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) 469 | 470 | Faf = ((prep_filt_afb2d(h0a1, h1a1, h0a1, h1a1, device=x.device), 471 | prep_filt_afb2d(h0a1, h1a1, h0b1, h1b1, device=x.device)), 472 | (prep_filt_afb2d(h0b1, h1b1, h0a1, h1a1, device=x.device), 473 | prep_filt_afb2d(h0b1, h1b1, h0b1, h1b1, device=x.device))) 474 | af = ((prep_filt_afb2d(h0a, h1a, h0a, h1a, device=x.device), 475 | prep_filt_afb2d(h0a, h1a, h0b, h1b, device=x.device)), 476 | (prep_filt_afb2d(h0b, h1b, h0a, h1a, device=x.device), 477 | prep_filt_afb2d(h0b, h1b, h0b, h1b, device=x.device))) 478 | 479 | # Do 4 fully decimated dwts 480 | w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)] 481 | lows = [[None for _ in range(2)] for _ in range(2)] 482 | for m in range(2): 483 | for n in range(2): 484 | # Do the first level transform with the first level filters 485 | # ll, bands = afb2d(x, (Faf[m][0], Faf[m][1], Faf[n][0], Faf[n][1]), mode=mode) 486 | bands = afb2d(x, Faf[m][n], mode=mode) 487 | # Separate the low and bandpasses 488 | s = bands.shape 489 | bands = bands.reshape(s[0], -1, 4, s[-2], s[-1]) 490 | ll = bands[:,:,0].contiguous() 491 | w[0][m][n] = [bands[:,:,2], bands[:,:,1], bands[:,:,3]] 492 | 493 | # Do the second+ level transform with the second level filters 494 | for j in range(1,J): 495 | # ll, bands = afb2d(ll, (af[m][0], af[m][1], af[n][0], af[n][1]), mode=mode) 496 | bands = afb2d(ll, af[m][n], mode=mode) 497 | # Separate the low and bandpasses 498 | s = bands.shape 499 | bands = bands.reshape(s[0], -1, 4, s[-2], s[-1]) 500 | ll = bands[:,:,0].contiguous() 501 | w[j][m][n] = [bands[:,:,2], bands[:,:,1], bands[:,:,3]] 502 | lows[m][n] = ll 503 | 504 | # Convert the quads into real and imaginary parts 505 | yh = [None,] * J 506 | for j in range(J): 507 | deg75r, deg105i = pm(w[j][0][0][0], w[j][1][1][0]) 508 | deg105r, deg75i = pm(w[j][0][1][0], w[j][1][0][0]) 509 | deg15r, deg165i = pm(w[j][0][0][1], w[j][1][1][1]) 510 | deg165r, deg15i = pm(w[j][0][1][1], w[j][1][0][1]) 511 | deg135r, deg45i = pm(w[j][0][0][2], w[j][1][1][2]) 512 | deg45r, deg135i = pm(w[j][0][1][2], w[j][1][0][2]) 513 | yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1) 514 | yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1) 515 | if mag: 516 | yh[j] = torch.sqrt(yhr**2 + yhi**2 + 0.01) - np.sqrt(0.01) 517 | else: 518 | yh[j] = torch.stack((yhr, yhi), dim=-1) 519 | 520 | return lows, yh 521 | 522 | 523 | def icplxdual2D(yl, yh, level1='farras', qshift='qshift_a', mode='periodization'): 524 | # Get the filters 525 | _, _, g0a1, g0b1, _, _, g1a1, g1b1 = _level1(level1) 526 | _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) 527 | 528 | dev = yl[0][0].device 529 | Faf = ((prep_filt_sfb2d(g0a1, g1a1, g0a1, g1a1, device=dev), 530 | prep_filt_sfb2d(g0a1, g1a1, g0b1, g1b1, device=dev)), 531 | (prep_filt_sfb2d(g0b1, g1b1, g0a1, g1a1, device=dev), 532 | prep_filt_sfb2d(g0b1, g1b1, g0b1, g1b1, device=dev))) 533 | af = ((prep_filt_sfb2d(g0a, g1a, g0a, g1a, device=dev), 534 | prep_filt_sfb2d(g0a, g1a, g0b, g1b, device=dev)), 535 | (prep_filt_sfb2d(g0b, g1b, g0a, g1a, device=dev), 536 | prep_filt_sfb2d(g0b, g1b, g0b, g1b, device=dev))) 537 | 538 | # Convert the highs back to subbands 539 | J = len(yh) 540 | w = [[[[None for i in range(3)] for j in range(2)] for k in range(2)] for l in range(J)] 541 | for j in range(J): 542 | w[j][0][0][0], w[j][1][1][0] = pm(yh[j][:,2,:,:,:,0], 543 | yh[j][:,3,:,:,:,1]) 544 | w[j][0][1][0], w[j][1][0][0] = pm(yh[j][:,3,:,:,:,0], 545 | yh[j][:,2,:,:,:,1]) 546 | w[j][0][0][1], w[j][1][1][1] = pm(yh[j][:,0,:,:,:,0], 547 | yh[j][:,5,:,:,:,1]) 548 | w[j][0][1][1], w[j][1][0][1] = pm(yh[j][:,5,:,:,:,0], 549 | yh[j][:,0,:,:,:,1]) 550 | w[j][0][0][2], w[j][1][1][2] = pm(yh[j][:,1,:,:,:,0], 551 | yh[j][:,4,:,:,:,1]) 552 | w[j][0][1][2], w[j][1][0][2] = pm(yh[j][:,4,:,:,:,0], 553 | yh[j][:,1,:,:,:,1]) 554 | w[j][0][0] = torch.stack(w[j][0][0], dim=2) 555 | w[j][0][1] = torch.stack(w[j][0][1], dim=2) 556 | w[j][1][0] = torch.stack(w[j][1][0], dim=2) 557 | w[j][1][1] = torch.stack(w[j][1][1], dim=2) 558 | 559 | y = None 560 | for m in range(2): 561 | for n in range(2): 562 | lo = yl[m][n] 563 | for j in range(J-1, 0, -1): 564 | lo = sfb2d(lo, w[j][m][n], af[m][n], mode=mode) 565 | lo = sfb2d(lo, w[0][m][n], Faf[m][n], mode=mode) 566 | 567 | # Add to the output 568 | if y is None: 569 | y = lo 570 | else: 571 | y = y + lo 572 | 573 | # Normalize 574 | y = y/2 575 | return y 576 | 577 | 578 | def pm(a, b): 579 | u = (a + b)/np.sqrt(2) 580 | v = (a - b)/np.sqrt(2) 581 | return u, v 582 | -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/transform2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy import ndarray, sqrt 4 | 5 | from pytorch_wavelets.dtcwt.coeffs import qshift as _qshift, biort as _biort, level1 6 | from pytorch_wavelets.dtcwt.lowlevel import prep_filt 7 | from pytorch_wavelets.dtcwt.transform_funcs import FWD_J1, FWD_J2PLUS 8 | from pytorch_wavelets.dtcwt.transform_funcs import INV_J1, INV_J2PLUS 9 | from pytorch_wavelets.dtcwt.transform_funcs import get_dimensions6 10 | from pytorch_wavelets.dwt.lowlevel import mode_to_int 11 | from pytorch_wavelets.dwt.transform2d import DWTForward, DWTInverse 12 | 13 | 14 | def pm(a, b): 15 | u = (a + b)/sqrt(2) 16 | v = (a - b)/sqrt(2) 17 | return u, v 18 | 19 | 20 | class DTCWTForward(nn.Module): 21 | """ Performs a 2d DTCWT Forward decomposition of an image 22 | 23 | Args: 24 | biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'. 25 | Specifies the first level biorthogonal wavelet filters. Can also 26 | give a two tuple for the low and highpass filters directly. 27 | qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c', 28 | 'qshift_d'. Specifies the second level quarter shift filters. Can 29 | also give a 4-tuple for the low tree a, low tree b, high tree a and 30 | high tree b filters directly. 31 | J (int): Number of levels of decomposition 32 | skip_hps (bools): List of bools of length J which specify whether or 33 | not to calculate the bandpass outputs at the given scale. 34 | skip_hps[0] is for the first scale. Can be a single bool in which 35 | case that is applied to all scales. 36 | include_scale (bool): If true, return the bandpass outputs. Can also be 37 | a list of length J specifying which lowpasses to return. I.e. if 38 | [False, True, True], the forward call will return the second and 39 | third lowpass outputs, but discard the lowpass from the first level 40 | transform. 41 | o_dim (int): Which dimension to put the orientations in 42 | ri_dim (int): which dimension to put the real and imaginary parts 43 | """ 44 | def __init__(self, biort='near_sym_a', qshift='qshift_a', 45 | J=3, skip_hps=False, include_scale=False, 46 | o_dim=2, ri_dim=-1, mode='symmetric'): 47 | super().__init__() 48 | if o_dim == ri_dim: 49 | raise ValueError("Orientations and real/imaginary parts must be " 50 | "in different dimensions.") 51 | 52 | self.biort = biort 53 | self.qshift = qshift 54 | self.J = J 55 | self.o_dim = o_dim 56 | self.ri_dim = ri_dim 57 | self.mode = mode 58 | if isinstance(biort, str): 59 | h0o, _, h1o, _ = _biort(biort) 60 | self.register_buffer('h0o', prep_filt(h0o, 1)) 61 | self.register_buffer('h1o', prep_filt(h1o, 1)) 62 | else: 63 | self.register_buffer('h0o', prep_filt(biort[0], 1)) 64 | self.register_buffer('h1o', prep_filt(biort[1], 1)) 65 | if isinstance(qshift, str): 66 | h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) 67 | self.register_buffer('h0a', prep_filt(h0a, 1)) 68 | self.register_buffer('h0b', prep_filt(h0b, 1)) 69 | self.register_buffer('h1a', prep_filt(h1a, 1)) 70 | self.register_buffer('h1b', prep_filt(h1b, 1)) 71 | else: 72 | self.register_buffer('h0a', prep_filt(qshift[0], 1)) 73 | self.register_buffer('h0b', prep_filt(qshift[1], 1)) 74 | self.register_buffer('h1a', prep_filt(qshift[2], 1)) 75 | self.register_buffer('h1b', prep_filt(qshift[3], 1)) 76 | 77 | # Get the function to do the DTCWT 78 | if isinstance(skip_hps, (list, tuple, ndarray)): 79 | self.skip_hps = skip_hps 80 | else: 81 | self.skip_hps = [skip_hps,] * self.J 82 | if isinstance(include_scale, (list, tuple, ndarray)): 83 | self.include_scale = include_scale 84 | else: 85 | self.include_scale = [include_scale,] * self.J 86 | 87 | def forward(self, x): 88 | """ Forward Dual Tree Complex Wavelet Transform 89 | 90 | Args: 91 | x (tensor): Input to transform. Should be of shape 92 | :math:`(N, C_{in}, H_{in}, W_{in})`. 93 | 94 | Returns: 95 | (yl, yh) 96 | tuple of lowpass (yl) and bandpass (yh) coefficients. 97 | If include_scale was true, yl will be a list of lowpass 98 | coefficients, otherwise will be just the final lowpass 99 | coefficient of shape :math:`(N, C_{in}, H_{in}', W_{in}')`. Yh 100 | will be a list of the complex bandpass coefficients of shape 101 | :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar 102 | shape depending on o_dim and ri_dim 103 | 104 | Note: 105 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a 106 | DTCWT pyramid. 107 | """ 108 | scales = [x.new_zeros([]),] * self.J 109 | highs = [x.new_zeros([]),] * self.J 110 | mode = mode_to_int(self.mode) 111 | if self.J == 0: 112 | return x, None 113 | 114 | # If the row/col count of X is not divisible by 2 then we need to 115 | # extend X 116 | r, c = x.shape[2:] 117 | if r % 2 != 0: 118 | x = torch.cat((x, x[:,:,-1:]), dim=2) 119 | if c % 2 != 0: 120 | x = torch.cat((x, x[:,:,:,-1:]), dim=3) 121 | 122 | # Do the level 1 transform 123 | low, h = FWD_J1.apply(x, self.h0o, self.h1o, self.skip_hps[0], 124 | self.o_dim, self.ri_dim, mode) 125 | highs[0] = h 126 | if self.include_scale[0]: 127 | scales[0] = low 128 | 129 | for j in range(1, self.J): 130 | # Ensure the lowpass is divisible by 4 131 | r, c = low.shape[2:] 132 | if r % 4 != 0: 133 | low = torch.cat((low[:,:,0:1], low, low[:,:,-1:]), dim=2) 134 | if c % 4 != 0: 135 | low = torch.cat((low[:,:,:,0:1], low, low[:,:,:,-1:]), dim=3) 136 | 137 | low, h = FWD_J2PLUS.apply(low, self.h0a, self.h1a, self.h0b, 138 | self.h1b, self.skip_hps[j], self.o_dim, 139 | self.ri_dim, mode) 140 | highs[j] = h 141 | if self.include_scale[j]: 142 | scales[j] = low 143 | 144 | if True in self.include_scale: 145 | return scales, highs 146 | else: 147 | return low, highs 148 | 149 | 150 | class DTCWTInverse(nn.Module): 151 | """ 2d DTCWT Inverse 152 | 153 | Args: 154 | biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'. 155 | Specifies the first level biorthogonal wavelet filters. Can also 156 | give a two tuple for the low and highpass filters directly. 157 | qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c', 158 | 'qshift_d'. Specifies the second level quarter shift filters. Can 159 | also give a 4-tuple for the low tree a, low tree b, high tree a and 160 | high tree b filters directly. 161 | J (int): Number of levels of decomposition. 162 | o_dim (int):which dimension the orientations are in 163 | ri_dim (int): which dimension to put th real and imaginary parts in 164 | """ 165 | 166 | def __init__(self, biort='near_sym_a', qshift='qshift_a', o_dim=2, 167 | ri_dim=-1, mode='symmetric'): 168 | super().__init__() 169 | self.biort = biort 170 | self.qshift = qshift 171 | self.o_dim = o_dim 172 | self.ri_dim = ri_dim 173 | self.mode = mode 174 | if isinstance(biort, str): 175 | _, g0o, _, g1o = _biort(biort) 176 | self.register_buffer('g0o', prep_filt(g0o, 1)) 177 | self.register_buffer('g1o', prep_filt(g1o, 1)) 178 | else: 179 | self.register_buffer('g0o', prep_filt(biort[0], 1)) 180 | self.register_buffer('g1o', prep_filt(biort[1], 1)) 181 | if isinstance(qshift, str): 182 | _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) 183 | self.register_buffer('g0a', prep_filt(g0a, 1)) 184 | self.register_buffer('g0b', prep_filt(g0b, 1)) 185 | self.register_buffer('g1a', prep_filt(g1a, 1)) 186 | self.register_buffer('g1b', prep_filt(g1b, 1)) 187 | else: 188 | self.register_buffer('g0a', prep_filt(qshift[0], 1)) 189 | self.register_buffer('g0b', prep_filt(qshift[1], 1)) 190 | self.register_buffer('g1a', prep_filt(qshift[2], 1)) 191 | self.register_buffer('g1b', prep_filt(qshift[3], 1)) 192 | 193 | def forward(self, coeffs): 194 | """ 195 | Args: 196 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: 197 | yl is a tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` 198 | and yh is a list of the complex bandpass coefficients of shape 199 | :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar 200 | depending on o_dim and ri_dim 201 | 202 | Returns: 203 | Reconstructed output 204 | 205 | Note: 206 | Can accept Nones or an empty tensor (torch.tensor([])) for the 207 | lowpass or bandpass inputs. In this cases, an array of zeros 208 | replaces that input. 209 | 210 | Note: 211 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a 212 | DTCWT pyramid. 213 | 214 | Note: 215 | If include_scale was true for the forward pass, you should provide 216 | only the final lowpass output here, as normal for an inverse wavelet 217 | transform. 218 | """ 219 | low, highs = coeffs 220 | J = len(highs) 221 | mode = mode_to_int(self.mode) 222 | _, _, h_dim, w_dim = get_dimensions6( 223 | self.o_dim, self.ri_dim) 224 | for j, s in zip(range(J-1, 0, -1), highs[1:][::-1]): 225 | if s is not None and s.shape != torch.Size([]): 226 | assert s.shape[self.o_dim] == 6, "Inverse transform must " \ 227 | "have input with 6 orientations" 228 | assert len(s.shape) == 6, "Bandpass inputs must have " \ 229 | "6 dimensions" 230 | assert s.shape[self.ri_dim] == 2, "Inputs must be complex " \ 231 | "with real and imaginary parts in the ri dimension" 232 | # Ensure the low and highpass are the right size 233 | r, c = low.shape[2:] 234 | r1, c1 = s.shape[h_dim], s.shape[w_dim] 235 | if r != r1 * 2: 236 | low = low[:,:,1:-1] 237 | if c != c1 * 2: 238 | low = low[:,:,:,1:-1] 239 | 240 | low = INV_J2PLUS.apply(low, s, self.g0a, self.g1a, self.g0b, 241 | self.g1b, self.o_dim, self.ri_dim, mode) 242 | 243 | # Ensure the low and highpass are the right size 244 | if highs[0] is not None and highs[0].shape != torch.Size([]): 245 | r, c = low.shape[2:] 246 | r1, c1 = highs[0].shape[h_dim], highs[0].shape[w_dim] 247 | if r != r1 * 2: 248 | low = low[:,:,1:-1] 249 | if c != c1 * 2: 250 | low = low[:,:,:,1:-1] 251 | 252 | low = INV_J1.apply(low, highs[0], self.g0o, self.g1o, self.o_dim, 253 | self.ri_dim, mode) 254 | return low 255 | 256 | 257 | -------------------------------------------------------------------------------- /pytorch_wavelets/dtcwt/transform_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import tensor 3 | from torch.autograd import Function 4 | from pytorch_wavelets.dtcwt.lowlevel import colfilter, rowfilter 5 | from pytorch_wavelets.dtcwt.lowlevel import coldfilt, rowdfilt 6 | from pytorch_wavelets.dtcwt.lowlevel import colifilt, rowifilt, q2c, c2q 7 | from pytorch_wavelets.dwt.lowlevel import int_to_mode 8 | 9 | 10 | def get_dimensions5(o_dim, ri_dim): 11 | """ Get the orientation, height and width dimensions after the real and 12 | imaginary parts have been popped off (5 dimensional tensor).""" 13 | o_dim = (o_dim % 6) 14 | ri_dim = (ri_dim % 6) 15 | 16 | if ri_dim < o_dim: 17 | o_dim -= 1 18 | 19 | if o_dim == 4: 20 | h_dim = 2 21 | w_dim = 3 22 | elif o_dim == 3: 23 | h_dim = 2 24 | w_dim = 4 25 | else: 26 | h_dim = 3 27 | w_dim = 4 28 | 29 | return o_dim, ri_dim, h_dim, w_dim 30 | 31 | 32 | def get_dimensions6(o_dim, ri_dim): 33 | """ Get the orientation, real/imag, height and width dimensions 34 | for the full tensor (6 dimensions).""" 35 | # Calculate which dimension to put the real and imaginary parts and the 36 | # orientations. Also work out where the rows and columns in the original 37 | # image were 38 | o_dim = (o_dim % 6) 39 | ri_dim = (ri_dim % 6) 40 | 41 | if ri_dim < o_dim: 42 | o_dim -= 1 43 | 44 | if o_dim >= 3 and ri_dim >= 3: 45 | h_dim = 2 46 | elif o_dim >= 4 or ri_dim >= 4: 47 | h_dim = 3 48 | else: 49 | h_dim = 4 50 | 51 | if o_dim >= 4 and ri_dim >= 4: 52 | w_dim = 3 53 | elif o_dim >= 4 or ri_dim >= 4: 54 | w_dim = 4 55 | else: 56 | w_dim = 5 57 | 58 | return o_dim, ri_dim, h_dim, w_dim 59 | 60 | 61 | def highs_to_orientations(lh, hl, hh, o_dim): 62 | (deg15r, deg15i), (deg165r, deg165i) = q2c(lh) 63 | (deg45r, deg45i), (deg135r, deg135i) = q2c(hh) 64 | (deg75r, deg75i), (deg105r, deg105i) = q2c(hl) 65 | 66 | # Convert real and imaginary to magnitude 67 | reals = torch.stack( 68 | [deg15r, deg45r, deg75r, deg105r, deg135r, deg165r], dim=o_dim) 69 | imags = torch.stack( 70 | [deg15i, deg45i, deg75i, deg105i, deg135i, deg165i], dim=o_dim) 71 | 72 | return reals, imags 73 | 74 | 75 | def orientations_to_highs(reals, imags, o_dim): 76 | dev = reals.device 77 | horiz = torch.index_select(reals, o_dim, tensor([0, 5], device=dev)) 78 | diag = torch.index_select(reals, o_dim, tensor([1, 4], device=dev)) 79 | vertic = torch.index_select(reals, o_dim, tensor([2, 3], device=dev)) 80 | deg15r, deg165r = torch.unbind(horiz, dim=o_dim) 81 | deg45r, deg135r = torch.unbind(diag, dim=o_dim) 82 | deg75r, deg105r = torch.unbind(vertic, dim=o_dim) 83 | dev = imags.device 84 | horiz = torch.index_select(imags, o_dim, tensor([0, 5], device=dev)) 85 | diag = torch.index_select(imags, o_dim, tensor([1, 4], device=dev)) 86 | vertic = torch.index_select(imags, o_dim, tensor([2, 3], device=dev)) 87 | deg15i, deg165i = torch.unbind(horiz, dim=o_dim) 88 | deg45i, deg135i = torch.unbind(diag, dim=o_dim) 89 | deg75i, deg105i = torch.unbind(vertic, dim=o_dim) 90 | 91 | lh = c2q((deg15r, deg15i), (deg165r, deg165i)) 92 | hl = c2q((deg75r, deg75i), (deg105r, deg105i)) 93 | hh = c2q((deg45r, deg45i), (deg135r, deg135i)) 94 | 95 | return lh, hl, hh 96 | 97 | 98 | def fwd_j1(x, h0, h1, skip_hps, o_dim, mode): 99 | """ Level 1 forward dtcwt. 100 | 101 | Have it as a separate function as can be used by 102 | the forward pass of the forward transform and the backward pass of the 103 | inverse transform. 104 | """ 105 | # Level 1 forward (biorthogonal analysis filters) 106 | if not skip_hps: 107 | lo = rowfilter(x, h0, mode) 108 | hi = rowfilter(x, h1, mode) 109 | ll = colfilter(lo, h0, mode) 110 | lh = colfilter(lo, h1, mode) 111 | del lo 112 | hl = colfilter(hi, h0, mode) 113 | hh = colfilter(hi, h1, mode) 114 | del hi 115 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim) 116 | else: 117 | ll = rowfilter(x, h0, mode) 118 | ll = colfilter(ll, h0, mode) 119 | highr = x.new_zeros([]) 120 | highi = x.new_zeros([]) 121 | return ll, highr, highi 122 | 123 | 124 | def fwd_j1_rot(x, h0, h1, h2, skip_hps, o_dim, mode): 125 | """ Level 1 forward dtcwt. 126 | 127 | Have it as a separate function as can be used by 128 | the forward pass of the forward transform and the backward pass of the 129 | inverse transform. 130 | """ 131 | # Level 1 forward (biorthogonal analysis filters) 132 | if not skip_hps: 133 | lo = rowfilter(x, h0, mode) 134 | hi = rowfilter(x, h1, mode) 135 | ba = rowfilter(x, h2, mode) 136 | 137 | lh = colfilter(lo, h1, mode) 138 | hl = colfilter(hi, h0, mode) 139 | hh = colfilter(ba, h2, mode) 140 | ll = colfilter(lo, h0, mode) 141 | 142 | del lo, hi, ba 143 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim) 144 | else: 145 | ll = rowfilter(x, h0, mode) 146 | ll = colfilter(ll, h0, mode) 147 | highr = x.new_zeros([]) 148 | highi = x.new_zeros([]) 149 | return ll, highr, highi 150 | 151 | 152 | def inv_j1(ll, highr, highi, g0, g1, o_dim, h_dim, w_dim, mode): 153 | """ Level1 inverse dtcwt. 154 | 155 | Have it as a separate function as can be used by the forward pass of the 156 | inverse transform and the backward pass of the forward transform. 157 | """ 158 | if highr is None or highr.shape == torch.Size([]): 159 | y = rowfilter(colfilter(ll, g0), g0) 160 | else: 161 | # Get the double sampled bandpass coefficients 162 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim) 163 | 164 | if ll is None or ll.shape == torch.Size([]): 165 | # Interpolate 166 | hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode) 167 | lo = colfilter(lh, g1, mode) 168 | del lh, hh, hl 169 | else: 170 | # Possibly cut back some rows to make the ll match the highs 171 | r, c = ll.shape[2:] 172 | r1, c1 = highr.shape[h_dim], highr.shape[w_dim] 173 | if r != r1 * 2: 174 | ll = ll[:,:,1:-1] 175 | if c != c1 * 2: 176 | ll = ll[:,:,:,1:-1] 177 | # Interpolate 178 | hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode) 179 | lo = colfilter(lh, g1, mode) + colfilter(ll, g0, mode) 180 | del lh, hl, hh 181 | 182 | y = rowfilter(hi, g1, mode) + rowfilter(lo, g0, mode) 183 | 184 | return y 185 | 186 | 187 | def inv_j1_rot(ll, highr, highi, g0, g1, g2, o_dim, h_dim, w_dim, mode): 188 | """ Level1 inverse dtcwt. 189 | 190 | Have it as a separate function as can be used by the forward pass of the 191 | inverse transform and the backward pass of the forward transform. 192 | """ 193 | if highr is None or highr.shape == torch.Size([]): 194 | y = rowfilter(colfilter(ll, g0), g0) 195 | else: 196 | # Get the double sampled bandpass coefficients 197 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim) 198 | 199 | if ll is None or ll.shape == torch.Size([]): 200 | # Interpolate 201 | lo = colfilter(lh, g1, mode) 202 | hi = colfilter(hl, g0, mode) 203 | ba = colfilter(hh, g2, mode) 204 | del lh, hh, hl 205 | else: 206 | # Possibly cut back some rows to make the ll match the highs 207 | r, c = ll.shape[2:] 208 | r1, c1 = highr.shape[h_dim], highr.shape[w_dim] 209 | if r != r1 * 2: 210 | ll = ll[:,:,1:-1] 211 | if c != c1 * 2: 212 | ll = ll[:,:,:,1:-1] 213 | 214 | # Interpolate 215 | lo = colfilter(lh, g1, mode) + colfilter(ll, g0, mode) 216 | hi = colfilter(hl, g0, mode) 217 | ba = colfilter(hh, g2, mode) 218 | del lh, hl, hh 219 | 220 | y = rowfilter(hi, g1, mode) + rowfilter(lo, g0, mode) + \ 221 | rowfilter(ba, g2, mode) 222 | 223 | return y 224 | 225 | 226 | def fwd_j2plus(x, h0a, h1a, h0b, h1b, skip_hps, o_dim, mode): 227 | """ Level 2 plus forward dtcwt. 228 | 229 | Have it as a separate function as can be used by 230 | the forward pass of the forward transform and the backward pass of the 231 | inverse transform. 232 | """ 233 | if not skip_hps: 234 | lo = rowdfilt(x, h0b, h0a, False, mode) 235 | hi = rowdfilt(x, h1b, h1a, True, mode) 236 | 237 | ll = coldfilt(lo, h0b, h0a, False, mode) 238 | lh = coldfilt(lo, h1b, h1a, True, mode) 239 | hl = coldfilt(hi, h0b, h0a, False, mode) 240 | hh = coldfilt(hi, h1b, h1a, True, mode) 241 | del lo, hi 242 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim) 243 | else: 244 | ll = rowdfilt(x, h0b, h0a, False, mode) 245 | ll = coldfilt(ll, h0b, h0a, False, mode) 246 | highr = None 247 | highi = None 248 | 249 | return ll, highr, highi 250 | 251 | 252 | def fwd_j2plus_rot(x, h0a, h1a, h0b, h1b, h2a, h2b, skip_hps, o_dim, mode): 253 | """ Level 2 plus forward dtcwt. 254 | 255 | Have it as a separate function as can be used by 256 | the forward pass of the forward transform and the backward pass of the 257 | inverse transform. 258 | """ 259 | if not skip_hps: 260 | lo = rowdfilt(x, h0b, h0a, False, mode) 261 | hi = rowdfilt(x, h1b, h1a, True, mode) 262 | ba = rowdfilt(x, h2b, h2a, True, mode) 263 | 264 | lh = coldfilt(lo, h1b, h1a, True, mode) 265 | hl = coldfilt(hi, h0b, h0a, False, mode) 266 | hh = coldfilt(ba, h2b, h2a, True, mode) 267 | ll = coldfilt(lo, h0b, h0a, False, mode) 268 | del lo, hi, ba 269 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim) 270 | else: 271 | ll = rowdfilt(x, h0b, h0a, False, mode) 272 | ll = coldfilt(ll, h0b, h0a, False, mode) 273 | highr = None 274 | highi = None 275 | 276 | return ll, highr, highi 277 | 278 | 279 | def inv_j2plus(ll, highr, highi, g0a, g1a, g0b, g1b, o_dim, h_dim, w_dim, mode): 280 | """ Level2+ inverse dtcwt. 281 | 282 | Have it as a separate function as can be used by the forward pass of the 283 | inverse transform and the backward pass of the forward transform. 284 | """ 285 | if highr is None or highr.shape == torch.Size([]): 286 | y = rowifilt(colifilt(ll, g0b, g0a, False, mode), g0b, g0a, False, mode) 287 | else: 288 | # Get the double sampled bandpass coefficients 289 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim) 290 | 291 | if ll is None or ll.shape == torch.Size([]): 292 | # Interpolate 293 | hi = colifilt(hh, g1b, g1a, True, mode) + \ 294 | colifilt(hl, g0b, g0a, False, mode) 295 | lo = colifilt(lh, g1b, g1a, True, mode) 296 | del lh, hh, hl 297 | else: 298 | # Interpolate 299 | hi = colifilt(hh, g1b, g1a, True, mode) + \ 300 | colifilt(hl, g0b, g0a, False, mode) 301 | lo = colifilt(lh, g1b, g1a, True, mode) + \ 302 | colifilt(ll, g0b, g0a, False, mode) 303 | del lh, hl, hh 304 | 305 | y = rowifilt(hi, g1b, g1a, True, mode) + \ 306 | rowifilt(lo, g0b, g0a, False, mode) 307 | return y 308 | 309 | 310 | def inv_j2plus_rot(ll, highr, highi, g0a, g1a, g0b, g1b, g2a, g2b, 311 | o_dim, h_dim, w_dim, mode): 312 | """ Level2+ inverse dtcwt. 313 | 314 | Have it as a separate function as can be used by the forward pass of the 315 | inverse transform and the backward pass of the forward transform. 316 | """ 317 | if highr is None or highr.shape == torch.Size([]): 318 | y = rowifilt(colifilt(ll, g0b, g0a, False, mode), g0b, g0a, False, mode) 319 | else: 320 | # Get the double sampled bandpass coefficients 321 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim) 322 | 323 | if ll is None or ll.shape == torch.Size([]): 324 | # Interpolate 325 | lo = colifilt(lh, g1b, g1a, True, mode) 326 | hi = colifilt(hl, g0b, g0a, False, mode) 327 | ba = colifilt(hh, g2b, g2a, True, mode) 328 | del lh, hh, hl 329 | else: 330 | # Interpolate 331 | lo = colifilt(lh, g1b, g1a, True, mode) + \ 332 | colifilt(ll, g0b, g0a, False, mode) 333 | hi = colifilt(hl, g0b, g0a, False, mode) 334 | ba = colifilt(hh, g2b, g2a, True, mode) 335 | del lh, hl, hh 336 | 337 | y = rowifilt(hi, g1b, g1a, True, mode) + \ 338 | rowifilt(lo, g0b, g0a, False, mode) + \ 339 | rowifilt(ba, g2b, g2a, True, mode) 340 | return y 341 | 342 | 343 | class FWD_J1(Function): 344 | """ Differentiable function doing 1 level forward DTCWT """ 345 | @staticmethod 346 | def forward(ctx, x, h0, h1, skip_hps, o_dim, ri_dim, mode): 347 | mode = int_to_mode(mode) 348 | ctx.mode = mode 349 | ctx.save_for_backward(h0, h1) 350 | ctx.dims = get_dimensions5(o_dim, ri_dim) 351 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1] 352 | 353 | ll, highr, highi = fwd_j1(x, h0, h1, skip_hps, o_dim, mode) 354 | if not skip_hps: 355 | highs = torch.stack((highr, highi), dim=ri_dim) 356 | else: 357 | highs = ll.new_zeros([]) 358 | return ll, highs 359 | 360 | @staticmethod 361 | def backward(ctx, dl, dh): 362 | h0, h1 = ctx.saved_tensors 363 | mode = ctx.mode 364 | dx = None 365 | if ctx.needs_input_grad[0]: 366 | o_dim, ri_dim, h_dim, w_dim = ctx.dims 367 | if dh is not None and dh.shape != torch.Size([]): 368 | dhr, dhi = torch.unbind(dh, dim=ri_dim) 369 | else: 370 | dhr = dl.new_zeros([]) 371 | dhi = dl.new_zeros([]) 372 | dx = inv_j1(dl, dhr, dhi, h0, h1, o_dim, h_dim, w_dim, mode) 373 | 374 | return dx, None, None, None, None, None, None 375 | 376 | 377 | class FWD_J2PLUS(Function): 378 | """ Differentiable function doing second level forward DTCWT """ 379 | @staticmethod 380 | def forward(ctx, x, h0a, h1a, h0b, h1b, skip_hps, o_dim, ri_dim, mode): 381 | mode = 'symmetric' 382 | ctx.mode = mode 383 | ctx.save_for_backward(h0a, h1a, h0b, h1b) 384 | ctx.dims = get_dimensions5(o_dim, ri_dim) 385 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1] 386 | 387 | ll, highr, highi = fwd_j2plus(x, h0a, h1a, h0b, h1b, skip_hps, o_dim, mode) 388 | if not skip_hps: 389 | highs = torch.stack((highr, highi), dim=ri_dim) 390 | else: 391 | highs = ll.new_zeros([]) 392 | return ll, highs 393 | 394 | @staticmethod 395 | def backward(ctx, dl, dh): 396 | h0a, h1a, h0b, h1b = ctx.saved_tensors 397 | mode = ctx.mode 398 | # The colifilt and rowifilt functions use conv2d not conv2d_transpose, 399 | # so need to reverse the filters 400 | h0a, h0b = h0b, h0a 401 | h1a, h1b = h1b, h1a 402 | dx = None 403 | if ctx.needs_input_grad[0]: 404 | o_dim, ri_dim, h_dim, w_dim = ctx.dims 405 | if dh is not None and dh.shape != torch.Size([]): 406 | dhr, dhi = torch.unbind(dh, dim=ri_dim) 407 | else: 408 | dhr = dl.new_zeros([]) 409 | dhi = dl.new_zeros([]) 410 | dx = inv_j2plus(dl, dhr, dhi, h0a, h1a, h0b, h1b, 411 | o_dim, h_dim, w_dim, mode) 412 | 413 | return dx, None, None, None, None, None, None, None, None 414 | 415 | 416 | class INV_J1(Function): 417 | """ Differentiable function doing 1 level inverse DTCWT """ 418 | @staticmethod 419 | def forward(ctx, lows, highs, g0, g1, o_dim, ri_dim, mode): 420 | mode = int_to_mode(mode) 421 | ctx.mode = mode 422 | ctx.save_for_backward(g0, g1) 423 | ctx.dims = get_dimensions5(o_dim, ri_dim) 424 | o_dim, ri_dim, h_dim, w_dim = ctx.dims 425 | if highs is not None and highs.shape != torch.Size([]): 426 | highr, highi = torch.unbind(highs, dim=ri_dim) 427 | else: 428 | highr = lows.new_zeros([]) 429 | highi = lows.new_zeros([]) 430 | y = inv_j1(lows, highr, highi, g0, g1, o_dim, h_dim, w_dim, mode) 431 | return y 432 | 433 | @staticmethod 434 | def backward(ctx, dy): 435 | g0, g1 = ctx.saved_tensors 436 | dl = None 437 | dh = None 438 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1] 439 | mode = ctx.mode 440 | if ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]: 441 | dl, _, _ = fwd_j1(dy, g0, g1, True, o_dim, mode) 442 | elif ctx.needs_input_grad[1] and not ctx.needs_input_grad[0]: 443 | _, dhr, dhi = fwd_j1(dy, g0, g1, False, o_dim, mode) 444 | dh = torch.stack((dhr, dhi), dim=ri_dim) 445 | elif ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: 446 | dl, dhr, dhi = fwd_j1(dy, g0, g1, False, o_dim, mode) 447 | dh = torch.stack((dhr, dhi), dim=ri_dim) 448 | 449 | return dl, dh, None, None, None, None, None 450 | 451 | 452 | class INV_J2PLUS(Function): 453 | """ Differentiable function doing level 2 onwards inverse DTCWT """ 454 | @staticmethod 455 | def forward(ctx, lows, highs, g0a, g1a, g0b, g1b, o_dim, ri_dim, mode): 456 | mode = 'symmetric' 457 | ctx.mode = mode 458 | ctx.save_for_backward(g0a, g1a, g0b, g1b) 459 | ctx.dims = get_dimensions5(o_dim, ri_dim) 460 | o_dim, ri_dim, h_dim, w_dim = ctx.dims 461 | if highs is not None and highs.shape != torch.Size([]): 462 | highr, highi = torch.unbind(highs, dim=ri_dim) 463 | else: 464 | highr = lows.new_zeros([]) 465 | highi = lows.new_zeros([]) 466 | y = inv_j2plus(lows, highr, highi, g0a, g1a, g0b, g1b, 467 | o_dim, h_dim, w_dim, mode) 468 | return y 469 | 470 | @staticmethod 471 | def backward(ctx, dy): 472 | g0a, g1a, g0b, g1b = ctx.saved_tensors 473 | g0a, g0b = g0b, g0a 474 | g1a, g1b = g1b, g1a 475 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1] 476 | mode = ctx.mode 477 | dl = None 478 | dh = None 479 | if ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]: 480 | dl, _, _ = fwd_j2plus(dy, g0a, g1a, g0b, g1b, True, o_dim, mode) 481 | elif ctx.needs_input_grad[1] and not ctx.needs_input_grad[0]: 482 | _, dhr, dhi = fwd_j2plus(dy, g0a, g1a, g0b, g1b, False, o_dim, mode) 483 | dh = torch.stack((dhr, dhi), dim=ri_dim) 484 | elif ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: 485 | dl, dhr, dhi = fwd_j2plus(dy, g0a, g1a, g0b, g1b, False, o_dim, mode) 486 | dh = torch.stack((dhr, dhi), dim=ri_dim) 487 | 488 | return dl, dh, None, None, None, None, None, None, None 489 | -------------------------------------------------------------------------------- /pytorch_wavelets/dwt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dwt/__init__.py -------------------------------------------------------------------------------- /pytorch_wavelets/dwt/swt_inverse.py: -------------------------------------------------------------------------------- 1 | 2 | def sfb1d_atrous(lo, hi, g0, g1, mode='periodization', dim=-1, dilation=1, 3 | pad1=None, pad=None): 4 | """ 1D synthesis filter bank of an image tensor with no upsampling. Used for 5 | the stationary wavelet transform. 6 | """ 7 | C = lo.shape[1] 8 | d = dim % 4 9 | # If g0, g1 are not tensors, make them. If they are, then assume that they 10 | # are in the right order 11 | if not isinstance(g0, torch.Tensor): 12 | g0 = torch.tensor(np.copy(np.array(g0).ravel()), 13 | dtype=torch.float, device=lo.device) 14 | if not isinstance(g1, torch.Tensor): 15 | g1 = torch.tensor(np.copy(np.array(g1).ravel()), 16 | dtype=torch.float, device=lo.device) 17 | L = g0.numel() 18 | shape = [1,1,1,1] 19 | shape[d] = L 20 | # If g aren't in the right shape, make them so 21 | if g0.shape != tuple(shape): 22 | g0 = g0.reshape(*shape) 23 | if g1.shape != tuple(shape): 24 | g1 = g1.reshape(*shape) 25 | g0 = torch.cat([g0]*C,dim=0) 26 | g1 = torch.cat([g1]*C,dim=0) 27 | 28 | # Calculate the padding size. 29 | # With dilation, zeros are inserted between the filter taps but not after. 30 | # that means a filter that is [a b c d] becomes [a 0 b 0 c 0 d]. 31 | centre = L / 2 32 | fsz = (L-1)*dilation + 1 33 | newcentre = fsz / 2 34 | before = newcentre - dilation*centre 35 | 36 | # When conv_transpose2d is done, a filter with k taps expands an input with 37 | # N samples to be N + k - 1 samples. The 'padding' is really the opposite of 38 | # that, and is how many samples on the edges you want to cut out. 39 | # In addition to this, we want the input to be extended before convolving. 40 | # This means the final output size without the padding option will be 41 | # N + k - 1 + k - 1 42 | # The final thing to worry about is making sure that the output is centred. 43 | short_offset = dilation - 1 44 | centre_offset = fsz % 2 45 | a = fsz//2 46 | b = fsz//2 + (fsz + 1) % 2 47 | # a = 0 48 | # b = 0 49 | pad = (0, 0, a, b) if d == 2 else (a, b, 0, 0) 50 | lo = mypad(lo, pad=pad, mode=mode) 51 | hi = mypad(hi, pad=pad, mode=mode) 52 | unpad = (fsz - 1, 0) if d == 2 else (0, fsz - 1) 53 | unpad = (0, 0) 54 | y = F.conv_transpose2d(lo, g0, padding=unpad, groups=C, dilation=dilation) + \ 55 | F.conv_transpose2d(hi, g1, padding=unpad, groups=C, dilation=dilation) 56 | # pad = (L-1, 0) if d == 2 else (0, L-1) 57 | # y = F.conv_transpose2d(lo, g0, padding=pad, groups=C, dilation=dilation) + \ 58 | # F.conv_transpose2d(hi, g1, padding=pad, groups=C, dilation=dilation) 59 | # 60 | # 61 | # Calculate the pad size 62 | # L2 = (L * dilation)//2 63 | # # pad = (0, 0, L2, L2+dilation) if d == 2 else (L2, L2+dilation, 0, 0) 64 | # a = dilation*2 65 | # b = dilation*(L-2) 66 | # if pad1 is None: 67 | # pad1 = (0, 0, a, b) if d == 2 else (a, b, 0, 0) 68 | # print(pad1) 69 | # lo = mypad(lo, pad=pad1, mode=mode) 70 | # hi = mypad(hi, pad=pad1, mode=mode) 71 | # if pad is None: 72 | # p = (a + b + (L - 1)*dilation)//2 73 | # pad = (p, 0) if d == 2 else (0, p) 74 | # print(pad) 75 | 76 | return y/(2*dilation) 77 | 78 | 79 | def sfb2d_atrous(ll, lh, hl, hh, filts, mode='zero'): 80 | """ Does a single level 2d wavelet reconstruction of wavelet coefficients. 81 | Does separate row and column filtering by two calls to 82 | :py:func:`pytorch_wavelets.dwt.lowlevel.sfb1d` 83 | 84 | Inputs: 85 | ll (torch.Tensor): lowpass coefficients 86 | lh (torch.Tensor): horizontal coefficients 87 | hl (torch.Tensor): vertical coefficients 88 | hh (torch.Tensor): diagonal coefficients 89 | filts (list of ndarray or torch.Tensor): If a list of tensors has been 90 | given, this function assumes they are in the right form (the form 91 | returned by 92 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`). 93 | Otherwise, this function will prepare the filters to be of the right 94 | form by calling 95 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`. 96 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which 97 | padding to use. If periodization, the output size will be half the 98 | input size. Otherwise, the output size will be slightly larger than 99 | half. 100 | """ 101 | tensorize = [not isinstance(x, torch.Tensor) for x in filts] 102 | if len(filts) == 2: 103 | g0, g1 = filts 104 | if True in tensorize: 105 | g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1) 106 | else: 107 | g0_col = g0 108 | g0_row = g0.transpose(2,3) 109 | g1_col = g1 110 | g1_row = g1.transpose(2,3) 111 | elif len(filts) == 4: 112 | if True in tensorize: 113 | g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts) 114 | else: 115 | g0_col, g1_col, g0_row, g1_row = filts 116 | else: 117 | raise ValueError("Unknown form for input filts") 118 | 119 | lo = sfb1d_atrous(ll, lh, g0_col, g1_col, mode=mode, dim=2) 120 | hi = sfb1d_atrous(hl, hh, g0_col, g1_col, mode=mode, dim=2) 121 | y = sfb1d_atrous(lo, hi, g0_row, g1_row, mode=mode, dim=3) 122 | 123 | return y 124 | 125 | 126 | class SWTInverse(nn.Module): 127 | """ Performs a 2d DWT Inverse reconstruction of an image 128 | 129 | Args: 130 | wave (str or pywt.Wavelet): Which wavelet to use 131 | C: deprecated, will be removed in future 132 | """ 133 | def __init__(self, wave='db1', mode='zero', separable=True): 134 | super().__init__() 135 | if isinstance(wave, str): 136 | wave = pywt.Wavelet(wave) 137 | if isinstance(wave, pywt.Wavelet): 138 | g0_col, g1_col = wave.rec_lo, wave.rec_hi 139 | g0_row, g1_row = g0_col, g1_col 140 | else: 141 | if len(wave) == 2: 142 | g0_col, g1_col = wave[0], wave[1] 143 | g0_row, g1_row = g0_col, g1_col 144 | elif len(wave) == 4: 145 | g0_col, g1_col = wave[0], wave[1] 146 | g0_row, g1_row = wave[2], wave[3] 147 | # Prepare the filters 148 | if separable: 149 | filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) 150 | self.register_buffer('g0_col', filts[0]) 151 | self.register_buffer('g1_col', filts[1]) 152 | self.register_buffer('g0_row', filts[2]) 153 | self.register_buffer('g1_row', filts[3]) 154 | else: 155 | filts = lowlevel.prep_filt_sfb2d_nonsep( 156 | g0_col, g1_col, g0_row, g1_row) 157 | self.register_buffer('h', filts) 158 | self.mode = mode 159 | self.separable = separable 160 | 161 | def forward(self, coeffs): 162 | """ 163 | Args: 164 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: 165 | yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', 166 | W_{in}')` and yh is a list of bandpass tensors of shape 167 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match 168 | the format returned by DWTForward 169 | 170 | Returns: 171 | Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` 172 | 173 | Note: 174 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly 175 | downsampled shapes of the DWT pyramid. 176 | 177 | Note: 178 | Can have None for any of the highpass scales and will treat the 179 | values as zeros (not in an efficient way though). 180 | """ 181 | yl, yh = coeffs 182 | ll = yl 183 | 184 | # Do a multilevel inverse transform 185 | for h in yh[::-1]: 186 | if h is None: 187 | h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], 188 | ll.shape[-1], device=ll.device) 189 | 190 | # 'Unpad' added dimensions 191 | if ll.shape[-2] > h.shape[-2]: 192 | ll = ll[...,:-1,:] 193 | if ll.shape[-1] > h.shape[-1]: 194 | ll = ll[...,:-1] 195 | 196 | # Do the synthesis filter banks 197 | if self.separable: 198 | lh, hl, hh = torch.unbind(h, dim=2) 199 | filts = (self.g0_col, self.g1_col, self.g0_row, self.g1_row) 200 | ll = lowlevel.sfb2d(ll, lh, hl, hh, filts, mode=self.mode) 201 | else: 202 | c = torch.cat((ll[:,:,None], h), dim=2) 203 | ll = lowlevel.sfb2d_nonsep(c, self.h, mode=self.mode) 204 | return ll 205 | -------------------------------------------------------------------------------- /pytorch_wavelets/dwt/transform1d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import pywt 3 | import pytorch_wavelets.dwt.lowlevel as lowlevel 4 | import torch 5 | 6 | 7 | class DWT1DForward(nn.Module): 8 | """ Performs a 1d DWT Forward decomposition of an image 9 | 10 | Args: 11 | J (int): Number of levels of decomposition 12 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. 13 | Can be: 14 | 1) a string to pass to pywt.Wavelet constructor 15 | 2) a pywt.Wavelet class 16 | 3) a tuple of numpy arrays (h0, h1) 17 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The 18 | padding scheme 19 | """ 20 | def __init__(self, J=1, wave='db1', mode='zero'): 21 | super().__init__() 22 | if isinstance(wave, str): 23 | wave = pywt.Wavelet(wave) 24 | if isinstance(wave, pywt.Wavelet): 25 | h0, h1 = wave.dec_lo, wave.dec_hi 26 | else: 27 | assert len(wave) == 2 28 | h0, h1 = wave[0], wave[1] 29 | 30 | # Prepare the filters - this makes them into column filters 31 | filts = lowlevel.prep_filt_afb1d(h0, h1) 32 | self.register_buffer('h0', filts[0]) 33 | self.register_buffer('h1', filts[1]) 34 | self.J = J 35 | self.mode = mode 36 | 37 | def forward(self, x): 38 | """ Forward pass of the DWT. 39 | 40 | Args: 41 | x (tensor): Input of shape :math:`(N, C_{in}, L_{in})` 42 | 43 | Returns: 44 | (yl, yh) 45 | tuple of lowpass (yl) and bandpass (yh) coefficients. 46 | yh is a list of length J with the first entry 47 | being the finest scale coefficients. 48 | """ 49 | assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)" 50 | highs = [] 51 | x0 = x 52 | mode = lowlevel.mode_to_int(self.mode) 53 | 54 | # Do a multilevel transform 55 | for j in range(self.J): 56 | x0, x1 = lowlevel.AFB1D.apply(x0, self.h0, self.h1, mode) 57 | highs.append(x1) 58 | 59 | return x0, highs 60 | 61 | 62 | class DWT1DInverse(nn.Module): 63 | """ Performs a 1d DWT Inverse reconstruction of an image 64 | 65 | Args: 66 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. 67 | Can be: 68 | 1) a string to pass to pywt.Wavelet constructor 69 | 2) a pywt.Wavelet class 70 | 3) a tuple of numpy arrays (h0, h1) 71 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The 72 | padding scheme 73 | """ 74 | def __init__(self, wave='db1', mode='zero'): 75 | super().__init__() 76 | if isinstance(wave, str): 77 | wave = pywt.Wavelet(wave) 78 | if isinstance(wave, pywt.Wavelet): 79 | g0, g1 = wave.rec_lo, wave.rec_hi 80 | else: 81 | assert len(wave) == 2 82 | g0, g1 = wave[0], wave[1] 83 | 84 | # Prepare the filters 85 | filts = lowlevel.prep_filt_sfb1d(g0, g1) 86 | self.register_buffer('g0', filts[0]) 87 | self.register_buffer('g1', filts[1]) 88 | self.mode = mode 89 | 90 | def forward(self, coeffs): 91 | """ 92 | Args: 93 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should 94 | match the format returned by DWT1DForward. 95 | 96 | Returns: 97 | Reconstructed input of shape :math:`(N, C_{in}, L_{in})` 98 | 99 | Note: 100 | Can have None for any of the highpass scales and will treat the 101 | values as zeros (not in an efficient way though). 102 | """ 103 | x0, highs = coeffs 104 | assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)" 105 | mode = lowlevel.mode_to_int(self.mode) 106 | # Do a multilevel inverse transform 107 | for x1 in highs[::-1]: 108 | if x1 is None: 109 | x1 = torch.zeros_like(x0) 110 | 111 | # 'Unpad' added signal 112 | if x0.shape[-1] > x1.shape[-1]: 113 | x0 = x0[..., :-1] 114 | x0 = lowlevel.SFB1D.apply(x0, x1, self.g0, self.g1, mode) 115 | return x0 116 | -------------------------------------------------------------------------------- /pytorch_wavelets/dwt/transform2d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import pywt 3 | import pytorch_wavelets.dwt.lowlevel as lowlevel 4 | import torch 5 | 6 | 7 | class DWTForward(nn.Module): 8 | """ Performs a 2d DWT Forward decomposition of an image 9 | 10 | Args: 11 | J (int): Number of levels of decomposition 12 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. 13 | Can be: 14 | 1) a string to pass to pywt.Wavelet constructor 15 | 2) a pywt.Wavelet class 16 | 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row) 17 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The 18 | padding scheme 19 | """ 20 | def __init__(self, J=1, wave='db1', mode='zero'): 21 | super().__init__() 22 | if isinstance(wave, str): 23 | wave = pywt.Wavelet(wave) 24 | if isinstance(wave, pywt.Wavelet): 25 | h0_col, h1_col = wave.dec_lo, wave.dec_hi 26 | h0_row, h1_row = h0_col, h1_col 27 | else: 28 | if len(wave) == 2: 29 | h0_col, h1_col = wave[0], wave[1] 30 | h0_row, h1_row = h0_col, h1_col 31 | elif len(wave) == 4: 32 | h0_col, h1_col = wave[0], wave[1] 33 | h0_row, h1_row = wave[2], wave[3] 34 | 35 | # Prepare the filters 36 | filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row) 37 | self.register_buffer('h0_col', filts[0]) 38 | self.register_buffer('h1_col', filts[1]) 39 | self.register_buffer('h0_row', filts[2]) 40 | self.register_buffer('h1_row', filts[3]) 41 | self.J = J 42 | self.mode = mode 43 | 44 | def forward(self, x): 45 | """ Forward pass of the DWT. 46 | 47 | Args: 48 | x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` 49 | 50 | Returns: 51 | (yl, yh) 52 | tuple of lowpass (yl) and bandpass (yh) coefficients. 53 | yh is a list of length J with the first entry 54 | being the finest scale coefficients. yl has shape 55 | :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape 56 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new 57 | dimension in yh iterates over the LH, HL and HH coefficients. 58 | 59 | Note: 60 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly 61 | downsampled shapes of the DWT pyramid. 62 | """ 63 | yh = [] 64 | ll = x 65 | mode = lowlevel.mode_to_int(self.mode) 66 | 67 | # Do a multilevel transform 68 | for j in range(self.J): 69 | # Do 1 level of the transform 70 | ll, high = lowlevel.AFB2D.apply( 71 | ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode) 72 | yh.append(high) 73 | 74 | return ll, yh 75 | 76 | 77 | class DWTInverse(nn.Module): 78 | """ Performs a 2d DWT Inverse reconstruction of an image 79 | 80 | Args: 81 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. 82 | Can be: 83 | 1) a string to pass to pywt.Wavelet constructor 84 | 2) a pywt.Wavelet class 85 | 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row) 86 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The 87 | padding scheme 88 | """ 89 | def __init__(self, wave='db1', mode='zero'): 90 | super().__init__() 91 | if isinstance(wave, str): 92 | wave = pywt.Wavelet(wave) 93 | if isinstance(wave, pywt.Wavelet): 94 | g0_col, g1_col = wave.rec_lo, wave.rec_hi 95 | g0_row, g1_row = g0_col, g1_col 96 | else: 97 | if len(wave) == 2: 98 | g0_col, g1_col = wave[0], wave[1] 99 | g0_row, g1_row = g0_col, g1_col 100 | elif len(wave) == 4: 101 | g0_col, g1_col = wave[0], wave[1] 102 | g0_row, g1_row = wave[2], wave[3] 103 | # Prepare the filters 104 | filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) 105 | self.register_buffer('g0_col', filts[0]) 106 | self.register_buffer('g1_col', filts[1]) 107 | self.register_buffer('g0_row', filts[2]) 108 | self.register_buffer('g1_row', filts[3]) 109 | self.mode = mode 110 | 111 | def forward(self, coeffs): 112 | """ 113 | Args: 114 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: 115 | yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', 116 | W_{in}')` and yh is a list of bandpass tensors of shape 117 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match 118 | the format returned by DWTForward 119 | 120 | Returns: 121 | Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` 122 | 123 | Note: 124 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly 125 | downsampled shapes of the DWT pyramid. 126 | 127 | Note: 128 | Can have None for any of the highpass scales and will treat the 129 | values as zeros (not in an efficient way though). 130 | """ 131 | yl, yh = coeffs 132 | ll = yl 133 | mode = lowlevel.mode_to_int(self.mode) 134 | 135 | # Do a multilevel inverse transform 136 | for h in yh[::-1]: 137 | if h is None: 138 | h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], 139 | ll.shape[-1], device=ll.device) 140 | 141 | # 'Unpad' added dimensions 142 | if ll.shape[-2] > h.shape[-2]: 143 | ll = ll[...,:-1,:] 144 | if ll.shape[-1] > h.shape[-1]: 145 | ll = ll[...,:-1] 146 | ll = lowlevel.SFB2D.apply( 147 | ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) 148 | return ll 149 | 150 | 151 | class SWTForward(nn.Module): 152 | """ Performs a 2d Stationary wavelet transform (or undecimated wavelet 153 | transform) of an image 154 | 155 | Args: 156 | J (int): Number of levels of decomposition 157 | wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to 158 | pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class, 159 | or can be a two tuple of array-like objects for the analysis low and 160 | high pass filters. 161 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The 162 | padding scheme. PyWavelets uses only periodization so we use this 163 | as our default scheme. 164 | """ 165 | def __init__(self, J=1, wave='db1', mode='periodization'): 166 | super().__init__() 167 | if isinstance(wave, str): 168 | wave = pywt.Wavelet(wave) 169 | if isinstance(wave, pywt.Wavelet): 170 | h0_col, h1_col = wave.dec_lo, wave.dec_hi 171 | h0_row, h1_row = h0_col, h1_col 172 | else: 173 | if len(wave) == 2: 174 | h0_col, h1_col = wave[0], wave[1] 175 | h0_row, h1_row = h0_col, h1_col 176 | elif len(wave) == 4: 177 | h0_col, h1_col = wave[0], wave[1] 178 | h0_row, h1_row = wave[2], wave[3] 179 | 180 | # Prepare the filters 181 | filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row) 182 | self.register_buffer('h0_col', filts[0]) 183 | self.register_buffer('h1_col', filts[1]) 184 | self.register_buffer('h0_row', filts[2]) 185 | self.register_buffer('h1_row', filts[3]) 186 | 187 | self.J = J 188 | self.mode = mode 189 | 190 | def forward(self, x): 191 | """ Forward pass of the SWT. 192 | 193 | Args: 194 | x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` 195 | 196 | Returns: 197 | List of coefficients for each scale. Each coefficient has 198 | shape :math:`(N, C_{in}, 4, H_{in}, W_{in})` where the extra 199 | dimension stores the 4 subbands for each scale. The ordering in 200 | these 4 coefficients is: (A, H, V, D) or (ll, lh, hl, hh). 201 | """ 202 | ll = x 203 | coeffs = [] 204 | # Do a multilevel transform 205 | filts = (self.h0_col, self.h1_col, self.h0_row, self.h1_row) 206 | for j in range(self.J): 207 | # Do 1 level of the transform 208 | y = lowlevel.afb2d_atrous(ll, filts, self.mode, 2**j) 209 | coeffs.append(y) 210 | ll = y[:,:,0] 211 | 212 | return coeffs 213 | -------------------------------------------------------------------------------- /pytorch_wavelets/scatternet/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import ScatLayer, ScatLayerj2 2 | 3 | __all__ = ['ScatLayer', 'ScatLayerj2'] 4 | -------------------------------------------------------------------------------- /pytorch_wavelets/scatternet/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_wavelets.dtcwt.coeffs import biort as _biort, qshift as _qshift 4 | from pytorch_wavelets.dtcwt.lowlevel import prep_filt 5 | 6 | from .lowlevel import mode_to_int 7 | from .lowlevel import ScatLayerj1_f, ScatLayerj1_rot_f 8 | from .lowlevel import ScatLayerj2_f, ScatLayerj2_rot_f 9 | 10 | 11 | class ScatLayer(nn.Module): 12 | """ Does one order of scattering at a single scale. Can be made into a 13 | second order scatternet by stacking two of these layers. 14 | Inputs: 15 | biort (str): the biorthogonal filters to use. if 'near_sym_b_bp' will 16 | use the rotationally symmetric filters. These have 13 and 19 taps 17 | so are quite long. They also require 7 1D convolutions instead of 6. 18 | x (torch.tensor): Input of shape (N, C, H, W) 19 | mode (str): padding mode. Can be 'symmetric' or 'zero' 20 | magbias (float): the magnitude bias to use for smoothing 21 | combine_colour (bool): if true, will only have colour lowpass and have 22 | greyscale bandpass 23 | Returns: 24 | y (torch.tensor): y has the lowpass and invariant U terms stacked along 25 | the channel dimension, and so has shape (N, 7*C, H/2, W/2). Where 26 | the first C channels are the lowpass outputs, and the next 6C are 27 | the magnitude highpass outputs. 28 | """ 29 | def __init__(self, biort='near_sym_a', mode='symmetric', magbias=1e-2, 30 | combine_colour=False): 31 | super().__init__() 32 | self.biort = biort 33 | # Have to convert the string to an int as the grad checks don't work 34 | # with string inputs 35 | self.mode_str = mode 36 | self.mode = mode_to_int(mode) 37 | self.magbias = magbias 38 | self.combine_colour = combine_colour 39 | if biort == 'near_sym_b_bp': 40 | self.bandpass_diag = True 41 | h0o, _, h1o, _, h2o, _ = _biort(biort) 42 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) 43 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) 44 | self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False) 45 | else: 46 | self.bandpass_diag = False 47 | h0o, _, h1o, _ = _biort(biort) 48 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) 49 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) 50 | 51 | def forward(self, x): 52 | # Do the single scale DTCWT 53 | # If the row/col count of X is not divisible by 2 then we need to 54 | # extend X 55 | _, ch, r, c = x.shape 56 | if r % 2 != 0: 57 | x = torch.cat((x, x[:,:,-1:]), dim=2) 58 | if c % 2 != 0: 59 | x = torch.cat((x, x[:,:,:,-1:]), dim=3) 60 | 61 | if self.combine_colour: 62 | assert ch == 3 63 | 64 | if self.bandpass_diag: 65 | Z = ScatLayerj1_rot_f.apply( 66 | x, self.h0o, self.h1o, self.h2o, self.mode, self.magbias, 67 | self.combine_colour) 68 | else: 69 | Z = ScatLayerj1_f.apply( 70 | x, self.h0o, self.h1o, self.mode, self.magbias, 71 | self.combine_colour) 72 | if not self.combine_colour: 73 | b, _, c, h, w = Z.shape 74 | Z = Z.view(b, 7*c, h, w) 75 | return Z 76 | 77 | def extra_repr(self): 78 | return "biort='{}', mode='{}', magbias={}".format( 79 | self.biort, self.mode_str, self.magbias) 80 | 81 | 82 | class ScatLayerj2(nn.Module): 83 | """ Does second order scattering for two scales. Uses correct dtcwt first 84 | and second level filters compared to ScatLayer which only uses biorthogonal 85 | filters. 86 | 87 | Inputs: 88 | biort (str): the biorthogonal filters to use. if 'near_sym_b_bp' will 89 | use the rotationally symmetric filters. These have 13 and 19 taps 90 | so are quite long. They also require 7 1D convolutions instead of 6. 91 | x (torch.tensor): Input of shape (N, C, H, W) 92 | mode (str): padding mode. Can be 'symmetric' or 'zero' 93 | Returns: 94 | y (torch.tensor): y has the lowpass and invariant U terms stacked along 95 | the channel dimension, and so has shape (N, 7*C, H/2, W/2). Where 96 | the first C channels are the lowpass outputs, and the next 6C are 97 | the magnitude highpass outputs. 98 | """ 99 | def __init__(self, biort='near_sym_a', qshift='qshift_a', mode='symmetric', 100 | magbias=1e-2, combine_colour=False): 101 | super().__init__() 102 | self.biort = biort 103 | self.qshift = biort 104 | # Have to convert the string to an int as the grad checks don't work 105 | # with string inputs 106 | self.mode_str = mode 107 | self.mode = mode_to_int(mode) 108 | self.magbias = magbias 109 | self.combine_colour = combine_colour 110 | if biort == 'near_sym_b_bp': 111 | assert qshift == 'qshift_b_bp' 112 | self.bandpass_diag = True 113 | h0o, _, h1o, _, h2o, _ = _biort(biort) 114 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) 115 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) 116 | self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False) 117 | h0a, h0b, _, _, h1a, h1b, _, _, h2a, h2b, _, _ = _qshift('qshift_b_bp') 118 | self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) 119 | self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) 120 | self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) 121 | self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) 122 | self.h2a = torch.nn.Parameter(prep_filt(h2a, 1), False) 123 | self.h2b = torch.nn.Parameter(prep_filt(h2b, 1), False) 124 | else: 125 | self.bandpass_diag = False 126 | h0o, _, h1o, _ = _biort(biort) 127 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) 128 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) 129 | h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) 130 | self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) 131 | self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) 132 | self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) 133 | self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) 134 | 135 | def forward(self, x): 136 | # Ensure the input size is divisible by 8 137 | ch, r, c = x.shape[1:] 138 | rem = r % 8 139 | if rem != 0: 140 | rows_after = (9-rem)//2 141 | rows_before = (8-rem) // 2 142 | x = torch.cat((x[:,:,:rows_before], x, 143 | x[:,:,-rows_after:]), dim=2) 144 | rem = c % 8 145 | if rem != 0: 146 | cols_after = (9-rem)//2 147 | cols_before = (8-rem) // 2 148 | x = torch.cat((x[:,:,:,:cols_before], x, 149 | x[:,:,:,-cols_after:]), dim=3) 150 | 151 | if self.combine_colour: 152 | assert ch == 3 153 | 154 | if self.bandpass_diag: 155 | pass 156 | Z = ScatLayerj2_rot_f.apply( 157 | x, self.h0o, self.h1o, self.h2o, self.h0a, self.h0b, self.h1a, 158 | self.h1b, self.h2a, self.h2b, self.mode, self.magbias, 159 | self.combine_colour) 160 | else: 161 | Z = ScatLayerj2_f.apply( 162 | x, self.h0o, self.h1o, self.h0a, self.h0b, self.h1a, 163 | self.h1b, self.mode, self.magbias, self.combine_colour) 164 | 165 | if not self.combine_colour: 166 | b, _, c, h, w = Z.shape 167 | Z = Z.view(b, 49*c, h, w) 168 | return Z 169 | 170 | def extra_repr(self): 171 | return "biort='{}', mode='{}', magbias={}".format( 172 | self.biort, self.mode_str, self.magbias) 173 | -------------------------------------------------------------------------------- /pytorch_wavelets/scatternet/lowlevel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1, inv_j1 6 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1_rot, inv_j1_rot 7 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j2plus, inv_j2plus 8 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j2plus_rot, inv_j2plus_rot 9 | 10 | 11 | def mode_to_int(mode): 12 | if mode == 'zero': 13 | return 0 14 | elif mode == 'symmetric': 15 | return 1 16 | elif mode == 'per' or mode == 'periodization': 17 | return 2 18 | elif mode == 'constant': 19 | return 3 20 | elif mode == 'reflect': 21 | return 4 22 | elif mode == 'replicate': 23 | return 5 24 | elif mode == 'periodic': 25 | return 6 26 | else: 27 | raise ValueError("Unkown pad type: {}".format(mode)) 28 | 29 | 30 | def int_to_mode(mode): 31 | if mode == 0: 32 | return 'zero' 33 | elif mode == 1: 34 | return 'symmetric' 35 | elif mode == 2: 36 | return 'periodization' 37 | elif mode == 3: 38 | return 'constant' 39 | elif mode == 4: 40 | return 'reflect' 41 | elif mode == 5: 42 | return 'replicate' 43 | elif mode == 6: 44 | return 'periodic' 45 | else: 46 | raise ValueError("Unkown pad type: {}".format(mode)) 47 | 48 | 49 | class SmoothMagFn(torch.autograd.Function): 50 | """ Class to do complex magnitude """ 51 | @staticmethod 52 | def forward(ctx, x, y, b): 53 | r = torch.sqrt(x**2 + y**2 + b**2) 54 | if x.requires_grad: 55 | dx = x/r 56 | dy = y/r 57 | ctx.save_for_backward(dx, dy) 58 | 59 | return r - b 60 | 61 | @staticmethod 62 | def backward(ctx, dr): 63 | dx = None 64 | if ctx.needs_input_grad[0]: 65 | drdx, drdy = ctx.saved_tensors 66 | dx = drdx * dr 67 | dy = drdy * dr 68 | return dx, dy, None 69 | 70 | 71 | class ScatLayerj1_f(torch.autograd.Function): 72 | """ Function to do forward and backward passes of a single scattering 73 | layer with the DTCWT biorthogonal filters. """ 74 | 75 | @staticmethod 76 | def forward(ctx, x, h0o, h1o, mode, bias, combine_colour): 77 | # bias = 1e-2 78 | # bias = 0 79 | ctx.in_shape = x.shape 80 | batch, ch, r, c = x.shape 81 | assert r % 2 == c % 2 == 0 82 | mode = int_to_mode(mode) 83 | ctx.mode = mode 84 | ctx.combine_colour = combine_colour 85 | 86 | ll, reals, imags = fwd_j1(x, h0o, h1o, False, 1, mode) 87 | ll = F.avg_pool2d(ll, 2) 88 | if combine_colour: 89 | r = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 + 90 | reals[:,:,1]**2 + imags[:,:,1]**2 + 91 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2) 92 | r = r[:, :, None] 93 | else: 94 | r = torch.sqrt(reals**2 + imags**2 + bias**2) 95 | 96 | if x.requires_grad: 97 | drdx = reals/r 98 | drdy = imags/r 99 | ctx.save_for_backward(h0o, h1o, drdx, drdy) 100 | else: 101 | z = x.new_zeros(1) 102 | ctx.save_for_backward(h0o, h1o, z, z) 103 | 104 | r = r - bias 105 | del reals, imags 106 | if combine_colour: 107 | Z = torch.cat((ll, r[:, :, 0]), dim=1) 108 | else: 109 | Z = torch.cat((ll[:, None], r), dim=1) 110 | 111 | return Z 112 | 113 | @staticmethod 114 | def backward(ctx, dZ): 115 | dX = None 116 | mode = ctx.mode 117 | 118 | if ctx.needs_input_grad[0]: 119 | # h0o, h1o, θ = ctx.saved_tensors 120 | h0o, h1o, drdx, drdy = ctx.saved_tensors 121 | # Use the special properties of the filters to get the time reverse 122 | h0o_t = h0o 123 | h1o_t = h1o 124 | 125 | # Level 1 backward (time reversed biorthogonal analysis filters) 126 | if ctx.combine_colour: 127 | dYl, dr = dZ[:,:3], dZ[:,3:] 128 | dr = dr[:, :, None] 129 | else: 130 | dYl, dr = dZ[:,0], dZ[:,1:] 131 | ll = 1/4 * F.interpolate(dYl, scale_factor=2, mode="nearest") 132 | reals = dr * drdx 133 | imags = dr * drdy 134 | 135 | dX = inv_j1(ll, reals, imags, h0o_t, h1o_t, 1, 3, 4, mode) 136 | 137 | return (dX,) + (None,) * 5 138 | 139 | 140 | class ScatLayerj1_rot_f(torch.autograd.Function): 141 | """ Function to do forward and backward passes of a single scattering 142 | layer with the DTCWT biorthogonal filters. Uses the rotationally symmetric 143 | filters, i.e. a slightly more expensive operation.""" 144 | 145 | @staticmethod 146 | def forward(ctx, x, h0o, h1o, h2o, mode, bias, combine_colour): 147 | mode = int_to_mode(mode) 148 | ctx.mode = mode 149 | # bias = 0 150 | ctx.in_shape = x.shape 151 | ctx.combine_colour = combine_colour 152 | batch, ch, r, c = x.shape 153 | assert r % 2 == c % 2 == 0 154 | 155 | # Level 1 forward (biorthogonal analysis filters) 156 | ll, reals, imags = fwd_j1_rot(x, h0o, h1o, h2o, False, 1, mode) 157 | ll = F.avg_pool2d(ll, 2) 158 | if combine_colour: 159 | r = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 + 160 | reals[:,:,1]**2 + imags[:,:,1]**2 + 161 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2) 162 | r = r[:, :, None] 163 | else: 164 | r = torch.sqrt(reals**2 + imags**2 + bias**2) 165 | if x.requires_grad: 166 | drdx = reals/r 167 | drdy = imags/r 168 | ctx.save_for_backward(h0o, h1o, h2o, drdx, drdy) 169 | else: 170 | z = x.new_zeros(1) 171 | ctx.save_for_backward(h0o, h1o, h2o, z, z) 172 | r = r - bias 173 | del reals, imags 174 | if combine_colour: 175 | Z = torch.cat((ll, r[:, :, 0]), dim=1) 176 | else: 177 | Z = torch.cat((ll[:, None], r), dim=1) 178 | 179 | return Z 180 | 181 | @staticmethod 182 | def backward(ctx, dZ): 183 | dX = None 184 | mode = ctx.mode 185 | 186 | if ctx.needs_input_grad[0]: 187 | # Don't need to do time reverse as these filters are symmetric 188 | # h0o, h1o, h2o, θ = ctx.saved_tensors 189 | h0o, h1o, h2o, drdx, drdy = ctx.saved_tensors 190 | 191 | # Level 1 backward (time reversed biorthogonal analysis filters) 192 | if ctx.combine_colour: 193 | dYl, dr = dZ[:,:3], dZ[:,3:] 194 | dr = dr[:, :, None] 195 | else: 196 | dYl, dr = dZ[:,0], dZ[:,1:] 197 | ll = 1/4 * F.interpolate(dYl, scale_factor=2, mode="nearest") 198 | 199 | reals = dr * drdx 200 | imags = dr * drdy 201 | dX = inv_j1_rot(ll, reals, imags, h0o, h1o, h2o, 1, 3, 4, mode) 202 | 203 | return (dX,) + (None,) * 6 204 | 205 | 206 | class ScatLayerj2_f(torch.autograd.Function): 207 | """ Function to do forward and backward passes of a single scattering 208 | layer with the DTCWT biorthogonal filters. """ 209 | 210 | @staticmethod 211 | def forward(ctx, x, h0o, h1o, h0a, h0b, h1a, h1b, mode, bias, combine_colour): 212 | # bias = 1e-2 213 | # bias = 0 214 | ctx.in_shape = x.shape 215 | batch, ch, r, c = x.shape 216 | assert r % 8 == c % 8 == 0 217 | mode = int_to_mode(mode) 218 | ctx.mode = mode 219 | ctx.combine_colour = combine_colour 220 | 221 | # First order scattering 222 | s0, reals, imags = fwd_j1(x, h0o, h1o, False, 1, mode) 223 | if combine_colour: 224 | s1_j1 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 + 225 | reals[:,:,1]**2 + imags[:,:,1]**2 + 226 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2) 227 | s1_j1 = s1_j1[:, :, None] 228 | if x.requires_grad: 229 | dsdx1 = reals/s1_j1 230 | dsdy1 = imags/s1_j1 231 | s1_j1 = s1_j1 - bias 232 | 233 | s0, reals, imags = fwd_j2plus(s0, h0a, h1a, h0b, h1b, False, 1, mode) 234 | s1_j2 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 + 235 | reals[:,:,1]**2 + imags[:,:,1]**2 + 236 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2) 237 | s1_j2 = s1_j2[:, :, None] 238 | if x.requires_grad: 239 | dsdx2 = reals/s1_j2 240 | dsdy2 = imags/s1_j2 241 | s1_j2 = s1_j2 - bias 242 | s0 = F.avg_pool2d(s0, 2) 243 | 244 | # Second order scattering 245 | s1_j1 = s1_j1[:, :, 0] 246 | s1_j1, reals, imags = fwd_j1(s1_j1, h0o, h1o, False, 1, mode) 247 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) 248 | if x.requires_grad: 249 | dsdx2_1 = reals/s2_j1 250 | dsdy2_1 = imags/s2_j1 251 | q = s2_j1.shape 252 | s2_j1 = s2_j1.view(q[0], 36, q[3], q[4]) 253 | s2_j1 = s2_j1 - bias 254 | s1_j1 = F.avg_pool2d(s1_j1, 2) 255 | if x.requires_grad: 256 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b, 257 | dsdx1, dsdy1, dsdx2, dsdy2, 258 | dsdx2_1, dsdy2_1) 259 | else: 260 | z = x.new_zeros(1) 261 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b, 262 | z, z, z, z, z, z) 263 | 264 | del reals, imags 265 | Z = torch.cat((s0, s1_j1, s1_j2[:,:,0], s2_j1), dim=1) 266 | 267 | else: 268 | s1_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) 269 | if x.requires_grad: 270 | dsdx1 = reals/s1_j1 271 | dsdy1 = imags/s1_j1 272 | s1_j1 = s1_j1 - bias 273 | 274 | s0, reals, imags = fwd_j2plus(s0, h0a, h1a, h0b, h1b, False, 1, mode) 275 | s1_j2 = torch.sqrt(reals**2 + imags**2 + bias**2) 276 | if x.requires_grad: 277 | dsdx2 = reals/s1_j2 278 | dsdy2 = imags/s1_j2 279 | s1_j2 = s1_j2 - bias 280 | s0 = F.avg_pool2d(s0, 2) 281 | 282 | # Second order scattering 283 | p = s1_j1.shape 284 | s1_j1 = s1_j1.view(p[0], 6*p[2], p[3], p[4]) 285 | 286 | s1_j1, reals, imags = fwd_j1(s1_j1, h0o, h1o, False, 1, mode) 287 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) 288 | if x.requires_grad: 289 | dsdx2_1 = reals/s2_j1 290 | dsdy2_1 = imags/s2_j1 291 | q = s2_j1.shape 292 | s2_j1 = s2_j1.view(q[0], 36, q[2]//6, q[3], q[4]) 293 | s2_j1 = s2_j1 - bias 294 | s1_j1 = F.avg_pool2d(s1_j1, 2) 295 | s1_j1 = s1_j1.view(p[0], 6, p[2], p[3]//2, p[4]//2) 296 | 297 | if x.requires_grad: 298 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b, 299 | dsdx1, dsdy1, dsdx2, dsdy2, 300 | dsdx2_1, dsdy2_1) 301 | else: 302 | z = x.new_zeros(1) 303 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b, 304 | z, z, z, z, z, z) 305 | 306 | del reals, imags 307 | Z = torch.cat((s0[:, None], s1_j1, s1_j2, s2_j1), dim=1) 308 | 309 | return Z 310 | 311 | @staticmethod 312 | def backward(ctx, dZ): 313 | dX = None 314 | mode = ctx.mode 315 | 316 | if ctx.needs_input_grad[0]: 317 | # Input has shape N, L, C, H, W 318 | o_dim = 1 319 | h_dim = 3 320 | w_dim = 4 321 | 322 | # Retrieve phase info 323 | (h0o, h1o, h0a, h0b, h1a, h1b, dsdx1, dsdy1, dsdx2, dsdy2, dsdx2_1, 324 | dsdy2_1) = ctx.saved_tensors 325 | 326 | # Use the special properties of the filters to get the time reverse 327 | h0o_t = h0o 328 | h1o_t = h1o 329 | h0a_t = h0b 330 | h0b_t = h0a 331 | h1a_t = h1b 332 | h1b_t = h1a 333 | 334 | # Level 1 backward (time reversed biorthogonal analysis filters) 335 | if ctx.combine_colour: 336 | ds0, ds1_j1, ds1_j2, ds2_j1 = \ 337 | dZ[:,:3], dZ[:,3:9], dZ[:,9:15], dZ[:,15:] 338 | ds1_j2 = ds1_j2[:, :, None] 339 | 340 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") 341 | q = ds2_j1.shape 342 | ds2_j1 = ds2_j1.view(q[0], 6, 6, q[2], q[3]) 343 | 344 | # Inverse second order scattering 345 | reals = ds2_j1 * dsdx2_1 346 | imags = ds2_j1 * dsdy2_1 347 | ds1_j1 = inv_j1( 348 | ds1_j1, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode) 349 | ds1_j1 = ds1_j1[:, :, None] 350 | 351 | # Inverse first order scattering j=2 352 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest") 353 | # s = ds1_j2.shape 354 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) 355 | reals = ds1_j2 * dsdx2 356 | imags = ds1_j2 * dsdy2 357 | ds0 = inv_j2plus( 358 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t, 359 | o_dim, h_dim, w_dim, mode) 360 | 361 | # Inverse first order scattering j=1 362 | reals = ds1_j1 * dsdx1 363 | imags = ds1_j1 * dsdy1 364 | dX = inv_j1( 365 | ds0, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode) 366 | else: 367 | ds0, ds1_j1, ds1_j2, ds2_j1 = \ 368 | dZ[:,0], dZ[:,1:7], dZ[:,7:13], dZ[:,13:] 369 | p = ds1_j1.shape 370 | ds1_j1 = ds1_j1.view(p[0], p[2]*6, p[3], p[4]) 371 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") 372 | q = ds2_j1.shape 373 | ds2_j1 = ds2_j1.view(q[0], 6, q[2]*6, q[3], q[4]) 374 | 375 | # Inverse second order scattering 376 | reals = ds2_j1 * dsdx2_1 377 | imags = ds2_j1 * dsdy2_1 378 | ds1_j1 = inv_j1( 379 | ds1_j1, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode) 380 | ds1_j1 = ds1_j1.view(p[0], 6, p[2], p[3]*2, p[4]*2) 381 | 382 | # Inverse first order scattering j=2 383 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest") 384 | # s = ds1_j2.shape 385 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) 386 | reals = ds1_j2 * dsdx2 387 | imags = ds1_j2 * dsdy2 388 | ds0 = inv_j2plus( 389 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t, 390 | o_dim, h_dim, w_dim, mode) 391 | 392 | # Inverse first order scattering j=1 393 | reals = ds1_j1 * dsdx1 394 | imags = ds1_j1 * dsdy1 395 | dX = inv_j1( 396 | ds0, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode) 397 | 398 | return (dX,) + (None,) * 9 399 | 400 | 401 | class ScatLayerj2_rot_f(torch.autograd.Function): 402 | """ Function to do forward and backward passes of a single scattering 403 | layer with the DTCWT bandpass biorthogonal and qshift filters . """ 404 | 405 | @staticmethod 406 | def forward(ctx, x, h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, mode, bias, combine_colour): 407 | # bias = 1e-2 408 | # bias = 0 409 | ctx.in_shape = x.shape 410 | batch, ch, r, c = x.shape 411 | assert r % 8 == c % 8 == 0 412 | mode = int_to_mode(mode) 413 | ctx.mode = mode 414 | ctx.combine_colour = combine_colour 415 | 416 | # First order scattering 417 | s0, reals, imags = fwd_j1_rot(x, h0o, h1o, h2o, False, 1, mode) 418 | if combine_colour: 419 | s1_j1 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 + 420 | reals[:,:,1]**2 + imags[:,:,1]**2 + 421 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2) 422 | s1_j1 = s1_j1[:, :, None] 423 | if x.requires_grad: 424 | dsdx1 = reals/s1_j1 425 | dsdy1 = imags/s1_j1 426 | s1_j1 = s1_j1 - bias 427 | 428 | s0, reals, imags = fwd_j2plus_rot(s0, h0a, h1a, h0b, h1b, h2a, h2b, False, 1, mode) 429 | s1_j2 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 + 430 | reals[:,:,1]**2 + imags[:,:,1]**2 + 431 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2) 432 | s1_j2 = s1_j2[:, :, None] 433 | if x.requires_grad: 434 | dsdx2 = reals/s1_j2 435 | dsdy2 = imags/s1_j2 436 | s1_j2 = s1_j2 - bias 437 | s0 = F.avg_pool2d(s0, 2) 438 | 439 | # Second order scattering 440 | s1_j1 = s1_j1[:, :, 0] 441 | s1_j1, reals, imags = fwd_j1_rot(s1_j1, h0o, h1o, h2o, False, 1, mode) 442 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) 443 | if x.requires_grad: 444 | dsdx2_1 = reals/s2_j1 445 | dsdy2_1 = imags/s2_j1 446 | q = s2_j1.shape 447 | s2_j1 = s2_j1.view(q[0], 36, q[3], q[4]) 448 | s2_j1 = s2_j1 - bias 449 | s1_j1 = F.avg_pool2d(s1_j1, 2) 450 | if x.requires_grad: 451 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, 452 | dsdx1, dsdy1, dsdx2, dsdy2, dsdx2_1, 453 | dsdy2_1) 454 | else: 455 | z = x.new_zeros(1) 456 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, 457 | z, z, z, z, z, z) 458 | 459 | del reals, imags 460 | Z = torch.cat((s0, s1_j1, s1_j2[:, :, 0], s2_j1), dim=1) 461 | else: 462 | s1_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) 463 | if x.requires_grad: 464 | dsdx1 = reals/s1_j1 465 | dsdy1 = imags/s1_j1 466 | s1_j1 = s1_j1 - bias 467 | 468 | s0, reals, imags = fwd_j2plus_rot(s0, h0a, h1a, h0b, h1b, h2a, h2b, False, 1, mode) 469 | s1_j2 = torch.sqrt(reals**2 + imags**2 + bias**2) 470 | if x.requires_grad: 471 | dsdx2 = reals/s1_j2 472 | dsdy2 = imags/s1_j2 473 | s1_j2 = s1_j2 - bias 474 | s0 = F.avg_pool2d(s0, 2) 475 | 476 | # Second order scattering 477 | p = s1_j1.shape 478 | s1_j1 = s1_j1.view(p[0], 6*p[2], p[3], p[4]) 479 | s1_j1, reals, imags = fwd_j1_rot(s1_j1, h0o, h1o, h2o, False, 1, mode) 480 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) 481 | if x.requires_grad: 482 | dsdx2_1 = reals/s2_j1 483 | dsdy2_1 = imags/s2_j1 484 | q = s2_j1.shape 485 | s2_j1 = s2_j1.view(q[0], 36, q[2]//6, q[3], q[4]) 486 | s2_j1 = s2_j1 - bias 487 | s1_j1 = F.avg_pool2d(s1_j1, 2) 488 | s1_j1 = s1_j1.view(p[0], 6, p[2], p[3]//2, p[4]//2) 489 | 490 | if x.requires_grad: 491 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, 492 | dsdx1, dsdy1, dsdx2, dsdy2, dsdx2_1, 493 | dsdy2_1) 494 | else: 495 | z = x.new_zeros(1) 496 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, 497 | z, z, z, z, z, z) 498 | 499 | del reals, imags 500 | Z = torch.cat((s0[:, None], s1_j1, s1_j2, s2_j1), dim=1) 501 | 502 | return Z 503 | 504 | @staticmethod 505 | def backward(ctx, dZ): 506 | dX = None 507 | mode = ctx.mode 508 | 509 | if ctx.needs_input_grad[0]: 510 | # Input has shape N, L, C, H, W 511 | o_dim = 1 512 | h_dim = 3 513 | w_dim = 4 514 | 515 | # Retrieve phase info 516 | (h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, dsdx1, dsdy1, dsdx2, 517 | dsdy2, dsdx2_1, dsdy2_1) = ctx.saved_tensors 518 | 519 | # Use the special properties of the filters to get the time reverse 520 | h0o_t = h0o 521 | h1o_t = h1o 522 | h2o_t = h2o 523 | h0a_t = h0b 524 | h0b_t = h0a 525 | h1a_t = h1b 526 | h1b_t = h1a 527 | h2a_t = h2b 528 | h2b_t = h2a 529 | 530 | # Level 1 backward (time reversed biorthogonal analysis filters) 531 | if ctx.combine_colour: 532 | ds0, ds1_j1, ds1_j2, ds2_j1 = \ 533 | dZ[:,:3], dZ[:,3:9], dZ[:,9:15], dZ[:,15:] 534 | ds1_j2 = ds1_j2[:, :, None] 535 | 536 | # Inverse second order scattering 537 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") 538 | q = ds2_j1.shape 539 | ds2_j1 = ds2_j1.view(q[0], 6, 6, q[2], q[3]) 540 | 541 | # Inverse second order scattering 542 | reals = ds2_j1 * dsdx2_1 543 | imags = ds2_j1 * dsdy2_1 544 | ds1_j1 = inv_j1_rot( 545 | ds1_j1, reals, imags, h0o_t, h1o_t, h2o_t, 546 | o_dim, h_dim, w_dim, mode) 547 | ds1_j1 = ds1_j1[:, :, None] 548 | 549 | # Inverse first order scattering j=2 550 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest") 551 | # s = ds1_j2.shape 552 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) 553 | reals = ds1_j2 * dsdx2 554 | imags = ds1_j2 * dsdy2 555 | ds0 = inv_j2plus_rot( 556 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t, h2a_t, h2b_t, 557 | o_dim, h_dim, w_dim, mode) 558 | 559 | # Inverse first order scattering j=1 560 | reals = ds1_j1 * dsdx1 561 | imags = ds1_j1 * dsdy1 562 | dX = inv_j1_rot( 563 | ds0, reals, imags, h0o_t, h1o_t, h2o_t, 564 | o_dim, h_dim, w_dim, mode) 565 | else: 566 | ds0, ds1_j1, ds1_j2, ds2_j1 = \ 567 | dZ[:,0], dZ[:,1:7], dZ[:,7:13], dZ[:,13:] 568 | 569 | # Inverse second order scattering 570 | p = ds1_j1.shape 571 | ds1_j1 = ds1_j1.view(p[0], p[2]*6, p[3], p[4]) 572 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") 573 | q = ds2_j1.shape 574 | ds2_j1 = ds2_j1.view(q[0], 6, q[2]*6, q[3], q[4]) 575 | reals = ds2_j1 * dsdx2_1 576 | imags = ds2_j1 * dsdy2_1 577 | ds1_j1 = inv_j1_rot( 578 | ds1_j1, reals, imags, h0o_t, h1o_t, h2o_t, 579 | o_dim, h_dim, w_dim, mode) 580 | ds1_j1 = ds1_j1.view(p[0], 6, p[2], p[3]*2, p[4]*2) 581 | 582 | # Inverse first order scattering j=2 583 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest") 584 | # s = ds1_j2.shape 585 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) 586 | reals = ds1_j2 * dsdx2 587 | imags = ds1_j2 * dsdy2 588 | ds0 = inv_j2plus_rot( 589 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t, h2a_t, h2b_t, 590 | o_dim, h_dim, w_dim, mode) 591 | 592 | # Inverse first order scattering j=1 593 | reals = ds1_j1 * dsdx1 594 | imags = ds1_j1 * dsdy1 595 | dX = inv_j1_rot( 596 | ds0, reals, imags, h0o_t, h1o_t, h2o_t, 597 | o_dim, h_dim, w_dim, mode) 598 | 599 | return (dX,) + (None,) * 12 600 | -------------------------------------------------------------------------------- /pytorch_wavelets/utils.py: -------------------------------------------------------------------------------- 1 | """ Useful utilities for testing the 2-D DTCWT with synthetic images""" 2 | 3 | from __future__ import absolute_import 4 | 5 | import functools 6 | import numpy as np 7 | 8 | 9 | def unpack(pyramid, backend='numpy'): 10 | """ Unpacks a pyramid give back the constituent parts. 11 | 12 | :param pyramid: The Pyramid of DTCWT transforms you wish to unpack 13 | :param str backend: A string from 'numpy', 'opencl', or 'tf' indicating 14 | which attributes you want to unpack from the pyramid. 15 | 16 | :returns: returns a generator which can be unpacked into the Yl, Yh and 17 | Yscale components of the pyramid. The generator will only return 2 18 | values if the pyramid was created with the include_scale parameter set 19 | to false. 20 | 21 | .. note:: 22 | 23 | You can still unpack a tf or opencl pyramid as if it were created by a 24 | numpy. In this case it will return a numpy array, rather than the 25 | backend specific array type. 26 | """ 27 | backend = backend.lower() 28 | if backend == 'numpy': 29 | yield pyramid.lowpass 30 | yield pyramid.highpasses 31 | if pyramid.scales is not None: 32 | yield pyramid.scales 33 | elif backend == 'opencl': 34 | yield pyramid.cl_lowpass 35 | yield pyramid.cl_highpasses 36 | if pyramid.cl_scales is not None: 37 | yield pyramid.cl_scales 38 | elif backend == 'tf': 39 | yield pyramid.lowpass_op 40 | yield pyramid.highpasses_ops 41 | if pyramid.scales_ops is not None: 42 | yield pyramid.scales_ops 43 | 44 | 45 | def drawedge(theta,r,w,N): 46 | """Generate an image of size N * N pels, of an edge going from 0 to 1 in 47 | height at theta degrees to the horizontal (top of image = 1 if angle = 0). 48 | r is a two-element vector, it is a coordinate in ij coords through which the 49 | step should pass. 50 | The shape of the intensity step is half a raised cosine w pels wide (w>=1). 51 | 52 | T. E . Gale's enhancement to drawedge() for MATLAB, transliterated 53 | to Python by S. C. Forshaw, Nov. 2013. """ 54 | 55 | # convert theta from degrees to radians 56 | thetar = np.array(theta * np.pi / 180) 57 | 58 | # Calculate image centre from given width 59 | imCentre = (np.array([N,N]).T - 1) / 2 + 1 60 | 61 | # Calculate values to subtract from the plane 62 | r = np.array([np.cos(thetar), np.sin(thetar)])*(-1) * (r - imCentre) 63 | 64 | # check width of raised cosine section 65 | w = np.maximum(1,w) 66 | 67 | ramp = np.arange(0,N) - (N+1)/2 68 | hgrad = np.sin(thetar)*(-1) * np.ones([N,1]) 69 | vgrad = np.cos(thetar)*(-1) * np.ones([1,N]) 70 | plane = ((hgrad * ramp) - r[0]) + ((ramp * vgrad).T - r[1]) 71 | x = 0.5 + 0.5 * np.sin(np.minimum(np.maximum( 72 | plane*(np.pi/w), np.pi/(-2)), np.pi/2)) 73 | 74 | return x 75 | 76 | 77 | def drawcirc(r,w,du,dv,N): 78 | 79 | """Generate an image of size N*N pels, containing a circle 80 | radius r pels and centred at du,dv relative 81 | to the centre of the image. The edge of the circle is a cosine shaped 82 | edge of width w (from 10 to 90% points). 83 | 84 | Python implementation by S. C. Forshaw, November 2013.""" 85 | 86 | # check value of w to avoid dividing by zero 87 | w = np.maximum(w,1) 88 | 89 | # x plane 90 | x = np.ones([N,1]) * ((np.arange(0,N,1, dtype='float') - 91 | (N+1) / 2 - dv) / r) 92 | 93 | # y vector 94 | y = (((np.arange(0,N,1, dtype='float') - (N+1) / 2 - du) / r) * 95 | np.ones([1,N])).T 96 | 97 | # Final circle image plane 98 | p = 0.5 + 0.5 * np.sin(np.minimum(np.maximum(( 99 | np.exp(np.array([-0.5]) * (x**2 + y**2)).T - np.exp((-0.5))) * (r * 3 / w), # noqa 100 | np.pi/(-2)), np.pi/2)) 101 | return p 102 | 103 | 104 | def asfarray(X): 105 | """Similar to :py:func:`numpy.asfarray` except that this function tries to 106 | preserve the original datatype of X if it is already a floating point type 107 | and will pass floating point arrays through directly without copying. 108 | 109 | """ 110 | X = np.asanyarray(X) 111 | return np.asfarray(X, dtype=X.dtype) 112 | 113 | 114 | def appropriate_complex_type_for(X): 115 | """Return an appropriate complex data type depending on the type of X. If X 116 | is already complex, return that, if it is floating point return a complex 117 | type of the appropriate size and if it is integer, choose an complex 118 | floating point type depending on the result of :py:func:`numpy.asfarray`. 119 | 120 | """ 121 | X = asfarray(X) 122 | 123 | if np.issubsctype(X.dtype, np.complex64) or \ 124 | np.issubsctype(X.dtype, np.complex128): 125 | return X.dtype 126 | elif np.issubsctype(X.dtype, np.float32): 127 | return np.complex64 128 | elif np.issubsctype(X.dtype, np.float64): 129 | return np.complex128 130 | 131 | # God knows, err on the side of caution 132 | return np.complex128 133 | 134 | 135 | def as_column_vector(v): 136 | """Return *v* as a column vector with shape (N,1). 137 | 138 | """ 139 | v = np.atleast_2d(v) 140 | if v.shape[0] == 1: 141 | return v.T 142 | else: 143 | return v 144 | 145 | 146 | def reflect(x, minx, maxx): 147 | """Reflect the values in matrix *x* about the scalar values *minx* and 148 | *maxx*. Hence a vector *x* containing a long linearly increasing series is 149 | converted into a waveform which ramps linearly up and down between *minx* 150 | and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + 151 | 0.5), the ramps will have repeated max and min samples. 152 | 153 | .. codeauthor:: Rich Wareham , Aug 2013 154 | .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. 155 | 156 | """ 157 | x = np.asanyarray(x) 158 | rng = maxx - minx 159 | rng_by_2 = 2 * rng 160 | mod = np.fmod(x - minx, rng_by_2) 161 | normed_mod = np.where(mod < 0, mod + rng_by_2, mod) 162 | out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx 163 | return np.array(out, dtype=x.dtype) 164 | 165 | 166 | def symm_pad_1d(l, m): 167 | """ Creates indices for symmetric padding. Works for 1-D. 168 | 169 | Inptus: 170 | l (int): size of input 171 | m (int): size of filter 172 | """ 173 | xe = reflect(np.arange(-m, l+m, dtype='int32'), -0.5, l-0.5) 174 | return xe 175 | 176 | 177 | # note that this decorator ignores **kwargs 178 | # From https://wiki.python.org/moin/PythonDecoratorLibrary#Alternate_memoize_as_nested_functions # noqa 179 | def memoize(obj): 180 | cache = obj.cache = {} 181 | 182 | @functools.wraps(obj) 183 | def memoizer(*args, **kwargs): 184 | if args not in cache: 185 | cache[args] = obj(*args, **kwargs) 186 | return cache[args] 187 | return memoizer 188 | 189 | 190 | def stacked_2d_matrix_vector_prod(mats, vecs): 191 | """ 192 | Interpret *mats* and *vecs* as arrays of 2D matrices and vectors. I.e. 193 | *mats* has shape PxQxNxM and *vecs* has shape PxQxM. The result 194 | is a PxQxN array equivalent to: 195 | 196 | .. code:: 197 | 198 | result[i,j,:] = mats[i,j,:,:].dot(vecs[i,j,:]) 199 | 200 | for all valid row and column indices *i* and *j*. 201 | """ 202 | return np.einsum('...ij,...j->...i', mats, vecs) 203 | 204 | 205 | def stacked_2d_vector_matrix_prod(vecs, mats): 206 | """ 207 | Interpret *mats* and *vecs* as arrays of 2D matrices and vectors. I.e. 208 | *mats* has shape PxQxNxM and *vecs* has shape PxQxN. The result 209 | is a PxQxM array equivalent to: 210 | 211 | .. code:: 212 | 213 | result[i,j,:] = mats[i,j,:,:].T.dot(vecs[i,j,:]) 214 | 215 | for all valid row and column indices *i* and *j*. 216 | """ 217 | vecshape = np.array(vecs.shape + (1,)) 218 | vecshape[-1:-3:-1] = vecshape[-2:] 219 | outshape = mats.shape[:-2] + (mats.shape[-1],) 220 | return stacked_2d_matrix_matrix_prod(vecs.reshape(vecshape), mats).reshape(outshape) # noqa 221 | 222 | 223 | def stacked_2d_matrix_matrix_prod(mats1, mats2): 224 | """ 225 | Interpret *mats1* and *mats2* as arrays of 2D matrices. I.e. 226 | *mats1* has shape PxQxNxM and *mats2* has shape PxQxMxR. The result 227 | is a PxQxNxR array equivalent to: 228 | 229 | .. code:: 230 | 231 | result[i,j,:,:] = mats1[i,j,:,:].dot(mats2[i,j,:,:]) 232 | 233 | for all valid row and column indices *i* and *j*. 234 | """ 235 | return np.einsum('...ij,...jk->...ik', mats1, mats2) 236 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import math 6 | import shutil 7 | import torch 8 | from model import WTFTP, WTFTP_AttnRemoved 9 | from torch.utils.data import DataLoader 10 | from dataloader import DataGenerator 11 | from utils import progress_bar, wt_coef_len, recorder 12 | import argparse 13 | import logging 14 | import datetime 15 | from tensorboardX import SummaryWriter 16 | from pytorch_wavelets import DWT1DForward, DWT1DInverse 17 | 18 | 19 | class Train: 20 | def __init__(self, opt, SEED=None, tracking=False): 21 | if SEED is not None: 22 | torch.manual_seed(SEED) 23 | if tracking: 24 | self.dev_recorder = recorder('dev_loss') 25 | self.train_recorder = recorder('train_loss', 'temporal_loss', 'freq_loss') 26 | self.SEED = SEED 27 | self.opt = opt 28 | self.iscuda = torch.cuda.is_available() 29 | self.device = torch.device(f'cuda:{self.opt.cuda}' if self.iscuda and not opt.cpu else 'cpu') 30 | self.data_set = DataGenerator(data_path=self.opt.datadir, 31 | minibatch_len=opt.minibatch_len, interval=opt.interval, 32 | use_preset_data_ranges=False) 33 | 34 | if self.opt.attn: 35 | self.net = WTFTP( 36 | n_inp=6, 37 | n_oup=6, 38 | his_step=opt.minibatch_len - 1, 39 | n_embding=opt.embding, 40 | en_layers=opt.enlayer, 41 | de_layers=opt.delayer, 42 | activation='relu', 43 | proj='linear', 44 | maxlevel=opt.maxlevel, 45 | en_dropout=opt.dpot, 46 | de_dropout=opt.dpot, 47 | bias=True 48 | ) 49 | else: 50 | self.net = WTFTP_AttnRemoved(n_inp=6, 51 | n_oup=6, 52 | n_embding=opt.embding, 53 | n_encoderLayers=opt.enlayer, 54 | n_decoderLayers=opt.delayer, 55 | activation='relu', 56 | proj='linear', 57 | maxlevel=opt.maxlevel, 58 | en_dropout=opt.dpot, 59 | de_dropout=opt.dpot) 60 | self.MSE = torch.nn.MSELoss(reduction='mean') 61 | self.MAE = torch.nn.L1Loss(reduction='mean') 62 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.opt.lr) 63 | self.opt_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5) 64 | if not self.opt.nologging: 65 | self.log_path = self.opt.logdir + f'/{datetime.datetime.now().strftime("%y-%m-%d")}' 66 | if self.opt.debug: 67 | self.log_path = self.opt.logdir + f'/DEBUG-{datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")}{"-" + self.opt.comments if len(self.opt.comments) > 0 else ""}' 68 | self.opt.saving_model_num = 10 69 | self.TX_log_path = self.log_path + '/Tensorboard' 70 | if not os.path.exists(self.log_path): 71 | os.makedirs(self.log_path) 72 | if os.path.exists(self.TX_log_path): 73 | shutil.rmtree(self.TX_log_path) 74 | os.mkdir(self.TX_log_path) 75 | else: 76 | os.mkdir(self.TX_log_path) 77 | logging.basicConfig(filename=os.path.join(self.log_path, 'train.log'), 78 | filemode='a', format='%(asctime)s %(message)s', level=logging.DEBUG) 79 | self.TX_logger = SummaryWriter(log_dir=self.TX_log_path) 80 | 81 | if self.opt.saving_model_num > 0: 82 | self.model_names = [] 83 | 84 | def train(self): 85 | self.net.to(self.device) 86 | self.net.args['train_opt'] = self.opt 87 | if not self.opt.nologging: 88 | log_str = f'TRAIN DETAILS'.center(50, '-') + \ 89 | f'\ntraining on device {self.device}...\n' 90 | for arg in vars(self.opt): 91 | log_str += f'{arg}: {getattr(self.opt, arg)}\n' 92 | log_str += f'MODEL DETAILS'.center(50, '-') + '\n' 93 | for key in self.net.args: 94 | log_str += f'{key}: {self.net.args[key]}\n' 95 | logging.debug(10 * '\n' + f'Beginning of the training epoch'.center(50, '-')) 96 | logging.debug(log_str) 97 | print(log_str) 98 | 99 | for i in range(self.opt.epoch): 100 | print('\n' + f'Epoch {i+1}/{self.opt.epoch}'.center(50, '-')) 101 | if not self.opt.nologging: 102 | logging.debug('\n' + f'Epoch {i+1}/{self.opt.epoch}'.center(50, '-')) 103 | print(f'lr: {float(self.opt_lr_scheduler.get_last_lr()[0])}') 104 | self.train_each_epoch(i+1) 105 | self.saving_model(self.log_path, f'epoch_{i+1}.pt') 106 | self.opt_lr_scheduler.step() 107 | self.dev_each_epoch(i+1) 108 | if not self.opt.nologging: 109 | logging.debug(10 * '\n' + f'End of the training epoch'.center(50, '-')) 110 | if hasattr(self, 'dev_recorder'): 111 | self.dev_recorder.push(self.log_path, 'dev_recorder.pt') 112 | if hasattr(self, 'train_recorder'): 113 | self.train_recorder.push(self.log_path, 'train_recorder.pt') 114 | 115 | def train_each_epoch(self, epoch): 116 | train_data = DataLoader(dataset=self.data_set.train_set, batch_size=self.opt.batch_size, shuffle=True, 117 | collate_fn=self.data_set.collate, num_workers=0, pin_memory=self.iscuda) 118 | self.net.train() 119 | dwt = DWT1DForward(wave=self.opt.wavelet, J=self.opt.maxlevel, mode=self.opt.wt_mode).to(self.device) 120 | idwt = DWT1DInverse(wave=self.opt.wavelet, mode=self.opt.wt_mode).to(self.device) 121 | start_time = time.perf_counter() 122 | batchs_len = len(train_data) 123 | train_loss_set = [] 124 | temporal_loss_set = [] 125 | freq_loss_set = [] 126 | for i, batch in enumerate(train_data): 127 | batch = torch.FloatTensor(batch).to(self.device) 128 | inp_batch = batch[:, :-1, :] # shape: batch * n_sequence * n_attr 129 | train_wt_loss = torch.tensor(0.0).to(self.device) 130 | tgt_batch = batch # use the all to calculate wavelet coefficients, shape: batch * n_sequence * n_attr 131 | wt_tgt_batch = dwt( 132 | batch.transpose(1, 2).contiguous()) # tuple(lo, hi), shape: batch * n_attr * n_sequence 133 | if self.opt.attn: 134 | wt_pre_batch, score_set = self.net(inp_batch) 135 | else: 136 | wt_pre_batch = self.net(inp_batch) 137 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(), 138 | [comp.transpose(1, 2).contiguous() for comp in 139 | wt_pre_batch[:-1]])).contiguous() # shape: batch * n_attr * n_sequence 140 | pre_batch = pre_batch.transpose(1, 2) # shape: batch * n_sequence * n_attr 141 | train_temporal_loss, loss_details = self.cal_spatial_loss(tgt_batch, pre_batch) 142 | train_wt_loss = self.cal_freq_loss(wt_tgt_batch, wt_pre_batch) 143 | loss_backward = torch.tensor(self.opt.w_spatial, dtype=torch.float, 144 | device=self.device) * train_temporal_loss + train_wt_loss 145 | self.optimizer.zero_grad() 146 | loss_backward.backward() 147 | self.optimizer.step() 148 | with torch.no_grad(): 149 | train_loss_set.append(float(loss_backward.detach().cpu())) 150 | temporal_loss_set.append(float(train_temporal_loss.detach().cpu())) 151 | freq_loss_set.append(float(train_wt_loss.detach().cpu())) 152 | # print loss 153 | print_str = f'train_loss(scaled): {loss_backward.item():.8f}, temporal_loss: {train_temporal_loss.item():.8f}, ' \ 154 | f'freq_loss: {train_wt_loss.item():.8f}' 155 | progress_bar(i, batchs_len, print_str, start_time) 156 | record_freq = 20 157 | if not self.opt.nologging and (i % ((batchs_len - 1) // record_freq) == 0 or i == batchs_len - 1): 158 | logging.debug(f'{i}/{batchs_len - 1} ' + print_str) 159 | print_str = f'ave_train_loss: {np.mean(train_loss_set):.8f}, ave_temporal_loss: {np.mean(temporal_loss_set):.8f}' \ 160 | + f', ave_freq_loss: {np.mean(freq_loss_set):.8f}' 161 | print(print_str) 162 | if not self.opt.nologging: 163 | logging.debug(print_str) 164 | self.TX_logger.add_scalar('train_loss', np.mean(train_loss_set), global_step=epoch) 165 | self.TX_logger.add_scalar('train_temporal_loss', np.mean(temporal_loss_set), global_step=epoch) 166 | self.TX_logger.add_scalar('train_freq_loss', np.mean(freq_loss_set), global_step=epoch) 167 | if hasattr(self, 'train_recorder'): 168 | self.train_recorder.add('train_loss', np.mean(train_loss_set)) 169 | self.train_recorder.add('temporal_loss', np.mean(temporal_loss_set)) 170 | self.train_recorder.add('freq_loss', np.mean(freq_loss_set)) 171 | if not self.opt.nologging and self.opt.attn: 172 | project = plt.cm.get_cmap('YlGnBu') 173 | with torch.no_grad(): 174 | self.TX_logger.add_image(f'score', 175 | project(torch.mean(score_set.clone().cpu(), dim=0).numpy())[:, :, :-1], 176 | dataformats='HWC', 177 | global_step=epoch) 178 | 179 | def dev_each_epoch(self, epoch): 180 | dev_data = DataLoader(dataset=self.data_set.dev_set, batch_size=self.opt.batch_size, shuffle=False, 181 | collate_fn=self.data_set.collate, num_workers=0, pin_memory=self.iscuda) 182 | self.net.eval() 183 | idwt = DWT1DInverse(wave=self.opt.wavelet, mode=self.opt.wt_mode).to(self.device) 184 | tgt_set = [] 185 | pre_set = [] 186 | wt_coef_length = wt_coef_len(self.opt.minibatch_len, wavelet=self.opt.wavelet, mode=self.opt.wt_mode, 187 | maxlevel=self.opt.maxlevel) 188 | with torch.no_grad(): 189 | for i, batch in enumerate(dev_data): 190 | batch = torch.FloatTensor(batch).to(self.device) 191 | n_batch, _, n_attr = batch.shape 192 | inp_batch = batch[:, :-1, :] # shape: batch * n_sequence * n_attr 193 | tgt_batch = batch 194 | if self.opt.attn: 195 | wt_pre_batch, score_set = self.net(inp_batch) 196 | else: 197 | wt_pre_batch = self.net(inp_batch) 198 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(), 199 | [comp.transpose(1, 2).contiguous() for comp in 200 | wt_pre_batch[:-1]])).contiguous() # shape: batch * n_attr * n_sequence 201 | pre_batch = pre_batch.transpose(1, 2) # shape: batch * n_sequence * n_attr 202 | tgt_set.append(tgt_batch) 203 | pre_set.append(pre_batch) 204 | tgt_set = torch.cat(tgt_set, dim=0) 205 | pre_set = torch.cat(pre_set, dim=0) 206 | dev_loss, loss_details = self.cal_spatial_loss(tgt_set, pre_set, is_training=False) 207 | print_str = f'Evaluation-Stage:\n' \ 208 | f'aveMSE(scaled): {dev_loss:.8f}, in each attr(RMSE, unscaled): {loss_details["rmse"]}\n' \ 209 | f'aveMAE(scaled): None, in each attr(MAE, unscaled): {loss_details["mae"]}' 210 | print(print_str) 211 | if not self.opt.nologging: 212 | logging.debug(print_str) 213 | self.TX_logger.add_scalar('eval_aveMSE', dev_loss, global_step=epoch) 214 | if hasattr(self, 'dev_recorder'): 215 | self.dev_recorder.add('dev_loss', dev_loss) 216 | 217 | def cal_spatial_loss(self, tgt, pre, is_training=True): 218 | """ 219 | :param tgt: shape: batch * n_sequence * n_attr 220 | :param pre: shape: batch * n_sequence+? * n_attr 221 | :return: 222 | """ 223 | n_sequence = tgt.shape[1] 224 | last_node_weight = 1.0 225 | if is_training: 226 | weights = torch.ones(n_sequence, dtype=torch.float, device=tgt.device) 227 | else: 228 | weights = torch.zeros(n_sequence, dtype=torch.float, device=tgt.device) 229 | weights[-1] = last_node_weight 230 | weighted_loss = torch.tensor(0.0).to(self.device) 231 | for i in range(n_sequence): 232 | weighted_loss += weights[i] * self.MSE(pre[:, i, :], tgt[:, i, :]) 233 | try: 234 | weighted_loss /= weights.count_nonzero() 235 | except: 236 | if is_training: 237 | weighted_loss /= torch.tensor(n_sequence, dtype=torch.float).to(self.device) 238 | else: 239 | weighted_loss /= torch.tensor(1.0, dtype=torch.float).to(self.device) 240 | loss_unscaled = {'rmse': {}, 'mae': {}} 241 | tgt_cloned = tgt.detach() 242 | pre_cloned = pre.detach() 243 | for i, name in enumerate(self.data_set.attr_names): 244 | loss_unscaled['rmse'][name] = math.sqrt( 245 | float(self.MSE(self.data_set.unscale(tgt_cloned[:, n_sequence - 1, i], name), 246 | self.data_set.unscale(pre_cloned[:, n_sequence - 1, i], name)))) 247 | loss_unscaled['mae'][name] = float(self.MAE(self.data_set.unscale(tgt_cloned[:, n_sequence - 1, i], name), 248 | self.data_set.unscale(pre_cloned[:, n_sequence - 1, i], name))) 249 | return weighted_loss, loss_unscaled 250 | 251 | def cal_freq_loss(self, tgt, pre) -> torch.Tensor: 252 | """ 253 | :param tgt: tuple:(lo, hi) shape: batch * n_attr * n_sequence 254 | :param pre: list:[hi's lo] shape: batch * n_attr * n_sequence 255 | :return: 256 | """ 257 | wt_loss = torch.tensor(0.0).to(self.device) 258 | wt_loss += self.opt.w_lo * self.MSE(tgt[0], pre[-1].transpose(1, 2)) 259 | for i in range(self.opt.maxlevel): 260 | wt_loss += self.opt.w_hi * self.MSE(pre[i].transpose(1, 2), tgt[1][i]) 261 | return wt_loss 262 | 263 | def saving_model(self, model_path, this_model_name): 264 | if self.opt.saving_model_num > 0: 265 | self.model_names.append(this_model_name) 266 | if len(self.model_names) > self.opt.saving_model_num: 267 | removed_model_name = self.model_names[0] 268 | del self.model_names[0] 269 | os.remove(os.path.join(model_path, removed_model_name)) # remove the oldest model 270 | torch.save(self.net, os.path.join(model_path, this_model_name)) # save the latest model 271 | 272 | if __name__ == '__main__': 273 | parser = argparse.ArgumentParser() 274 | parser.add_argument('--minibatch_len', default=10, type=int) 275 | parser.add_argument('--interval', default=1, type=int) 276 | parser.add_argument('--batch_size', default=2048, type=int) 277 | parser.add_argument('--epoch', default=150, type=int) 278 | parser.add_argument('--lr', default=0.001, type=float) 279 | parser.add_argument('--dpot', default=0.0, type=float) 280 | parser.add_argument('--cpu', action='store_true') 281 | parser.add_argument('--nologging', action='store_true') 282 | parser.add_argument('--logdir', default='./log', type=str) 283 | parser.add_argument('--datadir', default='./data', type=str) 284 | parser.add_argument('--bidirectional', action='store_true') 285 | parser.add_argument('--saving_model_num', default=0, type=int) 286 | parser.add_argument('--debug', action='store_true', 287 | help='logs saving in an independent dir and 10 models being saved') 288 | parser.add_argument('--maxlevel', default=1, type=int) 289 | parser.add_argument('--wavelet', default='haar', type=str) 290 | parser.add_argument('--L2details', action='store_true', 291 | help='L2 regularization for detail coefficients to avoid them to converge to zeros') 292 | parser.add_argument('--wt_mode', default='symmetric', type=str) 293 | parser.add_argument('--w_spatial', default='0.0', type=float) 294 | parser.add_argument('--w_lo', default='1.0', type=float) 295 | parser.add_argument('--w_hi', default='1.0', type=float) 296 | parser.add_argument('--enlayer', default='4', type=int) 297 | parser.add_argument('--delayer', default='1', type=int) 298 | parser.add_argument('--embding', default='64', type=int) 299 | parser.add_argument('--attn', action='store_true') 300 | parser.add_argument('--comments', default="", type=str, 301 | help='comments in the dir name for identification') 302 | parser.add_argument('--cuda', default=0, type=int) 303 | args = parser.parse_args() 304 | train = Train(args, tracking=True) 305 | train.train() 306 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pytorch_wavelets import DWT1DForward, DWT1DInverse 4 | import torch 5 | import pywt 6 | import time 7 | from torch.utils.data import DataLoader 8 | import os 9 | 10 | 11 | def wt_packet(inp, wavelet='db2', mode='symmetric', maxlevel=1): 12 | """ 13 | the last-dim vector to be decomposed 14 | :param inp: shape: batch * n_attr * n_sequence 15 | :param wavelet: 16 | :param mode: 17 | :param maxlevel: 18 | :return: oup: shape : batch * n_attr * level * n_sequence 19 | """ 20 | if maxlevel == 0: 21 | return inp.unsqueeze(2) 22 | dwt = DWT1DForward(wave=wavelet, J=1, mode=mode).to(inp.device) 23 | oup = [inp] 24 | for _ in range(maxlevel): 25 | tmp = [] 26 | for item in oup: 27 | lo, hi = dwt(item) 28 | tmp += [lo, hi[0]] 29 | oup = tmp 30 | oup = torch.stack(oup, dim=2) 31 | return oup 32 | 33 | 34 | def wt_packet_inverse(inp, wavelet='db2', mode='symmetric', maxlevel=1): 35 | """ 36 | the last-dim vector to be composed 37 | :param inp: shape : batch * n_attr * level * n_sequence 38 | :param wavelet: 39 | :param mode: 40 | :param maxlevel: 41 | :return: oup: batch * n_attr * n_sequence 42 | """ 43 | assert inp.shape[2] == 2**maxlevel 44 | if maxlevel == 0: 45 | return inp.squeeze(2) 46 | idwt = DWT1DInverse(wave=wavelet, mode=mode).to(inp.device) 47 | oup = inp 48 | for level in range(maxlevel, 0, -1): 49 | tmp = [] 50 | for i in range(2**(level-1)): 51 | lo, hi = oup[:, :, 2 * i, :], oup[:, :, 2 * i + 1, :] 52 | hi = [hi] 53 | tmp.append(idwt((lo, hi))) 54 | oup = torch.stack(tmp, dim=2) 55 | oup = oup.squeeze(2) 56 | return oup 57 | 58 | 59 | def progress_bar(step, n_step, str, start_time=time.perf_counter(), bar_len=20): 60 | ''' 61 | :param bar_len: length of the bar 62 | :param step: from 0 to n_step-1 63 | :param n_step: number of steps 64 | :param str: info to be printed 65 | :param start_time: time to begin the progress_bar 66 | :return: 67 | ''' 68 | step = step+1 69 | a = "*" * int(step * bar_len / n_step) 70 | b = " " * (bar_len - int(step * bar_len / n_step)) 71 | c = step / n_step * 100 72 | dur = time.perf_counter() - start_time 73 | print("\r{:^3.0f}%[{}{}]{:.2f}s {}".format(c, a, b, dur, str), end="") 74 | if step == n_step: 75 | print('') 76 | 77 | 78 | def wt_coef_len(in_length, wavelet, mode, maxlevel): 79 | test_inp = torch.ones((1, 1, in_length)) 80 | test_oup = wt_packet(test_inp, wavelet=wavelet, mode=mode, maxlevel=maxlevel) 81 | return test_oup.shape[-1] 82 | 83 | 84 | class recorder: 85 | def __init__(self, *attrs): 86 | if attrs is not None: 87 | self.saver = {} 88 | self.attrs = attrs 89 | for attr in attrs: 90 | self.saver[attr] = [] 91 | else: 92 | self.saver = {} 93 | self.attrs = attrs 94 | 95 | def add(self, attr, val): 96 | self.saver[attr].append(val) 97 | 98 | def __getitem__(self, item): 99 | return self.saver[item] 100 | 101 | def __len__(self): 102 | return len(self.attrs) 103 | 104 | def push(self, path, filename='recorde.pt'): 105 | torch.save(self.saver, os.path.join(path, filename)) 106 | 107 | def pull(self, filepath): 108 | self.saver = torch.load(filepath) 109 | self.attrs = self.saver.keys() 110 | --------------------------------------------------------------------------------