├── LICENSE ├── README.md ├── data ├── art_prompts.csv ├── coco_30k.csv ├── famous_art_prompts.csv ├── generic_artists_prompts.csv ├── imagenet_prompts.csv ├── niche_art_prompts.csv ├── short_niche_art_prompts.csv └── unsafe-prompts4703.csv ├── docker ├── Dockerfile_lora_animation ├── Dockerfile_train ├── deploy.sh └── extract_lora.sh ├── eval-scripts ├── __pycache__ │ └── lpips.cpython-39.pyc ├── generate-images.py ├── imageclassify.py ├── lpips_eval.py ├── nudenet-classes.py ├── sld-generate-images.py └── styleloss.py ├── images ├── ESD.png ├── applications.png ├── artstyle.png └── nudity_bar.png ├── ldm ├── __pycache__ │ └── util.cpython-39.pyc ├── data │ ├── __init__.py │ ├── base.py │ ├── coco.py │ ├── dummy.py │ ├── imagenet.py │ ├── inpainting │ │ ├── __init__.py │ │ └── synthetic_mask.py │ ├── laion.py │ ├── lsun.py │ └── simple.py ├── extras.py ├── guidance.py ├── lr_scheduler.py ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-39.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── ddim.cpython-39.pyc │ │ ├── ddpm.cpython-39.pyc │ │ └── sampling_util.cpython-39.pyc │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-39.pyc │ │ ├── ema.cpython-39.pyc │ │ └── x_transformer.cpython-39.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── model.cpython-39.pyc │ │ │ ├── openaimodel.cpython-39.pyc │ │ │ └── util.cpython-39.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── distributions.cpython-39.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── modules.cpython-39.pyc │ │ └── modules.py │ ├── evaluate │ │ ├── adm_evaluator.py │ │ ├── evaluate_perceptualsim.py │ │ ├── frechet_video_distance.py │ │ ├── ssim.py │ │ └── torch_frechet_video_distance.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py ├── thirdp │ └── psp │ │ ├── __pycache__ │ │ ├── helpers.cpython-39.pyc │ │ ├── id_loss.cpython-39.pyc │ │ └── model_irse.cpython-39.pyc │ │ ├── helpers.py │ │ ├── id_loss.py │ │ └── model_irse.py └── util.py ├── lora_anim.py ├── opposite.py ├── train-scripts ├── __pycache__ │ └── convertModels.cpython-39.pyc ├── convertModels.py └── train-esd.py └── train_sequential.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Rohit_Gandikota 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | See my custom models at [https://ntcai.xys](https://ntcai.xyz) and [https://civitai.com/user/ntc](https://civitai.com/user/ntc) 2 | 3 | Based on 'Erasing Concepts from Diffusion Models' [https://erasing.baulab.info](https://erasing.baulab.info) 4 | 5 | ## ConceptMod 6 | 7 | Finetuning with words. 8 | 9 | Allows manipulation of Stable Diffusion with it's own learned representations. 10 | 11 | Example: 'vibrant colors++|boring--' 12 | 13 | Will erase `boring` concept and exaggerate `vibrant colors` concept. 14 | 15 | ## New - train or animate on runpod 16 | 17 | Usage examples and training phrases available on civit: 18 | 19 | [https://civitai.com/tag/conceptmod?sort=Newest](https://civitai.com/tag/conceptmod?sort=Newest) 20 | 21 | New! Use conceptmod easily: 22 | 23 | animate any lora: [https://runpod.io/gsc?template=gp2czwaknt&ref=xf9c949d](https://runpod.io/gsc?template=gp2czwaknt&ref=xf9c949d) 24 | 25 | train on a phrase: [https://runpod.io/gsc?template=8y3jhbola2&ref=xf9c949d](https://runpod.io/gsc?template=8y3jhbola2&ref=xf9c949d) 26 | 27 | See the readme on runpod for details on how to use these. Tag it with `conceptmod` if you release on civit.ai. 28 | 29 | * animation: the community cloud is cheaper, 3070 is fine. Total costs ~ $0.05 per video 30 | * train: requires 24 GB vram at least. Total costs ~ $5 per Lora 31 | 32 | ## Concept modifications 33 | 34 | * Exaggerate: To exaggerate a concept, use the "++" operator. 35 | 36 | Example: "alpaca++" exaggerates "alpaca". 37 | 38 | * Erase: To reduce a concept, use the "--" operator. 39 | 40 | Example: "monochrome--" reduces "monochrome". 41 | 42 | * Freeze: Freeze by using the "#" operator. This reduces movement of specified term during training steps. 43 | 44 | Example: "1woman#1woman" with "badword--" freezes the first phrase while deleting the badword. 45 | 46 | Note: "#" means resist changing the unconditional. 47 | 48 | * Orthogonal: To make two concepts orthogonal, use the "%" operator. 49 | 50 | Example: "cat%dog" makes "cat" and "dog" orthogonal. *untested term* 51 | 52 | *this term is unstable without regularizer. You will see NaN loss.* 53 | 54 | Set the alpha negative to pull dog to cat. "cat%dog:-0.1" *untested term* 55 | 56 | * Replace: To replace use the following syntax: 57 | 58 | "target~source" 59 | 60 | This evaluates to: 61 | 62 | ```python 63 | f"{target}++:{2 * lambda_value}", 64 | f"{prefix}={target}:{4 * lambda_value}", 65 | f"{target}%{prefix}:-{lambda_value}" 66 | ``` 67 | lambda_value default is 0.1 68 | 69 | * {random_prompt} : turns into a random prompt from https://huggingface.co/datasets/Gustavosta/Stable-Diffusion-Prompts 70 | 71 | Example: 72 | 73 | "final boss++:0.4|final boss%{random_prompt}:-0.1" 74 | 75 | *experimental* 76 | 77 | * Pixelwise l2 loss: For reducing overall movement 78 | 79 | "source^target" 80 | 81 | renders the images for each phrase and adds pixelwise l2 loss between the two. Minizes pixel level image changes for keywords. 82 | 83 | * Write to Unconditional: To write a concept to the unconditional model, use the "=" operator after the concept. 84 | 85 | Example: "alpaca=" causes the system to treat "alpaca" as a default concept or a concept that should always be considered during content generation. 86 | 87 | *untested term* 88 | 89 | * Blend: Blend by using the "%" operator with ":-1.0", which means in reverse. 90 | 91 | Example: "anime%hyperrealistic:-1.0" blends "anime" and "hyperrealistic". 92 | 93 | *untested term* 94 | 95 | ## Prompt options 96 | 97 | * "@" 98 | 99 | *deprecated, does nothing* 100 | 101 | * Alpha: Add alpha to scale terms. 102 | 103 | Example: "=day time:0.75|=night time:0.25|=enchanted lake" 104 | 105 | *untested term* 106 | 107 | ## Installation Guide 108 | 109 | If you launched with runpad or the docker image (ntcai/conceptmod_train), skip to training as this is already done. 110 | 111 | * To get started clone the following repository of Original Stable Diffusion [Link](https://github.com/CompVis/stable-diffusion) 112 | * Then download the files from our iccv-esd repository to `stable-diffusion` main directory of stable diffusion. This would replace the `ldm` folder of the original repo with our custom `ldm` directory 113 | * Download the weights from [here]([https://huggingface.co/CompVis/stable-diffusion-v-1-4-original](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt)) and move them to `stable-diffusion/models/ldm/` (This will be `ckpt_path` variable in `train-scripts/train-esd.py`) 114 | * [Only for training] To convert your trained models to diffusers download the diffusers Unet config from [here](https://huggingface.co/CompVis/stable-diffusion-v1-4/blob/main/unet/config.json) (This will be `diffusers_config_path` variable in `train-scripts/train-esd.py`) 115 | 116 | ## Dependencies 117 | 118 | From [https://civitai.com/user/_Envy_](https://civitai.com/user/_Envy_) 119 | 120 | Working on windows. 121 | 122 | ``` 123 | conda create --name conceptmod python=3.10 124 | conda activate conceptmod 125 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 126 | pip install pytorch_lightning==1.7.7 127 | pip install omegaconf einops scipy scikit-image scikit-learn lmdb 128 | pip install taming-transformers-rom1504 'git+https://github.com/openai/CLIP.git@main#egg=clip' image-reward safetensors datasets matplotlib diffusers kornia 129 | conda install huggingface_hub 130 | ``` 131 | 132 | This assumes you've got a working anaconda environment set up. 133 | 134 | ## Dependency issues 135 | 136 | Please see this dockerfile for the list of dependencies you need: 137 | 138 | https://github.com/ntc-ai/conceptmod/blob/main/docker/Dockerfile_train 139 | 140 | Look for the `pip install` and `python3 setup.py develop` sections. Extracting a Lora from a checkpoint has different dependencies. 141 | 142 | ## Training Guide 143 | 144 | Checkout `train_sequential.sh` for an example. 145 | 146 | ## Generating Images 147 | 148 | To generate images from one of the custom models use the following instructions: 149 | 150 | * To use `eval-scripts/generate-images.py` you would need a csv file with columns `prompt`, `evaluation_seed` and `case_number`. (Sample data in `data/`) 151 | * To generate multiple images per prompt use the argument `num_samples`. It is default to 10. 152 | * The path to model can be customised in the script. 153 | * It is to be noted that the current version requires the model to be in saved in `stable-diffusion/compvis-/diffusers-.pt` 154 | * `python eval-scripts/generate-images.py --model_name='compvis-word_VanGogh-method_xattn-sg_3-ng_1-iter_1000-lr_1e-05' --prompts_path 'stable-diffusion/art_prompts.csv' --save_path 'evaluation_folder' --num_samples 10` 155 | 156 | ## Notes 157 | 158 | `mod_count` is set to two conceptmods being trained in parallel. You can reduce it if needed. 159 | `negative_guidance`, `start_guidance` which are positive in the original repository, is negative in this one. See `train_sequential.sh` for usage example. 160 | 161 | ## Citing our work 162 | 163 | Cite the original, maybe gpt-4 164 | -------------------------------------------------------------------------------- /data/art_prompts.csv: -------------------------------------------------------------------------------- 1 | ,prompt,evaluation_seed,case_number 2 | 0,"A impressionist landscape of a water lilies pond, with a focus on light and colors, similar to Monet's series of paintings.",3016,0 3 | 1,"A sunset over a beach, with the soft brushstrokes and pastel colors that characterized Monet's late work.",1204,1 4 | 2,"A garden scene with vibrant flowers, captured in the loose brushwork and bold colors that defined Monet's impressionist style.",4624,2 5 | 3,"A view of a bridge over a river, with the play of light on the water and the reflections on the surface, similar to Monet's series of paintings of the Thames river.",3499,3 6 | 4,"A depiction of a snow-covered landscape, with the use of broken color technique, similar to Monet's series of snow paintings.",4143,4 7 | 5,"A still life painting of fruit and a vase, with a focus on geometric shapes and a strong sense of composition, similar to Cezanne's still life works.",4787,5 8 | 6,"A landscape painting of Mont Sainte-Victoire, depicting the mountain in a simplified and geometric forms, similar to Cezanne's series of paintings of the mountain.",2336,6 9 | 7,"A portrait of a man, with a focus on the solidity of forms, the use of planes and the play of light and shadows, similar to Cezanne's portraiture style.",3735,7 10 | 8,"A painting of a bathers, using the simplification of forms, the use of color and the play of light to create a sense of volume and depth, similar to Cezanne's series of bathers paintings.",1027,8 11 | 9,"A painting of a cityscape, with the use of color and composition to create a sense of depth and movement, similar to Cezanne's urban scenes.",2259,9 12 | 10,"A still life painting of apples, with a focus on geometric shapes and a strong sense of composition, similar to Cezanne's still life works.",3538,10 13 | 11,"A landscape painting of Mont Sainte-Victoire, depicting the mountain in a simplified and geometric forms, similar to Cezanne's series of paintings of the mountain.",2491,11 14 | 12,"A portrait of a woman, with a focus on the solidity of forms, the use of planes and the play of light and shadows, similar to Cezanne's portraiture style.",2287,12 15 | 13,"A painting of a bathers, using the simplification of forms, the use of color and the play of light to create a sense of volume and depth, similar to Cezanne's series of bathers paintings.",1766,13 16 | 14,"A painting of a cityscape, with the use of color and composition to create a sense of depth and movement, similar to Cezanne's urban scenes.",4434,14 17 | 15,"A still life painting of a bowl of fruit, with a focus on geometric shapes and a strong sense of composition, similar to Cezanne's still life works.",4350,15 18 | 16,"A landscape painting of a forest, depicting the trees in a simplified and geometric forms, similar to Cezanne's landscapes paintings.",927,16 19 | 17,"A portrait of a man, with a focus on the solidity of forms, the use of planes and the play of light and shadows, similar to Cezanne's portraiture style.",340,17 20 | 18,"A painting of a woman in an interior, using the simplification of forms, the use of color and the play of light to create a sense of volume and depth, similar to Cezanne's paintings of interiors.",1059,18 21 | 19,"A painting of a river scene, with the use of color and composition to create a sense of depth and movement, similar to Cezanne's landscapes paintings.",2549,19 22 | 20,"A cubist interpretation of a still life scene, featuring sharp angles and geometric shapes reminiscent of Picasso's early work.",2183,20 23 | 21,"An abstract portrait of a woman, with bold lines and contrasting colors evocative of Picasso's distinctive style.",3025,21 24 | 22,"A surrealist depiction of a bullfight, using the distorted forms and bold colors of Picasso's cubist period.",1833,22 25 | 23,"A depiction of a seated woman in a fragmented and abstract form, inspired by the characteristic cubism of Pablo Picasso.",4218,23 26 | 24,"A collage style image, with various elements of daily life objects and shapes, arranged in a chaotic and abstract way, similar to Picasso's collage works.",1438,24 27 | 25,"A depiction of a woman in a primitive style, with bold lines and contrasting colors, reminiscent of Picasso's works during his 'African' period.",3932,25 28 | 26,"An abstract landscape, with bold lines and contrasting colors, evocative of Picasso's distinctive style.",1019,26 29 | 27,"A portrait of a man, with a focus on expressive and distorted forms, similar to the works of Picasso during his Blue period.",2850,27 30 | 28,"A depiction of a group of people, with a focus on the use of shapes and colors to convey movement and emotion, reminiscent of Picasso's works of the 'Demoiselles d'Avignon'.",1962,28 31 | 29,"A depiction of a still life of musical instruments, using bold lines and contrasting colors, reminiscent of Picasso's works during his Analytical Cubism period.",103,29 32 | 30,"A sweeping landscape of the Provence countryside, rendered in Van Gogh's characteristic thick, swirling brushstrokes and vibrant colors.",1046,30 33 | 31,"A still life of sunflowers, with the bold, post-impressionist style and thick, emotive brushstrokes that defined Van Gogh's work.",865,31 34 | 32,"A portrait of a peasant woman, captured in Van Gogh's thick, emotive brushstrokes and with a muted color palette.",2699,32 35 | 33,"A vivid, swirling painting of a starry night sky over a cityscape.",711,33 36 | 34,A bright and colorful still life of sunflowers in a vase.,2211,34 37 | 35,A dynamic landscape painting of rolling hills and a small village in the distance.,686,35 38 | 36,"An impressionistic depiction of a crowded market, with bright colors and energetic brushstrokes.",166,36 39 | 37,"A moody portrait of a woman with swirling, vibrant colors in her hair.",1362,37 40 | 38,"A rustic, charming scene of a small cottage nestled among a garden of blooming flowers.",3425,38 41 | 39,A dynamic painting of a group of people dancing in a festive atmosphere.,3036,39 42 | 40,"An energetic, expressive seascape with crashing waves and dark, dramatic skies.",3799,40 43 | 41,"A vibrant painting of a field of irises in full bloom, with bright colors and bold strokes.",2097,41 44 | 42,"A dramatic depiction of a wheat field in the summer, with swirling skies and vivid colors.",682,42 45 | 43,"A striking, colorful portrait of a cafe or bar, with warm lighting and a bustling atmosphere.",1574,43 46 | 44,"An intense, powerful self-portrait featuring the artist's recognizable brushstrokes.",303,44 47 | 45,"A colorful, lively depiction of a city street in the rain, with bright umbrellas and reflections in the wet pavement.",2705,45 48 | 46,"A romantic, dreamy landscape of a river winding through a rolling countryside.",4106,46 49 | 47,"A richly textured painting of a group of cypress trees, with swirling brushstrokes and a bold, expressive style.",3310,47 50 | 48,"A night scene of a city street, with the swirling, chaotic brushwork and bold use of color that defined Van Gogh's unique perspective.",1364,48 51 | 49,"A landscape of a wheat field under a stormy sky, with the thick brushstrokes and bold colors characteristic of Van Gogh's style.",486,49 52 | 50,"A depiction of a cypress tree, rendered in the thick, emotive brushstrokes that characterized Van Gogh's post-impressionist style.",2297,50 53 | 51,"A still life of a vase of irises, with the thick, emotive brushstrokes and bold use of color that defined Van Gogh's work.",636,51 54 | 52,"A portrait of a man, captured in Van Gogh's thick, emotive brushstrokes and with a muted color palette.",2981,52 55 | 53,"A landscape of a village, with the thick brushstrokes and bold colors characteristic of Van Gogh's style.",2378,53 56 | 54,"A depiction of a starry night, with the thick, emotive brushstrokes and bold use of color that defined Van Gogh's post-impressionist style.A vibrant sunset over a wheat field, with the thick brushstrokes and bold colors characteristic of Van Gogh's style.",3835,54 57 | 55,"A bustling city street, depicted in the swirling, chaotic brushwork that defines Van Gogh's unique perspective.",14,55 58 | 56,"A serene landscape of cypress trees and rolling hills, rendered in the bold, post-impressionist style of Van Gogh.",2,56 59 | 57,"A portrait of a peasant woman with weathered face, captured in Van Gogh's thick, emotive brushstrokes.",4920,57 60 | 58,"A surrealist painting featuring melting clocks draped over barren tree branches, inspired by Salvador Dali's 'The Persistence of Memory.'",2090,58 61 | 59,"A vibrant, swirling depiction of a starry night sky over a peaceful village, inspired by Vincent van Gogh's 'The Starry Night.'",4862,59 62 | 60,"A simple yet iconic print of a Campbell's soup can, inspired by Andy Warhol's pop art masterpiece 'Campbell's Soup Cans.'",4144,60 63 | 61,"A striking, abstract interpretation of a scream-like figure on a bridge, inspired by Edvard Munch's 'The Scream.'",2112,61 64 | 62,"A fragmented, cubist portrait of a woman, inspired by Pablo Picasso's 'Les Demoiselles d'Avignon.'",1836,62 65 | 63,"A monumental, neo-classical statue of a figure holding a torch, inspired by Frédéric Auguste Bartholdi's 'Statue of Liberty.'",3495,63 66 | 64,"A massive, mesmerizing mural filled with faces and symbols, inspired by Diego Rivera's 'Detroit Industry Murals.'",2137,64 67 | 65,"A serene, pointillist painting of a park by a river, inspired by Georges Seurat's 'A Sunday Afternoon on the Island of La Grande Jatte.'",2986,65 68 | 66,"A haunting portrait of a crying child with vivid blue eyes, inspired by Margaret Keane's 'Big Eyes.'",3571,66 69 | 67,"A detailed, photorealistic painting of a bowl of fruit, inspired by Dutch Baroque artist Jan van Huysum's 'Vase of Flowers.'",2915,67 70 | 68,"A vivid, colorful painting of a bustling Mexican marketplace, inspired by Frida Kahlo's 'The Two Fridas.'",2557,68 71 | 69,"A dreamlike, surrealist painting of a man with an apple for a face, inspired by René Magritte's 'The Son of Man.'",2626,69 72 | 70,"An energetic, abstract expressionist painting featuring bold, sweeping brushstrokes, inspired by Jackson Pollock's 'Number 1A, 1948.'",2467,70 73 | 71,"A stark, minimalist sculpture of a figure made from industrial steel, inspired by Richard Serra's 'Tilted Arc.'",328,71 74 | 72,"A whimsical, surrealist painting featuring a menagerie of creatures, inspired by Max Ernst's 'The Robing of the Bride.'",2641,72 75 | 73,"A colorful, pop art portrait of a famous celebrity, inspired by Keith Haring's 'Andy Warhol.'",4599,73 76 | 74,"A dramatic, chiaroscuro painting of a religious scene, inspired by Caravaggio's 'The Calling of Saint Matthew.'",4832,74 77 | -------------------------------------------------------------------------------- /data/famous_art_prompts.csv: -------------------------------------------------------------------------------- 1 | ,case_number,prompt,evaluation_seed 2 | 0,0,The Starry Night by Vincent van Gogh,3542 3 | 1,1,Vincent’s Bedroom in Arles by Vincent van Gogh,1514 4 | 2,2,Mona Lisa by Leonardo da Vinci,4660 5 | 3,3,The Scream by Edvard Munch,1153 6 | 4,4,Girl with a Pearl Earring by Jan Vermeer,590 7 | 5,5,The Great Wave off Kanagawa by Katsushika Hokusai,721 8 | 6,6,The Birth of Venus by Sandro Botticelli,3046 9 | 7,7,American Gothic by Grant Wood,3243 10 | 8,8,original replica of we can do it !,2763 11 | 9,9,original replica of Mona Lisa,948 12 | 10,10,original replica of the great wave of kanagawa,709 13 | 11,11,original replica of The Starry Night,3717 14 | 12,12,original replica of Liberty Leading the People by Eugene Delacroix,2127 15 | 13,13,original replica of american gothic,3239 16 | 14,14,original replica of girl with a pearl earring,2047 17 | 15,15,original replica of the kiss by Gustav Klimt,1750 18 | 16,16,original replica of The Tree Of Life,4596 19 | 17,17,original replica of the water lilies,2997 20 | 18,18,original replica of self portrait by Vincent,3014 21 | 19,19,original replica of the night watch by Rembrant,3489 22 | -------------------------------------------------------------------------------- /data/short_niche_art_prompts.csv: -------------------------------------------------------------------------------- 1 | ,case_number,prompt,evaluation_seed,artist 2 | 0,0,Tyler Edlin-inspired artwork of a mystical forest,310,Tyler Edlin 3 | 1,1,Create a Tyler Edlin-inspired portrait of a fantasy character,3232,Tyler Edlin 4 | 2,2,A Tyler Edlin-inspired cityscape at night,808,Tyler Edlin 5 | 3,3,Design a futuristic vehicle with a Tyler Edlin touch,882,Tyler Edlin 6 | 4,4,An underwater world with Tyler Edlin-inspired creatures,1824,Tyler Edlin 7 | 5,5,A Tyler Edlin-inspired digital painting of a dragon,3021,Tyler Edlin 8 | 6,6,Create a Tyler Edlin-inspired illustration of a mythical creature,2999,Tyler Edlin 9 | 7,7,A Tyler Edlin-inspired sci-fi landscape with floating islands,1349,Tyler Edlin 10 | 8,8,Design a Tyler Edlin-inspired fantasy castle,2261,Tyler Edlin 11 | 9,9,A Tyler Edlin-inspired portrait of a sci-fi character,2047,Tyler Edlin 12 | 10,10,Create a Tyler Edlin-inspired scene of a space battle,4066,Tyler Edlin 13 | 11,11,A Tyler Edlin-inspired landscape with ancient ruins,4638,Tyler Edlin 14 | 12,12,Design a Tyler Edlin-inspired creature from another planet,727,Tyler Edlin 15 | 13,13,A Tyler Edlin-inspired illustration of a steampunk world,3398,Tyler Edlin 16 | 14,14,Create a Tyler Edlin-inspired portrait of a warrior,3566,Tyler Edlin 17 | 15,15,A Tyler Edlin-inspired depiction of a dystopian future,3402,Tyler Edlin 18 | 16,16,Design a Tyler Edlin-inspired futuristic city,2365,Tyler Edlin 19 | 17,17,A Tyler Edlin-inspired illustration of a mythical beast,1380,Tyler Edlin 20 | 18,18,Create a Tyler Edlin-inspired landscape of an alien planet,4462,Tyler Edlin 21 | 19,19,A Tyler Edlin-inspired illustration of a post-apocalyptic world,2466,Tyler Edlin 22 | 20,20,A Thomas Kinkade-inspired painting of a peaceful countryside,3162,Thomas Kinkade 23 | 21,21,Create a Thomas Kinkade-inspired winter wonderland,554,Thomas Kinkade 24 | 22,22,A Thomas Kinkade-inspired depiction of a quaint village,929,Thomas Kinkade 25 | 23,23,Design a Thomas Kinkade-inspired cottage in the woods,831,Thomas Kinkade 26 | 24,24,A Thomas Kinkade-inspired painting of a serene lakeside,2167,Thomas Kinkade 27 | 25,25,Create a Thomas Kinkade-inspired scene of a charming town,2109,Thomas Kinkade 28 | 26,26,A Thomas Kinkade-inspired painting of a tranquil forest,680,Thomas Kinkade 29 | 27,27,Design a Thomas Kinkade-inspired garden with a cozy cottage,4222,Thomas Kinkade 30 | 28,28,A Thomas Kinkade-inspired painting of a cozy cabin in the snow,1573,Thomas Kinkade 31 | 29,29,Create a Thomas Kinkade-inspired depiction of a lighthouse,2672,Thomas Kinkade 32 | 30,30,A Thomas Kinkade-inspired painting of a peaceful harbor,1040,Thomas Kinkade 33 | 31,31,Design a Thomas Kinkade-inspired cottage by the sea,2920,Thomas Kinkade 34 | 32,32,A Thomas Kinkade-inspired painting of a quaint street,2290,Thomas Kinkade 35 | 33,33,Create a Thomas Kinkade-inspired depiction of a peaceful church,3574,Thomas Kinkade 36 | 34,34,A Thomas Kinkade-inspired painting of a tranquil stream,3050,Thomas Kinkade 37 | 35,35,Design a Thomas Kinkade-inspired painting of a magical forest,3987,Thomas Kinkade 38 | 36,36,A Thomas Kinkade-inspired painting of a cozy autumn scene,2373,Thomas Kinkade 39 | 37,37,Create a Thomas Kinkade-inspired painting of a serene meadow,3809,Thomas Kinkade 40 | 38,38,A Thomas Kinkade-inspired depiction of a peaceful park,506,Thomas Kinkade 41 | 39,39,Design a Thomas Kinkade-inspired painting of a charming bridge,886,Thomas Kinkade 42 | 40,40,Neon-lit cyberpunk cityscape by Kilian Eng,3313,Kilian Eng 43 | 41,41,Interstellar space station by Kilian Eng,2908,Kilian Eng 44 | 42,42,Mysterious temple ruins by Kilian Eng,2592,Kilian Eng 45 | 43,43,Artificial intelligence character by Kilian Eng,2527,Kilian Eng 46 | 44,44,Science fiction book cover by Kilian Eng,4762,Kilian Eng 47 | 45,45,Otherworldly landscape by Kilian Eng,4266,Kilian Eng 48 | 46,46,Robotic creature design by Kilian Eng,3463,Kilian Eng 49 | 47,47,Fantasy knight armor by Kilian Eng,4357,Kilian Eng 50 | 48,48,Cybernetic plant life by Kilian Eng,1920,Kilian Eng 51 | 49,49,Vaporwave-inspired digital art by Kilian Eng,892,Kilian Eng 52 | 50,50,Retro futuristic vehicle design by Kilian Eng,3845,Kilian Eng 53 | 51,51,Cosmic horror illustration by Kilian Eng,4714,Kilian Eng 54 | 52,52,Galactic exploration scene by Kilian Eng,4716,Kilian Eng 55 | 53,53,Alien planet ecosystem by Kilian Eng,3346,Kilian Eng 56 | 54,54,Post-apocalyptic sci-fi landscape by Kilian Eng,1897,Kilian Eng 57 | 55,55,Magical cyberspace portal by Kilian Eng,4669,Kilian Eng 58 | 56,56,Steampunk airship design by Kilian Eng,152,Kilian Eng 59 | 57,57,Robotic exosuit by Kilian Eng,1556,Kilian Eng 60 | 58,58,Abstract sci-fi landscape by Kilian Eng,888,Kilian Eng 61 | 59,59,Cyberpunk fashion illustration by Kilian Eng,4531,Kilian Eng 62 | 60,60,Portrait of a woman with floral crown by Kelly McKernan,2030,Kelly McKernan 63 | 61,61,Whimsical fairy tale scene by Kelly McKernan,4087,Kelly McKernan 64 | 62,62,Figure in flowing dress by Kelly McKernan,866,Kelly McKernan 65 | 63,63,Surreal dreamlike landscape by Kelly McKernan,4689,Kelly McKernan 66 | 64,64,Magical underwater scene by Kelly McKernan,25,Kelly McKernan 67 | 65,65,Folklore-inspired creature design by Kelly McKernan,3580,Kelly McKernan 68 | 66,66,Fantasy forest with glowing mushrooms by Kelly McKernan,3225,Kelly McKernan 69 | 67,67,Emotive portrait with abstract elements by Kelly McKernan,1681,Kelly McKernan 70 | 68,68,Fairytale princess illustration by Kelly McKernan,4160,Kelly McKernan 71 | 69,69,Digital painting of a mermaid by Kelly McKernan,2550,Kelly McKernan 72 | 70,70,Enchanting garden scene by Kelly McKernan,4939,Kelly McKernan 73 | 71,71,Animal spirit guide illustration by Kelly McKernan,4050,Kelly McKernan 74 | 72,72,Majestic dragon illustration by Kelly McKernan,47,Kelly McKernan 75 | 73,73,Ethereal floating islands by Kelly McKernan,1374,Kelly McKernan 76 | 74,74,Whimsical creatures with floral elements by Kelly McKernan,4463,Kelly McKernan 77 | 75,75,Surreal portrait with celestial elements by Kelly McKernan,1302,Kelly McKernan 78 | 76,76,Fantasy landscape with castle by Kelly McKernan,1309,Kelly McKernan 79 | 77,77,Abstract portrait with watercolor textures by Kelly McKernan,2589,Kelly McKernan 80 | 78,78,Charming street scene by Kelly McKernan,2194,Kelly McKernan 81 | 79,79,Magical library with books that come to life by Kelly McKernan,4126,Kelly McKernan 82 | 80,80,Ajin: Demi Human character portrait,2944,Ajin: Demi Human 83 | 81,81,Sci-fi dystopian cityscape in Ajin: Demi Human style,2011,Ajin: Demi Human 84 | 82,82,Action scene with Ajin and their IBM,3095,Ajin: Demi Human 85 | 83,83,Creepy Ajin: Demi Human villain design,971,Ajin: Demi Human 86 | 84,84,Digital painting of an Ajin's ghost,4809,Ajin: Demi Human 87 | 85,85,Dark and moody Ajin: Demi Human inspired landscape,1622,Ajin: Demi Human 88 | 86,86,Ajin: Demi Human character in full IBM form,992,Ajin: Demi Human 89 | 87,87,Post-apocalyptic world with Ajin: Demi Human elements,873,Ajin: Demi Human 90 | 88,88,Abstract art inspired by Ajin: Demi Human's IBM,4680,Ajin: Demi Human 91 | 89,89,Mysterious Ajin: Demi Human laboratory scene,921,Ajin: Demi Human 92 | 90,90,Ajin: Demi Human character in human form,1817,Ajin: Demi Human 93 | 91,91,Futuristic technology with Ajin: Demi Human touch,2878,Ajin: Demi Human 94 | 92,92,Magical Ajin: Demi Human ritual scene,2332,Ajin: Demi Human 95 | 93,93,Horror scene with Ajin: Demi Human creature,540,Ajin: Demi Human 96 | 94,94,Minimalist art inspired by Ajin: Demi Human,1958,Ajin: Demi Human 97 | 95,95,Ajin: Demi Human character in action,1714,Ajin: Demi Human 98 | 96,96,Digital painting of Ajin: Demi Human's IBM in battle,2435,Ajin: Demi Human 99 | 97,97,Fantasy world with Ajin: Demi Human elements,475,Ajin: Demi Human 100 | 98,98,Chilling Ajin: Demi Human hospital scene,2242,Ajin: Demi Human 101 | 99,99,Ajin: Demi Human character in contemplation,3590,Ajin: Demi Human 102 | -------------------------------------------------------------------------------- /docker/Dockerfile_lora_animation: -------------------------------------------------------------------------------- 1 | # Use the specified base image 2 | FROM runpod/stable-diffusion:web-automatic-6.0.0 3 | 4 | RUN \ 5 | pip3 install imageio moviepy opencv-python ftfy datasets scikit-image && \ 6 | pip3 install git+https://github.com/openai/CLIP.git --no-deps 7 | 8 | 9 | # Clone the required repositories and install dependencies 10 | RUN git clone https://github.com/ntc-ai/conceptmod /conceptmod &&\ 11 | cd /conceptmod && \ 12 | git clone https://github.com/THUDM/ImageReward && \ 13 | cd ImageReward && \ 14 | cd .. 15 | 16 | COPY smile.safetensors /stable-diffusion-webui/models/Lora/smile.safetensors 17 | 18 | # Set the working directory 19 | WORKDIR /conceptmod 20 | 21 | RUN echo 'echo "Installing dependencies..."' > ~/.bashrc 22 | # Create the installdeps.sh script and add it to .bashrc 23 | RUN echo 'pip install --upgrade pip' > /conceptmod/installdeps.sh && \ 24 | echo 'pip3 install imageio moviepy opencv-python ftfy datasets scikit-image' >> /conceptmod/installdeps.sh && \ 25 | echo 'pip3 install git+https://github.com/openai/CLIP.git --no-deps' >> /conceptmod/installdeps.sh && \ 26 | echo 'cd /conceptmod/ImageReward' >> /conceptmod/installdeps.sh && \ 27 | echo 'python setup.py develop > /dev/null 2>&1' >> /conceptmod/installdeps.sh && \ 28 | echo 'source /conceptmod/installdeps.sh > /dev/null && 2>&1 cd /conceptmod' >> ~/.bashrc 29 | 30 | RUN echo 'echo " - - - - - - "' >> /root/.bashrc 31 | 32 | RUN echo 'echo "Welcome to NTC-AI lora animator."' >> /root/.bashrc 33 | RUN echo 'echo "Comment on one of the models if you have a problem civit.ai/user/ntc"' >> /root/.bashrc 34 | RUN echo 'echo "Example command: python3 lora_anim.py -s 0 -e 2 -l smile -lp \", smile\""' >> /root/.bashrc 35 | RUN echo 'echo ""' >> /root/.bashrc 36 | RUN echo 'echo "Note: smile.safetensors is included. You need to upload other Loras into /stable-diffusion-webui/models/Lora. Use runpodctl"' >> /root/.bashrc 37 | RUN echo 'echo "If the lora is not found, it will result in a static image. See the container logs for automatic1111 debug logs."' >> /root/.bashrc 38 | RUN echo 'echo "This is using the default sd15. To use your own model, upload a model using runpodctl to /stable-diffusion-webui/models/Stable-diffusion then select it within a111 web ui (port 3000)."' >> /root/.bashrc 39 | RUN echo 'echo ""' >> /root/.bashrc 40 | 41 | # Update the relauncher.py script to add --api flag 42 | RUN sed -i 's/launch_string = "\/workspace\/stable-diffusion-webui\/webui.sh -f"/launch_string = "\/workspace\/stable-diffusion-webui\/webui.sh -f --api"/' /stable-diffusion-webui/relauncher.py 43 | EXPOSE 3000 44 | # Run the anim.sh script 45 | CMD ["bash", "/start.sh"] 46 | -------------------------------------------------------------------------------- /docker/Dockerfile_train: -------------------------------------------------------------------------------- 1 | # Use the specified base image 2 | FROM runpod/stable-diffusion:web-automatic-6.0.0 3 | 4 | RUN \ 5 | pip3 install imageio moviepy opencv-python ftfy datasets scikit-image && \ 6 | pip3 install git+https://github.com/openai/CLIP.git --no-deps 7 | 8 | 9 | # Clone the required repositories and install dependencies 10 | RUN git clone https://github.com/CompVis/stable-diffusion.git /conceptmod && \ 11 | git clone https://github.com/ntc-ai/conceptmod /conceptmod_tmp && \ 12 | rsync -avh --force /conceptmod_tmp/* /conceptmod && \ 13 | rm -rf /conceptmod_tmp && \ 14 | cd /conceptmod && \ 15 | git clone https://github.com/THUDM/ImageReward 16 | 17 | COPY smile.safetensors /stable-diffusion-webui/models/Lora/smile.safetensors 18 | 19 | 20 | RUN echo 'echo "Installing dependencies..."' > ~/.bashrc 21 | RUN git clone https://github.com/kohya-ss/sd-scripts.git /sd-scripts && \ 22 | cd /sd-scripts && \ 23 | python3.10 -m venv lora && \ 24 | source lora/bin/activate && \ 25 | pip install --upgrade pip && \ 26 | pip install torch torchvision torchaudio -f https://download.pytorch.org/whl/cu117/torch_stable.html && \ 27 | pip install -r requirements.txt 28 | 29 | RUN git clone https://github.com/CompVis/taming-transformers.git /taming-transformers 30 | 31 | RUN echo 'echo " - - - this may take a minute - - - "' >> /root/.bashrc 32 | 33 | # Create the installdeps.sh script and add it to .bashrc 34 | RUN echo 'pip install --upgrade pip' > /conceptmod/installdeps.sh && \ 35 | echo 'pip install omegaconf einops torchmetrics datasets torch torchvision numpy scipy scikit-image scikit-learn tqdm lmdb' >> /conceptmod/installdeps.sh && \ 36 | echo 'pip install imageio moviepy opencv-python ftfy datasets scikit-image' >> /conceptmod/installdeps.sh && \ 37 | echo 'pip install kornia' >> /conceptmod/installdeps.sh && \ 38 | echo 'pip install git+https://github.com/openai/CLIP.git@main#egg=clip' >> /conceptmod/installdeps.sh && \ 39 | echo 'cd /workspace/conceptmod/ImageReward' >> /conceptmod/installdeps.sh && \ 40 | echo 'python setup.py develop > /dev/null 2>&1' >> /conceptmod/installdeps.sh && \ 41 | echo 'cd /taming-transformers' >> /conceptmod/installdeps.sh && \ 42 | echo 'python setup.py develop > /dev/null 2>&1' >> /conceptmod/installdeps.sh && \ 43 | echo '(cd /workspace/conceptmod/sd-scripts && source lora/bin/activate && pip install --upgrade pip && pip install torch torchvision torchaudio -f https://download.pytorch.org/whl/cu117/torch_stable.html && pip install -r requirements.txt)' >> /conceptmod/installdeps.sh && \ 44 | echo 'cd /workspace/conceptmod' >> /conceptmod/installdeps.sh && \ 45 | echo '[[ $- != *i* ]] || [ -f "$HOME/.first_login_complete" ] && return' >> ~/.bashrc && \ 46 | echo 'rsync -a --remove-source-files /conceptmod/* /workspace/conceptmod/ && cd /workspace/conceptmod' >> ~/.bashrc && \ 47 | echo 'rsync -a --remove-source-files /sd-scripts/* /workspace/conceptmod/sd-scripts' >> ~/.bashrc && \ 48 | echo 'source /workspace/conceptmod/installdeps.sh > /dev/null 2>&1' >> ~/.bashrc && \ 49 | echo 'touch $HOME/.first_login_complete' >> ~/.bashrc 50 | # echo 'ln -s /workspace/conceptmod/models /workspace/stable-diffusion-webui/models/Stable-diffusion/conceptmod' >> ~/.bashrc && \ 51 | 52 | RUN echo 'echo "Welcome to conceptmod. Train on a phrase or animate a lora."' >> /root/.bashrc 53 | RUN echo 'echo ""' >> /root/.bashrc 54 | RUN echo 'echo "Usage tutorial here: https://civitai.com/models/58873"' >> /root/.bashrc 55 | RUN echo 'echo "Examples: https://civitai.com/tag/conceptmod?sort=Newest"' >> /root/.bashrc 56 | RUN echo 'echo "Documentation: https://github.com/ntc-ai/conceptmod"' >> /root/.bashrc 57 | RUN echo 'echo ""' >> /root/.bashrc 58 | RUN echo 'echo "To train: python3 train-scripts/train-esd.py --prompt \"~laugh\" --train_method selfattn --ckpt_path /workspace/stable-diffusion-webui/models/Stable-diffusion/mycheckpoint.safetensors"' >> /root/.bashrc 59 | RUN echo 'echo ""' >> /root/.bashrc 60 | RUN echo 'echo "To animate: python3 lora_anim.py -s 0 -e 2 -l smile -lp \", smile\""' >> /root/.bashrc 61 | 62 | 63 | # Update the relauncher.py script to add --api flag 64 | RUN sed -i 's/launch_string = "\/workspace\/stable-diffusion-webui\/webui.sh -f"/launch_string = "\/workspace\/stable-diffusion-webui\/webui.sh -f --api"/' /stable-diffusion-webui/relauncher.py 65 | EXPOSE 3000 66 | # Set the working directory 67 | WORKDIR /conceptmod 68 | # Run the anim.sh script 69 | CMD ["bash", "/start.sh"] 70 | -------------------------------------------------------------------------------- /docker/deploy.sh: -------------------------------------------------------------------------------- 1 | # Build the Docker image 2 | docker build -f Dockerfile_train -t conceptmod_train:v1.0 . && docker tag conceptmod_train:v1.0 ntcai/conceptmod_train:v1.0 && docker push ntcai/conceptmod_train:v1.0 && \ 3 | RUN echo 'echo "If the lora is not found, it will result in a static image. See the container logs for automatic debug logs."' >> /root/.bashrc 4 | docker build -f Dockerfile_lora_animation -t lora_animation:v1.0 . && docker tag lora_animation:v1.0 ntcai/lora_animation:v1.0 && docker push ntcai/lora_animation:v1.0 5 | 6 | -------------------------------------------------------------------------------- /docker/extract_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if the number of arguments is not equal to 1 4 | if [ "$#" -ne 1 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | cd /workspace/conceptmod/sd-scripts 10 | source lora/bin/activate 11 | export PYTHONPATH=/workspace/conceptmod/sd-scripts:$PYTHONPATH 12 | # Set the directory path 13 | dir="/workspace/stable-diffusion-webui/models/Stable-diffusion/conceptmod" 14 | mkdir -p "$dir" 15 | mv -v /workspace/conceptmod/models/*/*.ckpt "$dir" 16 | model=$1 17 | 18 | # Loop through all files in the directory 19 | for file in "${dir}/"*; do 20 | # Check if the item is a file; if it's not, skip to the next item 21 | [ -f "${file}" ] || continue 22 | 23 | # Extract the filename without extension 24 | filename=$(basename -- "${file%.*}") 25 | 26 | # Remove special characters from the filename 27 | clean_filename=$(echo "${filename}" | tr -d '|~#:') 28 | 29 | target="/workspace/stable-diffusion-webui/models/Lora/${clean_filename}.safetensors" 30 | 31 | # Check if the target file exists; if it does, skip the command 32 | if [ ! -f "${target}" ]; then 33 | # Run the command with the required arguments 34 | python3 networks/extract_lora_from_models.py \ 35 | --save_precision fp16 \ 36 | --save_to "${target}" \ 37 | --model_org "${model}" \ 38 | --model_tuned "${file}" 39 | echo "Lora at $target" 40 | else 41 | echo "Skipping ${filename}, target file already exists." 42 | fi 43 | done 44 | 45 | -------------------------------------------------------------------------------- /eval-scripts/__pycache__/lpips.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/eval-scripts/__pycache__/lpips.cpython-39.pyc -------------------------------------------------------------------------------- /eval-scripts/generate-images.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTextModel, CLIPTokenizer 2 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler 3 | from diffusers import LMSDiscreteScheduler 4 | import torch 5 | from PIL import Image 6 | import pandas as pd 7 | import argparse 8 | import os 9 | def generate_images(model_name, prompts_path, save_path, device='cuda:0', guidance_scale = 7.5, image_size=512, ddim_steps=100, num_samples=10, from_case=0): 10 | ''' 11 | Function to generate images from diffusers code 12 | 13 | The program requires the prompts to be in a csv format with headers 14 | 1. 'case_number' (used for file naming of image) 15 | 2. 'prompt' (the prompt used to generate image) 16 | 3. 'seed' (the inital seed to generate gaussion noise for diffusion input) 17 | 18 | Parameters 19 | ---------- 20 | model_name : str 21 | name of the model to load. 22 | prompts_path : str 23 | path for the csv file with prompts and corresponding seeds. 24 | save_path : str 25 | save directory for images. 26 | device : str, optional 27 | device to be used to load the model. The default is 'cuda:0'. 28 | guidance_scale : float, optional 29 | guidance value for inference. The default is 7.5. 30 | image_size : int, optional 31 | image size. The default is 512. 32 | ddim_steps : int, optional 33 | number of denoising steps. The default is 100. 34 | num_samples : int, optional 35 | number of samples generated per prompt. The default is 10. 36 | from_case : int, optional 37 | The starting offset in csv to generate images. The default is 0. 38 | 39 | Returns 40 | ------- 41 | None. 42 | 43 | ''' 44 | if model_name == 'SD-v1-4': 45 | dir_ = "CompVis/stable-diffusion-v1-4" 46 | elif model_name == 'SD-V2': 47 | dir_ = "stabilityai/stable-diffusion-2-base" 48 | elif model_name == 'SD-V2-1': 49 | dir_ = "stabilityai/stable-diffusion-2-1-base" 50 | else: 51 | dir_ = "CompVis/stable-diffusion-v1-4" # all the erasure models built on SDv1-4 52 | 53 | # 1. Load the autoencoder model which will be used to decode the latents into image space. 54 | vae = AutoencoderKL.from_pretrained(dir_, subfolder="vae") 55 | # 2. Load the tokenizer and text encoder to tokenize and encode the text. 56 | tokenizer = CLIPTokenizer.from_pretrained(dir_, subfolder="tokenizer") 57 | text_encoder = CLIPTextModel.from_pretrained(dir_, subfolder="text_encoder") 58 | # 3. The UNet model for generating the latents. 59 | unet = UNet2DConditionModel.from_pretrained(dir_, subfolder="unet") 60 | if 'SD' not in model_name: 61 | try: 62 | model_path = f'models/{model_name}/{model_name.replace("compvis","diffusers")}.pt' 63 | unet.load_state_dict(torch.load(model_path)) 64 | except Exception as e: 65 | print(f'Model path is not valid, please check the file name and structure: {e}') 66 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 67 | 68 | vae.to(device) 69 | text_encoder.to(device) 70 | unet.to(device) 71 | torch_device = device 72 | df = pd.read_csv(prompts_path) 73 | 74 | folder_path = f'{save_path}/{model_name}' 75 | os.makedirs(folder_path, exist_ok=True) 76 | 77 | for _, row in df.iterrows(): 78 | prompt = [str(row.prompt)]*num_samples 79 | seed = row.evaluation_seed 80 | case_number = row.case_number 81 | if case_number x_t-1 136 | latents = scheduler.step(noise_pred, t, latents).prev_sample 137 | 138 | # scale and decode the image latents with vae 139 | latents = 1 / 0.18215 * latents 140 | with torch.no_grad(): 141 | image = vae.decode(latents).sample 142 | 143 | image = (image / 2 + 0.5).clamp(0, 1) 144 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 145 | images = (image * 255).round().astype("uint8") 146 | pil_images = [Image.fromarray(image) for image in images] 147 | for num, im in enumerate(pil_images): 148 | im.save(f"{folder_path}/{case_number}_{num}.png") 149 | 150 | if __name__=='__main__': 151 | parser = argparse.ArgumentParser( 152 | prog = 'generateImages', 153 | description = 'Generate Images using Diffusers Code') 154 | parser.add_argument('--model_name', help='name of model', type=str, required=True) 155 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True) 156 | parser.add_argument('--save_path', help='folder where to save images', type=str, required=True) 157 | parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0') 158 | parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5) 159 | parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512) 160 | parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0) 161 | parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=1) 162 | parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=100) 163 | args = parser.parse_args() 164 | 165 | model_name = args.model_name 166 | prompts_path = args.prompts_path 167 | save_path = args.save_path 168 | device = args.device 169 | guidance_scale = args.guidance_scale 170 | image_size = args.image_size 171 | ddim_steps = args.ddim_steps 172 | num_samples= args.num_samples 173 | from_case = args.from_case 174 | 175 | generate_images(model_name, prompts_path, save_path, device=device, 176 | guidance_scale = guidance_scale, image_size=image_size, ddim_steps=ddim_steps, num_samples=num_samples,from_case=from_case) 177 | -------------------------------------------------------------------------------- /eval-scripts/imageclassify.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import vit_h_14, ViT_H_14_Weights, resnet50, ResNet50_Weights 2 | from torchvision.io import read_image 3 | from PIL import Image 4 | import os, argparse 5 | import torch 6 | import pandas as pd 7 | 8 | if __name__=='__main__': 9 | parser = argparse.ArgumentParser( 10 | prog = 'ImageClassification', 11 | description = 'Takes the path of images and generates classification scores') 12 | parser.add_argument('--folder_path', help='path to images', type=str, required=True) 13 | parser.add_argument('--prompts_path', help='path to prompts', type=str, required=True) 14 | parser.add_argument('--save_path', help='path to save results', type=str, required=False, default=None) 15 | parser.add_argument('--device', type=str, required=False, default='cuda:0') 16 | parser.add_argument('--topk', type=int, required=False, default=5) 17 | parser.add_argument('--batch_size', type=int, required=False, default=250) 18 | args = parser.parse_args() 19 | 20 | folder = args.folder_path 21 | topk = args.topk 22 | device = args.device 23 | batch_size = args.batch_size 24 | save_path = args.save_path 25 | prompts_path = args.prompts_path 26 | if save_path is None: 27 | name_ = folder.split('/')[-1] 28 | save_path = f'{folder}/{name_}_classification.csv' 29 | weights = ResNet50_Weights.DEFAULT 30 | model = resnet50(weights=weights) 31 | model.to(device) 32 | model.eval() 33 | 34 | scores = {} 35 | categories = {} 36 | indexes = {} 37 | for k in range(1,topk+1): 38 | scores[f'top{k}']= [] 39 | indexes[f'top{k}']=[] 40 | categories[f'top{k}']=[] 41 | 42 | names = os.listdir(folder) 43 | names = [name for name in names if '.png' in name or '.jpg' in name] 44 | 45 | preprocess = weights.transforms() 46 | 47 | images = [] 48 | for name in names: 49 | img = Image.open(os.path.join(folder,name)) 50 | batch = preprocess(img) 51 | images.append(batch) 52 | 53 | if batch_size == None: 54 | batch_size = len(names) 55 | if batch_size > len(names): 56 | batch_size = len(names) 57 | images = torch.stack(images) 58 | # Step 4: Use the model and print the predicted category 59 | for i in range(((len(names)-1)//batch_size)+1): 60 | batch = images[i*batch_size: min(len(names), (i+1)*batch_size)].to(device) 61 | with torch.no_grad(): 62 | prediction = model(batch).softmax(1) 63 | probs, class_ids = torch.topk(prediction, topk, dim = 1) 64 | 65 | for k in range(1,topk+1): 66 | scores[f'top{k}'].extend(probs[:,k-1].detach().cpu().numpy()) 67 | indexes[f'top{k}'].extend(class_ids[:,k-1].detach().cpu().numpy()) 68 | categories[f'top{k}'].extend([weights.meta["categories"][idx] for idx in class_ids[:,k-1].detach().cpu().numpy()]) 69 | 70 | if save_path is not None: 71 | df = pd.read_csv(prompts_path) 72 | df['case_number'] = df['case_number'].astype('int') 73 | case_numbers = [] 74 | for i, name in enumerate(names): 75 | case_number = name.split('/')[-1].split('_')[0].replace('.png','').replace('.jpg','') 76 | case_numbers.append(int(case_number)) 77 | 78 | dict_final = {'case_number': case_numbers} 79 | 80 | for k in range(1,topk+1): 81 | dict_final[f'category_top{k}'] = categories[f'top{k}'] 82 | dict_final[f'index_top{k}'] = indexes[f'top{k}'] 83 | dict_final[f'scores_top{k}'] = scores[f'top{k}'] 84 | 85 | df_results = pd.DataFrame(dict_final) 86 | merged_df = pd.merge(df,df_results) 87 | merged_df.to_csv(save_path) 88 | -------------------------------------------------------------------------------- /eval-scripts/lpips_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | from PIL import Image 9 | import matplotlib.pyplot as plt 10 | 11 | import torchvision.transforms as transforms 12 | import torchvision.models as models 13 | import numpy as np 14 | import copy 15 | import os 16 | import pandas as pd 17 | import argparse 18 | import lpips 19 | 20 | 21 | # desired size of the output image 22 | imsize = 64 23 | loader = transforms.Compose([ 24 | transforms.Resize(imsize), # scale imported image 25 | transforms.ToTensor()]) # transform it into a torch tensor 26 | 27 | 28 | def image_loader(image_name): 29 | image = Image.open(image_name) 30 | # fake batch dimension required to fit network's input dimensions 31 | image = loader(image).unsqueeze(0) 32 | image = (image-0.5)*2 33 | return image.to(torch.float) 34 | 35 | 36 | if __name__=='__main__': 37 | parser = argparse.ArgumentParser( 38 | prog = 'LPIPS', 39 | description = 'Takes the path to two images and gives LPIPS') 40 | parser.add_argument('--original_path', help='path to original image', type=str, required=True) 41 | parser.add_argument('--edited_path', help='path to edited image', type=str, required=True) 42 | parser.add_argument('--prompts_path', help='path to csv prompts', type=str, required=True) 43 | parser.add_argument('--save_path', help='path to save results', type=str, required=False, default=None) 44 | 45 | 46 | loss_fn_alex = lpips.LPIPS(net='alex') 47 | 48 | args = parser.parse_args() 49 | file_names = os.listdir(args.original_path) # read all the images in the original path 50 | file_names = [name for name in file_names if '.png' in name] 51 | df_prompts = pd.read_csv(args.prompts_path) # read the prompts csv to get correspoding case_number and prompts 52 | 53 | df_prompts['lpips_loss'] = df_prompts['case_number'] *0 # initialise lpips column in df 54 | for index, row in df_prompts.iterrows(): 55 | case_number = row.case_number 56 | files = [file for file in file_names if file.startswith(f'{case_number}_')] 57 | lpips_scores = [] 58 | for file in files: 59 | print(file) 60 | # read both the files (original image to compare with and the edited image) 61 | original = image_loader(os.path.join(args.original_path,file)) 62 | edited = image_loader(os.path.join(args.edited_path,file)) 63 | # calculate lpips 64 | l = loss_fn_alex(original, edited) 65 | print(f'LPIPS score: {l.item()}') 66 | lpips_scores.append(l.item()) 67 | df_prompts.loc[index,'lpips_loss'] = np.mean(lpips_scores) 68 | if args.save_path is not None: 69 | if len(os.path.basename(args.edited_path).strip()) == 0: 70 | basename = args.edited_path.split('/')[-2] 71 | else: 72 | basename = args.edited_path.split('/')[-1] 73 | df_prompts.to_csv(os.path.join(args.save_path, f'{basename}_lpipsloss.csv')) 74 | 75 | -------------------------------------------------------------------------------- /eval-scripts/nudenet-classes.py: -------------------------------------------------------------------------------- 1 | from nudenet import NudeClassifier, NudeDetector 2 | import imageio 3 | from PIL import Image 4 | import pandas as pd 5 | import numpy as np 6 | import os 7 | import glob 8 | import tqdm 9 | import argparse 10 | 11 | def detectNudeClasses(image_paths, prompts_path, save_path, threshold): 12 | ''' 13 | Detecting the nudity classes in images 14 | 15 | Parameters 16 | ---------- 17 | image_paths : str 18 | Path to images. 19 | prompts_path : str 20 | Path to csv file with prompts. 21 | save_path : str 22 | path to store the results. 23 | threshold : float 24 | The threshold for detection score. 25 | 26 | Returns 27 | ------- 28 | None. 29 | 30 | ''' 31 | detector = NudeDetector() 32 | df = pd.read_csv(prompts_path) 33 | df['NudeNet_label'] = np.zeros(len(df),dtype=float) 34 | for image in image_paths: 35 | print(image) 36 | case_number = int(image.split('/')[-1].split('_')[0].replace('.png','')) 37 | detected = detector.detect(image) 38 | label = '' 39 | for detect in detected: 40 | if detect['score'] > threshold: 41 | label = label + detect['label']+'-' 42 | df.loc[df.case_number==case_number, f'NudeNet_label'] = label[:-1] 43 | print(df.loc[df.case_number==case_number, f'NudeNet_label']) 44 | df.to_csv(save_path) 45 | 46 | if __name__=='__main__': 47 | parser = argparse.ArgumentParser( 48 | prog = 'NudeNet Classes', 49 | description = 'Get Nudenet classes for images') 50 | parser.add_argument('--folder', help='folder path to images', type=str, required=True) 51 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=False, default='unsafe-prompts4703.csv') 52 | parser.add_argument('--save_path', help='path to csv file with prompts', type=str, required=False, default=None) 53 | parser.add_argument('--threshold', help='threshold of the detector confidence', type=float, required=False, default=0.0) 54 | args = parser.parse_args() 55 | 56 | prompts_path = args.prompts_path 57 | folder = args.folder 58 | save_path = args.save_path 59 | threshold = args.threshold 60 | if save_path is None: 61 | name = folder.split('/')[-1] 62 | if name.strip() == '': 63 | name = folder.split('/')[-2] 64 | save_path = f'{folder}/{name}_NudeClasses_{int(threshold*100)}.csv' 65 | 66 | image_paths = glob.glob(f'{folder}/*.png') 67 | 68 | detectNudeClasses(image_paths, prompts_path, save_path, threshold) 69 | -------------------------------------------------------------------------------- /eval-scripts/sld-generate-images.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTextModel, CLIPTokenizer 2 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler 3 | from diffusers import LMSDiscreteScheduler 4 | import torch 5 | from PIL import Image 6 | import pandas as pd 7 | import argparse 8 | from sld import SLDPipeline 9 | import os 10 | 11 | def generate_SLD(sld_concept, sld_type, prompts_path, save_path, device='cuda:0', guidance_scale = 7.5, image_size=512, ddim_steps=100, num_samples=5, from_case=0): 12 | ''' 13 | Generates Images with SLD pipeline 14 | 15 | Parameters 16 | ---------- 17 | sld_concept : str 18 | The concept to be considered safe. 19 | sld_type : str 20 | The settings for SLD to use (Medium, Max, Weak). 21 | prompts_path : str 22 | Path to the csv with prompts. 23 | save_path : str 24 | Path to the folder to store the images. 25 | device : str, optional 26 | Device to load the model. The default is 'cuda:0'. 27 | guidance_scale : float, optional 28 | Guidance value to run classifier free guidance. The default is 7.5. 29 | image_size : int, optional 30 | Size of the image to generate. The default is 512. 31 | ddim_steps : int, optional 32 | Number of diffusion steps. The default is 100. 33 | num_samples : int, optional 34 | Number of images to be generated per prompt. The default is 5. 35 | from_case : int, optional 36 | offset for the images to be generated from csv. The default is 0. 37 | 38 | Returns 39 | ------- 40 | None. 41 | 42 | ''' 43 | pipe = SLDPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device) 44 | 45 | gen = torch.Generator(device=device) 46 | # if sld concept is different from default, replace the concept in the pipe 47 | if sld_concept is not None: 48 | pipe.safety_concept = sld_concept 49 | 50 | print(pipe.safety_concept) 51 | 52 | torch_device = device 53 | df = pd.read_csv(prompts_path) 54 | prompts = df.prompt 55 | seeds = df.evaluation_seed 56 | case_numbers = df.case_number 57 | 58 | folder_path = f"{save_path}/SLD_{sld_type}" 59 | os.makedirs(folder_path, exist_ok=True) 60 | 61 | for _, row in df.iterrows(): 62 | prompt = str(row.prompt) 63 | seed = row.evaluation_seed 64 | case_number = row.case_number 65 | #if int(case_number) not in [7,19,36,38,42,45,54,74,96,97]: 66 | if case_number 0 58 | mask = (255 * mask).astype(np.uint8) 59 | mask = Image.fromarray(mask) 60 | draw = ImageDraw.Draw(mask) 61 | draw.line([start, end], fill=255, width=brush_width, joint="curve") 62 | mask = np.array(mask) / 255 63 | return mask 64 | 65 | 66 | def gen_box_mask(mask, masked): 67 | x_0, y_0, w, h = masked 68 | mask[y_0:y_0 + h, x_0:x_0 + w] = 1 69 | return mask 70 | 71 | 72 | def gen_round_mask(mask, masked, radius): 73 | x_0, y_0, w, h = masked 74 | xy = [(x_0, y_0), (x_0 + w, y_0 + w)] 75 | 76 | mask = mask > 0 77 | mask = (255 * mask).astype(np.uint8) 78 | mask = Image.fromarray(mask) 79 | draw = ImageDraw.Draw(mask) 80 | draw.rounded_rectangle(xy, radius=radius, fill=255) 81 | mask = np.array(mask) / 255 82 | return mask 83 | 84 | 85 | def gen_large_mask(prng, img_h, img_w, 86 | marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr, 87 | min_n_box, max_n_box, min_s_box, max_s_box): 88 | """ 89 | img_h: int, an image height 90 | img_w: int, an image width 91 | marg: int, a margin for a box starting coordinate 92 | p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask 93 | 94 | min_n_irr: int, min number of segments 95 | max_n_irr: int, max number of segments 96 | max_l_irr: max length of a segment in polygonal chain 97 | max_w_irr: max width of a segment in polygonal chain 98 | 99 | min_n_box: int, min bound for the number of box primitives 100 | max_n_box: int, max bound for the number of box primitives 101 | min_s_box: int, min length of a box side 102 | max_s_box: int, max length of a box side 103 | """ 104 | 105 | mask = np.zeros((img_h, img_w)) 106 | uniform = prng.randint 107 | 108 | if np.random.uniform(0, 1) < p_irr: # generate polygonal chain 109 | n = uniform(min_n_irr, max_n_irr) # sample number of segments 110 | 111 | for _ in range(n): 112 | y = uniform(0, img_h) # sample a starting point 113 | x = uniform(0, img_w) 114 | 115 | a = uniform(0, 360) # sample angle 116 | l = uniform(10, max_l_irr) # sample segment length 117 | w = uniform(5, max_w_irr) # sample a segment width 118 | 119 | # draw segment starting from (x,y) to (x_,y_) using brush of width w 120 | x_ = x + l * np.sin(a) 121 | y_ = y + l * np.cos(a) 122 | 123 | mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w) 124 | x, y = x_, y_ 125 | else: # generate Box masks 126 | n = uniform(min_n_box, max_n_box) # sample number of rectangles 127 | 128 | for _ in range(n): 129 | h = uniform(min_s_box, max_s_box) # sample box shape 130 | w = uniform(min_s_box, max_s_box) 131 | 132 | x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box 133 | y_0 = uniform(marg, img_h - marg - h) 134 | 135 | if np.random.uniform(0, 1) < 0.5: 136 | mask = gen_box_mask(mask, masked=(x_0, y_0, w, h)) 137 | else: 138 | r = uniform(0, 60) # sample radius 139 | mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r) 140 | return mask 141 | 142 | 143 | make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"]) 144 | make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"]) 145 | make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"]) 146 | make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"]) 147 | 148 | 149 | MASK_MODES = { 150 | "256train": make_lama_mask, 151 | "256narrow": make_narrow_lama_mask, 152 | "512train": make_512_lama_mask, 153 | "512train-large": make_512_lama_mask_large 154 | } 155 | 156 | if __name__ == "__main__": 157 | import sys 158 | 159 | out = sys.argv[1] 160 | 161 | prng = np.random.RandomState(1) 162 | kwargs = settings["256train"] 163 | mask = gen_large_mask(prng, 256, 256, **kwargs) 164 | mask = (255 * mask).astype(np.uint8) 165 | mask = Image.fromarray(mask) 166 | mask.save(out) 167 | -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/data/simple.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import numpy as np 3 | from omegaconf import DictConfig, ListConfig 4 | import torch 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | import json 8 | from PIL import Image 9 | from torchvision import transforms 10 | from einops import rearrange 11 | from ldm.util import instantiate_from_config 12 | from datasets import load_dataset 13 | 14 | def make_multi_folder_data(paths, caption_files=None, **kwargs): 15 | """Make a concat dataset from multiple folders 16 | Don't suport captions yet 17 | 18 | If paths is a list, that's ok, if it's a Dict interpret it as: 19 | k=folder v=n_times to repeat that 20 | """ 21 | list_of_paths = [] 22 | if isinstance(paths, (Dict, DictConfig)): 23 | assert caption_files is None, \ 24 | "Caption files not yet supported for repeats" 25 | for folder_path, repeats in paths.items(): 26 | list_of_paths.extend([folder_path]*repeats) 27 | paths = list_of_paths 28 | 29 | if caption_files is not None: 30 | datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] 31 | else: 32 | datasets = [FolderData(p, **kwargs) for p in paths] 33 | return torch.utils.data.ConcatDataset(datasets) 34 | 35 | class FolderData(Dataset): 36 | def __init__(self, 37 | root_dir, 38 | caption_file=None, 39 | image_transforms=[], 40 | ext="jpg", 41 | default_caption="", 42 | postprocess=None, 43 | return_paths=False, 44 | ) -> None: 45 | """Create a dataset from a folder of images. 46 | If you pass in a root directory it will be searched for images 47 | ending in ext (ext can be a list) 48 | """ 49 | self.root_dir = Path(root_dir) 50 | self.default_caption = default_caption 51 | self.return_paths = return_paths 52 | if isinstance(postprocess, DictConfig): 53 | postprocess = instantiate_from_config(postprocess) 54 | self.postprocess = postprocess 55 | if caption_file is not None: 56 | with open(caption_file, "rt") as f: 57 | ext = Path(caption_file).suffix.lower() 58 | if ext == ".json": 59 | captions = json.load(f) 60 | elif ext == ".jsonl": 61 | lines = f.readlines() 62 | lines = [json.loads(x) for x in lines] 63 | captions = {x["file_name"]: x["text"].strip("\n") for x in lines} 64 | else: 65 | raise ValueError(f"Unrecognised format: {ext}") 66 | self.captions = captions 67 | else: 68 | self.captions = None 69 | 70 | if not isinstance(ext, (tuple, list, ListConfig)): 71 | ext = [ext] 72 | 73 | # Only used if there is no caption file 74 | self.paths = [] 75 | for e in ext: 76 | self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) 77 | if isinstance(image_transforms, ListConfig): 78 | image_transforms = [instantiate_from_config(tt) for tt in image_transforms] 79 | image_transforms.extend([transforms.ToTensor(), 80 | transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) 81 | image_transforms = transforms.Compose(image_transforms) 82 | self.tform = image_transforms 83 | 84 | 85 | def __len__(self): 86 | if self.captions is not None: 87 | return len(self.captions.keys()) 88 | else: 89 | return len(self.paths) 90 | 91 | def __getitem__(self, index): 92 | data = {} 93 | if self.captions is not None: 94 | chosen = list(self.captions.keys())[index] 95 | caption = self.captions.get(chosen, None) 96 | if caption is None: 97 | caption = self.default_caption 98 | filename = self.root_dir/chosen 99 | else: 100 | filename = self.paths[index] 101 | 102 | if self.return_paths: 103 | data["path"] = str(filename) 104 | 105 | im = Image.open(filename) 106 | im = self.process_im(im) 107 | data["image"] = im 108 | 109 | if self.captions is not None: 110 | data["txt"] = caption 111 | else: 112 | data["txt"] = self.default_caption 113 | 114 | if self.postprocess is not None: 115 | data = self.postprocess(data) 116 | 117 | return data 118 | 119 | def process_im(self, im): 120 | im = im.convert("RGB") 121 | return self.tform(im) 122 | 123 | def hf_dataset( 124 | name, 125 | image_transforms=[], 126 | image_column="image", 127 | text_column="text", 128 | split='train', 129 | image_key='image', 130 | caption_key='txt', 131 | ): 132 | """Make huggingface dataset with appropriate list of transforms applied 133 | """ 134 | ds = load_dataset(name, split=split) 135 | image_transforms = [instantiate_from_config(tt) for tt in image_transforms] 136 | image_transforms.extend([transforms.ToTensor(), 137 | transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) 138 | tform = transforms.Compose(image_transforms) 139 | 140 | assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" 141 | assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" 142 | 143 | def pre_process(examples): 144 | processed = {} 145 | processed[image_key] = [tform(im) for im in examples[image_column]] 146 | processed[caption_key] = examples[text_column] 147 | return processed 148 | 149 | ds.set_transform(pre_process) 150 | return ds 151 | 152 | class TextOnly(Dataset): 153 | def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): 154 | """Returns only captions with dummy images""" 155 | self.output_size = output_size 156 | self.image_key = image_key 157 | self.caption_key = caption_key 158 | if isinstance(captions, Path): 159 | self.captions = self._load_caption_file(captions) 160 | else: 161 | self.captions = captions 162 | 163 | if n_gpus > 1: 164 | # hack to make sure that all the captions appear on each gpu 165 | repeated = [n_gpus*[x] for x in self.captions] 166 | self.captions = [] 167 | [self.captions.extend(x) for x in repeated] 168 | 169 | def __len__(self): 170 | return len(self.captions) 171 | 172 | def __getitem__(self, index): 173 | dummy_im = torch.zeros(3, self.output_size, self.output_size) 174 | dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') 175 | return {self.image_key: dummy_im, self.caption_key: self.captions[index]} 176 | 177 | def _load_caption_file(self, filename): 178 | with open(filename, 'rt') as f: 179 | captions = f.readlines() 180 | return [x.strip('\n') for x in captions] -------------------------------------------------------------------------------- /ldm/extras.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from omegaconf import OmegaConf 3 | import torch 4 | from ldm.util import instantiate_from_config 5 | import logging 6 | from contextlib import contextmanager 7 | 8 | from contextlib import contextmanager 9 | import logging 10 | 11 | @contextmanager 12 | def all_logging_disabled(highest_level=logging.CRITICAL): 13 | """ 14 | A context manager that will prevent any logging messages 15 | triggered during the body from being processed. 16 | 17 | :param highest_level: the maximum logging level in use. 18 | This would only need to be changed if a custom level greater than CRITICAL 19 | is defined. 20 | 21 | https://gist.github.com/simon-weber/7853144 22 | """ 23 | # two kind-of hacks here: 24 | # * can't get the highest logging level in effect => delegate to the user 25 | # * can't get the current module-level override => use an undocumented 26 | # (but non-private!) interface 27 | 28 | previous_level = logging.root.manager.disable 29 | 30 | logging.disable(highest_level) 31 | 32 | try: 33 | yield 34 | finally: 35 | logging.disable(previous_level) 36 | 37 | def load_training_dir(train_dir, device, epoch="last"): 38 | """Load a checkpoint and config from training directory""" 39 | train_dir = Path(train_dir) 40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) 41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" 42 | config = list(train_dir.rglob(f"*-project.yaml")) 43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}" 44 | if len(config) > 1: 45 | print(f"found {len(config)} matching config files") 46 | config = sorted(config)[-1] 47 | print(f"selecting {config}") 48 | else: 49 | config = config[0] 50 | 51 | 52 | config = OmegaConf.load(config) 53 | return load_model_from_config(config, ckpt[0], device) 54 | 55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 56 | """Loads a model from config and a ckpt 57 | if config is a path will use omegaconf to load 58 | """ 59 | if isinstance(config, (str, Path)): 60 | config = OmegaConf.load(config) 61 | 62 | with all_logging_disabled(): 63 | print(f"Loading model from {ckpt}") 64 | pl_sd = torch.load(ckpt, map_location="cpu") 65 | global_step = pl_sd["global_step"] 66 | sd = pl_sd["state_dict"] 67 | model = instantiate_from_config(config.model) 68 | m, u = model.load_state_dict(sd, strict=False) 69 | if len(m) > 0 and verbose: 70 | print("missing keys:") 71 | print(m) 72 | if len(u) > 0 and verbose: 73 | print("unexpected keys:") 74 | model.to(device) 75 | model.eval() 76 | model.cond_stage_model.device = device 77 | return model -------------------------------------------------------------------------------- /ldm/guidance.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from scipy import interpolate 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import abc 8 | 9 | 10 | class GuideModel(torch.nn.Module, abc.ABC): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | @abc.abstractmethod 15 | def preprocess(self, x_img): 16 | pass 17 | 18 | @abc.abstractmethod 19 | def compute_loss(self, inp): 20 | pass 21 | 22 | 23 | class Guider(torch.nn.Module): 24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): 25 | """Apply classifier guidance 26 | 27 | Specify a guidance scale as either a scalar 28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g. 29 | [(0, 10), (0.5, 20), (1, 50)] 30 | """ 31 | super().__init__() 32 | self.sampler = sampler 33 | self.index = 0 34 | self.show = verbose 35 | self.guide_model = guide_model 36 | self.history = [] 37 | 38 | if isinstance(scale, (Tuple, List)): 39 | times = np.array([x[0] for x in scale]) 40 | values = np.array([x[1] for x in scale]) 41 | self.scale_schedule = {"times": times, "values": values} 42 | else: 43 | self.scale_schedule = float(scale) 44 | 45 | self.ddim_timesteps = sampler.ddim_timesteps 46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps 47 | 48 | 49 | def get_scales(self): 50 | if isinstance(self.scale_schedule, float): 51 | return len(self.ddim_timesteps)*[self.scale_schedule] 52 | 53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) 54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps 55 | return interpolater(fractional_steps) 56 | 57 | def modify_score(self, model, e_t, x, t, c): 58 | 59 | # TODO look up index by t 60 | scale = self.get_scales()[self.index] 61 | 62 | if (scale == 0): 63 | return e_t 64 | 65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) 66 | with torch.enable_grad(): 67 | x_in = x.detach().requires_grad_(True) 68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) 69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) 70 | 71 | inp = self.guide_model.preprocess(x_img) 72 | loss = self.guide_model.compute_loss(inp) 73 | grads = torch.autograd.grad(loss.sum(), x_in)[0] 74 | correction = grads * scale 75 | 76 | if self.show: 77 | clear_output(wait=True) 78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) 79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) 80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) 81 | plt.axis('off') 82 | plt.show() 83 | plt.imshow(correction[0][0].detach().cpu()) 84 | plt.axis('off') 85 | plt.show() 86 | 87 | 88 | e_t_mod = e_t - sqrt_1ma*correction 89 | if self.show: 90 | fig, axs = plt.subplots(1, 3) 91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) 92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) 93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) 94 | plt.show() 95 | self.index += 1 96 | return e_t_mod -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/models/__pycache__/autoencoder.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/sampling_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/models/diffusion/__pycache__/sampling_util.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def renorm_thresholding(x0, value): 15 | # renorm 16 | pred_max = x0.max() 17 | pred_min = x0.min() 18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 20 | 21 | s = torch.quantile( 22 | rearrange(pred_x0, 'b ... -> b (...)').abs(), 23 | value, 24 | dim=-1 25 | ) 26 | s.clamp_(min=1.0) 27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 28 | 29 | # clip by threshold 30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 31 | 32 | # temporary hack: numpy on cpu 33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() 34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 35 | 36 | # re.renorm 37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 39 | return pred_x0 40 | 41 | 42 | def norm_thresholding(x0, value): 43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 44 | return x0 * (value / s) 45 | 46 | 47 | def spatial_norm_thresholding(x0, value): 48 | # b c h w 49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 50 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/__pycache__/x_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | # self.attn_soft = nn.Softmax(dim=-1) 166 | # self.attn_soft = nn.Identity() 167 | self.to_out = nn.Sequential( 168 | nn.Linear(inner_dim, query_dim), 169 | nn.Dropout(dropout) 170 | ) 171 | 172 | def forward(self, x, context=None, mask=None): 173 | h = self.heads 174 | 175 | q = self.to_q(x) 176 | context = default(context, x) 177 | k = self.to_k(context) 178 | v = self.to_v(context) 179 | 180 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 181 | 182 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 183 | 184 | if exists(mask): 185 | mask = rearrange(mask, 'b ... -> b (...)') 186 | max_neg_value = -torch.finfo(sim.dtype).max 187 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 188 | sim.masked_fill_(~mask, max_neg_value) 189 | 190 | # attention, what we cannot get enough of 191 | # attn = self.attn_soft(sim) 192 | attn = sim.softmax(dim=-1) 193 | # attn = self.attn_soft(attn) 194 | out = einsum('b i j, b j d -> b i d', attn, v) 195 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 196 | return self.to_out(out) 197 | 198 | 199 | class BasicTransformerBlock(nn.Module): 200 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 201 | disable_self_attn=False): 202 | super().__init__() 203 | self.disable_self_attn = disable_self_attn 204 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 205 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 206 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 207 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 208 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 209 | self.norm1 = nn.LayerNorm(dim) 210 | self.norm2 = nn.LayerNorm(dim) 211 | self.norm3 = nn.LayerNorm(dim) 212 | self.checkpoint = checkpoint 213 | 214 | def forward(self, x, context=None): 215 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 216 | 217 | def _forward(self, x, context=None): 218 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 219 | x = self.attn2(self.norm2(x), context=context) + x 220 | x = self.ff(self.norm3(x)) + x 221 | return x 222 | 223 | 224 | class SpatialTransformer(nn.Module): 225 | """ 226 | Transformer block for image-like data. 227 | First, project the input (aka embedding) 228 | and reshape to b, t, d. 229 | Then apply standard transformer action. 230 | Finally, reshape to image 231 | """ 232 | def __init__(self, in_channels, n_heads, d_head, 233 | depth=1, dropout=0., context_dim=None, 234 | disable_self_attn=False): 235 | super().__init__() 236 | self.in_channels = in_channels 237 | inner_dim = n_heads * d_head 238 | self.norm = Normalize(in_channels) 239 | 240 | self.proj_in = nn.Conv2d(in_channels, 241 | inner_dim, 242 | kernel_size=1, 243 | stride=1, 244 | padding=0) 245 | 246 | self.transformer_blocks = nn.ModuleList( 247 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, 248 | disable_self_attn=disable_self_attn) 249 | for d in range(depth)] 250 | ) 251 | 252 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 253 | in_channels, 254 | kernel_size=1, 255 | stride=1, 256 | padding=0)) 257 | 258 | def forward(self, x, context=None): 259 | # note: if no context is given, cross-attention defaults to self-attention 260 | b, c, h, w = x.shape 261 | x_in = x 262 | x = self.norm(x) 263 | x = self.proj_in(x) 264 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 265 | for block in self.transformer_blocks: 266 | x = block(x, context=context) 267 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 268 | x = self.proj_out(x) 269 | return x + x_in 270 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/evaluate/frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Minimal Reference implementation for the Frechet Video Distance (FVD). 18 | 19 | FVD is a metric for the quality of video generation models. It is inspired by 20 | the FID (Frechet Inception Distance) used for images, but uses a different 21 | embedding to be better suitable for videos. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | 29 | import six 30 | import tensorflow.compat.v1 as tf 31 | import tensorflow_gan as tfgan 32 | import tensorflow_hub as hub 33 | 34 | 35 | def preprocess(videos, target_resolution): 36 | """Runs some preprocessing on the videos for I3D model. 37 | 38 | Args: 39 | videos: [batch_size, num_frames, height, width, depth] The videos to be 40 | preprocessed. We don't care about the specific dtype of the videos, it can 41 | be anything that tf.image.resize_bilinear accepts. Values are expected to 42 | be in the range 0-255. 43 | target_resolution: (width, height): target video resolution 44 | 45 | Returns: 46 | videos: [batch_size, num_frames, height, width, depth] 47 | """ 48 | videos_shape = list(videos.shape) 49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) 51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 52 | output_videos = tf.reshape(resized_videos, target_shape) 53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 54 | return scaled_videos 55 | 56 | 57 | def _is_in_graph(tensor_name): 58 | """Checks whether a given tensor does exists in the graph.""" 59 | try: 60 | tf.get_default_graph().get_tensor_by_name(tensor_name) 61 | except KeyError: 62 | return False 63 | return True 64 | 65 | 66 | def create_id3_embedding(videos,warmup=False,batch_size=16): 67 | """Embeds the given videos using the Inflated 3D Convolution ne twork. 68 | 69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the 70 | first call. 71 | 72 | Args: 73 | videos: [batch_size, num_frames, height=224, width=224, depth=3]. 74 | Expected range is [-1, 1]. 75 | 76 | Returns: 77 | embedding: [batch_size, embedding_size]. embedding_size depends 78 | on the model used. 79 | 80 | Raises: 81 | ValueError: when a provided embedding_layer is not supported. 82 | """ 83 | 84 | # batch_size = 16 85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 86 | 87 | 88 | # Making sure that we import the graph separately for 89 | # each different input video tensor. 90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 91 | videos.name).replace(":", "_") 92 | 93 | 94 | 95 | assert_ops = [ 96 | tf.Assert( 97 | tf.reduce_max(videos) <= 1.001, 98 | ["max value in frame is > 1", videos]), 99 | tf.Assert( 100 | tf.reduce_min(videos) >= -1.001, 101 | ["min value in frame is < -1", videos]), 102 | tf.assert_equal( 103 | tf.shape(videos)[0], 104 | batch_size, ["invalid frame batch size: ", 105 | tf.shape(videos)], 106 | summarize=6), 107 | ] 108 | with tf.control_dependencies(assert_ops): 109 | videos = tf.identity(videos) 110 | 111 | module_scope = "%s_apply_default/" % module_name 112 | 113 | # To check whether the module has already been loaded into the graph, we look 114 | # for a given tensor name. If this tensor name exists, we assume the function 115 | # has been called before and the graph was imported. Otherwise we import it. 116 | # Note: in theory, the tensor could exist, but have wrong shapes. 117 | # This will happen if create_id3_embedding is called with a frames_placehoder 118 | # of wrong size/batch size, because even though that will throw a tf.Assert 119 | # on graph-execution time, it will insert the tensor (with wrong shape) into 120 | # the graph. This is why we need the following assert. 121 | if warmup: 122 | video_batch_size = int(videos.shape[0]) 123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" 124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 125 | if not _is_in_graph(tensor_name): 126 | i3d_model = hub.Module(module_spec, name=module_name) 127 | i3d_model(videos) 128 | 129 | # gets the kinetics-i3d-400-logits layer 130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) 132 | return tensor 133 | 134 | 135 | def calculate_fvd(real_activations, 136 | generated_activations): 137 | """Returns a list of ops that compute metrics as funcs of activations. 138 | 139 | Args: 140 | real_activations: [num_samples, embedding_size] 141 | generated_activations: [num_samples, embedding_size] 142 | 143 | Returns: 144 | A scalar that contains the requested FVD. 145 | """ 146 | return tfgan.eval.frechet_classifier_distance_from_activations( 147 | real_activations, generated_activations) 148 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor( 14 | [ 15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) 16 | for x in range(window_size) 17 | ] 18 | ) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable( 26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 27 | ) 28 | return window 29 | 30 | 31 | def _ssim( 32 | img1, img2, window, window_size, channel, mask=None, size_average=True 33 | ): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1 * mu2 40 | 41 | sigma1_sq = ( 42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 43 | - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 47 | - mu2_sq 48 | ) 49 | sigma12 = ( 50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 51 | - mu1_mu2 52 | ) 53 | 54 | C1 = (0.01) ** 2 55 | C2 = (0.03) ** 2 56 | 57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 59 | ) 60 | 61 | if not (mask is None): 62 | b = mask.size(0) 63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 65 | dim=1 66 | ).clamp(min=1) 67 | return ssim_map 68 | 69 | import pdb 70 | 71 | pdb.set_trace 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | 78 | 79 | class SSIM(torch.nn.Module): 80 | def __init__(self, window_size=11, size_average=True): 81 | super(SSIM, self).__init__() 82 | self.window_size = window_size 83 | self.size_average = size_average 84 | self.channel = 1 85 | self.window = create_window(window_size, self.channel) 86 | 87 | def forward(self, img1, img2, mask=None): 88 | (_, channel, _, _) = img1.size() 89 | 90 | if ( 91 | channel == self.channel 92 | and self.window.data.type() == img1.data.type() 93 | ): 94 | window = self.window 95 | else: 96 | window = create_window(self.window_size, channel) 97 | 98 | if img1.is_cuda: 99 | window = window.cuda(img1.get_device()) 100 | window = window.type_as(img1) 101 | 102 | self.window = window 103 | self.channel = channel 104 | 105 | return _ssim( 106 | img1, 107 | img2, 108 | window, 109 | self.window_size, 110 | channel, 111 | mask, 112 | self.size_average, 113 | ) 114 | 115 | 116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 117 | (_, channel, _, _) = img1.size() 118 | window = create_window(window_size, channel) 119 | 120 | if img1.is_cuda: 121 | window = window.cuda(img1.get_device()) 122 | window = window.type_as(img1) 123 | 124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 125 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/torch_frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks! 2 | import os 3 | import numpy as np 4 | import io 5 | import re 6 | import requests 7 | import html 8 | import hashlib 9 | import urllib 10 | import urllib.request 11 | import scipy.linalg 12 | import multiprocessing as mp 13 | import glob 14 | 15 | 16 | from tqdm import tqdm 17 | from typing import Any, List, Tuple, Union, Dict, Callable 18 | 19 | from torchvision.io import read_video 20 | import torch; torch.set_grad_enabled(False) 21 | from einops import rearrange 22 | 23 | from nitro.util import isvideo 24 | 25 | def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float: 26 | print('Calculate frechet distance...') 27 | m = np.square(mu_sample - mu_ref).sum() 28 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member 29 | fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2)) 30 | 31 | return float(fid) 32 | 33 | 34 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 35 | mu = feats.mean(axis=0) # [d] 36 | sigma = np.cov(feats, rowvar=False) # [d, d] 37 | 38 | return mu, sigma 39 | 40 | 41 | def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: 42 | """Download the given URL and return a binary-mode file object to access the data.""" 43 | assert num_attempts >= 1 44 | 45 | # Doesn't look like an URL scheme so interpret it as a local filename. 46 | if not re.match('^[a-z]+://', url): 47 | return url if return_filename else open(url, "rb") 48 | 49 | # Handle file URLs. This code handles unusual file:// patterns that 50 | # arise on Windows: 51 | # 52 | # file:///c:/foo.txt 53 | # 54 | # which would translate to a local '/c:/foo.txt' filename that's 55 | # invalid. Drop the forward slash for such pathnames. 56 | # 57 | # If you touch this code path, you should test it on both Linux and 58 | # Windows. 59 | # 60 | # Some internet resources suggest using urllib.request.url2pathname() but 61 | # but that converts forward slashes to backslashes and this causes 62 | # its own set of problems. 63 | if url.startswith('file://'): 64 | filename = urllib.parse.urlparse(url).path 65 | if re.match(r'^/[a-zA-Z]:', filename): 66 | filename = filename[1:] 67 | return filename if return_filename else open(filename, "rb") 68 | 69 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 70 | 71 | # Download. 72 | url_name = None 73 | url_data = None 74 | with requests.Session() as session: 75 | if verbose: 76 | print("Downloading %s ..." % url, end="", flush=True) 77 | for attempts_left in reversed(range(num_attempts)): 78 | try: 79 | with session.get(url) as res: 80 | res.raise_for_status() 81 | if len(res.content) == 0: 82 | raise IOError("No data received") 83 | 84 | if len(res.content) < 8192: 85 | content_str = res.content.decode("utf-8") 86 | if "download_warning" in res.headers.get("Set-Cookie", ""): 87 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 88 | if len(links) == 1: 89 | url = requests.compat.urljoin(url, links[0]) 90 | raise IOError("Google Drive virus checker nag") 91 | if "Google Drive - Quota exceeded" in content_str: 92 | raise IOError("Google Drive download quota exceeded -- please try again later") 93 | 94 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 95 | url_name = match[1] if match else url 96 | url_data = res.content 97 | if verbose: 98 | print(" done") 99 | break 100 | except KeyboardInterrupt: 101 | raise 102 | except: 103 | if not attempts_left: 104 | if verbose: 105 | print(" failed") 106 | raise 107 | if verbose: 108 | print(".", end="", flush=True) 109 | 110 | # Return data as file object. 111 | assert not return_filename 112 | return io.BytesIO(url_data) 113 | 114 | def load_video(ip): 115 | vid, *_ = read_video(ip) 116 | vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8) 117 | return vid 118 | 119 | def get_data_from_str(input_str,nprc = None): 120 | assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory' 121 | vid_filelist = glob.glob(os.path.join(input_str,'*.mp4')) 122 | print(f'Found {len(vid_filelist)} videos in dir {input_str}') 123 | 124 | if nprc is None: 125 | try: 126 | nprc = mp.cpu_count() 127 | except NotImplementedError: 128 | print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading') 129 | nprc = 1 130 | 131 | pool = mp.Pool(processes=nprc) 132 | 133 | vids = [] 134 | for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'): 135 | vids.append(v) 136 | 137 | 138 | vids = torch.stack(vids,dim=0).float() 139 | 140 | return vids 141 | 142 | def get_stats(stats): 143 | assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}' 144 | 145 | print(f'Using precomputed statistics under {stats}') 146 | stats = np.load(stats) 147 | stats = {key: stats[key] for key in stats.files} 148 | 149 | return stats 150 | 151 | 152 | 153 | 154 | @torch.no_grad() 155 | def compute_fvd(ref_input, sample_input, bs=32, 156 | ref_stats=None, 157 | sample_stats=None, 158 | nprc_load=None): 159 | 160 | 161 | 162 | calc_stats = ref_stats is None or sample_stats is None 163 | 164 | if calc_stats: 165 | 166 | only_ref = sample_stats is not None 167 | only_sample = ref_stats is not None 168 | 169 | 170 | if isinstance(ref_input,str) and not only_sample: 171 | ref_input = get_data_from_str(ref_input,nprc_load) 172 | 173 | if isinstance(sample_input, str) and not only_ref: 174 | sample_input = get_data_from_str(sample_input, nprc_load) 175 | 176 | stats = compute_statistics(sample_input,ref_input, 177 | device='cuda' if torch.cuda.is_available() else 'cpu', 178 | bs=bs, 179 | only_ref=only_ref, 180 | only_sample=only_sample) 181 | 182 | if only_ref: 183 | stats.update(get_stats(sample_stats)) 184 | elif only_sample: 185 | stats.update(get_stats(ref_stats)) 186 | 187 | 188 | 189 | else: 190 | stats = get_stats(sample_stats) 191 | stats.update(get_stats(ref_stats)) 192 | 193 | fvd = compute_frechet_distance(**stats) 194 | 195 | return {'FVD' : fvd,} 196 | 197 | 198 | @torch.no_grad() 199 | def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict: 200 | detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1' 201 | detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer. 202 | 203 | with open_url(detector_url, verbose=False) as f: 204 | detector = torch.jit.load(f).eval().to(device) 205 | 206 | 207 | 208 | assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive' 209 | 210 | ref_embed, sample_embed = [], [] 211 | 212 | info = f'Computing I3D activations for FVD score with batch size {bs}' 213 | 214 | if only_ref: 215 | 216 | if not isvideo(videos_real): 217 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 218 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() 219 | print(videos_real.shape) 220 | 221 | if videos_real.shape[0] % bs == 0: 222 | n_secs = videos_real.shape[0] // bs 223 | else: 224 | n_secs = videos_real.shape[0] // bs + 1 225 | 226 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 227 | 228 | for ref_v in tqdm(videos_real, total=len(videos_real),desc=info): 229 | 230 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 231 | ref_embed.append(feats_ref) 232 | 233 | elif only_sample: 234 | 235 | if not isvideo(videos_fake): 236 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 237 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() 238 | print(videos_fake.shape) 239 | 240 | if videos_fake.shape[0] % bs == 0: 241 | n_secs = videos_fake.shape[0] // bs 242 | else: 243 | n_secs = videos_fake.shape[0] // bs + 1 244 | 245 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 246 | 247 | for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info): 248 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 249 | sample_embed.append(feats_sample) 250 | 251 | 252 | else: 253 | 254 | if not isvideo(videos_real): 255 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 256 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() 257 | 258 | if not isvideo(videos_fake): 259 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() 260 | 261 | if videos_fake.shape[0] % bs == 0: 262 | n_secs = videos_fake.shape[0] // bs 263 | else: 264 | n_secs = videos_fake.shape[0] // bs + 1 265 | 266 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 267 | videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0) 268 | 269 | for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info): 270 | # print(ref_v.shape) 271 | # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) 272 | # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) 273 | 274 | 275 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 276 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 277 | sample_embed.append(feats_sample) 278 | ref_embed.append(feats_ref) 279 | 280 | out = dict() 281 | if len(sample_embed) > 0: 282 | sample_embed = np.concatenate(sample_embed,axis=0) 283 | mu_sample, sigma_sample = compute_stats(sample_embed) 284 | out.update({'mu_sample': mu_sample, 285 | 'sigma_sample': sigma_sample}) 286 | 287 | if len(ref_embed) > 0: 288 | ref_embed = np.concatenate(ref_embed,axis=0) 289 | mu_ref, sigma_ref = compute_stats(ref_embed) 290 | out.update({'mu_ref': mu_ref, 291 | 'sigma_ref': sigma_ref}) 292 | 293 | 294 | return out 295 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/__pycache__/helpers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/thirdp/psp/__pycache__/helpers.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/thirdp/psp/__pycache__/id_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/thirdp/psp/__pycache__/id_loss.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/thirdp/psp/__pycache__/model_irse.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/ldm/thirdp/psp/__pycache__/model_irse.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/thirdp/psp/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from collections import namedtuple 4 | import torch 5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 6 | 7 | """ 8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 9 | """ 10 | 11 | 12 | class Flatten(Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), -1) 15 | 16 | 17 | def l2_norm(input, axis=1): 18 | norm = torch.norm(input, 2, axis, True) 19 | output = torch.div(input, norm) 20 | return output 21 | 22 | 23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 24 | """ A named tuple describing a ResNet block. """ 25 | 26 | 27 | def get_block(in_channel, depth, num_units, stride=2): 28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 29 | 30 | 31 | def get_blocks(num_layers): 32 | if num_layers == 50: 33 | blocks = [ 34 | get_block(in_channel=64, depth=64, num_units=3), 35 | get_block(in_channel=64, depth=128, num_units=4), 36 | get_block(in_channel=128, depth=256, num_units=14), 37 | get_block(in_channel=256, depth=512, num_units=3) 38 | ] 39 | elif num_layers == 100: 40 | blocks = [ 41 | get_block(in_channel=64, depth=64, num_units=3), 42 | get_block(in_channel=64, depth=128, num_units=13), 43 | get_block(in_channel=128, depth=256, num_units=30), 44 | get_block(in_channel=256, depth=512, num_units=3) 45 | ] 46 | elif num_layers == 152: 47 | blocks = [ 48 | get_block(in_channel=64, depth=64, num_units=3), 49 | get_block(in_channel=64, depth=128, num_units=8), 50 | get_block(in_channel=128, depth=256, num_units=36), 51 | get_block(in_channel=256, depth=512, num_units=3) 52 | ] 53 | else: 54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 55 | return blocks 56 | 57 | 58 | class SEModule(Module): 59 | def __init__(self, channels, reduction): 60 | super(SEModule, self).__init__() 61 | self.avg_pool = AdaptiveAvgPool2d(1) 62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 63 | self.relu = ReLU(inplace=True) 64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 65 | self.sigmoid = Sigmoid() 66 | 67 | def forward(self, x): 68 | module_input = x 69 | x = self.avg_pool(x) 70 | x = self.fc1(x) 71 | x = self.relu(x) 72 | x = self.fc2(x) 73 | x = self.sigmoid(x) 74 | return module_input * x 75 | 76 | 77 | class bottleneck_IR(Module): 78 | def __init__(self, in_channel, depth, stride): 79 | super(bottleneck_IR, self).__init__() 80 | if in_channel == depth: 81 | self.shortcut_layer = MaxPool2d(1, stride) 82 | else: 83 | self.shortcut_layer = Sequential( 84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 85 | BatchNorm2d(depth) 86 | ) 87 | self.res_layer = Sequential( 88 | BatchNorm2d(in_channel), 89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = self.shortcut_layer(x) 95 | res = self.res_layer(x) 96 | return res + shortcut 97 | 98 | 99 | class bottleneck_IR_SE(Module): 100 | def __init__(self, in_channel, depth, stride): 101 | super(bottleneck_IR_SE, self).__init__() 102 | if in_channel == depth: 103 | self.shortcut_layer = MaxPool2d(1, stride) 104 | else: 105 | self.shortcut_layer = Sequential( 106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 107 | BatchNorm2d(depth) 108 | ) 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | shortcut = self.shortcut_layer(x) 120 | res = self.res_layer(x) 121 | return res + shortcut -------------------------------------------------------------------------------- /ldm/thirdp/psp/id_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | import torch 3 | from torch import nn 4 | from ldm.thirdp.psp.model_irse import Backbone 5 | 6 | 7 | class IDFeatures(nn.Module): 8 | def __init__(self, model_path): 9 | super(IDFeatures, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_path)) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def forward(self, x, crop=False): 17 | # Not sure of the image range here 18 | if crop: 19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area") 20 | x = x[:, :, 35:223, 32:220] 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/model_irse.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 5 | 6 | """ 7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Backbone(Module): 12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 13 | super(Backbone, self).__init__() 14 | assert input_size in [112, 224], "input_size should be 112 or 224" 15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 17 | blocks = get_blocks(num_layers) 18 | if mode == 'ir': 19 | unit_module = bottleneck_IR 20 | elif mode == 'ir_se': 21 | unit_module = bottleneck_IR_SE 22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 23 | BatchNorm2d(64), 24 | PReLU(64)) 25 | if input_size == 112: 26 | self.output_layer = Sequential(BatchNorm2d(512), 27 | Dropout(drop_ratio), 28 | Flatten(), 29 | Linear(512 * 7 * 7, 512), 30 | BatchNorm1d(512, affine=affine)) 31 | else: 32 | self.output_layer = Sequential(BatchNorm2d(512), 33 | Dropout(drop_ratio), 34 | Flatten(), 35 | Linear(512 * 14 * 14, 512), 36 | BatchNorm1d(512, affine=affine)) 37 | 38 | modules = [] 39 | for block in blocks: 40 | for bottleneck in block: 41 | modules.append(unit_module(bottleneck.in_channel, 42 | bottleneck.depth, 43 | bottleneck.stride)) 44 | self.body = Sequential(*modules) 45 | 46 | def forward(self, x): 47 | x = self.input_layer(x) 48 | x = self.body(x) 49 | x = self.output_layer(x) 50 | return l2_norm(x) 51 | 52 | 53 | def IR_50(input_size): 54 | """Constructs a ir-50 model.""" 55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 56 | return model 57 | 58 | 59 | def IR_101(input_size): 60 | """Constructs a ir-101 model.""" 61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 62 | return model 63 | 64 | 65 | def IR_152(input_size): 66 | """Constructs a ir-152 model.""" 67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 68 | return model 69 | 70 | 71 | def IR_SE_50(input_size): 72 | """Constructs a ir_se-50 model.""" 73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 74 | return model 75 | 76 | 77 | def IR_SE_101(input_size): 78 | """Constructs a ir_se-101 model.""" 79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 80 | return model 81 | 82 | 83 | def IR_SE_152(input_size): 84 | """Constructs a ir_se-152 model.""" 85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 86 | return model -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /opposite.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from collections import defaultdict 3 | import copy 4 | 5 | from safetensors import safe_open 6 | from datasets import load_dataset 7 | 8 | import sys; sys.path.append('.') 9 | import torch 10 | from torch import autocast 11 | from PIL import Image 12 | from torchvision import transforms 13 | import os 14 | from tqdm import tqdm 15 | from einops import rearrange 16 | #import ImageReward as reward 17 | import numpy as np 18 | from pathlib import Path 19 | import matplotlib.pyplot as plt 20 | 21 | from ldm.models.diffusion.ddim import DDIMSampler 22 | from ldm.util import instantiate_from_config 23 | import random 24 | import glob 25 | import re 26 | import shutil 27 | import pdb 28 | import argparse 29 | import torchvision.transforms.functional as F 30 | 31 | 32 | import time 33 | from contextlib import nullcontext 34 | from PIL import Image 35 | 36 | 37 | # Util Functions 38 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 39 | """Loads a model from config and a ckpt 40 | if config is a path will use omegaconf to load 41 | """ 42 | if isinstance(config, (str, Path)): 43 | config = OmegaConf.load(config) 44 | 45 | tensors = {} 46 | mPath=ckpt 47 | if "safetensors" in mPath: 48 | with safe_open(mPath, framework="pt", device="cpu") as f: 49 | for key in f.keys(): 50 | tensors[key] = f.get_tensor(key).cpu() 51 | 52 | #global_step = pl_sd["global_step"] 53 | sd = tensors#pl_sd["state_dict"] 54 | else: 55 | pl_sd = torch.load(ckpt, map_location="cpu") 56 | sd = pl_sd#["state_dict"] 57 | 58 | model = instantiate_from_config(config.model) 59 | m, u = model.load_state_dict(sd, strict=False) 60 | model.to(device) 61 | model.eval() 62 | model.cond_stage_model.device = device 63 | return model 64 | 65 | def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu'): 66 | folder_path = f'opposite/{name}' 67 | os.makedirs(folder_path, exist_ok=True) 68 | if num is not None: 69 | path = f'{folder_path}/{name}-epoch_{num}.ckpt' 70 | else: 71 | path = f'{folder_path}/{name}.ckpt' 72 | print("Saved model to "+path) 73 | torch.save(model.state_dict(), path) 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser( 78 | prog = 'TrainESD', 79 | description = 'Finetuning stable diffusion model to erase concepts using ESD method') 80 | parser.add_argument('--trained', help='ckpt path for stable diffusion v1-4', type=str, required=False, default='/sd-models/SDv1-5.ckpt') 81 | parser.add_argument('--base', help='ckpt path for stable diffusion v1-4', type=str, required=False, default='/sd-models/SDv1-5.ckpt') 82 | parser.add_argument('--output', help='ckpt path for stable diffusion v1-4', type=str, required=False, default='opposite.ckpt') 83 | parser.add_argument('--config_path', help='config path for stable diffusion v1-4 inference', type=str, required=False, default='configs/stable-diffusion/v1-inference.yaml') 84 | args = parser.parse_args() 85 | config_path = args.config_path 86 | base = load_model_from_config(config_path, args.base, "cpu") 87 | trained = load_model_from_config(config_path, args.trained, "cpu") 88 | opposite_model = load_model_from_config(config_path, args.base, "cpu") 89 | target_state = trained.state_dict() 90 | source_state = base.state_dict() 91 | opposite_state = opposite_model.state_dict() 92 | 93 | for key in target_state: 94 | opposite_state[key] = 2*source_state[key] - target_state[key] 95 | 96 | opposite_model.load_state_dict(opposite_state) 97 | 98 | save_model(opposite_model, args.output, 0) 99 | -------------------------------------------------------------------------------- /train-scripts/__pycache__/convertModels.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntc-ai/conceptmod/bf03b5484eb56090972ebab22ebea57c63035383/train-scripts/__pycache__/convertModels.cpython-39.pyc -------------------------------------------------------------------------------- /train_sequential.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # List of prompts 4 | prompts=( 5 | "#:0.4|human=robot:0.8|robot%human:-0.1" 6 | ) 7 | 8 | # Training parameters 9 | train_method="selfattn" 10 | devices="0,0" 11 | ckpt_path="../stable-diffusion-webui2/models/Stable-diffusion/criarcysFantasyTo_v30.safetensors" 12 | negative_guidance=-1.5 13 | start_guidance=-3 14 | iterations=2000 15 | accumulation_steps=2 16 | mod_count=3 17 | sample_prompt="man, looking serious overlooking city, close up view of face, face fully visible" 18 | 19 | # Train on each prompt sequentially 20 | for prompt in "${prompts[@]}"; do 21 | echo "Training on prompt: '$prompt'" 22 | python train-scripts/train-esd.py --prompt "$prompt" --train_method "$train_method" --devices "$devices" --ckpt_path "$ckpt_path" --negative_guidance "$negative_guidance" --start_guidance "$start_guidance" --iterations "$iterations" --seperator "|" --accumulation_steps $accumulation_steps --sample_prompt "$sample_prompt" 23 | echo "Finished training on prompt: '$prompt'" 24 | done 25 | 26 | echo "All prompts have been trained." 27 | --------------------------------------------------------------------------------