├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── class │ ├── imagenet.txt │ └── lsun.txt ├── font │ └── arial.ttf └── vocab │ └── bpe_simple_vocab_16e6.txt ├── configs ├── imagenet_gpt_vitvq_base.yaml ├── imagenet_vitvq_base.yaml ├── imagenet_vitvq_large.yaml └── imagenet_vitvq_small.yaml ├── enhancing ├── __init__.py ├── dataloader │ ├── __init__.py │ ├── cc3m.py │ ├── classimage.py │ ├── coco.py │ ├── imagenet.py │ ├── inatural.py │ ├── lsun.py │ ├── srimage.py │ └── textimage.py ├── losses │ ├── layers.py │ ├── op │ │ ├── __init__.py │ │ ├── conv2d_gradfix.py │ │ ├── fused_act.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ ├── segmentation.py │ └── vqperceptual.py ├── modules │ ├── cond │ │ ├── clipcond.py │ │ ├── dummycond.py │ │ └── vqcond.py │ ├── stage1 │ │ ├── layers.py │ │ ├── quantizers.py │ │ └── vitvqgan.py │ └── stage2 │ │ ├── layers.py │ │ └── transformer.py └── utils │ ├── callback.py │ ├── general.py │ ├── scheduler.py │ └── tokenizer.py ├── environment.yaml ├── main.py └── requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | figures/ 2 | experiments/ 3 | *__pycache__/ 4 | data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Thuan H. Nguyen 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE./ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 4 | 5 | 6 |
7 | Table of Contents 8 |
    9 |
  1. 10 | About The Project 11 |
  2. 12 |
  3. 13 | Getting Started 14 | 18 |
  4. 19 |
  5. Roadmap
  6. 20 |
  7. Contributing
  8. 21 |
  9. License
  10. 22 |
  11. Contact
  12. 23 |
  13. Acknowledgments
  14. 24 |
25 |
26 | 27 | ## News 28 | ***09/09*** 29 | 1. The release weight of ViT-VQGAN small which is trained on ImageNet at [here](https://huggingface.co/thuanz123/vitvqgan-imagenet-small) 30 | 31 | ***16/08*** 32 | 1. First release weight of ViT-VQGAN base which is trained on ImageNet at [here](https://huggingface.co/thuanz123/vitvqgan-imagenet-base) 33 | 2. Add an colab notebook at [here](https://colab.research.google.com/drive/1y-PzYhkNQbhKj3i459pWd6TAO28SnF5h?usp=sharing) 34 | 35 | 36 | ## About The Project 37 | 38 | This is an unofficial implementation of both [ViT-VQGAN](https://arxiv.org/abs/2110.04627) and [RQ-VAE](https://arxiv.org/abs/2110.04627) in Pytorch. ViT-VQGAN is a simple ViT-based Vector Quantized AutoEncoder while RQ-VAE introduces a new residual quantization scheme. Further details can be viewed in the papers 39 | 40 | 41 | ## Getting Started 42 | 43 | For the ease of installation, you should use [anaconda](https://conda.io/) to setup this repo. 44 | 45 | ### Installation 46 | 47 | A suitable conda environment named `enhancing` can be created and activated with: 48 | ``` 49 | conda env create -f environment.yaml 50 | conda activate enhancing 51 | ``` 52 | 53 | 54 | ### Training 55 | 56 | Training is easy with one line: 57 | ```python3 main.py -c config_name -lr learning_rate -e epoch_nums``` 58 | 59 | 60 | ## Roadmap 61 | 62 | - [x] Add ViT-VQGAN 63 | - [x] Add ViT-based encoder and decoder 64 | - [x] Add factorized codes 65 | - [x] Add l2-normalized codes 66 | - [x] Replace PatchGAN discriminator with StyleGAN one 67 | - [x] Add RQ-VAE 68 | - [x] Add Residual Quantizer 69 | - [x] Add RQ-Transformer 70 | - [x] Add dataloader for some common dataset 71 | - [x] ImageNet 72 | - [x] LSUN 73 | - [x] COCO 74 | - [x] Add COCO Segmentation 75 | - [x] Add COCO Caption 76 | - [x] CC3M 77 | - [ ] Add pretrained models 78 | - [x] ViT-VQGAN small 79 | - [x] ViT-VQGAN base 80 | - [ ] ViT-VQGAN large 81 | 82 | 83 | 84 | ## Contributing 85 | 86 | Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. 87 | 88 | If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". 89 | Don't forget to give the project a star! Thanks again! 90 | 91 | 1. Fork the Project 92 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 93 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) 94 | 4. Push to the Branch (`git push origin feature/AmazingFeature`) 95 | 5. Open a Pull Request 96 | 97 | 98 | 99 | ## License 100 | 101 | Source code and pretrained weights are distributed under the MIT License. See `LICENSE` for more information. 102 | 103 | 104 | 105 | ## Contact 106 | 107 | Thuan H. Nguyen - [@leejohnthuan](https://twitter.com/leejohnthuan) - leejohnthuan@gmail.com 108 | 109 | 110 | 111 | ## Acknowledgements 112 | This project would not be possible without the generous sponsorship from [Stability AI](https://stability.ai/) and helpful discussion of folks in [LAION discord](https://discord.gg/j5GdN49g) 113 | 114 | This repo is heavily inspired by following repos and papers: 115 | 116 | * [Taming Transformers](https://github.com/CompVis/taming-transformers) 117 | * [ViT-Pytorch](https://github.com/lucidrains/vit-pytorch) 118 | * [minDALL-E](https://github.com/kakaobrain/minDALL-E) 119 | * [CLIP](https://github.com/openai/CLIP) 120 | * [ViT-VQGAN](https://arxiv.org/abs/2110.04627) 121 | * [RQ-VAE](https://arxiv.org/abs/2110.04627) 122 | -------------------------------------------------------------------------------- /assets/class/imagenet.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great white shark 4 | tiger shark 5 | hammerhead shark 6 | electric ray 7 | stingray 8 | cock 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house finch 14 | junco 15 | indigo bunting 16 | American robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | American dipper 22 | kite 23 | bald eagle 24 | vulture 25 | great grey owl 26 | fire salamander 27 | smooth newt 28 | newt 29 | spotted salamander 30 | axolotl 31 | American bullfrog 32 | tree frog 33 | tailed frog 34 | loggerhead sea turtle 35 | leatherback sea turtle 36 | mud turtle 37 | terrapin 38 | box turtle 39 | banded gecko 40 | green iguana 41 | Carolina anole 42 | desert grassland whiptail lizard 43 | agama 44 | frilled-necked lizard 45 | alligator lizard 46 | Gila monster 47 | European green lizard 48 | chameleon 49 | Komodo dragon 50 | Nile crocodile 51 | American alligator 52 | triceratops 53 | worm snake 54 | ring-necked snake 55 | eastern hog-nosed snake 56 | smooth green snake 57 | kingsnake 58 | garter snake 59 | water snake 60 | vine snake 61 | night snake 62 | boa constrictor 63 | African rock python 64 | Indian cobra 65 | green mamba 66 | sea snake 67 | Saharan horned viper 68 | eastern diamondback rattlesnake 69 | sidewinder 70 | trilobite 71 | harvestman 72 | scorpion 73 | yellow garden spider 74 | barn spider 75 | European garden spider 76 | southern black widow 77 | tarantula 78 | wolf spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse 84 | prairie grouse 85 | peacock 86 | quail 87 | partridge 88 | grey parrot 89 | macaw 90 | sulphur-crested cockatoo 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | duck 99 | red-breasted merganser 100 | goose 101 | black swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea anemone 110 | brain coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea slug 117 | chiton 118 | chambered nautilus 119 | Dungeness crab 120 | rock crab 121 | fiddler crab 122 | red king crab 123 | American lobster 124 | spiny lobster 125 | crayfish 126 | hermit crab 127 | isopod 128 | white stork 129 | black stork 130 | spoonbill 131 | flamingo 132 | little blue heron 133 | great egret 134 | bittern 135 | crane (bird) 136 | limpkin 137 | common gallinule 138 | American coot 139 | bustard 140 | ruddy turnstone 141 | dunlin 142 | common redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king penguin 147 | albatross 148 | grey whale 149 | killer whale 150 | dugong 151 | sea lion 152 | Chihuahua 153 | Japanese Chin 154 | Maltese 155 | Pekingese 156 | Shih Tzu 157 | King Charles Spaniel 158 | Papillon 159 | toy terrier 160 | Rhodesian Ridgeback 161 | Afghan Hound 162 | Basset Hound 163 | Beagle 164 | Bloodhound 165 | Bluetick Coonhound 166 | Black and Tan Coonhound 167 | Treeing Walker Coonhound 168 | English foxhound 169 | Redbone Coonhound 170 | borzoi 171 | Irish Wolfhound 172 | Italian Greyhound 173 | Whippet 174 | Ibizan Hound 175 | Norwegian Elkhound 176 | Otterhound 177 | Saluki 178 | Scottish Deerhound 179 | Weimaraner 180 | Staffordshire Bull Terrier 181 | American Staffordshire Terrier 182 | Bedlington Terrier 183 | Border Terrier 184 | Kerry Blue Terrier 185 | Irish Terrier 186 | Norfolk Terrier 187 | Norwich Terrier 188 | Yorkshire Terrier 189 | Wire Fox Terrier 190 | Lakeland Terrier 191 | Sealyham Terrier 192 | Airedale Terrier 193 | Cairn Terrier 194 | Australian Terrier 195 | Dandie Dinmont Terrier 196 | Boston Terrier 197 | Miniature Schnauzer 198 | Giant Schnauzer 199 | Standard Schnauzer 200 | Scottish Terrier 201 | Tibetan Terrier 202 | Australian Silky Terrier 203 | Soft-coated Wheaten Terrier 204 | West Highland White Terrier 205 | Lhasa Apso 206 | Flat-Coated Retriever 207 | Curly-coated Retriever 208 | Golden Retriever 209 | Labrador Retriever 210 | Chesapeake Bay Retriever 211 | German Shorthaired Pointer 212 | Vizsla 213 | English Setter 214 | Irish Setter 215 | Gordon Setter 216 | Brittany 217 | Clumber Spaniel 218 | English Springer Spaniel 219 | Welsh Springer Spaniel 220 | Cocker Spaniels 221 | Sussex Spaniel 222 | Irish Water Spaniel 223 | Kuvasz 224 | Schipperke 225 | Groenendael 226 | Malinois 227 | Briard 228 | Australian Kelpie 229 | Komondor 230 | Old English Sheepdog 231 | Shetland Sheepdog 232 | collie 233 | Border Collie 234 | Bouvier des Flandres 235 | Rottweiler 236 | German Shepherd Dog 237 | Dobermann 238 | Miniature Pinscher 239 | Greater Swiss Mountain Dog 240 | Bernese Mountain Dog 241 | Appenzeller Sennenhund 242 | Entlebucher Sennenhund 243 | Boxer 244 | Bullmastiff 245 | Tibetan Mastiff 246 | French Bulldog 247 | Great Dane 248 | St. Bernard 249 | husky 250 | Alaskan Malamute 251 | Siberian Husky 252 | Dalmatian 253 | Affenpinscher 254 | Basenji 255 | pug 256 | Leonberger 257 | Newfoundland 258 | Pyrenean Mountain Dog 259 | Samoyed 260 | Pomeranian 261 | Chow Chow 262 | Keeshond 263 | Griffon Bruxellois 264 | Pembroke Welsh Corgi 265 | Cardigan Welsh Corgi 266 | Toy Poodle 267 | Miniature Poodle 268 | Standard Poodle 269 | Mexican hairless dog 270 | grey wolf 271 | Alaskan tundra wolf 272 | red wolf 273 | coyote 274 | dingo 275 | dhole 276 | African wild dog 277 | hyena 278 | red fox 279 | kit fox 280 | Arctic fox 281 | grey fox 282 | tabby cat 283 | tiger cat 284 | Persian cat 285 | Siamese cat 286 | Egyptian Mau 287 | cougar 288 | lynx 289 | leopard 290 | snow leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown bear 296 | American black bear 297 | polar bear 298 | sloth bear 299 | mongoose 300 | meerkat 301 | tiger beetle 302 | ladybug 303 | ground beetle 304 | longhorn beetle 305 | leaf beetle 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket 314 | stick insect 315 | cockroach 316 | mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | red admiral 323 | ringlet 324 | monarch butterfly 325 | small white 326 | sulphur butterfly 327 | gossamer-winged butterfly 328 | starfish 329 | sea urchin 330 | sea cucumber 331 | cottontail rabbit 332 | hare 333 | Angora rabbit 334 | hamster 335 | porcupine 336 | fox squirrel 337 | marmot 338 | beaver 339 | guinea pig 340 | common sorrel 341 | zebra 342 | pig 343 | wild boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water buffalo 348 | bison 349 | ram 350 | bighorn sheep 351 | Alpine ibex 352 | hartebeest 353 | impala 354 | gazelle 355 | dromedary 356 | llama 357 | weasel 358 | mink 359 | European polecat 360 | black-footed ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas monkey 373 | baboon 374 | macaque 375 | langur 376 | black-and-white colobus 377 | proboscis monkey 378 | marmoset 379 | white-headed capuchin 380 | howler monkey 381 | titi 382 | Geoffroy's spider monkey 383 | common squirrel monkey 384 | ring-tailed lemur 385 | indri 386 | Asian elephant 387 | African bush elephant 388 | red panda 389 | giant panda 390 | snoek 391 | eel 392 | coho salmon 393 | rock beauty 394 | clownfish 395 | sturgeon 396 | garfish 397 | lionfish 398 | pufferfish 399 | abacus 400 | abaya 401 | academic gown 402 | accordion 403 | acoustic guitar 404 | aircraft carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibious vehicle 410 | analog clock 411 | apiary 412 | apron 413 | waste container 414 | assault rifle 415 | backpack 416 | bakery 417 | balance beam 418 | balloon 419 | ballpoint pen 420 | Band-Aid 421 | banjo 422 | baluster 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | swimming cap 435 | bath towel 436 | bathtub 437 | station wagon 438 | lighthouse 439 | beaker 440 | military cap 441 | beer bottle 442 | beer glass 443 | bell-cot 444 | bib 445 | tandem bicycle 446 | bikini 447 | ring binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsleigh 452 | bolo tie 453 | poke bonnet 454 | bookcase 455 | bookstore 456 | bottle cap 457 | bow 458 | bow tie 459 | brass 460 | bra 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof vest 467 | high-speed train 468 | butcher shop 469 | taxicab 470 | cauldron 471 | candle 472 | cannon 473 | canoe 474 | can opener 475 | cardigan 476 | car mirror 477 | carousel 478 | tool kit 479 | carton 480 | car wheel 481 | automated teller machine 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello 488 | mobile phone 489 | chain 490 | chain-link fence 491 | chain mail 492 | chainsaw 493 | chest 494 | chiffonier 495 | chime 496 | china cabinet 497 | Christmas stocking 498 | church 499 | movie theater 500 | cleaver 501 | cliff dwelling 502 | cloak 503 | clogs 504 | cocktail shaker 505 | coffee mug 506 | coffeemaker 507 | coil 508 | combination lock 509 | computer keyboard 510 | confectionery store 511 | container ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy boot 516 | cowboy hat 517 | cradle 518 | crane (machine) 519 | crash helmet 520 | crate 521 | infant bed 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop computer 529 | rotary dial telephone 530 | diaper 531 | digital clock 532 | digital watch 533 | dining table 534 | dishcloth 535 | dishwasher 536 | disc brake 537 | dock 538 | dog sled 539 | dome 540 | doormat 541 | drilling rig 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso machine 552 | face powder 553 | feather boa 554 | filing cabinet 555 | fireboat 556 | fire engine 557 | fire screen sheet 558 | flagpole 559 | flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster bed 566 | freight car 567 | French horn 568 | frying pan 569 | fur coat 570 | garbage truck 571 | gas mask 572 | gas pump 573 | goblet 574 | go-kart 575 | golf ball 576 | golf cart 577 | gondola 578 | gong 579 | gown 580 | grand piano 581 | greenhouse 582 | grille 583 | grocery store 584 | guillotine 585 | barrette 586 | hair spray 587 | half-track 588 | hammer 589 | hamper 590 | hair dryer 591 | hand-held computer 592 | handkerchief 593 | hard disk drive 594 | harmonica 595 | harp 596 | harvester 597 | hatchet 598 | holster 599 | home theater 600 | honeycomb 601 | hook 602 | hoop skirt 603 | horizontal bar 604 | horse-drawn vehicle 605 | hourglass 606 | iPod 607 | clothes iron 608 | jack-o'-lantern 609 | jeans 610 | jeep 611 | T-shirt 612 | jigsaw puzzle 613 | pulled rickshaw 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat 619 | ladle 620 | lampshade 621 | laptop computer 622 | lawn mower 623 | lens cap 624 | paper knife 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | ocean liner 630 | lipstick 631 | slip-on shoe 632 | lotion 633 | speaker 634 | loupe 635 | sawmill 636 | magnetic compass 637 | mail bag 638 | mailbox 639 | tights 640 | tank suit 641 | manhole cover 642 | maraca 643 | marimba 644 | mask 645 | match 646 | maypole 647 | maze 648 | measuring cup 649 | medicine chest 650 | megalith 651 | microphone 652 | microwave oven 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | square academic cap 669 | mosque 670 | mosquito net 671 | scooter 672 | mountain bike 673 | tent 674 | computer mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook computer 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil filter 688 | organ 689 | oscilloscope 690 | overskirt 691 | bullock cart 692 | oxygen mask 693 | packet 694 | paddle 695 | paddle wheel 696 | padlock 697 | paintbrush 698 | pajamas 699 | palace 700 | pan flute 701 | paper towel 702 | parachute 703 | parallel bars 704 | park bench 705 | parking meter 706 | passenger car 707 | patio 708 | payphone 709 | pedestal 710 | pencil case 711 | pencil sharpener 712 | perfume 713 | Petri dish 714 | photocopier 715 | plectrum 716 | Pickelhaube 717 | picket fence 718 | pickup truck 719 | pier 720 | piggy bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate ship 726 | pitcher 727 | hand plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow 732 | plunger 733 | Polaroid camera 734 | pole 735 | police van 736 | poncho 737 | billiard table 738 | soda bottle 739 | pot 740 | potter's wheel 741 | power drill 742 | prayer rug 743 | printer 744 | prison 745 | projectile 746 | projector 747 | hockey puck 748 | punching bag 749 | purse 750 | quill 751 | quilt 752 | race car 753 | racket 754 | radiator 755 | radio 756 | radio telescope 757 | rain barrel 758 | recreational vehicle 759 | reel 760 | reflex camera 761 | refrigerator 762 | remote control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking chair 767 | rotisserie 768 | eraser 769 | rugby ball 770 | ruler 771 | running shoe 772 | safe 773 | safety pin 774 | salt shaker 775 | sandal 776 | sarong 777 | saxophone 778 | scabbard 779 | weighing scale 780 | school bus 781 | schooner 782 | scoreboard 783 | CRT screen 784 | screw 785 | screwdriver 786 | seat belt 787 | sewing machine 788 | shield 789 | shoe store 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule 800 | sliding door 801 | slot machine 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar thermal collector 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | motorboat 816 | spider web 817 | spindle 818 | sports car 819 | spotlight 820 | stage 821 | steam locomotive 822 | through arch bridge 823 | steel drum 824 | stethoscope 825 | scarf 826 | stone wall 827 | stopwatch 828 | stove 829 | strainer 830 | tram 831 | stretcher 832 | couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglass 838 | sunglasses 839 | sunscreen 840 | suspension bridge 841 | mop 842 | sweatshirt 843 | swimsuit 844 | swing 845 | switch 846 | syringe 847 | table lamp 848 | tank 849 | tape player 850 | teapot 851 | teddy bear 852 | television 853 | tennis ball 854 | thatched roof 855 | front curtain 856 | thimble 857 | threshing machine 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck 866 | toy store 867 | tractor 868 | semi-trailer truck 869 | tray 870 | trench coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus 876 | trombone 877 | tub 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle 882 | upright piano 883 | vacuum cleaner 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet 895 | wardrobe 896 | military aircraft 897 | sink 898 | washing machine 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool 913 | split-rail fence 914 | shipwreck 915 | yawl 916 | yurt 917 | website 918 | comic book 919 | crossword 920 | traffic sign 921 | traffic light 922 | dust jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot 928 | trifle 929 | ice cream 930 | ice pop 931 | baguette 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hot dog 936 | mashed potato 937 | cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber 945 | artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate syrup 962 | dough 963 | meatloaf 964 | pizza 965 | pot pie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff 974 | coral reef 975 | geyser 976 | lakeshore 977 | promontory 978 | shoal 979 | seashore 980 | valley 981 | volcano 982 | baseball player 983 | bridegroom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper 988 | corn 989 | acorn 990 | rose hip 991 | horse chestnut seed 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn mushroom 996 | earth star 997 | hen-of-the-woods 998 | bolete 999 | ear 1000 | toilet paper -------------------------------------------------------------------------------- /assets/class/lsun.txt: -------------------------------------------------------------------------------- 1 | bedroom 2 | bridge 3 | church_outdoor 4 | classroom 5 | conference_room 6 | dining_room 7 | kitchen 8 | living_room 9 | restaurant 10 | test 11 | tower -------------------------------------------------------------------------------- /assets/font/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuanz123/enhancing-transformers/e1c185da5cc0eeefdded68cbff888d5b9c248372/assets/font/arial.ttf -------------------------------------------------------------------------------- /configs/imagenet_gpt_vitvq_base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: enhancing.modules.stage2.transformer.CondTransformer 3 | params: 4 | cond_key: class 5 | cond: 6 | target: enhancing.modules.cond.dummycond.ClassCond 7 | params: 8 | image_size: 256 9 | class_name: assets/class/imagenet.txt 10 | stage1: 11 | target: enhancing.modules.stage1.vitvqgan.ViTVQ 12 | params: 13 | image_key: image 14 | path: weight/imagenet_vitvq_base.ckpt 15 | image_size: 256 16 | patch_size: 8 17 | encoder: 18 | dim: 768 19 | depth: 12 20 | heads: 12 21 | mlp_dim: 3072 22 | decoder: 23 | dim: 768 24 | depth: 12 25 | heads: 12 26 | mlp_dim: 3072 27 | quantizer: 28 | embed_dim: 32 29 | n_embed: 8192 30 | loss: 31 | target: enhancing.losses.vqperceptual.DummyLoss 32 | transformer: 33 | target: enhancing.modules.stage2.layers.GPT 34 | params: 35 | vocab_cond_size: 1000 36 | vocab_img_size: 8192 37 | embed_dim: 6144 38 | cond_num_tokens: 1 39 | img_num_tokens: 1024 40 | n_heads: 16 41 | n_layers: 24 42 | 43 | dataset: 44 | target: enhancing.dataloader.DataModuleFromConfig 45 | params: 46 | batch_size: 4 47 | num_workers: 2 48 | train: 49 | target: enhancing.dataloader.imagenet.ImageNetTrain 50 | params: 51 | root: data/ilsvrc2012 52 | resolution: 256 53 | 54 | validation: 55 | target: enhancing.dataloader.imagenet.ImageNetValidation 56 | params: 57 | root: data/ilsvrc2012 58 | resolution: 256 59 | -------------------------------------------------------------------------------- /configs/imagenet_vitvq_base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: enhancing.modules.stage1.vitvqgan.ViTVQ 3 | params: 4 | image_key: image 5 | image_size: 256 6 | patch_size: 8 7 | encoder: 8 | dim: 768 9 | depth: 12 10 | heads: 12 11 | mlp_dim: 3072 12 | decoder: 13 | dim: 768 14 | depth: 12 15 | heads: 12 16 | mlp_dim: 3072 17 | quantizer: 18 | embed_dim: 32 19 | n_embed: 8192 20 | loss: 21 | target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator 22 | params: 23 | loglaplace_weight: 0.0 24 | loggaussian_weight: 1.0 25 | perceptual_weight: 0.1 26 | adversarial_weight: 0.1 27 | 28 | dataset: 29 | target: enhancing.dataloader.DataModuleFromConfig 30 | params: 31 | batch_size: 8 32 | num_workers: 4 33 | train: 34 | target: enhancing.dataloader.imagenet.ImageNetTrain 35 | params: 36 | root: data/ilsvrc2012 37 | resolution: 256 38 | 39 | validation: 40 | target: enhancing.dataloader.imagenet.ImageNetValidation 41 | params: 42 | root: data/ilsvrc2012 43 | resolution: 256 -------------------------------------------------------------------------------- /configs/imagenet_vitvq_large.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: enhancing.modules.stage1.vitvqgan.ViTVQ 3 | params: 4 | image_key: image 5 | image_size: 256 6 | patch_size: 8 7 | encoder: 8 | dim: 512 9 | depth: 8 10 | heads: 8 11 | mlp_dim: 2048 12 | decoder: 13 | dim: 1280 14 | depth: 32 15 | heads: 16 16 | mlp_dim: 5120 17 | quantizer: 18 | embed_dim: 32 19 | n_embed: 8192 20 | loss: 21 | target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator 22 | params: 23 | loglaplace_weight: 0.0 24 | loggaussian_weight: 1.0 25 | perceptual_weight: 0.1 26 | adversarial_weight: 0.1 27 | 28 | dataset: 29 | target: enhancing.dataloader.DataModuleFromConfig 30 | params: 31 | batch_size: 2 32 | num_workers: 4 33 | train: 34 | target: enhancing.dataloader.imagenet.ImageNetTrain 35 | params: 36 | root: data/ilsvrc2012 37 | resolution: 256 38 | 39 | validation: 40 | target: enhancing.dataloader.imagenet.ImageNetValidation 41 | params: 42 | root: data/ilsvrc2012 43 | resolution: 256 -------------------------------------------------------------------------------- /configs/imagenet_vitvq_small.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: enhancing.modules.stage1.vitvqgan.ViTVQ 3 | params: 4 | image_key: image 5 | image_size: 256 6 | patch_size: 8 7 | encoder: 8 | dim: 512 9 | depth: 8 10 | heads: 8 11 | mlp_dim: 2048 12 | decoder: 13 | dim: 512 14 | depth: 8 15 | heads: 8 16 | mlp_dim: 2048 17 | quantizer: 18 | embed_dim: 32 19 | n_embed: 8192 20 | loss: 21 | target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator 22 | params: 23 | loglaplace_weight: 0.0 24 | loggaussian_weight: 1.0 25 | perceptual_weight: 0.1 26 | adversarial_weight: 0.1 27 | 28 | dataset: 29 | target: enhancing.dataloader.DataModuleFromConfig 30 | params: 31 | batch_size: 8 32 | num_workers: 4 33 | train: 34 | target: enhancing.dataloader.imagenet.ImageNetTrain 35 | params: 36 | root: data/ilsvrc2012 37 | resolution: 256 38 | 39 | validation: 40 | target: enhancing.dataloader.imagenet.ImageNetValidation 41 | params: 42 | root: data/ilsvrc2012 43 | resolution: 256 -------------------------------------------------------------------------------- /enhancing/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import general 2 | -------------------------------------------------------------------------------- /enhancing/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from VQGAN (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | from typing import Optional 7 | from omegaconf import OmegaConf 8 | 9 | import pytorch_lightning as pl 10 | from torch.utils.data import DataLoader 11 | 12 | from ..utils.general import initialize_from_config 13 | 14 | class DataModuleFromConfig(pl.LightningDataModule): 15 | def __init__(self, batch_size: int, train: Optional[OmegaConf] = None, 16 | validation: Optional[OmegaConf] = None, 17 | test: Optional[OmegaConf] = None, 18 | num_workers: Optional[int] = None): 19 | super().__init__() 20 | self.dataset_configs = dict() 21 | self.batch_size = batch_size 22 | self.num_workers = num_workers if num_workers is not None else batch_size*2 23 | if train is not None: 24 | self.dataset_configs["train"] = train 25 | self.train_dataloader = self._train_dataloader 26 | if validation is not None: 27 | self.dataset_configs["validation"] = validation 28 | self.val_dataloader = self._val_dataloader 29 | if test is not None: 30 | self.dataset_configs["test"] = test 31 | self.test_dataloader = self._test_dataloader 32 | 33 | def prepare_data(self): 34 | for data_cfg in self.dataset_configs.values(): 35 | initialize_from_config(data_cfg) 36 | 37 | def setup(self, stage=None): 38 | self.datasets = dict( 39 | (k, initialize_from_config(self.dataset_configs[k])) 40 | for k in self.dataset_configs) 41 | 42 | def _train_dataloader(self): 43 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 44 | num_workers=self.num_workers, shuffle=True) 45 | 46 | def _val_dataloader(self): 47 | return DataLoader(self.datasets["validation"], 48 | batch_size=self.batch_size, 49 | num_workers=self.num_workers) 50 | 51 | def _test_dataloader(self): 52 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 53 | num_workers=self.num_workers) 54 | -------------------------------------------------------------------------------- /enhancing/dataloader/cc3m.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | from typing import Optional, Union, Callable, Tuple, Any 8 | from pathlib import Path 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | 12 | from torchvision import transforms as T 13 | from torch.utils.data import Dataset 14 | 15 | from ..utils.general import initialize_from_config 16 | 17 | class CC3MBase(Dataset): 18 | def __init__(self, folder: str, split: str, 19 | tokenizer: OmegaConf, 20 | transform: Callable) -> None: 21 | super().__init__() 22 | 23 | for line in open(f'{Path(folder)}/{split}_list.txt', 'r').readlines(): 24 | imgpath, text = line.strip().split('\t') 25 | self.items.append((Path(folder)/imgpath, text)) 26 | 27 | self.tokenizer = initialize_from_config(tokenizer) 28 | self.transform = transform 29 | 30 | def __len__(self) -> int: 31 | return len(self.keys) 32 | 33 | def __getitem__(self, ind: int) -> Tuple[Any, Any]: 34 | image_file, caption = self.items[ind] 35 | 36 | caption = self.tokenizer.tokenize(caption).squeeze(0) 37 | 38 | image = Image.open(image_file) 39 | if image.mode != 'RGB': 40 | image = image.convert('RGB') 41 | 42 | if self.transform: 43 | image = self.transform(image) 44 | 45 | # Success 46 | return {"caption": caption, "image": image} 47 | 48 | 49 | class CC3MTrain(TextImageBase): 50 | def __init__(self, folder: str, tokenizer: OmegaConf, 51 | resolution: Union[Tuple[int, int], int] = 256) -> None: 52 | transform = T.Compose([ 53 | T.Resize(resolution), 54 | T.RandomCrop(resolution), 55 | T.ToTensor(), 56 | ]) 57 | 58 | super().__init__(folder, 'train', tokenizer, transform) 59 | 60 | 61 | class CC3MValidation(TextImageBase): 62 | def __init__(self, folder: str, tokenizer: OmegaConf, 63 | resolution: Union[Tuple[int, int], int] = 256) -> None: 64 | transform = T.Compose([ 65 | T.Resize(resolution), 66 | T.CenterCrop(resolution), 67 | T.ToTensor(), 68 | ]) 69 | 70 | super().__init__(folder, 'val', tokenizer, transform) 71 | -------------------------------------------------------------------------------- /enhancing/dataloader/classimage.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import numpy as np 8 | from typing import Optional, Union, Callable, Tuple, Any 9 | from pathlib import Path 10 | from random import randint, choice 11 | from omegaconf import OmegaConf 12 | 13 | import torch 14 | from torchvision import transforms as T 15 | from torchvision.datasets import ImageFolder 16 | 17 | from ..utils.general import initialize_from_config 18 | 19 | class ClassImageBase(ImageFolder): 20 | def __init__(self, root: str, split: str, 21 | transform: Callable) -> None: 22 | root = Path(root)/split 23 | super().__init__(root, transform) 24 | 25 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 26 | image, target = super().__getitem__(index) 27 | 28 | return {'image': image, 'class': torch.tensor([target])} 29 | 30 | 31 | class ClassImageTrain(ClassImageBase): 32 | def __init__(self, root: str, 33 | resolution: Union[Tuple[int, int], int] = 256, 34 | resize_ratio: float = 0.75) -> None: 35 | if isinstance(resolution, int): 36 | resolution = [resolution, resolution] 37 | 38 | transform = T.Compose([ 39 | T.Resize(resolution), 40 | T.RandomCrop(resolution), 41 | T.RandomHorizontalFlip(), 42 | T.ToTensor() 43 | ]) 44 | 45 | super().__init__(root, 'train', transform) 46 | 47 | 48 | class ClassImageValidation(ClassImageBase): 49 | def __init__(self, root: str, 50 | resolution: Union[Tuple[int, int], int] = 256) -> None: 51 | if isinstance(resolution, int): 52 | resolution = [resolution, resolution] 53 | 54 | transform = T.Compose([ 55 | T.Resize(resolution), 56 | T.CenterCrop(resolution), 57 | T.ToTensor() 58 | ]) 59 | 60 | super().__init__(root, 'val', transform) 61 | -------------------------------------------------------------------------------- /enhancing/dataloader/coco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | import json 11 | import albumentations as A 12 | from omegaconf import OmegaConf 13 | from typing import Optional, List, Callable, Union, Tuple 14 | from pathlib import Path 15 | 16 | import numpy as np 17 | from PIL import Image 18 | from torch.utils.data import Dataset 19 | 20 | from ..utils.general import initialize_from_config 21 | 22 | 23 | class COCOBase(Dataset): 24 | def __init__(self, dataroot: str = "", labelroot: str = "", stuffthingroot: str = "", split: str = "", 25 | onehot_segmentation: bool = False, use_stuffthing: bool = False, 26 | tokenizer: Optional[OmegaConf] = None, transform: Optional[Callable] = None) -> None: 27 | assert split in ["train", "val"] 28 | self.split = split 29 | 30 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 31 | self.stuffthing = use_stuffthing # include thing in segmentation 32 | if self.onehot and not self.stuffthing: 33 | raise NotImplemented("One hot mode is only supported for the " 34 | "stuffthings version because labels are stored " 35 | "a bit different.") 36 | 37 | data_json = Path(labelroot)/f"captions_{split}2017.json" 38 | with open(data_json) as json_file: 39 | self.json_data = json.load(json_file) 40 | self.img_id_to_captions = dict() 41 | self.img_id_to_filepath = dict() 42 | self.img_id_to_segmentation_filepath = dict() 43 | 44 | if self.stuffthing: 45 | self.segmentation_prefix = Path(stuffthingroot)/f"{split}2017" 46 | else: 47 | self.segmentation_prefix = Path(labelroot)/f"stuff_{split}2017_pixelmaps" 48 | 49 | imagedirs = self.json_data["images"] 50 | self.labels = {"image_ids": list()} 51 | for imgdir in imagedirs: 52 | self.img_id_to_filepath[imgdir["id"]] = Path(dataroot)/f"{split}2017"/imgdir["file_name"] 53 | self.img_id_to_captions[imgdir["id"]] = list() 54 | pngfilename = imgdir["file_name"].replace("jpg", "png") 55 | self.img_id_to_segmentation_filepath[imgdir["id"]] = Path(self.segmentation_prefix)/pngfilename 56 | self.labels["image_ids"].append(imgdir["id"]) 57 | 58 | capdirs = self.json_data["annotations"] 59 | for capdir in capdirs: 60 | # there are in average 5 captions per image 61 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 62 | 63 | self.transform = transform 64 | self.tokenizer = initialize_from_config(tokenizer) 65 | 66 | def __len__(self): 67 | return len(self.labels["image_ids"]) 68 | 69 | def preprocess_image(self, image_path, segmentation_path): 70 | image = Image.open(image_path) 71 | if image.mode != "RGB": 72 | image = image.convert("RGB") 73 | image = np.array(image).astype(np.uint8) 74 | 75 | segmentation = Image.open(segmentation_path) 76 | if not self.onehot and not segmentation.mode == "RGB": 77 | segmentation = segmentation.convert("RGB") 78 | segmentation = np.array(segmentation).astype(np.uint8) 79 | if self.onehot: 80 | assert self.stuffthing 81 | # stored in caffe format: unlabeled==255. stuff and thing from 82 | # 0-181. to be compatible with the labels in 83 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 84 | # we shift stuffthing one to the right and put unlabeled in zero 85 | # as long as segmentation is uint8 shifting to right handles the 86 | # latter too 87 | assert segmentation.dtype == np.uint8 88 | segmentation = segmentation + 1 89 | 90 | image, segmentation = self.transform(image=image, segmentation=segmentation) 91 | image = (image / 255).astype(np.float32) 92 | 93 | if self.onehot: 94 | assert segmentation.dtype == np.uint8 95 | # make it one hot 96 | n_labels = 183 97 | flatseg = np.ravel(segmentation) 98 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 99 | onehot[np.arange(flatseg.size), flatseg] = True 100 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 101 | segmentation = onehot 102 | else: 103 | segmentation = (segmentation / 255).astype(np.float32) 104 | 105 | return image, segmentation 106 | 107 | def __getitem__(self, i): 108 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 109 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 110 | image, segmentation = self.preprocess_image(img_path, seg_path) 111 | 112 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 113 | caption = captions[np.random.randint(0, len(captions))] 114 | caption = self.tokenizer.tokenize(caption).squeeze(0) 115 | 116 | return {"image": image, "caption": caption, "segmentation": segmentation} 117 | 118 | 119 | class COCOTrain(COCOBase): 120 | def __init__(self, dataroot: str, labelroot: str, stuffthingroot: str, tokenizer: OmegaConf, 121 | resolution: Union[Tuple[int, int], int], onehot_segmentation: bool = False, use_stuffthing: bool = False) -> None: 122 | if isinstance(resolution, int): 123 | resolution = [resolution, resolution] 124 | 125 | transform = A.Compose( 126 | [A.SmallestMaxSize(max_size=min(resolution)), 127 | A.RandomCrop(height=resolution[0], width=resolution[1])], 128 | additional_targets={"segmentation": "image"}) 129 | 130 | super().__init__(dataroot, labelroot, stuffthingroot, "train", 131 | onehot_segmentation, use_stuffthing, tokenizer, transform) 132 | 133 | 134 | class COCOValidation(COCOBase): 135 | def __init__(self, dataroot: str, labelroot: str, stuffthingroot: str, tokenizer: OmegaConf, 136 | resolution: Union[Tuple[int, int], int], onehot_segmentation: bool = False, use_stuffthing: bool = False) -> None: 137 | if isinstance(resolution, int): 138 | resolution = [resolution, resolution] 139 | 140 | transform = A.Compose( 141 | [A.SmallestMaxSize(max_size=min(resolution)), 142 | A.CenterCrop(height=resolution[0], width=resolution[1])], 143 | additional_targets={"segmentation": "image"}) 144 | 145 | super().__init__(dataroot, labelroot, stuffthingroot, "val", 146 | onehot_segmentation, use_stuffthing, tokenizer, transform) 147 | -------------------------------------------------------------------------------- /enhancing/dataloader/imagenet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import PIL 8 | from typing import Any, Tuple, Union, Optional, Callable 9 | 10 | import torch 11 | from torchvision import transforms as T 12 | from torchvision.datasets import ImageNet 13 | 14 | 15 | class ImageNetBase(ImageNet): 16 | def __init__(self, root: str, split: str, 17 | transform: Optional[Callable] = None) -> None: 18 | super().__init__(root=root, split=split, transform=transform) 19 | 20 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 21 | sample, target = super().__getitem__(index) 22 | 23 | return {'image': sample, 'class': torch.tensor([target])} 24 | 25 | 26 | class ImageNetTrain(ImageNetBase): 27 | def __init__(self, root: str, 28 | resolution: Union[Tuple[int, int], int] = 256, 29 | resize_ratio: float = 0.75) -> None: 30 | 31 | transform = T.Compose([ 32 | T.Resize(resolution), 33 | T.RandomCrop(resolution), 34 | T.RandomHorizontalFlip(), 35 | T.ToTensor() 36 | ]) 37 | 38 | super().__init__(root=root, split='train', transform=transform) 39 | 40 | 41 | class ImageNetValidation(ImageNetBase): 42 | def __init__(self, root: str, 43 | resolution: Union[Tuple[int, int], int] = 256,) -> None: 44 | 45 | if isinstance(resolution, int): 46 | resolution = (resolution, resolution) 47 | 48 | transform = T.Compose([ 49 | T.Resize(resolution), 50 | T.CenterCrop(resolution), 51 | T.ToTensor() 52 | ]) 53 | 54 | super().__init__(root=root, split='val', transform=transform) 55 | -------------------------------------------------------------------------------- /enhancing/dataloader/inatural.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from Torchvision (https://github.com/pytorch/vision) 7 | # Copyright (c) 2016 Soumith Chintala. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | import os 11 | import PIL 12 | from typing import Any, Tuple, Union 13 | from pathlib import Path 14 | from typing import Optional, Union, Callable, Tuple, Any 15 | 16 | import torch 17 | from torchvision import transforms as T 18 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg 19 | from torchvision.datasets.vision import VisionDataset 20 | 21 | 22 | CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"] 23 | 24 | DATASET_URLS = { 25 | "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz", 26 | "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz", 27 | "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz", 28 | "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz", 29 | "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz", 30 | "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz", 31 | } 32 | 33 | DATASET_MD5 = { 34 | "2017": "7c784ea5e424efaec655bd392f87301f", 35 | "2018": "b1c6952ce38f31868cc50ea72d066cc3", 36 | "2019": "c60a6e2962c9b8ccbd458d12c8582644", 37 | "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed", 38 | "2021_train_mini": "db6ed8330e634445efc8fec83ae81442", 39 | "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc", 40 | } 41 | 42 | 43 | class INaturalistBase(VisionDataset): 44 | """`iNaturalist `_ Dataset. 45 | 46 | Args: 47 | root (string): Root directory of dataset where the image files are stored. 48 | This class does not require/use annotation files. 49 | version (string, optional): Which version of the dataset to download/use. One of 50 | '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'. 51 | Default: `2021_train`. 52 | target_type (string or list, optional): Type of target to use, for 2021 versions, one of: 53 | 54 | - ``full``: the full category (species) 55 | - ``kingdom``: e.g. "Animalia" 56 | - ``phylum``: e.g. "Arthropoda" 57 | - ``class``: e.g. "Insecta" 58 | - ``order``: e.g. "Coleoptera" 59 | - ``family``: e.g. "Cleridae" 60 | - ``genus``: e.g. "Trichodes" 61 | 62 | for 2017-2019 versions, one of: 63 | 64 | - ``full``: the full (numeric) category 65 | - ``super``: the super category, e.g. "Amphibians" 66 | 67 | Can also be a list to output a tuple with all specified target types. 68 | Defaults to ``full``. 69 | transform (callable, optional): A function/transform that takes in an PIL image 70 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 71 | target_transform (callable, optional): A function/transform that takes in the 72 | target and transforms it. 73 | download (bool, optional): If true, downloads the dataset from the internet and 74 | puts it in root directory. If dataset is already downloaded, it is not 75 | downloaded again. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | root: str, 81 | version: str = "2021_train", 82 | target_type: Union[List[str], str] = "full", 83 | transform: Optional[Callable] = None, 84 | target_transform: Optional[Callable] = None, 85 | ) -> None: 86 | self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) 87 | 88 | super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform) 89 | 90 | os.makedirs(root, exist_ok=True) 91 | path_exist = os.path.isdir(os.path.join(root,version)) 92 | if not path_exist: 93 | self.download() 94 | 95 | if not self._check_integrity(): 96 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 97 | 98 | self.all_categories: List[str] = [] 99 | 100 | # map: category type -> name of category -> index 101 | self.categories_index: Dict[str, Dict[str, int]] = {} 102 | 103 | # list indexed by category id, containing mapping from category type -> index 104 | self.categories_map: List[Dict[str, int]] = [] 105 | 106 | if not isinstance(target_type, list): 107 | target_type = [target_type] 108 | if self.version[:4] == "2021": 109 | self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type] 110 | self._init_2021() 111 | else: 112 | self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type] 113 | self._init_pre2021() 114 | 115 | # index of all files: (full category id, filename) 116 | self.index: List[Tuple[int, str]] = [] 117 | 118 | for dir_index, dir_name in enumerate(self.all_categories): 119 | files = os.listdir(os.path.join(self.root, dir_name)) 120 | for fname in files: 121 | self.index.append((dir_index, fname)) 122 | 123 | def _init_2021(self) -> None: 124 | """Initialize based on 2021 layout""" 125 | 126 | self.all_categories = sorted(os.listdir(self.root)) 127 | 128 | # map: category type -> name of category -> index 129 | self.categories_index = {k: {} for k in CATEGORIES_2021} 130 | 131 | for dir_index, dir_name in enumerate(self.all_categories): 132 | pieces = dir_name.split("_") 133 | if len(pieces) != 8: 134 | raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces") 135 | if pieces[0] != f"{dir_index:05d}": 136 | raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}") 137 | cat_map = {} 138 | for cat, name in zip(CATEGORIES_2021, pieces[1:7]): 139 | if name in self.categories_index[cat]: 140 | cat_id = self.categories_index[cat][name] 141 | else: 142 | cat_id = len(self.categories_index[cat]) 143 | self.categories_index[cat][name] = cat_id 144 | cat_map[cat] = cat_id 145 | self.categories_map.append(cat_map) 146 | 147 | def _init_pre2021(self) -> None: 148 | """Initialize based on 2017-2019 layout""" 149 | 150 | # map: category type -> name of category -> index 151 | self.categories_index = {"super": {}} 152 | 153 | cat_index = 0 154 | super_categories = sorted(os.listdir(self.root)) 155 | for sindex, scat in enumerate(super_categories): 156 | self.categories_index["super"][scat] = sindex 157 | subcategories = sorted(os.listdir(os.path.join(self.root, scat))) 158 | for subcat in subcategories: 159 | if self.version == "2017": 160 | # this version does not use ids as directory names 161 | subcat_i = cat_index 162 | cat_index += 1 163 | else: 164 | try: 165 | subcat_i = int(subcat) 166 | except ValueError: 167 | raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}") 168 | if subcat_i >= len(self.categories_map): 169 | old_len = len(self.categories_map) 170 | self.categories_map.extend([{}] * (subcat_i - old_len + 1)) 171 | self.all_categories.extend([""] * (subcat_i - old_len + 1)) 172 | if self.categories_map[subcat_i]: 173 | raise RuntimeError(f"Duplicate category {subcat}") 174 | self.categories_map[subcat_i] = {"super": sindex} 175 | self.all_categories[subcat_i] = os.path.join(scat, subcat) 176 | 177 | # validate the dictionary 178 | for cindex, c in enumerate(self.categories_map): 179 | if not c: 180 | raise RuntimeError(f"Missing category {cindex}") 181 | 182 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 183 | """ 184 | Args: 185 | index (int): Index 186 | 187 | Returns: 188 | tuple: (image, target) where the type of target specified by target_type. 189 | """ 190 | 191 | cat_id, fname = self.index[index] 192 | img = PIL.Image.open(os.path.join(self.root, self.all_categories[cat_id], fname)) 193 | 194 | target: Any = [] 195 | for t in self.target_type: 196 | if t == "full": 197 | target.append(cat_id) 198 | else: 199 | target.append(self.categories_map[cat_id][t]) 200 | target = tuple(target) if len(target) > 1 else target[0] 201 | 202 | if self.transform is not None: 203 | img = self.transform(img) 204 | 205 | if self.target_transform is not None: 206 | target = self.target_transform(target) 207 | 208 | return {'image': image, 'class': torch.tensor([target])} 209 | 210 | 211 | def __len__(self) -> int: 212 | return len(self.index) 213 | 214 | def category_name(self, category_type: str, category_id: int) -> str: 215 | """ 216 | Args: 217 | category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super" 218 | category_id(int): an index (class id) from this category 219 | 220 | Returns: 221 | the name of the category 222 | """ 223 | if category_type == "full": 224 | return self.all_categories[category_id] 225 | else: 226 | if category_type not in self.categories_index: 227 | raise ValueError(f"Invalid category type '{category_type}'") 228 | else: 229 | for name, id in self.categories_index[category_type].items(): 230 | if id == category_id: 231 | return name 232 | raise ValueError(f"Invalid category id {category_id} for {category_type}") 233 | 234 | 235 | def _check_integrity(self) -> bool: 236 | return os.path.exists(self.root) and len(os.listdir(self.root)) > 0 237 | 238 | def download(self) -> None: 239 | if self._check_integrity(): 240 | raise RuntimeError( 241 | f"The directory {self.root} already exists. " 242 | f"If you want to re-download or re-extract the images, delete the directory." 243 | ) 244 | 245 | base_root = os.path.dirname(self.root) 246 | 247 | download_and_extract_archive( 248 | DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version] 249 | ) 250 | 251 | orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz")) 252 | if not os.path.exists(orig_dir_name): 253 | raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}") 254 | os.rename(orig_dir_name, self.root) 255 | print(f"Dataset version '{self.version}' has been downloaded and prepared for use") 256 | 257 | 258 | class INaturalistTrain(INaturalistBase): 259 | def __init__(self, root: str, resolution: Union[Tuple[int, int], int] = 256) -> None: 260 | transform = T.Compose([ 261 | T.Resize(resolution), 262 | T.RandomCrop(resolution), 263 | T.RandomHorizontalFlip(), 264 | T.ToTensor() 265 | ]) 266 | 267 | super().__init__(root=root, version='2021_train', transform=transform) 268 | 269 | class INaturalistValidation(INaturalistBase): 270 | def __init__(self, root: str, resolution: Union[Tuple[int, int], int] = 256) -> None: 271 | transform = T.Compose([ 272 | T.Resize(resolution), 273 | T.CenterCrop(resolution), 274 | T.ToTensor() 275 | ]) 276 | 277 | super().__init__(root=root, version='2021_valid', transform=transform) 278 | -------------------------------------------------------------------------------- /enhancing/dataloader/lsun.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import PIL 8 | from typing import Any, Tuple, Union, List, Optional, Callable 9 | import subprocess 10 | from os.path import join, dirname, abspath, isfile, isdir 11 | 12 | import torch 13 | from torchvision import transforms as T 14 | from torchvision.datasets import LSUN 15 | 16 | 17 | class LSUNBase(LSUN): 18 | def __init__(self, root: str, classes: Union[Tuple[str, str]], 19 | transform: Optional[Callable] = None) -> None: 20 | super().__init__(root, classes, transform) 21 | 22 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 23 | image, target = super().__getitem__(index) 24 | 25 | return {'image': image, 'class': torch.tensor([target])} 26 | 27 | 28 | class LSUNTrain(LSUNBase): 29 | def __init__(self, root: str, classes: Union[Tuple[str, str]], 30 | resolution: Union[Tuple[int, int], int] = 256) -> None: 31 | transform = T.Compose([ 32 | T.Resize(resolution), 33 | T.RandomCrop(resolution), 34 | T.RandomHorizontalFlip(), 35 | T.ToTensor() 36 | ]) 37 | 38 | if classes not in ['train', 'val']: 39 | if not isinstance(classes, list): 40 | classes = [classes] 41 | 42 | classes = [class_+"_train" for class_ in classes] 43 | else: 44 | assert classes == 'train' 45 | 46 | super().__init__(root, classes, transform) 47 | 48 | 49 | class LSUNValidation(LSUNBase): 50 | def __init__(self, root: str, classes: Union[Tuple[str, str]], 51 | resolution: Union[Tuple[int, int], int] = 256) -> None: 52 | transform = T.Compose([ 53 | T.Resize(resolution), 54 | T.CenterCrop(resolution), 55 | T.ToTensor() 56 | ]) 57 | 58 | if classes not in ['train', 'val']: 59 | if not isinstance(classes, list): 60 | classes = [classes] 61 | 62 | classes = [class_+"_val" for class_ in classes] 63 | else: 64 | assert classes == 'val' 65 | 66 | super().__init__(root, classes, transform) 67 | -------------------------------------------------------------------------------- /enhancing/dataloader/srimage.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from DALLE-pytorch (https://github.com/lucidrains/DALLE-pytorch) 7 | # Copyright (c) 2020 Phil Wang. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | from typing import Optional, Tuple, Callable, Union 11 | from pathlib import Path 12 | from random import randint, choice 13 | from omegaconf import OmegaConf 14 | import PIL 15 | 16 | from torch import nn 17 | from torch.utils.data import Dataset 18 | from torchvision import transforms as T 19 | 20 | 21 | class SRBase(Dataset): 22 | def __init__(self, folder: str, split: str, transform: Callable) -> None: 23 | super().__init__() 24 | path = Path(folder)/split 25 | 26 | image_files = [ 27 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 28 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 29 | ] 30 | 31 | image_files = {image_file.stem: image_file for image_file in image_files} 32 | keys = image_files.keys() 33 | 34 | self.keys = list(keys) 35 | self.image_files = {k: v for k, v in image_files.items() if k in keys} 36 | 37 | self.hr_transform = transform 38 | 39 | def __len__(self): 40 | return len(self.keys) 41 | 42 | def random_sample(self): 43 | return self.__getitem__(randint(0, self.__len__() - 1)) 44 | 45 | def sequential_sample(self, ind): 46 | if ind >= self.__len__() - 1: 47 | return self.__getitem__(0) 48 | return self.__getitem__(ind + 1) 49 | 50 | def skip_sample(self, ind): 51 | return self.sequential_sample(ind=ind) 52 | 53 | def pad(self, img: PIL.Image.Image) -> PIL.Image.Image: 54 | if isinstance(self.resolution, int): 55 | self.resolution = (self.resolution, self.resolution) 56 | 57 | assert img.size[0] <= self.resolution[1] and img.size[1] <= self.resolution[0] 58 | left = (self.resolution[1] - img.size[0]) // 2 59 | top = (self.resolution[0] - img.size[1]) // 2 60 | right = self.resolution[1] - img.size[0] - left 61 | bottom = self.resolution[0] - img.size[1] - top 62 | 63 | return T.functional.pad(img, (left, top, right, bottom)) 64 | 65 | def __getitem__(self, ind): 66 | key = self.keys[ind] 67 | image_file = self.image_files[key] 68 | 69 | try: 70 | hr_img = PIL.Image.open(image_file) 71 | if hr_img.mode != 'RGB': 72 | hr_img = hr_img.convert('RGB') 73 | 74 | hr_tensor = self.hr_transform(hr_img) 75 | 76 | down_size = (hr_tensor.shape[1]//self.downscale, hr_tensor.shape[2]//self.downscale) 77 | lr_tensor = T.Resize(down_size, 3)(hr_tensor) 78 | except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions: 79 | print(f"An exception occurred trying to load file {image_file}.") 80 | print(f"Skipping index {ind}") 81 | return self.skip_sample(ind) 82 | 83 | # Success 84 | return {'low resolution': lr_tensor, 'high resolution': hr_tensor} 85 | 86 | 87 | class SRTrain(SRBase): 88 | def __init__(self, folder: str, 89 | resolution: Union[Tuple[int, int], int] = 2048, 90 | crop_resolution: Union[Tuple[int, int], int] = 512, 91 | downscale: int = 4) -> None: 92 | assert resolution % downscale == 0 93 | self.resolution = resolution 94 | self.downscale = downscale 95 | 96 | transform = T.Compose([ 97 | T.RandomCrop(crop_resolution), 98 | T.Lambda(self.pad), 99 | T.RandomHorizontalFlip(), 100 | T.RandomVerticalFlip(), 101 | T.ToTensor() 102 | ]) 103 | 104 | super().__init__(folder, 'train', transform) 105 | 106 | 107 | class SRValidation(SRBase): 108 | def __init__(self, folder: str, 109 | resolution: Union[Tuple[int, int], int] = 2048, 110 | downscale: int = 4) -> None: 111 | assert resolution % downscale == 0 112 | self.resolution = resolution 113 | self.downscale = downscale 114 | 115 | transform = T.Compose([ 116 | T.Lambda(self.pad), 117 | T.ToTensor() 118 | ]) 119 | 120 | super().__init__(folder, 'val', transform) 121 | 122 | -------------------------------------------------------------------------------- /enhancing/dataloader/textimage.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from DALLE-pytorch (https://github.com/lucidrains/DALLE-pytorch) 7 | # Copyright (c) 2020 Phil Wang. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | from typing import Optional, Union, Callable, Tuple, Any 11 | from pathlib import Path 12 | from random import randint, choice 13 | from omegaconf import OmegaConf 14 | import PIL 15 | 16 | import torch 17 | from torch.utils.data import Dataset 18 | from torchvision import transforms as T 19 | 20 | from ..utils.general import initialize_from_config 21 | 22 | 23 | class TextImageBase(Dataset): 24 | def __init__(self, folder: str, split: str, 25 | tokenizer: OmegaConf, 26 | transform: Callable) -> None: 27 | super().__init__() 28 | path = Path(folder)/split 29 | 30 | text_files = [*path.glob('**/*.txt')] 31 | image_files = [ 32 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 33 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 34 | ] 35 | 36 | text_files = {text_file.stem: text_file for text_file in text_files} 37 | image_files = {image_file.stem: image_file for image_file in image_files} 38 | 39 | keys = (image_files.keys() & text_files.keys()) 40 | 41 | self.keys = list(keys) 42 | self.text_files = {k: v for k, v in text_files.items() if k in keys} 43 | self.image_files = {k: v for k, v in image_files.items() if k in keys} 44 | self.tokenizer = initialize_from_config(tokenizer) 45 | self.image_transform = transform 46 | 47 | def __len__(self) -> int: 48 | return len(self.keys) 49 | 50 | def random_sample(self) -> Tuple[Any, Any]: 51 | return self.__getitem__(randint(0, self.__len__() - 1)) 52 | 53 | def sequential_sample(self, ind: int) -> Tuple[Any, Any]: 54 | if ind >= self.__len__() - 1: 55 | return self.__getitem__(0) 56 | return self.__getitem__(ind + 1) 57 | 58 | def skip_sample(self, ind: int) -> Tuple[Any, Any]: 59 | return self.sequential_sample(ind=ind) 60 | 61 | def __getitem__(self, ind: int) -> Tuple[Any, Any]: 62 | key = self.keys[ind] 63 | 64 | text_file = self.text_files[key] 65 | image_file = self.image_files[key] 66 | 67 | descriptions = text_file.read_text().split('\n') 68 | descriptions = list(filter(lambda t: len(t) > 0, descriptions)) 69 | 70 | try: 71 | description = choice(descriptions) 72 | except IndexError as zero_captions_in_file_ex: 73 | print(f"An exception occurred trying to load file {text_file}.") 74 | print(f"Skipping index {ind}") 75 | return self.skip_sample(ind) 76 | 77 | tokenized_text = self.tokenizer.tokenize(description).squeeze(0) 78 | try: 79 | image = PIL.Image.open(image_file) 80 | if image.mode != 'RGB': 81 | image = image.convert('RGB') 82 | image_tensor = self.image_transform(image) 83 | except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions: 84 | print(f"An exception occurred trying to load file {image_file}.") 85 | print(f"Skipping index {ind}") 86 | return self.skip_sample(ind) 87 | 88 | # Success 89 | return {"caption": tokenized_text, "image": image_tensor} 90 | 91 | 92 | class TextImageTrain(TextImageBase): 93 | def __init__(self, folder: str, 94 | tokenizer: OmegaConf, 95 | resolution: Union[Tuple[int, int], int] = 256) -> None: 96 | transform = T.Compose([ 97 | T.Resize(resolution), 98 | T.RandomCrop(resolution), 99 | T.ToTensor(), 100 | ]) 101 | 102 | super().__init__(folder, 'train', tokenizer, transform) 103 | 104 | 105 | class TextImageValidation(TextImageBase): 106 | def __init__(self, folder: str, 107 | tokenizer: OmegaConf, 108 | resolution: Union[Tuple[int, int], int] = 256) -> None: 109 | if isinstance(resolution, int): 110 | resolution = [resolution, resolution] 111 | 112 | transform = T.Compose([ 113 | T.Resize(resolution), 114 | T.CenterCrop(resolution), 115 | T.ToTensor(), 116 | ]) 117 | 118 | super().__init__(folder, 'val', tokenizer, transform) 119 | -------------------------------------------------------------------------------- /enhancing/losses/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | # Modified from StyleGAN2-Pytorch (https://github.com/rosinality/stylegan2-pytorch) 6 | # Copyright (c) 2019 Kim Seonghyeon. All Rights Reserved. 7 | # ------------------------------------------------------------------------------------ 8 | 9 | 10 | from math import log2, sqrt 11 | from functools import partial 12 | from typing import Optional, Union, Tuple, List 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from kornia.filters import filter2d 18 | 19 | from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 20 | 21 | 22 | def hinge_d_loss(logits_fake: torch.FloatTensor, logits_real: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: 23 | loss_fake = - logits_fake.mean() * 2 if logits_real is None else F.relu(1. + logits_fake).mean() 24 | loss_real = 0 if logits_real is None else F.relu(1. - logits_real).mean() 25 | 26 | return 0.5 * (loss_real + loss_fake) 27 | 28 | 29 | def vanilla_d_loss(logits_fake: torch.FloatTensor, logits_real: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: 30 | loss_fake = F.softplus(-logits_fake).mean() * 2 if logits_real is None else F.softplus(logits_fake).mean() 31 | loss_real = 0 if logits_real is None else F.softplus(-logits_real).mean() 32 | 33 | return 0.5 * (loss_real + loss_fake) 34 | 35 | 36 | def least_square_d_loss(logits_fake: torch.FloatTensor, logits_real: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: 37 | loss_fake = logits_fake.pow(2).mean() * 2 if logits_real is None else (1 + logits_fake).pow(2).mean() 38 | loss_real = 0 if logits_real is None else (1 - logits_real).pow(2).mean() 39 | 40 | return 0.5 * (loss_real + loss_fake) 41 | 42 | 43 | def weights_init(m: nn.Module) -> None: 44 | classname = m.__class__.__name__ 45 | if classname.find('Conv') != -1: 46 | nn.init.normal_(m.weight.data, 0.0, 0.02) 47 | elif classname.find('BatchNorm') != -1: 48 | nn.init.normal_(m.weight.data, 1.0, 0.02) 49 | nn.init.constant_(m.bias.data, 0) 50 | 51 | 52 | class ActNorm(nn.Module): 53 | def __init__(self, num_features: int, 54 | logdet: Optional[bool] = False, 55 | affine: Optional[bool] = True, 56 | allow_reverse_init: Optional[bool] = False) -> None: 57 | assert affine 58 | super().__init__() 59 | self.logdet = logdet 60 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 61 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 62 | self.allow_reverse_init = allow_reverse_init 63 | 64 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 65 | 66 | def initialize(self, input: torch.FloatTensor) -> None: 67 | with torch.no_grad(): 68 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 69 | mean = ( 70 | flatten.mean(1) 71 | .unsqueeze(1) 72 | .unsqueeze(2) 73 | .unsqueeze(3) 74 | .permute(1, 0, 2, 3) 75 | ) 76 | std = ( 77 | flatten.std(1) 78 | .unsqueeze(1) 79 | .unsqueeze(2) 80 | .unsqueeze(3) 81 | .permute(1, 0, 2, 3) 82 | ) 83 | 84 | self.loc.data.copy_(-mean) 85 | self.scale.data.copy_(1 / (std + 1e-6)) 86 | 87 | def forward(self, input: torch.FloatTensor, reverse: Optional[bool] = False) -> Union[torch.FloatTensor, Tuple]: 88 | if reverse: 89 | return self.reverse(input) 90 | if len(input.shape) == 2: 91 | input = input[:,:,None,None] 92 | squeeze = True 93 | else: 94 | squeeze = False 95 | 96 | _, _, height, width = input.shape 97 | 98 | if self.training and self.initialized.item() == 0: 99 | self.initialize(input) 100 | self.initialized.fill_(1) 101 | 102 | h = self.scale * (input + self.loc) 103 | 104 | if squeeze: 105 | h = h.squeeze(-1).squeeze(-1) 106 | 107 | if self.logdet: 108 | log_abs = torch.log(torch.abs(self.scale)) 109 | logdet = height*width*torch.sum(log_abs) 110 | logdet = logdet * torch.ones(input.shape[0]).to(input) 111 | return h, logdet 112 | 113 | return h 114 | 115 | def reverse(self, output: torch.FloatTensor) -> torch.FloatTensor: 116 | if self.training and self.initialized.item() == 0: 117 | if not self.allow_reverse_init: 118 | raise RuntimeError( 119 | "Initializing ActNorm in reverse direction is " 120 | "disabled by default. Use allow_reverse_init=True to enable." 121 | ) 122 | else: 123 | self.initialize(output) 124 | self.initialized.fill_(1) 125 | 126 | if len(output.shape) == 2: 127 | output = output[:,:,None,None] 128 | squeeze = True 129 | else: 130 | squeeze = False 131 | 132 | h = output / self.scale - self.loc 133 | 134 | if squeeze: 135 | h = h.squeeze(-1).squeeze(-1) 136 | 137 | return h 138 | 139 | 140 | class Blur(nn.Module): 141 | def __init__(self, kernel, pad, upsample_factor=1): 142 | super().__init__() 143 | 144 | kernel = torch.tensor(kernel, dtype=torch.float32) 145 | if kernel.ndim == 1: 146 | kernel = kernel[None, :] * kernel[:, None] 147 | 148 | kernel /= kernel.sum() 149 | 150 | if upsample_factor > 1: 151 | kernel = kernel * (upsample_factor ** 2) 152 | 153 | self.register_buffer("kernel", kernel) 154 | 155 | self.pad = pad 156 | 157 | def forward(self, input): 158 | out = upfirdn2d(input, self.kernel, pad=self.pad) 159 | 160 | return out 161 | 162 | 163 | class EqualConv2d(nn.Module): 164 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): 165 | super().__init__() 166 | 167 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) 168 | self.bias = nn.Parameter(torch.zeros(out_channel)) if bias else None 169 | 170 | self.scale = 1 / sqrt(in_channel * kernel_size ** 2) 171 | 172 | self.stride = stride 173 | self.padding = padding 174 | 175 | def forward(self, input): 176 | out = conv2d_gradfix.conv2d( 177 | input, 178 | self.weight * self.scale, 179 | bias=self.bias, 180 | stride=self.stride, 181 | padding=self.padding, 182 | ) 183 | 184 | return out 185 | 186 | 187 | class EqualLinear(nn.Module): 188 | def __init__( 189 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 190 | ): 191 | super().__init__() 192 | 193 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 194 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) if bias else None 195 | 196 | self.activation = activation 197 | 198 | self.scale = (1 / sqrt(in_dim)) * lr_mul 199 | self.lr_mul = lr_mul 200 | 201 | def forward(self, input): 202 | if self.activation: 203 | out = F.linear(input, self.weight * self.scale) 204 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 205 | 206 | else: 207 | out = F.linear( 208 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 209 | ) 210 | 211 | return out 212 | 213 | 214 | class ConvLayer(nn.Sequential): 215 | def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True): 216 | layers = [] 217 | 218 | if downsample: 219 | factor = 2 220 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 221 | pad0 = (p + 1) // 2 222 | pad1 = p // 2 223 | 224 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 225 | 226 | stride = 2 227 | self.padding = 0 228 | else: 229 | stride = 1 230 | self.padding = kernel_size // 2 231 | 232 | layers.append( 233 | EqualConv2d( 234 | in_channel, out_channel, 235 | kernel_size, padding=self.padding, 236 | stride=stride, bias=bias and not activate 237 | ) 238 | ) 239 | 240 | if activate: 241 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 242 | 243 | super().__init__(*layers) 244 | 245 | 246 | class StyleBlock(nn.Module): 247 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 248 | super().__init__() 249 | 250 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 251 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 252 | 253 | self.skip = ConvLayer( 254 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 255 | ) 256 | 257 | def forward(self, input): 258 | out = self.conv1(input) 259 | out = self.conv2(out) 260 | 261 | skip = self.skip(input) 262 | out = (out + skip) / sqrt(2) 263 | 264 | return out 265 | 266 | 267 | class PatchDiscriminator(nn.Module): 268 | """Defines a PatchGAN discriminator as in Pix2Pix 269 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 270 | """ 271 | def __init__(self, input_nc: int = 3, ndf: int = 64, n_layers: int = 3, use_actnorm: bool = False) -> None: 272 | """Construct a PatchGAN discriminator 273 | Parameters: 274 | input_nc (int) -- the number of channels in input images 275 | ndf (int) -- the number of filters in the last conv layer 276 | n_layers (int) -- the number of conv layers in the discriminator 277 | norm_layer -- normalization layer 278 | """ 279 | super().__init__() 280 | if not use_actnorm: 281 | norm_layer = nn.BatchNorm2d 282 | else: 283 | norm_layer = ActNorm 284 | if type(norm_layer) == partial: # no need to use bias as BatchNorm2d has affine parameters 285 | use_bias = norm_layer.func != nn.BatchNorm2d 286 | else: 287 | use_bias = norm_layer != nn.BatchNorm2d 288 | 289 | kw = 4 290 | padw = 1 291 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 292 | nf_mult = 1 293 | nf_mult_prev = 1 294 | for n in range(1, n_layers): # gradually increase the number of filters 295 | nf_mult_prev = nf_mult 296 | nf_mult = min(2 ** n, 8) 297 | sequence += [ 298 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 299 | norm_layer(ndf * nf_mult), 300 | nn.LeakyReLU(0.2, True) 301 | ] 302 | 303 | nf_mult_prev = nf_mult 304 | nf_mult = min(2 ** n_layers, 8) 305 | sequence += [ 306 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 307 | norm_layer(ndf * nf_mult), 308 | nn.LeakyReLU(0.2, True) 309 | ] 310 | 311 | sequence += [ 312 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 313 | self.main = nn.Sequential(*sequence) 314 | 315 | self.apply(weights_init) 316 | 317 | def forward(self, input: torch.FloatTensor) -> torch.FloatTensor: 318 | """Standard forward.""" 319 | return self.main(input) 320 | 321 | 322 | class StyleDiscriminator(nn.Module): 323 | def __init__(self, size=256, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 324 | super().__init__() 325 | 326 | channels = { 327 | 4: 512, 328 | 8: 512, 329 | 16: 512, 330 | 32: 512, 331 | 64: 256 * channel_multiplier, 332 | 128: 128 * channel_multiplier, 333 | 256: 64 * channel_multiplier, 334 | 512: 32 * channel_multiplier, 335 | 1024: 16 * channel_multiplier, 336 | } 337 | 338 | log_size = int(log2(size)) 339 | in_channel = channels[size] 340 | 341 | blocks = [ConvLayer(3, channels[size], 1)] 342 | for i in range(log_size, 2, -1): 343 | out_channel = channels[2 ** (i - 1)] 344 | blocks.append(StyleBlock(in_channel, out_channel, blur_kernel)) 345 | in_channel = out_channel 346 | 347 | self.blocks = nn.Sequential(*blocks) 348 | 349 | self.stddev_group = 4 350 | self.stddev_feat = 1 351 | 352 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 353 | self.final_linear = nn.Sequential( 354 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 355 | EqualLinear(channels[4], 1), 356 | ) 357 | 358 | def forward(self, x): 359 | out = self.blocks(x) 360 | batch, channel, height, width = out.shape 361 | 362 | group = min(batch, self.stddev_group) 363 | group = batch//(batch//group) 364 | 365 | stddev = out.view( 366 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 367 | ) 368 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 369 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 370 | stddev = stddev.repeat(group, 1, height, width) 371 | out = torch.cat([out, stddev], 1) 372 | 373 | out = self.final_conv(out) 374 | out = out.view(out.shape[0], -1) 375 | out = self.final_linear(out) 376 | 377 | return out.squeeze() 378 | -------------------------------------------------------------------------------- /enhancing/losses/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /enhancing/losses/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /enhancing/losses/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input.contiguous(), 51 | gradgrad_bias, 52 | out, 53 | 3, 54 | 1, 55 | ctx.negative_slope, 56 | ctx.scale, 57 | ) 58 | 59 | return gradgrad_out, None, None, None, None 60 | 61 | 62 | class FusedLeakyReLUFunction(Function): 63 | @staticmethod 64 | def forward(ctx, input, bias, negative_slope, scale): 65 | empty = input.new_empty(0) 66 | 67 | ctx.bias = bias is not None 68 | 69 | if bias is None: 70 | bias = empty 71 | 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | if not ctx.bias: 88 | grad_bias = None 89 | 90 | return grad_input, grad_bias, None, None 91 | 92 | 93 | class FusedLeakyReLU(nn.Module): 94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 95 | super().__init__() 96 | 97 | if bias: 98 | self.bias = nn.Parameter(torch.zeros(channel)) 99 | 100 | else: 101 | self.bias = None 102 | 103 | self.negative_slope = negative_slope 104 | self.scale = scale 105 | 106 | def forward(self, input): 107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 108 | 109 | 110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 111 | if input.device.type == "cpu": 112 | if bias is not None: 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return F.leaky_relu(input, negative_slope=0.2) * scale 123 | 124 | else: 125 | return FusedLeakyReLUFunction.apply( 126 | input.contiguous(), bias, negative_slope, scale 127 | ) 128 | -------------------------------------------------------------------------------- /enhancing/losses/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /enhancing/losses/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /enhancing/losses/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /enhancing/losses/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /enhancing/losses/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /enhancing/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BCELoss(nn.Module): 11 | def forward(self, prediction, target): 12 | loss = F.binary_cross_entropy_with_logits(prediction,target) 13 | 14 | return loss, {} 15 | 16 | 17 | class BCELossWithQuant(nn.Module): 18 | def __init__(self, codebook_weight=1.): 19 | super().__init__() 20 | self.codebook_weight = codebook_weight 21 | 22 | def forward(self, qloss, target, prediction, split): 23 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 24 | loss = bce_loss + self.codebook_weight*qloss 25 | 26 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 27 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 28 | "{}/quant_loss".format(split): qloss.detach().mean() 29 | } 30 | 31 | return loss, log 32 | -------------------------------------------------------------------------------- /enhancing/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | from omegaconf import OmegaConf 7 | from typing import Optional, Tuple 8 | 9 | import lpips 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .layers import * 15 | 16 | 17 | class DummyLoss(nn.Module): 18 | def __init__(self) -> None: 19 | super().__init__() 20 | 21 | 22 | class VQLPIPS(nn.Module): 23 | def __init__(self, codebook_weight: float = 1.0, 24 | loglaplace_weight: float = 1.0, 25 | loggaussian_weight: float = 1.0, 26 | perceptual_weight: float = 1.0) -> None: 27 | 28 | super().__init__() 29 | self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False) 30 | 31 | self.codebook_weight = codebook_weight 32 | self.loglaplace_weight = loglaplace_weight 33 | self.loggaussian_weight = loggaussian_weight 34 | self.perceptual_weight = perceptual_weight 35 | 36 | def forward(self, codebook_loss: torch.FloatTensor, inputs: torch.FloatTensor, reconstructions: torch.FloatTensor, optimizer_idx: int, 37 | global_step: int, batch_idx: int, last_layer: Optional[nn.Module] = None, split: Optional[str] = "train") -> Tuple: 38 | inputs = inputs.contiguous() 39 | reconstructions = reconstructions.contiguous() 40 | 41 | loglaplace_loss = (reconstructions - inputs).abs().mean() 42 | loggaussian_loss = (reconstructions - inputs).pow(2).mean() 43 | perceptual_loss = self.perceptual_loss(inputs*2-1, reconstructions*2-1).mean() 44 | 45 | nll_loss = self.loglaplace_weight * loglaplace_loss + self.loggaussian_weight * loggaussian_loss + self.perceptual_weight * perceptual_loss 46 | loss = nll_loss + self.codebook_weight * codebook_loss 47 | 48 | log = {"{}/total_loss".format(split): loss.clone().detach(), 49 | "{}/quant_loss".format(split): codebook_loss.detach(), 50 | "{}/rec_loss".format(split): nll_loss.detach(), 51 | "{}/loglaplace_loss".format(split): loglaplace_loss.detach(), 52 | "{}/loggaussian_loss".format(split): loggaussian_loss.detach(), 53 | "{}/perceptual_loss".format(split): perceptual_loss.detach() 54 | } 55 | 56 | return loss, log 57 | 58 | 59 | class VQLPIPSWithDiscriminator(nn.Module): 60 | def __init__(self, disc_start: int = 0, 61 | disc_loss: str = 'vanilla', 62 | disc_params: Optional[OmegaConf] = dict(), 63 | codebook_weight: float = 1.0, 64 | loglaplace_weight: float = 1.0, 65 | loggaussian_weight: float = 1.0, 66 | perceptual_weight: float = 1.0, 67 | adversarial_weight: float = 1.0, 68 | use_adaptive_adv: bool = False, 69 | r1_gamma: float = 10, 70 | do_r1_every: int = 16) -> None: 71 | 72 | super().__init__() 73 | assert disc_loss in ["hinge", "vanilla", "least_square"], f"Unknown GAN loss '{disc_loss}'." 74 | self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False) 75 | 76 | self.codebook_weight = codebook_weight 77 | self.loglaplace_weight = loglaplace_weight 78 | self.loggaussian_weight = loggaussian_weight 79 | self.perceptual_weight = perceptual_weight 80 | 81 | self.discriminator = StyleDiscriminator(**disc_params) 82 | self.discriminator_iter_start = disc_start 83 | if disc_loss == "hinge": 84 | self.disc_loss = hinge_d_loss 85 | elif disc_loss == "vanilla": 86 | self.disc_loss = vanilla_d_loss 87 | elif disc_loss == "least_square": 88 | self.disc_loss = least_square_d_loss 89 | 90 | self.adversarial_weight = adversarial_weight 91 | self.use_adaptive_adv = use_adaptive_adv 92 | self.r1_gamma = r1_gamma 93 | self.do_r1_every = do_r1_every 94 | 95 | def calculate_adaptive_factor(self, nll_loss: torch.FloatTensor, 96 | g_loss: torch.FloatTensor, last_layer: nn.Module) -> torch.FloatTensor: 97 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 98 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 99 | 100 | adapt_factor = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 101 | adapt_factor = adapt_factor.clamp(0.0, 1e4).detach() 102 | 103 | return adapt_factor 104 | 105 | def forward(self, codebook_loss: torch.FloatTensor, inputs: torch.FloatTensor, reconstructions: torch.FloatTensor, optimizer_idx: int, 106 | global_step: int, batch_idx: int, last_layer: Optional[nn.Module] = None, split: Optional[str] = "train") -> Tuple: 107 | inputs = inputs.contiguous() 108 | reconstructions = reconstructions.contiguous() 109 | 110 | # now the GAN part 111 | if optimizer_idx == 0: 112 | # generator update 113 | loglaplace_loss = (reconstructions - inputs).abs().mean() 114 | loggaussian_loss = (reconstructions - inputs).pow(2).mean() 115 | perceptual_loss = self.perceptual_loss(inputs*2-1, reconstructions*2-1).mean() 116 | 117 | nll_loss = self.loglaplace_weight * loglaplace_loss + self.loggaussian_weight * loggaussian_loss + self.perceptual_weight * perceptual_loss 118 | 119 | logits_fake = self.discriminator(reconstructions) 120 | g_loss = self.disc_loss(logits_fake) 121 | 122 | try: 123 | d_weight = self.adversarial_weight 124 | 125 | if self.use_adaptive_adv: 126 | d_weight *= self.calculate_adaptive_factor(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = 1 if global_step >= self.discriminator_iter_start else 0 132 | loss = nll_loss + disc_factor * d_weight * g_loss + self.codebook_weight * codebook_loss 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach(), 135 | "{}/quant_loss".format(split): codebook_loss.detach(), 136 | "{}/rec_loss".format(split): nll_loss.detach(), 137 | "{}/loglaplace_loss".format(split): loglaplace_loss.detach(), 138 | "{}/loggaussian_loss".format(split): loggaussian_loss.detach(), 139 | "{}/perceptual_loss".format(split): perceptual_loss.detach(), 140 | "{}/g_loss".format(split): g_loss.detach(), 141 | } 142 | 143 | if self.use_adaptive_adv: 144 | log["{}/d_weight".format(split)] = d_weight.detach() 145 | 146 | return loss, log 147 | 148 | if optimizer_idx == 1: 149 | # second pass for discriminator update 150 | disc_factor = 1 if global_step >= self.discriminator_iter_start else 0 151 | do_r1 = self.training and bool(disc_factor) and batch_idx % self.do_r1_every == 0 152 | 153 | logits_real = self.discriminator(inputs.requires_grad_(do_r1)) 154 | logits_fake = self.discriminator(reconstructions.detach()) 155 | 156 | d_loss = disc_factor * self.disc_loss(logits_fake, logits_real) 157 | if do_r1: 158 | with conv2d_gradfix.no_weight_gradients(): 159 | gradients, = torch.autograd.grad(outputs=logits_real.sum(), inputs=inputs, create_graph=True) 160 | 161 | gradients_norm = gradients.square().sum([1,2,3]).mean() 162 | d_loss += self.r1_gamma * self.do_r1_every * gradients_norm/2 163 | 164 | log = {"{}/disc_loss".format(split): d_loss.detach(), 165 | "{}/logits_real".format(split): logits_real.detach().mean(), 166 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 167 | } 168 | 169 | if do_r1: 170 | log["{}/r1_reg".format(split)] = gradients_norm.detach() 171 | 172 | return d_loss, log 173 | -------------------------------------------------------------------------------- /enhancing/modules/cond/clipcond.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | from omegaconf import OmegaConf 8 | from typing import Tuple, Union, List, Any 9 | 10 | import clip 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import transforms as T 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | from .dummycond import DummyCond 17 | from ...utils.general import initialize_from_config 18 | 19 | 20 | class ClipTextCond(DummyCond): 21 | def __init__(self, image_size: Union[Tuple[int, int], int], 22 | clip_model: str, tokenizer: OmegaConf) -> None: 23 | super().__init__() 24 | self.image_size = image_size 25 | self.clip_model, _ = clip.load(clip_model, device=device) 26 | self.tokenizer = initialize_from_config(tokenizer) 27 | 28 | def encode_codes(self, text: torch.LongTensor) -> torch.FloatTensor: 29 | with torch.no_grad(): 30 | text_features = model.encode_text(text) 31 | 32 | return text_features 33 | 34 | def to_img(self, texts: torch.LongTensor) -> torch.FloatTensor: 35 | W, H = self.image_size if isinstance(self.image_size, tuple) else (self.image_size, self.image_size) 36 | font = ImageFont.truetype("arial.ttf", 12) 37 | 38 | imgs = [] 39 | for text in texts: 40 | text = self.tokenizer.decode(text) 41 | words = text.split() 42 | length = 0 43 | 44 | for idx, word in enumerate(words): 45 | if length > 27: 46 | length = 0 47 | word[idx-int(idx>0)] += '\n' 48 | 49 | length += len(word) 50 | 51 | img = Image.new("RGBA", (W, H), "white") 52 | draw = ImageDraw.Draw(img) 53 | 54 | w, h = draw.textsize(text, font) 55 | draw.text(((W-w)/2,(H-h)/2), text, font=font, fill="black", align="center") 56 | 57 | img = img.convert('RGB') 58 | img = T.ToTensor()(img) 59 | imgs.append(img) 60 | 61 | return torch.stack(imgs, dim=0) 62 | 63 | 64 | class ClipImageCond(DummyCond): 65 | def __init__(self, clip_model: str) -> None: 66 | super().__init__() 67 | self.clip_model, _ = clip.load(clip_model, device=device) 68 | 69 | def encode_codes(self, image: torch.FloatTensor) -> torch.FloatTensor: 70 | with torch.no_grad(): 71 | image_features = model.encode_image(image) 72 | 73 | return image_features 74 | 75 | def to_img(self, image: torch.FloatTensor) -> torch.FloatTensor: 76 | return image.clamp(0, 1) 77 | -------------------------------------------------------------------------------- /enhancing/modules/cond/dummycond.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | from omegaconf import OmegaConf 9 | from typing import Tuple, Union, List, Any 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import transforms as T 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | from ...utils.general import initialize_from_config 17 | 18 | 19 | class DummyCond(nn.Module): 20 | def __init__(self) -> None: 21 | super().__init__() 22 | 23 | def encode(self, condition: Any) -> Tuple[Any, Any, Any]: 24 | return condition, None, condition 25 | 26 | def decode(self, condition: Any) -> Any: 27 | return condition 28 | 29 | def encode_codes(self, condition: Any) -> Any: 30 | return condition 31 | 32 | def decode_codes(self, condition: Any) -> Any: 33 | return condition 34 | 35 | 36 | class TextCond(DummyCond): 37 | def __init__(self, image_size: Union[Tuple[int, int], int], tokenizer: OmegaConf) -> None: 38 | super().__init__() 39 | self.image_size = image_size 40 | self.tokenizer = initialize_from_config(tokenizer) 41 | 42 | def to_img(self, texts: torch.LongTensor) -> torch.FloatTensor: 43 | W, H = self.image_size if isinstance(self.image_size, tuple) else (self.image_size, self.image_size) 44 | font = ImageFont.truetype(os.path.join(os.getcwd(), "assets", "font", "arial.ttf"), 12) 45 | 46 | imgs = [] 47 | for text in texts: 48 | text = self.tokenizer.decode(text) 49 | words = text.split() 50 | length = 0 51 | 52 | for idx, word in enumerate(words): 53 | if length > 27: 54 | length = 0 55 | word[idx-int(idx>0)] += '\n' 56 | 57 | length += len(word) 58 | 59 | img = Image.new("RGBA", (W, H), "white") 60 | draw = ImageDraw.Draw(img) 61 | 62 | w, h = draw.textsize(text, font) 63 | draw.text(((W-w)/2,(H-h)/2), text, font=font, fill="black", align="center") 64 | 65 | img = img.convert('RGB') 66 | img = T.ToTensor()(img) 67 | imgs.append(img) 68 | 69 | return torch.stack(imgs, dim=0) 70 | 71 | 72 | class ClassCond(DummyCond): 73 | def __init__(self, image_size: Union[Tuple[int, int], int], class_name: Union[str, List[str]]) -> None: 74 | super().__init__() 75 | self.img_size = image_size 76 | if isinstance(class_name, str): 77 | if class_name.endswith("txt") and os.path.isfile(class_name): 78 | self.cls_name = open(class_name, "r").read().split("\n") 79 | elif "." not in class_name and not os.path.isfile(class_name): 80 | self.cls_name = class_name 81 | elif isinstance(class_name, list) and isinstance(class_name[0], str): 82 | self.cls_name = class_name 83 | else: 84 | raise Exception("Class file format not supported") 85 | 86 | def to_img(self, clss: torch.LongTensor) -> torch.FloatTensor: 87 | W, H = self.img_size if isinstance(self.img_size, tuple) else (self.img_size, self.img_size) 88 | font = ImageFont.truetype(os.path.join(os.getcwd(), "assets", "font", "arial.ttf"), 12) 89 | 90 | imgs = [] 91 | for cls in clss: 92 | cls_name = self.cls_name[int(cls)] 93 | length = 0 94 | 95 | img = Image.new("RGBA", (W, H), "white") 96 | draw = ImageDraw.Draw(img) 97 | 98 | w, h = draw.textsize(cls_name, font) 99 | draw.text(((W-w)/2,(H-h)/2), cls_name, font=font, fill="black", align="center") 100 | 101 | img = img.convert('RGB') 102 | img = T.ToTensor()(img) 103 | imgs.append(img) 104 | 105 | return torch.stack(imgs, dim=0) 106 | -------------------------------------------------------------------------------- /enhancing/modules/cond/vqcond.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | import numpy as np 11 | from typing import Tuple, Dict, Any 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from ...utils.general import get_obj_from_str 17 | 18 | 19 | def VQCond(base_class: str, *args, **kwargs) -> object: 20 | def to_img(x: torch.FloatTensor) -> torch.FloatTensor: 21 | return x.clamp(0, 1) 22 | 23 | model = get_obj_from_str(base_class)(*args, **kwargs) 24 | model.to_img = to_img 25 | 26 | return model 27 | 28 | 29 | def VQSegmentation(base_class: str, n_labels: int, *args, **kwargs) -> object: 30 | base_model_cls = get_obj_from_str(base_class) 31 | class Wrapper(base_model_cls): 32 | def __init__(self) -> None: 33 | self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) 34 | super().__init__(*args, **kwargs) 35 | 36 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 37 | x = self.get_input(batch, self.image_key) 38 | xrec, qloss = self(x) 39 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") 40 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 41 | self.log("train/total_loss", total_loss, 42 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 43 | 44 | return aeloss 45 | 46 | def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> torch.FloatTensor: 47 | x = self.get_input(batch, self.image_key) 48 | xrec, qloss = self(x) 49 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") 50 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 51 | total_loss = log_dict_ae["val/total_loss"] 52 | self.log("val/total_loss", total_loss, 53 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 54 | 55 | return aeloss 56 | 57 | @torch.no_grad() 58 | def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict: 59 | log = dict() 60 | x = self.get_input(batch, self.image_key).to(self.device) 61 | xrec, _ = self(x) 62 | if x.shape[1] > 3: 63 | # colorize with random projection 64 | assert xrec.shape[1] > 3 65 | # convert logits to indices 66 | xrec = torch.argmax(xrec, dim=1, keepdim=True) 67 | xrec = F.one_hot(xrec, num_classes=x.shape[1]) 68 | xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() 69 | x = self.to_img(x) 70 | xrec = self.to_img(xrec) 71 | log["inputs"] = x 72 | log["reconstructions"] = xrec 73 | 74 | return log 75 | 76 | def to_img(self, x: torch.FloatTensor) -> torch.FloatTensor: 77 | x = F.conv2d(x, weight=self.colorize) 78 | 79 | return (x-x.min())/(x.max()-x.min()) 80 | 81 | return Wrapper() 82 | -------------------------------------------------------------------------------- /enhancing/modules/stage1/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from ViT-Pytorch (https://github.com/lucidrains/vit-pytorch) 7 | # Copyright (c) 2020 Phil Wang. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | import math 11 | import numpy as np 12 | from typing import Union, Tuple, List 13 | from collections import OrderedDict 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from einops import rearrange, repeat 19 | from einops.layers.torch import Rearrange 20 | 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size): 22 | """ 23 | grid_size: int or (int, int) of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_size = (grid_size, grid_size) if type(grid_size) != tuple else grid_size 28 | grid_h = np.arange(grid_size[0], dtype=np.float32) 29 | grid_w = np.arange(grid_size[1], dtype=np.float32) 30 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 31 | grid = np.stack(grid, axis=0) 32 | 33 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 34 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 35 | 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 40 | assert embed_dim % 2 == 0 41 | 42 | # use half of dimensions to encode grid_h 43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 45 | 46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float) 58 | omega /= embed_dim / 2. 59 | omega = 1. / 10000**omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | 70 | 71 | def init_weights(m): 72 | if isinstance(m, nn.Linear): 73 | # we use xavier_uniform following official JAX ViT: 74 | torch.nn.init.xavier_uniform_(m.weight) 75 | if m.bias is not None: 76 | nn.init.constant_(m.bias, 0) 77 | elif isinstance(m, nn.LayerNorm): 78 | nn.init.constant_(m.bias, 0) 79 | nn.init.constant_(m.weight, 1.0) 80 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 81 | w = m.weight.data 82 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 83 | 84 | 85 | class PreNorm(nn.Module): 86 | def __init__(self, dim: int, fn: nn.Module) -> None: 87 | super().__init__() 88 | self.norm = nn.LayerNorm(dim) 89 | self.fn = fn 90 | 91 | def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor: 92 | return self.fn(self.norm(x), **kwargs) 93 | 94 | 95 | class FeedForward(nn.Module): 96 | def __init__(self, dim: int, hidden_dim: int) -> None: 97 | super().__init__() 98 | self.net = nn.Sequential( 99 | nn.Linear(dim, hidden_dim), 100 | nn.Tanh(), 101 | nn.Linear(hidden_dim, dim) 102 | ) 103 | 104 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 105 | return self.net(x) 106 | 107 | 108 | class Attention(nn.Module): 109 | def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None: 110 | super().__init__() 111 | inner_dim = dim_head * heads 112 | project_out = not (heads == 1 and dim_head == dim) 113 | 114 | self.heads = heads 115 | self.scale = dim_head ** -0.5 116 | 117 | self.attend = nn.Softmax(dim = -1) 118 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 119 | 120 | self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity() 121 | 122 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 123 | qkv = self.to_qkv(x).chunk(3, dim = -1) 124 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 125 | 126 | attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale 127 | attn = self.attend(attn) 128 | 129 | out = torch.matmul(attn, v) 130 | out = rearrange(out, 'b h n d -> b n (h d)') 131 | 132 | return self.to_out(out) 133 | 134 | 135 | class Transformer(nn.Module): 136 | def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None: 137 | super().__init__() 138 | self.layers = nn.ModuleList([]) 139 | for idx in range(depth): 140 | layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)), 141 | PreNorm(dim, FeedForward(dim, mlp_dim))]) 142 | self.layers.append(layer) 143 | self.norm = nn.LayerNorm(dim) 144 | 145 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 146 | for attn, ff in self.layers: 147 | x = attn(x) + x 148 | x = ff(x) + x 149 | 150 | return self.norm(x) 151 | 152 | 153 | class ViTEncoder(nn.Module): 154 | def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int], 155 | dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None: 156 | super().__init__() 157 | image_height, image_width = image_size if isinstance(image_size, tuple) \ 158 | else (image_size, image_size) 159 | patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \ 160 | else (patch_size, patch_size) 161 | 162 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 163 | en_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width)) 164 | 165 | self.num_patches = (image_height // patch_height) * (image_width // patch_width) 166 | self.patch_dim = channels * patch_height * patch_width 167 | 168 | self.to_patch_embedding = nn.Sequential( 169 | nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size), 170 | Rearrange('b c h w -> b (h w) c'), 171 | ) 172 | self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False) 173 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 174 | 175 | self.apply(init_weights) 176 | 177 | def forward(self, img: torch.FloatTensor) -> torch.FloatTensor: 178 | x = self.to_patch_embedding(img) 179 | x = x + self.en_pos_embedding 180 | x = self.transformer(x) 181 | 182 | return x 183 | 184 | 185 | class ViTDecoder(nn.Module): 186 | def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int], 187 | dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None: 188 | super().__init__() 189 | image_height, image_width = image_size if isinstance(image_size, tuple) \ 190 | else (image_size, image_size) 191 | patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \ 192 | else (patch_size, patch_size) 193 | 194 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 195 | de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width)) 196 | 197 | self.num_patches = (image_height // patch_height) * (image_width // patch_width) 198 | self.patch_dim = channels * patch_height * patch_width 199 | 200 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 201 | self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False) 202 | self.to_pixel = nn.Sequential( 203 | Rearrange('b (h w) c -> b c h w', h=image_height // patch_height), 204 | nn.ConvTranspose2d(dim, channels, kernel_size=patch_size, stride=patch_size) 205 | ) 206 | 207 | self.apply(init_weights) 208 | 209 | def forward(self, token: torch.FloatTensor) -> torch.FloatTensor: 210 | x = token + self.de_pos_embedding 211 | x = self.transformer(x) 212 | x = self.to_pixel(x) 213 | 214 | return x 215 | 216 | def get_last_layer(self) -> nn.Parameter: 217 | return self.to_pixel[-1].weight 218 | -------------------------------------------------------------------------------- /enhancing/modules/stage1/quantizers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | import math 11 | from functools import partial 12 | from typing import Tuple, Optional 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class BaseQuantizer(nn.Module): 20 | def __init__(self, embed_dim: int, n_embed: int, straight_through: bool = True, use_norm: bool = True, 21 | use_residual: bool = False, num_quantizers: Optional[int] = None) -> None: 22 | super().__init__() 23 | self.straight_through = straight_through 24 | self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x 25 | 26 | self.use_residual = use_residual 27 | self.num_quantizers = num_quantizers 28 | 29 | self.embed_dim = embed_dim 30 | self.n_embed = n_embed 31 | 32 | self.embedding = nn.Embedding(self.n_embed, self.embed_dim) 33 | self.embedding.weight.data.normal_() 34 | 35 | def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: 36 | pass 37 | 38 | def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: 39 | if not self.use_residual: 40 | z_q, loss, encoding_indices = self.quantize(z) 41 | else: 42 | z_q = torch.zeros_like(z) 43 | residual = z.detach().clone() 44 | 45 | losses = [] 46 | encoding_indices = [] 47 | 48 | for _ in range(self.num_quantizers): 49 | z_qi, loss, indices = self.quantize(residual.clone()) 50 | residual.sub_(z_qi) 51 | z_q.add_(z_qi) 52 | 53 | encoding_indices.append(indices) 54 | losses.append(loss) 55 | 56 | losses, encoding_indices = map(partial(torch.stack, dim = -1), (losses, encoding_indices)) 57 | loss = losses.mean() 58 | 59 | # preserve gradients with straight-through estimator 60 | if self.straight_through: 61 | z_q = z + (z_q - z).detach() 62 | 63 | return z_q, loss, encoding_indices 64 | 65 | 66 | class VectorQuantizer(BaseQuantizer): 67 | def __init__(self, embed_dim: int, n_embed: int, beta: float = 0.25, use_norm: bool = True, 68 | use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None: 69 | super().__init__(embed_dim, n_embed, True, 70 | use_norm, use_residual, num_quantizers) 71 | 72 | self.beta = beta 73 | 74 | def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: 75 | z_reshaped_norm = self.norm(z.view(-1, self.embed_dim)) 76 | embedding_norm = self.norm(self.embedding.weight) 77 | 78 | d = torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) + \ 79 | torch.sum(embedding_norm ** 2, dim=1) - 2 * \ 80 | torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm) 81 | 82 | encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) 83 | encoding_indices = encoding_indices.view(*z.shape[:-1]) 84 | 85 | z_q = self.embedding(encoding_indices).view(z.shape) 86 | z_qnorm, z_norm = self.norm(z_q), self.norm(z) 87 | 88 | # compute loss for embedding 89 | loss = self.beta * torch.mean((z_qnorm.detach() - z_norm)**2) + \ 90 | torch.mean((z_qnorm - z_norm.detach())**2) 91 | 92 | return z_qnorm, loss, encoding_indices 93 | 94 | 95 | class GumbelQuantizer(BaseQuantizer): 96 | def __init__(self, embed_dim: int, n_embed: int, temp_init: float = 1.0, 97 | use_norm: bool = True, use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None: 98 | super().__init__(embed_dim, n_embed, False, 99 | use_norm, use_residual, num_quantizers) 100 | 101 | self.temperature = temp_init 102 | 103 | def quantize(self, z: torch.FloatTensor, temp: Optional[float] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: 104 | # force hard = True when we are in eval mode, as we must quantize 105 | hard = not self.training 106 | temp = self.temperature if temp is None else temp 107 | 108 | z_reshaped_norm = self.norm(z.view(-1, self.embed_dim)) 109 | embedding_norm = self.norm(self.embedding.weight) 110 | 111 | logits = - torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) - \ 112 | torch.sum(embedding_norm ** 2, dim=1) + 2 * \ 113 | torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm) 114 | logits = logits.view(*z.shape[:-1], -1) 115 | 116 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=-1, hard=hard) 117 | z_qnorm = torch.matmul(soft_one_hot, embedding_norm) 118 | 119 | # kl divergence to the prior loss 120 | logits = F.log_softmax(logits, dim=-1) # use log_softmax because it is more numerically stable 121 | loss = torch.sum(logits.exp() * (logits+math.log(self.n_embed)), dim=-1).mean() 122 | 123 | # get encoding via argmax 124 | encoding_indices = soft_one_hot.argmax(dim=-1) 125 | 126 | return z_qnorm, loss, encoding_indices 127 | -------------------------------------------------------------------------------- /enhancing/modules/stage1/vitvqgan.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | from typing import List, Tuple, Dict, Any, Optional 11 | from omegaconf import OmegaConf 12 | 13 | import PIL 14 | import torch 15 | import torch.nn as nn 16 | from torch.optim import lr_scheduler 17 | from torchvision import transforms as T 18 | import pytorch_lightning as pl 19 | 20 | from .layers import ViTEncoder as Encoder, ViTDecoder as Decoder 21 | from .quantizers import VectorQuantizer, GumbelQuantizer 22 | from ...utils.general import initialize_from_config 23 | 24 | 25 | class ViTVQ(pl.LightningModule): 26 | def __init__(self, image_key: str, image_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf, 27 | loss: OmegaConf, path: Optional[str] = None, ignore_keys: List[str] = list(), scheduler: Optional[OmegaConf] = None) -> None: 28 | super().__init__() 29 | self.path = path 30 | self.ignore_keys = ignore_keys 31 | self.image_key = image_key 32 | self.scheduler = scheduler 33 | 34 | self.loss = initialize_from_config(loss) 35 | self.encoder = Encoder(image_size=image_size, patch_size=patch_size, **encoder) 36 | self.decoder = Decoder(image_size=image_size, patch_size=patch_size, **decoder) 37 | self.quantizer = VectorQuantizer(**quantizer) 38 | self.pre_quant = nn.Linear(encoder.dim, quantizer.embed_dim) 39 | self.post_quant = nn.Linear(quantizer.embed_dim, decoder.dim) 40 | 41 | if path is not None: 42 | self.init_from_ckpt(path, ignore_keys) 43 | 44 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 45 | quant, diff = self.encode(x) 46 | dec = self.decode(quant) 47 | 48 | return dec, diff 49 | 50 | def init_from_ckpt(self, path: str, ignore_keys: List[str] = list()): 51 | sd = torch.load(path, map_location="cpu")["state_dict"] 52 | keys = list(sd.keys()) 53 | for k in keys: 54 | for ik in ignore_keys: 55 | if k.startswith(ik): 56 | print("Deleting key {} from state_dict.".format(k)) 57 | del sd[k] 58 | self.load_state_dict(sd, strict=False) 59 | print(f"Restored from {path}") 60 | 61 | def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 62 | h = self.encoder(x) 63 | h = self.pre_quant(h) 64 | quant, emb_loss, _ = self.quantizer(h) 65 | 66 | return quant, emb_loss 67 | 68 | def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor: 69 | quant = self.post_quant(quant) 70 | dec = self.decoder(quant) 71 | 72 | return dec 73 | 74 | def encode_codes(self, x: torch.FloatTensor) -> torch.LongTensor: 75 | h = self.encoder(x) 76 | h = self.pre_quant(h) 77 | _, _, codes = self.quantizer(h) 78 | 79 | return codes 80 | 81 | def decode_codes(self, code: torch.LongTensor) -> torch.FloatTensor: 82 | quant = self.quantizer.embedding(code) 83 | quant = self.quantizer.norm(quant) 84 | 85 | if self.quantizer.use_residual: 86 | quant = quant.sum(-2) 87 | 88 | dec = self.decode(quant) 89 | 90 | return dec 91 | 92 | def get_input(self, batch: Tuple[Any, Any], key: str = 'image') -> Any: 93 | x = batch[key] 94 | if len(x.shape) == 3: 95 | x = x[..., None] 96 | if x.dtype == torch.double: 97 | x = x.float() 98 | 99 | return x.contiguous() 100 | 101 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 102 | x = self.get_input(batch, self.image_key) 103 | xrec, qloss = self(x) 104 | 105 | if optimizer_idx == 0: 106 | # autoencoder 107 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx, 108 | last_layer=self.decoder.get_last_layer(), split="train") 109 | 110 | self.log("train/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 111 | del log_dict_ae["train/total_loss"] 112 | 113 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 114 | 115 | return aeloss 116 | 117 | if optimizer_idx == 1: 118 | # discriminator 119 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx, 120 | last_layer=self.decoder.get_last_layer(), split="train") 121 | 122 | self.log("train/disc_loss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 123 | del log_dict_disc["train/disc_loss"] 124 | 125 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 126 | 127 | return discloss 128 | 129 | def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> Dict: 130 | x = self.get_input(batch, self.image_key) 131 | xrec, qloss = self(x) 132 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, batch_idx, 133 | last_layer=self.decoder.get_last_layer(), split="val") 134 | 135 | rec_loss = log_dict_ae["val/rec_loss"] 136 | 137 | self.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 138 | self.log("val/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 139 | del log_dict_ae["val/rec_loss"] 140 | del log_dict_ae["val/total_loss"] 141 | 142 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 143 | 144 | if hasattr(self.loss, 'discriminator'): 145 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, batch_idx, 146 | last_layer=self.decoder.get_last_layer(), split="val") 147 | 148 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 149 | 150 | return self.log_dict 151 | 152 | def configure_optimizers(self) -> Tuple[List, List]: 153 | lr = self.learning_rate 154 | optim_groups = list(self.encoder.parameters()) + \ 155 | list(self.decoder.parameters()) + \ 156 | list(self.pre_quant.parameters()) + \ 157 | list(self.post_quant.parameters()) + \ 158 | list(self.quantizer.parameters()) 159 | 160 | optimizers = [torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)] 161 | schedulers = [] 162 | 163 | if hasattr(self.loss, 'discriminator'): 164 | optimizers.append(torch.optim.AdamW(self.loss.discriminator.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)) 165 | 166 | if self.scheduler is not None: 167 | self.scheduler.params.start = lr 168 | scheduler = initialize_from_config(self.scheduler) 169 | 170 | schedulers = [ 171 | { 172 | 'scheduler': lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler.schedule), 173 | 'interval': 'step', 174 | 'frequency': 1 175 | } for optimizer in optimizers 176 | ] 177 | 178 | return optimizers, schedulers 179 | 180 | def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict: 181 | log = dict() 182 | x = self.get_input(batch, self.image_key).to(self.device) 183 | quant, _ = self.encode(x) 184 | 185 | log["originals"] = x 186 | log["reconstructions"] = self.decode(quant) 187 | 188 | return log 189 | 190 | 191 | class ViTVQGumbel(ViTVQ): 192 | def __init__(self, image_key: str, image_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf, loss: OmegaConf, 193 | path: Optional[str] = None, ignore_keys: List[str] = list(), temperature_scheduler: OmegaConf = None, scheduler: Optional[OmegaConf] = None) -> None: 194 | super().__init__(image_key, image_size, patch_size, encoder, decoder, quantizer, loss, None, None, scheduler) 195 | 196 | self.temperature_scheduler = initialize_from_config(temperature_scheduler) \ 197 | if temperature_scheduler else None 198 | self.quantizer = GumbelQuantizer(**quantizer) 199 | 200 | if path is not None: 201 | self.init_from_ckpt(path, ignore_keys) 202 | 203 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 204 | if self.temperature_scheduler: 205 | self.quantizer.temperature = self.temperature_scheduler(self.global_step) 206 | 207 | loss = super().training_step(batch, batch_idx, optimizer_idx) 208 | 209 | if optimizer_idx == 0: 210 | self.log("temperature", self.quantizer.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) 211 | 212 | return loss 213 | -------------------------------------------------------------------------------- /enhancing/modules/stage2/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from minDALL-E (https://github.com/kakaobrain/minDALL-E) 7 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | # Modified from minGPT (https://github.com/karpathy/minGPT) 10 | # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved. 11 | # ------------------------------------------------------------------------------------ 12 | 13 | import math 14 | from omegaconf import OmegaConf 15 | from typing import Optional, Tuple, List 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | from torch.cuda.amp import autocast 21 | 22 | 23 | class MultiHeadSelfAttention(nn.Module): 24 | def __init__(self, 25 | ctx_len: int, 26 | cond_len: int, 27 | embed_dim: int, 28 | n_heads: int, 29 | attn_bias: bool, 30 | use_mask: bool = True): 31 | super().__init__() 32 | assert embed_dim % n_heads == 0 33 | 34 | # key, query, value projections for all heads 35 | self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias) 36 | self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias) 37 | self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias) 38 | 39 | # output projection 40 | self.proj = nn.Linear(embed_dim, embed_dim, attn_bias) 41 | 42 | self.n_heads = n_heads 43 | self.ctx_len = ctx_len 44 | self.use_mask = use_mask 45 | if self.use_mask: 46 | self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False) 47 | self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len) 48 | self.mask[:, :cond_len, :cond_len] = 1 49 | 50 | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) 51 | with torch.no_grad(): 52 | ww = torch.zeros(1, 1, embed_dim) 53 | for i in range(embed_dim): 54 | ww[0, 0, i] = i / (embed_dim - 1) 55 | self.time_mix = nn.Parameter(ww) 56 | 57 | def forward(self, x, use_cache=False, layer_past=None): 58 | B, T, C = x.shape 59 | 60 | x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix) 61 | x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C) 62 | 63 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 64 | k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 65 | q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 66 | v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 67 | 68 | if use_cache: 69 | present = torch.stack([k, v]) 70 | 71 | if layer_past is not None: 72 | past_key, past_value = layer_past 73 | k = torch.cat([past_key, k], dim=-2) 74 | v = torch.cat([past_value, v], dim=-2) 75 | 76 | if use_cache and layer_past is not None: 77 | # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K) 78 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) 79 | att = F.softmax(att, dim=-1) 80 | y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs) 81 | else: 82 | # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T) 83 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) 84 | if self.use_mask: 85 | mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T] 86 | att = att.masked_fill(mask == 0, float('-inf')) 87 | att = F.softmax(att, dim=-1) 88 | y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs) 89 | y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side 90 | 91 | # output projection 92 | y = self.proj(y) 93 | 94 | if use_cache: 95 | return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C) 96 | else: 97 | return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C) 98 | 99 | class FFN(nn.Module): 100 | def __init__(self, embed_dim, mlp_bias): 101 | super().__init__() 102 | self.p0 = nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias) 103 | self.p1 = nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias) 104 | 105 | def forward(self, x): 106 | x = self.p0(x) 107 | # x = F.gelu(x) 108 | x = torch.square(torch.relu(x)) 109 | x = self.p1(x) 110 | return x 111 | 112 | class Block(nn.Module): 113 | def __init__(self, 114 | ctx_len: int, 115 | cond_len: int, 116 | embed_dim: int, 117 | n_heads: int, 118 | mlp_bias: bool, 119 | attn_bias: bool): 120 | super().__init__() 121 | self.ln1 = nn.LayerNorm(embed_dim) 122 | self.ln2 = nn.LayerNorm(embed_dim) 123 | 124 | self.attn = MultiHeadSelfAttention(ctx_len=ctx_len, 125 | cond_len=cond_len, 126 | embed_dim=embed_dim, 127 | n_heads=n_heads, 128 | attn_bias=attn_bias, 129 | use_mask=True) 130 | self.mlp = FFN(embed_dim=embed_dim, mlp_bias=mlp_bias) 131 | 132 | def forward(self, x): 133 | x = x + self.attn(self.ln1(x)) 134 | x = x + self.mlp(self.ln2(x)) 135 | 136 | return x 137 | 138 | def sample(self, x, layer_past=None): 139 | attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past) 140 | x = x + attn 141 | x = x + self.mlp(self.ln2(x)) 142 | 143 | return x, present 144 | 145 | 146 | class GPT(nn.Module): 147 | def __init__(self, 148 | vocab_cond_size: int, 149 | vocab_img_size: int, 150 | embed_dim: int, 151 | cond_num_tokens: int, 152 | img_num_tokens: int, 153 | n_heads: int, 154 | n_layers: int, 155 | mlp_bias: bool = True, 156 | attn_bias: bool = True) -> None: 157 | super().__init__() 158 | self.img_num_tokens = img_num_tokens 159 | self.vocab_cond_size = vocab_cond_size 160 | 161 | # condition token and position embedding 162 | self.tok_emb_cond = nn.Embedding(vocab_cond_size, embed_dim) 163 | self.pos_emb_cond = nn.Parameter(torch.zeros(1, cond_num_tokens, embed_dim)) 164 | 165 | # input token and position embedding 166 | self.tok_emb_code = nn.Embedding(vocab_img_size, embed_dim) 167 | self.pos_emb_code = nn.Parameter(torch.zeros(1, img_num_tokens, embed_dim)) 168 | 169 | # transformer blocks 170 | self.blocks = [Block(ctx_len=cond_num_tokens + img_num_tokens, 171 | cond_len=cond_num_tokens, 172 | embed_dim=embed_dim, 173 | n_heads=n_heads, 174 | mlp_bias=mlp_bias, 175 | attn_bias=attn_bias) for i in range(1, n_layers+1)] 176 | self.blocks = nn.Sequential(*self.blocks) 177 | 178 | # head 179 | self.layer_norm = nn.LayerNorm(embed_dim) 180 | self.head = nn.Linear(embed_dim, vocab_img_size, bias=False) 181 | 182 | self.apply(self._init_weights) 183 | 184 | def _init_weights(self, module: nn.Module) -> None: 185 | if isinstance(module, (nn.Linear, nn.Embedding)): 186 | module.weight.data.normal_(mean=0.0, std=0.02) 187 | if isinstance(module, nn.Linear) and module.bias is not None: 188 | module.bias.data.zero_() 189 | elif isinstance(module, nn.LayerNorm): 190 | module.bias.data.zero_() 191 | module.weight.data.fill_(1.0) 192 | 193 | def forward(self, 194 | codes: torch.LongTensor, 195 | conds: torch.LongTensor) -> torch.FloatTensor: 196 | 197 | codes = codes.view(codes.shape[0], -1) 198 | codes = self.tok_emb_code(codes) 199 | conds = self.tok_emb_cond(conds) 200 | 201 | codes = codes + self.pos_emb_code 202 | conds = conds + self.pos_emb_cond 203 | 204 | x = torch.cat([conds, codes], axis=1).contiguous() 205 | x = self.blocks(x) 206 | x = self.layer_norm(x) 207 | 208 | x = x[:, conds.shape[1]-1:-1].contiguous() 209 | logits = self.head(x) 210 | 211 | return logits 212 | 213 | def sample(self, 214 | conds: torch.LongTensor, 215 | top_k: Optional[float] = None, 216 | top_p: Optional[float] = None, 217 | softmax_temperature: float = 1.0, 218 | use_fp16: bool = True) -> Tuple[torch.FloatTensor, torch.LongTensor]: 219 | 220 | past = codes = logits = None 221 | 222 | for i in range(self.img_num_tokens): 223 | if codes is None: 224 | codes_ = None 225 | pos_code = None 226 | else: 227 | codes_ = codes.clone().detach() 228 | codes_ = codes_[:, -1:] 229 | pos_code = self.pos_emb_code[:, i-1:i, :] 230 | 231 | logits_, presents = self.sample_step(codes_, conds, pos_code, use_fp16, past) 232 | 233 | logits_ = logits_.to(dtype=torch.float32) 234 | logits_ = logits_ / softmax_temperature 235 | 236 | presents = torch.stack(presents).clone().detach() 237 | if past is None: 238 | past = [presents] 239 | else: 240 | past.append(presents) 241 | 242 | if top_k is not None: 243 | v, ix = torch.topk(logits_, top_k) 244 | logits_[logits_ < v[:, [-1]]] = -float('Inf') 245 | probs = F.softmax(logits_, dim=-1) 246 | 247 | if top_p is not None: 248 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 249 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 250 | 251 | sorted_idx_remove_cond = cum_probs >= top_p 252 | 253 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() 254 | sorted_idx_remove_cond[..., 0] = 0 255 | 256 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) 257 | probs = probs.masked_fill(indices_to_remove, 0.0) 258 | probs = probs / torch.sum(probs, dim=-1, keepdim=True) 259 | 260 | idx = torch.multinomial(probs, num_samples=1).clone().detach() 261 | codes = idx if codes is None else torch.cat([codes, idx], axis=1) 262 | logits = logits_ if logits is None else torch.cat([logits, logits_], axis=1) 263 | 264 | del past 265 | 266 | return logits, codes 267 | 268 | def sample_step(self, 269 | codes: torch.LongTensor, 270 | conds: torch.LongTensor, 271 | pos_code: torch.LongTensor, 272 | use_fp16: bool = True, 273 | past: Optional[torch.FloatTensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]: 274 | 275 | with autocast(enabled=use_fp16): 276 | presents = [] 277 | 278 | if codes is None: 279 | assert past is None 280 | conds = self.tok_emb_cond(conds) 281 | x = conds + self.pos_emb_cond 282 | 283 | for i, block in enumerate(self.blocks): 284 | x, present = block.sample(x, layer_past=None) 285 | presents.append(present) 286 | x = self.layer_norm(x) 287 | x = x[:, conds.shape[1]-1].contiguous() 288 | else: 289 | assert past is not None 290 | codes = self.tok_emb_code(codes) 291 | x = codes + pos_code 292 | 293 | past = torch.cat(past, dim=-2) 294 | for i, block in enumerate(self.blocks): 295 | x, present = block.sample(x, layer_past=past[i]) 296 | presents.append(present) 297 | 298 | x = self.layer_norm(x) 299 | x = x[:, -1].contiguous() 300 | 301 | logits = self.head(x) 302 | 303 | return logits, presents 304 | 305 | 306 | class RQTransformer(nn.Module): 307 | def __init__(self, 308 | vocab_cond_size: int, 309 | vocab_img_size: int, 310 | embed_dim: int, 311 | cond_num_tokens: int, 312 | img_num_tokens: int, 313 | depth_num_tokens: int, 314 | spatial_n_heads: int, 315 | depth_n_heads: int, 316 | spatial_n_layers: int, 317 | depth_n_layers: int, 318 | mlp_bias: bool = True, 319 | attn_bias: bool = True) -> None: 320 | super().__init__() 321 | self.img_num_tokens = img_num_tokens 322 | self.depth_num_tokens = depth_num_tokens 323 | self.vocab_img_size = vocab_img_size 324 | 325 | # condition token and position embedding 326 | self.tok_emb_cond = nn.Embedding(vocab_cond_size, embed_dim) 327 | self.pos_emb_cond = nn.Parameter(torch.rand(1, cond_num_tokens, embed_dim)) 328 | 329 | # spatial token and position embedding 330 | self.tok_emb_code = nn.Embedding(vocab_img_size, embed_dim) 331 | self.pos_emb_code = nn.Parameter(torch.rand(1, img_num_tokens, embed_dim)) 332 | 333 | # depth position embedding 334 | self.pos_emb_depth = nn.Parameter(torch.rand(1, depth_num_tokens-1, embed_dim)) 335 | 336 | # spatial transformer 337 | self.spatial_transformer = [Block(ctx_len=cond_num_tokens + img_num_tokens, 338 | cond_len=cond_num_tokens, 339 | embed_dim=embed_dim, 340 | n_heads=spatial_n_heads, 341 | mlp_bias=mlp_bias, 342 | attn_bias=attn_bias) for i in range(1, spatial_n_layers+1)] 343 | self.spatial_transformer = nn.Sequential(*self.spatial_transformer) 344 | 345 | # depth transformer 346 | self.depth_transformer = [Block(ctx_len=depth_num_tokens, 347 | cond_len=0, 348 | embed_dim=embed_dim, 349 | n_heads=depth_n_heads, 350 | mlp_bias=mlp_bias, 351 | attn_bias=attn_bias) for i in range(1, depth_n_layers+1)] 352 | self.depth_transformer = nn.Sequential(*self.depth_transformer) 353 | 354 | # head 355 | self.ln_spatial = nn.LayerNorm(embed_dim) 356 | self.ln_depth = nn.LayerNorm(embed_dim) 357 | self.head = nn.Linear(embed_dim, vocab_img_size, bias=False) 358 | 359 | self.apply(self._init_weights) 360 | 361 | def _init_weights(self, module: nn.Module) -> None: 362 | if isinstance(module, (nn.Linear, nn.Embedding)): 363 | module.weight.data.normal_(mean=0.0, std=0.02) 364 | if isinstance(module, nn.Linear) and module.bias is not None: 365 | module.bias.data.zero_() 366 | elif isinstance(module, nn.LayerNorm): 367 | module.bias.data.zero_() 368 | module.weight.data.fill_(1.0) 369 | 370 | def forward(self, 371 | codes: torch.LongTensor, 372 | conds: torch.LongTensor) -> torch.FloatTensor: 373 | 374 | codes = codes.view(codes.shape[0], -1, codes.shape[-1]) 375 | codes = self.tok_emb_code(codes) 376 | conds = self.tok_emb_cond(conds) 377 | 378 | codes_cumsum = codes.cumsum(-1) 379 | codes_sum = codes_cumsum[..., -1, :] 380 | 381 | codes = codes_sum + self.pos_emb_code 382 | conds = conds + self.pos_emb_cond 383 | 384 | h = torch.cat([conds, codes], axis=1).contiguous() 385 | h = self.ln_spatial(self.spatial_transformer(h)) 386 | h = h[:, conds.shape[1]-1:-1].contiguous() 387 | 388 | v = codes_cumsum[..., :-1, :] + self.pos_emb_depth 389 | v = torch.cat([h.unsqueeze(2), v], axis=2).contiguous() 390 | 391 | v = v.view(-1, *v.shape[2:]) 392 | v = self.depth_transformer(v) 393 | logits = self.head(self.ln_depth(v)) 394 | 395 | return logits 396 | 397 | def sample(self, 398 | conds: torch.LongTensor, 399 | top_k: Optional[float] = None, 400 | top_p: Optional[float] = None, 401 | softmax_temperature: float = 1.0, 402 | use_fp16: bool = True) -> Tuple[torch.FloatTensor, torch.LongTensor]: 403 | 404 | past = codes = logits = None 405 | B, T, D, S = conds.shape[0], self.img_num_tokens, self.depth_num_tokens, self.vocab_img_size 406 | 407 | for i in range(self.img_num_tokens): 408 | depth_past = None 409 | 410 | if codes is None: 411 | codes_ = None 412 | pos_code = None 413 | else: 414 | codes_ = codes.clone().detach() 415 | codes_ = codes_[:, -self.depth_num_tokens:] 416 | pos_code = self.pos_emb_code[:, i-1:i, :] 417 | 418 | hidden, presents = self.sample_spatial_step(codes_, conds, pos_code, use_fp16, past) 419 | 420 | presents = torch.stack(presents).clone().detach() 421 | if past is None: 422 | past = [presents] 423 | else: 424 | past.append(presents) 425 | 426 | last_len = 0 if codes is None else codes.shape[-1] 427 | 428 | for d in range(self.depth_num_tokens): 429 | if depth_past is None: 430 | codes_ = None 431 | pos_depth = None 432 | else: 433 | codes_ = codes.clone().detach() 434 | codes_ = codes_[:, last_len:] 435 | pos_depth = self.pos_emb_depth[:, d-1:d, :] 436 | 437 | logits_, depth_presents = self.sample_depth_step(codes_, hidden, pos_depth, use_fp16, depth_past) 438 | 439 | logits_ = logits_.to(dtype=torch.float32) 440 | logits_ = logits_ / softmax_temperature 441 | 442 | depth_presents = torch.stack(depth_presents).clone().detach() 443 | if depth_past is None: 444 | depth_past = [depth_presents] 445 | else: 446 | depth_past.append(depth_presents) 447 | 448 | if top_k is not None: 449 | v, ix = torch.topk(logits_, top_k) 450 | logits_[logits_ < v[:, [-1]]] = -float('Inf') 451 | probs = F.softmax(logits_, dim=-1) 452 | 453 | if top_p is not None: 454 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 455 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 456 | 457 | sorted_idx_remove_cond = cum_probs >= top_p 458 | 459 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() 460 | sorted_idx_remove_cond[..., 0] = 0 461 | 462 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) 463 | probs = probs.masked_fill(indices_to_remove, 0.0) 464 | probs = probs / torch.sum(probs, dim=-1, keepdim=True) 465 | 466 | idx = torch.multinomial(probs, num_samples=1).clone().detach() 467 | codes = idx if codes is None else torch.cat([codes, idx], axis=1) 468 | logits = logits_ if logits is None else torch.cat([logits, logits_], axis=1) 469 | 470 | del depth_past 471 | 472 | del past 473 | 474 | codes = codes.view(B, T, D) 475 | logits = logits.view(B * T, D, S) 476 | 477 | return logits, codes 478 | 479 | def sample_spatial_step(self, 480 | codes: torch.LongTensor, 481 | conds: torch.LongTensor, 482 | pos_code: torch.LongTensor, 483 | use_fp16: bool = True, 484 | past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]: 485 | 486 | with autocast(enabled=use_fp16): 487 | presents = [] 488 | 489 | if codes is None: 490 | assert past is None 491 | conds = self.tok_emb_cond(conds) 492 | x = conds + self.pos_emb_cond 493 | 494 | for i, block in enumerate(self.spatial_transformer): 495 | x, present = block.sample(x, layer_past=None) 496 | presents.append(present) 497 | x = self.ln_spatial(x) 498 | x = x[:, conds.shape[1]-1:conds.shape[1]].contiguous() 499 | else: 500 | assert past is not None 501 | codes = self.tok_emb_code(codes) 502 | x = codes.sum(1, keepdim=True) + pos_code 503 | 504 | past = torch.cat(past, dim=-2) 505 | for i, block in enumerate(self.spatial_transformer): 506 | x, present = block.sample(x, layer_past=past[i]) 507 | presents.append(present) 508 | 509 | x = self.ln_spatial(x) 510 | x = x[:, -1:].contiguous() 511 | 512 | return x, presents 513 | 514 | def sample_depth_step(self, 515 | codes: torch.LongTensor, 516 | hidden: torch.FloatTensor, 517 | pos_depth: torch.LongTensor, 518 | use_fp16: bool = True, 519 | past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]: 520 | 521 | with autocast(enabled=use_fp16): 522 | presents = [] 523 | 524 | if codes is None: 525 | assert past is None 526 | x = hidden 527 | 528 | for i, block in enumerate(self.depth_transformer): 529 | x, present = block.sample(x, layer_past=None) 530 | presents.append(present) 531 | x = self.ln_depth(x) 532 | else: 533 | assert past is not None 534 | codes = self.tok_emb_code(codes) 535 | x = codes.sum(1, keepdim=True) + pos_depth 536 | 537 | past = torch.cat(past, dim=-2) 538 | for i, block in enumerate(self.depth_transformer): 539 | x, present = block.sample(x, layer_past=past[i]) 540 | presents.append(present) 541 | 542 | x = self.ln_depth(x) 543 | x = x[:, -1].contiguous() 544 | 545 | logits = self.head(x) 546 | 547 | return logits, presents 548 | -------------------------------------------------------------------------------- /enhancing/modules/stage2/transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | from typing import Optional, Tuple, Dict, Union, Any 11 | from omegaconf import OmegaConf 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.optim import lr_scheduler 16 | import torch.nn.functional as F 17 | import pytorch_lightning as pl 18 | 19 | from .layers import * 20 | from ...utils.general import initialize_from_config 21 | 22 | 23 | class CondTransformer(pl.LightningModule): 24 | def __init__(self, cond_key: str, cond: OmegaConf, stage1: OmegaConf, transformer: OmegaConf, 25 | path: Optional[str] = None, ignore_keys: List[str] = list(), 26 | code_shape: List[int] = None, scheduler: Optional[OmegaConf] = None) -> None: 27 | super().__init__() 28 | 29 | # get condition key, code shape and scheduler 30 | self.cond_key = cond_key 31 | self.code_shape = code_shape 32 | self.scheduler = scheduler 33 | 34 | # load condition model 35 | self.cond_model = initialize_from_config(cond) 36 | 37 | # load stage1 model 38 | self.stage1_model = initialize_from_config(stage1) 39 | 40 | # load transformer 41 | self.transformer = initialize_from_config(transformer) 42 | 43 | # make the parameters in stage1 model not trainable 44 | self.stage1_model.eval() 45 | for p in self.stage1_model.parameters(): 46 | p.requires_grad = False 47 | 48 | # make the parameters in condition model not trainable 49 | self.cond_model.eval() 50 | for p in self.cond_model.parameters(): 51 | p.requires_grad = False 52 | 53 | if path is not None: 54 | self.init_from_ckpt(path, ignore_keys) 55 | 56 | def forward(self, 57 | codes: torch.LongTensor, 58 | conds: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 59 | 60 | conds = conds.view(conds.shape[0], -1) 61 | logits = self.transformer(codes, conds) 62 | 63 | codes = codes.view(-1, codes.shape[-1]) 64 | 65 | return logits, codes 66 | 67 | def init_from_ckpt(self, path: str, ignore_keys: List[str] = list()): 68 | sd = torch.load(path, map_location="cpu")["state_dict"] 69 | keys = list(sd.keys()) 70 | for k in keys: 71 | for ik in ignore_keys: 72 | if k.startswith(ik): 73 | print("Deleting key {} from state_dict.".format(k)) 74 | del sd[k] 75 | self.load_state_dict(sd, strict=False) 76 | print(f"Restored from {path}") 77 | 78 | @torch.no_grad() 79 | def sample(self, 80 | conds: torch.LongTensor, 81 | top_k: Optional[float] = None, 82 | top_p: Optional[float] = None, 83 | softmax_temperature: float = 1.0, 84 | use_fp16: bool = True) -> torch.FloatTensor: 85 | 86 | conds = conds.view(conds.shape[0], -1) 87 | logits, codes = self.transformer.sample(conds=conds, top_k=top_k, top_p=top_p, 88 | softmax_temperature=softmax_temperature, 89 | use_fp16=use_fp16) 90 | 91 | if self.code_shape is not None: 92 | codes = codes.view(codes.shape[0], *self.code_shape) 93 | pixels = self.stage1_model.decode_codes(codes).clamp(0, 1) 94 | 95 | return pixels 96 | 97 | def get_input(self, batch: Tuple[Any, Any], key: str) -> torch.FloatTensor: 98 | x = batch[key] 99 | 100 | if len(x.shape) == 3: 101 | x = x[..., None] 102 | if x.dtype == torch.double: 103 | x = x.float() 104 | 105 | return x.contiguous() 106 | 107 | def shared_step(self, batch: Tuple[Any, Any], batch_idx: int) -> torch.FloatTensor: 108 | images = self.get_input(batch, self.stage1_model.image_key) 109 | conds = self.get_input(batch, self.cond_key) 110 | 111 | with torch.no_grad(): 112 | codes = self.stage1_model.encode_codes(images).detach() 113 | conds = self.cond_model.encode_codes(conds).detach() 114 | 115 | logits, codes = self(codes, conds) 116 | loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1)) 117 | 118 | return loss 119 | 120 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 121 | loss = self.shared_step(batch, batch_idx) 122 | self.log("train/total_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 123 | 124 | return loss 125 | 126 | def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> torch.FloatTensor: 127 | loss = self.shared_step(batch, batch_idx) 128 | self.log("val/total_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 129 | 130 | return loss 131 | 132 | def configure_optimizers(self) -> torch.optim.Optimizer: 133 | """ 134 | Following minGPT: 135 | This long function is unfortunately doing something very simple and is being very defensive: 136 | We are separating out all parameters of the model into two buckets: those that will experience 137 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 138 | We are then returning the PyTorch optimizer object. 139 | """ 140 | # separate out all parameters to those that will and won't experience regularizing weight decay 141 | decay = set() 142 | no_decay = set() 143 | whitelist_weight_modules = (torch.nn.Linear, ) 144 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 145 | for mn, m in self.transformer.named_modules(): 146 | for pn, p in m.named_parameters(): 147 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 148 | 149 | if pn.endswith('bias'): 150 | # all biases will not be decayed 151 | no_decay.add(fpn) 152 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 153 | # weights of whitelist modules will be weight decayed 154 | decay.add(fpn) 155 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 156 | # weights of blacklist modules will NOT be weight decayed 157 | no_decay.add(fpn) 158 | elif 'time_' in pn: # for RWKV 159 | no_decay.add(fpn) 160 | 161 | # special case the position embedding parameter in the root GPT module as not decayed 162 | no_decay.add('pos_emb_cond') 163 | no_decay.add('pos_emb_code') 164 | 165 | if hasattr(self.transformer, 'pos_emb_depth'): 166 | no_decay.add('pos_emb_depth') 167 | 168 | # validate that we considered every parameter 169 | param_dict = {pn: p for pn, p in self.transformer.named_parameters()} 170 | inter_params = decay & no_decay 171 | union_params = decay | no_decay 172 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 173 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay/ignored set!" \ 174 | % (str(param_dict.keys() - union_params), ) 175 | 176 | # create the pytorch optimizer object 177 | optim_groups = [ 178 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, 179 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 180 | ] 181 | optimizer = [torch.optim.Adam(optim_groups, lr=self.learning_rate, betas=(0.9, 0.96))] 182 | scheduler = [] 183 | 184 | if self.scheduler is not None: 185 | self.scheduler.params.start = lr 186 | scheduler = initialize_from_config(self.scheduler) 187 | 188 | scheduler = [{ 189 | 'scheduler': lr_scheduler.LambdaLR(optimizer[0], lr_lambda=self.scheduler.schedule), 190 | 'interval': 'step', 191 | 'frequency': 1 192 | }] 193 | 194 | return optimizer, scheduler 195 | 196 | def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict: 197 | log = dict() 198 | 199 | conds = self.get_input(batch, self.cond_key).to(self.device) 200 | cond_codes = self.cond_model.encode_codes(conds).detach() 201 | 202 | log["conditions"] = self.cond_model.to_img(conds) 203 | log["first samples"] = self.sample(cond_codes) 204 | log["second samples"] = self.sample(cond_codes) 205 | 206 | return log 207 | -------------------------------------------------------------------------------- /enhancing/utils/callback.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import os 7 | import wandb 8 | import numpy as np 9 | from PIL import Image 10 | from pathlib import Path 11 | from omegaconf import OmegaConf 12 | from typing import Tuple, Generic, Dict 13 | 14 | import torch 15 | import torchvision 16 | import pytorch_lightning as pl 17 | from pytorch_lightning.utilities.distributed import rank_zero_only 18 | from pytorch_lightning.callbacks import Callback 19 | 20 | 21 | class SetupCallback(Callback): 22 | def __init__(self, config: OmegaConf, exp_config: OmegaConf, basedir: Path, logdir: str = "log", ckptdir:str = "ckpt") -> None: 23 | super().__init__() 24 | self.logdir = basedir / logdir 25 | self.ckptdir = basedir / ckptdir 26 | self.config = config 27 | self.exp_config = exp_config 28 | 29 | def on_pretrain_routine_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None: 30 | if trainer.global_rank == 0: 31 | # Create logdirs and save configs 32 | os.makedirs(self.logdir, exist_ok=True) 33 | os.makedirs(self.ckptdir, exist_ok=True) 34 | 35 | print("Experiment config") 36 | print(self.exp_config.pretty()) 37 | 38 | print("Model config") 39 | print(self.config.pretty()) 40 | 41 | 42 | class ImageLogger(Callback): 43 | def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True, increase_log_steps: bool =True) -> None: 44 | super().__init__() 45 | self.batch_freq = batch_frequency 46 | self.max_images = max_images 47 | self.logger_log_images = { 48 | pl.loggers.WandbLogger: self._wandb, 49 | pl.loggers.TestTubeLogger: self._testtube, 50 | } 51 | self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] 52 | if not increase_log_steps: 53 | self.log_steps = [self.batch_freq] 54 | self.clamp = clamp 55 | 56 | @rank_zero_only 57 | def _wandb(self, pl_module, images, batch_idx, split): 58 | #raise ValueError("No way wandb") 59 | grids = dict() 60 | for k in images: 61 | grid = torchvision.utils.make_grid(images[k]) 62 | grids[f"{split}/{k}"] = wandb.Image(grid) 63 | pl_module.logger.experiment.log(grids) 64 | 65 | @rank_zero_only 66 | def _testtube(self, pl_module, images, batch_idx, split): 67 | for k in images: 68 | grid = torchvision.utils.make_grid(images[k]) 69 | grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w 70 | 71 | tag = f"{split}/{k}" 72 | pl_module.logger.experiment.add_image( 73 | tag, grid, 74 | global_step=pl_module.global_step) 75 | 76 | @rank_zero_only 77 | def log_local(self, save_dir: str, split: str, images: Dict, 78 | global_step: int, current_epoch: int, batch_idx: int) -> None: 79 | root = os.path.join(save_dir, "results", split) 80 | os.makedirs(root, exist_ok=True) 81 | for k in images: 82 | grid = torchvision.utils.make_grid(images[k], nrow=4) 83 | 84 | grid = grid.transpose(0,1).transpose(1,2).squeeze(-1) 85 | grid = grid.numpy() 86 | grid = (grid*255).astype(np.uint8) 87 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( 88 | k, 89 | global_step, 90 | current_epoch, 91 | batch_idx) 92 | path = os.path.join(root, filename) 93 | os.makedirs(os.path.split(path)[0], exist_ok=True) 94 | Image.fromarray(grid).save(path) 95 | 96 | def log_img(self, pl_module: pl.LightningModule, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int, split: str = "train") -> None: 97 | if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0 98 | hasattr(pl_module, "log_images") and 99 | callable(pl_module.log_images) and 100 | self.max_images > 0): 101 | logger = type(pl_module.logger) 102 | 103 | is_train = pl_module.training 104 | if is_train: 105 | pl_module.eval() 106 | 107 | with torch.no_grad(): 108 | images = pl_module.log_images(batch, split=split, pl_module=pl_module) 109 | 110 | for k in images: 111 | N = min(images[k].shape[0], self.max_images) 112 | images[k] = images[k][:N].detach().cpu() 113 | if self.clamp: 114 | images[k] = images[k].clamp(0, 1) 115 | 116 | self.log_local(pl_module.logger.save_dir, split, images, 117 | pl_module.global_step, pl_module.current_epoch, batch_idx) 118 | 119 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) 120 | logger_log_images(pl_module, images, pl_module.global_step, split) 121 | 122 | if is_train: 123 | pl_module.train() 124 | 125 | def check_frequency(self, batch_idx: int) -> bool: 126 | if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps): 127 | try: 128 | self.log_steps.pop(0) 129 | except IndexError: 130 | pass 131 | return True 132 | return False 133 | 134 | def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule, 135 | outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int) -> None: 136 | self.log_img(pl_module, batch, batch_idx, split="train") 137 | 138 | def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule, 139 | outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], 140 | dataloader_idx: int, batch_idx: int) -> None: 141 | self.log_img(pl_module, batch, batch_idx, split="val") 142 | -------------------------------------------------------------------------------- /enhancing/utils/general.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import os 7 | import random 8 | import importlib 9 | import pathlib 10 | from typing import Tuple, List, Dict, ClassVar 11 | import numpy as np 12 | from omegaconf import OmegaConf 13 | from datetime import datetime 14 | 15 | import torch 16 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback 17 | from pytorch_lightning.loggers import WandbLogger 18 | 19 | from .callback import * 20 | 21 | 22 | def set_seed(seed: int): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | 29 | def get_obj_from_str(name: str, reload: bool = False) -> ClassVar: 30 | module, cls = name.rsplit(".", 1) 31 | 32 | if reload: 33 | module_imp = importlib.import_module(module) 34 | importlib.reload(module_imp) 35 | 36 | return getattr(importlib.import_module(module, package=None), cls) 37 | 38 | 39 | def initialize_from_config(config: OmegaConf) -> object: 40 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 41 | 42 | 43 | def setup_callbacks(exp_config: OmegaConf, config: OmegaConf) -> Tuple[List[Callback], WandbLogger]: 44 | now = datetime.now().strftime('%d%m%Y_%H%M%S') 45 | basedir = pathlib.Path("experiments", exp_config.name, now) 46 | os.makedirs(basedir, exist_ok=True) 47 | 48 | setup_callback = SetupCallback(config, exp_config, basedir) 49 | checkpoint_callback = ModelCheckpoint( 50 | dirpath=setup_callback.ckptdir, 51 | filename=exp_config.name+"-{epoch:02d}", 52 | monitor="train/total_loss", 53 | save_top_k=-1, 54 | verbose=False, 55 | ) 56 | os.makedirs(setup_callback.logdir/'wandb', exist_ok=True) 57 | logger = WandbLogger(save_dir=str(setup_callback.logdir), name=exp_config.name+"_"+str(now)) 58 | logger_img_callback = ImageLogger(exp_config.batch_frequency, exp_config.max_images) 59 | 60 | return [setup_callback, checkpoint_callback, logger_img_callback], logger 61 | 62 | 63 | def get_config_from_file(config_file: str) -> Dict: 64 | config_file = OmegaConf.load(config_file) 65 | 66 | if 'base_config' in config_file.keys(): 67 | if config_file['base_config'] == "default_base": 68 | base_config = get_default_config() 69 | elif config_file['base_config'].endswith(".yaml"): 70 | base_config = get_config_from_file(config_file['base_config']) 71 | 72 | config_file = {key: value for key, value in config_file if key != "base_config"} 73 | 74 | return OmegaConf.merge(base_config, config_file) 75 | 76 | return config_file 77 | -------------------------------------------------------------------------------- /enhancing/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers) 7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | import numpy as np 11 | 12 | 13 | class BaseScheduler: 14 | def __init__(self): 15 | pass 16 | 17 | def schedule(self, n: int) -> float: 18 | pass 19 | 20 | def __call__(self, n: int) -> float: 21 | assert hasattr(self, 'start') 22 | 23 | return self.schedule(n) * self.start 24 | 25 | 26 | class ExponentialDecayScheduler(BaseScheduler): 27 | def __init__(self, start: float, end: float, decay_every_step: int, scale_factor: float) -> None: 28 | super().__init__() 29 | self.decay_every_step = decay_every_step 30 | self.scale_factor = scale_factor 31 | 32 | self.start = start 33 | self.end = end 34 | self.current = start 35 | 36 | def schedule(self, n: int) -> float: 37 | if not n % self.decay_every_step: 38 | res = np.exp(-self.scale_factor*n) * self.start 39 | self.current = max(self.end, res) 40 | 41 | return self.current / self.start 42 | 43 | 44 | class LambdaWarmUpCosineScheduler(BaseScheduler): 45 | def __init__(self, warm_up_steps: int, max_decay_steps: int, min_: float, max_: float, start: float) -> None: 46 | super().__init__() 47 | assert (max_decay_steps >= warm_up_steps) 48 | 49 | self.warm_up_steps = warm_up_steps 50 | self.start = start 51 | self.min_ = min_ 52 | self.max_ = max_ 53 | self.max_decay_steps = max_decay_steps 54 | self.last = 0. 55 | 56 | def schedule(self, n: int) -> float: 57 | if n < self.warm_up_steps: 58 | res = (self.max_ - self.start) / self.warm_up_steps * n + self.start 59 | self.last = res 60 | else: 61 | t = (n - self.warm_up_steps) / (self.max_decay_steps - self.warm_up_steps) 62 | t = min(t, 1.0) 63 | res = self.min_ + 0.5 * (self.max_ - self.min_) * (1 + np.cos(t * np.pi)) 64 | self.last = res 65 | 66 | return res / self.start 67 | 68 | 69 | class LambdaWarmUpLinearScheduler(BaseScheduler): 70 | def __init__(self, warm_up_steps: int, max_decay_steps: int, min_: float, max_: float, start: float) -> None: 71 | super().__init__() 72 | assert (max_decay_steps >= warm_up_steps) 73 | 74 | self.warm_up_steps = warm_up_steps 75 | self.start = start 76 | self.min_ = min_ 77 | self.max_ = max_ 78 | self.max_decay_steps = max_decay_steps 79 | self.last = 0. 80 | 81 | def schedule(self, n: int) -> float: 82 | if n < self.warm_up_steps: 83 | res = (self.max_ - self.start) / self.warm_up_steps * n + self.start 84 | self.last = res 85 | else: 86 | res = self.min_ + (self.max_ - self.min_) * (max_decay_steps - n) / max_decay_steps 87 | self.last = res 88 | 89 | return res / self.start 90 | -------------------------------------------------------------------------------- /enhancing/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from CLIP (https://github.com/openai/CLIP) 3 | # Copyright (c) 2021 OpenAI. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import os 7 | import ftfy 8 | import html 9 | import regex as re 10 | from pathlib import Path 11 | from functools import lru_cache 12 | from typing import Optional, List, Tuple, Dict, Set 13 | 14 | import torch 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return 'assets/vocab/bpe_simple_vocab_16e6.txt' 19 | 20 | @lru_cache() 21 | def bytes_to_unicode() -> Dict: 22 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 23 | cs = bs[:] 24 | n = 0 25 | for b in range(2 ** 8): 26 | if b not in bs: 27 | bs.append(b) 28 | cs.append(2 ** 8 + n) 29 | n += 1 30 | cs = [chr(n) for n in cs] 31 | return dict(zip(bs, cs)) 32 | 33 | def get_pairs(word: str) -> List[Tuple[str, str]]: 34 | pairs = set() 35 | prev_char = word[0] 36 | for char in word[1:]: 37 | pairs.add((prev_char, char)) 38 | prev_char = char 39 | return pairs 40 | 41 | def basic_clean(text: str) -> str: 42 | text = ftfy.fix_text(text) 43 | text = html.unescape(html.unescape(text)) 44 | return text.strip() 45 | 46 | def whitespace_clean(text: str) -> str: 47 | text = re.sub(r'\s+', ' ', text) 48 | text = text.strip() 49 | return text 50 | 51 | class SimpleTokenizer: 52 | def __init__(self, bpe_path: str = default_bpe(), text_length: int = 256, 53 | truncate_captions: bool = True) -> None: 54 | self.context_length = text_length 55 | self.truncate_text = truncate_captions 56 | self.byte_encoder = bytes_to_unicode() 57 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 58 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n') 59 | merges = merges[1:49152 - 256 - 2 + 1] 60 | merges = [tuple(merge.split()) for merge in merges] 61 | vocab = list(bytes_to_unicode().values()) 62 | vocab = vocab + [v + '' for v in vocab] 63 | for merge in merges: 64 | vocab.append(''.join(merge)) 65 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 66 | 67 | self.vocab_size = 49408 68 | 69 | self.encoder = dict(zip(vocab, range(len(vocab)))) 70 | self.decoder = {v: k for k, v in self.encoder.items()} 71 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 72 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 73 | self.pat = re.compile( 74 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 75 | re.IGNORECASE) 76 | 77 | def bpe(self, token: int) -> str: 78 | if token in self.cache: 79 | return self.cache[token] 80 | word = tuple(token[:-1]) + (token[-1] + '',) 81 | pairs = get_pairs(word) 82 | 83 | if not pairs: 84 | return token + '' 85 | 86 | while True: 87 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 88 | if bigram not in self.bpe_ranks: 89 | break 90 | first, second = bigram 91 | new_word = [] 92 | i = 0 93 | while i < len(word): 94 | try: 95 | j = word.index(first, i) 96 | new_word.extend(word[i:j]) 97 | i = j 98 | except: 99 | new_word.extend(word[i:]) 100 | break 101 | 102 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 103 | new_word.append(first + second) 104 | i += 2 105 | else: 106 | new_word.append(word[i]) 107 | i += 1 108 | new_word = tuple(new_word) 109 | word = new_word 110 | if len(word) == 1: 111 | break 112 | else: 113 | pairs = get_pairs(word) 114 | word = ' '.join(word) 115 | self.cache[token] = word 116 | return word 117 | 118 | def encode(self, text: str) -> List[int]: 119 | bpe_tokens = [] 120 | text = whitespace_clean(basic_clean(text)).lower() 121 | for token in re.findall(self.pat, text): 122 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 123 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 124 | return bpe_tokens 125 | 126 | def decode(self, tokens: List[int], remove_start_end: bool = True, pad_tokens: Optional[Set] = set()) -> str: 127 | if torch.is_tensor(tokens): 128 | tokens = tokens.tolist() 129 | 130 | if remove_start_end: 131 | tokens = [token for token in tokens if token not in (49406, 40407, 0)] 132 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | def tokenize(self, texts: str) -> List[int]: 137 | if isinstance(texts, str): 138 | texts = [texts] 139 | #assert type(texts) == list, f"texts is {texts}" 140 | all_tokens = [self.encode(text) for text in texts] 141 | result = torch.zeros(len(all_tokens), self.context_length, dtype=torch.long) 142 | 143 | for i, tokens in enumerate(all_tokens): 144 | if len(tokens) > self.context_length: 145 | if self.truncate_text: 146 | tokens = tokens[:self.context_length] 147 | else: 148 | raise RuntimeError(f"Input {texts[i]} is too long for context length {self.context_length}") 149 | result[i, :len(tokens)] = torch.tensor(tokens) 150 | 151 | return result 152 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: enhancing 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.0 9 | - pytorch=1.7.1 10 | - torchvision=0.8.2 11 | - numpy=1.19.2 12 | - ninja=1.10.2 13 | - pip: 14 | - ftfy==6.1.1 15 | - lpips==0.1.4 16 | - regex==2021.10.8 17 | - pytorch-lightning==1.5.10 18 | - einops==0.3.0 19 | - omegaconf==2.0.0 20 | - lmdb==1.0.0 21 | - wandb==0.12.21 22 | - git+https://github.com/openai/CLIP.git 23 | - albumentations==0.4.3 24 | - kornia==0.5.11 25 | - Pillow==9.0.1 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Enhancing Transformers 3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | import sys 9 | import argparse 10 | from pathlib import Path 11 | from omegaconf import OmegaConf 12 | import pytorch_lightning as pl 13 | 14 | from enhancing.utils.general import get_config_from_file, initialize_from_config, setup_callbacks 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-c', '--config', type=str, required=True) 19 | parser.add_argument('-s', '--seed', type=int, default=0) 20 | parser.add_argument('-nn', '--num_nodes', type=int, default=1) 21 | parser.add_argument('-ng', '--num_gpus', type=int, default=1) 22 | parser.add_argument('-u', '--update_every', type=int, default=1) 23 | parser.add_argument('-e', '--epochs', type=int, default=100) 24 | parser.add_argument('-lr', '--base_lr', type=float, default=4.5e-6) 25 | parser.add_argument('-a', '--use_amp', default=False, action='store_true') 26 | parser.add_argument('-b', '--batch_frequency', type=int, default=750) 27 | parser.add_argument('-m', '--max_images', type=int, default=4) 28 | args = parser.parse_args() 29 | 30 | # Set random seed 31 | pl.seed_everything(args.seed) 32 | 33 | # Load configuration 34 | config = get_config_from_file(Path("configs")/(args.config+".yaml")) 35 | exp_config = OmegaConf.create({"name": args.config, "epochs": args.epochs, "update_every": args.update_every, 36 | "base_lr": args.base_lr, "use_amp": args.use_amp, "batch_frequency": args.batch_frequency, 37 | "max_images": args.max_images}) 38 | 39 | # Build model 40 | model = initialize_from_config(config.model) 41 | model.learning_rate = exp_config.base_lr 42 | 43 | # Setup callbacks 44 | callbacks, logger = setup_callbacks(exp_config, config) 45 | 46 | # Build data modules 47 | data = initialize_from_config(config.dataset) 48 | data.prepare_data() 49 | 50 | # Build trainer 51 | trainer = pl.Trainer(max_epochs=exp_config.epochs, 52 | precision=16 if exp_config.use_amp else 32, 53 | callbacks=callbacks, 54 | gpus=args.num_gpus, 55 | num_nodes=args.num_nodes, 56 | strategy="ddp" if args.num_nodes > 1 or args.num_gpus > 1 else None, 57 | accumulate_grad_batches=exp_config.update_every, 58 | logger=logger) 59 | 60 | # Train 61 | trainer.fit(model, data) 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy==6.1.1 2 | lpips==0.1.4 3 | regex==2021.10.8 4 | torch==1.7.1 5 | torchvision==0.8.2 6 | pytorch-lightning==1.5.10 7 | einops==0.3.0 8 | omegaconf==2.0.0 9 | numpy==1.19.2 10 | lmdb==1.0.0 11 | wandb==0.12.21 12 | git+https://github.com/openai/CLIP.git 13 | albumentations==0.4.3 14 | kornia==0.5.11 15 | Pillow==9.0.1 --------------------------------------------------------------------------------