├── .gitignore ├── Groundtruth ├── gan.iam.test.gt.filter27 └── gan.iam.tr_va.gt.filter27 ├── README.md ├── corpora_english ├── brown-azAZ.tr ├── in_vocab.subset.tro.37 ├── oov.common_words └── oov_words.txt ├── create_style_sample.py ├── data ├── create_data.py ├── dataset.py ├── iam_test.py └── show_dataset.py ├── files └── english_words.txt ├── generate.py ├── generate ├── __init__.py ├── authors.py ├── fid.py ├── ocr.py ├── page.py ├── text.py ├── util.py └── writer.py ├── models ├── BigGAN_layers.py ├── BigGAN_networks.py ├── OCR_network.py ├── __init__.py ├── blocks.py ├── config.py ├── inception.py ├── model.py ├── networks.py ├── positional_encodings.py ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── transformer.py └── unifont_module.py ├── mytext.txt ├── requirements.txt ├── train.py └── util ├── __init__.py ├── augmentations.py ├── loading.py ├── misc.py ├── text.py ├── util.py └── vision.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | files/*.pickle 3 | saved_models/IAM-339-15-E3D3-LR5e-05-bs8-finetune/vatr.pth 4 | files/style_samples/* 5 | .idea/ 6 | files/tired_* 7 | *.pyc 8 | logs/ 9 | saved_images/ 10 | files/*.pth 11 | *.pth 12 | saved_models/* 13 | files/stats/* 14 | files/*.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Handwritten Text Generation from Visual Archetypes ++ 2 | 3 | This repository includes the code for training the VATr++ Styled Handwritten Text Generation model. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | conda create --name vatr python=3.9 9 | conda activate vatr 10 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia 11 | git clone https://github.com/aimagelab/VATr.git && cd VATr 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | [This folder](https://drive.google.com/drive/folders/13rJhjl7VsyiXlPTBvnp1EKkKEhckLalr?usp=sharing) contains the regular IAM dataset `IAM-32.pickle` and the modified version with attached punctuation marks `IAM-32-pa.pickle`. 16 | The folder also contains the synthetically pretrained weights for the encoder `resnet_18_pretrained.pth`. 17 | Please download these files and place them into the `files` folder. 18 | 19 | ## Training 20 | 21 | To train the regular VATr model, use the following command. This uses the default settings from the paper. 22 | 23 | ```bash 24 | python train.py 25 | ``` 26 | 27 | Useful arguments: 28 | ```bash 29 | python train.py 30 | --feat_model_path PATH # path to the pretrained resnet 18 checkpoint. By default this is the synthetically pretrained model 31 | --is_cycle # use style cycle loss for training 32 | --dataset DATASET # dataset to use. Default IAM 33 | --resume # resume training from the last checkpoint with the same name 34 | --wandb # use wandb for logging 35 | ``` 36 | 37 | Use the following arguments to apply full VATr++ training 38 | ```bash 39 | python train.py 40 | --d-crop-size 64 128 # Randomly crop input to discriminator to width 64 to 128 41 | --text-augment-strength 0.4 # Text augmentation for adding more rare characters 42 | --file-suffix pa # Use the punctuation attached version of IAM 43 | --augment-ocr # Augment the real images used to train the OCR model 44 | ``` 45 | 46 | ### Pretraining dataset 47 | The model `resnet_18_pretrained.pth` was pretrained by using this dataset: [Font Square](https://github.com/aimagelab/font_square) 48 | 49 | 50 | ## Generate Styled Handwritten Text Images 51 | 52 | We added some utility to generate handwritten text images using the trained model. These are used as follows: 53 | 54 | ```bash 55 | python generate.py [ACTION] --checkpoint files/vatrpp.pth 56 | ``` 57 | 58 | The following actions are available with their respective arguments. 59 | 60 | ### Custom Author 61 | 62 | Generate the given text for a custom author. 63 | 64 | ```bash 65 | text --text STRING # String to generate 66 | --text-path PATH # Optional path to text file 67 | --output PATH # Optional output location, default: files/output.png 68 | --style-folder PATH # Optional style folder containing writer samples, default: 'files/style_samples/00' 69 | ``` 70 | Style samples for the author are needed. These can be automatically generated from an image of a page using `create_style_sample.py`. 71 | ```bash 72 | python create_style_sample.py --input-image PATH # Path of the image to extract the style samples from. 73 | --output-folder PATH # Folder where the style samples should be saved 74 | ``` 75 | 76 | ### All Authors 77 | 78 | Generate some text for all authors of IAM. The output is saved to `saved_images/author_samples/` 79 | 80 | ```bash 81 | authors --test-set # Generate authors of test set, otherwise training set is generated 82 | --checkpoint PATH # Checkpoint used to generate text, files/vatr.pth by default 83 | --align # Detect the bottom lines for each word and align them 84 | --at-once # Generate the whole sentence at once instead of word-by-word 85 | --output-style # Also save the style images used to generate the words 86 | ``` 87 | 88 | ### Evaluation Images 89 | 90 | ```bash 91 | fid --target_dataset_path PATH # dataset file for which the test set will be generated 92 | --dataset-path PATH # dataset file from which style samples will be taken, for example the attached punctuation 93 | --output PATH # where to save the images, default is saved_images/fid 94 | --checkpoint PATH # Checkpoint used to generate text, files/vatr.pth by default 95 | --all-epochs # Generate evaluation images for all saved epochs available (checkpoint has to be a folder) 96 | --fake-only # Only output fake images, no ground truth 97 | --test-only # Only generate test set, not train set 98 | --long-tail # Only generate words containing long tail characters 99 | ``` -------------------------------------------------------------------------------- /corpora_english/in_vocab.subset.tro.37: -------------------------------------------------------------------------------- 1 | accents 2 | fifty 3 | gross 4 | Tea 5 | whom 6 | renamed 7 | Heaven 8 | Harry 9 | arrange 10 | captain 11 | why 12 | Father 13 | beaten 14 | Bar 15 | base 16 | creamy 17 | About 18 | Allies 19 | sound 20 | farmers 21 | anyone 22 | steel 23 | Mary 24 | used 25 | fever 26 | looking 27 | lately 28 | returns 29 | humans 30 | finals 31 | beyond 32 | lots 33 | waiting 34 | cited 35 | measure 36 | posse 37 | blow 38 | blonde 39 | twice 40 | Having 41 | compels 42 | rooms 43 | cocked 44 | virtual 45 | dying 46 | tons 47 | Travel 48 | idea 49 | gripped 50 | Act 51 | reign 52 | moods 53 | altered 54 | sample 55 | Soviet 56 | thick 57 | enigma 58 | here 59 | egghead 60 | Public 61 | Bryan 62 | porous 63 | estate 64 | guilty 65 | Caught 66 | Lucas 67 | observe 68 | mouth 69 | pricked 70 | obscure 71 | casual 72 | take 73 | home 74 | amber 75 | weekend 76 | forming 77 | aid 78 | outlook 79 | uniting 80 | But 81 | earnest 82 | bear 83 | news 84 | sparked 85 | merrily 86 | extreme 87 | North 88 | damned 89 | big 90 | bosses 91 | context 92 | easily 93 | took 94 | hurried 95 | Gene 96 | due 97 | deserve 98 | cult 99 | leisure 100 | critics 101 | parish 102 | Music 103 | charge 104 | grey 105 | Privy 106 | Fred 107 | massive 108 | others 109 | shirt 110 | average 111 | warning 112 | Tuesday 113 | locked 114 | possess 115 | -------------------------------------------------------------------------------- /corpora_english/oov.common_words: -------------------------------------------------------------------------------- 1 | planets 2 | lips 3 | varies 4 | impact 5 | skips 6 | Gold 7 | maple 8 | voyager 9 | noisy 10 | stick 11 | forums 12 | drafts 13 | crimson 14 | sever 15 | rackets 16 | sexy 17 | humming 18 | cheated 19 | lick 20 | grades 21 | heroic 22 | Clever 23 | foul 24 | mood 25 | warrior 26 | Morning 27 | poetic 28 | nodding 29 | certify 30 | reviews 31 | mosaics 32 | senders 33 | Isle 34 | Lied 35 | sand 36 | Weight 37 | writer 38 | trusts 39 | slot 40 | eaten 41 | squares 42 | lists 43 | vary 44 | witches 45 | compose 46 | demons 47 | therapy 48 | focus 49 | sticks 50 | Whose 51 | bumped 52 | visibly 53 | redeem 54 | arsenal 55 | lunatic 56 | Similar 57 | Bug 58 | adheres 59 | trail 60 | robbing 61 | Whisky 62 | super 63 | screwed 64 | Flower 65 | salads 66 | Glow 67 | Vapor 68 | Married 69 | recieve 70 | handle 71 | push 72 | card 73 | skiing 74 | lotus 75 | cloud 76 | windy 77 | monkey 78 | virus 79 | thunder 80 | -------------------------------------------------------------------------------- /corpora_english/oov_words.txt: -------------------------------------------------------------------------------- 1 | planets 2 | lips 3 | varies 4 | impact 5 | skips 6 | Gold 7 | maple 8 | voyager 9 | noisy 10 | stick 11 | forums 12 | drafts 13 | crimson 14 | sever 15 | rackets 16 | sexy 17 | humming 18 | cheated 19 | lick 20 | grades 21 | heroic 22 | Clever 23 | foul 24 | mood 25 | warrior 26 | Morning 27 | poetic 28 | nodding 29 | certify 30 | reviews 31 | mosaics 32 | senders 33 | Isle 34 | Lied 35 | sand 36 | Weight 37 | writer 38 | trusts 39 | slot 40 | eaten 41 | squares 42 | lists 43 | vary 44 | witches 45 | compose 46 | demons 47 | therapy 48 | focus 49 | sticks 50 | Whose 51 | bumped 52 | visibly 53 | redeem 54 | arsenal 55 | lunatic 56 | Similar 57 | Bug 58 | adheres 59 | trail 60 | robbing 61 | Whisky 62 | super 63 | screwed 64 | Flower 65 | salads 66 | Glow 67 | Vapor 68 | Married 69 | recieve 70 | handle 71 | push 72 | card 73 | skiing 74 | lotus 75 | cloud 76 | windy 77 | monkey 78 | virus 79 | thunder 80 | Keegan 81 | purling 82 | Orpheus 83 | Prence 84 | Yin 85 | Kansas 86 | jowls 87 | Alabama 88 | Szold 89 | Chou 90 | Orange 91 | suspend 92 | barred 93 | deceit 94 | reward 95 | soy 96 | Vail 97 | lad 98 | Loesser 99 | Hutton 100 | jerks 101 | yelling 102 | Heywood 103 | sacker 104 | comest 105 | tense 106 | par 107 | fiend 108 | Soiree 109 | voted 110 | Putting 111 | pansy 112 | doormen 113 | mayor 114 | Owens 115 | noting 116 | pauses 117 | USP 118 | crudely 119 | grooved 120 | furor 121 | ignited 122 | kittens 123 | broader 124 | slang 125 | ballets 126 | quacked 127 | Paulus 128 | Castles 129 | upswing 130 | dabbled 131 | Animals 132 | Kidder 133 | Writers 134 | laces 135 | bled 136 | scoped 137 | yield 138 | scoured 139 | Schenk 140 | Wratten 141 | Menfolk 142 | foamy 143 | scratch 144 | minced 145 | nudged 146 | Seats 147 | Judging 148 | Turbine 149 | Strict 150 | whined 151 | crupper 152 | Dussa 153 | finned 154 | voter 155 | Jacobs 156 | calmly 157 | hip 158 | clubs 159 | quintet 160 | blunts 161 | Grazie 162 | Barton 163 | NAB 164 | specie 165 | Fonta 166 | narrow 167 | Swan 168 | denials 169 | Rawson 170 | potato 171 | Choral 172 | diverse 173 | Educate 174 | unities 175 | Ferry 176 | Bonner 177 | manuals 178 | NAIR 179 | imputed 180 | initial 181 | wallet 182 | Sesame 183 | maroon 184 | Related 185 | Quiney 186 | Monster 187 | brainy 188 | Nolan 189 | Thrifty 190 | Tel 191 | Ye 192 | Sumter 193 | Bonnet 194 | sheepe 195 | nagged 196 | ribbing 197 | hunt 198 | AA 199 | Pohly 200 | triol 201 | saws 202 | popped 203 | aloof 204 | Ceramic 205 | thong 206 | typed 207 | broadly 208 | Figures 209 | riddle 210 | Otis 211 | Sainted 212 | upbeat 213 | Getting 214 | hisself 215 | junta 216 | Labans 217 | starter 218 | coward 219 | Anthea 220 | hurlers 221 | Dervish 222 | Turin 223 | oud 224 | tyranny 225 | Rotary 226 | Veneto 227 | pulls 228 | bowl 229 | utopias 230 | auburn 231 | osmotic 232 | myrtle 233 | furrow 234 | laws 235 | Uh 236 | Hodges 237 | Wilde 238 | Neck 239 | snaked 240 | decorum 241 | edema 242 | Dunston 243 | clinics 244 | Abide 245 | Dover 246 | voltaic 247 | Modern 248 | Farr 249 | thaw 250 | moi 251 | leaning 252 | wedlock 253 | Carson 254 | star 255 | Hymn 256 | Stack 257 | genes 258 | Shayne 259 | Moune 260 | slipped 261 | legatee 262 | coerced 263 | Gates 264 | pulse 265 | Granny 266 | bat 267 | Fruit 268 | Cadesi 269 | Tee 270 | Dreiser 271 | Getz 272 | Ways 273 | cogs 274 | hydrous 275 | sweep 276 | quarrel 277 | mobcaps 278 | slash 279 | throats 280 | Royaux 281 | cafes 282 | crusher 283 | rusted 284 | Eskimo 285 | slatted 286 | pallet 287 | yelps 288 | slanted 289 | confide 290 | Gomez 291 | untidy 292 | Sigmund 293 | Marine 294 | roll 295 | NRL 296 | Dukes 297 | tumours 298 | LP 299 | turtles 300 | audible 301 | Woodrow 302 | retreat 303 | Orders 304 | Conlow 305 | hobby 306 | skin 307 | tally 308 | frosted 309 | drowned 310 | wedged 311 | queen 312 | poised 313 | eluded 314 | Letter 315 | ticking 316 | kill 317 | rancor 318 | Plant 319 | Brandel 320 | Willows 321 | riddles 322 | carven 323 | Spiller 324 | yen 325 | jerky 326 | tenure 327 | daubed 328 | Serves 329 | pimpled 330 | ACTH 331 | ruh 332 | afield 333 | suffuse 334 | muffins 335 | Miners 336 | Cabrini 337 | weakly 338 | upriver 339 | Newsom 340 | Meeker 341 | weed 342 | fiscal 343 | Diane 344 | Errors 345 | Mig 346 | biz 347 | Drink 348 | chop 349 | Bumbry 350 | Babin 351 | optimum 352 | Leyden 353 | enrage 354 | induces 355 | newel 356 | trim 357 | bolts 358 | frog 359 | cinder 360 | Lo 361 | clobber 362 | Mennen 363 | Othon 364 | Ocean 365 | jerking 366 | engine 367 | Belasco 368 | hero 369 | flora 370 | Injuns 371 | Rico 372 | Gary 373 | snake 374 | hating 375 | Suggs 376 | booze 377 | Lescaut 378 | Molard 379 | startle 380 | Aggie 381 | lengthy 382 | Shoals 383 | ideals 384 | Zen 385 | stem 386 | noon 387 | hoes 388 | Seafood 389 | yuh 390 | Mostly 391 | seeds 392 | bestow 393 | acetate 394 | jokers 395 | waning 396 | volumes 397 | ein 398 | Rich 399 | Galt 400 | pasted -------------------------------------------------------------------------------- /create_style_sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import cv2 5 | from util.vision import get_page, get_words 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--input-image", type=str, required=True) 12 | parser.add_argument("--output-folder", type=str, required=True, default='files/style_samples/00') 13 | 14 | args = parser.parse_args() 15 | 16 | image = cv2.imread(args.input_image) 17 | image = cv2.resize(image, (image.shape[1], image.shape[0])) 18 | result = get_page(image) 19 | words, _ = get_words(result) 20 | 21 | output_path = args.output_folder 22 | if not os.path.exists(output_path): 23 | os.mkdir(output_path) 24 | for i, word in enumerate(words): 25 | cv2.imwrite(os.path.join(output_path, f"word{i}.png"), word) 26 | -------------------------------------------------------------------------------- /data/create_data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import os 4 | import pickle 5 | import random 6 | from collections import defaultdict 7 | 8 | import PIL 9 | import cv2 10 | import numpy as np 11 | from PIL import Image 12 | 13 | 14 | TO_MERGE = { 15 | '.': 'left', 16 | ',': 'left', 17 | '!': 'left', 18 | '?': 'left', 19 | '(': 'right', 20 | ')': 'left', 21 | '\"': 'random', 22 | "\'": 'random', 23 | ":": 'left', 24 | ";": 'left', 25 | "-": 'random' 26 | } 27 | 28 | FILTER_ERR = False 29 | 30 | 31 | def resize(image, size): 32 | image_pil = Image.fromarray(image.astype('uint8'), 'L') 33 | image_pil = image_pil.resize(size) 34 | return np.array(image_pil) 35 | 36 | 37 | def get_author_ids(base_folder: str): 38 | with open(os.path.join(base_folder, "gan.iam.tr_va.gt.filter27"), 'r') as f: 39 | training_authors = [line.split(",")[0] for line in f] 40 | training_authors = set(training_authors) 41 | 42 | with open(os.path.join(base_folder, "gan.iam.test.gt.filter27"), 'r') as f: 43 | test_authors = [line.split(",")[0] for line in f] 44 | test_authors = set(test_authors) 45 | 46 | assert len(training_authors.intersection(test_authors)) == 0 47 | 48 | return training_authors, test_authors 49 | 50 | 51 | class IAMImage: 52 | def __init__(self, image: np.array, label: str, image_id: int, line_id: str, bbox: list = None, iam_image_id: str = None): 53 | self.image = image 54 | self.label = label 55 | self.image_id = image_id 56 | self.line_id = line_id 57 | self.iam_image_id = iam_image_id 58 | self.has_bbox = False 59 | if bbox is not None: 60 | self.has_bbox = True 61 | self.x, self.y, self.w, self.h = bbox 62 | 63 | def merge(self, other: 'IAMImage'): 64 | global MERGER_COUNT 65 | assert self.has_bbox, "IAM image has no bounding box information" 66 | y = min(self.y, other.y) 67 | h = max(other.y + other.h, self.y + self.h) - y 68 | 69 | x = min(self.x, other.x) 70 | w = max(self.x + self.w, other.x + other.w) - x 71 | 72 | new_image = np.ones((h, w), dtype=self.image.dtype) * 255 73 | 74 | anchor_x = self.x - x 75 | anchor_y = self.y - y 76 | new_image[anchor_y:anchor_y + self.h, anchor_x:anchor_x + self.w] = self.image 77 | 78 | anchor_x = other.x - x 79 | anchor_y = other.y - y 80 | new_image[anchor_y:anchor_y + other.h, anchor_x:anchor_x + other.w] = other.image 81 | 82 | if other.x - (self.x + self.w) > 50: 83 | new_label = self.label + " " + other.label 84 | else: 85 | new_label = self.label + other.label 86 | new_id = self.image_id 87 | new_bbox = [x, y, w, h] 88 | 89 | new_iam_image_id = self.iam_image_id if len(self.label) > len(other.label) else other.iam_image_id 90 | return IAMImage(new_image, new_label, new_id, self.line_id, new_bbox, iam_image_id=new_iam_image_id) 91 | 92 | 93 | def read_iam_lines(base_folder: str) -> dict: 94 | form_to_author = {} 95 | with open(os.path.join(base_folder, "forms.txt"), 'r') as f: 96 | for line in f: 97 | if not line.startswith("#"): 98 | form, author, *_ = line.split(" ") 99 | form_to_author[form] = author 100 | 101 | training_authors, test_authors = get_author_ids(base_folder) 102 | 103 | dataset_dict = { 104 | 'train': defaultdict(list), 105 | 'test': defaultdict(list), 106 | 'other': defaultdict(list) 107 | } 108 | 109 | image_count = 0 110 | 111 | with open(os.path.join(base_folder, "sentences.txt"), 'r') as f: 112 | for line in f: 113 | if not line.startswith("#"): 114 | line_id, _, ok, *_, label = line.rstrip().split(" ") 115 | form_id = "-".join(line_id.split("-")[:2]) 116 | author_id = form_to_author[form_id] 117 | 118 | if ok != 'ok' and FILTER_ERR: 119 | continue 120 | 121 | line_label = "" 122 | for word in label.split("|"): 123 | if not(len(line_label) == 0 or word in [".", ","]): 124 | line_label += " " 125 | line_label += word 126 | 127 | image_path = os.path.join(base_folder, "sentences", form_id.split("-")[0], form_id, f"{line_id}.png") 128 | 129 | subset = 'other' 130 | if author_id in training_authors: 131 | subset = 'train' 132 | elif author_id in test_authors: 133 | subset = 'test' 134 | 135 | im = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 136 | if im is not None and im.size > 1: 137 | dataset_dict[subset][author_id].append(IAMImage( 138 | im, line_label, image_count, line_id, None 139 | )) 140 | image_count += 1 141 | 142 | return dataset_dict 143 | 144 | 145 | def read_iam(base_folder: str) -> dict: 146 | with open(os.path.join(base_folder, "forms.txt"), 'r') as f: 147 | forms = [line.rstrip() for line in f if not line.startswith("#")] 148 | 149 | training_authors, test_authors = get_author_ids(base_folder) 150 | 151 | image_info = {} 152 | with open(os.path.join(base_folder, "words.txt"), 'r') as f: 153 | for line in f: 154 | if not line.startswith("#"): 155 | image_id, ok, threshold, x, y, w, h, tag, *content = line.rstrip().split(" ") 156 | image_info[image_id] = { 157 | 'ok': ok == 'ok', 158 | 'threshold': threshold, 159 | 'content': " ".join(content) if isinstance(content, list) else content, 160 | 'bbox': [int(x), int(y), int(w), int(h)] 161 | } 162 | 163 | dataset_dict = { 164 | 'train': defaultdict(list), 165 | 'test': defaultdict(list), 166 | 'other': defaultdict(list) 167 | } 168 | 169 | image_count = 0 170 | err_count = 0 171 | 172 | for form in forms: 173 | form_id, writer_id, *_ = form.split(" ") 174 | base_form = form_id.split("-")[0] 175 | 176 | form_path = os.path.join(base_folder, "words", base_form, form_id) 177 | 178 | for image_name in os.listdir(form_path): 179 | image_id = image_name.split(".")[0] 180 | info = image_info[image_id] 181 | 182 | subset = 'other' 183 | if writer_id in training_authors: 184 | subset = 'train' 185 | elif writer_id in test_authors: 186 | subset = 'test' 187 | 188 | if info['ok'] or not FILTER_ERR: 189 | im = cv2.imread(os.path.join(form_path, image_name), cv2.IMREAD_GRAYSCALE) 190 | if not info['ok'] and False: 191 | cv2.destroyAllWindows() 192 | print(info['content']) 193 | cv2.imshow("image", im) 194 | cv2.waitKey(0) 195 | 196 | if im is not None and im.size > 1: 197 | dataset_dict[subset][writer_id].append(IAMImage( 198 | im, info['content'], image_count, "-".join(image_id.split("-")[:3]), info['bbox'], iam_image_id=image_id 199 | )) 200 | image_count += 1 201 | else: 202 | err_count += 1 203 | print(f"Could not read image {image_name}, skipping") 204 | else: 205 | err_count += 1 206 | 207 | assert not dataset_dict['train'].keys() & dataset_dict['test'].keys(), "Training and Testing set have common authors" 208 | 209 | print(f"Skipped images: {err_count}") 210 | 211 | return dataset_dict 212 | 213 | 214 | def read_cvl_set(set_folder: str): 215 | set_images = defaultdict(list) 216 | words_path = os.path.join(set_folder, "words") 217 | 218 | image_id = 0 219 | 220 | for author_id in os.listdir(words_path): 221 | author_path = os.path.join(words_path, author_id) 222 | 223 | for image_file in os.listdir(author_path): 224 | label = image_file.split("-")[-1].split(".")[0] 225 | line_id = "-".join(image_file.split("-")[:-2]) 226 | 227 | stream = open(os.path.join(author_path, image_file), "rb") 228 | bytes = bytearray(stream.read()) 229 | numpyarray = np.asarray(bytes, dtype=np.uint8) 230 | image = cv2.imdecode(numpyarray, cv2.IMREAD_UNCHANGED) 231 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 232 | if image is not None and image.size > 1: 233 | set_images[int(author_id)].append(IAMImage(image, label, image_id, line_id)) 234 | image_id += 1 235 | 236 | return set_images 237 | 238 | 239 | def read_cvl(base_folder: str): 240 | dataset_dict = { 241 | 'test': read_cvl_set(os.path.join(base_folder, 'testset')), 242 | 'train': read_cvl_set(os.path.join(base_folder, 'trainset')) 243 | } 244 | 245 | assert not dataset_dict['train'].keys() & dataset_dict[ 246 | 'test'].keys(), "Training and Testing set have common authors" 247 | 248 | return dataset_dict 249 | 250 | def pad_top(image: np.array, height: int) -> np.array: 251 | result = np.ones((height, image.shape[1]), dtype=np.uint8) * 255 252 | result[height - image.shape[0]:, :image.shape[1]] = image 253 | 254 | return result 255 | 256 | 257 | def scale_per_writer(writer_dict: dict, target_height: int, char_width: int = None) -> dict: 258 | for author_id in writer_dict.keys(): 259 | max_height = max([image_dict.image.shape[0] for image_dict in writer_dict[author_id]]) 260 | scale_y = target_height / max_height 261 | 262 | for image_dict in writer_dict[author_id]: 263 | image = image_dict.image 264 | scale_x = scale_y if char_width is None else len(image_dict.label) * char_width / image_dict.image.shape[1] 265 | #image = cv2.resize(image, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC) 266 | image = resize(image, (int(image.shape[1] * scale_x), int(image.shape[0] * scale_y))) 267 | image_dict.image = pad_top(image, target_height) 268 | 269 | return writer_dict 270 | 271 | 272 | def scale_images(writer_dict: dict, target_height: int, char_width: int = None) -> dict: 273 | for author_id in writer_dict.keys(): 274 | for image_dict in writer_dict[author_id]: 275 | scale_y = target_height / image_dict.image.shape[0] 276 | scale_x = scale_y if char_width is None else len(image_dict.label) * char_width / image_dict.image.shape[1] 277 | #image_dict.image = cv2.resize(image_dict.image, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC) 278 | image_dict.image = resize(image_dict.image, (int(image_dict.image.shape[1] * scale_x), target_height)) 279 | return writer_dict 280 | 281 | 282 | def scale_word_width(writer_dict: dict): 283 | for author_id in writer_dict.keys(): 284 | for image_dict in writer_dict[author_id]: 285 | width = len(image_dict.label) * (image_dict.image.shape[0] / 2.0) 286 | image_dict.image = resize(image_dict.image, (int(width), image_dict.image.shape[0])) 287 | return writer_dict 288 | 289 | 290 | def get_sentences(author_dict: dict): 291 | collected = defaultdict(list) 292 | for image in author_dict: 293 | collected[image.line_id].append(image) 294 | 295 | return [v for k, v in collected.items()] 296 | 297 | 298 | def merge_author_words(author_words): 299 | def try_left_merge(index: int): 300 | if index > 0 and author_words[index - 1].line_id == author_words[index].line_id and not to_remove[index - 1] and not author_words[index - 1].label in TO_MERGE.keys(): 301 | merged = author_words[index - 1].merge(author_words[index]) 302 | author_words[index - 1] = merged 303 | to_remove[index] = True 304 | return True 305 | return False 306 | 307 | def try_right_merge(index: int): 308 | if index < len(author_words) - 1 and author_words[index].line_id == author_words[index + 1].line_id and not to_remove[index + 1] and not author_words[index + 1].label in TO_MERGE.keys(): 309 | merged = iam_image.merge(author_words[index + 1]) 310 | author_words[index + 1] = merged 311 | to_remove[index] = True 312 | return True 313 | return False 314 | 315 | to_remove = [False for _ in range(len(author_words))] 316 | for i in range(len(author_words)): 317 | iam_image = author_words[i] 318 | if iam_image.label in TO_MERGE.keys(): 319 | merge_type = TO_MERGE[iam_image.label] if TO_MERGE[iam_image.label] != 'random' else random.choice(['left', 'right']) 320 | if merge_type == 'left': 321 | if not try_left_merge(i): 322 | if not try_right_merge(i): 323 | print(f"Could not merge char: {iam_image.label}") 324 | else: 325 | if not try_right_merge(i): 326 | if not try_left_merge(i): 327 | print(f"Could not merge char: {iam_image.label}") 328 | 329 | return [image for image, remove in zip(author_words, to_remove) if not remove], sum(to_remove) 330 | 331 | 332 | def merge_punctuation(writer_dict: dict) -> dict: 333 | for author_id in writer_dict.keys(): 334 | author_dict = writer_dict[author_id] 335 | 336 | merged = 1 337 | while merged > 0: 338 | author_dict, merged = merge_author_words(author_dict) 339 | 340 | writer_dict[author_id] = author_dict 341 | 342 | return writer_dict 343 | 344 | 345 | def filter_punctuation(writer_dict: dict) -> dict: 346 | for author_id in writer_dict.keys(): 347 | author_list = [im for im in writer_dict[author_id] if im.label not in TO_MERGE.keys()] 348 | 349 | writer_dict[author_id] = author_list 350 | 351 | return writer_dict 352 | 353 | 354 | def filter_by_width(writer_dict: dict, target_height: int = 32, min_width: int = 16, max_width: int = 17) -> dict: 355 | def is_valid(iam_image: IAMImage) -> bool: 356 | target_width = (target_height / iam_image.image.shape[0]) * iam_image.image.shape[1] 357 | if len(iam_image.label) * min_width / 3 <= target_width <= len(iam_image.label) * max_width * 3: 358 | return True 359 | else: 360 | return False 361 | 362 | for author_id in writer_dict.keys(): 363 | author_list = [im for im in writer_dict[author_id] if is_valid(im)] 364 | 365 | writer_dict[author_id] = author_list 366 | 367 | return writer_dict 368 | 369 | 370 | def write_data(dataset_dict: dict, location: str, height, punct_mode: str = 'none', author_scale: bool = False, uniform_char_width: bool = False): 371 | assert punct_mode in ['none', 'filter', 'merge'] 372 | result = {} 373 | for key in dataset_dict.keys(): 374 | result[key] = {} 375 | 376 | subset_dict = dataset_dict[key] 377 | 378 | subset_dict = filter_by_width(subset_dict) 379 | 380 | if punct_mode == 'merge': 381 | subset_dict = merge_punctuation(subset_dict) 382 | elif punct_mode == 'filter': 383 | subset_dict = filter_punctuation(subset_dict) 384 | 385 | char_width = 16 if uniform_char_width else None 386 | 387 | if author_scale: 388 | subset_dict = scale_per_writer(subset_dict, height, char_width) 389 | else: 390 | subset_dict = scale_images(subset_dict, height, char_width) 391 | 392 | for author_id in subset_dict: 393 | author_images = [] 394 | for image_dict in subset_dict[author_id]: 395 | author_images.append({ 396 | 'img': PIL.Image.fromarray(image_dict.image), 397 | 'label': image_dict.label, 398 | 'image_id': image_dict.image_id, 399 | 'original_image_id': image_dict.iam_image_id 400 | }) 401 | result[key][author_id] = author_images 402 | 403 | with open(location, 'wb') as f: 404 | pickle.dump(result, f) 405 | 406 | 407 | def write_fid(dataset_dict: dict, location: str): 408 | data = dataset_dict['test'] 409 | data = scale_images(data, 64, None) 410 | for author in data.keys(): 411 | author_folder = os.path.join(location, author) 412 | os.mkdir(author_folder) 413 | count = 0 414 | for image in data[author]: 415 | img = image.image 416 | cv2.imwrite(os.path.join(author_folder, f"{count}.png"), img.squeeze().astype(np.uint8)) 417 | count += 1 418 | 419 | 420 | def write_images_per_author(dataset_dict: dict, output_file: str): 421 | data = dataset_dict["test"] 422 | 423 | result = {} 424 | 425 | for author in data.keys(): 426 | author_images = [image.iam_image_id for image in data[author]] 427 | result[author] = author_images 428 | 429 | with open(output_file, 'w') as f: 430 | json.dump(result, f) 431 | 432 | 433 | def write_words(dataset_dict: dict, output_file): 434 | data = dataset_dict['train'] 435 | 436 | all_words = [] 437 | 438 | for author in data.keys(): 439 | all_words.extend([image.label for image in data[author]]) 440 | 441 | with open(output_file, 'w') as f: 442 | for word in all_words: 443 | f.write(f"{word}\n") 444 | 445 | 446 | if __name__ == "__main__": 447 | data_path = r"D:\Datasets\IAM" 448 | fid_location = r"E:/projects/evaluation/shtg_interface/data/reference_imgs/h64/iam" 449 | height = 32 450 | data_collection = {} 451 | 452 | output_location = r"E:\projects\evaluation\shtg_interface\data\datasets" 453 | 454 | data = read_iam(data_path) 455 | test_data = dict(scale_word_width(data['test'])) 456 | train_data = dict(scale_word_width(data['train'])) 457 | test_data.update(train_data) 458 | for key, value in test_data.items(): 459 | for image_object in value: 460 | if len(image_object.label) <= 0 or image_object.image.size == 0: 461 | continue 462 | data_collection[image_object.iam_image_id] = { 463 | 'img': image_object.image, 464 | 'lbl': image_object.label, 465 | 'author_id': key 466 | } 467 | 468 | with gzip.open(os.path.join(output_location, f"iam_w16_words_data.pkl.gz"), 'wb') as f: 469 | pickle.dump(data_collection, f) 470 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | import os 9 | import pickle 10 | import numpy as np 11 | from PIL import Image 12 | from pathlib import Path 13 | 14 | 15 | def get_dataset_path(dataset_name, height, file_suffix, datasets_path): 16 | if file_suffix is not None: 17 | filename = f'{dataset_name}-{height}-{file_suffix}.pickle' 18 | else: 19 | filename = f'{dataset_name}-{height}.pickle' 20 | 21 | return os.path.join(datasets_path, filename) 22 | 23 | 24 | def get_transform(grayscale=False, convert=True): 25 | transform_list = [] 26 | if grayscale: 27 | transform_list.append(transforms.Grayscale(1)) 28 | 29 | if convert: 30 | transform_list += [transforms.ToTensor()] 31 | if grayscale: 32 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 33 | else: 34 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 35 | 36 | return transforms.Compose(transform_list) 37 | 38 | 39 | class TextDataset: 40 | 41 | def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, min_virtual_size=0, validation=False, debug=False): 42 | self.NUM_EXAMPLES = num_examples 43 | self.debug = debug 44 | self.min_virtual_size = min_virtual_size 45 | 46 | subset = 'test' if validation else 'train' 47 | 48 | # base_path=DATASET_PATHS 49 | file_to_store = open(base_path, "rb") 50 | self.IMG_DATA = pickle.load(file_to_store)[subset] 51 | self.IMG_DATA = dict(list(self.IMG_DATA.items())) # [:NUM_WRITERS]) 52 | if 'None' in self.IMG_DATA.keys(): 53 | del self.IMG_DATA['None'] 54 | 55 | self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), []))))) 56 | self.author_id = list(self.IMG_DATA.keys()) 57 | 58 | self.transform = get_transform(grayscale=True) 59 | self.target_transform = target_transform 60 | 61 | self.collate_fn = TextCollator(collator_resolution) 62 | 63 | def __len__(self): 64 | if self.debug: 65 | return 16 66 | return max(len(self.author_id), self.min_virtual_size) 67 | 68 | @property 69 | def num_writers(self): 70 | return len(self.author_id) 71 | 72 | def __getitem__(self, index): 73 | index = index % len(self.author_id) 74 | 75 | author_id = self.author_id[index] 76 | 77 | self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id] 78 | random_idxs = random.choices([i for i in range(len(self.IMG_DATA_AUTHOR))], k=self.NUM_EXAMPLES) 79 | 80 | word_data = random.choice(self.IMG_DATA_AUTHOR) 81 | real_img = self.transform(word_data['img'].convert('L')) 82 | real_labels = word_data['label'].encode() 83 | 84 | imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs] 85 | slabels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs] 86 | 87 | max_width = 192 # [img.shape[1] for img in imgs] 88 | 89 | imgs_pad = [] 90 | imgs_wids = [] 91 | 92 | for img in imgs: 93 | img_height, img_width = img.shape[0], img.shape[1] 94 | output_img = np.ones((img_height, max_width), dtype='float32') * 255.0 95 | output_img[:, :img_width] = img[:, :max_width] 96 | 97 | imgs_pad.append(self.transform(Image.fromarray(output_img.astype(np.uint8)))) 98 | imgs_wids.append(img_width) 99 | 100 | imgs_pad = torch.cat(imgs_pad, 0) 101 | 102 | item = { 103 | 'simg': imgs_pad, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)] 104 | 'swids': imgs_wids, # widths of the N images [list(N)] 105 | 'img': real_img, # the input image [1, H (32), W] 106 | 'label': real_labels, # the label of the input image [byte] 107 | 'img_path': 'img_path', 108 | 'idx': 'indexes', 109 | 'wcl': index, # id of the author [int], 110 | 'slabels': slabels, 111 | 'author_id': author_id 112 | } 113 | return item 114 | 115 | def get_stats(self): 116 | char_counts = defaultdict(lambda: 0) 117 | total = 0 118 | 119 | for author in self.IMG_DATA.keys(): 120 | for data in self.IMG_DATA[author]: 121 | for char in data['label']: 122 | char_counts[char] += 1 123 | total += 1 124 | 125 | char_counts = {k: 1.0 / (v / total) for k, v in char_counts.items()} 126 | 127 | return char_counts 128 | 129 | 130 | class TextCollator(object): 131 | def __init__(self, resolution): 132 | self.resolution = resolution 133 | 134 | def __call__(self, batch): 135 | if isinstance(batch[0], list): 136 | batch = sum(batch, []) 137 | img_path = [item['img_path'] for item in batch] 138 | width = [item['img'].shape[2] for item in batch] 139 | indexes = [item['idx'] for item in batch] 140 | simgs = torch.stack([item['simg'] for item in batch], 0) 141 | wcls = torch.Tensor([item['wcl'] for item in batch]) 142 | swids = torch.Tensor([item['swids'] for item in batch]) 143 | imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], max(width)], 144 | dtype=torch.float32) 145 | for idx, item in enumerate(batch): 146 | try: 147 | imgs[idx, :, :, 0:item['img'].shape[2]] = item['img'] 148 | except: 149 | print(imgs.shape) 150 | item = {'img': imgs, 'img_path': img_path, 'idx': indexes, 'simg': simgs, 'swids': swids, 'wcl': wcls} 151 | if 'label' in batch[0].keys(): 152 | labels = [item['label'] for item in batch] 153 | item['label'] = labels 154 | if 'slabels' in batch[0].keys(): 155 | slabels = [item['slabels'] for item in batch] 156 | item['slabels'] = np.array(slabels) 157 | if 'z' in batch[0].keys(): 158 | z = torch.stack([item['z'] for item in batch]) 159 | item['z'] = z 160 | return item 161 | 162 | 163 | class CollectionTextDataset(Dataset): 164 | def __init__(self, datasets, datasets_path, dataset_class, file_suffix=None, height=32, **kwargs): 165 | self.datasets = {} 166 | for dataset_name in sorted(datasets.split(',')): 167 | dataset_file = get_dataset_path(dataset_name, height, file_suffix, datasets_path) 168 | dataset = dataset_class(dataset_file, **kwargs) 169 | self.datasets[dataset_name] = dataset 170 | self.alphabet = ''.join(sorted(set(''.join(d.alphabet for d in self.datasets.values())))) 171 | 172 | def __len__(self): 173 | return sum(len(d) for d in self.datasets.values()) 174 | 175 | @property 176 | def num_writers(self): 177 | return sum(d.num_writers for d in self.datasets.values()) 178 | 179 | def __getitem__(self, index): 180 | for dataset in self.datasets.values(): 181 | if index < len(dataset): 182 | return dataset[index] 183 | index -= len(dataset) 184 | raise IndexError 185 | 186 | def get_dataset(self, index): 187 | for dataset_name, dataset in self.datasets.items(): 188 | if index < len(dataset): 189 | return dataset_name 190 | index -= len(dataset) 191 | raise IndexError 192 | 193 | def collate_fn(self, batch): 194 | return self.datasets[self.get_dataset(0)].collate_fn(batch) 195 | 196 | 197 | class FidDataset(Dataset): 198 | def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, mode='train', style_dataset=None): 199 | self.NUM_EXAMPLES = num_examples 200 | 201 | # base_path=DATASET_PATHS 202 | with open(base_path, "rb") as f: 203 | self.IMG_DATA = pickle.load(f) 204 | 205 | self.IMG_DATA = self.IMG_DATA[mode] 206 | if 'None' in self.IMG_DATA.keys(): 207 | del self.IMG_DATA['None'] 208 | 209 | self.STYLE_IMG_DATA = None 210 | if style_dataset is not None: 211 | with open(style_dataset, "rb") as f: 212 | self.STYLE_IMG_DATA = pickle.load(f) 213 | 214 | self.STYLE_IMG_DATA = self.STYLE_IMG_DATA[mode] 215 | if 'None' in self.STYLE_IMG_DATA.keys(): 216 | del self.STYLE_IMG_DATA['None'] 217 | 218 | self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), []))))) 219 | self.author_id = sorted(self.IMG_DATA.keys()) 220 | 221 | self.transform = get_transform(grayscale=True) 222 | self.target_transform = target_transform 223 | self.dataset_size = sum(len(samples) for samples in self.IMG_DATA.values()) 224 | self.collate_fn = TextCollator(collator_resolution) 225 | 226 | def __len__(self): 227 | return self.dataset_size 228 | 229 | @property 230 | def num_writers(self): 231 | return len(self.author_id) 232 | 233 | def __getitem__(self, index): 234 | NUM_SAMPLES = self.NUM_EXAMPLES 235 | sample, author_id = None, None 236 | for author_id, samples in self.IMG_DATA.items(): 237 | if index < len(samples): 238 | sample, author_id = samples[index], author_id 239 | break 240 | index -= len(samples) 241 | 242 | real_image = self.transform(sample['img'].convert('L')) 243 | real_label = sample['label'].encode() 244 | 245 | style_dataset = self.STYLE_IMG_DATA if self.STYLE_IMG_DATA is not None else self.IMG_DATA 246 | 247 | author_style_images = style_dataset[author_id] 248 | random_idxs = np.random.choice(len(author_style_images), NUM_SAMPLES, replace=True) 249 | style_images = [np.array(author_style_images[idx]['img'].convert('L')) for idx in random_idxs] 250 | 251 | max_width = 192 252 | 253 | imgs_pad = [] 254 | imgs_wids = [] 255 | 256 | for img in style_images: 257 | img = 255 - img 258 | img_height, img_width = img.shape[0], img.shape[1] 259 | outImg = np.zeros((img_height, max_width), dtype='float32') 260 | outImg[:, :img_width] = img[:, :max_width] 261 | 262 | img = 255 - outImg 263 | 264 | imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8)))) 265 | imgs_wids.append(img_width) 266 | 267 | imgs_pad = torch.cat(imgs_pad, 0) 268 | 269 | item = { 270 | 'simg': imgs_pad, # widths of the N images [list(N)] 271 | 'swids': imgs_wids, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)] 272 | 'img': real_image, # the input image [1, H (32), W] 273 | 'label': real_label, # the label of the input image [byte] 274 | 'img_path': 'img_path', 275 | 'idx': sample['img_id'] if 'img_id' in sample.keys() else sample['image_id'], 276 | 'wcl': int(author_id) # id of the author [int] 277 | } 278 | return item 279 | 280 | 281 | class FolderDataset: 282 | def __init__(self, folder_path, num_examples=15, word_lengths=None): 283 | folder_path = Path(folder_path) 284 | self.imgs = list([p for p in folder_path.iterdir() if not p.suffix == '.txt']) 285 | self.transform = get_transform(grayscale=True) 286 | self.num_examples = num_examples 287 | self.word_lengths = word_lengths 288 | 289 | def __len__(self): 290 | return len(self.imgs) 291 | 292 | def sample_style(self): 293 | random_idxs = np.random.choice(len(self.imgs), self.num_examples, replace=False) 294 | image_names = [self.imgs[idx].stem for idx in random_idxs] 295 | imgs = [Image.open(self.imgs[idx]).convert('L') for idx in random_idxs] 296 | if self.word_lengths is None: 297 | imgs = [img.resize((img.size[0] * 32 // img.size[1], 32), Image.BILINEAR) for img in imgs] 298 | else: 299 | imgs = [img.resize((self.word_lengths[name] * 16, 32), Image.BILINEAR) for img, name in zip(imgs, image_names)] 300 | imgs = [np.array(img) for img in imgs] 301 | 302 | max_width = 192 # [img.shape[1] for img in imgs] 303 | 304 | imgs_pad = [] 305 | imgs_wids = [] 306 | 307 | for img in imgs: 308 | img = 255 - img 309 | img_height, img_width = img.shape[0], img.shape[1] 310 | outImg = np.zeros((img_height, max_width), dtype='float32') 311 | outImg[:, :img_width] = img[:, :max_width] 312 | 313 | img = 255 - outImg 314 | 315 | imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8)))) 316 | imgs_wids.append(img_width) 317 | 318 | imgs_pad = torch.cat(imgs_pad, 0) 319 | 320 | item = { 321 | 'simg': imgs_pad, # widths of the N images [list(N)] 322 | 'swids': imgs_wids, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)] 323 | } 324 | return item 325 | -------------------------------------------------------------------------------- /data/iam_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def test_split(): 5 | iam_path = r"C:\Users\bramv\Documents\Werk\Research\Unimore\datasets\IAM" 6 | 7 | original_set_names = ["trainset.txt", "validationset1.txt", "validationset2.txt", "testset.txt"] 8 | original_set_ids = [] 9 | 10 | print("ORIGINAL IAM") 11 | print("---------------------") 12 | 13 | for set_name in original_set_names: 14 | with open(os.path.join(iam_path, set_name), 'r') as f: 15 | set_form_ids = ["-".join(l.rstrip().split("-")[:-1]) for l in f] 16 | 17 | form_to_id = {} 18 | with open(os.path.join(iam_path, "forms.txt"), 'r') as f: 19 | for line in f: 20 | if line.startswith("#"): 21 | continue 22 | form, id, *_ = line.split(" ") 23 | assert form not in form_to_id.keys() or form_to_id[form] == id 24 | form_to_id[form] = int(id) 25 | 26 | set_authors = [form_to_id[form] for form in set_form_ids] 27 | 28 | set_authors = set(sorted(set_authors)) 29 | original_set_ids.append(set_authors) 30 | print(f"{set_name} count: {len(set_authors)}") 31 | 32 | htg_set_names = ["gan.iam.tr_va.gt.filter27", "gan.iam.test.gt.filter27"] 33 | 34 | print("\n\nHTG IAM") 35 | print("---------------------") 36 | 37 | for set_name in htg_set_names: 38 | with open(os.path.join(iam_path, set_name), 'r') as f: 39 | set_authors = [int(l.split(",")[0]) for l in f] 40 | 41 | set_authors = set(set_authors) 42 | 43 | print(f"{set_name} count: {len(set_authors)}") 44 | for name, original_set in zip(original_set_names, original_set_ids): 45 | intr = set_authors.intersection(original_set) 46 | print(f"\t intersection with {name}: {len(intr)}") 47 | 48 | 49 | 50 | if __name__ == "__main__": 51 | test_split() 52 | -------------------------------------------------------------------------------- /data/show_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import shutil 5 | 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | from data.dataset import get_transform 11 | 12 | 13 | def summarize_dataset(data: dict): 14 | print(f"Training authors: {len(data['train'].keys())} \t Testing authors: {len(data['test'].keys())}") 15 | training_images = sum([len(data['train'][k]) for k in data['train'].keys()]) 16 | testing_images = sum([len(data['test'][k]) for k in data['test'].keys()]) 17 | print(f"Training images: {training_images} \t Testing images: {testing_images}") 18 | 19 | 20 | def compare_data(path_a: str, path_b: str): 21 | with open(path_a, 'rb') as f: 22 | data_a = pickle.load(f) 23 | summarize_dataset(data_a) 24 | 25 | with open(path_b, 'rb') as f: 26 | data_b = pickle.load(f) 27 | summarize_dataset(data_b) 28 | 29 | training_a = data_a['train'] 30 | training_b = data_b['train'] 31 | 32 | training_a = {int(k): v for k, v in training_a.items()} 33 | training_b = {int(k): v for k, v in training_b.items()} 34 | 35 | while True: 36 | author = random.choice(list(training_a.keys())) 37 | 38 | if author in training_b.keys(): 39 | author_images_a = [np.array(im_dict["img"]) for im_dict in training_a[author]] 40 | author_images_b = [np.array(im_dict["img"]) for im_dict in training_b[author]] 41 | 42 | labels_a = [str(im_dict["label"]) for im_dict in training_a[author]] 43 | labels_b = [str(im_dict["label"]) for im_dict in training_b[author]] 44 | 45 | vis_a = np.hstack(author_images_a[:10]) 46 | vis_b = np.hstack(author_images_b[:10]) 47 | 48 | cv2.imshow("Author a", vis_a) 49 | cv2.imshow("Author b", vis_b) 50 | 51 | cv2.waitKey(0) 52 | 53 | else: 54 | print(f"Author: {author} not found in second dataset") 55 | 56 | 57 | def show_dataset(path: str, samples: int = 10): 58 | with open(path, 'rb') as f: 59 | data = pickle.load(f) 60 | summarize_dataset(data) 61 | 62 | training = data['train'] 63 | 64 | author = training['013'] 65 | author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in author] 66 | 67 | for img in author_images: 68 | cv2.imshow('image', img) 69 | cv2.waitKey(0) 70 | 71 | for author in list(training.keys()): 72 | 73 | author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in training[author]] 74 | labels = [str(im_dict["label"]) for im_dict in training[author]] 75 | 76 | vis = np.hstack(author_images[:samples]) 77 | print(f"Author: {author}") 78 | cv2.destroyAllWindows() 79 | cv2.imshow("vis", vis) 80 | cv2.waitKey(0) 81 | 82 | 83 | def test_transform(path: str): 84 | with open(path, 'rb') as f: 85 | data = pickle.load(f) 86 | summarize_dataset(data) 87 | 88 | training = data['train'] 89 | transform = get_transform(grayscale=True) 90 | 91 | for author_id in training.keys(): 92 | author = training[author_id] 93 | for image_dict in author: 94 | original_image = image_dict['img'].convert('L') 95 | transformed_image = transform(original_image).detach().numpy() 96 | restored_image = (((transformed_image + 1) / 2) * 255).astype(np.uint8) 97 | restored_image = np.squeeze(restored_image) 98 | original_image = np.array(original_image) 99 | 100 | wrong_pixels = (original_image != restored_image).astype(np.uint8) * 255 101 | 102 | combined = np.hstack((restored_image, original_image, wrong_pixels)) 103 | 104 | cv2.imshow("original", original_image) 105 | cv2.imshow("restored", restored_image) 106 | cv2.imshow("combined", combined) 107 | 108 | f, ax = plt.subplots(1, 2) 109 | ax[0].hist(original_image.flatten()) 110 | ax[1].hist(restored_image.flatten()) 111 | plt.show() 112 | 113 | cv2.waitKey(0) 114 | 115 | def dump_words(): 116 | data_path = r"..\files\IAM-32.pickle" 117 | 118 | p_mark = 'point' 119 | p = '.' 120 | 121 | with open(data_path, 'rb') as f: 122 | data = pickle.load(f) 123 | 124 | training = data['train'] 125 | 126 | target_folder = f"../saved_images/debug/{p_mark}" 127 | 128 | if os.path.exists(target_folder): 129 | shutil.rmtree(target_folder) 130 | 131 | os.mkdir(target_folder) 132 | 133 | count = 0 134 | 135 | for author in list(training.keys()): 136 | 137 | author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in training[author]] 138 | labels = [str(im_dict["label"]) for im_dict in training[author]] 139 | 140 | for img, label in zip(author_images, labels): 141 | if p in label: 142 | cv2.imwrite(os.path.join(target_folder, f"{count}.png"), img) 143 | count += 1 144 | 145 | 146 | if __name__ == "__main__": 147 | test_transform("../files/IAM-32.pickle") 148 | #show_dataset("../files/IAM-32.pickle") 149 | #compare_data(r"../files/IAM-32.pickle", r"../files/_IAM-32.pickle") 150 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from generate import generate_text, generate_authors, generate_fid, generate_page, generate_ocr, generate_ocr_msgpack 3 | from generate.ocr import generate_ocr_reference 4 | from util.misc import add_vatr_args 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("action", choices=['text', 'fid', 'page', 'authors', 'ocr']) 9 | 10 | parser.add_argument("-s", "--style-folder", default='files/style_samples/00', type=str) 11 | parser.add_argument("-t", "--text", default='That\'s one small step for man, one giant leap for mankind ΑαΒβΓγΔδ', type=str) 12 | parser.add_argument("--text-path", default=None, type=str, help='Path to text file with texts to generate') 13 | parser.add_argument("-c", "--checkpoint", default='files/vatr.pth', type=str) 14 | parser.add_argument("-o", "--output", default=None, type=str) 15 | parser.add_argument("--count", default=1000, type=int) 16 | parser.add_argument("-a", "--align", action='store_true') 17 | parser.add_argument("--at-once", action='store_true') 18 | parser.add_argument("--output-style", action='store_true') 19 | parser.add_argument("-d", "--dataset-path", type=str) 20 | parser.add_argument("--target-dataset-path", type=str, default=None) 21 | parser.add_argument("--charset-file", type=str, default=None) 22 | parser.add_argument("--interp-styles", action='store_true') 23 | 24 | parser.add_argument("--test-only", action='store_true') 25 | parser.add_argument("--fake-only", action='store_true') 26 | parser.add_argument("--all-epochs", action='store_true') 27 | parser.add_argument("--long-tail", action='store_true') 28 | parser.add_argument("--msgpack", action='store_true') 29 | parser.add_argument("--reference", action='store_true') 30 | parser.add_argument("--test-set", action='store_true') 31 | 32 | parser = add_vatr_args(parser) 33 | args = parser.parse_args() 34 | 35 | if args.action == 'text': 36 | generate_text(args) 37 | elif args.action == 'authors': 38 | generate_authors(args) 39 | elif args.action == 'fid': 40 | generate_fid(args) 41 | elif args.action == 'page': 42 | generate_page(args) 43 | elif args.action == 'ocr': 44 | if args.msgpack: 45 | generate_ocr_msgpack(args) 46 | elif args.reference: 47 | generate_ocr_reference(args) 48 | else: 49 | generate_ocr(args) 50 | -------------------------------------------------------------------------------- /generate/__init__.py: -------------------------------------------------------------------------------- 1 | from generate.text import generate_text 2 | from generate.fid import generate_fid 3 | from generate.authors import generate_authors 4 | from generate.page import generate_page 5 | from generate.ocr import generate_ocr, generate_ocr_msgpack 6 | -------------------------------------------------------------------------------- /generate/authors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from data.dataset import CollectionTextDataset, TextDataset 8 | from generate.util import stack_lines 9 | from generate.writer import Writer 10 | 11 | 12 | def generate_authors(args): 13 | dataset = CollectionTextDataset( 14 | args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, 15 | collator_resolution=args.resolution, validation=args.test_set 16 | ) 17 | 18 | args.num_writers = dataset.num_writers 19 | 20 | writer = Writer(args.checkpoint, args, only_generator=True) 21 | 22 | if args.text.endswith(".txt"): 23 | with open(args.text, 'r') as f: 24 | lines = [l.rstrip() for l in f] 25 | else: 26 | lines = [args.text] 27 | 28 | output_dir = "saved_images/author_samples/" 29 | if os.path.exists(output_dir): 30 | shutil.rmtree(output_dir) 31 | os.mkdir(output_dir) 32 | 33 | fakes, author_ids, style_images = writer.generate_authors(lines, dataset, args.align, args.at_once) 34 | 35 | for fake, author_id, style in zip(fakes, author_ids, style_images): 36 | author_dir = os.path.join(output_dir, str(author_id)) 37 | os.mkdir(author_dir) 38 | 39 | for i, line in enumerate(fake): 40 | cv2.imwrite(os.path.join(author_dir, f"line_{i}.png"), line) 41 | 42 | total = stack_lines(fake) 43 | cv2.imwrite(os.path.join(author_dir, "total.png"), total) 44 | 45 | if args.output_style: 46 | for i, image in enumerate(style): 47 | cv2.imwrite(os.path.join(author_dir, f"style_{i}.png"), image) 48 | 49 | -------------------------------------------------------------------------------- /generate/fid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.utils.data 6 | 7 | from data.dataset import FidDataset 8 | from generate.writer import Writer 9 | 10 | 11 | def generate_fid(args): 12 | if 'iam' in args.target_dataset_path.lower(): 13 | args.num_writers = 339 14 | elif 'cvl' in args.target_dataset_path.lower(): 15 | args.num_writers = 283 16 | else: 17 | raise ValueError 18 | 19 | args.vocab_size = len(args.alphabet) 20 | 21 | dataset_train = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='train', style_dataset=args.dataset_path) 22 | train_loader = torch.utils.data.DataLoader( 23 | dataset_train, 24 | batch_size=args.batch_size, 25 | shuffle=False, 26 | num_workers=args.num_workers, 27 | pin_memory=True, drop_last=False, 28 | collate_fn=dataset_train.collate_fn 29 | ) 30 | 31 | dataset_test = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='test', style_dataset=args.dataset_path) 32 | test_loader = torch.utils.data.DataLoader( 33 | dataset_test, 34 | batch_size=args.batch_size, 35 | shuffle=False, 36 | num_workers=0, 37 | pin_memory=True, drop_last=False, 38 | collate_fn=dataset_test.collate_fn 39 | ) 40 | 41 | args.output = 'saved_images' if args.output is None else args.output 42 | args.output = Path(args.output) / 'fid' / args.target_dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "") 43 | 44 | model_folder = args.checkpoint.split("/")[-2] if args.checkpoint.endswith(".pth") else args.checkpoint.split("/")[-1] 45 | model_tag = model_folder.split("-")[-1] if "-" in model_folder else "vatr" 46 | model_tag += "_" + args.dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "") 47 | 48 | if not args.all_epochs: 49 | writer = Writer(args.checkpoint, args, only_generator=True) 50 | if not args.test_only: 51 | writer.generate_fid(args.output, train_loader, model_tag=model_tag, split='train', fake_only=args.fake_only, long_tail_only=args.long_tail) 52 | writer.generate_fid(args.output, test_loader, model_tag=model_tag, split='test', fake_only=args.fake_only, long_tail_only=args.long_tail) 53 | else: 54 | epochs = sorted([int(f.split("_")[0]) for f in os.listdir(args.checkpoint) if "_" in f]) 55 | generate_real = True 56 | 57 | for epoch in epochs: 58 | checkpoint_path = os.path.join(args.checkpoint, f"{str(epoch).zfill(4)}_model.pth") 59 | writer = Writer(checkpoint_path, args, only_generator=True) 60 | writer.generate_fid(args.output, test_loader, model_tag=f"{model_tag}_{epoch}", split='test', fake_only=not generate_real, long_tail_only=args.long_tail) 61 | generate_real = False 62 | 63 | print('Done') 64 | -------------------------------------------------------------------------------- /generate/ocr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import cv2 5 | import msgpack 6 | import torch 7 | 8 | from data.dataset import CollectionTextDataset, TextDataset, FolderDataset, FidDataset, get_dataset_path 9 | from generate.writer import Writer 10 | from util.text import get_generator 11 | 12 | 13 | def generate_ocr(args): 14 | """ 15 | Generate OCR training data. Words generated are from given text generator. 16 | """ 17 | dataset = CollectionTextDataset( 18 | args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, 19 | collator_resolution=args.resolution, validation=True 20 | ) 21 | args.num_writers = dataset.num_writers 22 | 23 | writer = Writer(args.checkpoint, args, only_generator=True) 24 | 25 | generator = get_generator(args) 26 | 27 | writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, text_generator=generator) 28 | 29 | 30 | def generate_ocr_reference(args): 31 | """ 32 | Generate OCR training data. Words generated are words from given dataset. Reference words are also saved. 33 | """ 34 | dataset = CollectionTextDataset( 35 | args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, 36 | collator_resolution=args.resolution, validation=True 37 | ) 38 | 39 | #dataset = FidDataset(get_dataset_path(args.dataset, 32, args.file_suffix, 'files'), mode='test', collator_resolution=args.resolution) 40 | 41 | args.num_writers = dataset.num_writers 42 | 43 | writer = Writer(args.checkpoint, args, only_generator=True) 44 | 45 | writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, long_tail=args.long_tail) 46 | 47 | 48 | def generate_ocr_msgpack(args): 49 | """ 50 | Generate OCR dataset. Words generated are specified in given msgpack file 51 | """ 52 | dataset = FolderDataset(args.dataset_path) 53 | args.num_writers = 339 54 | 55 | if args.charset_file: 56 | charset = msgpack.load(open(args.charset_file, 'rb'), use_list=False, strict_map_key=False) 57 | args.alphabet = "".join(charset['char2idx'].keys()) 58 | 59 | writer = Writer(args.checkpoint, args, only_generator=True) 60 | 61 | lines = msgpack.load(open(args.text_path, 'rb'), use_list=False) 62 | 63 | print(f"Generating {len(lines)} to {args.output}") 64 | 65 | for i, (filename, target) in enumerate(lines): 66 | if not os.path.exists(os.path.join(args.output, filename)): 67 | style = torch.unsqueeze(dataset.sample_style()['simg'], dim=0).to(args.device) 68 | fake = writer.create_fake_sentence(style, target, at_once=True) 69 | 70 | cv2.imwrite(os.path.join(args.output, filename), fake) 71 | 72 | print(f"Done") 73 | -------------------------------------------------------------------------------- /generate/page.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from data.dataset import CollectionTextDataset, TextDataset 8 | from models.model import VATr 9 | from util.loading import load_checkpoint, load_generator 10 | 11 | 12 | def generate_page(args): 13 | args.output = 'vatr' if args.output is None else args.output 14 | 15 | args.vocab_size = len(args.alphabet) 16 | 17 | dataset = CollectionTextDataset( 18 | args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, 19 | collator_resolution=args.resolution 20 | ) 21 | datasetval = CollectionTextDataset( 22 | args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, 23 | collator_resolution=args.resolution, validation=True 24 | ) 25 | 26 | args.num_writers = dataset.num_writers 27 | 28 | model = VATr(args) 29 | checkpoint = torch.load(args.checkpoint, map_location=args.device) 30 | model = load_generator(model, checkpoint) 31 | 32 | train_loader = torch.utils.data.DataLoader( 33 | dataset, 34 | batch_size=8, 35 | shuffle=True, 36 | num_workers=0, 37 | pin_memory=True, drop_last=True, 38 | collate_fn=dataset.collate_fn) 39 | 40 | val_loader = torch.utils.data.DataLoader( 41 | datasetval, 42 | batch_size=8, 43 | shuffle=True, 44 | num_workers=0, 45 | pin_memory=True, drop_last=True, 46 | collate_fn=datasetval.collate_fn) 47 | 48 | data_train = next(iter(train_loader)) 49 | data_val = next(iter(val_loader)) 50 | 51 | model.eval() 52 | with torch.no_grad(): 53 | page = model._generate_page(data_train['simg'].to(args.device), data_val['swids']) 54 | page_val = model._generate_page(data_val['simg'].to(args.device), data_val['swids']) 55 | 56 | cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_train.png"), (page * 255).astype(np.uint8)) 57 | cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_val.png"), (page_val * 255).astype(np.uint8)) 58 | -------------------------------------------------------------------------------- /generate/text.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | 5 | from generate.writer import Writer 6 | 7 | 8 | def generate_text(args): 9 | if args.text_path is not None: 10 | with open(args.text_path, 'r') as f: 11 | args.text = f.read() 12 | args.text = args.text.splitlines() 13 | args.output = 'files/output.png' if args.output is None else args.output 14 | args.output = Path(args.output) 15 | args.output.parent.mkdir(parents=True, exist_ok=True) 16 | args.num_writers = 0 17 | 18 | writer = Writer(args.checkpoint, args, only_generator=True) 19 | writer.set_style_folder(args.style_folder) 20 | fakes = writer.generate(args.text, args.align) 21 | for i, fake in enumerate(fakes): 22 | dst_path = args.output.parent / (args.output.stem + f'_{i:03d}' + args.output.suffix) 23 | cv2.imwrite(str(dst_path), fake) 24 | print('Done') 25 | -------------------------------------------------------------------------------- /generate/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def stack_lines(lines: list, h_gap: int = 6): 5 | width = max([im.shape[1] for im in lines]) 6 | height = (lines[0].shape[0] + h_gap) * len(lines) 7 | 8 | result = np.ones((height, width)) * 255 9 | 10 | y_pos = 0 11 | for line in lines: 12 | result[y_pos:y_pos + line.shape[0], 0:line.shape[1]] = line 13 | y_pos += line.shape[0] + h_gap 14 | 15 | return result -------------------------------------------------------------------------------- /generate/writer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import shutil 5 | from collections import defaultdict 6 | import time 7 | from datetime import timedelta 8 | from pathlib import Path 9 | 10 | import cv2 11 | import numpy as np 12 | import torch 13 | 14 | from data.dataset import FolderDataset 15 | from models.model import VATr 16 | from util.loading import load_checkpoint, load_generator 17 | from util.misc import FakeArgs 18 | from util.text import TextGenerator 19 | from util.vision import detect_text_bounds 20 | 21 | 22 | def get_long_tail_chars(): 23 | with open(f"files/longtail.txt", 'r') as f: 24 | chars = [c.rstrip() for c in f] 25 | 26 | chars.remove('') 27 | 28 | return chars 29 | 30 | 31 | class Writer: 32 | def __init__(self, checkpoint_path, args, only_generator: bool = False): 33 | self.model = VATr(args) 34 | checkpoint = torch.load(checkpoint_path, map_location=args.device) 35 | load_checkpoint(self.model, checkpoint) if not only_generator else load_generator(self.model, checkpoint) 36 | self.model.eval() 37 | self.style_dataset = None 38 | 39 | def set_style_folder(self, style_folder, num_examples=15): 40 | word_lengths = None 41 | if os.path.exists(os.path.join(style_folder, "word_lengths.txt")): 42 | word_lengths = {} 43 | with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f: 44 | for line in f: 45 | word, length = line.rstrip().split(",") 46 | word_lengths[word] = int(length) 47 | 48 | self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths) 49 | 50 | @torch.no_grad() 51 | def generate(self, texts, align_words: bool = False, at_once: bool = False): 52 | if isinstance(texts, str): 53 | texts = [texts] 54 | if self.style_dataset is None: 55 | raise Exception('Style is not set') 56 | 57 | fakes = [] 58 | for i, text in enumerate(texts, 1): 59 | print(f'[{i}/{len(texts)}] Generating for text: {text}') 60 | style = self.style_dataset.sample_style() 61 | style_images = style['simg'].unsqueeze(0).to(self.model.args.device) 62 | 63 | fake = self.create_fake_sentence(style_images, text, align_words, at_once) 64 | 65 | fakes.append(fake) 66 | return fakes 67 | 68 | @torch.no_grad() 69 | def create_fake_sentence(self, style_images, text, align_words=False, at_once=False): 70 | text = "".join([c for c in text if c in self.model.args.alphabet]) 71 | 72 | text = text.split() if not at_once else [text] 73 | gap = np.ones((32, 16)) 74 | 75 | text_encode, len_text, encode_pos = self.model.netconverter.encode(text) 76 | text_encode = text_encode.to(self.model.args.device).unsqueeze(0) 77 | 78 | fake = self.model._generate_fakes(style_images, text_encode, len_text) 79 | if not at_once: 80 | if align_words: 81 | fake = self.stitch_words(fake, show_lines=False) 82 | else: 83 | fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16] 84 | else: 85 | fake = fake[0] 86 | fake = (fake * 255).astype(np.uint8) 87 | 88 | return fake 89 | 90 | @torch.no_grad() 91 | def generate_authors(self, text, dataset, align_words: bool = False, at_once: bool = False): 92 | fakes = [] 93 | author_ids = [] 94 | style = [] 95 | 96 | for item in dataset: 97 | print(f"Generating author {item['wcl']}") 98 | style_images = item['simg'].to(self.model.args.device).unsqueeze(0) 99 | 100 | generated_lines = [self.create_fake_sentence(style_images, line, align_words, at_once) for line in text] 101 | 102 | fakes.append(generated_lines) 103 | author_ids.append(item['author_id']) 104 | style.append((((item['simg'].numpy() + 1.0) / 2.0) * 255).astype(np.uint8)) 105 | 106 | return fakes, author_ids, style 107 | 108 | @torch.no_grad() 109 | def generate_characters(self, dataset, characters: str): 110 | """ 111 | Generate each of the given characters for each of the authors in the dataset. 112 | """ 113 | fakes = [] 114 | 115 | text_encode, len_text, encode_pos = self.model.netconverter.encode([c for c in characters]) 116 | text_encode = text_encode.to(self.model.args.device).unsqueeze(0) 117 | 118 | for item in dataset: 119 | print(f"Generating author {item['wcl']}") 120 | style_images = item['simg'].to(self.model.args.device).unsqueeze(0) 121 | fake = self.model.netG.evaluate(style_images, text_encode) 122 | 123 | fakes.append(fake) 124 | 125 | return fakes 126 | 127 | @torch.no_grad() 128 | def generate_batch(self, style_imgs, text): 129 | """ 130 | Given a batch of style images and text, generate images using the model 131 | """ 132 | device = self.model.args.device 133 | text_encode, _, _ = self.model.netconverter.encode(text) 134 | fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device)) 135 | return fakes 136 | 137 | @torch.no_grad() 138 | def generate_ocr(self, dataset, number: int, output_folder: str = 'saved_images/ocr', interpolate_style: bool = False, text_generator: TextGenerator = None, long_tail: bool = False): 139 | def create_and_write(style, text, interpolated=False): 140 | nonlocal image_counter, annotations 141 | 142 | text_encode, len_text, encode_pos = self.model.netconverter.encode([text]) 143 | text_encode = text_encode.to(self.model.args.device) 144 | 145 | fake = self.model.netG.generate(style, text_encode) 146 | 147 | fake = (fake + 1) / 2 148 | fake = fake.cpu().numpy() 149 | fake = np.squeeze((fake * 255).astype(np.uint8)) 150 | 151 | image_filename = f"{image_counter}.png" if not interpolated else f"{image_counter}_i.png" 152 | 153 | cv2.imwrite(os.path.join(output_folder, "generated", image_filename), fake) 154 | 155 | annotations.append((image_filename, text)) 156 | 157 | image_counter += 1 158 | 159 | image_counter = 0 160 | annotations = [] 161 | previous_style = None 162 | long_tail_chars = get_long_tail_chars() 163 | 164 | os.mkdir(os.path.join(output_folder, "generated")) 165 | if text_generator is None: 166 | os.mkdir(os.path.join(output_folder, "reference")) 167 | 168 | while image_counter < number: 169 | author_index = random.randint(0, len(dataset) - 1) 170 | item = dataset[author_index] 171 | 172 | style_images = item['simg'].to(self.model.args.device).unsqueeze(0) 173 | style = self.model.netG.compute_style(style_images) 174 | 175 | if interpolate_style and previous_style is not None: 176 | factor = float(np.clip(random.gauss(0.5, 0.15), 0.0, 1.0)) 177 | intermediate_style = torch.lerp(previous_style, style, factor) 178 | text = text_generator.generate() 179 | 180 | create_and_write(intermediate_style, text, interpolated=True) 181 | 182 | if text_generator is not None: 183 | text = text_generator.generate() 184 | else: 185 | text = str(item['label'].decode()) 186 | 187 | if long_tail and not any(c in long_tail_chars for c in text): 188 | continue 189 | 190 | fake = (item['img'] + 1) / 2 191 | fake = fake.cpu().numpy() 192 | fake = np.squeeze((fake * 255).astype(np.uint8)) 193 | 194 | image_filename = f"{image_counter}.png" 195 | 196 | cv2.imwrite(os.path.join(output_folder, "reference", image_filename), fake) 197 | 198 | create_and_write(style, text) 199 | 200 | previous_style = style 201 | 202 | if text_generator is None: 203 | with open(os.path.join(output_folder, "reference", "labels.csv"), 'w') as fr: 204 | fr.write(f"filename,words\n") 205 | for annotation in annotations: 206 | fr.write(f"{annotation[0]},{annotation[1]}\n") 207 | 208 | with open(os.path.join(output_folder, "generated", "labels.csv"), 'w') as fg: 209 | fg.write(f"filename,words\n") 210 | for annotation in annotations: 211 | fg.write(f"{annotation[0]},{annotation[1]}\n") 212 | 213 | 214 | @staticmethod 215 | def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False): 216 | gap_width = 16 217 | 218 | bottom_lines = [] 219 | top_lines = [] 220 | for i in range(len(words)): 221 | b, t = detect_text_bounds(words[i]) 222 | bottom_lines.append(b) 223 | top_lines.append(t) 224 | if show_lines: 225 | words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0)) 226 | words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0)) 227 | 228 | bottom_lines = np.array(bottom_lines, dtype=float) 229 | 230 | if scale_words: 231 | top_lines = np.array(top_lines, dtype=float) 232 | gaps = bottom_lines - top_lines 233 | target_gap = np.mean(gaps) 234 | scales = target_gap / gaps 235 | 236 | bottom_lines *= scales 237 | top_lines *= scales 238 | words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)] 239 | 240 | highest = np.max(bottom_lines) 241 | offsets = highest - bottom_lines 242 | height = np.max(offsets + [word.shape[0] for word in words]) 243 | 244 | result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words]))) 245 | 246 | x_pos = 0 247 | for bottom_line, word in zip(bottom_lines, words): 248 | offset = int(highest - bottom_line) 249 | 250 | result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word 251 | 252 | x_pos += word.shape[1] + gap_width 253 | 254 | return result 255 | 256 | @torch.no_grad() 257 | def generate_fid(self, path, loader, model_tag, split='train', fake_only=False, long_tail_only=False): 258 | if not isinstance(path, Path): 259 | path = Path(path) 260 | 261 | path.mkdir(exist_ok=True, parents=True) 262 | 263 | appendix = f"{split}" if not long_tail_only else f"{split}_lt" 264 | 265 | real_base = path / f'real_{appendix}' 266 | fake_base = path / model_tag / f'fake_{appendix}' 267 | 268 | if real_base.exists() and not fake_only: 269 | shutil.rmtree(real_base) 270 | 271 | if fake_base.exists(): 272 | shutil.rmtree(fake_base) 273 | 274 | real_base.mkdir(exist_ok=True) 275 | fake_base.mkdir(exist_ok=True, parents=True) 276 | 277 | print('Saving images...') 278 | 279 | print(' Saving images on {}'.format(str(real_base))) 280 | print(' Saving images on {}'.format(str(fake_base))) 281 | 282 | long_tail_chars = get_long_tail_chars() 283 | counter = 0 284 | ann = defaultdict(lambda: {}) 285 | start_time = time.time() 286 | for step, data in enumerate(loader): 287 | style_images = data['simg'].to(self.model.args.device) 288 | 289 | texts = [l.decode('utf-8') for l in data['label']] 290 | texts = [t.encode('utf-8') for t in texts] 291 | eval_text_encode, eval_len_text, _ = self.model.netconverter.encode(texts) 292 | eval_text_encode = eval_text_encode.to(self.model.args.device).unsqueeze(1) 293 | 294 | vis_style = np.vstack(style_images[0].detach().cpu().numpy()) 295 | vis_style = ((vis_style + 1) / 2) * 255 296 | 297 | fakes = self.model.netG.evaluate(style_images, eval_text_encode) 298 | fake_images = torch.cat(fakes, 1).detach().cpu().numpy() 299 | real_images = data['img'].detach().cpu().numpy() 300 | writer_ids = data['wcl'].int().tolist() 301 | 302 | for i, (fake, real, wid, lb, img_id) in enumerate(zip(fake_images, real_images, writer_ids, data['label'], data['idx'])): 303 | lb = lb.decode() 304 | ann[f"{wid:03d}"][f'{img_id:05d}'] = lb 305 | img_id = f'{img_id:05d}.png' 306 | 307 | is_long_tail = any(c in long_tail_chars for c in lb) 308 | 309 | if long_tail_only and not is_long_tail: 310 | continue 311 | 312 | fake_img_path = fake_base / f"{wid:03d}" / img_id 313 | fake_img_path.parent.mkdir(exist_ok=True, parents=True) 314 | cv2.imwrite(str(fake_img_path), 255 * ((fake.squeeze() + 1) / 2)) 315 | 316 | if not fake_only: 317 | real_img_path = real_base / f"{wid:03d}" / img_id 318 | real_img_path.parent.mkdir(exist_ok=True, parents=True) 319 | cv2.imwrite(str(real_img_path), 255 * ((real.squeeze() + 1) / 2)) 320 | 321 | counter += 1 322 | 323 | eta = (time.time() - start_time) / (step + 1) * (len(loader) - step - 1) 324 | eta = str(timedelta(seconds=eta)) 325 | if step % 100 == 0: 326 | print(f'[{(step + 1) / len(loader) * 100:.02f}%][{counter:05d}] ETA {eta}') 327 | 328 | with open(path / 'ann.json', 'w') as f: 329 | json.dump(ann, f) 330 | -------------------------------------------------------------------------------- /models/OCR_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .networks import * 3 | 4 | 5 | class BidirectionalLSTM(nn.Module): 6 | 7 | def __init__(self, nIn, nHidden, nOut): 8 | super(BidirectionalLSTM, self).__init__() 9 | 10 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 11 | self.embedding = nn.Linear(nHidden * 2, nOut) 12 | 13 | 14 | def forward(self, input): 15 | recurrent, _ = self.rnn(input) 16 | T, b, h = recurrent.size() 17 | t_rec = recurrent.view(T * b, h) 18 | 19 | output = self.embedding(t_rec) # [T * b, nOut] 20 | output = output.view(T, b, -1) 21 | 22 | return output 23 | 24 | 25 | class CRNN(nn.Module): 26 | 27 | def __init__(self, args, leakyRelu=False): 28 | super(CRNN, self).__init__() 29 | self.args = args 30 | self.name = 'OCR' 31 | self.add_noise = False 32 | self.noise_fac = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([0.2])) 33 | #assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16' 34 | 35 | ks = [3, 3, 3, 3, 3, 3, 2] 36 | ps = [1, 1, 1, 1, 1, 1, 0] 37 | ss = [1, 1, 1, 1, 1, 1, 1] 38 | nm = [64, 128, 256, 256, 512, 512, 512] 39 | 40 | cnn = nn.Sequential() 41 | nh = 256 42 | dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero 43 | 44 | def convRelu(i, batchNormalization=False): 45 | nIn = 1 if i == 0 else nm[i - 1] 46 | nOut = nm[i] 47 | cnn.add_module('conv{0}'.format(i), 48 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) 49 | if batchNormalization: 50 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 51 | if leakyRelu: 52 | cnn.add_module('relu{0}'.format(i), 53 | nn.LeakyReLU(0.2, inplace=True)) 54 | else: 55 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 56 | 57 | convRelu(0) 58 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 59 | convRelu(1) 60 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 61 | convRelu(2, True) 62 | convRelu(3) 63 | cnn.add_module('pooling{0}'.format(2), 64 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 65 | convRelu(4, True) 66 | if self.args.resolution==63: 67 | cnn.add_module('pooling{0}'.format(3), 68 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 69 | convRelu(5) 70 | cnn.add_module('pooling{0}'.format(4), 71 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 72 | convRelu(6, True) # 512x1x16 73 | 74 | self.cnn = cnn 75 | self.use_rnn = False 76 | if self.use_rnn: 77 | self.rnn = nn.Sequential( 78 | BidirectionalLSTM(512, nh, nh), 79 | BidirectionalLSTM(nh, nh, )) 80 | else: 81 | self.linear = nn.Linear(512, self.args.vocab_size) 82 | 83 | # replace all nan/inf in gradients to zero 84 | if dealwith_lossnone: 85 | self.register_backward_hook(self.backward_hook) 86 | 87 | self.device = torch.device('cuda:{}'.format(0)) 88 | self.init = 'N02' 89 | # Initialize weights 90 | 91 | self = init_weights(self, self.init) 92 | 93 | def forward(self, input): 94 | # conv features 95 | if self.add_noise: 96 | input = input + self.noise_fac.sample(input.size()).squeeze(-1).to(self.args.device) 97 | conv = self.cnn(input) 98 | b, c, h, w = conv.size() 99 | if h!=1: 100 | print('a') 101 | assert h == 1, "the height of conv must be 1" 102 | conv = conv.squeeze(2) 103 | conv = conv.permute(2, 0, 1) # [w, b, c] 104 | 105 | if self.use_rnn: 106 | # rnn features 107 | output = self.rnn(conv) 108 | else: 109 | output = self.linear(conv) 110 | return output 111 | 112 | def backward_hook(self, module, grad_input, grad_output): 113 | for g in grad_input: 114 | g[g != g] = 0 # replace all nan/inf in gradients to zero 115 | 116 | 117 | class strLabelConverter(object): 118 | """Convert between str and label. 119 | NOTE: 120 | Insert `blank` to the alphabet for CTC. 121 | Args: 122 | alphabet (str): set of the possible characters. 123 | ignore_case (bool, default=True): whether or not to ignore all of the case. 124 | """ 125 | 126 | def __init__(self, alphabet, ignore_case=False): 127 | self._ignore_case = ignore_case 128 | if self._ignore_case: 129 | alphabet = alphabet.lower() 130 | self.alphabet = alphabet + '-' # for `-1` index 131 | 132 | self.dict = {} 133 | for i, char in enumerate(alphabet): 134 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 135 | self.dict[char] = i + 1 136 | 137 | def encode(self, text): 138 | """Support batch or single str. 139 | Args: 140 | text (str or list of str): texts to convert. 141 | Returns: 142 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 143 | torch.IntTensor [n]: length of each text. 144 | """ 145 | length = [] 146 | result = [] 147 | results = [] 148 | for item in text: 149 | if isinstance(item, bytes): item = item.decode('utf-8', 'strict') 150 | length.append(len(item)) 151 | for char in item: 152 | index = self.dict[char] 153 | result.append(index) 154 | results.append(result) 155 | result = [] 156 | 157 | return torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length), None 158 | 159 | def decode(self, t, length, raw=False): 160 | """Decode encoded texts back into strs. 161 | Args: 162 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 163 | torch.IntTensor [n]: length of each text. 164 | Raises: 165 | AssertionError: when the texts and its length does not match. 166 | Returns: 167 | text (str or list of str): texts to convert. 168 | """ 169 | if length.numel() == 1: 170 | length = length[0] 171 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), 172 | length) 173 | if raw: 174 | return ''.join([self.alphabet[i - 1] for i in t]) 175 | else: 176 | char_list = [] 177 | for i in range(length): 178 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 179 | char_list.append(self.alphabet[t[i] - 1]) 180 | return ''.join(char_list) 181 | else: 182 | # batch mode 183 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( 184 | t.numel(), length.sum()) 185 | texts = [] 186 | index = 0 187 | for i in range(length.numel()): 188 | l = length[i] 189 | texts.append( 190 | self.decode( 191 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 192 | index += l 193 | return texts 194 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | """ 19 | 20 | import importlib 21 | 22 | 23 | def find_model_using_name(model_name): 24 | """Import the module "models/[model_name]_model.py". 25 | 26 | In the file, the class called DatasetNameModel() will 27 | be instantiated. It has to be a subclass of BaseModel, 28 | and it is case-insensitive. 29 | """ 30 | model_filename = "models." + model_name + "_model" 31 | modellib = importlib.import_module(model_filename) 32 | model = None 33 | target_model_name = model_name.replace('_', '') + 'model' 34 | for name, cls in modellib.__dict__.items(): 35 | if name.lower() == target_model_name.lower() \ 36 | and issubclass(cls, BaseModel): 37 | model = cls 38 | 39 | if model is None: 40 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 41 | exit(0) 42 | 43 | return model 44 | 45 | 46 | def get_option_setter(model_name): 47 | """Return the static method of the model class.""" 48 | model_class = find_model_using_name(model_name) 49 | return model_class.modify_commandline_options 50 | 51 | 52 | def create_model(opt): 53 | """Create a model given the option. 54 | 55 | This function warps the class CustomDatasetDataLoader. 56 | This is the main interface between this package and 'train.py'/'test.py' 57 | 58 | Example: 59 | >>> from models import create_model 60 | >>> model = create_model(opt) 61 | """ 62 | model = find_model_using_name(opt.model) 63 | instance = model(opt) 64 | print("model [%s] was created" % type(instance).__name__) 65 | return instance 66 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class ResBlocks(nn.Module): 7 | def __init__(self, num_blocks, dim, norm, activation, pad_type): 8 | super(ResBlocks, self).__init__() 9 | self.model = [] 10 | for i in range(num_blocks): 11 | self.model += [ResBlock(dim, 12 | norm=norm, 13 | activation=activation, 14 | pad_type=pad_type)] 15 | self.model = nn.Sequential(*self.model) 16 | 17 | def forward(self, x): 18 | return self.model(x) 19 | 20 | 21 | class ResBlock(nn.Module): 22 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 23 | super(ResBlock, self).__init__() 24 | model = [] 25 | model += [Conv2dBlock(dim, dim, 3, 1, 1, 26 | norm=norm, 27 | activation=activation, 28 | pad_type=pad_type)] 29 | model += [Conv2dBlock(dim, dim, 3, 1, 1, 30 | norm=norm, 31 | activation='none', 32 | pad_type=pad_type)] 33 | self.model = nn.Sequential(*model) 34 | 35 | def forward(self, x): 36 | residual = x 37 | out = self.model(x) 38 | out += residual 39 | return out 40 | 41 | 42 | class ActFirstResBlock(nn.Module): 43 | def __init__(self, fin, fout, fhid=None, 44 | activation='lrelu', norm='none'): 45 | super().__init__() 46 | self.learned_shortcut = (fin != fout) 47 | self.fin = fin 48 | self.fout = fout 49 | self.fhid = min(fin, fout) if fhid is None else fhid 50 | self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1, 51 | padding=1, pad_type='reflect', norm=norm, 52 | activation=activation, activation_first=True) 53 | self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1, 54 | padding=1, pad_type='reflect', norm=norm, 55 | activation=activation, activation_first=True) 56 | if self.learned_shortcut: 57 | self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1, 58 | activation='none', use_bias=False) 59 | 60 | def forward(self, x): 61 | x_s = self.conv_s(x) if self.learned_shortcut else x 62 | dx = self.conv_0(x) 63 | dx = self.conv_1(dx) 64 | out = x_s + dx 65 | return out 66 | 67 | 68 | class LinearBlock(nn.Module): 69 | def __init__(self, in_dim, out_dim, norm='none', activation='relu'): 70 | super(LinearBlock, self).__init__() 71 | use_bias = True 72 | self.fc = nn.Linear(in_dim, out_dim, bias=use_bias) 73 | 74 | # initialize normalization 75 | norm_dim = out_dim 76 | if norm == 'bn': 77 | self.norm = nn.BatchNorm1d(norm_dim) 78 | elif norm == 'in': 79 | self.norm = nn.InstanceNorm1d(norm_dim) 80 | elif norm == 'none': 81 | self.norm = None 82 | else: 83 | assert 0, "Unsupported normalization: {}".format(norm) 84 | 85 | # initialize activation 86 | if activation == 'relu': 87 | self.activation = nn.ReLU(inplace=False) 88 | elif activation == 'lrelu': 89 | self.activation = nn.LeakyReLU(0.2, inplace=False) 90 | elif activation == 'tanh': 91 | self.activation = nn.Tanh() 92 | elif activation == 'none': 93 | self.activation = None 94 | else: 95 | assert 0, "Unsupported activation: {}".format(activation) 96 | 97 | def forward(self, x): 98 | out = self.fc(x) 99 | if self.norm: 100 | out = self.norm(out) 101 | if self.activation: 102 | out = self.activation(out) 103 | return out 104 | 105 | 106 | class Conv2dBlock(nn.Module): 107 | def __init__(self, in_dim, out_dim, ks, st, padding=0, 108 | norm='none', activation='relu', pad_type='zero', 109 | use_bias=True, activation_first=False): 110 | super(Conv2dBlock, self).__init__() 111 | self.use_bias = use_bias 112 | self.activation_first = activation_first 113 | # initialize padding 114 | if pad_type == 'reflect': 115 | self.pad = nn.ReflectionPad2d(padding) 116 | elif pad_type == 'replicate': 117 | self.pad = nn.ReplicationPad2d(padding) 118 | elif pad_type == 'zero': 119 | self.pad = nn.ZeroPad2d(padding) 120 | else: 121 | assert 0, "Unsupported padding type: {}".format(pad_type) 122 | 123 | # initialize normalization 124 | norm_dim = out_dim 125 | if norm == 'bn': 126 | self.norm = nn.BatchNorm2d(norm_dim) 127 | elif norm == 'in': 128 | self.norm = nn.InstanceNorm2d(norm_dim) 129 | elif norm == 'adain': 130 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 131 | elif norm == 'none': 132 | self.norm = None 133 | else: 134 | assert 0, "Unsupported normalization: {}".format(norm) 135 | 136 | # initialize activation 137 | if activation == 'relu': 138 | self.activation = nn.ReLU(inplace=False) 139 | elif activation == 'lrelu': 140 | self.activation = nn.LeakyReLU(0.2, inplace=False) 141 | elif activation == 'tanh': 142 | self.activation = nn.Tanh() 143 | elif activation == 'none': 144 | self.activation = None 145 | else: 146 | assert 0, "Unsupported activation: {}".format(activation) 147 | 148 | self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias) 149 | 150 | def forward(self, x): 151 | if self.activation_first: 152 | if self.activation: 153 | x = self.activation(x) 154 | x = self.conv(self.pad(x)) 155 | if self.norm: 156 | x = self.norm(x) 157 | else: 158 | x = self.conv(self.pad(x)) 159 | if self.norm: 160 | x = self.norm(x) 161 | if self.activation: 162 | x = self.activation(x) 163 | return x 164 | 165 | 166 | class AdaptiveInstanceNorm2d(nn.Module): 167 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 168 | super(AdaptiveInstanceNorm2d, self).__init__() 169 | self.num_features = num_features 170 | self.eps = eps 171 | self.momentum = momentum 172 | self.weight = None 173 | self.bias = None 174 | self.register_buffer('running_mean', torch.zeros(num_features)) 175 | self.register_buffer('running_var', torch.ones(num_features)) 176 | 177 | def forward(self, x): 178 | assert self.weight is not None and \ 179 | self.bias is not None, "Please assign AdaIN weight first" 180 | b, c = x.size(0), x.size(1) 181 | running_mean = self.running_mean.repeat(b) 182 | running_var = self.running_var.repeat(b) 183 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 184 | out = F.batch_norm( 185 | x_reshaped, running_mean, running_var, self.weight, self.bias, 186 | True, self.momentum, self.eps) 187 | return out.view(b, c, *x.size()[2:]) 188 | 189 | def __repr__(self): 190 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 191 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | tn_hidden_dim = 512 2 | tn_dropout = 0.1 3 | tn_nheads = 8 4 | tn_dim_feedforward = 512 5 | tn_enc_layers = 3 6 | tn_dec_layers = 3 -------------------------------------------------------------------------------- /models/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = models.inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def fid_inception_v3(): 167 | """Build pretrained Inception model for FID computation 168 | 169 | The Inception model for FID computation uses a different set of weights 170 | and has a slightly different structure than torchvision's Inception. 171 | 172 | This method first constructs torchvision's Inception and then patches the 173 | necessary parts that are different in the FID Inception model. 174 | """ 175 | inception = models.inception_v3(num_classes=1008, 176 | aux_logits=False, 177 | pretrained=False) 178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 185 | inception.Mixed_7b = FIDInceptionE_1(1280) 186 | inception.Mixed_7c = FIDInceptionE_2(2048) 187 | 188 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 189 | inception.load_state_dict(state_dict) 190 | return inception 191 | 192 | 193 | class FIDInceptionA(models.inception.InceptionA): 194 | """InceptionA block patched for FID computation""" 195 | def __init__(self, in_channels, pool_features): 196 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 197 | 198 | def forward(self, x): 199 | branch1x1 = self.branch1x1(x) 200 | 201 | branch5x5 = self.branch5x5_1(x) 202 | branch5x5 = self.branch5x5_2(branch5x5) 203 | 204 | branch3x3dbl = self.branch3x3dbl_1(x) 205 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 206 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 207 | 208 | # Patch: Tensorflow's average pool does not use the padded zero's in 209 | # its average calculation 210 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 211 | count_include_pad=False) 212 | branch_pool = self.branch_pool(branch_pool) 213 | 214 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 215 | return torch.cat(outputs, 1) 216 | 217 | 218 | class FIDInceptionC(models.inception.InceptionC): 219 | """InceptionC block patched for FID computation""" 220 | def __init__(self, in_channels, channels_7x7): 221 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 222 | 223 | def forward(self, x): 224 | branch1x1 = self.branch1x1(x) 225 | 226 | branch7x7 = self.branch7x7_1(x) 227 | branch7x7 = self.branch7x7_2(branch7x7) 228 | branch7x7 = self.branch7x7_3(branch7x7) 229 | 230 | branch7x7dbl = self.branch7x7dbl_1(x) 231 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 232 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 233 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 234 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 235 | 236 | # Patch: Tensorflow's average pool does not use the padded zero's in 237 | # its average calculation 238 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 239 | count_include_pad=False) 240 | branch_pool = self.branch_pool(branch_pool) 241 | 242 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class FIDInceptionE_1(models.inception.InceptionE): 247 | """First InceptionE block patched for FID computation""" 248 | def __init__(self, in_channels): 249 | super(FIDInceptionE_1, self).__init__(in_channels) 250 | 251 | def forward(self, x): 252 | branch1x1 = self.branch1x1(x) 253 | 254 | branch3x3 = self.branch3x3_1(x) 255 | branch3x3 = [ 256 | self.branch3x3_2a(branch3x3), 257 | self.branch3x3_2b(branch3x3), 258 | ] 259 | branch3x3 = torch.cat(branch3x3, 1) 260 | 261 | branch3x3dbl = self.branch3x3dbl_1(x) 262 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 263 | branch3x3dbl = [ 264 | self.branch3x3dbl_3a(branch3x3dbl), 265 | self.branch3x3dbl_3b(branch3x3dbl), 266 | ] 267 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 268 | 269 | # Patch: Tensorflow's average pool does not use the padded zero's in 270 | # its average calculation 271 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 272 | count_include_pad=False) 273 | branch_pool = self.branch_pool(branch_pool) 274 | 275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 276 | return torch.cat(outputs, 1) 277 | 278 | 279 | class FIDInceptionE_2(models.inception.InceptionE): 280 | """Second InceptionE block patched for FID computation""" 281 | def __init__(self, in_channels): 282 | super(FIDInceptionE_2, self).__init__(in_channels) 283 | 284 | def forward(self, x): 285 | branch1x1 = self.branch1x1(x) 286 | 287 | branch3x3 = self.branch3x3_1(x) 288 | branch3x3 = [ 289 | self.branch3x3_2a(branch3x3), 290 | self.branch3x3_2b(branch3x3), 291 | ] 292 | branch3x3 = torch.cat(branch3x3, 1) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = [ 297 | self.branch3x3dbl_3a(branch3x3dbl), 298 | self.branch3x3dbl_3b(branch3x3dbl), 299 | ] 300 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 301 | 302 | # Patch: The FID Inception model uses max pooling instead of average 303 | # pooling. This is likely an error in this specific Inception 304 | # implementation, as other Inception models use average pooling here 305 | # (which matches the description in the paper). 306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) 311 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | from util.util import to_device, load_network 7 | 8 | ############################################################################### 9 | # Helper Functions 10 | ############################################################################### 11 | 12 | 13 | def init_weights(net, init_type='normal', init_gain=0.02): 14 | """Initialize network weights. 15 | 16 | Parameters: 17 | net (network) -- network to be initialized 18 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 19 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 20 | 21 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 22 | work better for some applications. Feel free to try yourself. 23 | """ 24 | def init_func(m): # define the initialization function 25 | classname = m.__class__.__name__ 26 | if (isinstance(m, nn.Conv2d) 27 | or isinstance(m, nn.Linear) 28 | or isinstance(m, nn.Embedding)): 29 | # if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 30 | if init_type == 'N02': 31 | init.normal_(m.weight.data, 0.0, init_gain) 32 | elif init_type in ['glorot', 'xavier']: 33 | init.xavier_normal_(m.weight.data, gain=init_gain) 34 | elif init_type == 'kaiming': 35 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 36 | elif init_type == 'ortho': 37 | init.orthogonal_(m.weight.data, gain=init_gain) 38 | else: 39 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 40 | # if hasattr(m, 'bias') and m.bias is not None: 41 | # init.constant_(m.bias.data, 0.0) 42 | # elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 43 | # init.normal_(m.weight.data, 1.0, init_gain) 44 | # init.constant_(m.bias.data, 0.0) 45 | if init_type in ['N02', 'glorot', 'xavier', 'kaiming', 'ortho']: 46 | # print('initialize network with %s' % init_type) 47 | net.apply(init_func) # apply the initialization function 48 | else: 49 | # print('loading the model from %s' % init_type) 50 | net = load_network(net, init_type, 'latest') 51 | return net 52 | 53 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 54 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 55 | Parameters: 56 | net (network) -- the network to be initialized 57 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 58 | gain (float) -- scaling factor for normal, xavier and orthogonal. 59 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 60 | 61 | Return an initialized network. 62 | """ 63 | if len(gpu_ids) > 0: 64 | assert(torch.cuda.is_available()) 65 | net.to(gpu_ids[0]) 66 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 67 | init_weights(net, init_type, init_gain=init_gain) 68 | return net 69 | 70 | 71 | def get_scheduler(optimizer, opt): 72 | """Return a learning rate scheduler 73 | 74 | Parameters: 75 | optimizer -- the optimizer of the network 76 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  77 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 78 | 79 | For 'linear', we keep the same learning rate for the first epochs 80 | and linearly decay the rate to zero over the next epochs. 81 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 82 | See https://pytorch.org/docs/stable/optim.html for more details. 83 | """ 84 | if opt.lr_policy == 'linear': 85 | def lambda_rule(epoch): 86 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 87 | return lr_l 88 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 89 | elif opt.lr_policy == 'step': 90 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 91 | elif opt.lr_policy == 'plateau': 92 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 93 | elif opt.lr_policy == 'cosine': 94 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 95 | else: 96 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 97 | return scheduler 98 | 99 | -------------------------------------------------------------------------------- /models/positional_encodings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_emb(sin_inp): 7 | """ 8 | Gets a base embedding for one dimension with sin and cos intertwined 9 | """ 10 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 11 | return torch.flatten(emb, -2, -1) 12 | 13 | 14 | class PositionalEncoding1D(nn.Module): 15 | def __init__(self, channels): 16 | """ 17 | :param channels: The last dimension of the tensor you want to apply pos emb to. 18 | """ 19 | super(PositionalEncoding1D, self).__init__() 20 | self.org_channels = channels 21 | channels = int(np.ceil(channels / 2) * 2) 22 | self.channels = channels 23 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 24 | self.register_buffer("inv_freq", inv_freq, persistent=False) 25 | self.cached_penc = None 26 | 27 | def forward(self, tensor): 28 | """ 29 | :param tensor: A 3d tensor of size (batch_size, x, ch) 30 | :return: Positional Encoding Matrix of size (batch_size, x, ch) 31 | """ 32 | if len(tensor.shape) != 3: 33 | raise RuntimeError("The input tensor has to be 3d!") 34 | 35 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 36 | return self.cached_penc 37 | 38 | self.cached_penc = None 39 | batch_size, x, orig_ch = tensor.shape 40 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 41 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 42 | emb_x = get_emb(sin_inp_x) 43 | emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type()) 44 | emb[:, : self.channels] = emb_x 45 | 46 | self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1) 47 | return self.cached_penc 48 | 49 | 50 | class PositionalEncodingPermute1D(nn.Module): 51 | def __init__(self, channels): 52 | """ 53 | Accepts (batchsize, ch, x) instead of (batchsize, x, ch) 54 | """ 55 | super(PositionalEncodingPermute1D, self).__init__() 56 | self.penc = PositionalEncoding1D(channels) 57 | 58 | def forward(self, tensor): 59 | tensor = tensor.permute(0, 2, 1) 60 | enc = self.penc(tensor) 61 | return enc.permute(0, 2, 1) 62 | 63 | @property 64 | def org_channels(self): 65 | return self.penc.org_channels 66 | 67 | 68 | class PositionalEncoding2D(nn.Module): 69 | def __init__(self, channels): 70 | """ 71 | :param channels: The last dimension of the tensor you want to apply pos emb to. 72 | """ 73 | super(PositionalEncoding2D, self).__init__() 74 | self.org_channels = channels 75 | channels = int(np.ceil(channels / 4) * 2) 76 | self.channels = channels 77 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 78 | self.register_buffer("inv_freq", inv_freq) 79 | self.cached_penc = None 80 | 81 | def forward(self, tensor): 82 | """ 83 | :param tensor: A 4d tensor of size (batch_size, x, y, ch) 84 | :return: Positional Encoding Matrix of size (batch_size, x, y, ch) 85 | """ 86 | if len(tensor.shape) != 4: 87 | raise RuntimeError("The input tensor has to be 4d!") 88 | 89 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 90 | return self.cached_penc 91 | 92 | self.cached_penc = None 93 | batch_size, x, y, orig_ch = tensor.shape 94 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 95 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 96 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 97 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 98 | emb_x = get_emb(sin_inp_x).unsqueeze(1) 99 | emb_y = get_emb(sin_inp_y) 100 | emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( 101 | tensor.type() 102 | ) 103 | emb[:, :, : self.channels] = emb_x 104 | emb[:, :, self.channels : 2 * self.channels] = emb_y 105 | 106 | self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1) 107 | return self.cached_penc 108 | 109 | 110 | class PositionalEncodingPermute2D(nn.Module): 111 | def __init__(self, channels): 112 | """ 113 | Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch) 114 | """ 115 | super(PositionalEncodingPermute2D, self).__init__() 116 | self.penc = PositionalEncoding2D(channels) 117 | 118 | def forward(self, tensor): 119 | tensor = tensor.permute(0, 2, 3, 1) 120 | enc = self.penc(tensor) 121 | return enc.permute(0, 3, 1, 2) 122 | 123 | @property 124 | def org_channels(self): 125 | return self.penc.org_channels 126 | 127 | 128 | class PositionalEncoding3D(nn.Module): 129 | def __init__(self, channels): 130 | """ 131 | :param channels: The last dimension of the tensor you want to apply pos emb to. 132 | """ 133 | super(PositionalEncoding3D, self).__init__() 134 | self.org_channels = channels 135 | channels = int(np.ceil(channels / 6) * 2) 136 | if channels % 2: 137 | channels += 1 138 | self.channels = channels 139 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 140 | self.register_buffer("inv_freq", inv_freq) 141 | self.cached_penc = None 142 | 143 | def forward(self, tensor): 144 | """ 145 | :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) 146 | :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) 147 | """ 148 | if len(tensor.shape) != 5: 149 | raise RuntimeError("The input tensor has to be 5d!") 150 | 151 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 152 | return self.cached_penc 153 | 154 | self.cached_penc = None 155 | batch_size, x, y, z, orig_ch = tensor.shape 156 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 157 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 158 | pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) 159 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 160 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 161 | sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) 162 | emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1) 163 | emb_y = get_emb(sin_inp_y).unsqueeze(1) 164 | emb_z = get_emb(sin_inp_z) 165 | emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type( 166 | tensor.type() 167 | ) 168 | emb[:, :, :, : self.channels] = emb_x 169 | emb[:, :, :, self.channels : 2 * self.channels] = emb_y 170 | emb[:, :, :, 2 * self.channels :] = emb_z 171 | 172 | self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1) 173 | return self.cached_penc 174 | 175 | 176 | class PositionalEncodingPermute3D(nn.Module): 177 | def __init__(self, channels): 178 | """ 179 | Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch) 180 | """ 181 | super(PositionalEncodingPermute3D, self).__init__() 182 | self.penc = PositionalEncoding3D(channels) 183 | 184 | def forward(self, tensor): 185 | tensor = tensor.permute(0, 2, 3, 4, 1) 186 | enc = self.penc(tensor) 187 | return enc.permute(0, 4, 1, 2, 3) 188 | 189 | @property 190 | def org_channels(self): 191 | return self.penc.org_channels 192 | 193 | 194 | class Summer(nn.Module): 195 | def __init__(self, penc): 196 | """ 197 | :param model: The type of positional encoding to run the summer on. 198 | """ 199 | super(Summer, self).__init__() 200 | self.penc = penc 201 | 202 | def forward(self, tensor): 203 | """ 204 | :param tensor: A 3, 4 or 5d tensor that matches the model output size 205 | :return: Positional Encoding Matrix summed to the original tensor 206 | """ 207 | penc = self.penc(tensor) 208 | assert ( 209 | tensor.size() == penc.size() 210 | ), "The original tensor size {} and the positional encoding tensor size {} must match!".format( 211 | tensor.size(), penc.size() 212 | ) 213 | return tensor + penc 214 | 215 | 216 | class SparsePositionalEncoding2D(PositionalEncoding2D): 217 | def __init__(self, channels, x, y, device='cuda'): 218 | super(SparsePositionalEncoding2D, self).__init__(channels) 219 | self.y, self.x = y, x 220 | self.fake_tensor = torch.zeros((1, x, y, channels), device=device) 221 | 222 | def forward(self, coords): 223 | """ 224 | :param coords: A list of list of coordinates (((x1, y1), (x2, y22), ... ), ... ) 225 | :return: Positional Encoding Matrix summed to the original tensor 226 | """ 227 | encodings = super().forward(self.fake_tensor) 228 | encodings = encodings.permute(0, 3, 1, 2) 229 | indices = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(c) for c in coords], batch_first=True, padding_value=-1) 230 | indices = indices.unsqueeze(0).to(self.fake_tensor.device) 231 | assert self.x == self.y 232 | indices = (indices + 0.5) / self.x * 2 - 1 233 | indices = torch.flip(indices, (-1, )) 234 | return torch.nn.functional.grid_sample(encodings, indices).squeeze().permute(2, 1, 0) 235 | 236 | # all_encodings = [] 237 | # for coords_row in coords: 238 | # res_encodings = [] 239 | # for xy in coords_row: 240 | # if xy is None: 241 | # res_encodings.append(padding) 242 | # else: 243 | # x, y = xy 244 | # res_encodings.append(encodings[x, y, :]) 245 | # all_encodings.append(res_encodings) 246 | # return torch.stack(res_encodings).to(self.fake_tensor.device) 247 | 248 | # coords = torch.Tensor(coords).to(self.fake_tensor.device).long() 249 | # assert torch.all(coords[:, 0] < self.x) 250 | # assert torch.all(coords[:, 1] < self.y) 251 | # coords = coords[:, 0] + (coords[:, 1] * self.x) 252 | # encodings = super().forward(self.fake_tensor).reshape((-1, self.org_channels)) 253 | # return encodings[coords] 254 | 255 | if __name__ == '__main__': 256 | pos = SparsePositionalEncoding2D(10, 10, 20) 257 | pos([[0, 0], [0, 9], [1, 0], [9, 15]]) 258 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /models/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input, gain=None, bias=None): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | out = F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | if gain is not None: 55 | out = out + gain 56 | if bias is not None: 57 | out = out + bias 58 | return out 59 | 60 | # Resize the input to (B, C, -1). 61 | input_shape = input.size() 62 | # print(input_shape) 63 | input = input.view(input.size(0), input.size(1), -1) 64 | 65 | # Compute the sum and square-sum. 66 | sum_size = input.size(0) * input.size(2) 67 | input_sum = _sum_ft(input) 68 | input_ssum = _sum_ft(input ** 2) 69 | # Reduce-and-broadcast the statistics. 70 | # print('it begins') 71 | if self._parallel_id == 0: 72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 73 | else: 74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | # if self._parallel_id == 0: 76 | # # print('here') 77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 78 | # else: 79 | # # print('there') 80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 81 | 82 | # print('how2') 83 | # num = sum_size 84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) 85 | # Fix the graph 86 | # sum = (sum.detach() - input_sum.detach()) + input_sum 87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum 88 | 89 | # mean = sum / num 90 | # var = ssum / num - mean ** 2 91 | # # var = (ssum - mean * sum) / num 92 | # inv_std = torch.rsqrt(var + self.eps) 93 | 94 | # Compute the output. 95 | if gain is not None: 96 | # print('gaining') 97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) 98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) 99 | # output = input * scale - shift 100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) 101 | elif self.affine: 102 | # MJY:: Fuse the multiplication for speed. 103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 104 | else: 105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 106 | 107 | # Reshape it. 108 | return output.view(input_shape) 109 | 110 | def __data_parallel_replicate__(self, ctx, copy_id): 111 | self._is_parallel = True 112 | self._parallel_id = copy_id 113 | 114 | # parallel_id == 0 means master device. 115 | if self._parallel_id == 0: 116 | ctx.sync_master = self._sync_master 117 | else: 118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 119 | 120 | def _data_parallel_master(self, intermediates): 121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 122 | 123 | # Always using same "device order" makes the ReduceAdd operation faster. 124 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 126 | 127 | to_reduce = [i[1][:2] for i in intermediates] 128 | to_reduce = [j for i in to_reduce for j in i] # flatten 129 | target_gpus = [i[1].sum.get_device() for i in intermediates] 130 | 131 | sum_size = sum([i[1].sum_size for i in intermediates]) 132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 134 | 135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 136 | # print('a') 137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) 138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) 139 | # print('b') 140 | outputs = [] 141 | for i, rec in enumerate(intermediates): 142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) 144 | 145 | return outputs 146 | 147 | def _compute_mean_std(self, sum_, ssum, size): 148 | """Compute the mean and standard-deviation with sum and square-sum. This method 149 | also maintains the moving average on the master device.""" 150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 151 | mean = sum_ / size 152 | sumvar = ssum - sum_ * mean 153 | unbias_var = sumvar / (size - 1) 154 | bias_var = sumvar / size 155 | 156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 158 | return mean, torch.rsqrt(bias_var + self.eps) 159 | # return mean, bias_var.clamp(self.eps) ** -0.5 160 | 161 | 162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 164 | mini-batch. 165 | 166 | .. math:: 167 | 168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 169 | 170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 171 | standard-deviation are reduced across all devices during training. 172 | 173 | For example, when one uses `nn.DataParallel` to wrap the network during 174 | training, PyTorch's implementation normalize the tensor on each device using 175 | the statistics only on that device, which accelerated the computation and 176 | is also easy to implement, but the statistics might be inaccurate. 177 | Instead, in this synchronized version, the statistics will be computed 178 | over all training samples distributed on multiple devices. 179 | 180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 181 | as the built-in PyTorch implementation. 182 | 183 | The mean and standard-deviation are calculated per-dimension over 184 | the mini-batches and gamma and beta are learnable parameter vectors 185 | of size C (where C is the input size). 186 | 187 | During training, this layer keeps a running estimate of its computed mean 188 | and variance. The running sum is kept with a default momentum of 0.1. 189 | 190 | During evaluation, this running mean/variance is used for normalization. 191 | 192 | Because the BatchNorm is done over the `C` dimension, computing statistics 193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 194 | 195 | Args: 196 | num_features: num_features from an expected input of size 197 | `batch_size x num_features [x width]` 198 | eps: a value added to the denominator for numerical stability. 199 | Default: 1e-5 200 | momentum: the value used for the running_mean and running_var 201 | computation. Default: 0.1 202 | affine: a boolean value that when set to ``True``, gives the layer learnable 203 | affine parameters. Default: ``True`` 204 | 205 | Shape: 206 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 208 | 209 | Examples: 210 | >>> # With Learnable Parameters 211 | >>> m = SynchronizedBatchNorm1d(100) 212 | >>> # Without Learnable Parameters 213 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 214 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 215 | >>> output = m(input) 216 | """ 217 | 218 | def _check_input_dim(self, input): 219 | if input.dim() != 2 and input.dim() != 3: 220 | raise ValueError('expected 2D or 3D input (got {}D input)' 221 | .format(input.dim())) 222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 223 | 224 | 225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 227 | of 3d inputs 228 | 229 | .. math:: 230 | 231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 232 | 233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 234 | standard-deviation are reduced across all devices during training. 235 | 236 | For example, when one uses `nn.DataParallel` to wrap the network during 237 | training, PyTorch's implementation normalize the tensor on each device using 238 | the statistics only on that device, which accelerated the computation and 239 | is also easy to implement, but the statistics might be inaccurate. 240 | Instead, in this synchronized version, the statistics will be computed 241 | over all training samples distributed on multiple devices. 242 | 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | 246 | The mean and standard-deviation are calculated per-dimension over 247 | the mini-batches and gamma and beta are learnable parameter vectors 248 | of size C (where C is the input size). 249 | 250 | During training, this layer keeps a running estimate of its computed mean 251 | and variance. The running sum is kept with a default momentum of 0.1. 252 | 253 | During evaluation, this running mean/variance is used for normalization. 254 | 255 | Because the BatchNorm is done over the `C` dimension, computing statistics 256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 257 | 258 | Args: 259 | num_features: num_features from an expected input of 260 | size batch_size x num_features x height x width 261 | eps: a value added to the denominator for numerical stability. 262 | Default: 1e-5 263 | momentum: the value used for the running_mean and running_var 264 | computation. Default: 0.1 265 | affine: a boolean value that when set to ``True``, gives the layer learnable 266 | affine parameters. Default: ``True`` 267 | 268 | Shape: 269 | - Input: :math:`(N, C, H, W)` 270 | - Output: :math:`(N, C, H, W)` (same shape as input) 271 | 272 | Examples: 273 | >>> # With Learnable Parameters 274 | >>> m = SynchronizedBatchNorm2d(100) 275 | >>> # Without Learnable Parameters 276 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 278 | >>> output = m(input) 279 | """ 280 | 281 | def _check_input_dim(self, input): 282 | if input.dim() != 4: 283 | raise ValueError('expected 4D input (got {}D input)' 284 | .format(input.dim())) 285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 286 | 287 | 288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 290 | of 4d inputs 291 | 292 | .. math:: 293 | 294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 295 | 296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 297 | standard-deviation are reduced across all devices during training. 298 | 299 | For example, when one uses `nn.DataParallel` to wrap the network during 300 | training, PyTorch's implementation normalize the tensor on each device using 301 | the statistics only on that device, which accelerated the computation and 302 | is also easy to implement, but the statistics might be inaccurate. 303 | Instead, in this synchronized version, the statistics will be computed 304 | over all training samples distributed on multiple devices. 305 | 306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 307 | as the built-in PyTorch implementation. 308 | 309 | The mean and standard-deviation are calculated per-dimension over 310 | the mini-batches and gamma and beta are learnable parameter vectors 311 | of size C (where C is the input size). 312 | 313 | During training, this layer keeps a running estimate of its computed mean 314 | and variance. The running sum is kept with a default momentum of 0.1. 315 | 316 | During evaluation, this running mean/variance is used for normalization. 317 | 318 | Because the BatchNorm is done over the `C` dimension, computing statistics 319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 320 | or Spatio-temporal BatchNorm 321 | 322 | Args: 323 | num_features: num_features from an expected input of 324 | size batch_size x num_features x depth x height x width 325 | eps: a value added to the denominator for numerical stability. 326 | Default: 1e-5 327 | momentum: the value used for the running_mean and running_var 328 | computation. Default: 0.1 329 | affine: a boolean value that when set to ``True``, gives the layer learnable 330 | affine parameters. Default: ``True`` 331 | 332 | Shape: 333 | - Input: :math:`(N, C, D, H, W)` 334 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 335 | 336 | Examples: 337 | >>> # With Learnable Parameters 338 | >>> m = SynchronizedBatchNorm3d(100) 339 | >>> # Without Learnable Parameters 340 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 342 | >>> output = m(input) 343 | """ 344 | 345 | def _check_input_dim(self, input): 346 | if input.dim() != 5: 347 | raise ValueError('expected 5D input (got {}D input)' 348 | .format(input.dim())) 349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /models/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /models/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /models/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | 5 | Copy-paste from torch.nn.Transformer with modifications: 6 | * positional encodings are passed in MHattention 7 | * extra LN at the end of encoder is removed 8 | * decoder returns a stack of activations from all decoding layers 9 | """ 10 | import copy 11 | from typing import Optional, List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | 18 | class Transformer(nn.Module): 19 | 20 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 21 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 22 | activation="relu", normalize_before=False, 23 | return_intermediate_dec=False): 24 | super().__init__() 25 | 26 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 27 | dropout, activation, normalize_before) 28 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 30 | 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(d_model) 34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 35 | return_intermediate=return_intermediate_dec) 36 | 37 | self._reset_parameters() 38 | 39 | self.d_model = d_model 40 | self.nhead = nhead 41 | 42 | def _reset_parameters(self): 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | def forward(self, src, query_embed, y_ind): 48 | # flatten NxCxHxW to HWxNxC 49 | bs, c, h, w = src.shape 50 | src = src.flatten(2).permute(2, 0, 1) 51 | 52 | y_emb = query_embed[y_ind].permute(1,0,2) 53 | 54 | tgt = torch.zeros_like(y_emb) 55 | memory = self.encoder(src) 56 | hs = self.decoder(tgt, memory, query_pos=y_emb) 57 | 58 | return torch.cat([hs.transpose(1, 2)[-1], y_emb.permute(1,0,2)], -1) 59 | 60 | 61 | class TransformerEncoder(nn.Module): 62 | 63 | def __init__(self, encoder_layer, num_layers, norm=None): 64 | super().__init__() 65 | self.layers = _get_clones(encoder_layer, num_layers) 66 | self.num_layers = num_layers 67 | self.norm = norm 68 | 69 | def forward(self, src, 70 | mask: Optional[Tensor] = None, 71 | src_key_padding_mask: Optional[Tensor] = None, 72 | pos: Optional[Tensor] = None): 73 | output = src 74 | 75 | for layer in self.layers: 76 | output = layer(output, src_mask=mask, 77 | src_key_padding_mask=src_key_padding_mask, pos=pos) 78 | 79 | if self.norm is not None: 80 | output = self.norm(output) 81 | 82 | return output 83 | 84 | 85 | class TransformerDecoder(nn.Module): 86 | 87 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 88 | super().__init__() 89 | self.layers = _get_clones(decoder_layer, num_layers) 90 | self.num_layers = num_layers 91 | self.norm = norm 92 | self.return_intermediate = return_intermediate 93 | 94 | def forward(self, tgt, memory, 95 | tgt_mask: Optional[Tensor] = None, 96 | memory_mask: Optional[Tensor] = None, 97 | tgt_key_padding_mask: Optional[Tensor] = None, 98 | memory_key_padding_mask: Optional[Tensor] = None, 99 | pos: Optional[Tensor] = None, 100 | query_pos: Optional[Tensor] = None): 101 | output = tgt 102 | 103 | intermediate = [] 104 | 105 | for layer in self.layers: 106 | output = layer(output, memory, tgt_mask=tgt_mask, 107 | memory_mask=memory_mask, 108 | tgt_key_padding_mask=tgt_key_padding_mask, 109 | memory_key_padding_mask=memory_key_padding_mask, 110 | pos=pos, query_pos=query_pos) 111 | if self.return_intermediate: 112 | intermediate.append(self.norm(output)) 113 | 114 | if self.norm is not None: 115 | output = self.norm(output) 116 | if self.return_intermediate: 117 | intermediate.pop() 118 | intermediate.append(output) 119 | 120 | if self.return_intermediate: 121 | return torch.stack(intermediate) 122 | 123 | return output.unsqueeze(0) 124 | 125 | 126 | class TransformerEncoderLayer(nn.Module): 127 | 128 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 129 | activation="relu", normalize_before=False): 130 | super().__init__() 131 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 132 | # Implementation of Feedforward model 133 | self.linear1 = nn.Linear(d_model, dim_feedforward) 134 | self.dropout = nn.Dropout(dropout) 135 | self.linear2 = nn.Linear(dim_feedforward, d_model) 136 | 137 | self.norm1 = nn.LayerNorm(d_model) 138 | self.norm2 = nn.LayerNorm(d_model) 139 | self.dropout1 = nn.Dropout(dropout) 140 | self.dropout2 = nn.Dropout(dropout) 141 | 142 | self.activation = _get_activation_fn(activation) 143 | self.normalize_before = normalize_before 144 | 145 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 146 | return tensor if pos is None else tensor + pos 147 | 148 | def forward_post(self, 149 | src, 150 | src_mask: Optional[Tensor] = None, 151 | src_key_padding_mask: Optional[Tensor] = None, 152 | pos: Optional[Tensor] = None): 153 | q = k = self.with_pos_embed(src, pos) 154 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 155 | key_padding_mask=src_key_padding_mask)[0] 156 | src = src + self.dropout1(src2) 157 | src = self.norm1(src) 158 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 159 | src = src + self.dropout2(src2) 160 | src = self.norm2(src) 161 | return src 162 | 163 | def forward_pre(self, src, 164 | src_mask: Optional[Tensor] = None, 165 | src_key_padding_mask: Optional[Tensor] = None, 166 | pos: Optional[Tensor] = None): 167 | src2 = self.norm1(src) 168 | q = k = self.with_pos_embed(src2, pos) 169 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 170 | key_padding_mask=src_key_padding_mask)[0] 171 | src = src + self.dropout1(src2) 172 | src2 = self.norm2(src) 173 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 174 | src = src + self.dropout2(src2) 175 | return src 176 | 177 | def forward(self, src, 178 | src_mask: Optional[Tensor] = None, 179 | src_key_padding_mask: Optional[Tensor] = None, 180 | pos: Optional[Tensor] = None): 181 | if self.normalize_before: 182 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 183 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 184 | 185 | 186 | class TransformerDecoderLayer(nn.Module): 187 | 188 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 189 | activation="relu", normalize_before=False): 190 | super().__init__() 191 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 192 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 193 | # Implementation of Feedforward model 194 | self.linear1 = nn.Linear(d_model, dim_feedforward) 195 | self.dropout = nn.Dropout(dropout) 196 | self.linear2 = nn.Linear(dim_feedforward, d_model) 197 | 198 | self.norm1 = nn.LayerNorm(d_model) 199 | self.norm2 = nn.LayerNorm(d_model) 200 | self.norm3 = nn.LayerNorm(d_model) 201 | self.dropout1 = nn.Dropout(dropout) 202 | self.dropout2 = nn.Dropout(dropout) 203 | self.dropout3 = nn.Dropout(dropout) 204 | 205 | self.activation = _get_activation_fn(activation) 206 | self.normalize_before = normalize_before 207 | 208 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 209 | return tensor if pos is None else tensor + pos 210 | 211 | def forward_post(self, tgt, memory, 212 | tgt_mask: Optional[Tensor] = None, 213 | memory_mask: Optional[Tensor] = None, 214 | tgt_key_padding_mask: Optional[Tensor] = None, 215 | memory_key_padding_mask: Optional[Tensor] = None, 216 | pos: Optional[Tensor] = None, 217 | query_pos: Optional[Tensor] = None): 218 | q = k = self.with_pos_embed(tgt, query_pos) 219 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 220 | key_padding_mask=tgt_key_padding_mask)[0] 221 | tgt = tgt + self.dropout1(tgt2) 222 | tgt = self.norm1(tgt) 223 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 224 | key=self.with_pos_embed(memory, pos), 225 | value=memory, attn_mask=memory_mask, 226 | key_padding_mask=memory_key_padding_mask)[0] 227 | tgt = tgt + self.dropout2(tgt2) 228 | tgt = self.norm2(tgt) 229 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 230 | tgt = tgt + self.dropout3(tgt2) 231 | tgt = self.norm3(tgt) 232 | return tgt 233 | 234 | def forward_pre(self, tgt, memory, 235 | tgt_mask: Optional[Tensor] = None, 236 | memory_mask: Optional[Tensor] = None, 237 | tgt_key_padding_mask: Optional[Tensor] = None, 238 | memory_key_padding_mask: Optional[Tensor] = None, 239 | pos: Optional[Tensor] = None, 240 | query_pos: Optional[Tensor] = None): 241 | tgt2 = self.norm1(tgt) 242 | q = k = self.with_pos_embed(tgt2, query_pos) 243 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 244 | key_padding_mask=tgt_key_padding_mask)[0] 245 | tgt = tgt + self.dropout1(tgt2) 246 | tgt2 = self.norm2(tgt) 247 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 248 | key=self.with_pos_embed(memory, pos), 249 | value=memory, attn_mask=memory_mask, 250 | key_padding_mask=memory_key_padding_mask)[0] 251 | tgt = tgt + self.dropout2(tgt2) 252 | tgt2 = self.norm3(tgt) 253 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 254 | tgt = tgt + self.dropout3(tgt2) 255 | return tgt 256 | 257 | def forward(self, tgt, memory, 258 | tgt_mask: Optional[Tensor] = None, 259 | memory_mask: Optional[Tensor] = None, 260 | tgt_key_padding_mask: Optional[Tensor] = None, 261 | memory_key_padding_mask: Optional[Tensor] = None, 262 | pos: Optional[Tensor] = None, 263 | query_pos: Optional[Tensor] = None): 264 | if self.normalize_before: 265 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 266 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 267 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 268 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 269 | 270 | 271 | def _get_clones(module, N): 272 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 273 | 274 | 275 | def build_transformer(args): 276 | return Transformer( 277 | d_model=args.hidden_dim, 278 | dropout=args.dropout, 279 | nhead=args.nheads, 280 | dim_feedforward=args.dim_feedforward, 281 | num_encoder_layers=args.enc_layers, 282 | num_decoder_layers=args.dec_layers, 283 | normalize_before=args.pre_norm, 284 | return_intermediate_dec=True, 285 | ) 286 | 287 | 288 | def _get_activation_fn(activation): 289 | """Return an activation function given a string""" 290 | if activation == "relu": 291 | return F.relu 292 | if activation == "gelu": 293 | return F.gelu 294 | if activation == "glu": 295 | return F.glu 296 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 297 | -------------------------------------------------------------------------------- /models/unifont_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import pickle 8 | import numpy as np 9 | 10 | 11 | def gauss(x, sigma=1.0): 12 | return (1.0 / math.sqrt(2.0 * math.pi) * sigma) * math.exp(-x**2 / (2.0 * sigma**2)) 13 | 14 | 15 | class UnifontModule(torch.nn.Module): 16 | def __init__(self, out_dim, alphabet, device='cuda', input_type='unifont', projection='linear'): 17 | super(UnifontModule, self).__init__() 18 | self.projection_type = projection 19 | self.device = device 20 | self.alphabet = alphabet 21 | self.symbols = self.get_symbols('unifont') 22 | self.symbols_repr = self.get_symbols(input_type) 23 | 24 | if projection == 'linear': 25 | self.linear = torch.nn.Linear(self.symbols_repr.shape[1], out_dim) 26 | else: 27 | self.linear = torch.nn.Identity() 28 | 29 | def get_symbols(self, input_type): 30 | with open(f"files/{input_type}.pickle", "rb") as f: 31 | symbols = pickle.load(f) 32 | 33 | all_symbols = {sym['idx'][0]: sym['mat'].astype(np.float32) for sym in symbols} 34 | symbols = [] 35 | for char in self.alphabet: 36 | im = all_symbols[ord(char)] 37 | im = im.flatten() 38 | symbols.append(im) 39 | 40 | symbols.insert(0, np.zeros_like(symbols[0])) 41 | symbols = np.stack(symbols) 42 | return torch.from_numpy(symbols).float().to(self.device) 43 | 44 | def forward(self, QR): 45 | if self.projection_type != 'cnn': 46 | return self.linear(self.symbols_repr[QR]) 47 | else: 48 | result = [] 49 | symbols = self.symbols_repr[QR] 50 | for b in range(QR.size(0)): 51 | result.append(self.linear(torch.unsqueeze(symbols[b], dim=1))) 52 | 53 | return torch.stack(result) 54 | 55 | 56 | class LearnableModule(torch.nn.Module): 57 | def __init__(self, out_dim, device='cuda'): 58 | super(LearnableModule, self).__init__() 59 | self.device = device 60 | self.param = torch.nn.Parameter(torch.zeros(1, 1, 256, device=device)) 61 | self.linear = torch.nn.Linear(256, out_dim) 62 | 63 | def forward(self, QR): 64 | return self.linear(self.param).repeat((QR.shape[0], 1, 1)) 65 | 66 | 67 | if __name__ == "__main__": 68 | module = UnifontModule(512, "bluuuuurp", 'cpu', projection='cnn') -------------------------------------------------------------------------------- /mytext.txt: -------------------------------------------------------------------------------- 1 | The well-known story I told at the conferences about hypochondria in L.A./California, New York,...and Richmond went as follows: It amused people who knew Tommy to hear this; however, it distressed Suzi when Tommy (1982--2019) asked, "How can I find out who yelled*, 'FIRE!' in the theater?" ---"ZOE DESCANEL." #PANIC. α + β is 374 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb==0.14.0 2 | opencv-python==4.4.* 3 | scipy==1.10.1 4 | gdown==4.6.6 5 | matplotlib 6 | streamlit -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import math 4 | import time 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | import wandb 10 | 11 | from data.dataset import TextDataset, CollectionTextDataset 12 | from models.model import VATr 13 | from util.misc import EpochLossTracker, add_vatr_args, LinearScheduler 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--resume", action='store_true') 19 | parser = add_vatr_args(parser) 20 | 21 | args = parser.parse_args() 22 | 23 | rSeed(args.seed) 24 | dataset = CollectionTextDataset( 25 | args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, 26 | collator_resolution=args.resolution, min_virtual_size=339, validation=False, debug=False, height=args.img_height 27 | ) 28 | datasetval = CollectionTextDataset( 29 | args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, 30 | collator_resolution=args.resolution, min_virtual_size=161, validation=True, height=args.img_height 31 | ) 32 | 33 | args.num_writers = dataset.num_writers 34 | 35 | if args.dataset == 'IAM' or args.dataset == 'CVL': 36 | args.alphabet = 'Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%' 37 | else: 38 | args.alphabet = ''.join(sorted(set(dataset.alphabet + datasetval.alphabet))) 39 | args.special_alphabet = ''.join(c for c in args.special_alphabet if c not in dataset.alphabet) 40 | 41 | args.exp_name = f"{args.dataset}-{args.num_writers}-{args.num_examples}-LR{args.g_lr}-bs{args.batch_size}-{args.tag}" 42 | 43 | config = {k: v for k, v in args.__dict__.items() if isinstance(v, (bool, int, str, float))} 44 | args.wandb = args.wandb and (not torch.cuda.is_available() or torch.cuda.get_device_name(0) != 'Tesla K80') 45 | wandb_id = wandb.util.generate_id() 46 | 47 | MODEL_PATH = os.path.join(args.save_model_path, args.exp_name) 48 | os.makedirs(MODEL_PATH, exist_ok=True) 49 | 50 | train_loader = torch.utils.data.DataLoader( 51 | dataset, 52 | batch_size=args.batch_size, 53 | shuffle=True, 54 | num_workers=args.num_workers, 55 | pin_memory=True, drop_last=True, 56 | collate_fn=dataset.collate_fn) 57 | 58 | val_loader = torch.utils.data.DataLoader( 59 | datasetval, 60 | batch_size=args.batch_size, 61 | shuffle=True, 62 | num_workers=args.num_workers, 63 | pin_memory=True, drop_last=True, 64 | collate_fn=datasetval.collate_fn) 65 | 66 | model = VATr(args) 67 | start_epoch = 0 68 | 69 | del config['alphabet'] 70 | del config['special_alphabet'] 71 | 72 | wandb_params = { 73 | 'project': 'VATr', 74 | 'config': config, 75 | 'name': args.exp_name, 76 | 'id': wandb_id 77 | } 78 | 79 | checkpoint_path = os.path.join(MODEL_PATH, 'model.pth') 80 | 81 | loss_tracker = EpochLossTracker() 82 | 83 | if args.resume and os.path.exists(checkpoint_path): 84 | checkpoint = torch.load(checkpoint_path, map_location=args.device) 85 | model.load_state_dict(checkpoint['model']) 86 | start_epoch = checkpoint['epoch'] 87 | wandb_params['id'] = checkpoint['wandb_id'] 88 | wandb_params['resume'] = True 89 | print(checkpoint_path + ' : Model loaded Successfully') 90 | elif args.resume: 91 | raise FileNotFoundError(f'No model found at {checkpoint_path}') 92 | else: 93 | if args.feat_model_path is not None and args.feat_model_path.lower() != 'none': 94 | print('Loading...', args.feat_model_path) 95 | assert os.path.exists(args.feat_model_path) 96 | checkpoint = torch.load(args.feat_model_path, map_location=args.device) 97 | checkpoint['model']['conv1.weight'] = checkpoint['model']['conv1.weight'].mean(1).unsqueeze(1) 98 | del checkpoint['model']['fc.weight'] 99 | del checkpoint['model']['fc.bias'] 100 | miss, unexp = model.netG.Feat_Encoder.load_state_dict(checkpoint['model'], strict=False) 101 | if not os.path.isdir(MODEL_PATH): 102 | os.mkdir(MODEL_PATH) 103 | else: 104 | print(f'WARNING: No resume of Resnet-18, starting from scratch') 105 | 106 | if args.wandb: 107 | wandb.init(**wandb_params) 108 | wandb.watch(model) 109 | 110 | print(f"Starting training") 111 | for epoch in range(start_epoch, args.epochs): 112 | start_time = time.time() 113 | log_time = time.time() 114 | loss_tracker.reset() 115 | model.d_acc.update(0.0) 116 | if args.text_augment_strength > 0: 117 | model.set_text_aug_strength(args.text_augment_strength) 118 | 119 | for i, data in enumerate(train_loader): 120 | model.update_parameters(epoch) 121 | model._set_input(data) 122 | 123 | model.optimize_G_only() 124 | model.optimize_G_step() 125 | 126 | model.optimize_D_OCR() 127 | model.optimize_D_OCR_step() 128 | 129 | model.optimize_G_WL() 130 | model.optimize_G_step() 131 | 132 | model.optimize_D_WL() 133 | model.optimize_D_WL_step() 134 | 135 | if time.time() - log_time > 10: 136 | print( 137 | f'Epoch {epoch} {i / len(train_loader) * 100:.02f}% running, current time: {time.time() - start_time:.2f} s') 138 | log_time = time.time() 139 | 140 | batch_losses = model.get_current_losses() 141 | batch_losses['d_acc'] = model.d_acc.avg 142 | loss_tracker.add_batch(batch_losses) 143 | 144 | end_time = time.time() 145 | data_val = next(iter(val_loader)) 146 | losses = loss_tracker.get_epoch_loss() 147 | page = model._generate_page(model.sdata, model.input['swids']) 148 | page_val = model._generate_page(data_val['simg'].to(args.device), data_val['swids']) 149 | 150 | d_train, d_val, d_fake = model.compute_d_stats(train_loader, val_loader) 151 | 152 | if args.wandb: 153 | wandb.log({ 154 | 'loss-G': losses['G'], 155 | 'loss-D': losses['D'], 156 | 'loss-Dfake': losses['Dfake'], 157 | 'loss-Dreal': losses['Dreal'], 158 | 'loss-OCR_fake': losses['OCR_fake'], 159 | 'loss-OCR_real': losses['OCR_real'], 160 | 'loss-w_fake': losses['w_fake'], 161 | 'loss-w_real': losses['w_real'], 162 | 'd_acc': losses['d_acc'], 163 | 'd-rv': (d_train - d_val) / (d_train - d_fake), 164 | 'd-fake': d_fake, 165 | 'd-real': d_train, 166 | 'd-val': d_val, 167 | 'l_cycle': losses['cycle'], 168 | 'epoch': epoch, 169 | 'timeperepoch': end_time - start_time, 170 | 'result': [wandb.Image(page, caption="page"), wandb.Image(page_val, caption="page_val")], 171 | 'd-crop-size': model.netD.augmenter.get_current_width() if model.netD.crop else 0 172 | }) 173 | 174 | print({'EPOCH': epoch, 'TIME': end_time - start_time, 'LOSSES': losses}) 175 | print(f"Text sample: {model.get_text_sample(10)}") 176 | 177 | checkpoint = { 178 | 'model': model.state_dict(), 179 | 'wandb_id': wandb_id, 180 | 'epoch': epoch 181 | } 182 | if epoch % args.save_model == 0: 183 | torch.save(checkpoint, os.path.join(MODEL_PATH, 'model.pth')) 184 | 185 | if epoch % args.save_model_history == 0: 186 | torch.save(checkpoint, os.path.join(MODEL_PATH, f'{epoch:04d}_model.pth')) 187 | 188 | 189 | def rSeed(sd): 190 | random.seed(sd) 191 | np.random.seed(sd) 192 | torch.manual_seed(sd) 193 | torch.cuda.manual_seed(sd) 194 | 195 | 196 | if __name__ == "__main__": 197 | print("Training Model") 198 | main() 199 | wandb.finish() 200 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/augmentations.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pickle 3 | import random 4 | from abc import ABC, abstractmethod 5 | 6 | import cv2 7 | import numpy as np 8 | import math 9 | import torch 10 | import torchvision.transforms 11 | import torchvision.transforms.functional as F 12 | from matplotlib import pyplot as plt 13 | 14 | from data.dataset import CollectionTextDataset, TextDataset 15 | 16 | 17 | def to_opencv(batch: torch.Tensor): 18 | images = [] 19 | 20 | for image in batch: 21 | image = image.detach().cpu().numpy() 22 | image = (image + 1.0) / 2.0 23 | images.append(np.squeeze(image)) 24 | 25 | return images 26 | 27 | 28 | class RandomMorphological(torch.nn.Module): 29 | def __init__(self, max_size: 5, max_iterations = 1, operation = cv2.MORPH_ERODE): 30 | super().__init__() 31 | self.elements = [cv2.MORPH_RECT, cv2.MORPH_ELLIPSE] 32 | self.max_size = max_size 33 | self.max_iterations = max_iterations 34 | self.operation = operation 35 | 36 | def forward(self, x): 37 | device = x.device 38 | 39 | images = to_opencv(x) 40 | 41 | result = [] 42 | 43 | size = random.randint(1, self.max_size) 44 | kernel = cv2.getStructuringElement(random.choice(self.elements), (size, size)) 45 | 46 | for image in images: 47 | image = cv2.resize(image, (image.shape[1] * 2, image.shape[0] * 2)) 48 | morphed = cv2.morphologyEx(image, op=self.operation, kernel=kernel, iterations=random.randint(1, self.max_iterations)) 49 | morphed = cv2.resize(morphed, (image.shape[1] // 2, image.shape[0] // 2)) 50 | morphed = morphed * 2.0 - 1.0 51 | 52 | result.append(torch.Tensor(morphed)) 53 | 54 | return torch.unsqueeze(torch.stack(result).to(device), dim=1) 55 | 56 | 57 | def gauss_noise_tensor(img): 58 | # https://github.com/pytorch/vision/issues/6192 59 | assert isinstance(img, torch.Tensor) 60 | dtype = img.dtype 61 | if not img.is_floating_point(): 62 | img = img.to(torch.float32) 63 | 64 | sigma = 0.075 65 | 66 | out = img + sigma * (torch.randn_like(img) - 0.5) 67 | 68 | out = torch.clamp(out, -1.0, 1.0) 69 | 70 | if out.dtype != dtype: 71 | out = out.to(dtype) 72 | 73 | return out 74 | 75 | 76 | def compute_word_width(image: torch.Tensor) -> int: 77 | indices = torch.where((image < 0).int())[2] 78 | index = torch.max(indices) if len(indices) > 0 else image.size(-1) 79 | 80 | return index 81 | 82 | 83 | class Downsize(torch.nn.Module): 84 | def __init__(self): 85 | super().__init__() 86 | 87 | self.aug = torchvision.transforms.Compose([ 88 | torchvision.transforms.RandomAffine(0.0, scale=(0.8, 1.0), interpolation=torchvision.transforms.InterpolationMode.NEAREST, fill=1.0), 89 | torchvision.transforms.GaussianBlur(3, sigma=0.3) 90 | ]) 91 | 92 | def forward(self, x): 93 | return self.aug(x) 94 | 95 | 96 | class OCRAugment(torch.nn.Module): 97 | def __init__(self, prob: float = 0.5, no: int = 2): 98 | super().__init__() 99 | self.prob = prob 100 | self.no = no 101 | 102 | interp = torchvision.transforms.InterpolationMode.NEAREST 103 | fill = 1.0 104 | 105 | self.augmentations = [ 106 | torchvision.transforms.RandomRotation(3.0, interpolation=interp, fill=fill), 107 | torchvision.transforms.RandomAffine(0.0, translate=(0.05, 0.05), interpolation=interp, fill=fill), 108 | Downsize(), 109 | torchvision.transforms.ElasticTransform(alpha=10.0, sigma=7.0, fill=fill, interpolation=interp), 110 | torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5), 111 | torchvision.transforms.GaussianBlur(3, sigma=(0.1, 1.0)), 112 | gauss_noise_tensor, 113 | RandomMorphological(max_size=4, max_iterations=2, operation=cv2.MORPH_ERODE), 114 | RandomMorphological(max_size=2, max_iterations=1, operation=cv2.MORPH_DILATE) 115 | ] 116 | 117 | def forward(self, x): 118 | if random.uniform(0.0, 1.0) > self.prob: 119 | return x 120 | 121 | augmentations = random.choices(self.augmentations, k=self.no) 122 | 123 | for augmentation in augmentations: 124 | x = augmentation(x) 125 | 126 | return x 127 | 128 | 129 | class WordCrop(torch.nn.Module, ABC): 130 | def __init__(self, use_padding: bool = False): 131 | super().__init__() 132 | self.use_padding = use_padding 133 | self.pad = torchvision.transforms.Pad([2, 2, 2, 2], 1.0) 134 | 135 | @abstractmethod 136 | def get_current_width(self): 137 | pass 138 | 139 | @abstractmethod 140 | def update(self, epoch: int): 141 | pass 142 | 143 | def forward(self, images): 144 | assert len(images.size()) == 4 and images.size(1) == 1, "Augmentation works on batches of one channel images" 145 | 146 | if self.use_padding: 147 | images = self.pad(images) 148 | 149 | results = [] 150 | width = self.get_current_width() 151 | 152 | for image in images: 153 | index = compute_word_width(image) 154 | max_index = max(min(index - width // 2, image.size(2) - width), 0) 155 | start_index = random.randint(0, max_index) 156 | 157 | results.append(F.crop(image, 0, start_index, image.size(1), min(width, image.size(2)))) 158 | 159 | return torch.stack(results) 160 | 161 | 162 | class StaticWordCrop(WordCrop): 163 | def __init__(self, width: int, use_padding: bool = False): 164 | super().__init__(use_padding=use_padding) 165 | self.width = width 166 | 167 | def get_current_width(self): 168 | return int(self.width) 169 | 170 | def update(self, epoch: int): 171 | pass 172 | 173 | 174 | class RandomWordCrop(WordCrop): 175 | def __init__(self, min_width: int, max_width: int, use_padding: bool = False): 176 | super().__init__(use_padding) 177 | 178 | self.min_width = min_width 179 | self.max_width = max_width 180 | 181 | self.current_width = random.randint(self.min_width, self.max_width) 182 | 183 | def update(self, epoch: int): 184 | self.current_width = random.randint(self.min_width, self.max_width) 185 | 186 | def get_current_width(self): 187 | return self.current_width 188 | 189 | 190 | class FullCrop(torch.nn.Module): 191 | def __init__(self, width: int): 192 | super().__init__() 193 | self.width = width 194 | self.height = 32 195 | self.pad = torchvision.transforms.Pad([6, 6, 6, 6], 1.0) 196 | 197 | def get_current_width(self): 198 | return self.width 199 | 200 | def forward(self, images): 201 | assert len(images.size()) == 4 and images.size(1) == 1, "Augmentation works on batches of one channel images" 202 | images = self.pad(images) 203 | 204 | results = [] 205 | 206 | for image in images: 207 | index = compute_word_width(image) 208 | max_index = max(min(index - self.width // 2, image.size(2) - self.width), 0) 209 | 210 | start_width = random.randint(0, max_index) 211 | start_height = random.randint(0, image.size(1) - self.height) 212 | 213 | results.append(F.crop(image, start_height, start_width, self.height, min(self.width, image.size(2)))) 214 | 215 | return torch.stack(results) 216 | 217 | 218 | class ProgressiveWordCrop(WordCrop): 219 | def __init__(self, width: int, warmup_epochs: int, start_width: int = 128, use_padding: bool = False): 220 | super().__init__(use_padding=use_padding) 221 | self.target_width = width 222 | self.warmup_epochs = warmup_epochs 223 | self.start_width = start_width 224 | self.current_width = float(start_width) 225 | 226 | def update(self, epoch: int): 227 | value = self.start_width - ((self.start_width - self.target_width) / self.warmup_epochs) * epoch 228 | self.current_width = max(value, self.target_width) 229 | 230 | def get_current_width(self): 231 | return int(round(self.current_width)) 232 | 233 | 234 | class CycleWordCrop(WordCrop): 235 | def __init__(self, width: int, cycle_epochs: int, start_width: int = 128, use_padding: bool = False): 236 | super().__init__(use_padding=use_padding) 237 | 238 | self.target_width = width 239 | self.start_width = start_width 240 | self.current_width = float(start_width) 241 | self.cycle_epochs = float(cycle_epochs) 242 | 243 | def update(self, epoch: int): 244 | value = (math.cos((float(epoch) * 2 * math.pi) / self.cycle_epochs) + 1) * ((self.start_width - self.target_width) / 2) + self.target_width 245 | self.current_width = value 246 | 247 | def get_current_width(self): 248 | return int(round(self.current_width)) 249 | 250 | 251 | class HeightResize(torch.nn.Module): 252 | def __init__(self, target_height: int): 253 | super().__init__() 254 | self.target_height = target_height 255 | 256 | def forward(self, x): 257 | width, height = F.get_image_size(x) 258 | scale = self.target_height / height 259 | 260 | return F.resize(x, [int(height * scale), int(width * scale)]) 261 | 262 | 263 | 264 | def show_crops(): 265 | with open("../files/IAM-32-pa.pickle", 'rb') as f: 266 | data = pickle.load(f) 267 | 268 | for author in data['train'].keys(): 269 | for image in data['train'][author]: 270 | image = torch.Tensor(np.expand_dims(np.expand_dims(np.array(image['img']), 0), 0)) 271 | 272 | augmenter = torchvision.transforms.Compose([ 273 | HeightResize(32), 274 | FullCrop(128) 275 | ]) 276 | 277 | batch = augmenter(image) 278 | 279 | batch = batch.detach().cpu().numpy() 280 | result = [np.squeeze(im) for im in batch] 281 | 282 | #plt.imshow(np.squeeze(image)) 283 | 284 | f, ax = plt.subplots(1, len(result)) 285 | 286 | for i in range(len(result)): 287 | ax.imshow(result[i]) 288 | 289 | plt.show() 290 | 291 | 292 | if __name__ == "__main__": 293 | dataset = CollectionTextDataset( 294 | 'IAM', '../files', TextDataset, file_suffix='pa', num_examples=15, 295 | collator_resolution=16, min_virtual_size=339, validation=False, debug=False 296 | ) 297 | 298 | train_loader = torch.utils.data.DataLoader( 299 | dataset, 300 | batch_size=8, 301 | shuffle=True, 302 | pin_memory=True, drop_last=True, 303 | collate_fn=dataset.collate_fn) 304 | 305 | augmenter = OCRAugment(no=3, prob=1.0) 306 | 307 | target_folder = r"C:\Users\bramv\Documents\Werk\Research\Unimore\VATr\VATr_ext\saved_images\debug\ocr_aug" 308 | 309 | image_no = 0 310 | 311 | for batch in train_loader: 312 | for i in range(5): 313 | augmented = augmenter(batch["img"]) 314 | 315 | img = np.squeeze((augmented[0].detach().cpu().numpy() + 1.0) / 2.0) 316 | 317 | img = (img * 255.0).astype(np.uint8) 318 | 319 | print(cv2.imwrite(os.path.join(target_folder, f"{image_no}_{i}.png"), img)) 320 | 321 | img = np.squeeze((batch["img"][0].detach().cpu().numpy() + 1.0) / 2.0) 322 | img = (img * 255.0).astype(np.uint8) 323 | cv2.imwrite(os.path.join(target_folder, f"{image_no}.png"), img) 324 | 325 | if image_no > 5: 326 | break 327 | 328 | image_no+=1 329 | 330 | -------------------------------------------------------------------------------- /util/loading.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | OLD_KEYS = ['netG._logvarD.bias', 'netG._logvarD.weight', 'netG._logvarE.bias', 'netG._logvarE.weight', 'netG._muD.bias', 'netG._muD.weight', 'netG._muE.bias', 'netG._muE.weight', 'netD.embed.weight', 'netD.embed.u0', 'netD.embed.sv0', 'netD.embed.bias'] 4 | 5 | 6 | def load_generator(model, checkpoint): 7 | if not isinstance(checkpoint, collections.OrderedDict): 8 | checkpoint = checkpoint['model'] 9 | 10 | checkpoint = {k.replace("netG.",""): v for k, v in checkpoint.items() if k.startswith("netG") and k not in OLD_KEYS} 11 | model.netG.load_state_dict(checkpoint) 12 | 13 | return model 14 | 15 | 16 | def load_checkpoint(model, checkpoint): 17 | if not isinstance(checkpoint, collections.OrderedDict): 18 | checkpoint = checkpoint['model'] 19 | old_model = model.state_dict() 20 | if len(checkpoint.keys()) == 241: # default 21 | counter = 0 22 | for k, v in checkpoint.items(): 23 | if k in old_model: 24 | old_model[k] = v 25 | counter += 1 26 | elif 'netG.' + k in old_model: 27 | old_model['netG.' + k] = v 28 | counter += 1 29 | 30 | ckeys = [k for k in checkpoint.keys() if 'Feat_Encoder' in k] 31 | okeys = [k for k in old_model.keys() if 'Feat_Encoder' in k] 32 | for ck, ok in zip(ckeys, okeys): 33 | old_model[ok] = checkpoint[ck] 34 | counter += 1 35 | assert counter == 241 36 | checkpoint_dict = old_model 37 | else: 38 | checkpoint = {k: v for k, v in checkpoint.items() if k not in OLD_KEYS} 39 | assert len(old_model) == len(checkpoint) 40 | checkpoint_dict = {k2: v1 for (k1, v1), (k2, v2) in zip(checkpoint.items(), old_model.items()) if 41 | v1.shape == v2.shape} 42 | assert len(old_model) == len(checkpoint_dict) 43 | model.load_state_dict(checkpoint_dict, strict=False) 44 | return model -------------------------------------------------------------------------------- /util/text.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import random 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import string 8 | from abc import ABC, abstractmethod 9 | from functools import partial 10 | 11 | 12 | class TextGenerator(ABC): 13 | def __init__(self, max_lenght: int = None): 14 | self.max_length = max_lenght 15 | 16 | @abstractmethod 17 | def generate(self): 18 | pass 19 | 20 | 21 | class AugmentedGenerator(TextGenerator): 22 | def __init__(self, strength: float, alphabet: list, max_lenght: int = None): 23 | super().__init__(max_lenght) 24 | self.strength = strength 25 | self.alphabet = list(alphabet) 26 | if "%" in alphabet: 27 | self.alphabet.remove("%") 28 | 29 | @abstractmethod 30 | def generate(self): 31 | pass 32 | 33 | def set_strength(self, strength: float): 34 | self.strength = strength 35 | 36 | def get_strength(self): 37 | return self.strength 38 | 39 | 40 | class ProportionalAugmentedGenerator(AugmentedGenerator): 41 | def __init__(self, max_length: int, generator: TextGenerator, alphabet: list, strength: float = 0.5): 42 | super().__init__(strength, alphabet, max_length) 43 | self.generator = generator 44 | 45 | self.char_stats = {} 46 | self.sampling_probs = {} 47 | self.init_statistics() 48 | 49 | def init_statistics(self): 50 | char_occurrences = {k: 0 for k in self.alphabet} 51 | character_count = 0 52 | 53 | for _ in range(10000): 54 | word = self.generator.generate() 55 | for char in word: 56 | char_occurrences[char] += 1 57 | character_count += 1 58 | 59 | self.char_stats = {k: v / character_count for k, v in char_occurrences.items()} 60 | scale = max([v for v in self.char_stats.values()]) 61 | self.char_stats = {k: v / scale for k, v in self.char_stats.items()} 62 | self.sampling_probs = {k: 1.0 - v for k, v in self.char_stats.items()} 63 | 64 | def random_char(self): 65 | return random.choices(list(self.sampling_probs.keys()), weights=list(self.sampling_probs.values()), k=1)[0] 66 | 67 | def generate(self): 68 | word = self.generator.generate() 69 | word = self.augment(word) 70 | return word 71 | 72 | def augment(self, word): 73 | probs = np.random.rand(len(word)) 74 | target_probs = [self.strength * self.char_stats[c] for c in word] 75 | 76 | replace = probs < target_probs 77 | 78 | for index in range(len(word)): 79 | if replace[index]: 80 | char = self.random_char() 81 | word = set_char(word, char, index) 82 | return word 83 | 84 | 85 | class FileTextGenerator(TextGenerator): 86 | def __init__(self, max_length: int, file_path: str, alphabet: list): 87 | super().__init__(max_length) 88 | 89 | with open(file_path, 'r') as f: 90 | self.words = f.read().splitlines() 91 | self.words = [l for l in self.words if len(l) < self.max_length and set(l) <= set(alphabet)] 92 | 93 | def generate(self): 94 | return random.choice(self.words) 95 | 96 | 97 | class CVLFileTextIterator(TextGenerator): 98 | def __init__(self, max_length: int, file_path: str, alphabet: list): 99 | super().__init__(max_length) 100 | 101 | self.words = [] 102 | 103 | with open(file_path, 'r') as f: 104 | next(f) 105 | for line in f: 106 | _, *annotation = line.rstrip().split(",") 107 | annotation = ",".join(annotation) 108 | self.words.append(annotation) 109 | self.words = [l for l in self.words if len(l) < self.max_length and set(l) <= set(alphabet)] 110 | self.index = 0 111 | 112 | def generate(self): 113 | word = self.words[self.index % len(self.words)] 114 | self.index += 1 115 | return word 116 | 117 | 118 | def set_char(s, character, location): 119 | return s[:location] + character + s[location + 1:] 120 | 121 | 122 | class GibberishGenerator(TextGenerator): 123 | def __init__(self, max_length: int = None): 124 | super().__init__(max_length) 125 | self.lower_case = list(string.ascii_lowercase) 126 | self.upper_case = list(string.ascii_uppercase) 127 | self.special = list(' .-\',"&();#:!?+*/') 128 | self.numbers = [str(i) for i in range(10)] 129 | 130 | def get_word_length(self) -> int: 131 | length = int(math.ceil(np.random.chisquare(8))) 132 | while self.max_length is not None and length > self.max_length: 133 | length = int(math.ceil(np.random.chisquare(8))) 134 | return length 135 | 136 | def generate(self): 137 | return self.generate_random() 138 | 139 | def generate_random(self): 140 | alphabet = self.upper_case + self.lower_case + self.special + self.numbers 141 | string = ''.join(random.choices(alphabet, k=self.get_word_length())) 142 | 143 | return string 144 | 145 | 146 | class IAMTextGenerator(TextGenerator): 147 | def generate(self): 148 | return random.choice(self.words) 149 | 150 | def __init__(self, max_length: int, path: str, subset: str = 'train'): 151 | super().__init__(max_length) 152 | 153 | with open(path, 'rb') as f: 154 | data = pickle.load(f) 155 | 156 | data = data[subset] 157 | self.words = [] 158 | for author_id in data.keys(): 159 | for image_dict in data[author_id]: 160 | if len(image_dict['label']) <= self.max_length: 161 | self.words.append(image_dict['label']) 162 | 163 | 164 | def get_generator(args): 165 | if args.corpus == "standard": 166 | if args.english_words_path.endswith(".csv"): 167 | generator = CVLFileTextIterator(20, args.english_words_path, args.alphabet) 168 | else: 169 | generator = FileTextGenerator(20, args.english_words_path, args.alphabet) 170 | else: 171 | generator = IAMTextGenerator(20, "files/IAM-32.pickle", 'train') 172 | 173 | if args.text_augment_strength > 0: 174 | if args.text_aug_type == 'proportional': 175 | return ProportionalAugmentedGenerator(20, generator, args.alphabet, args.text_augment_strength) 176 | elif args.text_aug_type == 'gibberish': 177 | return GibberishGenerator(20) 178 | else: 179 | return ProportionalAugmentedGenerator(20, generator, args.alphabet, args.text_augment_strength) 180 | 181 | return generator 182 | 183 | 184 | if __name__ == "__main__": 185 | alphabet = list('Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%') 186 | original_generator = FileTextGenerator(max_length=20, file_path="../files/english_words.txt", alphabet=alphabet) 187 | gib = ProportionalAugmentedGenerator(20, original_generator, alphabet=alphabet, strength=0.5) 188 | 189 | generated_words = [] 190 | 191 | for _ in range(1000): 192 | word = gib.generate() 193 | generated_words.append(len(word)) 194 | if len(set(word)) < len(word): 195 | print(word) 196 | 197 | plt.hist(generated_words) 198 | plt.show() -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import torch.nn.functional as F 8 | 9 | 10 | def load_network(net, save_dir, epoch): 11 | """Load all the networks from the disk. 12 | 13 | Parameters: 14 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 15 | """ 16 | load_filename = '%s_net_%s.pth' % (epoch, net.name) 17 | load_path = os.path.join(save_dir, load_filename) 18 | # if you are using PyTorch newer than 0.4 (e.g., built from 19 | # GitHub source), you can remove str() on self.device 20 | state_dict = torch.load(load_path) 21 | if hasattr(state_dict, '_metadata'): 22 | del state_dict._metadata 23 | net.load_state_dict(state_dict) 24 | return net 25 | 26 | def writeCache(env, cache): 27 | with env.begin(write=True) as txn: 28 | for k, v in cache.items(): 29 | if type(k) == str: 30 | k = k.encode() 31 | if type(v) == str: 32 | v = v.encode() 33 | txn.put(k, v) 34 | 35 | def loadData(v, data): 36 | with torch.no_grad(): 37 | v.resize_(data.size()).copy_(data) 38 | 39 | def multiple_replace(string, rep_dict): 40 | for key in rep_dict.keys(): 41 | string = string.replace(key, rep_dict[key]) 42 | return string 43 | 44 | def get_curr_data(data, batch_size, counter): 45 | curr_data = {} 46 | for key in data: 47 | curr_data[key] = data[key][batch_size*counter:batch_size*(counter+1)] 48 | return curr_data 49 | 50 | # Utility file to seed rngs 51 | def seed_rng(seed): 52 | torch.manual_seed(seed) 53 | torch.cuda.manual_seed(seed) 54 | np.random.seed(seed) 55 | 56 | # turn tensor of classes to tensor of one hot tensors: 57 | def make_one_hot(labels, len_labels, n_classes): 58 | one_hot = torch.zeros((labels.shape[0], labels.shape[1], n_classes),dtype=torch.float32) 59 | for i in range(len(labels)): 60 | one_hot[i,np.array(range(len_labels[i])), labels[i,:len_labels[i]]-1]=1 61 | return one_hot 62 | 63 | # Hinge Loss 64 | def loss_hinge_dis(dis_fake, dis_real, len_text_fake, len_text, mask_loss): 65 | try: 66 | mask_real = torch.ones(dis_real.shape).to(dis_real.device) 67 | mask_fake = torch.ones(dis_fake.shape).to(dis_fake.device) 68 | except RuntimeError: 69 | raise 70 | if mask_loss and len(dis_fake.shape)>2: 71 | for i in range(len(len_text)): 72 | mask_real[i, :, :, len_text[i]:] = 0 73 | mask_fake[i, :, :, len_text_fake[i]:] = 0 74 | loss_real = torch.sum(F.relu(1. - dis_real * mask_real))/torch.sum(mask_real) 75 | loss_fake = torch.sum(F.relu(1. + dis_fake * mask_fake))/torch.sum(mask_fake) 76 | return loss_real, loss_fake 77 | 78 | 79 | def loss_hinge_gen(dis_fake, len_text_fake, mask_loss): 80 | mask_fake = torch.ones(dis_fake.shape).to(dis_fake.device) 81 | if mask_loss and len(dis_fake.shape)>2: 82 | for i in range(len(len_text_fake)): 83 | mask_fake[i, :, :, len_text_fake[i]:] = 0 84 | loss = -torch.sum(dis_fake*mask_fake)/torch.sum(mask_fake) 85 | return loss 86 | 87 | def loss_std(z, lengths, mask_loss): 88 | loss_std = torch.zeros(1).to(z.device) 89 | z_mean = torch.ones((z.shape[0], z.shape[1])).to(z.device) 90 | for i in range(len(lengths)): 91 | if mask_loss: 92 | if lengths[i]>1: 93 | loss_std += torch.mean(torch.std(z[i, :, :, :lengths[i]], 2)) 94 | z_mean[i,:] = torch.mean(z[i, :, :, :lengths[i]], 2).squeeze(1) 95 | else: 96 | z_mean[i, :] = z[i, :, :, 0].squeeze(1) 97 | else: 98 | loss_std += torch.mean(torch.std(z[i, :, :, :], 2)) 99 | z_mean[i,:] = torch.mean(z[i, :, :, :], 2).squeeze(1) 100 | loss_std = loss_std/z.shape[0] 101 | return loss_std, z_mean 102 | 103 | # Convenience utility to switch off requires_grad 104 | def toggle_grad(model, on_or_off): 105 | for param in model.parameters(): 106 | param.requires_grad = on_or_off 107 | 108 | 109 | # Apply modified ortho reg to a model 110 | # This function is an optimized version that directly computes the gradient, 111 | # instead of computing and then differentiating the loss. 112 | def ortho(model, strength=1e-4, blacklist=[]): 113 | with torch.no_grad(): 114 | for param in model.parameters(): 115 | # Only apply this to parameters with at least 2 axes, and not in the blacklist 116 | if len(param.shape) < 2 or any([param is item for item in blacklist]): 117 | continue 118 | w = param.view(param.shape[0], -1) 119 | grad = (2 * torch.mm(torch.mm(w, w.t()) 120 | * (1. - torch.eye(w.shape[0], device=w.device)), w)) 121 | param.grad.data += strength * grad.view(param.shape) 122 | 123 | 124 | # Default ortho reg 125 | # This function is an optimized version that directly computes the gradient, 126 | # instead of computing and then differentiating the loss. 127 | def default_ortho(model, strength=1e-4, blacklist=[]): 128 | with torch.no_grad(): 129 | for param in model.parameters(): 130 | # Only apply this to parameters with at least 2 axes & not in blacklist 131 | if len(param.shape) < 2 or param in blacklist: 132 | continue 133 | w = param.view(param.shape[0], -1) 134 | grad = (2 * torch.mm(torch.mm(w, w.t()) 135 | - torch.eye(w.shape[0], device=w.device), w)) 136 | param.grad.data += strength * grad.view(param.shape) 137 | 138 | 139 | # Convenience utility to switch off requires_grad 140 | def toggle_grad(model, on_or_off): 141 | for param in model.parameters(): 142 | param.requires_grad = on_or_off 143 | 144 | 145 | # A highly simplified convenience class for sampling from distributions 146 | # One could also use PyTorch's inbuilt distributions package. 147 | # Note that this class requires initialization to proceed as 148 | # x = Distribution(torch.randn(size)) 149 | # x.init_distribution(dist_type, **dist_kwargs) 150 | # x = x.to(device,dtype) 151 | # This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2 152 | class Distribution(torch.Tensor): 153 | # Init the params of the distribution 154 | def init_distribution(self, dist_type, **kwargs): 155 | seed_rng(kwargs['seed']) 156 | self.dist_type = dist_type 157 | self.dist_kwargs = kwargs 158 | if self.dist_type == 'normal': 159 | self.mean, self.var = kwargs['mean'], kwargs['var'] 160 | elif self.dist_type == 'categorical': 161 | self.num_categories = kwargs['num_categories'] 162 | elif self.dist_type == 'poisson': 163 | self.lam = kwargs['var'] 164 | elif self.dist_type == 'gamma': 165 | self.scale = kwargs['var'] 166 | 167 | 168 | def sample_(self): 169 | if self.dist_type == 'normal': 170 | self.normal_(self.mean, self.var) 171 | elif self.dist_type == 'categorical': 172 | self.random_(0, self.num_categories) 173 | elif self.dist_type == 'poisson': 174 | type = self.type() 175 | device = self.device 176 | data = np.random.poisson(self.lam, self.size()) 177 | self.data = torch.from_numpy(data).type(type).to(device) 178 | elif self.dist_type == 'gamma': 179 | type = self.type() 180 | device = self.device 181 | data = np.random.gamma(shape=1, scale=self.scale, size=self.size()) 182 | self.data = torch.from_numpy(data).type(type).to(device) 183 | # return self.variable 184 | 185 | # Silly hack: overwrite the to() method to wrap the new object 186 | # in a distribution as well 187 | def to(self, *args, **kwargs): 188 | new_obj = Distribution(self) 189 | new_obj.init_distribution(self.dist_type, **self.dist_kwargs) 190 | new_obj.data = super().to(*args, **kwargs) 191 | return new_obj 192 | 193 | 194 | def to_device(net, gpu_ids): 195 | if len(gpu_ids) > 0: 196 | assert(torch.cuda.is_available()) 197 | net.to(gpu_ids[0]) 198 | # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 199 | if len(gpu_ids)>1: 200 | net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda() 201 | # net = torch.nn.DistributedDataParallel(net) 202 | return net 203 | 204 | 205 | # Convenience function to prepare a z and y vector 206 | def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', 207 | fp16=False, z_var=1.0, z_dist='normal', seed=0): 208 | z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) 209 | z_.init_distribution(z_dist, mean=0, var=z_var, seed=seed) 210 | z_ = z_.to(device, torch.float16 if fp16 else torch.float32) 211 | 212 | if fp16: 213 | z_ = z_.half() 214 | 215 | y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) 216 | y_.init_distribution('categorical', num_categories=nclasses, seed=seed) 217 | y_ = y_.to(device, torch.int64) 218 | return z_, y_ 219 | 220 | 221 | def tensor2im(input_image, imtype=np.uint8): 222 | """"Converts a Tensor array into a numpy image array. 223 | 224 | Parameters: 225 | input_image (tensor) -- the input image tensor array 226 | imtype (type) -- the desired type of the converted numpy array 227 | """ 228 | if not isinstance(input_image, np.ndarray): 229 | if isinstance(input_image, torch.Tensor): # get the data from a variable 230 | image_tensor = input_image.data 231 | else: 232 | return input_image 233 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 234 | if image_numpy.shape[0] == 1: # grayscale to RGB 235 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 236 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 237 | else: # if it is a numpy array, do nothing 238 | image_numpy = input_image 239 | return image_numpy.astype(imtype) 240 | 241 | 242 | def diagnose_network(net, name='network'): 243 | """Calculate and print the mean of average absolute(gradients) 244 | 245 | Parameters: 246 | net (torch network) -- Torch network 247 | name (str) -- the name of the network 248 | """ 249 | mean = 0.0 250 | count = 0 251 | for param in net.parameters(): 252 | if param.grad is not None: 253 | mean += torch.mean(torch.abs(param.grad.data)) 254 | count += 1 255 | if count > 0: 256 | mean = mean / count 257 | print(name) 258 | print(mean) 259 | 260 | 261 | def save_image(image_numpy, image_path): 262 | """Save a numpy image to the disk 263 | 264 | Parameters: 265 | image_numpy (numpy array) -- input numpy array 266 | image_path (str) -- the path of the image 267 | """ 268 | image_pil = Image.fromarray(image_numpy) 269 | image_pil.save(image_path) 270 | 271 | 272 | def print_numpy(x, val=True, shp=False): 273 | """Print the mean, min, max, median, std, and size of a numpy array 274 | 275 | Parameters: 276 | val (bool) -- if print the values of the numpy array 277 | shp (bool) -- if print the shape of the numpy array 278 | """ 279 | x = x.astype(np.float64) 280 | if shp: 281 | print('shape,', x.shape) 282 | if val: 283 | x = x.flatten() 284 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 285 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 286 | 287 | 288 | def mkdirs(paths): 289 | """create empty directories if they don't exist 290 | 291 | Parameters: 292 | paths (str list) -- a list of directory paths 293 | """ 294 | if isinstance(paths, list) and not isinstance(paths, str): 295 | for path in paths: 296 | mkdir(path) 297 | else: 298 | mkdir(paths) 299 | 300 | 301 | def mkdir(path): 302 | """create a single empty directory if it didn't exist 303 | 304 | Parameters: 305 | path (str) -- a single directory path 306 | """ 307 | if not os.path.exists(path): 308 | os.makedirs(path) 309 | -------------------------------------------------------------------------------- /util/vision.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def detect_text_bounds(image: np.array) -> (int, int): 6 | """ 7 | Find the lower and upper bounding lines in an image of a word 8 | """ 9 | if len(image.shape) >= 3 and image.shape[2] == 3: 10 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 11 | elif len(image.shape) >= 3 and image.shape[2] == 1: 12 | image = np.squeeze(image, axis=-1) 13 | 14 | _, threshold = cv2.threshold(image, 0.8, 1, cv2.THRESH_BINARY_INV) 15 | 16 | line_sums = np.sum(threshold, axis=1).astype(float) 17 | line_sums = np.convolve(line_sums, np.ones(5) / 5, mode='same') 18 | 19 | line_sums_d = np.diff(line_sums) 20 | 21 | std_factor = 0.5 22 | min_threshold = np.mean(line_sums_d[line_sums_d <= 0]) - std_factor * np.std(line_sums_d[line_sums_d <= 0]) 23 | bottom_index = np.max(np.where(line_sums_d < min_threshold)) 24 | 25 | max_threshold = np.mean(line_sums_d[line_sums_d >= 0]) + std_factor * np.std(line_sums_d[line_sums_d >= 0]) 26 | top_index = np.min(np.where(line_sums_d > max_threshold)) 27 | 28 | return bottom_index, top_index 29 | 30 | 31 | def dist(p_one, p_two) -> float: 32 | return np.linalg.norm(p_two - p_one) 33 | 34 | 35 | def crop(image: np.array, ratio: float = None, pixels: int = None) -> np.array: 36 | assert ratio is not None or pixels is not None, "Please specify either pixels or a ratio to crop" 37 | 38 | width, height = image.shape[:2] 39 | 40 | if ratio is not None: 41 | 42 | width_crop = int(ratio * width) 43 | height_crop = int(ratio * height) 44 | else: 45 | width_crop= pixels 46 | height_crop = pixels 47 | 48 | return image[height_crop:height-height_crop, width_crop:width-width_crop] 49 | 50 | 51 | def find_target_points(top_left, top_right, bottom_left, bottom_right): 52 | max_width = max(int(dist(bottom_right, bottom_left)), int(dist(top_right, top_left))) 53 | max_height = max(int(dist(top_right, bottom_right)), int(dist(top_left, bottom_left))) 54 | destination_corners = [[0, 0], [max_width, 0], [max_width, max_height], [0, max_height]] 55 | 56 | return order_points(destination_corners) 57 | 58 | 59 | def order_points(points: np.array) -> tuple: 60 | """ 61 | inspired by: https://learnopencv.com/automatic-document-scanner-using-opencv/ 62 | """ 63 | sum = np.sum(points, axis=1) 64 | top_left = points[np.argmin(sum)] 65 | bottom_right = points[np.argmax(sum)] 66 | 67 | diff = np.diff(points, axis=1) 68 | top_right = points[np.argmin(diff)] 69 | bottom_left = points[np.argmax(diff)] 70 | 71 | return top_left, top_right, bottom_left, bottom_right 72 | 73 | 74 | def get_page(image: np.array) -> np.array: 75 | """ 76 | inspired by: https://github.com/Kakaranish/OpenCV-paper-detection 77 | """ 78 | filtered = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 79 | filtered = cv2.medianBlur(filtered, 11) 80 | 81 | canny = cv2.Canny(filtered, 30, 50, 3) 82 | contours, _ = cv2.findContours(canny, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) 83 | 84 | max_perimeter = 0 85 | max_contour = None 86 | for contour in contours: 87 | contour = np.array(contour) 88 | perimeter = cv2.arcLength(contour, True) 89 | contour_approx = cv2.approxPolyDP(contour, 0.02 * perimeter, True) 90 | 91 | if perimeter > max_perimeter and cv2.isContourConvex(contour_approx) and len(contour_approx) == 4: 92 | max_perimeter = perimeter 93 | max_contour = contour_approx 94 | 95 | if max_contour is not None: 96 | max_contour = np.squeeze(max_contour) 97 | points = order_points(max_contour) 98 | 99 | target_points = find_target_points(*points) 100 | M = cv2.getPerspectiveTransform(np.float32(points), np.float32(target_points)) 101 | final = cv2.warpPerspective(image, M, (target_points[3][0], target_points[3][1]), flags=cv2.INTER_LINEAR) 102 | final = crop(final, pixels=10) 103 | return final 104 | 105 | return image 106 | 107 | 108 | def get_words(page: np.array, dilation_size: int = 3): 109 | gray = cv2.cvtColor(page, cv2.COLOR_BGR2GRAY) 110 | _, thresholded = cv2.threshold(gray, 125, 1, cv2.THRESH_BINARY_INV) 111 | 112 | dilation_size = dilation_size 113 | element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * dilation_size + 1, 2 * dilation_size + 1), 114 | (dilation_size, dilation_size)) 115 | thresholded = cv2.dilate(thresholded, element, iterations=3) 116 | 117 | contours, _ = cv2.findContours(thresholded, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 118 | 119 | words = [] 120 | boxes = [] 121 | 122 | for contour in contours: 123 | x, y, w, h = cv2.boundingRect(contour) 124 | ratio = w / h 125 | if ratio <= 0.1 or ratio >= 10.0: 126 | continue 127 | boxes.append([x, y, w, h]) 128 | words.append(page[y:y+h, x:x+w]) 129 | 130 | return words, boxes --------------------------------------------------------------------------------