├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── class
│ ├── imagenet.txt
│ └── lsun.txt
├── font
│ └── arial.ttf
└── vocab
│ └── bpe_simple_vocab_16e6.txt
├── configs
├── imagenet_gpt_vitvq_base.yaml
├── imagenet_vitvq_base.yaml
├── imagenet_vitvq_large.yaml
└── imagenet_vitvq_small.yaml
├── enhancing
├── __init__.py
├── dataloader
│ ├── __init__.py
│ ├── cc3m.py
│ ├── classimage.py
│ ├── coco.py
│ ├── imagenet.py
│ ├── inatural.py
│ ├── lsun.py
│ ├── srimage.py
│ └── textimage.py
├── losses
│ ├── layers.py
│ ├── op
│ │ ├── __init__.py
│ │ ├── conv2d_gradfix.py
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ ├── segmentation.py
│ └── vqperceptual.py
├── modules
│ ├── cond
│ │ ├── clipcond.py
│ │ ├── dummycond.py
│ │ └── vqcond.py
│ ├── stage1
│ │ ├── layers.py
│ │ ├── quantizers.py
│ │ └── vitvqgan.py
│ └── stage2
│ │ ├── layers.py
│ │ └── transformer.py
└── utils
│ ├── callback.py
│ ├── general.py
│ ├── scheduler.py
│ └── tokenizer.py
├── environment.yaml
├── main.py
└── requirements.txt
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | figures/
2 | experiments/
3 | *__pycache__/
4 | data
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022 Thuan H. Nguyen
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19 | OR OTHER DEALINGS IN THE SOFTWARE./
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
6 |
7 | Table of Contents
8 |
9 | -
10 | About The Project
11 |
12 | -
13 | Getting Started
14 |
18 |
19 | - Roadmap
20 | - Contributing
21 | - License
22 | - Contact
23 | - Acknowledgments
24 |
25 |
26 |
27 | ## News
28 | ***09/09***
29 | 1. The release weight of ViT-VQGAN small which is trained on ImageNet at [here](https://huggingface.co/thuanz123/vitvqgan-imagenet-small)
30 |
31 | ***16/08***
32 | 1. First release weight of ViT-VQGAN base which is trained on ImageNet at [here](https://huggingface.co/thuanz123/vitvqgan-imagenet-base)
33 | 2. Add an colab notebook at [here](https://colab.research.google.com/drive/1y-PzYhkNQbhKj3i459pWd6TAO28SnF5h?usp=sharing)
34 |
35 |
36 | ## About The Project
37 |
38 | This is an unofficial implementation of both [ViT-VQGAN](https://arxiv.org/abs/2110.04627) and [RQ-VAE](https://arxiv.org/abs/2110.04627) in Pytorch. ViT-VQGAN is a simple ViT-based Vector Quantized AutoEncoder while RQ-VAE introduces a new residual quantization scheme. Further details can be viewed in the papers
39 |
40 |
41 | ## Getting Started
42 |
43 | For the ease of installation, you should use [anaconda](https://conda.io/) to setup this repo.
44 |
45 | ### Installation
46 |
47 | A suitable conda environment named `enhancing` can be created and activated with:
48 | ```
49 | conda env create -f environment.yaml
50 | conda activate enhancing
51 | ```
52 |
53 |
54 | ### Training
55 |
56 | Training is easy with one line:
57 | ```python3 main.py -c config_name -lr learning_rate -e epoch_nums```
58 |
59 |
60 | ## Roadmap
61 |
62 | - [x] Add ViT-VQGAN
63 | - [x] Add ViT-based encoder and decoder
64 | - [x] Add factorized codes
65 | - [x] Add l2-normalized codes
66 | - [x] Replace PatchGAN discriminator with StyleGAN one
67 | - [x] Add RQ-VAE
68 | - [x] Add Residual Quantizer
69 | - [x] Add RQ-Transformer
70 | - [x] Add dataloader for some common dataset
71 | - [x] ImageNet
72 | - [x] LSUN
73 | - [x] COCO
74 | - [x] Add COCO Segmentation
75 | - [x] Add COCO Caption
76 | - [x] CC3M
77 | - [ ] Add pretrained models
78 | - [x] ViT-VQGAN small
79 | - [x] ViT-VQGAN base
80 | - [ ] ViT-VQGAN large
81 |
82 |
83 |
84 | ## Contributing
85 |
86 | Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**.
87 |
88 | If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement".
89 | Don't forget to give the project a star! Thanks again!
90 |
91 | 1. Fork the Project
92 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`)
93 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`)
94 | 4. Push to the Branch (`git push origin feature/AmazingFeature`)
95 | 5. Open a Pull Request
96 |
97 |
98 |
99 | ## License
100 |
101 | Source code and pretrained weights are distributed under the MIT License. See `LICENSE` for more information.
102 |
103 |
104 |
105 | ## Contact
106 |
107 | Thuan H. Nguyen - [@leejohnthuan](https://twitter.com/leejohnthuan) - leejohnthuan@gmail.com
108 |
109 |
110 |
111 | ## Acknowledgements
112 | This project would not be possible without the generous sponsorship from [Stability AI](https://stability.ai/) and helpful discussion of folks in [LAION discord](https://discord.gg/j5GdN49g)
113 |
114 | This repo is heavily inspired by following repos and papers:
115 |
116 | * [Taming Transformers](https://github.com/CompVis/taming-transformers)
117 | * [ViT-Pytorch](https://github.com/lucidrains/vit-pytorch)
118 | * [minDALL-E](https://github.com/kakaobrain/minDALL-E)
119 | * [CLIP](https://github.com/openai/CLIP)
120 | * [ViT-VQGAN](https://arxiv.org/abs/2110.04627)
121 | * [RQ-VAE](https://arxiv.org/abs/2110.04627)
122 |
--------------------------------------------------------------------------------
/assets/class/imagenet.txt:
--------------------------------------------------------------------------------
1 | tench
2 | goldfish
3 | great white shark
4 | tiger shark
5 | hammerhead shark
6 | electric ray
7 | stingray
8 | cock
9 | hen
10 | ostrich
11 | brambling
12 | goldfinch
13 | house finch
14 | junco
15 | indigo bunting
16 | American robin
17 | bulbul
18 | jay
19 | magpie
20 | chickadee
21 | American dipper
22 | kite
23 | bald eagle
24 | vulture
25 | great grey owl
26 | fire salamander
27 | smooth newt
28 | newt
29 | spotted salamander
30 | axolotl
31 | American bullfrog
32 | tree frog
33 | tailed frog
34 | loggerhead sea turtle
35 | leatherback sea turtle
36 | mud turtle
37 | terrapin
38 | box turtle
39 | banded gecko
40 | green iguana
41 | Carolina anole
42 | desert grassland whiptail lizard
43 | agama
44 | frilled-necked lizard
45 | alligator lizard
46 | Gila monster
47 | European green lizard
48 | chameleon
49 | Komodo dragon
50 | Nile crocodile
51 | American alligator
52 | triceratops
53 | worm snake
54 | ring-necked snake
55 | eastern hog-nosed snake
56 | smooth green snake
57 | kingsnake
58 | garter snake
59 | water snake
60 | vine snake
61 | night snake
62 | boa constrictor
63 | African rock python
64 | Indian cobra
65 | green mamba
66 | sea snake
67 | Saharan horned viper
68 | eastern diamondback rattlesnake
69 | sidewinder
70 | trilobite
71 | harvestman
72 | scorpion
73 | yellow garden spider
74 | barn spider
75 | European garden spider
76 | southern black widow
77 | tarantula
78 | wolf spider
79 | tick
80 | centipede
81 | black grouse
82 | ptarmigan
83 | ruffed grouse
84 | prairie grouse
85 | peacock
86 | quail
87 | partridge
88 | grey parrot
89 | macaw
90 | sulphur-crested cockatoo
91 | lorikeet
92 | coucal
93 | bee eater
94 | hornbill
95 | hummingbird
96 | jacamar
97 | toucan
98 | duck
99 | red-breasted merganser
100 | goose
101 | black swan
102 | tusker
103 | echidna
104 | platypus
105 | wallaby
106 | koala
107 | wombat
108 | jellyfish
109 | sea anemone
110 | brain coral
111 | flatworm
112 | nematode
113 | conch
114 | snail
115 | slug
116 | sea slug
117 | chiton
118 | chambered nautilus
119 | Dungeness crab
120 | rock crab
121 | fiddler crab
122 | red king crab
123 | American lobster
124 | spiny lobster
125 | crayfish
126 | hermit crab
127 | isopod
128 | white stork
129 | black stork
130 | spoonbill
131 | flamingo
132 | little blue heron
133 | great egret
134 | bittern
135 | crane (bird)
136 | limpkin
137 | common gallinule
138 | American coot
139 | bustard
140 | ruddy turnstone
141 | dunlin
142 | common redshank
143 | dowitcher
144 | oystercatcher
145 | pelican
146 | king penguin
147 | albatross
148 | grey whale
149 | killer whale
150 | dugong
151 | sea lion
152 | Chihuahua
153 | Japanese Chin
154 | Maltese
155 | Pekingese
156 | Shih Tzu
157 | King Charles Spaniel
158 | Papillon
159 | toy terrier
160 | Rhodesian Ridgeback
161 | Afghan Hound
162 | Basset Hound
163 | Beagle
164 | Bloodhound
165 | Bluetick Coonhound
166 | Black and Tan Coonhound
167 | Treeing Walker Coonhound
168 | English foxhound
169 | Redbone Coonhound
170 | borzoi
171 | Irish Wolfhound
172 | Italian Greyhound
173 | Whippet
174 | Ibizan Hound
175 | Norwegian Elkhound
176 | Otterhound
177 | Saluki
178 | Scottish Deerhound
179 | Weimaraner
180 | Staffordshire Bull Terrier
181 | American Staffordshire Terrier
182 | Bedlington Terrier
183 | Border Terrier
184 | Kerry Blue Terrier
185 | Irish Terrier
186 | Norfolk Terrier
187 | Norwich Terrier
188 | Yorkshire Terrier
189 | Wire Fox Terrier
190 | Lakeland Terrier
191 | Sealyham Terrier
192 | Airedale Terrier
193 | Cairn Terrier
194 | Australian Terrier
195 | Dandie Dinmont Terrier
196 | Boston Terrier
197 | Miniature Schnauzer
198 | Giant Schnauzer
199 | Standard Schnauzer
200 | Scottish Terrier
201 | Tibetan Terrier
202 | Australian Silky Terrier
203 | Soft-coated Wheaten Terrier
204 | West Highland White Terrier
205 | Lhasa Apso
206 | Flat-Coated Retriever
207 | Curly-coated Retriever
208 | Golden Retriever
209 | Labrador Retriever
210 | Chesapeake Bay Retriever
211 | German Shorthaired Pointer
212 | Vizsla
213 | English Setter
214 | Irish Setter
215 | Gordon Setter
216 | Brittany
217 | Clumber Spaniel
218 | English Springer Spaniel
219 | Welsh Springer Spaniel
220 | Cocker Spaniels
221 | Sussex Spaniel
222 | Irish Water Spaniel
223 | Kuvasz
224 | Schipperke
225 | Groenendael
226 | Malinois
227 | Briard
228 | Australian Kelpie
229 | Komondor
230 | Old English Sheepdog
231 | Shetland Sheepdog
232 | collie
233 | Border Collie
234 | Bouvier des Flandres
235 | Rottweiler
236 | German Shepherd Dog
237 | Dobermann
238 | Miniature Pinscher
239 | Greater Swiss Mountain Dog
240 | Bernese Mountain Dog
241 | Appenzeller Sennenhund
242 | Entlebucher Sennenhund
243 | Boxer
244 | Bullmastiff
245 | Tibetan Mastiff
246 | French Bulldog
247 | Great Dane
248 | St. Bernard
249 | husky
250 | Alaskan Malamute
251 | Siberian Husky
252 | Dalmatian
253 | Affenpinscher
254 | Basenji
255 | pug
256 | Leonberger
257 | Newfoundland
258 | Pyrenean Mountain Dog
259 | Samoyed
260 | Pomeranian
261 | Chow Chow
262 | Keeshond
263 | Griffon Bruxellois
264 | Pembroke Welsh Corgi
265 | Cardigan Welsh Corgi
266 | Toy Poodle
267 | Miniature Poodle
268 | Standard Poodle
269 | Mexican hairless dog
270 | grey wolf
271 | Alaskan tundra wolf
272 | red wolf
273 | coyote
274 | dingo
275 | dhole
276 | African wild dog
277 | hyena
278 | red fox
279 | kit fox
280 | Arctic fox
281 | grey fox
282 | tabby cat
283 | tiger cat
284 | Persian cat
285 | Siamese cat
286 | Egyptian Mau
287 | cougar
288 | lynx
289 | leopard
290 | snow leopard
291 | jaguar
292 | lion
293 | tiger
294 | cheetah
295 | brown bear
296 | American black bear
297 | polar bear
298 | sloth bear
299 | mongoose
300 | meerkat
301 | tiger beetle
302 | ladybug
303 | ground beetle
304 | longhorn beetle
305 | leaf beetle
306 | dung beetle
307 | rhinoceros beetle
308 | weevil
309 | fly
310 | bee
311 | ant
312 | grasshopper
313 | cricket
314 | stick insect
315 | cockroach
316 | mantis
317 | cicada
318 | leafhopper
319 | lacewing
320 | dragonfly
321 | damselfly
322 | red admiral
323 | ringlet
324 | monarch butterfly
325 | small white
326 | sulphur butterfly
327 | gossamer-winged butterfly
328 | starfish
329 | sea urchin
330 | sea cucumber
331 | cottontail rabbit
332 | hare
333 | Angora rabbit
334 | hamster
335 | porcupine
336 | fox squirrel
337 | marmot
338 | beaver
339 | guinea pig
340 | common sorrel
341 | zebra
342 | pig
343 | wild boar
344 | warthog
345 | hippopotamus
346 | ox
347 | water buffalo
348 | bison
349 | ram
350 | bighorn sheep
351 | Alpine ibex
352 | hartebeest
353 | impala
354 | gazelle
355 | dromedary
356 | llama
357 | weasel
358 | mink
359 | European polecat
360 | black-footed ferret
361 | otter
362 | skunk
363 | badger
364 | armadillo
365 | three-toed sloth
366 | orangutan
367 | gorilla
368 | chimpanzee
369 | gibbon
370 | siamang
371 | guenon
372 | patas monkey
373 | baboon
374 | macaque
375 | langur
376 | black-and-white colobus
377 | proboscis monkey
378 | marmoset
379 | white-headed capuchin
380 | howler monkey
381 | titi
382 | Geoffroy's spider monkey
383 | common squirrel monkey
384 | ring-tailed lemur
385 | indri
386 | Asian elephant
387 | African bush elephant
388 | red panda
389 | giant panda
390 | snoek
391 | eel
392 | coho salmon
393 | rock beauty
394 | clownfish
395 | sturgeon
396 | garfish
397 | lionfish
398 | pufferfish
399 | abacus
400 | abaya
401 | academic gown
402 | accordion
403 | acoustic guitar
404 | aircraft carrier
405 | airliner
406 | airship
407 | altar
408 | ambulance
409 | amphibious vehicle
410 | analog clock
411 | apiary
412 | apron
413 | waste container
414 | assault rifle
415 | backpack
416 | bakery
417 | balance beam
418 | balloon
419 | ballpoint pen
420 | Band-Aid
421 | banjo
422 | baluster
423 | barbell
424 | barber chair
425 | barbershop
426 | barn
427 | barometer
428 | barrel
429 | wheelbarrow
430 | baseball
431 | basketball
432 | bassinet
433 | bassoon
434 | swimming cap
435 | bath towel
436 | bathtub
437 | station wagon
438 | lighthouse
439 | beaker
440 | military cap
441 | beer bottle
442 | beer glass
443 | bell-cot
444 | bib
445 | tandem bicycle
446 | bikini
447 | ring binder
448 | binoculars
449 | birdhouse
450 | boathouse
451 | bobsleigh
452 | bolo tie
453 | poke bonnet
454 | bookcase
455 | bookstore
456 | bottle cap
457 | bow
458 | bow tie
459 | brass
460 | bra
461 | breakwater
462 | breastplate
463 | broom
464 | bucket
465 | buckle
466 | bulletproof vest
467 | high-speed train
468 | butcher shop
469 | taxicab
470 | cauldron
471 | candle
472 | cannon
473 | canoe
474 | can opener
475 | cardigan
476 | car mirror
477 | carousel
478 | tool kit
479 | carton
480 | car wheel
481 | automated teller machine
482 | cassette
483 | cassette player
484 | castle
485 | catamaran
486 | CD player
487 | cello
488 | mobile phone
489 | chain
490 | chain-link fence
491 | chain mail
492 | chainsaw
493 | chest
494 | chiffonier
495 | chime
496 | china cabinet
497 | Christmas stocking
498 | church
499 | movie theater
500 | cleaver
501 | cliff dwelling
502 | cloak
503 | clogs
504 | cocktail shaker
505 | coffee mug
506 | coffeemaker
507 | coil
508 | combination lock
509 | computer keyboard
510 | confectionery store
511 | container ship
512 | convertible
513 | corkscrew
514 | cornet
515 | cowboy boot
516 | cowboy hat
517 | cradle
518 | crane (machine)
519 | crash helmet
520 | crate
521 | infant bed
522 | Crock Pot
523 | croquet ball
524 | crutch
525 | cuirass
526 | dam
527 | desk
528 | desktop computer
529 | rotary dial telephone
530 | diaper
531 | digital clock
532 | digital watch
533 | dining table
534 | dishcloth
535 | dishwasher
536 | disc brake
537 | dock
538 | dog sled
539 | dome
540 | doormat
541 | drilling rig
542 | drum
543 | drumstick
544 | dumbbell
545 | Dutch oven
546 | electric fan
547 | electric guitar
548 | electric locomotive
549 | entertainment center
550 | envelope
551 | espresso machine
552 | face powder
553 | feather boa
554 | filing cabinet
555 | fireboat
556 | fire engine
557 | fire screen sheet
558 | flagpole
559 | flute
560 | folding chair
561 | football helmet
562 | forklift
563 | fountain
564 | fountain pen
565 | four-poster bed
566 | freight car
567 | French horn
568 | frying pan
569 | fur coat
570 | garbage truck
571 | gas mask
572 | gas pump
573 | goblet
574 | go-kart
575 | golf ball
576 | golf cart
577 | gondola
578 | gong
579 | gown
580 | grand piano
581 | greenhouse
582 | grille
583 | grocery store
584 | guillotine
585 | barrette
586 | hair spray
587 | half-track
588 | hammer
589 | hamper
590 | hair dryer
591 | hand-held computer
592 | handkerchief
593 | hard disk drive
594 | harmonica
595 | harp
596 | harvester
597 | hatchet
598 | holster
599 | home theater
600 | honeycomb
601 | hook
602 | hoop skirt
603 | horizontal bar
604 | horse-drawn vehicle
605 | hourglass
606 | iPod
607 | clothes iron
608 | jack-o'-lantern
609 | jeans
610 | jeep
611 | T-shirt
612 | jigsaw puzzle
613 | pulled rickshaw
614 | joystick
615 | kimono
616 | knee pad
617 | knot
618 | lab coat
619 | ladle
620 | lampshade
621 | laptop computer
622 | lawn mower
623 | lens cap
624 | paper knife
625 | library
626 | lifeboat
627 | lighter
628 | limousine
629 | ocean liner
630 | lipstick
631 | slip-on shoe
632 | lotion
633 | speaker
634 | loupe
635 | sawmill
636 | magnetic compass
637 | mail bag
638 | mailbox
639 | tights
640 | tank suit
641 | manhole cover
642 | maraca
643 | marimba
644 | mask
645 | match
646 | maypole
647 | maze
648 | measuring cup
649 | medicine chest
650 | megalith
651 | microphone
652 | microwave oven
653 | military uniform
654 | milk can
655 | minibus
656 | miniskirt
657 | minivan
658 | missile
659 | mitten
660 | mixing bowl
661 | mobile home
662 | Model T
663 | modem
664 | monastery
665 | monitor
666 | moped
667 | mortar
668 | square academic cap
669 | mosque
670 | mosquito net
671 | scooter
672 | mountain bike
673 | tent
674 | computer mouse
675 | mousetrap
676 | moving van
677 | muzzle
678 | nail
679 | neck brace
680 | necklace
681 | nipple
682 | notebook computer
683 | obelisk
684 | oboe
685 | ocarina
686 | odometer
687 | oil filter
688 | organ
689 | oscilloscope
690 | overskirt
691 | bullock cart
692 | oxygen mask
693 | packet
694 | paddle
695 | paddle wheel
696 | padlock
697 | paintbrush
698 | pajamas
699 | palace
700 | pan flute
701 | paper towel
702 | parachute
703 | parallel bars
704 | park bench
705 | parking meter
706 | passenger car
707 | patio
708 | payphone
709 | pedestal
710 | pencil case
711 | pencil sharpener
712 | perfume
713 | Petri dish
714 | photocopier
715 | plectrum
716 | Pickelhaube
717 | picket fence
718 | pickup truck
719 | pier
720 | piggy bank
721 | pill bottle
722 | pillow
723 | ping-pong ball
724 | pinwheel
725 | pirate ship
726 | pitcher
727 | hand plane
728 | planetarium
729 | plastic bag
730 | plate rack
731 | plow
732 | plunger
733 | Polaroid camera
734 | pole
735 | police van
736 | poncho
737 | billiard table
738 | soda bottle
739 | pot
740 | potter's wheel
741 | power drill
742 | prayer rug
743 | printer
744 | prison
745 | projectile
746 | projector
747 | hockey puck
748 | punching bag
749 | purse
750 | quill
751 | quilt
752 | race car
753 | racket
754 | radiator
755 | radio
756 | radio telescope
757 | rain barrel
758 | recreational vehicle
759 | reel
760 | reflex camera
761 | refrigerator
762 | remote control
763 | restaurant
764 | revolver
765 | rifle
766 | rocking chair
767 | rotisserie
768 | eraser
769 | rugby ball
770 | ruler
771 | running shoe
772 | safe
773 | safety pin
774 | salt shaker
775 | sandal
776 | sarong
777 | saxophone
778 | scabbard
779 | weighing scale
780 | school bus
781 | schooner
782 | scoreboard
783 | CRT screen
784 | screw
785 | screwdriver
786 | seat belt
787 | sewing machine
788 | shield
789 | shoe store
790 | shoji
791 | shopping basket
792 | shopping cart
793 | shovel
794 | shower cap
795 | shower curtain
796 | ski
797 | ski mask
798 | sleeping bag
799 | slide rule
800 | sliding door
801 | slot machine
802 | snorkel
803 | snowmobile
804 | snowplow
805 | soap dispenser
806 | soccer ball
807 | sock
808 | solar thermal collector
809 | sombrero
810 | soup bowl
811 | space bar
812 | space heater
813 | space shuttle
814 | spatula
815 | motorboat
816 | spider web
817 | spindle
818 | sports car
819 | spotlight
820 | stage
821 | steam locomotive
822 | through arch bridge
823 | steel drum
824 | stethoscope
825 | scarf
826 | stone wall
827 | stopwatch
828 | stove
829 | strainer
830 | tram
831 | stretcher
832 | couch
833 | stupa
834 | submarine
835 | suit
836 | sundial
837 | sunglass
838 | sunglasses
839 | sunscreen
840 | suspension bridge
841 | mop
842 | sweatshirt
843 | swimsuit
844 | swing
845 | switch
846 | syringe
847 | table lamp
848 | tank
849 | tape player
850 | teapot
851 | teddy bear
852 | television
853 | tennis ball
854 | thatched roof
855 | front curtain
856 | thimble
857 | threshing machine
858 | throne
859 | tile roof
860 | toaster
861 | tobacco shop
862 | toilet seat
863 | torch
864 | totem pole
865 | tow truck
866 | toy store
867 | tractor
868 | semi-trailer truck
869 | tray
870 | trench coat
871 | tricycle
872 | trimaran
873 | tripod
874 | triumphal arch
875 | trolleybus
876 | trombone
877 | tub
878 | turnstile
879 | typewriter keyboard
880 | umbrella
881 | unicycle
882 | upright piano
883 | vacuum cleaner
884 | vase
885 | vault
886 | velvet
887 | vending machine
888 | vestment
889 | viaduct
890 | violin
891 | volleyball
892 | waffle iron
893 | wall clock
894 | wallet
895 | wardrobe
896 | military aircraft
897 | sink
898 | washing machine
899 | water bottle
900 | water jug
901 | water tower
902 | whiskey jug
903 | whistle
904 | wig
905 | window screen
906 | window shade
907 | Windsor tie
908 | wine bottle
909 | wing
910 | wok
911 | wooden spoon
912 | wool
913 | split-rail fence
914 | shipwreck
915 | yawl
916 | yurt
917 | website
918 | comic book
919 | crossword
920 | traffic sign
921 | traffic light
922 | dust jacket
923 | menu
924 | plate
925 | guacamole
926 | consomme
927 | hot pot
928 | trifle
929 | ice cream
930 | ice pop
931 | baguette
932 | bagel
933 | pretzel
934 | cheeseburger
935 | hot dog
936 | mashed potato
937 | cabbage
938 | broccoli
939 | cauliflower
940 | zucchini
941 | spaghetti squash
942 | acorn squash
943 | butternut squash
944 | cucumber
945 | artichoke
946 | bell pepper
947 | cardoon
948 | mushroom
949 | Granny Smith
950 | strawberry
951 | orange
952 | lemon
953 | fig
954 | pineapple
955 | banana
956 | jackfruit
957 | custard apple
958 | pomegranate
959 | hay
960 | carbonara
961 | chocolate syrup
962 | dough
963 | meatloaf
964 | pizza
965 | pot pie
966 | burrito
967 | red wine
968 | espresso
969 | cup
970 | eggnog
971 | alp
972 | bubble
973 | cliff
974 | coral reef
975 | geyser
976 | lakeshore
977 | promontory
978 | shoal
979 | seashore
980 | valley
981 | volcano
982 | baseball player
983 | bridegroom
984 | scuba diver
985 | rapeseed
986 | daisy
987 | yellow lady's slipper
988 | corn
989 | acorn
990 | rose hip
991 | horse chestnut seed
992 | coral fungus
993 | agaric
994 | gyromitra
995 | stinkhorn mushroom
996 | earth star
997 | hen-of-the-woods
998 | bolete
999 | ear
1000 | toilet paper
--------------------------------------------------------------------------------
/assets/class/lsun.txt:
--------------------------------------------------------------------------------
1 | bedroom
2 | bridge
3 | church_outdoor
4 | classroom
5 | conference_room
6 | dining_room
7 | kitchen
8 | living_room
9 | restaurant
10 | test
11 | tower
--------------------------------------------------------------------------------
/assets/font/arial.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuanz123/enhancing-transformers/e1c185da5cc0eeefdded68cbff888d5b9c248372/assets/font/arial.ttf
--------------------------------------------------------------------------------
/configs/imagenet_gpt_vitvq_base.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: enhancing.modules.stage2.transformer.CondTransformer
3 | params:
4 | cond_key: class
5 | cond:
6 | target: enhancing.modules.cond.dummycond.ClassCond
7 | params:
8 | image_size: 256
9 | class_name: assets/class/imagenet.txt
10 | stage1:
11 | target: enhancing.modules.stage1.vitvqgan.ViTVQ
12 | params:
13 | image_key: image
14 | path: weight/imagenet_vitvq_base.ckpt
15 | image_size: 256
16 | patch_size: 8
17 | encoder:
18 | dim: 768
19 | depth: 12
20 | heads: 12
21 | mlp_dim: 3072
22 | decoder:
23 | dim: 768
24 | depth: 12
25 | heads: 12
26 | mlp_dim: 3072
27 | quantizer:
28 | embed_dim: 32
29 | n_embed: 8192
30 | loss:
31 | target: enhancing.losses.vqperceptual.DummyLoss
32 | transformer:
33 | target: enhancing.modules.stage2.layers.GPT
34 | params:
35 | vocab_cond_size: 1000
36 | vocab_img_size: 8192
37 | embed_dim: 6144
38 | cond_num_tokens: 1
39 | img_num_tokens: 1024
40 | n_heads: 16
41 | n_layers: 24
42 |
43 | dataset:
44 | target: enhancing.dataloader.DataModuleFromConfig
45 | params:
46 | batch_size: 4
47 | num_workers: 2
48 | train:
49 | target: enhancing.dataloader.imagenet.ImageNetTrain
50 | params:
51 | root: data/ilsvrc2012
52 | resolution: 256
53 |
54 | validation:
55 | target: enhancing.dataloader.imagenet.ImageNetValidation
56 | params:
57 | root: data/ilsvrc2012
58 | resolution: 256
59 |
--------------------------------------------------------------------------------
/configs/imagenet_vitvq_base.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: enhancing.modules.stage1.vitvqgan.ViTVQ
3 | params:
4 | image_key: image
5 | image_size: 256
6 | patch_size: 8
7 | encoder:
8 | dim: 768
9 | depth: 12
10 | heads: 12
11 | mlp_dim: 3072
12 | decoder:
13 | dim: 768
14 | depth: 12
15 | heads: 12
16 | mlp_dim: 3072
17 | quantizer:
18 | embed_dim: 32
19 | n_embed: 8192
20 | loss:
21 | target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator
22 | params:
23 | loglaplace_weight: 0.0
24 | loggaussian_weight: 1.0
25 | perceptual_weight: 0.1
26 | adversarial_weight: 0.1
27 |
28 | dataset:
29 | target: enhancing.dataloader.DataModuleFromConfig
30 | params:
31 | batch_size: 8
32 | num_workers: 4
33 | train:
34 | target: enhancing.dataloader.imagenet.ImageNetTrain
35 | params:
36 | root: data/ilsvrc2012
37 | resolution: 256
38 |
39 | validation:
40 | target: enhancing.dataloader.imagenet.ImageNetValidation
41 | params:
42 | root: data/ilsvrc2012
43 | resolution: 256
--------------------------------------------------------------------------------
/configs/imagenet_vitvq_large.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: enhancing.modules.stage1.vitvqgan.ViTVQ
3 | params:
4 | image_key: image
5 | image_size: 256
6 | patch_size: 8
7 | encoder:
8 | dim: 512
9 | depth: 8
10 | heads: 8
11 | mlp_dim: 2048
12 | decoder:
13 | dim: 1280
14 | depth: 32
15 | heads: 16
16 | mlp_dim: 5120
17 | quantizer:
18 | embed_dim: 32
19 | n_embed: 8192
20 | loss:
21 | target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator
22 | params:
23 | loglaplace_weight: 0.0
24 | loggaussian_weight: 1.0
25 | perceptual_weight: 0.1
26 | adversarial_weight: 0.1
27 |
28 | dataset:
29 | target: enhancing.dataloader.DataModuleFromConfig
30 | params:
31 | batch_size: 2
32 | num_workers: 4
33 | train:
34 | target: enhancing.dataloader.imagenet.ImageNetTrain
35 | params:
36 | root: data/ilsvrc2012
37 | resolution: 256
38 |
39 | validation:
40 | target: enhancing.dataloader.imagenet.ImageNetValidation
41 | params:
42 | root: data/ilsvrc2012
43 | resolution: 256
--------------------------------------------------------------------------------
/configs/imagenet_vitvq_small.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: enhancing.modules.stage1.vitvqgan.ViTVQ
3 | params:
4 | image_key: image
5 | image_size: 256
6 | patch_size: 8
7 | encoder:
8 | dim: 512
9 | depth: 8
10 | heads: 8
11 | mlp_dim: 2048
12 | decoder:
13 | dim: 512
14 | depth: 8
15 | heads: 8
16 | mlp_dim: 2048
17 | quantizer:
18 | embed_dim: 32
19 | n_embed: 8192
20 | loss:
21 | target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator
22 | params:
23 | loglaplace_weight: 0.0
24 | loggaussian_weight: 1.0
25 | perceptual_weight: 0.1
26 | adversarial_weight: 0.1
27 |
28 | dataset:
29 | target: enhancing.dataloader.DataModuleFromConfig
30 | params:
31 | batch_size: 8
32 | num_workers: 4
33 | train:
34 | target: enhancing.dataloader.imagenet.ImageNetTrain
35 | params:
36 | root: data/ilsvrc2012
37 | resolution: 256
38 |
39 | validation:
40 | target: enhancing.dataloader.imagenet.ImageNetValidation
41 | params:
42 | root: data/ilsvrc2012
43 | resolution: 256
--------------------------------------------------------------------------------
/enhancing/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import general
2 |
--------------------------------------------------------------------------------
/enhancing/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | from typing import Optional
7 | from omegaconf import OmegaConf
8 |
9 | import pytorch_lightning as pl
10 | from torch.utils.data import DataLoader
11 |
12 | from ..utils.general import initialize_from_config
13 |
14 | class DataModuleFromConfig(pl.LightningDataModule):
15 | def __init__(self, batch_size: int, train: Optional[OmegaConf] = None,
16 | validation: Optional[OmegaConf] = None,
17 | test: Optional[OmegaConf] = None,
18 | num_workers: Optional[int] = None):
19 | super().__init__()
20 | self.dataset_configs = dict()
21 | self.batch_size = batch_size
22 | self.num_workers = num_workers if num_workers is not None else batch_size*2
23 | if train is not None:
24 | self.dataset_configs["train"] = train
25 | self.train_dataloader = self._train_dataloader
26 | if validation is not None:
27 | self.dataset_configs["validation"] = validation
28 | self.val_dataloader = self._val_dataloader
29 | if test is not None:
30 | self.dataset_configs["test"] = test
31 | self.test_dataloader = self._test_dataloader
32 |
33 | def prepare_data(self):
34 | for data_cfg in self.dataset_configs.values():
35 | initialize_from_config(data_cfg)
36 |
37 | def setup(self, stage=None):
38 | self.datasets = dict(
39 | (k, initialize_from_config(self.dataset_configs[k]))
40 | for k in self.dataset_configs)
41 |
42 | def _train_dataloader(self):
43 | return DataLoader(self.datasets["train"], batch_size=self.batch_size,
44 | num_workers=self.num_workers, shuffle=True)
45 |
46 | def _val_dataloader(self):
47 | return DataLoader(self.datasets["validation"],
48 | batch_size=self.batch_size,
49 | num_workers=self.num_workers)
50 |
51 | def _test_dataloader(self):
52 | return DataLoader(self.datasets["test"], batch_size=self.batch_size,
53 | num_workers=self.num_workers)
54 |
--------------------------------------------------------------------------------
/enhancing/dataloader/cc3m.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | from typing import Optional, Union, Callable, Tuple, Any
8 | from pathlib import Path
9 | from omegaconf import OmegaConf
10 | from PIL import Image
11 |
12 | from torchvision import transforms as T
13 | from torch.utils.data import Dataset
14 |
15 | from ..utils.general import initialize_from_config
16 |
17 | class CC3MBase(Dataset):
18 | def __init__(self, folder: str, split: str,
19 | tokenizer: OmegaConf,
20 | transform: Callable) -> None:
21 | super().__init__()
22 |
23 | for line in open(f'{Path(folder)}/{split}_list.txt', 'r').readlines():
24 | imgpath, text = line.strip().split('\t')
25 | self.items.append((Path(folder)/imgpath, text))
26 |
27 | self.tokenizer = initialize_from_config(tokenizer)
28 | self.transform = transform
29 |
30 | def __len__(self) -> int:
31 | return len(self.keys)
32 |
33 | def __getitem__(self, ind: int) -> Tuple[Any, Any]:
34 | image_file, caption = self.items[ind]
35 |
36 | caption = self.tokenizer.tokenize(caption).squeeze(0)
37 |
38 | image = Image.open(image_file)
39 | if image.mode != 'RGB':
40 | image = image.convert('RGB')
41 |
42 | if self.transform:
43 | image = self.transform(image)
44 |
45 | # Success
46 | return {"caption": caption, "image": image}
47 |
48 |
49 | class CC3MTrain(TextImageBase):
50 | def __init__(self, folder: str, tokenizer: OmegaConf,
51 | resolution: Union[Tuple[int, int], int] = 256) -> None:
52 | transform = T.Compose([
53 | T.Resize(resolution),
54 | T.RandomCrop(resolution),
55 | T.ToTensor(),
56 | ])
57 |
58 | super().__init__(folder, 'train', tokenizer, transform)
59 |
60 |
61 | class CC3MValidation(TextImageBase):
62 | def __init__(self, folder: str, tokenizer: OmegaConf,
63 | resolution: Union[Tuple[int, int], int] = 256) -> None:
64 | transform = T.Compose([
65 | T.Resize(resolution),
66 | T.CenterCrop(resolution),
67 | T.ToTensor(),
68 | ])
69 |
70 | super().__init__(folder, 'val', tokenizer, transform)
71 |
--------------------------------------------------------------------------------
/enhancing/dataloader/classimage.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import numpy as np
8 | from typing import Optional, Union, Callable, Tuple, Any
9 | from pathlib import Path
10 | from random import randint, choice
11 | from omegaconf import OmegaConf
12 |
13 | import torch
14 | from torchvision import transforms as T
15 | from torchvision.datasets import ImageFolder
16 |
17 | from ..utils.general import initialize_from_config
18 |
19 | class ClassImageBase(ImageFolder):
20 | def __init__(self, root: str, split: str,
21 | transform: Callable) -> None:
22 | root = Path(root)/split
23 | super().__init__(root, transform)
24 |
25 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
26 | image, target = super().__getitem__(index)
27 |
28 | return {'image': image, 'class': torch.tensor([target])}
29 |
30 |
31 | class ClassImageTrain(ClassImageBase):
32 | def __init__(self, root: str,
33 | resolution: Union[Tuple[int, int], int] = 256,
34 | resize_ratio: float = 0.75) -> None:
35 | if isinstance(resolution, int):
36 | resolution = [resolution, resolution]
37 |
38 | transform = T.Compose([
39 | T.Resize(resolution),
40 | T.RandomCrop(resolution),
41 | T.RandomHorizontalFlip(),
42 | T.ToTensor()
43 | ])
44 |
45 | super().__init__(root, 'train', transform)
46 |
47 |
48 | class ClassImageValidation(ClassImageBase):
49 | def __init__(self, root: str,
50 | resolution: Union[Tuple[int, int], int] = 256) -> None:
51 | if isinstance(resolution, int):
52 | resolution = [resolution, resolution]
53 |
54 | transform = T.Compose([
55 | T.Resize(resolution),
56 | T.CenterCrop(resolution),
57 | T.ToTensor()
58 | ])
59 |
60 | super().__init__(root, 'val', transform)
61 |
--------------------------------------------------------------------------------
/enhancing/dataloader/coco.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | import json
11 | import albumentations as A
12 | from omegaconf import OmegaConf
13 | from typing import Optional, List, Callable, Union, Tuple
14 | from pathlib import Path
15 |
16 | import numpy as np
17 | from PIL import Image
18 | from torch.utils.data import Dataset
19 |
20 | from ..utils.general import initialize_from_config
21 |
22 |
23 | class COCOBase(Dataset):
24 | def __init__(self, dataroot: str = "", labelroot: str = "", stuffthingroot: str = "", split: str = "",
25 | onehot_segmentation: bool = False, use_stuffthing: bool = False,
26 | tokenizer: Optional[OmegaConf] = None, transform: Optional[Callable] = None) -> None:
27 | assert split in ["train", "val"]
28 | self.split = split
29 |
30 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot
31 | self.stuffthing = use_stuffthing # include thing in segmentation
32 | if self.onehot and not self.stuffthing:
33 | raise NotImplemented("One hot mode is only supported for the "
34 | "stuffthings version because labels are stored "
35 | "a bit different.")
36 |
37 | data_json = Path(labelroot)/f"captions_{split}2017.json"
38 | with open(data_json) as json_file:
39 | self.json_data = json.load(json_file)
40 | self.img_id_to_captions = dict()
41 | self.img_id_to_filepath = dict()
42 | self.img_id_to_segmentation_filepath = dict()
43 |
44 | if self.stuffthing:
45 | self.segmentation_prefix = Path(stuffthingroot)/f"{split}2017"
46 | else:
47 | self.segmentation_prefix = Path(labelroot)/f"stuff_{split}2017_pixelmaps"
48 |
49 | imagedirs = self.json_data["images"]
50 | self.labels = {"image_ids": list()}
51 | for imgdir in imagedirs:
52 | self.img_id_to_filepath[imgdir["id"]] = Path(dataroot)/f"{split}2017"/imgdir["file_name"]
53 | self.img_id_to_captions[imgdir["id"]] = list()
54 | pngfilename = imgdir["file_name"].replace("jpg", "png")
55 | self.img_id_to_segmentation_filepath[imgdir["id"]] = Path(self.segmentation_prefix)/pngfilename
56 | self.labels["image_ids"].append(imgdir["id"])
57 |
58 | capdirs = self.json_data["annotations"]
59 | for capdir in capdirs:
60 | # there are in average 5 captions per image
61 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
62 |
63 | self.transform = transform
64 | self.tokenizer = initialize_from_config(tokenizer)
65 |
66 | def __len__(self):
67 | return len(self.labels["image_ids"])
68 |
69 | def preprocess_image(self, image_path, segmentation_path):
70 | image = Image.open(image_path)
71 | if image.mode != "RGB":
72 | image = image.convert("RGB")
73 | image = np.array(image).astype(np.uint8)
74 |
75 | segmentation = Image.open(segmentation_path)
76 | if not self.onehot and not segmentation.mode == "RGB":
77 | segmentation = segmentation.convert("RGB")
78 | segmentation = np.array(segmentation).astype(np.uint8)
79 | if self.onehot:
80 | assert self.stuffthing
81 | # stored in caffe format: unlabeled==255. stuff and thing from
82 | # 0-181. to be compatible with the labels in
83 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt
84 | # we shift stuffthing one to the right and put unlabeled in zero
85 | # as long as segmentation is uint8 shifting to right handles the
86 | # latter too
87 | assert segmentation.dtype == np.uint8
88 | segmentation = segmentation + 1
89 |
90 | image, segmentation = self.transform(image=image, segmentation=segmentation)
91 | image = (image / 255).astype(np.float32)
92 |
93 | if self.onehot:
94 | assert segmentation.dtype == np.uint8
95 | # make it one hot
96 | n_labels = 183
97 | flatseg = np.ravel(segmentation)
98 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
99 | onehot[np.arange(flatseg.size), flatseg] = True
100 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
101 | segmentation = onehot
102 | else:
103 | segmentation = (segmentation / 255).astype(np.float32)
104 |
105 | return image, segmentation
106 |
107 | def __getitem__(self, i):
108 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
109 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
110 | image, segmentation = self.preprocess_image(img_path, seg_path)
111 |
112 | captions = self.img_id_to_captions[self.labels["image_ids"][i]]
113 | caption = captions[np.random.randint(0, len(captions))]
114 | caption = self.tokenizer.tokenize(caption).squeeze(0)
115 |
116 | return {"image": image, "caption": caption, "segmentation": segmentation}
117 |
118 |
119 | class COCOTrain(COCOBase):
120 | def __init__(self, dataroot: str, labelroot: str, stuffthingroot: str, tokenizer: OmegaConf,
121 | resolution: Union[Tuple[int, int], int], onehot_segmentation: bool = False, use_stuffthing: bool = False) -> None:
122 | if isinstance(resolution, int):
123 | resolution = [resolution, resolution]
124 |
125 | transform = A.Compose(
126 | [A.SmallestMaxSize(max_size=min(resolution)),
127 | A.RandomCrop(height=resolution[0], width=resolution[1])],
128 | additional_targets={"segmentation": "image"})
129 |
130 | super().__init__(dataroot, labelroot, stuffthingroot, "train",
131 | onehot_segmentation, use_stuffthing, tokenizer, transform)
132 |
133 |
134 | class COCOValidation(COCOBase):
135 | def __init__(self, dataroot: str, labelroot: str, stuffthingroot: str, tokenizer: OmegaConf,
136 | resolution: Union[Tuple[int, int], int], onehot_segmentation: bool = False, use_stuffthing: bool = False) -> None:
137 | if isinstance(resolution, int):
138 | resolution = [resolution, resolution]
139 |
140 | transform = A.Compose(
141 | [A.SmallestMaxSize(max_size=min(resolution)),
142 | A.CenterCrop(height=resolution[0], width=resolution[1])],
143 | additional_targets={"segmentation": "image"})
144 |
145 | super().__init__(dataroot, labelroot, stuffthingroot, "val",
146 | onehot_segmentation, use_stuffthing, tokenizer, transform)
147 |
--------------------------------------------------------------------------------
/enhancing/dataloader/imagenet.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import PIL
8 | from typing import Any, Tuple, Union, Optional, Callable
9 |
10 | import torch
11 | from torchvision import transforms as T
12 | from torchvision.datasets import ImageNet
13 |
14 |
15 | class ImageNetBase(ImageNet):
16 | def __init__(self, root: str, split: str,
17 | transform: Optional[Callable] = None) -> None:
18 | super().__init__(root=root, split=split, transform=transform)
19 |
20 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
21 | sample, target = super().__getitem__(index)
22 |
23 | return {'image': sample, 'class': torch.tensor([target])}
24 |
25 |
26 | class ImageNetTrain(ImageNetBase):
27 | def __init__(self, root: str,
28 | resolution: Union[Tuple[int, int], int] = 256,
29 | resize_ratio: float = 0.75) -> None:
30 |
31 | transform = T.Compose([
32 | T.Resize(resolution),
33 | T.RandomCrop(resolution),
34 | T.RandomHorizontalFlip(),
35 | T.ToTensor()
36 | ])
37 |
38 | super().__init__(root=root, split='train', transform=transform)
39 |
40 |
41 | class ImageNetValidation(ImageNetBase):
42 | def __init__(self, root: str,
43 | resolution: Union[Tuple[int, int], int] = 256,) -> None:
44 |
45 | if isinstance(resolution, int):
46 | resolution = (resolution, resolution)
47 |
48 | transform = T.Compose([
49 | T.Resize(resolution),
50 | T.CenterCrop(resolution),
51 | T.ToTensor()
52 | ])
53 |
54 | super().__init__(root=root, split='val', transform=transform)
55 |
--------------------------------------------------------------------------------
/enhancing/dataloader/inatural.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from Torchvision (https://github.com/pytorch/vision)
7 | # Copyright (c) 2016 Soumith Chintala. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | import os
11 | import PIL
12 | from typing import Any, Tuple, Union
13 | from pathlib import Path
14 | from typing import Optional, Union, Callable, Tuple, Any
15 |
16 | import torch
17 | from torchvision import transforms as T
18 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
19 | from torchvision.datasets.vision import VisionDataset
20 |
21 |
22 | CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
23 |
24 | DATASET_URLS = {
25 | "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
26 | "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
27 | "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
28 | "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
29 | "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
30 | "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
31 | }
32 |
33 | DATASET_MD5 = {
34 | "2017": "7c784ea5e424efaec655bd392f87301f",
35 | "2018": "b1c6952ce38f31868cc50ea72d066cc3",
36 | "2019": "c60a6e2962c9b8ccbd458d12c8582644",
37 | "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
38 | "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
39 | "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
40 | }
41 |
42 |
43 | class INaturalistBase(VisionDataset):
44 | """`iNaturalist `_ Dataset.
45 |
46 | Args:
47 | root (string): Root directory of dataset where the image files are stored.
48 | This class does not require/use annotation files.
49 | version (string, optional): Which version of the dataset to download/use. One of
50 | '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
51 | Default: `2021_train`.
52 | target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
53 |
54 | - ``full``: the full category (species)
55 | - ``kingdom``: e.g. "Animalia"
56 | - ``phylum``: e.g. "Arthropoda"
57 | - ``class``: e.g. "Insecta"
58 | - ``order``: e.g. "Coleoptera"
59 | - ``family``: e.g. "Cleridae"
60 | - ``genus``: e.g. "Trichodes"
61 |
62 | for 2017-2019 versions, one of:
63 |
64 | - ``full``: the full (numeric) category
65 | - ``super``: the super category, e.g. "Amphibians"
66 |
67 | Can also be a list to output a tuple with all specified target types.
68 | Defaults to ``full``.
69 | transform (callable, optional): A function/transform that takes in an PIL image
70 | and returns a transformed version. E.g, ``transforms.RandomCrop``
71 | target_transform (callable, optional): A function/transform that takes in the
72 | target and transforms it.
73 | download (bool, optional): If true, downloads the dataset from the internet and
74 | puts it in root directory. If dataset is already downloaded, it is not
75 | downloaded again.
76 | """
77 |
78 | def __init__(
79 | self,
80 | root: str,
81 | version: str = "2021_train",
82 | target_type: Union[List[str], str] = "full",
83 | transform: Optional[Callable] = None,
84 | target_transform: Optional[Callable] = None,
85 | ) -> None:
86 | self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
87 |
88 | super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
89 |
90 | os.makedirs(root, exist_ok=True)
91 | path_exist = os.path.isdir(os.path.join(root,version))
92 | if not path_exist:
93 | self.download()
94 |
95 | if not self._check_integrity():
96 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
97 |
98 | self.all_categories: List[str] = []
99 |
100 | # map: category type -> name of category -> index
101 | self.categories_index: Dict[str, Dict[str, int]] = {}
102 |
103 | # list indexed by category id, containing mapping from category type -> index
104 | self.categories_map: List[Dict[str, int]] = []
105 |
106 | if not isinstance(target_type, list):
107 | target_type = [target_type]
108 | if self.version[:4] == "2021":
109 | self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
110 | self._init_2021()
111 | else:
112 | self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
113 | self._init_pre2021()
114 |
115 | # index of all files: (full category id, filename)
116 | self.index: List[Tuple[int, str]] = []
117 |
118 | for dir_index, dir_name in enumerate(self.all_categories):
119 | files = os.listdir(os.path.join(self.root, dir_name))
120 | for fname in files:
121 | self.index.append((dir_index, fname))
122 |
123 | def _init_2021(self) -> None:
124 | """Initialize based on 2021 layout"""
125 |
126 | self.all_categories = sorted(os.listdir(self.root))
127 |
128 | # map: category type -> name of category -> index
129 | self.categories_index = {k: {} for k in CATEGORIES_2021}
130 |
131 | for dir_index, dir_name in enumerate(self.all_categories):
132 | pieces = dir_name.split("_")
133 | if len(pieces) != 8:
134 | raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
135 | if pieces[0] != f"{dir_index:05d}":
136 | raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
137 | cat_map = {}
138 | for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
139 | if name in self.categories_index[cat]:
140 | cat_id = self.categories_index[cat][name]
141 | else:
142 | cat_id = len(self.categories_index[cat])
143 | self.categories_index[cat][name] = cat_id
144 | cat_map[cat] = cat_id
145 | self.categories_map.append(cat_map)
146 |
147 | def _init_pre2021(self) -> None:
148 | """Initialize based on 2017-2019 layout"""
149 |
150 | # map: category type -> name of category -> index
151 | self.categories_index = {"super": {}}
152 |
153 | cat_index = 0
154 | super_categories = sorted(os.listdir(self.root))
155 | for sindex, scat in enumerate(super_categories):
156 | self.categories_index["super"][scat] = sindex
157 | subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
158 | for subcat in subcategories:
159 | if self.version == "2017":
160 | # this version does not use ids as directory names
161 | subcat_i = cat_index
162 | cat_index += 1
163 | else:
164 | try:
165 | subcat_i = int(subcat)
166 | except ValueError:
167 | raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
168 | if subcat_i >= len(self.categories_map):
169 | old_len = len(self.categories_map)
170 | self.categories_map.extend([{}] * (subcat_i - old_len + 1))
171 | self.all_categories.extend([""] * (subcat_i - old_len + 1))
172 | if self.categories_map[subcat_i]:
173 | raise RuntimeError(f"Duplicate category {subcat}")
174 | self.categories_map[subcat_i] = {"super": sindex}
175 | self.all_categories[subcat_i] = os.path.join(scat, subcat)
176 |
177 | # validate the dictionary
178 | for cindex, c in enumerate(self.categories_map):
179 | if not c:
180 | raise RuntimeError(f"Missing category {cindex}")
181 |
182 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
183 | """
184 | Args:
185 | index (int): Index
186 |
187 | Returns:
188 | tuple: (image, target) where the type of target specified by target_type.
189 | """
190 |
191 | cat_id, fname = self.index[index]
192 | img = PIL.Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
193 |
194 | target: Any = []
195 | for t in self.target_type:
196 | if t == "full":
197 | target.append(cat_id)
198 | else:
199 | target.append(self.categories_map[cat_id][t])
200 | target = tuple(target) if len(target) > 1 else target[0]
201 |
202 | if self.transform is not None:
203 | img = self.transform(img)
204 |
205 | if self.target_transform is not None:
206 | target = self.target_transform(target)
207 |
208 | return {'image': image, 'class': torch.tensor([target])}
209 |
210 |
211 | def __len__(self) -> int:
212 | return len(self.index)
213 |
214 | def category_name(self, category_type: str, category_id: int) -> str:
215 | """
216 | Args:
217 | category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
218 | category_id(int): an index (class id) from this category
219 |
220 | Returns:
221 | the name of the category
222 | """
223 | if category_type == "full":
224 | return self.all_categories[category_id]
225 | else:
226 | if category_type not in self.categories_index:
227 | raise ValueError(f"Invalid category type '{category_type}'")
228 | else:
229 | for name, id in self.categories_index[category_type].items():
230 | if id == category_id:
231 | return name
232 | raise ValueError(f"Invalid category id {category_id} for {category_type}")
233 |
234 |
235 | def _check_integrity(self) -> bool:
236 | return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
237 |
238 | def download(self) -> None:
239 | if self._check_integrity():
240 | raise RuntimeError(
241 | f"The directory {self.root} already exists. "
242 | f"If you want to re-download or re-extract the images, delete the directory."
243 | )
244 |
245 | base_root = os.path.dirname(self.root)
246 |
247 | download_and_extract_archive(
248 | DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
249 | )
250 |
251 | orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
252 | if not os.path.exists(orig_dir_name):
253 | raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
254 | os.rename(orig_dir_name, self.root)
255 | print(f"Dataset version '{self.version}' has been downloaded and prepared for use")
256 |
257 |
258 | class INaturalistTrain(INaturalistBase):
259 | def __init__(self, root: str, resolution: Union[Tuple[int, int], int] = 256) -> None:
260 | transform = T.Compose([
261 | T.Resize(resolution),
262 | T.RandomCrop(resolution),
263 | T.RandomHorizontalFlip(),
264 | T.ToTensor()
265 | ])
266 |
267 | super().__init__(root=root, version='2021_train', transform=transform)
268 |
269 | class INaturalistValidation(INaturalistBase):
270 | def __init__(self, root: str, resolution: Union[Tuple[int, int], int] = 256) -> None:
271 | transform = T.Compose([
272 | T.Resize(resolution),
273 | T.CenterCrop(resolution),
274 | T.ToTensor()
275 | ])
276 |
277 | super().__init__(root=root, version='2021_valid', transform=transform)
278 |
--------------------------------------------------------------------------------
/enhancing/dataloader/lsun.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import PIL
8 | from typing import Any, Tuple, Union, List, Optional, Callable
9 | import subprocess
10 | from os.path import join, dirname, abspath, isfile, isdir
11 |
12 | import torch
13 | from torchvision import transforms as T
14 | from torchvision.datasets import LSUN
15 |
16 |
17 | class LSUNBase(LSUN):
18 | def __init__(self, root: str, classes: Union[Tuple[str, str]],
19 | transform: Optional[Callable] = None) -> None:
20 | super().__init__(root, classes, transform)
21 |
22 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
23 | image, target = super().__getitem__(index)
24 |
25 | return {'image': image, 'class': torch.tensor([target])}
26 |
27 |
28 | class LSUNTrain(LSUNBase):
29 | def __init__(self, root: str, classes: Union[Tuple[str, str]],
30 | resolution: Union[Tuple[int, int], int] = 256) -> None:
31 | transform = T.Compose([
32 | T.Resize(resolution),
33 | T.RandomCrop(resolution),
34 | T.RandomHorizontalFlip(),
35 | T.ToTensor()
36 | ])
37 |
38 | if classes not in ['train', 'val']:
39 | if not isinstance(classes, list):
40 | classes = [classes]
41 |
42 | classes = [class_+"_train" for class_ in classes]
43 | else:
44 | assert classes == 'train'
45 |
46 | super().__init__(root, classes, transform)
47 |
48 |
49 | class LSUNValidation(LSUNBase):
50 | def __init__(self, root: str, classes: Union[Tuple[str, str]],
51 | resolution: Union[Tuple[int, int], int] = 256) -> None:
52 | transform = T.Compose([
53 | T.Resize(resolution),
54 | T.CenterCrop(resolution),
55 | T.ToTensor()
56 | ])
57 |
58 | if classes not in ['train', 'val']:
59 | if not isinstance(classes, list):
60 | classes = [classes]
61 |
62 | classes = [class_+"_val" for class_ in classes]
63 | else:
64 | assert classes == 'val'
65 |
66 | super().__init__(root, classes, transform)
67 |
--------------------------------------------------------------------------------
/enhancing/dataloader/srimage.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from DALLE-pytorch (https://github.com/lucidrains/DALLE-pytorch)
7 | # Copyright (c) 2020 Phil Wang. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | from typing import Optional, Tuple, Callable, Union
11 | from pathlib import Path
12 | from random import randint, choice
13 | from omegaconf import OmegaConf
14 | import PIL
15 |
16 | from torch import nn
17 | from torch.utils.data import Dataset
18 | from torchvision import transforms as T
19 |
20 |
21 | class SRBase(Dataset):
22 | def __init__(self, folder: str, split: str, transform: Callable) -> None:
23 | super().__init__()
24 | path = Path(folder)/split
25 |
26 | image_files = [
27 | *path.glob('**/*.png'), *path.glob('**/*.jpg'),
28 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
29 | ]
30 |
31 | image_files = {image_file.stem: image_file for image_file in image_files}
32 | keys = image_files.keys()
33 |
34 | self.keys = list(keys)
35 | self.image_files = {k: v for k, v in image_files.items() if k in keys}
36 |
37 | self.hr_transform = transform
38 |
39 | def __len__(self):
40 | return len(self.keys)
41 |
42 | def random_sample(self):
43 | return self.__getitem__(randint(0, self.__len__() - 1))
44 |
45 | def sequential_sample(self, ind):
46 | if ind >= self.__len__() - 1:
47 | return self.__getitem__(0)
48 | return self.__getitem__(ind + 1)
49 |
50 | def skip_sample(self, ind):
51 | return self.sequential_sample(ind=ind)
52 |
53 | def pad(self, img: PIL.Image.Image) -> PIL.Image.Image:
54 | if isinstance(self.resolution, int):
55 | self.resolution = (self.resolution, self.resolution)
56 |
57 | assert img.size[0] <= self.resolution[1] and img.size[1] <= self.resolution[0]
58 | left = (self.resolution[1] - img.size[0]) // 2
59 | top = (self.resolution[0] - img.size[1]) // 2
60 | right = self.resolution[1] - img.size[0] - left
61 | bottom = self.resolution[0] - img.size[1] - top
62 |
63 | return T.functional.pad(img, (left, top, right, bottom))
64 |
65 | def __getitem__(self, ind):
66 | key = self.keys[ind]
67 | image_file = self.image_files[key]
68 |
69 | try:
70 | hr_img = PIL.Image.open(image_file)
71 | if hr_img.mode != 'RGB':
72 | hr_img = hr_img.convert('RGB')
73 |
74 | hr_tensor = self.hr_transform(hr_img)
75 |
76 | down_size = (hr_tensor.shape[1]//self.downscale, hr_tensor.shape[2]//self.downscale)
77 | lr_tensor = T.Resize(down_size, 3)(hr_tensor)
78 | except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
79 | print(f"An exception occurred trying to load file {image_file}.")
80 | print(f"Skipping index {ind}")
81 | return self.skip_sample(ind)
82 |
83 | # Success
84 | return {'low resolution': lr_tensor, 'high resolution': hr_tensor}
85 |
86 |
87 | class SRTrain(SRBase):
88 | def __init__(self, folder: str,
89 | resolution: Union[Tuple[int, int], int] = 2048,
90 | crop_resolution: Union[Tuple[int, int], int] = 512,
91 | downscale: int = 4) -> None:
92 | assert resolution % downscale == 0
93 | self.resolution = resolution
94 | self.downscale = downscale
95 |
96 | transform = T.Compose([
97 | T.RandomCrop(crop_resolution),
98 | T.Lambda(self.pad),
99 | T.RandomHorizontalFlip(),
100 | T.RandomVerticalFlip(),
101 | T.ToTensor()
102 | ])
103 |
104 | super().__init__(folder, 'train', transform)
105 |
106 |
107 | class SRValidation(SRBase):
108 | def __init__(self, folder: str,
109 | resolution: Union[Tuple[int, int], int] = 2048,
110 | downscale: int = 4) -> None:
111 | assert resolution % downscale == 0
112 | self.resolution = resolution
113 | self.downscale = downscale
114 |
115 | transform = T.Compose([
116 | T.Lambda(self.pad),
117 | T.ToTensor()
118 | ])
119 |
120 | super().__init__(folder, 'val', transform)
121 |
122 |
--------------------------------------------------------------------------------
/enhancing/dataloader/textimage.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from DALLE-pytorch (https://github.com/lucidrains/DALLE-pytorch)
7 | # Copyright (c) 2020 Phil Wang. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | from typing import Optional, Union, Callable, Tuple, Any
11 | from pathlib import Path
12 | from random import randint, choice
13 | from omegaconf import OmegaConf
14 | import PIL
15 |
16 | import torch
17 | from torch.utils.data import Dataset
18 | from torchvision import transforms as T
19 |
20 | from ..utils.general import initialize_from_config
21 |
22 |
23 | class TextImageBase(Dataset):
24 | def __init__(self, folder: str, split: str,
25 | tokenizer: OmegaConf,
26 | transform: Callable) -> None:
27 | super().__init__()
28 | path = Path(folder)/split
29 |
30 | text_files = [*path.glob('**/*.txt')]
31 | image_files = [
32 | *path.glob('**/*.png'), *path.glob('**/*.jpg'),
33 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
34 | ]
35 |
36 | text_files = {text_file.stem: text_file for text_file in text_files}
37 | image_files = {image_file.stem: image_file for image_file in image_files}
38 |
39 | keys = (image_files.keys() & text_files.keys())
40 |
41 | self.keys = list(keys)
42 | self.text_files = {k: v for k, v in text_files.items() if k in keys}
43 | self.image_files = {k: v for k, v in image_files.items() if k in keys}
44 | self.tokenizer = initialize_from_config(tokenizer)
45 | self.image_transform = transform
46 |
47 | def __len__(self) -> int:
48 | return len(self.keys)
49 |
50 | def random_sample(self) -> Tuple[Any, Any]:
51 | return self.__getitem__(randint(0, self.__len__() - 1))
52 |
53 | def sequential_sample(self, ind: int) -> Tuple[Any, Any]:
54 | if ind >= self.__len__() - 1:
55 | return self.__getitem__(0)
56 | return self.__getitem__(ind + 1)
57 |
58 | def skip_sample(self, ind: int) -> Tuple[Any, Any]:
59 | return self.sequential_sample(ind=ind)
60 |
61 | def __getitem__(self, ind: int) -> Tuple[Any, Any]:
62 | key = self.keys[ind]
63 |
64 | text_file = self.text_files[key]
65 | image_file = self.image_files[key]
66 |
67 | descriptions = text_file.read_text().split('\n')
68 | descriptions = list(filter(lambda t: len(t) > 0, descriptions))
69 |
70 | try:
71 | description = choice(descriptions)
72 | except IndexError as zero_captions_in_file_ex:
73 | print(f"An exception occurred trying to load file {text_file}.")
74 | print(f"Skipping index {ind}")
75 | return self.skip_sample(ind)
76 |
77 | tokenized_text = self.tokenizer.tokenize(description).squeeze(0)
78 | try:
79 | image = PIL.Image.open(image_file)
80 | if image.mode != 'RGB':
81 | image = image.convert('RGB')
82 | image_tensor = self.image_transform(image)
83 | except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
84 | print(f"An exception occurred trying to load file {image_file}.")
85 | print(f"Skipping index {ind}")
86 | return self.skip_sample(ind)
87 |
88 | # Success
89 | return {"caption": tokenized_text, "image": image_tensor}
90 |
91 |
92 | class TextImageTrain(TextImageBase):
93 | def __init__(self, folder: str,
94 | tokenizer: OmegaConf,
95 | resolution: Union[Tuple[int, int], int] = 256) -> None:
96 | transform = T.Compose([
97 | T.Resize(resolution),
98 | T.RandomCrop(resolution),
99 | T.ToTensor(),
100 | ])
101 |
102 | super().__init__(folder, 'train', tokenizer, transform)
103 |
104 |
105 | class TextImageValidation(TextImageBase):
106 | def __init__(self, folder: str,
107 | tokenizer: OmegaConf,
108 | resolution: Union[Tuple[int, int], int] = 256) -> None:
109 | if isinstance(resolution, int):
110 | resolution = [resolution, resolution]
111 |
112 | transform = T.Compose([
113 | T.Resize(resolution),
114 | T.CenterCrop(resolution),
115 | T.ToTensor(),
116 | ])
117 |
118 | super().__init__(folder, 'val', tokenizer, transform)
119 |
--------------------------------------------------------------------------------
/enhancing/losses/layers.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 | # Modified from StyleGAN2-Pytorch (https://github.com/rosinality/stylegan2-pytorch)
6 | # Copyright (c) 2019 Kim Seonghyeon. All Rights Reserved.
7 | # ------------------------------------------------------------------------------------
8 |
9 |
10 | from math import log2, sqrt
11 | from functools import partial
12 | from typing import Optional, Union, Tuple, List
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from kornia.filters import filter2d
18 |
19 | from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
20 |
21 |
22 | def hinge_d_loss(logits_fake: torch.FloatTensor, logits_real: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
23 | loss_fake = - logits_fake.mean() * 2 if logits_real is None else F.relu(1. + logits_fake).mean()
24 | loss_real = 0 if logits_real is None else F.relu(1. - logits_real).mean()
25 |
26 | return 0.5 * (loss_real + loss_fake)
27 |
28 |
29 | def vanilla_d_loss(logits_fake: torch.FloatTensor, logits_real: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
30 | loss_fake = F.softplus(-logits_fake).mean() * 2 if logits_real is None else F.softplus(logits_fake).mean()
31 | loss_real = 0 if logits_real is None else F.softplus(-logits_real).mean()
32 |
33 | return 0.5 * (loss_real + loss_fake)
34 |
35 |
36 | def least_square_d_loss(logits_fake: torch.FloatTensor, logits_real: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
37 | loss_fake = logits_fake.pow(2).mean() * 2 if logits_real is None else (1 + logits_fake).pow(2).mean()
38 | loss_real = 0 if logits_real is None else (1 - logits_real).pow(2).mean()
39 |
40 | return 0.5 * (loss_real + loss_fake)
41 |
42 |
43 | def weights_init(m: nn.Module) -> None:
44 | classname = m.__class__.__name__
45 | if classname.find('Conv') != -1:
46 | nn.init.normal_(m.weight.data, 0.0, 0.02)
47 | elif classname.find('BatchNorm') != -1:
48 | nn.init.normal_(m.weight.data, 1.0, 0.02)
49 | nn.init.constant_(m.bias.data, 0)
50 |
51 |
52 | class ActNorm(nn.Module):
53 | def __init__(self, num_features: int,
54 | logdet: Optional[bool] = False,
55 | affine: Optional[bool] = True,
56 | allow_reverse_init: Optional[bool] = False) -> None:
57 | assert affine
58 | super().__init__()
59 | self.logdet = logdet
60 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
61 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
62 | self.allow_reverse_init = allow_reverse_init
63 |
64 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
65 |
66 | def initialize(self, input: torch.FloatTensor) -> None:
67 | with torch.no_grad():
68 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
69 | mean = (
70 | flatten.mean(1)
71 | .unsqueeze(1)
72 | .unsqueeze(2)
73 | .unsqueeze(3)
74 | .permute(1, 0, 2, 3)
75 | )
76 | std = (
77 | flatten.std(1)
78 | .unsqueeze(1)
79 | .unsqueeze(2)
80 | .unsqueeze(3)
81 | .permute(1, 0, 2, 3)
82 | )
83 |
84 | self.loc.data.copy_(-mean)
85 | self.scale.data.copy_(1 / (std + 1e-6))
86 |
87 | def forward(self, input: torch.FloatTensor, reverse: Optional[bool] = False) -> Union[torch.FloatTensor, Tuple]:
88 | if reverse:
89 | return self.reverse(input)
90 | if len(input.shape) == 2:
91 | input = input[:,:,None,None]
92 | squeeze = True
93 | else:
94 | squeeze = False
95 |
96 | _, _, height, width = input.shape
97 |
98 | if self.training and self.initialized.item() == 0:
99 | self.initialize(input)
100 | self.initialized.fill_(1)
101 |
102 | h = self.scale * (input + self.loc)
103 |
104 | if squeeze:
105 | h = h.squeeze(-1).squeeze(-1)
106 |
107 | if self.logdet:
108 | log_abs = torch.log(torch.abs(self.scale))
109 | logdet = height*width*torch.sum(log_abs)
110 | logdet = logdet * torch.ones(input.shape[0]).to(input)
111 | return h, logdet
112 |
113 | return h
114 |
115 | def reverse(self, output: torch.FloatTensor) -> torch.FloatTensor:
116 | if self.training and self.initialized.item() == 0:
117 | if not self.allow_reverse_init:
118 | raise RuntimeError(
119 | "Initializing ActNorm in reverse direction is "
120 | "disabled by default. Use allow_reverse_init=True to enable."
121 | )
122 | else:
123 | self.initialize(output)
124 | self.initialized.fill_(1)
125 |
126 | if len(output.shape) == 2:
127 | output = output[:,:,None,None]
128 | squeeze = True
129 | else:
130 | squeeze = False
131 |
132 | h = output / self.scale - self.loc
133 |
134 | if squeeze:
135 | h = h.squeeze(-1).squeeze(-1)
136 |
137 | return h
138 |
139 |
140 | class Blur(nn.Module):
141 | def __init__(self, kernel, pad, upsample_factor=1):
142 | super().__init__()
143 |
144 | kernel = torch.tensor(kernel, dtype=torch.float32)
145 | if kernel.ndim == 1:
146 | kernel = kernel[None, :] * kernel[:, None]
147 |
148 | kernel /= kernel.sum()
149 |
150 | if upsample_factor > 1:
151 | kernel = kernel * (upsample_factor ** 2)
152 |
153 | self.register_buffer("kernel", kernel)
154 |
155 | self.pad = pad
156 |
157 | def forward(self, input):
158 | out = upfirdn2d(input, self.kernel, pad=self.pad)
159 |
160 | return out
161 |
162 |
163 | class EqualConv2d(nn.Module):
164 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
165 | super().__init__()
166 |
167 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
168 | self.bias = nn.Parameter(torch.zeros(out_channel)) if bias else None
169 |
170 | self.scale = 1 / sqrt(in_channel * kernel_size ** 2)
171 |
172 | self.stride = stride
173 | self.padding = padding
174 |
175 | def forward(self, input):
176 | out = conv2d_gradfix.conv2d(
177 | input,
178 | self.weight * self.scale,
179 | bias=self.bias,
180 | stride=self.stride,
181 | padding=self.padding,
182 | )
183 |
184 | return out
185 |
186 |
187 | class EqualLinear(nn.Module):
188 | def __init__(
189 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
190 | ):
191 | super().__init__()
192 |
193 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
194 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) if bias else None
195 |
196 | self.activation = activation
197 |
198 | self.scale = (1 / sqrt(in_dim)) * lr_mul
199 | self.lr_mul = lr_mul
200 |
201 | def forward(self, input):
202 | if self.activation:
203 | out = F.linear(input, self.weight * self.scale)
204 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
205 |
206 | else:
207 | out = F.linear(
208 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
209 | )
210 |
211 | return out
212 |
213 |
214 | class ConvLayer(nn.Sequential):
215 | def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True):
216 | layers = []
217 |
218 | if downsample:
219 | factor = 2
220 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
221 | pad0 = (p + 1) // 2
222 | pad1 = p // 2
223 |
224 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
225 |
226 | stride = 2
227 | self.padding = 0
228 | else:
229 | stride = 1
230 | self.padding = kernel_size // 2
231 |
232 | layers.append(
233 | EqualConv2d(
234 | in_channel, out_channel,
235 | kernel_size, padding=self.padding,
236 | stride=stride, bias=bias and not activate
237 | )
238 | )
239 |
240 | if activate:
241 | layers.append(FusedLeakyReLU(out_channel, bias=bias))
242 |
243 | super().__init__(*layers)
244 |
245 |
246 | class StyleBlock(nn.Module):
247 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
248 | super().__init__()
249 |
250 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
251 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
252 |
253 | self.skip = ConvLayer(
254 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
255 | )
256 |
257 | def forward(self, input):
258 | out = self.conv1(input)
259 | out = self.conv2(out)
260 |
261 | skip = self.skip(input)
262 | out = (out + skip) / sqrt(2)
263 |
264 | return out
265 |
266 |
267 | class PatchDiscriminator(nn.Module):
268 | """Defines a PatchGAN discriminator as in Pix2Pix
269 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
270 | """
271 | def __init__(self, input_nc: int = 3, ndf: int = 64, n_layers: int = 3, use_actnorm: bool = False) -> None:
272 | """Construct a PatchGAN discriminator
273 | Parameters:
274 | input_nc (int) -- the number of channels in input images
275 | ndf (int) -- the number of filters in the last conv layer
276 | n_layers (int) -- the number of conv layers in the discriminator
277 | norm_layer -- normalization layer
278 | """
279 | super().__init__()
280 | if not use_actnorm:
281 | norm_layer = nn.BatchNorm2d
282 | else:
283 | norm_layer = ActNorm
284 | if type(norm_layer) == partial: # no need to use bias as BatchNorm2d has affine parameters
285 | use_bias = norm_layer.func != nn.BatchNorm2d
286 | else:
287 | use_bias = norm_layer != nn.BatchNorm2d
288 |
289 | kw = 4
290 | padw = 1
291 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
292 | nf_mult = 1
293 | nf_mult_prev = 1
294 | for n in range(1, n_layers): # gradually increase the number of filters
295 | nf_mult_prev = nf_mult
296 | nf_mult = min(2 ** n, 8)
297 | sequence += [
298 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
299 | norm_layer(ndf * nf_mult),
300 | nn.LeakyReLU(0.2, True)
301 | ]
302 |
303 | nf_mult_prev = nf_mult
304 | nf_mult = min(2 ** n_layers, 8)
305 | sequence += [
306 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
307 | norm_layer(ndf * nf_mult),
308 | nn.LeakyReLU(0.2, True)
309 | ]
310 |
311 | sequence += [
312 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
313 | self.main = nn.Sequential(*sequence)
314 |
315 | self.apply(weights_init)
316 |
317 | def forward(self, input: torch.FloatTensor) -> torch.FloatTensor:
318 | """Standard forward."""
319 | return self.main(input)
320 |
321 |
322 | class StyleDiscriminator(nn.Module):
323 | def __init__(self, size=256, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
324 | super().__init__()
325 |
326 | channels = {
327 | 4: 512,
328 | 8: 512,
329 | 16: 512,
330 | 32: 512,
331 | 64: 256 * channel_multiplier,
332 | 128: 128 * channel_multiplier,
333 | 256: 64 * channel_multiplier,
334 | 512: 32 * channel_multiplier,
335 | 1024: 16 * channel_multiplier,
336 | }
337 |
338 | log_size = int(log2(size))
339 | in_channel = channels[size]
340 |
341 | blocks = [ConvLayer(3, channels[size], 1)]
342 | for i in range(log_size, 2, -1):
343 | out_channel = channels[2 ** (i - 1)]
344 | blocks.append(StyleBlock(in_channel, out_channel, blur_kernel))
345 | in_channel = out_channel
346 |
347 | self.blocks = nn.Sequential(*blocks)
348 |
349 | self.stddev_group = 4
350 | self.stddev_feat = 1
351 |
352 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
353 | self.final_linear = nn.Sequential(
354 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
355 | EqualLinear(channels[4], 1),
356 | )
357 |
358 | def forward(self, x):
359 | out = self.blocks(x)
360 | batch, channel, height, width = out.shape
361 |
362 | group = min(batch, self.stddev_group)
363 | group = batch//(batch//group)
364 |
365 | stddev = out.view(
366 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
367 | )
368 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
369 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
370 | stddev = stddev.repeat(group, 1, height, width)
371 | out = torch.cat([out, stddev], 1)
372 |
373 | out = self.final_conv(out)
374 | out = out.view(out.shape[0], -1)
375 | out = self.final_linear(out)
376 |
377 | return out.squeeze()
378 |
--------------------------------------------------------------------------------
/enhancing/losses/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/enhancing/losses/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import warnings
3 |
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 |
8 | enabled = True
9 | weight_gradients_disabled = False
10 |
11 |
12 | @contextlib.contextmanager
13 | def no_weight_gradients():
14 | global weight_gradients_disabled
15 |
16 | old = weight_gradients_disabled
17 | weight_gradients_disabled = True
18 | yield
19 | weight_gradients_disabled = old
20 |
21 |
22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23 | if could_use_op(input):
24 | return conv2d_gradfix(
25 | transpose=False,
26 | weight_shape=weight.shape,
27 | stride=stride,
28 | padding=padding,
29 | output_padding=0,
30 | dilation=dilation,
31 | groups=groups,
32 | ).apply(input, weight, bias)
33 |
34 | return F.conv2d(
35 | input=input,
36 | weight=weight,
37 | bias=bias,
38 | stride=stride,
39 | padding=padding,
40 | dilation=dilation,
41 | groups=groups,
42 | )
43 |
44 |
45 | def conv_transpose2d(
46 | input,
47 | weight,
48 | bias=None,
49 | stride=1,
50 | padding=0,
51 | output_padding=0,
52 | groups=1,
53 | dilation=1,
54 | ):
55 | if could_use_op(input):
56 | return conv2d_gradfix(
57 | transpose=True,
58 | weight_shape=weight.shape,
59 | stride=stride,
60 | padding=padding,
61 | output_padding=output_padding,
62 | groups=groups,
63 | dilation=dilation,
64 | ).apply(input, weight, bias)
65 |
66 | return F.conv_transpose2d(
67 | input=input,
68 | weight=weight,
69 | bias=bias,
70 | stride=stride,
71 | padding=padding,
72 | output_padding=output_padding,
73 | dilation=dilation,
74 | groups=groups,
75 | )
76 |
77 |
78 | def could_use_op(input):
79 | if (not enabled) or (not torch.backends.cudnn.enabled):
80 | return False
81 |
82 | if input.device.type != "cuda":
83 | return False
84 |
85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86 | return True
87 |
88 | warnings.warn(
89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90 | )
91 |
92 | return False
93 |
94 |
95 | def ensure_tuple(xs, ndim):
96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97 |
98 | return xs
99 |
100 |
101 | conv2d_gradfix_cache = dict()
102 |
103 |
104 | def conv2d_gradfix(
105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
106 | ):
107 | ndim = 2
108 | weight_shape = tuple(weight_shape)
109 | stride = ensure_tuple(stride, ndim)
110 | padding = ensure_tuple(padding, ndim)
111 | output_padding = ensure_tuple(output_padding, ndim)
112 | dilation = ensure_tuple(dilation, ndim)
113 |
114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115 | if key in conv2d_gradfix_cache:
116 | return conv2d_gradfix_cache[key]
117 |
118 | common_kwargs = dict(
119 | stride=stride, padding=padding, dilation=dilation, groups=groups
120 | )
121 |
122 | def calc_output_padding(input_shape, output_shape):
123 | if transpose:
124 | return [0, 0]
125 |
126 | return [
127 | input_shape[i + 2]
128 | - (output_shape[i + 2] - 1) * stride[i]
129 | - (1 - 2 * padding[i])
130 | - dilation[i] * (weight_shape[i + 2] - 1)
131 | for i in range(ndim)
132 | ]
133 |
134 | class Conv2d(autograd.Function):
135 | @staticmethod
136 | def forward(ctx, input, weight, bias):
137 | if not transpose:
138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139 |
140 | else:
141 | out = F.conv_transpose2d(
142 | input=input,
143 | weight=weight,
144 | bias=bias,
145 | output_padding=output_padding,
146 | **common_kwargs,
147 | )
148 |
149 | ctx.save_for_backward(input, weight)
150 |
151 | return out
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | input, weight = ctx.saved_tensors
156 | grad_input, grad_weight, grad_bias = None, None, None
157 |
158 | if ctx.needs_input_grad[0]:
159 | p = calc_output_padding(
160 | input_shape=input.shape, output_shape=grad_output.shape
161 | )
162 | grad_input = conv2d_gradfix(
163 | transpose=(not transpose),
164 | weight_shape=weight_shape,
165 | output_padding=p,
166 | **common_kwargs,
167 | ).apply(grad_output, weight, None)
168 |
169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
171 |
172 | if ctx.needs_input_grad[2]:
173 | grad_bias = grad_output.sum((0, 2, 3))
174 |
175 | return grad_input, grad_weight, grad_bias
176 |
177 | class Conv2dGradWeight(autograd.Function):
178 | @staticmethod
179 | def forward(ctx, grad_output, input):
180 | op = torch._C._jit_get_operation(
181 | "aten::cudnn_convolution_backward_weight"
182 | if not transpose
183 | else "aten::cudnn_convolution_transpose_backward_weight"
184 | )
185 | flags = [
186 | torch.backends.cudnn.benchmark,
187 | torch.backends.cudnn.deterministic,
188 | torch.backends.cudnn.allow_tf32,
189 | ]
190 | grad_weight = op(
191 | weight_shape,
192 | grad_output,
193 | input,
194 | padding,
195 | stride,
196 | dilation,
197 | groups,
198 | *flags,
199 | )
200 | ctx.save_for_backward(grad_output, input)
201 |
202 | return grad_weight
203 |
204 | @staticmethod
205 | def backward(ctx, grad_grad_weight):
206 | grad_output, input = ctx.saved_tensors
207 | grad_grad_output, grad_grad_input = None, None
208 |
209 | if ctx.needs_input_grad[0]:
210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211 |
212 | if ctx.needs_input_grad[1]:
213 | p = calc_output_padding(
214 | input_shape=input.shape, output_shape=grad_output.shape
215 | )
216 | grad_grad_input = conv2d_gradfix(
217 | transpose=(not transpose),
218 | weight_shape=weight_shape,
219 | output_padding=p,
220 | **common_kwargs,
221 | ).apply(grad_output, grad_grad_weight, None)
222 |
223 | return grad_grad_output, grad_grad_input
224 |
225 | conv2d_gradfix_cache[key] = Conv2d
226 |
227 | return Conv2d
228 |
--------------------------------------------------------------------------------
/enhancing/losses/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input.contiguous(),
51 | gradgrad_bias,
52 | out,
53 | 3,
54 | 1,
55 | ctx.negative_slope,
56 | ctx.scale,
57 | )
58 |
59 | return gradgrad_out, None, None, None, None
60 |
61 |
62 | class FusedLeakyReLUFunction(Function):
63 | @staticmethod
64 | def forward(ctx, input, bias, negative_slope, scale):
65 | empty = input.new_empty(0)
66 |
67 | ctx.bias = bias is not None
68 |
69 | if bias is None:
70 | bias = empty
71 |
72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73 | ctx.save_for_backward(out)
74 | ctx.negative_slope = negative_slope
75 | ctx.scale = scale
76 |
77 | return out
78 |
79 | @staticmethod
80 | def backward(ctx, grad_output):
81 | out, = ctx.saved_tensors
82 |
83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85 | )
86 |
87 | if not ctx.bias:
88 | grad_bias = None
89 |
90 | return grad_input, grad_bias, None, None
91 |
92 |
93 | class FusedLeakyReLU(nn.Module):
94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95 | super().__init__()
96 |
97 | if bias:
98 | self.bias = nn.Parameter(torch.zeros(channel))
99 |
100 | else:
101 | self.bias = None
102 |
103 | self.negative_slope = negative_slope
104 | self.scale = scale
105 |
106 | def forward(self, input):
107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108 |
109 |
110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111 | if input.device.type == "cpu":
112 | if bias is not None:
113 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
114 | return (
115 | F.leaky_relu(
116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117 | )
118 | * scale
119 | )
120 |
121 | else:
122 | return F.leaky_relu(input, negative_slope=0.2) * scale
123 |
124 | else:
125 | return FusedLeakyReLUFunction.apply(
126 | input.contiguous(), bias, negative_slope, scale
127 | )
128 |
--------------------------------------------------------------------------------
/enhancing/losses/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 |
5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6 | const torch::Tensor &bias,
7 | const torch::Tensor &refer, int act, int grad,
8 | float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) \
11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12 | #define CHECK_CONTIGUOUS(x) \
13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14 | #define CHECK_INPUT(x) \
15 | CHECK_CUDA(x); \
16 | CHECK_CONTIGUOUS(x)
17 |
18 | torch::Tensor fused_bias_act(const torch::Tensor &input,
19 | const torch::Tensor &bias,
20 | const torch::Tensor &refer, int act, int grad,
21 | float alpha, float scale) {
22 | CHECK_INPUT(input);
23 | CHECK_INPUT(bias);
24 |
25 | at::DeviceGuard guard(input.device());
26 |
27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28 | }
29 |
30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32 | }
--------------------------------------------------------------------------------
/enhancing/losses/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 |
15 | #include
16 | #include
17 |
18 | template
19 | static __global__ void
20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22 | scalar_t scale, int loop_x, int size_x, int step_b,
23 | int size_b, int use_bias, int use_ref) {
24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25 |
26 | scalar_t zero = 0.0;
27 |
28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29 | loop_idx++, xi += blockDim.x) {
30 | scalar_t x = p_x[xi];
31 |
32 | if (use_bias) {
33 | x += p_b[(xi / step_b) % size_b];
34 | }
35 |
36 | scalar_t ref = use_ref ? p_ref[xi] : zero;
37 |
38 | scalar_t y;
39 |
40 | switch (act * 10 + grad) {
41 | default:
42 | case 10:
43 | y = x;
44 | break;
45 | case 11:
46 | y = x;
47 | break;
48 | case 12:
49 | y = 0.0;
50 | break;
51 |
52 | case 30:
53 | y = (x > 0.0) ? x : x * alpha;
54 | break;
55 | case 31:
56 | y = (ref > 0.0) ? x : x * alpha;
57 | break;
58 | case 32:
59 | y = 0.0;
60 | break;
61 | }
62 |
63 | out[xi] = y * scale;
64 | }
65 | }
66 |
67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68 | const torch::Tensor &bias,
69 | const torch::Tensor &refer, int act, int grad,
70 | float alpha, float scale) {
71 | int curDevice = -1;
72 | cudaGetDevice(&curDevice);
73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74 |
75 | auto x = input.contiguous();
76 | auto b = bias.contiguous();
77 | auto ref = refer.contiguous();
78 |
79 | int use_bias = b.numel() ? 1 : 0;
80 | int use_ref = ref.numel() ? 1 : 0;
81 |
82 | int size_x = x.numel();
83 | int size_b = b.numel();
84 | int step_b = 1;
85 |
86 | for (int i = 1 + 1; i < x.dim(); i++) {
87 | step_b *= x.size(i);
88 | }
89 |
90 | int loop_x = 4;
91 | int block_size = 4 * 32;
92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93 |
94 | auto y = torch::empty_like(x);
95 |
96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97 | x.scalar_type(), "fused_bias_act_kernel", [&] {
98 | fused_bias_act_kernel<<>>(
99 | y.data_ptr(), x.data_ptr(),
100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha,
101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102 | });
103 |
104 | return y;
105 | }
--------------------------------------------------------------------------------
/enhancing/losses/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5 | const torch::Tensor &kernel, int up_x, int up_y,
6 | int down_x, int down_y, int pad_x0, int pad_x1,
7 | int pad_y0, int pad_y1);
8 |
9 | #define CHECK_CUDA(x) \
10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) \
12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) \
14 | CHECK_CUDA(x); \
15 | CHECK_CONTIGUOUS(x)
16 |
17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18 | int up_x, int up_y, int down_x, int down_y, int pad_x0,
19 | int pad_x1, int pad_y0, int pad_y1) {
20 | CHECK_INPUT(input);
21 | CHECK_INPUT(kernel);
22 |
23 | at::DeviceGuard guard(input.device());
24 |
25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26 | pad_y0, pad_y1);
27 | }
28 |
29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31 | }
--------------------------------------------------------------------------------
/enhancing/losses/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import os
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | upfirdn2d_op = load(
12 | "upfirdn2d",
13 | sources=[
14 | os.path.join(module_path, "upfirdn2d.cpp"),
15 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class UpFirDn2dBackward(Function):
21 | @staticmethod
22 | def forward(
23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24 | ):
25 |
26 | up_x, up_y = up
27 | down_x, down_y = down
28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29 |
30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31 |
32 | grad_input = upfirdn2d_op.upfirdn2d(
33 | grad_output,
34 | grad_kernel,
35 | down_x,
36 | down_y,
37 | up_x,
38 | up_y,
39 | g_pad_x0,
40 | g_pad_x1,
41 | g_pad_y0,
42 | g_pad_y1,
43 | )
44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45 |
46 | ctx.save_for_backward(kernel)
47 |
48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
49 |
50 | ctx.up_x = up_x
51 | ctx.up_y = up_y
52 | ctx.down_x = down_x
53 | ctx.down_y = down_y
54 | ctx.pad_x0 = pad_x0
55 | ctx.pad_x1 = pad_x1
56 | ctx.pad_y0 = pad_y0
57 | ctx.pad_y1 = pad_y1
58 | ctx.in_size = in_size
59 | ctx.out_size = out_size
60 |
61 | return grad_input
62 |
63 | @staticmethod
64 | def backward(ctx, gradgrad_input):
65 | kernel, = ctx.saved_tensors
66 |
67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68 |
69 | gradgrad_out = upfirdn2d_op.upfirdn2d(
70 | gradgrad_input,
71 | kernel,
72 | ctx.up_x,
73 | ctx.up_y,
74 | ctx.down_x,
75 | ctx.down_y,
76 | ctx.pad_x0,
77 | ctx.pad_x1,
78 | ctx.pad_y0,
79 | ctx.pad_y1,
80 | )
81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82 | gradgrad_out = gradgrad_out.view(
83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84 | )
85 |
86 | return gradgrad_out, None, None, None, None, None, None, None, None
87 |
88 |
89 | class UpFirDn2d(Function):
90 | @staticmethod
91 | def forward(ctx, input, kernel, up, down, pad):
92 | up_x, up_y = up
93 | down_x, down_y = down
94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
95 |
96 | kernel_h, kernel_w = kernel.shape
97 | batch, channel, in_h, in_w = input.shape
98 | ctx.in_size = input.shape
99 |
100 | input = input.reshape(-1, in_h, in_w, 1)
101 |
102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103 |
104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106 | ctx.out_size = (out_h, out_w)
107 |
108 | ctx.up = (up_x, up_y)
109 | ctx.down = (down_x, down_y)
110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111 |
112 | g_pad_x0 = kernel_w - pad_x0 - 1
113 | g_pad_y0 = kernel_h - pad_y0 - 1
114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116 |
117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118 |
119 | out = upfirdn2d_op.upfirdn2d(
120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121 | )
122 | # out = out.view(major, out_h, out_w, minor)
123 | out = out.view(-1, channel, out_h, out_w)
124 |
125 | return out
126 |
127 | @staticmethod
128 | def backward(ctx, grad_output):
129 | kernel, grad_kernel = ctx.saved_tensors
130 |
131 | grad_input = None
132 |
133 | if ctx.needs_input_grad[0]:
134 | grad_input = UpFirDn2dBackward.apply(
135 | grad_output,
136 | kernel,
137 | grad_kernel,
138 | ctx.up,
139 | ctx.down,
140 | ctx.pad,
141 | ctx.g_pad,
142 | ctx.in_size,
143 | ctx.out_size,
144 | )
145 |
146 | return grad_input, None, None, None, None
147 |
148 |
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150 | if not isinstance(up, abc.Iterable):
151 | up = (up, up)
152 |
153 | if not isinstance(down, abc.Iterable):
154 | down = (down, down)
155 |
156 | if len(pad) == 2:
157 | pad = (pad[0], pad[1], pad[0], pad[1])
158 |
159 | if input.device.type == "cpu":
160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161 |
162 | else:
163 | out = UpFirDn2d.apply(input, kernel, up, down, pad)
164 |
165 | return out
166 |
167 |
168 | def upfirdn2d_native(
169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170 | ):
171 | _, channel, in_h, in_w = input.shape
172 | input = input.reshape(-1, in_h, in_w, 1)
173 |
174 | _, in_h, in_w, minor = input.shape
175 | kernel_h, kernel_w = kernel.shape
176 |
177 | out = input.view(-1, in_h, 1, in_w, 1, minor)
178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180 |
181 | out = F.pad(
182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183 | )
184 | out = out[
185 | :,
186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188 | :,
189 | ]
190 |
191 | out = out.permute(0, 3, 1, 2)
192 | out = out.reshape(
193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194 | )
195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196 | out = F.conv2d(out, w)
197 | out = out.reshape(
198 | -1,
199 | minor,
200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202 | )
203 | out = out.permute(0, 2, 3, 1)
204 | out = out[:, ::down_y, ::down_x, :]
205 |
206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208 |
209 | return out.view(-1, channel, out_h, out_w)
210 |
--------------------------------------------------------------------------------
/enhancing/losses/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/enhancing/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BCELoss(nn.Module):
11 | def forward(self, prediction, target):
12 | loss = F.binary_cross_entropy_with_logits(prediction,target)
13 |
14 | return loss, {}
15 |
16 |
17 | class BCELossWithQuant(nn.Module):
18 | def __init__(self, codebook_weight=1.):
19 | super().__init__()
20 | self.codebook_weight = codebook_weight
21 |
22 | def forward(self, qloss, target, prediction, split):
23 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
24 | loss = bce_loss + self.codebook_weight*qloss
25 |
26 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
27 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
28 | "{}/quant_loss".format(split): qloss.detach().mean()
29 | }
30 |
31 | return loss, log
32 |
--------------------------------------------------------------------------------
/enhancing/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | from omegaconf import OmegaConf
7 | from typing import Optional, Tuple
8 |
9 | import lpips
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from .layers import *
15 |
16 |
17 | class DummyLoss(nn.Module):
18 | def __init__(self) -> None:
19 | super().__init__()
20 |
21 |
22 | class VQLPIPS(nn.Module):
23 | def __init__(self, codebook_weight: float = 1.0,
24 | loglaplace_weight: float = 1.0,
25 | loggaussian_weight: float = 1.0,
26 | perceptual_weight: float = 1.0) -> None:
27 |
28 | super().__init__()
29 | self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False)
30 |
31 | self.codebook_weight = codebook_weight
32 | self.loglaplace_weight = loglaplace_weight
33 | self.loggaussian_weight = loggaussian_weight
34 | self.perceptual_weight = perceptual_weight
35 |
36 | def forward(self, codebook_loss: torch.FloatTensor, inputs: torch.FloatTensor, reconstructions: torch.FloatTensor, optimizer_idx: int,
37 | global_step: int, batch_idx: int, last_layer: Optional[nn.Module] = None, split: Optional[str] = "train") -> Tuple:
38 | inputs = inputs.contiguous()
39 | reconstructions = reconstructions.contiguous()
40 |
41 | loglaplace_loss = (reconstructions - inputs).abs().mean()
42 | loggaussian_loss = (reconstructions - inputs).pow(2).mean()
43 | perceptual_loss = self.perceptual_loss(inputs*2-1, reconstructions*2-1).mean()
44 |
45 | nll_loss = self.loglaplace_weight * loglaplace_loss + self.loggaussian_weight * loggaussian_loss + self.perceptual_weight * perceptual_loss
46 | loss = nll_loss + self.codebook_weight * codebook_loss
47 |
48 | log = {"{}/total_loss".format(split): loss.clone().detach(),
49 | "{}/quant_loss".format(split): codebook_loss.detach(),
50 | "{}/rec_loss".format(split): nll_loss.detach(),
51 | "{}/loglaplace_loss".format(split): loglaplace_loss.detach(),
52 | "{}/loggaussian_loss".format(split): loggaussian_loss.detach(),
53 | "{}/perceptual_loss".format(split): perceptual_loss.detach()
54 | }
55 |
56 | return loss, log
57 |
58 |
59 | class VQLPIPSWithDiscriminator(nn.Module):
60 | def __init__(self, disc_start: int = 0,
61 | disc_loss: str = 'vanilla',
62 | disc_params: Optional[OmegaConf] = dict(),
63 | codebook_weight: float = 1.0,
64 | loglaplace_weight: float = 1.0,
65 | loggaussian_weight: float = 1.0,
66 | perceptual_weight: float = 1.0,
67 | adversarial_weight: float = 1.0,
68 | use_adaptive_adv: bool = False,
69 | r1_gamma: float = 10,
70 | do_r1_every: int = 16) -> None:
71 |
72 | super().__init__()
73 | assert disc_loss in ["hinge", "vanilla", "least_square"], f"Unknown GAN loss '{disc_loss}'."
74 | self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False)
75 |
76 | self.codebook_weight = codebook_weight
77 | self.loglaplace_weight = loglaplace_weight
78 | self.loggaussian_weight = loggaussian_weight
79 | self.perceptual_weight = perceptual_weight
80 |
81 | self.discriminator = StyleDiscriminator(**disc_params)
82 | self.discriminator_iter_start = disc_start
83 | if disc_loss == "hinge":
84 | self.disc_loss = hinge_d_loss
85 | elif disc_loss == "vanilla":
86 | self.disc_loss = vanilla_d_loss
87 | elif disc_loss == "least_square":
88 | self.disc_loss = least_square_d_loss
89 |
90 | self.adversarial_weight = adversarial_weight
91 | self.use_adaptive_adv = use_adaptive_adv
92 | self.r1_gamma = r1_gamma
93 | self.do_r1_every = do_r1_every
94 |
95 | def calculate_adaptive_factor(self, nll_loss: torch.FloatTensor,
96 | g_loss: torch.FloatTensor, last_layer: nn.Module) -> torch.FloatTensor:
97 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
98 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
99 |
100 | adapt_factor = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
101 | adapt_factor = adapt_factor.clamp(0.0, 1e4).detach()
102 |
103 | return adapt_factor
104 |
105 | def forward(self, codebook_loss: torch.FloatTensor, inputs: torch.FloatTensor, reconstructions: torch.FloatTensor, optimizer_idx: int,
106 | global_step: int, batch_idx: int, last_layer: Optional[nn.Module] = None, split: Optional[str] = "train") -> Tuple:
107 | inputs = inputs.contiguous()
108 | reconstructions = reconstructions.contiguous()
109 |
110 | # now the GAN part
111 | if optimizer_idx == 0:
112 | # generator update
113 | loglaplace_loss = (reconstructions - inputs).abs().mean()
114 | loggaussian_loss = (reconstructions - inputs).pow(2).mean()
115 | perceptual_loss = self.perceptual_loss(inputs*2-1, reconstructions*2-1).mean()
116 |
117 | nll_loss = self.loglaplace_weight * loglaplace_loss + self.loggaussian_weight * loggaussian_loss + self.perceptual_weight * perceptual_loss
118 |
119 | logits_fake = self.discriminator(reconstructions)
120 | g_loss = self.disc_loss(logits_fake)
121 |
122 | try:
123 | d_weight = self.adversarial_weight
124 |
125 | if self.use_adaptive_adv:
126 | d_weight *= self.calculate_adaptive_factor(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = 1 if global_step >= self.discriminator_iter_start else 0
132 | loss = nll_loss + disc_factor * d_weight * g_loss + self.codebook_weight * codebook_loss
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach(),
135 | "{}/quant_loss".format(split): codebook_loss.detach(),
136 | "{}/rec_loss".format(split): nll_loss.detach(),
137 | "{}/loglaplace_loss".format(split): loglaplace_loss.detach(),
138 | "{}/loggaussian_loss".format(split): loggaussian_loss.detach(),
139 | "{}/perceptual_loss".format(split): perceptual_loss.detach(),
140 | "{}/g_loss".format(split): g_loss.detach(),
141 | }
142 |
143 | if self.use_adaptive_adv:
144 | log["{}/d_weight".format(split)] = d_weight.detach()
145 |
146 | return loss, log
147 |
148 | if optimizer_idx == 1:
149 | # second pass for discriminator update
150 | disc_factor = 1 if global_step >= self.discriminator_iter_start else 0
151 | do_r1 = self.training and bool(disc_factor) and batch_idx % self.do_r1_every == 0
152 |
153 | logits_real = self.discriminator(inputs.requires_grad_(do_r1))
154 | logits_fake = self.discriminator(reconstructions.detach())
155 |
156 | d_loss = disc_factor * self.disc_loss(logits_fake, logits_real)
157 | if do_r1:
158 | with conv2d_gradfix.no_weight_gradients():
159 | gradients, = torch.autograd.grad(outputs=logits_real.sum(), inputs=inputs, create_graph=True)
160 |
161 | gradients_norm = gradients.square().sum([1,2,3]).mean()
162 | d_loss += self.r1_gamma * self.do_r1_every * gradients_norm/2
163 |
164 | log = {"{}/disc_loss".format(split): d_loss.detach(),
165 | "{}/logits_real".format(split): logits_real.detach().mean(),
166 | "{}/logits_fake".format(split): logits_fake.detach().mean(),
167 | }
168 |
169 | if do_r1:
170 | log["{}/r1_reg".format(split)] = gradients_norm.detach()
171 |
172 | return d_loss, log
173 |
--------------------------------------------------------------------------------
/enhancing/modules/cond/clipcond.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | from omegaconf import OmegaConf
8 | from typing import Tuple, Union, List, Any
9 |
10 | import clip
11 | import torch
12 | import torch.nn as nn
13 | from torchvision import transforms as T
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 | from .dummycond import DummyCond
17 | from ...utils.general import initialize_from_config
18 |
19 |
20 | class ClipTextCond(DummyCond):
21 | def __init__(self, image_size: Union[Tuple[int, int], int],
22 | clip_model: str, tokenizer: OmegaConf) -> None:
23 | super().__init__()
24 | self.image_size = image_size
25 | self.clip_model, _ = clip.load(clip_model, device=device)
26 | self.tokenizer = initialize_from_config(tokenizer)
27 |
28 | def encode_codes(self, text: torch.LongTensor) -> torch.FloatTensor:
29 | with torch.no_grad():
30 | text_features = model.encode_text(text)
31 |
32 | return text_features
33 |
34 | def to_img(self, texts: torch.LongTensor) -> torch.FloatTensor:
35 | W, H = self.image_size if isinstance(self.image_size, tuple) else (self.image_size, self.image_size)
36 | font = ImageFont.truetype("arial.ttf", 12)
37 |
38 | imgs = []
39 | for text in texts:
40 | text = self.tokenizer.decode(text)
41 | words = text.split()
42 | length = 0
43 |
44 | for idx, word in enumerate(words):
45 | if length > 27:
46 | length = 0
47 | word[idx-int(idx>0)] += '\n'
48 |
49 | length += len(word)
50 |
51 | img = Image.new("RGBA", (W, H), "white")
52 | draw = ImageDraw.Draw(img)
53 |
54 | w, h = draw.textsize(text, font)
55 | draw.text(((W-w)/2,(H-h)/2), text, font=font, fill="black", align="center")
56 |
57 | img = img.convert('RGB')
58 | img = T.ToTensor()(img)
59 | imgs.append(img)
60 |
61 | return torch.stack(imgs, dim=0)
62 |
63 |
64 | class ClipImageCond(DummyCond):
65 | def __init__(self, clip_model: str) -> None:
66 | super().__init__()
67 | self.clip_model, _ = clip.load(clip_model, device=device)
68 |
69 | def encode_codes(self, image: torch.FloatTensor) -> torch.FloatTensor:
70 | with torch.no_grad():
71 | image_features = model.encode_image(image)
72 |
73 | return image_features
74 |
75 | def to_img(self, image: torch.FloatTensor) -> torch.FloatTensor:
76 | return image.clamp(0, 1)
77 |
--------------------------------------------------------------------------------
/enhancing/modules/cond/dummycond.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import os
8 | from omegaconf import OmegaConf
9 | from typing import Tuple, Union, List, Any
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torchvision import transforms as T
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 | from ...utils.general import initialize_from_config
17 |
18 |
19 | class DummyCond(nn.Module):
20 | def __init__(self) -> None:
21 | super().__init__()
22 |
23 | def encode(self, condition: Any) -> Tuple[Any, Any, Any]:
24 | return condition, None, condition
25 |
26 | def decode(self, condition: Any) -> Any:
27 | return condition
28 |
29 | def encode_codes(self, condition: Any) -> Any:
30 | return condition
31 |
32 | def decode_codes(self, condition: Any) -> Any:
33 | return condition
34 |
35 |
36 | class TextCond(DummyCond):
37 | def __init__(self, image_size: Union[Tuple[int, int], int], tokenizer: OmegaConf) -> None:
38 | super().__init__()
39 | self.image_size = image_size
40 | self.tokenizer = initialize_from_config(tokenizer)
41 |
42 | def to_img(self, texts: torch.LongTensor) -> torch.FloatTensor:
43 | W, H = self.image_size if isinstance(self.image_size, tuple) else (self.image_size, self.image_size)
44 | font = ImageFont.truetype(os.path.join(os.getcwd(), "assets", "font", "arial.ttf"), 12)
45 |
46 | imgs = []
47 | for text in texts:
48 | text = self.tokenizer.decode(text)
49 | words = text.split()
50 | length = 0
51 |
52 | for idx, word in enumerate(words):
53 | if length > 27:
54 | length = 0
55 | word[idx-int(idx>0)] += '\n'
56 |
57 | length += len(word)
58 |
59 | img = Image.new("RGBA", (W, H), "white")
60 | draw = ImageDraw.Draw(img)
61 |
62 | w, h = draw.textsize(text, font)
63 | draw.text(((W-w)/2,(H-h)/2), text, font=font, fill="black", align="center")
64 |
65 | img = img.convert('RGB')
66 | img = T.ToTensor()(img)
67 | imgs.append(img)
68 |
69 | return torch.stack(imgs, dim=0)
70 |
71 |
72 | class ClassCond(DummyCond):
73 | def __init__(self, image_size: Union[Tuple[int, int], int], class_name: Union[str, List[str]]) -> None:
74 | super().__init__()
75 | self.img_size = image_size
76 | if isinstance(class_name, str):
77 | if class_name.endswith("txt") and os.path.isfile(class_name):
78 | self.cls_name = open(class_name, "r").read().split("\n")
79 | elif "." not in class_name and not os.path.isfile(class_name):
80 | self.cls_name = class_name
81 | elif isinstance(class_name, list) and isinstance(class_name[0], str):
82 | self.cls_name = class_name
83 | else:
84 | raise Exception("Class file format not supported")
85 |
86 | def to_img(self, clss: torch.LongTensor) -> torch.FloatTensor:
87 | W, H = self.img_size if isinstance(self.img_size, tuple) else (self.img_size, self.img_size)
88 | font = ImageFont.truetype(os.path.join(os.getcwd(), "assets", "font", "arial.ttf"), 12)
89 |
90 | imgs = []
91 | for cls in clss:
92 | cls_name = self.cls_name[int(cls)]
93 | length = 0
94 |
95 | img = Image.new("RGBA", (W, H), "white")
96 | draw = ImageDraw.Draw(img)
97 |
98 | w, h = draw.textsize(cls_name, font)
99 | draw.text(((W-w)/2,(H-h)/2), cls_name, font=font, fill="black", align="center")
100 |
101 | img = img.convert('RGB')
102 | img = T.ToTensor()(img)
103 | imgs.append(img)
104 |
105 | return torch.stack(imgs, dim=0)
106 |
--------------------------------------------------------------------------------
/enhancing/modules/cond/vqcond.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | import numpy as np
11 | from typing import Tuple, Dict, Any
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from ...utils.general import get_obj_from_str
17 |
18 |
19 | def VQCond(base_class: str, *args, **kwargs) -> object:
20 | def to_img(x: torch.FloatTensor) -> torch.FloatTensor:
21 | return x.clamp(0, 1)
22 |
23 | model = get_obj_from_str(base_class)(*args, **kwargs)
24 | model.to_img = to_img
25 |
26 | return model
27 |
28 |
29 | def VQSegmentation(base_class: str, n_labels: int, *args, **kwargs) -> object:
30 | base_model_cls = get_obj_from_str(base_class)
31 | class Wrapper(base_model_cls):
32 | def __init__(self) -> None:
33 | self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
34 | super().__init__(*args, **kwargs)
35 |
36 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
37 | x = self.get_input(batch, self.image_key)
38 | xrec, qloss = self(x)
39 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
40 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
41 | self.log("train/total_loss", total_loss,
42 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
43 |
44 | return aeloss
45 |
46 | def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> torch.FloatTensor:
47 | x = self.get_input(batch, self.image_key)
48 | xrec, qloss = self(x)
49 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
50 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
51 | total_loss = log_dict_ae["val/total_loss"]
52 | self.log("val/total_loss", total_loss,
53 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
54 |
55 | return aeloss
56 |
57 | @torch.no_grad()
58 | def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict:
59 | log = dict()
60 | x = self.get_input(batch, self.image_key).to(self.device)
61 | xrec, _ = self(x)
62 | if x.shape[1] > 3:
63 | # colorize with random projection
64 | assert xrec.shape[1] > 3
65 | # convert logits to indices
66 | xrec = torch.argmax(xrec, dim=1, keepdim=True)
67 | xrec = F.one_hot(xrec, num_classes=x.shape[1])
68 | xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
69 | x = self.to_img(x)
70 | xrec = self.to_img(xrec)
71 | log["inputs"] = x
72 | log["reconstructions"] = xrec
73 |
74 | return log
75 |
76 | def to_img(self, x: torch.FloatTensor) -> torch.FloatTensor:
77 | x = F.conv2d(x, weight=self.colorize)
78 |
79 | return (x-x.min())/(x.max()-x.min())
80 |
81 | return Wrapper()
82 |
--------------------------------------------------------------------------------
/enhancing/modules/stage1/layers.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from ViT-Pytorch (https://github.com/lucidrains/vit-pytorch)
7 | # Copyright (c) 2020 Phil Wang. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | import math
11 | import numpy as np
12 | from typing import Union, Tuple, List
13 | from collections import OrderedDict
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from einops import rearrange, repeat
19 | from einops.layers.torch import Rearrange
20 |
21 | def get_2d_sincos_pos_embed(embed_dim, grid_size):
22 | """
23 | grid_size: int or (int, int) of the grid height and width
24 | return:
25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
26 | """
27 | grid_size = (grid_size, grid_size) if type(grid_size) != tuple else grid_size
28 | grid_h = np.arange(grid_size[0], dtype=np.float32)
29 | grid_w = np.arange(grid_size[1], dtype=np.float32)
30 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
31 | grid = np.stack(grid, axis=0)
32 |
33 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
34 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35 |
36 | return pos_embed
37 |
38 |
39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
40 | assert embed_dim % 2 == 0
41 |
42 | # use half of dimensions to encode grid_h
43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
45 |
46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
47 | return emb
48 |
49 |
50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
51 | """
52 | embed_dim: output dimension for each position
53 | pos: a list of positions to be encoded: size (M,)
54 | out: (M, D)
55 | """
56 | assert embed_dim % 2 == 0
57 | omega = np.arange(embed_dim // 2, dtype=np.float)
58 | omega /= embed_dim / 2.
59 | omega = 1. / 10000**omega # (D/2,)
60 |
61 | pos = pos.reshape(-1) # (M,)
62 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
63 |
64 | emb_sin = np.sin(out) # (M, D/2)
65 | emb_cos = np.cos(out) # (M, D/2)
66 |
67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68 | return emb
69 |
70 |
71 | def init_weights(m):
72 | if isinstance(m, nn.Linear):
73 | # we use xavier_uniform following official JAX ViT:
74 | torch.nn.init.xavier_uniform_(m.weight)
75 | if m.bias is not None:
76 | nn.init.constant_(m.bias, 0)
77 | elif isinstance(m, nn.LayerNorm):
78 | nn.init.constant_(m.bias, 0)
79 | nn.init.constant_(m.weight, 1.0)
80 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
81 | w = m.weight.data
82 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
83 |
84 |
85 | class PreNorm(nn.Module):
86 | def __init__(self, dim: int, fn: nn.Module) -> None:
87 | super().__init__()
88 | self.norm = nn.LayerNorm(dim)
89 | self.fn = fn
90 |
91 | def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
92 | return self.fn(self.norm(x), **kwargs)
93 |
94 |
95 | class FeedForward(nn.Module):
96 | def __init__(self, dim: int, hidden_dim: int) -> None:
97 | super().__init__()
98 | self.net = nn.Sequential(
99 | nn.Linear(dim, hidden_dim),
100 | nn.Tanh(),
101 | nn.Linear(hidden_dim, dim)
102 | )
103 |
104 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
105 | return self.net(x)
106 |
107 |
108 | class Attention(nn.Module):
109 | def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None:
110 | super().__init__()
111 | inner_dim = dim_head * heads
112 | project_out = not (heads == 1 and dim_head == dim)
113 |
114 | self.heads = heads
115 | self.scale = dim_head ** -0.5
116 |
117 | self.attend = nn.Softmax(dim = -1)
118 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
119 |
120 | self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity()
121 |
122 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
123 | qkv = self.to_qkv(x).chunk(3, dim = -1)
124 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
125 |
126 | attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
127 | attn = self.attend(attn)
128 |
129 | out = torch.matmul(attn, v)
130 | out = rearrange(out, 'b h n d -> b n (h d)')
131 |
132 | return self.to_out(out)
133 |
134 |
135 | class Transformer(nn.Module):
136 | def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None:
137 | super().__init__()
138 | self.layers = nn.ModuleList([])
139 | for idx in range(depth):
140 | layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)),
141 | PreNorm(dim, FeedForward(dim, mlp_dim))])
142 | self.layers.append(layer)
143 | self.norm = nn.LayerNorm(dim)
144 |
145 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
146 | for attn, ff in self.layers:
147 | x = attn(x) + x
148 | x = ff(x) + x
149 |
150 | return self.norm(x)
151 |
152 |
153 | class ViTEncoder(nn.Module):
154 | def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
155 | dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
156 | super().__init__()
157 | image_height, image_width = image_size if isinstance(image_size, tuple) \
158 | else (image_size, image_size)
159 | patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
160 | else (patch_size, patch_size)
161 |
162 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
163 | en_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))
164 |
165 | self.num_patches = (image_height // patch_height) * (image_width // patch_width)
166 | self.patch_dim = channels * patch_height * patch_width
167 |
168 | self.to_patch_embedding = nn.Sequential(
169 | nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size),
170 | Rearrange('b c h w -> b (h w) c'),
171 | )
172 | self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False)
173 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
174 |
175 | self.apply(init_weights)
176 |
177 | def forward(self, img: torch.FloatTensor) -> torch.FloatTensor:
178 | x = self.to_patch_embedding(img)
179 | x = x + self.en_pos_embedding
180 | x = self.transformer(x)
181 |
182 | return x
183 |
184 |
185 | class ViTDecoder(nn.Module):
186 | def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
187 | dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
188 | super().__init__()
189 | image_height, image_width = image_size if isinstance(image_size, tuple) \
190 | else (image_size, image_size)
191 | patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
192 | else (patch_size, patch_size)
193 |
194 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
195 | de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))
196 |
197 | self.num_patches = (image_height // patch_height) * (image_width // patch_width)
198 | self.patch_dim = channels * patch_height * patch_width
199 |
200 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
201 | self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False)
202 | self.to_pixel = nn.Sequential(
203 | Rearrange('b (h w) c -> b c h w', h=image_height // patch_height),
204 | nn.ConvTranspose2d(dim, channels, kernel_size=patch_size, stride=patch_size)
205 | )
206 |
207 | self.apply(init_weights)
208 |
209 | def forward(self, token: torch.FloatTensor) -> torch.FloatTensor:
210 | x = token + self.de_pos_embedding
211 | x = self.transformer(x)
212 | x = self.to_pixel(x)
213 |
214 | return x
215 |
216 | def get_last_layer(self) -> nn.Parameter:
217 | return self.to_pixel[-1].weight
218 |
--------------------------------------------------------------------------------
/enhancing/modules/stage1/quantizers.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | import math
11 | from functools import partial
12 | from typing import Tuple, Optional
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 |
18 |
19 | class BaseQuantizer(nn.Module):
20 | def __init__(self, embed_dim: int, n_embed: int, straight_through: bool = True, use_norm: bool = True,
21 | use_residual: bool = False, num_quantizers: Optional[int] = None) -> None:
22 | super().__init__()
23 | self.straight_through = straight_through
24 | self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x
25 |
26 | self.use_residual = use_residual
27 | self.num_quantizers = num_quantizers
28 |
29 | self.embed_dim = embed_dim
30 | self.n_embed = n_embed
31 |
32 | self.embedding = nn.Embedding(self.n_embed, self.embed_dim)
33 | self.embedding.weight.data.normal_()
34 |
35 | def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
36 | pass
37 |
38 | def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
39 | if not self.use_residual:
40 | z_q, loss, encoding_indices = self.quantize(z)
41 | else:
42 | z_q = torch.zeros_like(z)
43 | residual = z.detach().clone()
44 |
45 | losses = []
46 | encoding_indices = []
47 |
48 | for _ in range(self.num_quantizers):
49 | z_qi, loss, indices = self.quantize(residual.clone())
50 | residual.sub_(z_qi)
51 | z_q.add_(z_qi)
52 |
53 | encoding_indices.append(indices)
54 | losses.append(loss)
55 |
56 | losses, encoding_indices = map(partial(torch.stack, dim = -1), (losses, encoding_indices))
57 | loss = losses.mean()
58 |
59 | # preserve gradients with straight-through estimator
60 | if self.straight_through:
61 | z_q = z + (z_q - z).detach()
62 |
63 | return z_q, loss, encoding_indices
64 |
65 |
66 | class VectorQuantizer(BaseQuantizer):
67 | def __init__(self, embed_dim: int, n_embed: int, beta: float = 0.25, use_norm: bool = True,
68 | use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None:
69 | super().__init__(embed_dim, n_embed, True,
70 | use_norm, use_residual, num_quantizers)
71 |
72 | self.beta = beta
73 |
74 | def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
75 | z_reshaped_norm = self.norm(z.view(-1, self.embed_dim))
76 | embedding_norm = self.norm(self.embedding.weight)
77 |
78 | d = torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) + \
79 | torch.sum(embedding_norm ** 2, dim=1) - 2 * \
80 | torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm)
81 |
82 | encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
83 | encoding_indices = encoding_indices.view(*z.shape[:-1])
84 |
85 | z_q = self.embedding(encoding_indices).view(z.shape)
86 | z_qnorm, z_norm = self.norm(z_q), self.norm(z)
87 |
88 | # compute loss for embedding
89 | loss = self.beta * torch.mean((z_qnorm.detach() - z_norm)**2) + \
90 | torch.mean((z_qnorm - z_norm.detach())**2)
91 |
92 | return z_qnorm, loss, encoding_indices
93 |
94 |
95 | class GumbelQuantizer(BaseQuantizer):
96 | def __init__(self, embed_dim: int, n_embed: int, temp_init: float = 1.0,
97 | use_norm: bool = True, use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None:
98 | super().__init__(embed_dim, n_embed, False,
99 | use_norm, use_residual, num_quantizers)
100 |
101 | self.temperature = temp_init
102 |
103 | def quantize(self, z: torch.FloatTensor, temp: Optional[float] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
104 | # force hard = True when we are in eval mode, as we must quantize
105 | hard = not self.training
106 | temp = self.temperature if temp is None else temp
107 |
108 | z_reshaped_norm = self.norm(z.view(-1, self.embed_dim))
109 | embedding_norm = self.norm(self.embedding.weight)
110 |
111 | logits = - torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) - \
112 | torch.sum(embedding_norm ** 2, dim=1) + 2 * \
113 | torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm)
114 | logits = logits.view(*z.shape[:-1], -1)
115 |
116 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=-1, hard=hard)
117 | z_qnorm = torch.matmul(soft_one_hot, embedding_norm)
118 |
119 | # kl divergence to the prior loss
120 | logits = F.log_softmax(logits, dim=-1) # use log_softmax because it is more numerically stable
121 | loss = torch.sum(logits.exp() * (logits+math.log(self.n_embed)), dim=-1).mean()
122 |
123 | # get encoding via argmax
124 | encoding_indices = soft_one_hot.argmax(dim=-1)
125 |
126 | return z_qnorm, loss, encoding_indices
127 |
--------------------------------------------------------------------------------
/enhancing/modules/stage1/vitvqgan.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | from typing import List, Tuple, Dict, Any, Optional
11 | from omegaconf import OmegaConf
12 |
13 | import PIL
14 | import torch
15 | import torch.nn as nn
16 | from torch.optim import lr_scheduler
17 | from torchvision import transforms as T
18 | import pytorch_lightning as pl
19 |
20 | from .layers import ViTEncoder as Encoder, ViTDecoder as Decoder
21 | from .quantizers import VectorQuantizer, GumbelQuantizer
22 | from ...utils.general import initialize_from_config
23 |
24 |
25 | class ViTVQ(pl.LightningModule):
26 | def __init__(self, image_key: str, image_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf,
27 | loss: OmegaConf, path: Optional[str] = None, ignore_keys: List[str] = list(), scheduler: Optional[OmegaConf] = None) -> None:
28 | super().__init__()
29 | self.path = path
30 | self.ignore_keys = ignore_keys
31 | self.image_key = image_key
32 | self.scheduler = scheduler
33 |
34 | self.loss = initialize_from_config(loss)
35 | self.encoder = Encoder(image_size=image_size, patch_size=patch_size, **encoder)
36 | self.decoder = Decoder(image_size=image_size, patch_size=patch_size, **decoder)
37 | self.quantizer = VectorQuantizer(**quantizer)
38 | self.pre_quant = nn.Linear(encoder.dim, quantizer.embed_dim)
39 | self.post_quant = nn.Linear(quantizer.embed_dim, decoder.dim)
40 |
41 | if path is not None:
42 | self.init_from_ckpt(path, ignore_keys)
43 |
44 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
45 | quant, diff = self.encode(x)
46 | dec = self.decode(quant)
47 |
48 | return dec, diff
49 |
50 | def init_from_ckpt(self, path: str, ignore_keys: List[str] = list()):
51 | sd = torch.load(path, map_location="cpu")["state_dict"]
52 | keys = list(sd.keys())
53 | for k in keys:
54 | for ik in ignore_keys:
55 | if k.startswith(ik):
56 | print("Deleting key {} from state_dict.".format(k))
57 | del sd[k]
58 | self.load_state_dict(sd, strict=False)
59 | print(f"Restored from {path}")
60 |
61 | def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
62 | h = self.encoder(x)
63 | h = self.pre_quant(h)
64 | quant, emb_loss, _ = self.quantizer(h)
65 |
66 | return quant, emb_loss
67 |
68 | def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
69 | quant = self.post_quant(quant)
70 | dec = self.decoder(quant)
71 |
72 | return dec
73 |
74 | def encode_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
75 | h = self.encoder(x)
76 | h = self.pre_quant(h)
77 | _, _, codes = self.quantizer(h)
78 |
79 | return codes
80 |
81 | def decode_codes(self, code: torch.LongTensor) -> torch.FloatTensor:
82 | quant = self.quantizer.embedding(code)
83 | quant = self.quantizer.norm(quant)
84 |
85 | if self.quantizer.use_residual:
86 | quant = quant.sum(-2)
87 |
88 | dec = self.decode(quant)
89 |
90 | return dec
91 |
92 | def get_input(self, batch: Tuple[Any, Any], key: str = 'image') -> Any:
93 | x = batch[key]
94 | if len(x.shape) == 3:
95 | x = x[..., None]
96 | if x.dtype == torch.double:
97 | x = x.float()
98 |
99 | return x.contiguous()
100 |
101 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
102 | x = self.get_input(batch, self.image_key)
103 | xrec, qloss = self(x)
104 |
105 | if optimizer_idx == 0:
106 | # autoencoder
107 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
108 | last_layer=self.decoder.get_last_layer(), split="train")
109 |
110 | self.log("train/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
111 | del log_dict_ae["train/total_loss"]
112 |
113 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
114 |
115 | return aeloss
116 |
117 | if optimizer_idx == 1:
118 | # discriminator
119 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
120 | last_layer=self.decoder.get_last_layer(), split="train")
121 |
122 | self.log("train/disc_loss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
123 | del log_dict_disc["train/disc_loss"]
124 |
125 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
126 |
127 | return discloss
128 |
129 | def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> Dict:
130 | x = self.get_input(batch, self.image_key)
131 | xrec, qloss = self(x)
132 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, batch_idx,
133 | last_layer=self.decoder.get_last_layer(), split="val")
134 |
135 | rec_loss = log_dict_ae["val/rec_loss"]
136 |
137 | self.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
138 | self.log("val/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
139 | del log_dict_ae["val/rec_loss"]
140 | del log_dict_ae["val/total_loss"]
141 |
142 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
143 |
144 | if hasattr(self.loss, 'discriminator'):
145 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, batch_idx,
146 | last_layer=self.decoder.get_last_layer(), split="val")
147 |
148 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
149 |
150 | return self.log_dict
151 |
152 | def configure_optimizers(self) -> Tuple[List, List]:
153 | lr = self.learning_rate
154 | optim_groups = list(self.encoder.parameters()) + \
155 | list(self.decoder.parameters()) + \
156 | list(self.pre_quant.parameters()) + \
157 | list(self.post_quant.parameters()) + \
158 | list(self.quantizer.parameters())
159 |
160 | optimizers = [torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
161 | schedulers = []
162 |
163 | if hasattr(self.loss, 'discriminator'):
164 | optimizers.append(torch.optim.AdamW(self.loss.discriminator.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4))
165 |
166 | if self.scheduler is not None:
167 | self.scheduler.params.start = lr
168 | scheduler = initialize_from_config(self.scheduler)
169 |
170 | schedulers = [
171 | {
172 | 'scheduler': lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler.schedule),
173 | 'interval': 'step',
174 | 'frequency': 1
175 | } for optimizer in optimizers
176 | ]
177 |
178 | return optimizers, schedulers
179 |
180 | def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict:
181 | log = dict()
182 | x = self.get_input(batch, self.image_key).to(self.device)
183 | quant, _ = self.encode(x)
184 |
185 | log["originals"] = x
186 | log["reconstructions"] = self.decode(quant)
187 |
188 | return log
189 |
190 |
191 | class ViTVQGumbel(ViTVQ):
192 | def __init__(self, image_key: str, image_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf, loss: OmegaConf,
193 | path: Optional[str] = None, ignore_keys: List[str] = list(), temperature_scheduler: OmegaConf = None, scheduler: Optional[OmegaConf] = None) -> None:
194 | super().__init__(image_key, image_size, patch_size, encoder, decoder, quantizer, loss, None, None, scheduler)
195 |
196 | self.temperature_scheduler = initialize_from_config(temperature_scheduler) \
197 | if temperature_scheduler else None
198 | self.quantizer = GumbelQuantizer(**quantizer)
199 |
200 | if path is not None:
201 | self.init_from_ckpt(path, ignore_keys)
202 |
203 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
204 | if self.temperature_scheduler:
205 | self.quantizer.temperature = self.temperature_scheduler(self.global_step)
206 |
207 | loss = super().training_step(batch, batch_idx, optimizer_idx)
208 |
209 | if optimizer_idx == 0:
210 | self.log("temperature", self.quantizer.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
211 |
212 | return loss
213 |
--------------------------------------------------------------------------------
/enhancing/modules/stage2/layers.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from minDALL-E (https://github.com/kakaobrain/minDALL-E)
7 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 | # Modified from minGPT (https://github.com/karpathy/minGPT)
10 | # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
11 | # ------------------------------------------------------------------------------------
12 |
13 | import math
14 | from omegaconf import OmegaConf
15 | from typing import Optional, Tuple, List
16 |
17 | import torch
18 | import torch.nn as nn
19 | from torch.nn import functional as F
20 | from torch.cuda.amp import autocast
21 |
22 |
23 | class MultiHeadSelfAttention(nn.Module):
24 | def __init__(self,
25 | ctx_len: int,
26 | cond_len: int,
27 | embed_dim: int,
28 | n_heads: int,
29 | attn_bias: bool,
30 | use_mask: bool = True):
31 | super().__init__()
32 | assert embed_dim % n_heads == 0
33 |
34 | # key, query, value projections for all heads
35 | self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
36 | self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
37 | self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
38 |
39 | # output projection
40 | self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
41 |
42 | self.n_heads = n_heads
43 | self.ctx_len = ctx_len
44 | self.use_mask = use_mask
45 | if self.use_mask:
46 | self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
47 | self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
48 | self.mask[:, :cond_len, :cond_len] = 1
49 |
50 | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
51 | with torch.no_grad():
52 | ww = torch.zeros(1, 1, embed_dim)
53 | for i in range(embed_dim):
54 | ww[0, 0, i] = i / (embed_dim - 1)
55 | self.time_mix = nn.Parameter(ww)
56 |
57 | def forward(self, x, use_cache=False, layer_past=None):
58 | B, T, C = x.shape
59 |
60 | x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
61 | x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
62 |
63 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
64 | k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
65 | q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
66 | v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
67 |
68 | if use_cache:
69 | present = torch.stack([k, v])
70 |
71 | if layer_past is not None:
72 | past_key, past_value = layer_past
73 | k = torch.cat([past_key, k], dim=-2)
74 | v = torch.cat([past_value, v], dim=-2)
75 |
76 | if use_cache and layer_past is not None:
77 | # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
78 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
79 | att = F.softmax(att, dim=-1)
80 | y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
81 | else:
82 | # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
83 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
84 | if self.use_mask:
85 | mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
86 | att = att.masked_fill(mask == 0, float('-inf'))
87 | att = F.softmax(att, dim=-1)
88 | y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
89 | y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
90 |
91 | # output projection
92 | y = self.proj(y)
93 |
94 | if use_cache:
95 | return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
96 | else:
97 | return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
98 |
99 | class FFN(nn.Module):
100 | def __init__(self, embed_dim, mlp_bias):
101 | super().__init__()
102 | self.p0 = nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias)
103 | self.p1 = nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias)
104 |
105 | def forward(self, x):
106 | x = self.p0(x)
107 | # x = F.gelu(x)
108 | x = torch.square(torch.relu(x))
109 | x = self.p1(x)
110 | return x
111 |
112 | class Block(nn.Module):
113 | def __init__(self,
114 | ctx_len: int,
115 | cond_len: int,
116 | embed_dim: int,
117 | n_heads: int,
118 | mlp_bias: bool,
119 | attn_bias: bool):
120 | super().__init__()
121 | self.ln1 = nn.LayerNorm(embed_dim)
122 | self.ln2 = nn.LayerNorm(embed_dim)
123 |
124 | self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
125 | cond_len=cond_len,
126 | embed_dim=embed_dim,
127 | n_heads=n_heads,
128 | attn_bias=attn_bias,
129 | use_mask=True)
130 | self.mlp = FFN(embed_dim=embed_dim, mlp_bias=mlp_bias)
131 |
132 | def forward(self, x):
133 | x = x + self.attn(self.ln1(x))
134 | x = x + self.mlp(self.ln2(x))
135 |
136 | return x
137 |
138 | def sample(self, x, layer_past=None):
139 | attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
140 | x = x + attn
141 | x = x + self.mlp(self.ln2(x))
142 |
143 | return x, present
144 |
145 |
146 | class GPT(nn.Module):
147 | def __init__(self,
148 | vocab_cond_size: int,
149 | vocab_img_size: int,
150 | embed_dim: int,
151 | cond_num_tokens: int,
152 | img_num_tokens: int,
153 | n_heads: int,
154 | n_layers: int,
155 | mlp_bias: bool = True,
156 | attn_bias: bool = True) -> None:
157 | super().__init__()
158 | self.img_num_tokens = img_num_tokens
159 | self.vocab_cond_size = vocab_cond_size
160 |
161 | # condition token and position embedding
162 | self.tok_emb_cond = nn.Embedding(vocab_cond_size, embed_dim)
163 | self.pos_emb_cond = nn.Parameter(torch.zeros(1, cond_num_tokens, embed_dim))
164 |
165 | # input token and position embedding
166 | self.tok_emb_code = nn.Embedding(vocab_img_size, embed_dim)
167 | self.pos_emb_code = nn.Parameter(torch.zeros(1, img_num_tokens, embed_dim))
168 |
169 | # transformer blocks
170 | self.blocks = [Block(ctx_len=cond_num_tokens + img_num_tokens,
171 | cond_len=cond_num_tokens,
172 | embed_dim=embed_dim,
173 | n_heads=n_heads,
174 | mlp_bias=mlp_bias,
175 | attn_bias=attn_bias) for i in range(1, n_layers+1)]
176 | self.blocks = nn.Sequential(*self.blocks)
177 |
178 | # head
179 | self.layer_norm = nn.LayerNorm(embed_dim)
180 | self.head = nn.Linear(embed_dim, vocab_img_size, bias=False)
181 |
182 | self.apply(self._init_weights)
183 |
184 | def _init_weights(self, module: nn.Module) -> None:
185 | if isinstance(module, (nn.Linear, nn.Embedding)):
186 | module.weight.data.normal_(mean=0.0, std=0.02)
187 | if isinstance(module, nn.Linear) and module.bias is not None:
188 | module.bias.data.zero_()
189 | elif isinstance(module, nn.LayerNorm):
190 | module.bias.data.zero_()
191 | module.weight.data.fill_(1.0)
192 |
193 | def forward(self,
194 | codes: torch.LongTensor,
195 | conds: torch.LongTensor) -> torch.FloatTensor:
196 |
197 | codes = codes.view(codes.shape[0], -1)
198 | codes = self.tok_emb_code(codes)
199 | conds = self.tok_emb_cond(conds)
200 |
201 | codes = codes + self.pos_emb_code
202 | conds = conds + self.pos_emb_cond
203 |
204 | x = torch.cat([conds, codes], axis=1).contiguous()
205 | x = self.blocks(x)
206 | x = self.layer_norm(x)
207 |
208 | x = x[:, conds.shape[1]-1:-1].contiguous()
209 | logits = self.head(x)
210 |
211 | return logits
212 |
213 | def sample(self,
214 | conds: torch.LongTensor,
215 | top_k: Optional[float] = None,
216 | top_p: Optional[float] = None,
217 | softmax_temperature: float = 1.0,
218 | use_fp16: bool = True) -> Tuple[torch.FloatTensor, torch.LongTensor]:
219 |
220 | past = codes = logits = None
221 |
222 | for i in range(self.img_num_tokens):
223 | if codes is None:
224 | codes_ = None
225 | pos_code = None
226 | else:
227 | codes_ = codes.clone().detach()
228 | codes_ = codes_[:, -1:]
229 | pos_code = self.pos_emb_code[:, i-1:i, :]
230 |
231 | logits_, presents = self.sample_step(codes_, conds, pos_code, use_fp16, past)
232 |
233 | logits_ = logits_.to(dtype=torch.float32)
234 | logits_ = logits_ / softmax_temperature
235 |
236 | presents = torch.stack(presents).clone().detach()
237 | if past is None:
238 | past = [presents]
239 | else:
240 | past.append(presents)
241 |
242 | if top_k is not None:
243 | v, ix = torch.topk(logits_, top_k)
244 | logits_[logits_ < v[:, [-1]]] = -float('Inf')
245 | probs = F.softmax(logits_, dim=-1)
246 |
247 | if top_p is not None:
248 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
249 | cum_probs = torch.cumsum(sorted_probs, dim=-1)
250 |
251 | sorted_idx_remove_cond = cum_probs >= top_p
252 |
253 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
254 | sorted_idx_remove_cond[..., 0] = 0
255 |
256 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
257 | probs = probs.masked_fill(indices_to_remove, 0.0)
258 | probs = probs / torch.sum(probs, dim=-1, keepdim=True)
259 |
260 | idx = torch.multinomial(probs, num_samples=1).clone().detach()
261 | codes = idx if codes is None else torch.cat([codes, idx], axis=1)
262 | logits = logits_ if logits is None else torch.cat([logits, logits_], axis=1)
263 |
264 | del past
265 |
266 | return logits, codes
267 |
268 | def sample_step(self,
269 | codes: torch.LongTensor,
270 | conds: torch.LongTensor,
271 | pos_code: torch.LongTensor,
272 | use_fp16: bool = True,
273 | past: Optional[torch.FloatTensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
274 |
275 | with autocast(enabled=use_fp16):
276 | presents = []
277 |
278 | if codes is None:
279 | assert past is None
280 | conds = self.tok_emb_cond(conds)
281 | x = conds + self.pos_emb_cond
282 |
283 | for i, block in enumerate(self.blocks):
284 | x, present = block.sample(x, layer_past=None)
285 | presents.append(present)
286 | x = self.layer_norm(x)
287 | x = x[:, conds.shape[1]-1].contiguous()
288 | else:
289 | assert past is not None
290 | codes = self.tok_emb_code(codes)
291 | x = codes + pos_code
292 |
293 | past = torch.cat(past, dim=-2)
294 | for i, block in enumerate(self.blocks):
295 | x, present = block.sample(x, layer_past=past[i])
296 | presents.append(present)
297 |
298 | x = self.layer_norm(x)
299 | x = x[:, -1].contiguous()
300 |
301 | logits = self.head(x)
302 |
303 | return logits, presents
304 |
305 |
306 | class RQTransformer(nn.Module):
307 | def __init__(self,
308 | vocab_cond_size: int,
309 | vocab_img_size: int,
310 | embed_dim: int,
311 | cond_num_tokens: int,
312 | img_num_tokens: int,
313 | depth_num_tokens: int,
314 | spatial_n_heads: int,
315 | depth_n_heads: int,
316 | spatial_n_layers: int,
317 | depth_n_layers: int,
318 | mlp_bias: bool = True,
319 | attn_bias: bool = True) -> None:
320 | super().__init__()
321 | self.img_num_tokens = img_num_tokens
322 | self.depth_num_tokens = depth_num_tokens
323 | self.vocab_img_size = vocab_img_size
324 |
325 | # condition token and position embedding
326 | self.tok_emb_cond = nn.Embedding(vocab_cond_size, embed_dim)
327 | self.pos_emb_cond = nn.Parameter(torch.rand(1, cond_num_tokens, embed_dim))
328 |
329 | # spatial token and position embedding
330 | self.tok_emb_code = nn.Embedding(vocab_img_size, embed_dim)
331 | self.pos_emb_code = nn.Parameter(torch.rand(1, img_num_tokens, embed_dim))
332 |
333 | # depth position embedding
334 | self.pos_emb_depth = nn.Parameter(torch.rand(1, depth_num_tokens-1, embed_dim))
335 |
336 | # spatial transformer
337 | self.spatial_transformer = [Block(ctx_len=cond_num_tokens + img_num_tokens,
338 | cond_len=cond_num_tokens,
339 | embed_dim=embed_dim,
340 | n_heads=spatial_n_heads,
341 | mlp_bias=mlp_bias,
342 | attn_bias=attn_bias) for i in range(1, spatial_n_layers+1)]
343 | self.spatial_transformer = nn.Sequential(*self.spatial_transformer)
344 |
345 | # depth transformer
346 | self.depth_transformer = [Block(ctx_len=depth_num_tokens,
347 | cond_len=0,
348 | embed_dim=embed_dim,
349 | n_heads=depth_n_heads,
350 | mlp_bias=mlp_bias,
351 | attn_bias=attn_bias) for i in range(1, depth_n_layers+1)]
352 | self.depth_transformer = nn.Sequential(*self.depth_transformer)
353 |
354 | # head
355 | self.ln_spatial = nn.LayerNorm(embed_dim)
356 | self.ln_depth = nn.LayerNorm(embed_dim)
357 | self.head = nn.Linear(embed_dim, vocab_img_size, bias=False)
358 |
359 | self.apply(self._init_weights)
360 |
361 | def _init_weights(self, module: nn.Module) -> None:
362 | if isinstance(module, (nn.Linear, nn.Embedding)):
363 | module.weight.data.normal_(mean=0.0, std=0.02)
364 | if isinstance(module, nn.Linear) and module.bias is not None:
365 | module.bias.data.zero_()
366 | elif isinstance(module, nn.LayerNorm):
367 | module.bias.data.zero_()
368 | module.weight.data.fill_(1.0)
369 |
370 | def forward(self,
371 | codes: torch.LongTensor,
372 | conds: torch.LongTensor) -> torch.FloatTensor:
373 |
374 | codes = codes.view(codes.shape[0], -1, codes.shape[-1])
375 | codes = self.tok_emb_code(codes)
376 | conds = self.tok_emb_cond(conds)
377 |
378 | codes_cumsum = codes.cumsum(-1)
379 | codes_sum = codes_cumsum[..., -1, :]
380 |
381 | codes = codes_sum + self.pos_emb_code
382 | conds = conds + self.pos_emb_cond
383 |
384 | h = torch.cat([conds, codes], axis=1).contiguous()
385 | h = self.ln_spatial(self.spatial_transformer(h))
386 | h = h[:, conds.shape[1]-1:-1].contiguous()
387 |
388 | v = codes_cumsum[..., :-1, :] + self.pos_emb_depth
389 | v = torch.cat([h.unsqueeze(2), v], axis=2).contiguous()
390 |
391 | v = v.view(-1, *v.shape[2:])
392 | v = self.depth_transformer(v)
393 | logits = self.head(self.ln_depth(v))
394 |
395 | return logits
396 |
397 | def sample(self,
398 | conds: torch.LongTensor,
399 | top_k: Optional[float] = None,
400 | top_p: Optional[float] = None,
401 | softmax_temperature: float = 1.0,
402 | use_fp16: bool = True) -> Tuple[torch.FloatTensor, torch.LongTensor]:
403 |
404 | past = codes = logits = None
405 | B, T, D, S = conds.shape[0], self.img_num_tokens, self.depth_num_tokens, self.vocab_img_size
406 |
407 | for i in range(self.img_num_tokens):
408 | depth_past = None
409 |
410 | if codes is None:
411 | codes_ = None
412 | pos_code = None
413 | else:
414 | codes_ = codes.clone().detach()
415 | codes_ = codes_[:, -self.depth_num_tokens:]
416 | pos_code = self.pos_emb_code[:, i-1:i, :]
417 |
418 | hidden, presents = self.sample_spatial_step(codes_, conds, pos_code, use_fp16, past)
419 |
420 | presents = torch.stack(presents).clone().detach()
421 | if past is None:
422 | past = [presents]
423 | else:
424 | past.append(presents)
425 |
426 | last_len = 0 if codes is None else codes.shape[-1]
427 |
428 | for d in range(self.depth_num_tokens):
429 | if depth_past is None:
430 | codes_ = None
431 | pos_depth = None
432 | else:
433 | codes_ = codes.clone().detach()
434 | codes_ = codes_[:, last_len:]
435 | pos_depth = self.pos_emb_depth[:, d-1:d, :]
436 |
437 | logits_, depth_presents = self.sample_depth_step(codes_, hidden, pos_depth, use_fp16, depth_past)
438 |
439 | logits_ = logits_.to(dtype=torch.float32)
440 | logits_ = logits_ / softmax_temperature
441 |
442 | depth_presents = torch.stack(depth_presents).clone().detach()
443 | if depth_past is None:
444 | depth_past = [depth_presents]
445 | else:
446 | depth_past.append(depth_presents)
447 |
448 | if top_k is not None:
449 | v, ix = torch.topk(logits_, top_k)
450 | logits_[logits_ < v[:, [-1]]] = -float('Inf')
451 | probs = F.softmax(logits_, dim=-1)
452 |
453 | if top_p is not None:
454 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
455 | cum_probs = torch.cumsum(sorted_probs, dim=-1)
456 |
457 | sorted_idx_remove_cond = cum_probs >= top_p
458 |
459 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
460 | sorted_idx_remove_cond[..., 0] = 0
461 |
462 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
463 | probs = probs.masked_fill(indices_to_remove, 0.0)
464 | probs = probs / torch.sum(probs, dim=-1, keepdim=True)
465 |
466 | idx = torch.multinomial(probs, num_samples=1).clone().detach()
467 | codes = idx if codes is None else torch.cat([codes, idx], axis=1)
468 | logits = logits_ if logits is None else torch.cat([logits, logits_], axis=1)
469 |
470 | del depth_past
471 |
472 | del past
473 |
474 | codes = codes.view(B, T, D)
475 | logits = logits.view(B * T, D, S)
476 |
477 | return logits, codes
478 |
479 | def sample_spatial_step(self,
480 | codes: torch.LongTensor,
481 | conds: torch.LongTensor,
482 | pos_code: torch.LongTensor,
483 | use_fp16: bool = True,
484 | past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
485 |
486 | with autocast(enabled=use_fp16):
487 | presents = []
488 |
489 | if codes is None:
490 | assert past is None
491 | conds = self.tok_emb_cond(conds)
492 | x = conds + self.pos_emb_cond
493 |
494 | for i, block in enumerate(self.spatial_transformer):
495 | x, present = block.sample(x, layer_past=None)
496 | presents.append(present)
497 | x = self.ln_spatial(x)
498 | x = x[:, conds.shape[1]-1:conds.shape[1]].contiguous()
499 | else:
500 | assert past is not None
501 | codes = self.tok_emb_code(codes)
502 | x = codes.sum(1, keepdim=True) + pos_code
503 |
504 | past = torch.cat(past, dim=-2)
505 | for i, block in enumerate(self.spatial_transformer):
506 | x, present = block.sample(x, layer_past=past[i])
507 | presents.append(present)
508 |
509 | x = self.ln_spatial(x)
510 | x = x[:, -1:].contiguous()
511 |
512 | return x, presents
513 |
514 | def sample_depth_step(self,
515 | codes: torch.LongTensor,
516 | hidden: torch.FloatTensor,
517 | pos_depth: torch.LongTensor,
518 | use_fp16: bool = True,
519 | past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
520 |
521 | with autocast(enabled=use_fp16):
522 | presents = []
523 |
524 | if codes is None:
525 | assert past is None
526 | x = hidden
527 |
528 | for i, block in enumerate(self.depth_transformer):
529 | x, present = block.sample(x, layer_past=None)
530 | presents.append(present)
531 | x = self.ln_depth(x)
532 | else:
533 | assert past is not None
534 | codes = self.tok_emb_code(codes)
535 | x = codes.sum(1, keepdim=True) + pos_depth
536 |
537 | past = torch.cat(past, dim=-2)
538 | for i, block in enumerate(self.depth_transformer):
539 | x, present = block.sample(x, layer_past=past[i])
540 | presents.append(present)
541 |
542 | x = self.ln_depth(x)
543 | x = x[:, -1].contiguous()
544 |
545 | logits = self.head(x)
546 |
547 | return logits, presents
548 |
--------------------------------------------------------------------------------
/enhancing/modules/stage2/transformer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | from typing import Optional, Tuple, Dict, Union, Any
11 | from omegaconf import OmegaConf
12 |
13 | import torch
14 | import torch.nn as nn
15 | from torch.optim import lr_scheduler
16 | import torch.nn.functional as F
17 | import pytorch_lightning as pl
18 |
19 | from .layers import *
20 | from ...utils.general import initialize_from_config
21 |
22 |
23 | class CondTransformer(pl.LightningModule):
24 | def __init__(self, cond_key: str, cond: OmegaConf, stage1: OmegaConf, transformer: OmegaConf,
25 | path: Optional[str] = None, ignore_keys: List[str] = list(),
26 | code_shape: List[int] = None, scheduler: Optional[OmegaConf] = None) -> None:
27 | super().__init__()
28 |
29 | # get condition key, code shape and scheduler
30 | self.cond_key = cond_key
31 | self.code_shape = code_shape
32 | self.scheduler = scheduler
33 |
34 | # load condition model
35 | self.cond_model = initialize_from_config(cond)
36 |
37 | # load stage1 model
38 | self.stage1_model = initialize_from_config(stage1)
39 |
40 | # load transformer
41 | self.transformer = initialize_from_config(transformer)
42 |
43 | # make the parameters in stage1 model not trainable
44 | self.stage1_model.eval()
45 | for p in self.stage1_model.parameters():
46 | p.requires_grad = False
47 |
48 | # make the parameters in condition model not trainable
49 | self.cond_model.eval()
50 | for p in self.cond_model.parameters():
51 | p.requires_grad = False
52 |
53 | if path is not None:
54 | self.init_from_ckpt(path, ignore_keys)
55 |
56 | def forward(self,
57 | codes: torch.LongTensor,
58 | conds: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
59 |
60 | conds = conds.view(conds.shape[0], -1)
61 | logits = self.transformer(codes, conds)
62 |
63 | codes = codes.view(-1, codes.shape[-1])
64 |
65 | return logits, codes
66 |
67 | def init_from_ckpt(self, path: str, ignore_keys: List[str] = list()):
68 | sd = torch.load(path, map_location="cpu")["state_dict"]
69 | keys = list(sd.keys())
70 | for k in keys:
71 | for ik in ignore_keys:
72 | if k.startswith(ik):
73 | print("Deleting key {} from state_dict.".format(k))
74 | del sd[k]
75 | self.load_state_dict(sd, strict=False)
76 | print(f"Restored from {path}")
77 |
78 | @torch.no_grad()
79 | def sample(self,
80 | conds: torch.LongTensor,
81 | top_k: Optional[float] = None,
82 | top_p: Optional[float] = None,
83 | softmax_temperature: float = 1.0,
84 | use_fp16: bool = True) -> torch.FloatTensor:
85 |
86 | conds = conds.view(conds.shape[0], -1)
87 | logits, codes = self.transformer.sample(conds=conds, top_k=top_k, top_p=top_p,
88 | softmax_temperature=softmax_temperature,
89 | use_fp16=use_fp16)
90 |
91 | if self.code_shape is not None:
92 | codes = codes.view(codes.shape[0], *self.code_shape)
93 | pixels = self.stage1_model.decode_codes(codes).clamp(0, 1)
94 |
95 | return pixels
96 |
97 | def get_input(self, batch: Tuple[Any, Any], key: str) -> torch.FloatTensor:
98 | x = batch[key]
99 |
100 | if len(x.shape) == 3:
101 | x = x[..., None]
102 | if x.dtype == torch.double:
103 | x = x.float()
104 |
105 | return x.contiguous()
106 |
107 | def shared_step(self, batch: Tuple[Any, Any], batch_idx: int) -> torch.FloatTensor:
108 | images = self.get_input(batch, self.stage1_model.image_key)
109 | conds = self.get_input(batch, self.cond_key)
110 |
111 | with torch.no_grad():
112 | codes = self.stage1_model.encode_codes(images).detach()
113 | conds = self.cond_model.encode_codes(conds).detach()
114 |
115 | logits, codes = self(codes, conds)
116 | loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
117 |
118 | return loss
119 |
120 | def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
121 | loss = self.shared_step(batch, batch_idx)
122 | self.log("train/total_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
123 |
124 | return loss
125 |
126 | def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> torch.FloatTensor:
127 | loss = self.shared_step(batch, batch_idx)
128 | self.log("val/total_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
129 |
130 | return loss
131 |
132 | def configure_optimizers(self) -> torch.optim.Optimizer:
133 | """
134 | Following minGPT:
135 | This long function is unfortunately doing something very simple and is being very defensive:
136 | We are separating out all parameters of the model into two buckets: those that will experience
137 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
138 | We are then returning the PyTorch optimizer object.
139 | """
140 | # separate out all parameters to those that will and won't experience regularizing weight decay
141 | decay = set()
142 | no_decay = set()
143 | whitelist_weight_modules = (torch.nn.Linear, )
144 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
145 | for mn, m in self.transformer.named_modules():
146 | for pn, p in m.named_parameters():
147 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
148 |
149 | if pn.endswith('bias'):
150 | # all biases will not be decayed
151 | no_decay.add(fpn)
152 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
153 | # weights of whitelist modules will be weight decayed
154 | decay.add(fpn)
155 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
156 | # weights of blacklist modules will NOT be weight decayed
157 | no_decay.add(fpn)
158 | elif 'time_' in pn: # for RWKV
159 | no_decay.add(fpn)
160 |
161 | # special case the position embedding parameter in the root GPT module as not decayed
162 | no_decay.add('pos_emb_cond')
163 | no_decay.add('pos_emb_code')
164 |
165 | if hasattr(self.transformer, 'pos_emb_depth'):
166 | no_decay.add('pos_emb_depth')
167 |
168 | # validate that we considered every parameter
169 | param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
170 | inter_params = decay & no_decay
171 | union_params = decay | no_decay
172 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
173 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay/ignored set!" \
174 | % (str(param_dict.keys() - union_params), )
175 |
176 | # create the pytorch optimizer object
177 | optim_groups = [
178 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
179 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
180 | ]
181 | optimizer = [torch.optim.Adam(optim_groups, lr=self.learning_rate, betas=(0.9, 0.96))]
182 | scheduler = []
183 |
184 | if self.scheduler is not None:
185 | self.scheduler.params.start = lr
186 | scheduler = initialize_from_config(self.scheduler)
187 |
188 | scheduler = [{
189 | 'scheduler': lr_scheduler.LambdaLR(optimizer[0], lr_lambda=self.scheduler.schedule),
190 | 'interval': 'step',
191 | 'frequency': 1
192 | }]
193 |
194 | return optimizer, scheduler
195 |
196 | def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict:
197 | log = dict()
198 |
199 | conds = self.get_input(batch, self.cond_key).to(self.device)
200 | cond_codes = self.cond_model.encode_codes(conds).detach()
201 |
202 | log["conditions"] = self.cond_model.to_img(conds)
203 | log["first samples"] = self.sample(cond_codes)
204 | log["second samples"] = self.sample(cond_codes)
205 |
206 | return log
207 |
--------------------------------------------------------------------------------
/enhancing/utils/callback.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import os
7 | import wandb
8 | import numpy as np
9 | from PIL import Image
10 | from pathlib import Path
11 | from omegaconf import OmegaConf
12 | from typing import Tuple, Generic, Dict
13 |
14 | import torch
15 | import torchvision
16 | import pytorch_lightning as pl
17 | from pytorch_lightning.utilities.distributed import rank_zero_only
18 | from pytorch_lightning.callbacks import Callback
19 |
20 |
21 | class SetupCallback(Callback):
22 | def __init__(self, config: OmegaConf, exp_config: OmegaConf, basedir: Path, logdir: str = "log", ckptdir:str = "ckpt") -> None:
23 | super().__init__()
24 | self.logdir = basedir / logdir
25 | self.ckptdir = basedir / ckptdir
26 | self.config = config
27 | self.exp_config = exp_config
28 |
29 | def on_pretrain_routine_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
30 | if trainer.global_rank == 0:
31 | # Create logdirs and save configs
32 | os.makedirs(self.logdir, exist_ok=True)
33 | os.makedirs(self.ckptdir, exist_ok=True)
34 |
35 | print("Experiment config")
36 | print(self.exp_config.pretty())
37 |
38 | print("Model config")
39 | print(self.config.pretty())
40 |
41 |
42 | class ImageLogger(Callback):
43 | def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True, increase_log_steps: bool =True) -> None:
44 | super().__init__()
45 | self.batch_freq = batch_frequency
46 | self.max_images = max_images
47 | self.logger_log_images = {
48 | pl.loggers.WandbLogger: self._wandb,
49 | pl.loggers.TestTubeLogger: self._testtube,
50 | }
51 | self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
52 | if not increase_log_steps:
53 | self.log_steps = [self.batch_freq]
54 | self.clamp = clamp
55 |
56 | @rank_zero_only
57 | def _wandb(self, pl_module, images, batch_idx, split):
58 | #raise ValueError("No way wandb")
59 | grids = dict()
60 | for k in images:
61 | grid = torchvision.utils.make_grid(images[k])
62 | grids[f"{split}/{k}"] = wandb.Image(grid)
63 | pl_module.logger.experiment.log(grids)
64 |
65 | @rank_zero_only
66 | def _testtube(self, pl_module, images, batch_idx, split):
67 | for k in images:
68 | grid = torchvision.utils.make_grid(images[k])
69 | grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
70 |
71 | tag = f"{split}/{k}"
72 | pl_module.logger.experiment.add_image(
73 | tag, grid,
74 | global_step=pl_module.global_step)
75 |
76 | @rank_zero_only
77 | def log_local(self, save_dir: str, split: str, images: Dict,
78 | global_step: int, current_epoch: int, batch_idx: int) -> None:
79 | root = os.path.join(save_dir, "results", split)
80 | os.makedirs(root, exist_ok=True)
81 | for k in images:
82 | grid = torchvision.utils.make_grid(images[k], nrow=4)
83 |
84 | grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
85 | grid = grid.numpy()
86 | grid = (grid*255).astype(np.uint8)
87 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
88 | k,
89 | global_step,
90 | current_epoch,
91 | batch_idx)
92 | path = os.path.join(root, filename)
93 | os.makedirs(os.path.split(path)[0], exist_ok=True)
94 | Image.fromarray(grid).save(path)
95 |
96 | def log_img(self, pl_module: pl.LightningModule, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int, split: str = "train") -> None:
97 | if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
98 | hasattr(pl_module, "log_images") and
99 | callable(pl_module.log_images) and
100 | self.max_images > 0):
101 | logger = type(pl_module.logger)
102 |
103 | is_train = pl_module.training
104 | if is_train:
105 | pl_module.eval()
106 |
107 | with torch.no_grad():
108 | images = pl_module.log_images(batch, split=split, pl_module=pl_module)
109 |
110 | for k in images:
111 | N = min(images[k].shape[0], self.max_images)
112 | images[k] = images[k][:N].detach().cpu()
113 | if self.clamp:
114 | images[k] = images[k].clamp(0, 1)
115 |
116 | self.log_local(pl_module.logger.save_dir, split, images,
117 | pl_module.global_step, pl_module.current_epoch, batch_idx)
118 |
119 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
120 | logger_log_images(pl_module, images, pl_module.global_step, split)
121 |
122 | if is_train:
123 | pl_module.train()
124 |
125 | def check_frequency(self, batch_idx: int) -> bool:
126 | if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
127 | try:
128 | self.log_steps.pop(0)
129 | except IndexError:
130 | pass
131 | return True
132 | return False
133 |
134 | def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
135 | outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int) -> None:
136 | self.log_img(pl_module, batch, batch_idx, split="train")
137 |
138 | def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
139 | outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor],
140 | dataloader_idx: int, batch_idx: int) -> None:
141 | self.log_img(pl_module, batch, batch_idx, split="val")
142 |
--------------------------------------------------------------------------------
/enhancing/utils/general.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import os
7 | import random
8 | import importlib
9 | import pathlib
10 | from typing import Tuple, List, Dict, ClassVar
11 | import numpy as np
12 | from omegaconf import OmegaConf
13 | from datetime import datetime
14 |
15 | import torch
16 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback
17 | from pytorch_lightning.loggers import WandbLogger
18 |
19 | from .callback import *
20 |
21 |
22 | def set_seed(seed: int):
23 | random.seed(seed)
24 | np.random.seed(seed)
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed_all(seed)
27 |
28 |
29 | def get_obj_from_str(name: str, reload: bool = False) -> ClassVar:
30 | module, cls = name.rsplit(".", 1)
31 |
32 | if reload:
33 | module_imp = importlib.import_module(module)
34 | importlib.reload(module_imp)
35 |
36 | return getattr(importlib.import_module(module, package=None), cls)
37 |
38 |
39 | def initialize_from_config(config: OmegaConf) -> object:
40 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
41 |
42 |
43 | def setup_callbacks(exp_config: OmegaConf, config: OmegaConf) -> Tuple[List[Callback], WandbLogger]:
44 | now = datetime.now().strftime('%d%m%Y_%H%M%S')
45 | basedir = pathlib.Path("experiments", exp_config.name, now)
46 | os.makedirs(basedir, exist_ok=True)
47 |
48 | setup_callback = SetupCallback(config, exp_config, basedir)
49 | checkpoint_callback = ModelCheckpoint(
50 | dirpath=setup_callback.ckptdir,
51 | filename=exp_config.name+"-{epoch:02d}",
52 | monitor="train/total_loss",
53 | save_top_k=-1,
54 | verbose=False,
55 | )
56 | os.makedirs(setup_callback.logdir/'wandb', exist_ok=True)
57 | logger = WandbLogger(save_dir=str(setup_callback.logdir), name=exp_config.name+"_"+str(now))
58 | logger_img_callback = ImageLogger(exp_config.batch_frequency, exp_config.max_images)
59 |
60 | return [setup_callback, checkpoint_callback, logger_img_callback], logger
61 |
62 |
63 | def get_config_from_file(config_file: str) -> Dict:
64 | config_file = OmegaConf.load(config_file)
65 |
66 | if 'base_config' in config_file.keys():
67 | if config_file['base_config'] == "default_base":
68 | base_config = get_default_config()
69 | elif config_file['base_config'].endswith(".yaml"):
70 | base_config = get_config_from_file(config_file['base_config'])
71 |
72 | config_file = {key: value for key, value in config_file if key != "base_config"}
73 |
74 | return OmegaConf.merge(base_config, config_file)
75 |
76 | return config_file
77 |
--------------------------------------------------------------------------------
/enhancing/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
7 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 |
13 | class BaseScheduler:
14 | def __init__(self):
15 | pass
16 |
17 | def schedule(self, n: int) -> float:
18 | pass
19 |
20 | def __call__(self, n: int) -> float:
21 | assert hasattr(self, 'start')
22 |
23 | return self.schedule(n) * self.start
24 |
25 |
26 | class ExponentialDecayScheduler(BaseScheduler):
27 | def __init__(self, start: float, end: float, decay_every_step: int, scale_factor: float) -> None:
28 | super().__init__()
29 | self.decay_every_step = decay_every_step
30 | self.scale_factor = scale_factor
31 |
32 | self.start = start
33 | self.end = end
34 | self.current = start
35 |
36 | def schedule(self, n: int) -> float:
37 | if not n % self.decay_every_step:
38 | res = np.exp(-self.scale_factor*n) * self.start
39 | self.current = max(self.end, res)
40 |
41 | return self.current / self.start
42 |
43 |
44 | class LambdaWarmUpCosineScheduler(BaseScheduler):
45 | def __init__(self, warm_up_steps: int, max_decay_steps: int, min_: float, max_: float, start: float) -> None:
46 | super().__init__()
47 | assert (max_decay_steps >= warm_up_steps)
48 |
49 | self.warm_up_steps = warm_up_steps
50 | self.start = start
51 | self.min_ = min_
52 | self.max_ = max_
53 | self.max_decay_steps = max_decay_steps
54 | self.last = 0.
55 |
56 | def schedule(self, n: int) -> float:
57 | if n < self.warm_up_steps:
58 | res = (self.max_ - self.start) / self.warm_up_steps * n + self.start
59 | self.last = res
60 | else:
61 | t = (n - self.warm_up_steps) / (self.max_decay_steps - self.warm_up_steps)
62 | t = min(t, 1.0)
63 | res = self.min_ + 0.5 * (self.max_ - self.min_) * (1 + np.cos(t * np.pi))
64 | self.last = res
65 |
66 | return res / self.start
67 |
68 |
69 | class LambdaWarmUpLinearScheduler(BaseScheduler):
70 | def __init__(self, warm_up_steps: int, max_decay_steps: int, min_: float, max_: float, start: float) -> None:
71 | super().__init__()
72 | assert (max_decay_steps >= warm_up_steps)
73 |
74 | self.warm_up_steps = warm_up_steps
75 | self.start = start
76 | self.min_ = min_
77 | self.max_ = max_
78 | self.max_decay_steps = max_decay_steps
79 | self.last = 0.
80 |
81 | def schedule(self, n: int) -> float:
82 | if n < self.warm_up_steps:
83 | res = (self.max_ - self.start) / self.warm_up_steps * n + self.start
84 | self.last = res
85 | else:
86 | res = self.min_ + (self.max_ - self.min_) * (max_decay_steps - n) / max_decay_steps
87 | self.last = res
88 |
89 | return res / self.start
90 |
--------------------------------------------------------------------------------
/enhancing/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from CLIP (https://github.com/openai/CLIP)
3 | # Copyright (c) 2021 OpenAI. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import os
7 | import ftfy
8 | import html
9 | import regex as re
10 | from pathlib import Path
11 | from functools import lru_cache
12 | from typing import Optional, List, Tuple, Dict, Set
13 |
14 | import torch
15 |
16 | @lru_cache()
17 | def default_bpe():
18 | return 'assets/vocab/bpe_simple_vocab_16e6.txt'
19 |
20 | @lru_cache()
21 | def bytes_to_unicode() -> Dict:
22 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
23 | cs = bs[:]
24 | n = 0
25 | for b in range(2 ** 8):
26 | if b not in bs:
27 | bs.append(b)
28 | cs.append(2 ** 8 + n)
29 | n += 1
30 | cs = [chr(n) for n in cs]
31 | return dict(zip(bs, cs))
32 |
33 | def get_pairs(word: str) -> List[Tuple[str, str]]:
34 | pairs = set()
35 | prev_char = word[0]
36 | for char in word[1:]:
37 | pairs.add((prev_char, char))
38 | prev_char = char
39 | return pairs
40 |
41 | def basic_clean(text: str) -> str:
42 | text = ftfy.fix_text(text)
43 | text = html.unescape(html.unescape(text))
44 | return text.strip()
45 |
46 | def whitespace_clean(text: str) -> str:
47 | text = re.sub(r'\s+', ' ', text)
48 | text = text.strip()
49 | return text
50 |
51 | class SimpleTokenizer:
52 | def __init__(self, bpe_path: str = default_bpe(), text_length: int = 256,
53 | truncate_captions: bool = True) -> None:
54 | self.context_length = text_length
55 | self.truncate_text = truncate_captions
56 | self.byte_encoder = bytes_to_unicode()
57 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
58 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
59 | merges = merges[1:49152 - 256 - 2 + 1]
60 | merges = [tuple(merge.split()) for merge in merges]
61 | vocab = list(bytes_to_unicode().values())
62 | vocab = vocab + [v + '' for v in vocab]
63 | for merge in merges:
64 | vocab.append(''.join(merge))
65 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
66 |
67 | self.vocab_size = 49408
68 |
69 | self.encoder = dict(zip(vocab, range(len(vocab))))
70 | self.decoder = {v: k for k, v in self.encoder.items()}
71 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
72 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
73 | self.pat = re.compile(
74 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
75 | re.IGNORECASE)
76 |
77 | def bpe(self, token: int) -> str:
78 | if token in self.cache:
79 | return self.cache[token]
80 | word = tuple(token[:-1]) + (token[-1] + '',)
81 | pairs = get_pairs(word)
82 |
83 | if not pairs:
84 | return token + ''
85 |
86 | while True:
87 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
88 | if bigram not in self.bpe_ranks:
89 | break
90 | first, second = bigram
91 | new_word = []
92 | i = 0
93 | while i < len(word):
94 | try:
95 | j = word.index(first, i)
96 | new_word.extend(word[i:j])
97 | i = j
98 | except:
99 | new_word.extend(word[i:])
100 | break
101 |
102 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
103 | new_word.append(first + second)
104 | i += 2
105 | else:
106 | new_word.append(word[i])
107 | i += 1
108 | new_word = tuple(new_word)
109 | word = new_word
110 | if len(word) == 1:
111 | break
112 | else:
113 | pairs = get_pairs(word)
114 | word = ' '.join(word)
115 | self.cache[token] = word
116 | return word
117 |
118 | def encode(self, text: str) -> List[int]:
119 | bpe_tokens = []
120 | text = whitespace_clean(basic_clean(text)).lower()
121 | for token in re.findall(self.pat, text):
122 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
123 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
124 | return bpe_tokens
125 |
126 | def decode(self, tokens: List[int], remove_start_end: bool = True, pad_tokens: Optional[Set] = set()) -> str:
127 | if torch.is_tensor(tokens):
128 | tokens = tokens.tolist()
129 |
130 | if remove_start_end:
131 | tokens = [token for token in tokens if token not in (49406, 40407, 0)]
132 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens])
133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
134 | return text
135 |
136 | def tokenize(self, texts: str) -> List[int]:
137 | if isinstance(texts, str):
138 | texts = [texts]
139 | #assert type(texts) == list, f"texts is {texts}"
140 | all_tokens = [self.encode(text) for text in texts]
141 | result = torch.zeros(len(all_tokens), self.context_length, dtype=torch.long)
142 |
143 | for i, tokens in enumerate(all_tokens):
144 | if len(tokens) > self.context_length:
145 | if self.truncate_text:
146 | tokens = tokens[:self.context_length]
147 | else:
148 | raise RuntimeError(f"Input {texts[i]} is too long for context length {self.context_length}")
149 | result[i, :len(tokens)] = torch.tensor(tokens)
150 |
151 | return result
152 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: enhancing
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.0
9 | - pytorch=1.7.1
10 | - torchvision=0.8.2
11 | - numpy=1.19.2
12 | - ninja=1.10.2
13 | - pip:
14 | - ftfy==6.1.1
15 | - lpips==0.1.4
16 | - regex==2021.10.8
17 | - pytorch-lightning==1.5.10
18 | - einops==0.3.0
19 | - omegaconf==2.0.0
20 | - lmdb==1.0.0
21 | - wandb==0.12.21
22 | - git+https://github.com/openai/CLIP.git
23 | - albumentations==0.4.3
24 | - kornia==0.5.11
25 | - Pillow==9.0.1
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Enhancing Transformers
3 | # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
4 | # Licensed under the MIT License [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import os
8 | import sys
9 | import argparse
10 | from pathlib import Path
11 | from omegaconf import OmegaConf
12 | import pytorch_lightning as pl
13 |
14 | from enhancing.utils.general import get_config_from_file, initialize_from_config, setup_callbacks
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-c', '--config', type=str, required=True)
19 | parser.add_argument('-s', '--seed', type=int, default=0)
20 | parser.add_argument('-nn', '--num_nodes', type=int, default=1)
21 | parser.add_argument('-ng', '--num_gpus', type=int, default=1)
22 | parser.add_argument('-u', '--update_every', type=int, default=1)
23 | parser.add_argument('-e', '--epochs', type=int, default=100)
24 | parser.add_argument('-lr', '--base_lr', type=float, default=4.5e-6)
25 | parser.add_argument('-a', '--use_amp', default=False, action='store_true')
26 | parser.add_argument('-b', '--batch_frequency', type=int, default=750)
27 | parser.add_argument('-m', '--max_images', type=int, default=4)
28 | args = parser.parse_args()
29 |
30 | # Set random seed
31 | pl.seed_everything(args.seed)
32 |
33 | # Load configuration
34 | config = get_config_from_file(Path("configs")/(args.config+".yaml"))
35 | exp_config = OmegaConf.create({"name": args.config, "epochs": args.epochs, "update_every": args.update_every,
36 | "base_lr": args.base_lr, "use_amp": args.use_amp, "batch_frequency": args.batch_frequency,
37 | "max_images": args.max_images})
38 |
39 | # Build model
40 | model = initialize_from_config(config.model)
41 | model.learning_rate = exp_config.base_lr
42 |
43 | # Setup callbacks
44 | callbacks, logger = setup_callbacks(exp_config, config)
45 |
46 | # Build data modules
47 | data = initialize_from_config(config.dataset)
48 | data.prepare_data()
49 |
50 | # Build trainer
51 | trainer = pl.Trainer(max_epochs=exp_config.epochs,
52 | precision=16 if exp_config.use_amp else 32,
53 | callbacks=callbacks,
54 | gpus=args.num_gpus,
55 | num_nodes=args.num_nodes,
56 | strategy="ddp" if args.num_nodes > 1 or args.num_gpus > 1 else None,
57 | accumulate_grad_batches=exp_config.update_every,
58 | logger=logger)
59 |
60 | # Train
61 | trainer.fit(model, data)
62 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy==6.1.1
2 | lpips==0.1.4
3 | regex==2021.10.8
4 | torch==1.7.1
5 | torchvision==0.8.2
6 | pytorch-lightning==1.5.10
7 | einops==0.3.0
8 | omegaconf==2.0.0
9 | numpy==1.19.2
10 | lmdb==1.0.0
11 | wandb==0.12.21
12 | git+https://github.com/openai/CLIP.git
13 | albumentations==0.4.3
14 | kornia==0.5.11
15 | Pillow==9.0.1
--------------------------------------------------------------------------------