├── LICENSE
├── README.md
├── config.yml
├── data
├── __init__.py
├── batching.py
├── dm.py
├── qd.py
├── sketch.py
├── unpack_ndjson.py
└── utils.py
├── gifs
├── 0.gif
├── 1.gif
├── 2.gif
├── 3.gif
├── 4.gif
├── 5.gif
├── 6.gif
├── 7.gif
└── 8.gif
├── main.py
├── models
├── __init__.py
└── score.py
├── requirements.txt
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ayan Das
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ChiroDiff: Modelling chirographic data with Diffusion Models
2 | ### Accepted at International Conference on Learning Representation (ICLR) 2023
3 |
4 | Authors: [Ayan Das](https://ayandas.me/), [Yongxin Yang](https://yang.ac/), [Timothy Hospedales](https://homepages.inf.ed.ac.uk/thospeda/), [Tao Xiang](http://personal.ee.surrey.ac.uk/Personal/T.Xiang/index.html), [Yi-Zhe Song](http://personal.ee.surrey.ac.uk/Personal/Y.Song/)
5 |
6 |
7 | 
8 | 

9 | 
10 | 
11 |
12 |
13 |
14 |
15 |
16 | [OpenReview], [arXiv] & [Project Page]
17 |
18 |
19 |
20 | > **Abstract:** Generative modelling over continuous-time geometric constructs, a.k.a such as handwriting, sketches, drawings etc., have been accomplished through autoregressive distributions. Such strictly-ordered discrete factorization however falls short of capturing key properties of chirographic data -- it fails to build holistic understanding of the temporal concept due to one-way visibility (causality). Consequently, temporal data has been modelled as discrete token sequences of fixed sampling rate instead of capturing the true underlying concept. In this paper, we introduce a powerful model-class namely "Denoising Diffusion Probabilistic Models" or DDPMs for chirographic data that specifically addresses these flaws. Our model named "ChiroDiff", being non-autoregressive, learns to capture holistic concepts and therefore remains resilient to higher temporal sampling rate up to a good extent. Moreover, we show that many important downstream utilities (e.g. conditional sampling, creative mixing) can be flexibly implemented using ChiroDiff. We further show some unique use-cases like stochastic vectorization, de-noising/healing, abstraction are also possible with this model-class. We perform quantitative and qualitative evaluation of our framework on relevant datasets and found it to be better or on par with competing approaches.
21 |
22 | ---
23 |
24 | ## Running the code
25 |
26 | The instructions below guide you regarding running the codes in this repository.
27 |
28 | #### Table of contents:
29 | 1. Environment and libraries
30 | 2. Data preparation
31 | 3. Training
32 | 4. Inference
33 |
34 | ### Environment & Libraries
35 |
36 | Running the code may require some libraries slightly outdated. The full list is provided as a `requirements.txt` in this repo. Please create a virtual environment with `conda` or `venv` and run
37 |
38 | ~~~bash
39 | (myenv) $ pip install -r requirements.txt
40 | ~~~
41 |
42 | ### Data preparation
43 |
44 | You can feed the data in one of two ways -- "unpacked" and "unpacked and preprocessed". The first one will dynamically load data from individual files, whereas the later packs preprocessed input into one single `.npz` file -- increasing training speed.
45 |
46 | - To "unpack" the QuickDraw dataset, [download](https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/raw?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=true) the `.ndjson` file for any category(s) you like and save it in a folder `/path/to/all/ndjsons/`. Then use the utility `data/unpack_ndjson.py` provided to unpack them
47 | ~~~bash
48 | (myenv) $ python data/unpack_ndjson.py --data_folder /path/to/all/ndjsons/ --category cat --out_folder /the/output/dir/ --max_sketches 100000
49 | # produced a folder /the/output/dir/cat/ with all samples (unpacked)
50 | ~~~
51 | You may use this folder just as is, however, it might be slow for training. We recommend you pack them with all preprocessing into one `.npz` file using the `data/qd.py` script
52 | ~~~bash
53 | (myenv) $ python data/qd.py /the/output/dir/cat threeseqdel
54 | # produces a file /the/output/dir/cat_threeseqdel.npz
55 | ~~~
56 | The `threeseqdel` is one of many modes of training -- more on this later. The produced `.npz` file can be now used for training. Please see the `if __name__ == '__main__'` section of `data/qd.py` for preprocessing options.
57 |
58 | - For `VMNIST` and `KanjiVG` datasets, the unpacked files are readily available [here](https://drive.google.com/drive/folders/1C6euR9HPLdL_nubqLk8wEao96KRNZJl_?usp=sharing) for download. Follow the same steps above to process and pack them into `.npz`.
59 |
60 |
61 | ### Training & Sampling
62 |
63 | There are multiple training "modes" corresponding to the model type (unconditional, sequence conditioned etc).
64 |
65 | ```bash
66 | threeseqdel # unconditional model with delta (velocity) sequence
67 | threeseqdel_pointcloudcond # conditioned on pointcloud representation
68 | threeseqdel_classcond # conditioned on class
69 | threeseqdel_threeseqdelcond # conditioned on self
70 |
71 | threeseqabs # unconditional model with absolute (position) sequence
72 | threeseqabs_pointcloudcond # conditioned on pointcloud representation
73 | threeseqabs_classcond # conditioned on class
74 | threeseqabs_threeseqabscond # conditioned on self
75 | ```
76 |
77 | - Use one of the modes in `--model.repr` and `--data.init_args.repr` command line argument.
78 | - Use the processed data file (i.e. `*.npz`) with `--data.init_args.root_dir ...`. You may also use un-processed data folder here.
79 |
80 | **Note:** For simplicity, we provided a `config.yml` file where all possible command line option can be altered. Then run the main script as
81 |
82 | ```bash
83 | (myenv) $ python main.py fit --config config.yml --model.arch_layer 3 --model.noise_T 100 ...
84 | ```
85 |
86 | You will also need `wandb` for logging. Please use your own account and fill the correct values of `--trainer.logger.init_args.{entity, project}` in the `config.yml` file. You may also remove the `wandb` logger entirely and replace with another logger of your choice. In that case, you might have to modify few lines of codes.
87 |
88 | While training, the script will save the full config of the run, a "best model" and a "last model". Once trained, use the saved model (saved every 300 epoch) and full configuration using the `--ckpt_path` and `--config` argument like so
89 |
90 | ```bash
91 | (myenv) $ python main.py test --config ./logs/test-run/config.yaml --ckpt_path ./logs/test-run/.../checkpoints/model.ckpt --limit_test_batches 1
92 | ```
93 |
94 | By default, the testing phase will write some vizualization helpful for inspection. For example, a generation results and a diffusion process vizualization. Test time option have `--test_` prefixes. Feel free to play around with them.
95 |
96 | ```bash
97 | (myenv) $ python main.py test --config ... --ckpt_path ... \
98 | --test_sampling_algo ddpm \
99 | --test_variance_strength 0.75 \
100 | --text_viz_process backward \
101 | --test_save_everything 1
102 | ```
103 | ---
104 |
105 | You can site the paper as
106 |
107 | ```bibtex
108 | @inproceedings{das2023chirodiff,
109 | title={ChiroDiff: Modelling chirographic data with Diffusion Models},
110 | author={Ayan Das and Yongxin Yang and Timothy Hospedales and Tao Xiang and Yi-Zhe Song},
111 | booktitle={The Eleventh International Conference on Learning Representations },
112 | year={2023},
113 | url={https://openreview.net/forum?id=1ROAstc9jv}
114 | }
115 | ```
116 |
117 | ---
118 |
119 | **Notes:**
120 |
121 | 1. This repository is a part of our research codebase and may therefore contain codes/options that are not part of the paper.
122 | 2. This repo may also contain some implmenetation details that has been upgraded since the submission of the paper.
123 | 3. The README is still incomplete and I will add more info when I get time. You may try different settings yourself.
124 | 4. The default parameters might not match the ones in the paper. Feel free to change play with them.
125 |
126 |
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 | seed_everything: null
2 | trainer:
3 | logger:
4 | class_path: utils.CustomWandbLogger
5 | init_args:
6 | entity: # TODO: fill these two ..
7 | project: # .. entires yourself.
8 | offline: true
9 | log_model: false
10 | save_dir: ./logs/
11 | name: test-run
12 | group: test
13 |
14 | process_position: 0
15 | num_nodes: 1
16 | accelerator: gpu
17 | devices: 1
18 | auto_select_gpus: true
19 |
20 | gradient_clip_algorithm: norm
21 | enable_progress_bar: true
22 | overfit_batches: 0.0
23 | track_grad_norm: -1
24 | check_val_every_n_epoch: 1
25 | fast_dev_run: false
26 | accumulate_grad_batches: 1
27 | max_epochs: 100000
28 | limit_train_batches: 1.0
29 | limit_val_batches: 1.0
30 | limit_test_batches: 1.0
31 | log_every_n_steps: 8
32 | strategy: dp
33 | sync_batchnorm: false
34 | enable_model_summary: true
35 | weights_summary: top
36 | num_sanity_val_steps: 0
37 | profiler: null
38 | benchmark: false
39 | deterministic: false
40 | detect_anomaly: false
41 | auto_scale_batch_size: false
42 | prepare_data_per_node: null
43 | plugins: null
44 | amp_backend: native
45 | amp_level: null
46 | move_metrics_to_cpu: false
47 | stochastic_weight_avg: false
48 |
49 | gradient_clip_val: 0.1
50 | precision: 16
51 |
52 | model:
53 | repr: ${data.init_args.repr}
54 | modeltype: birnn
55 | time_embedding: randomfourier
56 |
57 | optim_ema: true
58 | optim_lr: 1.e-3
59 | optim_gamma: 0.9995
60 | optim_warmup: 15000
61 | optim_sched: steplr
62 | optim_interval: epoch
63 | optim_div_factor: 2
64 | optim_decay: 1.e-2
65 |
66 | arch_parameterization: eps
67 | arch_dropout: 0.
68 | arch_pe_dim: 8
69 | arch_head: 4
70 | arch_layer: 3
71 | arch_internal: 96
72 |
73 | # conditioning model arch
74 | arch_layer_cond: 3
75 | arch_internal_cond: 112
76 | arch_n_cond_latent: 96
77 |
78 | noise_T: 35
79 | noise_low_noise: 1.e-4
80 | noise_high_noise: 2.e-2
81 | noise_schedule: linear
82 |
83 | test_variance_strength: 0.75
84 | test_sampling_algo: ddpm
85 | test_n_viz: 10
86 | test_n_sample_viz: 10
87 | test_recon: true
88 | test_interp: false
89 |
90 | data:
91 | class_path: data.dm.QuickDrawDM
92 | init_args:
93 | root_dir: # TODO: path to the _.npz file
94 | repr: threeseqabs
95 |
96 | split_fraction: 0.85
97 | perlin_noise: 0.1
98 | split_seed: 5555
99 | num_workers: 4
100 | batch_size: 128
101 | max_strokes: 20
102 | max_sketches: 100000
103 |
104 | ckpt_path: null
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/data/__init__.py
--------------------------------------------------------------------------------
/data/batching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.nn.utils.rnn import pad_sequence
4 | from copy import deepcopy
5 |
6 | from data.sketch import Sketch
7 |
8 | # TODO: reasonable limit. can be made cmd arg later
9 | MAX_SEQ_LENGTH = 300
10 |
11 |
12 | class SketchRepr(object):
13 |
14 | def __init__(self, penbit=True, cache=False):
15 | super().__init__()
16 |
17 | self.penbit = penbit
18 |
19 | def represent(self, sketch):
20 | raise NotImplementedError('Abstract method not callable')
21 |
22 | def collate(self, batch: list):
23 | raise NotImplementedError('Abstract method not callable')
24 |
25 |
26 | class Strokewise(SketchRepr):
27 |
28 | def __init__(self, penbit=True, cache=False):
29 | # Here 'granularity' means stroke-granularity
30 | super().__init__(penbit=penbit, cache=cache)
31 |
32 | def represent(self, sketch: Sketch):
33 | sk_repr = []
34 |
35 | total_seq_len = sum([len(stroke) for stroke in sketch])
36 | if total_seq_len >= MAX_SEQ_LENGTH:
37 | return None
38 |
39 | for stroke in sketch:
40 | seq_stroke, _ = stroke.tensorize()
41 |
42 | # TODO: clean this properly; strokes becomes (2,) sized
43 | if len(stroke) > 1 and len(stroke.stroke.shape) == 2: # sloppy fix
44 | sk_repr.append({
45 | 'start': seq_stroke[0, :],
46 | 'time_range': torch.from_numpy(stroke.timestamps.astype(np.float32)),
47 | 'poly_stroke': seq_stroke
48 | })
49 |
50 | return sk_repr
51 |
52 |
53 | class Pointcloud(Strokewise):
54 |
55 | def construct_sample(sk):
56 | if len(sk) == 0:
57 | return None
58 |
59 | sk_set = torch.cat([s['poly_stroke'] for s in sk], 0)
60 | return sk_set
61 |
62 | def represent(self, sketch: Sketch):
63 | sk = super().represent(sketch)
64 | return sk and Pointcloud.construct_sample(sk)
65 |
66 | def collate(batch: list):
67 | batch = [b for b in batch if b is not None]
68 | lens = torch.tensor([b.shape[0] for b in batch])
69 | pd = pad_sequence(batch, batch_first=True)
70 | return None, (pd, lens)
71 |
72 |
73 | class ThreePointDelta(Strokewise):
74 |
75 | def construct_sample(sk, penbit=True):
76 | sk_list = []
77 | t_list = []
78 | for i, stroke in enumerate(sk):
79 | timestamps = stroke['time_range']
80 | stroke = stroke['poly_stroke']
81 | pen = torch.ones(stroke.shape[0], 1, dtype=stroke.dtype, device=stroke.device) * i
82 | pen[-1, 0] = i + 1
83 | sk_list.append(torch.cat([stroke, pen], -1))
84 | t_list.append(timestamps)
85 |
86 | if len(sk) == 0:
87 | return None
88 |
89 | sk = torch.cat(sk_list, 0)
90 | time = torch.cat(t_list, 0)
91 | sk_delta = sk[1:, :] - sk[:-1, :]
92 | if not penbit:
93 | sk_delta = sk_delta[:, :-1]
94 |
95 | time = time[:-1] # velocity is not available for the last point
96 | return torch.cat([sk_delta, time[:, None]], -1)
97 |
98 | def represent(self, sketch: Sketch):
99 | sk = super().represent(sketch)
100 | return sk and ThreePointDelta.construct_sample(sk, self.penbit)
101 |
102 | def collate(batch: list):
103 | batch = [b for b in batch if b is not None]
104 | lens = torch.tensor([b.shape[0] for b in batch])
105 | pd = pad_sequence(batch, batch_first=True)
106 | return None, (pd, lens)
107 |
108 |
109 | class ThreePointDelta_PointCloudCond(Strokewise):
110 |
111 | def represent(self, sketch: Sketch):
112 | sk = super().represent(sketch)
113 |
114 | if sk is None:
115 | return None, None
116 |
117 | sk_threepointdelta = ThreePointDelta.construct_sample(sk, self.penbit)
118 |
119 | sketch = deepcopy(sketch)
120 | sk = super().represent(sketch)
121 |
122 | if sk is None:
123 | return None, None
124 |
125 | sk_pointcloud = Pointcloud.construct_sample(sk)
126 |
127 | return sk_pointcloud, sk_threepointdelta
128 |
129 | def collate(batch: list):
130 | sk_threepointdeltas = [tpd for _, tpd in batch]
131 | sk_pointclouds = [pc for pc, _ in batch]
132 |
133 | _, pc_batch = Pointcloud.collate(sk_pointclouds)
134 | _, tpd_batch = ThreePointDelta.collate(sk_threepointdeltas)
135 | return pc_batch, tpd_batch
136 |
137 |
138 | class ThreePointAbs(Strokewise):
139 |
140 | def construct_sample(sk, penbit=True):
141 | sk_list = []
142 | t_list = []
143 | for _, stroke in enumerate(sk):
144 | timestamps = stroke['time_range']
145 | stroke = stroke['poly_stroke']
146 | pen = torch.zeros(stroke.shape[0], 1, dtype=stroke.dtype, device=stroke.device)
147 | pen[-1, 0] = 1.
148 | sk_list.append(torch.cat([stroke, pen], -1))
149 | t_list.append(timestamps)
150 |
151 | if len(sk) == 0:
152 | return None
153 |
154 | sk = torch.cat(sk_list, 0)
155 | if not penbit:
156 | sk = sk[:, :-1]
157 |
158 | time = torch.cat(t_list, 0)
159 | return torch.cat([sk[1:, :], time[1:, None]], -1)
160 |
161 | def represent(self, sketch: Sketch):
162 | sk = super().represent(sketch)
163 | return sk and ThreePointAbs.construct_sample(sk, self.penbit)
164 |
165 | def collate(batch: list):
166 | batch = [b for b in batch if b is not None]
167 | lens = torch.tensor([b.shape[0] for b in batch])
168 | pd = pad_sequence(batch, batch_first=True)
169 | return None, (pd, lens)
170 |
171 |
172 | class ThreePointAbs_PointCloudCond(Strokewise):
173 |
174 | def represent(self, sketch: Sketch):
175 | sk = super().represent(sketch)
176 |
177 | if sk is None:
178 | return None, None
179 |
180 | sk_threepointabs = ThreePointAbs.construct_sample(sk, self.penbit)
181 |
182 | sketch = deepcopy(sketch)
183 | sk = super().represent(sketch)
184 |
185 | if sk is None:
186 | return None, None
187 |
188 | sk_pointcloud = Pointcloud.construct_sample(sk)
189 |
190 | return sk_pointcloud, sk_threepointabs
191 |
192 | def collate(batch: list):
193 | sk_threepointabss = [tpa for _, tpa in batch]
194 | sk_pointclouds = [pc for pc, _ in batch]
195 |
196 | _, pc_batch = Pointcloud.collate(sk_pointclouds)
197 | _, tpa_batch = ThreePointAbs.collate(sk_threepointabss)
198 | return pc_batch, tpa_batch
199 |
200 |
201 | class ThreePointAbs_ThreeSeqAbs(Strokewise):
202 |
203 | def __init__(self, penbit=True, cond_rdp=None, cache=False):
204 | super().__init__(penbit, cache)
205 |
206 | self.cond_rdp = cond_rdp
207 |
208 | def represent(self, sketch: Sketch):
209 | cond_sketch = deepcopy(sketch)
210 |
211 | # spatially scaling back to 1. and then 10. is needed because the stuff in the middle
212 | # (resampling rate, RDP parameter) are sensitive to spatial scale of the vector entity.
213 | cond_sketch.scale_spatial(1.)
214 | if self.cond_rdp is not None:
215 | cond_sketch.rdp(self.cond_rdp)
216 | cond_sketch.scale_spatial(10.)
217 |
218 | cond_sk = super().represent(cond_sketch)
219 | sk = super().represent(sketch)
220 |
221 | if sk is None or cond_sk is None:
222 | return None, None
223 |
224 | cond_sk_threepointabs = ThreePointAbs.construct_sample(cond_sk, self.penbit)
225 | sk_threepointabs = ThreePointAbs.construct_sample(sk, self.penbit)
226 |
227 | # timestep not needed for the condition
228 | return cond_sk_threepointabs, \
229 | sk_threepointabs
230 |
231 | def collate(batch: list):
232 | sk_threepointabss = [h_tpa for _, h_tpa in batch]
233 | cond_sk_threepointabss = [l_tpa for l_tpa, _ in batch]
234 |
235 | _, tpa_batch = ThreePointAbs.collate(sk_threepointabss)
236 | _, cond_tpa_batch = ThreePointAbs.collate(cond_sk_threepointabss)
237 | return cond_tpa_batch, tpa_batch
238 |
239 |
240 | class ThreePointDel_ThreeSeqDel(Strokewise):
241 |
242 | def __init__(self, penbit=True, cond_rdp=None, cache=False):
243 | super().__init__(penbit, cache)
244 |
245 | self.cond_rdp = cond_rdp
246 |
247 | def represent(self, sketch: Sketch):
248 | cond_sketch = deepcopy(sketch)
249 |
250 | # spatially scaling back to 1. and then 10. is needed because the stuff in the middle
251 | # (resampling rate, RDP parameter) are sensitive to spatial scale of the vector entity.
252 | cond_sketch.scale_spatial(1.)
253 | if self.cond_rdp is not None:
254 | cond_sketch.rdp(self.cond_rdp)
255 | cond_sketch.scale_spatial(10.)
256 |
257 | cond_sk = super().represent(cond_sketch)
258 | sk = super().represent(sketch)
259 |
260 | if sk is None or cond_sk is None:
261 | return None, None
262 |
263 | cond_sk_threepointdel = ThreePointDelta.construct_sample(cond_sk, self.penbit)
264 | sk_threepointdel = ThreePointDelta.construct_sample(sk, self.penbit)
265 |
266 | # timestep not needed for the condition
267 | return cond_sk_threepointdel, \
268 | sk_threepointdel
269 |
270 | def collate(batch: list):
271 | sk_threepointdels = [h_tpd for _, h_tpd in batch]
272 | cond_sk_threepointdels = [l_tpd for l_tpd, _ in batch]
273 |
274 | _, tpd_batch = ThreePointAbs.collate(sk_threepointdels)
275 | _, cond_tpd_batch = ThreePointAbs.collate(cond_sk_threepointdels)
276 | return cond_tpd_batch, tpd_batch
277 |
278 |
279 | class StrokeSet(Strokewise):
280 |
281 | def represent(self, sketch: Sketch):
282 | sk = super().represent(sketch)
283 |
284 | sk_list = []
285 | for stroke in sk:
286 | abs_stroke = stroke['poly_stroke']
287 | del_stroke = abs_stroke[1:, ...] - abs_stroke[:-1, ...]
288 | start_del_stroke = torch.cat([stroke['start'][None, :], del_stroke], 0)
289 | sk_list.append(start_del_stroke.ravel())
290 |
291 | sk = torch.stack(sk_list, 0)
292 | return sk
293 |
294 | def collate(batch: list):
295 | batch = [b for b in batch if b is not None]
296 | lens = torch.tensor([b.shape[0] for b in batch])
297 | pd = pad_sequence(batch, batch_first=True)
298 | return pd, lens
299 |
--------------------------------------------------------------------------------
/data/dm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from enum import Enum
3 | from typing import Optional
4 |
5 | import torch
6 | from torch.utils.data import DataLoader, random_split, Dataset
7 | from pytorch_lightning import LightningDataModule
8 |
9 | from data.qd import (
10 | DS_threeseqdel,
11 | DS_threeseqabs,
12 | DS_threeseqdel_pointcloudcond,
13 | DS_threeseqdel_classcond,
14 | DS_threeseqabs_classcond,
15 | DS_threeseqabs_pointcloudcond,
16 | DS_threeseqabs_threeseqabscond,
17 | DS_threeseqdel_threeseqdelcond
18 | )
19 |
20 |
21 | class ReprType(str, Enum):
22 | threeseqdel = "threeseqdel"
23 | threeseqabs = "threeseqabs"
24 | threeseqabs_threeseqabscond = "threeseqabs_threeseqabscond"
25 | threeseqdel_pointcloudcond = "threeseqdel_pointcloudcond"
26 | threeseqdel_classcond = "threeseqdel_classcond"
27 | threeseqabs_classcond = "threeseqabs_classcond"
28 | threeseqabs_pointcloudcond = "threeseqabs_pointcloudcond"
29 | threeseqdel_threeseqdelcond = "threeseqdel_threeseqdelcond"
30 |
31 |
32 | class GenericDM(LightningDataModule):
33 |
34 | def __init__(self, split_seed, split_fraction, batch_size, num_worker, repr):
35 | super().__init__()
36 |
37 | self.split_seed = split_seed
38 | self.split_fraction = split_fraction
39 | self.batch_size = batch_size
40 | self.num_worker = num_worker
41 | self.repr = repr
42 |
43 | # subclasses need to set this with a 'Dataset' instance
44 | self._dataset = None
45 |
46 | @property
47 | def dataset(self):
48 | if self._dataset is None:
49 | raise ValueError(f'Subclass {self.__class__.__name__} is yet to assign a Dataset')
50 | else:
51 | return self._dataset
52 |
53 | @dataset.setter
54 | def dataset(self, d):
55 | if not isinstance(d, Dataset):
56 | raise ValueError(f'Expected a Dataset, got {d}')
57 | else:
58 | self._dataset = d
59 |
60 | def compute_split_size(self):
61 | self.train_len = int(len(self.dataset) * self.split_fraction)
62 | self.valid_len = len(self.dataset) - self.train_len
63 |
64 | def setup(self, stage: str):
65 | self.train_dataset, self.valid_dataset = \
66 | random_split(self.dataset, [self.train_len, self.valid_len],
67 | torch.Generator().manual_seed(self.split_seed))
68 |
69 | def train_dataloader(self):
70 | return DataLoader(self.train_dataset,
71 | batch_size=self.batch_size, pin_memory=True, drop_last=True, shuffle=True,
72 | num_workers=self.num_worker, collate_fn=self.dataset.__class__.collate)
73 |
74 | def val_dataloader(self):
75 | return DataLoader(self.valid_dataset,
76 | batch_size=self.batch_size, pin_memory=True, drop_last=True, shuffle=True,
77 | num_workers=self.num_worker, collate_fn=self.dataset.__class__.collate)
78 |
79 | def test_dataloader(self):
80 | return self.val_dataloader()
81 |
82 |
83 | class QuickDrawDM(GenericDM):
84 |
85 | def __init__(self,
86 | root_dir: str,
87 | max_sketches: Optional[int] = None,
88 | max_strokes: Optional[int] = None,
89 | split_fraction: float = 0.85,
90 | perlin_noise: float = 0.2,
91 | penbit: bool = True,
92 | split_seed: int = 4321,
93 | batch_size: int = 64,
94 | num_workers: int = os.cpu_count() // 2,
95 | rdp: Optional[float] = None,
96 | cond_rdp: Optional[float] = None,
97 | repr: ReprType = ReprType.threeseqdel,
98 | cache: bool = False
99 | ):
100 | """QuickDraw Datamodule (OneSeq)
101 |
102 | Args:
103 | root_dir: Root directory of QD data (unpacked by `unpack_ndjson.py` utility)
104 | category: QD category name
105 | max_sketches: Maximum number of sketches to use
106 | max_strokes: clamp the maximum number of strokes (None for all strokes)
107 | split_fraction: Train/Validation split fraction
108 | perlin_noise: Strength of Perlin noise (YET TO BE IMPL)
109 | granularity: Number of points in each sample
110 | split_seed: Data splitting seed
111 | batch_size: Batch size for training
112 | rdp: RDP algorithm parameter ('None' to ignore RDP entirely)
113 | repr: data representation (oneseq or strokewise)
114 | """
115 | self.save_hyperparameters()
116 | self.hp = self.hparams # an easier name
117 | super().__init__(self.hp.split_seed,
118 | self.hp.split_fraction,
119 | self.hp.batch_size,
120 | self.hp.num_workers,
121 | self.hp.repr)
122 |
123 | self._construct()
124 |
125 | def _construct(self):
126 | if self.hp.repr == ReprType.threeseqdel:
127 | self.dataset = DS_threeseqdel(self.hp.root_dir,
128 | perlin_noise=self.hp.perlin_noise,
129 | max_sketches=self.hp.max_sketches,
130 | max_strokes=self.hp.max_strokes,
131 | penbit=self.hp.penbit,
132 | rdp=self.hp.rdp)
133 | elif self.hp.repr == ReprType.threeseqabs:
134 | self.dataset = DS_threeseqabs(self.hp.root_dir,
135 | perlin_noise=self.hp.perlin_noise,
136 | max_sketches=self.hp.max_sketches,
137 | max_strokes=self.hp.max_strokes,
138 | penbit=self.hp.penbit,
139 | rdp=self.hp.rdp)
140 | elif self.hp.repr == ReprType.threeseqdel_pointcloudcond:
141 | self.dataset = DS_threeseqdel_pointcloudcond(self.hp.root_dir,
142 | perlin_noise=self.hp.perlin_noise,
143 | max_sketches=self.hp.max_sketches,
144 | max_strokes=self.hp.max_strokes,
145 | penbit=self.hp.penbit,
146 | rdp=self.hp.rdp)
147 | elif self.hp.repr == ReprType.threeseqdel_classcond:
148 | self.dataset = DS_threeseqdel_classcond(self.hp.root_dir,
149 | perlin_noise=self.hp.perlin_noise,
150 | max_sketches=self.hp.max_sketches,
151 | max_strokes=self.hp.max_strokes,
152 | penbit=self.hp.penbit,
153 | rdp=self.hp.rdp)
154 | elif self.hp.repr == ReprType.threeseqabs_classcond:
155 | self.dataset = DS_threeseqabs_classcond(self.hp.root_dir,
156 | perlin_noise=self.hp.perlin_noise,
157 | max_sketches=self.hp.max_sketches,
158 | max_strokes=self.hp.max_strokes,
159 | penbit=self.hp.penbit,
160 | rdp=self.hp.rdp)
161 | elif self.hp.repr == ReprType.threeseqabs_pointcloudcond:
162 | self.dataset = DS_threeseqabs_pointcloudcond(self.hp.root_dir,
163 | perlin_noise=self.hp.perlin_noise,
164 | max_sketches=self.hp.max_sketches,
165 | max_strokes=self.hp.max_strokes,
166 | penbit=self.hp.penbit,
167 | rdp=self.hp.rdp)
168 | elif self.hp.repr == ReprType.threeseqabs_threeseqabscond:
169 | self.dataset = DS_threeseqabs_threeseqabscond(self.hp.root_dir,
170 | perlin_noise=self.hp.perlin_noise,
171 | max_sketches=self.hp.max_sketches,
172 | max_strokes=self.hp.max_strokes,
173 | penbit=self.hp.penbit,
174 | rdp=self.hp.rdp,
175 | cond_rdp=self.hp.cond_rdp)
176 | elif self.hp.repr == ReprType.threeseqdel_threeseqdelcond:
177 | self.dataset = DS_threeseqdel_threeseqdelcond(self.hp.root_dir,
178 | perlin_noise=self.hp.perlin_noise,
179 | max_sketches=self.hp.max_sketches,
180 | max_strokes=self.hp.max_strokes,
181 | penbit=self.hp.penbit,
182 | rdp=self.hp.rdp,
183 | cond_rdp=self.hp.cond_rdp)
184 | else:
185 | pass
186 |
187 | self.compute_split_size()
188 |
--------------------------------------------------------------------------------
/data/qd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import pickle
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 | from data.sketch import Sketch
12 | from data.batching import Pointcloud, Strokewise, \
13 | ThreePointAbs, ThreePointDelta, ThreePointDelta_PointCloudCond, \
14 | ThreePointAbs_PointCloudCond, ThreePointAbs_ThreeSeqAbs, \
15 | ThreePointDel_ThreeSeqDel
16 |
17 |
18 | class QuickDraw(Dataset):
19 |
20 | def __init__(self, data_root, shuffle=True, perlin_noise=0.2,
21 | max_sketches=10000, max_strokes=None, rdp=None, **kwargs):
22 | super().__init__()
23 |
24 | self.data_root = data_root
25 |
26 | if os.path.isfile(self.data_root) and self.data_root.endswith('.npz'):
27 | self.npz_ptr = np.load(self.data_root, allow_pickle=True)
28 | self.attrs = self.npz_ptr.files
29 | self.data = {attr: self.npz_ptr[attr] for attr in self.attrs}
30 | self.cached = True
31 | return
32 | else:
33 | self.cached = False
34 |
35 | self.max_sketches = max_sketches
36 | self.max_strokes = max_strokes
37 | self.perlin_noise = perlin_noise
38 | self.rdp = rdp
39 |
40 | self.content_list = os.listdir(self.data_root)
41 | if all([os.path.isdir(os.path.join(self.data_root, c_path)) for c_path in self.content_list]):
42 | self.categories = self.content_list
43 | self.n_categories = len(self.categories)
44 | self.file_list = []
45 | for cat in self.categories:
46 | cat_content = os.listdir(os.path.join(self.data_root, cat))
47 |
48 | if self.max_sketches is not None:
49 | max_sketches_per_cat = min(self.max_sketches, len(cat_content))
50 | del cat_content[max_sketches_per_cat:]
51 |
52 | self.file_list.extend([os.path.join(cat, c) for c in cat_content])
53 |
54 | else:
55 | self.categories = None
56 | self.file_list = self.content_list
57 |
58 | if self.max_sketches is not None:
59 | max_sketches = min(self.max_sketches, len(self.file_list))
60 | del self.file_list[max_sketches:]
61 |
62 | if shuffle:
63 | random.shuffle(self.file_list)
64 |
65 |
66 | def __len__(self):
67 | if not self.cached:
68 | return len(self.file_list)
69 | else:
70 | return self.data[self.attrs[0]].shape[0]
71 |
72 | def get_sketch(self, i):
73 | if self.categories is not None:
74 | cat, _ = self.file_list[i].split('/')
75 | assert cat in self.categories, "something wrong with category/folder names"
76 | self.cat_idx = self.categories.index(cat)
77 | else:
78 | self.cat_idx = None
79 |
80 | file_path = os.path.join(self.data_root, self.file_list[i])
81 |
82 | with open(file_path, 'rb') as f:
83 | self.data = pickle.load(f)
84 |
85 | stroke_list = self.data['drawing']
86 |
87 | if self.max_strokes is not None:
88 | max_strokes = min(self.max_strokes, len(stroke_list))
89 | stroke_list = stroke_list[:max_strokes]
90 |
91 | sketch = Sketch(stroke_list, label=self.cat_idx)
92 |
93 | seed = random.randint(0, 10000)
94 | sketch.jitter(seed=seed, noise_level=self.perlin_noise)
95 |
96 | sketch.move()
97 | sketch.shift_time(0)
98 | sketch.scale_spatial(1)
99 | sketch.resample(delta=0.05)
100 | if self.rdp is not None:
101 | sketch.rdp(eps=self.rdp)
102 | sketch.scale_spatial(10)
103 | sketch.scale_time(1)
104 |
105 | return sketch
106 |
107 | def __getitem__(self, i):
108 | if not self.cached:
109 | return self.represent(self.get_sketch(i))
110 | else:
111 | if len(self.attrs) > 1:
112 | return tuple(torch.from_numpy(self.data[attr][i]) for attr in self.attrs)
113 | else:
114 | return torch.from_numpy(self.data[self.attrs[0]][i])
115 |
116 |
117 | class QDSketchStrokewise(QuickDraw, Strokewise):
118 |
119 | def __init__(self, *args, **kwargs):
120 | QuickDraw.__init__(self, *args, **kwargs)
121 | Strokewise.__init__(self)
122 |
123 | def __getitem__(self, i):
124 | return self.represent(super().get_sketch(i))
125 |
126 |
127 | class QDSketchPointcloud(QuickDraw, Pointcloud):
128 |
129 | def __init__(self, *args, **kwargs):
130 | QuickDraw.__init__(self, *args, **kwargs)
131 | Pointcloud.__init__(self)
132 |
133 | def __getitem__(self, i):
134 | return self.represent(super().__getitem__(i))
135 |
136 |
137 | class DS_threeseqdel(QuickDraw, ThreePointDelta):
138 |
139 | def __init__(self, *args, **kwargs):
140 | QuickDraw.__init__(self, *args, **kwargs)
141 | ThreePointDelta.__init__(self, penbit=kwargs.get('penbit', True))
142 |
143 |
144 | class DS_threeseqabs(QuickDraw, ThreePointAbs):
145 |
146 | def __init__(self, *args, **kwargs):
147 | QuickDraw.__init__(self, *args, **kwargs)
148 | ThreePointAbs.__init__(self, penbit=kwargs.get('penbit', True))
149 |
150 |
151 | class DS_threeseqabs_classcond(QuickDraw, ThreePointAbs):
152 |
153 | def __init__(self, *args, **kwargs):
154 | QuickDraw.__init__(self, *args, **kwargs)
155 | ThreePointAbs.__init__(self, penbit=kwargs.get('penbit', True))
156 |
157 | def represent(self, sketch: Sketch):
158 | label = torch.tensor(sketch.label, dtype=torch.int64)
159 | return label, super().represent(sketch)
160 |
161 | def collate(batch: list):
162 | class_batch = torch.stack([c for c, _ in batch], 0)
163 | _, tpa_batch = ThreePointAbs.collate([tpa for _, tpa in batch])
164 | return class_batch, tpa_batch
165 |
166 |
167 | class DS_threeseqdel_classcond(QuickDraw, ThreePointDelta):
168 |
169 | def __init__(self, *args, **kwargs):
170 | QuickDraw.__init__(self, *args, **kwargs)
171 | ThreePointDelta.__init__(self, penbit=kwargs.get('penbit', True))
172 |
173 | def represent(self, sketch: Sketch):
174 | label = torch.tensor(sketch.label, dtype=torch.int64)
175 | return label, super().represent(sketch)
176 |
177 | def collate(batch: list):
178 | class_batch = torch.stack([c for c, _ in batch], 0)
179 | _, tpd_batch = ThreePointDelta.collate([tpd for _, tpd in batch])
180 | return class_batch, tpd_batch
181 |
182 |
183 | class DS_threeseqdel_pointcloudcond(QuickDraw, ThreePointDelta_PointCloudCond):
184 |
185 | def __init__(self, *args, **kwargs):
186 | QuickDraw.__init__(self, *args, **kwargs)
187 | ThreePointDelta_PointCloudCond.__init__(self, penbit=kwargs.get('penbit', True))
188 |
189 |
190 | class DS_threeseqabs_pointcloudcond(QuickDraw, ThreePointAbs_PointCloudCond):
191 |
192 | def __init__(self, *args, **kwargs):
193 | QuickDraw.__init__(self, *args, **kwargs)
194 | ThreePointAbs_PointCloudCond.__init__(self, penbit=kwargs.get('penbit', True))
195 |
196 |
197 | class DS_threeseqabs_threeseqabscond(QuickDraw, ThreePointAbs_ThreeSeqAbs):
198 |
199 | def __init__(self, *args, **kwargs):
200 | QuickDraw.__init__(self, *args, **kwargs)
201 | ThreePointAbs_ThreeSeqAbs.__init__(self, penbit=kwargs.get('penbit', True),
202 | cond_rdp=kwargs.get('cond_rdp', None))
203 |
204 | def __getitem__(self, i):
205 | if not self.cached:
206 | return self.represent(self.get_sketch(i))
207 | else:
208 | if len(self.attrs) > 1:
209 | return tuple(torch.from_numpy(self.data[attr][i]) for attr in self.attrs)
210 | else:
211 | # in case we need the same data as cond
212 | d = self.data[self.attrs[0]][i]
213 | return torch.from_numpy(d), torch.from_numpy(d)
214 |
215 |
216 | class DS_threeseqdel_threeseqdelcond(QuickDraw, ThreePointDel_ThreeSeqDel):
217 |
218 | def __init__(self, *args, **kwargs):
219 | QuickDraw.__init__(self, *args, **kwargs)
220 | ThreePointDel_ThreeSeqDel.__init__(self, penbit=kwargs.get('penbit', True),
221 | cond_rdp=kwargs.get('cond_rdp', None))
222 |
223 | def __getitem__(self, i):
224 | if not self.cached:
225 | return self.represent(self.get_sketch(i))
226 | else:
227 | if len(self.attrs) > 1:
228 | return tuple(torch.from_numpy(self.data[attr][i]) for attr in self.attrs)
229 | else:
230 | # in case we need the same data as cond
231 | d = self.data[self.attrs[0]][i]
232 | return torch.from_numpy(d), torch.from_numpy(d)
233 |
234 |
235 | if __name__ == '__main__':
236 | class_name_str = eval('DS_' + sys.argv[2])
237 | ds = class_name_str(
238 | sys.argv[1],
239 | perlin_noise=0.,
240 | max_sketches=100000,
241 | max_strokes=25,
242 | penbit=True,
243 | rdp=None
244 | )
245 | dummy_sample = ds[0]
246 | if not isinstance(dummy_sample, tuple):
247 | n_attr = 1
248 | else:
249 | n_attr = len(dummy_sample)
250 |
251 | samples = [[] for _ in range(n_attr)]
252 | for sam in tqdm(ds):
253 | if n_attr == 1:
254 | sam = (sam, )
255 | for a in range(n_attr):
256 | if sam[a] is None:
257 | break
258 | samples[a].append(sam[a].numpy())
259 | samples = [np.array(sams, dtype=np.ndarray) for sams in samples]
260 | attrs = [f'attr{a}' for a in range(n_attr)]
261 |
262 | np.savez(sys.argv[1] + f'_{sys.argv[2]}.npz', **dict(zip(attrs, samples)))
--------------------------------------------------------------------------------
/data/sketch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from matplotlib.colors import Colormap
5 | from simplification.cutil import simplify_coords_idx
6 |
7 | from data.utils import continuous_noise, resample
8 |
9 |
10 | class Stroke(object):
11 | def __init__(self, stroke, timestamps=None):
12 | super().__init__()
13 |
14 | self.type = type(stroke)
15 | if self.type in [np.ndarray, torch.Tensor]:
16 | self.stroke = stroke
17 | assert isinstance(timestamps, self.type), \
18 | "stroke & timestamps must have same type"
19 | self.timestamps = timestamps
20 |
21 | if self.timestamps.ndim != 1:
22 | raise AssertionError('timestamps must be 1D array')
23 |
24 | def draw(self, axis=None):
25 | raise NotImplementedError('Use one of the subclasses of Stroke')
26 |
27 | def __len__(self):
28 | return self.stroke.shape[0]
29 |
30 | def tensorize(self):
31 | if self.type is torch.Tensor:
32 | return self.stroke, self.timestamps
33 | else:
34 | return torch.from_numpy(self.stroke.astype(np.float32)), \
35 | torch.from_numpy(self.timestamps.astype(np.float32))
36 |
37 |
38 | class PolylineStroke(Stroke):
39 | def __init__(self, stroke, timestamps=None):
40 |
41 | stroke = np.array(stroke).T if isinstance(stroke, list) else stroke
42 | timestamps = np.array(timestamps) if isinstance(timestamps, list) else timestamps
43 | super().__init__(stroke, timestamps)
44 |
45 | def rdp(self, eps=0.01):
46 | is_tensor = isinstance(self.stroke, torch.Tensor)
47 | stroke = self.stroke.data.cpu().numpy() if is_tensor else self.stroke
48 | stroke = np.ascontiguousarray(stroke)
49 | simpl_idx = simplify_coords_idx(stroke, eps)
50 |
51 | self.stroke = self.stroke[simpl_idx]
52 | self.timestamps = self.timestamps[simpl_idx]
53 |
54 | def resample(self, granularity):
55 | stroke = self.stroke.numpy() if (self.type is torch.Tensor) else self.stroke
56 | timestamps = self.timestamps.numpy() \
57 | if (self.type is torch.Tensor) else self.timestamps
58 |
59 | self.stroke, self.timestamps = resample(stroke, timestamps, granularity)
60 |
61 | def jitter(self, seed, noise_level=0.2):
62 | stroke = self.stroke.numpy() if (self.type is torch.Tensor) else self.stroke
63 | self.stroke = continuous_noise(stroke, seed=seed, noise_level=noise_level)
64 |
65 | def move(self, by=np.zeros((1, 2))):
66 | self.stroke = self.stroke + by
67 |
68 | def shift_time(self, to=0.):
69 | self.timestamps = self.timestamps - self.initial_time + to
70 |
71 | def scale_time(self, factor=1.):
72 | self.timestamps = (self.timestamps / self.terminal_time) * factor
73 |
74 | @property
75 | def initial_time(self):
76 | return self.timestamps[0]
77 |
78 | @property
79 | def terminal_time(self):
80 | return self.timestamps[-1]
81 |
82 | @property
83 | def start(self):
84 | return self.stroke[0, :]
85 |
86 | @property
87 | def end(self):
88 | return self.stroke[-1, :]
89 |
90 | def draw(self, axis=None, color='black', linewidth=1, scatter=True):
91 | if axis is None:
92 | fig = plt.figure()
93 | axis = plt.gca()
94 |
95 | stroke = self.stroke.data.cpu().numpy() if (self.type is torch.Tensor) else self.stroke
96 | if not isinstance(color, list):
97 | axis.plot(stroke[:, 0], stroke[:, 1], color=color, linewidth=linewidth)
98 | else:
99 | for i in range(len(self) - 1):
100 | axis.plot(stroke[i:i+2, 0], stroke[i:i+2, 1], color=color[i], linewidth=linewidth, solid_capstyle='round')
101 |
102 | if scatter:
103 | stroke = self.stroke.data.cpu().numpy() if (self.type is torch.Tensor) else self.stroke
104 | if not isinstance(color, list):
105 | axis.scatter(stroke[:, 0], stroke[:, 1], color=color, s=linewidth*2)
106 | else:
107 | for i in range(len(self)):
108 | axis.scatter(stroke[None, i, 0], stroke[None, i, 1], color=color[i], s=linewidth*2)
109 |
110 | @property
111 | def enclosing_circle_radius(self):
112 | norms = np.linalg.norm(self.stroke, 2, -1)
113 | return norms.max()
114 |
115 | @property
116 | def length(self):
117 | return (((self.stroke[1:, :] - self.stroke[:-1, :])**2).sum(-1)**0.5).sum().item()
118 |
119 |
120 | class Sketch(object):
121 |
122 | def __init__(self, strokes, label=None):
123 | super().__init__()
124 | self.label = label # optional class label
125 |
126 | self.strokes = []
127 | for s in strokes:
128 | stroke = PolylineStroke(s[:2], s[-1])
129 | if len(stroke) > 1:
130 | # one point strokes are not tolerable
131 | self.strokes.append(stroke)
132 |
133 | @property
134 | def nstrokes(self):
135 | return len(self.strokes)
136 |
137 | def __len__(self):
138 | return self.nstrokes
139 |
140 | def rdp(self, eps=0.01):
141 | for stroke in self.strokes:
142 | stroke.rdp(eps)
143 |
144 | def resample(self, delta=0.1):
145 | for stroke in self.strokes:
146 | n = max(2, int(stroke.length / delta))
147 | stroke.resample(n)
148 |
149 | def move(self, to=np.zeros((1, 2))):
150 | move_by = to - self.strokes[0].start
151 | for stroke in self.strokes:
152 | stroke.move(move_by)
153 |
154 | def __getitem__(self, i):
155 | return self.strokes[i]
156 |
157 | def draw(self, axis=None, cla=True, color='black', **kwargs):
158 | if axis is None:
159 | fig = plt.figure()
160 | axis = plt.gca()
161 |
162 | if cla:
163 | axis.cla()
164 |
165 | if not isinstance(color, Colormap):
166 | for stroke in self.strokes:
167 | stroke.draw(axis, color=color, **kwargs)
168 | else:
169 | seg_lens = [len(s) for s in self.strokes]
170 | colors = [color(i / (sum(seg_lens) - 1)) for i in range(sum(seg_lens))]
171 | c = 0
172 | for stroke in self.strokes:
173 | l = len(stroke)
174 | stroke.draw(axis, color=colors[c:c+l], **kwargs)
175 | c += l
176 |
177 | xmin, xmax = axis.get_xlim()
178 | ymin, ymax = axis.get_ylim()
179 | width = xmax - xmin
180 | height = ymax - ymin
181 | xmin, xmax = xmin - 0.1 * width, xmax + 0.1 * width
182 | ymin, ymax = ymin - 0.1 * height, ymax + 0.1 * height
183 | axis.set_xlim([xmin, xmax])
184 | axis.set_ylim([ymin, ymax])
185 |
186 | axis.set_xticks([])
187 | axis.set_yticks([])
188 | axis.set_xticklabels([])
189 | axis.set_xticklabels([])
190 |
191 | @property
192 | def terminal_time(self):
193 | return self.strokes[-1].terminal_time
194 |
195 | @property
196 | def initial_time(self):
197 | return self.strokes[0].initial_time
198 |
199 | def shift_time(self, to=0.):
200 | initial_time = self.initial_time
201 | for stroke in self.strokes:
202 | stroke.timestamps = stroke.timestamps - initial_time
203 |
204 | def scale_time(self, factor=1.):
205 | for stroke in self.strokes:
206 | stroke.timestamps = (stroke.timestamps / self.terminal_time) * factor
207 |
208 | def scale_spatial(self, factor=1.):
209 | enclosing_circle_radius = max([stroke.enclosing_circle_radius for stroke in self.strokes])
210 | for stroke in self.strokes:
211 | stroke.stroke = (stroke.stroke / enclosing_circle_radius) * factor
212 |
213 | def jitter(self, seed, noise_level=0.2):
214 | for i, stroke in enumerate(self.strokes):
215 | stroke.jitter(seed + i, noise_level)
216 |
217 | def _fill_penup(start, end, granularity):
218 | start = start.unsqueeze(0).repeat(granularity, 1)
219 | end = end.unsqueeze(0).repeat(granularity, 1)
220 | alpha = torch.linspace(0., 1., granularity).unsqueeze(-1)
221 | stroke = start * (1. - alpha) + end * alpha
222 | return stroke
223 |
224 | def _add_pen_state(stroke, fill_value=0.):
225 | stroke_plus_pen = torch.cat([
226 | stroke,
227 | torch.ones(len(stroke), 1, device=stroke.device) * fill_value
228 | ], dim=-1)
229 | return stroke_plus_pen
230 |
231 | def tensorize(self, joining_granularity=20):
232 | seq_strokes, seq_timestamps = [], []
233 |
234 | current_stroke, current_timestamps = self[0].tensorize()
235 | seq_strokes.append(Sketch._add_pen_state(current_stroke, 0.))
236 | seq_timestamps.append(current_timestamps)
237 |
238 | for i in range(1, self.nstrokes):
239 | next_stroke, next_timestamps = self[i].tensorize()
240 | joining_stroke = Sketch._fill_penup(current_stroke[-1, :], next_stroke[0, :],
241 | granularity=joining_granularity)
242 | joining_stroke_pen = Sketch._add_pen_state(joining_stroke, 1.)
243 | joining_timestamps = torch.linspace(current_timestamps[-1], next_timestamps[0], len(joining_stroke_pen),
244 | device=joining_stroke_pen.device)
245 | # ignore the first and last one to avoid duplication
246 | seq_strokes.append(joining_stroke_pen[1:-1, ...])
247 | seq_timestamps.append(joining_timestamps[1:-1])
248 |
249 | next_stroke_pen = Sketch._add_pen_state(next_stroke, 0.)
250 | seq_strokes.append(next_stroke_pen)
251 | seq_timestamps.append(next_timestamps)
252 |
253 | current_stroke, current_timestamps = next_stroke, next_timestamps
254 |
255 | return torch.cat(seq_strokes, 0), torch.cat(seq_timestamps, 0)
256 |
257 | def from_threeseqabs(seq, ts=None):
258 | # `seq` can be (N x 3) array, either np.ndarray or torch.Tensor
259 | n_points, _ = seq.shape
260 | seq, penbits = seq[:, :-1], seq[:, -1]
261 |
262 | dummy_timestamps = ts or np.linspace(0., 1., n_points)
263 | seq = np.concatenate((seq, dummy_timestamps[:, None]), axis=-1)
264 |
265 | split_locations, = penbits.nonzero()
266 | strokes = np.split(seq, split_locations + 1, axis=0)
267 |
268 | return Sketch([strk.T.tolist() for strk in strokes])
--------------------------------------------------------------------------------
/data/unpack_ndjson.py:
--------------------------------------------------------------------------------
1 | '''
2 | Unpack Quick Draw OR DiDi data to make Data loading more efficient.
3 | Otherwise full loading of '.ndjson' takes a while.
4 |
5 | For QD, "python unpack_ndjson.py --data_folder /path/to/QD/raw -c cat -o /path/to/empty/dir"
6 | For DiDi, "python unpack_ndjson.py --data_folder /path/to/DiDi -c diagrams_wo_text_20200131 -o /path/to/empty/dir"
7 | Author: Ayan Das
8 | '''
9 |
10 | import os
11 | import pickle
12 | import argparse
13 | import ndjson as nj
14 | from tqdm import tqdm
15 |
16 |
17 | def main(args):
18 | data_path = os.path.join(args.data_folder, args.category + '.ndjson')
19 | with open(data_path, 'r') as f:
20 | data = nj.load(f)
21 |
22 | out_path = os.path.join(args.out_folder, args.category)
23 |
24 | if not os.path.exists(out_path):
25 | os.makedirs(out_path)
26 |
27 | for i, sample in enumerate(tqdm(data)):
28 | out_file_path = os.path.join(out_path, f'sketch_{i}')
29 | with open(out_file_path, 'wb') as f:
30 | pickle.dump(sample, f)
31 |
32 | if i > args.max_sketches:
33 | break
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--data_folder', type=str, required=True,
39 | help='QD folder of raw data (.ndjson)')
40 | parser.add_argument('-c', '--category', type=str, required=True, help='name of a category')
41 | parser.add_argument('-o', '--out_folder', type=str, required=True, help='output folder (empty)')
42 | parser.add_argument('-m', '--max_sketches', type=int, required=False, default=10000)
43 | args = parser.parse_args()
44 |
45 | main(args)
46 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.special import comb as choose
4 |
5 | from noise import pnoise2
6 | from shapely.geometry import LineString
7 |
8 |
9 | def resample(seq: np.ndarray, timestamps, granularity):
10 | # seq should be (N x 2) numpy array
11 | seq = LineString(seq)
12 | distances = np.linspace(0, seq.length, granularity)
13 | seq_resampled = LineString([seq.interpolate(d) for d in distances])
14 | seq_resampled = np.array([seq_resampled.xy[0], seq_resampled.xy[1]]).T
15 | ts_resampled = np.linspace(timestamps[0], timestamps[-1], granularity)
16 |
17 | return seq_resampled, ts_resampled
18 |
19 |
20 | def continuous_noise(stroke: np.ndarray, seed=0, noise_level=0.3):
21 | '''
22 | Given stroke is used as seed to generate a continuous noise-stroke
23 | and added to the original stroke; used as a part of augmentation.
24 | Implementation uses Perlin noise.
25 | '''
26 |
27 | if noise_level == 0.:
28 | return stroke
29 |
30 | noise_on_stroke = np.zeros_like(stroke)
31 | stroke_ = stroke + seed
32 | for i in range(len(stroke)):
33 | n1 = pnoise2(*stroke_[i, ...] + 5)
34 | n2 = pnoise2(*stroke_[i, ...] - 5)
35 | noise_on_stroke[i, ...] = [n1, n2]
36 |
37 | noise_on_stroke = noise_on_stroke - noise_on_stroke.mean(0)
38 | return noise_on_stroke * noise_level + stroke
39 |
40 |
41 | def discrete_noise(stroke: np.ndarray, seed=0, noise_level=0.3):
42 | '''Standard random gaussian jittering. Independently applied on each point.'''
43 |
44 | if noise_level == 0.:
45 | return stroke
46 |
47 | old_state = np.random.get_state()
48 | np.random.seed(seed)
49 | stroke = stroke + np.random.rand(*stroke.shape) * noise_level
50 | np.random.set_state(old_state)
51 | return stroke
52 |
53 |
54 | def draw_bezier(ctrlPoints, nPointsCurve=100):
55 | '''
56 | Draws a Bezier curve with given control points.
57 |
58 | ctrlPoints: shape (n+1, 2) matrix containing all control points
59 | nPointsCurve: granularity of the Bezier curve
60 | '''
61 |
62 | def bezier_matrix(degree):
63 | m = degree
64 | Q = np.zeros((degree + 1, degree + 1))
65 | for i in range(degree + 1):
66 | for j in range(degree + 1):
67 | if (0 <= (i+j)) and ((i+j) <= degree):
68 | Q[i, j] = choose(m, j) * choose(m-j, m-i-j) * ((-1)**(m-i-j))
69 | return Q
70 |
71 | def T(ts: np.ndarray, d: int):
72 | # 'ts' is a vector (np.array) of time points
73 | ts = ts[..., np.newaxis]
74 | Q = tuple(ts**n for n in range(d, -1, -1))
75 | return np.concatenate(Q, 1)
76 |
77 | nCtrlPoints, _ = ctrlPoints.shape
78 |
79 | ts = np.linspace(0., 1., num=nPointsCurve)
80 |
81 | curve = np.matmul(T(ts, nCtrlPoints - 1), bezier_matrix(nCtrlPoints-1) @ ctrlPoints)
82 |
83 | return curve
84 |
--------------------------------------------------------------------------------
/gifs/0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/0.gif
--------------------------------------------------------------------------------
/gifs/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/1.gif
--------------------------------------------------------------------------------
/gifs/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/2.gif
--------------------------------------------------------------------------------
/gifs/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/3.gif
--------------------------------------------------------------------------------
/gifs/4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/4.gif
--------------------------------------------------------------------------------
/gifs/5.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/5.gif
--------------------------------------------------------------------------------
/gifs/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/6.gif
--------------------------------------------------------------------------------
/gifs/7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/7.gif
--------------------------------------------------------------------------------
/gifs/8.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/8.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import typing
4 | import contextlib
5 | import numpy as np
6 | import matplotlib
7 | from matplotlib.cm import get_cmap
8 | matplotlib.rcParams['axes.edgecolor'] = '#aaaaaa'
9 | from enum import Enum
10 |
11 | import torch
12 | from torch_ema import ExponentialMovingAverage as EMA
13 | import pytorch_lightning as pl
14 | from pytorch_lightning.utilities.cli import LightningCLI
15 | from pytorch_lightning.callbacks import (
16 | LearningRateMonitor,
17 | TQDMProgressBar,
18 | ModelCheckpoint
19 | )
20 |
21 | from data.dm import ReprType, GenericDM
22 | from data.sketch import Sketch
23 | from models.score import (
24 | ScoreFunc,
25 | TransformerSetFeature,
26 | BiRNNEncoderFeature,
27 | ClassEmbedding
28 | )
29 | from utils import (
30 | positionalencoding1d,
31 | random_fourier_encoding_dyn,
32 | make_pad_mask_for_transformer,
33 | openai_cosine_schedule,
34 | linear_schedule,
35 | CustomViz,
36 | )
37 |
38 |
39 | class SketchDiffusion(pl.LightningModule):
40 |
41 | class ModelType(Enum):
42 | birnn = "birnn"
43 | transformer = "transformer"
44 |
45 | class SamplingAlgo(Enum):
46 | ddpm = "ddpm"
47 | ddim = "ddim"
48 | fddim = "fddim" # only for private use
49 |
50 | class NoiseSchedule(Enum):
51 | linear = "linear"
52 | cosine = "cosine"
53 |
54 | class TimeEmbedding(Enum):
55 | sinusoidal = "sinusoidal"
56 | randomfourier = "randomfourier"
57 |
58 | class VizProcess(Enum):
59 | forward = "forward"
60 | backward = "backward"
61 | both = "both"
62 |
63 | class Parameterization(Enum):
64 | mu = "mu"
65 | eps = "eps"
66 |
67 | def __init__(self,
68 | repr: ReprType = ReprType.threeseqdel,
69 | modeltype: ModelType = ModelType.transformer,
70 | time_embedding: TimeEmbedding = TimeEmbedding.sinusoidal,
71 | vae_weight: float = 0.,
72 | vae_kl_anneal_start: int = 200_000,
73 | vae_kl_anneal_end: int = 400_000,
74 | num_classes: typing.Optional[int] = None,
75 | optim_ema: bool = True,
76 | optim_sched: str = 'steplr',
77 | optim_lr: float = 1.e-4,
78 | optim_decay: float = 1.e-2,
79 | optim_gamma: float = 0.9995,
80 | optim_warmup: int = 3000,
81 | optim_interval: str = 'step',
82 | optim_div_factor: int = 3,
83 | arch_head: int = 4,
84 | arch_layer: int = 4,
85 | arch_internal: int = 64,
86 | arch_layer_cond: typing.Optional[int] = None,
87 | arch_internal_cond: typing.Optional[int] = None,
88 | arch_pe_dim: int = 2,
89 | arch_n_cond_latent: int = 32,
90 | arch_causal: bool = False,
91 | arch_dropout: float = 0.1,
92 | arch_parameterization: Parameterization = Parameterization.eps, # unused
93 | noise_low_noise: float = 1e-4,
94 | noise_high_noise: float = 2e-2,
95 | noise_schedule: NoiseSchedule = NoiseSchedule.linear,
96 | noise_T: int = 1000,
97 | test_variance_strength: float = 0.5,
98 | test_sampling_algo: SamplingAlgo = SamplingAlgo.ddpm,
99 | test_partial_T: typing.Optional[int] = None,
100 | test_recon: bool = True,
101 | test_interp: bool = False,
102 | test_n_viz: int = 10,
103 | test_n_sample_viz: int = 10,
104 | test_viz_fig_compact: bool = True,
105 | text_viz_process: VizProcess = VizProcess.both,
106 | test_save_everything: bool = True
107 | ) -> None:
108 | """
109 | Diffusion Model for Sketches (both set and sequential representation)
110 |
111 | Args:
112 | repr: POINTCLOUD for sets and THREEPOINT for sequence
113 | arch: architecture params of transformer/RNN (head, layer, inp_n_emb, ff_dim, pe_dim)
114 | noise: noise parameters (number of scales, low and high noise variance, T)
115 | test: which test to do (reconstruction, interpolation etc)
116 | """
117 |
118 | super().__init__()
119 | self.save_hyperparameters()
120 | self.hp = self.hparams
121 |
122 | self.cond = self.hp.repr in [
123 | ReprType.threeseqdel_pointcloudcond,
124 | ReprType.threeseqdel_classcond,
125 | ReprType.threeseqabs_classcond,
126 | ReprType.threeseqabs_pointcloudcond,
127 | ReprType.threeseqabs_threeseqabscond
128 | ]
129 |
130 | if self.hp.vae_weight != 0.:
131 | assert self.hp.repr.value.endswith('pointcloudcond') or self.hp.repr.value.endswith('threeseqabscond'), \
132 | "VAE only allowed in bottlenecked conditional models"
133 |
134 | self.elem_dim = 3
135 |
136 | self.pe_dim = self.hp.arch_pe_dim
137 |
138 | n_cond_dim = 0
139 | if self.cond:
140 | n_cond_dim = self.hp.arch_n_cond_latent
141 |
142 | self.seq_pe_dim = self.pe_dim if self.hp.modeltype == self.ModelType.transformer else 0
143 |
144 | if self.cond:
145 | if self.hp.repr.value.endswith('pointcloudcond'):
146 | self.encoder = TransformerSetFeature(
147 | self.hp.arch_internal_cond or self.hp.arch_internal,
148 | self.hp.arch_layer_cond or self.hp.arch_layer,
149 | self.hp.arch_head,
150 | n_cond_dim,
151 | dropout=self.hp.arch_dropout,
152 | vae_weight=self.hp.vae_weight
153 | )
154 | elif self.hp.repr == ReprType.threeseqabs_threeseqabscond:
155 | self.encoder = BiRNNEncoderFeature(
156 | self.hp.arch_internal_cond or self.hp.arch_internal,
157 | self.hp.arch_layer_cond or self.hp.arch_layer,
158 | n_cond_dim,
159 | dropout=self.hp.arch_dropout,
160 | vae_weight=self.hp.vae_weight
161 | )
162 | elif self.hp.repr == ReprType.threeseqdel_classcond or self.hp.repr == ReprType.threeseqabs_classcond:
163 | assert self.hp.num_classes is not None, "class conditional model but num_classes == 0"
164 | self.encoder = ClassEmbedding(self.hp.num_classes, n_cond_dim)
165 | else:
166 | raise NotImplementedError('unknown conditioning type')
167 |
168 | self.scorefn = ScoreFunc(
169 | self.hp.modeltype.value,
170 | # kwargs go here onwards
171 | inp_n_features=self.elem_dim * 2 - 1, # concat complementary repr too
172 | out_n_features=self.elem_dim,
173 | time_pe_features=self.pe_dim,
174 | seq_pe_features=self.seq_pe_dim,
175 | n_cond_features=n_cond_dim,
176 | n_internal=self.hp.arch_internal,
177 | n_head=self.hp.arch_head,
178 | n_layer=self.hp.arch_layer,
179 | causal=self.hp.arch_causal,
180 | dropout=self.hp.arch_dropout
181 | )
182 | if self.hp.optim_ema:
183 | self.ema = EMA([
184 | *self.scorefn.parameters(),
185 | *(self.encoder.parameters() if self.cond else [])
186 | ], decay=0.9999)
187 |
188 | self.register_buffer("pe_proj_W",
189 | torch.randn(self.pe_dim // 2, 1, requires_grad=False), persistent=True
190 | )
191 | if self.seq_pe_dim > 0:
192 | self.register_buffer("seq_proj_W",
193 | torch.randn(self.seq_pe_dim // 2, 1, requires_grad=False), persistent=True
194 | )
195 |
196 | # pre-computing all betas and alphas
197 | schedule_generator = {
198 | SketchDiffusion.NoiseSchedule.linear: linear_schedule,
199 | SketchDiffusion.NoiseSchedule.cosine: openai_cosine_schedule
200 | }[self.hp.noise_schedule]
201 | betas, alphas, alpha_bar, sqrt_alpha_bar, sqrt_one_min_alpha_bar, beta_tilde = \
202 | schedule_generator(
203 | self.hp.noise_T,
204 | self.hp.noise_low_noise * 1000 / self.hp.noise_T,
205 | self.hp.noise_high_noise * 1000 / self.hp.noise_T,
206 | )
207 | self.register_buffer("betas", torch.from_numpy(betas), persistent=False)
208 | self.register_buffer("alphas", torch.from_numpy(alphas), persistent=False)
209 | self.register_buffer("alpha_bar", torch.from_numpy(alpha_bar), persistent=False)
210 | self.register_buffer("sqrt_alpha_bar", torch.from_numpy(sqrt_alpha_bar), persistent=False)
211 | self.register_buffer("sqrt_one_min_alpha_bar", torch.from_numpy(sqrt_one_min_alpha_bar), persistent=False)
212 | self.register_buffer("beta_tilde", torch.from_numpy(beta_tilde), persistent=False)
213 |
214 | def to(self, *args, **kwargs):
215 | ret = super().to(*args, **kwargs)
216 | if self.device.index == 0 and self.hp.optim_ema:
217 | self.ema.to(self.device)
218 | return ret
219 |
220 | def on_fit_start(self) -> None:
221 | self.on_test_start() # needed for testing while training
222 |
223 | def on_before_zero_grad(self, optimizer) -> None:
224 | if self.device.index == 0 and self.hp.optim_ema:
225 | self.ema.update([
226 | *self.scorefn.parameters(),
227 | *(self.encoder.parameters() if self.cond else [])
228 | ])
229 |
230 | def on_save_checkpoint(self, checkpoint: dict) -> None:
231 | if self.hp.optim_ema:
232 | checkpoint["ema"] = self.ema.state_dict()
233 |
234 | def on_load_checkpoint(self, checkpoint) -> None:
235 | if self.hp.optim_ema:
236 | self.ema.load_state_dict(checkpoint["ema"])
237 |
238 | @contextlib.contextmanager
239 | def ema_average(self, activate=True):
240 | if activate:
241 | with self.ema.average_parameters() as ctx:
242 | yield ctx
243 | else:
244 | with contextlib.nullcontext() as ctx:
245 | yield ctx
246 |
247 | def stdg_noise_seeded(self, *dims, seed: typing.Optional[int] = None):
248 | if seed is not None:
249 | _rngstate = torch.get_rng_state()
250 | torch.manual_seed(seed)
251 | _tmp = torch.randn(*dims, device=self.device)
252 | if seed is not None:
253 | torch.set_rng_state(_rngstate)
254 | return _tmp
255 |
256 | def create_batch_with_utilities(self, padded_seq, lens, seed=None):
257 | # padded_seq: (BxTxF) shape
258 | # lens: (B,) shaped long tensor to denote original length of each sample
259 | batch_size, = lens.shape
260 | padded_seq, timestamps = padded_seq[..., :self.elem_dim], padded_seq[..., self.elem_dim:]
261 |
262 | batch = {} # Keys: noise_target, timestamps, lens, noise_t, noisy_points, t
263 |
264 | # different 't's for different sample in the batch
265 | t = torch.randint(1, self.hp.noise_T + 1, size=(batch_size, ))
266 |
267 | g_noise = self.stdg_noise_seeded(*padded_seq.shape, seed=seed)
268 |
269 | batch['timestamps'] = timestamps
270 | batch['lens'] = lens
271 | batch['noise_t'] = self.pe[t - 1, :]
272 | batch['t'] = t - 1
273 | batch['noisy_points'] = padded_seq * self.sqrt_alpha_bar[t - 1, None, None] \
274 | + g_noise * self.sqrt_one_min_alpha_bar[t - 1, None, None]
275 | batch['target'] = g_noise
276 |
277 | return batch
278 |
279 | def ncsn_loss(self, score, noise_target, lens, t):
280 | pad_mask = make_pad_mask_for_transformer(lens, total_length=score.shape[1], device=lens.device)
281 | unreduced_loss = (score - noise_target).pow(2).mean(-1)
282 | masked_loss = (unreduced_loss * (~pad_mask).float()) / lens.unsqueeze(-1)
283 | per_sample_loss = masked_loss.sum(-1) # sum along length since already divided by lengths
284 | return per_sample_loss.mean()
285 |
286 | def configure_optimizers(self):
287 | optim = torch.optim.AdamW(self.parameters(),
288 | lr=self.hp.optim_lr,
289 | weight_decay=self.hp.optim_decay)
290 | if self.hp.optim_sched == 'steplr':
291 | sched = torch.optim.lr_scheduler.StepLR(optim,
292 | step_size=1,
293 | gamma=self.hp.optim_gamma
294 | )
295 | elif self.hp.optim_sched == 'onecyclelr':
296 | steps_per_epoch = len(self.trainer.datamodule.train_dataset) \
297 | // self.trainer.datamodule.batch_size
298 | total_epochs = self.trainer.max_epochs
299 | total_steps = steps_per_epoch * total_epochs
300 | total = total_epochs if self.hp.optim_interval == 'epoch' else total_steps
301 | warmup_fraction = self.hp.optim_warmup / total
302 | sched = torch.optim.lr_scheduler.OneCycleLR(optim,
303 | max_lr=self.hp.optim_lr,
304 | total_steps=total,
305 | anneal_strategy='linear',
306 | cycle_momentum=True,
307 | pct_start=warmup_fraction,
308 | div_factor=self.hp.optim_div_factor,
309 | final_div_factor=1000
310 | )
311 | else:
312 | raise NotImplementedError('scheduler not known/implemented')
313 |
314 | return {
315 | 'optimizer': optim,
316 | 'lr_scheduler': {
317 | 'scheduler': sched,
318 | 'frequency': 1,
319 | 'interval': self.hp.optim_interval
320 | }
321 | }
322 |
323 | def create_posvel_aug_input(self, points):
324 | if self.hp.repr.value.startswith('threeseqdel'):
325 | points_vel = points
326 | points_pos = torch.cumsum(points[..., :-1], dim=1)
327 | elif self.hp.repr.value.startswith('threeseqabs'):
328 | points_vel = torch.cat([
329 | points[:, 0, None, :-1],
330 | (points[:, 1:, :-1] - points[:, :-1, :-1])
331 | ], 1)
332 | points_pos = points
333 | else:
334 | raise NotImplementedError('ReprType not implemented')
335 |
336 | return points_pos, points_vel
337 |
338 | def forward(self, noisy_points, seq_pe, lens, noise_t, cond_latent):
339 | noisy_points_pos, noisy_points_vel = self.create_posvel_aug_input(noisy_points)
340 |
341 | if self.hp.modeltype == SketchDiffusion.ModelType.transformer:
342 | origin = torch.zeros(noisy_points.size(0), 1, 3, dtype=self.dtype, device=self.device, requires_grad=False)
343 | noisy_points_pos = torch.cat([origin[..., :noisy_points_pos.shape[-1]], noisy_points_pos], 1)
344 | noisy_points_vel = torch.cat([origin[..., :noisy_points_vel.shape[-1]], noisy_points_vel], 1)
345 | seq_pe = torch.cat([self._create_seq_embeddings(origin[..., :1]), seq_pe], dim=1) # add origin timestamp
346 | lens = lens + 1 # due an added origin
347 |
348 | with self.ema_average(not self.training and self.hp.optim_ema):
349 | out = self.scorefn((noisy_points_pos, noisy_points_vel), seq_pe, lens, noise_t, cond_latent)
350 |
351 | return out
352 |
353 | def _create_seq_embeddings(self, timestamps):
354 | if self.seq_pe_dim > 0:
355 | batch_size, max_len, _ = timestamps.shape
356 | timestamps = timestamps.permute(2, 0, 1)
357 | temb = random_fourier_encoding_dyn(timestamps.view(1, batch_size * max_len), self.seq_proj_W, scale=4.)
358 | return temb.view(batch_size, max_len, self.seq_pe_dim)
359 | else:
360 | return None
361 |
362 | def encode(self, *args):
363 | if self.cond:
364 | with self.ema_average(not self.training and self.hp.optim_ema):
365 | return self.encoder(*args)
366 | else:
367 | return None, 0.
368 |
369 | def training_step(self, batch, batch_idx):
370 | cond_batch, batch = batch
371 |
372 | batch = self.create_batch_with_utilities(*batch)
373 | cond_latent, kl_loss = self.encode(cond_batch)
374 | score = self(batch['noisy_points'], self._create_seq_embeddings(batch['timestamps']),
375 | batch['lens'], batch['noise_t'], cond_latent)
376 | loss = self.ncsn_loss(score, batch['target'], batch['lens'], batch['t'])
377 | self.log('train/loss', loss, prog_bar=True)
378 | if self.hp.vae_weight != 0.:
379 | kl_loss = kl_loss.mean()
380 | self.log('train/kl', kl_loss, prog_bar=False)
381 | kl_annealing_factor = min(max(self.global_step - self.hp.vae_kl_anneal_start, 0.) / \
382 | (self.hp.vae_kl_anneal_end - self.hp.vae_kl_anneal_start), 1.)
383 | self.log('train/kl_factor', kl_annealing_factor, prog_bar=False)
384 | else:
385 | kl_annealing_factor = 0.
386 | return loss + \
387 | self.hp.vae_weight * kl_annealing_factor * kl_loss
388 |
389 | def validation_step(self, batch, batch_idx):
390 | loss = self.training_step(batch, batch_idx)
391 |
392 | # on-the-fly testing while training
393 | if batch_idx == 0 and (self.current_epoch + 0) % 300 == 0 and self.device.index == 0:
394 | save_file_path = os.path.join(self.trainer.log_dir,
395 | f"ddpm1.pdf")
396 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddpm, langevin_strength=1.)
397 | self.fig.savefig(save_file_path, bbox_inches='tight')
398 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddpm1')
399 |
400 | save_file_path = os.path.join(self.trainer.log_dir,
401 | f"ddpm.5.pdf")
402 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddpm, langevin_strength=0.5)
403 | self.fig.savefig(save_file_path, bbox_inches='tight')
404 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddpm.5')
405 |
406 | save_file_path = os.path.join(self.trainer.log_dir,
407 | f"ddim_reco.pdf")
408 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddim, langevin_strength=0.)
409 | self.fig.savefig(save_file_path, bbox_inches='tight')
410 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddim_reco')
411 |
412 | save_file_path = os.path.join(self.trainer.log_dir,
413 | f"ddim_gen.pdf")
414 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddim, langevin_strength=0., generation=True)
415 | self.fig.savefig(save_file_path, bbox_inches='tight')
416 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddim_gen')
417 |
418 | return loss
419 |
420 | def validation_epoch_end(self, losses_for_batches) -> None:
421 | valid_loss = sum(losses_for_batches) / len(losses_for_batches)
422 | self.log('valid/loss', valid_loss, prog_bar=True)
423 |
424 | def on_test_start(self) -> None:
425 | ts = torch.linspace(1, self.hp.noise_T, self.hp.noise_T,
426 | dtype=self.dtype, device=self.device) / self.hp.noise_T
427 | self.pe = random_fourier_encoding_dyn(ts[None, ...], self.pe_proj_W, scale=4.) \
428 | if self.hp.time_embedding == SketchDiffusion.TimeEmbedding.randomfourier else \
429 | positionalencoding1d(self.pe_dim, self.hp.noise_T, N=self.hp.noise_T,
430 | dtype=self.dtype, device=self.device)
431 |
432 | n_viz = self.hp.test_n_viz * 2 if self.hp.text_viz_process == SketchDiffusion.VizProcess.both else self.hp.test_n_viz
433 | cviz = CustomViz(self.hp.test_n_sample_viz, n_viz, compact_mode=self.hp.test_viz_fig_compact)
434 | self.fig, self.ax = cviz, cviz
435 |
436 | def cache_reverse_process(self, all_points_t, t, lens, idx, prefix='gen'):
437 | # npz_save_path = os.path.join(self.trainer.log_dir, f'{prefix}_rev_{idx}.npz')
438 | # with open(npz_save_path, 'wb') as f:
439 | # np.savez(f, reverse=all_points_t.cpu().numpy(), lens=lens.cpu().numpy())
440 | samples = all_points_t[t, ...]
441 | samples = torch.split(samples, self.ax.shape[0], dim=0)
442 | lens = torch.split(lens, self.ax.shape[0], dim=0)
443 | for j in range(self.ax.shape[1]):
444 | try:
445 | self.draw_on_seq(samples[j], lens[j], j)
446 | except:
447 | for i in range(self.ax.shape[0]):
448 | self.ax[i, j].cla()
449 | self.ax[i, j].axis('off')
450 | save_file_path = os.path.join(self.trainer.log_dir, f'{prefix}_{idx}.svg')
451 | self.fig.savefig(save_file_path, bbox_inches='tight')
452 |
453 | def test_step(self, batch, batch_idx):
454 | if self.hp.test_recon:
455 | save_file_path = os.path.join(self.trainer.log_dir, f'diff_{batch_idx}.svg')
456 | rev_dict = self.reconstruction(batch, self.hp.test_sampling_algo, self.hp.test_variance_strength,
457 | generation=True, partial_t=self.hp.test_partial_T)
458 | self.fig.savefig(save_file_path, bbox_inches='tight')
459 | if self.hp.test_save_everything:
460 | _, (vels, lens) = batch
461 | vels, ts = vels[..., :self.elem_dim], vels[..., self.elem_dim:]
462 | orig, orig_len = self.velocity_to_position(vels, lens)
463 | # self.cache_reverse_process(orig[None, ...], -1, orig_len, idx=batch_idx, prefix='orig')
464 | self.cache_reverse_process(rev_dict["all"], -1, rev_dict["lens"], idx=batch_idx, prefix=f'gen')
465 |
466 | if self.hp.test_interp:
467 | save_file_path = os.path.join(self.trainer.log_dir, f'interp_{batch_idx}.svg')
468 | _ = self.interpolation(batch, self.hp.test_sampling_algo, langevin_strength=0.)
469 | self.fig.savefig(save_file_path, bbox_inches='tight')
470 |
471 | def velocity_to_position(self, points, lens):
472 | B, _, _ = points.shape
473 |
474 | points = torch.cat([
475 | torch.zeros(B, 1, self.elem_dim, dtype=points.dtype, device=points.device),
476 | points
477 | ], dim=1)
478 | lens = lens + 1 # there is the extra initial point along length
479 |
480 | if self.hp.repr.value.startswith('threeseqdel'):
481 | # last one is pen-up bit -- leave it as is
482 | points[..., :-1] = torch.cumsum(points[..., :-1], dim=1)
483 | else:
484 | # this incorporates THREESEQABS
485 | pass
486 |
487 | points[..., -1][points[..., -1] > 0.8] = 1.
488 | points[..., -1][points[..., -1] < 0.8] = 0.
489 |
490 | return points, lens
491 |
492 | def draw_on_seq(self, points, lens, t_):
493 | points = points.detach().cpu().numpy()
494 | lens = lens.cpu().numpy()
495 |
496 | cm = get_cmap('copper') # I like this one
497 | for b in range(self.hp.test_n_sample_viz):
498 | sample_seq: Sketch = Sketch.from_threeseqabs(points[b, :lens[b], :])
499 | sample_seq.draw(self.ax[b, t_], color=cm, cla=True, scatter=False)
500 |
501 | def forward_diffusion(self, velocs, lens, draw=True, end_t=None):
502 | viz_t = np.linspace(0, end_t or self.hp.noise_T, self.hp.test_n_viz, dtype=np.int64)
503 |
504 | if draw: # the original sample
505 | points, points_len = self.velocity_to_position(velocs, lens)
506 | self.draw_on_seq(points, points_len, self.t_)
507 | self.t_ += 1
508 |
509 | for t in viz_t[1:]:
510 | g_noise = self.stdg_noise_seeded(*velocs.shape)
511 |
512 | velocs_t = velocs * self.sqrt_alpha_bar[t - 1, None, None] \
513 | + g_noise * self.sqrt_one_min_alpha_bar[t - 1, None, None]
514 |
515 | if draw:
516 | points_t, points_len = self.velocity_to_position(velocs_t, lens)
517 | self.draw_on_seq(points_t, points_len, self.t_)
518 | self.t_ += 1
519 |
520 | return velocs_t
521 |
522 | def reverse_purturb_DDPM(self, points, timestamps, t, lens, cond_latent, steps, noise_weight=1.):
523 | now, now_index = steps[t], steps[t] - 1
524 |
525 | score = self(points, timestamps, lens, self.pe[now_index, :].repeat(points.shape[0], 1), cond_latent)
526 | k1 = 1. / torch.sqrt(self.alphas[now_index])
527 | k2 = (1. - self.alphas[now_index]) / self.sqrt_one_min_alpha_bar[now_index]
528 | mean = k1 * (points - k2 * score)
529 |
530 | gen_noise = self.stdg_noise_seeded(*points.shape) * torch.sqrt(self.beta_tilde[now_index]) \
531 | if now > 1 else 0.
532 |
533 | points = mean + gen_noise * noise_weight
534 | return points
535 |
536 | def reverse_purturb_DDIM(self, points, timestamps, t, lens, cond_latent, steps, noise_weight=0.):
537 | now, now_index = steps[t], steps[t] - 1
538 |
539 | score = self(points, timestamps, lens, self.pe[now_index, :].repeat(points.shape[0], 1), cond_latent)
540 | x0_pred = (points - self.sqrt_one_min_alpha_bar[now_index] * score) \
541 | / self.sqrt_alpha_bar[now_index]
542 |
543 | if now > 1:
544 | prev, prev_index = steps[t + 1], steps[t + 1] - 1
545 |
546 | # generalized version of DDIM sampler, with explicit \sigma_t
547 | s1 = self.sqrt_one_min_alpha_bar[prev_index] / self.sqrt_one_min_alpha_bar[now_index]
548 | s2 = torch.sqrt(1. - self.alpha_bar[now_index] / self.alpha_bar[prev_index])
549 | sigma = (s1 * s2) * noise_weight # additional control for the noise
550 |
551 | gen_noise = self.stdg_noise_seeded(*points.shape)
552 |
553 | points = self.sqrt_alpha_bar[prev_index] * x0_pred \
554 | + torch.sqrt(1. - self.alpha_bar[prev_index] - sigma**2) * score \
555 | + gen_noise * sigma
556 | else:
557 | points = x0_pred
558 |
559 | return points
560 |
561 | def forward_purturb_DDIM(self, points, timestamps, t, lens, cond_latent, steps, noise_weight=1.):
562 | # DDIM's reverse of the reverse process -- integrating the ODE backwards
563 | now, now_index = steps[t], steps[t] - 1
564 | prev, prev_index = steps[t] - 1, steps[t] - 2
565 |
566 | score = self(points, timestamps, lens, self.pe[prev_index, :].repeat(points.shape[0], 1), cond_latent) \
567 | if prev != 0 else 0.
568 |
569 | xT_pred = (points - self.sqrt_one_min_alpha_bar[prev_index] * score) \
570 | / (self.sqrt_alpha_bar[prev_index] if prev != 0 else 1.)
571 |
572 | points = self.sqrt_alpha_bar[now_index] * xT_pred + self.sqrt_one_min_alpha_bar[now_index] * score
573 | return points
574 |
575 | def reverse_diffusion(self, points, timestamps, lens, cond_latent, sampling_algo, langevin_strength, draw=True, start_t=None):
576 | veloc_t = points
577 |
578 | if start_t is not None:
579 | assert sampling_algo == SketchDiffusion.SamplingAlgo.ddpm, \
580 | 'partially stopping diffusion makes sense only for stochastic sampler'
581 | assert start_t <= self.hp.noise_T, f"partial stopping time must be less that T={self.hp.noise_T}"
582 |
583 | inference_steps, sampling_fn = {
584 | SketchDiffusion.SamplingAlgo.ddpm: (
585 | np.linspace(start_t or self.hp.noise_T, 1, start_t or self.hp.noise_T, dtype=np.int64),
586 | SketchDiffusion.reverse_purturb_DDPM
587 | ),
588 | SketchDiffusion.SamplingAlgo.ddim: (
589 | np.linspace(self.hp.noise_T, 1, self.hp.noise_T, dtype=np.int64),
590 | SketchDiffusion.reverse_purturb_DDIM
591 | ),
592 | SketchDiffusion.SamplingAlgo.fddim: (
593 | np.linspace(1, self.hp.noise_T, self.hp.noise_T, dtype=np.int64),
594 | SketchDiffusion.forward_purturb_DDIM
595 | )
596 | }[sampling_algo]
597 |
598 | viz_t = np.linspace(self.hp.noise_T, 1, self.hp.test_n_viz, dtype=np.int64)
599 |
600 | points_t_all_steps = []
601 | for t in range(inference_steps.shape[0]):
602 | veloc_t = sampling_fn(self, veloc_t, timestamps, t, lens, cond_latent,
603 | inference_steps, noise_weight=langevin_strength)
604 | points_t, points_len = self.velocity_to_position(veloc_t, lens)
605 | if inference_steps[t] in viz_t:
606 | if draw:
607 | self.draw_on_seq(points_t, points_len, self.t_)
608 | self.t_ += 1
609 |
610 | if self.hp.test_save_everything:
611 | points_t_all_steps.append(points_t)
612 |
613 | return {
614 | "orig_last": veloc_t,
615 | "last": points_t,
616 | "all": torch.stack(points_t_all_steps, 0) if self.hp.test_save_everything else [ ],
617 | "lens": points_len
618 | }
619 |
620 | def reconstruction(self, batch, sampling_algo, langevin_strength, generation=False, partial_t=None):
621 | assert sampling_algo != SketchDiffusion.SamplingAlgo.fddim, "FDDIM is not to be used by public API"
622 |
623 | self.t_ = 0
624 | cond_batch, (points, lens) = batch
625 |
626 | cond_latent, _ = self.encode(cond_batch)
627 | points, timestamps = points[..., :self.elem_dim], points[..., self.elem_dim:]
628 |
629 | if sampling_algo != SketchDiffusion.SamplingAlgo.ddim:
630 | diffused = self.forward_diffusion(points, lens,
631 | draw=self.hp.text_viz_process == SketchDiffusion.VizProcess.forward \
632 | or self.hp.text_viz_process == SketchDiffusion.VizProcess.both,
633 | end_t=partial_t)
634 |
635 | if partial_t is None:
636 | perm = torch.randperm(lens.size(0))
637 | lens = lens[perm] # reset lengths
638 | diffused = torch.randn_like(diffused)
639 | else:
640 | # execute forward DDIM (feature extraction)
641 | diffused = self.reverse_diffusion(points, self._create_seq_embeddings(timestamps), lens, cond_latent,
642 | SketchDiffusion.SamplingAlgo.fddim, langevin_strength,
643 | draw=self.hp.text_viz_process == SketchDiffusion.VizProcess.forward \
644 | or self.hp.text_viz_process == SketchDiffusion.VizProcess.both)
645 | diffused = diffused["orig_last"]
646 | if generation:
647 | diffused = torch.randn_like(diffused)
648 |
649 | rev_dict = self.reverse_diffusion(diffused, self._create_seq_embeddings(timestamps), lens, cond_latent,
650 | sampling_algo, langevin_strength,
651 | draw=self.hp.text_viz_process == SketchDiffusion.VizProcess.backward \
652 | or self.hp.text_viz_process == SketchDiffusion.VizProcess.both, start_t=partial_t)
653 | return rev_dict
654 |
655 | def interpolation(self, batch, sampling_algo, langevin_strength=0.):
656 | assert sampling_algo != SketchDiffusion.SamplingAlgo.fddim, "FDDIM is not to be used by public API"
657 |
658 | cond_batch1, (points1, lens1) = batch # samples not really needed, only lens
659 |
660 | # random shuffle before executing generation
661 | perm = torch.randperm(points1.shape[0], device=points1.device)
662 | points2, lens2 = points1[perm, ...], lens1[perm]
663 |
664 | cond_latent1, _ = self.encode(cond_batch1)
665 | cond_latent2 = cond_latent1[perm, ...] if self.cond else None
666 |
667 | points1, timestamps1 = points1[..., :self.elem_dim], points1[..., self.elem_dim:]
668 | points2, timestamps2 = points2[..., :self.elem_dim], points2[..., self.elem_dim:]
669 |
670 | prior1 = torch.randn_like(points1)
671 | prior2 = torch.randn_like(points2)
672 |
673 | for a_, alpha in enumerate(np.linspace(0., 1., self.ax.shape[1])):
674 | if not self.cond:
675 | prior = prior1 * (1. - alpha) + prior2 * alpha
676 | lens = lens1
677 | cond_latent = None
678 | else:
679 | prior = prior1
680 | lens = lens1
681 | cond_latent = cond_latent1 * (1. - alpha) + cond_latent2 * alpha
682 |
683 | if self.hp.modeltype == SketchDiffusion.ModelType.transformer:
684 | raise NotImplementedError('interpolation with transformer model not yet implemented')
685 |
686 | recon_dict = self.reverse_diffusion(prior, None, lens, cond_latent,
687 | sampling_algo, langevin_strength=0., draw=False)
688 | self.draw_on_seq(recon_dict["last"], recon_dict["lens"], a_)
689 |
690 |
691 | if __name__ == '__main__':
692 | cli = LightningCLI(SketchDiffusion, GenericDM, run=True,
693 | subclass_mode_data=True,
694 | parser_kwargs={"parser_mode": "omegaconf"},
695 | trainer_defaults={
696 | 'callbacks': [
697 | LearningRateMonitor(logging_interval='step'),
698 | ModelCheckpoint(monitor='valid/loss', filename='model', save_last=True),
699 | TQDMProgressBar(refresh_rate=1 if sys.stdin.isatty() else 0)
700 | ]
701 | })
702 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | @author : Hyunwoong
3 | @when : 2019-10-22
4 | @homepage : https://github.com/gusdnd852
5 | """
--------------------------------------------------------------------------------
/models/score.py:
--------------------------------------------------------------------------------
1 | import typing
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6 |
7 | from utils import make_pad_mask_for_transformer
8 |
9 |
10 | class ConditionalTransformerEncoder(nn.Module):
11 |
12 | def __init__(self, n_input, n_internal, n_layers, n_head, causal=False, dropout=0.) -> None:
13 | super().__init__()
14 |
15 | self.n_input = n_input
16 | self.n_internal = n_internal
17 | self.n_layers = n_layers
18 | self.n_head = n_head
19 | self.causal = causal
20 | self.dropout = dropout
21 |
22 | self.embedder = nn.Linear(self.n_input, self.n_internal)
23 |
24 | self.transformer = nn.TransformerEncoder(
25 | nn.TransformerEncoderLayer(
26 | self.n_internal,
27 | self.n_head,
28 | dim_feedforward=self.n_internal * 2,
29 | batch_first=True, dropout=self.dropout, activation=F.silu
30 | ),
31 | num_layers=self.n_layers
32 | )
33 |
34 | def forward(self, noisy, lens):
35 | _, max_len, _ = noisy.shape
36 | len_padd_mask = make_pad_mask_for_transformer(lens, max_len, noisy.device)
37 |
38 | if self.causal:
39 | I = torch.eye(max_len, dtype=noisy.dtype, device=noisy.device)
40 | attn_mask = (torch.cumsum(I, -1) - I) == 1.
41 | else:
42 | attn_mask = None
43 |
44 | input_emb = self.embedder(noisy)
45 | output = self.transformer(input_emb, mask=attn_mask, src_key_padding_mask=len_padd_mask)
46 |
47 | return output
48 |
49 |
50 | class ConditionalBiRNN(nn.Module):
51 |
52 | def __init__(self, n_input, n_hidden, n_layers, dropout=0., causal=False) -> None:
53 | super().__init__()
54 |
55 | self.n_input = n_input
56 | self.n_hidden = n_hidden
57 | self.n_layers = n_layers
58 | self.dropout = dropout
59 | self.causal = causal
60 |
61 | self.rnn = nn.GRU(self.n_input, self.n_hidden, self.n_layers,
62 | batch_first=True,
63 | dropout=self.dropout,
64 | bidirectional=not self.causal)
65 |
66 | directionality = 2 if not self.causal else 1
67 | self.out_proj = nn.Linear(
68 | self.n_hidden * directionality,
69 | self.n_hidden
70 | )
71 |
72 | def forward(self, noisy, lens):
73 | noisy_packed = pack_padded_sequence(noisy, lens.cpu(), batch_first=True, enforce_sorted=False)
74 | hid, _ = self.rnn(noisy_packed)
75 | out_unpacked, _ = pad_packed_sequence(hid, batch_first=True)
76 |
77 | return self.out_proj(out_unpacked)
78 |
79 |
80 | class ScoreFunc(nn.Module):
81 |
82 | def __init__(self, modeltype, *, inp_n_features=5, out_n_features=3, time_pe_features=2, seq_pe_features=2,
83 | n_cond_features=0, n_head=4, n_layer=4, n_internal=64, causal=False, dropout=0.) -> None:
84 | super().__init__()
85 |
86 | self.modeltype = modeltype
87 | self.inp_n_features = inp_n_features
88 | self.out_n_features = out_n_features
89 | self.time_pe_features = time_pe_features # for diffusion steps
90 | self.seq_pe_features = seq_pe_features # for sequence time-stamps
91 | self.n_cond_features = n_cond_features # for conditioning
92 | self.n_internal = n_internal
93 | self.n_head = n_head
94 | self.n_layer = n_layer
95 | self.causal = causal
96 | self.dropout = dropout
97 |
98 | self.n_additionals = self.time_pe_features + self.seq_pe_features + self.n_cond_features
99 | self.n_total_features = self.inp_n_features + self.n_additionals
100 |
101 | if self.modeltype == 'birnn':
102 | self.model = ConditionalBiRNN(self.n_total_features, self.n_internal, self.n_layer,
103 | dropout=self.dropout, causal=self.causal)
104 | elif self.modeltype == 'transformer':
105 | self.model = ConditionalTransformerEncoder(self.n_total_features,
106 | self.n_internal, self.n_layer, self.n_head, causal=self.causal, dropout=self.dropout)
107 | else:
108 | raise NotImplementedError(f"Unknown model type {self.modeltype.value}")
109 |
110 | self.final_proj = nn.Sequential(
111 | nn.Linear(self.n_internal * (2 if self.modeltype == 'transformer' else 1) \
112 | + self.n_additionals - self.seq_pe_features, self.out_n_features),
113 | )
114 |
115 | def forward(self, noisy, seq_pe, lens, time_pe, cond=None):
116 | noisy_pos, noisy_vel = noisy
117 | noisy = torch.cat([noisy_pos, noisy_vel], -1)
118 |
119 | if isinstance(cond, tuple):
120 | # This is 'threeseqabs_threeseqabseqsampledcond' repr.
121 | # But not a good way to check (TODO: better API)
122 | cond = torch.cat(cond, -1)
123 |
124 | batch_size, max_len, _ = noisy.shape
125 |
126 | time_pe = time_pe.unsqueeze(1).repeat(1, max_len, 1)
127 |
128 | if cond is not None:
129 | assert self.n_cond_features != 0, "conditioning is being done but no dimension allocated"
130 | if len(cond.shape) == 2:
131 | cond = cond.unsqueeze(1).repeat(1, max_len, 1)
132 | time_cond = torch.cat([time_pe, cond], -1)
133 | else:
134 | time_cond = time_pe
135 |
136 | if self.seq_pe_features > 0:
137 | additionals = torch.cat([seq_pe, time_cond], -1)
138 | else:
139 | additionals = time_cond
140 |
141 | output = self.model(
142 | torch.cat([noisy, additionals], -1),
143 | lens
144 | )
145 |
146 | if self.modeltype == 'birnn':
147 | return self.final_proj(torch.cat([output, time_cond], -1))
148 | else:
149 | conseq_cat_output = torch.cat([output[:, :-1, :], output[:, 1:, ]], -1)
150 | return self.final_proj(torch.cat([conseq_cat_output, time_cond[:, 1:, :]], -1))
151 |
152 |
153 | class TransformerSetFeature(ConditionalTransformerEncoder):
154 |
155 | def __init__(self, n_internal, n_layers, n_head, n_latent, dropout=0., vae_weight=0.) -> None:
156 | # '+1' is for the extra feature for denoting feature extractor token
157 | super().__init__(2 + 1, n_internal, n_layers, n_head, causal=False, dropout=dropout)
158 | self.n_latent = n_latent
159 | self.vae_weight = vae_weight
160 |
161 | if self.vae_weight == 0.:
162 | self.latent_proj = nn.Sequential(
163 | nn.Linear(n_internal, self.n_latent),
164 | nn.Tanh()
165 | )
166 | else:
167 | self.latent_proj_mean = nn.Sequential(nn.Linear(n_internal, self.n_latent))
168 | self.latent_proj_logvar = nn.Sequential(nn.Linear(n_internal, self.n_latent))
169 |
170 | def forward(self, cond_batch):
171 | set_input, lens = cond_batch
172 | B, L, _ = set_input.shape
173 | # creating an extra feature extractor token
174 | pad_token = torch.zeros(B, L, 1, device=set_input.device, dtype=set_input.dtype)
175 | feat_token = torch.tensor([0., 0., 1.], device=set_input.device, dtype=set_input.dtype)
176 | feat_token = feat_token[None, None, :].repeat(B, 1, 1)
177 | set_input = torch.cat([set_input, pad_token], -1)
178 | set_input = torch.cat([feat_token, set_input], 1)
179 | lens = lens + 1 # extra token for feature extraction
180 |
181 | trans_out = super().forward(set_input, lens)
182 |
183 | if self.vae_weight == 0.:
184 | return self.latent_proj(trans_out[:, 0]), 0.
185 | else:
186 | mu = self.latent_proj_mean(trans_out[:, 0])
187 | logvar = self.latent_proj_logvar(trans_out[:, 0])
188 | posterior = torch.distributions.Normal(mu, torch.exp(0.5 * logvar))
189 | prior = torch.distributions.Normal(
190 | torch.zeros_like(mu),
191 | torch.ones_like(logvar)
192 | )
193 | return posterior.rsample(), torch.distributions.kl_divergence(posterior, prior)
194 |
195 |
196 | class BiRNNEncoderFeature(ConditionalBiRNN):
197 |
198 | def __init__(self, n_hidden, n_layers, n_latent, dropout=0., vae_weight=0.) -> None:
199 | super().__init__(3, n_hidden, n_layers, dropout)
200 | self.out_proj = nn.Identity()
201 | self.vae_weight = vae_weight
202 |
203 | self.n_latent = n_latent
204 |
205 | if self.vae_weight == 0.:
206 | self.latent_proj = nn.Sequential(
207 | nn.Linear(2 * self.n_hidden, self.n_latent),
208 | nn.Tanh()
209 | )
210 | else:
211 | self.latent_proj_mean = nn.Sequential(nn.Linear(self.n_hidden, self.n_latent))
212 | self.latent_proj_logvar = nn.Sequential(nn.Linear(self.n_hidden, self.n_latent))
213 |
214 | def forward(self, cond_batch):
215 | batch, lens = cond_batch
216 | batch_size, max_len, _ = batch.shape
217 | batch = batch[..., :-1] # exclude the timestamps
218 | out = super().forward(batch, lens).view(batch_size, max_len, 2, self.n_hidden)
219 | out_fwd, out_bwd = out[:, :, 0, :], out[:, :, 1, :]
220 | fwd_feat = torch.gather(
221 | out_fwd,
222 | 1,
223 | lens[:, None, None].repeat(1, 1, self.n_hidden) - 1
224 | ).squeeze()
225 | bwd_feat = out_bwd[:, 0, :]
226 |
227 | if self.vae_weight == 0.:
228 | return self.latent_proj(torch.cat([fwd_feat, bwd_feat], -1)), 0.
229 | else:
230 | mu = self.latent_proj_mean(torch.cat([fwd_feat, bwd_feat], -1))
231 | logvar = self.latent_proj_logvar(torch.cat([fwd_feat, bwd_feat], -1))
232 | posterior = torch.distributions.Normal(mu, torch.exp(0.5 * logvar))
233 | prior = torch.distributions.Normal(
234 | torch.zeros_like(mu),
235 | torch.ones_like(logvar)
236 | )
237 | return posterior.rsample(), torch.distributions.kl_divergence(posterior, prior)
238 |
239 |
240 | class Lambda(nn.Module):
241 |
242 | def __init__(self, fn: typing.Callable) -> None:
243 | super().__init__()
244 | self.fn = fn
245 |
246 | def forward(self, x):
247 | # the extra zero is to make it compatible with other encoder
248 | return self.fn(x), 0.
249 |
250 |
251 | class ClassEmbedding(nn.Module):
252 |
253 | def __init__(self, num_classes, emb_dim) -> None:
254 | super().__init__()
255 |
256 | self.num_classes = num_classes
257 | self.emb_dim = emb_dim
258 | self.emb = nn.Embedding(self.num_classes, self.emb_dim)
259 |
260 | def forward(self, x):
261 | return self.emb(x), 0.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ndjson
2 | scipy
3 | matplotlib
4 | tqdm
5 | Pillow
6 | pytorch-lightning==1.5.9
7 | simplification
8 | noise
9 | torch-ema
10 | omegaconf
11 | jsonargparse[signatures]
12 | shapely
13 | wandb
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import typing
3 | import math
4 | import torch
5 | import numpy as np
6 | import torch.nn.functional as F
7 | from pytorch_lightning.loggers import WandbLogger
8 | from wandb.util import generate_id
9 |
10 | import matplotlib.pyplot as plt
11 |
12 |
13 | def positionalencoding1d(d_model, length, N=10000, dtype=None, device=None):
14 | """
15 | :param d_model: dimension of the model
16 | :param length: length of positions
17 | :return: length*d_model position matrix
18 | """
19 | if d_model % 2 != 0:
20 | raise ValueError("Cannot use sin/cos positional encoding with "
21 | "odd dim (got dim={:d})".format(d_model))
22 | pe = torch.zeros(length, d_model, dtype=dtype, device=device)
23 | position = torch.arange(0, length, dtype=dtype, device=device).unsqueeze(1)
24 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=dtype, device=device) *
25 | -(math.log(N) / d_model)))
26 | pe[:, 0::2] = torch.sin(position.float() * div_term)
27 | pe[:, 1::2] = torch.cos(position.float() * div_term)
28 |
29 | return pe
30 |
31 |
32 | def random_fourier_encoding_dyn(ts, W, scale=4.):
33 | proj = (W * scale) @ ts
34 | emb = torch.cat([torch.sin(2 * torch.pi * proj), torch.cos(2 * torch.pi * proj)], 0)
35 | return emb.T
36 |
37 |
38 | def make_pad_mask_for_transformer(lens, total_length=None, device=None):
39 | total_length = total_length or max(lens)
40 | pad = torch.zeros(len(lens), total_length + 1, device=device)
41 | for b, l in enumerate(lens):
42 | pad[b, l] = 1.
43 | pad = torch.cumsum(pad, 1)
44 | return (pad[:, :-1] == 1.)
45 |
46 |
47 | def nonunif_timestep_selector(T, infer_T, gamma=2.):
48 | ui = np.linspace(1., 0., infer_T) # uniform index
49 | return np.unique(np.clip(
50 | # sample using gamma curves (y = x^gamma)
51 | np.floor((ui ** gamma) * T), 1., T
52 | ))[::-1].astype(np.int64)
53 |
54 |
55 | def openai_cosine_schedule(T, *args, s=0.008):
56 | # # explicitely defined $\bar{\alpha_t}$ and cosine function;
57 | # # beta and alpha derived thereafter; suggested by "Improved Denoising ..
58 | # # .. Diffusion Probabilistic Models" by OpenAI
59 |
60 | def f(t): return math.cos((t/T + s) / (1 + s) * math.pi / 2) ** 2
61 | alpha_bar = np.array([f(t) / f(0) for t in range(T + 1)], dtype=np.float32)
62 | sqrt_alpha_bar = np.sqrt(alpha_bar)
63 | sqrt_one_min_alpha_bar = np.sqrt(1. - alpha_bar)
64 | betas = np.clip(1. - alpha_bar[1:] / alpha_bar[:-1], 0., 0.999)
65 | alphas = 1. - betas
66 | beta_tilde = (1. - alpha_bar[:-1]) / (1. - alpha_bar[1:]) * betas
67 |
68 | return betas, alphas, alpha_bar[1:], \
69 | sqrt_alpha_bar[1:], sqrt_one_min_alpha_bar[1:], beta_tilde
70 |
71 |
72 | def linear_schedule(T, low_noise, high_noise):
73 | # standard linear schedule defined in terms of $\beta_t$
74 | betas = np.linspace(low_noise, high_noise, T, dtype=np.float32)
75 | alphas = 1. - betas
76 | alpha_bar = np.cumprod(alphas, 0)
77 | sqrt_alpha_bar = np.sqrt(alpha_bar)
78 | sqrt_one_min_alpha_bar = np.sqrt(1. - alpha_bar)
79 | beta_tilde_wo_first_term = ((sqrt_one_min_alpha_bar[:-1] / sqrt_one_min_alpha_bar[1:])**2 * betas[1:])
80 | beta_tilde = np.array([
81 | beta_tilde_wo_first_term[0],
82 | *beta_tilde_wo_first_term
83 | ])
84 |
85 | return betas, alphas, alpha_bar, \
86 | sqrt_alpha_bar, sqrt_one_min_alpha_bar, beta_tilde
87 |
88 |
89 | def cg_subtracted_noise(noise, lens):
90 | mask = torch.cumprod(1. - F.one_hot(lens, num_classes=noise.size(1) + 1)[:, :-1, None].float(), 1)
91 | # make sure the padding doesn't interfere in CoM calculation
92 | com = (mask * noise).sum(1, keepdim=True) / lens[:, None, None]
93 | return noise - com
94 |
95 |
96 | class CustomWandbLogger(WandbLogger):
97 |
98 | def __init__(self,
99 | name: typing.Optional[str],
100 | save_dir: typing.Optional[str] = 'logs',
101 | group: typing.Optional[str] = 'common',
102 | project: typing.Optional[str] = 'diffset',
103 | log_model: typing.Optional[bool] = True,
104 | offline: bool = False,
105 | entity: typing.Optional[str] = 'dasayan05'):
106 | rid = generate_id()
107 | name_rid = '-'.join([name, rid])
108 | super().__init__(name=name_rid, id=rid, offline=offline,
109 | save_dir=os.path.join(save_dir, name_rid), project=project,
110 | log_model=log_model, group=group, entity=entity)
111 |
112 |
113 | class CustomViz(object):
114 |
115 | def __init__(self, test_n_sample_viz: int, n_viz: int, compact_mode: bool = True, subfig_slack: float = 0.) -> None:
116 | super().__init__()
117 |
118 | self.test_n_sample_viz = test_n_sample_viz
119 | self.n_viz = n_viz
120 | self.compact_mode = compact_mode
121 |
122 | if self.compact_mode:
123 | self.fig, self.ax = plt.subplots(
124 | self.test_n_sample_viz,
125 | self.n_viz,
126 | figsize=(self.n_viz, self.test_n_sample_viz),
127 | gridspec_kw = {'wspace': subfig_slack, 'hspace': subfig_slack})
128 | else:
129 | self.figs = [
130 | [
131 | plt.subplots(1, 1, figsize=(1, 1)) \
132 | for j in range(self.n_viz)
133 | ] for i in range(self.test_n_sample_viz)
134 | ]
135 |
136 | def __getitem__(self, pos: tuple):
137 | i, j = pos
138 | if self.compact_mode:
139 | return self.ax[i, j]
140 | else:
141 | _, ax = self.figs[i][j]
142 | return ax
143 |
144 | @property
145 | def shape(self):
146 | return self.test_n_sample_viz, self.n_viz
147 |
148 | def savefig(self, path: str, **kwargs):
149 | if self.compact_mode:
150 | self.fig.savefig(path, **kwargs)
151 | else:
152 | *rest, ext = path.split('.')
153 | rest = '.'.join(rest)
154 | os.makedirs(rest, exist_ok=False)
155 | for i in range(self.test_n_sample_viz):
156 | for j in range(self.n_viz):
157 | path = os.path.join(rest, f'{i}_{j}.' + ext)
158 | fig, _ = self.figs[i][j]
159 | fig.savefig(path, **kwargs)
--------------------------------------------------------------------------------