├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
├── example_data.npy
└── source_data.xlsx
├── data_ranges.npy
├── dataloader.py
├── desired_results.xlsx
├── infer.py
├── model.py
├── pics
└── framework.png
├── pytorch_wavelets
├── __init__.py
├── _version.py
├── dtcwt
│ ├── __init__.py
│ ├── coeffs.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── antonini.npz
│ │ ├── farras.npz
│ │ ├── legall.npz
│ │ ├── near_sym_a.npz
│ │ ├── near_sym_a2.npz
│ │ ├── near_sym_b.npz
│ │ ├── near_sym_b_bp.npz
│ │ ├── qshift_06.npz
│ │ ├── qshift_32.npz
│ │ ├── qshift_a.npz
│ │ ├── qshift_b.npz
│ │ ├── qshift_b_bp.npz
│ │ ├── qshift_c.npz
│ │ └── qshift_d.npz
│ ├── lowlevel.py
│ ├── lowlevel2.py
│ ├── transform2d.py
│ └── transform_funcs.py
├── dwt
│ ├── __init__.py
│ ├── lowlevel.py
│ ├── swt_inverse.py
│ ├── transform1d.py
│ └── transform2d.py
├── scatternet
│ ├── __init__.py
│ ├── layers.py
│ └── lowlevel.py
└── utils.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | converted.pt
2 | model_mod1121.py
3 | prepareData.py
4 | transfer.py
5 | WTFTP_his_tgt_pre_score.pt
6 | .idea
7 | __pycache__
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Flight trajectory prediction enabled by time-frequency wavelet transform
2 |
3 | # Introduction
4 |
5 | This repository provides source codes of the proposed flight trajectory prediction framework, called WTFTP,
6 | and example samples for the paper Flight trajectory
7 | prediction enabled by time-frequency wavelet transform.
8 | This is a novel wavelet-based flight trajectory prediction framework by modeling global flight trends and local motion
9 | details.
10 |
11 |

12 |
13 | ## Repository Structure
14 | ```
15 | wtftp-model
16 | │ dataloader.py (Load trajectory data from ./data)
17 | │ data_ranges.npy (Provide data ranges for normalization and denormalization)
18 | │ desired_results.xlsx (Excepted prediction result of example samples)
19 | │ infer.py (Perform the prediction procedure)
20 | │ LICENSE (LICENSE file)
21 | │ model.py (The neural architecture corresponding to the WTFTP framework)
22 | │ README.md (The current README file)
23 | │ train.py (The training code)
24 | │ utils.py (Tools for the project)
25 | │
26 | ├─data
27 | │ │ README.md (README file for the dataset.)
28 | │ │ example_data.npy (Example data file)
29 | │ │ source_data.xlsx (Source data of figures in the paper)
30 | │ ├─dev (Archive for the validation data)
31 | │ ├─test (Archive for the test data)
32 | │ └─train (Archive for the training data)
33 | └─pics
34 | ```
35 |
36 | ## Package Requirements
37 |
38 | + Python == 3.7.1
39 | + torch == 1.4.0 + cu100
40 | + numpy == 1.18.5
41 | + tensorboard == 2.3.0
42 | + tensorboardX == 2.1
43 | + PyWavelets == 1.2.0
44 | + matplotlib == 3.2.1
45 |
46 | ## System Requirements
47 | + Ubuntu 16.04 operating system
48 | + Intel(R) Xeon(TM) E5-2690@2.90GHz
49 | + 128G of memory
50 | + 8TB of hard disks
51 | + 8 $\times$ NVIDIA(R) GeForce RTX(TM) 2080 Ti 11G GPUs.
52 |
53 |
54 | # Instructions
55 | ## Installation
56 |
57 | ### Clone this repository
58 |
59 | ```
60 | git clone https://github.com/MusDev7/wtftp-model.git
61 | ```
62 |
63 | ### Create proper software and hardware environment
64 |
65 | You are recommended to create a conda environment with the package requirements mentioned above, and conduct the
66 | training and test on the suggested system configurations.
67 |
68 | ## Training
69 |
70 | The training script is provided by `train.py` for the flight trajectory prediction. The arguments for the training
71 | process are defined bellow:
72 |
73 | `--minibatch_len`: Integer. The sliding-window length for constructing the samples. `default=10`
74 |
75 | `--interval`: Integer. The sampling period. `default=1`
76 |
77 | `--batch_size`: Integer. The number of samples in a single training batch. `default=2048`
78 |
79 | `--epoch`: Integer. The maximum epoch for training process. `default=150`
80 |
81 | `--lr`: Float. The learning rate of the Adam optimizer. `default=0.001`
82 |
83 | `--dpot`: Float. The dropout probability. `default=0.0`
84 |
85 | `--cpu`: Optional. Use the CPU for training process.
86 |
87 | `--nolongging`: Optional. The logs will not be recorded.
88 |
89 | `--logdir`: String. The path for logs. `default='./log'`
90 |
91 | `--datadir`: String. The path for dataset. `default='./data'`
92 |
93 | `--saving_model_num`: Integer. The number of models to be saved during the training process. `default=0`
94 |
95 | `--debug`: Optional. For debugging the scripts.
96 |
97 | `--bidirectional`: Optional. Use the bidirectional LSTM block.
98 |
99 | `--maxlevel`: Integer. The level of wavelet analysis. `default=1`
100 |
101 | `--wavelet`: String. The wavelet basis. `default=haar`
102 |
103 | `--wt_mode`: String. The signal extension mode for wavelet transform. `default=symmetric`
104 |
105 | `--w_lo`: Float. The weight for the low-frequency wavelet component in the loss function. `default=1.0`
106 |
107 | `--w_hi`: Float. The weight for the high-frequency wavelet components in the loss function. `default=1.0`
108 |
109 | `--enlayer`: Integer. The layer number of the LSTM block in the encoder. `default=4`
110 |
111 | `--delayer`: Integer. The layer number of the LSTM block in the decoder. `default=1`
112 |
113 | `--embding`: Integer. The dimension of trajectory embeddings, enhanced trajectory embeddings and contextual
114 | embeddings. `default=64`
115 |
116 | `--attn`: Optional. Use the wavelet attention module in the decoder.
117 |
118 | `--cuda`: Integer. The GPU index for training process.
119 |
120 | To train the WTFTP, use the following command.
121 |
122 | ```
123 | python train.py --saving_model_num 10 --attn
124 | ```
125 |
126 | To train the WTFTP without the wavelet attention module, use the following command.
127 |
128 | ```
129 | python train.py --saving_model_num 10
130 | ```
131 |
132 | To train the WTFTP of 2 or 3 level of wavelet analysis, use the following command.
133 |
134 | ```
135 | python train.py --saving_model_num 10 --attn --maxlevel 2
136 | ```
137 | ```
138 | python train.py --saving_model_num 10 --attn --maxlevel 3
139 | ```
140 |
141 | To train the WTFTP without the wavelet attention module of 2 or 3 level of wavelet analysis, use the following command.
142 |
143 | ```
144 | python train.py --saving_model_num 10 --maxlevel 2
145 | ```
146 |
147 | ```
148 | python train.py --saving_model_num 10 --maxlevel 3
149 | ```
150 |
151 |
152 | ## Test
153 |
154 | The test script is provided by `infer.py` for the evaluation. The arguments for the test process are defined bellow:
155 |
156 | `--minibatch_len`: Integer. The sliding-window length for constructing the samples. `default=10`
157 |
158 | `--interval`: Integer. The sampling period. `default=1`
159 |
160 | `--pre_len`: Integer. The prediction horizons for evaluation. `default=1`
161 |
162 | `--batch_size`: Integer. The number of samples in a single test batch. `default=2048`
163 |
164 | `--cpu`: Optional. Use the CPU for test process.
165 |
166 | `--logdir`: String. The path for logs. `default='./log'`
167 |
168 | `--datadir`: String. The path for dataset. `default='./data'`
169 |
170 | `--netdir`: String. The path for the model.
171 |
172 | To test the model, use the following command.
173 |
174 | ```
175 | python infer.py --netdir ./xxx.pt
176 | ```
177 |
178 | # Dataset
179 |
180 | In this repository, the example samples are provided for evaluation. They can be accessed in the `/data/example_data.npy`.
181 | The guidance about the example data can be found in `/data/README`.
182 |
183 |
184 | # Acknowledgement
185 |
186 | The PyTorch implementation of wavelet transform is utilized to support the procedure of the DWT and IDWT procedures
187 | in this work. Its repository can be accessed [here](https://github.com/fbcotter/pytorch_wavelets).
188 | Thank all contributors to this project.
189 |
190 | # Citation
191 |
192 | Zhang, Z., Guo, D., Zhou, S. et al. Flight trajectory prediction enabled by time-frequency wavelet transform. Nat Commun 14, 5258 (2023). https://doi.org/10.1038/s41467-023-40903-9
193 |
194 | # Contact
195 |
196 | Zheng Zhang (zhaeng@stu.scu.edu.cn, musevr.ae@gmail.com)
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Note
2 |
3 | We are not authorized to publicly release the whole dataset used during the current study concerning
4 | safety-critical issues. Nonetheless, the processed example samples are
5 | available in `example_data.npy`. Source data for figures are also provided in `source_data.xlsx`.
6 |
7 | ## Guidance of example data
8 |
9 | The example data has been stored as a binary file in NumPy format. It can be accessed by using:
10 |
11 | ```python
12 | import numpy as np
13 | example_data = np.load("example_data.npy")
14 | ```
15 |
16 | The first dimension of `example_data` is the sample number of 500. The second is the sliding-window size of 10 (the first 9
17 | lines represent the input trajectory sequence, whereas the final line serves as the target trajectory point). And the
18 | last dimension of 6 indicates the six attributes: longitude (degree), latitude (degree), altitude (10 meters), and
19 | velocities (kilometers per hour) along previous three position components.
20 |
21 | The predicted trajectory of example data can be found in `desired_results.xlsx`.
22 |
--------------------------------------------------------------------------------
/data/example_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/data/example_data.npy
--------------------------------------------------------------------------------
/data/source_data.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/data/source_data.xlsx
--------------------------------------------------------------------------------
/data_ranges.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/data_ranges.npy
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import random
5 | # import coordinate_conversion as cc
6 | import numpy as np
7 | import torch
8 | import torch.utils.data as tu_data
9 |
10 | class DataGenerator:
11 | def __init__(self, data_path, minibatch_len, interval=1, use_preset_data_ranges=False,
12 | train=True, test=True, dev=True, train_shuffle=True, test_shuffle=False, dev_shuffle=True):
13 | assert os.path.exists(data_path)
14 | self.attr_names = ['lon', 'lat', 'alt', 'spdx', 'spdy', 'spdz']
15 | self.data_path = data_path
16 | self.interval = interval
17 | self.minibatch_len = minibatch_len
18 | self.data_status = np.load('data_ranges.npy', allow_pickle=True).item()
19 | assert type(self.data_status) is dict
20 | self.preset_data_ranges = {"lon": {"max": 113.689, "min": 93.883}, "lat": {"max": 37.585, "min": 19.305},
21 | "alt": {"max": 1500, "min": 0}, "spdx": {"max": 878, "min": -945},
22 | "spdy": {"max": 925, "min": -963}, "spdz": {"max": 43, "min": -48}}
23 | self.use_preset_data_ranges = use_preset_data_ranges
24 | if train:
25 | self.train_set = mini_DataGenerator(self.readtxt(os.path.join(self.data_path, 'train'), shuffle=train_shuffle))
26 | if dev:
27 | self.dev_set = mini_DataGenerator(self.readtxt(os.path.join(self.data_path, 'dev'), shuffle=dev_shuffle))
28 | if test:
29 | self.test_set = mini_DataGenerator(self.readtxt(os.path.join(self.data_path, 'test'), shuffle=test_shuffle))
30 | if use_preset_data_ranges:
31 | assert self.preset_data_ranges is not None
32 | print('data range:', self.data_status)
33 |
34 | def readtxt(self, data_path, shuffle=True):
35 | assert os.path.exists(data_path)
36 | data = []
37 | for root, dirs, file_names in os.walk(data_path):
38 | for file_name in file_names:
39 | if not file_name.endswith('txt'):
40 | continue
41 | with open(os.path.join(root, file_name)) as file:
42 | lines = file.readlines()
43 | lines = lines[::self.interval]
44 | if len(lines) == self.minibatch_len:
45 | data.append(lines)
46 | elif len(lines) < self.minibatch_len:
47 | continue
48 | else:
49 | for i in range(len(lines)-self.minibatch_len+1):
50 | data.append(lines[i:i+self.minibatch_len])
51 | print(f'{len(data)} items loaded from \'{data_path}\'')
52 | if shuffle:
53 | random.shuffle(data)
54 | return data
55 |
56 | def scale(self, inp, attr):
57 | assert type(attr) is str and attr in self.attr_names
58 | data_status = self.data_status if not self.use_preset_data_ranges else self.preset_data_ranges
59 | inp = (inp-data_status[attr]['min'])/(data_status[attr]['max']-data_status[attr]['min'])
60 | return inp
61 |
62 | def unscale(self, inp, attr):
63 | assert type(attr) is str and attr in self.attr_names
64 | data_status = self.data_status if not self.use_preset_data_ranges else self.preset_data_ranges
65 | inp = inp*(data_status[attr]['max']-data_status[attr]['min'])+data_status[attr]['min']
66 | return inp
67 |
68 | def collate(self, inp):
69 | '''
70 | :param inp: batch * n_sequence * n_attr
71 | :return:
72 | '''
73 | oup = []
74 | for minibatch in inp:
75 | tmp = []
76 | for line in minibatch:
77 | items = line.strip().split("|")
78 | lon, lat, alt, spdx, spdy, spdz = float(items[4]), float(items[5]), int(float(items[6]) / 10), \
79 | float(items[7]), float(items[8]), float(items[9])
80 | tmp.append([lon, lat, alt, spdx, spdy, spdz])
81 | minibatch = np.array(tmp)
82 | for i in range(minibatch.shape[-1]):
83 | minibatch[:, i] = self.scale(minibatch[:, i], self.attr_names[i])
84 | oup.append(minibatch)
85 | return np.array(oup)
86 |
87 |
88 | class mini_DataGenerator(tu_data.Dataset):
89 | def __init__(self, data):
90 | self.data = data
91 |
92 | def __getitem__(self, item):
93 | return self.data[item]
94 |
95 | def __len__(self):
96 | return len(self.data)
97 |
--------------------------------------------------------------------------------
/desired_results.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/desired_results.xlsx
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import time
5 | import numpy as np
6 | import math
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from dataloader import DataGenerator
10 | import logging
11 | import datetime
12 | import matplotlib.pyplot as plt
13 | from mpl_toolkits.mplot3d import Axes3D
14 | from pytorch_wavelets import DWT1DForward, DWT1DInverse
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--minibatch_len', default=10, type=int)
18 | parser.add_argument('--pre_len', default=1, type=int)
19 | parser.add_argument('--interval', default=1, type=int)
20 | parser.add_argument('--batch_size', default=2048, type=int)
21 | parser.add_argument('--cpu', action='store_true')
22 | parser.add_argument('--logdir', default='./log', type=str)
23 | parser.add_argument('--datadir', default='dataaaaaa', type=str)
24 | parser.add_argument('--netdir', default=None, type=str)
25 |
26 |
27 | class Test:
28 | def __init__(self, opt, net=None):
29 | self.opt = opt
30 | self.iscuda = torch.cuda.is_available()
31 | self.device = f'cuda:{torch.cuda.current_device()}' if self.iscuda and not opt.cpu else 'cpu'
32 | self.data_set = DataGenerator(data_path=self.opt.datadir,
33 | minibatch_len=opt.minibatch_len, interval=opt.interval,
34 | use_preset_data_ranges=False, train=False, dev=False, test_shuffle=True)
35 | self.net = net
36 | self.model_path = None
37 | self.MSE = torch.nn.MSELoss(reduction='mean')
38 | self.MAE = torch.nn.L1Loss(reduction='mean')
39 | if net is not None:
40 | assert next(self.net.parameters()).device == self.device
41 |
42 | def load_model(self, model_path):
43 | self.model_path = model_path
44 | self.net = torch.load(model_path, map_location=self.device)
45 |
46 | def test(self):
47 | print_str = f'model details:\n{self.net.args}'
48 | print(print_str)
49 | self.log_path = self.opt.logdir + f'/{datetime.datetime.now().strftime("%y-%m-%d")}'
50 | if not os.path.exists(self.log_path):
51 | os.makedirs(self.log_path)
52 | logging.basicConfig(filename=os.path.join(self.log_path, f'test_{self.net.args["train_opt"].comments}.log'),
53 | filemode='w', format='%(asctime)s %(message)s', level=logging.DEBUG)
54 | logging.debug(print_str)
55 | logging.debug(self.model_path)
56 | test_data = DataLoader(dataset=self.data_set.test_set, batch_size=self.opt.batch_size, shuffle=False,
57 | collate_fn=self.data_set.collate)
58 | idwt = DWT1DInverse(wave=self.net.args['train_opt'].wavelet, mode=self.net.args['train_opt'].wt_mode).to(
59 | self.device)
60 | self.net.eval()
61 | tgt_set = []
62 | pre_set = []
63 |
64 | with torch.no_grad():
65 | his_batch_set = []
66 | all_score_set = []
67 | for i, batch in enumerate(test_data):
68 | batch = torch.FloatTensor(batch).to(self.device)
69 | n_batch, _, n_attr = batch.shape
70 | inp_batch = batch[:, :self.opt.minibatch_len-self.opt.pre_len, :] # shape: batch * his_len * n_attr
71 | his_batch_set.append(inp_batch)
72 | tgt_set.append(batch[:, -self.opt.pre_len:, :]) # shape: batch * pre_len * n_attr
73 | pre_batch_set = []
74 | for j in range(self.opt.pre_len):
75 | if j > 0:
76 | new_batch = pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :].unsqueeze(1)
77 | inp_batch = torch.cat((inp_batch[:, 1:, :], new_batch), dim=1) # shape: batch * his_len * n_attr
78 | if self.net.__class__.__name__ == 'WTFTP':
79 | wt_pre_batch, score_set = self.net(inp_batch)
80 | else:
81 | wt_pre_batch = self.net(inp_batch)
82 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(),
83 | [comp.transpose(1, 2).contiguous() for comp in
84 | wt_pre_batch[:-1]])).contiguous()
85 | pre_batch = pre_batch.transpose(1, 2) # shape: batch * n_sequence * n_attr
86 | pre_batch_set.append(pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :])
87 | if self.net.__class__.__name__ == 'WTFTP' and j == 0:
88 | all_score_set.append(score_set)
89 | pre_batch_set = torch.stack(pre_batch_set, dim=1) # shape: batch * pre_len * n_attr
90 | pre_set.append(pre_batch_set)
91 |
92 | tgt_set = torch.cat(tgt_set, dim=0)
93 | pre_set = torch.cat(pre_set, dim=0)
94 | # try:
95 | # all_score_set = torch.cat(all_score_set, dim=0)
96 | # except:
97 | # all_score_set = ['no scores']
98 | # his_batch_set = torch.cat(his_batch_set, dim=0)
99 | # torch.save([his_batch_set, tgt_set, pre_set, all_score_set],
100 | # f'{self.net.args["train_opt"].comments}_his_tgt_pre_score.pt', _use_new_zipfile_serialization=False)
101 | for i in range(self.opt.pre_len):
102 | avemse = float(self.MSE(tgt_set[:, :i + 1, :], pre_set[:, :i + 1, :]).cpu())
103 | avemae = float(self.MAE(tgt_set[:, :i + 1, :], pre_set[:, :i + 1, :]).cpu())
104 | rmse = {}
105 | mae = {}
106 | mre = {}
107 | for j, name in enumerate(self.data_set.attr_names):
108 | rmse[name] = float(self.MSE(self.data_set.unscale(tgt_set[:, :i + 1, j], name),
109 | self.data_set.unscale(pre_set[:, :i + 1, j], name)).sqrt().cpu())
110 | mae[name] = float(self.MAE(self.data_set.unscale(tgt_set[:, :i + 1, j], name),
111 | self.data_set.unscale(pre_set[:, :i + 1, j], name)).cpu())
112 | logit = self.data_set.unscale(tgt_set[:, :i + 1, j], name) != 0
113 | mre[name] = float(torch.mean(torch.abs(self.data_set.unscale(tgt_set[:, :i + 1, j], name)-
114 | self.data_set.unscale(pre_set[:, :i + 1, j], name))[logit]/
115 | self.data_set.unscale(tgt_set[:, :i + 1, j], name)[logit]).cpu()) * 100 if name in 'lonlatalt' \
116 | else "N/A"
117 | lon = self.data_set.unscale(pre_set[:, :i + 1, 0], 'lon').cpu().numpy()
118 | lat = self.data_set.unscale(pre_set[:, :i + 1, 1], 'lat').cpu().numpy()
119 | alt = self.data_set.unscale(pre_set[:, :i + 1, 2], 'alt').cpu().numpy() / 100 # km
120 | X, Y, Z = self.gc2ecef(lon, lat, alt)
121 | lon_t = self.data_set.unscale(tgt_set[:, :i + 1, 0], 'lon').cpu().numpy()
122 | lat_t = self.data_set.unscale(tgt_set[:, :i + 1, 1], 'lat').cpu().numpy()
123 | alt_t = self.data_set.unscale(tgt_set[:, :i + 1, 2], 'alt').cpu().numpy() / 100 # km
124 | X_t, Y_t, Z_t = self.gc2ecef(lon_t, lat_t, alt_t)
125 | MDE = np.mean(np.sqrt((X - X_t) ** 2 + (Y - Y_t) ** 2 + (Z - Z_t) ** 2))
126 | print_str = f'\nStep {i + 1}: \naveMSE(scaled): {avemse:.8f}, in each attr(RMSE, unscaled): {rmse}\n' \
127 | f'aveMAE(scaled): {avemae:.8f}, in each attr(MAE, unscaled): {mae}\n' \
128 | f'In each attr(MRE, %): {mre}\n' \
129 | f'MDE(unscaled): {MDE:.8f}\n'
130 | print(print_str)
131 | logging.debug(print_str)
132 |
133 | def gc2ecef(self, lon, lat, alt):
134 | a = 6378.137 # km
135 | b = 6356.752
136 | lat = np.radians(lat)
137 | lon = np.radians(lon)
138 | e_square = 1 - (b ** 2) / (a ** 2)
139 | N = a / np.sqrt(1 - e_square * (np.sin(lat) ** 2))
140 | X = (N + alt) * np.cos(lat) * np.cos(lon)
141 | Y = (N + alt) * np.cos(lat) * np.sin(lon)
142 | Z = ((b ** 2) / (a ** 2) * N + alt) * np.sin(lat)
143 | return X, Y, Z
144 |
145 | def draw_demo(self, items=None, realtime=False):
146 | plt.rcParams['font.family'] = 'arial'
147 | plt.rcParams['font.size'] = 16
148 | total_num = len(self.data_set.test_set)
149 | test_data = DataLoader(dataset=self.data_set.test_set, batch_size=self.opt.batch_size, shuffle=False,
150 | collate_fn=self.data_set.collate)
151 | idwt = DWT1DInverse(wave=self.net.args['train_opt'].wavelet, mode=self.net.args['train_opt'].wt_mode).to(
152 | self.device)
153 | if items is None:
154 | items = [random.randint(0, total_num)]
155 | elif type(items) is int:
156 | items = [items % total_num]
157 | elif type(items) is list and len(items) > 0 and type(items[0]) is int:
158 | pass
159 | else:
160 | TypeError(type(items))
161 | if realtime:
162 | items = [int(input("item: "))]
163 | while len(items) > 0 and items[0] > 0:
164 | item = items[0]
165 | del items[0]
166 | n_batch = item // self.opt.batch_size
167 | n_minibatch = item % self.opt.batch_size
168 | sel_batch = None
169 | for i, batch in enumerate(test_data):
170 | if i == n_batch:
171 | sel_batch = batch
172 | break
173 | traj = sel_batch[n_minibatch:n_minibatch + 1, ...]
174 | with torch.no_grad():
175 | self.net.eval()
176 | self.net.to(self.device)
177 | inp_batch = torch.FloatTensor(traj[:, :self.opt.minibatch_len-self.opt.pre_len, :]).to(
178 | self.device) # shape: 1 * his_len * n_attr
179 | pre_batch_set = []
180 | full_pre_set = []
181 | for j in range(self.opt.pre_len):
182 | if j > 0:
183 | new_batch = pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :].unsqueeze(1)
184 | inp_batch = torch.cat((inp_batch[:, 1:, :], new_batch), dim=1) # shape: batch * his_len * n_attr
185 | if self.net.__class__.__name__ == 'WTFTP':
186 | wt_pre_batch, score_set = self.net(inp_batch)
187 | else:
188 | wt_pre_batch = self.net(inp_batch)
189 | if j == 0:
190 | first_wt_pre = wt_pre_batch
191 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(),
192 | [comp.transpose(1, 2).contiguous() for comp in
193 | wt_pre_batch[:-1]])).contiguous()
194 | pre_batch = pre_batch.transpose(1, 2) # shape: 1 * n_sequence * n_attr
195 | pre_batch_set.append(pre_batch[:, self.opt.minibatch_len-self.opt.pre_len, :])
196 | full_pre_set.append(pre_batch[:, :self.opt.minibatch_len-self.opt.pre_len + 1, :].clone())
197 | pre_batch_set = torch.stack(pre_batch_set, dim=1) # shape: 1 * pre_len * n_attr
198 |
199 | lla_his = np.array(traj[0, :self.opt.minibatch_len-self.opt.pre_len, 0:3]) # shape: his_len * n_attr
200 | lla_trg = np.array(traj[0, -self.opt.pre_len:, 0:3]) # shape: pre_len * n_attr
201 | lla_pre = np.array(pre_batch_set[0, :, 0:3].cpu().numpy()) # shape: pre_len * n_attr
202 | for i, name in enumerate(self.data_set.attr_names):
203 | if i > 2:
204 | break
205 | lla_his[:, i] = self.data_set.unscale(lla_his[:, i], name)
206 | lla_trg[:, i] = self.data_set.unscale(lla_trg[:, i], name)
207 | lla_pre[:, i] = self.data_set.unscale(lla_pre[:, i], name)
208 |
209 | fig = plt.figure(figsize=(9,9))
210 | elev_azim_set = [[90, 0], [0, 0], [0, 90], [None, None]] # represent top view, lateral view(lat), lateral view(lon) and default, respectively
211 | for i, elev_azim in enumerate(elev_azim_set):
212 | ax = fig.add_subplot(2, 2, i + 1, projection='3d')
213 | ax.view_init(elev=elev_azim[0], azim=elev_azim[1])
214 | ax.plot3D(lla_his[:, 0], lla_his[:, 1], lla_his[:, 2], marker='o', markeredgecolor='dodgerblue',
215 | label='his')
216 | ax.plot3D(lla_trg[:, 0], lla_trg[:, 1], lla_trg[:, 2], marker='*', markeredgecolor='blueviolet',
217 | label='tgt')
218 | ax.plot3D(lla_pre[:, 0], lla_pre[:, 1], lla_pre[:, 2], marker='p', markeredgecolor='orangered',
219 | label='pre')
220 | ax.set_xlabel('lon')
221 | ax.set_ylabel('lat')
222 | ax.set_zlabel('alt')
223 | ax.set_zlim(min(lla_his[:, 2]) - 20, max(lla_his[:, 2]) + 20)
224 | plt.suptitle(f'item_{item}')
225 | ax.legend()
226 | plt.tight_layout()
227 | plt.show()
228 | if realtime:
229 | items.append(int(input("item: ")))
230 |
231 | if __name__ == '__main__':
232 | opt = parser.parse_args()
233 | test = Test(opt)
234 | test.load_model(opt.netdir)
235 | test.test()
236 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: zheng zhang
3 | Email: musevr.ae@gmail.com
4 | """
5 | import torch
6 | from torch import nn
7 |
8 | class WTFTP_AttnRemoved(nn.Module):
9 | def __init__(self, n_inp, n_oup, n_embding=128, n_encoderLayers=4, n_decoderLayers=4, proj='linear', activation='relu', maxlevel=1, en_dropout=0,
10 | de_dropout=0, steps=None):
11 | super(WTFTP_AttnRemoved, self).__init__()
12 | self.args = locals()
13 | self.n_oup = n_oup
14 | self.maxlevel = maxlevel
15 | self.n_decoderLayers = n_decoderLayers
16 | self.steps = steps if steps is not None else [5,3,2]
17 | if activation == 'relu':
18 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding//2, bias=False, proj=proj), nn.ReLU(), Embed(n_embding//2, n_embding, bias=False, proj=proj), nn.ReLU())
19 | elif activation == 'sigmoid':
20 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding//2, bias=False, proj=proj), nn.Sigmoid(), Embed(n_embding//2, n_embding, bias=False, proj=proj), nn.Sigmoid())
21 | else:
22 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding//2, bias=False, proj=proj), nn.ReLU(), Embed(n_embding//2, n_embding, bias=False, proj=proj), nn.ReLU())
23 |
24 | self.encoder = nn.LSTM(input_size=n_embding,
25 | hidden_size=n_embding,
26 | num_layers=n_encoderLayers,
27 | bidirectional=False,
28 | batch_first=True,
29 | dropout=en_dropout,
30 | bias=False)
31 | self.decoders = nn.ModuleList([nn.LSTM(input_size=n_embding,
32 | hidden_size=n_embding,
33 | num_layers=n_decoderLayers,
34 | bidirectional=False,
35 | batch_first=True,
36 | dropout=de_dropout,
37 | bias=False) for _ in range(1+maxlevel)])
38 | self.LNs = nn.ModuleList([nn.LayerNorm(n_embding) for _ in range(1+maxlevel)])
39 | self.oup_embdings = nn.ModuleList([Embed(n_embding, n_oup, bias=True, proj=proj) for _ in range(1+maxlevel)])
40 |
41 | def forward(self, inp):
42 | """
43 | :param inp: shape: batch * n_sequence * n_attr
44 | :return: shape: batch * n_desiredLength * level * n_attr
45 | """
46 | inp = self.inp_embding(inp)
47 | _, (h_0, c_0) = self.encoder(inp)
48 | oup = []
49 | n_batch = inp.shape[0]
50 | i = 0
51 | for decoder, LN, oup_embding in zip(self.decoders, self.LNs, self.oup_embdings):
52 | this_decoder_oup = []
53 | decoder_first_inp = torch.zeros((n_batch, 1, inp.shape[-1]), device=inp.device)
54 | de_H = torch.zeros(self.n_decoderLayers, inp.shape[0], inp.shape[-1], dtype=torch.float,
55 | device=inp.device)
56 | de_C = torch.zeros(self.n_decoderLayers, inp.shape[0], inp.shape[-1], dtype=torch.float,
57 | device=inp.device)
58 | de_C[0, :, :] = c_0[-1, :, :].clone()
59 | decoder_hidden = (de_H, de_C)
60 | if i == self.maxlevel:
61 | this_step = self.steps[self.maxlevel-1]
62 | else:
63 | this_step = self.steps[i]
64 | for _ in range(this_step):
65 | decoder_first_inp, decoder_hidden = decoder(decoder_first_inp, decoder_hidden)
66 | de_oup = LN(decoder_first_inp)
67 | de_oup = oup_embding(de_oup)
68 | this_decoder_oup.append(de_oup)
69 | this_decoder_oup = torch.cat(this_decoder_oup, dim=1) # shape: batch * n_desiredLength * n_attr
70 | oup.append(this_decoder_oup)
71 | i += 1
72 | return oup
73 |
74 |
75 | class WTFTP(nn.Module):
76 | def __init__(self, n_inp, n_oup, his_step, n_embding=64, en_layers=4, de_layers=1, proj='linear',
77 | activation='relu', maxlevel=1, en_dropout=0.,
78 | de_dropout=0., out_split=False, decoder_init_zero=False, bias=False, se_skip=True,
79 | attn_conv_params=None):
80 | """
81 | :param n_inp:
82 | :param n_oup:
83 | :param n_embding:
84 | :param layers:
85 | :param maxlevel:
86 | :param en_dropout:
87 | :param de_dropout:
88 | :param out_split:
89 | :param bias:
90 | """
91 | super(WTFTP, self).__init__()
92 | self.args = locals()
93 | self.n_embding = n_embding
94 | self.n_oup = n_oup
95 | self.maxlevel = maxlevel
96 | self.decoder_init_zero = decoder_init_zero
97 | self.de_layers = de_layers
98 | attn_conv_params = {0: {'stride': 2, 'kernel': 2, 'pad': 1}, 1: {'stride': 3, 'kernel': 3, 'pad': 0},
99 | 2: {'stride': 5, 'kernel': 5, 'pad': 1}} if attn_conv_params == None else attn_conv_params
100 | if activation == 'relu':
101 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding // 2, bias=False, proj=proj), nn.ReLU(),
102 | Embed(n_embding // 2, n_embding, bias=False, proj=proj), nn.ReLU())
103 | elif activation == 'sigmoid':
104 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding // 2, bias=False, proj=proj), nn.Sigmoid(),
105 | Embed(n_embding // 2, n_embding, bias=False, proj=proj), nn.Sigmoid())
106 | else:
107 | self.inp_embding = nn.Sequential(Embed(n_inp, n_embding // 2, bias=False, proj=proj), nn.ReLU(),
108 | Embed(n_embding // 2, n_embding, bias=False, proj=proj), nn.ReLU())
109 | self.encoder = nn.LSTM(input_size=n_embding,
110 | hidden_size=n_embding,
111 | num_layers=en_layers,
112 | bidirectional=False,
113 | batch_first=True,
114 | dropout=en_dropout,
115 | bias=False)
116 | self.Wt_attns = nn.ModuleList([WaveletAttention(hid_dim=n_embding, init_steps=his_step, skip=se_skip,
117 | kernel=attn_conv_params[i]['kernel'],
118 | stride=attn_conv_params[i]['stride'],
119 | padding=attn_conv_params[i]['pad'])
120 | for i in range(maxlevel)] +
121 | [WaveletAttention(hid_dim=n_embding, init_steps=his_step, skip=se_skip,
122 | kernel=attn_conv_params[maxlevel - 1]['kernel'],
123 | stride=attn_conv_params[maxlevel - 1]['stride'],
124 | padding=attn_conv_params[maxlevel - 1]['pad'])])
125 | self.decoders = nn.ModuleList([nn.LSTM(input_size=n_embding,
126 | hidden_size=n_embding,
127 | num_layers=de_layers,
128 | bidirectional=False,
129 | batch_first=True,
130 | dropout=de_dropout,
131 | bias=False) for _ in range(1 + maxlevel)])
132 | self.LNs = nn.ModuleList([nn.LayerNorm(n_embding) for _ in range(1 + maxlevel)])
133 | if out_split:
134 | assert n_embding % n_oup == 0
135 | self.out_split = out_split
136 | self.oup_embdings = nn.ModuleList(
137 | [nn.Linear(n_embding // n_oup, 1, bias=bias) for _ in range(n_oup * 2 ** maxlevel)])
138 | else:
139 | self.oup_embdings = nn.ModuleList(
140 | [Embed(n_embding, n_oup, bias=True, proj=proj) for _ in range(1 + maxlevel)])
141 |
142 | def forward(self, inp):
143 | """
144 | :param inp: shape: batch * n_sequence * n_attr
145 | :param steps: coeff length of wavelet
146 | :return: coef_set: shape: batch * steps * levels * n_oup,
147 | all_scores_set: shape: batch * steps * levels * n_sequence, n_sequence here is timeSteps of inp
148 | """
149 | embdings = self.inp_embding(inp) # batch * n_sequence * n_embding
150 | en_oup, (h_en, c_en) = self.encoder(embdings)
151 | all_de_oup_set = []
152 | all_scores_set = []
153 | for i, (attn, decoder, LN) in enumerate(zip(self.Wt_attns, self.decoders, self.LNs)):
154 | if self.decoder_init_zero:
155 | de_HC = None
156 | else:
157 | de_H = torch.zeros(self.de_layers, embdings.shape[0], embdings.shape[-1], dtype=torch.float,
158 | device=embdings.device)
159 | de_C = torch.zeros(self.de_layers, embdings.shape[0], embdings.shape[-1], dtype=torch.float,
160 | device=embdings.device)
161 | de_C[0, :, :] = c_en[-1, :, :].clone()
162 | de_HC = (de_H, de_C)
163 | de_inp, weight = attn(en_oup)
164 | de_oup_set, _ = decoder(de_inp, de_HC) # shape: batch * steps * n_embding
165 | all_de_oup_set.append(LN(de_oup_set))
166 | all_scores_set.append(weight)
167 | all_scores_set = torch.cat(all_scores_set, dim=1) # shape: batch * steps * levels * n_sequence
168 | if hasattr(self, 'out_split'):
169 | all_de_oup_set = torch.cat(all_de_oup_set, dim=2) # shape: batch * steps * (levels*n_embding)
170 | split_all_de_oup_set = torch.split(all_de_oup_set, split_size_or_sections=self.n_embding // self.n_oup,
171 | dim=-1)
172 | coef_set = []
173 | for i, linear in enumerate(self.oup_embdings):
174 | coef_set.append(
175 | linear(split_all_de_oup_set[i])
176 | ) # shape: batch * steps * 1
177 | coef_set = torch.cat(coef_set, dim=-1).reshape(inp.shape[0], -1, 2 ** self.maxlevel,
178 | self.n_oup) # shape: batch * steps * levels * n_oup
179 | else:
180 | coef_set = []
181 | for i, linear in enumerate(self.oup_embdings):
182 | coef_set.append(
183 | linear(all_de_oup_set[i])
184 | ) # shape: batch * steps * n_oup
185 | return coef_set, all_scores_set
186 |
187 | class Embed(nn.Module):
188 | def __init__(self, in_features, out_features, bias=True, proj='linear'):
189 | super(Embed, self).__init__()
190 | self.proj = proj
191 | if proj == 'linear':
192 | self.embed = nn.Linear(in_features, out_features, bias)
193 | else:
194 | self.embed = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=3, stride=1,
195 | padding=1,
196 | padding_mode='replicate', bias=bias)
197 |
198 | def forward(self, inp):
199 | """
200 | inp: B * T * D
201 | """
202 | if self.proj == 'linear':
203 | inp = self.embed(inp)
204 | else:
205 | inp = self.embed(inp.transpose(1, 2)).transpose(1, 2)
206 | return inp
207 |
208 | class WaveletAttention(nn.Module):
209 | def __init__(self, hid_dim, init_steps, skip=True, kernel=3, stride=3, padding=0):
210 | super(WaveletAttention, self).__init__()
211 | self.enhance = EnhancedBlock(hid_size=hid_dim, channel=init_steps, skip=skip)
212 | self.convs = nn.Sequential(nn.Conv1d(in_channels=hid_dim,
213 | out_channels=hid_dim,
214 | kernel_size=kernel,
215 | stride=stride, padding=padding,
216 | padding_mode='zeros',
217 | bias=False),
218 | nn.ReLU())
219 |
220 | def forward(self, hid_set: torch.Tensor):
221 | """
222 | q = W1 * hid_0, K = W2 * hid_set
223 | softmax(q * K.T) * hid_set
224 | :param hid_set: Key set, shape: batch * n_sequence * hid_dim
225 | :param hid_0: Query hidden state, shape: batch * 1 * hid_dim
226 | :return: out: shape: shape: batch * 1 * hid_dim, scores: shape: batch * 1 * sequence
227 | """
228 | reweighted, weight = self.enhance(hid_set)
229 | align = self.convs(reweighted.transpose(1, 2)).transpose(1, 2) # B * DestStep * D
230 | return align, weight
231 |
232 |
233 | class EnhancedBlock(nn.Module):
234 | def __init__(self, hid_size, channel, skip=True):
235 | super(EnhancedBlock, self).__init__()
236 | self.skip = skip
237 | self.comp = nn.Sequential(nn.Linear(hid_size, hid_size // 2, bias=False),
238 | nn.ReLU(),
239 | nn.Linear(hid_size // 2, 1, bias=False))
240 | self.activate = nn.Sequential(nn.Linear(channel, channel // 2, bias=False),
241 | nn.ReLU(),
242 | nn.Linear(channel // 2, channel, bias=False),
243 | nn.Sigmoid())
244 |
245 | def forward(self, inp):
246 | S = self.comp(inp) # B * T * 1
247 | E = self.activate(S.transpose(1, 2)) # B * 1 * T
248 | out = inp * E.transpose(1, 2).expand_as(inp)
249 | if self.skip:
250 | out += inp
251 | return out, E # B * T * D
252 |
253 | if __name__ == '__main__':
254 | model = WTFTP(6, 6, 9)
255 | model = WTFTP_AttnRemoved(6, 6)
256 | inp = torch.rand((100, 9, 6))
257 | out, _ = model(inp)
258 | print(model.__class__.__name__)
259 | print(out)
260 | print(model.args)
261 |
--------------------------------------------------------------------------------
/pics/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pics/framework.png
--------------------------------------------------------------------------------
/pytorch_wavelets/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | '__version__',
3 | 'DTCWTForward',
4 | 'DTCWTInverse',
5 | 'DWTForward',
6 | 'DWTInverse',
7 | 'DWT1DForward',
8 | 'DWT1DInverse',
9 | 'DTCWT',
10 | 'IDTCWT',
11 | 'DWT',
12 | 'IDWT',
13 | 'DWT1D',
14 | 'DWT2D',
15 | 'IDWT1D',
16 | 'IDWT2D',
17 | 'ScatLayer',
18 | 'ScatLayerj2'
19 | ]
20 |
21 | from pytorch_wavelets._version import __version__
22 | from pytorch_wavelets.dtcwt.transform2d import DTCWTForward, DTCWTInverse
23 | from pytorch_wavelets.dwt.transform2d import DWTForward, DWTInverse
24 | from pytorch_wavelets.dwt.transform1d import DWT1DForward, DWT1DInverse
25 | from pytorch_wavelets.scatternet import ScatLayer, ScatLayerj2
26 |
27 | # Some aliases
28 | DTCWT = DTCWTForward
29 | IDTCWT = DTCWTInverse
30 | DWT = DWTForward
31 | IDWT = DWTInverse
32 | DWT2D = DWT
33 | IDWT2D = IDWT
34 |
35 | DWT1D = DWT1DForward
36 | IDWT1D = DWT1DInverse
37 |
--------------------------------------------------------------------------------
/pytorch_wavelets/_version.py:
--------------------------------------------------------------------------------
1 | # IMPORTANT: before release, remove the 'devN' tag from the release name
2 | __version__ = '1.3.0'
3 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Provide low-level torch accelerated operations. This backend requires that
3 | torch be installed. Works best with a GPU but still offers good
4 | improvements with a CPU.
5 |
6 | """
7 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/coeffs.py:
--------------------------------------------------------------------------------
1 | """Functions to load standard wavelet coefficients.
2 |
3 | """
4 | from __future__ import absolute_import
5 |
6 | from numpy import load
7 | from pkg_resources import resource_stream
8 | try:
9 | import pywt
10 | _HAVE_PYWT = True
11 | except ImportError:
12 | _HAVE_PYWT = False
13 |
14 | COEFF_CACHE = {}
15 |
16 |
17 | def _load_from_file(basename, varnames):
18 |
19 | try:
20 | mat = COEFF_CACHE[basename]
21 | except KeyError:
22 | with resource_stream('pytorch_wavelets.dtcwt.data', basename + '.npz') as f:
23 | mat = dict(load(f))
24 | COEFF_CACHE[basename] = mat
25 |
26 | try:
27 | return tuple(mat[k] for k in varnames)
28 | except KeyError:
29 | raise ValueError(
30 | 'Wavelet does not define ({0}) coefficients'.format(
31 | ', '.join(varnames)))
32 |
33 |
34 | def biort(name):
35 | """ Deprecated. Use :py::func:`pytorch_wavelets.dtcwt.coeffs.level1`
36 | Instead
37 | """
38 | return level1(name, compact=True)
39 |
40 |
41 | def level1(name, compact=False):
42 | """Load level 1 wavelet by name.
43 |
44 | :param name: a string specifying the wavelet family name
45 | :returns: a tuple of vectors giving filter coefficients
46 |
47 | ============= ============================================
48 | Name Wavelet
49 | ============= ============================================
50 | antonini Antonini 9,7 tap filters.
51 | farras Farras 8,8 tap filters
52 | legall LeGall 5,3 tap filters.
53 | near_sym_a Near-Symmetric 5,7 tap filters.
54 | near_sym_b Near-Symmetric 13,19 tap filters.
55 | near_sym_b_bp Near-Symmetric 13,19 tap filters + BP filter
56 | ============= ============================================
57 |
58 | Return a tuple whose elements are a vector specifying the h0o, g0o, h1o and
59 | g1o coefficients.
60 |
61 | See :ref:`rot-symm-wavelets` for an explanation of the ``near_sym_b_bp``
62 | wavelet filters.
63 |
64 | :raises IOError: if name does not correspond to a set of wavelets known to
65 | the library.
66 | :raises ValueError: if name doesn't specify
67 | :py:func:`pytorch_wavelets.dtcwt.coeffs.qshift` wavelet.
68 |
69 | """
70 | if compact:
71 | if name == 'near_sym_b_bp':
72 | return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o', 'h2o', 'g2o'))
73 | else:
74 | return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o'))
75 | else:
76 | return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b',
77 | 'g1a', 'g1b'))
78 |
79 |
80 | def qshift(name):
81 | """Load level >=2 wavelet by name,
82 |
83 | :param name: a string specifying the wavelet family name
84 | :returns: a tuple of vectors giving filter coefficients
85 |
86 | ============ ============================================
87 | Name Wavelet
88 | ============ ============================================
89 | qshift_06 Quarter Sample Shift Orthogonal (Q-Shift) 10,10 tap filters,
90 | (only 6,6 non-zero taps).
91 | qshift_a Q-shift 10,10 tap filters,
92 | (with 10,10 non-zero taps, unlike qshift_06).
93 | qshift_b Q-Shift 14,14 tap filters.
94 | qshift_c Q-Shift 16,16 tap filters.
95 | qshift_d Q-Shift 18,18 tap filters.
96 | qshift_b_bp Q-Shift 18,18 tap filters + BP
97 | ============ ============================================
98 |
99 | Return a tuple whose elements are a vector specifying the h0a, h0b, g0a,
100 | g0b, h1a, h1b, g1a and g1b coefficients.
101 |
102 | See :ref:`rot-symm-wavelets` for an explanation of the ``qshift_b_bp``
103 | wavelet filters.
104 |
105 | :raises IOError: if name does not correspond to a set of wavelets known to
106 | the library.
107 | :raises ValueError: if name doesn't specify a
108 | :py:func:`pytorch_wavelets.dtcwt.coeffs.biort` wavelet.
109 |
110 | """
111 | if name == 'qshift_b_bp':
112 | return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b',
113 | 'g1a', 'g1b', 'h2a', 'h2b', 'g2a','g2b'))
114 | else:
115 | return _load_from_file(name, ('h0a', 'h0b', 'g0a', 'g0b', 'h1a', 'h1b',
116 | 'g1a', 'g1b'))
117 |
118 |
119 | def pywt_coeffs(name):
120 | """ Wraps pywt Wavelet function. """
121 | if not _HAVE_PYWT:
122 | raise ImportError("Could not find PyWavelets module")
123 | return pywt.Wavelet(name)
124 |
125 | # vim:sw=4:sts=4:et
126 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/__init__.py
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/antonini.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/antonini.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/farras.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/farras.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/legall.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/legall.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/near_sym_a.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_a.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/near_sym_a2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_a2.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/near_sym_b.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_b.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/near_sym_b_bp.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/near_sym_b_bp.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/qshift_06.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_06.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/qshift_32.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_32.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/qshift_a.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_a.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/qshift_b.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_b.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/qshift_b_bp.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_b_bp.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/qshift_c.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_c.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/data/qshift_d.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dtcwt/data/qshift_d.npz
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/lowlevel.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from pytorch_wavelets.utils import symm_pad_1d as symm_pad
7 |
8 |
9 | def as_column_vector(v):
10 | """Return *v* as a column vector with shape (N,1).
11 |
12 | """
13 | v = np.atleast_2d(v)
14 | if v.shape[0] == 1:
15 | return v.T
16 | else:
17 | return v
18 |
19 |
20 | def _as_row_vector(v):
21 | """Return *v* as a row vector with shape (1, N).
22 | """
23 | v = np.atleast_2d(v)
24 | if v.shape[0] == 1:
25 | return v
26 | else:
27 | return v.T
28 |
29 |
30 | def _as_row_tensor(h):
31 | if isinstance(h, torch.Tensor):
32 | h = torch.reshape(h, [1, -1])
33 | else:
34 | h = as_column_vector(h).T
35 | h = torch.tensor(h, dtype=torch.get_default_dtype())
36 | return h
37 |
38 |
39 | def _as_col_vector(v):
40 | """Return *v* as a column vector with shape (N,1).
41 | """
42 | v = np.atleast_2d(v)
43 | if v.shape[0] == 1:
44 | return v.T
45 | else:
46 | return v
47 |
48 |
49 | def _as_col_tensor(h):
50 | if isinstance(h, torch.Tensor):
51 | h = torch.reshape(h, [-1, 1])
52 | else:
53 | h = as_column_vector(h)
54 | h = torch.tensor(h, dtype=torch.get_default_dtype())
55 | return h
56 |
57 |
58 | def prep_filt(h, c, transpose=False):
59 | """ Prepares an array to be of the correct format for pytorch.
60 | Can also specify whether to make it a row filter (set tranpose=True)"""
61 | h = _as_col_vector(h)[::-1]
62 | h = h[None, None, :]
63 | h = np.repeat(h, repeats=c, axis=0)
64 | if transpose:
65 | h = h.transpose((0,1,3,2))
66 | h = np.copy(h)
67 | return torch.tensor(h, dtype=torch.get_default_dtype())
68 |
69 |
70 | def colfilter(X, h, mode='symmetric'):
71 | if X is None or X.shape == torch.Size([]):
72 | return torch.zeros(1,1,1,1, device=X.device)
73 | b, ch, row, col = X.shape
74 | m = h.shape[2] // 2
75 | if mode == 'symmetric':
76 | xe = symm_pad(row, m)
77 | X = F.conv2d(X[:,:,xe], h.repeat(ch,1,1,1), groups=ch)
78 | else:
79 | X = F.conv2d(X, h.repeat(ch, 1, 1, 1), groups=ch, padding=(m, 0))
80 | return X
81 |
82 |
83 | def rowfilter(X, h, mode='symmetric'):
84 | if X is None or X.shape == torch.Size([]):
85 | return torch.zeros(1,1,1,1, device=X.device)
86 | b, ch, row, col = X.shape
87 | m = h.shape[2] // 2
88 | h = h.transpose(2,3).contiguous()
89 | if mode == 'symmetric':
90 | xe = symm_pad(col, m)
91 | X = F.conv2d(X[:,:,:,xe], h.repeat(ch,1,1,1), groups=ch)
92 | else:
93 | X = F.conv2d(X, h.repeat(ch,1,1,1), groups=ch, padding=(0, m))
94 | return X
95 |
96 |
97 | def coldfilt(X, ha, hb, highpass=False, mode='symmetric'):
98 | if X is None or X.shape == torch.Size([]):
99 | return torch.zeros(1,1,1,1, device=X.device)
100 | batch, ch, r, c = X.shape
101 | r2 = r // 2
102 | if r % 4 != 0:
103 | raise ValueError('No. of rows in X must be a multiple of 4\n' +
104 | 'X was {}'.format(X.shape))
105 |
106 | if mode == 'symmetric':
107 | m = ha.shape[2]
108 | xe = symm_pad(r, m)
109 | X = torch.cat((X[:,:,xe[2::2]], X[:,:,xe[3::2]]), dim=1)
110 | h = torch.cat((ha.repeat(ch, 1, 1, 1), hb.repeat(ch, 1, 1, 1)), dim=0)
111 | X = F.conv2d(X, h, stride=(2,1), groups=ch*2)
112 | else:
113 | raise NotImplementedError()
114 |
115 | # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping
116 | # interleaves the columns
117 | if highpass:
118 | X = torch.stack((X[:, ch:], X[:, :ch]), dim=-2).view(batch, ch, r2, c)
119 | else:
120 | X = torch.stack((X[:, :ch], X[:, ch:]), dim=-2).view(batch, ch, r2, c)
121 |
122 | return X
123 |
124 |
125 | def rowdfilt(X, ha, hb, highpass=False, mode='symmetric'):
126 | if X is None or X.shape == torch.Size([]):
127 | return torch.zeros(1,1,1,1, device=X.device)
128 | batch, ch, r, c = X.shape
129 | c2 = c // 2
130 | if c % 4 != 0:
131 | raise ValueError('No. of cols in X must be a multiple of 4\n' +
132 | 'X was {}'.format(X.shape))
133 |
134 | if mode == 'symmetric':
135 | m = ha.shape[2]
136 | xe = symm_pad(c, m)
137 | X = torch.cat((X[:,:,:,xe[2::2]], X[:,:,:,xe[3::2]]), dim=1)
138 | h = torch.cat((ha.reshape(1,1,1,m).repeat(ch, 1, 1, 1),
139 | hb.reshape(1,1,1,m).repeat(ch, 1, 1, 1)), dim=0)
140 | X = F.conv2d(X, h, stride=(1,2), groups=ch*2)
141 | else:
142 | raise NotImplementedError()
143 |
144 | # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping
145 | # interleaves the columns
146 | if highpass:
147 | Y = torch.stack((X[:, ch:], X[:, :ch]), dim=-1).view(batch, ch, r, c2)
148 | else:
149 | Y = torch.stack((X[:, :ch], X[:, ch:]), dim=-1).view(batch, ch, r, c2)
150 |
151 | return Y
152 |
153 |
154 | def colifilt(X, ha, hb, highpass=False, mode='symmetric'):
155 | if X is None or X.shape == torch.Size([]):
156 | return torch.zeros(1,1,1,1, device=X.device)
157 | m = ha.shape[2]
158 | m2 = m // 2
159 | hao = ha[:,:,1::2]
160 | hae = ha[:,:,::2]
161 | hbo = hb[:,:,1::2]
162 | hbe = hb[:,:,::2]
163 | batch, ch, r, c = X.shape
164 | if r % 2 != 0:
165 | raise ValueError('No. of rows in X must be a multiple of 2.\n' +
166 | 'X was {}'.format(X.shape))
167 | xe = symm_pad(r, m2)
168 |
169 | if m2 % 2 == 0:
170 | h1 = hae
171 | h2 = hbe
172 | h3 = hao
173 | h4 = hbo
174 | if highpass:
175 | X = torch.cat((X[:,:,xe[1:-2:2]], X[:,:,xe[:-2:2]], X[:,:,xe[3::2]], X[:,:,xe[2::2]]), dim=1)
176 | else:
177 | X = torch.cat((X[:,:,xe[:-2:2]], X[:,:,xe[1:-2:2]], X[:,:,xe[2::2]], X[:,:,xe[3::2]]), dim=1)
178 | else:
179 | h1 = hao
180 | h2 = hbo
181 | h3 = hae
182 | h4 = hbe
183 | if highpass:
184 | X = torch.cat((X[:,:,xe[2:-1:2]], X[:,:,xe[1:-1:2]], X[:,:,xe[2:-1:2]], X[:,:,xe[1:-1:2]]), dim=1)
185 | else:
186 | X = torch.cat((X[:,:,xe[1:-1:2]], X[:,:,xe[2:-1:2]], X[:,:,xe[1:-1:2]], X[:,:,xe[2:-1:2]]), dim=1)
187 | h = torch.cat((h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1),
188 | h3.repeat(ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)), dim=0)
189 |
190 | X = F.conv2d(X, h, groups=4*ch)
191 | # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor
192 | # [batch, ch, r2, 4, c]
193 | X = torch.stack([X[:,:ch], X[:,ch:2*ch], X[:,2*ch:3*ch], X[:,3*ch:]], dim=3).view(batch, ch, r*2, c)
194 |
195 | return X
196 |
197 |
198 | def rowifilt(X, ha, hb, highpass=False, mode='symmetric'):
199 | if X is None or X.shape == torch.Size([]):
200 | return torch.zeros(1,1,1,1, device=X.device)
201 | m = ha.shape[2]
202 | m2 = m // 2
203 | hao = ha[:,:,1::2]
204 | hae = ha[:,:,::2]
205 | hbo = hb[:,:,1::2]
206 | hbe = hb[:,:,::2]
207 | batch, ch, r, c = X.shape
208 | if c % 2 != 0:
209 | raise ValueError('No. of cols in X must be a multiple of 2.\n' +
210 | 'X was {}'.format(X.shape))
211 | xe = symm_pad(c, m2)
212 |
213 | if m2 % 2 == 0:
214 | h1 = hae
215 | h2 = hbe
216 | h3 = hao
217 | h4 = hbo
218 | if highpass:
219 | X = torch.cat((X[:,:,:,xe[1:-2:2]], X[:,:,:,xe[:-2:2]], X[:,:,:,xe[3::2]], X[:,:,:,xe[2::2]]), dim=1)
220 | else:
221 | X = torch.cat((X[:,:,:,xe[:-2:2]], X[:,:,:,xe[1:-2:2]], X[:,:,:,xe[2::2]], X[:,:,:,xe[3::2]]), dim=1)
222 | else:
223 | h1 = hao
224 | h2 = hbo
225 | h3 = hae
226 | h4 = hbe
227 | if highpass:
228 | X = torch.cat((X[:,:,:,xe[2:-1:2]], X[:,:,:,xe[1:-1:2]], X[:,:,:,xe[2:-1:2]], X[:,:,:,xe[1:-1:2]]), dim=1)
229 | else:
230 | X = torch.cat((X[:,:,:,xe[1:-1:2]], X[:,:,:,xe[2:-1:2]], X[:,:,:,xe[1:-1:2]], X[:,:,:,xe[2:-1:2]]), dim=1)
231 | h = torch.cat((h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1),
232 | h3.repeat(ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)),
233 | dim=0).reshape(4*ch, 1, 1, m2)
234 |
235 | X = F.conv2d(X, h, groups=4*ch)
236 | # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor
237 | # [batch, ch, r2, 4, c]
238 | X = torch.stack([X[:,:ch], X[:,ch:2*ch], X[:,2*ch:3*ch], X[:,3*ch:]], dim=4).view(batch, ch, r, c*2)
239 | return X
240 |
241 |
242 | # def q2c(y, dim=-1):
243 | def q2c(y, dim=-1):
244 | """
245 | Convert from quads in y to complex numbers in z.
246 | """
247 |
248 | # Arrange pixels from the corners of the quads into
249 | # 2 subimages of alternate real and imag pixels.
250 | # a----b
251 | # | |
252 | # | |
253 | # c----d
254 | # Combine (a,b) and (d,c) to form two complex subimages.
255 | y = y/np.sqrt(2)
256 | a, b = y[:,:, 0::2, 0::2], y[:,:, 0::2, 1::2]
257 | c, d = y[:,:, 1::2, 0::2], y[:,:, 1::2, 1::2]
258 |
259 | # return torch.stack((a-d, b+c), dim=dim), torch.stack((a+d, b-c), dim=dim)
260 | return ((a-d, b+c), (a+d, b-c))
261 |
262 |
263 | def c2q(w1, w2):
264 | """
265 | Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
266 | in z.
267 |
268 | Arrange pixels from the real and imag parts of the 2 highpasses
269 | into 4 separate subimages .
270 | A----B Re Im of w(:,:,1)
271 | | |
272 | | |
273 | C----D Re Im of w(:,:,2)
274 |
275 | """
276 | w1r, w1i = w1
277 | w2r, w2i = w2
278 |
279 | x1 = w1r + w2r
280 | x2 = w1i + w2i
281 | x3 = w1i - w2i
282 | x4 = -w1r + w2r
283 |
284 | # Get the shape of the tensor excluding the real/imagniary part
285 | b, ch, r, c = w1r.shape
286 |
287 | # Create new empty tensor and fill it
288 | y = w1r.new_zeros((b, ch, r*2, c*2), requires_grad=w1r.requires_grad)
289 | y[:, :, ::2,::2] = x1
290 | y[:, :, ::2, 1::2] = x2
291 | y[:, :, 1::2, ::2] = x3
292 | y[:, :, 1::2, 1::2] = x4
293 | y /= np.sqrt(2)
294 |
295 | return y
296 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/lowlevel2.py:
--------------------------------------------------------------------------------
1 | """ This module was part of an attempt to speed up the DTCWT. The code was
2 | ultimately slower than the original implementation, but it is a nice
3 | reference point for doing a DTCWT directly as 4 separate DWTs.
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 | from pytorch_wavelets.dwt.lowlevel import roll, mypad
10 | import pywt
11 | from pytorch_wavelets.dwt.transform2d import DWTForward, DWTInverse
12 | from pytorch_wavelets.dwt.lowlevel import afb2d, sfb2d_nonsep as sfb2d
13 | from pytorch_wavelets.dwt.lowlevel import prep_filt_afb2d, prep_filt_sfb2d_nonsep as prep_filt_sfb2d
14 | from pytorch_wavelets.dtcwt.coeffs import level1 as _level1, qshift as _qshift, biort as _biort
15 |
16 |
17 | class DTCWTForward2(nn.Module):
18 | """ DTCWT based on 4 DWTs. Still works, but the above implementation is
19 | faster """
20 | def __init__(self, biort='farras', qshift='qshift_a', J=3,
21 | mode='symmetric'):
22 | super().__init__()
23 | self.biort = biort
24 | self.qshift = qshift
25 | self.J = J
26 |
27 | if isinstance(biort, str):
28 | biort = _level1(biort)
29 | assert len(biort) == 8
30 | h0a1, h0b1, _, _, h1a1, h1b1, _, _ = biort
31 | DWTaa1 = DWTForward(J=1, wave=(h0a1, h1a1, h0a1, h1a1), mode=mode)
32 | DWTab1 = DWTForward(J=1, wave=(h0a1, h1a1, h0b1, h1b1), mode=mode)
33 | DWTba1 = DWTForward(J=1, wave=(h0b1, h1b1, h0a1, h1a1), mode=mode)
34 | DWTbb1 = DWTForward(J=1, wave=(h0b1, h1b1, h0b1, h1b1), mode=mode)
35 | self.level1 = nn.ModuleList([DWTaa1, DWTab1, DWTba1, DWTbb1])
36 |
37 | if J > 1:
38 | if isinstance(qshift, str):
39 | qshift = _qshift(qshift)
40 | assert len(qshift) == 8
41 | h0a, h0b, _, _, h1a, h1b, _, _ = qshift
42 | DWTaa = DWTForward(J-1, (h0a, h1a, h0a, h1a), mode=mode)
43 | DWTab = DWTForward(J-1, (h0a, h1a, h0b, h1b), mode=mode)
44 | DWTba = DWTForward(J-1, (h0b, h1b, h0a, h1a), mode=mode)
45 | DWTbb = DWTForward(J-1, (h0b, h1b, h0b, h1b), mode=mode)
46 | self.level2 = nn.ModuleList([DWTaa, DWTab, DWTba, DWTbb])
47 |
48 | def forward(self, x):
49 | x = x/2
50 | J = self.J
51 | w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)]
52 | lows = [[None for _ in range(2)] for _ in range(2)]
53 | for m in range(2):
54 | for n in range(2):
55 | # Do the first level transform
56 | ll, (w[0][m][n],) = self.level1[m*2+n](x)
57 | # w[0][m][n] = [bands[:,:,2], bands[:,:,1], bands[:,:,3]]
58 |
59 | # Do the second+ level transform with the second level filters
60 | if J > 1:
61 | ll, bands = self.level2[m*2+n](ll)
62 | for j in range(1,J):
63 | w[j][m][n] = bands[j-1]
64 | lows[m][n] = ll
65 |
66 | # Convert the quads into real and imaginary parts
67 | yh = [None,] * J
68 | for j in range(J):
69 | deg75r, deg105i = pm(w[j][0][0][:,:,1], w[j][1][1][:,:,1])
70 | deg105r, deg75i = pm(w[j][0][1][:,:,1], w[j][1][0][:,:,1])
71 | deg15r, deg165i = pm(w[j][0][0][:,:,0], w[j][1][1][:,:,0])
72 | deg165r, deg15i = pm(w[j][0][1][:,:,0], w[j][1][0][:,:,0])
73 | deg135r, deg45i = pm(w[j][0][0][:,:,2], w[j][1][1][:,:,2])
74 | deg45r, deg135i = pm(w[j][0][1][:,:,2], w[j][1][0][:,:,2])
75 | w[j] = None
76 | yhr = torch.stack((deg15r, deg45r, deg75r,
77 | deg105r, deg135r, deg165r), dim=1)
78 | yhi = torch.stack((deg15i, deg45i, deg75i,
79 | deg105i, deg135i, deg165i), dim=1)
80 | yh[j] = torch.stack((yhr, yhi), dim=-1)
81 |
82 | return lows, yh
83 |
84 |
85 | class DTCWTInverse2(nn.Module):
86 | def __init__(self, biort='farras', qshift='qshift_a',
87 | mode='symmetric'):
88 | super().__init__()
89 | self.biort = biort
90 | self.qshift = qshift
91 |
92 | if isinstance(biort, str):
93 | biort = _level1(biort)
94 | assert len(biort) == 8
95 | _, _, g0a1, g0b1, _, _, g1a1, g1b1 = biort
96 | IWTaa1 = DWTInverse(wave=(g0a1, g1a1, g0a1, g1a1), mode=mode)
97 | IWTab1 = DWTInverse(wave=(g0a1, g1a1, g0b1, g1b1), mode=mode)
98 | IWTba1 = DWTInverse(wave=(g0b1, g1b1, g0a1, g1a1), mode=mode)
99 | IWTbb1 = DWTInverse(wave=(g0b1, g1b1, g0b1, g1b1), mode=mode)
100 | self.level1 = nn.ModuleList([IWTaa1, IWTab1, IWTba1, IWTbb1])
101 |
102 | if isinstance(qshift, str):
103 | qshift = _qshift(qshift)
104 | assert len(qshift) == 8
105 | _, _, g0a, g0b, _, _, g1a, g1b = qshift
106 | IWTaa = DWTInverse(wave=(g0a, g1a, g0a, g1a), mode=mode)
107 | IWTab = DWTInverse(wave=(g0a, g1a, g0b, g1b), mode=mode)
108 | IWTba = DWTInverse(wave=(g0b, g1b, g0a, g1a), mode=mode)
109 | IWTbb = DWTInverse(wave=(g0b, g1b, g0b, g1b), mode=mode)
110 | self.level2 = nn.ModuleList([IWTaa, IWTab, IWTba, IWTbb])
111 |
112 | def forward(self, x):
113 | # Convert the highs back to subbands
114 | yl, yh = x
115 | J = len(yh)
116 | # w = [[[[None for i in range(3)] for j in range(2)]
117 | # for k in range(2)] for l in range(J)]
118 | w = [[[[None for band in range(3)] for j in range(J)]
119 | for m in range(2)] for n in range(2)]
120 | for j in range(J):
121 | w[0][0][j][0], w[1][1][j][0] = pm(
122 | yh[j][:,2,:,:,:,0], yh[j][:,3,:,:,:,1])
123 | w[0][1][j][0], w[1][0][j][0] = pm(
124 | yh[j][:,3,:,:,:,0], yh[j][:,2,:,:,:,1])
125 | w[0][0][j][1], w[1][1][j][1] = pm(
126 | yh[j][:,0,:,:,:,0], yh[j][:,5,:,:,:,1])
127 | w[0][1][j][1], w[1][0][j][1] = pm(
128 | yh[j][:,5,:,:,:,0], yh[j][:,0,:,:,:,1])
129 | w[0][0][j][2], w[1][1][j][2] = pm(
130 | yh[j][:,1,:,:,:,0], yh[j][:,4,:,:,:,1])
131 | w[0][1][j][2], w[1][0][j][2] = pm(
132 | yh[j][:,4,:,:,:,0], yh[j][:,1,:,:,:,1])
133 | w[0][0][j] = torch.stack(w[0][0][j], dim=2)
134 | w[0][1][j] = torch.stack(w[0][1][j], dim=2)
135 | w[1][0][j] = torch.stack(w[1][0][j], dim=2)
136 | w[1][1][j] = torch.stack(w[1][1][j], dim=2)
137 |
138 | y = None
139 | for m in range(2):
140 | for n in range(2):
141 | lo = yl[m][n]
142 | if J > 1:
143 | lo = self.level2[m*2+n]((lo, w[m][n][1:]))
144 | lo = self.level1[m*2+n]((lo, (w[m][n][0],)))
145 |
146 | # Add to the output
147 | if y is None:
148 | y = lo
149 | else:
150 | y = y + lo
151 |
152 | # Normalize
153 | y = y/2
154 | return y
155 |
156 |
157 | def prep_filt_quad_afb2d_nonsep(
158 | h0a_col, h1a_col, h0a_row, h1a_row,
159 | h0b_col, h1b_col, h0b_row, h1b_row,
160 | h0c_col, h1c_col, h0c_row, h1c_row,
161 | h0d_col, h1d_col, h0d_row, h1d_row, device=None):
162 | """
163 | Prepares the filters to be of the right form for the afb2d_nonsep function.
164 | In particular, makes 2d point spread functions, and mirror images them in
165 | preparation to do torch.conv2d.
166 |
167 | Inputs:
168 | h0_col (array-like): low pass column filter bank
169 | h1_col (array-like): high pass column filter bank
170 | h0_row (array-like): low pass row filter bank. If none, will assume the
171 | same as column filter
172 | h1_row (array-like): high pass row filter bank. If none, will assume the
173 | same as column filter
174 | device: which device to put the tensors on to
175 |
176 | Returns:
177 | filts: (4, 1, h, w) tensor ready to get the four subbands
178 | """
179 | lla = np.outer(h0a_col, h0a_row)
180 | lha = np.outer(h1a_col, h0a_row)
181 | hla = np.outer(h0a_col, h1a_row)
182 | hha = np.outer(h1a_col, h1a_row)
183 | llb = np.outer(h0b_col, h0b_row)
184 | lhb = np.outer(h1b_col, h0b_row)
185 | hlb = np.outer(h0b_col, h1b_row)
186 | hhb = np.outer(h1b_col, h1b_row)
187 | llc = np.outer(h0c_col, h0c_row)
188 | lhc = np.outer(h1c_col, h0c_row)
189 | hlc = np.outer(h0c_col, h1c_row)
190 | hhc = np.outer(h1c_col, h1c_row)
191 | lld = np.outer(h0d_col, h0d_row)
192 | lhd = np.outer(h1d_col, h0d_row)
193 | hld = np.outer(h0d_col, h1d_row)
194 | hhd = np.outer(h1d_col, h1d_row)
195 | filts = np.stack([lla[None,::-1,::-1], llb[None,::-1,::-1],
196 | llc[None,::-1,::-1], lld[None,::-1,::-1],
197 | lha[None,::-1,::-1], lhb[None,::-1,::-1],
198 | lhc[None,::-1,::-1], lhd[None,::-1,::-1],
199 | hla[None,::-1,::-1], hlb[None,::-1,::-1],
200 | hlc[None,::-1,::-1], hld[None,::-1,::-1],
201 | hha[None,::-1,::-1], hhb[None,::-1,::-1],
202 | hhc[None,::-1,::-1], hhd[None,::-1,::-1]],
203 | axis=0)
204 | filts = torch.tensor(filts, dtype=torch.get_default_dtype(), device=device)
205 | return filts
206 |
207 |
208 | def prep_filt_quad_afb2d(h0a, h1a, h0b, h1b, device=None):
209 | """
210 | Prepares the filters to be of the right form for the quad_afb2d function.
211 |
212 | Inputs:
213 | h0_col (array-like): low pass column filter bank
214 | h1_col (array-like): high pass column filter bank
215 | h0_row (array-like): low pass row filter bank. If none, will assume the
216 | same as column filter
217 | h1_row (array-like): high pass row filter bank. If none, will assume the
218 | same as column filter
219 | device: which device to put the tensors on to
220 |
221 | Returns:
222 | filts: (4, 1, h, w) tensor ready to get the four subbands
223 | """
224 | h0a_col = np.array(h0a).ravel()[::-1][None, :, None]
225 | h1a_col = np.array(h1a).ravel()[::-1][None, :, None]
226 | h0b_col = np.array(h0a).ravel()[::-1][None, :, None]
227 | h1b_col = np.array(h1a).ravel()[::-1][None, :, None]
228 | h0c_col = np.array(h0b).ravel()[::-1][None, :, None]
229 | h1c_col = np.array(h1b).ravel()[::-1][None, :, None]
230 | h0d_col = np.array(h0b).ravel()[::-1][None, :, None]
231 | h1d_col = np.array(h1b).ravel()[::-1][None, :, None]
232 | h0a_row = np.array(h0a).ravel()[::-1][None, None, :]
233 | h1a_row = np.array(h1a).ravel()[::-1][None, None, :]
234 | h0b_row = np.array(h0b).ravel()[::-1][None, None, :]
235 | h1b_row = np.array(h1b).ravel()[::-1][None, None, :]
236 | h0c_row = np.array(h0a).ravel()[::-1][None, None, :]
237 | h1c_row = np.array(h1a).ravel()[::-1][None, None, :]
238 | h0d_row = np.array(h0b).ravel()[::-1][None, None, :]
239 | h1d_row = np.array(h1b).ravel()[::-1][None, None, :]
240 | cols = np.stack((h0a_col, h1a_col,
241 | h0b_col, h1b_col,
242 | h0c_col, h1c_col,
243 | h0d_col, h1d_col), axis=0)
244 | rows = np.stack((h0a_row, h1a_row,
245 | h0a_row, h1a_row,
246 | h0b_row, h1b_row,
247 | h0b_row, h1b_row,
248 | h0c_row, h1c_row,
249 | h0c_row, h1c_row,
250 | h0d_row, h1d_row,
251 | h0d_row, h1d_row), axis=0)
252 | cols = torch.tensor(np.copy(cols), dtype=torch.get_default_dtype(),
253 | device=device)
254 | rows = torch.tensor(np.copy(rows), dtype=torch.get_default_dtype(),
255 | device=device)
256 | return cols, rows
257 |
258 |
259 | def quad_afb2d(x, cols, rows, mode='zero', split=True, stride=2):
260 | """ Does a single level 2d wavelet decomposition of an input. Does separate
261 | row and column filtering by two calls to
262 | :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
263 |
264 | Inputs:
265 | x (torch.Tensor): Input to decompose
266 | filts (list of ndarray or torch.Tensor): If a list of tensors has been
267 | given, this function assumes they are in the right form (the form
268 | returned by
269 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`).
270 | Otherwise, this function will prepare the filters to be of the right
271 | form by calling
272 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`.
273 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
274 | padding to use. If periodization, the output size will be half the
275 | input size. Otherwise, the output size will be slightly larger than
276 | half.
277 | """
278 | x = x/2
279 | C = x.shape[1]
280 | cols = torch.cat([cols]*C, dim=0)
281 | rows = torch.cat([rows]*C, dim=0)
282 |
283 | if mode == 'per' or mode == 'periodization':
284 | # Do column filtering
285 | L = cols.shape[2]
286 | L2 = L // 2
287 | if x.shape[2] % 2 == 1:
288 | x = torch.cat((x, x[:,:,-1:]), dim=2)
289 | N2 = x.shape[2] // 2
290 | x = roll(x, -L2, dim=2)
291 | pad = (L-1, 0)
292 | lohi = F.conv2d(x, cols, padding=pad, stride=(stride,1), groups=C)
293 | lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2]
294 | lohi = lohi[:,:,:N2]
295 |
296 | # Do row filtering
297 | L = rows.shape[3]
298 | L2 = L // 2
299 | if lohi.shape[3] % 2 == 1:
300 | lohi = torch.cat((lohi, lohi[:,:,:,-1:]), dim=3)
301 | N2 = x.shape[3] // 2
302 | lohi = roll(lohi, -L2, dim=3)
303 | pad = (0, L-1)
304 | w = F.conv2d(lohi, rows, padding=pad, stride=(1,stride), groups=8*C)
305 | w[:,:,:,:L2] = w[:,:,:,:L2] + w[:,:,:,N2:N2+L2]
306 | w = w[:,:,:,:N2]
307 | elif mode == 'zero':
308 | # Do column filtering
309 | N = x.shape[2]
310 | L = cols.shape[2]
311 | outsize = pywt.dwt_coeff_len(N, L, mode='zero')
312 | p = 2 * (outsize - 1) - N + L
313 |
314 | # Sadly, pytorch only allows for same padding before and after, if
315 | # we need to do more padding after for odd length signals, have to
316 | # prepad
317 | if p % 2 == 1:
318 | x = F.pad(x, (0, 0, 0, 1))
319 | pad = (p//2, 0)
320 | # Calculate the high and lowpass
321 | lohi = F.conv2d(x, cols, padding=pad, stride=(stride,1), groups=C)
322 |
323 | # Do row filtering
324 | N = lohi.shape[3]
325 | L = rows.shape[3]
326 | outsize = pywt.dwt_coeff_len(N, L, mode='zero')
327 | p = 2 * (outsize - 1) - N + L
328 | if p % 2 == 1:
329 | lohi = F.pad(lohi, (0, 1, 0, 0))
330 | pad = (0, p//2)
331 | w = F.conv2d(lohi, rows, padding=pad, stride=(1,stride), groups=8*C)
332 | elif mode == 'symmetric' or mode == 'reflect':
333 | # Do column filtering
334 | N = x.shape[2]
335 | L = cols.shape[2]
336 | outsize = pywt.dwt_coeff_len(N, L, mode=mode)
337 | p = 2 * (outsize - 1) - N + L
338 | x = mypad(x, pad=(0, 0, p//2, (p+1)//2), mode=mode)
339 | lohi = F.conv2d(x, cols, stride=(stride,1), groups=C)
340 |
341 | # Do row filtering
342 | N = lohi.shape[3]
343 | L = rows.shape[3]
344 | outsize = pywt.dwt_coeff_len(N, L, mode=mode)
345 | p = 2 * (outsize - 1) - N + L
346 | lohi = mypad(lohi, pad=(p//2, (p+1)//2, 0, 0), mode=mode)
347 | w = F.conv2d(lohi, rows, stride=(1,stride), groups=8*C)
348 | else:
349 | raise ValueError("Unkown pad type: {}".format(mode))
350 |
351 | y = w.view((w.shape[0], C, 4, 4, w.shape[-2], w.shape[-1]))
352 | yl = y[:,:,:,0]
353 | yh = y[:,:,:,1:]
354 | deg75r, deg105i = pm(yh[:,:,0,0], yh[:,:,3,0])
355 | deg105r, deg75i = pm(yh[:,:,1,0], yh[:,:,2,0])
356 | deg15r, deg165i = pm(yh[:,:,0,1], yh[:,:,3,1])
357 | deg165r, deg15i = pm(yh[:,:,1,1], yh[:,:,2,1])
358 | deg135r, deg45i = pm(yh[:,:,0,2], yh[:,:,3,2])
359 | deg45r, deg135i = pm(yh[:,:,1,2], yh[:,:,2,2])
360 | yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1)
361 | yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1)
362 | yh = torch.stack((yhr, yhi), dim=-1)
363 |
364 | yl_rowa = torch.stack((yl[:,:,1], yl[:,:,0]), dim=-1)
365 | yl_rowb = torch.stack((yl[:,:,3], yl[:,:,2]), dim=-1)
366 | yl_rowa = yl_rowa.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1]*2)
367 | yl_rowb = yl_rowb.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1]*2)
368 | z = torch.stack((yl_rowb, yl_rowa), dim=-2)
369 | yl = z.view(yl.shape[0], C, yl.shape[-2]*2, yl.shape[-1]*2)
370 |
371 | return yl.contiguous(), yh
372 |
373 |
374 | def quad_afb2d_nonsep(x, filts, mode='zero'):
375 | """ Does a 1 level 2d wavelet decomposition of an input. Doesn't do separate
376 | row and column filtering.
377 |
378 | Inputs:
379 | x (torch.Tensor): Input to decompose
380 | filts (list or torch.Tensor): If a list is given, should be the low and
381 | highpass filter banks. If a tensor is given, it should be of the
382 | form created by
383 | :py:func:`pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d_nonsep`
384 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
385 | padding to use. If periodization, the output size will be half the
386 | input size. Otherwise, the output size will be slightly larger than
387 | half.
388 | """
389 | C = x.shape[1]
390 | Ny = x.shape[2]
391 | Nx = x.shape[3]
392 |
393 | # Check the filter inputs
394 | f = torch.cat([filts]*C, dim=0)
395 | Ly = f.shape[2]
396 | Lx = f.shape[3]
397 |
398 | if mode == 'periodization' or mode == 'per':
399 | if x.shape[2] % 2 == 1:
400 | x = torch.cat((x, x[:,:,-1:]), dim=2)
401 | Ny += 1
402 | if x.shape[3] % 2 == 1:
403 | x = torch.cat((x, x[:,:,:,-1:]), dim=3)
404 | Nx += 1
405 | pad = (Ly-1, Lx-1)
406 | stride = (2, 2)
407 | x = roll(roll(x, -Ly//2, dim=2), -Lx//2, dim=3)
408 | y = F.conv2d(x, f, padding=pad, stride=stride, groups=C)
409 | y[:,:,:Ly//2] += y[:,:,Ny//2:Ny//2+Ly//2]
410 | y[:,:,:,:Lx//2] += y[:,:,:,Nx//2:Nx//2+Lx//2]
411 | y = y[:,:,:Ny//2, :Nx//2]
412 | elif mode == 'zero' or mode == 'symmetric' or mode == 'reflect':
413 | # Calculate the pad size
414 | out1 = pywt.dwt_coeff_len(Ny, Ly, mode=mode)
415 | out2 = pywt.dwt_coeff_len(Nx, Lx, mode=mode)
416 | p1 = 2 * (out1 - 1) - Ny + Ly
417 | p2 = 2 * (out2 - 1) - Nx + Lx
418 | if mode == 'zero':
419 | # Sadly, pytorch only allows for same padding before and after, if
420 | # we need to do more padding after for odd length signals, have to
421 | # prepad
422 | if p1 % 2 == 1 and p2 % 2 == 1:
423 | x = F.pad(x, (0, 1, 0, 1))
424 | elif p1 % 2 == 1:
425 | x = F.pad(x, (0, 0, 0, 1))
426 | elif p2 % 2 == 1:
427 | x = F.pad(x, (0, 1, 0, 0))
428 | # Calculate the high and lowpass
429 | y = F.conv2d(
430 | x, f, padding=(p1//2, p2//2), stride=2, groups=C)
431 | elif mode == 'symmetric' or mode == 'reflect':
432 | pad = (p2//2, (p2+1)//2, p1//2, (p1+1)//2)
433 | x = mypad(x, pad=pad, mode=mode)
434 | y = F.conv2d(x, f, stride=2, groups=C)
435 | else:
436 | raise ValueError("Unkown pad type: {}".format(mode))
437 |
438 | y = y.reshape((y.shape[0], C, 4, y.shape[-2], y.shape[-1]))
439 | yl = y[:,:,0].contiguous()
440 | yh = y[:,:,1:].contiguous()
441 | return yl, yh
442 |
443 |
444 | def cplxdual2D(x, J, level1='farras', qshift='qshift_a', mode='periodization',
445 | mag=False):
446 | """ Do a complex dtcwt
447 |
448 | Returns:
449 | lows: lowpass outputs from each of the 4 trees. Is a 2x2 list of lists
450 | w: bandpass outputs from each of the 4 trees. Is a list of lists, with
451 | shape [J][2][2][3]. Initially the 3 outputs are the lh, hl and hh from
452 | each of the 4 trees. After doing sums and differences though, they
453 | become the real and imaginary parts for the 6 orientations. In
454 | particular:
455 | first index - indexes over scales
456 | second index - 0 = real, 1 = imaginary
457 | third and fourth indices:
458 | 0,1 = 15 degrees
459 | 1,2 = 45 degrees
460 | 0,0 = 75 degrees
461 | 1,0 = 105 degrees
462 | 0,2 = 135 degrees
463 | 1,1 = 165 degrees
464 | """
465 | x = x/2
466 | # Get the filters
467 | h0a1, h0b1, _, _, h1a1, h1b1, _, _ = _level1(level1)
468 | h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
469 |
470 | Faf = ((prep_filt_afb2d(h0a1, h1a1, h0a1, h1a1, device=x.device),
471 | prep_filt_afb2d(h0a1, h1a1, h0b1, h1b1, device=x.device)),
472 | (prep_filt_afb2d(h0b1, h1b1, h0a1, h1a1, device=x.device),
473 | prep_filt_afb2d(h0b1, h1b1, h0b1, h1b1, device=x.device)))
474 | af = ((prep_filt_afb2d(h0a, h1a, h0a, h1a, device=x.device),
475 | prep_filt_afb2d(h0a, h1a, h0b, h1b, device=x.device)),
476 | (prep_filt_afb2d(h0b, h1b, h0a, h1a, device=x.device),
477 | prep_filt_afb2d(h0b, h1b, h0b, h1b, device=x.device)))
478 |
479 | # Do 4 fully decimated dwts
480 | w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)]
481 | lows = [[None for _ in range(2)] for _ in range(2)]
482 | for m in range(2):
483 | for n in range(2):
484 | # Do the first level transform with the first level filters
485 | # ll, bands = afb2d(x, (Faf[m][0], Faf[m][1], Faf[n][0], Faf[n][1]), mode=mode)
486 | bands = afb2d(x, Faf[m][n], mode=mode)
487 | # Separate the low and bandpasses
488 | s = bands.shape
489 | bands = bands.reshape(s[0], -1, 4, s[-2], s[-1])
490 | ll = bands[:,:,0].contiguous()
491 | w[0][m][n] = [bands[:,:,2], bands[:,:,1], bands[:,:,3]]
492 |
493 | # Do the second+ level transform with the second level filters
494 | for j in range(1,J):
495 | # ll, bands = afb2d(ll, (af[m][0], af[m][1], af[n][0], af[n][1]), mode=mode)
496 | bands = afb2d(ll, af[m][n], mode=mode)
497 | # Separate the low and bandpasses
498 | s = bands.shape
499 | bands = bands.reshape(s[0], -1, 4, s[-2], s[-1])
500 | ll = bands[:,:,0].contiguous()
501 | w[j][m][n] = [bands[:,:,2], bands[:,:,1], bands[:,:,3]]
502 | lows[m][n] = ll
503 |
504 | # Convert the quads into real and imaginary parts
505 | yh = [None,] * J
506 | for j in range(J):
507 | deg75r, deg105i = pm(w[j][0][0][0], w[j][1][1][0])
508 | deg105r, deg75i = pm(w[j][0][1][0], w[j][1][0][0])
509 | deg15r, deg165i = pm(w[j][0][0][1], w[j][1][1][1])
510 | deg165r, deg15i = pm(w[j][0][1][1], w[j][1][0][1])
511 | deg135r, deg45i = pm(w[j][0][0][2], w[j][1][1][2])
512 | deg45r, deg135i = pm(w[j][0][1][2], w[j][1][0][2])
513 | yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1)
514 | yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1)
515 | if mag:
516 | yh[j] = torch.sqrt(yhr**2 + yhi**2 + 0.01) - np.sqrt(0.01)
517 | else:
518 | yh[j] = torch.stack((yhr, yhi), dim=-1)
519 |
520 | return lows, yh
521 |
522 |
523 | def icplxdual2D(yl, yh, level1='farras', qshift='qshift_a', mode='periodization'):
524 | # Get the filters
525 | _, _, g0a1, g0b1, _, _, g1a1, g1b1 = _level1(level1)
526 | _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)
527 |
528 | dev = yl[0][0].device
529 | Faf = ((prep_filt_sfb2d(g0a1, g1a1, g0a1, g1a1, device=dev),
530 | prep_filt_sfb2d(g0a1, g1a1, g0b1, g1b1, device=dev)),
531 | (prep_filt_sfb2d(g0b1, g1b1, g0a1, g1a1, device=dev),
532 | prep_filt_sfb2d(g0b1, g1b1, g0b1, g1b1, device=dev)))
533 | af = ((prep_filt_sfb2d(g0a, g1a, g0a, g1a, device=dev),
534 | prep_filt_sfb2d(g0a, g1a, g0b, g1b, device=dev)),
535 | (prep_filt_sfb2d(g0b, g1b, g0a, g1a, device=dev),
536 | prep_filt_sfb2d(g0b, g1b, g0b, g1b, device=dev)))
537 |
538 | # Convert the highs back to subbands
539 | J = len(yh)
540 | w = [[[[None for i in range(3)] for j in range(2)] for k in range(2)] for l in range(J)]
541 | for j in range(J):
542 | w[j][0][0][0], w[j][1][1][0] = pm(yh[j][:,2,:,:,:,0],
543 | yh[j][:,3,:,:,:,1])
544 | w[j][0][1][0], w[j][1][0][0] = pm(yh[j][:,3,:,:,:,0],
545 | yh[j][:,2,:,:,:,1])
546 | w[j][0][0][1], w[j][1][1][1] = pm(yh[j][:,0,:,:,:,0],
547 | yh[j][:,5,:,:,:,1])
548 | w[j][0][1][1], w[j][1][0][1] = pm(yh[j][:,5,:,:,:,0],
549 | yh[j][:,0,:,:,:,1])
550 | w[j][0][0][2], w[j][1][1][2] = pm(yh[j][:,1,:,:,:,0],
551 | yh[j][:,4,:,:,:,1])
552 | w[j][0][1][2], w[j][1][0][2] = pm(yh[j][:,4,:,:,:,0],
553 | yh[j][:,1,:,:,:,1])
554 | w[j][0][0] = torch.stack(w[j][0][0], dim=2)
555 | w[j][0][1] = torch.stack(w[j][0][1], dim=2)
556 | w[j][1][0] = torch.stack(w[j][1][0], dim=2)
557 | w[j][1][1] = torch.stack(w[j][1][1], dim=2)
558 |
559 | y = None
560 | for m in range(2):
561 | for n in range(2):
562 | lo = yl[m][n]
563 | for j in range(J-1, 0, -1):
564 | lo = sfb2d(lo, w[j][m][n], af[m][n], mode=mode)
565 | lo = sfb2d(lo, w[0][m][n], Faf[m][n], mode=mode)
566 |
567 | # Add to the output
568 | if y is None:
569 | y = lo
570 | else:
571 | y = y + lo
572 |
573 | # Normalize
574 | y = y/2
575 | return y
576 |
577 |
578 | def pm(a, b):
579 | u = (a + b)/np.sqrt(2)
580 | v = (a - b)/np.sqrt(2)
581 | return u, v
582 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/transform2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from numpy import ndarray, sqrt
4 |
5 | from pytorch_wavelets.dtcwt.coeffs import qshift as _qshift, biort as _biort, level1
6 | from pytorch_wavelets.dtcwt.lowlevel import prep_filt
7 | from pytorch_wavelets.dtcwt.transform_funcs import FWD_J1, FWD_J2PLUS
8 | from pytorch_wavelets.dtcwt.transform_funcs import INV_J1, INV_J2PLUS
9 | from pytorch_wavelets.dtcwt.transform_funcs import get_dimensions6
10 | from pytorch_wavelets.dwt.lowlevel import mode_to_int
11 | from pytorch_wavelets.dwt.transform2d import DWTForward, DWTInverse
12 |
13 |
14 | def pm(a, b):
15 | u = (a + b)/sqrt(2)
16 | v = (a - b)/sqrt(2)
17 | return u, v
18 |
19 |
20 | class DTCWTForward(nn.Module):
21 | """ Performs a 2d DTCWT Forward decomposition of an image
22 |
23 | Args:
24 | biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'.
25 | Specifies the first level biorthogonal wavelet filters. Can also
26 | give a two tuple for the low and highpass filters directly.
27 | qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c',
28 | 'qshift_d'. Specifies the second level quarter shift filters. Can
29 | also give a 4-tuple for the low tree a, low tree b, high tree a and
30 | high tree b filters directly.
31 | J (int): Number of levels of decomposition
32 | skip_hps (bools): List of bools of length J which specify whether or
33 | not to calculate the bandpass outputs at the given scale.
34 | skip_hps[0] is for the first scale. Can be a single bool in which
35 | case that is applied to all scales.
36 | include_scale (bool): If true, return the bandpass outputs. Can also be
37 | a list of length J specifying which lowpasses to return. I.e. if
38 | [False, True, True], the forward call will return the second and
39 | third lowpass outputs, but discard the lowpass from the first level
40 | transform.
41 | o_dim (int): Which dimension to put the orientations in
42 | ri_dim (int): which dimension to put the real and imaginary parts
43 | """
44 | def __init__(self, biort='near_sym_a', qshift='qshift_a',
45 | J=3, skip_hps=False, include_scale=False,
46 | o_dim=2, ri_dim=-1, mode='symmetric'):
47 | super().__init__()
48 | if o_dim == ri_dim:
49 | raise ValueError("Orientations and real/imaginary parts must be "
50 | "in different dimensions.")
51 |
52 | self.biort = biort
53 | self.qshift = qshift
54 | self.J = J
55 | self.o_dim = o_dim
56 | self.ri_dim = ri_dim
57 | self.mode = mode
58 | if isinstance(biort, str):
59 | h0o, _, h1o, _ = _biort(biort)
60 | self.register_buffer('h0o', prep_filt(h0o, 1))
61 | self.register_buffer('h1o', prep_filt(h1o, 1))
62 | else:
63 | self.register_buffer('h0o', prep_filt(biort[0], 1))
64 | self.register_buffer('h1o', prep_filt(biort[1], 1))
65 | if isinstance(qshift, str):
66 | h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
67 | self.register_buffer('h0a', prep_filt(h0a, 1))
68 | self.register_buffer('h0b', prep_filt(h0b, 1))
69 | self.register_buffer('h1a', prep_filt(h1a, 1))
70 | self.register_buffer('h1b', prep_filt(h1b, 1))
71 | else:
72 | self.register_buffer('h0a', prep_filt(qshift[0], 1))
73 | self.register_buffer('h0b', prep_filt(qshift[1], 1))
74 | self.register_buffer('h1a', prep_filt(qshift[2], 1))
75 | self.register_buffer('h1b', prep_filt(qshift[3], 1))
76 |
77 | # Get the function to do the DTCWT
78 | if isinstance(skip_hps, (list, tuple, ndarray)):
79 | self.skip_hps = skip_hps
80 | else:
81 | self.skip_hps = [skip_hps,] * self.J
82 | if isinstance(include_scale, (list, tuple, ndarray)):
83 | self.include_scale = include_scale
84 | else:
85 | self.include_scale = [include_scale,] * self.J
86 |
87 | def forward(self, x):
88 | """ Forward Dual Tree Complex Wavelet Transform
89 |
90 | Args:
91 | x (tensor): Input to transform. Should be of shape
92 | :math:`(N, C_{in}, H_{in}, W_{in})`.
93 |
94 | Returns:
95 | (yl, yh)
96 | tuple of lowpass (yl) and bandpass (yh) coefficients.
97 | If include_scale was true, yl will be a list of lowpass
98 | coefficients, otherwise will be just the final lowpass
99 | coefficient of shape :math:`(N, C_{in}, H_{in}', W_{in}')`. Yh
100 | will be a list of the complex bandpass coefficients of shape
101 | :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar
102 | shape depending on o_dim and ri_dim
103 |
104 | Note:
105 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a
106 | DTCWT pyramid.
107 | """
108 | scales = [x.new_zeros([]),] * self.J
109 | highs = [x.new_zeros([]),] * self.J
110 | mode = mode_to_int(self.mode)
111 | if self.J == 0:
112 | return x, None
113 |
114 | # If the row/col count of X is not divisible by 2 then we need to
115 | # extend X
116 | r, c = x.shape[2:]
117 | if r % 2 != 0:
118 | x = torch.cat((x, x[:,:,-1:]), dim=2)
119 | if c % 2 != 0:
120 | x = torch.cat((x, x[:,:,:,-1:]), dim=3)
121 |
122 | # Do the level 1 transform
123 | low, h = FWD_J1.apply(x, self.h0o, self.h1o, self.skip_hps[0],
124 | self.o_dim, self.ri_dim, mode)
125 | highs[0] = h
126 | if self.include_scale[0]:
127 | scales[0] = low
128 |
129 | for j in range(1, self.J):
130 | # Ensure the lowpass is divisible by 4
131 | r, c = low.shape[2:]
132 | if r % 4 != 0:
133 | low = torch.cat((low[:,:,0:1], low, low[:,:,-1:]), dim=2)
134 | if c % 4 != 0:
135 | low = torch.cat((low[:,:,:,0:1], low, low[:,:,:,-1:]), dim=3)
136 |
137 | low, h = FWD_J2PLUS.apply(low, self.h0a, self.h1a, self.h0b,
138 | self.h1b, self.skip_hps[j], self.o_dim,
139 | self.ri_dim, mode)
140 | highs[j] = h
141 | if self.include_scale[j]:
142 | scales[j] = low
143 |
144 | if True in self.include_scale:
145 | return scales, highs
146 | else:
147 | return low, highs
148 |
149 |
150 | class DTCWTInverse(nn.Module):
151 | """ 2d DTCWT Inverse
152 |
153 | Args:
154 | biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'.
155 | Specifies the first level biorthogonal wavelet filters. Can also
156 | give a two tuple for the low and highpass filters directly.
157 | qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c',
158 | 'qshift_d'. Specifies the second level quarter shift filters. Can
159 | also give a 4-tuple for the low tree a, low tree b, high tree a and
160 | high tree b filters directly.
161 | J (int): Number of levels of decomposition.
162 | o_dim (int):which dimension the orientations are in
163 | ri_dim (int): which dimension to put th real and imaginary parts in
164 | """
165 |
166 | def __init__(self, biort='near_sym_a', qshift='qshift_a', o_dim=2,
167 | ri_dim=-1, mode='symmetric'):
168 | super().__init__()
169 | self.biort = biort
170 | self.qshift = qshift
171 | self.o_dim = o_dim
172 | self.ri_dim = ri_dim
173 | self.mode = mode
174 | if isinstance(biort, str):
175 | _, g0o, _, g1o = _biort(biort)
176 | self.register_buffer('g0o', prep_filt(g0o, 1))
177 | self.register_buffer('g1o', prep_filt(g1o, 1))
178 | else:
179 | self.register_buffer('g0o', prep_filt(biort[0], 1))
180 | self.register_buffer('g1o', prep_filt(biort[1], 1))
181 | if isinstance(qshift, str):
182 | _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)
183 | self.register_buffer('g0a', prep_filt(g0a, 1))
184 | self.register_buffer('g0b', prep_filt(g0b, 1))
185 | self.register_buffer('g1a', prep_filt(g1a, 1))
186 | self.register_buffer('g1b', prep_filt(g1b, 1))
187 | else:
188 | self.register_buffer('g0a', prep_filt(qshift[0], 1))
189 | self.register_buffer('g0b', prep_filt(qshift[1], 1))
190 | self.register_buffer('g1a', prep_filt(qshift[2], 1))
191 | self.register_buffer('g1b', prep_filt(qshift[3], 1))
192 |
193 | def forward(self, coeffs):
194 | """
195 | Args:
196 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
197 | yl is a tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')`
198 | and yh is a list of the complex bandpass coefficients of shape
199 | :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar
200 | depending on o_dim and ri_dim
201 |
202 | Returns:
203 | Reconstructed output
204 |
205 | Note:
206 | Can accept Nones or an empty tensor (torch.tensor([])) for the
207 | lowpass or bandpass inputs. In this cases, an array of zeros
208 | replaces that input.
209 |
210 | Note:
211 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a
212 | DTCWT pyramid.
213 |
214 | Note:
215 | If include_scale was true for the forward pass, you should provide
216 | only the final lowpass output here, as normal for an inverse wavelet
217 | transform.
218 | """
219 | low, highs = coeffs
220 | J = len(highs)
221 | mode = mode_to_int(self.mode)
222 | _, _, h_dim, w_dim = get_dimensions6(
223 | self.o_dim, self.ri_dim)
224 | for j, s in zip(range(J-1, 0, -1), highs[1:][::-1]):
225 | if s is not None and s.shape != torch.Size([]):
226 | assert s.shape[self.o_dim] == 6, "Inverse transform must " \
227 | "have input with 6 orientations"
228 | assert len(s.shape) == 6, "Bandpass inputs must have " \
229 | "6 dimensions"
230 | assert s.shape[self.ri_dim] == 2, "Inputs must be complex " \
231 | "with real and imaginary parts in the ri dimension"
232 | # Ensure the low and highpass are the right size
233 | r, c = low.shape[2:]
234 | r1, c1 = s.shape[h_dim], s.shape[w_dim]
235 | if r != r1 * 2:
236 | low = low[:,:,1:-1]
237 | if c != c1 * 2:
238 | low = low[:,:,:,1:-1]
239 |
240 | low = INV_J2PLUS.apply(low, s, self.g0a, self.g1a, self.g0b,
241 | self.g1b, self.o_dim, self.ri_dim, mode)
242 |
243 | # Ensure the low and highpass are the right size
244 | if highs[0] is not None and highs[0].shape != torch.Size([]):
245 | r, c = low.shape[2:]
246 | r1, c1 = highs[0].shape[h_dim], highs[0].shape[w_dim]
247 | if r != r1 * 2:
248 | low = low[:,:,1:-1]
249 | if c != c1 * 2:
250 | low = low[:,:,:,1:-1]
251 |
252 | low = INV_J1.apply(low, highs[0], self.g0o, self.g1o, self.o_dim,
253 | self.ri_dim, mode)
254 | return low
255 |
256 |
257 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dtcwt/transform_funcs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import tensor
3 | from torch.autograd import Function
4 | from pytorch_wavelets.dtcwt.lowlevel import colfilter, rowfilter
5 | from pytorch_wavelets.dtcwt.lowlevel import coldfilt, rowdfilt
6 | from pytorch_wavelets.dtcwt.lowlevel import colifilt, rowifilt, q2c, c2q
7 | from pytorch_wavelets.dwt.lowlevel import int_to_mode
8 |
9 |
10 | def get_dimensions5(o_dim, ri_dim):
11 | """ Get the orientation, height and width dimensions after the real and
12 | imaginary parts have been popped off (5 dimensional tensor)."""
13 | o_dim = (o_dim % 6)
14 | ri_dim = (ri_dim % 6)
15 |
16 | if ri_dim < o_dim:
17 | o_dim -= 1
18 |
19 | if o_dim == 4:
20 | h_dim = 2
21 | w_dim = 3
22 | elif o_dim == 3:
23 | h_dim = 2
24 | w_dim = 4
25 | else:
26 | h_dim = 3
27 | w_dim = 4
28 |
29 | return o_dim, ri_dim, h_dim, w_dim
30 |
31 |
32 | def get_dimensions6(o_dim, ri_dim):
33 | """ Get the orientation, real/imag, height and width dimensions
34 | for the full tensor (6 dimensions)."""
35 | # Calculate which dimension to put the real and imaginary parts and the
36 | # orientations. Also work out where the rows and columns in the original
37 | # image were
38 | o_dim = (o_dim % 6)
39 | ri_dim = (ri_dim % 6)
40 |
41 | if ri_dim < o_dim:
42 | o_dim -= 1
43 |
44 | if o_dim >= 3 and ri_dim >= 3:
45 | h_dim = 2
46 | elif o_dim >= 4 or ri_dim >= 4:
47 | h_dim = 3
48 | else:
49 | h_dim = 4
50 |
51 | if o_dim >= 4 and ri_dim >= 4:
52 | w_dim = 3
53 | elif o_dim >= 4 or ri_dim >= 4:
54 | w_dim = 4
55 | else:
56 | w_dim = 5
57 |
58 | return o_dim, ri_dim, h_dim, w_dim
59 |
60 |
61 | def highs_to_orientations(lh, hl, hh, o_dim):
62 | (deg15r, deg15i), (deg165r, deg165i) = q2c(lh)
63 | (deg45r, deg45i), (deg135r, deg135i) = q2c(hh)
64 | (deg75r, deg75i), (deg105r, deg105i) = q2c(hl)
65 |
66 | # Convert real and imaginary to magnitude
67 | reals = torch.stack(
68 | [deg15r, deg45r, deg75r, deg105r, deg135r, deg165r], dim=o_dim)
69 | imags = torch.stack(
70 | [deg15i, deg45i, deg75i, deg105i, deg135i, deg165i], dim=o_dim)
71 |
72 | return reals, imags
73 |
74 |
75 | def orientations_to_highs(reals, imags, o_dim):
76 | dev = reals.device
77 | horiz = torch.index_select(reals, o_dim, tensor([0, 5], device=dev))
78 | diag = torch.index_select(reals, o_dim, tensor([1, 4], device=dev))
79 | vertic = torch.index_select(reals, o_dim, tensor([2, 3], device=dev))
80 | deg15r, deg165r = torch.unbind(horiz, dim=o_dim)
81 | deg45r, deg135r = torch.unbind(diag, dim=o_dim)
82 | deg75r, deg105r = torch.unbind(vertic, dim=o_dim)
83 | dev = imags.device
84 | horiz = torch.index_select(imags, o_dim, tensor([0, 5], device=dev))
85 | diag = torch.index_select(imags, o_dim, tensor([1, 4], device=dev))
86 | vertic = torch.index_select(imags, o_dim, tensor([2, 3], device=dev))
87 | deg15i, deg165i = torch.unbind(horiz, dim=o_dim)
88 | deg45i, deg135i = torch.unbind(diag, dim=o_dim)
89 | deg75i, deg105i = torch.unbind(vertic, dim=o_dim)
90 |
91 | lh = c2q((deg15r, deg15i), (deg165r, deg165i))
92 | hl = c2q((deg75r, deg75i), (deg105r, deg105i))
93 | hh = c2q((deg45r, deg45i), (deg135r, deg135i))
94 |
95 | return lh, hl, hh
96 |
97 |
98 | def fwd_j1(x, h0, h1, skip_hps, o_dim, mode):
99 | """ Level 1 forward dtcwt.
100 |
101 | Have it as a separate function as can be used by
102 | the forward pass of the forward transform and the backward pass of the
103 | inverse transform.
104 | """
105 | # Level 1 forward (biorthogonal analysis filters)
106 | if not skip_hps:
107 | lo = rowfilter(x, h0, mode)
108 | hi = rowfilter(x, h1, mode)
109 | ll = colfilter(lo, h0, mode)
110 | lh = colfilter(lo, h1, mode)
111 | del lo
112 | hl = colfilter(hi, h0, mode)
113 | hh = colfilter(hi, h1, mode)
114 | del hi
115 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim)
116 | else:
117 | ll = rowfilter(x, h0, mode)
118 | ll = colfilter(ll, h0, mode)
119 | highr = x.new_zeros([])
120 | highi = x.new_zeros([])
121 | return ll, highr, highi
122 |
123 |
124 | def fwd_j1_rot(x, h0, h1, h2, skip_hps, o_dim, mode):
125 | """ Level 1 forward dtcwt.
126 |
127 | Have it as a separate function as can be used by
128 | the forward pass of the forward transform and the backward pass of the
129 | inverse transform.
130 | """
131 | # Level 1 forward (biorthogonal analysis filters)
132 | if not skip_hps:
133 | lo = rowfilter(x, h0, mode)
134 | hi = rowfilter(x, h1, mode)
135 | ba = rowfilter(x, h2, mode)
136 |
137 | lh = colfilter(lo, h1, mode)
138 | hl = colfilter(hi, h0, mode)
139 | hh = colfilter(ba, h2, mode)
140 | ll = colfilter(lo, h0, mode)
141 |
142 | del lo, hi, ba
143 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim)
144 | else:
145 | ll = rowfilter(x, h0, mode)
146 | ll = colfilter(ll, h0, mode)
147 | highr = x.new_zeros([])
148 | highi = x.new_zeros([])
149 | return ll, highr, highi
150 |
151 |
152 | def inv_j1(ll, highr, highi, g0, g1, o_dim, h_dim, w_dim, mode):
153 | """ Level1 inverse dtcwt.
154 |
155 | Have it as a separate function as can be used by the forward pass of the
156 | inverse transform and the backward pass of the forward transform.
157 | """
158 | if highr is None or highr.shape == torch.Size([]):
159 | y = rowfilter(colfilter(ll, g0), g0)
160 | else:
161 | # Get the double sampled bandpass coefficients
162 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim)
163 |
164 | if ll is None or ll.shape == torch.Size([]):
165 | # Interpolate
166 | hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode)
167 | lo = colfilter(lh, g1, mode)
168 | del lh, hh, hl
169 | else:
170 | # Possibly cut back some rows to make the ll match the highs
171 | r, c = ll.shape[2:]
172 | r1, c1 = highr.shape[h_dim], highr.shape[w_dim]
173 | if r != r1 * 2:
174 | ll = ll[:,:,1:-1]
175 | if c != c1 * 2:
176 | ll = ll[:,:,:,1:-1]
177 | # Interpolate
178 | hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode)
179 | lo = colfilter(lh, g1, mode) + colfilter(ll, g0, mode)
180 | del lh, hl, hh
181 |
182 | y = rowfilter(hi, g1, mode) + rowfilter(lo, g0, mode)
183 |
184 | return y
185 |
186 |
187 | def inv_j1_rot(ll, highr, highi, g0, g1, g2, o_dim, h_dim, w_dim, mode):
188 | """ Level1 inverse dtcwt.
189 |
190 | Have it as a separate function as can be used by the forward pass of the
191 | inverse transform and the backward pass of the forward transform.
192 | """
193 | if highr is None or highr.shape == torch.Size([]):
194 | y = rowfilter(colfilter(ll, g0), g0)
195 | else:
196 | # Get the double sampled bandpass coefficients
197 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim)
198 |
199 | if ll is None or ll.shape == torch.Size([]):
200 | # Interpolate
201 | lo = colfilter(lh, g1, mode)
202 | hi = colfilter(hl, g0, mode)
203 | ba = colfilter(hh, g2, mode)
204 | del lh, hh, hl
205 | else:
206 | # Possibly cut back some rows to make the ll match the highs
207 | r, c = ll.shape[2:]
208 | r1, c1 = highr.shape[h_dim], highr.shape[w_dim]
209 | if r != r1 * 2:
210 | ll = ll[:,:,1:-1]
211 | if c != c1 * 2:
212 | ll = ll[:,:,:,1:-1]
213 |
214 | # Interpolate
215 | lo = colfilter(lh, g1, mode) + colfilter(ll, g0, mode)
216 | hi = colfilter(hl, g0, mode)
217 | ba = colfilter(hh, g2, mode)
218 | del lh, hl, hh
219 |
220 | y = rowfilter(hi, g1, mode) + rowfilter(lo, g0, mode) + \
221 | rowfilter(ba, g2, mode)
222 |
223 | return y
224 |
225 |
226 | def fwd_j2plus(x, h0a, h1a, h0b, h1b, skip_hps, o_dim, mode):
227 | """ Level 2 plus forward dtcwt.
228 |
229 | Have it as a separate function as can be used by
230 | the forward pass of the forward transform and the backward pass of the
231 | inverse transform.
232 | """
233 | if not skip_hps:
234 | lo = rowdfilt(x, h0b, h0a, False, mode)
235 | hi = rowdfilt(x, h1b, h1a, True, mode)
236 |
237 | ll = coldfilt(lo, h0b, h0a, False, mode)
238 | lh = coldfilt(lo, h1b, h1a, True, mode)
239 | hl = coldfilt(hi, h0b, h0a, False, mode)
240 | hh = coldfilt(hi, h1b, h1a, True, mode)
241 | del lo, hi
242 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim)
243 | else:
244 | ll = rowdfilt(x, h0b, h0a, False, mode)
245 | ll = coldfilt(ll, h0b, h0a, False, mode)
246 | highr = None
247 | highi = None
248 |
249 | return ll, highr, highi
250 |
251 |
252 | def fwd_j2plus_rot(x, h0a, h1a, h0b, h1b, h2a, h2b, skip_hps, o_dim, mode):
253 | """ Level 2 plus forward dtcwt.
254 |
255 | Have it as a separate function as can be used by
256 | the forward pass of the forward transform and the backward pass of the
257 | inverse transform.
258 | """
259 | if not skip_hps:
260 | lo = rowdfilt(x, h0b, h0a, False, mode)
261 | hi = rowdfilt(x, h1b, h1a, True, mode)
262 | ba = rowdfilt(x, h2b, h2a, True, mode)
263 |
264 | lh = coldfilt(lo, h1b, h1a, True, mode)
265 | hl = coldfilt(hi, h0b, h0a, False, mode)
266 | hh = coldfilt(ba, h2b, h2a, True, mode)
267 | ll = coldfilt(lo, h0b, h0a, False, mode)
268 | del lo, hi, ba
269 | highr, highi = highs_to_orientations(lh, hl, hh, o_dim)
270 | else:
271 | ll = rowdfilt(x, h0b, h0a, False, mode)
272 | ll = coldfilt(ll, h0b, h0a, False, mode)
273 | highr = None
274 | highi = None
275 |
276 | return ll, highr, highi
277 |
278 |
279 | def inv_j2plus(ll, highr, highi, g0a, g1a, g0b, g1b, o_dim, h_dim, w_dim, mode):
280 | """ Level2+ inverse dtcwt.
281 |
282 | Have it as a separate function as can be used by the forward pass of the
283 | inverse transform and the backward pass of the forward transform.
284 | """
285 | if highr is None or highr.shape == torch.Size([]):
286 | y = rowifilt(colifilt(ll, g0b, g0a, False, mode), g0b, g0a, False, mode)
287 | else:
288 | # Get the double sampled bandpass coefficients
289 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim)
290 |
291 | if ll is None or ll.shape == torch.Size([]):
292 | # Interpolate
293 | hi = colifilt(hh, g1b, g1a, True, mode) + \
294 | colifilt(hl, g0b, g0a, False, mode)
295 | lo = colifilt(lh, g1b, g1a, True, mode)
296 | del lh, hh, hl
297 | else:
298 | # Interpolate
299 | hi = colifilt(hh, g1b, g1a, True, mode) + \
300 | colifilt(hl, g0b, g0a, False, mode)
301 | lo = colifilt(lh, g1b, g1a, True, mode) + \
302 | colifilt(ll, g0b, g0a, False, mode)
303 | del lh, hl, hh
304 |
305 | y = rowifilt(hi, g1b, g1a, True, mode) + \
306 | rowifilt(lo, g0b, g0a, False, mode)
307 | return y
308 |
309 |
310 | def inv_j2plus_rot(ll, highr, highi, g0a, g1a, g0b, g1b, g2a, g2b,
311 | o_dim, h_dim, w_dim, mode):
312 | """ Level2+ inverse dtcwt.
313 |
314 | Have it as a separate function as can be used by the forward pass of the
315 | inverse transform and the backward pass of the forward transform.
316 | """
317 | if highr is None or highr.shape == torch.Size([]):
318 | y = rowifilt(colifilt(ll, g0b, g0a, False, mode), g0b, g0a, False, mode)
319 | else:
320 | # Get the double sampled bandpass coefficients
321 | lh, hl, hh = orientations_to_highs(highr, highi, o_dim)
322 |
323 | if ll is None or ll.shape == torch.Size([]):
324 | # Interpolate
325 | lo = colifilt(lh, g1b, g1a, True, mode)
326 | hi = colifilt(hl, g0b, g0a, False, mode)
327 | ba = colifilt(hh, g2b, g2a, True, mode)
328 | del lh, hh, hl
329 | else:
330 | # Interpolate
331 | lo = colifilt(lh, g1b, g1a, True, mode) + \
332 | colifilt(ll, g0b, g0a, False, mode)
333 | hi = colifilt(hl, g0b, g0a, False, mode)
334 | ba = colifilt(hh, g2b, g2a, True, mode)
335 | del lh, hl, hh
336 |
337 | y = rowifilt(hi, g1b, g1a, True, mode) + \
338 | rowifilt(lo, g0b, g0a, False, mode) + \
339 | rowifilt(ba, g2b, g2a, True, mode)
340 | return y
341 |
342 |
343 | class FWD_J1(Function):
344 | """ Differentiable function doing 1 level forward DTCWT """
345 | @staticmethod
346 | def forward(ctx, x, h0, h1, skip_hps, o_dim, ri_dim, mode):
347 | mode = int_to_mode(mode)
348 | ctx.mode = mode
349 | ctx.save_for_backward(h0, h1)
350 | ctx.dims = get_dimensions5(o_dim, ri_dim)
351 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1]
352 |
353 | ll, highr, highi = fwd_j1(x, h0, h1, skip_hps, o_dim, mode)
354 | if not skip_hps:
355 | highs = torch.stack((highr, highi), dim=ri_dim)
356 | else:
357 | highs = ll.new_zeros([])
358 | return ll, highs
359 |
360 | @staticmethod
361 | def backward(ctx, dl, dh):
362 | h0, h1 = ctx.saved_tensors
363 | mode = ctx.mode
364 | dx = None
365 | if ctx.needs_input_grad[0]:
366 | o_dim, ri_dim, h_dim, w_dim = ctx.dims
367 | if dh is not None and dh.shape != torch.Size([]):
368 | dhr, dhi = torch.unbind(dh, dim=ri_dim)
369 | else:
370 | dhr = dl.new_zeros([])
371 | dhi = dl.new_zeros([])
372 | dx = inv_j1(dl, dhr, dhi, h0, h1, o_dim, h_dim, w_dim, mode)
373 |
374 | return dx, None, None, None, None, None, None
375 |
376 |
377 | class FWD_J2PLUS(Function):
378 | """ Differentiable function doing second level forward DTCWT """
379 | @staticmethod
380 | def forward(ctx, x, h0a, h1a, h0b, h1b, skip_hps, o_dim, ri_dim, mode):
381 | mode = 'symmetric'
382 | ctx.mode = mode
383 | ctx.save_for_backward(h0a, h1a, h0b, h1b)
384 | ctx.dims = get_dimensions5(o_dim, ri_dim)
385 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1]
386 |
387 | ll, highr, highi = fwd_j2plus(x, h0a, h1a, h0b, h1b, skip_hps, o_dim, mode)
388 | if not skip_hps:
389 | highs = torch.stack((highr, highi), dim=ri_dim)
390 | else:
391 | highs = ll.new_zeros([])
392 | return ll, highs
393 |
394 | @staticmethod
395 | def backward(ctx, dl, dh):
396 | h0a, h1a, h0b, h1b = ctx.saved_tensors
397 | mode = ctx.mode
398 | # The colifilt and rowifilt functions use conv2d not conv2d_transpose,
399 | # so need to reverse the filters
400 | h0a, h0b = h0b, h0a
401 | h1a, h1b = h1b, h1a
402 | dx = None
403 | if ctx.needs_input_grad[0]:
404 | o_dim, ri_dim, h_dim, w_dim = ctx.dims
405 | if dh is not None and dh.shape != torch.Size([]):
406 | dhr, dhi = torch.unbind(dh, dim=ri_dim)
407 | else:
408 | dhr = dl.new_zeros([])
409 | dhi = dl.new_zeros([])
410 | dx = inv_j2plus(dl, dhr, dhi, h0a, h1a, h0b, h1b,
411 | o_dim, h_dim, w_dim, mode)
412 |
413 | return dx, None, None, None, None, None, None, None, None
414 |
415 |
416 | class INV_J1(Function):
417 | """ Differentiable function doing 1 level inverse DTCWT """
418 | @staticmethod
419 | def forward(ctx, lows, highs, g0, g1, o_dim, ri_dim, mode):
420 | mode = int_to_mode(mode)
421 | ctx.mode = mode
422 | ctx.save_for_backward(g0, g1)
423 | ctx.dims = get_dimensions5(o_dim, ri_dim)
424 | o_dim, ri_dim, h_dim, w_dim = ctx.dims
425 | if highs is not None and highs.shape != torch.Size([]):
426 | highr, highi = torch.unbind(highs, dim=ri_dim)
427 | else:
428 | highr = lows.new_zeros([])
429 | highi = lows.new_zeros([])
430 | y = inv_j1(lows, highr, highi, g0, g1, o_dim, h_dim, w_dim, mode)
431 | return y
432 |
433 | @staticmethod
434 | def backward(ctx, dy):
435 | g0, g1 = ctx.saved_tensors
436 | dl = None
437 | dh = None
438 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1]
439 | mode = ctx.mode
440 | if ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]:
441 | dl, _, _ = fwd_j1(dy, g0, g1, True, o_dim, mode)
442 | elif ctx.needs_input_grad[1] and not ctx.needs_input_grad[0]:
443 | _, dhr, dhi = fwd_j1(dy, g0, g1, False, o_dim, mode)
444 | dh = torch.stack((dhr, dhi), dim=ri_dim)
445 | elif ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
446 | dl, dhr, dhi = fwd_j1(dy, g0, g1, False, o_dim, mode)
447 | dh = torch.stack((dhr, dhi), dim=ri_dim)
448 |
449 | return dl, dh, None, None, None, None, None
450 |
451 |
452 | class INV_J2PLUS(Function):
453 | """ Differentiable function doing level 2 onwards inverse DTCWT """
454 | @staticmethod
455 | def forward(ctx, lows, highs, g0a, g1a, g0b, g1b, o_dim, ri_dim, mode):
456 | mode = 'symmetric'
457 | ctx.mode = mode
458 | ctx.save_for_backward(g0a, g1a, g0b, g1b)
459 | ctx.dims = get_dimensions5(o_dim, ri_dim)
460 | o_dim, ri_dim, h_dim, w_dim = ctx.dims
461 | if highs is not None and highs.shape != torch.Size([]):
462 | highr, highi = torch.unbind(highs, dim=ri_dim)
463 | else:
464 | highr = lows.new_zeros([])
465 | highi = lows.new_zeros([])
466 | y = inv_j2plus(lows, highr, highi, g0a, g1a, g0b, g1b,
467 | o_dim, h_dim, w_dim, mode)
468 | return y
469 |
470 | @staticmethod
471 | def backward(ctx, dy):
472 | g0a, g1a, g0b, g1b = ctx.saved_tensors
473 | g0a, g0b = g0b, g0a
474 | g1a, g1b = g1b, g1a
475 | o_dim, ri_dim = ctx.dims[0], ctx.dims[1]
476 | mode = ctx.mode
477 | dl = None
478 | dh = None
479 | if ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]:
480 | dl, _, _ = fwd_j2plus(dy, g0a, g1a, g0b, g1b, True, o_dim, mode)
481 | elif ctx.needs_input_grad[1] and not ctx.needs_input_grad[0]:
482 | _, dhr, dhi = fwd_j2plus(dy, g0a, g1a, g0b, g1b, False, o_dim, mode)
483 | dh = torch.stack((dhr, dhi), dim=ri_dim)
484 | elif ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
485 | dl, dhr, dhi = fwd_j2plus(dy, g0a, g1a, g0b, g1b, False, o_dim, mode)
486 | dh = torch.stack((dhr, dhi), dim=ri_dim)
487 |
488 | return dl, dh, None, None, None, None, None, None, None
489 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dwt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MusDev7/wtftp-model/d6293e23ad83b29861b781f84737aa7a59cd20c1/pytorch_wavelets/dwt/__init__.py
--------------------------------------------------------------------------------
/pytorch_wavelets/dwt/swt_inverse.py:
--------------------------------------------------------------------------------
1 |
2 | def sfb1d_atrous(lo, hi, g0, g1, mode='periodization', dim=-1, dilation=1,
3 | pad1=None, pad=None):
4 | """ 1D synthesis filter bank of an image tensor with no upsampling. Used for
5 | the stationary wavelet transform.
6 | """
7 | C = lo.shape[1]
8 | d = dim % 4
9 | # If g0, g1 are not tensors, make them. If they are, then assume that they
10 | # are in the right order
11 | if not isinstance(g0, torch.Tensor):
12 | g0 = torch.tensor(np.copy(np.array(g0).ravel()),
13 | dtype=torch.float, device=lo.device)
14 | if not isinstance(g1, torch.Tensor):
15 | g1 = torch.tensor(np.copy(np.array(g1).ravel()),
16 | dtype=torch.float, device=lo.device)
17 | L = g0.numel()
18 | shape = [1,1,1,1]
19 | shape[d] = L
20 | # If g aren't in the right shape, make them so
21 | if g0.shape != tuple(shape):
22 | g0 = g0.reshape(*shape)
23 | if g1.shape != tuple(shape):
24 | g1 = g1.reshape(*shape)
25 | g0 = torch.cat([g0]*C,dim=0)
26 | g1 = torch.cat([g1]*C,dim=0)
27 |
28 | # Calculate the padding size.
29 | # With dilation, zeros are inserted between the filter taps but not after.
30 | # that means a filter that is [a b c d] becomes [a 0 b 0 c 0 d].
31 | centre = L / 2
32 | fsz = (L-1)*dilation + 1
33 | newcentre = fsz / 2
34 | before = newcentre - dilation*centre
35 |
36 | # When conv_transpose2d is done, a filter with k taps expands an input with
37 | # N samples to be N + k - 1 samples. The 'padding' is really the opposite of
38 | # that, and is how many samples on the edges you want to cut out.
39 | # In addition to this, we want the input to be extended before convolving.
40 | # This means the final output size without the padding option will be
41 | # N + k - 1 + k - 1
42 | # The final thing to worry about is making sure that the output is centred.
43 | short_offset = dilation - 1
44 | centre_offset = fsz % 2
45 | a = fsz//2
46 | b = fsz//2 + (fsz + 1) % 2
47 | # a = 0
48 | # b = 0
49 | pad = (0, 0, a, b) if d == 2 else (a, b, 0, 0)
50 | lo = mypad(lo, pad=pad, mode=mode)
51 | hi = mypad(hi, pad=pad, mode=mode)
52 | unpad = (fsz - 1, 0) if d == 2 else (0, fsz - 1)
53 | unpad = (0, 0)
54 | y = F.conv_transpose2d(lo, g0, padding=unpad, groups=C, dilation=dilation) + \
55 | F.conv_transpose2d(hi, g1, padding=unpad, groups=C, dilation=dilation)
56 | # pad = (L-1, 0) if d == 2 else (0, L-1)
57 | # y = F.conv_transpose2d(lo, g0, padding=pad, groups=C, dilation=dilation) + \
58 | # F.conv_transpose2d(hi, g1, padding=pad, groups=C, dilation=dilation)
59 | #
60 | #
61 | # Calculate the pad size
62 | # L2 = (L * dilation)//2
63 | # # pad = (0, 0, L2, L2+dilation) if d == 2 else (L2, L2+dilation, 0, 0)
64 | # a = dilation*2
65 | # b = dilation*(L-2)
66 | # if pad1 is None:
67 | # pad1 = (0, 0, a, b) if d == 2 else (a, b, 0, 0)
68 | # print(pad1)
69 | # lo = mypad(lo, pad=pad1, mode=mode)
70 | # hi = mypad(hi, pad=pad1, mode=mode)
71 | # if pad is None:
72 | # p = (a + b + (L - 1)*dilation)//2
73 | # pad = (p, 0) if d == 2 else (0, p)
74 | # print(pad)
75 |
76 | return y/(2*dilation)
77 |
78 |
79 | def sfb2d_atrous(ll, lh, hl, hh, filts, mode='zero'):
80 | """ Does a single level 2d wavelet reconstruction of wavelet coefficients.
81 | Does separate row and column filtering by two calls to
82 | :py:func:`pytorch_wavelets.dwt.lowlevel.sfb1d`
83 |
84 | Inputs:
85 | ll (torch.Tensor): lowpass coefficients
86 | lh (torch.Tensor): horizontal coefficients
87 | hl (torch.Tensor): vertical coefficients
88 | hh (torch.Tensor): diagonal coefficients
89 | filts (list of ndarray or torch.Tensor): If a list of tensors has been
90 | given, this function assumes they are in the right form (the form
91 | returned by
92 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`).
93 | Otherwise, this function will prepare the filters to be of the right
94 | form by calling
95 | :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`.
96 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
97 | padding to use. If periodization, the output size will be half the
98 | input size. Otherwise, the output size will be slightly larger than
99 | half.
100 | """
101 | tensorize = [not isinstance(x, torch.Tensor) for x in filts]
102 | if len(filts) == 2:
103 | g0, g1 = filts
104 | if True in tensorize:
105 | g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1)
106 | else:
107 | g0_col = g0
108 | g0_row = g0.transpose(2,3)
109 | g1_col = g1
110 | g1_row = g1.transpose(2,3)
111 | elif len(filts) == 4:
112 | if True in tensorize:
113 | g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts)
114 | else:
115 | g0_col, g1_col, g0_row, g1_row = filts
116 | else:
117 | raise ValueError("Unknown form for input filts")
118 |
119 | lo = sfb1d_atrous(ll, lh, g0_col, g1_col, mode=mode, dim=2)
120 | hi = sfb1d_atrous(hl, hh, g0_col, g1_col, mode=mode, dim=2)
121 | y = sfb1d_atrous(lo, hi, g0_row, g1_row, mode=mode, dim=3)
122 |
123 | return y
124 |
125 |
126 | class SWTInverse(nn.Module):
127 | """ Performs a 2d DWT Inverse reconstruction of an image
128 |
129 | Args:
130 | wave (str or pywt.Wavelet): Which wavelet to use
131 | C: deprecated, will be removed in future
132 | """
133 | def __init__(self, wave='db1', mode='zero', separable=True):
134 | super().__init__()
135 | if isinstance(wave, str):
136 | wave = pywt.Wavelet(wave)
137 | if isinstance(wave, pywt.Wavelet):
138 | g0_col, g1_col = wave.rec_lo, wave.rec_hi
139 | g0_row, g1_row = g0_col, g1_col
140 | else:
141 | if len(wave) == 2:
142 | g0_col, g1_col = wave[0], wave[1]
143 | g0_row, g1_row = g0_col, g1_col
144 | elif len(wave) == 4:
145 | g0_col, g1_col = wave[0], wave[1]
146 | g0_row, g1_row = wave[2], wave[3]
147 | # Prepare the filters
148 | if separable:
149 | filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
150 | self.register_buffer('g0_col', filts[0])
151 | self.register_buffer('g1_col', filts[1])
152 | self.register_buffer('g0_row', filts[2])
153 | self.register_buffer('g1_row', filts[3])
154 | else:
155 | filts = lowlevel.prep_filt_sfb2d_nonsep(
156 | g0_col, g1_col, g0_row, g1_row)
157 | self.register_buffer('h', filts)
158 | self.mode = mode
159 | self.separable = separable
160 |
161 | def forward(self, coeffs):
162 | """
163 | Args:
164 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
165 | yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
166 | W_{in}')` and yh is a list of bandpass tensors of shape
167 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
168 | the format returned by DWTForward
169 |
170 | Returns:
171 | Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
172 |
173 | Note:
174 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
175 | downsampled shapes of the DWT pyramid.
176 |
177 | Note:
178 | Can have None for any of the highpass scales and will treat the
179 | values as zeros (not in an efficient way though).
180 | """
181 | yl, yh = coeffs
182 | ll = yl
183 |
184 | # Do a multilevel inverse transform
185 | for h in yh[::-1]:
186 | if h is None:
187 | h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
188 | ll.shape[-1], device=ll.device)
189 |
190 | # 'Unpad' added dimensions
191 | if ll.shape[-2] > h.shape[-2]:
192 | ll = ll[...,:-1,:]
193 | if ll.shape[-1] > h.shape[-1]:
194 | ll = ll[...,:-1]
195 |
196 | # Do the synthesis filter banks
197 | if self.separable:
198 | lh, hl, hh = torch.unbind(h, dim=2)
199 | filts = (self.g0_col, self.g1_col, self.g0_row, self.g1_row)
200 | ll = lowlevel.sfb2d(ll, lh, hl, hh, filts, mode=self.mode)
201 | else:
202 | c = torch.cat((ll[:,:,None], h), dim=2)
203 | ll = lowlevel.sfb2d_nonsep(c, self.h, mode=self.mode)
204 | return ll
205 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dwt/transform1d.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import pywt
3 | import pytorch_wavelets.dwt.lowlevel as lowlevel
4 | import torch
5 |
6 |
7 | class DWT1DForward(nn.Module):
8 | """ Performs a 1d DWT Forward decomposition of an image
9 |
10 | Args:
11 | J (int): Number of levels of decomposition
12 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
13 | Can be:
14 | 1) a string to pass to pywt.Wavelet constructor
15 | 2) a pywt.Wavelet class
16 | 3) a tuple of numpy arrays (h0, h1)
17 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
18 | padding scheme
19 | """
20 | def __init__(self, J=1, wave='db1', mode='zero'):
21 | super().__init__()
22 | if isinstance(wave, str):
23 | wave = pywt.Wavelet(wave)
24 | if isinstance(wave, pywt.Wavelet):
25 | h0, h1 = wave.dec_lo, wave.dec_hi
26 | else:
27 | assert len(wave) == 2
28 | h0, h1 = wave[0], wave[1]
29 |
30 | # Prepare the filters - this makes them into column filters
31 | filts = lowlevel.prep_filt_afb1d(h0, h1)
32 | self.register_buffer('h0', filts[0])
33 | self.register_buffer('h1', filts[1])
34 | self.J = J
35 | self.mode = mode
36 |
37 | def forward(self, x):
38 | """ Forward pass of the DWT.
39 |
40 | Args:
41 | x (tensor): Input of shape :math:`(N, C_{in}, L_{in})`
42 |
43 | Returns:
44 | (yl, yh)
45 | tuple of lowpass (yl) and bandpass (yh) coefficients.
46 | yh is a list of length J with the first entry
47 | being the finest scale coefficients.
48 | """
49 | assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)"
50 | highs = []
51 | x0 = x
52 | mode = lowlevel.mode_to_int(self.mode)
53 |
54 | # Do a multilevel transform
55 | for j in range(self.J):
56 | x0, x1 = lowlevel.AFB1D.apply(x0, self.h0, self.h1, mode)
57 | highs.append(x1)
58 |
59 | return x0, highs
60 |
61 |
62 | class DWT1DInverse(nn.Module):
63 | """ Performs a 1d DWT Inverse reconstruction of an image
64 |
65 | Args:
66 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
67 | Can be:
68 | 1) a string to pass to pywt.Wavelet constructor
69 | 2) a pywt.Wavelet class
70 | 3) a tuple of numpy arrays (h0, h1)
71 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
72 | padding scheme
73 | """
74 | def __init__(self, wave='db1', mode='zero'):
75 | super().__init__()
76 | if isinstance(wave, str):
77 | wave = pywt.Wavelet(wave)
78 | if isinstance(wave, pywt.Wavelet):
79 | g0, g1 = wave.rec_lo, wave.rec_hi
80 | else:
81 | assert len(wave) == 2
82 | g0, g1 = wave[0], wave[1]
83 |
84 | # Prepare the filters
85 | filts = lowlevel.prep_filt_sfb1d(g0, g1)
86 | self.register_buffer('g0', filts[0])
87 | self.register_buffer('g1', filts[1])
88 | self.mode = mode
89 |
90 | def forward(self, coeffs):
91 | """
92 | Args:
93 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should
94 | match the format returned by DWT1DForward.
95 |
96 | Returns:
97 | Reconstructed input of shape :math:`(N, C_{in}, L_{in})`
98 |
99 | Note:
100 | Can have None for any of the highpass scales and will treat the
101 | values as zeros (not in an efficient way though).
102 | """
103 | x0, highs = coeffs
104 | assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)"
105 | mode = lowlevel.mode_to_int(self.mode)
106 | # Do a multilevel inverse transform
107 | for x1 in highs[::-1]:
108 | if x1 is None:
109 | x1 = torch.zeros_like(x0)
110 |
111 | # 'Unpad' added signal
112 | if x0.shape[-1] > x1.shape[-1]:
113 | x0 = x0[..., :-1]
114 | x0 = lowlevel.SFB1D.apply(x0, x1, self.g0, self.g1, mode)
115 | return x0
116 |
--------------------------------------------------------------------------------
/pytorch_wavelets/dwt/transform2d.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import pywt
3 | import pytorch_wavelets.dwt.lowlevel as lowlevel
4 | import torch
5 |
6 |
7 | class DWTForward(nn.Module):
8 | """ Performs a 2d DWT Forward decomposition of an image
9 |
10 | Args:
11 | J (int): Number of levels of decomposition
12 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
13 | Can be:
14 | 1) a string to pass to pywt.Wavelet constructor
15 | 2) a pywt.Wavelet class
16 | 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row)
17 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
18 | padding scheme
19 | """
20 | def __init__(self, J=1, wave='db1', mode='zero'):
21 | super().__init__()
22 | if isinstance(wave, str):
23 | wave = pywt.Wavelet(wave)
24 | if isinstance(wave, pywt.Wavelet):
25 | h0_col, h1_col = wave.dec_lo, wave.dec_hi
26 | h0_row, h1_row = h0_col, h1_col
27 | else:
28 | if len(wave) == 2:
29 | h0_col, h1_col = wave[0], wave[1]
30 | h0_row, h1_row = h0_col, h1_col
31 | elif len(wave) == 4:
32 | h0_col, h1_col = wave[0], wave[1]
33 | h0_row, h1_row = wave[2], wave[3]
34 |
35 | # Prepare the filters
36 | filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
37 | self.register_buffer('h0_col', filts[0])
38 | self.register_buffer('h1_col', filts[1])
39 | self.register_buffer('h0_row', filts[2])
40 | self.register_buffer('h1_row', filts[3])
41 | self.J = J
42 | self.mode = mode
43 |
44 | def forward(self, x):
45 | """ Forward pass of the DWT.
46 |
47 | Args:
48 | x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
49 |
50 | Returns:
51 | (yl, yh)
52 | tuple of lowpass (yl) and bandpass (yh) coefficients.
53 | yh is a list of length J with the first entry
54 | being the finest scale coefficients. yl has shape
55 | :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
56 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
57 | dimension in yh iterates over the LH, HL and HH coefficients.
58 |
59 | Note:
60 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
61 | downsampled shapes of the DWT pyramid.
62 | """
63 | yh = []
64 | ll = x
65 | mode = lowlevel.mode_to_int(self.mode)
66 |
67 | # Do a multilevel transform
68 | for j in range(self.J):
69 | # Do 1 level of the transform
70 | ll, high = lowlevel.AFB2D.apply(
71 | ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode)
72 | yh.append(high)
73 |
74 | return ll, yh
75 |
76 |
77 | class DWTInverse(nn.Module):
78 | """ Performs a 2d DWT Inverse reconstruction of an image
79 |
80 | Args:
81 | wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
82 | Can be:
83 | 1) a string to pass to pywt.Wavelet constructor
84 | 2) a pywt.Wavelet class
85 | 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row)
86 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
87 | padding scheme
88 | """
89 | def __init__(self, wave='db1', mode='zero'):
90 | super().__init__()
91 | if isinstance(wave, str):
92 | wave = pywt.Wavelet(wave)
93 | if isinstance(wave, pywt.Wavelet):
94 | g0_col, g1_col = wave.rec_lo, wave.rec_hi
95 | g0_row, g1_row = g0_col, g1_col
96 | else:
97 | if len(wave) == 2:
98 | g0_col, g1_col = wave[0], wave[1]
99 | g0_row, g1_row = g0_col, g1_col
100 | elif len(wave) == 4:
101 | g0_col, g1_col = wave[0], wave[1]
102 | g0_row, g1_row = wave[2], wave[3]
103 | # Prepare the filters
104 | filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
105 | self.register_buffer('g0_col', filts[0])
106 | self.register_buffer('g1_col', filts[1])
107 | self.register_buffer('g0_row', filts[2])
108 | self.register_buffer('g1_row', filts[3])
109 | self.mode = mode
110 |
111 | def forward(self, coeffs):
112 | """
113 | Args:
114 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
115 | yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
116 | W_{in}')` and yh is a list of bandpass tensors of shape
117 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
118 | the format returned by DWTForward
119 |
120 | Returns:
121 | Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
122 |
123 | Note:
124 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
125 | downsampled shapes of the DWT pyramid.
126 |
127 | Note:
128 | Can have None for any of the highpass scales and will treat the
129 | values as zeros (not in an efficient way though).
130 | """
131 | yl, yh = coeffs
132 | ll = yl
133 | mode = lowlevel.mode_to_int(self.mode)
134 |
135 | # Do a multilevel inverse transform
136 | for h in yh[::-1]:
137 | if h is None:
138 | h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
139 | ll.shape[-1], device=ll.device)
140 |
141 | # 'Unpad' added dimensions
142 | if ll.shape[-2] > h.shape[-2]:
143 | ll = ll[...,:-1,:]
144 | if ll.shape[-1] > h.shape[-1]:
145 | ll = ll[...,:-1]
146 | ll = lowlevel.SFB2D.apply(
147 | ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
148 | return ll
149 |
150 |
151 | class SWTForward(nn.Module):
152 | """ Performs a 2d Stationary wavelet transform (or undecimated wavelet
153 | transform) of an image
154 |
155 | Args:
156 | J (int): Number of levels of decomposition
157 | wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to
158 | pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class,
159 | or can be a two tuple of array-like objects for the analysis low and
160 | high pass filters.
161 | mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
162 | padding scheme. PyWavelets uses only periodization so we use this
163 | as our default scheme.
164 | """
165 | def __init__(self, J=1, wave='db1', mode='periodization'):
166 | super().__init__()
167 | if isinstance(wave, str):
168 | wave = pywt.Wavelet(wave)
169 | if isinstance(wave, pywt.Wavelet):
170 | h0_col, h1_col = wave.dec_lo, wave.dec_hi
171 | h0_row, h1_row = h0_col, h1_col
172 | else:
173 | if len(wave) == 2:
174 | h0_col, h1_col = wave[0], wave[1]
175 | h0_row, h1_row = h0_col, h1_col
176 | elif len(wave) == 4:
177 | h0_col, h1_col = wave[0], wave[1]
178 | h0_row, h1_row = wave[2], wave[3]
179 |
180 | # Prepare the filters
181 | filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
182 | self.register_buffer('h0_col', filts[0])
183 | self.register_buffer('h1_col', filts[1])
184 | self.register_buffer('h0_row', filts[2])
185 | self.register_buffer('h1_row', filts[3])
186 |
187 | self.J = J
188 | self.mode = mode
189 |
190 | def forward(self, x):
191 | """ Forward pass of the SWT.
192 |
193 | Args:
194 | x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
195 |
196 | Returns:
197 | List of coefficients for each scale. Each coefficient has
198 | shape :math:`(N, C_{in}, 4, H_{in}, W_{in})` where the extra
199 | dimension stores the 4 subbands for each scale. The ordering in
200 | these 4 coefficients is: (A, H, V, D) or (ll, lh, hl, hh).
201 | """
202 | ll = x
203 | coeffs = []
204 | # Do a multilevel transform
205 | filts = (self.h0_col, self.h1_col, self.h0_row, self.h1_row)
206 | for j in range(self.J):
207 | # Do 1 level of the transform
208 | y = lowlevel.afb2d_atrous(ll, filts, self.mode, 2**j)
209 | coeffs.append(y)
210 | ll = y[:,:,0]
211 |
212 | return coeffs
213 |
--------------------------------------------------------------------------------
/pytorch_wavelets/scatternet/__init__.py:
--------------------------------------------------------------------------------
1 | from .layers import ScatLayer, ScatLayerj2
2 |
3 | __all__ = ['ScatLayer', 'ScatLayerj2']
4 |
--------------------------------------------------------------------------------
/pytorch_wavelets/scatternet/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_wavelets.dtcwt.coeffs import biort as _biort, qshift as _qshift
4 | from pytorch_wavelets.dtcwt.lowlevel import prep_filt
5 |
6 | from .lowlevel import mode_to_int
7 | from .lowlevel import ScatLayerj1_f, ScatLayerj1_rot_f
8 | from .lowlevel import ScatLayerj2_f, ScatLayerj2_rot_f
9 |
10 |
11 | class ScatLayer(nn.Module):
12 | """ Does one order of scattering at a single scale. Can be made into a
13 | second order scatternet by stacking two of these layers.
14 | Inputs:
15 | biort (str): the biorthogonal filters to use. if 'near_sym_b_bp' will
16 | use the rotationally symmetric filters. These have 13 and 19 taps
17 | so are quite long. They also require 7 1D convolutions instead of 6.
18 | x (torch.tensor): Input of shape (N, C, H, W)
19 | mode (str): padding mode. Can be 'symmetric' or 'zero'
20 | magbias (float): the magnitude bias to use for smoothing
21 | combine_colour (bool): if true, will only have colour lowpass and have
22 | greyscale bandpass
23 | Returns:
24 | y (torch.tensor): y has the lowpass and invariant U terms stacked along
25 | the channel dimension, and so has shape (N, 7*C, H/2, W/2). Where
26 | the first C channels are the lowpass outputs, and the next 6C are
27 | the magnitude highpass outputs.
28 | """
29 | def __init__(self, biort='near_sym_a', mode='symmetric', magbias=1e-2,
30 | combine_colour=False):
31 | super().__init__()
32 | self.biort = biort
33 | # Have to convert the string to an int as the grad checks don't work
34 | # with string inputs
35 | self.mode_str = mode
36 | self.mode = mode_to_int(mode)
37 | self.magbias = magbias
38 | self.combine_colour = combine_colour
39 | if biort == 'near_sym_b_bp':
40 | self.bandpass_diag = True
41 | h0o, _, h1o, _, h2o, _ = _biort(biort)
42 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
43 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
44 | self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False)
45 | else:
46 | self.bandpass_diag = False
47 | h0o, _, h1o, _ = _biort(biort)
48 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
49 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
50 |
51 | def forward(self, x):
52 | # Do the single scale DTCWT
53 | # If the row/col count of X is not divisible by 2 then we need to
54 | # extend X
55 | _, ch, r, c = x.shape
56 | if r % 2 != 0:
57 | x = torch.cat((x, x[:,:,-1:]), dim=2)
58 | if c % 2 != 0:
59 | x = torch.cat((x, x[:,:,:,-1:]), dim=3)
60 |
61 | if self.combine_colour:
62 | assert ch == 3
63 |
64 | if self.bandpass_diag:
65 | Z = ScatLayerj1_rot_f.apply(
66 | x, self.h0o, self.h1o, self.h2o, self.mode, self.magbias,
67 | self.combine_colour)
68 | else:
69 | Z = ScatLayerj1_f.apply(
70 | x, self.h0o, self.h1o, self.mode, self.magbias,
71 | self.combine_colour)
72 | if not self.combine_colour:
73 | b, _, c, h, w = Z.shape
74 | Z = Z.view(b, 7*c, h, w)
75 | return Z
76 |
77 | def extra_repr(self):
78 | return "biort='{}', mode='{}', magbias={}".format(
79 | self.biort, self.mode_str, self.magbias)
80 |
81 |
82 | class ScatLayerj2(nn.Module):
83 | """ Does second order scattering for two scales. Uses correct dtcwt first
84 | and second level filters compared to ScatLayer which only uses biorthogonal
85 | filters.
86 |
87 | Inputs:
88 | biort (str): the biorthogonal filters to use. if 'near_sym_b_bp' will
89 | use the rotationally symmetric filters. These have 13 and 19 taps
90 | so are quite long. They also require 7 1D convolutions instead of 6.
91 | x (torch.tensor): Input of shape (N, C, H, W)
92 | mode (str): padding mode. Can be 'symmetric' or 'zero'
93 | Returns:
94 | y (torch.tensor): y has the lowpass and invariant U terms stacked along
95 | the channel dimension, and so has shape (N, 7*C, H/2, W/2). Where
96 | the first C channels are the lowpass outputs, and the next 6C are
97 | the magnitude highpass outputs.
98 | """
99 | def __init__(self, biort='near_sym_a', qshift='qshift_a', mode='symmetric',
100 | magbias=1e-2, combine_colour=False):
101 | super().__init__()
102 | self.biort = biort
103 | self.qshift = biort
104 | # Have to convert the string to an int as the grad checks don't work
105 | # with string inputs
106 | self.mode_str = mode
107 | self.mode = mode_to_int(mode)
108 | self.magbias = magbias
109 | self.combine_colour = combine_colour
110 | if biort == 'near_sym_b_bp':
111 | assert qshift == 'qshift_b_bp'
112 | self.bandpass_diag = True
113 | h0o, _, h1o, _, h2o, _ = _biort(biort)
114 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
115 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
116 | self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False)
117 | h0a, h0b, _, _, h1a, h1b, _, _, h2a, h2b, _, _ = _qshift('qshift_b_bp')
118 | self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False)
119 | self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False)
120 | self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False)
121 | self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
122 | self.h2a = torch.nn.Parameter(prep_filt(h2a, 1), False)
123 | self.h2b = torch.nn.Parameter(prep_filt(h2b, 1), False)
124 | else:
125 | self.bandpass_diag = False
126 | h0o, _, h1o, _ = _biort(biort)
127 | self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
128 | self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
129 | h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
130 | self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False)
131 | self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False)
132 | self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False)
133 | self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
134 |
135 | def forward(self, x):
136 | # Ensure the input size is divisible by 8
137 | ch, r, c = x.shape[1:]
138 | rem = r % 8
139 | if rem != 0:
140 | rows_after = (9-rem)//2
141 | rows_before = (8-rem) // 2
142 | x = torch.cat((x[:,:,:rows_before], x,
143 | x[:,:,-rows_after:]), dim=2)
144 | rem = c % 8
145 | if rem != 0:
146 | cols_after = (9-rem)//2
147 | cols_before = (8-rem) // 2
148 | x = torch.cat((x[:,:,:,:cols_before], x,
149 | x[:,:,:,-cols_after:]), dim=3)
150 |
151 | if self.combine_colour:
152 | assert ch == 3
153 |
154 | if self.bandpass_diag:
155 | pass
156 | Z = ScatLayerj2_rot_f.apply(
157 | x, self.h0o, self.h1o, self.h2o, self.h0a, self.h0b, self.h1a,
158 | self.h1b, self.h2a, self.h2b, self.mode, self.magbias,
159 | self.combine_colour)
160 | else:
161 | Z = ScatLayerj2_f.apply(
162 | x, self.h0o, self.h1o, self.h0a, self.h0b, self.h1a,
163 | self.h1b, self.mode, self.magbias, self.combine_colour)
164 |
165 | if not self.combine_colour:
166 | b, _, c, h, w = Z.shape
167 | Z = Z.view(b, 49*c, h, w)
168 | return Z
169 |
170 | def extra_repr(self):
171 | return "biort='{}', mode='{}', magbias={}".format(
172 | self.biort, self.mode_str, self.magbias)
173 |
--------------------------------------------------------------------------------
/pytorch_wavelets/scatternet/lowlevel.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1, inv_j1
6 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1_rot, inv_j1_rot
7 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j2plus, inv_j2plus
8 | from pytorch_wavelets.dtcwt.transform_funcs import fwd_j2plus_rot, inv_j2plus_rot
9 |
10 |
11 | def mode_to_int(mode):
12 | if mode == 'zero':
13 | return 0
14 | elif mode == 'symmetric':
15 | return 1
16 | elif mode == 'per' or mode == 'periodization':
17 | return 2
18 | elif mode == 'constant':
19 | return 3
20 | elif mode == 'reflect':
21 | return 4
22 | elif mode == 'replicate':
23 | return 5
24 | elif mode == 'periodic':
25 | return 6
26 | else:
27 | raise ValueError("Unkown pad type: {}".format(mode))
28 |
29 |
30 | def int_to_mode(mode):
31 | if mode == 0:
32 | return 'zero'
33 | elif mode == 1:
34 | return 'symmetric'
35 | elif mode == 2:
36 | return 'periodization'
37 | elif mode == 3:
38 | return 'constant'
39 | elif mode == 4:
40 | return 'reflect'
41 | elif mode == 5:
42 | return 'replicate'
43 | elif mode == 6:
44 | return 'periodic'
45 | else:
46 | raise ValueError("Unkown pad type: {}".format(mode))
47 |
48 |
49 | class SmoothMagFn(torch.autograd.Function):
50 | """ Class to do complex magnitude """
51 | @staticmethod
52 | def forward(ctx, x, y, b):
53 | r = torch.sqrt(x**2 + y**2 + b**2)
54 | if x.requires_grad:
55 | dx = x/r
56 | dy = y/r
57 | ctx.save_for_backward(dx, dy)
58 |
59 | return r - b
60 |
61 | @staticmethod
62 | def backward(ctx, dr):
63 | dx = None
64 | if ctx.needs_input_grad[0]:
65 | drdx, drdy = ctx.saved_tensors
66 | dx = drdx * dr
67 | dy = drdy * dr
68 | return dx, dy, None
69 |
70 |
71 | class ScatLayerj1_f(torch.autograd.Function):
72 | """ Function to do forward and backward passes of a single scattering
73 | layer with the DTCWT biorthogonal filters. """
74 |
75 | @staticmethod
76 | def forward(ctx, x, h0o, h1o, mode, bias, combine_colour):
77 | # bias = 1e-2
78 | # bias = 0
79 | ctx.in_shape = x.shape
80 | batch, ch, r, c = x.shape
81 | assert r % 2 == c % 2 == 0
82 | mode = int_to_mode(mode)
83 | ctx.mode = mode
84 | ctx.combine_colour = combine_colour
85 |
86 | ll, reals, imags = fwd_j1(x, h0o, h1o, False, 1, mode)
87 | ll = F.avg_pool2d(ll, 2)
88 | if combine_colour:
89 | r = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 +
90 | reals[:,:,1]**2 + imags[:,:,1]**2 +
91 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2)
92 | r = r[:, :, None]
93 | else:
94 | r = torch.sqrt(reals**2 + imags**2 + bias**2)
95 |
96 | if x.requires_grad:
97 | drdx = reals/r
98 | drdy = imags/r
99 | ctx.save_for_backward(h0o, h1o, drdx, drdy)
100 | else:
101 | z = x.new_zeros(1)
102 | ctx.save_for_backward(h0o, h1o, z, z)
103 |
104 | r = r - bias
105 | del reals, imags
106 | if combine_colour:
107 | Z = torch.cat((ll, r[:, :, 0]), dim=1)
108 | else:
109 | Z = torch.cat((ll[:, None], r), dim=1)
110 |
111 | return Z
112 |
113 | @staticmethod
114 | def backward(ctx, dZ):
115 | dX = None
116 | mode = ctx.mode
117 |
118 | if ctx.needs_input_grad[0]:
119 | # h0o, h1o, θ = ctx.saved_tensors
120 | h0o, h1o, drdx, drdy = ctx.saved_tensors
121 | # Use the special properties of the filters to get the time reverse
122 | h0o_t = h0o
123 | h1o_t = h1o
124 |
125 | # Level 1 backward (time reversed biorthogonal analysis filters)
126 | if ctx.combine_colour:
127 | dYl, dr = dZ[:,:3], dZ[:,3:]
128 | dr = dr[:, :, None]
129 | else:
130 | dYl, dr = dZ[:,0], dZ[:,1:]
131 | ll = 1/4 * F.interpolate(dYl, scale_factor=2, mode="nearest")
132 | reals = dr * drdx
133 | imags = dr * drdy
134 |
135 | dX = inv_j1(ll, reals, imags, h0o_t, h1o_t, 1, 3, 4, mode)
136 |
137 | return (dX,) + (None,) * 5
138 |
139 |
140 | class ScatLayerj1_rot_f(torch.autograd.Function):
141 | """ Function to do forward and backward passes of a single scattering
142 | layer with the DTCWT biorthogonal filters. Uses the rotationally symmetric
143 | filters, i.e. a slightly more expensive operation."""
144 |
145 | @staticmethod
146 | def forward(ctx, x, h0o, h1o, h2o, mode, bias, combine_colour):
147 | mode = int_to_mode(mode)
148 | ctx.mode = mode
149 | # bias = 0
150 | ctx.in_shape = x.shape
151 | ctx.combine_colour = combine_colour
152 | batch, ch, r, c = x.shape
153 | assert r % 2 == c % 2 == 0
154 |
155 | # Level 1 forward (biorthogonal analysis filters)
156 | ll, reals, imags = fwd_j1_rot(x, h0o, h1o, h2o, False, 1, mode)
157 | ll = F.avg_pool2d(ll, 2)
158 | if combine_colour:
159 | r = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 +
160 | reals[:,:,1]**2 + imags[:,:,1]**2 +
161 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2)
162 | r = r[:, :, None]
163 | else:
164 | r = torch.sqrt(reals**2 + imags**2 + bias**2)
165 | if x.requires_grad:
166 | drdx = reals/r
167 | drdy = imags/r
168 | ctx.save_for_backward(h0o, h1o, h2o, drdx, drdy)
169 | else:
170 | z = x.new_zeros(1)
171 | ctx.save_for_backward(h0o, h1o, h2o, z, z)
172 | r = r - bias
173 | del reals, imags
174 | if combine_colour:
175 | Z = torch.cat((ll, r[:, :, 0]), dim=1)
176 | else:
177 | Z = torch.cat((ll[:, None], r), dim=1)
178 |
179 | return Z
180 |
181 | @staticmethod
182 | def backward(ctx, dZ):
183 | dX = None
184 | mode = ctx.mode
185 |
186 | if ctx.needs_input_grad[0]:
187 | # Don't need to do time reverse as these filters are symmetric
188 | # h0o, h1o, h2o, θ = ctx.saved_tensors
189 | h0o, h1o, h2o, drdx, drdy = ctx.saved_tensors
190 |
191 | # Level 1 backward (time reversed biorthogonal analysis filters)
192 | if ctx.combine_colour:
193 | dYl, dr = dZ[:,:3], dZ[:,3:]
194 | dr = dr[:, :, None]
195 | else:
196 | dYl, dr = dZ[:,0], dZ[:,1:]
197 | ll = 1/4 * F.interpolate(dYl, scale_factor=2, mode="nearest")
198 |
199 | reals = dr * drdx
200 | imags = dr * drdy
201 | dX = inv_j1_rot(ll, reals, imags, h0o, h1o, h2o, 1, 3, 4, mode)
202 |
203 | return (dX,) + (None,) * 6
204 |
205 |
206 | class ScatLayerj2_f(torch.autograd.Function):
207 | """ Function to do forward and backward passes of a single scattering
208 | layer with the DTCWT biorthogonal filters. """
209 |
210 | @staticmethod
211 | def forward(ctx, x, h0o, h1o, h0a, h0b, h1a, h1b, mode, bias, combine_colour):
212 | # bias = 1e-2
213 | # bias = 0
214 | ctx.in_shape = x.shape
215 | batch, ch, r, c = x.shape
216 | assert r % 8 == c % 8 == 0
217 | mode = int_to_mode(mode)
218 | ctx.mode = mode
219 | ctx.combine_colour = combine_colour
220 |
221 | # First order scattering
222 | s0, reals, imags = fwd_j1(x, h0o, h1o, False, 1, mode)
223 | if combine_colour:
224 | s1_j1 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 +
225 | reals[:,:,1]**2 + imags[:,:,1]**2 +
226 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2)
227 | s1_j1 = s1_j1[:, :, None]
228 | if x.requires_grad:
229 | dsdx1 = reals/s1_j1
230 | dsdy1 = imags/s1_j1
231 | s1_j1 = s1_j1 - bias
232 |
233 | s0, reals, imags = fwd_j2plus(s0, h0a, h1a, h0b, h1b, False, 1, mode)
234 | s1_j2 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 +
235 | reals[:,:,1]**2 + imags[:,:,1]**2 +
236 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2)
237 | s1_j2 = s1_j2[:, :, None]
238 | if x.requires_grad:
239 | dsdx2 = reals/s1_j2
240 | dsdy2 = imags/s1_j2
241 | s1_j2 = s1_j2 - bias
242 | s0 = F.avg_pool2d(s0, 2)
243 |
244 | # Second order scattering
245 | s1_j1 = s1_j1[:, :, 0]
246 | s1_j1, reals, imags = fwd_j1(s1_j1, h0o, h1o, False, 1, mode)
247 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2)
248 | if x.requires_grad:
249 | dsdx2_1 = reals/s2_j1
250 | dsdy2_1 = imags/s2_j1
251 | q = s2_j1.shape
252 | s2_j1 = s2_j1.view(q[0], 36, q[3], q[4])
253 | s2_j1 = s2_j1 - bias
254 | s1_j1 = F.avg_pool2d(s1_j1, 2)
255 | if x.requires_grad:
256 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b,
257 | dsdx1, dsdy1, dsdx2, dsdy2,
258 | dsdx2_1, dsdy2_1)
259 | else:
260 | z = x.new_zeros(1)
261 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b,
262 | z, z, z, z, z, z)
263 |
264 | del reals, imags
265 | Z = torch.cat((s0, s1_j1, s1_j2[:,:,0], s2_j1), dim=1)
266 |
267 | else:
268 | s1_j1 = torch.sqrt(reals**2 + imags**2 + bias**2)
269 | if x.requires_grad:
270 | dsdx1 = reals/s1_j1
271 | dsdy1 = imags/s1_j1
272 | s1_j1 = s1_j1 - bias
273 |
274 | s0, reals, imags = fwd_j2plus(s0, h0a, h1a, h0b, h1b, False, 1, mode)
275 | s1_j2 = torch.sqrt(reals**2 + imags**2 + bias**2)
276 | if x.requires_grad:
277 | dsdx2 = reals/s1_j2
278 | dsdy2 = imags/s1_j2
279 | s1_j2 = s1_j2 - bias
280 | s0 = F.avg_pool2d(s0, 2)
281 |
282 | # Second order scattering
283 | p = s1_j1.shape
284 | s1_j1 = s1_j1.view(p[0], 6*p[2], p[3], p[4])
285 |
286 | s1_j1, reals, imags = fwd_j1(s1_j1, h0o, h1o, False, 1, mode)
287 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2)
288 | if x.requires_grad:
289 | dsdx2_1 = reals/s2_j1
290 | dsdy2_1 = imags/s2_j1
291 | q = s2_j1.shape
292 | s2_j1 = s2_j1.view(q[0], 36, q[2]//6, q[3], q[4])
293 | s2_j1 = s2_j1 - bias
294 | s1_j1 = F.avg_pool2d(s1_j1, 2)
295 | s1_j1 = s1_j1.view(p[0], 6, p[2], p[3]//2, p[4]//2)
296 |
297 | if x.requires_grad:
298 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b,
299 | dsdx1, dsdy1, dsdx2, dsdy2,
300 | dsdx2_1, dsdy2_1)
301 | else:
302 | z = x.new_zeros(1)
303 | ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b,
304 | z, z, z, z, z, z)
305 |
306 | del reals, imags
307 | Z = torch.cat((s0[:, None], s1_j1, s1_j2, s2_j1), dim=1)
308 |
309 | return Z
310 |
311 | @staticmethod
312 | def backward(ctx, dZ):
313 | dX = None
314 | mode = ctx.mode
315 |
316 | if ctx.needs_input_grad[0]:
317 | # Input has shape N, L, C, H, W
318 | o_dim = 1
319 | h_dim = 3
320 | w_dim = 4
321 |
322 | # Retrieve phase info
323 | (h0o, h1o, h0a, h0b, h1a, h1b, dsdx1, dsdy1, dsdx2, dsdy2, dsdx2_1,
324 | dsdy2_1) = ctx.saved_tensors
325 |
326 | # Use the special properties of the filters to get the time reverse
327 | h0o_t = h0o
328 | h1o_t = h1o
329 | h0a_t = h0b
330 | h0b_t = h0a
331 | h1a_t = h1b
332 | h1b_t = h1a
333 |
334 | # Level 1 backward (time reversed biorthogonal analysis filters)
335 | if ctx.combine_colour:
336 | ds0, ds1_j1, ds1_j2, ds2_j1 = \
337 | dZ[:,:3], dZ[:,3:9], dZ[:,9:15], dZ[:,15:]
338 | ds1_j2 = ds1_j2[:, :, None]
339 |
340 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest")
341 | q = ds2_j1.shape
342 | ds2_j1 = ds2_j1.view(q[0], 6, 6, q[2], q[3])
343 |
344 | # Inverse second order scattering
345 | reals = ds2_j1 * dsdx2_1
346 | imags = ds2_j1 * dsdy2_1
347 | ds1_j1 = inv_j1(
348 | ds1_j1, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode)
349 | ds1_j1 = ds1_j1[:, :, None]
350 |
351 | # Inverse first order scattering j=2
352 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest")
353 | # s = ds1_j2.shape
354 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3])
355 | reals = ds1_j2 * dsdx2
356 | imags = ds1_j2 * dsdy2
357 | ds0 = inv_j2plus(
358 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t,
359 | o_dim, h_dim, w_dim, mode)
360 |
361 | # Inverse first order scattering j=1
362 | reals = ds1_j1 * dsdx1
363 | imags = ds1_j1 * dsdy1
364 | dX = inv_j1(
365 | ds0, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode)
366 | else:
367 | ds0, ds1_j1, ds1_j2, ds2_j1 = \
368 | dZ[:,0], dZ[:,1:7], dZ[:,7:13], dZ[:,13:]
369 | p = ds1_j1.shape
370 | ds1_j1 = ds1_j1.view(p[0], p[2]*6, p[3], p[4])
371 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest")
372 | q = ds2_j1.shape
373 | ds2_j1 = ds2_j1.view(q[0], 6, q[2]*6, q[3], q[4])
374 |
375 | # Inverse second order scattering
376 | reals = ds2_j1 * dsdx2_1
377 | imags = ds2_j1 * dsdy2_1
378 | ds1_j1 = inv_j1(
379 | ds1_j1, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode)
380 | ds1_j1 = ds1_j1.view(p[0], 6, p[2], p[3]*2, p[4]*2)
381 |
382 | # Inverse first order scattering j=2
383 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest")
384 | # s = ds1_j2.shape
385 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3])
386 | reals = ds1_j2 * dsdx2
387 | imags = ds1_j2 * dsdy2
388 | ds0 = inv_j2plus(
389 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t,
390 | o_dim, h_dim, w_dim, mode)
391 |
392 | # Inverse first order scattering j=1
393 | reals = ds1_j1 * dsdx1
394 | imags = ds1_j1 * dsdy1
395 | dX = inv_j1(
396 | ds0, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode)
397 |
398 | return (dX,) + (None,) * 9
399 |
400 |
401 | class ScatLayerj2_rot_f(torch.autograd.Function):
402 | """ Function to do forward and backward passes of a single scattering
403 | layer with the DTCWT bandpass biorthogonal and qshift filters . """
404 |
405 | @staticmethod
406 | def forward(ctx, x, h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, mode, bias, combine_colour):
407 | # bias = 1e-2
408 | # bias = 0
409 | ctx.in_shape = x.shape
410 | batch, ch, r, c = x.shape
411 | assert r % 8 == c % 8 == 0
412 | mode = int_to_mode(mode)
413 | ctx.mode = mode
414 | ctx.combine_colour = combine_colour
415 |
416 | # First order scattering
417 | s0, reals, imags = fwd_j1_rot(x, h0o, h1o, h2o, False, 1, mode)
418 | if combine_colour:
419 | s1_j1 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 +
420 | reals[:,:,1]**2 + imags[:,:,1]**2 +
421 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2)
422 | s1_j1 = s1_j1[:, :, None]
423 | if x.requires_grad:
424 | dsdx1 = reals/s1_j1
425 | dsdy1 = imags/s1_j1
426 | s1_j1 = s1_j1 - bias
427 |
428 | s0, reals, imags = fwd_j2plus_rot(s0, h0a, h1a, h0b, h1b, h2a, h2b, False, 1, mode)
429 | s1_j2 = torch.sqrt(reals[:,:,0]**2 + imags[:,:,0]**2 +
430 | reals[:,:,1]**2 + imags[:,:,1]**2 +
431 | reals[:,:,2]**2 + imags[:,:,2]**2 + bias**2)
432 | s1_j2 = s1_j2[:, :, None]
433 | if x.requires_grad:
434 | dsdx2 = reals/s1_j2
435 | dsdy2 = imags/s1_j2
436 | s1_j2 = s1_j2 - bias
437 | s0 = F.avg_pool2d(s0, 2)
438 |
439 | # Second order scattering
440 | s1_j1 = s1_j1[:, :, 0]
441 | s1_j1, reals, imags = fwd_j1_rot(s1_j1, h0o, h1o, h2o, False, 1, mode)
442 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2)
443 | if x.requires_grad:
444 | dsdx2_1 = reals/s2_j1
445 | dsdy2_1 = imags/s2_j1
446 | q = s2_j1.shape
447 | s2_j1 = s2_j1.view(q[0], 36, q[3], q[4])
448 | s2_j1 = s2_j1 - bias
449 | s1_j1 = F.avg_pool2d(s1_j1, 2)
450 | if x.requires_grad:
451 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b,
452 | dsdx1, dsdy1, dsdx2, dsdy2, dsdx2_1,
453 | dsdy2_1)
454 | else:
455 | z = x.new_zeros(1)
456 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b,
457 | z, z, z, z, z, z)
458 |
459 | del reals, imags
460 | Z = torch.cat((s0, s1_j1, s1_j2[:, :, 0], s2_j1), dim=1)
461 | else:
462 | s1_j1 = torch.sqrt(reals**2 + imags**2 + bias**2)
463 | if x.requires_grad:
464 | dsdx1 = reals/s1_j1
465 | dsdy1 = imags/s1_j1
466 | s1_j1 = s1_j1 - bias
467 |
468 | s0, reals, imags = fwd_j2plus_rot(s0, h0a, h1a, h0b, h1b, h2a, h2b, False, 1, mode)
469 | s1_j2 = torch.sqrt(reals**2 + imags**2 + bias**2)
470 | if x.requires_grad:
471 | dsdx2 = reals/s1_j2
472 | dsdy2 = imags/s1_j2
473 | s1_j2 = s1_j2 - bias
474 | s0 = F.avg_pool2d(s0, 2)
475 |
476 | # Second order scattering
477 | p = s1_j1.shape
478 | s1_j1 = s1_j1.view(p[0], 6*p[2], p[3], p[4])
479 | s1_j1, reals, imags = fwd_j1_rot(s1_j1, h0o, h1o, h2o, False, 1, mode)
480 | s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2)
481 | if x.requires_grad:
482 | dsdx2_1 = reals/s2_j1
483 | dsdy2_1 = imags/s2_j1
484 | q = s2_j1.shape
485 | s2_j1 = s2_j1.view(q[0], 36, q[2]//6, q[3], q[4])
486 | s2_j1 = s2_j1 - bias
487 | s1_j1 = F.avg_pool2d(s1_j1, 2)
488 | s1_j1 = s1_j1.view(p[0], 6, p[2], p[3]//2, p[4]//2)
489 |
490 | if x.requires_grad:
491 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b,
492 | dsdx1, dsdy1, dsdx2, dsdy2, dsdx2_1,
493 | dsdy2_1)
494 | else:
495 | z = x.new_zeros(1)
496 | ctx.save_for_backward(h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b,
497 | z, z, z, z, z, z)
498 |
499 | del reals, imags
500 | Z = torch.cat((s0[:, None], s1_j1, s1_j2, s2_j1), dim=1)
501 |
502 | return Z
503 |
504 | @staticmethod
505 | def backward(ctx, dZ):
506 | dX = None
507 | mode = ctx.mode
508 |
509 | if ctx.needs_input_grad[0]:
510 | # Input has shape N, L, C, H, W
511 | o_dim = 1
512 | h_dim = 3
513 | w_dim = 4
514 |
515 | # Retrieve phase info
516 | (h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, dsdx1, dsdy1, dsdx2,
517 | dsdy2, dsdx2_1, dsdy2_1) = ctx.saved_tensors
518 |
519 | # Use the special properties of the filters to get the time reverse
520 | h0o_t = h0o
521 | h1o_t = h1o
522 | h2o_t = h2o
523 | h0a_t = h0b
524 | h0b_t = h0a
525 | h1a_t = h1b
526 | h1b_t = h1a
527 | h2a_t = h2b
528 | h2b_t = h2a
529 |
530 | # Level 1 backward (time reversed biorthogonal analysis filters)
531 | if ctx.combine_colour:
532 | ds0, ds1_j1, ds1_j2, ds2_j1 = \
533 | dZ[:,:3], dZ[:,3:9], dZ[:,9:15], dZ[:,15:]
534 | ds1_j2 = ds1_j2[:, :, None]
535 |
536 | # Inverse second order scattering
537 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest")
538 | q = ds2_j1.shape
539 | ds2_j1 = ds2_j1.view(q[0], 6, 6, q[2], q[3])
540 |
541 | # Inverse second order scattering
542 | reals = ds2_j1 * dsdx2_1
543 | imags = ds2_j1 * dsdy2_1
544 | ds1_j1 = inv_j1_rot(
545 | ds1_j1, reals, imags, h0o_t, h1o_t, h2o_t,
546 | o_dim, h_dim, w_dim, mode)
547 | ds1_j1 = ds1_j1[:, :, None]
548 |
549 | # Inverse first order scattering j=2
550 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest")
551 | # s = ds1_j2.shape
552 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3])
553 | reals = ds1_j2 * dsdx2
554 | imags = ds1_j2 * dsdy2
555 | ds0 = inv_j2plus_rot(
556 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t, h2a_t, h2b_t,
557 | o_dim, h_dim, w_dim, mode)
558 |
559 | # Inverse first order scattering j=1
560 | reals = ds1_j1 * dsdx1
561 | imags = ds1_j1 * dsdy1
562 | dX = inv_j1_rot(
563 | ds0, reals, imags, h0o_t, h1o_t, h2o_t,
564 | o_dim, h_dim, w_dim, mode)
565 | else:
566 | ds0, ds1_j1, ds1_j2, ds2_j1 = \
567 | dZ[:,0], dZ[:,1:7], dZ[:,7:13], dZ[:,13:]
568 |
569 | # Inverse second order scattering
570 | p = ds1_j1.shape
571 | ds1_j1 = ds1_j1.view(p[0], p[2]*6, p[3], p[4])
572 | ds1_j1 = 1/4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest")
573 | q = ds2_j1.shape
574 | ds2_j1 = ds2_j1.view(q[0], 6, q[2]*6, q[3], q[4])
575 | reals = ds2_j1 * dsdx2_1
576 | imags = ds2_j1 * dsdy2_1
577 | ds1_j1 = inv_j1_rot(
578 | ds1_j1, reals, imags, h0o_t, h1o_t, h2o_t,
579 | o_dim, h_dim, w_dim, mode)
580 | ds1_j1 = ds1_j1.view(p[0], 6, p[2], p[3]*2, p[4]*2)
581 |
582 | # Inverse first order scattering j=2
583 | ds0 = 1/4 * F.interpolate(ds0, scale_factor=2, mode="nearest")
584 | # s = ds1_j2.shape
585 | # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3])
586 | reals = ds1_j2 * dsdx2
587 | imags = ds1_j2 * dsdy2
588 | ds0 = inv_j2plus_rot(
589 | ds0, reals, imags, h0a_t, h1a_t, h0b_t, h1b_t, h2a_t, h2b_t,
590 | o_dim, h_dim, w_dim, mode)
591 |
592 | # Inverse first order scattering j=1
593 | reals = ds1_j1 * dsdx1
594 | imags = ds1_j1 * dsdy1
595 | dX = inv_j1_rot(
596 | ds0, reals, imags, h0o_t, h1o_t, h2o_t,
597 | o_dim, h_dim, w_dim, mode)
598 |
599 | return (dX,) + (None,) * 12
600 |
--------------------------------------------------------------------------------
/pytorch_wavelets/utils.py:
--------------------------------------------------------------------------------
1 | """ Useful utilities for testing the 2-D DTCWT with synthetic images"""
2 |
3 | from __future__ import absolute_import
4 |
5 | import functools
6 | import numpy as np
7 |
8 |
9 | def unpack(pyramid, backend='numpy'):
10 | """ Unpacks a pyramid give back the constituent parts.
11 |
12 | :param pyramid: The Pyramid of DTCWT transforms you wish to unpack
13 | :param str backend: A string from 'numpy', 'opencl', or 'tf' indicating
14 | which attributes you want to unpack from the pyramid.
15 |
16 | :returns: returns a generator which can be unpacked into the Yl, Yh and
17 | Yscale components of the pyramid. The generator will only return 2
18 | values if the pyramid was created with the include_scale parameter set
19 | to false.
20 |
21 | .. note::
22 |
23 | You can still unpack a tf or opencl pyramid as if it were created by a
24 | numpy. In this case it will return a numpy array, rather than the
25 | backend specific array type.
26 | """
27 | backend = backend.lower()
28 | if backend == 'numpy':
29 | yield pyramid.lowpass
30 | yield pyramid.highpasses
31 | if pyramid.scales is not None:
32 | yield pyramid.scales
33 | elif backend == 'opencl':
34 | yield pyramid.cl_lowpass
35 | yield pyramid.cl_highpasses
36 | if pyramid.cl_scales is not None:
37 | yield pyramid.cl_scales
38 | elif backend == 'tf':
39 | yield pyramid.lowpass_op
40 | yield pyramid.highpasses_ops
41 | if pyramid.scales_ops is not None:
42 | yield pyramid.scales_ops
43 |
44 |
45 | def drawedge(theta,r,w,N):
46 | """Generate an image of size N * N pels, of an edge going from 0 to 1 in
47 | height at theta degrees to the horizontal (top of image = 1 if angle = 0).
48 | r is a two-element vector, it is a coordinate in ij coords through which the
49 | step should pass.
50 | The shape of the intensity step is half a raised cosine w pels wide (w>=1).
51 |
52 | T. E . Gale's enhancement to drawedge() for MATLAB, transliterated
53 | to Python by S. C. Forshaw, Nov. 2013. """
54 |
55 | # convert theta from degrees to radians
56 | thetar = np.array(theta * np.pi / 180)
57 |
58 | # Calculate image centre from given width
59 | imCentre = (np.array([N,N]).T - 1) / 2 + 1
60 |
61 | # Calculate values to subtract from the plane
62 | r = np.array([np.cos(thetar), np.sin(thetar)])*(-1) * (r - imCentre)
63 |
64 | # check width of raised cosine section
65 | w = np.maximum(1,w)
66 |
67 | ramp = np.arange(0,N) - (N+1)/2
68 | hgrad = np.sin(thetar)*(-1) * np.ones([N,1])
69 | vgrad = np.cos(thetar)*(-1) * np.ones([1,N])
70 | plane = ((hgrad * ramp) - r[0]) + ((ramp * vgrad).T - r[1])
71 | x = 0.5 + 0.5 * np.sin(np.minimum(np.maximum(
72 | plane*(np.pi/w), np.pi/(-2)), np.pi/2))
73 |
74 | return x
75 |
76 |
77 | def drawcirc(r,w,du,dv,N):
78 |
79 | """Generate an image of size N*N pels, containing a circle
80 | radius r pels and centred at du,dv relative
81 | to the centre of the image. The edge of the circle is a cosine shaped
82 | edge of width w (from 10 to 90% points).
83 |
84 | Python implementation by S. C. Forshaw, November 2013."""
85 |
86 | # check value of w to avoid dividing by zero
87 | w = np.maximum(w,1)
88 |
89 | # x plane
90 | x = np.ones([N,1]) * ((np.arange(0,N,1, dtype='float') -
91 | (N+1) / 2 - dv) / r)
92 |
93 | # y vector
94 | y = (((np.arange(0,N,1, dtype='float') - (N+1) / 2 - du) / r) *
95 | np.ones([1,N])).T
96 |
97 | # Final circle image plane
98 | p = 0.5 + 0.5 * np.sin(np.minimum(np.maximum((
99 | np.exp(np.array([-0.5]) * (x**2 + y**2)).T - np.exp((-0.5))) * (r * 3 / w), # noqa
100 | np.pi/(-2)), np.pi/2))
101 | return p
102 |
103 |
104 | def asfarray(X):
105 | """Similar to :py:func:`numpy.asfarray` except that this function tries to
106 | preserve the original datatype of X if it is already a floating point type
107 | and will pass floating point arrays through directly without copying.
108 |
109 | """
110 | X = np.asanyarray(X)
111 | return np.asfarray(X, dtype=X.dtype)
112 |
113 |
114 | def appropriate_complex_type_for(X):
115 | """Return an appropriate complex data type depending on the type of X. If X
116 | is already complex, return that, if it is floating point return a complex
117 | type of the appropriate size and if it is integer, choose an complex
118 | floating point type depending on the result of :py:func:`numpy.asfarray`.
119 |
120 | """
121 | X = asfarray(X)
122 |
123 | if np.issubsctype(X.dtype, np.complex64) or \
124 | np.issubsctype(X.dtype, np.complex128):
125 | return X.dtype
126 | elif np.issubsctype(X.dtype, np.float32):
127 | return np.complex64
128 | elif np.issubsctype(X.dtype, np.float64):
129 | return np.complex128
130 |
131 | # God knows, err on the side of caution
132 | return np.complex128
133 |
134 |
135 | def as_column_vector(v):
136 | """Return *v* as a column vector with shape (N,1).
137 |
138 | """
139 | v = np.atleast_2d(v)
140 | if v.shape[0] == 1:
141 | return v.T
142 | else:
143 | return v
144 |
145 |
146 | def reflect(x, minx, maxx):
147 | """Reflect the values in matrix *x* about the scalar values *minx* and
148 | *maxx*. Hence a vector *x* containing a long linearly increasing series is
149 | converted into a waveform which ramps linearly up and down between *minx*
150 | and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
151 | 0.5), the ramps will have repeated max and min samples.
152 |
153 | .. codeauthor:: Rich Wareham , Aug 2013
154 | .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
155 |
156 | """
157 | x = np.asanyarray(x)
158 | rng = maxx - minx
159 | rng_by_2 = 2 * rng
160 | mod = np.fmod(x - minx, rng_by_2)
161 | normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
162 | out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
163 | return np.array(out, dtype=x.dtype)
164 |
165 |
166 | def symm_pad_1d(l, m):
167 | """ Creates indices for symmetric padding. Works for 1-D.
168 |
169 | Inptus:
170 | l (int): size of input
171 | m (int): size of filter
172 | """
173 | xe = reflect(np.arange(-m, l+m, dtype='int32'), -0.5, l-0.5)
174 | return xe
175 |
176 |
177 | # note that this decorator ignores **kwargs
178 | # From https://wiki.python.org/moin/PythonDecoratorLibrary#Alternate_memoize_as_nested_functions # noqa
179 | def memoize(obj):
180 | cache = obj.cache = {}
181 |
182 | @functools.wraps(obj)
183 | def memoizer(*args, **kwargs):
184 | if args not in cache:
185 | cache[args] = obj(*args, **kwargs)
186 | return cache[args]
187 | return memoizer
188 |
189 |
190 | def stacked_2d_matrix_vector_prod(mats, vecs):
191 | """
192 | Interpret *mats* and *vecs* as arrays of 2D matrices and vectors. I.e.
193 | *mats* has shape PxQxNxM and *vecs* has shape PxQxM. The result
194 | is a PxQxN array equivalent to:
195 |
196 | .. code::
197 |
198 | result[i,j,:] = mats[i,j,:,:].dot(vecs[i,j,:])
199 |
200 | for all valid row and column indices *i* and *j*.
201 | """
202 | return np.einsum('...ij,...j->...i', mats, vecs)
203 |
204 |
205 | def stacked_2d_vector_matrix_prod(vecs, mats):
206 | """
207 | Interpret *mats* and *vecs* as arrays of 2D matrices and vectors. I.e.
208 | *mats* has shape PxQxNxM and *vecs* has shape PxQxN. The result
209 | is a PxQxM array equivalent to:
210 |
211 | .. code::
212 |
213 | result[i,j,:] = mats[i,j,:,:].T.dot(vecs[i,j,:])
214 |
215 | for all valid row and column indices *i* and *j*.
216 | """
217 | vecshape = np.array(vecs.shape + (1,))
218 | vecshape[-1:-3:-1] = vecshape[-2:]
219 | outshape = mats.shape[:-2] + (mats.shape[-1],)
220 | return stacked_2d_matrix_matrix_prod(vecs.reshape(vecshape), mats).reshape(outshape) # noqa
221 |
222 |
223 | def stacked_2d_matrix_matrix_prod(mats1, mats2):
224 | """
225 | Interpret *mats1* and *mats2* as arrays of 2D matrices. I.e.
226 | *mats1* has shape PxQxNxM and *mats2* has shape PxQxMxR. The result
227 | is a PxQxNxR array equivalent to:
228 |
229 | .. code::
230 |
231 | result[i,j,:,:] = mats1[i,j,:,:].dot(mats2[i,j,:,:])
232 |
233 | for all valid row and column indices *i* and *j*.
234 | """
235 | return np.einsum('...ij,...jk->...ik', mats1, mats2)
236 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import math
6 | import shutil
7 | import torch
8 | from model import WTFTP, WTFTP_AttnRemoved
9 | from torch.utils.data import DataLoader
10 | from dataloader import DataGenerator
11 | from utils import progress_bar, wt_coef_len, recorder
12 | import argparse
13 | import logging
14 | import datetime
15 | from tensorboardX import SummaryWriter
16 | from pytorch_wavelets import DWT1DForward, DWT1DInverse
17 |
18 |
19 | class Train:
20 | def __init__(self, opt, SEED=None, tracking=False):
21 | if SEED is not None:
22 | torch.manual_seed(SEED)
23 | if tracking:
24 | self.dev_recorder = recorder('dev_loss')
25 | self.train_recorder = recorder('train_loss', 'temporal_loss', 'freq_loss')
26 | self.SEED = SEED
27 | self.opt = opt
28 | self.iscuda = torch.cuda.is_available()
29 | self.device = torch.device(f'cuda:{self.opt.cuda}' if self.iscuda and not opt.cpu else 'cpu')
30 | self.data_set = DataGenerator(data_path=self.opt.datadir,
31 | minibatch_len=opt.minibatch_len, interval=opt.interval,
32 | use_preset_data_ranges=False)
33 |
34 | if self.opt.attn:
35 | self.net = WTFTP(
36 | n_inp=6,
37 | n_oup=6,
38 | his_step=opt.minibatch_len - 1,
39 | n_embding=opt.embding,
40 | en_layers=opt.enlayer,
41 | de_layers=opt.delayer,
42 | activation='relu',
43 | proj='linear',
44 | maxlevel=opt.maxlevel,
45 | en_dropout=opt.dpot,
46 | de_dropout=opt.dpot,
47 | bias=True
48 | )
49 | else:
50 | self.net = WTFTP_AttnRemoved(n_inp=6,
51 | n_oup=6,
52 | n_embding=opt.embding,
53 | n_encoderLayers=opt.enlayer,
54 | n_decoderLayers=opt.delayer,
55 | activation='relu',
56 | proj='linear',
57 | maxlevel=opt.maxlevel,
58 | en_dropout=opt.dpot,
59 | de_dropout=opt.dpot)
60 | self.MSE = torch.nn.MSELoss(reduction='mean')
61 | self.MAE = torch.nn.L1Loss(reduction='mean')
62 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.opt.lr)
63 | self.opt_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5)
64 | if not self.opt.nologging:
65 | self.log_path = self.opt.logdir + f'/{datetime.datetime.now().strftime("%y-%m-%d")}'
66 | if self.opt.debug:
67 | self.log_path = self.opt.logdir + f'/DEBUG-{datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")}{"-" + self.opt.comments if len(self.opt.comments) > 0 else ""}'
68 | self.opt.saving_model_num = 10
69 | self.TX_log_path = self.log_path + '/Tensorboard'
70 | if not os.path.exists(self.log_path):
71 | os.makedirs(self.log_path)
72 | if os.path.exists(self.TX_log_path):
73 | shutil.rmtree(self.TX_log_path)
74 | os.mkdir(self.TX_log_path)
75 | else:
76 | os.mkdir(self.TX_log_path)
77 | logging.basicConfig(filename=os.path.join(self.log_path, 'train.log'),
78 | filemode='a', format='%(asctime)s %(message)s', level=logging.DEBUG)
79 | self.TX_logger = SummaryWriter(log_dir=self.TX_log_path)
80 |
81 | if self.opt.saving_model_num > 0:
82 | self.model_names = []
83 |
84 | def train(self):
85 | self.net.to(self.device)
86 | self.net.args['train_opt'] = self.opt
87 | if not self.opt.nologging:
88 | log_str = f'TRAIN DETAILS'.center(50, '-') + \
89 | f'\ntraining on device {self.device}...\n'
90 | for arg in vars(self.opt):
91 | log_str += f'{arg}: {getattr(self.opt, arg)}\n'
92 | log_str += f'MODEL DETAILS'.center(50, '-') + '\n'
93 | for key in self.net.args:
94 | log_str += f'{key}: {self.net.args[key]}\n'
95 | logging.debug(10 * '\n' + f'Beginning of the training epoch'.center(50, '-'))
96 | logging.debug(log_str)
97 | print(log_str)
98 |
99 | for i in range(self.opt.epoch):
100 | print('\n' + f'Epoch {i+1}/{self.opt.epoch}'.center(50, '-'))
101 | if not self.opt.nologging:
102 | logging.debug('\n' + f'Epoch {i+1}/{self.opt.epoch}'.center(50, '-'))
103 | print(f'lr: {float(self.opt_lr_scheduler.get_last_lr()[0])}')
104 | self.train_each_epoch(i+1)
105 | self.saving_model(self.log_path, f'epoch_{i+1}.pt')
106 | self.opt_lr_scheduler.step()
107 | self.dev_each_epoch(i+1)
108 | if not self.opt.nologging:
109 | logging.debug(10 * '\n' + f'End of the training epoch'.center(50, '-'))
110 | if hasattr(self, 'dev_recorder'):
111 | self.dev_recorder.push(self.log_path, 'dev_recorder.pt')
112 | if hasattr(self, 'train_recorder'):
113 | self.train_recorder.push(self.log_path, 'train_recorder.pt')
114 |
115 | def train_each_epoch(self, epoch):
116 | train_data = DataLoader(dataset=self.data_set.train_set, batch_size=self.opt.batch_size, shuffle=True,
117 | collate_fn=self.data_set.collate, num_workers=0, pin_memory=self.iscuda)
118 | self.net.train()
119 | dwt = DWT1DForward(wave=self.opt.wavelet, J=self.opt.maxlevel, mode=self.opt.wt_mode).to(self.device)
120 | idwt = DWT1DInverse(wave=self.opt.wavelet, mode=self.opt.wt_mode).to(self.device)
121 | start_time = time.perf_counter()
122 | batchs_len = len(train_data)
123 | train_loss_set = []
124 | temporal_loss_set = []
125 | freq_loss_set = []
126 | for i, batch in enumerate(train_data):
127 | batch = torch.FloatTensor(batch).to(self.device)
128 | inp_batch = batch[:, :-1, :] # shape: batch * n_sequence * n_attr
129 | train_wt_loss = torch.tensor(0.0).to(self.device)
130 | tgt_batch = batch # use the all to calculate wavelet coefficients, shape: batch * n_sequence * n_attr
131 | wt_tgt_batch = dwt(
132 | batch.transpose(1, 2).contiguous()) # tuple(lo, hi), shape: batch * n_attr * n_sequence
133 | if self.opt.attn:
134 | wt_pre_batch, score_set = self.net(inp_batch)
135 | else:
136 | wt_pre_batch = self.net(inp_batch)
137 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(),
138 | [comp.transpose(1, 2).contiguous() for comp in
139 | wt_pre_batch[:-1]])).contiguous() # shape: batch * n_attr * n_sequence
140 | pre_batch = pre_batch.transpose(1, 2) # shape: batch * n_sequence * n_attr
141 | train_temporal_loss, loss_details = self.cal_spatial_loss(tgt_batch, pre_batch)
142 | train_wt_loss = self.cal_freq_loss(wt_tgt_batch, wt_pre_batch)
143 | loss_backward = torch.tensor(self.opt.w_spatial, dtype=torch.float,
144 | device=self.device) * train_temporal_loss + train_wt_loss
145 | self.optimizer.zero_grad()
146 | loss_backward.backward()
147 | self.optimizer.step()
148 | with torch.no_grad():
149 | train_loss_set.append(float(loss_backward.detach().cpu()))
150 | temporal_loss_set.append(float(train_temporal_loss.detach().cpu()))
151 | freq_loss_set.append(float(train_wt_loss.detach().cpu()))
152 | # print loss
153 | print_str = f'train_loss(scaled): {loss_backward.item():.8f}, temporal_loss: {train_temporal_loss.item():.8f}, ' \
154 | f'freq_loss: {train_wt_loss.item():.8f}'
155 | progress_bar(i, batchs_len, print_str, start_time)
156 | record_freq = 20
157 | if not self.opt.nologging and (i % ((batchs_len - 1) // record_freq) == 0 or i == batchs_len - 1):
158 | logging.debug(f'{i}/{batchs_len - 1} ' + print_str)
159 | print_str = f'ave_train_loss: {np.mean(train_loss_set):.8f}, ave_temporal_loss: {np.mean(temporal_loss_set):.8f}' \
160 | + f', ave_freq_loss: {np.mean(freq_loss_set):.8f}'
161 | print(print_str)
162 | if not self.opt.nologging:
163 | logging.debug(print_str)
164 | self.TX_logger.add_scalar('train_loss', np.mean(train_loss_set), global_step=epoch)
165 | self.TX_logger.add_scalar('train_temporal_loss', np.mean(temporal_loss_set), global_step=epoch)
166 | self.TX_logger.add_scalar('train_freq_loss', np.mean(freq_loss_set), global_step=epoch)
167 | if hasattr(self, 'train_recorder'):
168 | self.train_recorder.add('train_loss', np.mean(train_loss_set))
169 | self.train_recorder.add('temporal_loss', np.mean(temporal_loss_set))
170 | self.train_recorder.add('freq_loss', np.mean(freq_loss_set))
171 | if not self.opt.nologging and self.opt.attn:
172 | project = plt.cm.get_cmap('YlGnBu')
173 | with torch.no_grad():
174 | self.TX_logger.add_image(f'score',
175 | project(torch.mean(score_set.clone().cpu(), dim=0).numpy())[:, :, :-1],
176 | dataformats='HWC',
177 | global_step=epoch)
178 |
179 | def dev_each_epoch(self, epoch):
180 | dev_data = DataLoader(dataset=self.data_set.dev_set, batch_size=self.opt.batch_size, shuffle=False,
181 | collate_fn=self.data_set.collate, num_workers=0, pin_memory=self.iscuda)
182 | self.net.eval()
183 | idwt = DWT1DInverse(wave=self.opt.wavelet, mode=self.opt.wt_mode).to(self.device)
184 | tgt_set = []
185 | pre_set = []
186 | wt_coef_length = wt_coef_len(self.opt.minibatch_len, wavelet=self.opt.wavelet, mode=self.opt.wt_mode,
187 | maxlevel=self.opt.maxlevel)
188 | with torch.no_grad():
189 | for i, batch in enumerate(dev_data):
190 | batch = torch.FloatTensor(batch).to(self.device)
191 | n_batch, _, n_attr = batch.shape
192 | inp_batch = batch[:, :-1, :] # shape: batch * n_sequence * n_attr
193 | tgt_batch = batch
194 | if self.opt.attn:
195 | wt_pre_batch, score_set = self.net(inp_batch)
196 | else:
197 | wt_pre_batch = self.net(inp_batch)
198 | pre_batch = idwt((wt_pre_batch[-1].transpose(1, 2).contiguous(),
199 | [comp.transpose(1, 2).contiguous() for comp in
200 | wt_pre_batch[:-1]])).contiguous() # shape: batch * n_attr * n_sequence
201 | pre_batch = pre_batch.transpose(1, 2) # shape: batch * n_sequence * n_attr
202 | tgt_set.append(tgt_batch)
203 | pre_set.append(pre_batch)
204 | tgt_set = torch.cat(tgt_set, dim=0)
205 | pre_set = torch.cat(pre_set, dim=0)
206 | dev_loss, loss_details = self.cal_spatial_loss(tgt_set, pre_set, is_training=False)
207 | print_str = f'Evaluation-Stage:\n' \
208 | f'aveMSE(scaled): {dev_loss:.8f}, in each attr(RMSE, unscaled): {loss_details["rmse"]}\n' \
209 | f'aveMAE(scaled): None, in each attr(MAE, unscaled): {loss_details["mae"]}'
210 | print(print_str)
211 | if not self.opt.nologging:
212 | logging.debug(print_str)
213 | self.TX_logger.add_scalar('eval_aveMSE', dev_loss, global_step=epoch)
214 | if hasattr(self, 'dev_recorder'):
215 | self.dev_recorder.add('dev_loss', dev_loss)
216 |
217 | def cal_spatial_loss(self, tgt, pre, is_training=True):
218 | """
219 | :param tgt: shape: batch * n_sequence * n_attr
220 | :param pre: shape: batch * n_sequence+? * n_attr
221 | :return:
222 | """
223 | n_sequence = tgt.shape[1]
224 | last_node_weight = 1.0
225 | if is_training:
226 | weights = torch.ones(n_sequence, dtype=torch.float, device=tgt.device)
227 | else:
228 | weights = torch.zeros(n_sequence, dtype=torch.float, device=tgt.device)
229 | weights[-1] = last_node_weight
230 | weighted_loss = torch.tensor(0.0).to(self.device)
231 | for i in range(n_sequence):
232 | weighted_loss += weights[i] * self.MSE(pre[:, i, :], tgt[:, i, :])
233 | try:
234 | weighted_loss /= weights.count_nonzero()
235 | except:
236 | if is_training:
237 | weighted_loss /= torch.tensor(n_sequence, dtype=torch.float).to(self.device)
238 | else:
239 | weighted_loss /= torch.tensor(1.0, dtype=torch.float).to(self.device)
240 | loss_unscaled = {'rmse': {}, 'mae': {}}
241 | tgt_cloned = tgt.detach()
242 | pre_cloned = pre.detach()
243 | for i, name in enumerate(self.data_set.attr_names):
244 | loss_unscaled['rmse'][name] = math.sqrt(
245 | float(self.MSE(self.data_set.unscale(tgt_cloned[:, n_sequence - 1, i], name),
246 | self.data_set.unscale(pre_cloned[:, n_sequence - 1, i], name))))
247 | loss_unscaled['mae'][name] = float(self.MAE(self.data_set.unscale(tgt_cloned[:, n_sequence - 1, i], name),
248 | self.data_set.unscale(pre_cloned[:, n_sequence - 1, i], name)))
249 | return weighted_loss, loss_unscaled
250 |
251 | def cal_freq_loss(self, tgt, pre) -> torch.Tensor:
252 | """
253 | :param tgt: tuple:(lo, hi) shape: batch * n_attr * n_sequence
254 | :param pre: list:[hi's lo] shape: batch * n_attr * n_sequence
255 | :return:
256 | """
257 | wt_loss = torch.tensor(0.0).to(self.device)
258 | wt_loss += self.opt.w_lo * self.MSE(tgt[0], pre[-1].transpose(1, 2))
259 | for i in range(self.opt.maxlevel):
260 | wt_loss += self.opt.w_hi * self.MSE(pre[i].transpose(1, 2), tgt[1][i])
261 | return wt_loss
262 |
263 | def saving_model(self, model_path, this_model_name):
264 | if self.opt.saving_model_num > 0:
265 | self.model_names.append(this_model_name)
266 | if len(self.model_names) > self.opt.saving_model_num:
267 | removed_model_name = self.model_names[0]
268 | del self.model_names[0]
269 | os.remove(os.path.join(model_path, removed_model_name)) # remove the oldest model
270 | torch.save(self.net, os.path.join(model_path, this_model_name)) # save the latest model
271 |
272 | if __name__ == '__main__':
273 | parser = argparse.ArgumentParser()
274 | parser.add_argument('--minibatch_len', default=10, type=int)
275 | parser.add_argument('--interval', default=1, type=int)
276 | parser.add_argument('--batch_size', default=2048, type=int)
277 | parser.add_argument('--epoch', default=150, type=int)
278 | parser.add_argument('--lr', default=0.001, type=float)
279 | parser.add_argument('--dpot', default=0.0, type=float)
280 | parser.add_argument('--cpu', action='store_true')
281 | parser.add_argument('--nologging', action='store_true')
282 | parser.add_argument('--logdir', default='./log', type=str)
283 | parser.add_argument('--datadir', default='./data', type=str)
284 | parser.add_argument('--bidirectional', action='store_true')
285 | parser.add_argument('--saving_model_num', default=0, type=int)
286 | parser.add_argument('--debug', action='store_true',
287 | help='logs saving in an independent dir and 10 models being saved')
288 | parser.add_argument('--maxlevel', default=1, type=int)
289 | parser.add_argument('--wavelet', default='haar', type=str)
290 | parser.add_argument('--L2details', action='store_true',
291 | help='L2 regularization for detail coefficients to avoid them to converge to zeros')
292 | parser.add_argument('--wt_mode', default='symmetric', type=str)
293 | parser.add_argument('--w_spatial', default='0.0', type=float)
294 | parser.add_argument('--w_lo', default='1.0', type=float)
295 | parser.add_argument('--w_hi', default='1.0', type=float)
296 | parser.add_argument('--enlayer', default='4', type=int)
297 | parser.add_argument('--delayer', default='1', type=int)
298 | parser.add_argument('--embding', default='64', type=int)
299 | parser.add_argument('--attn', action='store_true')
300 | parser.add_argument('--comments', default="", type=str,
301 | help='comments in the dir name for identification')
302 | parser.add_argument('--cuda', default=0, type=int)
303 | args = parser.parse_args()
304 | train = Train(args, tracking=True)
305 | train.train()
306 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pytorch_wavelets import DWT1DForward, DWT1DInverse
4 | import torch
5 | import pywt
6 | import time
7 | from torch.utils.data import DataLoader
8 | import os
9 |
10 |
11 | def wt_packet(inp, wavelet='db2', mode='symmetric', maxlevel=1):
12 | """
13 | the last-dim vector to be decomposed
14 | :param inp: shape: batch * n_attr * n_sequence
15 | :param wavelet:
16 | :param mode:
17 | :param maxlevel:
18 | :return: oup: shape : batch * n_attr * level * n_sequence
19 | """
20 | if maxlevel == 0:
21 | return inp.unsqueeze(2)
22 | dwt = DWT1DForward(wave=wavelet, J=1, mode=mode).to(inp.device)
23 | oup = [inp]
24 | for _ in range(maxlevel):
25 | tmp = []
26 | for item in oup:
27 | lo, hi = dwt(item)
28 | tmp += [lo, hi[0]]
29 | oup = tmp
30 | oup = torch.stack(oup, dim=2)
31 | return oup
32 |
33 |
34 | def wt_packet_inverse(inp, wavelet='db2', mode='symmetric', maxlevel=1):
35 | """
36 | the last-dim vector to be composed
37 | :param inp: shape : batch * n_attr * level * n_sequence
38 | :param wavelet:
39 | :param mode:
40 | :param maxlevel:
41 | :return: oup: batch * n_attr * n_sequence
42 | """
43 | assert inp.shape[2] == 2**maxlevel
44 | if maxlevel == 0:
45 | return inp.squeeze(2)
46 | idwt = DWT1DInverse(wave=wavelet, mode=mode).to(inp.device)
47 | oup = inp
48 | for level in range(maxlevel, 0, -1):
49 | tmp = []
50 | for i in range(2**(level-1)):
51 | lo, hi = oup[:, :, 2 * i, :], oup[:, :, 2 * i + 1, :]
52 | hi = [hi]
53 | tmp.append(idwt((lo, hi)))
54 | oup = torch.stack(tmp, dim=2)
55 | oup = oup.squeeze(2)
56 | return oup
57 |
58 |
59 | def progress_bar(step, n_step, str, start_time=time.perf_counter(), bar_len=20):
60 | '''
61 | :param bar_len: length of the bar
62 | :param step: from 0 to n_step-1
63 | :param n_step: number of steps
64 | :param str: info to be printed
65 | :param start_time: time to begin the progress_bar
66 | :return:
67 | '''
68 | step = step+1
69 | a = "*" * int(step * bar_len / n_step)
70 | b = " " * (bar_len - int(step * bar_len / n_step))
71 | c = step / n_step * 100
72 | dur = time.perf_counter() - start_time
73 | print("\r{:^3.0f}%[{}{}]{:.2f}s {}".format(c, a, b, dur, str), end="")
74 | if step == n_step:
75 | print('')
76 |
77 |
78 | def wt_coef_len(in_length, wavelet, mode, maxlevel):
79 | test_inp = torch.ones((1, 1, in_length))
80 | test_oup = wt_packet(test_inp, wavelet=wavelet, mode=mode, maxlevel=maxlevel)
81 | return test_oup.shape[-1]
82 |
83 |
84 | class recorder:
85 | def __init__(self, *attrs):
86 | if attrs is not None:
87 | self.saver = {}
88 | self.attrs = attrs
89 | for attr in attrs:
90 | self.saver[attr] = []
91 | else:
92 | self.saver = {}
93 | self.attrs = attrs
94 |
95 | def add(self, attr, val):
96 | self.saver[attr].append(val)
97 |
98 | def __getitem__(self, item):
99 | return self.saver[item]
100 |
101 | def __len__(self):
102 | return len(self.attrs)
103 |
104 | def push(self, path, filename='recorde.pt'):
105 | torch.save(self.saver, os.path.join(path, filename))
106 |
107 | def pull(self, filepath):
108 | self.saver = torch.load(filepath)
109 | self.attrs = self.saver.keys()
110 |
--------------------------------------------------------------------------------