├── .gitignore ├── LICENSE ├── README.md ├── configs ├── CHINESE_CASIA.yml ├── CHINESE_USER.yml ├── English_CASIA.yml └── Japanese_TUATHANDS.yml ├── data ├── data_loader └── loader.py ├── environment.yml ├── evaluate.py ├── models ├── encoder.py ├── eval_model.py ├── gmm.py ├── loss.py ├── model.py └── transformer.py ├── parse_config.py ├── static ├── Poster_SDT.pdf ├── duo.gif ├── duo_loop.gif ├── mo.gif ├── mo_loop.gif ├── offline_Chinese.jpg ├── online_Chinese.jpg ├── overview_sdt.jpg ├── print.png ├── software.png ├── svg.png ├── tai.gif ├── tai_loop.gif └── various_scripts.jpg ├── test.py ├── train.py ├── trainer └── trainer.py ├── user_generate.py └── utils ├── logger.py ├── metrics.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data/* 3 | Saved/* 4 | model_zoo/*.pth 5 | auto_* 6 | .vscode 7 | Generated 8 | style_samples -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gang Dai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![MIT LICENSE](https://shields.io/badge/license-MIT-green) 2 | ![python 3.8](https://img.shields.io/badge/python-3.8-brightgreen) 3 | # 🔥 Disentangling Writer and Character Styles for Handwriting Generation 4 | 5 |

6 | 7 | ArXiv 8 | | 9 | Poster 10 | | 11 | Video 12 | | 13 | Project 14 | 15 |

16 | 17 | ## 📢 Introduction 18 | - The proposed style-disentangled Transformer (SDT) generates online handwritings with conditional content and style. 19 | - Existing RNN-based methods mainly focus on capturing a person’s overall writing style, neglecting subtle style inconsistencies between characters written by the same person. In light of this, SDT disentangles the writer-wise and character-wise style representations from individual handwriting samples for enhancing imitation performance. 20 | - We extend SDT and introduce an offline-to-offline framework for improving the generation quality of offline Chinese handwritings. 21 | 22 |
23 | 24 |
25 |

26 | Overview of our SDT 27 |

28 | 29 |
30 | 31 |
32 |

33 | Three samples of online characters with writing orders 34 |

35 | 36 | ## 📅 News 37 | - [2024/11/26] 🎉🎉🎉 Release of the implementations of Content Score and Style Score. 38 | - [2024/07/01] 🎉🎉🎉 A new state-of-the-art method for handwritten text generation, named [One-DM](https://github.com/dailenson/One-DM), is accepted by ECCV 2024. 39 | - [2024/01/07] Add a tutorial and code for synthesizing handwriting with user-customized styles, more information can be found [here](https://github.com/dailenson/SDT/issues/43). 40 | - [2023/12/15] 🎉🎉🎉 This work is reported by a top [bilibili](https://www.bilibili.com/video/BV19w411t7vD/?buvid=XX73A437799B0DCC93D6D21690FA9CAE696EC&from_spmid=default-value&is_story_h5=false&mid=Xr0IfLrZqLFnTCriRB2HcQ%3D%3D&p=1&plat_id=116&share_from=ugc&share_medium=android&share_plat=android&share_session_id=2f9e186f-d693-4b61-80c6-372942bec32b&share_source=WEIXIN&share_source=weixin&share_tag=s_i&spmid=united.player-video-detail.0.0×tamp=1720580374&unique_k=OqWsKIV&up_id=19319172) video blogger with 2.7 million followers and received nearly one million views. 41 | - [2023/10/10] The [author](https://scholar.google.com.hk/citations?user=a2SwkisAAAAJ&hl=zh-CN) is invited to give a [talk](https://www.bilibili.com/video/BV1kQ4y1W7a7/?spm_id_from=333.999.0.0&vd_source=cbc77ced94dbf77f5ecef4e0afa94a33) (in Chinese) by CSIG (China Society of Image and Graphics). 42 | - [2023/06/14] This work is reported by [Synced](https://mp.weixin.qq.com/s/EX_Loj4PvIztQH5zrl2FNw) (机器之心). 43 | - [2023/04/12] Initial release of the datasets, pre-trained models, training and testing codes. 44 | - [2023/02/28] 🎉🎉🎉 Our SDT is accepted by CVPR 2023. 45 | 46 | ## 📺 Handwriting generation results 47 | - **Online Chinese handwriting generation** 48 | ![online Chinese](static/online_Chinese.jpg) 49 | 50 | - **Applications to various scripts** 51 | ![other scripts](static/various_scripts.jpg) 52 | - **Extension on offline Chinese handwriting generation** 53 | ![offline Chinese](static/offline_Chinese.jpg) 54 | 55 | 56 | ## 🔨 Requirements 57 | ``` 58 | conda create -n sdt python=3.8 -y 59 | conda activate sdt 60 | # install all dependencies 61 | conda env create -f environment.yml 62 | ``` 63 | 64 | ## 📂 Folder Structure 65 | ``` 66 | SDT/ 67 | │ 68 | ├── train.py - main script to start training 69 | ├── test.py - generate characters via trained model 70 | ├── evaluate.py - evaluation of generated samples 71 | │ 72 | ├── configs/*.yml - holds configuration for training 73 | ├── parse_config.py - class to handle config file 74 | │ 75 | ├── data_loader/ - anything about data loading goes here 76 | │ └── loader.py 77 | │ 78 | ├── model_zoo/ - pre-trained content encoder model 79 | │ 80 | ├── data/ - default directory for storing experimental datasets 81 | │ 82 | ├── model/ - networks, models and losses 83 | │ ├── encoder.py 84 | │ ├── gmm.py 85 | │ ├── loss.py 86 | │ ├── model.py 87 | │ └── transformer.py 88 | │ 89 | ├── saved/ 90 | │ ├── models/ - trained models are saved here 91 | │ ├── tborad/ - tensorboard visualization 92 | │ └── samples/ - visualization samples in the training process 93 | │ 94 | ├── trainer/ - trainers 95 | │ └── trainer.py 96 | │ 97 | └── utils/ - small utility functions 98 | ├── util.py 99 | └── logger.py - set log dir for tensorboard and logging output 100 | ``` 101 | 102 | ## 💿 Datasets 103 | 104 | We provide Chinese, Japanese and English datasets in [Google Drive](https://drive.google.com/drive/folders/17Ju2chVwlNvoX7HCKrhJOqySK-Y-hU8K?usp=share_link) | [Baidu Netdisk](https://pan.baidu.com/s/1RNQSRhBAEFPe2kFXsHZfLA) PW:xu9u. Please download these datasets, uzip them and move the extracted files to /data. 105 | 106 | ## 🍔 Pre-trained model 107 | 108 | | Model|Google Drive|Baidu Netdisk| 109 | |---------------|---------|-----------------------------------------| 110 | |Well-trained SDT|[Google Drive](https://drive.google.com/drive/folders/1LendizOwcNXlyY946ThS8HQ4wJX--YL7?usp=sharing) | [Baidu Netdisk](https://pan.baidu.com/s/1RNQSRhBAEFPe2kFXsHZfLA?pwd=xu9u) 111 | |Content encoder|[Google Drive](https://drive.google.com/drive/folders/1N-MGRnXEZmxAW-98Hz2f-o80oHrNaN_a?usp=share_link) | [Baidu Netdisk](https://pan.baidu.com/s/1RNQSRhBAEFPe2kFXsHZfLA?pwd=xu9u) 112 | |Content Score|[Google Drive](https://drive.google.com/drive/folders/1-2ciY6yfI4l1bVUD661EzEW5PInZb_62?usp=sharing)|[Baidu Netdisk]( https://pan.baidu.com/s/1cs8qWOhwISZz7w1dAYMQ3g?pwd=s8e8) 113 | |Style Score|[Google Drive](https://drive.google.com/drive/folders/1-2ciY6yfI4l1bVUD661EzEW5PInZb_62?usp=sharing) | [Baidu Netdisk]( https://pan.baidu.com/s/1cs8qWOhwISZz7w1dAYMQ3g?pwd=s8e8) 114 | 115 | **Note**: 116 | Please download these weights, and move them to /model_zoo. 117 | 118 | ## 🚀 Training & Test 119 | **Training** 120 | - To train the SDT on the Chinese dataset, run this command: 121 | ``` 122 | python train.py --cfg configs/CHINESE_CASIA.yml --log Chinese_log 123 | ``` 124 | 125 | - To train the SDT on the Japanese dataset, run this command: 126 | ``` 127 | python train.py --cfg configs/Japanese_TUATHANDS.yml --log Japanese_log 128 | ``` 129 | 130 | - To train the SDT on the English dataset, run this command: 131 | ``` 132 | python train.py --cfg configs/English_CASIA.yml --log English_log 133 | ``` 134 | 135 | **Qualitative Test** 136 | - To generate **online Chinese handwritings** with our SDT, run this command: 137 | ``` 138 | python test.py --pretrained_model checkpoint_path --store_type online --sample_size 500 --dir Generated/Chinese 139 | ``` 140 | - To generate **offline Chinese handwriting images** with our SDT, run this command: 141 | ``` 142 | python test.py --pretrained_model checkpoint_path --store_type offline --sample_size 500 --dir Generated_img/Chinese 143 | ``` 144 | 145 | - To generate **online Japanese handwritings** with our SDT, run this command: 146 | ``` 147 | python test.py --pretrained_model checkpoint_path --store_type online --sample_size 500 --dir Generated/Japanese 148 | ``` 149 | - To generate **offline Japanese handwriting images** with our SDT, run this command: 150 | ``` 151 | python test.py --pretrained_model checkpoint_path --store_type offline --sample_size 500 --dir Generated_img/Japanese 152 | ``` 153 | - To generate **online English handwritings** with our SDT, run this command: 154 | ``` 155 | python test.py --pretrained_model checkpoint_path --store_type online --sample_size 500 --dir Generated/English 156 | ``` 157 | - To generate **offline English handwriting images** with our SDT, run this command: 158 | ``` 159 | python test.py --pretrained_model checkpoint_path --store_type offline --sample_size 500 --dir Generated_img/English 160 | ``` 161 | 162 | **Quantitative Evaluation** 163 | - To evaluate the generated handwritings, you need to set `data_path` to the path of the generated handwritings (e.g., Generated/Chinese), and run this command: 164 | ``` 165 | python evaluate.py --data_path Generated/Chinese --metric DTW 166 | ``` 167 | - To calculate the Content Score of generated handwritings, you need to set `data_path` to the path of the generated handwritings (e.g., Generated/Chinese), and run this command: 168 | ``` 169 | python evaluate.py --data_path Generated/Chinese --metric Content_score --pretrained_model model_zoo/chinese_content_iter30k_acc95.pth 170 | ``` 171 | - To calculate the Style Score of generated handwritings, you need to set `data_path` to the path of the generated handwriting images (e.g., Generated_img/Chinese), and run this command: 172 | ``` 173 | python evaluate.py --data_path Generated_img/Chinese --metric Style_score --pretrained_model models_zoo/chinese_style_iter60k_acc999.pth 174 | ``` 175 | ## 🏰 Practical Application 176 | We are delighted to discover that **[P0etry-rain](https://github.com/P0etry-rain)** has proposed a pipeline that involves initially converting the generated results by our SDT to TTF format, followed by the development of software to enable flexible adjustments in spacing between paragraphs, lines, and characters. Below, we present TTF files, software interface and the printed results. More details can be seen in [#78](https://github.com/dailenson/SDT/issues/78#issue-2247810028). 177 | - **TTF File** 178 | ![SVG](static/svg.png) 179 | 180 | - **Software Interface** 181 | ![Interface](static/software.png) 182 | 183 | - **Printed Results** 184 | ![Result](static/print.png) 185 | 186 | 187 | 188 | ## ❤️ Citation 189 | If you find our work inspiring or use our codebase in your research, please cite our work: 190 | ``` 191 | @inproceedings{dai2023disentangling, 192 | title={Disentangling Writer and Character Styles for Handwriting Generation}, 193 | author={Dai, Gang and Zhang, Yifan and Wang, Qingfeng and Du, Qing and Yu, Zhuliang and Liu, Zhuoman and Huang, Shuangping}, 194 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 195 | pages={5977--5986}, 196 | year={2023} 197 | } 198 | ``` 199 | 200 | ## ⭐ StarGraph 201 | [![Star History Chart](https://api.star-history.com/svg?repos=dailenson/SDT&type=Timeline)](https://star-history.com/#dailenson/SDT&Timeline) 202 | 203 | 204 | -------------------------------------------------------------------------------- /configs/CHINESE_CASIA.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ENCODER_LAYERS: 2 3 | WRI_DEC_LAYERS: 2 4 | GLY_DEC_LAYERS: 2 5 | NUM_HEAD_LAYERS: 1 6 | NUM_IMGS: 15 7 | NUM_GPUS: 1 # TODO, support multi GPUs 8 | SOLVER: 9 | BASE_LR: 0.0002 10 | MAX_ITER: 200000 11 | WARMUP_ITERS: 20000 12 | TYPE: Adam # TODO, support optional optimizer 13 | GRAD_L2_CLIP: 5.0 14 | TRAIN: 15 | ISTRAIN: True 16 | IMS_PER_BATCH: 64 17 | SNAPSHOT_BEGIN: 2000 18 | SNAPSHOT_ITERS: 4000 19 | VALIDATE_ITERS: 2000 20 | VALIDATE_BEGIN: 2000 21 | SEED: 1001 22 | IMG_H: 64 23 | IMG_W: 64 24 | TEST: 25 | ISTRAIN: False 26 | IMG_H: 64 27 | IMG_W: 64 28 | DATA_LOADER: 29 | NUM_THREADS: 8 30 | CONCAT_GRID: True 31 | TYPE: ScriptDataset 32 | PATH: data 33 | DATASET: CHINESE 34 | -------------------------------------------------------------------------------- /configs/CHINESE_USER.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ENCODER_LAYERS: 2 3 | WRI_DEC_LAYERS: 2 4 | GLY_DEC_LAYERS: 2 5 | NUM_HEAD_LAYERS: 1 6 | NUM_IMGS: 15 7 | NUM_GPUS: 1 # TODO, support multi GPUs 8 | SOLVER: 9 | BASE_LR: 0.0002 10 | MAX_ITER: 200000 11 | WARMUP_ITERS: 20000 12 | TYPE: Adam # TODO, support optional optimizer 13 | GRAD_L2_CLIP: 5.0 14 | TRAIN: 15 | ISTRAIN: True 16 | IMS_PER_BATCH: 64 17 | SNAPSHOT_BEGIN: 2000 18 | SNAPSHOT_ITERS: 4000 19 | VALIDATE_ITERS: 2000 20 | VALIDATE_BEGIN: 2000 21 | SEED: 1001 22 | IMG_H: 64 23 | IMG_W: 64 24 | TEST: 25 | ISTRAIN: False 26 | IMG_H: 64 27 | IMG_W: 64 28 | DATA_LOADER: 29 | NUM_THREADS: 8 30 | CONCAT_GRID: True 31 | TYPE: UserDataset 32 | PATH: data 33 | DATASET: CHINESE -------------------------------------------------------------------------------- /configs/English_CASIA.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ENCODER_LAYERS: 2 3 | WRI_DEC_LAYERS: 2 4 | GLY_DEC_LAYERS: 2 5 | NUM_HEAD_LAYERS: 1 6 | NUM_IMGS: 15 7 | NUM_GPUS: 1 # TODO, support multi GPUs 8 | SOLVER: 9 | BASE_LR: 0.0002 10 | MAX_ITER: 200000 11 | WARMUP_ITERS: 20000 12 | TYPE: Adam # TODO, support optional optimizer 13 | GRAD_L2_CLIP: 5.0 14 | TRAIN: 15 | ISTRAIN: True 16 | IMS_PER_BATCH: 64 17 | SNAPSHOT_BEGIN: 2000 18 | SNAPSHOT_ITERS: 4000 19 | VALIDATE_ITERS: 2000 20 | VALIDATE_BEGIN: 2000 21 | SEED: 1001 22 | IMG_H: 64 23 | IMG_W: 64 24 | TEST: 25 | ISTRAIN: False 26 | IMG_H: 64 27 | IMG_W: 64 28 | DATA_LOADER: 29 | NUM_THREADS: 8 30 | CONCAT_GRID: True 31 | TYPE: ScriptDataset 32 | PATH: data 33 | DATASET: ENGLISH 34 | -------------------------------------------------------------------------------- /configs/Japanese_TUATHANDS.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ENCODER_LAYERS: 2 3 | WRI_DEC_LAYERS: 2 4 | GLY_DEC_LAYERS: 2 5 | NUM_HEAD_LAYERS: 1 6 | NUM_IMGS: 15 7 | NUM_GPUS: 1 # TODO, support multi GPUs 8 | SOLVER: 9 | BASE_LR: 0.0002 10 | MAX_ITER: 200000 11 | WARMUP_ITERS: 20000 12 | TYPE: Adam # TODO, support optional optimizer 13 | GRAD_L2_CLIP: 5.0 14 | TRAIN: 15 | ISTRAIN: True 16 | IMS_PER_BATCH: 64 17 | SNAPSHOT_BEGIN: 2000 18 | SNAPSHOT_ITERS: 4000 19 | VALIDATE_ITERS: 2000 20 | VALIDATE_BEGIN: 2000 21 | SEED: 1001 22 | IMG_H: 64 23 | IMG_W: 64 24 | TEST: 25 | ISTRAIN: False 26 | IMG_H: 64 27 | IMG_W: 64 28 | DATA_LOADER: 29 | NUM_THREADS: 8 30 | CONCAT_GRID: True 31 | TYPE: ScriptDataset 32 | PATH: data 33 | DATASET: JAPANESE 34 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | /home/mdisk/daigang/StyleWriting/data -------------------------------------------------------------------------------- /data_loader/loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | from utils.util import normalize_xys 3 | from torch.utils.data import Dataset 4 | import os 5 | import torch 6 | import numpy as np 7 | import pickle 8 | from torchvision import transforms 9 | import lmdb 10 | from utils.util import corrds2xys 11 | import codecs 12 | import glob 13 | import cv2 14 | from PIL import ImageDraw, Image 15 | transform_data = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean = (0.5), std = (0.5)) 18 | ]) 19 | 20 | script={"CHINESE":['CASIA_CHINESE', 'Chinese_content.pkl'], 21 | 'JAPANESE':['TUATHANDS_JAPANESE', 'Japanese_content.pkl'], 22 | "ENGLISH":['CASIA_ENGLISH', 'English_content.pkl'] 23 | } 24 | 25 | class ScriptDataset(Dataset): 26 | def __init__(self, root='data', dataset='CHINESE', is_train=True, num_img = 15): 27 | data_path = os.path.join(root, script[dataset][0]) 28 | self.dataset = dataset 29 | self.content = pickle.load(open(os.path.join(data_path, script[dataset][1]), 'rb')) #content samples 30 | self.char_dict = pickle.load(open(os.path.join(data_path, 'character_dict.pkl'), 'rb')) 31 | self.all_writer = pickle.load(open(os.path.join(data_path, 'writer_dict.pkl'), 'rb')) 32 | self.is_train = is_train 33 | if self.is_train: 34 | lmdb_path = os.path.join(data_path, 'train') # online characters 35 | self.img_path = os.path.join(data_path, 'train_style_samples') # style samples 36 | self.num_img = num_img*2 37 | self.writer_dict = self.all_writer['train_writer'] 38 | else: 39 | lmdb_path = os.path.join(data_path, 'test') # online characters 40 | self.img_path = os.path.join(data_path, 'test_style_samples') # style samples 41 | self.num_img = num_img 42 | self.writer_dict = self.all_writer['test_writer'] 43 | if not os.path.exists(lmdb_path): 44 | raise IOError("input the correct lmdb path") 45 | 46 | self.lmdb = lmdb.open(lmdb_path, max_readers=8, readonly=True, lock=False, readahead=False, meminit=False) 47 | if script[dataset][0] == "CASIA_CHINESE" : 48 | self.max_len = -1 # Do not filter characters with many trajectory points 49 | else: # Japanese, Indic, English 50 | self.max_len = 150 51 | 52 | self.all_path = {} 53 | for pkl in os.listdir(self.img_path): 54 | writer = pkl.split('.')[0] 55 | self.all_path[writer] = os.path.join(self.img_path, pkl) 56 | 57 | with self.lmdb.begin(write=False) as txn: 58 | self.num_sample = int(txn.get('num_sample'.encode('utf-8')).decode()) 59 | if self.max_len <= 0: 60 | self.indexes = list(range(0, self.num_sample)) 61 | else: 62 | print('Filter the characters containing more than max_len points') 63 | self.indexes = [] 64 | for i in range(self.num_sample): 65 | data_id = str(i).encode('utf-8') 66 | data_byte = txn.get(data_id) 67 | coords = pickle.loads(data_byte)['coordinates'] 68 | if len(coords) < self.max_len: 69 | self.indexes.append(i) 70 | else: 71 | pass 72 | 73 | def __getitem__(self, index): 74 | index = self.indexes[index] 75 | with self.lmdb.begin(write=False) as txn: 76 | data = pickle.loads(txn.get(str(index).encode('utf-8'))) 77 | tag_char, coords, fname = data['tag_char'], data['coordinates'], data['fname'] 78 | char_img = self.content[tag_char] # content samples 79 | char_img = char_img/255. # Normalize pixel values between 0.0 and 1.0 80 | writer = data['fname'].split('.')[0] 81 | img_path_list = self.all_path[writer] 82 | with open(img_path_list, 'rb') as f: 83 | style_samples = pickle.load(f) 84 | img_list = [] 85 | img_label = [] 86 | random_indexs = random.sample(range(len(style_samples)), self.num_img) 87 | for idx in random_indexs: 88 | tmp_img = style_samples[idx]['img'] 89 | tmp_img = tmp_img/255. 90 | tmp_label = style_samples[idx]['label'] 91 | img_list.append(tmp_img) 92 | if self.dataset == 'JAPANESE': 93 | tmp_label = bytes.fromhex(tmp_label[5:]) 94 | tmp_label = codecs.decode(tmp_label, "cp932") 95 | img_label.append(tmp_label) 96 | img_list = np.expand_dims(np.array(img_list), 1) # [N, C, H, W], C=1 97 | coords = normalize_xys(coords) # Coordinate Normalization 98 | 99 | #### Convert absolute coordinate values into relative ones 100 | coords[1:, :2] = coords[1:, :2] - coords[:-1, :2] 101 | 102 | writer_id = self.writer_dict[fname] 103 | character_id = self.char_dict.find(tag_char) 104 | label_id = [] 105 | for i in range(self.num_img): 106 | label_id.append(self.char_dict.find(img_label[i])) 107 | return {'coords': torch.Tensor(coords), 108 | 'character_id': torch.Tensor([character_id]), 109 | 'writer_id': torch.Tensor([writer_id]), 110 | 'img_list': torch.Tensor(img_list), 111 | 'char_img': torch.Tensor(char_img), 112 | 'img_label': torch.Tensor([label_id])} 113 | 114 | def __len__(self): 115 | return len(self.indexes) 116 | 117 | def collate_fn_(self, batch_data): 118 | bs = len(batch_data) 119 | max_len = max([s['coords'].shape[0] for s in batch_data]) + 1 120 | output = {'coords': torch.zeros((bs, max_len, 5)), # (x, y, state_1, state_2, state_3) 121 | 'coords_len': torch.zeros((bs, )), 122 | 'character_id': torch.zeros((bs,)), 123 | 'writer_id': torch.zeros((bs,)), 124 | 'img_list': [], 125 | 'char_img': [], 126 | 'img_label': []} 127 | output['coords'][:,:,-1] = 1 # pad to a fixed length with pen-end state 128 | 129 | for i in range(bs): 130 | s = batch_data[i]['coords'].shape[0] 131 | output['coords'][i, :s] = batch_data[i]['coords'] 132 | output['coords'][i, 0, :2] = 0 ### put pen-down state in the first token 133 | output['coords_len'][i] = s 134 | output['character_id'][i] = batch_data[i]['character_id'] 135 | output['writer_id'][i] = batch_data[i]['writer_id'] 136 | output['img_list'].append(batch_data[i]['img_list']) 137 | output['char_img'].append(batch_data[i]['char_img']) 138 | output['img_label'].append(batch_data[i]['img_label']) 139 | output['img_list'] = torch.stack(output['img_list'], 0) # -> (B, num_img, 1, H, W) 140 | temp = torch.stack(output['char_img'], 0) 141 | output['char_img'] = temp.unsqueeze(1) 142 | output['img_label'] = torch.cat(output['img_label'], 0) 143 | output['img_label'] = output['img_label'].view(-1, 1).squeeze() 144 | return output 145 | 146 | """ 147 | loading generated online characters for evaluating the generation quality 148 | """ 149 | class Online_Dataset(Dataset): 150 | def __init__(self, data_path): 151 | lmdb_path = os.path.join(data_path, 'test') 152 | print("loading characters from", lmdb_path) 153 | if not os.path.exists(lmdb_path): 154 | raise IOError("input the correct lmdb path") 155 | 156 | self.char_dict = pickle.load(open(os.path.join(data_path, 'character_dict.pkl'), 'rb')) 157 | self.writer_dict = pickle.load(open(os.path.join(data_path, 'writer_dict.pkl'), 'rb')) 158 | self.lmdb = lmdb.open(lmdb_path, max_readers=8, readonly=True, lock=False, readahead=False, meminit=False) 159 | 160 | with self.lmdb.begin(write=False) as txn: 161 | self.num_sample = int(txn.get('num_sample'.encode('utf-8')).decode()) 162 | self.indexes = list(range(0, self.num_sample)) 163 | 164 | def __getitem__(self, index): 165 | with self.lmdb.begin(write=False) as txn: 166 | data = pickle.loads(txn.get(str(index).encode('utf-8'))) 167 | character_id, coords, writer_id, coords_gt = data['character_id'], \ 168 | data['coordinates'], data['writer_id'], data['coords_gt'] 169 | try: 170 | coords, coords_gt = corrds2xys(coords), corrds2xys(coords_gt) 171 | except: 172 | print('Error in character format conversion') 173 | return self[index+1] 174 | return {'coords': torch.Tensor(coords), 175 | 'character_id': torch.Tensor([character_id]), 176 | 'writer_id': torch.Tensor([writer_id]), 177 | 'coords_gt': torch.Tensor(coords_gt)} 178 | 179 | def __len__(self): 180 | return len(self.indexes) 181 | 182 | def collate_fn_(self, batch_data): 183 | bs = len(batch_data) 184 | max_len = max([s['coords'].shape[0] for s in batch_data]) 185 | max_len_gt = max([h['coords_gt'].shape[0] for h in batch_data]) 186 | output = {'coords': torch.zeros((bs, max_len, 5)), # preds -> (x,y,state) 187 | 'coords_gt':torch.zeros((bs, max_len_gt, 5)), # gt -> (x,y,state) 188 | 'coords_len': torch.zeros((bs, )), 189 | 'len_gt': torch.zeros((bs, )), 190 | 'character_id': torch.zeros((bs,)), 191 | 'writer_id': torch.zeros((bs,))} 192 | 193 | for i in range(bs): 194 | s = batch_data[i]['coords'].shape[0] 195 | output['coords'][i, :s] = batch_data[i]['coords'] 196 | h = batch_data[i]['coords_gt'].shape[0] 197 | output['coords_gt'][i, :h] = batch_data[i]['coords_gt'] 198 | output['coords_len'][i], output['len_gt'][i] = s, h 199 | output['character_id'][i] = batch_data[i]['character_id'] 200 | output['writer_id'][i] = batch_data[i]['writer_id'] 201 | return output 202 | 203 | 204 | class UserDataset(Dataset): 205 | def __init__(self, root='data', dataset='CHINESE', style_path='style_samples'): 206 | data_path = os.path.join(root, script[dataset][0]) 207 | self.content = pickle.load(open(os.path.join(data_path, script[dataset][1]), 'rb')) #content samples 208 | self.char_dict = pickle.load(open(os.path.join(data_path, 'character_dict.pkl'), 'rb')) 209 | self.style_path = glob.glob(style_path+'/*.[jp][pn]g') 210 | 211 | def __len__(self): 212 | return len(self.char_dict) 213 | 214 | def __getitem__(self, index): 215 | char = self.char_dict[index] # content samples 216 | char_img = self.content[char] 217 | char_img = char_img/255. # Normalize pixel values between 0.0 and 1.0 218 | img_list = [] 219 | for idx in range(len(self.style_path)): 220 | style_img = cv2.imread(self.style_path[idx], flags=0) 221 | style_img = cv2.resize(style_img, (64, 64)) 222 | style_img = style_img/255. 223 | img_list.append(style_img) 224 | img_list = np.expand_dims(np.array(img_list), 1) 225 | 226 | return {'char_img': torch.Tensor(char_img).unsqueeze(0), 227 | 'img_list': torch.Tensor(img_list), 228 | 'char': char} 229 | 230 | """ 231 | loading generated offline characters for calculating the Style Score 232 | takes 15 characters belonging to the same person as one input set 233 | """ 234 | class test_offline_Style_Dataset(Dataset): 235 | def __init__(self, root=None, is_train=True, num_img=15): 236 | self.is_train = is_train 237 | self.train_path = {} 238 | self.test_path = {} 239 | self.all_len = 0 240 | self.num_img = num_img 241 | if os.path.exists(os.path.join(root, 'writer_dict.pkl')): 242 | self.writer_dict = pickle.load(open(os.path.join(root, 'writer_dict.pkl'), 'rb')) 243 | else: 244 | self.writer_dict = [i for i in range(60)] 245 | all_jpg = glob.glob(os.path.join(root,'test/*.jpg')) 246 | all_jpg = sorted(all_jpg) 247 | self.all_len = len(all_jpg) 248 | for path in all_jpg: 249 | pot = os.path.basename(path).split('_')[0] 250 | if pot in self.test_path: 251 | self.test_path[pot].append(path) 252 | else: 253 | self.test_path[pot] = [] 254 | 255 | if self.is_train: 256 | data_path = self.train_path 257 | else: 258 | data_path = self.test_path 259 | self.indexs = data_path 260 | assert len(self.indexs) > 0, "input valid dataset!" 261 | print("loading %d datasets" % (len(self.indexs))) 262 | self.num_class = len(self.writer_dict) 263 | 264 | def __getitem__(self, index): 265 | num_random = self.num_img 266 | img_list = [] 267 | label_list = [] 268 | pot_name = random.choice(list(self.indexs.keys())) 269 | tmp_path = self.indexs[pot_name] 270 | random_indexs = random.sample(tmp_path,num_random) 271 | for path in random_indexs: 272 | img = Image.open(path).convert('L') 273 | data = transform_data(img) 274 | img_list.append(data) 275 | label = int(pot_name) 276 | img_list = torch.cat(img_list,0) 277 | if num_random==1: 278 | char = os.path.basename(path).split('_')[1] 279 | return img_list, label, char 280 | return img_list, label 281 | 282 | def __len__(self): 283 | return self.all_len 284 | 285 | """ 286 | loading generated online characters for calculating the Content Score 287 | """ 288 | class Online_Gen_Dataset(Dataset): 289 | def __init__(self, data_path='lmdb', is_train=True): 290 | self.is_train = is_train 291 | if is_train: 292 | lmdb_path = os.path.join(data_path, 'train') 293 | else: 294 | lmdb_path = os.path.join(data_path, 'test') 295 | if not os.path.exists(lmdb_path): 296 | print("input the correct lmdb path") 297 | raise NotImplementedError 298 | 299 | self.char_dict = pickle.load(open(os.path.join(data_path, 'character_dict.pkl'), 'rb')) 300 | self.writer_dict = pickle.load(open(os.path.join(data_path, 'writer_dict.pkl'), 'rb')) 301 | self.lmdb = lmdb.open(lmdb_path, max_readers=8, readonly=True, lock=False, readahead=False, meminit=False) 302 | self.max_len = -1 303 | self.alphabet = '' 304 | self.cat_xy_grid = True 305 | 306 | with self.lmdb.begin(write=False) as txn: 307 | self.num_sample = int(txn.get('num_sample'.encode('utf-8')).decode()) 308 | if len(self.alphabet) <= 0: 309 | self.indexes = list(range(0, self.num_sample)) 310 | else: 311 | print('filter data out of alphabet') 312 | self.indexes = [] 313 | for i in range(self.num_sample): 314 | data_id = str(i).encode('utf-8') 315 | data_byte = txn.get(data_id) 316 | character_id = pickle.loads(data_byte)['character_id'] 317 | tag_char = self.char_dict[character_id] 318 | if tag_char in self.alphabet: 319 | self.indexes.append(i) 320 | 321 | def __getitem__(self, index): 322 | if self.is_train: 323 | index = index % (len(self)) 324 | index = self.indexes[index] 325 | 326 | with self.lmdb.begin(write=False) as txn: 327 | data = pickle.loads(txn.get(str(index).encode('utf-8'))) 328 | character_id, coords, writer_id = data['character_id'], data['coordinates'], data['writer_id'] 329 | if self.is_train and self.max_len > 0: 330 | l_seq = sum([len(l)//2 for l in coords]) 331 | if l_seq > self.max_len: 332 | print('skip {},{}'.format(index, self.char_dict[character_id])) 333 | return self[index+1] 334 | try: 335 | coords = corrds2xys(coords) 336 | except: 337 | print('error') 338 | return self[index+1] 339 | 340 | if coords is None: 341 | return self[index+1] 342 | else: 343 | pass 344 | return {'coords': torch.Tensor(coords), 345 | 'character_id': torch.Tensor([character_id]), 346 | 'writer_id': torch.Tensor([writer_id])} 347 | 348 | def __len__(self): 349 | return len(self.indexes) 350 | 351 | def collate_fn_(self, batch_data): 352 | bs = len(batch_data) 353 | max_len = max([s['coords'].shape[0] for s in batch_data]) 354 | output = {'coords': torch.zeros((bs, max_len, 5)), 355 | 'coords_len': torch.zeros((bs, )), 356 | 'character_id': torch.zeros((bs,)), 357 | 'writer_id': torch.zeros((bs,))} 358 | 359 | for i in range(bs): 360 | s = batch_data[i]['coords'].shape[0] 361 | output['coords'][i, :s] = batch_data[i]['coords'] 362 | output['coords_len'][i] = s 363 | output['character_id'][i] = batch_data[i]['character_id'] 364 | output['writer_id'][i] = batch_data[i]['writer_id'] 365 | 366 | return output -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py38h6a678d5_8 11 | - bzip2=1.0.8=h5eee18b_6 12 | - ca-certificates=2024.3.11=h06a4308_0 13 | - certifi=2024.6.2=py38h06a4308_0 14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 15 | - cuda-cudart=11.7.99=0 16 | - cuda-cupti=11.7.101=0 17 | - cuda-libraries=11.7.1=0 18 | - cuda-nvrtc=11.7.99=0 19 | - cuda-nvtx=11.7.91=0 20 | - cuda-runtime=11.7.1=0 21 | - cuda-version=12.5=3 22 | - ffmpeg=4.3=hf484d3e_0 23 | - freetype=2.12.1=h4a9f257_0 24 | - gmp=6.2.1=h295c915_3 25 | - gnutls=3.6.15=he1e5248_0 26 | - idna=3.7=py38h06a4308_0 27 | - intel-openmp=2023.1.0=hdb19cb5_46306 28 | - jpeg=9e=h5eee18b_1 29 | - lame=3.100=h7b6447c_0 30 | - lcms2=2.12=h3be6417_0 31 | - ld_impl_linux-64=2.38=h1181459_1 32 | - lerc=3.0=h295c915_0 33 | - libcublas=11.10.3.66=0 34 | - libcufft=10.7.2.124=h4fbf590_0 35 | - libcufile=1.10.0.4=0 36 | - libcurand=10.3.6.39=0 37 | - libcusolver=11.4.0.1=0 38 | - libcusparse=11.7.4.91=0 39 | - libdeflate=1.17=h5eee18b_1 40 | - libffi=3.4.4=h6a678d5_1 41 | - libgcc-ng=11.2.0=h1234567_1 42 | - libgomp=11.2.0=h1234567_1 43 | - libiconv=1.16=h5eee18b_3 44 | - libidn2=2.3.4=h5eee18b_0 45 | - libnpp=11.7.4.75=0 46 | - libnvjpeg=11.8.0.2=0 47 | - libpng=1.6.39=h5eee18b_0 48 | - libstdcxx-ng=11.2.0=h1234567_1 49 | - libtasn1=4.19.0=h5eee18b_0 50 | - libtiff=4.5.1=h6a678d5_0 51 | - libunistring=0.9.10=h27cfd23_0 52 | - libwebp-base=1.3.2=h5eee18b_0 53 | - lz4-c=1.9.4=h6a678d5_1 54 | - mkl=2023.1.0=h213fc3f_46344 55 | - mkl-service=2.4.0=py38h5eee18b_1 56 | - mkl_fft=1.3.8=py38h5eee18b_0 57 | - mkl_random=1.2.4=py38hdb19cb5_0 58 | - ncurses=6.4=h6a678d5_0 59 | - nettle=3.7.3=hbbd107a_1 60 | - numpy=1.24.3=py38hf6e8229_1 61 | - numpy-base=1.24.3=py38h060ed82_1 62 | - openh264=2.1.1=h4ff587b_0 63 | - openjpeg=2.4.0=h9ca470c_1 64 | - openssl=3.0.14=h5eee18b_0 65 | - pillow=10.3.0=py38h5eee18b_0 66 | - pip=24.0=py38h06a4308_0 67 | - pysocks=1.7.1=py38h06a4308_0 68 | - python=3.8.19=h955ad1f_0 69 | - pytorch=1.13.0=py3.8_cuda11.7_cudnn8.5.0_0 70 | - pytorch-cuda=11.7=h778d358_5 71 | - pytorch-mutex=1.0=cuda 72 | - readline=8.2=h5eee18b_0 73 | - requests=2.32.2=py38h06a4308_0 74 | - setuptools=69.5.1=py38h06a4308_0 75 | - sqlite=3.45.3=h5eee18b_0 76 | - tbb=2021.8.0=hdb19cb5_0 77 | - tk=8.6.14=h39e8969_0 78 | - torchaudio=0.13.0=py38_cu117 79 | - torchvision=0.14.0=py38_cu117 80 | - typing_extensions=4.11.0=py38h06a4308_0 81 | - urllib3=2.2.2=py38h06a4308_0 82 | - wheel=0.43.0=py38h06a4308_0 83 | - xz=5.4.6=h5eee18b_1 84 | - zlib=1.2.13=h5eee18b_1 85 | - zstd=1.5.5=hc292b87_2 86 | - pip: 87 | - easydict==1.13 88 | - einops==0.8.0 89 | - lmdb==1.5.1 90 | - opencv-python==4.5.2.54 91 | prefix: /home/daigang/miniconda3/envs/test 92 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from data_loader.loader import Online_Dataset, test_offline_Style_Dataset, Online_Gen_Dataset 3 | import torch 4 | import numpy as np 5 | import tqdm 6 | from fastdtw import fastdtw 7 | from utils.metrics import * 8 | 9 | def main(opt): 10 | if opt.metric == 'DTW': 11 | """ set dataloader""" 12 | test_dataset = Online_Dataset(opt.data_path) 13 | print('loading generated samples, the total amount of samples is', len(test_dataset)) 14 | test_loader = torch.utils.data.DataLoader(test_dataset, 15 | batch_size=opt.batchsize, 16 | shuffle=True, 17 | sampler=None, 18 | drop_last=False, 19 | collate_fn=test_dataset.collate_fn_, 20 | num_workers=8) 21 | DTW = fast_norm_len_dtw(test_loader) 22 | print(f"the avg fast_norm_len_dtw is {DTW}") 23 | 24 | if opt.metric == 'Style_score': 25 | test_dataset = test_offline_Style_Dataset(opt.data_path,False) 26 | print('num testing images:', len(test_dataset)) 27 | test_loader = torch.utils.data.DataLoader(test_dataset, 28 | batch_size=opt.batchsize, 29 | shuffle=True, 30 | drop_last=False, 31 | pin_memory=True, 32 | num_workers=8) 33 | style_score = get_style_score(test_loader, opt.pretrained_model) 34 | print(f"the style_score is {style_score}") 35 | 36 | if opt.metric == 'Content_score': 37 | test_dataset = Online_Gen_Dataset(opt.data_path, False) 38 | print('num test images: ', len(test_dataset)) 39 | test_loader = torch.utils.data.DataLoader(test_dataset, 40 | batch_size=opt.batchsize, 41 | shuffle=True, 42 | sampler=None, 43 | drop_last=False, 44 | collate_fn=test_dataset.collate_fn_, 45 | num_workers=8) 46 | content_score = get_content_score(test_loader, opt.pretrained_model) 47 | print(f"the content_score is {content_score}") 48 | 49 | if __name__ == '__main__': 50 | """Parse input arguments""" 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('--data_path', type=str, dest='data_path', default='Generated/Chinese', 53 | help='dataset path for evaluating the metrics') 54 | parser.add_argument('--metric', type=str, default='DTW', help='the metric to evaluate the generated data, DTW, Style_score or Content_score') 55 | parser.add_argument('--batchsize', type=int, default=64) 56 | parser.add_argument('--pretrained_model', type=str, default='model_zoo/chinese_style_iter60k_acc999.pth', help='pre-trained model for calculating Style Score or Content Score') 57 | opt = parser.parse_args() 58 | main(opt) -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from models.transformer import * 5 | from einops import rearrange 6 | 7 | ### content encoder 8 | class Content_TR(nn.Module): 9 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=3, 10 | dim_feedforward=2048, dropout=0.1, activation="relu", 11 | normalize_before=True): 12 | super(Content_TR, self).__init__() 13 | self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2])) 14 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 15 | dropout, activation, normalize_before) 16 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 17 | self.add_position = PositionalEncoding(dropout=0.1, dim=d_model) 18 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, norm=encoder_norm) 19 | 20 | def forward(self, x): 21 | x = self.Feat_Encoder(x) 22 | #x = self.recti_channel(x) 23 | x = rearrange(x, 'n c h w -> (h w) n c') 24 | x = self.add_position(x) 25 | x = self.encoder(x) 26 | return x 27 | 28 | ### For the training of Chinese handwriting generation task, 29 | ### we first pre-train the content encoder for character classification. 30 | ### No need to pre-train the encoder in other languages (e.g, Japanese, English and Indic). 31 | 32 | class Content_Cls(nn.Module): 33 | def __init__(self, d_model=512, num_encoder_layers=3, num_classes=6763) -> None: 34 | super(Content_Cls, self).__init__() 35 | self.feature_ext = Content_TR(d_model, num_encoder_layers) 36 | self.cls_head = nn.Linear(d_model, num_classes) 37 | self._reset_parameters() 38 | 39 | def _reset_parameters(self): 40 | for p in self.parameters(): 41 | if p.dim() > 1: 42 | nn.init.xavier_uniform_(p) 43 | 44 | def forward(self, x): 45 | x = self.feature_ext(x) 46 | x = torch.mean(x, 0) 47 | out = self.cls_head(x) 48 | return out -------------------------------------------------------------------------------- /models/eval_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from ptflops import get_model_complexity_info 5 | 6 | ### The model used to calculate Style Score 7 | class offline_style(nn.Module): 8 | def __init__(self, num_class=240, vote=False): 9 | super(offline_style, self).__init__() 10 | self.l1 = nn.Sequential(nn.Conv2d(1,96,5,2),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(3,2)) 11 | self.l2 = nn.Sequential(nn.Conv2d(96,256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(3,2)) 12 | self.l3 = nn.Sequential(nn.Conv2d(256,384,3,1,1),nn.BatchNorm2d(384),nn.ReLU(), 13 | nn.Conv2d(384,384,3,1,1),nn.BatchNorm2d(384),nn.ReLU(),nn.MaxPool2d(3,2)) 14 | self.l4 = nn.Sequential(nn.Conv2d(384,256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(3,2)) 15 | self.fc1 = nn.Sequential(nn.Flatten(1),nn.Linear(1024, num_class)) 16 | self.vote = vote 17 | 18 | def forward(self,x): 19 | if self.vote: 20 | n,c,h,w = x.size() 21 | if not self.training: 22 | x = x.view(n*c,1,h,w) 23 | out = self.l1(x) 24 | out = self.l2(out) 25 | out = self.l3(out) 26 | out = self.l4(out) 27 | n1,c1,h1,w1 = out.size() 28 | out = self.fc1(out) 29 | if not self.training: 30 | out = out.view(n,c,-1) 31 | return out 32 | else: 33 | n,c,h,w = x.size() 34 | x = x.view(n*c,1,h,w) 35 | out = self.l1(x) 36 | out = self.l2(out) 37 | out = self.l3(out) 38 | out = self.l4(out) 39 | n1,c1,h1,w1 = out.size() 40 | out = torch.mean(out.view(n,c,c1,h1,w1),1) 41 | out = self.fc1(out) 42 | return out 43 | 44 | ### The model used to calculate Content Score 45 | class Character_Net(nn.Module): 46 | def __init__(self, nclass=3755): 47 | super(Character_Net, self).__init__() 48 | self.l1 = nn.Sequential(nn.Conv1d(5, 64, kernel_size=7, stride=1, padding=3), nn.ReLU(), nn.BatchNorm1d(64)) 49 | 50 | self.l2 = nn.MaxPool1d(kernel_size=2, stride=2) 51 | 52 | self.l3 = nn.Sequential(nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.BatchNorm1d(64)) 53 | 54 | self.l4 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.BatchNorm1d(128)) 55 | 56 | self.l5 = nn.MaxPool1d(kernel_size=2, stride=2) 57 | 58 | self.l6 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.BatchNorm1d(128)) 59 | 60 | self.l7 = nn.Sequential(nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.BatchNorm1d(256)) 61 | 62 | self.l8 = nn.MaxPool1d(kernel_size=2, stride=2) 63 | 64 | self.l9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU()) 65 | 66 | print('num of character is {}'.format(nclass)) 67 | self.l11 = nn.Linear(256, nclass) 68 | 69 | def forward(self, x, l): 70 | x = self.l1(x) 71 | x = self.l2(x) 72 | x = self.l3(x) 73 | x = self.l4(x) 74 | x = self.l5(x) 75 | x = self.l6(x) 76 | x = self.l7(x) 77 | x = self.l8(x) 78 | x = self.l9(x) 79 | hidden = mask_avr_pooling(x, torch.div(l, 8, rounding_mode='floor')) 80 | x = self.l11(hidden) 81 | return x 82 | 83 | def mask_avr_pooling_rnn(x, l): 84 | N,T,C = x.size() 85 | mask = length_to_mask(l, max_len=T) 86 | mask = mask.unsqueeze(-1) 87 | o = torch.sum(x*mask, dim=-2, keepdim=False) 88 | o = o/(l.unsqueeze(-1)+1e-5) 89 | return o 90 | 91 | def mask_avr_pooling(x, l): 92 | N,C,T = x.size() 93 | mask = length_to_mask(l, max_len=T) 94 | mask = mask.unsqueeze(1) 95 | o = torch.sum(x*mask, dim=-1, keepdim=False) 96 | o = o/(l.unsqueeze(-1)+1e-5) 97 | return o 98 | 99 | def length_to_mask(length, max_len=None, dtype=None): 100 | """length: B. 101 | return B x max_len. 102 | If max_len is None, then max of length will be used. 103 | """ 104 | assert len(length.shape) == 1, 'Length shape should be 1 dimensional.' 105 | max_len = max_len or length.max().item() 106 | mask = torch.arange(max_len, device=length.device, 107 | dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1) 108 | if dtype is not None: 109 | mask = torch.as_tensor(mask, dtype=dtype, device=length.device) 110 | return mask -------------------------------------------------------------------------------- /models/gmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ### split final output of our model into Mixture Density Network (MDN) parameters and pen state 4 | def get_mixture_coef(output): 5 | z = output 6 | z_pen_logits = z[:, 0:3] # pen state 7 | 8 | # MDN parameters are used to predict the pen moving 9 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = torch.split(z[:, 3:], 20, 1) 10 | 11 | # softmax pi weights: 12 | z_pi = torch.softmax(z_pi, -1) 13 | 14 | # exponentiate the sigmas and also make corr between -1 and 1. 15 | z_sigma1 = torch.minimum(torch.exp(z_sigma1), torch.Tensor([500.0]).cuda()) 16 | z_sigma2 = torch.minimum(torch.exp(z_sigma2), torch.Tensor([500.0]).cuda()) 17 | z_corr = torch.tanh(z_corr) 18 | result = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits] 19 | return result 20 | 21 | ### generate the pen moving and state from the predict output 22 | def get_seq_from_gmm(gmm_pred): 23 | gmm_pred = gmm_pred.reshape(-1, 123) 24 | [pi, mu1, mu2, sigma1, sigma2, corr, pen_logits] = get_mixture_coef(gmm_pred) 25 | max_mixture_idx = torch.stack([torch.arange(pi.shape[0], dtype=torch.int64).cuda(), torch.argmax(pi, 1)], 1) 26 | next_x1 = mu1[list(max_mixture_idx.T)] 27 | next_x2 = mu2[list(max_mixture_idx.T)] 28 | pen_state = torch.argmax(gmm_pred[:, :3], dim=-1) 29 | pen_state = torch.nn.functional.one_hot(pen_state, num_classes=3).to(gmm_pred) 30 | seq_pred = torch.cat([next_x1.unsqueeze(1), next_x2.unsqueeze(1), pen_state],-1) 31 | return seq_pred -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | class SupConLoss(nn.Module): 7 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 8 | It also supports the unsupervised contrastive loss in SimCLR""" 9 | def __init__(self, temperature=0.07, contrast_mode='all', 10 | base_temperature=0.07): 11 | super(SupConLoss, self).__init__() 12 | self.temperature = temperature 13 | self.contrast_mode = contrast_mode 14 | self.base_temperature = base_temperature 15 | 16 | def forward(self, features, labels=None, mask=None): 17 | """Compute loss for model. If both `labels` and `mask` are None, 18 | it degenerates to SimCLR unsupervised loss: 19 | https://arxiv.org/pdf/2002.05709.pdf 20 | Args: 21 | features: hidden vector of shape [bsz, n_views, ...]. 22 | labels: ground truth of shape [bsz]. 23 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 24 | has the same class as sample i. Can be asymmetric. 25 | Returns: 26 | A loss scalar. 27 | """ 28 | device = (torch.device('cuda') 29 | if features.is_cuda 30 | else torch.device('cpu')) 31 | 32 | if len(features.shape) < 3: 33 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 34 | 'at least 3 dimensions are required') 35 | if len(features.shape) > 3: 36 | features = features.view(features.shape[0], features.shape[1], -1) 37 | 38 | batch_size = features.shape[0] 39 | if labels is not None and mask is not None: 40 | raise ValueError('Cannot define both `labels` and `mask`') 41 | elif labels is None and mask is None: 42 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 43 | elif labels is not None: 44 | labels = labels.contiguous().view(-1, 1) 45 | if labels.shape[0] != batch_size: 46 | raise ValueError('Num of labels does not match num of features') 47 | mask = torch.eq(labels, labels.T).float().to(device) 48 | else: 49 | mask = mask.float().to(device) 50 | 51 | contrast_count = features.shape[1] 52 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 53 | if self.contrast_mode == 'one': 54 | anchor_feature = features[:, 0] 55 | anchor_count = 1 56 | elif self.contrast_mode == 'all': 57 | anchor_feature = contrast_feature 58 | anchor_count = contrast_count 59 | else: 60 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 61 | 62 | # compute logits 63 | anchor_dot_contrast = torch.div( 64 | torch.matmul(anchor_feature, contrast_feature.T), 65 | self.temperature) 66 | # for numerical stability 67 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 68 | logits = anchor_dot_contrast - logits_max.detach() 69 | 70 | # tile mask 71 | mask = mask.repeat(anchor_count, contrast_count) 72 | # mask-out self-contrast cases 73 | logits_mask = torch.scatter( 74 | torch.ones_like(mask), 75 | 1, 76 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 77 | 0 78 | ) 79 | mask = mask * logits_mask 80 | 81 | # compute log_prob 82 | exp_logits = torch.exp(logits) * logits_mask 83 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 84 | 85 | # compute mean of log-likelihood over positive 86 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 87 | 88 | # loss 89 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 90 | loss = loss.view(anchor_count, batch_size).mean() 91 | 92 | return loss 93 | 94 | 95 | """pen moving prediction and pen state classification losses""" 96 | def get_pen_loss(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits, x1_data, x2_data, 97 | pen_data): 98 | result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr) 99 | epsilon = 1e-10 100 | # result1 is the loss wrt pen offset 101 | result1 = torch.multiply(result0, z_pi) 102 | result1 = torch.sum(result1, 1, keepdims=True) 103 | result1 = - torch.log(result1 + epsilon) # avoid log(0) 104 | 105 | fs = 1.0 - pen_data[:, 2] # use training data for this 106 | fs = fs.reshape(-1, 1) 107 | # Zero out loss terms beyond N_s, the last actual stroke 108 | result1 = torch.multiply(result1, fs) 109 | loss_fn = torch.nn.CrossEntropyLoss() 110 | result2 = loss_fn(z_pen_logits, torch.argmax(pen_data, -1)) 111 | return result1, result2 # result1: pen offset loss, result2: category loss 112 | 113 | """Normal distribution""" 114 | def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): 115 | s1 = torch.clip(s1, 1e-6, 500.0) 116 | s2 = torch.clip(s2, 1e-6, 500.0) 117 | 118 | norm1 = torch.subtract(x1, mu1) # Returns x1-mu1 element-wise 119 | norm2 = torch.subtract(x2, mu2) 120 | s1s2 = torch.multiply(s1, s2) 121 | 122 | z = (torch.square(torch.div(norm1, s1)) + torch.square(torch.div(norm2, s2)) - 123 | 2 * torch.div(torch.multiply(rho, torch.multiply(norm1, norm2)), s1s2)) 124 | neg_rho = torch.clip(1 - torch.square(rho), 1e-6, 1.0) 125 | result = torch.exp(torch.div(-z, 2 * neg_rho)) 126 | denom = 2 * np.pi * torch.multiply(s1s2, torch.sqrt(neg_rho)) 127 | result = torch.div(result, denom) 128 | return result -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from models.transformer import * 5 | from models.encoder import Content_TR 6 | from einops import rearrange, repeat 7 | from models.gmm import get_seq_from_gmm 8 | 9 | ''' 10 | the overall architecture of our style-disentangled Transformer (SDT). 11 | the input of our SDT is the gray image with 1 channel. 12 | ''' 13 | class SDT_Generator(nn.Module): 14 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=2, num_head_layers= 1, 15 | wri_dec_layers=2, gly_dec_layers=2, dim_feedforward=2048, dropout=0.1, 16 | activation="relu", normalize_before=True, return_intermediate_dec=True): 17 | super(SDT_Generator, self).__init__() 18 | ### style encoder with dual heads 19 | self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2])) 20 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 21 | dropout, activation, normalize_before) 22 | self.base_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, None) 23 | writer_norm = nn.LayerNorm(d_model) if normalize_before else None 24 | glyph_norm = nn.LayerNorm(d_model) if normalize_before else None 25 | self.writer_head = TransformerEncoder(encoder_layer, num_head_layers, writer_norm) 26 | self.glyph_head = TransformerEncoder(encoder_layer, num_head_layers, glyph_norm) 27 | 28 | ### content ecoder 29 | self.content_encoder = Content_TR(d_model, num_encoder_layers) 30 | 31 | ### decoder for receiving writer-wise and character-wise styles 32 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 33 | dropout, activation, normalize_before) 34 | wri_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None 35 | self.wri_decoder = TransformerDecoder(decoder_layer, wri_dec_layers, wri_decoder_norm, 36 | return_intermediate=return_intermediate_dec) 37 | gly_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None 38 | self.gly_decoder = TransformerDecoder(decoder_layer, gly_dec_layers, gly_decoder_norm, 39 | return_intermediate=return_intermediate_dec) 40 | 41 | ### two mlps that project style features into the space where nce_loss is applied 42 | self.pro_mlp_writer = nn.Sequential( 43 | nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256)) 44 | self.pro_mlp_character = nn.Sequential( 45 | nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256)) 46 | 47 | self.SeqtoEmb = SeqtoEmb(hid_dim=d_model) 48 | self.EmbtoSeq = EmbtoSeq(hid_dim=d_model) 49 | self.add_position = PositionalEncoding(dropout=0.1, dim=d_model) 50 | self._reset_parameters() 51 | 52 | 53 | def _reset_parameters(self): 54 | for p in self.parameters(): 55 | if p.dim() > 1: 56 | nn.init.xavier_uniform_(p) 57 | 58 | def random_double_sampling(self, x, ratio=0.25): 59 | """ 60 | Sample the positive pair (i.e., o and o^+) within a character by per-sample shuffling. 61 | Per-sample shuffling is done by argsort random noise. 62 | x: [L, B, N, D], sequence 63 | return o [B, N, 1, D], o^+ [B, N, 1, D] 64 | """ 65 | L, B, N, D = x.shape # length, batch, group_number, dim 66 | x = rearrange(x, "L B N D -> B N L D") 67 | noise = torch.rand(B, N, L, device=x.device) # noise in [0, 1] 68 | # sort noise for each sample 69 | ids_shuffle = torch.argsort(noise, dim=2) 70 | 71 | anchor_tokens, pos_tokens = int(L*ratio), int(L*2*ratio) 72 | ids_keep_anchor, ids_keep_pos = ids_shuffle[:, :, :anchor_tokens], ids_shuffle[:, :, anchor_tokens:pos_tokens] 73 | x_anchor = torch.gather( 74 | x, dim=2, index=ids_keep_anchor.unsqueeze(-1).repeat(1, 1, 1, D)) 75 | x_pos = torch.gather( 76 | x, dim=2, index=ids_keep_pos.unsqueeze(-1).repeat(1, 1, 1, D)) 77 | return x_anchor, x_pos 78 | 79 | # the shape of style_imgs is [B, 2*N, C, H, W] during training 80 | def forward(self, style_imgs, seq, char_img): 81 | batch_size, num_imgs, in_planes, h, w = style_imgs.shape 82 | 83 | # style_imgs: [B, 2*N, C:1, H, W] -> FEAT_ST_ENC: [4*N, B, C:512] 84 | style_imgs = style_imgs.view(-1, in_planes, h, w) # [B*2N, C:1, H, W] 85 | style_embe = self.Feat_Encoder(style_imgs) # [B*2N, C:512, 2, 2] 86 | 87 | anchor_num = num_imgs//2 88 | style_embe = style_embe.view(batch_size*num_imgs, 512, -1).permute(2, 0, 1) # [4, B*2N, C:512] 89 | FEAT_ST_ENC = self.add_position(style_embe) 90 | 91 | memory = self.base_encoder(FEAT_ST_ENC) # [4, B*2N, C] 92 | writer_memory = self.writer_head(memory) 93 | glyph_memory = self.glyph_head(memory) 94 | 95 | writer_memory = rearrange(writer_memory, 't (b p n) c -> t (p b) n c', 96 | b=batch_size, p=2, n=anchor_num) # [4, 2*B, N, C] 97 | glyph_memory = rearrange(glyph_memory, 't (b p n) c -> t (p b) n c', 98 | b=batch_size, p=2, n=anchor_num) # [4, 2*B, N, C] 99 | 100 | # writer-nce 101 | memory_fea = rearrange(writer_memory, 't b n c ->(t n) b c') # [4*N, 2*B, C] 102 | compact_fea = torch.mean(memory_fea, 0) # [2*B, C] 103 | # compact_fea:[2*B, C:512] -> nce_emb: [B, 2, C:128] 104 | pro_emb = self.pro_mlp_writer(compact_fea) 105 | query_emb = pro_emb[:batch_size, :] 106 | pos_emb = pro_emb[batch_size:, :] 107 | nce_emb = torch.stack((query_emb, pos_emb), 1) # [B, 2, C] 108 | nce_emb = nn.functional.normalize(nce_emb, p=2, dim=2) 109 | 110 | # glyph-nce 111 | patch_emb = glyph_memory[:, :batch_size] # [4, B, N, C] 112 | # sample the positive pair 113 | anc, positive = self.random_double_sampling(patch_emb) 114 | n_channels = anc.shape[-1] 115 | anc = anc.reshape(batch_size, -1, n_channels) 116 | anc_compact = torch.mean(anc, 1, keepdim=True) 117 | anc_compact = self.pro_mlp_character(anc_compact) # [B, 1, C] 118 | positive = positive.reshape(batch_size, -1, n_channels) 119 | positive_compact = torch.mean(positive, 1, keepdim=True) 120 | positive_compact = self.pro_mlp_character(positive_compact) # [B, 1, C] 121 | 122 | nce_emb_patch = torch.cat((anc_compact, positive_compact), 1) # [B, 2, C] 123 | nce_emb_patch = nn.functional.normalize(nce_emb_patch, p=2, dim=2) 124 | 125 | # input the writer-wise & character-wise styles into the decoder 126 | writer_style = memory_fea[:, :batch_size, :] # [4*N, B, C] 127 | glyph_style = glyph_memory[:, :batch_size] # [4, B, N, C] 128 | glyph_style = rearrange(glyph_style, 't b n c -> (t n) b c') # [4*N, B, C] 129 | 130 | # QUERY: [char_emb, seq_emb] 131 | seq_emb = self.SeqtoEmb(seq).permute(1, 0, 2) 132 | T, N, C = seq_emb.shape 133 | 134 | char_emb = self.content_encoder(char_img) # [4, N, 512] 135 | char_emb = torch.mean(char_emb, 0) #[N, 512] 136 | char_emb = repeat(char_emb, 'n c -> t n c', t = 1) 137 | tgt = torch.cat((char_emb, seq_emb), 0) # [1+T], put the content token as the first token 138 | tgt_mask = generate_square_subsequent_mask(sz=(T+1)).to(tgt) 139 | tgt = self.add_position(tgt) 140 | 141 | # [wri_dec_layers, T, B, C] 142 | wri_hs = self.wri_decoder(tgt, writer_style, tgt_mask=tgt_mask) 143 | # [gly_dec_layers, T, B, C] 144 | hs = self.gly_decoder(wri_hs[-1], glyph_style, tgt_mask=tgt_mask) 145 | 146 | h = hs.transpose(1, 2)[-1] # B T C 147 | pred_sequence = self.EmbtoSeq(h) 148 | return pred_sequence, nce_emb, nce_emb_patch 149 | 150 | # style_imgs: [B, N, C, H, W] 151 | def inference(self, style_imgs, char_img, max_len): 152 | batch_size, num_imgs, in_planes, h, w = style_imgs.shape 153 | # [B, N, C, H, W] -> [B*N, C, H, W] 154 | style_imgs = style_imgs.view(-1, in_planes, h, w) 155 | # [B*N, 1, 64, 64] -> [B*N, 512, 2, 2] 156 | style_embe = self.Feat_Encoder(style_imgs) 157 | FEAT_ST = style_embe.reshape(batch_size*num_imgs, 512, -1).permute(2, 0, 1) # [4, B*N, C] 158 | FEAT_ST_ENC = self.add_position(FEAT_ST) # [4, B*N, C:512] 159 | memory = self.base_encoder(FEAT_ST_ENC) # [5, B*N, C] 160 | memory_writer = self.writer_head(memory) # [4, B*N, C] 161 | memory_glyph = self.glyph_head(memory) # [4, B*N, C] 162 | memory_writer = rearrange( 163 | memory_writer, 't (b n) c ->(t n) b c', b=batch_size) # [4*N, B, C] 164 | memory_glyph = rearrange( 165 | memory_glyph, 't (b n) c -> (t n) b c', b=batch_size) # [4*N, B, C] 166 | 167 | char_emb = self.content_encoder(char_img) 168 | char_emb = torch.mean(char_emb, 0) #[N, 256] 169 | src_tensor = torch.zeros(max_len + 1, batch_size, 512).to(char_emb) 170 | pred_sequence = torch.zeros(max_len, batch_size, 5).to(char_emb) 171 | src_tensor[0] = char_emb 172 | tgt_mask = generate_square_subsequent_mask(sz=max_len + 1).to(char_emb) 173 | for i in range(max_len): 174 | src_tensor[i] = self.add_position(src_tensor[i], step=i) 175 | 176 | wri_hs = self.wri_decoder( 177 | src_tensor, memory_writer, tgt_mask=tgt_mask) 178 | hs = self.gly_decoder(wri_hs[-1], memory_glyph, tgt_mask=tgt_mask) 179 | 180 | output_hid = hs[-1][i] 181 | gmm_pred = self.EmbtoSeq(output_hid) 182 | pred_sequence[i] = get_seq_from_gmm(gmm_pred) 183 | pen_state = pred_sequence[i, :, 2:] 184 | seq_emb = self.SeqtoEmb(pred_sequence[i]) 185 | src_tensor[i + 1] = seq_emb 186 | if sum(pen_state[:, -1]) == batch_size: 187 | break 188 | else: 189 | pass 190 | return pred_sequence.transpose(0, 1) # N, T, C 191 | 192 | ''' 193 | project the handwriting sequences to the transformer hidden space 194 | ''' 195 | class SeqtoEmb(nn.Module): 196 | def __init__(self, hid_dim, dropout=0.1): 197 | super().__init__() 198 | self.fc_1 = nn.Linear(5, 256) 199 | self.fc_2 = nn.Linear(256, hid_dim) 200 | self.dropout = nn.Dropout(dropout) 201 | 202 | def forward(self, seq): 203 | x = self.dropout(torch.relu(self.fc_1(seq))) 204 | x = self.fc_2(x) 205 | return x 206 | 207 | ''' 208 | project the transformer hidden space to handwriting sequences 209 | ''' 210 | class EmbtoSeq(nn.Module): 211 | def __init__(self, hid_dim, dropout=0.1): 212 | super().__init__() 213 | self.fc_1 = nn.Linear(hid_dim, 256) 214 | self.fc_2 = nn.Linear(256, 123) 215 | self.dropout = nn.Dropout(dropout) 216 | 217 | def forward(self, seq): 218 | x = self.dropout(torch.relu(self.fc_1(seq))) 219 | x = self.fc_2(x) 220 | return x 221 | 222 | 223 | ''' 224 | generate the attention mask, i.e. [[0, inf, inf], 225 | [0, 0, inf], 226 | [0, 0, 0]]. 227 | The masked positions are filled with float('-inf'). 228 | Unmasked positions are filled with float(0.0). 229 | ''' 230 | def generate_square_subsequent_mask(sz: int) -> Tensor: 231 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 232 | mask = mask.float().masked_fill(mask == 0, float( 233 | '-inf')).masked_fill(mask == 1, float(0.0)) 234 | return mask -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copy-paste from torch.nn.Transformer with modifications: 3 | * positional encodings are passed in MHattention 4 | * extra LN at the end of encoder is removed 5 | * decoder returns a stack of activations from all decoding layers 6 | ''' 7 | 8 | import copy 9 | from typing import Optional, List 10 | import math 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, Tensor 14 | 15 | 16 | class Transformer(nn.Module): 17 | 18 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 19 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 20 | activation="relu", normalize_before=False, 21 | return_intermediate_dec=False): 22 | super().__init__() 23 | 24 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 25 | dropout, activation, normalize_before) 26 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 27 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 28 | 29 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 30 | dropout, activation, normalize_before) 31 | decoder_norm = nn.LayerNorm(d_model) 32 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 33 | return_intermediate=return_intermediate_dec) 34 | 35 | self._reset_parameters() 36 | 37 | self.d_model = d_model 38 | self.nhead = nhead 39 | 40 | def _reset_parameters(self): 41 | for p in self.parameters(): 42 | if p.dim() > 1: 43 | nn.init.xavier_uniform_(p) 44 | 45 | def forward(self, src, query_embed, y_ind): 46 | # flatten NxCxHxW to HWxNxC 47 | bs, c, h, w = src.shape 48 | src = src.flatten(2).permute(2, 0, 1) 49 | 50 | y_emb = query_embed[y_ind].permute(1,0,2) 51 | 52 | tgt = torch.zeros_like(y_emb) 53 | memory = self.encoder(src) 54 | hs = self.decoder(tgt, memory, query_pos=y_emb) 55 | 56 | return torch.cat([hs.transpose(1, 2)[-1], y_emb.permute(1,0,2)], -1) 57 | 58 | 59 | class TransformerEncoder(nn.Module): 60 | 61 | def __init__(self, encoder_layer, num_layers, norm=None): 62 | super().__init__() 63 | self.layers = _get_clones(encoder_layer, num_layers) 64 | self.num_layers = num_layers 65 | self.norm = norm 66 | 67 | def forward(self, src, 68 | mask: Optional[Tensor] = None, 69 | src_key_padding_mask: Optional[Tensor] = None, 70 | pos: Optional[Tensor] = None): 71 | output = src 72 | 73 | for layer in self.layers: 74 | output = layer(output, src_mask=mask, 75 | src_key_padding_mask=src_key_padding_mask, pos=pos) 76 | 77 | if self.norm is not None: 78 | output = self.norm(output) 79 | 80 | return output 81 | 82 | 83 | class TransformerDecoder(nn.Module): 84 | 85 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 86 | super().__init__() 87 | self.layers = _get_clones(decoder_layer, num_layers) 88 | self.num_layers = num_layers 89 | self.norm = norm 90 | self.return_intermediate = return_intermediate 91 | 92 | def forward(self, tgt, memory, 93 | tgt_mask: Optional[Tensor] = None, 94 | memory_mask: Optional[Tensor] = None, 95 | tgt_key_padding_mask: Optional[Tensor] = None, 96 | memory_key_padding_mask: Optional[Tensor] = None, 97 | pos: Optional[Tensor] = None, 98 | query_pos: Optional[Tensor] = None): 99 | output = tgt 100 | 101 | intermediate = [] 102 | 103 | for layer in self.layers: 104 | output = layer(output, memory, tgt_mask=tgt_mask, 105 | memory_mask=memory_mask, 106 | tgt_key_padding_mask=tgt_key_padding_mask, 107 | memory_key_padding_mask=memory_key_padding_mask, 108 | pos=pos, query_pos=query_pos) 109 | if self.return_intermediate: 110 | intermediate.append(self.norm(output)) 111 | 112 | if self.norm is not None: 113 | output = self.norm(output) 114 | if self.return_intermediate: 115 | intermediate.pop() 116 | intermediate.append(output) 117 | 118 | if self.return_intermediate: 119 | return torch.stack(intermediate) 120 | 121 | return output.unsqueeze(0) 122 | 123 | 124 | class TransformerEncoderLayer(nn.Module): 125 | 126 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 127 | activation="relu", normalize_before=False): 128 | super().__init__() 129 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 130 | # Implementation of Feedforward model 131 | self.linear1 = nn.Linear(d_model, dim_feedforward) 132 | self.dropout = nn.Dropout(dropout) 133 | self.linear2 = nn.Linear(dim_feedforward, d_model) 134 | 135 | self.norm1 = nn.LayerNorm(d_model) 136 | self.norm2 = nn.LayerNorm(d_model) 137 | self.dropout1 = nn.Dropout(dropout) 138 | self.dropout2 = nn.Dropout(dropout) 139 | 140 | self.activation = _get_activation_fn(activation) 141 | self.normalize_before = normalize_before 142 | 143 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 144 | return tensor if pos is None else tensor + pos 145 | 146 | def forward_post(self, 147 | src, 148 | src_mask: Optional[Tensor] = None, 149 | src_key_padding_mask: Optional[Tensor] = None, 150 | pos: Optional[Tensor] = None): 151 | q = k = self.with_pos_embed(src, pos) 152 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 153 | key_padding_mask=src_key_padding_mask)[0] 154 | src = src + self.dropout1(src2) 155 | src = self.norm1(src) 156 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 157 | src = src + self.dropout2(src2) 158 | src = self.norm2(src) 159 | return src 160 | 161 | def forward_pre(self, src, 162 | src_mask: Optional[Tensor] = None, 163 | src_key_padding_mask: Optional[Tensor] = None, 164 | pos: Optional[Tensor] = None): 165 | src2 = self.norm1(src) 166 | q = k = self.with_pos_embed(src2, pos) 167 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 168 | key_padding_mask=src_key_padding_mask)[0] 169 | src = src + self.dropout1(src2) 170 | src2 = self.norm2(src) 171 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 172 | src = src + self.dropout2(src2) 173 | return src 174 | 175 | def forward(self, src, 176 | src_mask: Optional[Tensor] = None, 177 | src_key_padding_mask: Optional[Tensor] = None, 178 | pos: Optional[Tensor] = None): 179 | if self.normalize_before: 180 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 181 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 182 | 183 | 184 | class TransformerDecoderLayer(nn.Module): 185 | 186 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 187 | activation="relu", normalize_before=False): 188 | super().__init__() 189 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 190 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 191 | # Implementation of Feedforward model 192 | self.linear1 = nn.Linear(d_model, dim_feedforward) 193 | self.dropout = nn.Dropout(dropout) 194 | self.linear2 = nn.Linear(dim_feedforward, d_model) 195 | 196 | self.norm1 = nn.LayerNorm(d_model) 197 | self.norm2 = nn.LayerNorm(d_model) 198 | self.norm3 = nn.LayerNorm(d_model) 199 | self.dropout1 = nn.Dropout(dropout) 200 | self.dropout2 = nn.Dropout(dropout) 201 | self.dropout3 = nn.Dropout(dropout) 202 | 203 | self.activation = _get_activation_fn(activation) 204 | self.normalize_before = normalize_before 205 | 206 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 207 | return tensor if pos is None else tensor + pos 208 | 209 | def forward_post(self, tgt, memory, 210 | tgt_mask: Optional[Tensor] = None, 211 | memory_mask: Optional[Tensor] = None, 212 | tgt_key_padding_mask: Optional[Tensor] = None, 213 | memory_key_padding_mask: Optional[Tensor] = None, 214 | pos: Optional[Tensor] = None, 215 | query_pos: Optional[Tensor] = None): 216 | q = k = self.with_pos_embed(tgt, query_pos) 217 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 218 | key_padding_mask=tgt_key_padding_mask)[0] 219 | tgt = tgt + self.dropout1(tgt2) 220 | tgt = self.norm1(tgt) 221 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 222 | key=self.with_pos_embed(memory, pos), 223 | value=memory, attn_mask=memory_mask, 224 | key_padding_mask=memory_key_padding_mask)[0] 225 | tgt = tgt + self.dropout2(tgt2) 226 | tgt = self.norm2(tgt) 227 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 228 | tgt = tgt + self.dropout3(tgt2) 229 | tgt = self.norm3(tgt) 230 | return tgt 231 | 232 | def forward_pre(self, tgt, memory, 233 | tgt_mask: Optional[Tensor] = None, 234 | memory_mask: Optional[Tensor] = None, 235 | tgt_key_padding_mask: Optional[Tensor] = None, 236 | memory_key_padding_mask: Optional[Tensor] = None, 237 | pos: Optional[Tensor] = None, 238 | query_pos: Optional[Tensor] = None): 239 | tgt2 = self.norm1(tgt) 240 | q = k = self.with_pos_embed(tgt2, query_pos) 241 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 242 | key_padding_mask=tgt_key_padding_mask)[0] 243 | tgt = tgt + self.dropout1(tgt2) 244 | tgt2 = self.norm2(tgt) 245 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 246 | key=self.with_pos_embed(memory, pos), 247 | value=memory, attn_mask=memory_mask, 248 | key_padding_mask=memory_key_padding_mask)[0] 249 | tgt = tgt + self.dropout2(tgt2) 250 | tgt2 = self.norm3(tgt) 251 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 252 | tgt = tgt + self.dropout3(tgt2) 253 | return tgt 254 | 255 | def forward(self, tgt, memory, 256 | tgt_mask: Optional[Tensor] = None, 257 | memory_mask: Optional[Tensor] = None, 258 | tgt_key_padding_mask: Optional[Tensor] = None, 259 | memory_key_padding_mask: Optional[Tensor] = None, 260 | pos: Optional[Tensor] = None, 261 | query_pos: Optional[Tensor] = None): 262 | if self.normalize_before: 263 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 264 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 265 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 266 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 267 | 268 | 269 | def _get_clones(module, N): 270 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 271 | 272 | 273 | def build_transformer(args): 274 | return Transformer( 275 | d_model=args.hidden_dim, 276 | dropout=args.dropout, 277 | nhead=args.nheads, 278 | dim_feedforward=args.dim_feedforward, 279 | num_encoder_layers=args.enc_layers, 280 | num_decoder_layers=args.dec_layers, 281 | normalize_before=args.pre_norm, 282 | return_intermediate_dec=True, 283 | ) 284 | 285 | 286 | def _get_activation_fn(activation): 287 | """Return an activation function given a string""" 288 | if activation == "relu": 289 | return F.relu 290 | if activation == "gelu": 291 | return F.gelu 292 | if activation == "glu": 293 | return F.glu 294 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 295 | 296 | class PositionalEncoding(nn.Module): 297 | """Sinusoidal positional encoding for non-recurrent neural networks. 298 | 299 | Implementation based on "Attention Is All You Need" 300 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` 301 | 302 | Args: 303 | dropout (float): dropout parameter 304 | dim (int): embedding size 305 | """ 306 | 307 | def __init__(self, dropout, dim, max_len=500): 308 | if dim % 2 != 0: 309 | raise ValueError("Cannot use sin/cos positional encoding with " 310 | "odd dim (got dim={:d})".format(dim)) 311 | pe = torch.zeros(max_len, dim) 312 | position = torch.arange(0, max_len).unsqueeze(1) 313 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 314 | -(math.log(10000.0) / dim))) 315 | pe[:, 0::2] = torch.sin(position.float() * div_term) 316 | pe[:, 1::2] = torch.cos(position.float() * div_term) 317 | pe = pe.unsqueeze(1) 318 | super(PositionalEncoding, self).__init__() 319 | self.register_buffer('pe', pe) 320 | self.dropout = nn.Dropout(p=dropout) 321 | self.dim = dim 322 | 323 | def forward(self, emb, step=None): 324 | """Embed inputs. 325 | 326 | Args: 327 | emb (FloatTensor): Sequence of word vectors 328 | ``(seq_len, batch_size, self.dim)`` 329 | step (int or NoneType): If stepwise (``seq_len = 1``), use 330 | the encoding for this position. 331 | """ 332 | 333 | emb = emb * math.sqrt(self.dim) 334 | if step is None: 335 | emb = emb + self.pe[:emb.size(0)] 336 | else: 337 | emb = emb + self.pe[step] 338 | emb = self.dropout(emb) 339 | return emb 340 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import six 7 | import os 8 | import os.path as osp 9 | import copy 10 | from ast import literal_eval 11 | 12 | import numpy as np 13 | from packaging import version 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import init 17 | import yaml 18 | from easydict import EasyDict 19 | 20 | class AttrDict(EasyDict): 21 | IMMUTABLE = '__immutable__' 22 | 23 | def __init__(self, *args): 24 | super(EasyDict, self).__init__(*args) 25 | 26 | def immutable(self, is_immutable): 27 | """Set immutability to is_immutable and recursively apply the setting 28 | to all nested AttrDicts. 29 | """ 30 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 31 | # Recursively set immutable state 32 | for v in self.__dict__.values(): 33 | if isinstance(v, AttrDict): 34 | v.immutable(is_immutable) 35 | for v in self.values(): 36 | if isinstance(v, AttrDict): 37 | v.immutable(is_immutable) 38 | 39 | def is_immutable(self): 40 | return self.__dict__[AttrDict.IMMUTABLE] 41 | 42 | __C = AttrDict() 43 | # Consumers can get config by: 44 | # from parse_config import cfg 45 | cfg = __C 46 | 47 | 48 | # Random note: avoid using '.ON' as a config key since yaml converts it to True; 49 | # prefer 'ENABLED' instead 50 | 51 | # ---------------------------------------------------------------------------- # 52 | # Training options 53 | # ---------------------------------------------------------------------------- # 54 | __C.TRAIN = AttrDict() 55 | 56 | # Datasets to train on 57 | __C.TRAIN.ISTRAIN = True 58 | 59 | # Height Pixel 60 | __C.TRAIN.IMG_H = 64 61 | 62 | # Width Pixel 63 | __C.TRAIN.IMG_W = 64 64 | 65 | # keep image aspect while trainning, didn't support True yet 66 | __C.TRAIN.KEEP_ASPECT = False 67 | 68 | # Images *per GPU* in the training minibatch 69 | # Total images per minibatch = TRAIN.IMS_PER_BATCH * NUM_GPUS 70 | __C.TRAIN.IMS_PER_BATCH = 64 71 | 72 | # Snapshot (model checkpoint) period 73 | # Divide by NUM_GPUS to determine actual period (e.g., 20000/8 => 2500 iters) 74 | # to allow for linear training schedule scaling 75 | __C.TRAIN.SNAPSHOT_ITERS = 3000 76 | 77 | __C.TRAIN.SNAPSHOT_BEGIN = 0 78 | 79 | __C.TRAIN.VALIDATE_ITERS = 0 80 | 81 | __C.TRAIN.VALIDATE_BEGIN = 0 82 | 83 | __C.TRAIN.TEST_ITERS = 6000 84 | 85 | __C.TRAIN.DATASET = '' 86 | 87 | # Dropout probability in dense3 88 | __C.TRAIN.DROPOUT_P = 0. 89 | 90 | # Set the random seed 91 | __C.TRAIN.SEED = 1001 92 | 93 | 94 | # ---------------------------------------------------------------------------- # 95 | # Data loader options 96 | # ---------------------------------------------------------------------------- # 97 | __C.DATA_LOADER = AttrDict() 98 | 99 | # Number of Python threads to use for the data loader (warning: using too many 100 | # threads can cause GIL-based interference with Python Ops leading to *slower* 101 | # training; 4 seems to be the sweet spot in our experience) 102 | __C.DATA_LOADER.NUM_THREADS = 8 103 | 104 | __C.DATA_LOADER.CONCAT_GRID = False 105 | 106 | __C.DATA_LOADER.PATH = 'data' 107 | 108 | __C.DATA_LOADER.TYPE = 'ScriptDataset' 109 | 110 | __C.DATA_LOADER.DATASET = 'CHINESE' 111 | 112 | 113 | # ---------------------------------------------------------------------------- # 114 | # Inference ('test') options 115 | # ---------------------------------------------------------------------------- # 116 | __C.TEST = AttrDict() 117 | 118 | # Datasets to test on 119 | # Available dataset list: datasets.dataset_catalog.DATASETS.keys() 120 | # If multiple datasets are listed, testing is performed on each one sequentially 121 | __C.TEST.ISTRAIN = False 122 | 123 | # Scale to use during testing (can NOT list multiple scales) 124 | # The scale is the pixel size of an image's shortest side 125 | __C.TEST.IMG_H = 64 126 | 127 | # Max pixel size of the longest side of a scaled input image 128 | __C.TEST.IMG_W = 64 129 | 130 | __C.TEST.KEEP_ASPECT = False 131 | 132 | __C.TEST.DATASET = '' 133 | 134 | # ---------------------------------------------------------------------------- # 135 | # Model options 136 | # ---------------------------------------------------------------------------- # 137 | __C.MODEL = AttrDict() 138 | 139 | # the number of the encoder layers 140 | __C.MODEL.ENCODER_LAYERS = 3 141 | 142 | # the number of layers for fusing writer-wise styles 143 | __C.MODEL.WRI_DEC_LAYERS = 2 144 | 145 | # the number of layers for fusing character-wise styles 146 | __C.MODEL.GLY_DEC_LAYERS = 2 147 | 148 | # the number of layers for each style head 149 | __C.MODEL.NUM_HEAD_LAYERS = 1 150 | 151 | # the number of style references 152 | __C.MODEL.NUM_IMGS = 15 153 | 154 | # ---------------------------------------------------------------------------- # 155 | # Solver options 156 | # ---------------------------------------------------------------------------- # 157 | __C.SOLVER = AttrDict() 158 | 159 | # support 'SGD', 'Adam', 'Adadelta' and 'Rmsprop' 160 | __C.SOLVER.TYPE = 'Adam' 161 | 162 | # Base learning rate for the specified schedule 163 | __C.SOLVER.BASE_LR = 0.001 164 | 165 | # Maximum number of trainning iterations 166 | __C.SOLVER.MAX_ITER = 7200000 167 | 168 | __C.SOLVER.WARMUP_ITERS = 0 169 | 170 | # CLIP Gradient L2 Nrom 171 | __C.SOLVER.GRAD_L2_CLIP = 1.0 172 | 173 | # ---------------------------------------------------------------------------- # 174 | # MISC options 175 | # ---------------------------------------------------------------------------- # 176 | 177 | # Number of GPUs to use (applies to both training and testing) 178 | __C.NUM_GPUS = 1 179 | 180 | # Root directory of project 181 | __C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__))) 182 | 183 | # Output basedir 184 | __C.OUTPUT_DIR = 'Saved' 185 | 186 | def assert_and_infer_cfg(make_immutable=True): 187 | """Call this function in your script after you have finished setting all cfg 188 | values that are necessary (e.g., merging a config from a file, merging 189 | command line config options, etc.). By default, this function will also 190 | mark the global cfg as immutable to prevent changing the global cfg settings 191 | during script execution (which can lead to hard to debug errors or code 192 | that's harder to understand than is necessary). 193 | """ 194 | if version.parse(torch.__version__) < version.parse('0.4.0'): 195 | __C.PYTORCH_VERSION_LESS_THAN_040 = True 196 | # create alias for PyTorch version less than 0.4.0 197 | init.uniform_ = init.uniform 198 | init.normal_ = init.normal 199 | init.constant_ = init.constant 200 | init.kaiming_normal_ = init.kaiming_normal 201 | torch.nn.utils.clip_grad_norm_ = torch.nn.utils.clip_grad_norm 202 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 203 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 204 | tensor.requires_grad = requires_grad 205 | tensor._backward_hooks = backward_hooks 206 | return tensor 207 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 208 | if make_immutable: 209 | cfg.immutable(True) 210 | 211 | 212 | def merge_cfg_from_file(cfg_filename): 213 | """Load a yaml config file and merge it into the global config.""" 214 | with open(cfg_filename, 'r') as f: 215 | yaml_cfg = AttrDict(yaml.full_load(f)) 216 | _merge_a_into_b(yaml_cfg, __C) 217 | 218 | cfg_from_file = merge_cfg_from_file 219 | 220 | 221 | def merge_cfg_from_cfg(cfg_other): 222 | """Merge `cfg_other` into the global config.""" 223 | _merge_a_into_b(cfg_other, __C) 224 | 225 | 226 | def merge_cfg_from_list(cfg_list): 227 | """Merge config keys, values in a list (e.g., from command line) into the 228 | global config. For example, `cfg_list = ['TEST.NMS', 0.5]`. 229 | """ 230 | assert len(cfg_list) % 2 == 0 231 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 232 | # if _key_is_deprecated(full_key): 233 | # continue 234 | # if _key_is_renamed(full_key): 235 | # _raise_key_rename_error(full_key) 236 | key_list = full_key.split('.') 237 | d = __C 238 | for subkey in key_list[:-1]: 239 | assert subkey in d, 'Non-existent key: {}'.format(full_key) 240 | d = d[subkey] 241 | subkey = key_list[-1] 242 | assert subkey in d, 'Non-existent key: {}'.format(full_key) 243 | value = _decode_cfg_value(v) 244 | value = _check_and_coerce_cfg_value_type( 245 | value, d[subkey], subkey, full_key 246 | ) 247 | d[subkey] = value 248 | 249 | cfg_from_list = merge_cfg_from_list 250 | 251 | 252 | def _merge_a_into_b(a, b, stack=None): 253 | """Merge config dictionary a into config dictionary b, clobbering the 254 | options in b whenever they are also specified in a. 255 | """ 256 | assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict' 257 | assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict' 258 | 259 | for k, v_ in a.items(): 260 | full_key = '.'.join(stack) + '.' + k if stack is not None else k 261 | # a must specify keys that are in b 262 | if k not in b: 263 | # if _key_is_deprecated(full_key): 264 | # continue 265 | # elif _key_is_renamed(full_key): 266 | # _raise_key_rename_error(full_key) 267 | # else: 268 | raise KeyError('Non-existent config key: {}'.format(full_key)) 269 | 270 | v = copy.deepcopy(v_) 271 | v = _decode_cfg_value(v) 272 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) 273 | 274 | # Recursively merge dicts 275 | if isinstance(v, AttrDict): 276 | try: 277 | stack_push = [k] if stack is None else stack + [k] 278 | _merge_a_into_b(v, b[k], stack=stack_push) 279 | except BaseException: 280 | raise 281 | else: 282 | b[k] = v 283 | 284 | 285 | def _decode_cfg_value(v): 286 | """Decodes a raw config value (e.g., from a yaml config files or command 287 | line argument) into a Python object. 288 | """ 289 | # Configs parsed from raw yaml will contain dictionary keys that need to be 290 | # converted to AttrDict objects 291 | if isinstance(v, dict): 292 | return AttrDict(v) 293 | # All remaining processing is only applied to strings 294 | if not isinstance(v, six.string_types): 295 | return v 296 | # Try to interpret `v` as a: 297 | # string, number, tuple, list, dict, boolean, or None 298 | try: 299 | v = literal_eval(v) 300 | # The following two excepts allow v to pass through when it represents a 301 | # string. 302 | # 303 | # Longer explanation: 304 | # The type of v is always a string (before calling literal_eval), but 305 | # sometimes it *represents* a string and other times a data structure, like 306 | # a list. In the case that v represents a string, what we got back from the 307 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 308 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 309 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 310 | # will raise a SyntaxError. 311 | except ValueError: 312 | pass 313 | except SyntaxError: 314 | pass 315 | return v 316 | 317 | 318 | def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key): 319 | """Checks that `value_a`, which is intended to replace `value_b` is of the 320 | right type. The type is correct if it matches exactly or is one of a few 321 | cases in which the type can be easily coerced. 322 | """ 323 | # The types must match (with some exceptions) 324 | type_b = type(value_b) 325 | type_a = type(value_a) 326 | if type_a is type_b: 327 | return value_a 328 | 329 | # Exceptions: numpy arrays, strings, tuple<->list 330 | if isinstance(value_b, np.ndarray): 331 | value_a = np.array(value_a, dtype=value_b.dtype) 332 | elif isinstance(value_b, six.string_types): 333 | value_a = str(value_a) 334 | elif isinstance(value_a, tuple) and isinstance(value_b, list): 335 | value_a = list(value_a) 336 | elif isinstance(value_a, list) and isinstance(value_b, tuple): 337 | value_a = tuple(value_a) 338 | else: 339 | raise ValueError( 340 | 'Type mismatch ({} vs. {}) with values ({} vs. {}) for config ' 341 | 'key: {}'.format(type_b, type_a, value_b, value_a, full_key) 342 | ) 343 | return value_a -------------------------------------------------------------------------------- /static/Poster_SDT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/Poster_SDT.pdf -------------------------------------------------------------------------------- /static/duo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/duo.gif -------------------------------------------------------------------------------- /static/duo_loop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/duo_loop.gif -------------------------------------------------------------------------------- /static/mo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/mo.gif -------------------------------------------------------------------------------- /static/mo_loop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/mo_loop.gif -------------------------------------------------------------------------------- /static/offline_Chinese.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/offline_Chinese.jpg -------------------------------------------------------------------------------- /static/online_Chinese.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/online_Chinese.jpg -------------------------------------------------------------------------------- /static/overview_sdt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/overview_sdt.jpg -------------------------------------------------------------------------------- /static/print.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/print.png -------------------------------------------------------------------------------- /static/software.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/software.png -------------------------------------------------------------------------------- /static/svg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/svg.png -------------------------------------------------------------------------------- /static/tai.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/tai.gif -------------------------------------------------------------------------------- /static/tai_loop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/tai_loop.gif -------------------------------------------------------------------------------- /static/various_scripts.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dailenson/SDT/339aa2155d93263bdb7d533830559d011d41995c/static/various_scripts.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from parse_config import cfg, cfg_from_file, assert_and_infer_cfg 4 | import torch 5 | from data_loader.loader import ScriptDataset 6 | import pickle 7 | from models.model import SDT_Generator 8 | import tqdm 9 | from utils.util import writeCache, dxdynp_to_list, coords_render 10 | import lmdb 11 | 12 | def main(opt): 13 | """ load config file into cfg""" 14 | cfg_from_file(opt.cfg_file) 15 | assert_and_infer_cfg() 16 | 17 | """setup data_loader instances""" 18 | test_dataset = ScriptDataset( 19 | cfg.DATA_LOADER.PATH, cfg.DATA_LOADER.DATASET, cfg.TEST.ISTRAIN, cfg.MODEL.NUM_IMGS) 20 | test_loader = torch.utils.data.DataLoader(test_dataset, 21 | batch_size=cfg.TRAIN.IMS_PER_BATCH, 22 | shuffle=True, 23 | sampler=None, 24 | drop_last=False, 25 | collate_fn=test_dataset.collate_fn_, 26 | num_workers=cfg.DATA_LOADER.NUM_THREADS) 27 | char_dict = test_dataset.char_dict 28 | writer_dict = test_dataset.writer_dict 29 | 30 | os.makedirs(os.path.join(opt.save_dir, 'test'), exist_ok=True) 31 | test_env = lmdb.open(os.path.join(opt.save_dir, 'test'), map_size=1099511627776) 32 | pickle.dump(writer_dict, open(os.path.join(opt.save_dir, 'writer_dict.pkl'), 'wb')) 33 | pickle.dump(char_dict, open(os.path.join(opt.save_dir, 'character_dict.pkl'), 'wb')) 34 | 35 | """build model architecture""" 36 | model = SDT_Generator(num_encoder_layers=cfg.MODEL.ENCODER_LAYERS, 37 | num_head_layers= cfg.MODEL.NUM_HEAD_LAYERS, 38 | wri_dec_layers=cfg.MODEL.WRI_DEC_LAYERS, 39 | gly_dec_layers= cfg.MODEL.GLY_DEC_LAYERS).to('cuda') 40 | if len(opt.pretrained_model) > 0: 41 | model_weight = torch.load(opt.pretrained_model) 42 | model.load_state_dict(model_weight) 43 | print('load pretrained model from {}'.format(opt.pretrained_model)) 44 | else: 45 | raise IOError('input the correct checkpoint path') 46 | model.eval() 47 | 48 | """calculate the total batches of generated samples""" 49 | if opt.sample_size == 'all': 50 | batch_samples = len(test_loader) 51 | else: 52 | batch_samples = int(opt.sample_size)*len(writer_dict)//cfg.TRAIN.IMS_PER_BATCH 53 | 54 | batch_num, num_count= 0, 0 55 | data_iter = iter(test_loader) 56 | with torch.no_grad(): 57 | for _ in tqdm.tqdm(range(batch_samples)): 58 | batch_num += 1 59 | if batch_num > batch_samples: 60 | break 61 | else: 62 | data = next(data_iter) 63 | # prepare input 64 | coords, coords_len, character_id, writer_id, img_list, char_img = data['coords'].cuda(), \ 65 | data['coords_len'].cuda(), \ 66 | data['character_id'].long().cuda(), \ 67 | data['writer_id'].long().cuda(), \ 68 | data['img_list'].cuda(), \ 69 | data['char_img'].cuda() 70 | preds = model.inference(img_list, char_img, 120) 71 | bs = character_id.shape[0] 72 | SOS = torch.tensor(bs * [[0, 0, 1, 0, 0]]).unsqueeze(1).to(preds) 73 | preds = torch.cat((SOS, preds), 1) # add the SOS token like GT 74 | preds = preds.detach().cpu().numpy() 75 | 76 | test_cache = {} 77 | coords = coords.detach().cpu().numpy() 78 | if opt.store_type == 'online': 79 | for i, pred in enumerate(preds): 80 | pred, _ = dxdynp_to_list(preds[i]) 81 | coord, _ = dxdynp_to_list(coords[i]) 82 | data = {'coordinates': pred, 'writer_id': writer_id[i].item(), 83 | 'character_id': character_id[i].item(), 'coords_gt':coord} 84 | data_byte = pickle.dumps(data) 85 | data_id = str(num_count).encode('utf-8') 86 | test_cache[data_id] = data_byte 87 | num_count += 1 88 | test_cache['num_sample'.encode('utf-8')] = str(num_count).encode() 89 | writeCache(test_env, test_cache) 90 | elif opt.store_type == 'img': 91 | for i, pred in enumerate(preds): 92 | """intends to blur the boundaries of each sample to fit the actual using situations, 93 | as suggested in 'Deep imitator: Handwriting calligraphy imitation via deep attention networks'""" 94 | sk_pil = coords_render(preds[i], split=True, width=256, height=256, thickness=8, board=0) 95 | character = char_dict[character_id[i].item()] 96 | save_path = os.path.join(opt.save_dir, 'test', 97 | str(writer_id[i].item()) + '_' + character+'.png') 98 | try: 99 | sk_pil.save(save_path) 100 | except: 101 | print('error. %s, %s, %s' % (save_path, str(writer_id[i].item()), character)) 102 | else: 103 | raise NotImplementedError('only support online or img format') 104 | 105 | if __name__ == '__main__': 106 | """Parse input arguments""" 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('--cfg', dest='cfg_file', default='configs/CHINESE_CASIA.yml', 109 | help='Config file for training (and optionally testing)') 110 | parser.add_argument('--dir', dest='save_dir', default='Generated/Chinese', help='target dir for storing the generated characters') 111 | parser.add_argument('--pretrained_model', dest='pretrained_model', default='', required=True, help='continue train model') 112 | parser.add_argument('--store_type', dest='store_type', required=True, default='online', help='online or img') 113 | parser.add_argument('--sample_size', dest='sample_size', default='500', required=True, help='randomly generate a certain number of characters for each writer') 114 | opt = parser.parse_args() 115 | main(opt) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from parse_config import cfg, cfg_from_file, assert_and_infer_cfg 3 | from utils.util import fix_seed, load_specific_dict 4 | from models.loss import SupConLoss, get_pen_loss 5 | from models.model import SDT_Generator 6 | from utils.logger import set_log 7 | from data_loader.loader import ScriptDataset 8 | import torch 9 | from trainer.trainer import Trainer 10 | 11 | def main(opt): 12 | """ load config file into cfg""" 13 | cfg_from_file(opt.cfg_file) 14 | assert_and_infer_cfg() 15 | """fix the random seed""" 16 | fix_seed(cfg.TRAIN.SEED) 17 | """ prepare log file """ 18 | logs = set_log(cfg.OUTPUT_DIR, opt.cfg_file, opt.log_name) 19 | """ set dataset""" 20 | train_dataset = ScriptDataset( 21 | cfg.DATA_LOADER.PATH, cfg.DATA_LOADER.DATASET, cfg.TRAIN.ISTRAIN, cfg.MODEL.NUM_IMGS) 22 | print('number of training images: ', len(train_dataset)) 23 | train_loader = torch.utils.data.DataLoader(train_dataset, 24 | batch_size=cfg.TRAIN.IMS_PER_BATCH, 25 | shuffle=True, 26 | drop_last=False, 27 | collate_fn=train_dataset.collate_fn_, 28 | num_workers=cfg.DATA_LOADER.NUM_THREADS) 29 | test_dataset = ScriptDataset( 30 | cfg.DATA_LOADER.PATH, cfg.DATA_LOADER.DATASET, cfg.TEST.ISTRAIN, cfg.MODEL.NUM_IMGS) 31 | test_loader = torch.utils.data.DataLoader(test_dataset, 32 | batch_size=cfg.TRAIN.IMS_PER_BATCH, 33 | shuffle=True, 34 | sampler=None, 35 | drop_last=False, 36 | collate_fn=test_dataset.collate_fn_, 37 | num_workers=cfg.DATA_LOADER.NUM_THREADS) 38 | char_dict = test_dataset.char_dict 39 | """ build model, criterion and optimizer""" 40 | model = SDT_Generator(num_encoder_layers=cfg.MODEL.ENCODER_LAYERS, 41 | num_head_layers= cfg.MODEL.NUM_HEAD_LAYERS, 42 | wri_dec_layers=cfg.MODEL.WRI_DEC_LAYERS, 43 | gly_dec_layers= cfg.MODEL.GLY_DEC_LAYERS).to('cuda') 44 | ### load checkpoint 45 | if len(opt.pretrained_model) > 0: 46 | model.load_state_dict(torch.load(opt.pretrained_model)) 47 | print('load pretrained model from {}'.format(opt.pretrained_model)) 48 | elif len(opt.content_pretrained) > 0: 49 | model_dict = load_specific_dict(model.content_encoder, opt.content_pretrained, "feature_ext") 50 | model.content_encoder.load_state_dict(model_dict) 51 | print('load content pretrained model from {}'.format(opt.content_pretrained)) 52 | else: 53 | pass 54 | criterion = dict(NCE=SupConLoss(contrast_mode='all'), PEN=get_pen_loss) 55 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.SOLVER.BASE_LR) 56 | """start training iterations""" 57 | trainer = Trainer(model, criterion, optimizer, train_loader, logs, char_dict, test_loader) 58 | trainer.train() 59 | 60 | if __name__ == '__main__': 61 | """Parse input arguments""" 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--pretrained_model', default='', 64 | dest='pretrained_model', required=False, help='continue to train model') 65 | parser.add_argument('--content_pretrained', default='model_zoo/position_layer2_dim512_iter138k_test_acc0.9443.pth', 66 | dest='content_pretrained', required=False, help='continue to train content encoder') 67 | parser.add_argument('--cfg', dest='cfg_file', default='configs/CHINESE_CASIA.yml', 68 | help='Config file for training (and optionally testing)') 69 | parser.add_argument('--log', default='debug', 70 | dest='log_name', required=False, help='the filename of log') 71 | opt = parser.parse_args() 72 | main(opt) -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tensorboardX import SummaryWriter 3 | import time 4 | from parse_config import cfg 5 | from models.gmm import get_mixture_coef 6 | import os 7 | import datetime 8 | import sys 9 | from utils.util import coords_render 10 | from PIL import Image 11 | 12 | class Trainer: 13 | def __init__(self, model, criterion, optimizer, data_loader, 14 | logs, char_dict, valid_data_loader=None): 15 | self.model = model 16 | self.criterion = criterion 17 | self.optimizer = optimizer 18 | self.data_loader = data_loader 19 | self.char_dict = char_dict 20 | self.valid_data_loader = valid_data_loader 21 | self.nce_criterion = criterion['NCE'] 22 | self.pen_criterion = criterion['PEN'] 23 | self.tb_summary = SummaryWriter(logs['tboard']) 24 | self.save_model_dir = logs['model'] 25 | self.save_sample_dir = logs['sample'] 26 | 27 | def _train_iter(self, data, step): 28 | self.model.train() 29 | prev_time = time.time() 30 | # prepare input 31 | coords, coords_len, character_id, writer_id, img_list, char_img = data['coords'].cuda(), \ 32 | data['coords_len'].cuda(), \ 33 | data['character_id'].long().cuda(), \ 34 | data['writer_id'].long().cuda(), \ 35 | data['img_list'].cuda(), \ 36 | data['char_img'].cuda() 37 | 38 | # forward 39 | input_seq = coords[:, 1:-1] 40 | preds, nce_emb, nce_emb_patch= self.model(img_list, input_seq, char_img) 41 | 42 | # calculate loss 43 | gt_coords = coords[:, 1:, :] 44 | nce_loss_writer = self.nce_criterion(nce_emb, labels=writer_id) 45 | nce_loss_glyph = self.nce_criterion(nce_emb_patch) 46 | preds = preds.view(-1, 123) 47 | gt_coords = gt_coords.reshape(-1, 5) 48 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits] = get_mixture_coef(preds) 49 | moving_loss_all, state_loss = self.pen_criterion(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, \ 50 | o_corr, o_pen_logits, gt_coords[:,0].unsqueeze(-1), gt_coords[:,1].unsqueeze(-1), gt_coords[:,2:]) 51 | moving_loss = torch.sum(moving_loss_all) / torch.sum(coords_len) 52 | pen_loss = moving_loss + 2*state_loss 53 | loss = pen_loss + nce_loss_writer + nce_loss_glyph 54 | 55 | # backward and update trainable parameters 56 | self.model.zero_grad() 57 | loss.backward() 58 | if cfg.SOLVER.GRAD_L2_CLIP > 0: 59 | torch.nn.utils.clip_grad_norm(self.model.parameters(), cfg.SOLVER.GRAD_L2_CLIP) 60 | self.optimizer.step() 61 | 62 | # log file 63 | loss_dict = {"pen_loss": pen_loss.item(), "moving_loss": moving_loss.item(), 64 | "state_loss": state_loss.item(), "nce_loss_writer":nce_loss_writer.item(), 65 | "nce_loss_glyph":nce_loss_glyph.item()} 66 | self.tb_summary.add_scalars("loss", loss_dict, step) 67 | iter_left = cfg.SOLVER.MAX_ITER - step 68 | time_left = datetime.timedelta( 69 | seconds=iter_left * (time.time() - prev_time)) 70 | self._progress(step, loss.item(), time_left) 71 | 72 | del data, preds, loss 73 | torch.cuda.empty_cache() 74 | 75 | 76 | def _valid_iter(self, step): 77 | self.model.eval() 78 | print('loading test dataset, the number is', len(self.valid_data_loader)) 79 | try: 80 | test_loader_iter = iter(self.valid_data_loader) 81 | test_data = next(test_loader_iter) 82 | except StopIteration: 83 | test_loader_iter = iter(self.valid_data_loader) 84 | test_data = next(test_loader_iter) 85 | # prepare input 86 | coords, coords_len, character_id, writer_id, img_list, char_img = test_data['coords'].cuda(), \ 87 | test_data['coords_len'].cuda(), \ 88 | test_data['character_id'].long().cuda(), \ 89 | test_data['writer_id'].long().cuda(), \ 90 | test_data['img_list'].cuda(), \ 91 | test_data['char_img'].cuda() 92 | # forward 93 | with torch.no_grad(): 94 | preds = self.model.inference(img_list, char_img, 120) 95 | bs = character_id.shape[0] 96 | SOS = torch.tensor(bs * [[0, 0, 1, 0, 0]]).unsqueeze(1).to(preds) 97 | preds = torch.cat((SOS, preds), 1) # add the first token 98 | preds = preds.cpu().numpy() 99 | gt_coords = coords.cpu().numpy() # [N, T, C] 100 | self._vis_genarate_samples(gt_coords, preds, character_id, step) 101 | 102 | def train(self): 103 | """start training iterations""" 104 | train_loader_iter = iter(self.data_loader) 105 | for step in range(cfg.SOLVER.MAX_ITER): 106 | try: 107 | data = next(train_loader_iter) 108 | except StopIteration: 109 | train_loader_iter = iter(self.data_loader) 110 | data = next(train_loader_iter) 111 | self._train_iter(data, step) 112 | 113 | if (step+1) > cfg.TRAIN.SNAPSHOT_BEGIN and (step+1) % cfg.TRAIN.SNAPSHOT_ITERS == 0: 114 | self._save_checkpoint(step) 115 | else: 116 | pass 117 | if self.valid_data_loader is not None: 118 | if (step+1) > cfg.TRAIN.VALIDATE_BEGIN and (step+1) % cfg.TRAIN.VALIDATE_ITERS == 0: 119 | self._valid_iter(step) 120 | else: 121 | pass 122 | 123 | 124 | def _progress(self, step, loss, time_left): 125 | terminal_log = 'iter:%d ' % step 126 | terminal_log += '%s:%.3f ' % ('loss', loss) 127 | terminal_log += 'ETA:%s\r\n' % str(time_left) 128 | sys.stdout.write(terminal_log) 129 | 130 | def _save_checkpoint(self, step): 131 | model_path = '{}/checkpoint-iter{}.pth'.format(self.save_model_dir, step) 132 | torch.save(self.model.state_dict(), model_path) 133 | print('save model to {}'.format(model_path)) 134 | 135 | def _vis_genarate_samples(self, gt_coords, preds, character_id, step): 136 | for i, _ in enumerate(gt_coords): 137 | gt_img = coords_render(gt_coords[i], split=True, width=64, height=64, thickness=1) 138 | pred_img = coords_render(preds[i], split=True, width=64, height=64, thickness=1) 139 | example_img = Image.new("RGB", (cfg.TEST.IMG_W * 2, cfg.TEST.IMG_H), 140 | (255, 255, 255)) 141 | example_img.paste(pred_img, (0, 0)) # gererated character 142 | example_img.paste(gt_img, (cfg.TEST.IMG_W, 0)) # gt character 143 | character = self.char_dict[character_id[i].item()] 144 | save_path = os.path.join(self.save_sample_dir, 'ite.' + str(step//100000) 145 | + '-'+ str(step//100000 + 100000), character + '_' + str(step) + '_.jpg') 146 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 147 | try: 148 | example_img.save(save_path) 149 | except: 150 | print('error. %s, %s' % (save_path, character)) -------------------------------------------------------------------------------- /user_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from parse_config import cfg, cfg_from_file, assert_and_infer_cfg 4 | import torch 5 | from data_loader.loader import UserDataset 6 | import pickle 7 | from models.model import SDT_Generator 8 | import tqdm 9 | from utils.util import writeCache, dxdynp_to_list, coords_render 10 | import lmdb 11 | 12 | def main(opt): 13 | """ load config file into cfg""" 14 | cfg_from_file(opt.cfg_file) 15 | assert_and_infer_cfg() 16 | 17 | """setup data_loader instances""" 18 | test_dataset = UserDataset( 19 | cfg.DATA_LOADER.PATH, cfg.DATA_LOADER.DATASET, opt.style_path) 20 | test_loader = torch.utils.data.DataLoader(test_dataset, 21 | batch_size=cfg.TRAIN.IMS_PER_BATCH, 22 | shuffle=True, 23 | sampler=None, 24 | drop_last=False, 25 | num_workers=cfg.DATA_LOADER.NUM_THREADS) 26 | 27 | os.makedirs(os.path.join(opt.save_dir), exist_ok=True) 28 | 29 | """build model architecture""" 30 | model = SDT_Generator(num_encoder_layers=cfg.MODEL.ENCODER_LAYERS, 31 | num_head_layers= cfg.MODEL.NUM_HEAD_LAYERS, 32 | wri_dec_layers=cfg.MODEL.WRI_DEC_LAYERS, 33 | gly_dec_layers= cfg.MODEL.GLY_DEC_LAYERS).to('cuda') 34 | if len(opt.pretrained_model) > 0: 35 | model_weight = torch.load(opt.pretrained_model) 36 | model.load_state_dict(model_weight) 37 | print('load pretrained model from {}'.format(opt.pretrained_model)) 38 | else: 39 | raise IOError('input the correct checkpoint path') 40 | model.eval() 41 | 42 | """setup the dataloader""" 43 | batch_samples = len(test_loader) 44 | data_iter = iter(test_loader) 45 | with torch.no_grad(): 46 | for _ in tqdm.tqdm(range(batch_samples)): 47 | 48 | data = next(data_iter) 49 | # prepare input 50 | img_list, char_img, char = data['img_list'].cuda(), \ 51 | data['char_img'].cuda(), data['char'] 52 | preds = model.inference(img_list, char_img, 120) 53 | bs = char_img.shape[0] 54 | SOS = torch.tensor(bs * [[0, 0, 1, 0, 0]]).unsqueeze(1).to(preds) 55 | preds = torch.cat((SOS, preds), 1) # add the SOS token like GT 56 | preds = preds.detach().cpu().numpy() 57 | 58 | for i, pred in enumerate(preds): 59 | """Render the character images by connecting the coordinates""" 60 | sk_pil = coords_render(preds[i], split=True, width=256, height=256, thickness=8, board=1) 61 | 62 | save_path = os.path.join(opt.save_dir, char[i] +'.png') 63 | try: 64 | sk_pil.save(save_path) 65 | except: 66 | print('error. %s, %s' % (save_path, char[i])) 67 | 68 | 69 | if __name__ == '__main__': 70 | """Parse input arguments""" 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--cfg', dest='cfg_file', default='configs/CHINESE_USER.yml', 73 | help='Config file for training (and optionally testing)') 74 | parser.add_argument('--dir', dest='save_dir', default='Generated/Chinese_User', help='target dir for storing the generated characters') 75 | parser.add_argument('--pretrained_model', dest='pretrained_model', default='', required=True, help='continue train model') 76 | parser.add_argument('--style_path', dest='style_path', default='style_samples', help='dir of style samples') 77 | opt = parser.parse_args() 78 | main(opt) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import os 4 | 5 | """ prepare logdir for tensorboard and logging output""" 6 | def set_log(output_dir, cfg_file, log_name): 7 | t = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 8 | base_name = os.path.basename(cfg_file).split('.')[0] 9 | log_dir = os.path.join(output_dir, base_name, log_name + "-" + t) 10 | logs = {} 11 | for temp in ['tboard', 'model', 'sample']: 12 | temp_dir = os.path.join(log_dir, temp) 13 | os.makedirs(temp_dir, exist_ok=True) 14 | logs[temp] = temp_dir 15 | return logs -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from data_loader.loader import Online_Dataset 3 | import torch 4 | import numpy as np 5 | import tqdm 6 | from fastdtw import fastdtw 7 | from models.eval_model import * 8 | 9 | def fast_norm_len_dtw(test_loader): 10 | """start test iterations""" 11 | euclidean = lambda x, y: np.sqrt(sum((x - y) ** 2)) 12 | fast_norm_dtw_len, total_num = 0, 0 13 | 14 | for data in tqdm.tqdm(test_loader): 15 | preds, preds_len, character_id, writer_id, coords_gts, len_gts = data['coords'], \ 16 | data['coords_len'].long(), \ 17 | data['character_id'].long(), \ 18 | data['writer_id'].long(), \ 19 | data['coords_gt'], \ 20 | data['len_gt'].long() 21 | for i, pred in enumerate(preds): 22 | pred_len, gt_len= preds_len[i], len_gts[i] 23 | pred_valid, gt_valid = pred[:pred_len], coords_gts[i][:gt_len] 24 | 25 | # Convert relative coordinates into absolute coordinates 26 | seq_1 = torch.cumsum(gt_valid[:, :2], dim=0) 27 | seq_2 = torch.cumsum(pred_valid[:, :2], dim=0) 28 | 29 | # DTW between paired real and fake online characters 30 | fast_d, _ = fastdtw(seq_1, seq_2, dist= euclidean) 31 | fast_norm_dtw_len += (fast_d/gt_len) 32 | total_num += len(preds) 33 | avg_fast_norm_dtw_len = fast_norm_dtw_len/total_num 34 | return avg_fast_norm_dtw_len 35 | 36 | def get_style_score(test_loader,pretrained_model): 37 | correct = torch.zeros(1).squeeze().cuda() 38 | total = torch.zeros(1).squeeze().cuda() 39 | print('calculate the acc for the testset') 40 | print('loading testset...') 41 | 42 | model = offline_style(num_class=test_loader.dataset.num_class).cuda() 43 | 44 | if len(pretrained_model) > 0: 45 | model.load_state_dict(torch.load(pretrained_model)) 46 | print('load pretrained model from {}'.format(pretrained_model)) 47 | 48 | model.eval() 49 | with torch.no_grad(): 50 | for data, labels in tqdm.tqdm(test_loader): 51 | data, labels = data.cuda(), labels.cuda() 52 | test_preds = model(data) 53 | prediction = torch.argmax(test_preds, 1) 54 | correct += (prediction == labels).sum().float() 55 | total += len(labels) 56 | acc_str = (correct/total).cpu().numpy() 57 | return acc_str 58 | 59 | def get_content_score(test_loader,pretrained_model): 60 | """ set model, criterion and optimizer""" 61 | Net = Character_Net(nclass=len(test_loader.dataset.char_dict)).cuda().eval() 62 | if len(pretrained_model) > 0: 63 | Net.load_state_dict(torch.load(pretrained_model)) 64 | print('load pretrained model from {}'.format(pretrained_model)) 65 | 66 | """start test iterations""" 67 | 68 | Net.eval() 69 | correct = torch.zeros(1).squeeze().cuda() 70 | total = torch.zeros(1).squeeze().cuda() 71 | 72 | for data in tqdm.tqdm(test_loader): 73 | coords, coords_len, character_id, writer_id = data['coords'].cuda(), \ 74 | data['coords_len'].cuda(), \ 75 | data['character_id'].long().cuda(), \ 76 | data['writer_id'].long().cuda() 77 | 78 | with torch.no_grad(): 79 | coords = torch.transpose(coords, 1, 2) 80 | logits = Net(coords, coords_len) 81 | prediction = torch.argmax(logits, 1) 82 | correct += (prediction == character_id.long()).sum().float() 83 | total += len(coords) 84 | acc = (correct/total).cpu().numpy() 85 | return acc -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from PIL import ImageDraw, Image 5 | 6 | 7 | ''' 8 | description: Normalize the xy-coordinates into a standard interval. 9 | Refer to "Drawing and Recognizing Chinese Characters with Recurrent Neural Network". 10 | ''' 11 | def normalize_xys(xys): 12 | stroken_state = np.cumsum(np.concatenate((np.array([0]), xys[:, -2]))[:-1]) 13 | px_sum = py_sum = len_sum = 0 14 | for ptr_idx in range(0, xys.shape[0] - 2): 15 | if stroken_state[ptr_idx] == stroken_state[ptr_idx + 1]: 16 | xy_1, xy = xys[ptr_idx][:2], xys[ptr_idx + 1][:2] 17 | temp_len = np.sqrt(np.sum(np.power(xy - xy_1, 2))) 18 | temp_px, temp_py = temp_len * (xy_1 + xy) / 2 19 | px_sum += temp_px 20 | py_sum += temp_py 21 | len_sum += temp_len 22 | if len_sum==0: 23 | raise Exception("Broken online characters") 24 | else: 25 | pass 26 | 27 | mux, muy = px_sum / len_sum, py_sum / len_sum 28 | dx_sum, dy_sum = 0, 0 29 | for ptr_idx in range(0, xys.shape[0] - 2): 30 | if stroken_state[ptr_idx] == stroken_state[ptr_idx + 1]: 31 | xy_1, xy = xys[ptr_idx][:2], xys[ptr_idx + 1][:2] 32 | temp_len = np.sqrt(np.sum(np.power(xy - xy_1, 2))) 33 | temp_dx = temp_len * ( 34 | np.power(xy_1[0] - mux, 2) + np.power(xy[0] - mux, 2) + (xy_1[0] - mux) * (xy[0] - mux)) / 3 35 | temp_dy = temp_len * ( 36 | np.power(xy_1[1] - muy, 2) + np.power(xy[1] - muy, 2) + (xy_1[1] - muy) * (xy[1] - muy)) / 3 37 | dx_sum += temp_dx 38 | dy_sum += temp_dy 39 | sigma = np.sqrt(dx_sum / len_sum) 40 | if sigma == 0: 41 | sigma = np.sqrt(dy_sum / len_sum) 42 | xys[:, 0], xys[:, 1] = (xys[:, 0] - mux) / sigma, (xys[:, 1] - muy) / sigma 43 | return xys 44 | 45 | ''' 46 | description: Rendering offline character images by connecting coordinate points 47 | ''' 48 | def coords_render(coordinates, split, width, height, thickness, board=5): 49 | canvas_w = width 50 | canvas_h = height 51 | board_w = board 52 | board_h = board 53 | # preprocess canvas size 54 | p_canvas_w = canvas_w - 2*board_w 55 | p_canvas_h = canvas_h - 2*board_h 56 | 57 | # find original character size to fit with canvas 58 | min_x = 635535 59 | min_y = 635535 60 | max_x = -1 61 | max_y = -1 62 | 63 | coordinates[:, 0] = np.cumsum(coordinates[:, 0]) 64 | coordinates[:, 1] = np.cumsum(coordinates[:, 1]) 65 | if split: 66 | ids = np.where(coordinates[:, -1] == 1)[0] 67 | if len(ids) < 1: ### if not exist [0, 0, 1] 68 | ids = np.where(coordinates[:, 3] == 1)[0] + 1 69 | if len(ids) < 1: ### if not exist [0, 1, 0] 70 | ids = np.array([len(coordinates)]) 71 | xys_split = np.split(coordinates, ids, axis=0)[:-1] # remove the blank list 72 | else: 73 | xys_split = np.split(coordinates, ids, axis=0) 74 | else: ### if exist [0, 0, 1] 75 | remove_end = np.split(coordinates, ids, axis=0)[0] 76 | ids = np.where(remove_end[:, 3] == 1)[0] + 1 ### break in [0, 1, 0] 77 | xys_split = np.split(remove_end, ids, axis=0) 78 | else: 79 | pass 80 | for stroke in xys_split: 81 | for (x, y) in stroke[:, :2].reshape((-1, 2)): 82 | min_x = min(x, min_x) 83 | max_x = max(x, max_x) 84 | min_y = min(y, min_y) 85 | max_y = max(y, max_y) 86 | original_size = max(max_x-min_x, max_y-min_y) 87 | canvas = Image.new(mode='L', size=(canvas_w, canvas_h), color=255) 88 | draw = ImageDraw.Draw(canvas) 89 | 90 | for stroke in xys_split: 91 | xs, ys = stroke[:, 0], stroke[:, 1] 92 | xys = np.stack([xs, ys], axis=-1).reshape(-1) 93 | xys[::2] = (xys[::2]-min_x) / original_size * p_canvas_w + board_w 94 | xys[1::2] = (xys[1::2] - min_y) / original_size * p_canvas_h + board_h 95 | xys = np.round(xys) 96 | draw.line(xys.tolist(), fill=0, width=thickness) 97 | return canvas 98 | 99 | # fix random seeds for reproducibility 100 | def fix_seed(random_seed): 101 | random.seed(random_seed) 102 | np.random.seed(random_seed) 103 | torch.backends.cudnn.deterministic = True 104 | torch.backends.cudnn.benchmark = False 105 | if torch.cuda.device_count() > 0 and torch.cuda.is_available(): 106 | torch.cuda.manual_seed_all(random_seed) 107 | else: 108 | torch.manual_seed(random_seed) 109 | 110 | ### model loads specific parameters (i.e., par) from pretrained_model 111 | def load_specific_dict(model, pretrained_model, par): 112 | model_dict = model.state_dict() 113 | pretrained_dict = torch.load(pretrained_model) 114 | if par in list(pretrained_dict.keys())[0]: 115 | count = len(par) + 1 116 | pretrained_dict = {k[count:]: v for k, v in pretrained_dict.items() if k[count:] in model_dict} 117 | else: 118 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 119 | if len(pretrained_dict) > 0: 120 | model_dict.update(pretrained_dict) 121 | else: 122 | return ValueError 123 | return model_dict 124 | 125 | 126 | def writeCache(env, cache): 127 | with env.begin(write=True) as txn: 128 | for k, v in cache.items(): 129 | txn.put(k, v) 130 | 131 | 132 | ''' 133 | description: convert the np version of coordinates to the list counterpart 134 | ''' 135 | def dxdynp_to_list(coordinates): 136 | ids = np.where(coordinates[:, -1] == 1)[0] 137 | length = coordinates[:, 2:4].sum() 138 | if len(ids) < 1: ### if not exist [0, 0, 1] 139 | ids = np.where(coordinates[:, 3] == 1)[0] + 1 140 | if len(ids) < 1: ### if not exist [0, 1, 0] 141 | ids = np.array([len(coordinates)]) 142 | xys_split = np.split(coordinates, ids, axis=0)[:-1] # remove the blank list 143 | else: 144 | xys_split = np.split(coordinates, ids, axis=0) 145 | else: ### if exist [0, 0, 1] 146 | remove_end = np.split(coordinates, ids, axis=0)[0] 147 | ids = np.where(remove_end[:, 3] == 1)[0] + 1 ### break in [0, 1, 0] 148 | xys_split = np.split(remove_end, ids, axis=0)[:-1] # split from the remove_end 149 | 150 | coord_list = [] 151 | for stroke in xys_split: 152 | xs, ys = stroke[:, 0], stroke[:, 1] 153 | if len(xs) > 0: 154 | xys = np.stack([xs, ys], axis=-1).reshape(-1) 155 | coord_list.append(xys) 156 | else: 157 | pass 158 | return coord_list, length 159 | 160 | ''' 161 | description: 162 | [x, y] --> [x, y, p1, p2, p3] 163 | see 'A NEURAL REPRESENTATION OF SKETCH DRAWINGS' for more details 164 | ''' 165 | def corrds2xys(coordinates): 166 | new_strokes = [] 167 | for stroke in coordinates: 168 | for (x, y) in np.array(stroke).reshape((-1, 2)): 169 | p = np.array([x, y, 1, 0, 0], np.float32) 170 | new_strokes.append(p) 171 | try: 172 | new_strokes[-1][2:] = [0, 1, 0] # set the end of a stroke 173 | except IndexError: 174 | print(stroke) 175 | return None 176 | new_strokes = np.stack(new_strokes, axis=0) 177 | return new_strokes --------------------------------------------------------------------------------