├── LICENSE ├── README.md ├── config.yml ├── data ├── __init__.py ├── batching.py ├── dm.py ├── qd.py ├── sketch.py ├── unpack_ndjson.py └── utils.py ├── gifs ├── 0.gif ├── 1.gif ├── 2.gif ├── 3.gif ├── 4.gif ├── 5.gif ├── 6.gif ├── 7.gif └── 8.gif ├── main.py ├── models ├── __init__.py └── score.py ├── requirements.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ayan Das 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChiroDiff: Modelling chirographic data with Diffusion Models 2 | ### Accepted at International Conference on Learning Representation (ICLR) 2023 3 | 4 | Authors: [Ayan Das](https://ayandas.me/), [Yongxin Yang](https://yang.ac/), [Timothy Hospedales](https://homepages.inf.ed.ac.uk/thospeda/), [Tao Xiang](http://personal.ee.surrey.ac.uk/Personal/T.Xiang/index.html), [Yi-Zhe Song](http://personal.ee.surrey.ac.uk/Personal/Y.Song/) 5 | 6 |

7 | 8 |
9 | 10 | 11 |

12 | 13 | 14 |

15 |

16 | [OpenReview], [arXiv] & [Project Page] 17 |

18 |

19 | 20 | > **Abstract:** Generative modelling over continuous-time geometric constructs, a.k.a such as handwriting, sketches, drawings etc., have been accomplished through autoregressive distributions. Such strictly-ordered discrete factorization however falls short of capturing key properties of chirographic data -- it fails to build holistic understanding of the temporal concept due to one-way visibility (causality). Consequently, temporal data has been modelled as discrete token sequences of fixed sampling rate instead of capturing the true underlying concept. In this paper, we introduce a powerful model-class namely "Denoising Diffusion Probabilistic Models" or DDPMs for chirographic data that specifically addresses these flaws. Our model named "ChiroDiff", being non-autoregressive, learns to capture holistic concepts and therefore remains resilient to higher temporal sampling rate up to a good extent. Moreover, we show that many important downstream utilities (e.g. conditional sampling, creative mixing) can be flexibly implemented using ChiroDiff. We further show some unique use-cases like stochastic vectorization, de-noising/healing, abstraction are also possible with this model-class. We perform quantitative and qualitative evaluation of our framework on relevant datasets and found it to be better or on par with competing approaches. 21 | 22 | --- 23 | 24 | ## Running the code 25 | 26 | The instructions below guide you regarding running the codes in this repository. 27 | 28 | #### Table of contents: 29 | 1. Environment and libraries 30 | 2. Data preparation 31 | 3. Training 32 | 4. Inference 33 | 34 | ### Environment & Libraries 35 | 36 | Running the code may require some libraries slightly outdated. The full list is provided as a `requirements.txt` in this repo. Please create a virtual environment with `conda` or `venv` and run 37 | 38 | ~~~bash 39 | (myenv) $ pip install -r requirements.txt 40 | ~~~ 41 | 42 | ### Data preparation 43 | 44 | You can feed the data in one of two ways -- "unpacked" and "unpacked and preprocessed". The first one will dynamically load data from individual files, whereas the later packs preprocessed input into one single `.npz` file -- increasing training speed. 45 | 46 | - To "unpack" the QuickDraw dataset, [download](https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/raw?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=true) the `.ndjson` file for any category(s) you like and save it in a folder `/path/to/all/ndjsons/`. Then use the utility `data/unpack_ndjson.py` provided to unpack them 47 | ~~~bash 48 | (myenv) $ python data/unpack_ndjson.py --data_folder /path/to/all/ndjsons/ --category cat --out_folder /the/output/dir/ --max_sketches 100000 49 | # produced a folder /the/output/dir/cat/ with all samples (unpacked) 50 | ~~~ 51 | You may use this folder just as is, however, it might be slow for training. We recommend you pack them with all preprocessing into one `.npz` file using the `data/qd.py` script 52 | ~~~bash 53 | (myenv) $ python data/qd.py /the/output/dir/cat threeseqdel 54 | # produces a file /the/output/dir/cat_threeseqdel.npz 55 | ~~~ 56 | The `threeseqdel` is one of many modes of training -- more on this later. The produced `.npz` file can be now used for training. Please see the `if __name__ == '__main__'` section of `data/qd.py` for preprocessing options. 57 | 58 | - For `VMNIST` and `KanjiVG` datasets, the unpacked files are readily available [here](https://drive.google.com/drive/folders/1C6euR9HPLdL_nubqLk8wEao96KRNZJl_?usp=sharing) for download. Follow the same steps above to process and pack them into `.npz`. 59 | 60 | 61 | ### Training & Sampling 62 | 63 | There are multiple training "modes" corresponding to the model type (unconditional, sequence conditioned etc). 64 | 65 | ```bash 66 | threeseqdel # unconditional model with delta (velocity) sequence 67 | threeseqdel_pointcloudcond # conditioned on pointcloud representation 68 | threeseqdel_classcond # conditioned on class 69 | threeseqdel_threeseqdelcond # conditioned on self 70 | 71 | threeseqabs # unconditional model with absolute (position) sequence 72 | threeseqabs_pointcloudcond # conditioned on pointcloud representation 73 | threeseqabs_classcond # conditioned on class 74 | threeseqabs_threeseqabscond # conditioned on self 75 | ``` 76 | 77 | - Use one of the modes in `--model.repr` and `--data.init_args.repr` command line argument. 78 | - Use the processed data file (i.e. `*.npz`) with `--data.init_args.root_dir ...`. You may also use un-processed data folder here. 79 | 80 | **Note:** For simplicity, we provided a `config.yml` file where all possible command line option can be altered. Then run the main script as 81 | 82 | ```bash 83 | (myenv) $ python main.py fit --config config.yml --model.arch_layer 3 --model.noise_T 100 ... 84 | ``` 85 | 86 | You will also need `wandb` for logging. Please use your own account and fill the correct values of `--trainer.logger.init_args.{entity, project}` in the `config.yml` file. You may also remove the `wandb` logger entirely and replace with another logger of your choice. In that case, you might have to modify few lines of codes. 87 | 88 | While training, the script will save the full config of the run, a "best model" and a "last model". Once trained, use the saved model (saved every 300 epoch) and full configuration using the `--ckpt_path` and `--config` argument like so 89 | 90 | ```bash 91 | (myenv) $ python main.py test --config ./logs/test-run/config.yaml --ckpt_path ./logs/test-run/.../checkpoints/model.ckpt --limit_test_batches 1 92 | ``` 93 | 94 | By default, the testing phase will write some vizualization helpful for inspection. For example, a generation results and a diffusion process vizualization. Test time option have `--test_` prefixes. Feel free to play around with them. 95 | 96 | ```bash 97 | (myenv) $ python main.py test --config ... --ckpt_path ... \ 98 | --test_sampling_algo ddpm \ 99 | --test_variance_strength 0.75 \ 100 | --text_viz_process backward \ 101 | --test_save_everything 1 102 | ``` 103 | --- 104 | 105 | You can site the paper as 106 | 107 | ```bibtex 108 | @inproceedings{das2023chirodiff, 109 | title={ChiroDiff: Modelling chirographic data with Diffusion Models}, 110 | author={Ayan Das and Yongxin Yang and Timothy Hospedales and Tao Xiang and Yi-Zhe Song}, 111 | booktitle={The Eleventh International Conference on Learning Representations }, 112 | year={2023}, 113 | url={https://openreview.net/forum?id=1ROAstc9jv} 114 | } 115 | ``` 116 | 117 | --- 118 | 119 | **Notes:** 120 | 121 | 1. This repository is a part of our research codebase and may therefore contain codes/options that are not part of the paper. 122 | 2. This repo may also contain some implmenetation details that has been upgraded since the submission of the paper. 123 | 3. The README is still incomplete and I will add more info when I get time. You may try different settings yourself. 124 | 4. The default parameters might not match the ones in the paper. Feel free to change play with them. 125 | 126 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | seed_everything: null 2 | trainer: 3 | logger: 4 | class_path: utils.CustomWandbLogger 5 | init_args: 6 | entity: # TODO: fill these two .. 7 | project: # .. entires yourself. 8 | offline: true 9 | log_model: false 10 | save_dir: ./logs/ 11 | name: test-run 12 | group: test 13 | 14 | process_position: 0 15 | num_nodes: 1 16 | accelerator: gpu 17 | devices: 1 18 | auto_select_gpus: true 19 | 20 | gradient_clip_algorithm: norm 21 | enable_progress_bar: true 22 | overfit_batches: 0.0 23 | track_grad_norm: -1 24 | check_val_every_n_epoch: 1 25 | fast_dev_run: false 26 | accumulate_grad_batches: 1 27 | max_epochs: 100000 28 | limit_train_batches: 1.0 29 | limit_val_batches: 1.0 30 | limit_test_batches: 1.0 31 | log_every_n_steps: 8 32 | strategy: dp 33 | sync_batchnorm: false 34 | enable_model_summary: true 35 | weights_summary: top 36 | num_sanity_val_steps: 0 37 | profiler: null 38 | benchmark: false 39 | deterministic: false 40 | detect_anomaly: false 41 | auto_scale_batch_size: false 42 | prepare_data_per_node: null 43 | plugins: null 44 | amp_backend: native 45 | amp_level: null 46 | move_metrics_to_cpu: false 47 | stochastic_weight_avg: false 48 | 49 | gradient_clip_val: 0.1 50 | precision: 16 51 | 52 | model: 53 | repr: ${data.init_args.repr} 54 | modeltype: birnn 55 | time_embedding: randomfourier 56 | 57 | optim_ema: true 58 | optim_lr: 1.e-3 59 | optim_gamma: 0.9995 60 | optim_warmup: 15000 61 | optim_sched: steplr 62 | optim_interval: epoch 63 | optim_div_factor: 2 64 | optim_decay: 1.e-2 65 | 66 | arch_parameterization: eps 67 | arch_dropout: 0. 68 | arch_pe_dim: 8 69 | arch_head: 4 70 | arch_layer: 3 71 | arch_internal: 96 72 | 73 | # conditioning model arch 74 | arch_layer_cond: 3 75 | arch_internal_cond: 112 76 | arch_n_cond_latent: 96 77 | 78 | noise_T: 35 79 | noise_low_noise: 1.e-4 80 | noise_high_noise: 2.e-2 81 | noise_schedule: linear 82 | 83 | test_variance_strength: 0.75 84 | test_sampling_algo: ddpm 85 | test_n_viz: 10 86 | test_n_sample_viz: 10 87 | test_recon: true 88 | test_interp: false 89 | 90 | data: 91 | class_path: data.dm.QuickDrawDM 92 | init_args: 93 | root_dir: # TODO: path to the _.npz file 94 | repr: threeseqabs 95 | 96 | split_fraction: 0.85 97 | perlin_noise: 0.1 98 | split_seed: 5555 99 | num_workers: 4 100 | batch_size: 128 101 | max_strokes: 20 102 | max_sketches: 100000 103 | 104 | ckpt_path: null -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/data/__init__.py -------------------------------------------------------------------------------- /data/batching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.utils.rnn import pad_sequence 4 | from copy import deepcopy 5 | 6 | from data.sketch import Sketch 7 | 8 | # TODO: reasonable limit. can be made cmd arg later 9 | MAX_SEQ_LENGTH = 300 10 | 11 | 12 | class SketchRepr(object): 13 | 14 | def __init__(self, penbit=True, cache=False): 15 | super().__init__() 16 | 17 | self.penbit = penbit 18 | 19 | def represent(self, sketch): 20 | raise NotImplementedError('Abstract method not callable') 21 | 22 | def collate(self, batch: list): 23 | raise NotImplementedError('Abstract method not callable') 24 | 25 | 26 | class Strokewise(SketchRepr): 27 | 28 | def __init__(self, penbit=True, cache=False): 29 | # Here 'granularity' means stroke-granularity 30 | super().__init__(penbit=penbit, cache=cache) 31 | 32 | def represent(self, sketch: Sketch): 33 | sk_repr = [] 34 | 35 | total_seq_len = sum([len(stroke) for stroke in sketch]) 36 | if total_seq_len >= MAX_SEQ_LENGTH: 37 | return None 38 | 39 | for stroke in sketch: 40 | seq_stroke, _ = stroke.tensorize() 41 | 42 | # TODO: clean this properly; strokes becomes (2,) sized 43 | if len(stroke) > 1 and len(stroke.stroke.shape) == 2: # sloppy fix 44 | sk_repr.append({ 45 | 'start': seq_stroke[0, :], 46 | 'time_range': torch.from_numpy(stroke.timestamps.astype(np.float32)), 47 | 'poly_stroke': seq_stroke 48 | }) 49 | 50 | return sk_repr 51 | 52 | 53 | class Pointcloud(Strokewise): 54 | 55 | def construct_sample(sk): 56 | if len(sk) == 0: 57 | return None 58 | 59 | sk_set = torch.cat([s['poly_stroke'] for s in sk], 0) 60 | return sk_set 61 | 62 | def represent(self, sketch: Sketch): 63 | sk = super().represent(sketch) 64 | return sk and Pointcloud.construct_sample(sk) 65 | 66 | def collate(batch: list): 67 | batch = [b for b in batch if b is not None] 68 | lens = torch.tensor([b.shape[0] for b in batch]) 69 | pd = pad_sequence(batch, batch_first=True) 70 | return None, (pd, lens) 71 | 72 | 73 | class ThreePointDelta(Strokewise): 74 | 75 | def construct_sample(sk, penbit=True): 76 | sk_list = [] 77 | t_list = [] 78 | for i, stroke in enumerate(sk): 79 | timestamps = stroke['time_range'] 80 | stroke = stroke['poly_stroke'] 81 | pen = torch.ones(stroke.shape[0], 1, dtype=stroke.dtype, device=stroke.device) * i 82 | pen[-1, 0] = i + 1 83 | sk_list.append(torch.cat([stroke, pen], -1)) 84 | t_list.append(timestamps) 85 | 86 | if len(sk) == 0: 87 | return None 88 | 89 | sk = torch.cat(sk_list, 0) 90 | time = torch.cat(t_list, 0) 91 | sk_delta = sk[1:, :] - sk[:-1, :] 92 | if not penbit: 93 | sk_delta = sk_delta[:, :-1] 94 | 95 | time = time[:-1] # velocity is not available for the last point 96 | return torch.cat([sk_delta, time[:, None]], -1) 97 | 98 | def represent(self, sketch: Sketch): 99 | sk = super().represent(sketch) 100 | return sk and ThreePointDelta.construct_sample(sk, self.penbit) 101 | 102 | def collate(batch: list): 103 | batch = [b for b in batch if b is not None] 104 | lens = torch.tensor([b.shape[0] for b in batch]) 105 | pd = pad_sequence(batch, batch_first=True) 106 | return None, (pd, lens) 107 | 108 | 109 | class ThreePointDelta_PointCloudCond(Strokewise): 110 | 111 | def represent(self, sketch: Sketch): 112 | sk = super().represent(sketch) 113 | 114 | if sk is None: 115 | return None, None 116 | 117 | sk_threepointdelta = ThreePointDelta.construct_sample(sk, self.penbit) 118 | 119 | sketch = deepcopy(sketch) 120 | sk = super().represent(sketch) 121 | 122 | if sk is None: 123 | return None, None 124 | 125 | sk_pointcloud = Pointcloud.construct_sample(sk) 126 | 127 | return sk_pointcloud, sk_threepointdelta 128 | 129 | def collate(batch: list): 130 | sk_threepointdeltas = [tpd for _, tpd in batch] 131 | sk_pointclouds = [pc for pc, _ in batch] 132 | 133 | _, pc_batch = Pointcloud.collate(sk_pointclouds) 134 | _, tpd_batch = ThreePointDelta.collate(sk_threepointdeltas) 135 | return pc_batch, tpd_batch 136 | 137 | 138 | class ThreePointAbs(Strokewise): 139 | 140 | def construct_sample(sk, penbit=True): 141 | sk_list = [] 142 | t_list = [] 143 | for _, stroke in enumerate(sk): 144 | timestamps = stroke['time_range'] 145 | stroke = stroke['poly_stroke'] 146 | pen = torch.zeros(stroke.shape[0], 1, dtype=stroke.dtype, device=stroke.device) 147 | pen[-1, 0] = 1. 148 | sk_list.append(torch.cat([stroke, pen], -1)) 149 | t_list.append(timestamps) 150 | 151 | if len(sk) == 0: 152 | return None 153 | 154 | sk = torch.cat(sk_list, 0) 155 | if not penbit: 156 | sk = sk[:, :-1] 157 | 158 | time = torch.cat(t_list, 0) 159 | return torch.cat([sk[1:, :], time[1:, None]], -1) 160 | 161 | def represent(self, sketch: Sketch): 162 | sk = super().represent(sketch) 163 | return sk and ThreePointAbs.construct_sample(sk, self.penbit) 164 | 165 | def collate(batch: list): 166 | batch = [b for b in batch if b is not None] 167 | lens = torch.tensor([b.shape[0] for b in batch]) 168 | pd = pad_sequence(batch, batch_first=True) 169 | return None, (pd, lens) 170 | 171 | 172 | class ThreePointAbs_PointCloudCond(Strokewise): 173 | 174 | def represent(self, sketch: Sketch): 175 | sk = super().represent(sketch) 176 | 177 | if sk is None: 178 | return None, None 179 | 180 | sk_threepointabs = ThreePointAbs.construct_sample(sk, self.penbit) 181 | 182 | sketch = deepcopy(sketch) 183 | sk = super().represent(sketch) 184 | 185 | if sk is None: 186 | return None, None 187 | 188 | sk_pointcloud = Pointcloud.construct_sample(sk) 189 | 190 | return sk_pointcloud, sk_threepointabs 191 | 192 | def collate(batch: list): 193 | sk_threepointabss = [tpa for _, tpa in batch] 194 | sk_pointclouds = [pc for pc, _ in batch] 195 | 196 | _, pc_batch = Pointcloud.collate(sk_pointclouds) 197 | _, tpa_batch = ThreePointAbs.collate(sk_threepointabss) 198 | return pc_batch, tpa_batch 199 | 200 | 201 | class ThreePointAbs_ThreeSeqAbs(Strokewise): 202 | 203 | def __init__(self, penbit=True, cond_rdp=None, cache=False): 204 | super().__init__(penbit, cache) 205 | 206 | self.cond_rdp = cond_rdp 207 | 208 | def represent(self, sketch: Sketch): 209 | cond_sketch = deepcopy(sketch) 210 | 211 | # spatially scaling back to 1. and then 10. is needed because the stuff in the middle 212 | # (resampling rate, RDP parameter) are sensitive to spatial scale of the vector entity. 213 | cond_sketch.scale_spatial(1.) 214 | if self.cond_rdp is not None: 215 | cond_sketch.rdp(self.cond_rdp) 216 | cond_sketch.scale_spatial(10.) 217 | 218 | cond_sk = super().represent(cond_sketch) 219 | sk = super().represent(sketch) 220 | 221 | if sk is None or cond_sk is None: 222 | return None, None 223 | 224 | cond_sk_threepointabs = ThreePointAbs.construct_sample(cond_sk, self.penbit) 225 | sk_threepointabs = ThreePointAbs.construct_sample(sk, self.penbit) 226 | 227 | # timestep not needed for the condition 228 | return cond_sk_threepointabs, \ 229 | sk_threepointabs 230 | 231 | def collate(batch: list): 232 | sk_threepointabss = [h_tpa for _, h_tpa in batch] 233 | cond_sk_threepointabss = [l_tpa for l_tpa, _ in batch] 234 | 235 | _, tpa_batch = ThreePointAbs.collate(sk_threepointabss) 236 | _, cond_tpa_batch = ThreePointAbs.collate(cond_sk_threepointabss) 237 | return cond_tpa_batch, tpa_batch 238 | 239 | 240 | class ThreePointDel_ThreeSeqDel(Strokewise): 241 | 242 | def __init__(self, penbit=True, cond_rdp=None, cache=False): 243 | super().__init__(penbit, cache) 244 | 245 | self.cond_rdp = cond_rdp 246 | 247 | def represent(self, sketch: Sketch): 248 | cond_sketch = deepcopy(sketch) 249 | 250 | # spatially scaling back to 1. and then 10. is needed because the stuff in the middle 251 | # (resampling rate, RDP parameter) are sensitive to spatial scale of the vector entity. 252 | cond_sketch.scale_spatial(1.) 253 | if self.cond_rdp is not None: 254 | cond_sketch.rdp(self.cond_rdp) 255 | cond_sketch.scale_spatial(10.) 256 | 257 | cond_sk = super().represent(cond_sketch) 258 | sk = super().represent(sketch) 259 | 260 | if sk is None or cond_sk is None: 261 | return None, None 262 | 263 | cond_sk_threepointdel = ThreePointDelta.construct_sample(cond_sk, self.penbit) 264 | sk_threepointdel = ThreePointDelta.construct_sample(sk, self.penbit) 265 | 266 | # timestep not needed for the condition 267 | return cond_sk_threepointdel, \ 268 | sk_threepointdel 269 | 270 | def collate(batch: list): 271 | sk_threepointdels = [h_tpd for _, h_tpd in batch] 272 | cond_sk_threepointdels = [l_tpd for l_tpd, _ in batch] 273 | 274 | _, tpd_batch = ThreePointAbs.collate(sk_threepointdels) 275 | _, cond_tpd_batch = ThreePointAbs.collate(cond_sk_threepointdels) 276 | return cond_tpd_batch, tpd_batch 277 | 278 | 279 | class StrokeSet(Strokewise): 280 | 281 | def represent(self, sketch: Sketch): 282 | sk = super().represent(sketch) 283 | 284 | sk_list = [] 285 | for stroke in sk: 286 | abs_stroke = stroke['poly_stroke'] 287 | del_stroke = abs_stroke[1:, ...] - abs_stroke[:-1, ...] 288 | start_del_stroke = torch.cat([stroke['start'][None, :], del_stroke], 0) 289 | sk_list.append(start_del_stroke.ravel()) 290 | 291 | sk = torch.stack(sk_list, 0) 292 | return sk 293 | 294 | def collate(batch: list): 295 | batch = [b for b in batch if b is not None] 296 | lens = torch.tensor([b.shape[0] for b in batch]) 297 | pd = pad_sequence(batch, batch_first=True) 298 | return pd, lens 299 | -------------------------------------------------------------------------------- /data/dm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | from typing import Optional 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, random_split, Dataset 7 | from pytorch_lightning import LightningDataModule 8 | 9 | from data.qd import ( 10 | DS_threeseqdel, 11 | DS_threeseqabs, 12 | DS_threeseqdel_pointcloudcond, 13 | DS_threeseqdel_classcond, 14 | DS_threeseqabs_classcond, 15 | DS_threeseqabs_pointcloudcond, 16 | DS_threeseqabs_threeseqabscond, 17 | DS_threeseqdel_threeseqdelcond 18 | ) 19 | 20 | 21 | class ReprType(str, Enum): 22 | threeseqdel = "threeseqdel" 23 | threeseqabs = "threeseqabs" 24 | threeseqabs_threeseqabscond = "threeseqabs_threeseqabscond" 25 | threeseqdel_pointcloudcond = "threeseqdel_pointcloudcond" 26 | threeseqdel_classcond = "threeseqdel_classcond" 27 | threeseqabs_classcond = "threeseqabs_classcond" 28 | threeseqabs_pointcloudcond = "threeseqabs_pointcloudcond" 29 | threeseqdel_threeseqdelcond = "threeseqdel_threeseqdelcond" 30 | 31 | 32 | class GenericDM(LightningDataModule): 33 | 34 | def __init__(self, split_seed, split_fraction, batch_size, num_worker, repr): 35 | super().__init__() 36 | 37 | self.split_seed = split_seed 38 | self.split_fraction = split_fraction 39 | self.batch_size = batch_size 40 | self.num_worker = num_worker 41 | self.repr = repr 42 | 43 | # subclasses need to set this with a 'Dataset' instance 44 | self._dataset = None 45 | 46 | @property 47 | def dataset(self): 48 | if self._dataset is None: 49 | raise ValueError(f'Subclass {self.__class__.__name__} is yet to assign a Dataset') 50 | else: 51 | return self._dataset 52 | 53 | @dataset.setter 54 | def dataset(self, d): 55 | if not isinstance(d, Dataset): 56 | raise ValueError(f'Expected a Dataset, got {d}') 57 | else: 58 | self._dataset = d 59 | 60 | def compute_split_size(self): 61 | self.train_len = int(len(self.dataset) * self.split_fraction) 62 | self.valid_len = len(self.dataset) - self.train_len 63 | 64 | def setup(self, stage: str): 65 | self.train_dataset, self.valid_dataset = \ 66 | random_split(self.dataset, [self.train_len, self.valid_len], 67 | torch.Generator().manual_seed(self.split_seed)) 68 | 69 | def train_dataloader(self): 70 | return DataLoader(self.train_dataset, 71 | batch_size=self.batch_size, pin_memory=True, drop_last=True, shuffle=True, 72 | num_workers=self.num_worker, collate_fn=self.dataset.__class__.collate) 73 | 74 | def val_dataloader(self): 75 | return DataLoader(self.valid_dataset, 76 | batch_size=self.batch_size, pin_memory=True, drop_last=True, shuffle=True, 77 | num_workers=self.num_worker, collate_fn=self.dataset.__class__.collate) 78 | 79 | def test_dataloader(self): 80 | return self.val_dataloader() 81 | 82 | 83 | class QuickDrawDM(GenericDM): 84 | 85 | def __init__(self, 86 | root_dir: str, 87 | max_sketches: Optional[int] = None, 88 | max_strokes: Optional[int] = None, 89 | split_fraction: float = 0.85, 90 | perlin_noise: float = 0.2, 91 | penbit: bool = True, 92 | split_seed: int = 4321, 93 | batch_size: int = 64, 94 | num_workers: int = os.cpu_count() // 2, 95 | rdp: Optional[float] = None, 96 | cond_rdp: Optional[float] = None, 97 | repr: ReprType = ReprType.threeseqdel, 98 | cache: bool = False 99 | ): 100 | """QuickDraw Datamodule (OneSeq) 101 | 102 | Args: 103 | root_dir: Root directory of QD data (unpacked by `unpack_ndjson.py` utility) 104 | category: QD category name 105 | max_sketches: Maximum number of sketches to use 106 | max_strokes: clamp the maximum number of strokes (None for all strokes) 107 | split_fraction: Train/Validation split fraction 108 | perlin_noise: Strength of Perlin noise (YET TO BE IMPL) 109 | granularity: Number of points in each sample 110 | split_seed: Data splitting seed 111 | batch_size: Batch size for training 112 | rdp: RDP algorithm parameter ('None' to ignore RDP entirely) 113 | repr: data representation (oneseq or strokewise) 114 | """ 115 | self.save_hyperparameters() 116 | self.hp = self.hparams # an easier name 117 | super().__init__(self.hp.split_seed, 118 | self.hp.split_fraction, 119 | self.hp.batch_size, 120 | self.hp.num_workers, 121 | self.hp.repr) 122 | 123 | self._construct() 124 | 125 | def _construct(self): 126 | if self.hp.repr == ReprType.threeseqdel: 127 | self.dataset = DS_threeseqdel(self.hp.root_dir, 128 | perlin_noise=self.hp.perlin_noise, 129 | max_sketches=self.hp.max_sketches, 130 | max_strokes=self.hp.max_strokes, 131 | penbit=self.hp.penbit, 132 | rdp=self.hp.rdp) 133 | elif self.hp.repr == ReprType.threeseqabs: 134 | self.dataset = DS_threeseqabs(self.hp.root_dir, 135 | perlin_noise=self.hp.perlin_noise, 136 | max_sketches=self.hp.max_sketches, 137 | max_strokes=self.hp.max_strokes, 138 | penbit=self.hp.penbit, 139 | rdp=self.hp.rdp) 140 | elif self.hp.repr == ReprType.threeseqdel_pointcloudcond: 141 | self.dataset = DS_threeseqdel_pointcloudcond(self.hp.root_dir, 142 | perlin_noise=self.hp.perlin_noise, 143 | max_sketches=self.hp.max_sketches, 144 | max_strokes=self.hp.max_strokes, 145 | penbit=self.hp.penbit, 146 | rdp=self.hp.rdp) 147 | elif self.hp.repr == ReprType.threeseqdel_classcond: 148 | self.dataset = DS_threeseqdel_classcond(self.hp.root_dir, 149 | perlin_noise=self.hp.perlin_noise, 150 | max_sketches=self.hp.max_sketches, 151 | max_strokes=self.hp.max_strokes, 152 | penbit=self.hp.penbit, 153 | rdp=self.hp.rdp) 154 | elif self.hp.repr == ReprType.threeseqabs_classcond: 155 | self.dataset = DS_threeseqabs_classcond(self.hp.root_dir, 156 | perlin_noise=self.hp.perlin_noise, 157 | max_sketches=self.hp.max_sketches, 158 | max_strokes=self.hp.max_strokes, 159 | penbit=self.hp.penbit, 160 | rdp=self.hp.rdp) 161 | elif self.hp.repr == ReprType.threeseqabs_pointcloudcond: 162 | self.dataset = DS_threeseqabs_pointcloudcond(self.hp.root_dir, 163 | perlin_noise=self.hp.perlin_noise, 164 | max_sketches=self.hp.max_sketches, 165 | max_strokes=self.hp.max_strokes, 166 | penbit=self.hp.penbit, 167 | rdp=self.hp.rdp) 168 | elif self.hp.repr == ReprType.threeseqabs_threeseqabscond: 169 | self.dataset = DS_threeseqabs_threeseqabscond(self.hp.root_dir, 170 | perlin_noise=self.hp.perlin_noise, 171 | max_sketches=self.hp.max_sketches, 172 | max_strokes=self.hp.max_strokes, 173 | penbit=self.hp.penbit, 174 | rdp=self.hp.rdp, 175 | cond_rdp=self.hp.cond_rdp) 176 | elif self.hp.repr == ReprType.threeseqdel_threeseqdelcond: 177 | self.dataset = DS_threeseqdel_threeseqdelcond(self.hp.root_dir, 178 | perlin_noise=self.hp.perlin_noise, 179 | max_sketches=self.hp.max_sketches, 180 | max_strokes=self.hp.max_strokes, 181 | penbit=self.hp.penbit, 182 | rdp=self.hp.rdp, 183 | cond_rdp=self.hp.cond_rdp) 184 | else: 185 | pass 186 | 187 | self.compute_split_size() 188 | -------------------------------------------------------------------------------- /data/qd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import pickle 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | from data.sketch import Sketch 12 | from data.batching import Pointcloud, Strokewise, \ 13 | ThreePointAbs, ThreePointDelta, ThreePointDelta_PointCloudCond, \ 14 | ThreePointAbs_PointCloudCond, ThreePointAbs_ThreeSeqAbs, \ 15 | ThreePointDel_ThreeSeqDel 16 | 17 | 18 | class QuickDraw(Dataset): 19 | 20 | def __init__(self, data_root, shuffle=True, perlin_noise=0.2, 21 | max_sketches=10000, max_strokes=None, rdp=None, **kwargs): 22 | super().__init__() 23 | 24 | self.data_root = data_root 25 | 26 | if os.path.isfile(self.data_root) and self.data_root.endswith('.npz'): 27 | self.npz_ptr = np.load(self.data_root, allow_pickle=True) 28 | self.attrs = self.npz_ptr.files 29 | self.data = {attr: self.npz_ptr[attr] for attr in self.attrs} 30 | self.cached = True 31 | return 32 | else: 33 | self.cached = False 34 | 35 | self.max_sketches = max_sketches 36 | self.max_strokes = max_strokes 37 | self.perlin_noise = perlin_noise 38 | self.rdp = rdp 39 | 40 | self.content_list = os.listdir(self.data_root) 41 | if all([os.path.isdir(os.path.join(self.data_root, c_path)) for c_path in self.content_list]): 42 | self.categories = self.content_list 43 | self.n_categories = len(self.categories) 44 | self.file_list = [] 45 | for cat in self.categories: 46 | cat_content = os.listdir(os.path.join(self.data_root, cat)) 47 | 48 | if self.max_sketches is not None: 49 | max_sketches_per_cat = min(self.max_sketches, len(cat_content)) 50 | del cat_content[max_sketches_per_cat:] 51 | 52 | self.file_list.extend([os.path.join(cat, c) for c in cat_content]) 53 | 54 | else: 55 | self.categories = None 56 | self.file_list = self.content_list 57 | 58 | if self.max_sketches is not None: 59 | max_sketches = min(self.max_sketches, len(self.file_list)) 60 | del self.file_list[max_sketches:] 61 | 62 | if shuffle: 63 | random.shuffle(self.file_list) 64 | 65 | 66 | def __len__(self): 67 | if not self.cached: 68 | return len(self.file_list) 69 | else: 70 | return self.data[self.attrs[0]].shape[0] 71 | 72 | def get_sketch(self, i): 73 | if self.categories is not None: 74 | cat, _ = self.file_list[i].split('/') 75 | assert cat in self.categories, "something wrong with category/folder names" 76 | self.cat_idx = self.categories.index(cat) 77 | else: 78 | self.cat_idx = None 79 | 80 | file_path = os.path.join(self.data_root, self.file_list[i]) 81 | 82 | with open(file_path, 'rb') as f: 83 | self.data = pickle.load(f) 84 | 85 | stroke_list = self.data['drawing'] 86 | 87 | if self.max_strokes is not None: 88 | max_strokes = min(self.max_strokes, len(stroke_list)) 89 | stroke_list = stroke_list[:max_strokes] 90 | 91 | sketch = Sketch(stroke_list, label=self.cat_idx) 92 | 93 | seed = random.randint(0, 10000) 94 | sketch.jitter(seed=seed, noise_level=self.perlin_noise) 95 | 96 | sketch.move() 97 | sketch.shift_time(0) 98 | sketch.scale_spatial(1) 99 | sketch.resample(delta=0.05) 100 | if self.rdp is not None: 101 | sketch.rdp(eps=self.rdp) 102 | sketch.scale_spatial(10) 103 | sketch.scale_time(1) 104 | 105 | return sketch 106 | 107 | def __getitem__(self, i): 108 | if not self.cached: 109 | return self.represent(self.get_sketch(i)) 110 | else: 111 | if len(self.attrs) > 1: 112 | return tuple(torch.from_numpy(self.data[attr][i]) for attr in self.attrs) 113 | else: 114 | return torch.from_numpy(self.data[self.attrs[0]][i]) 115 | 116 | 117 | class QDSketchStrokewise(QuickDraw, Strokewise): 118 | 119 | def __init__(self, *args, **kwargs): 120 | QuickDraw.__init__(self, *args, **kwargs) 121 | Strokewise.__init__(self) 122 | 123 | def __getitem__(self, i): 124 | return self.represent(super().get_sketch(i)) 125 | 126 | 127 | class QDSketchPointcloud(QuickDraw, Pointcloud): 128 | 129 | def __init__(self, *args, **kwargs): 130 | QuickDraw.__init__(self, *args, **kwargs) 131 | Pointcloud.__init__(self) 132 | 133 | def __getitem__(self, i): 134 | return self.represent(super().__getitem__(i)) 135 | 136 | 137 | class DS_threeseqdel(QuickDraw, ThreePointDelta): 138 | 139 | def __init__(self, *args, **kwargs): 140 | QuickDraw.__init__(self, *args, **kwargs) 141 | ThreePointDelta.__init__(self, penbit=kwargs.get('penbit', True)) 142 | 143 | 144 | class DS_threeseqabs(QuickDraw, ThreePointAbs): 145 | 146 | def __init__(self, *args, **kwargs): 147 | QuickDraw.__init__(self, *args, **kwargs) 148 | ThreePointAbs.__init__(self, penbit=kwargs.get('penbit', True)) 149 | 150 | 151 | class DS_threeseqabs_classcond(QuickDraw, ThreePointAbs): 152 | 153 | def __init__(self, *args, **kwargs): 154 | QuickDraw.__init__(self, *args, **kwargs) 155 | ThreePointAbs.__init__(self, penbit=kwargs.get('penbit', True)) 156 | 157 | def represent(self, sketch: Sketch): 158 | label = torch.tensor(sketch.label, dtype=torch.int64) 159 | return label, super().represent(sketch) 160 | 161 | def collate(batch: list): 162 | class_batch = torch.stack([c for c, _ in batch], 0) 163 | _, tpa_batch = ThreePointAbs.collate([tpa for _, tpa in batch]) 164 | return class_batch, tpa_batch 165 | 166 | 167 | class DS_threeseqdel_classcond(QuickDraw, ThreePointDelta): 168 | 169 | def __init__(self, *args, **kwargs): 170 | QuickDraw.__init__(self, *args, **kwargs) 171 | ThreePointDelta.__init__(self, penbit=kwargs.get('penbit', True)) 172 | 173 | def represent(self, sketch: Sketch): 174 | label = torch.tensor(sketch.label, dtype=torch.int64) 175 | return label, super().represent(sketch) 176 | 177 | def collate(batch: list): 178 | class_batch = torch.stack([c for c, _ in batch], 0) 179 | _, tpd_batch = ThreePointDelta.collate([tpd for _, tpd in batch]) 180 | return class_batch, tpd_batch 181 | 182 | 183 | class DS_threeseqdel_pointcloudcond(QuickDraw, ThreePointDelta_PointCloudCond): 184 | 185 | def __init__(self, *args, **kwargs): 186 | QuickDraw.__init__(self, *args, **kwargs) 187 | ThreePointDelta_PointCloudCond.__init__(self, penbit=kwargs.get('penbit', True)) 188 | 189 | 190 | class DS_threeseqabs_pointcloudcond(QuickDraw, ThreePointAbs_PointCloudCond): 191 | 192 | def __init__(self, *args, **kwargs): 193 | QuickDraw.__init__(self, *args, **kwargs) 194 | ThreePointAbs_PointCloudCond.__init__(self, penbit=kwargs.get('penbit', True)) 195 | 196 | 197 | class DS_threeseqabs_threeseqabscond(QuickDraw, ThreePointAbs_ThreeSeqAbs): 198 | 199 | def __init__(self, *args, **kwargs): 200 | QuickDraw.__init__(self, *args, **kwargs) 201 | ThreePointAbs_ThreeSeqAbs.__init__(self, penbit=kwargs.get('penbit', True), 202 | cond_rdp=kwargs.get('cond_rdp', None)) 203 | 204 | def __getitem__(self, i): 205 | if not self.cached: 206 | return self.represent(self.get_sketch(i)) 207 | else: 208 | if len(self.attrs) > 1: 209 | return tuple(torch.from_numpy(self.data[attr][i]) for attr in self.attrs) 210 | else: 211 | # in case we need the same data as cond 212 | d = self.data[self.attrs[0]][i] 213 | return torch.from_numpy(d), torch.from_numpy(d) 214 | 215 | 216 | class DS_threeseqdel_threeseqdelcond(QuickDraw, ThreePointDel_ThreeSeqDel): 217 | 218 | def __init__(self, *args, **kwargs): 219 | QuickDraw.__init__(self, *args, **kwargs) 220 | ThreePointDel_ThreeSeqDel.__init__(self, penbit=kwargs.get('penbit', True), 221 | cond_rdp=kwargs.get('cond_rdp', None)) 222 | 223 | def __getitem__(self, i): 224 | if not self.cached: 225 | return self.represent(self.get_sketch(i)) 226 | else: 227 | if len(self.attrs) > 1: 228 | return tuple(torch.from_numpy(self.data[attr][i]) for attr in self.attrs) 229 | else: 230 | # in case we need the same data as cond 231 | d = self.data[self.attrs[0]][i] 232 | return torch.from_numpy(d), torch.from_numpy(d) 233 | 234 | 235 | if __name__ == '__main__': 236 | class_name_str = eval('DS_' + sys.argv[2]) 237 | ds = class_name_str( 238 | sys.argv[1], 239 | perlin_noise=0., 240 | max_sketches=100000, 241 | max_strokes=25, 242 | penbit=True, 243 | rdp=None 244 | ) 245 | dummy_sample = ds[0] 246 | if not isinstance(dummy_sample, tuple): 247 | n_attr = 1 248 | else: 249 | n_attr = len(dummy_sample) 250 | 251 | samples = [[] for _ in range(n_attr)] 252 | for sam in tqdm(ds): 253 | if n_attr == 1: 254 | sam = (sam, ) 255 | for a in range(n_attr): 256 | if sam[a] is None: 257 | break 258 | samples[a].append(sam[a].numpy()) 259 | samples = [np.array(sams, dtype=np.ndarray) for sams in samples] 260 | attrs = [f'attr{a}' for a in range(n_attr)] 261 | 262 | np.savez(sys.argv[1] + f'_{sys.argv[2]}.npz', **dict(zip(attrs, samples))) -------------------------------------------------------------------------------- /data/sketch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from matplotlib.colors import Colormap 5 | from simplification.cutil import simplify_coords_idx 6 | 7 | from data.utils import continuous_noise, resample 8 | 9 | 10 | class Stroke(object): 11 | def __init__(self, stroke, timestamps=None): 12 | super().__init__() 13 | 14 | self.type = type(stroke) 15 | if self.type in [np.ndarray, torch.Tensor]: 16 | self.stroke = stroke 17 | assert isinstance(timestamps, self.type), \ 18 | "stroke & timestamps must have same type" 19 | self.timestamps = timestamps 20 | 21 | if self.timestamps.ndim != 1: 22 | raise AssertionError('timestamps must be 1D array') 23 | 24 | def draw(self, axis=None): 25 | raise NotImplementedError('Use one of the subclasses of Stroke') 26 | 27 | def __len__(self): 28 | return self.stroke.shape[0] 29 | 30 | def tensorize(self): 31 | if self.type is torch.Tensor: 32 | return self.stroke, self.timestamps 33 | else: 34 | return torch.from_numpy(self.stroke.astype(np.float32)), \ 35 | torch.from_numpy(self.timestamps.astype(np.float32)) 36 | 37 | 38 | class PolylineStroke(Stroke): 39 | def __init__(self, stroke, timestamps=None): 40 | 41 | stroke = np.array(stroke).T if isinstance(stroke, list) else stroke 42 | timestamps = np.array(timestamps) if isinstance(timestamps, list) else timestamps 43 | super().__init__(stroke, timestamps) 44 | 45 | def rdp(self, eps=0.01): 46 | is_tensor = isinstance(self.stroke, torch.Tensor) 47 | stroke = self.stroke.data.cpu().numpy() if is_tensor else self.stroke 48 | stroke = np.ascontiguousarray(stroke) 49 | simpl_idx = simplify_coords_idx(stroke, eps) 50 | 51 | self.stroke = self.stroke[simpl_idx] 52 | self.timestamps = self.timestamps[simpl_idx] 53 | 54 | def resample(self, granularity): 55 | stroke = self.stroke.numpy() if (self.type is torch.Tensor) else self.stroke 56 | timestamps = self.timestamps.numpy() \ 57 | if (self.type is torch.Tensor) else self.timestamps 58 | 59 | self.stroke, self.timestamps = resample(stroke, timestamps, granularity) 60 | 61 | def jitter(self, seed, noise_level=0.2): 62 | stroke = self.stroke.numpy() if (self.type is torch.Tensor) else self.stroke 63 | self.stroke = continuous_noise(stroke, seed=seed, noise_level=noise_level) 64 | 65 | def move(self, by=np.zeros((1, 2))): 66 | self.stroke = self.stroke + by 67 | 68 | def shift_time(self, to=0.): 69 | self.timestamps = self.timestamps - self.initial_time + to 70 | 71 | def scale_time(self, factor=1.): 72 | self.timestamps = (self.timestamps / self.terminal_time) * factor 73 | 74 | @property 75 | def initial_time(self): 76 | return self.timestamps[0] 77 | 78 | @property 79 | def terminal_time(self): 80 | return self.timestamps[-1] 81 | 82 | @property 83 | def start(self): 84 | return self.stroke[0, :] 85 | 86 | @property 87 | def end(self): 88 | return self.stroke[-1, :] 89 | 90 | def draw(self, axis=None, color='black', linewidth=1, scatter=True): 91 | if axis is None: 92 | fig = plt.figure() 93 | axis = plt.gca() 94 | 95 | stroke = self.stroke.data.cpu().numpy() if (self.type is torch.Tensor) else self.stroke 96 | if not isinstance(color, list): 97 | axis.plot(stroke[:, 0], stroke[:, 1], color=color, linewidth=linewidth) 98 | else: 99 | for i in range(len(self) - 1): 100 | axis.plot(stroke[i:i+2, 0], stroke[i:i+2, 1], color=color[i], linewidth=linewidth, solid_capstyle='round') 101 | 102 | if scatter: 103 | stroke = self.stroke.data.cpu().numpy() if (self.type is torch.Tensor) else self.stroke 104 | if not isinstance(color, list): 105 | axis.scatter(stroke[:, 0], stroke[:, 1], color=color, s=linewidth*2) 106 | else: 107 | for i in range(len(self)): 108 | axis.scatter(stroke[None, i, 0], stroke[None, i, 1], color=color[i], s=linewidth*2) 109 | 110 | @property 111 | def enclosing_circle_radius(self): 112 | norms = np.linalg.norm(self.stroke, 2, -1) 113 | return norms.max() 114 | 115 | @property 116 | def length(self): 117 | return (((self.stroke[1:, :] - self.stroke[:-1, :])**2).sum(-1)**0.5).sum().item() 118 | 119 | 120 | class Sketch(object): 121 | 122 | def __init__(self, strokes, label=None): 123 | super().__init__() 124 | self.label = label # optional class label 125 | 126 | self.strokes = [] 127 | for s in strokes: 128 | stroke = PolylineStroke(s[:2], s[-1]) 129 | if len(stroke) > 1: 130 | # one point strokes are not tolerable 131 | self.strokes.append(stroke) 132 | 133 | @property 134 | def nstrokes(self): 135 | return len(self.strokes) 136 | 137 | def __len__(self): 138 | return self.nstrokes 139 | 140 | def rdp(self, eps=0.01): 141 | for stroke in self.strokes: 142 | stroke.rdp(eps) 143 | 144 | def resample(self, delta=0.1): 145 | for stroke in self.strokes: 146 | n = max(2, int(stroke.length / delta)) 147 | stroke.resample(n) 148 | 149 | def move(self, to=np.zeros((1, 2))): 150 | move_by = to - self.strokes[0].start 151 | for stroke in self.strokes: 152 | stroke.move(move_by) 153 | 154 | def __getitem__(self, i): 155 | return self.strokes[i] 156 | 157 | def draw(self, axis=None, cla=True, color='black', **kwargs): 158 | if axis is None: 159 | fig = plt.figure() 160 | axis = plt.gca() 161 | 162 | if cla: 163 | axis.cla() 164 | 165 | if not isinstance(color, Colormap): 166 | for stroke in self.strokes: 167 | stroke.draw(axis, color=color, **kwargs) 168 | else: 169 | seg_lens = [len(s) for s in self.strokes] 170 | colors = [color(i / (sum(seg_lens) - 1)) for i in range(sum(seg_lens))] 171 | c = 0 172 | for stroke in self.strokes: 173 | l = len(stroke) 174 | stroke.draw(axis, color=colors[c:c+l], **kwargs) 175 | c += l 176 | 177 | xmin, xmax = axis.get_xlim() 178 | ymin, ymax = axis.get_ylim() 179 | width = xmax - xmin 180 | height = ymax - ymin 181 | xmin, xmax = xmin - 0.1 * width, xmax + 0.1 * width 182 | ymin, ymax = ymin - 0.1 * height, ymax + 0.1 * height 183 | axis.set_xlim([xmin, xmax]) 184 | axis.set_ylim([ymin, ymax]) 185 | 186 | axis.set_xticks([]) 187 | axis.set_yticks([]) 188 | axis.set_xticklabels([]) 189 | axis.set_xticklabels([]) 190 | 191 | @property 192 | def terminal_time(self): 193 | return self.strokes[-1].terminal_time 194 | 195 | @property 196 | def initial_time(self): 197 | return self.strokes[0].initial_time 198 | 199 | def shift_time(self, to=0.): 200 | initial_time = self.initial_time 201 | for stroke in self.strokes: 202 | stroke.timestamps = stroke.timestamps - initial_time 203 | 204 | def scale_time(self, factor=1.): 205 | for stroke in self.strokes: 206 | stroke.timestamps = (stroke.timestamps / self.terminal_time) * factor 207 | 208 | def scale_spatial(self, factor=1.): 209 | enclosing_circle_radius = max([stroke.enclosing_circle_radius for stroke in self.strokes]) 210 | for stroke in self.strokes: 211 | stroke.stroke = (stroke.stroke / enclosing_circle_radius) * factor 212 | 213 | def jitter(self, seed, noise_level=0.2): 214 | for i, stroke in enumerate(self.strokes): 215 | stroke.jitter(seed + i, noise_level) 216 | 217 | def _fill_penup(start, end, granularity): 218 | start = start.unsqueeze(0).repeat(granularity, 1) 219 | end = end.unsqueeze(0).repeat(granularity, 1) 220 | alpha = torch.linspace(0., 1., granularity).unsqueeze(-1) 221 | stroke = start * (1. - alpha) + end * alpha 222 | return stroke 223 | 224 | def _add_pen_state(stroke, fill_value=0.): 225 | stroke_plus_pen = torch.cat([ 226 | stroke, 227 | torch.ones(len(stroke), 1, device=stroke.device) * fill_value 228 | ], dim=-1) 229 | return stroke_plus_pen 230 | 231 | def tensorize(self, joining_granularity=20): 232 | seq_strokes, seq_timestamps = [], [] 233 | 234 | current_stroke, current_timestamps = self[0].tensorize() 235 | seq_strokes.append(Sketch._add_pen_state(current_stroke, 0.)) 236 | seq_timestamps.append(current_timestamps) 237 | 238 | for i in range(1, self.nstrokes): 239 | next_stroke, next_timestamps = self[i].tensorize() 240 | joining_stroke = Sketch._fill_penup(current_stroke[-1, :], next_stroke[0, :], 241 | granularity=joining_granularity) 242 | joining_stroke_pen = Sketch._add_pen_state(joining_stroke, 1.) 243 | joining_timestamps = torch.linspace(current_timestamps[-1], next_timestamps[0], len(joining_stroke_pen), 244 | device=joining_stroke_pen.device) 245 | # ignore the first and last one to avoid duplication 246 | seq_strokes.append(joining_stroke_pen[1:-1, ...]) 247 | seq_timestamps.append(joining_timestamps[1:-1]) 248 | 249 | next_stroke_pen = Sketch._add_pen_state(next_stroke, 0.) 250 | seq_strokes.append(next_stroke_pen) 251 | seq_timestamps.append(next_timestamps) 252 | 253 | current_stroke, current_timestamps = next_stroke, next_timestamps 254 | 255 | return torch.cat(seq_strokes, 0), torch.cat(seq_timestamps, 0) 256 | 257 | def from_threeseqabs(seq, ts=None): 258 | # `seq` can be (N x 3) array, either np.ndarray or torch.Tensor 259 | n_points, _ = seq.shape 260 | seq, penbits = seq[:, :-1], seq[:, -1] 261 | 262 | dummy_timestamps = ts or np.linspace(0., 1., n_points) 263 | seq = np.concatenate((seq, dummy_timestamps[:, None]), axis=-1) 264 | 265 | split_locations, = penbits.nonzero() 266 | strokes = np.split(seq, split_locations + 1, axis=0) 267 | 268 | return Sketch([strk.T.tolist() for strk in strokes]) -------------------------------------------------------------------------------- /data/unpack_ndjson.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Unpack Quick Draw OR DiDi data to make Data loading more efficient. 3 | Otherwise full loading of '.ndjson' takes a while. 4 | 5 | For QD, "python unpack_ndjson.py --data_folder /path/to/QD/raw -c cat -o /path/to/empty/dir" 6 | For DiDi, "python unpack_ndjson.py --data_folder /path/to/DiDi -c diagrams_wo_text_20200131 -o /path/to/empty/dir" 7 | Author: Ayan Das 8 | ''' 9 | 10 | import os 11 | import pickle 12 | import argparse 13 | import ndjson as nj 14 | from tqdm import tqdm 15 | 16 | 17 | def main(args): 18 | data_path = os.path.join(args.data_folder, args.category + '.ndjson') 19 | with open(data_path, 'r') as f: 20 | data = nj.load(f) 21 | 22 | out_path = os.path.join(args.out_folder, args.category) 23 | 24 | if not os.path.exists(out_path): 25 | os.makedirs(out_path) 26 | 27 | for i, sample in enumerate(tqdm(data)): 28 | out_file_path = os.path.join(out_path, f'sketch_{i}') 29 | with open(out_file_path, 'wb') as f: 30 | pickle.dump(sample, f) 31 | 32 | if i > args.max_sketches: 33 | break 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--data_folder', type=str, required=True, 39 | help='QD folder of raw data (.ndjson)') 40 | parser.add_argument('-c', '--category', type=str, required=True, help='name of a category') 41 | parser.add_argument('-o', '--out_folder', type=str, required=True, help='output folder (empty)') 42 | parser.add_argument('-m', '--max_sketches', type=int, required=False, default=10000) 43 | args = parser.parse_args() 44 | 45 | main(args) 46 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.special import comb as choose 4 | 5 | from noise import pnoise2 6 | from shapely.geometry import LineString 7 | 8 | 9 | def resample(seq: np.ndarray, timestamps, granularity): 10 | # seq should be (N x 2) numpy array 11 | seq = LineString(seq) 12 | distances = np.linspace(0, seq.length, granularity) 13 | seq_resampled = LineString([seq.interpolate(d) for d in distances]) 14 | seq_resampled = np.array([seq_resampled.xy[0], seq_resampled.xy[1]]).T 15 | ts_resampled = np.linspace(timestamps[0], timestamps[-1], granularity) 16 | 17 | return seq_resampled, ts_resampled 18 | 19 | 20 | def continuous_noise(stroke: np.ndarray, seed=0, noise_level=0.3): 21 | ''' 22 | Given stroke is used as seed to generate a continuous noise-stroke 23 | and added to the original stroke; used as a part of augmentation. 24 | Implementation uses Perlin noise. 25 | ''' 26 | 27 | if noise_level == 0.: 28 | return stroke 29 | 30 | noise_on_stroke = np.zeros_like(stroke) 31 | stroke_ = stroke + seed 32 | for i in range(len(stroke)): 33 | n1 = pnoise2(*stroke_[i, ...] + 5) 34 | n2 = pnoise2(*stroke_[i, ...] - 5) 35 | noise_on_stroke[i, ...] = [n1, n2] 36 | 37 | noise_on_stroke = noise_on_stroke - noise_on_stroke.mean(0) 38 | return noise_on_stroke * noise_level + stroke 39 | 40 | 41 | def discrete_noise(stroke: np.ndarray, seed=0, noise_level=0.3): 42 | '''Standard random gaussian jittering. Independently applied on each point.''' 43 | 44 | if noise_level == 0.: 45 | return stroke 46 | 47 | old_state = np.random.get_state() 48 | np.random.seed(seed) 49 | stroke = stroke + np.random.rand(*stroke.shape) * noise_level 50 | np.random.set_state(old_state) 51 | return stroke 52 | 53 | 54 | def draw_bezier(ctrlPoints, nPointsCurve=100): 55 | ''' 56 | Draws a Bezier curve with given control points. 57 | 58 | ctrlPoints: shape (n+1, 2) matrix containing all control points 59 | nPointsCurve: granularity of the Bezier curve 60 | ''' 61 | 62 | def bezier_matrix(degree): 63 | m = degree 64 | Q = np.zeros((degree + 1, degree + 1)) 65 | for i in range(degree + 1): 66 | for j in range(degree + 1): 67 | if (0 <= (i+j)) and ((i+j) <= degree): 68 | Q[i, j] = choose(m, j) * choose(m-j, m-i-j) * ((-1)**(m-i-j)) 69 | return Q 70 | 71 | def T(ts: np.ndarray, d: int): 72 | # 'ts' is a vector (np.array) of time points 73 | ts = ts[..., np.newaxis] 74 | Q = tuple(ts**n for n in range(d, -1, -1)) 75 | return np.concatenate(Q, 1) 76 | 77 | nCtrlPoints, _ = ctrlPoints.shape 78 | 79 | ts = np.linspace(0., 1., num=nPointsCurve) 80 | 81 | curve = np.matmul(T(ts, nCtrlPoints - 1), bezier_matrix(nCtrlPoints-1) @ ctrlPoints) 82 | 83 | return curve 84 | -------------------------------------------------------------------------------- /gifs/0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/0.gif -------------------------------------------------------------------------------- /gifs/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/1.gif -------------------------------------------------------------------------------- /gifs/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/2.gif -------------------------------------------------------------------------------- /gifs/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/3.gif -------------------------------------------------------------------------------- /gifs/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/4.gif -------------------------------------------------------------------------------- /gifs/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/5.gif -------------------------------------------------------------------------------- /gifs/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/6.gif -------------------------------------------------------------------------------- /gifs/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/7.gif -------------------------------------------------------------------------------- /gifs/8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasayan05/chirodiff/e9e2ecc88e746f0d99e2008da31895548bfd5d3c/gifs/8.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import typing 4 | import contextlib 5 | import numpy as np 6 | import matplotlib 7 | from matplotlib.cm import get_cmap 8 | matplotlib.rcParams['axes.edgecolor'] = '#aaaaaa' 9 | from enum import Enum 10 | 11 | import torch 12 | from torch_ema import ExponentialMovingAverage as EMA 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.utilities.cli import LightningCLI 15 | from pytorch_lightning.callbacks import ( 16 | LearningRateMonitor, 17 | TQDMProgressBar, 18 | ModelCheckpoint 19 | ) 20 | 21 | from data.dm import ReprType, GenericDM 22 | from data.sketch import Sketch 23 | from models.score import ( 24 | ScoreFunc, 25 | TransformerSetFeature, 26 | BiRNNEncoderFeature, 27 | ClassEmbedding 28 | ) 29 | from utils import ( 30 | positionalencoding1d, 31 | random_fourier_encoding_dyn, 32 | make_pad_mask_for_transformer, 33 | openai_cosine_schedule, 34 | linear_schedule, 35 | CustomViz, 36 | ) 37 | 38 | 39 | class SketchDiffusion(pl.LightningModule): 40 | 41 | class ModelType(Enum): 42 | birnn = "birnn" 43 | transformer = "transformer" 44 | 45 | class SamplingAlgo(Enum): 46 | ddpm = "ddpm" 47 | ddim = "ddim" 48 | fddim = "fddim" # only for private use 49 | 50 | class NoiseSchedule(Enum): 51 | linear = "linear" 52 | cosine = "cosine" 53 | 54 | class TimeEmbedding(Enum): 55 | sinusoidal = "sinusoidal" 56 | randomfourier = "randomfourier" 57 | 58 | class VizProcess(Enum): 59 | forward = "forward" 60 | backward = "backward" 61 | both = "both" 62 | 63 | class Parameterization(Enum): 64 | mu = "mu" 65 | eps = "eps" 66 | 67 | def __init__(self, 68 | repr: ReprType = ReprType.threeseqdel, 69 | modeltype: ModelType = ModelType.transformer, 70 | time_embedding: TimeEmbedding = TimeEmbedding.sinusoidal, 71 | vae_weight: float = 0., 72 | vae_kl_anneal_start: int = 200_000, 73 | vae_kl_anneal_end: int = 400_000, 74 | num_classes: typing.Optional[int] = None, 75 | optim_ema: bool = True, 76 | optim_sched: str = 'steplr', 77 | optim_lr: float = 1.e-4, 78 | optim_decay: float = 1.e-2, 79 | optim_gamma: float = 0.9995, 80 | optim_warmup: int = 3000, 81 | optim_interval: str = 'step', 82 | optim_div_factor: int = 3, 83 | arch_head: int = 4, 84 | arch_layer: int = 4, 85 | arch_internal: int = 64, 86 | arch_layer_cond: typing.Optional[int] = None, 87 | arch_internal_cond: typing.Optional[int] = None, 88 | arch_pe_dim: int = 2, 89 | arch_n_cond_latent: int = 32, 90 | arch_causal: bool = False, 91 | arch_dropout: float = 0.1, 92 | arch_parameterization: Parameterization = Parameterization.eps, # unused 93 | noise_low_noise: float = 1e-4, 94 | noise_high_noise: float = 2e-2, 95 | noise_schedule: NoiseSchedule = NoiseSchedule.linear, 96 | noise_T: int = 1000, 97 | test_variance_strength: float = 0.5, 98 | test_sampling_algo: SamplingAlgo = SamplingAlgo.ddpm, 99 | test_partial_T: typing.Optional[int] = None, 100 | test_recon: bool = True, 101 | test_interp: bool = False, 102 | test_n_viz: int = 10, 103 | test_n_sample_viz: int = 10, 104 | test_viz_fig_compact: bool = True, 105 | text_viz_process: VizProcess = VizProcess.both, 106 | test_save_everything: bool = True 107 | ) -> None: 108 | """ 109 | Diffusion Model for Sketches (both set and sequential representation) 110 | 111 | Args: 112 | repr: POINTCLOUD for sets and THREEPOINT for sequence 113 | arch: architecture params of transformer/RNN (head, layer, inp_n_emb, ff_dim, pe_dim) 114 | noise: noise parameters (number of scales, low and high noise variance, T) 115 | test: which test to do (reconstruction, interpolation etc) 116 | """ 117 | 118 | super().__init__() 119 | self.save_hyperparameters() 120 | self.hp = self.hparams 121 | 122 | self.cond = self.hp.repr in [ 123 | ReprType.threeseqdel_pointcloudcond, 124 | ReprType.threeseqdel_classcond, 125 | ReprType.threeseqabs_classcond, 126 | ReprType.threeseqabs_pointcloudcond, 127 | ReprType.threeseqabs_threeseqabscond 128 | ] 129 | 130 | if self.hp.vae_weight != 0.: 131 | assert self.hp.repr.value.endswith('pointcloudcond') or self.hp.repr.value.endswith('threeseqabscond'), \ 132 | "VAE only allowed in bottlenecked conditional models" 133 | 134 | self.elem_dim = 3 135 | 136 | self.pe_dim = self.hp.arch_pe_dim 137 | 138 | n_cond_dim = 0 139 | if self.cond: 140 | n_cond_dim = self.hp.arch_n_cond_latent 141 | 142 | self.seq_pe_dim = self.pe_dim if self.hp.modeltype == self.ModelType.transformer else 0 143 | 144 | if self.cond: 145 | if self.hp.repr.value.endswith('pointcloudcond'): 146 | self.encoder = TransformerSetFeature( 147 | self.hp.arch_internal_cond or self.hp.arch_internal, 148 | self.hp.arch_layer_cond or self.hp.arch_layer, 149 | self.hp.arch_head, 150 | n_cond_dim, 151 | dropout=self.hp.arch_dropout, 152 | vae_weight=self.hp.vae_weight 153 | ) 154 | elif self.hp.repr == ReprType.threeseqabs_threeseqabscond: 155 | self.encoder = BiRNNEncoderFeature( 156 | self.hp.arch_internal_cond or self.hp.arch_internal, 157 | self.hp.arch_layer_cond or self.hp.arch_layer, 158 | n_cond_dim, 159 | dropout=self.hp.arch_dropout, 160 | vae_weight=self.hp.vae_weight 161 | ) 162 | elif self.hp.repr == ReprType.threeseqdel_classcond or self.hp.repr == ReprType.threeseqabs_classcond: 163 | assert self.hp.num_classes is not None, "class conditional model but num_classes == 0" 164 | self.encoder = ClassEmbedding(self.hp.num_classes, n_cond_dim) 165 | else: 166 | raise NotImplementedError('unknown conditioning type') 167 | 168 | self.scorefn = ScoreFunc( 169 | self.hp.modeltype.value, 170 | # kwargs go here onwards 171 | inp_n_features=self.elem_dim * 2 - 1, # concat complementary repr too 172 | out_n_features=self.elem_dim, 173 | time_pe_features=self.pe_dim, 174 | seq_pe_features=self.seq_pe_dim, 175 | n_cond_features=n_cond_dim, 176 | n_internal=self.hp.arch_internal, 177 | n_head=self.hp.arch_head, 178 | n_layer=self.hp.arch_layer, 179 | causal=self.hp.arch_causal, 180 | dropout=self.hp.arch_dropout 181 | ) 182 | if self.hp.optim_ema: 183 | self.ema = EMA([ 184 | *self.scorefn.parameters(), 185 | *(self.encoder.parameters() if self.cond else []) 186 | ], decay=0.9999) 187 | 188 | self.register_buffer("pe_proj_W", 189 | torch.randn(self.pe_dim // 2, 1, requires_grad=False), persistent=True 190 | ) 191 | if self.seq_pe_dim > 0: 192 | self.register_buffer("seq_proj_W", 193 | torch.randn(self.seq_pe_dim // 2, 1, requires_grad=False), persistent=True 194 | ) 195 | 196 | # pre-computing all betas and alphas 197 | schedule_generator = { 198 | SketchDiffusion.NoiseSchedule.linear: linear_schedule, 199 | SketchDiffusion.NoiseSchedule.cosine: openai_cosine_schedule 200 | }[self.hp.noise_schedule] 201 | betas, alphas, alpha_bar, sqrt_alpha_bar, sqrt_one_min_alpha_bar, beta_tilde = \ 202 | schedule_generator( 203 | self.hp.noise_T, 204 | self.hp.noise_low_noise * 1000 / self.hp.noise_T, 205 | self.hp.noise_high_noise * 1000 / self.hp.noise_T, 206 | ) 207 | self.register_buffer("betas", torch.from_numpy(betas), persistent=False) 208 | self.register_buffer("alphas", torch.from_numpy(alphas), persistent=False) 209 | self.register_buffer("alpha_bar", torch.from_numpy(alpha_bar), persistent=False) 210 | self.register_buffer("sqrt_alpha_bar", torch.from_numpy(sqrt_alpha_bar), persistent=False) 211 | self.register_buffer("sqrt_one_min_alpha_bar", torch.from_numpy(sqrt_one_min_alpha_bar), persistent=False) 212 | self.register_buffer("beta_tilde", torch.from_numpy(beta_tilde), persistent=False) 213 | 214 | def to(self, *args, **kwargs): 215 | ret = super().to(*args, **kwargs) 216 | if self.device.index == 0 and self.hp.optim_ema: 217 | self.ema.to(self.device) 218 | return ret 219 | 220 | def on_fit_start(self) -> None: 221 | self.on_test_start() # needed for testing while training 222 | 223 | def on_before_zero_grad(self, optimizer) -> None: 224 | if self.device.index == 0 and self.hp.optim_ema: 225 | self.ema.update([ 226 | *self.scorefn.parameters(), 227 | *(self.encoder.parameters() if self.cond else []) 228 | ]) 229 | 230 | def on_save_checkpoint(self, checkpoint: dict) -> None: 231 | if self.hp.optim_ema: 232 | checkpoint["ema"] = self.ema.state_dict() 233 | 234 | def on_load_checkpoint(self, checkpoint) -> None: 235 | if self.hp.optim_ema: 236 | self.ema.load_state_dict(checkpoint["ema"]) 237 | 238 | @contextlib.contextmanager 239 | def ema_average(self, activate=True): 240 | if activate: 241 | with self.ema.average_parameters() as ctx: 242 | yield ctx 243 | else: 244 | with contextlib.nullcontext() as ctx: 245 | yield ctx 246 | 247 | def stdg_noise_seeded(self, *dims, seed: typing.Optional[int] = None): 248 | if seed is not None: 249 | _rngstate = torch.get_rng_state() 250 | torch.manual_seed(seed) 251 | _tmp = torch.randn(*dims, device=self.device) 252 | if seed is not None: 253 | torch.set_rng_state(_rngstate) 254 | return _tmp 255 | 256 | def create_batch_with_utilities(self, padded_seq, lens, seed=None): 257 | # padded_seq: (BxTxF) shape 258 | # lens: (B,) shaped long tensor to denote original length of each sample 259 | batch_size, = lens.shape 260 | padded_seq, timestamps = padded_seq[..., :self.elem_dim], padded_seq[..., self.elem_dim:] 261 | 262 | batch = {} # Keys: noise_target, timestamps, lens, noise_t, noisy_points, t 263 | 264 | # different 't's for different sample in the batch 265 | t = torch.randint(1, self.hp.noise_T + 1, size=(batch_size, )) 266 | 267 | g_noise = self.stdg_noise_seeded(*padded_seq.shape, seed=seed) 268 | 269 | batch['timestamps'] = timestamps 270 | batch['lens'] = lens 271 | batch['noise_t'] = self.pe[t - 1, :] 272 | batch['t'] = t - 1 273 | batch['noisy_points'] = padded_seq * self.sqrt_alpha_bar[t - 1, None, None] \ 274 | + g_noise * self.sqrt_one_min_alpha_bar[t - 1, None, None] 275 | batch['target'] = g_noise 276 | 277 | return batch 278 | 279 | def ncsn_loss(self, score, noise_target, lens, t): 280 | pad_mask = make_pad_mask_for_transformer(lens, total_length=score.shape[1], device=lens.device) 281 | unreduced_loss = (score - noise_target).pow(2).mean(-1) 282 | masked_loss = (unreduced_loss * (~pad_mask).float()) / lens.unsqueeze(-1) 283 | per_sample_loss = masked_loss.sum(-1) # sum along length since already divided by lengths 284 | return per_sample_loss.mean() 285 | 286 | def configure_optimizers(self): 287 | optim = torch.optim.AdamW(self.parameters(), 288 | lr=self.hp.optim_lr, 289 | weight_decay=self.hp.optim_decay) 290 | if self.hp.optim_sched == 'steplr': 291 | sched = torch.optim.lr_scheduler.StepLR(optim, 292 | step_size=1, 293 | gamma=self.hp.optim_gamma 294 | ) 295 | elif self.hp.optim_sched == 'onecyclelr': 296 | steps_per_epoch = len(self.trainer.datamodule.train_dataset) \ 297 | // self.trainer.datamodule.batch_size 298 | total_epochs = self.trainer.max_epochs 299 | total_steps = steps_per_epoch * total_epochs 300 | total = total_epochs if self.hp.optim_interval == 'epoch' else total_steps 301 | warmup_fraction = self.hp.optim_warmup / total 302 | sched = torch.optim.lr_scheduler.OneCycleLR(optim, 303 | max_lr=self.hp.optim_lr, 304 | total_steps=total, 305 | anneal_strategy='linear', 306 | cycle_momentum=True, 307 | pct_start=warmup_fraction, 308 | div_factor=self.hp.optim_div_factor, 309 | final_div_factor=1000 310 | ) 311 | else: 312 | raise NotImplementedError('scheduler not known/implemented') 313 | 314 | return { 315 | 'optimizer': optim, 316 | 'lr_scheduler': { 317 | 'scheduler': sched, 318 | 'frequency': 1, 319 | 'interval': self.hp.optim_interval 320 | } 321 | } 322 | 323 | def create_posvel_aug_input(self, points): 324 | if self.hp.repr.value.startswith('threeseqdel'): 325 | points_vel = points 326 | points_pos = torch.cumsum(points[..., :-1], dim=1) 327 | elif self.hp.repr.value.startswith('threeseqabs'): 328 | points_vel = torch.cat([ 329 | points[:, 0, None, :-1], 330 | (points[:, 1:, :-1] - points[:, :-1, :-1]) 331 | ], 1) 332 | points_pos = points 333 | else: 334 | raise NotImplementedError('ReprType not implemented') 335 | 336 | return points_pos, points_vel 337 | 338 | def forward(self, noisy_points, seq_pe, lens, noise_t, cond_latent): 339 | noisy_points_pos, noisy_points_vel = self.create_posvel_aug_input(noisy_points) 340 | 341 | if self.hp.modeltype == SketchDiffusion.ModelType.transformer: 342 | origin = torch.zeros(noisy_points.size(0), 1, 3, dtype=self.dtype, device=self.device, requires_grad=False) 343 | noisy_points_pos = torch.cat([origin[..., :noisy_points_pos.shape[-1]], noisy_points_pos], 1) 344 | noisy_points_vel = torch.cat([origin[..., :noisy_points_vel.shape[-1]], noisy_points_vel], 1) 345 | seq_pe = torch.cat([self._create_seq_embeddings(origin[..., :1]), seq_pe], dim=1) # add origin timestamp 346 | lens = lens + 1 # due an added origin 347 | 348 | with self.ema_average(not self.training and self.hp.optim_ema): 349 | out = self.scorefn((noisy_points_pos, noisy_points_vel), seq_pe, lens, noise_t, cond_latent) 350 | 351 | return out 352 | 353 | def _create_seq_embeddings(self, timestamps): 354 | if self.seq_pe_dim > 0: 355 | batch_size, max_len, _ = timestamps.shape 356 | timestamps = timestamps.permute(2, 0, 1) 357 | temb = random_fourier_encoding_dyn(timestamps.view(1, batch_size * max_len), self.seq_proj_W, scale=4.) 358 | return temb.view(batch_size, max_len, self.seq_pe_dim) 359 | else: 360 | return None 361 | 362 | def encode(self, *args): 363 | if self.cond: 364 | with self.ema_average(not self.training and self.hp.optim_ema): 365 | return self.encoder(*args) 366 | else: 367 | return None, 0. 368 | 369 | def training_step(self, batch, batch_idx): 370 | cond_batch, batch = batch 371 | 372 | batch = self.create_batch_with_utilities(*batch) 373 | cond_latent, kl_loss = self.encode(cond_batch) 374 | score = self(batch['noisy_points'], self._create_seq_embeddings(batch['timestamps']), 375 | batch['lens'], batch['noise_t'], cond_latent) 376 | loss = self.ncsn_loss(score, batch['target'], batch['lens'], batch['t']) 377 | self.log('train/loss', loss, prog_bar=True) 378 | if self.hp.vae_weight != 0.: 379 | kl_loss = kl_loss.mean() 380 | self.log('train/kl', kl_loss, prog_bar=False) 381 | kl_annealing_factor = min(max(self.global_step - self.hp.vae_kl_anneal_start, 0.) / \ 382 | (self.hp.vae_kl_anneal_end - self.hp.vae_kl_anneal_start), 1.) 383 | self.log('train/kl_factor', kl_annealing_factor, prog_bar=False) 384 | else: 385 | kl_annealing_factor = 0. 386 | return loss + \ 387 | self.hp.vae_weight * kl_annealing_factor * kl_loss 388 | 389 | def validation_step(self, batch, batch_idx): 390 | loss = self.training_step(batch, batch_idx) 391 | 392 | # on-the-fly testing while training 393 | if batch_idx == 0 and (self.current_epoch + 0) % 300 == 0 and self.device.index == 0: 394 | save_file_path = os.path.join(self.trainer.log_dir, 395 | f"ddpm1.pdf") 396 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddpm, langevin_strength=1.) 397 | self.fig.savefig(save_file_path, bbox_inches='tight') 398 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddpm1') 399 | 400 | save_file_path = os.path.join(self.trainer.log_dir, 401 | f"ddpm.5.pdf") 402 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddpm, langevin_strength=0.5) 403 | self.fig.savefig(save_file_path, bbox_inches='tight') 404 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddpm.5') 405 | 406 | save_file_path = os.path.join(self.trainer.log_dir, 407 | f"ddim_reco.pdf") 408 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddim, langevin_strength=0.) 409 | self.fig.savefig(save_file_path, bbox_inches='tight') 410 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddim_reco') 411 | 412 | save_file_path = os.path.join(self.trainer.log_dir, 413 | f"ddim_gen.pdf") 414 | ret_dict = self.reconstruction(batch, SketchDiffusion.SamplingAlgo.ddim, langevin_strength=0., generation=True) 415 | self.fig.savefig(save_file_path, bbox_inches='tight') 416 | self.cache_reverse_process(ret_dict["all"], -1, ret_dict["lens"], idx=batch_idx, prefix='ddim_gen') 417 | 418 | return loss 419 | 420 | def validation_epoch_end(self, losses_for_batches) -> None: 421 | valid_loss = sum(losses_for_batches) / len(losses_for_batches) 422 | self.log('valid/loss', valid_loss, prog_bar=True) 423 | 424 | def on_test_start(self) -> None: 425 | ts = torch.linspace(1, self.hp.noise_T, self.hp.noise_T, 426 | dtype=self.dtype, device=self.device) / self.hp.noise_T 427 | self.pe = random_fourier_encoding_dyn(ts[None, ...], self.pe_proj_W, scale=4.) \ 428 | if self.hp.time_embedding == SketchDiffusion.TimeEmbedding.randomfourier else \ 429 | positionalencoding1d(self.pe_dim, self.hp.noise_T, N=self.hp.noise_T, 430 | dtype=self.dtype, device=self.device) 431 | 432 | n_viz = self.hp.test_n_viz * 2 if self.hp.text_viz_process == SketchDiffusion.VizProcess.both else self.hp.test_n_viz 433 | cviz = CustomViz(self.hp.test_n_sample_viz, n_viz, compact_mode=self.hp.test_viz_fig_compact) 434 | self.fig, self.ax = cviz, cviz 435 | 436 | def cache_reverse_process(self, all_points_t, t, lens, idx, prefix='gen'): 437 | # npz_save_path = os.path.join(self.trainer.log_dir, f'{prefix}_rev_{idx}.npz') 438 | # with open(npz_save_path, 'wb') as f: 439 | # np.savez(f, reverse=all_points_t.cpu().numpy(), lens=lens.cpu().numpy()) 440 | samples = all_points_t[t, ...] 441 | samples = torch.split(samples, self.ax.shape[0], dim=0) 442 | lens = torch.split(lens, self.ax.shape[0], dim=0) 443 | for j in range(self.ax.shape[1]): 444 | try: 445 | self.draw_on_seq(samples[j], lens[j], j) 446 | except: 447 | for i in range(self.ax.shape[0]): 448 | self.ax[i, j].cla() 449 | self.ax[i, j].axis('off') 450 | save_file_path = os.path.join(self.trainer.log_dir, f'{prefix}_{idx}.svg') 451 | self.fig.savefig(save_file_path, bbox_inches='tight') 452 | 453 | def test_step(self, batch, batch_idx): 454 | if self.hp.test_recon: 455 | save_file_path = os.path.join(self.trainer.log_dir, f'diff_{batch_idx}.svg') 456 | rev_dict = self.reconstruction(batch, self.hp.test_sampling_algo, self.hp.test_variance_strength, 457 | generation=True, partial_t=self.hp.test_partial_T) 458 | self.fig.savefig(save_file_path, bbox_inches='tight') 459 | if self.hp.test_save_everything: 460 | _, (vels, lens) = batch 461 | vels, ts = vels[..., :self.elem_dim], vels[..., self.elem_dim:] 462 | orig, orig_len = self.velocity_to_position(vels, lens) 463 | # self.cache_reverse_process(orig[None, ...], -1, orig_len, idx=batch_idx, prefix='orig') 464 | self.cache_reverse_process(rev_dict["all"], -1, rev_dict["lens"], idx=batch_idx, prefix=f'gen') 465 | 466 | if self.hp.test_interp: 467 | save_file_path = os.path.join(self.trainer.log_dir, f'interp_{batch_idx}.svg') 468 | _ = self.interpolation(batch, self.hp.test_sampling_algo, langevin_strength=0.) 469 | self.fig.savefig(save_file_path, bbox_inches='tight') 470 | 471 | def velocity_to_position(self, points, lens): 472 | B, _, _ = points.shape 473 | 474 | points = torch.cat([ 475 | torch.zeros(B, 1, self.elem_dim, dtype=points.dtype, device=points.device), 476 | points 477 | ], dim=1) 478 | lens = lens + 1 # there is the extra initial point along length 479 | 480 | if self.hp.repr.value.startswith('threeseqdel'): 481 | # last one is pen-up bit -- leave it as is 482 | points[..., :-1] = torch.cumsum(points[..., :-1], dim=1) 483 | else: 484 | # this incorporates THREESEQABS 485 | pass 486 | 487 | points[..., -1][points[..., -1] > 0.8] = 1. 488 | points[..., -1][points[..., -1] < 0.8] = 0. 489 | 490 | return points, lens 491 | 492 | def draw_on_seq(self, points, lens, t_): 493 | points = points.detach().cpu().numpy() 494 | lens = lens.cpu().numpy() 495 | 496 | cm = get_cmap('copper') # I like this one 497 | for b in range(self.hp.test_n_sample_viz): 498 | sample_seq: Sketch = Sketch.from_threeseqabs(points[b, :lens[b], :]) 499 | sample_seq.draw(self.ax[b, t_], color=cm, cla=True, scatter=False) 500 | 501 | def forward_diffusion(self, velocs, lens, draw=True, end_t=None): 502 | viz_t = np.linspace(0, end_t or self.hp.noise_T, self.hp.test_n_viz, dtype=np.int64) 503 | 504 | if draw: # the original sample 505 | points, points_len = self.velocity_to_position(velocs, lens) 506 | self.draw_on_seq(points, points_len, self.t_) 507 | self.t_ += 1 508 | 509 | for t in viz_t[1:]: 510 | g_noise = self.stdg_noise_seeded(*velocs.shape) 511 | 512 | velocs_t = velocs * self.sqrt_alpha_bar[t - 1, None, None] \ 513 | + g_noise * self.sqrt_one_min_alpha_bar[t - 1, None, None] 514 | 515 | if draw: 516 | points_t, points_len = self.velocity_to_position(velocs_t, lens) 517 | self.draw_on_seq(points_t, points_len, self.t_) 518 | self.t_ += 1 519 | 520 | return velocs_t 521 | 522 | def reverse_purturb_DDPM(self, points, timestamps, t, lens, cond_latent, steps, noise_weight=1.): 523 | now, now_index = steps[t], steps[t] - 1 524 | 525 | score = self(points, timestamps, lens, self.pe[now_index, :].repeat(points.shape[0], 1), cond_latent) 526 | k1 = 1. / torch.sqrt(self.alphas[now_index]) 527 | k2 = (1. - self.alphas[now_index]) / self.sqrt_one_min_alpha_bar[now_index] 528 | mean = k1 * (points - k2 * score) 529 | 530 | gen_noise = self.stdg_noise_seeded(*points.shape) * torch.sqrt(self.beta_tilde[now_index]) \ 531 | if now > 1 else 0. 532 | 533 | points = mean + gen_noise * noise_weight 534 | return points 535 | 536 | def reverse_purturb_DDIM(self, points, timestamps, t, lens, cond_latent, steps, noise_weight=0.): 537 | now, now_index = steps[t], steps[t] - 1 538 | 539 | score = self(points, timestamps, lens, self.pe[now_index, :].repeat(points.shape[0], 1), cond_latent) 540 | x0_pred = (points - self.sqrt_one_min_alpha_bar[now_index] * score) \ 541 | / self.sqrt_alpha_bar[now_index] 542 | 543 | if now > 1: 544 | prev, prev_index = steps[t + 1], steps[t + 1] - 1 545 | 546 | # generalized version of DDIM sampler, with explicit \sigma_t 547 | s1 = self.sqrt_one_min_alpha_bar[prev_index] / self.sqrt_one_min_alpha_bar[now_index] 548 | s2 = torch.sqrt(1. - self.alpha_bar[now_index] / self.alpha_bar[prev_index]) 549 | sigma = (s1 * s2) * noise_weight # additional control for the noise 550 | 551 | gen_noise = self.stdg_noise_seeded(*points.shape) 552 | 553 | points = self.sqrt_alpha_bar[prev_index] * x0_pred \ 554 | + torch.sqrt(1. - self.alpha_bar[prev_index] - sigma**2) * score \ 555 | + gen_noise * sigma 556 | else: 557 | points = x0_pred 558 | 559 | return points 560 | 561 | def forward_purturb_DDIM(self, points, timestamps, t, lens, cond_latent, steps, noise_weight=1.): 562 | # DDIM's reverse of the reverse process -- integrating the ODE backwards 563 | now, now_index = steps[t], steps[t] - 1 564 | prev, prev_index = steps[t] - 1, steps[t] - 2 565 | 566 | score = self(points, timestamps, lens, self.pe[prev_index, :].repeat(points.shape[0], 1), cond_latent) \ 567 | if prev != 0 else 0. 568 | 569 | xT_pred = (points - self.sqrt_one_min_alpha_bar[prev_index] * score) \ 570 | / (self.sqrt_alpha_bar[prev_index] if prev != 0 else 1.) 571 | 572 | points = self.sqrt_alpha_bar[now_index] * xT_pred + self.sqrt_one_min_alpha_bar[now_index] * score 573 | return points 574 | 575 | def reverse_diffusion(self, points, timestamps, lens, cond_latent, sampling_algo, langevin_strength, draw=True, start_t=None): 576 | veloc_t = points 577 | 578 | if start_t is not None: 579 | assert sampling_algo == SketchDiffusion.SamplingAlgo.ddpm, \ 580 | 'partially stopping diffusion makes sense only for stochastic sampler' 581 | assert start_t <= self.hp.noise_T, f"partial stopping time must be less that T={self.hp.noise_T}" 582 | 583 | inference_steps, sampling_fn = { 584 | SketchDiffusion.SamplingAlgo.ddpm: ( 585 | np.linspace(start_t or self.hp.noise_T, 1, start_t or self.hp.noise_T, dtype=np.int64), 586 | SketchDiffusion.reverse_purturb_DDPM 587 | ), 588 | SketchDiffusion.SamplingAlgo.ddim: ( 589 | np.linspace(self.hp.noise_T, 1, self.hp.noise_T, dtype=np.int64), 590 | SketchDiffusion.reverse_purturb_DDIM 591 | ), 592 | SketchDiffusion.SamplingAlgo.fddim: ( 593 | np.linspace(1, self.hp.noise_T, self.hp.noise_T, dtype=np.int64), 594 | SketchDiffusion.forward_purturb_DDIM 595 | ) 596 | }[sampling_algo] 597 | 598 | viz_t = np.linspace(self.hp.noise_T, 1, self.hp.test_n_viz, dtype=np.int64) 599 | 600 | points_t_all_steps = [] 601 | for t in range(inference_steps.shape[0]): 602 | veloc_t = sampling_fn(self, veloc_t, timestamps, t, lens, cond_latent, 603 | inference_steps, noise_weight=langevin_strength) 604 | points_t, points_len = self.velocity_to_position(veloc_t, lens) 605 | if inference_steps[t] in viz_t: 606 | if draw: 607 | self.draw_on_seq(points_t, points_len, self.t_) 608 | self.t_ += 1 609 | 610 | if self.hp.test_save_everything: 611 | points_t_all_steps.append(points_t) 612 | 613 | return { 614 | "orig_last": veloc_t, 615 | "last": points_t, 616 | "all": torch.stack(points_t_all_steps, 0) if self.hp.test_save_everything else [ ], 617 | "lens": points_len 618 | } 619 | 620 | def reconstruction(self, batch, sampling_algo, langevin_strength, generation=False, partial_t=None): 621 | assert sampling_algo != SketchDiffusion.SamplingAlgo.fddim, "FDDIM is not to be used by public API" 622 | 623 | self.t_ = 0 624 | cond_batch, (points, lens) = batch 625 | 626 | cond_latent, _ = self.encode(cond_batch) 627 | points, timestamps = points[..., :self.elem_dim], points[..., self.elem_dim:] 628 | 629 | if sampling_algo != SketchDiffusion.SamplingAlgo.ddim: 630 | diffused = self.forward_diffusion(points, lens, 631 | draw=self.hp.text_viz_process == SketchDiffusion.VizProcess.forward \ 632 | or self.hp.text_viz_process == SketchDiffusion.VizProcess.both, 633 | end_t=partial_t) 634 | 635 | if partial_t is None: 636 | perm = torch.randperm(lens.size(0)) 637 | lens = lens[perm] # reset lengths 638 | diffused = torch.randn_like(diffused) 639 | else: 640 | # execute forward DDIM (feature extraction) 641 | diffused = self.reverse_diffusion(points, self._create_seq_embeddings(timestamps), lens, cond_latent, 642 | SketchDiffusion.SamplingAlgo.fddim, langevin_strength, 643 | draw=self.hp.text_viz_process == SketchDiffusion.VizProcess.forward \ 644 | or self.hp.text_viz_process == SketchDiffusion.VizProcess.both) 645 | diffused = diffused["orig_last"] 646 | if generation: 647 | diffused = torch.randn_like(diffused) 648 | 649 | rev_dict = self.reverse_diffusion(diffused, self._create_seq_embeddings(timestamps), lens, cond_latent, 650 | sampling_algo, langevin_strength, 651 | draw=self.hp.text_viz_process == SketchDiffusion.VizProcess.backward \ 652 | or self.hp.text_viz_process == SketchDiffusion.VizProcess.both, start_t=partial_t) 653 | return rev_dict 654 | 655 | def interpolation(self, batch, sampling_algo, langevin_strength=0.): 656 | assert sampling_algo != SketchDiffusion.SamplingAlgo.fddim, "FDDIM is not to be used by public API" 657 | 658 | cond_batch1, (points1, lens1) = batch # samples not really needed, only lens 659 | 660 | # random shuffle before executing generation 661 | perm = torch.randperm(points1.shape[0], device=points1.device) 662 | points2, lens2 = points1[perm, ...], lens1[perm] 663 | 664 | cond_latent1, _ = self.encode(cond_batch1) 665 | cond_latent2 = cond_latent1[perm, ...] if self.cond else None 666 | 667 | points1, timestamps1 = points1[..., :self.elem_dim], points1[..., self.elem_dim:] 668 | points2, timestamps2 = points2[..., :self.elem_dim], points2[..., self.elem_dim:] 669 | 670 | prior1 = torch.randn_like(points1) 671 | prior2 = torch.randn_like(points2) 672 | 673 | for a_, alpha in enumerate(np.linspace(0., 1., self.ax.shape[1])): 674 | if not self.cond: 675 | prior = prior1 * (1. - alpha) + prior2 * alpha 676 | lens = lens1 677 | cond_latent = None 678 | else: 679 | prior = prior1 680 | lens = lens1 681 | cond_latent = cond_latent1 * (1. - alpha) + cond_latent2 * alpha 682 | 683 | if self.hp.modeltype == SketchDiffusion.ModelType.transformer: 684 | raise NotImplementedError('interpolation with transformer model not yet implemented') 685 | 686 | recon_dict = self.reverse_diffusion(prior, None, lens, cond_latent, 687 | sampling_algo, langevin_strength=0., draw=False) 688 | self.draw_on_seq(recon_dict["last"], recon_dict["lens"], a_) 689 | 690 | 691 | if __name__ == '__main__': 692 | cli = LightningCLI(SketchDiffusion, GenericDM, run=True, 693 | subclass_mode_data=True, 694 | parser_kwargs={"parser_mode": "omegaconf"}, 695 | trainer_defaults={ 696 | 'callbacks': [ 697 | LearningRateMonitor(logging_interval='step'), 698 | ModelCheckpoint(monitor='valid/loss', filename='model', save_last=True), 699 | TQDMProgressBar(refresh_rate=1 if sys.stdin.isatty() else 0) 700 | ] 701 | }) 702 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hyunwoong 3 | @when : 2019-10-22 4 | @homepage : https://github.com/gusdnd852 5 | """ -------------------------------------------------------------------------------- /models/score.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 6 | 7 | from utils import make_pad_mask_for_transformer 8 | 9 | 10 | class ConditionalTransformerEncoder(nn.Module): 11 | 12 | def __init__(self, n_input, n_internal, n_layers, n_head, causal=False, dropout=0.) -> None: 13 | super().__init__() 14 | 15 | self.n_input = n_input 16 | self.n_internal = n_internal 17 | self.n_layers = n_layers 18 | self.n_head = n_head 19 | self.causal = causal 20 | self.dropout = dropout 21 | 22 | self.embedder = nn.Linear(self.n_input, self.n_internal) 23 | 24 | self.transformer = nn.TransformerEncoder( 25 | nn.TransformerEncoderLayer( 26 | self.n_internal, 27 | self.n_head, 28 | dim_feedforward=self.n_internal * 2, 29 | batch_first=True, dropout=self.dropout, activation=F.silu 30 | ), 31 | num_layers=self.n_layers 32 | ) 33 | 34 | def forward(self, noisy, lens): 35 | _, max_len, _ = noisy.shape 36 | len_padd_mask = make_pad_mask_for_transformer(lens, max_len, noisy.device) 37 | 38 | if self.causal: 39 | I = torch.eye(max_len, dtype=noisy.dtype, device=noisy.device) 40 | attn_mask = (torch.cumsum(I, -1) - I) == 1. 41 | else: 42 | attn_mask = None 43 | 44 | input_emb = self.embedder(noisy) 45 | output = self.transformer(input_emb, mask=attn_mask, src_key_padding_mask=len_padd_mask) 46 | 47 | return output 48 | 49 | 50 | class ConditionalBiRNN(nn.Module): 51 | 52 | def __init__(self, n_input, n_hidden, n_layers, dropout=0., causal=False) -> None: 53 | super().__init__() 54 | 55 | self.n_input = n_input 56 | self.n_hidden = n_hidden 57 | self.n_layers = n_layers 58 | self.dropout = dropout 59 | self.causal = causal 60 | 61 | self.rnn = nn.GRU(self.n_input, self.n_hidden, self.n_layers, 62 | batch_first=True, 63 | dropout=self.dropout, 64 | bidirectional=not self.causal) 65 | 66 | directionality = 2 if not self.causal else 1 67 | self.out_proj = nn.Linear( 68 | self.n_hidden * directionality, 69 | self.n_hidden 70 | ) 71 | 72 | def forward(self, noisy, lens): 73 | noisy_packed = pack_padded_sequence(noisy, lens.cpu(), batch_first=True, enforce_sorted=False) 74 | hid, _ = self.rnn(noisy_packed) 75 | out_unpacked, _ = pad_packed_sequence(hid, batch_first=True) 76 | 77 | return self.out_proj(out_unpacked) 78 | 79 | 80 | class ScoreFunc(nn.Module): 81 | 82 | def __init__(self, modeltype, *, inp_n_features=5, out_n_features=3, time_pe_features=2, seq_pe_features=2, 83 | n_cond_features=0, n_head=4, n_layer=4, n_internal=64, causal=False, dropout=0.) -> None: 84 | super().__init__() 85 | 86 | self.modeltype = modeltype 87 | self.inp_n_features = inp_n_features 88 | self.out_n_features = out_n_features 89 | self.time_pe_features = time_pe_features # for diffusion steps 90 | self.seq_pe_features = seq_pe_features # for sequence time-stamps 91 | self.n_cond_features = n_cond_features # for conditioning 92 | self.n_internal = n_internal 93 | self.n_head = n_head 94 | self.n_layer = n_layer 95 | self.causal = causal 96 | self.dropout = dropout 97 | 98 | self.n_additionals = self.time_pe_features + self.seq_pe_features + self.n_cond_features 99 | self.n_total_features = self.inp_n_features + self.n_additionals 100 | 101 | if self.modeltype == 'birnn': 102 | self.model = ConditionalBiRNN(self.n_total_features, self.n_internal, self.n_layer, 103 | dropout=self.dropout, causal=self.causal) 104 | elif self.modeltype == 'transformer': 105 | self.model = ConditionalTransformerEncoder(self.n_total_features, 106 | self.n_internal, self.n_layer, self.n_head, causal=self.causal, dropout=self.dropout) 107 | else: 108 | raise NotImplementedError(f"Unknown model type {self.modeltype.value}") 109 | 110 | self.final_proj = nn.Sequential( 111 | nn.Linear(self.n_internal * (2 if self.modeltype == 'transformer' else 1) \ 112 | + self.n_additionals - self.seq_pe_features, self.out_n_features), 113 | ) 114 | 115 | def forward(self, noisy, seq_pe, lens, time_pe, cond=None): 116 | noisy_pos, noisy_vel = noisy 117 | noisy = torch.cat([noisy_pos, noisy_vel], -1) 118 | 119 | if isinstance(cond, tuple): 120 | # This is 'threeseqabs_threeseqabseqsampledcond' repr. 121 | # But not a good way to check (TODO: better API) 122 | cond = torch.cat(cond, -1) 123 | 124 | batch_size, max_len, _ = noisy.shape 125 | 126 | time_pe = time_pe.unsqueeze(1).repeat(1, max_len, 1) 127 | 128 | if cond is not None: 129 | assert self.n_cond_features != 0, "conditioning is being done but no dimension allocated" 130 | if len(cond.shape) == 2: 131 | cond = cond.unsqueeze(1).repeat(1, max_len, 1) 132 | time_cond = torch.cat([time_pe, cond], -1) 133 | else: 134 | time_cond = time_pe 135 | 136 | if self.seq_pe_features > 0: 137 | additionals = torch.cat([seq_pe, time_cond], -1) 138 | else: 139 | additionals = time_cond 140 | 141 | output = self.model( 142 | torch.cat([noisy, additionals], -1), 143 | lens 144 | ) 145 | 146 | if self.modeltype == 'birnn': 147 | return self.final_proj(torch.cat([output, time_cond], -1)) 148 | else: 149 | conseq_cat_output = torch.cat([output[:, :-1, :], output[:, 1:, ]], -1) 150 | return self.final_proj(torch.cat([conseq_cat_output, time_cond[:, 1:, :]], -1)) 151 | 152 | 153 | class TransformerSetFeature(ConditionalTransformerEncoder): 154 | 155 | def __init__(self, n_internal, n_layers, n_head, n_latent, dropout=0., vae_weight=0.) -> None: 156 | # '+1' is for the extra feature for denoting feature extractor token 157 | super().__init__(2 + 1, n_internal, n_layers, n_head, causal=False, dropout=dropout) 158 | self.n_latent = n_latent 159 | self.vae_weight = vae_weight 160 | 161 | if self.vae_weight == 0.: 162 | self.latent_proj = nn.Sequential( 163 | nn.Linear(n_internal, self.n_latent), 164 | nn.Tanh() 165 | ) 166 | else: 167 | self.latent_proj_mean = nn.Sequential(nn.Linear(n_internal, self.n_latent)) 168 | self.latent_proj_logvar = nn.Sequential(nn.Linear(n_internal, self.n_latent)) 169 | 170 | def forward(self, cond_batch): 171 | set_input, lens = cond_batch 172 | B, L, _ = set_input.shape 173 | # creating an extra feature extractor token 174 | pad_token = torch.zeros(B, L, 1, device=set_input.device, dtype=set_input.dtype) 175 | feat_token = torch.tensor([0., 0., 1.], device=set_input.device, dtype=set_input.dtype) 176 | feat_token = feat_token[None, None, :].repeat(B, 1, 1) 177 | set_input = torch.cat([set_input, pad_token], -1) 178 | set_input = torch.cat([feat_token, set_input], 1) 179 | lens = lens + 1 # extra token for feature extraction 180 | 181 | trans_out = super().forward(set_input, lens) 182 | 183 | if self.vae_weight == 0.: 184 | return self.latent_proj(trans_out[:, 0]), 0. 185 | else: 186 | mu = self.latent_proj_mean(trans_out[:, 0]) 187 | logvar = self.latent_proj_logvar(trans_out[:, 0]) 188 | posterior = torch.distributions.Normal(mu, torch.exp(0.5 * logvar)) 189 | prior = torch.distributions.Normal( 190 | torch.zeros_like(mu), 191 | torch.ones_like(logvar) 192 | ) 193 | return posterior.rsample(), torch.distributions.kl_divergence(posterior, prior) 194 | 195 | 196 | class BiRNNEncoderFeature(ConditionalBiRNN): 197 | 198 | def __init__(self, n_hidden, n_layers, n_latent, dropout=0., vae_weight=0.) -> None: 199 | super().__init__(3, n_hidden, n_layers, dropout) 200 | self.out_proj = nn.Identity() 201 | self.vae_weight = vae_weight 202 | 203 | self.n_latent = n_latent 204 | 205 | if self.vae_weight == 0.: 206 | self.latent_proj = nn.Sequential( 207 | nn.Linear(2 * self.n_hidden, self.n_latent), 208 | nn.Tanh() 209 | ) 210 | else: 211 | self.latent_proj_mean = nn.Sequential(nn.Linear(self.n_hidden, self.n_latent)) 212 | self.latent_proj_logvar = nn.Sequential(nn.Linear(self.n_hidden, self.n_latent)) 213 | 214 | def forward(self, cond_batch): 215 | batch, lens = cond_batch 216 | batch_size, max_len, _ = batch.shape 217 | batch = batch[..., :-1] # exclude the timestamps 218 | out = super().forward(batch, lens).view(batch_size, max_len, 2, self.n_hidden) 219 | out_fwd, out_bwd = out[:, :, 0, :], out[:, :, 1, :] 220 | fwd_feat = torch.gather( 221 | out_fwd, 222 | 1, 223 | lens[:, None, None].repeat(1, 1, self.n_hidden) - 1 224 | ).squeeze() 225 | bwd_feat = out_bwd[:, 0, :] 226 | 227 | if self.vae_weight == 0.: 228 | return self.latent_proj(torch.cat([fwd_feat, bwd_feat], -1)), 0. 229 | else: 230 | mu = self.latent_proj_mean(torch.cat([fwd_feat, bwd_feat], -1)) 231 | logvar = self.latent_proj_logvar(torch.cat([fwd_feat, bwd_feat], -1)) 232 | posterior = torch.distributions.Normal(mu, torch.exp(0.5 * logvar)) 233 | prior = torch.distributions.Normal( 234 | torch.zeros_like(mu), 235 | torch.ones_like(logvar) 236 | ) 237 | return posterior.rsample(), torch.distributions.kl_divergence(posterior, prior) 238 | 239 | 240 | class Lambda(nn.Module): 241 | 242 | def __init__(self, fn: typing.Callable) -> None: 243 | super().__init__() 244 | self.fn = fn 245 | 246 | def forward(self, x): 247 | # the extra zero is to make it compatible with other encoder 248 | return self.fn(x), 0. 249 | 250 | 251 | class ClassEmbedding(nn.Module): 252 | 253 | def __init__(self, num_classes, emb_dim) -> None: 254 | super().__init__() 255 | 256 | self.num_classes = num_classes 257 | self.emb_dim = emb_dim 258 | self.emb = nn.Embedding(self.num_classes, self.emb_dim) 259 | 260 | def forward(self, x): 261 | return self.emb(x), 0. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ndjson 2 | scipy 3 | matplotlib 4 | tqdm 5 | Pillow 6 | pytorch-lightning==1.5.9 7 | simplification 8 | noise 9 | torch-ema 10 | omegaconf 11 | jsonargparse[signatures] 12 | shapely 13 | wandb -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing 3 | import math 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from pytorch_lightning.loggers import WandbLogger 8 | from wandb.util import generate_id 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def positionalencoding1d(d_model, length, N=10000, dtype=None, device=None): 14 | """ 15 | :param d_model: dimension of the model 16 | :param length: length of positions 17 | :return: length*d_model position matrix 18 | """ 19 | if d_model % 2 != 0: 20 | raise ValueError("Cannot use sin/cos positional encoding with " 21 | "odd dim (got dim={:d})".format(d_model)) 22 | pe = torch.zeros(length, d_model, dtype=dtype, device=device) 23 | position = torch.arange(0, length, dtype=dtype, device=device).unsqueeze(1) 24 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=dtype, device=device) * 25 | -(math.log(N) / d_model))) 26 | pe[:, 0::2] = torch.sin(position.float() * div_term) 27 | pe[:, 1::2] = torch.cos(position.float() * div_term) 28 | 29 | return pe 30 | 31 | 32 | def random_fourier_encoding_dyn(ts, W, scale=4.): 33 | proj = (W * scale) @ ts 34 | emb = torch.cat([torch.sin(2 * torch.pi * proj), torch.cos(2 * torch.pi * proj)], 0) 35 | return emb.T 36 | 37 | 38 | def make_pad_mask_for_transformer(lens, total_length=None, device=None): 39 | total_length = total_length or max(lens) 40 | pad = torch.zeros(len(lens), total_length + 1, device=device) 41 | for b, l in enumerate(lens): 42 | pad[b, l] = 1. 43 | pad = torch.cumsum(pad, 1) 44 | return (pad[:, :-1] == 1.) 45 | 46 | 47 | def nonunif_timestep_selector(T, infer_T, gamma=2.): 48 | ui = np.linspace(1., 0., infer_T) # uniform index 49 | return np.unique(np.clip( 50 | # sample using gamma curves (y = x^gamma) 51 | np.floor((ui ** gamma) * T), 1., T 52 | ))[::-1].astype(np.int64) 53 | 54 | 55 | def openai_cosine_schedule(T, *args, s=0.008): 56 | # # explicitely defined $\bar{\alpha_t}$ and cosine function; 57 | # # beta and alpha derived thereafter; suggested by "Improved Denoising .. 58 | # # .. Diffusion Probabilistic Models" by OpenAI 59 | 60 | def f(t): return math.cos((t/T + s) / (1 + s) * math.pi / 2) ** 2 61 | alpha_bar = np.array([f(t) / f(0) for t in range(T + 1)], dtype=np.float32) 62 | sqrt_alpha_bar = np.sqrt(alpha_bar) 63 | sqrt_one_min_alpha_bar = np.sqrt(1. - alpha_bar) 64 | betas = np.clip(1. - alpha_bar[1:] / alpha_bar[:-1], 0., 0.999) 65 | alphas = 1. - betas 66 | beta_tilde = (1. - alpha_bar[:-1]) / (1. - alpha_bar[1:]) * betas 67 | 68 | return betas, alphas, alpha_bar[1:], \ 69 | sqrt_alpha_bar[1:], sqrt_one_min_alpha_bar[1:], beta_tilde 70 | 71 | 72 | def linear_schedule(T, low_noise, high_noise): 73 | # standard linear schedule defined in terms of $\beta_t$ 74 | betas = np.linspace(low_noise, high_noise, T, dtype=np.float32) 75 | alphas = 1. - betas 76 | alpha_bar = np.cumprod(alphas, 0) 77 | sqrt_alpha_bar = np.sqrt(alpha_bar) 78 | sqrt_one_min_alpha_bar = np.sqrt(1. - alpha_bar) 79 | beta_tilde_wo_first_term = ((sqrt_one_min_alpha_bar[:-1] / sqrt_one_min_alpha_bar[1:])**2 * betas[1:]) 80 | beta_tilde = np.array([ 81 | beta_tilde_wo_first_term[0], 82 | *beta_tilde_wo_first_term 83 | ]) 84 | 85 | return betas, alphas, alpha_bar, \ 86 | sqrt_alpha_bar, sqrt_one_min_alpha_bar, beta_tilde 87 | 88 | 89 | def cg_subtracted_noise(noise, lens): 90 | mask = torch.cumprod(1. - F.one_hot(lens, num_classes=noise.size(1) + 1)[:, :-1, None].float(), 1) 91 | # make sure the padding doesn't interfere in CoM calculation 92 | com = (mask * noise).sum(1, keepdim=True) / lens[:, None, None] 93 | return noise - com 94 | 95 | 96 | class CustomWandbLogger(WandbLogger): 97 | 98 | def __init__(self, 99 | name: typing.Optional[str], 100 | save_dir: typing.Optional[str] = 'logs', 101 | group: typing.Optional[str] = 'common', 102 | project: typing.Optional[str] = 'diffset', 103 | log_model: typing.Optional[bool] = True, 104 | offline: bool = False, 105 | entity: typing.Optional[str] = 'dasayan05'): 106 | rid = generate_id() 107 | name_rid = '-'.join([name, rid]) 108 | super().__init__(name=name_rid, id=rid, offline=offline, 109 | save_dir=os.path.join(save_dir, name_rid), project=project, 110 | log_model=log_model, group=group, entity=entity) 111 | 112 | 113 | class CustomViz(object): 114 | 115 | def __init__(self, test_n_sample_viz: int, n_viz: int, compact_mode: bool = True, subfig_slack: float = 0.) -> None: 116 | super().__init__() 117 | 118 | self.test_n_sample_viz = test_n_sample_viz 119 | self.n_viz = n_viz 120 | self.compact_mode = compact_mode 121 | 122 | if self.compact_mode: 123 | self.fig, self.ax = plt.subplots( 124 | self.test_n_sample_viz, 125 | self.n_viz, 126 | figsize=(self.n_viz, self.test_n_sample_viz), 127 | gridspec_kw = {'wspace': subfig_slack, 'hspace': subfig_slack}) 128 | else: 129 | self.figs = [ 130 | [ 131 | plt.subplots(1, 1, figsize=(1, 1)) \ 132 | for j in range(self.n_viz) 133 | ] for i in range(self.test_n_sample_viz) 134 | ] 135 | 136 | def __getitem__(self, pos: tuple): 137 | i, j = pos 138 | if self.compact_mode: 139 | return self.ax[i, j] 140 | else: 141 | _, ax = self.figs[i][j] 142 | return ax 143 | 144 | @property 145 | def shape(self): 146 | return self.test_n_sample_viz, self.n_viz 147 | 148 | def savefig(self, path: str, **kwargs): 149 | if self.compact_mode: 150 | self.fig.savefig(path, **kwargs) 151 | else: 152 | *rest, ext = path.split('.') 153 | rest = '.'.join(rest) 154 | os.makedirs(rest, exist_ok=False) 155 | for i in range(self.test_n_sample_viz): 156 | for j in range(self.n_viz): 157 | path = os.path.join(rest, f'{i}_{j}.' + ext) 158 | fig, _ = self.figs[i][j] 159 | fig.savefig(path, **kwargs) --------------------------------------------------------------------------------