├── .dockerignore ├── docker ├── run-shell.sh ├── Dockerfile ├── docker-run.sh ├── entrypoint.sh └── opencv.pc ├── .gitignore ├── js ├── mode.js ├── upload.js ├── outpaint.js ├── xss.js ├── setup.js ├── proceed.js ├── keyboard.js └── toolbar.js ├── config.yaml ├── .gitmodules ├── docker-compose.yml ├── docs ├── run_with_docker.md ├── usage.md └── setup_guide.md ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── perlin2d.py ├── models ├── v1-inference.yaml └── v1-inpainting-inference.yaml ├── stablediffusion_infinity_colab.ipynb ├── readme.md ├── interrogate.py ├── postprocess.py ├── environment.yml ├── utils.py ├── LICENSE ├── process.py ├── index.html ├── canvas.py └── convert_checkpoint.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .github/ 3 | .git/ 4 | docs/ 5 | .dockerignore 6 | readme.md 7 | LICENSE -------------------------------------------------------------------------------- /docker/run-shell.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "$(dirname $0)" 4 | 5 | docker-compose run -p 8888:8888 --rm -u root sd-infinity bash 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | Makefile 3 | .ipynb_checkpoints/ 4 | build/ 5 | csrc/ 6 | .idea/ 7 | travis.sh 8 | *.iml 9 | .token 10 | -------------------------------------------------------------------------------- /js/mode.js: -------------------------------------------------------------------------------- 1 | function(mode){ 2 | let app=document.querySelector("gradio-app").shadowRoot; 3 | let frame=app.querySelector("#sdinfframe").contentWindow.document; 4 | frame.querySelector("#mode").value=mode; 5 | return mode; 6 | } -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | shortcut: 2 | clear: Escape 3 | load: Ctrl+o 4 | save: Ctrl+s 5 | export: Ctrl+e 6 | upload: Ctrl+u 7 | selection: 1 8 | canvas: 2 9 | eraser: 3 10 | outpaint: d 11 | accept: a 12 | cancel: c 13 | retry: r 14 | prev: q 15 | next: e 16 | zoom_in: z 17 | zoom_out: x 18 | random_seed: s -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "glid_3_xl_stable"] 2 | path = glid_3_xl_stable 3 | url = https://github.com/lkwq007/glid_3_xl_stable.git 4 | [submodule "PyPatchMatch"] 5 | path = PyPatchMatch 6 | url = https://github.com/lkwq007/PyPatchMatch.git 7 | [submodule "sd_grpcserver"] 8 | path = sd_grpcserver 9 | url = https://github.com/lkwq007/sd_grpcserver.git 10 | [submodule "blip_model"] 11 | path = blip_model 12 | url = https://github.com/lkwq007/blip_model 13 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3:4.12.0 2 | 3 | RUN apt-get update && \ 4 | apt install -y \ 5 | fonts-dejavu-core \ 6 | build-essential \ 7 | libopencv-dev \ 8 | cmake \ 9 | vim \ 10 | && apt-get clean 11 | 12 | COPY docker/opencv.pc /usr/lib/pkgconfig/opencv.pc 13 | 14 | RUN useradd -ms /bin/bash user 15 | USER user 16 | 17 | RUN mkdir ~/.huggingface && conda init bash 18 | 19 | COPY --chown=user:user . /app 20 | WORKDIR /app 21 | 22 | EXPOSE 8888 23 | CMD ["/app/docker/entrypoint.sh"] -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | sd-infinity: 3 | build: 4 | context: . 5 | dockerfile: ./docker/Dockerfile 6 | #shm_size: '2gb' # Enable if more shared memory is needed 7 | ports: 8 | - "8888:8888" 9 | volumes: 10 | - user_home:/home/user 11 | - cond_env:/opt/conda/envs 12 | deploy: 13 | resources: 14 | reservations: 15 | devices: 16 | - driver: nvidia 17 | device_ids: ['0'] 18 | capabilities: [gpu] 19 | 20 | volumes: 21 | user_home: {} 22 | cond_env: {} -------------------------------------------------------------------------------- /docs/run_with_docker.md: -------------------------------------------------------------------------------- 1 | 2 | # Running with Docker on Windows or Linux with NVIDIA GPU 3 | On Windows 10 or 11 you can follow this guide to setting up Docker with WSL2 https://www.youtube.com/watch?v=PB7zM3JrgkI 4 | 5 | Native Linux 6 | 7 | ``` 8 | cd stablediffusion-infinity/docker 9 | ./docker-run.sh 10 | ``` 11 | 12 | Windows 10,11 with WSL2 shell: 13 | - open windows Command Prompt, type "bash" 14 | - once in bash, type: 15 | ``` 16 | cd /mnt/c/PATH-TO-YOUR/stablediffusion-infinity/docker 17 | ./docker-run.sh 18 | ``` 19 | 20 | Open "http://localhost:8888" in your browser ( even though the log says http://0.0.0.0:8888 ) -------------------------------------------------------------------------------- /docker/docker-run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | echo Current dir: "$(pwd)" 5 | 6 | if ! docker version | grep 'linux/amd64' ; then 7 | echo "Could not find docker." 8 | exit 1 9 | fi 10 | 11 | if ! docker-compose version | grep v2 ; then 12 | echo "docker-compose v2.x is not installed" 13 | exit 1 14 | fi 15 | 16 | 17 | if ! docker run -it --gpus=all --rm nvidia/cuda:11.4.2-base-ubuntu20.04 nvidia-smi | grep -e 'NVIDIA.*On' ; then 18 | echo "Docker could not find your NVIDIA gpu" 19 | exit 1 20 | fi 21 | 22 | if ! docker compose build ; then 23 | echo "Error while building" 24 | exit 1 25 | fi 26 | docker compose up -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature Request]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /js/upload.js: -------------------------------------------------------------------------------- 1 | function(a,b){ 2 | if(!window.my_observe_upload) 3 | { 4 | console.log("setup upload here"); 5 | window.my_observe_upload = new MutationObserver(function (event) { 6 | console.log(event); 7 | var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document; 8 | frame.querySelector("#upload").click(); 9 | }); 10 | window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span"); 11 | window.my_observe_upload.observe(window.my_observe_upload_target, { 12 | attributes: false, 13 | subtree: true, 14 | childList: true, 15 | characterData: true 16 | }); 17 | } 18 | return [a,b]; 19 | } -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /app 4 | 5 | set -euxo pipefail 6 | 7 | set -x 8 | 9 | if ! conda env list | grep sd-inf ; then 10 | echo "Creating environment, it may appear to freeze for a few minutes..." 11 | conda env create -f environment.yml 12 | echo "Finished installing." 13 | echo "conda activate sd-inf" >> ~/.bashrc 14 | shasum environment.yml > ~/.environment.sha 15 | fi 16 | 17 | . "/opt/conda/etc/profile.d/conda.sh" 18 | conda activate sd-inf 19 | 20 | if shasum -c ~/.environment.sha > /dev/null 2>&1 ; then 21 | echo "environment.yml is unchanged." 22 | else 23 | echo "environment.yml was changed, please wait a minute until it says 'Done updating'..." 24 | conda env update --file environment.yml 25 | shasum environment.yml > ~/.environment.sha 26 | echo "Done updating." 27 | fi 28 | 29 | python app.py --port=8888 --host=0.0.0.0 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | For setup problems or dependencies problems, please post in Q&A in Discussions 13 | 14 | **To Reproduce** 15 | Steps to reproduce the behavior: 16 | 1. Go to '...' 17 | 2. Click on '....' 18 | 3. Scroll down to '....' 19 | 4. See error 20 | 21 | **Expected behavior** 22 | A clear and concise description of what you expected to happen. 23 | 24 | **Screenshots** 25 | If applicable, add screenshots to help explain your problem. 26 | 27 | **Desktop (please complete the following information):** 28 | - OS: [e.g. windows] 29 | - Browser [e.g. chrome] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /js/outpaint.js: -------------------------------------------------------------------------------- 1 | function(a){ 2 | if(!window.my_observe_outpaint) 3 | { 4 | console.log("setup outpaint here"); 5 | window.my_observe_outpaint = new MutationObserver(function (event) { 6 | console.log(event); 7 | let app=document.querySelector("gradio-app"); 8 | app=app.shadowRoot??app; 9 | let frame=app.querySelector("#sdinfframe").contentWindow; 10 | frame.postMessage(["outpaint", ""], "*"); 11 | }); 12 | var app=document.querySelector("gradio-app"); 13 | app=app.shadowRoot??app; 14 | window.my_observe_outpaint_target=app.querySelector("#output span"); 15 | window.my_observe_outpaint.observe(window.my_observe_outpaint_target, { 16 | attributes: false, 17 | subtree: true, 18 | childList: true, 19 | characterData: true 20 | }); 21 | } 22 | return a; 23 | } -------------------------------------------------------------------------------- /js/xss.js: -------------------------------------------------------------------------------- 1 | var setup_outpaint=function(){ 2 | if(!window.my_observe_outpaint) 3 | { 4 | console.log("setup outpaint here"); 5 | window.my_observe_outpaint = new MutationObserver(function (event) { 6 | console.log(event); 7 | let app=document.querySelector("gradio-app"); 8 | app=app.shadowRoot??app; 9 | let frame=app.querySelector("#sdinfframe").contentWindow; 10 | frame.postMessage(["outpaint", ""], "*"); 11 | }); 12 | var app=document.querySelector("gradio-app"); 13 | app=app.shadowRoot??app; 14 | window.my_observe_outpaint_target=app.querySelector("#output span"); 15 | window.my_observe_outpaint.observe(window.my_observe_outpaint_target, { 16 | attributes: false, 17 | subtree: true, 18 | childList: true, 19 | characterData: true 20 | }); 21 | } 22 | }; 23 | window.config_obj={ 24 | resize_check: true, 25 | enable_safety: true, 26 | use_correction: false, 27 | enable_img2img: false, 28 | use_seed: false, 29 | seed_val: 0, 30 | interrogate_mode: false, 31 | }; 32 | setup_outpaint(); -------------------------------------------------------------------------------- /js/setup.js: -------------------------------------------------------------------------------- 1 | function(token_val, width, height, size, model_choice, model_path){ 2 | let app=document.querySelector("gradio-app"); 3 | app=app.shadowRoot??app; 4 | app.querySelector("#sdinfframe").style.height=80+Number(height)+"px"; 5 | // app.querySelector("#setup_row").style.display="none"; 6 | app.querySelector("#model_path_input").style.display="none"; 7 | let frame=app.querySelector("#sdinfframe").contentWindow.document; 8 | 9 | if(frame.querySelector("#setup").value=="0") 10 | { 11 | window.my_setup=setInterval(function(){ 12 | let app=document.querySelector("gradio-app"); 13 | app=app.shadowRoot??app; 14 | let frame=app.querySelector("#sdinfframe").contentWindow.document; 15 | console.log("Check PyScript...") 16 | if(frame.querySelector("#setup").value=="1") 17 | { 18 | frame.querySelector("#draw").click(); 19 | clearInterval(window.my_setup); 20 | } 21 | }, 100) 22 | } 23 | else 24 | { 25 | frame.querySelector("#draw").click(); 26 | } 27 | return [token_val, width, height, size, model_choice, model_path]; 28 | } -------------------------------------------------------------------------------- /js/proceed.js: -------------------------------------------------------------------------------- 1 | function(sel_buffer_str, 2 | prompt_text, 3 | negative_prompt_text, 4 | strength, 5 | guidance, 6 | step, 7 | resize_check, 8 | fill_mode, 9 | enable_safety, 10 | use_correction, 11 | enable_img2img, 12 | use_seed, 13 | seed_val, 14 | generate_num, 15 | scheduler, 16 | scheduler_eta, 17 | interrogate_mode, 18 | state){ 19 | let app=document.querySelector("gradio-app"); 20 | app=app.shadowRoot??app; 21 | sel_buffer=app.querySelector("#input textarea").value; 22 | let use_correction_bak=false; 23 | ({resize_check,enable_safety,enable_img2img,use_seed,seed_val,interrogate_mode}=window.config_obj); 24 | seed_val=Number(seed_val); 25 | return [ 26 | sel_buffer, 27 | prompt_text, 28 | negative_prompt_text, 29 | strength, 30 | guidance, 31 | step, 32 | resize_check, 33 | fill_mode, 34 | enable_safety, 35 | use_correction, 36 | enable_img2img, 37 | use_seed, 38 | seed_val, 39 | generate_num, 40 | scheduler, 41 | scheduler_eta, 42 | interrogate_mode, 43 | state, 44 | ] 45 | } -------------------------------------------------------------------------------- /docker/opencv.pc: -------------------------------------------------------------------------------- 1 | prefix=/usr 2 | exec_prefix=${prefix} 3 | includedir=${prefix}/include 4 | libdir=${exec_prefix}/lib 5 | 6 | Name: opencv 7 | Description: The opencv library 8 | Version: 2.x.x 9 | Cflags: -I${includedir}/opencv4 10 | #Cflags: -I${includedir}/opencv -I${includedir}/opencv2 11 | Libs: -L${libdir} -lopencv_calib3d -lopencv_imgproc -lopencv_xobjdetect -lopencv_hdf -lopencv_flann -lopencv_core -lopencv_dpm -lopencv_videoio -lopencv_reg -lopencv_quality -lopencv_tracking -lopencv_dnn_superres -lopencv_objdetect -lopencv_stitching -lopencv_saliency -lopencv_intensity_transform -lopencv_rapid -lopencv_dnn -lopencv_features2d -lopencv_text -lopencv_calib3d -lopencv_line_descriptor -lopencv_superres -lopencv_ml -lopencv_alphamat -lopencv_viz -lopencv_optflow -lopencv_videostab -lopencv_bioinspired -lopencv_highgui -lopencv_img_hash -lopencv_freetype -lopencv_imgcodecs -lopencv_mcc -lopencv_video -lopencv_photo -lopencv_surface_matching -lopencv_rgbd -lopencv_datasets -lopencv_ximgproc -lopencv_plot -lopencv_face -lopencv_stereo -lopencv_aruco -lopencv_dnn_objdetect -lopencv_phase_unwrapping -lopencv_bgsegm -lopencv_ccalib -lopencv_hfs -lopencv_imgproc -lopencv_shape -lopencv_xphoto -lopencv_structured_light -lopencv_fuzzy -------------------------------------------------------------------------------- /js/keyboard.js: -------------------------------------------------------------------------------- 1 | 2 | window.my_setup_keyboard=setInterval(function(){ 3 | let app=document.querySelector("gradio-app"); 4 | app=app.shadowRoot??app; 5 | let frame=app.querySelector("#sdinfframe").contentWindow; 6 | console.log("Check iframe..."); 7 | if(frame.setup_shortcut) 8 | { 9 | frame.setup_shortcut(json); 10 | clearInterval(window.my_setup_keyboard); 11 | } 12 | }, 1000); 13 | var config=JSON.parse(json); 14 | var key_map={}; 15 | Object.keys(config.shortcut).forEach(k=>{ 16 | key_map[config.shortcut[k]]=k; 17 | }); 18 | document.addEventListener("keydown", e => { 19 | if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA") 20 | { 21 | let key=e.key; 22 | if(e.ctrlKey) 23 | { 24 | key="Ctrl+"+e.key; 25 | if(key in key_map) 26 | { 27 | e.preventDefault(); 28 | } 29 | } 30 | let app=document.querySelector("gradio-app"); 31 | app=app.shadowRoot??app; 32 | let frame=app.querySelector("#sdinfframe").contentDocument; 33 | frame.dispatchEvent( 34 | new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey}) 35 | ); 36 | } 37 | }) -------------------------------------------------------------------------------- /perlin2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ########## 4 | # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921 5 | def perlin(x, y, seed=0): 6 | # permutation table 7 | np.random.seed(seed) 8 | p = np.arange(256, dtype=int) 9 | np.random.shuffle(p) 10 | p = np.stack([p, p]).flatten() 11 | # coordinates of the top-left 12 | xi, yi = x.astype(int), y.astype(int) 13 | # internal coordinates 14 | xf, yf = x - xi, y - yi 15 | # fade factors 16 | u, v = fade(xf), fade(yf) 17 | # noise components 18 | n00 = gradient(p[p[xi] + yi], xf, yf) 19 | n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1) 20 | n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1) 21 | n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf) 22 | # combine noises 23 | x1 = lerp(n00, n10, u) 24 | x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01 25 | return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here 26 | 27 | 28 | def lerp(a, b, x): 29 | "linear interpolation" 30 | return a + x * (b - a) 31 | 32 | 33 | def fade(t): 34 | "6t^5 - 15t^4 + 10t^3" 35 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 36 | 37 | 38 | def gradient(h, x, y): 39 | "grad converts h to the right gradient vector and return the dot product with (x,y)" 40 | vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]]) 41 | g = vectors[h % 4] 42 | return g[:, :, 0] * x + g[:, :, 1] * y 43 | 44 | 45 | ########## -------------------------------------------------------------------------------- /models/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /models/v1-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | finetune_keys: null 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /stablediffusion_infinity_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "collapsed_sections": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "# stablediffusion-infinity\n", 23 | "\n", 24 | "https://github.com/lkwq007/stablediffusion-infinity\n", 25 | "\n", 26 | "Outpainting with Stable Diffusion on an infinite canvas" 27 | ], 28 | "metadata": { 29 | "id": "IgN1jqV_DemW" 30 | } 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "JvbfNNSJDTW5" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "#@title setup libs\n", 41 | "!nvidia-smi -L\n", 42 | "!pip install -qq -U diffusers==0.11.1 transformers ftfy accelerate\n", 43 | "!pip install -q gradio==3.11.0\n", 44 | "!pip install -q fpie timm\n", 45 | "!pip uninstall taichi -y" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "source": [ 51 | "#@title setup stablediffusion-infinity\n", 52 | "!git clone --recurse-submodules https://github.com/lkwq007/stablediffusion-infinity\n", 53 | "%cd stablediffusion-infinity\n", 54 | "!cp -r PyPatchMatch/csrc .\n", 55 | "!cp PyPatchMatch/Makefile .\n", 56 | "!cp PyPatchMatch/Makefile_fallback .\n", 57 | "!cp PyPatchMatch/travis.sh .\n", 58 | "!cp PyPatchMatch/patch_match.py . " 59 | ], 60 | "metadata": { 61 | "id": "D1BDhQCJDilE" 62 | }, 63 | "execution_count": null, 64 | "outputs": [] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "source": [ 69 | "#@title start stablediffusion-infinity (first setup may takes about two minutes for downloading models)\n", 70 | "!python app.py --share" 71 | ], 72 | "metadata": { 73 | "id": "UGotC5ckDlmO" 74 | }, 75 | "execution_count": null, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "source": [], 81 | "metadata": { 82 | "id": "R1-E07CMFZoj" 83 | }, 84 | "execution_count": null, 85 | "outputs": [] 86 | } 87 | ] 88 | } 89 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Models 4 | 5 | - stablediffusion-inpainting: `runwayml/stable-diffusion-inpainting`, does not support img2img mode 6 | - stablediffusion-inpainting+img2img-v1.5: `runwayml/stable-diffusion-inpainting` + `runwayml/stable-diffusion-v1-5`, supports img2img mode, requires larger vRAM 7 | - stablediffusion-v1.5: `runwayml/stable-diffusion-v1-5`, inpainting with `diffusers`'s legacy pipeline, low quality for outpainting, supports img2img mode 8 | - stablediffusion-v1.4: `CompVis/stable-diffusion-v1-4`, inpainting with `diffusers`'s legacy pipeline, low quality for outpainting, supports img2img mode 9 | 10 | ## Loading local model 11 | 12 | Note that when loading a local checkpoint, you have to specify the correct model choice before setup. 13 | ```shell 14 | python app.py --local_model path_to_local_model 15 | # e.g. 16 | # diffusers model weights 17 | python app.py --local_model ./models/runwayml/stable-diffusion-inpainting 18 | python app.py --local_model models/CompVis/stable-diffusion-v1-4/model_index.json 19 | # original model checkpoint 20 | python app.py --local_model /home/user/checkpoint/model.ckpt 21 | ``` 22 | 23 | ## Loading remote model 24 | 25 | Note that when loading a remote model, you have to specify the correct model choice before setup. 26 | ```shell 27 | python app.py --remote_model model_name 28 | # e.g. 29 | python app.py --remote_model hakurei/waifu-diffusion-v1-3 30 | ``` 31 | 32 | ## Using textual inversion embeddings 33 | 34 | Put `*.bin` inside `embeddings` directory. 35 | 36 | ## Using a dreambooth finetuned model 37 | 38 | ``` 39 | python app.py --remote_model model_name 40 | # e.g. 41 | python app.py --remote_model sd-dreambooth-library/pikachu 42 | # or download the weight/checkpoint and load with 43 | python app.py --local_model path_to_model 44 | ``` 45 | 46 | ## Model Path for Docker users 47 | 48 | Docker users can specify a local model path or remote mode name within the web app. 49 | 50 | ## Using fp32 mode or low vRAM mode (some GPUs might not work well fp16) 51 | 52 | ```shell 53 | python app.py --fp32 --lowvram 54 | ``` 55 | 56 | ## HTTPS 57 | 58 | ```shell 59 | python app.py --encrypt --ssl_keyfile path_to_ssl_keyfile --ssl_certfile path_to_ssl_certfile 60 | ``` 61 | 62 | ## Keyboard shortcut 63 | 64 | The shortcut can be configured via `config.yaml`. Currently only support `[key]` or `[Ctrl]` + `[key]` 65 | 66 | Default shortcuts are: 67 | 68 | ```yaml 69 | shortcut: 70 | clear: Escape 71 | load: Ctrl+o 72 | save: Ctrl+s 73 | export: Ctrl+e 74 | upload: Ctrl+u 75 | selection: 1 76 | canvas: 2 77 | eraser: 3 78 | outpaint: d 79 | accept: a 80 | cancel: c 81 | retry: r 82 | prev: q 83 | next: e 84 | zoom_in: z 85 | zoom_out: x 86 | random_seed: s 87 | ``` 88 | 89 | ## Glossary 90 | 91 | (From diffusers' document https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) 92 | - prompt: The prompt to guide the image generation. 93 | - step: The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. 94 | - guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,usually at the expense of lower image quality. 95 | - negative_prompt: The prompt or prompts not to guide the image generation. 96 | - Sample number: The number of images to generate per prompt 97 | - scheduler: A scheduler is used in combination with `unet` to denoise the encoded image latens. 98 | - eta: Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to DDIMScheduler, will be ignored for others. 99 | - strength: for img2img only, Conceptually, indicates how much to transform the reference image. -------------------------------------------------------------------------------- /docs/setup_guide.md: -------------------------------------------------------------------------------- 1 | # Setup Guide 2 | 3 | Please install conda at first ([miniconda](https://docs.conda.io/en/latest/miniconda.html) or [anaconda](https://docs.anaconda.com/anaconda/install/)). 4 | 5 | - [Setup with Linux/Nvidia GPU](#linux) 6 | - [Setup with Linux/AMD GPU](#linux-amd) 7 | - [Setup with Windows](#windows-nvidia) 8 | - [Setup with MacOS](#macos) 9 | - [Upgrade from previous version](#upgrade) 10 | 11 | ## Setup with Linux/Nvidia GPU 12 | 13 | ### conda env 14 | setup with `environment.yml` 15 | ``` 16 | git clone --recurse-submodules https://github.com/lkwq007/stablediffusion-infinity 17 | cd stablediffusion-infinity 18 | conda env create -f environment.yml 19 | ``` 20 | 21 | if the `environment.yml` doesn't work for you, you may install dependencies manually: 22 | ``` 23 | conda create -n sd-inf python=3.10 24 | conda activate sd-inf 25 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 26 | conda install scipy scikit-image 27 | conda install -c conda-forge diffusers transformers ftfy accelerate 28 | pip install opencv-python 29 | pip install -U gradio 30 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 31 | pip install timm 32 | ``` 33 | 34 | After setup the environment, you can run stablediffusion-infinity with following commands: 35 | ``` 36 | conda activate sd-inf 37 | python app.py 38 | ``` 39 | 40 | ## Setup with Linux/AMD GPU (untested) 41 | 42 | ``` 43 | conda create -n sd-inf python=3.10 44 | conda activate sd-inf 45 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2 46 | conda install scipy scikit-image 47 | conda install -c conda-forge diffusers transformers ftfy accelerate 48 | pip install opencv-python 49 | pip install -U gradio 50 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 51 | pip install timm 52 | ``` 53 | 54 | 55 | ### CPP library (optional) 56 | 57 | Note that `opencv` library (e.g. `libopencv-dev`/`opencv-devel`, the package name may differ on different distributions) is required for `PyPatchMatch`. You may need to install `opencv` by yourself. If no `opencv` installed, the `patch_match` option (usually better quality) won't work. 58 | 59 | ## Setup with Windows 60 | 61 | 62 | ``` 63 | conda create -n sd-inf python=3.10 64 | conda activate sd-inf 65 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 66 | conda install scipy scikit-image 67 | conda install -c conda-forge diffusers transformers ftfy accelerate 68 | pip install opencv-python 69 | pip install -U gradio 70 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 71 | pip install timm 72 | ``` 73 | 74 | If you use AMD GPUs, you need to install the ONNX runtime `pip install onnxruntime-directml` (only works with the `stablediffusion-inpainting` model, untested on AMD devices). 75 | 76 | For windows, you may need to replace `pip install opencv-python` with `conda install -c conda-forge opencv` 77 | 78 | After setup the environment, you can run stablediffusion-infinity with following commands: 79 | ``` 80 | conda activate sd-inf 81 | python app.py 82 | ``` 83 | ## Setup with MacOS 84 | 85 | ### conda env 86 | ``` 87 | conda create -n sd-inf python=3.10 88 | conda activate sd-inf 89 | conda install pytorch torchvision torchaudio -c pytorch-nightly 90 | conda install scipy scikit-image 91 | conda install -c conda-forge diffusers transformers ftfy accelerate 92 | pip install opencv-python 93 | pip install -U gradio 94 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 95 | pip install timm 96 | ``` 97 | 98 | After setup the environment, you can run stablediffusion-infinity with following commands: 99 | ``` 100 | conda activate sd-inf 101 | python app.py 102 | ``` 103 | ### CPP library (optional) 104 | 105 | Note that `opencv` library is required for `PyPatchMatch`. You may need to install `opencv` by yourself (via `homebrew` or compile from source). If no `opencv` installed, the `patch_match` option (usually better quality) won't work. 106 | 107 | ## Upgrade 108 | 109 | ``` 110 | conda install -c conda-forge diffusers transformers ftfy accelerate 111 | conda update -c conda-forge diffusers transformers ftfy accelerate 112 | pip install -U gradio 113 | ``` -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # stablediffusion-infinity 2 | 3 | Outpainting with Stable Diffusion on an infinite canvas. 4 | 5 | [](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb) 6 | [](https://huggingface.co/spaces/lnyan/stablediffusion-infinity) 7 | [](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md) 8 | 9 |  10 | 11 | https://user-images.githubusercontent.com/1665437/197244111-51884b3b-dffe-4dcf-a82a-fa5117c79934.mp4 12 | 13 | ## Status 14 | 15 | Powered by Stable Diffusion inpainting model, this project now works well. However, the quality of results is still not guaranteed. 16 | You may need to do prompt engineering, change the size of the selection, reduce the size of the outpainting region to get better outpainting results. 17 | 18 | The project now becomes a web app based on PyScript and Gradio. For Jupyter Notebook version, please check out the [ipycanvas](https://github.com/lkwq007/stablediffusion-infinity/tree/ipycanvas) branch. 19 | 20 | Pull requests are welcome for better UI control, ideas to achieve better results, or any other improvements. 21 | 22 | Update: the project add photometric correction to suppress seams, to use this feature, you need to install [fpie](https://github.com/Trinkle23897/Fast-Poisson-Image-Editing): `pip install fpie` (Linux/MacOS only) 23 | 24 | ## Docs 25 | 26 | ### Get Started 27 | 28 | - Setup for Windows: [setup_guide](./docs/setup_guide.md#windows) 29 | - Setup for Linux: [setup_guide](./docs/setup_guide.md#linux) 30 | - Setup for MacOS: [setup_guide](./docs/setup_guide.md#macos) 31 | - Running with Docker on Windows or Linux with NVIDIA GPU: [run_with_docker](./docs/run_with_docker.md) 32 | - Usages: [usage](./docs/usage.md) 33 | 34 | ### FAQs 35 | 36 | - The result is a black square: 37 | - False positive rate of safety checker is relatively high, you may disable the safety_checker 38 | - Some GPUs might not work with `fp16`: `python app.py --fp32 --lowvram` 39 | - What is the init_mode 40 | - init_mode indicates how to fill the empty/masked region, usually `patch_match` is better than others 41 | - Why not use `postMessage` for iframe interaction 42 | - The iframe and the gradio are in the same origin. For `postMessage` version, check out [gradio-space](https://github.com/lkwq007/stablediffusion-infinity/tree/gradio-space) version 43 | 44 | ### Known issues 45 | 46 | - The canvas is implemented with `NumPy` + `PyScript` (the project was originally implemented with `ipycanvas` inside a jupyter notebook), which is relatively inefficient compared with pure frontend solutions. 47 | - By design, the canvas is infinite. However, the canvas size is **finite** in practice. Your RAM and browser limit the canvas size. The canvas might crash or behave strangely when zoomed out by a certain scale. 48 | - The canvas requires internet: You can deploy and serve PyScript, Pyodide, and other JS/CSS assets with a local HTTP server and modify `index.html` accordingly. 49 | - Photometric correction might not work (`taichi` does not support the multithreading environment). A dirty hack (quite unreliable) is implemented to move related computation inside a subprocess. 50 | - Stable Diffusion inpainting model is much slower when selection size is larger than 512x512 51 | 52 | ## Credit 53 | 54 | The code of `perlin2d.py` is from https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921 and is **not** included in the scope of LICENSE used in this repo. 55 | 56 | The submodule `glid_3_xl_stable` is based on https://github.com/Jack000/glid-3-xl-stable 57 | 58 | The submodule `PyPatchMatch` is based on https://github.com/vacancy/PyPatchMatch 59 | 60 | The code of `postprocess.py` and `process.py` is modified based on https://github.com/Trinkle23897/Fast-Poisson-Image-Editing 61 | 62 | The code of `convert_checkpoint.py` is modified based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py 63 | 64 | The submodule `sd_grpcserver` and `handleImageAdjustment()` in `utils.py` are based on https://github.com/hafriedlander/stable-diffusion-grpcserver and https://github.com/parlance-zz/g-diffuser-bot 65 | 66 | `w2ui.min.js` and `w2ui.min.css` is from https://github.com/vitmalina/w2ui. `fabric.min.js` is a custom build of https://github.com/fabricjs/fabric.js 67 | 68 | `interrogate.py` is based on https://github.com/pharmapsychotic/clip-interrogator v1, the submodule `blip_model` is based on https://github.com/salesforce/BLIP 69 | -------------------------------------------------------------------------------- /interrogate.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2022 pharmapsychotic 5 | https://github.com/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb 6 | """ 7 | 8 | import numpy as np 9 | import os 10 | import torch 11 | import torchvision.transforms as T 12 | import torchvision.transforms.functional as TF 13 | 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torchvision import transforms 17 | from torchvision.transforms.functional import InterpolationMode 18 | from transformers import CLIPTokenizer, CLIPModel 19 | from transformers import CLIPProcessor, CLIPModel 20 | 21 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "blip_model", "data") 22 | def load_list(filename): 23 | with open(filename, 'r', encoding='utf-8', errors='replace') as f: 24 | items = [line.strip() for line in f.readlines()] 25 | return items 26 | 27 | artists = load_list(os.path.join(data_path, 'artists.txt')) 28 | flavors = load_list(os.path.join(data_path, 'flavors.txt')) 29 | mediums = load_list(os.path.join(data_path, 'mediums.txt')) 30 | movements = load_list(os.path.join(data_path, 'movements.txt')) 31 | 32 | sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] 33 | trending_list = [site for site in sites] 34 | trending_list.extend(["trending on "+site for site in sites]) 35 | trending_list.extend(["featured on "+site for site in sites]) 36 | trending_list.extend([site+" contest winner" for site in sites]) 37 | 38 | device="cpu" 39 | blip_image_eval_size = 384 40 | clip_name="openai/clip-vit-large-patch14" 41 | 42 | blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' 43 | 44 | def generate_caption(blip_model, pil_image, device="cpu"): 45 | gpu_image = transforms.Compose([ 46 | transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 49 | ])(pil_image).unsqueeze(0).to(device) 50 | 51 | with torch.no_grad(): 52 | caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) 53 | return caption[0] 54 | 55 | def rank(text_features, image_features, text_array, top_count=1): 56 | top_count = min(top_count, len(text_array)) 57 | similarity = torch.zeros((1, len(text_array))) 58 | for i in range(image_features.shape[0]): 59 | similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) 60 | similarity /= image_features.shape[0] 61 | 62 | top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) 63 | return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] 64 | 65 | class Interrogator: 66 | def __init__(self) -> None: 67 | self.tokenizer = CLIPTokenizer.from_pretrained(clip_name) 68 | try: 69 | self.get_blip() 70 | except: 71 | self.blip_model = None 72 | self.model = CLIPModel.from_pretrained(clip_name) 73 | self.processor = CLIPProcessor.from_pretrained(clip_name) 74 | self.text_feature_lst = [torch.load(os.path.join(data_path, f"{i}.pth")) for i in range(5)] 75 | 76 | def get_blip(self): 77 | from blip_model.blip import blip_decoder 78 | blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base') 79 | blip_model.eval() 80 | self.blip_model = blip_model 81 | 82 | 83 | def interrogate(self,image,use_caption=False): 84 | if self.blip_model: 85 | caption = generate_caption(self.blip_model, image) 86 | else: 87 | caption = "" 88 | model,processor=self.model,self.processor 89 | bests = [[('',0)]]*5 90 | if True: 91 | print(f"Interrogating with {clip_name}...") 92 | 93 | inputs = processor(images=image, return_tensors="pt") 94 | with torch.no_grad(): 95 | image_features = model.get_image_features(**inputs) 96 | image_features /= image_features.norm(dim=-1, keepdim=True) 97 | ranks = [ 98 | rank(self.text_feature_lst[0], image_features, mediums), 99 | rank(self.text_feature_lst[1], image_features, ["by "+artist for artist in artists]), 100 | rank(self.text_feature_lst[2], image_features, trending_list), 101 | rank(self.text_feature_lst[3], image_features, movements), 102 | rank(self.text_feature_lst[4], image_features, flavors, top_count=3) 103 | ] 104 | 105 | for i in range(len(ranks)): 106 | confidence_sum = 0 107 | for ci in range(len(ranks[i])): 108 | confidence_sum += ranks[i][ci][1] 109 | if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))): 110 | bests[i] = ranks[i] 111 | 112 | flaves = ', '.join([f"{x[0]}" for x in bests[4]]) 113 | medium = bests[0][0][0] 114 | print(ranks) 115 | if caption.startswith(medium): 116 | return f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}" 117 | else: 118 | return f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}" 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /postprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/Trinkle23897/Fast-Poisson-Image-Editing 3 | MIT License 4 | 5 | Copyright (c) 2022 Jiayi Weng 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | import time 27 | import argparse 28 | import os 29 | import fpie 30 | from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND 31 | from fpie.io import read_images, write_image 32 | from process import BaseProcessor, EquProcessor, GridProcessor 33 | 34 | from PIL import Image 35 | import numpy as np 36 | import skimage 37 | import skimage.measure 38 | import scipy 39 | import scipy.signal 40 | 41 | 42 | class PhotometricCorrection: 43 | def __init__(self,quite=False): 44 | self.get_parser("cli") 45 | args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"]) 46 | args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0) 47 | self.backend=args.backend 48 | self.args=args 49 | self.quite=quite 50 | proc: BaseProcessor 51 | proc = GridProcessor( 52 | args.gradient, 53 | args.backend, 54 | args.cpu, 55 | args.mpi_sync_interval, 56 | args.block_size, 57 | args.grid_x, 58 | args.grid_y, 59 | ) 60 | print( 61 | f"[PIE]Successfully initialize PIE {args.method} solver " 62 | f"with {args.backend} backend" 63 | ) 64 | self.proc=proc 65 | 66 | def run(self, original_image, inpainted_image, mode="mask_mode"): 67 | print(f"[PIE] start") 68 | if mode=="disabled": 69 | return inpainted_image 70 | input_arr=np.array(original_image) 71 | if input_arr[:,:,-1].sum()<1: 72 | return inpainted_image 73 | output_arr=np.array(inpainted_image) 74 | mask=input_arr[:,:,-1] 75 | mask=255-mask 76 | if mask.sum()<1 and mode=="mask_mode": 77 | mode="" 78 | if mode=="mask_mode": 79 | mask = skimage.measure.block_reduce(mask, (8, 8), np.max) 80 | mask = mask.repeat(8, axis=0).repeat(8, axis=1) 81 | else: 82 | mask[8:-9,8:-9]=255 83 | mask = mask[:,:,np.newaxis].repeat(3,axis=2) 84 | nmask=mask.copy() 85 | output_arr2=output_arr[:,:,0:3].copy() 86 | input_arr2=input_arr[:,:,0:3].copy() 87 | output_arr2[nmask<128]=0 88 | input_arr2[nmask>=128]=0 89 | output_arr2+=input_arr2 90 | src = output_arr2[:,:,0:3] 91 | tgt = src.copy() 92 | proc=self.proc 93 | args=self.args 94 | if proc.root: 95 | n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1)) 96 | proc.sync() 97 | if proc.root: 98 | result = tgt 99 | t = time.time() 100 | if args.p == 0: 101 | args.p = args.n 102 | 103 | for i in range(0, args.n, args.p): 104 | if proc.root: 105 | result, err = proc.step(args.p) # type: ignore 106 | print(f"[PIE] Iter {i + args.p}, abs_err {err}") 107 | else: 108 | proc.step(args.p) 109 | 110 | if proc.root: 111 | dt = time.time() - t 112 | print(f"[PIE] Time elapsed: {dt:.4f}s") 113 | # make sure consistent with dummy process 114 | return Image.fromarray(result) 115 | 116 | 117 | def get_parser(self,gen_type: str) -> argparse.Namespace: 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument( 120 | "-v", "--version", action="store_true", help="show the version and exit" 121 | ) 122 | parser.add_argument( 123 | "--check-backend", action="store_true", help="print all available backends" 124 | ) 125 | if gen_type == "gui" and "mpi" in ALL_BACKEND: 126 | # gui doesn't support MPI backend 127 | ALL_BACKEND.remove("mpi") 128 | parser.add_argument( 129 | "-b", 130 | "--backend", 131 | type=str, 132 | choices=ALL_BACKEND, 133 | default=DEFAULT_BACKEND, 134 | help="backend choice", 135 | ) 136 | parser.add_argument( 137 | "-c", 138 | "--cpu", 139 | type=int, 140 | default=CPU_COUNT, 141 | help="number of CPU used", 142 | ) 143 | parser.add_argument( 144 | "-z", 145 | "--block-size", 146 | type=int, 147 | default=1024, 148 | help="cuda block size (only for equ solver)", 149 | ) 150 | parser.add_argument( 151 | "--method", 152 | type=str, 153 | choices=["equ", "grid"], 154 | default="equ", 155 | help="how to parallelize computation", 156 | ) 157 | parser.add_argument("-s", "--source", type=str, help="source image filename") 158 | if gen_type == "cli": 159 | parser.add_argument( 160 | "-m", 161 | "--mask", 162 | type=str, 163 | help="mask image filename (default is to use the whole source image)", 164 | default="", 165 | ) 166 | parser.add_argument("-t", "--target", type=str, help="target image filename") 167 | parser.add_argument("-o", "--output", type=str, help="output image filename") 168 | if gen_type == "cli": 169 | parser.add_argument( 170 | "-h0", type=int, help="mask position (height) on source image", default=0 171 | ) 172 | parser.add_argument( 173 | "-w0", type=int, help="mask position (width) on source image", default=0 174 | ) 175 | parser.add_argument( 176 | "-h1", type=int, help="mask position (height) on target image", default=0 177 | ) 178 | parser.add_argument( 179 | "-w1", type=int, help="mask position (width) on target image", default=0 180 | ) 181 | parser.add_argument( 182 | "-g", 183 | "--gradient", 184 | type=str, 185 | choices=["max", "src", "avg"], 186 | default="max", 187 | help="how to calculate gradient for PIE", 188 | ) 189 | parser.add_argument( 190 | "-n", 191 | type=int, 192 | help="how many iteration would you perfer, the more the better", 193 | default=5000, 194 | ) 195 | if gen_type == "cli": 196 | parser.add_argument( 197 | "-p", type=int, help="output result every P iteration", default=0 198 | ) 199 | if "mpi" in ALL_BACKEND: 200 | parser.add_argument( 201 | "--mpi-sync-interval", 202 | type=int, 203 | help="MPI sync iteration interval", 204 | default=100, 205 | ) 206 | parser.add_argument( 207 | "--grid-x", type=int, help="x axis stride for grid solver", default=8 208 | ) 209 | parser.add_argument( 210 | "--grid-y", type=int, help="y axis stride for grid solver", default=8 211 | ) 212 | self.parser=parser 213 | 214 | if __name__ =="__main__": 215 | import sys 216 | import io 217 | import base64 218 | from PIL import Image 219 | def base64_to_pil(base64_str): 220 | data = base64.b64decode(str(base64_str)) 221 | pil = Image.open(io.BytesIO(data)) 222 | return pil 223 | 224 | def pil_to_base64(out_pil): 225 | out_buffer = io.BytesIO() 226 | out_pil.save(out_buffer, format="PNG") 227 | out_buffer.seek(0) 228 | base64_bytes = base64.b64encode(out_buffer.read()) 229 | base64_str = base64_bytes.decode("ascii") 230 | return base64_str 231 | correction_func=PhotometricCorrection(quite=True) 232 | while True: 233 | buffer = sys.stdin.readline() 234 | print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ") 235 | if len(buffer)==0: 236 | break 237 | if isinstance(buffer,str): 238 | lst=buffer.strip().split(",") 239 | else: 240 | lst=buffer.decode("ascii").strip().split(",") 241 | img0=base64_to_pil(lst[0]) 242 | img1=base64_to_pil(lst[1]) 243 | ret=correction_func.run(img0,img1,mode=lst[2]) 244 | ret_base64=pil_to_base64(ret) 245 | if isinstance(buffer,str): 246 | sys.stdout.write(f"{ret_base64}\n") 247 | else: 248 | sys.stdout.write(f"{ret_base64}\n".encode()) 249 | sys.stdout.flush() -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sd-inf 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - abseil-cpp=20211102.0=h27087fc_1 11 | - accelerate=0.14.0=pyhd8ed1ab_0 12 | - aiohttp=3.8.1=py310h5764c6d_1 13 | - aiosignal=1.3.1=pyhd8ed1ab_0 14 | - arrow-cpp=8.0.0=py310h3098874_0 15 | - async-timeout=4.0.2=pyhd8ed1ab_0 16 | - attrs=22.1.0=pyh71513ae_1 17 | - aws-c-common=0.4.57=he1b5a44_1 18 | - aws-c-event-stream=0.1.6=h72b8ae1_3 19 | - aws-checksums=0.1.9=h346380f_0 20 | - aws-sdk-cpp=1.8.185=hce553d0_0 21 | - backports=1.0=py_2 22 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 23 | - blas=1.0=mkl 24 | - blosc=1.21.0=h4ff587b_1 25 | - boost-cpp=1.78.0=he72f1d9_0 26 | - brotli=1.0.9=h5eee18b_7 27 | - brotli-bin=1.0.9=h5eee18b_7 28 | - brotlipy=0.7.0=py310h7f8727e_1002 29 | - brunsli=0.1=h2531618_0 30 | - bzip2=1.0.8=h7b6447c_0 31 | - c-ares=1.18.1=h7f8727e_0 32 | - ca-certificates=2022.10.11=h06a4308_0 33 | - certifi=2022.9.24=py310h06a4308_0 34 | - cffi=1.15.1=py310h74dc2b5_0 35 | - cfitsio=3.470=h5893167_7 36 | - charls=2.2.0=h2531618_0 37 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 38 | - click=8.1.3=unix_pyhd8ed1ab_2 39 | - cloudpickle=2.0.0=pyhd3eb1b0_0 40 | - colorama=0.4.6=pyhd8ed1ab_0 41 | - cryptography=38.0.1=py310h9ce1e76_0 42 | - cuda=11.6.2=0 43 | - cuda-cccl=11.6.55=hf6102b2_0 44 | - cuda-command-line-tools=11.6.2=0 45 | - cuda-compiler=11.6.2=0 46 | - cuda-cudart=11.6.55=he381448_0 47 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 48 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 49 | - cuda-cupti=11.6.124=h86345e5_0 50 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 51 | - cuda-driver-dev=11.6.55=0 52 | - cuda-gdb=11.8.86=0 53 | - cuda-libraries=11.6.2=0 54 | - cuda-libraries-dev=11.6.2=0 55 | - cuda-memcheck=11.8.86=0 56 | - cuda-nsight=11.8.86=0 57 | - cuda-nsight-compute=11.8.0=0 58 | - cuda-nvcc=11.6.124=hbba6d2d_0 59 | - cuda-nvdisasm=11.8.86=0 60 | - cuda-nvml-dev=11.6.55=haa9ef22_0 61 | - cuda-nvprof=11.8.87=0 62 | - cuda-nvprune=11.6.124=he22ec0a_0 63 | - cuda-nvrtc=11.6.124=h020bade_0 64 | - cuda-nvrtc-dev=11.6.124=h249d397_0 65 | - cuda-nvtx=11.6.124=h0630a44_0 66 | - cuda-nvvp=11.8.87=0 67 | - cuda-runtime=11.6.2=0 68 | - cuda-samples=11.6.101=h8efea70_0 69 | - cuda-sanitizer-api=11.8.86=0 70 | - cuda-toolkit=11.6.2=0 71 | - cuda-tools=11.6.2=0 72 | - cuda-visual-tools=11.6.2=0 73 | - cytoolz=0.12.0=py310h5eee18b_0 74 | - dask-core=2022.7.0=py310h06a4308_0 75 | - dataclasses=0.8=pyhc8e2a94_3 76 | - datasets=2.7.0=pyhd8ed1ab_0 77 | - diffusers=0.11.1=pyhd8ed1ab_0 78 | - dill=0.3.6=pyhd8ed1ab_1 79 | - ffmpeg=4.3=hf484d3e_0 80 | - fftw=3.3.9=h27cfd23_1 81 | - filelock=3.8.0=pyhd8ed1ab_0 82 | - freetype=2.12.1=h4a9f257_0 83 | - frozenlist=1.3.0=py310h5764c6d_1 84 | - fsspec=2022.10.0=py310h06a4308_0 85 | - ftfy=6.1.1=pyhd8ed1ab_0 86 | - gds-tools=1.4.0.31=0 87 | - gflags=2.2.2=he1b5a44_1004 88 | - giflib=5.2.1=h7b6447c_0 89 | - glog=0.6.0=h6f12383_0 90 | - gmp=6.2.1=h295c915_3 91 | - gnutls=3.6.15=he1e5248_0 92 | - grpc-cpp=1.46.1=h33aed49_0 93 | - huggingface_hub=0.11.0=pyhd8ed1ab_0 94 | - icu=70.1=h27087fc_0 95 | - idna=3.4=py310h06a4308_0 96 | - imagecodecs=2021.8.26=py310hecf7e94_1 97 | - imageio=2.19.3=py310h06a4308_0 98 | - importlib-metadata=5.0.0=pyha770c72_1 99 | - importlib_metadata=5.0.0=hd8ed1ab_1 100 | - intel-openmp=2021.4.0=h06a4308_3561 101 | - joblib=1.2.0=pyhd8ed1ab_0 102 | - jpeg=9e=h7f8727e_0 103 | - jxrlib=1.1=h7b6447c_2 104 | - krb5=1.19.2=hac12032_0 105 | - lame=3.100=h7b6447c_0 106 | - lcms2=2.12=h3be6417_0 107 | - ld_impl_linux-64=2.38=h1181459_1 108 | - lerc=3.0=h295c915_0 109 | - libaec=1.0.4=he6710b0_1 110 | - libbrotlicommon=1.0.9=h5eee18b_7 111 | - libbrotlidec=1.0.9=h5eee18b_7 112 | - libbrotlienc=1.0.9=h5eee18b_7 113 | - libcublas=11.11.3.6=0 114 | - libcublas-dev=11.11.3.6=0 115 | - libcufft=10.9.0.58=0 116 | - libcufft-dev=10.9.0.58=0 117 | - libcufile=1.4.0.31=0 118 | - libcufile-dev=1.4.0.31=0 119 | - libcurand=10.3.0.86=0 120 | - libcurand-dev=10.3.0.86=0 121 | - libcurl=7.85.0=h91b91d3_0 122 | - libcusolver=11.4.1.48=0 123 | - libcusolver-dev=11.4.1.48=0 124 | - libcusparse=11.7.5.86=0 125 | - libcusparse-dev=11.7.5.86=0 126 | - libdeflate=1.8=h7f8727e_5 127 | - libedit=3.1.20210910=h7f8727e_0 128 | - libev=4.33=h7f8727e_1 129 | - libevent=2.1.10=h9b69904_4 130 | - libffi=3.3=he6710b0_2 131 | - libgcc-ng=11.2.0=h1234567_1 132 | - libgfortran-ng=11.2.0=h00389a5_1 133 | - libgfortran5=11.2.0=h1234567_1 134 | - libgomp=11.2.0=h1234567_1 135 | - libiconv=1.16=h7f8727e_2 136 | - libidn2=2.3.2=h7f8727e_0 137 | - libnghttp2=1.46.0=hce63b2e_0 138 | - libnpp=11.8.0.86=0 139 | - libnpp-dev=11.8.0.86=0 140 | - libnvjpeg=11.9.0.86=0 141 | - libnvjpeg-dev=11.9.0.86=0 142 | - libpng=1.6.37=hbc83047_0 143 | - libprotobuf=3.20.1=h4ff587b_0 144 | - libssh2=1.10.0=h8f2d780_0 145 | - libstdcxx-ng=11.2.0=h1234567_1 146 | - libtasn1=4.16.0=h27cfd23_0 147 | - libthrift=0.15.0=he6d91bd_0 148 | - libtiff=4.4.0=hecacb30_2 149 | - libunistring=0.9.10=h27cfd23_0 150 | - libuuid=1.41.5=h5eee18b_0 151 | - libwebp=1.2.4=h11a3e52_0 152 | - libwebp-base=1.2.4=h5eee18b_0 153 | - libzopfli=1.0.3=he6710b0_0 154 | - locket=1.0.0=py310h06a4308_0 155 | - lz4-c=1.9.3=h295c915_1 156 | - mkl=2021.4.0=h06a4308_640 157 | - mkl-service=2.4.0=py310h7f8727e_0 158 | - mkl_fft=1.3.1=py310hd6ae3a3_0 159 | - mkl_random=1.2.2=py310h00e6091_0 160 | - multidict=6.0.2=py310h5764c6d_1 161 | - multiprocess=0.70.12.2=py310h5764c6d_2 162 | - ncurses=6.3=h5eee18b_3 163 | - nettle=3.7.3=hbbd107a_1 164 | - networkx=2.8.4=py310h06a4308_0 165 | - nsight-compute=2022.3.0.22=0 166 | - numpy=1.23.4=py310hd5efca6_0 167 | - numpy-base=1.23.4=py310h8e6c178_0 168 | - openh264=2.1.1=h4ff587b_0 169 | - openjpeg=2.4.0=h3ad879b_0 170 | - openssl=1.1.1s=h7f8727e_0 171 | - orc=1.7.4=h07ed6aa_0 172 | - packaging=21.3=pyhd3eb1b0_0 173 | - pandas=1.4.2=py310h769672d_1 174 | - partd=1.2.0=pyhd3eb1b0_1 175 | - pillow=9.2.0=py310hace64e9_1 176 | - pip=22.2.2=py310h06a4308_0 177 | - psutil=5.9.1=py310h5764c6d_0 178 | - pyarrow=8.0.0=py310h468efa6_0 179 | - pycparser=2.21=pyhd3eb1b0_0 180 | - pyopenssl=22.0.0=pyhd3eb1b0_0 181 | - pyparsing=3.0.9=py310h06a4308_0 182 | - pysocks=1.7.1=py310h06a4308_0 183 | - python=3.10.8=haa1d7c7_0 184 | - python-dateutil=2.8.2=pyhd8ed1ab_0 185 | - python-xxhash=3.0.0=py310h5764c6d_1 186 | - python_abi=3.10=2_cp310 187 | - pytorch=1.13.0=py3.10_cuda11.6_cudnn8.3.2_0 188 | - pytorch-cuda=11.6=h867d48c_0 189 | - pytorch-mutex=1.0=cuda 190 | - pytz=2022.6=pyhd8ed1ab_0 191 | - pywavelets=1.3.0=py310h7f8727e_0 192 | - re2=2022.04.01=h27087fc_0 193 | - readline=8.2=h5eee18b_0 194 | - regex=2022.4.24=py310h5764c6d_0 195 | - requests=2.28.1=py310h06a4308_0 196 | - responses=0.18.0=pyhd8ed1ab_0 197 | - sacremoses=0.0.53=pyhd8ed1ab_0 198 | - scikit-image=0.19.2=py310h00e6091_0 199 | - scipy=1.9.3=py310hd5efca6_0 200 | - setuptools=65.5.0=py310h06a4308_0 201 | - six=1.16.0=pyhd3eb1b0_1 202 | - snappy=1.1.9=h295c915_0 203 | - sqlite=3.39.3=h5082296_0 204 | - tifffile=2021.7.2=pyhd3eb1b0_2 205 | - tk=8.6.12=h1ccaba5_0 206 | - tokenizers=0.11.4=py310h3dcd8bd_1 207 | - toolz=0.12.0=py310h06a4308_0 208 | - torchaudio=0.13.0=py310_cu116 209 | - torchvision=0.14.0=py310_cu116 210 | - tqdm=4.64.1=pyhd8ed1ab_0 211 | - transformers=4.24.0=pyhd8ed1ab_0 212 | - typing-extensions=4.3.0=py310h06a4308_0 213 | - typing_extensions=4.3.0=py310h06a4308_0 214 | - tzdata=2022f=h04d1e81_0 215 | - urllib3=1.26.12=py310h06a4308_0 216 | - utf8proc=2.6.1=h27cfd23_0 217 | - wcwidth=0.2.5=pyh9f0ad1d_2 218 | - wheel=0.37.1=pyhd3eb1b0_0 219 | - xxhash=0.8.0=h7f98852_3 220 | - xz=5.2.6=h5eee18b_0 221 | - yaml=0.2.5=h7b6447c_0 222 | - yarl=1.7.2=py310h5764c6d_2 223 | - zfp=0.5.5=h295c915_6 224 | - zipp=3.10.0=pyhd8ed1ab_0 225 | - zlib=1.2.13=h5eee18b_0 226 | - zstd=1.5.2=ha4553b6_0 227 | - pip: 228 | - absl-py==1.3.0 229 | - antlr4-python3-runtime==4.9.3 230 | - anyio==3.6.2 231 | - bcrypt==4.0.1 232 | - cachetools==5.2.0 233 | - cmake==3.25.0 234 | - commonmark==0.9.1 235 | - contourpy==1.0.6 236 | - cycler==0.11.0 237 | - einops==0.4.1 238 | - fastapi==0.87.0 239 | - ffmpy==0.3.0 240 | - fonttools==4.38.0 241 | - fpie==0.2.4 242 | - google-auth==2.14.1 243 | - google-auth-oauthlib==0.4.6 244 | - gradio==3.10.1 245 | - grpcio==1.51.0 246 | - h11==0.12.0 247 | - httpcore==0.15.0 248 | - httpx==0.23.1 249 | - jinja2==3.1.2 250 | - kiwisolver==1.4.4 251 | - linkify-it-py==1.0.3 252 | - llvmlite==0.39.1 253 | - markdown==3.4.1 254 | - markdown-it-py==2.1.0 255 | - markupsafe==2.1.1 256 | - matplotlib==3.6.2 257 | - mdit-py-plugins==0.3.1 258 | - mdurl==0.1.2 259 | - numba==0.56.4 260 | - oauthlib==3.2.2 261 | - omegaconf==2.2.3 262 | - opencv-python==4.6.0.66 263 | - opencv-python-headless==4.6.0.66 264 | - orjson==3.8.2 265 | - paramiko==2.12.0 266 | - protobuf==3.20.3 267 | - pyasn1==0.4.8 268 | - pyasn1-modules==0.2.8 269 | - pycryptodome==3.15.0 270 | - pydantic==1.10.2 271 | - pydeprecate==0.3.2 272 | - pydub==0.25.1 273 | - pygments==2.13.0 274 | - pynacl==1.5.0 275 | - python-multipart==0.0.5 276 | - pytorch-lightning==1.7.7 277 | - pyyaml==6.0 278 | - requests-oauthlib==1.3.1 279 | - rfc3986==1.5.0 280 | - rich==12.6.0 281 | - rsa==4.9 282 | - sniffio==1.3.0 283 | - sourceinspect==0.0.4 284 | - starlette==0.21.0 285 | - taichi==1.2.2 286 | - tensorboard==2.11.0 287 | - tensorboard-data-server==0.6.1 288 | - tensorboard-plugin-wit==1.8.1 289 | - timm==0.6.11 290 | - torchmetrics==0.10.3 291 | - uc-micro-py==1.0.1 292 | - uvicorn==0.20.0 293 | - websockets==10.4 294 | - werkzeug==2.2.2 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from PIL import ImageFilter 3 | import cv2 4 | import numpy as np 5 | import scipy 6 | import scipy.signal 7 | from scipy.spatial import cKDTree 8 | 9 | import os 10 | from perlin2d import * 11 | 12 | patch_match_compiled = True 13 | 14 | try: 15 | from PyPatchMatch import patch_match 16 | except Exception as e: 17 | try: 18 | import patch_match 19 | except Exception as e: 20 | patch_match_compiled = False 21 | 22 | try: 23 | patch_match 24 | except NameError: 25 | print("patch_match compiling failed, will fall back to edge_pad") 26 | patch_match_compiled = False 27 | 28 | 29 | 30 | 31 | def edge_pad(img, mask, mode=1): 32 | if mode == 0: 33 | nmask = mask.copy() 34 | nmask[nmask > 0] = 1 35 | res0 = 1 - nmask 36 | res1 = nmask 37 | p0 = np.stack(res0.nonzero(), axis=0).transpose() 38 | p1 = np.stack(res1.nonzero(), axis=0).transpose() 39 | min_dists, min_dist_idx = cKDTree(p1).query(p0, 1) 40 | loc = p1[min_dist_idx] 41 | for (a, b), (c, d) in zip(p0, loc): 42 | img[a, b] = img[c, d] 43 | elif mode == 1: 44 | record = {} 45 | kernel = [[1] * 3 for _ in range(3)] 46 | nmask = mask.copy() 47 | nmask[nmask > 0] = 1 48 | res = scipy.signal.convolve2d( 49 | nmask, kernel, mode="same", boundary="fill", fillvalue=1 50 | ) 51 | res[nmask < 1] = 0 52 | res[res == 9] = 0 53 | res[res > 0] = 1 54 | ylst, xlst = res.nonzero() 55 | queue = [(y, x) for y, x in zip(ylst, xlst)] 56 | # bfs here 57 | cnt = res.astype(np.float32) 58 | acc = img.astype(np.float32) 59 | step = 1 60 | h = acc.shape[0] 61 | w = acc.shape[1] 62 | offset = [(1, 0), (-1, 0), (0, 1), (0, -1)] 63 | while queue: 64 | target = [] 65 | for y, x in queue: 66 | val = acc[y][x] 67 | for yo, xo in offset: 68 | yn = y + yo 69 | xn = x + xo 70 | if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1: 71 | if record.get((yn, xn), step) == step: 72 | acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val 73 | cnt[yn][xn] += 1 74 | acc[yn][xn] /= cnt[yn][xn] 75 | if (yn, xn) not in record: 76 | record[(yn, xn)] = step 77 | target.append((yn, xn)) 78 | step += 1 79 | queue = target 80 | img = acc.astype(np.uint8) 81 | else: 82 | nmask = mask.copy() 83 | ylst, xlst = nmask.nonzero() 84 | yt, xt = ylst.min(), xlst.min() 85 | yb, xb = ylst.max(), xlst.max() 86 | content = img[yt : yb + 1, xt : xb + 1] 87 | img = np.pad( 88 | content, 89 | ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)), 90 | mode="edge", 91 | ) 92 | return img, mask 93 | 94 | 95 | def perlin_noise(img, mask): 96 | lin_x = np.linspace(0, 5, mask.shape[1], endpoint=False) 97 | lin_y = np.linspace(0, 5, mask.shape[0], endpoint=False) 98 | x, y = np.meshgrid(lin_x, lin_y) 99 | avg = img.mean(axis=0).mean(axis=0) 100 | # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)] 101 | noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)] 102 | noise = np.stack(noise, axis=-1) 103 | # mask=skimage.measure.block_reduce(mask,(8,8),np.min) 104 | # mask=mask.repeat(8, axis=0).repeat(8, axis=1) 105 | # mask_image=Image.fromarray(mask) 106 | # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4)) 107 | # mask=np.array(mask_image) 108 | nmask = mask.copy() 109 | # nmask=nmask/255.0 110 | nmask[mask > 0] = 1 111 | img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise 112 | # img=img.astype(np.uint8) 113 | return img, mask 114 | 115 | 116 | def gaussian_noise(img, mask): 117 | noise = np.random.randn(mask.shape[0], mask.shape[1], 3) 118 | noise = (noise + 1) / 2 * 255 119 | noise = noise.astype(np.uint8) 120 | nmask = mask.copy() 121 | nmask[mask > 0] = 1 122 | img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise 123 | return img, mask 124 | 125 | 126 | def cv2_telea(img, mask): 127 | ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA) 128 | return ret, mask 129 | 130 | 131 | def cv2_ns(img, mask): 132 | ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS) 133 | return ret, mask 134 | 135 | 136 | def patch_match_func(img, mask): 137 | ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3) 138 | return ret, mask 139 | 140 | 141 | def mean_fill(img, mask): 142 | avg = img.mean(axis=0).mean(axis=0) 143 | img[mask < 1] = avg 144 | return img, mask 145 | 146 | """ 147 | Apache-2.0 license 148 | https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/main/sdgrpcserver/services/generate.py 149 | https://github.com/parlance-zz/g-diffuser-bot/tree/g-diffuser-bot-beta2 150 | _handleImageAdjustment 151 | """ 152 | try: 153 | from sd_grpcserver.sdgrpcserver import images 154 | import torch 155 | from math import sqrt 156 | def handleImageAdjustment(array, adjustments): 157 | tensor = images.fromPIL(Image.fromarray(array)) 158 | for adjustment in adjustments: 159 | which = adjustment[0] 160 | 161 | if which == "blur": 162 | sigma = adjustment[1] 163 | direction = adjustment[2] 164 | 165 | if direction == "DOWN" or direction == "UP": 166 | orig = tensor 167 | repeatCount=256 168 | sigma /= sqrt(repeatCount) 169 | 170 | for _ in range(repeatCount): 171 | tensor = images.gaussianblur(tensor, sigma) 172 | if direction == "DOWN": 173 | tensor = torch.minimum(tensor, orig) 174 | else: 175 | tensor = torch.maximum(tensor, orig) 176 | else: 177 | tensor = images.gaussianblur(tensor, adjustment.blur.sigma) 178 | elif which == "invert": 179 | tensor = images.invert(tensor) 180 | elif which == "levels": 181 | tensor = images.levels(tensor, adjustment[1], adjustment[2], adjustment[3], adjustment[4]) 182 | elif which == "channels": 183 | tensor = images.channelmap(tensor, [adjustment.channels.r, adjustment.channels.g, adjustment.channels.b, adjustment.channels.a]) 184 | elif which == "rescale": 185 | self.unimp("Rescale") 186 | elif which == "crop": 187 | tensor = images.crop(tensor, adjustment.crop.top, adjustment.crop.left, adjustment.crop.height, adjustment.crop.width) 188 | return np.array(images.toPIL(tensor)[0]) 189 | 190 | def g_diffuser(img,mask): 191 | adjustments=[["blur",32,"UP"],["level",0,0.05,0,1]] 192 | mask=handleImageAdjustment(mask,adjustments) 193 | out_mask=handleImageAdjustment(mask,adjustments) 194 | return img, mask 195 | except: 196 | def g_diffuser(img,mask): 197 | return img,mask 198 | 199 | def dummy_fill(img,mask): 200 | return img,mask 201 | functbl = { 202 | "gaussian": gaussian_noise, 203 | "perlin": perlin_noise, 204 | "edge_pad": edge_pad, 205 | "patchmatch": patch_match_func if patch_match_compiled else edge_pad, 206 | "cv2_ns": cv2_ns, 207 | "cv2_telea": cv2_telea, 208 | "g_diffuser": g_diffuser, 209 | "g_diffuser_lib": dummy_fill, 210 | } 211 | 212 | try: 213 | from postprocess import PhotometricCorrection 214 | correction_func = PhotometricCorrection() 215 | except Exception as e: 216 | print(e, "so PhotometricCorrection is disabled") 217 | class DummyCorrection: 218 | def __init__(self): 219 | self.backend="" 220 | pass 221 | def run(self,a,b,**kwargs): 222 | return b 223 | correction_func=DummyCorrection() 224 | 225 | class DummyInterrogator: 226 | def __init__(self) -> None: 227 | pass 228 | def interrogate(self,pil): 229 | return "Interrogator init failed" 230 | 231 | if "taichi" in correction_func.backend: 232 | import sys 233 | import io 234 | import base64 235 | from PIL import Image 236 | def base64_to_pil(base64_str): 237 | data = base64.b64decode(str(base64_str)) 238 | pil = Image.open(io.BytesIO(data)) 239 | return pil 240 | 241 | def pil_to_base64(out_pil): 242 | out_buffer = io.BytesIO() 243 | out_pil.save(out_buffer, format="PNG") 244 | out_buffer.seek(0) 245 | base64_bytes = base64.b64encode(out_buffer.read()) 246 | base64_str = base64_bytes.decode("ascii") 247 | return base64_str 248 | from subprocess import Popen, PIPE, STDOUT 249 | class SubprocessCorrection: 250 | def __init__(self): 251 | self.backend=correction_func.backend 252 | self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT) 253 | def run(self,img_input,img_inpainted,mode): 254 | if mode=="disabled": 255 | return img_inpainted 256 | base64_str_input = pil_to_base64(img_input) 257 | base64_str_inpainted = pil_to_base64(img_inpainted) 258 | try: 259 | if self.child.poll(): 260 | self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT) 261 | self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode()) 262 | self.child.stdin.flush() 263 | out = self.child.stdout.readline() 264 | base64_str=out.decode().strip() 265 | while base64_str and base64_str[0]=="[": 266 | print(base64_str) 267 | out = self.child.stdout.readline() 268 | base64_str=out.decode().strip() 269 | ret=base64_to_pil(base64_str) 270 | except: 271 | print("[PIE] not working, photometric correction is disabled") 272 | ret=img_inpainted 273 | return ret 274 | correction_func = SubprocessCorrection() 275 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/Trinkle23897/Fast-Poisson-Image-Editing 3 | MIT License 4 | 5 | Copyright (c) 2022 Jiayi Weng 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | import os 26 | from abc import ABC, abstractmethod 27 | from typing import Any, Optional, Tuple 28 | 29 | import numpy as np 30 | 31 | from fpie import np_solver 32 | 33 | import scipy 34 | import scipy.signal 35 | 36 | CPU_COUNT = os.cpu_count() or 1 37 | DEFAULT_BACKEND = "numpy" 38 | ALL_BACKEND = ["numpy"] 39 | 40 | try: 41 | from fpie import numba_solver 42 | ALL_BACKEND += ["numba"] 43 | DEFAULT_BACKEND = "numba" 44 | except ImportError: 45 | numba_solver = None # type: ignore 46 | 47 | try: 48 | from fpie import taichi_solver 49 | ALL_BACKEND += ["taichi-cpu", "taichi-gpu"] 50 | DEFAULT_BACKEND = "taichi-cpu" 51 | except ImportError: 52 | taichi_solver = None # type: ignore 53 | 54 | # try: 55 | # from fpie import core_gcc # type: ignore 56 | # DEFAULT_BACKEND = "gcc" 57 | # ALL_BACKEND.append("gcc") 58 | # except ImportError: 59 | # core_gcc = None 60 | 61 | # try: 62 | # from fpie import core_openmp # type: ignore 63 | # DEFAULT_BACKEND = "openmp" 64 | # ALL_BACKEND.append("openmp") 65 | # except ImportError: 66 | # core_openmp = None 67 | 68 | # try: 69 | # from mpi4py import MPI 70 | 71 | # from fpie import core_mpi # type: ignore 72 | # ALL_BACKEND.append("mpi") 73 | # except ImportError: 74 | # MPI = None # type: ignore 75 | # core_mpi = None 76 | 77 | try: 78 | from fpie import core_cuda # type: ignore 79 | DEFAULT_BACKEND = "cuda" 80 | ALL_BACKEND.append("cuda") 81 | except ImportError: 82 | core_cuda = None 83 | 84 | 85 | class BaseProcessor(ABC): 86 | """API definition for processor class.""" 87 | 88 | def __init__( 89 | self, gradient: str, rank: int, backend: str, core: Optional[Any] 90 | ): 91 | if core is None: 92 | error_msg = { 93 | "numpy": 94 | "Please run `pip install numpy`.", 95 | "numba": 96 | "Please run `pip install numba`.", 97 | "gcc": 98 | "Please install cmake and gcc in your operating system.", 99 | "openmp": 100 | "Please make sure your gcc is compatible with `-fopenmp` option.", 101 | "mpi": 102 | "Please install MPI and run `pip install mpi4py`.", 103 | "cuda": 104 | "Please make sure nvcc and cuda-related libraries are available.", 105 | "taichi": 106 | "Please run `pip install taichi`.", 107 | } 108 | print(error_msg[backend.split("-")[0]]) 109 | 110 | raise AssertionError(f"Invalid backend {backend}.") 111 | 112 | self.gradient = gradient 113 | self.rank = rank 114 | self.backend = backend 115 | self.core = core 116 | self.root = rank == 0 117 | 118 | def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: 119 | if self.gradient == "src": 120 | return a 121 | if self.gradient == "avg": 122 | return (a + b) / 2 123 | # mix gradient, see Equ. 12 in PIE paper 124 | mask = np.abs(a) < np.abs(b) 125 | a[mask] = b[mask] 126 | return a 127 | 128 | @abstractmethod 129 | def reset( 130 | self, 131 | src: np.ndarray, 132 | mask: np.ndarray, 133 | tgt: np.ndarray, 134 | mask_on_src: Tuple[int, int], 135 | mask_on_tgt: Tuple[int, int], 136 | ) -> int: 137 | pass 138 | 139 | def sync(self) -> None: 140 | self.core.sync() 141 | 142 | @abstractmethod 143 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]: 144 | pass 145 | 146 | 147 | class EquProcessor(BaseProcessor): 148 | """PIE Jacobi equation processor.""" 149 | 150 | def __init__( 151 | self, 152 | gradient: str = "max", 153 | backend: str = DEFAULT_BACKEND, 154 | n_cpu: int = CPU_COUNT, 155 | min_interval: int = 100, 156 | block_size: int = 1024, 157 | ): 158 | core: Optional[Any] = None 159 | rank = 0 160 | 161 | if backend == "numpy": 162 | core = np_solver.EquSolver() 163 | elif backend == "numba" and numba_solver is not None: 164 | core = numba_solver.EquSolver() 165 | elif backend == "gcc": 166 | core = core_gcc.EquSolver() 167 | elif backend == "openmp" and core_openmp is not None: 168 | core = core_openmp.EquSolver(n_cpu) 169 | elif backend == "mpi" and core_mpi is not None: 170 | core = core_mpi.EquSolver(min_interval) 171 | rank = MPI.COMM_WORLD.Get_rank() 172 | elif backend == "cuda" and core_cuda is not None: 173 | core = core_cuda.EquSolver(block_size) 174 | elif backend.startswith("taichi") and taichi_solver is not None: 175 | core = taichi_solver.EquSolver(backend, n_cpu, block_size) 176 | 177 | super().__init__(gradient, rank, backend, core) 178 | 179 | def mask2index( 180 | self, mask: np.ndarray 181 | ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]: 182 | x, y = np.nonzero(mask) 183 | max_id = x.shape[0] + 1 184 | index = np.zeros((max_id, 3)) 185 | ids = self.core.partition(mask) 186 | ids[mask == 0] = 0 # reserve id=0 for constant 187 | index = ids[x, y].argsort() 188 | return ids, max_id, x[index], y[index] 189 | 190 | def reset( 191 | self, 192 | src: np.ndarray, 193 | mask: np.ndarray, 194 | tgt: np.ndarray, 195 | mask_on_src: Tuple[int, int], 196 | mask_on_tgt: Tuple[int, int], 197 | ) -> int: 198 | assert self.root 199 | # check validity 200 | # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1] 201 | # assert mask_on_src[0] + mask.shape[0] <= src.shape[0] 202 | # assert mask_on_src[1] + mask.shape[1] <= src.shape[1] 203 | # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0] 204 | # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1] 205 | 206 | if len(mask.shape) == 3: 207 | mask = mask.mean(-1) 208 | mask = (mask >= 128).astype(np.int32) 209 | 210 | # zero-out edge 211 | mask[0] = 0 212 | mask[-1] = 0 213 | mask[:, 0] = 0 214 | mask[:, -1] = 0 215 | 216 | x, y = np.nonzero(mask) 217 | x0, x1 = x.min() - 1, x.max() + 2 218 | y0, y1 = y.min() - 1, y.max() + 2 219 | mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1]) 220 | mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1]) 221 | mask = mask[x0:x1, y0:y1] 222 | ids, max_id, index_x, index_y = self.mask2index(mask) 223 | 224 | src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1] 225 | tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] 226 | 227 | src_C = src[src_x, src_y].astype(np.float32) 228 | src_U = src[src_x - 1, src_y].astype(np.float32) 229 | src_D = src[src_x + 1, src_y].astype(np.float32) 230 | src_L = src[src_x, src_y - 1].astype(np.float32) 231 | src_R = src[src_x, src_y + 1].astype(np.float32) 232 | tgt_C = tgt[tgt_x, tgt_y].astype(np.float32) 233 | tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32) 234 | tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32) 235 | tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32) 236 | tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32) 237 | 238 | grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \ 239 | + self.mixgrad(src_C - src_R, tgt_C - tgt_R) \ 240 | + self.mixgrad(src_C - src_U, tgt_C - tgt_U) \ 241 | + self.mixgrad(src_C - src_D, tgt_C - tgt_D) 242 | 243 | A = np.zeros((max_id, 4), np.int32) 244 | X = np.zeros((max_id, 3), np.float32) 245 | B = np.zeros((max_id, 3), np.float32) 246 | 247 | X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]] 248 | # four-way 249 | A[1:, 0] = ids[index_x - 1, index_y] 250 | A[1:, 1] = ids[index_x + 1, index_y] 251 | A[1:, 2] = ids[index_x, index_y - 1] 252 | A[1:, 3] = ids[index_x, index_y + 1] 253 | B[1:] = grad 254 | m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1) 255 | B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]] 256 | m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1) 257 | B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1] 258 | m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1) 259 | B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1] 260 | m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1) 261 | B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]] 262 | 263 | self.tgt = tgt.copy() 264 | self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]) 265 | self.core.reset(max_id, A, X, B) 266 | return max_id 267 | 268 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]: 269 | result = self.core.step(iteration) 270 | if self.root: 271 | x, err = result 272 | self.tgt[self.tgt_index] = x[1:] 273 | return self.tgt, err 274 | return None 275 | 276 | 277 | class GridProcessor(BaseProcessor): 278 | """PIE grid processor.""" 279 | 280 | def __init__( 281 | self, 282 | gradient: str = "max", 283 | backend: str = DEFAULT_BACKEND, 284 | n_cpu: int = CPU_COUNT, 285 | min_interval: int = 100, 286 | block_size: int = 1024, 287 | grid_x: int = 8, 288 | grid_y: int = 8, 289 | ): 290 | core: Optional[Any] = None 291 | rank = 0 292 | 293 | if backend == "numpy": 294 | core = np_solver.GridSolver() 295 | elif backend == "numba" and numba_solver is not None: 296 | core = numba_solver.GridSolver() 297 | elif backend == "gcc": 298 | core = core_gcc.GridSolver(grid_x, grid_y) 299 | elif backend == "openmp" and core_openmp is not None: 300 | core = core_openmp.GridSolver(grid_x, grid_y, n_cpu) 301 | elif backend == "mpi" and core_mpi is not None: 302 | core = core_mpi.GridSolver(min_interval) 303 | rank = MPI.COMM_WORLD.Get_rank() 304 | elif backend == "cuda" and core_cuda is not None: 305 | core = core_cuda.GridSolver(grid_x, grid_y) 306 | elif backend.startswith("taichi") and taichi_solver is not None: 307 | core = taichi_solver.GridSolver( 308 | grid_x, grid_y, backend, n_cpu, block_size 309 | ) 310 | 311 | super().__init__(gradient, rank, backend, core) 312 | 313 | def reset( 314 | self, 315 | src: np.ndarray, 316 | mask: np.ndarray, 317 | tgt: np.ndarray, 318 | mask_on_src: Tuple[int, int], 319 | mask_on_tgt: Tuple[int, int], 320 | ) -> int: 321 | assert self.root 322 | # check validity 323 | # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1] 324 | # assert mask_on_src[0] + mask.shape[0] <= src.shape[0] 325 | # assert mask_on_src[1] + mask.shape[1] <= src.shape[1] 326 | # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0] 327 | # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1] 328 | 329 | if len(mask.shape) == 3: 330 | mask = mask.mean(-1) 331 | mask = (mask >= 128).astype(np.int32) 332 | 333 | # zero-out edge 334 | mask[0] = 0 335 | mask[-1] = 0 336 | mask[:, 0] = 0 337 | mask[:, -1] = 0 338 | 339 | x, y = np.nonzero(mask) 340 | x0, x1 = x.min() - 1, x.max() + 2 341 | y0, y1 = y.min() - 1, y.max() + 2 342 | mask = mask[x0:x1, y0:y1] 343 | max_id = np.prod(mask.shape) 344 | 345 | src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1, 346 | mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32) 347 | tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1, 348 | mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32) 349 | grad = np.zeros([*mask.shape, 3], np.float32) 350 | grad[1:] += self.mixgrad( 351 | src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1] 352 | ) 353 | grad[:-1] += self.mixgrad( 354 | src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:] 355 | ) 356 | grad[:, 1:] += self.mixgrad( 357 | src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1] 358 | ) 359 | grad[:, :-1] += self.mixgrad( 360 | src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:] 361 | ) 362 | 363 | grad[mask == 0] = 0 364 | if True: 365 | kernel = [[1] * 3 for _ in range(3)] 366 | nmask = mask.copy() 367 | nmask[nmask > 0] = 1 368 | res = scipy.signal.convolve2d( 369 | nmask, kernel, mode="same", boundary="fill", fillvalue=1 370 | ) 371 | res[nmask < 1] = 0 372 | res[res == 9] = 0 373 | res[res > 0] = 1 374 | grad[res>0]=0 375 | # ylst, xlst = res.nonzero() 376 | # for y, x in zip(ylst, xlst): 377 | # grad[y,x]=0 378 | # for yi in range(-1,2): 379 | # for xi in range(-1,2): 380 | # grad[y+yi,x+xi]=0 381 | self.x0 = mask_on_tgt[0] + x0 382 | self.x1 = mask_on_tgt[0] + x1 383 | self.y0 = mask_on_tgt[1] + y0 384 | self.y1 = mask_on_tgt[1] + y1 385 | self.tgt = tgt.copy() 386 | self.core.reset(max_id, mask, tgt_crop, grad) 387 | return max_id 388 | 389 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]: 390 | result = self.core.step(iteration) 391 | if self.root: 392 | tgt, err = result 393 | self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt 394 | return self.tgt, err 395 | return None 396 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 |
3 |