├── .github
└── workflows
│ └── publish.yml
├── LICENSE.txt
├── README.md
├── __init__.py
├── bmab
├── __init__.py
├── external
│ ├── __init__.py
│ ├── advanced_clip
│ │ └── __init__.py
│ ├── fill
│ │ ├── __init__.py
│ │ ├── controlnet_union.py
│ │ └── pipeline_fill_sd_xl.py
│ ├── lama
│ │ ├── __init__.py
│ │ ├── config.yaml
│ │ └── saicinpainting
│ │ │ ├── __init__.py
│ │ │ ├── training
│ │ │ ├── __init__.py
│ │ │ ├── data
│ │ │ │ ├── __init__.py
│ │ │ │ └── masks.py
│ │ │ ├── losses
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adversarial.py
│ │ │ │ ├── constants.py
│ │ │ │ ├── distance_weighting.py
│ │ │ │ ├── feature_matching.py
│ │ │ │ ├── perceptual.py
│ │ │ │ ├── segmentation.py
│ │ │ │ └── style_loss.py
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── depthwise_sep_conv.py
│ │ │ │ ├── fake_fakes.py
│ │ │ │ ├── ffc.py
│ │ │ │ ├── multidilated_conv.py
│ │ │ │ ├── multiscale.py
│ │ │ │ ├── pix2pixhd.py
│ │ │ │ ├── spatial_transform.py
│ │ │ │ └── squeeze_excitation.py
│ │ │ ├── trainers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ └── default.py
│ │ │ └── visualizers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── colors.py
│ │ │ │ ├── directory.py
│ │ │ │ └── noop.py
│ │ │ └── utils.py
│ └── rmbg14
│ │ ├── MyConfig.py
│ │ ├── __init__.py
│ │ ├── briarmbg.py
│ │ └── utilities.py
├── nodes
│ ├── __init__.py
│ ├── a1111api.py
│ ├── basic.py
│ ├── binder.py
│ ├── cnloader.py
│ ├── detailers.py
│ ├── fill.py
│ ├── imaging.py
│ ├── loaders.py
│ ├── resize.py
│ ├── sampler.py
│ ├── toy.py
│ ├── upscaler.py
│ ├── utilnode.py
│ └── watermark.py
├── process.py
├── serverext.py
└── utils
│ ├── __init__.py
│ ├── color.py
│ ├── colorname.py
│ ├── grdino.py
│ ├── sam.py
│ └── yolo.py
├── models
└── put_models_here
├── pyproject.toml
├── requirements.txt
├── resources
├── cache
│ └── _cachefiles_will_be_put_here
├── examples
│ ├── bmab-flux-sample.json
│ ├── example.json
│ ├── hand-detail-example.json
│ ├── ic-light-example.json
│ └── openpose-hand-detailing-example.json
└── wildcard
│ └── put_wildcard_here
└── web
├── gemini.js
├── loadoutputimage.js
├── previewtext.js
└── remoteaccess.js
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | jobs:
11 | publish-node:
12 | name: Publish Custom Node to registry
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Check out code
16 | uses: actions/checkout@v4
17 | - name: Publish Custom Node
18 | uses: Comfy-Org/publish-node-action@main
19 | with:
20 | ## Add your own personal access token to your Github Repository secrets and reference it here.
21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # comfyui_bmab
2 |
3 | BMAB is an custom nodes of ComfyUI and has the function of post-processing the generated image according to settings.
4 | If necessary, you can find and redraw people, faces, and hands, or perform functions such as resize, resample, and add noise.
5 | You can composite two images or perform the Upscale function.
6 |
7 |
8 |
9 | You can download sample.json.
10 |
11 | https://github.com/portu-sim/comfyui_bmab/blob/main/resources/examples/example.json
12 |
13 | # Flux
14 |
15 | BMAB now supports Flux 1.
16 |
17 |
18 | https://github.com/portu-sim/comfyui_bmab/blob/main/resources/examples/bmab-flux-sample.json
19 |
20 |
21 |
22 | # New Nodes
23 |
24 | ## BMAB Inpaint
25 |
26 |
27 |
28 | ## BMAB Reframe
29 |
30 |
31 |
32 | ## BMAB Outpaint By Ratio
33 |
34 |
35 |
36 |
37 |
38 | ### Gallery
39 |
40 | [instagram](https://www.instagram.com/portu.sim/)
41 | [facebook](https://www.facebook.com/portusimkr)
42 |
43 | ### Hand Detailing Sample
44 |
45 | BMAB detects and enlarges the upper body of a person and performs Openpose at high resolution to fix incorrectly drawn hands.
46 |
47 |
48 |
49 |
50 | # Installation
51 |
52 | You can install comfyui_bmab using ComfyUI-Manager easily.
53 | You will need to install a total of three custom nodes.
54 |
55 | * comfyui_bmab
56 | * comfyui_controlnet_aux
57 | * https://github.com/Fannovel16/comfyui_controlnet_aux.git
58 | * Fannovel16, Thanks for excellent code.
59 | * ComfyUI_IPAdapter_plus
60 | * https://github.com/cubiq/ComfyUI_IPAdapter_plus.git
61 | * cubiq, Thanks for excellent code.
62 |
63 |
64 | ### Grounding DINO Installation
65 |
66 | Transfomer v4.40.0 has Grounding DINO implementation.
67 | https://github.com/huggingface/transformers/releases/tag/v4.40.0
68 | Now BMAB use transformer for detecting object.
69 | No installation required.
70 |
71 | ## Install Manually
72 |
73 | I can't describe about your python environment.
74 | I will write the installation instructions assuming you have some knowledge of Python.
75 |
76 |
77 | ### Windows portable User
78 |
79 | ```commandline
80 | cd ComfyUI/custom_nodes
81 | git clone https://github.com/portu-sim/comfyui_bmab.git
82 | cd comfyui_bmab
83 | python_embeded\python.exe -m pip install -r requirements.txt
84 | cd ..
85 | ```
86 |
87 | You will need to install two additional custom nodes required by comfyui_bmab.
88 |
89 | ```commandline
90 | cd ComfyUI/custom_nodes
91 | git clone https://github.com/Fannovel16/comfyui_controlnet_aux.git
92 | cd comfyui_controlnet_aux
93 | python_embeded\python.exe -r pip install requirements.txt
94 | cd ..
95 | git clone https://github.com/cubiq/ComfyUI_IPAdapter_plus.git
96 | cd ComfyUI_IPAdapter_plus
97 | python_embeded\python.exe -m pip install -r requirements.txt
98 | cd ..
99 | ```
100 |
101 | ### Other python environment
102 |
103 | ```commandline
104 | cd ComfyUI/custom_nodes
105 | git clone https://github.com/portu-sim/comfyui_bmab.git
106 | cd comfyui_bmab
107 | pip install -r requirements.txt
108 | cd ..
109 | ```
110 |
111 | You will need to install two additional custom nodes required by comfyui_bmab.
112 |
113 | ```commandline
114 | cd ComfyUI/custom_nodes
115 | git clone https://github.com/Fannovel16/comfyui_controlnet_aux.git
116 | cd comfyui_controlnet_aux
117 | pip install -r requirements.txt
118 | cd ..
119 | git clone https://github.com/cubiq/ComfyUI_IPAdapter_plus.git
120 | cd ComfyUI_IPAdapter_plus
121 | pip install -r requirements.txt
122 | cd ..
123 | ```
124 |
125 |
126 | Run ComfyUI
127 |
128 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pillow_avif
4 |
5 | sys.path.append(os.path.join(os.path.dirname(__file__)))
6 | from bmab import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
7 |
8 | try:
9 | import testnodes
10 | print('Register test nodes.')
11 | testnodes.register(NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS)
12 | except Exception as e:
13 | print('Not found test nodes.')
14 | print(e)
15 |
16 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
17 |
18 | WEB_DIRECTORY = f'./web'
19 |
20 |
--------------------------------------------------------------------------------
/bmab/__init__.py:
--------------------------------------------------------------------------------
1 | from bmab import nodes, serverext
2 |
3 |
4 | NODE_CLASS_MAPPINGS = {
5 | # Basic
6 | 'BMAB Basic': nodes.BMABBasic,
7 | 'BMAB Edge': nodes.BMABEdge,
8 | 'BMAB Text': nodes.BMABText,
9 | 'BMAB Preview Text': nodes.BMABPreviewText,
10 |
11 | # Resize
12 | 'BMAB Resize By Person': nodes.BMABResizeByPerson,
13 | 'BMAB Resize By Ratio': nodes.BMABResizeByRatio,
14 | 'BMAB Resize and Fill': nodes.BMABResizeAndFill,
15 | 'BMAB Crop': nodes.BMABCrop,
16 | 'BMAB Zoom Out': nodes.BMABZoomOut,
17 | 'BMAB Square': nodes.BMABSquare,
18 |
19 | # Sampler
20 | 'BMAB Integrator': nodes.BMABIntegrator,
21 | 'BMAB ToBind': nodes.BMABToBind,
22 | 'BMAB Flux Integrator': nodes.BMABFluxIntegrator,
23 | 'BMAB Extractor': nodes.BMABExtractor,
24 | 'BMAB SeedGenerator': nodes.BMABSeedGenerator,
25 | 'BMAB KSampler': nodes.BMABKSampler,
26 | 'BMAB KSamplerHiresFix': nodes.BMABKSamplerHiresFix,
27 | 'BMAB KSamplerHiresFixWithUpscaler': nodes.BMABKSamplerHiresFixWithUpscaler,
28 | 'BMAB Context': nodes.BMABContextNode,
29 | 'BMAB Import Integrator': nodes.BMABImportIntegrator,
30 | 'BMAB KSamplerKohyaDeepShrink': nodes.BMABKSamplerKohyaDeepShrink,
31 | 'BMAB Clip Text Encoder SDXL': nodes.BMABClipTextEncoderSDXL,
32 |
33 | # Detailer
34 | 'BMAB Face Detailer': nodes.BMABFaceDetailer,
35 | 'BMAB Person Detailer': nodes.BMABPersonDetailer,
36 | 'BMAB Simple Hand Detailer': nodes.BMABSimpleHandDetailer,
37 | 'BMAB Subframe Hand Detailer': nodes.BMABSubframeHandDetailer,
38 | 'BMAB Openpose Hand Detailer': nodes.BMABOpenposeHandDetailer,
39 | 'BMAB Detail Anything': nodes.BMABDetailAnything,
40 |
41 | # Control Net
42 | 'BMAB ControlNet': nodes.BMABControlNet,
43 | 'BMAB Flux ControlNet': nodes.BMABFluxControlNet,
44 | 'BMAB ControlNet Openpose': nodes.BMABControlNetOpenpose,
45 | 'BMAB ControlNet IPAdapter': nodes.BMABControlNetIPAdapter,
46 |
47 | # Imaging
48 | 'BMAB Detection Crop': nodes.BMABDetectionCrop,
49 | 'BMAB Remove Background': nodes.BMABRemoveBackground,
50 | 'BMAB Alpha Composit': nodes.BMABAlphaComposit,
51 | 'BMAB Blend': nodes.BMABBlend,
52 | 'BMAB Detect And Mask': nodes.BMABDetectAndMask,
53 | 'BMAB Lama Inpaint': nodes.BMABLamaInpaint,
54 | 'BMAB Detector': nodes.BMABDetector,
55 | 'BMAB Segment Anything': nodes.BMABSegmentAnything,
56 | 'BMAB Masks To Images': nodes.BMABMasksToImages,
57 | 'BMAB Load Image': nodes.BMABLoadImage,
58 | 'BMAB Load Output Image': nodes.BMABLoadOutputImage,
59 | 'BMAB Black And White': nodes.BMABBlackAndWhite,
60 | 'BMAB Detect And Paste': nodes.BMABDetectAndPaste,
61 |
62 | # SD-WebUI API
63 | 'BMAB SD-WebUI API Server': nodes.BMABApiServer,
64 | 'BMAB SD-WebUI API T2I': nodes.BMABApiSDWebUIT2I,
65 | 'BMAB SD-WebUI API I2I': nodes.BMABApiSDWebUII2I,
66 | 'BMAB SD-WebUI API T2I Hires.Fix': nodes.BMABApiSDWebUIT2IHiresFix,
67 | 'BMAB SD-WebUI API BMAB Extension': nodes.BMABApiSDWebUIBMABExtension,
68 | 'BMAB SD-WebUI API ControlNet': nodes.BMABApiSDWebUIControlNet,
69 |
70 | # UTIL Nodes
71 | 'BMAB Model To Bind': nodes.BMABModelToBind,
72 | 'BMAB Conditioning To Bind': nodes.BMABConditioningToBind,
73 | 'BMAB Noise Generator': nodes.BMABNoiseGenerator,
74 | 'BMAB Base64 Image': nodes.BMABBase64Image,
75 | 'BMAB Image Storage': nodes.BMABImageStorage,
76 | 'BMAB Normalize Size': nodes.BMABNormalizeSize,
77 | 'BMAB Dummy': nodes.BMABDummy,
78 |
79 | # Watermark
80 | 'BMAB Watermark': nodes.BMABWatermark,
81 |
82 | 'BMAB Upscaler': nodes.BMABUpscale,
83 | 'BMAB Save Image': nodes.BMABSaveImage,
84 | 'BMAB Remote Access And Save': nodes.BMABRemoteAccessAndSave,
85 | 'BMAB Upscale With Model': nodes.BMABUpscaleWithModel,
86 | 'BMAB LoRA Loader': nodes.BMABLoraLoader,
87 | 'BMAB Prompt': nodes.BMABPrompt,
88 | 'BMAB Google Gemini Prompt': nodes.BMABGoogleGemini,
89 |
90 | # Fill
91 | 'BMAB Reframe': nodes.BMABReframe,
92 | 'BMAB Outpaint By Ratio': nodes.BMABOutpaintByRatio,
93 | 'BMAB Inpaint': nodes.BMABInpaint
94 | }
95 |
96 | NODE_DISPLAY_NAME_MAPPINGS = {
97 | # Preview
98 | 'BMAB Basic': 'BMAB Basic',
99 | 'BMAB Edge': 'BMAB Edge',
100 | 'BMAB Text': 'BMAB Text',
101 | 'BMAB Preview Text': 'BMAB Preview Text',
102 |
103 | # Resize
104 | 'BMAB Resize By Person': 'BMAB Resize By Person',
105 | 'BMAB Resize By Ratio': 'BMAB Resize By Ratio',
106 | 'BMAB Resize and Fill': 'BMAB Resize And Fill',
107 | 'BMAB Crop': 'BMAB Crop',
108 | 'BMAB Zoom Out': 'BMAB Zoom Out',
109 | 'BMAB Square': 'BMAB Square',
110 |
111 | # Sampler
112 | 'BMAB Integrator': 'BMAB Integrator',
113 | 'BMAB ToBind': 'BMAB ToBind',
114 | 'BMAB Flux Integrator': 'BMAB Flux Integrator',
115 | 'BMAB KSampler': 'BMAB KSampler',
116 | 'BMAB KSamplerHiresFix': 'BMAB KSampler Hires. Fix',
117 | 'BMAB KSamplerHiresFixWithUpscaler': 'BMAB KSampler Hires. Fix With Upscaler',
118 | 'BMAB Extractor': 'BMAB Extractor',
119 | 'BMAB SeedGenerator': 'BMAB Seed Generator',
120 | 'BMAB Context': 'BMAB Context',
121 | 'BMAB Import Integrator': 'BMAB Import Integrator',
122 | 'BMAB KSamplerKohyaDeepShrink': 'BMAB KSampler with Kohya Deep Shrink',
123 | 'BMAB Clip Text Encoder SDXL': 'BMAB Clip Text Encoder SDXL',
124 |
125 | # Detailer
126 | 'BMAB Face Detailer': 'BMAB Face Detailer',
127 | 'BMAB Person Detailer': 'BMAB Person Detailer',
128 | 'BMAB Simple Hand Detailer': 'BMAB Simple Hand Detailer',
129 | 'BMAB Subframe Hand Detailer': 'BMAB Subframe Hand Detailer',
130 | 'BMAB Openpose Hand Detailer': 'BMAB Openpose Hand Detailer',
131 | 'BMAB Detail Anything': 'BMAB Detail Anything',
132 |
133 | # Control Net
134 | 'BMAB ControlNet': 'BMAB ControlNet',
135 | 'BMAB Flux ControlNet': 'BMAB Flux ControlNet',
136 | 'BMAB ControlNet Openpose': 'BMAB ControlNet Openpose',
137 | 'BMAB ControlNet IPAdapter': 'BMAB ControlNet IPAdapter',
138 |
139 | # Imaging
140 | 'BMAB Detection Crop': 'BMAB Detection Crop',
141 | 'BMAB Remove Background': 'BMAB Remove Background',
142 | 'BMAB Alpha Composit': 'BMAB Alpha Composit',
143 | 'BMAB Blend': 'BMAB Blend',
144 | 'BMAB Detect And Mask': 'BMAB Detect And Mask',
145 | 'BMAB Lama Inpaint': 'BMAB Lama Inpaint',
146 | 'BMAB Detector': 'BMAB Detector',
147 | 'BMAB Segment Anything': 'BMAB Segment Anything',
148 | 'BMAB Masks To Images': 'BMAB Masks To Images',
149 | 'BMAB Load Image': 'BMAB Load Image',
150 | 'BMAB Load Output Image': 'BMAB Load Output Image',
151 | 'BMAB Black And White': 'BMAB Black And White',
152 | 'BMAB Detect And Paste': 'BMAB Detect And Paste',
153 |
154 | # SD-WebUI API
155 | 'BMAB SD-WebUI API Server': 'BMAB SD-WebUI API Server',
156 | 'BMAB SD-WebUI API T2I': 'BMAB SD-WebUI API T2I',
157 | 'BMAB SD-WebUI API I2I': 'BMAB SD-WebUI API I2I',
158 | 'BMAB SD-WebUI API T2I Hires.Fix': 'BMAB SD-WebUI API T2I Hires.Fix',
159 | 'BMAB SD-WebUI API BMAB Extension': 'BMAB SD-WebUI API BMAB Extension',
160 | 'BMAB SD-WebUI API ControlNet': 'BMAB SD-WebUI API ControlNet',
161 |
162 | # UTIL Nodes
163 | 'BMAB Model To Bind': 'BMAB Model To Bind',
164 | 'BMAB Conditioning To Bind': 'BMAB Conditioning To Bind',
165 | 'BMAB Noise Generator': 'BMAB Noise Generator',
166 | 'BMAB Base64 Image': 'BMAB Base64 Image',
167 | 'BMAB Image Storage': 'BMAB Image Storage',
168 | 'BMAB Normalize Size': 'BMAB Normalize Size',
169 | 'BMAB Dummy': 'BMAB Dummy',
170 |
171 | # Watermark
172 | 'BMAB Watermark': 'BMAB Watermark',
173 |
174 | 'BMAB DinoSam': 'BMAB DinoSam',
175 | 'BMAB Upscaler': 'BMAB Upscaler',
176 | 'BMAB Control Net': 'BMAB ControlNet',
177 | 'BMAB Save Image': 'BMAB Save Image',
178 | 'BMAB Remote Access And Save': 'BMAB Remote Access And Save',
179 | 'BMAB Upscale With Model': 'BMAB Upscale With Model',
180 | 'BMAB LoRA Loader': 'BMAB Lora Loader',
181 | 'BMAB Prompt': 'BMAB Prompt',
182 | 'BMAB Google Gemini Prompt': 'BMAB Google Gemini API',
183 |
184 | # Fill
185 | 'BMAB Reframe': 'BMAB Reframe',
186 | 'BMAB Outpaint By Ratio': 'BMAB Outpaint By Ratio',
187 | 'BMAB Inpaint': 'BMAB Inpaint',
188 | }
189 |
190 |
--------------------------------------------------------------------------------
/bmab/external/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/__init__.py
--------------------------------------------------------------------------------
/bmab/external/fill/__init__.py:
--------------------------------------------------------------------------------
1 | # from https://huggingface.co/OzzyGT
2 | # from https://huggingface.co/fffiloni
3 |
--------------------------------------------------------------------------------
/bmab/external/lama/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/advimman/lama
2 | # https://github.com/Mikubill/sd-webui-controlnet
3 |
4 | import os
5 | import gc
6 | import sys
7 | import cv2
8 | import yaml
9 | import torch
10 | import numpy as np
11 | from PIL import Image
12 | from omegaconf import OmegaConf
13 | from einops import rearrange
14 |
15 | from bmab.external.lama.saicinpainting.training.trainers import load_checkpoint
16 | from bmab import utils
17 |
18 |
19 | def lama_inpainting(image, mask, device='gpu'):
20 | mask_input = mask
21 | lama = LamaInpainting()
22 | if device == 'cpu':
23 | lama.device = 'cpu'
24 |
25 | if device == 'mps':
26 | image = lama(image, mask_input)
27 | elif lama.device.startswith('cuda') and image.width != image.height:
28 | width, height = image.size
29 | mx = max(image.width, image.height)
30 | resized = Image.new('RGB', (mx, mx), 0)
31 | mask = Image.new('L', (mx, mx), 0)
32 | if height < width:
33 | y0 = (mx - height) // 2
34 | resized.paste(image, (0, y0))
35 | mask.paste(mask_input, (0, y0))
36 | l = lama(resized, mask)
37 | image = l.crop((0, y0, width, y0 + height))
38 | else:
39 | x0 = (mx - width) // 2
40 | resized.paste(image, (x0, 0))
41 | mask.paste(mask_input, (x0, 0))
42 | l = lama(resized, mask)
43 | image = l.crop((x0, 0, x0 + width, height))
44 | else:
45 | image = lama(image, mask_input)
46 | del lama
47 | gc.collect()
48 | if torch.cuda.is_available():
49 | torch.cuda.empty_cache()
50 | torch.cuda.ipc_collect()
51 | return image
52 |
53 |
54 | class LamaInpainting:
55 |
56 | def __init__(self):
57 | if sys.platform == 'darwin':
58 | self.device = 'mps'
59 | elif torch.cuda.is_available():
60 | self.device = 'cuda'
61 | else:
62 | self.device = 'cpu'
63 | self.model = None
64 |
65 | @staticmethod
66 | def load_image(pilimg, mode='RGB'):
67 | img = np.array(pilimg.convert(mode))
68 | if img.ndim == 3:
69 | print('transpose')
70 | img = np.transpose(img, (2, 0, 1))
71 | out_img = img.astype('float32') / 255
72 | return out_img
73 |
74 | def load_model(self):
75 | config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.yaml')
76 | cfg = yaml.safe_load(open(config_path, 'rt'))
77 | cfg = OmegaConf.create(cfg)
78 | cfg.training_model.predict_only = True
79 | cfg.visualizer.kind = 'noop'
80 | lamapth = utils.lazy_loader('ControlNetLama.pth')
81 | self.model = load_checkpoint(cfg, lamapth, strict=False, map_location='cpu')
82 | self.model = self.model.to(self.device)
83 | self.model.eval()
84 |
85 | def unload_model(self):
86 | if self.model is not None:
87 | self.model.cpu()
88 |
89 | def __call__(self, image, mask):
90 | if self.model is None:
91 | self.load_model()
92 | self.model.to(self.device)
93 |
94 | opencv_image = cv2.cvtColor(np.array(image.convert('RGB')), cv2.COLOR_RGB2BGR)[:, :, 0:3]
95 | opencv_mask = cv2.cvtColor(np.array(mask.convert('RGB')), cv2.COLOR_RGB2BGR)[:, :, 0:1]
96 | color = np.ascontiguousarray(opencv_image).astype(np.float32) / 255.0
97 | mask = np.ascontiguousarray(opencv_mask).astype(np.float32) / 255.0
98 |
99 | with torch.no_grad():
100 | color = torch.from_numpy(color).float().to(self.device)
101 | mask = torch.from_numpy(mask).float().to(self.device)
102 | mask = (mask > 0.5).float()
103 | color = color * (1 - mask)
104 | image_feed = torch.cat([color, mask], dim=2)
105 | image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
106 | result = self.model(image_feed)[0]
107 | result = rearrange(result, 'c h w -> h w c')
108 | result = result * mask + color * (1 - mask)
109 | result *= 255.0
110 |
111 | img = result.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
112 | color_coverted = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
113 | pil_image = Image.fromarray(color_coverted)
114 | return pil_image
115 |
--------------------------------------------------------------------------------
/bmab/external/lama/config.yaml:
--------------------------------------------------------------------------------
1 | run_title: b18_ffc075_batch8x15
2 | training_model:
3 | kind: default
4 | visualize_each_iters: 1000
5 | concat_mask: true
6 | store_discr_outputs_for_vis: true
7 | losses:
8 | l1:
9 | weight_missing: 0
10 | weight_known: 10
11 | perceptual:
12 | weight: 0
13 | adversarial:
14 | kind: r1
15 | weight: 10
16 | gp_coef: 0.001
17 | mask_as_fake_target: true
18 | allow_scale_mask: true
19 | feature_matching:
20 | weight: 100
21 | resnet_pl:
22 | weight: 30
23 | weights_path: ${env:TORCH_HOME}
24 |
25 | optimizers:
26 | generator:
27 | kind: adam
28 | lr: 0.001
29 | discriminator:
30 | kind: adam
31 | lr: 0.0001
32 | visualizer:
33 | key_order:
34 | - image
35 | - predicted_image
36 | - discr_output_fake
37 | - discr_output_real
38 | - inpainted
39 | rescale_keys:
40 | - discr_output_fake
41 | - discr_output_real
42 | kind: directory
43 | outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
44 | location:
45 | data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
46 | out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
47 | tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
48 | data:
49 | batch_size: 15
50 | val_batch_size: 2
51 | num_workers: 3
52 | train:
53 | indir: ${location.data_root_dir}/train
54 | out_size: 256
55 | mask_gen_kwargs:
56 | irregular_proba: 1
57 | irregular_kwargs:
58 | max_angle: 4
59 | max_len: 200
60 | max_width: 100
61 | max_times: 5
62 | min_times: 1
63 | box_proba: 1
64 | box_kwargs:
65 | margin: 10
66 | bbox_min_size: 30
67 | bbox_max_size: 150
68 | max_times: 3
69 | min_times: 1
70 | segm_proba: 0
71 | segm_kwargs:
72 | confidence_threshold: 0.5
73 | max_object_area: 0.5
74 | min_mask_area: 0.07
75 | downsample_levels: 6
76 | num_variants_per_mask: 1
77 | rigidness_mode: 1
78 | max_foreground_coverage: 0.3
79 | max_foreground_intersection: 0.7
80 | max_mask_intersection: 0.1
81 | max_hidden_area: 0.1
82 | max_scale_change: 0.25
83 | horizontal_flip: true
84 | max_vertical_shift: 0.2
85 | position_shuffle: true
86 | transform_variant: distortions
87 | dataloader_kwargs:
88 | batch_size: ${data.batch_size}
89 | shuffle: true
90 | num_workers: ${data.num_workers}
91 | val:
92 | indir: ${location.data_root_dir}/val
93 | img_suffix: .png
94 | dataloader_kwargs:
95 | batch_size: ${data.val_batch_size}
96 | shuffle: false
97 | num_workers: ${data.num_workers}
98 | visual_test:
99 | indir: ${location.data_root_dir}/korean_test
100 | img_suffix: _input.png
101 | pad_out_to_modulo: 32
102 | dataloader_kwargs:
103 | batch_size: 1
104 | shuffle: false
105 | num_workers: ${data.num_workers}
106 | generator:
107 | kind: ffc_resnet
108 | input_nc: 4
109 | output_nc: 3
110 | ngf: 64
111 | n_downsampling: 3
112 | n_blocks: 18
113 | add_out_act: sigmoid
114 | init_conv_kwargs:
115 | ratio_gin: 0
116 | ratio_gout: 0
117 | enable_lfu: false
118 | downsample_conv_kwargs:
119 | ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
120 | ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
121 | enable_lfu: false
122 | resnet_conv_kwargs:
123 | ratio_gin: 0.75
124 | ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
125 | enable_lfu: false
126 | discriminator:
127 | kind: pix2pixhd_nlayer
128 | input_nc: 3
129 | ndf: 64
130 | n_layers: 4
131 | evaluator:
132 | kind: default
133 | inpainted_key: inpainted
134 | integral_kind: ssim_fid100_f1
135 | trainer:
136 | kwargs:
137 | gpus: -1
138 | accelerator: ddp
139 | max_epochs: 200
140 | gradient_clip_val: 1
141 | log_gpu_memory: None
142 | limit_train_batches: 25000
143 | val_check_interval: ${trainer.kwargs.limit_train_batches}
144 | log_every_n_steps: 1000
145 | precision: 32
146 | terminate_on_nan: false
147 | check_val_every_n_epoch: 1
148 | num_sanity_val_steps: 8
149 | limit_val_batches: 1000
150 | replace_sampler_ddp: false
151 | checkpoint_kwargs:
152 | verbose: true
153 | save_top_k: 5
154 | save_last: true
155 | period: 1
156 | monitor: val_ssim_fid100_f1_total_mean
157 | mode: max
158 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/__init__.py
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/training/__init__.py
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/training/data/__init__.py
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/training/losses/__init__.py
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/adversarial.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class BaseAdversarialLoss:
9 | def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
10 | generator: nn.Module, discriminator: nn.Module):
11 | """
12 | Prepare for generator step
13 | :param real_batch: Tensor, a batch of real samples
14 | :param fake_batch: Tensor, a batch of samples produced by generator
15 | :param generator:
16 | :param discriminator:
17 | :return: None
18 | """
19 |
20 | def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
21 | generator: nn.Module, discriminator: nn.Module):
22 | """
23 | Prepare for discriminator step
24 | :param real_batch: Tensor, a batch of real samples
25 | :param fake_batch: Tensor, a batch of samples produced by generator
26 | :param generator:
27 | :param discriminator:
28 | :return: None
29 | """
30 |
31 | def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
32 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
33 | mask: Optional[torch.Tensor] = None) \
34 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
35 | """
36 | Calculate generator loss
37 | :param real_batch: Tensor, a batch of real samples
38 | :param fake_batch: Tensor, a batch of samples produced by generator
39 | :param discr_real_pred: Tensor, discriminator output for real_batch
40 | :param discr_fake_pred: Tensor, discriminator output for fake_batch
41 | :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
42 | :return: total generator loss along with some values that might be interesting to log
43 | """
44 | raise NotImplemented()
45 |
46 | def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
47 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
48 | mask: Optional[torch.Tensor] = None) \
49 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
50 | """
51 | Calculate discriminator loss and call .backward() on it
52 | :param real_batch: Tensor, a batch of real samples
53 | :param fake_batch: Tensor, a batch of samples produced by generator
54 | :param discr_real_pred: Tensor, discriminator output for real_batch
55 | :param discr_fake_pred: Tensor, discriminator output for fake_batch
56 | :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
57 | :return: total discriminator loss along with some values that might be interesting to log
58 | """
59 | raise NotImplemented()
60 |
61 | def interpolate_mask(self, mask, shape):
62 | assert mask is not None
63 | assert self.allow_scale_mask or shape == mask.shape[-2:]
64 | if shape != mask.shape[-2:] and self.allow_scale_mask:
65 | if self.mask_scale_mode == 'maxpool':
66 | mask = F.adaptive_max_pool2d(mask, shape)
67 | else:
68 | mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
69 | return mask
70 |
71 | def make_r1_gp(discr_real_pred, real_batch):
72 | if torch.is_grad_enabled():
73 | grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
74 | grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
75 | else:
76 | grad_penalty = 0
77 | real_batch.requires_grad = False
78 |
79 | return grad_penalty
80 |
81 | class NonSaturatingWithR1(BaseAdversarialLoss):
82 | def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
83 | mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
84 | use_unmasked_for_gen=True, use_unmasked_for_discr=True):
85 | self.gp_coef = gp_coef
86 | self.weight = weight
87 | # use for discr => use for gen;
88 | # otherwise we teach only the discr to pay attention to very small difference
89 | assert use_unmasked_for_gen or (not use_unmasked_for_discr)
90 | # mask as target => use unmasked for discr:
91 | # if we don't care about unmasked regions at all
92 | # then it doesn't matter if the value of mask_as_fake_target is true or false
93 | assert use_unmasked_for_discr or (not mask_as_fake_target)
94 | self.use_unmasked_for_gen = use_unmasked_for_gen
95 | self.use_unmasked_for_discr = use_unmasked_for_discr
96 | self.mask_as_fake_target = mask_as_fake_target
97 | self.allow_scale_mask = allow_scale_mask
98 | self.mask_scale_mode = mask_scale_mode
99 | self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
100 |
101 | def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
102 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
103 | mask=None) \
104 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
105 | fake_loss = F.softplus(-discr_fake_pred)
106 | if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
107 | not self.use_unmasked_for_gen: # == if masked region should be treated differently
108 | mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
109 | if not self.use_unmasked_for_gen:
110 | fake_loss = fake_loss * mask
111 | else:
112 | pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
113 | fake_loss = fake_loss * pixel_weights
114 |
115 | return fake_loss.mean() * self.weight, dict()
116 |
117 | def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
118 | generator: nn.Module, discriminator: nn.Module):
119 | real_batch.requires_grad = True
120 |
121 | def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
122 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
123 | mask=None) \
124 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
125 |
126 | real_loss = F.softplus(-discr_real_pred)
127 | grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
128 | fake_loss = F.softplus(discr_fake_pred)
129 |
130 | if not self.use_unmasked_for_discr or self.mask_as_fake_target:
131 | # == if masked region should be treated differently
132 | mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
133 | # use_unmasked_for_discr=False only makes sense for fakes;
134 | # for reals there is no difference beetween two regions
135 | fake_loss = fake_loss * mask
136 | if self.mask_as_fake_target:
137 | fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
138 |
139 | sum_discr_loss = real_loss + grad_penalty + fake_loss
140 | metrics = dict(discr_real_out=discr_real_pred.mean(),
141 | discr_fake_out=discr_fake_pred.mean(),
142 | discr_real_gp=grad_penalty)
143 | return sum_discr_loss.mean(), metrics
144 |
145 | class BCELoss(BaseAdversarialLoss):
146 | def __init__(self, weight):
147 | self.weight = weight
148 | self.bce_loss = nn.BCEWithLogitsLoss()
149 |
150 | def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
151 | real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
152 | fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
153 | return fake_loss, dict()
154 |
155 | def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
156 | generator: nn.Module, discriminator: nn.Module):
157 | real_batch.requires_grad = True
158 |
159 | def discriminator_loss(self,
160 | mask: torch.Tensor,
161 | discr_real_pred: torch.Tensor,
162 | discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
163 |
164 | real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
165 | sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
166 | metrics = dict(discr_real_out=discr_real_pred.mean(),
167 | discr_fake_out=discr_fake_pred.mean(),
168 | discr_real_gp=0)
169 | return sum_discr_loss, metrics
170 |
171 |
172 | def make_discrim_loss(kind, **kwargs):
173 | if kind == 'r1':
174 | return NonSaturatingWithR1(**kwargs)
175 | elif kind == 'bce':
176 | return BCELoss(**kwargs)
177 | raise ValueError(f'Unknown adversarial loss kind {kind}')
178 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/constants.py:
--------------------------------------------------------------------------------
1 | weights = {"ade20k":
2 | [6.34517766497462,
3 | 9.328358208955224,
4 | 11.389521640091116,
5 | 16.10305958132045,
6 | 20.833333333333332,
7 | 22.22222222222222,
8 | 25.125628140703515,
9 | 43.29004329004329,
10 | 50.5050505050505,
11 | 54.6448087431694,
12 | 55.24861878453038,
13 | 60.24096385542168,
14 | 62.5,
15 | 66.2251655629139,
16 | 84.74576271186442,
17 | 90.90909090909092,
18 | 91.74311926605505,
19 | 96.15384615384616,
20 | 96.15384615384616,
21 | 97.08737864077669,
22 | 102.04081632653062,
23 | 135.13513513513513,
24 | 149.2537313432836,
25 | 153.84615384615384,
26 | 163.93442622950818,
27 | 166.66666666666666,
28 | 188.67924528301887,
29 | 192.30769230769232,
30 | 217.3913043478261,
31 | 227.27272727272725,
32 | 227.27272727272725,
33 | 227.27272727272725,
34 | 303.03030303030306,
35 | 322.5806451612903,
36 | 333.3333333333333,
37 | 370.3703703703703,
38 | 384.61538461538464,
39 | 416.6666666666667,
40 | 416.6666666666667,
41 | 434.7826086956522,
42 | 434.7826086956522,
43 | 454.5454545454545,
44 | 454.5454545454545,
45 | 500.0,
46 | 526.3157894736842,
47 | 526.3157894736842,
48 | 555.5555555555555,
49 | 555.5555555555555,
50 | 555.5555555555555,
51 | 555.5555555555555,
52 | 555.5555555555555,
53 | 555.5555555555555,
54 | 555.5555555555555,
55 | 588.2352941176471,
56 | 588.2352941176471,
57 | 588.2352941176471,
58 | 588.2352941176471,
59 | 588.2352941176471,
60 | 666.6666666666666,
61 | 666.6666666666666,
62 | 666.6666666666666,
63 | 666.6666666666666,
64 | 714.2857142857143,
65 | 714.2857142857143,
66 | 714.2857142857143,
67 | 714.2857142857143,
68 | 714.2857142857143,
69 | 769.2307692307693,
70 | 769.2307692307693,
71 | 769.2307692307693,
72 | 833.3333333333334,
73 | 833.3333333333334,
74 | 833.3333333333334,
75 | 833.3333333333334,
76 | 909.090909090909,
77 | 1000.0,
78 | 1111.111111111111,
79 | 1111.111111111111,
80 | 1111.111111111111,
81 | 1111.111111111111,
82 | 1111.111111111111,
83 | 1250.0,
84 | 1250.0,
85 | 1250.0,
86 | 1250.0,
87 | 1250.0,
88 | 1428.5714285714287,
89 | 1428.5714285714287,
90 | 1428.5714285714287,
91 | 1428.5714285714287,
92 | 1428.5714285714287,
93 | 1428.5714285714287,
94 | 1428.5714285714287,
95 | 1666.6666666666667,
96 | 1666.6666666666667,
97 | 1666.6666666666667,
98 | 1666.6666666666667,
99 | 1666.6666666666667,
100 | 1666.6666666666667,
101 | 1666.6666666666667,
102 | 1666.6666666666667,
103 | 1666.6666666666667,
104 | 1666.6666666666667,
105 | 1666.6666666666667,
106 | 2000.0,
107 | 2000.0,
108 | 2000.0,
109 | 2000.0,
110 | 2000.0,
111 | 2000.0,
112 | 2000.0,
113 | 2000.0,
114 | 2000.0,
115 | 2000.0,
116 | 2000.0,
117 | 2000.0,
118 | 2000.0,
119 | 2000.0,
120 | 2000.0,
121 | 2000.0,
122 | 2000.0,
123 | 2500.0,
124 | 2500.0,
125 | 2500.0,
126 | 2500.0,
127 | 2500.0,
128 | 2500.0,
129 | 2500.0,
130 | 2500.0,
131 | 2500.0,
132 | 2500.0,
133 | 2500.0,
134 | 2500.0,
135 | 2500.0,
136 | 3333.3333333333335,
137 | 3333.3333333333335,
138 | 3333.3333333333335,
139 | 3333.3333333333335,
140 | 3333.3333333333335,
141 | 3333.3333333333335,
142 | 3333.3333333333335,
143 | 3333.3333333333335,
144 | 3333.3333333333335,
145 | 3333.3333333333335,
146 | 3333.3333333333335,
147 | 3333.3333333333335,
148 | 3333.3333333333335,
149 | 5000.0,
150 | 5000.0,
151 | 5000.0]
152 | }
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/distance_weighting.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | from bmab.external.lama.saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN
7 |
8 |
9 | def dummy_distance_weighter(real_img, pred_img, mask):
10 | return mask
11 |
12 |
13 | def get_gauss_kernel(kernel_size, width_factor=1):
14 | coords = torch.stack(torch.meshgrid(torch.arange(kernel_size),
15 | torch.arange(kernel_size)),
16 | dim=0).float()
17 | diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor)
18 | diff /= diff.sum()
19 | return diff
20 |
21 |
22 | class BlurMask(nn.Module):
23 | def __init__(self, kernel_size=5, width_factor=1):
24 | super().__init__()
25 | self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False)
26 | self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor))
27 |
28 | def forward(self, real_img, pred_img, mask):
29 | with torch.no_grad():
30 | result = self.filter(mask) * mask
31 | return result
32 |
33 |
34 | class EmulatedEDTMask(nn.Module):
35 | def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1):
36 | super().__init__()
37 | self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate',
38 | bias=False)
39 | self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float))
40 | self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False)
41 | self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor))
42 |
43 | def forward(self, real_img, pred_img, mask):
44 | with torch.no_grad():
45 | known_mask = 1 - mask
46 | dilated_known_mask = (self.dilate_filter(known_mask) > 1).float()
47 | result = self.blur_filter(1 - dilated_known_mask) * mask
48 | return result
49 |
50 |
51 | class PropagatePerceptualSim(nn.Module):
52 | def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3):
53 | super().__init__()
54 | vgg = torchvision.models.vgg19(pretrained=True).features
55 | vgg_avg_pooling = []
56 |
57 | for weights in vgg.parameters():
58 | weights.requires_grad = False
59 |
60 | cur_level_i = 0
61 | for module in vgg.modules():
62 | if module.__class__.__name__ == 'Sequential':
63 | continue
64 | elif module.__class__.__name__ == 'MaxPool2d':
65 | vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
66 | else:
67 | vgg_avg_pooling.append(module)
68 | if module.__class__.__name__ == 'ReLU':
69 | cur_level_i += 1
70 | if cur_level_i == level:
71 | break
72 |
73 | self.features = nn.Sequential(*vgg_avg_pooling)
74 |
75 | self.max_iters = max_iters
76 | self.temperature = temperature
77 | self.do_erode = erode_mask_size > 0
78 | if self.do_erode:
79 | self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False)
80 | self.erode_mask.weight.data.fill_(1)
81 |
82 | def forward(self, real_img, pred_img, mask):
83 | with torch.no_grad():
84 | real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img)
85 | real_feats = self.features(real_img)
86 |
87 | vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True)
88 | / self.temperature)
89 | horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True)
90 | / self.temperature)
91 |
92 | mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False)
93 | if self.do_erode:
94 | mask_scaled = (self.erode_mask(mask_scaled) > 1).float()
95 |
96 | cur_knowness = 1 - mask_scaled
97 |
98 | for iter_i in range(self.max_iters):
99 | new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate')
100 | new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate')
101 |
102 | new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate')
103 | new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate')
104 |
105 | new_knowness = torch.stack([new_top_knowness, new_bottom_knowness,
106 | new_left_knowness, new_right_knowness],
107 | dim=0).max(0).values
108 |
109 | cur_knowness = torch.max(cur_knowness, new_knowness)
110 |
111 | cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear')
112 | result = torch.min(mask, 1 - cur_knowness)
113 |
114 | return result
115 |
116 |
117 | def make_mask_distance_weighter(kind='none', **kwargs):
118 | if kind == 'none':
119 | return dummy_distance_weighter
120 | if kind == 'blur':
121 | return BlurMask(**kwargs)
122 | if kind == 'edt':
123 | return EmulatedEDTMask(**kwargs)
124 | if kind == 'pps':
125 | return PropagatePerceptualSim(**kwargs)
126 | raise ValueError(f'Unknown mask distance weighter kind {kind}')
127 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/feature_matching.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
8 | per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
9 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known
10 | return (pixel_weights * per_pixel_l2).mean()
11 |
12 |
13 | def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
14 | per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
15 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known
16 | return (pixel_weights * per_pixel_l1).mean()
17 |
18 |
19 | def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
20 | if mask is None:
21 | res = torch.stack([F.mse_loss(fake_feat, target_feat)
22 | for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
23 | else:
24 | res = 0
25 | norm = 0
26 | for fake_feat, target_feat in zip(fake_features, target_features):
27 | cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
28 | error_weights = 1 - cur_mask
29 | cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
30 | res = res + cur_val
31 | norm += 1
32 | res = res / norm
33 | return res
34 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/perceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | # from models.ade20k import ModelBuilder
7 | from bmab.external.lama.saicinpainting.utils import check_and_warn_input_range
8 |
9 |
10 | IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
11 | IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
12 |
13 |
14 | class PerceptualLoss(nn.Module):
15 | def __init__(self, normalize_inputs=True):
16 | super(PerceptualLoss, self).__init__()
17 |
18 | self.normalize_inputs = normalize_inputs
19 | self.mean_ = IMAGENET_MEAN
20 | self.std_ = IMAGENET_STD
21 |
22 | vgg = torchvision.models.vgg19(pretrained=True).features
23 | vgg_avg_pooling = []
24 |
25 | for weights in vgg.parameters():
26 | weights.requires_grad = False
27 |
28 | for module in vgg.modules():
29 | if module.__class__.__name__ == 'Sequential':
30 | continue
31 | elif module.__class__.__name__ == 'MaxPool2d':
32 | vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
33 | else:
34 | vgg_avg_pooling.append(module)
35 |
36 | self.vgg = nn.Sequential(*vgg_avg_pooling)
37 |
38 | def do_normalize_inputs(self, x):
39 | return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
40 |
41 | def partial_losses(self, input, target, mask=None):
42 | check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
43 |
44 | # we expect input and target to be in [0, 1] range
45 | losses = []
46 |
47 | if self.normalize_inputs:
48 | features_input = self.do_normalize_inputs(input)
49 | features_target = self.do_normalize_inputs(target)
50 | else:
51 | features_input = input
52 | features_target = target
53 |
54 | for layer in self.vgg[:30]:
55 |
56 | features_input = layer(features_input)
57 | features_target = layer(features_target)
58 |
59 | if layer.__class__.__name__ == 'ReLU':
60 | loss = F.mse_loss(features_input, features_target, reduction='none')
61 |
62 | if mask is not None:
63 | cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
64 | mode='bilinear', align_corners=False)
65 | loss = loss * (1 - cur_mask)
66 |
67 | loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
68 | losses.append(loss)
69 |
70 | return losses
71 |
72 | def forward(self, input, target, mask=None):
73 | losses = self.partial_losses(input, target, mask=mask)
74 | return torch.stack(losses).sum(dim=0)
75 |
76 | def get_global_features(self, input):
77 | check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
78 |
79 | if self.normalize_inputs:
80 | features_input = self.do_normalize_inputs(input)
81 | else:
82 | features_input = input
83 |
84 | features_input = self.vgg(features_input)
85 | return features_input
86 |
87 |
88 | class ResNetPL(nn.Module):
89 | def __init__(self, weight=1,
90 | weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
91 | super().__init__()
92 | self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
93 | arch_encoder=arch_encoder,
94 | arch_decoder='ppm_deepsup',
95 | fc_dim=2048,
96 | segmentation=segmentation)
97 | self.impl.eval()
98 | for w in self.impl.parameters():
99 | w.requires_grad_(False)
100 |
101 | self.weight = weight
102 |
103 | def forward(self, pred, target):
104 | pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
105 | target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
106 |
107 | pred_feats = self.impl(pred, return_feature_maps=True)
108 | target_feats = self.impl(target, return_feature_maps=True)
109 |
110 | result = torch.stack([F.mse_loss(cur_pred, cur_target)
111 | for cur_pred, cur_target
112 | in zip(pred_feats, target_feats)]).sum() * self.weight
113 | return result
114 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .constants import weights as constant_weights
6 |
7 |
8 | class CrossEntropy2d(nn.Module):
9 | def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs):
10 | """
11 | weight (Tensor, optional): a manual rescaling weight given to each class.
12 | If given, has to be a Tensor of size "nclasses"
13 | """
14 | super(CrossEntropy2d, self).__init__()
15 | self.reduction = reduction
16 | self.ignore_label = ignore_label
17 | self.weights = weights
18 | if self.weights is not None:
19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 | self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
21 |
22 | def forward(self, predict, target):
23 | """
24 | Args:
25 | predict:(n, c, h, w)
26 | target:(n, 1, h, w)
27 | """
28 | target = target.long()
29 | assert not target.requires_grad
30 | assert predict.dim() == 4, "{0}".format(predict.size())
31 | assert target.dim() == 4, "{0}".format(target.size())
32 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
33 | assert target.size(1) == 1, "{0}".format(target.size(1))
34 | assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
35 | assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
36 | target = target.squeeze(1)
37 | n, c, h, w = predict.size()
38 | target_mask = (target >= 0) * (target != self.ignore_label)
39 | target = target[target_mask]
40 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
41 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
42 | loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction)
43 | return loss
44 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/losses/style_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 |
5 |
6 | class PerceptualLoss(nn.Module):
7 | r"""
8 | Perceptual loss, VGG-based
9 | https://arxiv.org/abs/1603.08155
10 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
11 | """
12 |
13 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
14 | super(PerceptualLoss, self).__init__()
15 | self.add_module('vgg', VGG19())
16 | self.criterion = torch.nn.L1Loss()
17 | self.weights = weights
18 |
19 | def __call__(self, x, y):
20 | # Compute features
21 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
22 |
23 | content_loss = 0.0
24 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
25 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
26 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
27 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
28 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
29 |
30 |
31 | return content_loss
32 |
33 |
34 | class VGG19(torch.nn.Module):
35 | def __init__(self):
36 | super(VGG19, self).__init__()
37 | features = models.vgg19(pretrained=True).features
38 | self.relu1_1 = torch.nn.Sequential()
39 | self.relu1_2 = torch.nn.Sequential()
40 |
41 | self.relu2_1 = torch.nn.Sequential()
42 | self.relu2_2 = torch.nn.Sequential()
43 |
44 | self.relu3_1 = torch.nn.Sequential()
45 | self.relu3_2 = torch.nn.Sequential()
46 | self.relu3_3 = torch.nn.Sequential()
47 | self.relu3_4 = torch.nn.Sequential()
48 |
49 | self.relu4_1 = torch.nn.Sequential()
50 | self.relu4_2 = torch.nn.Sequential()
51 | self.relu4_3 = torch.nn.Sequential()
52 | self.relu4_4 = torch.nn.Sequential()
53 |
54 | self.relu5_1 = torch.nn.Sequential()
55 | self.relu5_2 = torch.nn.Sequential()
56 | self.relu5_3 = torch.nn.Sequential()
57 | self.relu5_4 = torch.nn.Sequential()
58 |
59 | for x in range(2):
60 | self.relu1_1.add_module(str(x), features[x])
61 |
62 | for x in range(2, 4):
63 | self.relu1_2.add_module(str(x), features[x])
64 |
65 | for x in range(4, 7):
66 | self.relu2_1.add_module(str(x), features[x])
67 |
68 | for x in range(7, 9):
69 | self.relu2_2.add_module(str(x), features[x])
70 |
71 | for x in range(9, 12):
72 | self.relu3_1.add_module(str(x), features[x])
73 |
74 | for x in range(12, 14):
75 | self.relu3_2.add_module(str(x), features[x])
76 |
77 | for x in range(14, 16):
78 | self.relu3_2.add_module(str(x), features[x])
79 |
80 | for x in range(16, 18):
81 | self.relu3_4.add_module(str(x), features[x])
82 |
83 | for x in range(18, 21):
84 | self.relu4_1.add_module(str(x), features[x])
85 |
86 | for x in range(21, 23):
87 | self.relu4_2.add_module(str(x), features[x])
88 |
89 | for x in range(23, 25):
90 | self.relu4_3.add_module(str(x), features[x])
91 |
92 | for x in range(25, 27):
93 | self.relu4_4.add_module(str(x), features[x])
94 |
95 | for x in range(27, 30):
96 | self.relu5_1.add_module(str(x), features[x])
97 |
98 | for x in range(30, 32):
99 | self.relu5_2.add_module(str(x), features[x])
100 |
101 | for x in range(32, 34):
102 | self.relu5_3.add_module(str(x), features[x])
103 |
104 | for x in range(34, 36):
105 | self.relu5_4.add_module(str(x), features[x])
106 |
107 | # don't need the gradients, just want the features
108 | for param in self.parameters():
109 | param.requires_grad = False
110 |
111 | def forward(self, x):
112 | relu1_1 = self.relu1_1(x)
113 | relu1_2 = self.relu1_2(relu1_1)
114 |
115 | relu2_1 = self.relu2_1(relu1_2)
116 | relu2_2 = self.relu2_2(relu2_1)
117 |
118 | relu3_1 = self.relu3_1(relu2_2)
119 | relu3_2 = self.relu3_2(relu3_1)
120 | relu3_3 = self.relu3_3(relu3_2)
121 | relu3_4 = self.relu3_4(relu3_3)
122 |
123 | relu4_1 = self.relu4_1(relu3_4)
124 | relu4_2 = self.relu4_2(relu4_1)
125 | relu4_3 = self.relu4_3(relu4_2)
126 | relu4_4 = self.relu4_4(relu4_3)
127 |
128 | relu5_1 = self.relu5_1(relu4_4)
129 | relu5_2 = self.relu5_2(relu5_1)
130 | relu5_3 = self.relu5_3(relu5_2)
131 | relu5_4 = self.relu5_4(relu5_3)
132 |
133 | out = {
134 | 'relu1_1': relu1_1,
135 | 'relu1_2': relu1_2,
136 |
137 | 'relu2_1': relu2_1,
138 | 'relu2_2': relu2_2,
139 |
140 | 'relu3_1': relu3_1,
141 | 'relu3_2': relu3_2,
142 | 'relu3_3': relu3_3,
143 | 'relu3_4': relu3_4,
144 |
145 | 'relu4_1': relu4_1,
146 | 'relu4_2': relu4_2,
147 | 'relu4_3': relu4_3,
148 | 'relu4_4': relu4_4,
149 |
150 | 'relu5_1': relu5_1,
151 | 'relu5_2': relu5_2,
152 | 'relu5_3': relu5_3,
153 | 'relu5_4': relu5_4,
154 | }
155 | return out
156 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from bmab.external.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator
4 | from bmab.external.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
5 | NLayerDiscriminator, MultidilatedNLayerDiscriminator
6 |
7 |
8 | def make_generator(config, kind, **kwargs):
9 | logging.info(f'Make generator {kind}')
10 |
11 | if kind == 'pix2pixhd_multidilated':
12 | return MultiDilatedGlobalGenerator(**kwargs)
13 |
14 | if kind == 'pix2pixhd_global':
15 | return GlobalGenerator(**kwargs)
16 |
17 | if kind == 'ffc_resnet':
18 | return FFCResNetGenerator(**kwargs)
19 |
20 | raise ValueError(f'Unknown generator kind {kind}')
21 |
22 |
23 | def make_discriminator(kind, **kwargs):
24 | logging.info(f'Make discriminator {kind}')
25 |
26 | if kind == 'pix2pixhd_nlayer_multidilated':
27 | return MultidilatedNLayerDiscriminator(**kwargs)
28 |
29 | if kind == 'pix2pixhd_nlayer':
30 | return NLayerDiscriminator(**kwargs)
31 |
32 | raise ValueError(f'Unknown discriminator kind {kind}')
33 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Tuple, List
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from bmab.external.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
8 | from bmab.external.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
9 |
10 |
11 | class BaseDiscriminator(nn.Module):
12 | @abc.abstractmethod
13 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
14 | """
15 | Predict scores and get intermediate activations. Useful for feature matching loss
16 | :return tuple (scores, list of intermediate activations)
17 | """
18 | raise NotImplemented()
19 |
20 |
21 | def get_conv_block_ctor(kind='default'):
22 | if not isinstance(kind, str):
23 | return kind
24 | if kind == 'default':
25 | return nn.Conv2d
26 | if kind == 'depthwise':
27 | return DepthWiseSeperableConv
28 | if kind == 'multidilated':
29 | return MultidilatedConv
30 | raise ValueError(f'Unknown convolutional block kind {kind}')
31 |
32 |
33 | def get_norm_layer(kind='bn'):
34 | if not isinstance(kind, str):
35 | return kind
36 | if kind == 'bn':
37 | return nn.BatchNorm2d
38 | if kind == 'in':
39 | return nn.InstanceNorm2d
40 | raise ValueError(f'Unknown norm block kind {kind}')
41 |
42 |
43 | def get_activation(kind='tanh'):
44 | if kind == 'tanh':
45 | return nn.Tanh()
46 | if kind == 'sigmoid':
47 | return nn.Sigmoid()
48 | if kind is False:
49 | return nn.Identity()
50 | raise ValueError(f'Unknown activation kind {kind}')
51 |
52 |
53 | class SimpleMultiStepGenerator(nn.Module):
54 | def __init__(self, steps: List[nn.Module]):
55 | super().__init__()
56 | self.steps = nn.ModuleList(steps)
57 |
58 | def forward(self, x):
59 | cur_in = x
60 | outs = []
61 | for step in self.steps:
62 | cur_out = step(cur_in)
63 | outs.append(cur_out)
64 | cur_in = torch.cat((cur_in, cur_out), dim=1)
65 | return torch.cat(outs[::-1], dim=1)
66 |
67 | def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
68 | if kind == 'convtranspose':
69 | return [nn.ConvTranspose2d(min(max_features, ngf * mult),
70 | min(max_features, int(ngf * mult / 2)),
71 | kernel_size=3, stride=2, padding=1, output_padding=1),
72 | norm_layer(min(max_features, int(ngf * mult / 2))), activation]
73 | elif kind == 'bilinear':
74 | return [nn.Upsample(scale_factor=2, mode='bilinear'),
75 | DepthWiseSeperableConv(min(max_features, ngf * mult),
76 | min(max_features, int(ngf * mult / 2)),
77 | kernel_size=3, stride=1, padding=1),
78 | norm_layer(min(max_features, int(ngf * mult / 2))), activation]
79 | else:
80 | raise Exception(f"Invalid deconv kind: {kind}")
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/depthwise_sep_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class DepthWiseSeperableConv(nn.Module):
5 | def __init__(self, in_dim, out_dim, *args, **kwargs):
6 | super().__init__()
7 | if 'groups' in kwargs:
8 | # ignoring groups for Depthwise Sep Conv
9 | del kwargs['groups']
10 |
11 | self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
12 | self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
13 |
14 | def forward(self, x):
15 | out = self.depthwise(x)
16 | out = self.pointwise(out)
17 | return out
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/fake_fakes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia import SamplePadding
3 | from kornia.augmentation import RandomAffine, CenterCrop
4 |
5 |
6 | class FakeFakesGenerator:
7 | def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
8 | self.grad_aug = RandomAffine(degrees=360,
9 | translate=0.2,
10 | padding_mode=SamplePadding.REFLECTION,
11 | keepdim=False,
12 | p=1)
13 | self.img_aug = RandomAffine(degrees=img_aug_degree,
14 | translate=img_aug_translate,
15 | padding_mode=SamplePadding.REFLECTION,
16 | keepdim=True,
17 | p=1)
18 | self.aug_proba = aug_proba
19 |
20 | def __call__(self, input_images, masks):
21 | blend_masks = self._fill_masks_with_gradient(masks)
22 | blend_target = self._make_blend_target(input_images)
23 | result = input_images * (1 - blend_masks) + blend_target * blend_masks
24 | return result, blend_masks
25 |
26 | def _make_blend_target(self, input_images):
27 | batch_size = input_images.shape[0]
28 | permuted = input_images[torch.randperm(batch_size)]
29 | augmented = self.img_aug(input_images)
30 | is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
31 | result = augmented * is_aug + permuted * (1 - is_aug)
32 | return result
33 |
34 | def _fill_masks_with_gradient(self, masks):
35 | batch_size, _, height, width = masks.shape
36 | grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
37 | .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
38 | grad = self.grad_aug(grad)
39 | grad = CenterCrop((height, width))(grad)
40 | grad *= masks
41 |
42 | grad_for_min = grad + (1 - masks) * 10
43 | grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
44 | grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
45 | grad.clamp_(min=0, max=1)
46 |
47 | return grad
48 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/multidilated_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | from bmab.external.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
5 |
6 | class MultidilatedConv(nn.Module):
7 | def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
8 | shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
9 | super().__init__()
10 | convs = []
11 | self.equal_dim = equal_dim
12 | assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
13 | if comb_mode in ('cat_out', 'cat_both'):
14 | self.cat_out = True
15 | if equal_dim:
16 | assert out_dim % dilation_num == 0
17 | out_dims = [out_dim // dilation_num] * dilation_num
18 | self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
19 | else:
20 | out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
21 | out_dims.append(out_dim - sum(out_dims))
22 | index = []
23 | starts = [0] + out_dims[:-1]
24 | lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
25 | for i in range(out_dims[-1]):
26 | for j in range(dilation_num):
27 | index += list(range(starts[j], starts[j] + lengths[j]))
28 | starts[j] += lengths[j]
29 | self.index = index
30 | assert(len(index) == out_dim)
31 | self.out_dims = out_dims
32 | else:
33 | self.cat_out = False
34 | self.out_dims = [out_dim] * dilation_num
35 |
36 | if comb_mode in ('cat_in', 'cat_both'):
37 | if equal_dim:
38 | assert in_dim % dilation_num == 0
39 | in_dims = [in_dim // dilation_num] * dilation_num
40 | else:
41 | in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
42 | in_dims.append(in_dim - sum(in_dims))
43 | self.in_dims = in_dims
44 | self.cat_in = True
45 | else:
46 | self.cat_in = False
47 | self.in_dims = [in_dim] * dilation_num
48 |
49 | conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
50 | dilation = min_dilation
51 | for i in range(dilation_num):
52 | if isinstance(padding, int):
53 | cur_padding = padding * dilation
54 | else:
55 | cur_padding = padding[i]
56 | convs.append(conv_type(
57 | self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
58 | ))
59 | if i > 0 and shared_weights:
60 | convs[-1].weight = convs[0].weight
61 | convs[-1].bias = convs[0].bias
62 | dilation *= 2
63 | self.convs = nn.ModuleList(convs)
64 |
65 | self.shuffle_in_channels = shuffle_in_channels
66 | if self.shuffle_in_channels:
67 | # shuffle list as shuffling of tensors is nondeterministic
68 | in_channels_permute = list(range(in_dim))
69 | random.shuffle(in_channels_permute)
70 | # save as buffer so it is saved and loaded with checkpoint
71 | self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
72 |
73 | def forward(self, x):
74 | if self.shuffle_in_channels:
75 | x = x[:, self.in_channels_permute]
76 |
77 | outs = []
78 | if self.cat_in:
79 | if self.equal_dim:
80 | x = x.chunk(len(self.convs), dim=1)
81 | else:
82 | new_x = []
83 | start = 0
84 | for dim in self.in_dims:
85 | new_x.append(x[:, start:start+dim])
86 | start += dim
87 | x = new_x
88 | for i, conv in enumerate(self.convs):
89 | if self.cat_in:
90 | input = x[i]
91 | else:
92 | input = x
93 | outs.append(conv(input))
94 | if self.cat_out:
95 | out = torch.cat(outs, dim=1)[:, self.index]
96 | else:
97 | out = sum(outs)
98 | return out
99 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/multiscale.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Union, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from bmab.external.lama.saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
8 | from bmab.external.lama.saicinpainting.training.modules.pix2pixhd import ResnetBlock
9 |
10 |
11 | class ResNetHead(nn.Module):
12 | def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
13 | padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
14 | assert (n_blocks >= 0)
15 | super(ResNetHead, self).__init__()
16 |
17 | conv_layer = get_conv_block_ctor(conv_kind)
18 |
19 | model = [nn.ReflectionPad2d(3),
20 | conv_layer(input_nc, ngf, kernel_size=7, padding=0),
21 | norm_layer(ngf),
22 | activation]
23 |
24 | ### downsample
25 | for i in range(n_downsampling):
26 | mult = 2 ** i
27 | model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
28 | norm_layer(ngf * mult * 2),
29 | activation]
30 |
31 | mult = 2 ** n_downsampling
32 |
33 | ### resnet blocks
34 | for i in range(n_blocks):
35 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
36 | conv_kind=conv_kind)]
37 |
38 | self.model = nn.Sequential(*model)
39 |
40 | def forward(self, input):
41 | return self.model(input)
42 |
43 |
44 | class ResNetTail(nn.Module):
45 | def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
46 | padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
47 | up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
48 | add_in_proj=None):
49 | assert (n_blocks >= 0)
50 | super(ResNetTail, self).__init__()
51 |
52 | mult = 2 ** n_downsampling
53 |
54 | model = []
55 |
56 | if add_in_proj is not None:
57 | model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
58 |
59 | ### resnet blocks
60 | for i in range(n_blocks):
61 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
62 | conv_kind=conv_kind)]
63 |
64 | ### upsample
65 | for i in range(n_downsampling):
66 | mult = 2 ** (n_downsampling - i)
67 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
68 | output_padding=1),
69 | up_norm_layer(int(ngf * mult / 2)),
70 | up_activation]
71 | self.model = nn.Sequential(*model)
72 |
73 | out_layers = []
74 | for _ in range(out_extra_layers_n):
75 | out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
76 | up_norm_layer(ngf),
77 | up_activation]
78 | out_layers += [nn.ReflectionPad2d(3),
79 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
80 |
81 | if add_out_act:
82 | out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
83 |
84 | self.out_proj = nn.Sequential(*out_layers)
85 |
86 | def forward(self, input, return_last_act=False):
87 | features = self.model(input)
88 | out = self.out_proj(features)
89 | if return_last_act:
90 | return out, features
91 | else:
92 | return out
93 |
94 |
95 | class MultiscaleResNet(nn.Module):
96 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
97 | norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
98 | up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
99 | out_cumulative=False, return_only_hr=False):
100 | super().__init__()
101 |
102 | self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
103 | n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
104 | conv_kind=conv_kind, activation=activation)
105 | for i in range(n_scales)])
106 | tail_in_feats = ngf * (2 ** n_downsampling) + ngf
107 | self.tails = nn.ModuleList([ResNetTail(output_nc,
108 | ngf=ngf, n_downsampling=n_downsampling,
109 | n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
110 | conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
111 | up_activation=up_activation, add_out_act=add_out_act,
112 | out_extra_layers_n=out_extra_layers_n,
113 | add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
114 | for i in range(n_scales)])
115 |
116 | self.out_cumulative = out_cumulative
117 | self.return_only_hr = return_only_hr
118 |
119 | @property
120 | def num_scales(self):
121 | return len(self.heads)
122 |
123 | def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
124 | -> Union[torch.Tensor, List[torch.Tensor]]:
125 | """
126 | :param ms_inputs: List of inputs of different resolutions from HR to LR
127 | :param smallest_scales_num: int or None, number of smallest scales to take at input
128 | :return: Depending on return_only_hr:
129 | True: Only the most HR output
130 | False: List of outputs of different resolutions from HR to LR
131 | """
132 | if smallest_scales_num is None:
133 | assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
134 | smallest_scales_num = len(self.heads)
135 | else:
136 | assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
137 |
138 | cur_heads = self.heads[-smallest_scales_num:]
139 | ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
140 |
141 | all_outputs = []
142 | prev_tail_features = None
143 | for i in range(len(ms_features)):
144 | scale_i = -i - 1
145 |
146 | cur_tail_input = ms_features[-i - 1]
147 | if prev_tail_features is not None:
148 | if prev_tail_features.shape != cur_tail_input.shape:
149 | prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
150 | mode='bilinear', align_corners=False)
151 | cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
152 |
153 | cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
154 |
155 | prev_tail_features = cur_tail_feats
156 | all_outputs.append(cur_out)
157 |
158 | if self.out_cumulative:
159 | all_outputs_cum = [all_outputs[0]]
160 | for i in range(1, len(ms_features)):
161 | cur_out = all_outputs[i]
162 | cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
163 | mode='bilinear', align_corners=False)
164 | all_outputs_cum.append(cur_out_cum)
165 | all_outputs = all_outputs_cum
166 |
167 | if self.return_only_hr:
168 | return all_outputs[-1]
169 | else:
170 | return all_outputs[::-1]
171 |
172 |
173 | class MultiscaleDiscriminatorSimple(nn.Module):
174 | def __init__(self, ms_impl):
175 | super().__init__()
176 | self.ms_impl = nn.ModuleList(ms_impl)
177 |
178 | @property
179 | def num_scales(self):
180 | return len(self.ms_impl)
181 |
182 | def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
183 | -> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
184 | """
185 | :param ms_inputs: List of inputs of different resolutions from HR to LR
186 | :param smallest_scales_num: int or None, number of smallest scales to take at input
187 | :return: List of pairs (prediction, features) for different resolutions from HR to LR
188 | """
189 | if smallest_scales_num is None:
190 | assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
191 | smallest_scales_num = len(self.heads)
192 | else:
193 | assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
194 | (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
195 |
196 | return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
197 |
198 |
199 | class SingleToMultiScaleInputMixin:
200 | def forward(self, x: torch.Tensor) -> List:
201 | orig_height, orig_width = x.shape[2:]
202 | factors = [2 ** i for i in range(self.num_scales)]
203 | ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
204 | for f in factors]
205 | return super().forward(ms_inputs)
206 |
207 |
208 | class GeneratorMultiToSingleOutputMixin:
209 | def forward(self, x):
210 | return super().forward(x)[0]
211 |
212 |
213 | class DiscriminatorMultiToSingleOutputMixin:
214 | def forward(self, x):
215 | out_feat_tuples = super().forward(x)
216 | return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
217 |
218 |
219 | class DiscriminatorMultiToSingleOutputStackedMixin:
220 | def __init__(self, *args, return_feats_only_levels=None, **kwargs):
221 | super().__init__(*args, **kwargs)
222 | self.return_feats_only_levels = return_feats_only_levels
223 |
224 | def forward(self, x):
225 | out_feat_tuples = super().forward(x)
226 | outs = [out for out, _ in out_feat_tuples]
227 | scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
228 | mode='bilinear', align_corners=False)
229 | for cur_out in outs[1:]]
230 | out = torch.cat(scaled_outs, dim=1)
231 | if self.return_feats_only_levels is not None:
232 | feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
233 | else:
234 | feat_lists = [flist for _, flist in out_feat_tuples]
235 | feats = [f for flist in feat_lists for f in flist]
236 | return out, feats
237 |
238 |
239 | class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
240 | pass
241 |
242 |
243 | class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
244 | pass
245 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/spatial_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from kornia.geometry.transform import rotate
5 |
6 |
7 | class LearnableSpatialTransformWrapper(nn.Module):
8 | def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
9 | super().__init__()
10 | self.impl = impl
11 | self.angle = torch.rand(1) * angle_init_range
12 | if train_angle:
13 | self.angle = nn.Parameter(self.angle, requires_grad=True)
14 | self.pad_coef = pad_coef
15 |
16 | def forward(self, x):
17 | if torch.is_tensor(x):
18 | return self.inverse_transform(self.impl(self.transform(x)), x)
19 | elif isinstance(x, tuple):
20 | x_trans = tuple(self.transform(elem) for elem in x)
21 | y_trans = self.impl(x_trans)
22 | return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
23 | else:
24 | raise ValueError(f'Unexpected input type {type(x)}')
25 |
26 | def transform(self, x):
27 | height, width = x.shape[2:]
28 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
29 | x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
30 | x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
31 | return x_padded_rotated
32 |
33 | def inverse_transform(self, y_padded_rotated, orig_x):
34 | height, width = orig_x.shape[2:]
35 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
36 |
37 | y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
38 | y_height, y_width = y_padded.shape[2:]
39 | y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
40 | return y
41 |
42 |
43 | if __name__ == '__main__':
44 | layer = LearnableSpatialTransformWrapper(nn.Identity())
45 | x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
46 | y = layer(x)
47 | assert x.shape == y.shape
48 | assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
49 | print('all ok')
50 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/modules/squeeze_excitation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class SELayer(nn.Module):
5 | def __init__(self, channel, reduction=16):
6 | super(SELayer, self).__init__()
7 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
8 | self.fc = nn.Sequential(
9 | nn.Linear(channel, channel // reduction, bias=False),
10 | nn.ReLU(inplace=True),
11 | nn.Linear(channel // reduction, channel, bias=False),
12 | nn.Sigmoid()
13 | )
14 |
15 | def forward(self, x):
16 | b, c, _, _ = x.size()
17 | y = self.avg_pool(x).view(b, c)
18 | y = self.fc(y).view(b, c, 1, 1)
19 | res = x * y.expand_as(x)
20 | return res
21 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from bmab.external.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
4 |
5 |
6 | def get_training_model_class(kind):
7 | if kind == 'default':
8 | return DefaultInpaintingTrainingModule
9 |
10 | raise ValueError(f'Unknown trainer module {kind}')
11 |
12 |
13 | def make_training_model(config):
14 | kind = config.training_model.kind
15 | kwargs = dict(config.training_model)
16 | kwargs.pop('kind')
17 | kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
18 |
19 | logging.info(f'Make training model {kind}')
20 |
21 | cls = get_training_model_class(kind)
22 | return cls(config, **kwargs)
23 |
24 |
25 | def load_checkpoint(train_config, path, map_location='cuda', strict=True):
26 | model = make_training_model(train_config).generator
27 | state = torch.load(path, map_location=map_location)
28 | model.load_state_dict(state, strict=strict)
29 | return model
30 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/trainers/default.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from omegaconf import OmegaConf
6 |
7 | # from bmab.external.lama.saicinpainting.training.data.datasets import make_constant_area_crop_params
8 | from bmab.external.lama.saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter
9 | from bmab.external.lama.saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss
10 | # from bmab.external.lama.saicinpainting.training.modules.fake_fakes import FakeFakesGenerator
11 | from bmab.external.lama.saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise
12 | from bmab.external.lama.saicinpainting.utils import add_prefix_to_keys, get_ramp
13 |
14 | LOGGER = logging.getLogger(__name__)
15 |
16 |
17 | def make_constant_area_crop_batch(batch, **kwargs):
18 | crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
19 | img_width=batch['image'].shape[3],
20 | **kwargs)
21 | batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
22 | batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width]
23 | return batch
24 |
25 |
26 | class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
27 | def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
28 | add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
29 | distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
30 | fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
31 | **kwargs):
32 | super().__init__(*args, **kwargs)
33 | self.concat_mask = concat_mask
34 | self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
35 | self.image_to_discriminator = image_to_discriminator
36 | self.add_noise_kwargs = add_noise_kwargs
37 | self.noise_fill_hole = noise_fill_hole
38 | self.const_area_crop_kwargs = const_area_crop_kwargs
39 | self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
40 | if distance_weighter_kwargs is not None else None
41 | self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
42 |
43 | self.fake_fakes_proba = fake_fakes_proba
44 | if self.fake_fakes_proba > 1e-3:
45 | self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
46 |
47 | def forward(self, batch):
48 | if self.training and self.rescale_size_getter is not None:
49 | cur_size = self.rescale_size_getter(self.global_step)
50 | batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
51 | batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
52 |
53 | if self.training and self.const_area_crop_kwargs is not None:
54 | batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
55 |
56 | img = batch['image']
57 | mask = batch['mask']
58 |
59 | masked_img = img * (1 - mask)
60 |
61 | if self.add_noise_kwargs is not None:
62 | noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs)
63 | if self.noise_fill_hole:
64 | masked_img = masked_img + mask * noise[:, :masked_img.shape[1]]
65 | masked_img = torch.cat([masked_img, noise], dim=1)
66 |
67 | if self.concat_mask:
68 | masked_img = torch.cat([masked_img, mask], dim=1)
69 |
70 | batch['predicted_image'] = self.generator(masked_img)
71 | batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
72 |
73 | if self.fake_fakes_proba > 1e-3:
74 | if self.training and torch.rand(1).item() < self.fake_fakes_proba:
75 | batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
76 | batch['use_fake_fakes'] = True
77 | else:
78 | batch['fake_fakes'] = torch.zeros_like(img)
79 | batch['fake_fakes_masks'] = torch.zeros_like(mask)
80 | batch['use_fake_fakes'] = False
81 |
82 | batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
83 | if self.refine_mask_for_losses is not None and self.training \
84 | else mask
85 |
86 | return batch
87 |
88 | def generator_loss(self, batch):
89 | img = batch['image']
90 | predicted_img = batch[self.image_to_discriminator]
91 | original_mask = batch['mask']
92 | supervised_mask = batch['mask_for_losses']
93 |
94 | # L1
95 | l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
96 | self.config.losses.l1.weight_known,
97 | self.config.losses.l1.weight_missing)
98 |
99 | total_loss = l1_value
100 | metrics = dict(gen_l1=l1_value)
101 |
102 | # vgg-based perceptual loss
103 | if self.config.losses.perceptual.weight > 0:
104 | pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight
105 | total_loss = total_loss + pl_value
106 | metrics['gen_pl'] = pl_value
107 |
108 | # discriminator
109 | # adversarial_loss calls backward by itself
110 | mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
111 | self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img,
112 | generator=self.generator, discriminator=self.discriminator)
113 | discr_real_pred, discr_real_features = self.discriminator(img)
114 | discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
115 | adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img,
116 | fake_batch=predicted_img,
117 | discr_real_pred=discr_real_pred,
118 | discr_fake_pred=discr_fake_pred,
119 | mask=mask_for_discr)
120 | total_loss = total_loss + adv_gen_loss
121 | metrics['gen_adv'] = adv_gen_loss
122 | metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
123 |
124 | # feature matching
125 | if self.config.losses.feature_matching.weight > 0:
126 | need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False)
127 | mask_for_fm = supervised_mask if need_mask_in_fm else None
128 | fm_value = feature_matching_loss(discr_fake_features, discr_real_features,
129 | mask=mask_for_fm) * self.config.losses.feature_matching.weight
130 | total_loss = total_loss + fm_value
131 | metrics['gen_fm'] = fm_value
132 |
133 | if self.loss_resnet_pl is not None:
134 | resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
135 | total_loss = total_loss + resnet_pl_value
136 | metrics['gen_resnet_pl'] = resnet_pl_value
137 |
138 | return total_loss, metrics
139 |
140 | def discriminator_loss(self, batch):
141 | total_loss = 0
142 | metrics = {}
143 |
144 | predicted_img = batch[self.image_to_discriminator].detach()
145 | self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img,
146 | generator=self.generator, discriminator=self.discriminator)
147 | discr_real_pred, discr_real_features = self.discriminator(batch['image'])
148 | discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
149 | adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'],
150 | fake_batch=predicted_img,
151 | discr_real_pred=discr_real_pred,
152 | discr_fake_pred=discr_fake_pred,
153 | mask=batch['mask'])
154 | total_loss = total_loss + adv_discr_loss
155 | metrics['discr_adv'] = adv_discr_loss
156 | metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
157 |
158 |
159 | if batch.get('use_fake_fakes', False):
160 | fake_fakes = batch['fake_fakes']
161 | self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes,
162 | generator=self.generator, discriminator=self.discriminator)
163 | discr_fake_fakes_pred, _ = self.discriminator(fake_fakes)
164 | fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss(
165 | real_batch=batch['image'],
166 | fake_batch=fake_fakes,
167 | discr_real_pred=discr_real_pred,
168 | discr_fake_pred=discr_fake_fakes_pred,
169 | mask=batch['mask']
170 | )
171 | total_loss = total_loss + fake_fakes_adv_discr_loss
172 | metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
173 | metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
174 |
175 | return total_loss, metrics
176 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from bmab.external.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer
4 | from bmab.external.lama.saicinpainting.training.visualizers.noop import NoopVisualizer
5 |
6 |
7 | def make_visualizer(kind, **kwargs):
8 | logging.info(f'Make visualizer {kind}')
9 |
10 | if kind == 'directory':
11 | return DirectoryVisualizer(**kwargs)
12 | if kind == 'noop':
13 | return NoopVisualizer()
14 |
15 | raise ValueError(f'Unknown visualizer kind {kind}')
16 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/visualizers/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Dict, List
3 |
4 | import numpy as np
5 | import torch
6 | from skimage import color
7 | from skimage.segmentation import mark_boundaries
8 |
9 | from . import colors
10 |
11 | COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
12 |
13 |
14 | class BaseVisualizer:
15 | @abc.abstractmethod
16 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
17 | """
18 | Take a batch, make an image from it and visualize
19 | """
20 | raise NotImplementedError()
21 |
22 |
23 | def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
24 | last_without_mask=True, rescale_keys=None, mask_only_first=None,
25 | black_mask=False) -> np.ndarray:
26 | mask = images_dict['mask'] > 0.5
27 | result = []
28 | for i, k in enumerate(keys):
29 | img = images_dict[k]
30 | img = np.transpose(img, (1, 2, 0))
31 |
32 | if rescale_keys is not None and k in rescale_keys:
33 | img = img - img.min()
34 | img /= img.max() + 1e-5
35 | if len(img.shape) == 2:
36 | img = np.expand_dims(img, 2)
37 |
38 | if img.shape[2] == 1:
39 | img = np.repeat(img, 3, axis=2)
40 | elif (img.shape[2] > 3):
41 | img_classes = img.argmax(2)
42 | img = color.label2rgb(img_classes, colors=COLORS)
43 |
44 | if mask_only_first:
45 | need_mark_boundaries = i == 0
46 | else:
47 | need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
48 |
49 | if need_mark_boundaries:
50 | if black_mask:
51 | img = img * (1 - mask[0][..., None])
52 | img = mark_boundaries(img,
53 | mask[0],
54 | color=(1., 0., 0.),
55 | outline_color=(1., 1., 1.),
56 | mode='thick')
57 | result.append(img)
58 | return np.concatenate(result, axis=1)
59 |
60 |
61 | def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
62 | last_without_mask=True, rescale_keys=None) -> np.ndarray:
63 | batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
64 | if k in keys or k == 'mask'}
65 |
66 | batch_size = next(iter(batch.values())).shape[0]
67 | items_to_vis = min(batch_size, max_items)
68 | result = []
69 | for i in range(items_to_vis):
70 | cur_dct = {k: tens[i] for k, tens in batch.items()}
71 | result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
72 | rescale_keys=rescale_keys))
73 | return np.concatenate(result, axis=0)
74 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/visualizers/colors.py:
--------------------------------------------------------------------------------
1 | import random
2 | import colorsys
3 |
4 | import numpy as np
5 | import matplotlib
6 | matplotlib.use('agg')
7 | import matplotlib.pyplot as plt
8 | from matplotlib.colors import LinearSegmentedColormap
9 |
10 |
11 | def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False):
12 | # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
13 | """
14 | Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
15 | :param nlabels: Number of labels (size of colormap)
16 | :param type: 'bright' for strong colors, 'soft' for pastel colors
17 | :param first_color_black: Option to use first color as black, True or False
18 | :param last_color_black: Option to use last color as black, True or False
19 | :param verbose: Prints the number of labels and shows the colormap. True or False
20 | :return: colormap for matplotlib
21 | """
22 | if type not in ('bright', 'soft'):
23 | print ('Please choose "bright" or "soft" for type')
24 | return
25 |
26 | if verbose:
27 | print('Number of labels: ' + str(nlabels))
28 |
29 | # Generate color map for bright colors, based on hsv
30 | if type == 'bright':
31 | randHSVcolors = [(np.random.uniform(low=0.0, high=1),
32 | np.random.uniform(low=0.2, high=1),
33 | np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]
34 |
35 | # Convert HSV list to RGB
36 | randRGBcolors = []
37 | for HSVcolor in randHSVcolors:
38 | randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))
39 |
40 | if first_color_black:
41 | randRGBcolors[0] = [0, 0, 0]
42 |
43 | if last_color_black:
44 | randRGBcolors[-1] = [0, 0, 0]
45 |
46 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
47 |
48 | # Generate soft pastel colors, by limiting the RGB spectrum
49 | if type == 'soft':
50 | low = 0.6
51 | high = 0.95
52 | randRGBcolors = [(np.random.uniform(low=low, high=high),
53 | np.random.uniform(low=low, high=high),
54 | np.random.uniform(low=low, high=high)) for i in range(nlabels)]
55 |
56 | if first_color_black:
57 | randRGBcolors[0] = [0, 0, 0]
58 |
59 | if last_color_black:
60 | randRGBcolors[-1] = [0, 0, 0]
61 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
62 |
63 | # Display colorbar
64 | if verbose:
65 | from matplotlib import colors, colorbar
66 | from matplotlib import pyplot as plt
67 | fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))
68 |
69 | bounds = np.linspace(0, nlabels, nlabels + 1)
70 | norm = colors.BoundaryNorm(bounds, nlabels)
71 |
72 | cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
73 | boundaries=bounds, format='%1i', orientation=u'horizontal')
74 |
75 | return randRGBcolors, random_colormap
76 |
77 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/visualizers/directory.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 |
6 | from bmab.external.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch
7 | from bmab.external.lama.saicinpainting.utils import check_and_warn_input_range
8 |
9 |
10 | class DirectoryVisualizer(BaseVisualizer):
11 | DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ')
12 |
13 | def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10,
14 | last_without_mask=True, rescale_keys=None):
15 | self.outdir = outdir
16 | os.makedirs(self.outdir, exist_ok=True)
17 | self.key_order = key_order
18 | self.max_items_in_batch = max_items_in_batch
19 | self.last_without_mask = last_without_mask
20 | self.rescale_keys = rescale_keys
21 |
22 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
23 | check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image')
24 | vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch,
25 | last_without_mask=self.last_without_mask,
26 | rescale_keys=self.rescale_keys)
27 |
28 | vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
29 |
30 | curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}')
31 | os.makedirs(curoutdir, exist_ok=True)
32 | rank_suffix = f'_r{rank}' if rank is not None else ''
33 | out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg')
34 |
35 | vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
36 | cv2.imwrite(out_fname, vis_img)
37 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/training/visualizers/noop.py:
--------------------------------------------------------------------------------
1 | from bmab.external.lama.saicinpainting.training.visualizers.base import BaseVisualizer
2 |
3 |
4 | class NoopVisualizer(BaseVisualizer):
5 | def __init__(self, *args, **kwargs):
6 | pass
7 |
8 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
9 | pass
10 |
--------------------------------------------------------------------------------
/bmab/external/lama/saicinpainting/utils.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import functools
3 | import logging
4 | import numbers
5 | import os
6 | import signal
7 | import sys
8 | import traceback
9 | import warnings
10 |
11 | import torch
12 | from pytorch_lightning import seed_everything
13 |
14 | LOGGER = logging.getLogger(__name__)
15 |
16 |
17 | def check_and_warn_input_range(tensor, min_value, max_value, name):
18 | actual_min = tensor.min()
19 | actual_max = tensor.max()
20 | if actual_min < min_value or actual_max > max_value:
21 | warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
22 |
23 |
24 | def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
25 | for k, v in cur_dict.items():
26 | target_key = prefix + k
27 | target[target_key] = target.get(target_key, default) + v
28 |
29 |
30 | def average_dicts(dict_list):
31 | result = {}
32 | norm = 1e-3
33 | for dct in dict_list:
34 | sum_dict_with_prefix(result, dct, '')
35 | norm += 1
36 | for k in list(result):
37 | result[k] /= norm
38 | return result
39 |
40 |
41 | def add_prefix_to_keys(dct, prefix):
42 | return {prefix + k: v for k, v in dct.items()}
43 |
44 |
45 | def set_requires_grad(module, value):
46 | for param in module.parameters():
47 | param.requires_grad = value
48 |
49 |
50 | def flatten_dict(dct):
51 | result = {}
52 | for k, v in dct.items():
53 | if isinstance(k, tuple):
54 | k = '_'.join(k)
55 | if isinstance(v, dict):
56 | for sub_k, sub_v in flatten_dict(v).items():
57 | result[f'{k}_{sub_k}'] = sub_v
58 | else:
59 | result[k] = v
60 | return result
61 |
62 |
63 | class LinearRamp:
64 | def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
65 | self.start_value = start_value
66 | self.end_value = end_value
67 | self.start_iter = start_iter
68 | self.end_iter = end_iter
69 |
70 | def __call__(self, i):
71 | if i < self.start_iter:
72 | return self.start_value
73 | if i >= self.end_iter:
74 | return self.end_value
75 | part = (i - self.start_iter) / (self.end_iter - self.start_iter)
76 | return self.start_value * (1 - part) + self.end_value * part
77 |
78 |
79 | class LadderRamp:
80 | def __init__(self, start_iters, values):
81 | self.start_iters = start_iters
82 | self.values = values
83 | assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
84 |
85 | def __call__(self, i):
86 | segment_i = bisect.bisect_right(self.start_iters, i)
87 | return self.values[segment_i]
88 |
89 |
90 | def get_ramp(kind='ladder', **kwargs):
91 | if kind == 'linear':
92 | return LinearRamp(**kwargs)
93 | if kind == 'ladder':
94 | return LadderRamp(**kwargs)
95 | raise ValueError(f'Unexpected ramp kind: {kind}')
96 |
97 |
98 | def print_traceback_handler(sig, frame):
99 | LOGGER.warning(f'Received signal {sig}')
100 | bt = ''.join(traceback.format_stack())
101 | LOGGER.warning(f'Requested stack trace:\n{bt}')
102 |
103 |
104 | def register_debug_signal_handlers(sig=None, handler=print_traceback_handler):
105 | LOGGER.warning(f'Setting signal {sig} handler {handler}')
106 | signal.signal(sig, handler)
107 |
108 |
109 | def handle_deterministic_config(config):
110 | seed = dict(config).get('seed', None)
111 | if seed is None:
112 | return False
113 |
114 | seed_everything(seed)
115 | return True
116 |
117 |
118 | def get_shape(t):
119 | if torch.is_tensor(t):
120 | return tuple(t.shape)
121 | elif isinstance(t, dict):
122 | return {n: get_shape(q) for n, q in t.items()}
123 | elif isinstance(t, (list, tuple)):
124 | return [get_shape(q) for q in t]
125 | elif isinstance(t, numbers.Number):
126 | return type(t)
127 | else:
128 | raise ValueError('unexpected type {}'.format(type(t)))
129 |
130 |
131 | def get_has_ddp_rank():
132 | master_port = os.environ.get('MASTER_PORT', None)
133 | node_rank = os.environ.get('NODE_RANK', None)
134 | local_rank = os.environ.get('LOCAL_RANK', None)
135 | world_size = os.environ.get('WORLD_SIZE', None)
136 | has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
137 | return has_rank
138 |
139 |
140 | def handle_ddp_subprocess():
141 | def main_decorator(main_func):
142 | @functools.wraps(main_func)
143 | def new_main(*args, **kwargs):
144 | # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
145 | parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
146 | has_parent = parent_cwd is not None
147 | has_rank = get_has_ddp_rank()
148 | assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
149 |
150 | if has_parent:
151 | # we are in the worker
152 | sys.argv.extend([
153 | f'hydra.run.dir={parent_cwd}',
154 | # 'hydra/hydra_logging=disabled',
155 | # 'hydra/job_logging=disabled'
156 | ])
157 | # do nothing if this is a top-level process
158 | # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
159 |
160 | main_func(*args, **kwargs)
161 | return new_main
162 | return main_decorator
163 |
164 |
165 | def handle_ddp_parent_process():
166 | parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
167 | has_parent = parent_cwd is not None
168 | has_rank = get_has_ddp_rank()
169 | assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
170 |
171 | if parent_cwd is None:
172 | os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
173 |
174 | return has_parent
175 |
--------------------------------------------------------------------------------
/bmab/external/rmbg14/MyConfig.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 | from typing import List
3 |
4 | class RMBGConfig(PretrainedConfig):
5 | model_type = "SegformerForSemanticSegmentation"
6 | def __init__(
7 | self,
8 | in_ch=3,
9 | out_ch=1,
10 | **kwargs):
11 | self.in_ch = in_ch
12 | self.out_ch = out_ch
13 | super().__init__(**kwargs)
14 |
--------------------------------------------------------------------------------
/bmab/external/rmbg14/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/rmbg14/__init__.py
--------------------------------------------------------------------------------
/bmab/external/rmbg14/utilities.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torchvision.transforms.functional import normalize
4 | import numpy as np
5 |
6 | def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
7 | if len(im.shape) < 3:
8 | im = im[:, :, np.newaxis]
9 | # orig_im_size=im.shape[0:2]
10 | im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
11 | im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
12 | image = torch.divide(im_tensor,255.0)
13 | image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
14 | return image
15 |
16 |
17 | def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
18 | result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
19 | ma = torch.max(result)
20 | mi = torch.min(result)
21 | result = (result-mi)/(ma-mi)
22 | im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
23 | im_array = np.squeeze(im_array)
24 | return im_array
25 |
--------------------------------------------------------------------------------
/bmab/nodes/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BMABBasic, BMABBind, BMABSaveImage, BMABText, BMABPreviewText, BMABRemoteAccessAndSave
2 | from .binder import BMABBind, BMABLoraBind
3 | from .cnloader import BMABControlNet, BMABControlNetOpenpose, BMABControlNetIPAdapter, BMABFluxControlNet
4 | from .detailers import BMABFaceDetailer, BMABPersonDetailer, BMABSimpleHandDetailer, BMABSubframeHandDetailer
5 | from .detailers import BMABOpenposeHandDetailer, BMABDetailAnything
6 | from .imaging import BMABDetectionCrop, BMABRemoveBackground, BMABAlphaComposit, BMABBlend
7 | from .imaging import BMABDetectAndMask, BMABLamaInpaint, BMABDetector, BMABSegmentAnything, BMABMasksToImages
8 | from .imaging import BMABLoadImage, BMABEdge, BMABLoadOutputImage, BMABBlackAndWhite, BMABDetectAndPaste
9 | from .loaders import BMABLoraLoader
10 | from .resize import BMABResizeByPerson, BMABResizeByRatio, BMABResizeAndFill, BMABCrop, BMABZoomOut, BMABSquare
11 | from .sampler import BMABKSampler, BMABKSamplerHiresFix, BMABPrompt, BMABIntegrator, BMABSeedGenerator, BMABExtractor
12 | from .sampler import BMABContextNode, BMABKSamplerHiresFixWithUpscaler, BMABImportIntegrator, BMABKSamplerKohyaDeepShrink
13 | from .sampler import BMABClipTextEncoderSDXL, BMABFluxIntegrator, BMABToBind
14 | from .upscaler import BMABUpscale, BMABUpscaleWithModel
15 | from .toy import BMABGoogleGemini
16 | from .a1111api import BMABApiServer, BMABApiSDWebUIT2I, BMABApiSDWebUIT2IHiresFix, BMABApiSDWebUIControlNet
17 | from .a1111api import BMABApiSDWebUIBMABExtension, BMABApiSDWebUII2I
18 | from .utilnode import BMABModelToBind, BMABConditioningToBind, BMABNoiseGenerator
19 | from .watermark import BMABWatermark
20 | from .fill import BMABInpaint, BMABOutpaintByRatio, BMABReframe
21 | from .utilnode import BMABBase64Image, BMABImageStorage, BMABNormalizeSize, BMABDummy
22 |
--------------------------------------------------------------------------------
/bmab/nodes/basic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import glob
5 | import time
6 | import numpy as np
7 | import folder_paths
8 | from comfy.cli_args import args
9 | from PIL.PngImagePlugin import PngInfo
10 |
11 | from PIL import Image, ImageEnhance, ImageOps
12 | from bmab import utils
13 | from bmab.nodes.binder import BMABBind
14 |
15 |
16 | def calc_color_temperature(temp):
17 | white = (255.0, 254.11008387561782, 250.0419083427406)
18 |
19 | temperature = temp / 100
20 |
21 | if temperature <= 66:
22 | red = 255.0
23 | else:
24 | red = float(temperature - 60)
25 | red = 329.698727446 * math.pow(red, -0.1332047592)
26 | if red < 0:
27 | red = 0
28 | if red > 255:
29 | red = 255
30 |
31 | if temperature <= 66:
32 | green = temperature
33 | green = 99.4708025861 * math.log(green) - 161.1195681661
34 | else:
35 | green = float(temperature - 60)
36 | green = 288.1221695283 * math.pow(green, -0.0755148492)
37 | if green < 0:
38 | green = 0
39 | if green > 255:
40 | green = 255
41 |
42 | if temperature >= 66:
43 | blue = 255.0
44 | else:
45 | if temperature <= 19:
46 | blue = 0.0
47 | else:
48 | blue = float(temperature - 10)
49 | blue = 138.5177312231 * math.log(blue) - 305.0447927307
50 | if blue < 0:
51 | blue = 0
52 | if blue > 255:
53 | blue = 255
54 |
55 | return red / white[0], green / white[1], blue / white[2]
56 |
57 |
58 | class BMABBasic:
59 | @classmethod
60 | def INPUT_TYPES(s):
61 | return {
62 | 'required': {
63 | 'contrast': ('FLOAT', {'default': 1.0, 'min': 0, 'max': 2, 'step': 0.05}),
64 | 'brightness': ('FLOAT', {'default': 1.0, 'min': 0, 'max': 2, 'step': 0.05}),
65 | 'sharpeness': ('FLOAT', {'default': 1.0, 'min': -5.0, 'max': 5.0, 'step': 0.1}),
66 | 'color_saturation': ('FLOAT', {'default': 1.0, 'min': 0.0, 'max': 2.0, 'step': 0.01}),
67 | 'color_temperature': ('INT', {'default': 0, 'min': -2000, 'max': 2000, 'step': 1}),
68 | 'noise_alpha': ('FLOAT', {'default': 0, 'min': 0.0, 'max': 1.0, 'step': 0.05}),
69 | },
70 | 'optional': {
71 | 'bind': ('BMAB bind',),
72 | 'image': ('IMAGE',),
73 | },
74 | 'hidden': {'unique_id': 'UNIQUE_ID'}
75 | }
76 |
77 | RETURN_TYPES = ('BMAB bind', 'IMAGE',)
78 | RETURN_NAMES = ('BMAB bind', 'image',)
79 | FUNCTION = 'process'
80 |
81 | CATEGORY = 'BMAB/basic'
82 |
83 | def process(self, contrast, brightness, sharpeness, color_saturation, color_temperature, noise_alpha, unique_id, bind: BMABBind = None, image=None):
84 | if bind is None:
85 | pixels = image
86 | else:
87 | pixels = bind.pixels if image is None else image
88 |
89 | results = []
90 | for bgimg in utils.get_pils_from_pixels(pixels):
91 | if contrast != 1:
92 | enhancer = ImageEnhance.Contrast(bgimg)
93 | bgimg = enhancer.enhance(contrast)
94 |
95 | if brightness != 1:
96 | enhancer = ImageEnhance.Brightness(bgimg)
97 | bgimg = enhancer.enhance(brightness)
98 |
99 | if sharpeness != 1:
100 | enhancer = ImageEnhance.Sharpness(bgimg)
101 | bgimg = enhancer.enhance(sharpeness)
102 |
103 | if color_saturation != 1:
104 | enhancer = ImageEnhance.Color(bgimg)
105 | bgimg = enhancer.enhance(color_saturation)
106 |
107 | if color_temperature != 0:
108 | temp = calc_color_temperature(6500 + color_temperature)
109 | az = []
110 | for d in bgimg.getdata():
111 | az.append((int(d[0] * temp[0]), int(d[1] * temp[1]), int(d[2] * temp[2])))
112 | bgimg = Image.new('RGB', bgimg.size)
113 | bgimg.putdata(az)
114 |
115 | if noise_alpha != 0:
116 | img_noise = utils.generate_noise(0, bgimg.width, bgimg.height)
117 | bgimg = Image.blend(bgimg, img_noise, alpha=noise_alpha)
118 |
119 | results.append(bgimg)
120 |
121 | pixels = utils.get_pixels_from_pils(results)
122 | return BMABBind.result(bind, pixels, )
123 |
124 |
125 | class BMABSaveImage:
126 | def __init__(self):
127 | self.output_dir = folder_paths.get_output_directory()
128 | self.type = 'output'
129 | self.compress_level = 4
130 |
131 | @classmethod
132 | def INPUT_TYPES(s):
133 | return {
134 | 'required': {
135 | 'filename_prefix': ('STRING', {'default': 'bmab'}),
136 | 'format': (['png', 'jpg'], ),
137 | 'use_date': (['disable', 'enable'], ),
138 | },
139 | 'hidden': {
140 | 'prompt': 'PROMPT', 'extra_pnginfo': 'EXTRA_PNGINFO'
141 | },
142 | 'optional': {
143 | 'bind': ('BMAB bind',),
144 | 'images': ('IMAGE',),
145 | }
146 | }
147 |
148 | RETURN_TYPES = ()
149 | FUNCTION = 'save_images'
150 |
151 | OUTPUT_NODE = True
152 |
153 | CATEGORY = 'BMAB/basic'
154 |
155 | @staticmethod
156 | def get_file_sequence(prefix, subdir):
157 | output_dir = os.path.normpath(os.path.join(folder_paths.get_output_directory(), subdir))
158 | find_path = os.path.join(output_dir, f'{prefix}*')
159 | sequence = 0
160 | for f in glob.glob(find_path):
161 | filename = os.path.basename(f)
162 | split_name = filename[len(prefix)+1:].replace('.', '_').split('_')
163 | try:
164 | file_sequence = int(split_name[0])
165 | except:
166 | continue
167 | if file_sequence > sequence:
168 | sequence = file_sequence
169 | return sequence + 1
170 |
171 | @staticmethod
172 | def get_sub_directory(use_date):
173 | if not use_date:
174 | return ''
175 |
176 | dd = time.strftime('%Y-%m-%d', time.localtime(time.time()))
177 | full_output_folder = os.path.join(folder_paths.output_directory, dd)
178 | print(full_output_folder)
179 | if not os.path.exists(full_output_folder):
180 | os.mkdir(full_output_folder)
181 | return dd
182 |
183 | def save_images(self, filename_prefix='bmab', format='png', use_date='disable', prompt=None, extra_pnginfo=None, bind: BMABBind = None, images=None):
184 | if images is None:
185 | images = bind.pixels
186 | output_dir = folder_paths.get_output_directory()
187 | results = list()
188 | use_date = use_date == 'enable'
189 |
190 | subdir = self.get_sub_directory(use_date)
191 | prefix_split = filename_prefix.split('/')
192 | if len(prefix_split) != 1:
193 | filename_prefix = prefix_split[-1]
194 | subdir = os.path.join(subdir, '/'.join(prefix_split[:-1]))
195 | sequence = self.get_file_sequence(filename_prefix, subdir)
196 |
197 | for (batch_number, image) in enumerate(images):
198 |
199 | i = 255. * image.cpu().numpy()
200 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
201 | metadata = None
202 | if not args.disable_metadata:
203 | metadata = PngInfo()
204 | if prompt is not None:
205 | metadata.add_text('prompt', json.dumps(prompt))
206 | if extra_pnginfo is not None:
207 | for x in extra_pnginfo:
208 | metadata.add_text(x, json.dumps(extra_pnginfo[x]))
209 |
210 | if batch_number > 0:
211 | filename = f'{filename_prefix}_{sequence:05}_{batch_number}'
212 | else:
213 | filename = f'{filename_prefix}_{sequence:05}'
214 |
215 | if bind is not None:
216 | file = f'{filename}_{bind.seed}.{format}'
217 | else:
218 | file = f'{filename}.{format}'
219 |
220 | if use_date:
221 | output_dir = os.path.join(output_dir, subdir)
222 |
223 | if not os.path.exists(output_dir):
224 | os.mkdir(output_dir)
225 |
226 | if format == 'png':
227 | img.save(os.path.join(output_dir, file), pnginfo=metadata, compress_level=self.compress_level)
228 | else:
229 | img.save(os.path.join(output_dir, file))
230 |
231 | results.append({
232 | 'filename': file,
233 | 'subfolder': subdir,
234 | 'type': self.type
235 | })
236 |
237 | sequence += 1
238 |
239 | return {'ui': {'images': results}}
240 |
241 |
242 | class BMABText:
243 | @classmethod
244 | def INPUT_TYPES(s):
245 | return {
246 | 'required': {
247 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
248 | },
249 | 'optional': {
250 | 'text': ('STRING', {"forceInput": True}),
251 | }
252 | }
253 |
254 | RETURN_TYPES = ('STRING',)
255 | RETURN_NAMES = ('string',)
256 | FUNCTION = 'export'
257 |
258 | CATEGORY = 'BMAB/basic'
259 |
260 | def export(self, prompt, text=None):
261 | if text is not None:
262 | prompt = prompt.replace('__prompt__', text)
263 | result = utils.parse_prompt(prompt, 0)
264 | return (result,)
265 |
266 |
267 | class BMABPreviewText:
268 | def __init__(self):
269 | pass
270 |
271 | @classmethod
272 | def INPUT_TYPES(s):
273 | return {
274 | "required": {
275 | "text": ("STRING", {"forceInput": True}),
276 | },
277 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
278 | }
279 |
280 | RETURN_TYPES = ("STRING",)
281 | OUTPUT_NODE = True
282 | FUNCTION = "preview_text"
283 |
284 | CATEGORY = "BMAB/basic"
285 |
286 | def preview_text(self, text, prompt=None, extra_pnginfo=None):
287 | return {"ui": {"string": [text, ]}, "result": (text,)}
288 |
289 |
290 | class BMABRemoteAccessAndSave(BMABSaveImage):
291 |
292 | @classmethod
293 | def INPUT_TYPES(s):
294 | return {
295 | 'required': {
296 | 'filename_prefix': ('STRING', {'default': 'bmab'}),
297 | 'format': (['png', 'jpg'], ),
298 | 'use_date': (['disable', 'enable'], ),
299 | 'remote_name': ('STRING', {'multiline': False}),
300 | },
301 | 'hidden': {
302 | 'prompt': 'PROMPT', 'extra_pnginfo': 'EXTRA_PNGINFO'
303 | },
304 | 'optional': {
305 | 'bind': ('BMAB bind',),
306 | 'images': ('IMAGE',),
307 | }
308 | }
309 |
310 | RETURN_TYPES = ()
311 | FUNCTION = 'remote_save_images'
312 |
313 | OUTPUT_NODE = True
314 |
315 | CATEGORY = 'BMAB/basic'
316 |
317 | def remote_save_images(self, filename_prefix='bmab', format='png', use_date='disable', remote_name=None, prompt=None, extra_pnginfo=None, bind: BMABBind = None, images=None):
318 | return self.save_images(filename_prefix, format, use_date, prompt, extra_pnginfo, bind, images)
319 |
320 |
321 |
--------------------------------------------------------------------------------
/bmab/nodes/binder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | import copy
5 |
6 |
7 | class BMABContext:
8 |
9 | def __init__(self, *args) -> None:
10 | super().__init__()
11 | self.seed, self.sampler, self.scheduler, self.cfg_scale, self.steps = args
12 |
13 | def get(self):
14 | return self.seed, self.sampler, self.scheduler, self.cfg_scale, self.steps
15 |
16 | def update(self, steps, cfg_scale, sampler, scheduler):
17 | if steps == 0:
18 | steps = self.steps
19 | if cfg_scale == 0:
20 | cfg_scale = self.cfg_scale
21 | if sampler != 'Use same sampler':
22 | sampler = self.sampler
23 | if scheduler != 'Use same scheduler':
24 | scheduler = self.scheduler
25 | return steps, cfg_scale, sampler, scheduler
26 |
27 |
28 | class BMABBind:
29 |
30 | def __init__(self, *args) -> None:
31 | super().__init__()
32 |
33 | self.model, self.clip, self.vae, self.prompt, self.negative_prompt, self.positive, self.negative, self.latent_image, self.context, self.pixels, self.seed = args
34 |
35 | def copy(self):
36 | return copy.copy(self)
37 |
38 | @staticmethod
39 | def result(bind, pixels, *args):
40 | if bind is None:
41 | return (None, pixels, *args)
42 | else:
43 | bind.pixels = pixels
44 | return (bind, bind.pixels, *args)
45 |
46 | def get(self):
47 | return self.model, self.clip, self.vae, self.prompt, self.negative_prompt, self.positive, self.negative, self.latent_image, self.context, self.pixels, self.seed
48 |
49 |
50 | class BMABLoraBind:
51 | def __init__(self, *args) -> None:
52 | super().__init__()
53 | self.loras = []
54 |
55 | def append(self, *args):
56 | self.loras.append(args)
57 |
58 | def copy(self):
59 | return copy.deepcopy(self)
60 |
--------------------------------------------------------------------------------
/bmab/nodes/fill.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from diffusers import AutoencoderKL, TCDScheduler
4 | from diffusers.models.model_loading_utils import load_state_dict
5 | from huggingface_hub import hf_hub_download
6 |
7 | from bmab.external.fill.controlnet_union import ControlNetModel_Union
8 | from bmab.external.fill.pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
9 |
10 | from PIL import Image, ImageDraw, ImageFilter
11 |
12 | from bmab import utils
13 |
14 |
15 | pipe = None
16 |
17 |
18 | def load():
19 | global pipe
20 |
21 | config_file = hf_hub_download(
22 | "xinsir/controlnet-union-sdxl-1.0",
23 | filename="config_promax.json",
24 | )
25 |
26 | config = ControlNetModel_Union.load_config(config_file)
27 | controlnet_model = ControlNetModel_Union.from_config(config)
28 | model_file = hf_hub_download(
29 | "xinsir/controlnet-union-sdxl-1.0",
30 | filename="diffusion_pytorch_model_promax.safetensors",
31 | )
32 | state_dict = load_state_dict(model_file)
33 | model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
34 | controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
35 | )
36 | model.to(device="cuda", dtype=torch.float16)
37 |
38 | vae = AutoencoderKL.from_pretrained(
39 | "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
40 | ).to("cuda")
41 |
42 | pipe = StableDiffusionXLFillPipeline.from_pretrained(
43 | "SG161222/RealVisXL_V5.0_Lightning",
44 | torch_dtype=torch.float16,
45 | vae=vae,
46 | controlnet=model,
47 | variant="fp16",
48 | ).to("cuda")
49 |
50 | pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
51 |
52 |
53 | def unload():
54 | global pipe
55 | if pipe is not None:
56 | pipe = None
57 | utils.torch_gc()
58 |
59 |
60 | class BMABReframe:
61 |
62 | @classmethod
63 | def INPUT_TYPES(s):
64 | return {
65 | 'required': {
66 | 'image': ('IMAGE',),
67 | 'ratio': (['1:1', '4:5', '2:3', '9:16', '5:4', '3:2', '16:9'],),
68 | 'dilation': ('INT', {'default': 32, 'min': 4, 'max': 128, 'step': 1}),
69 | 'step': ('INT', {'default': 8, 'min': 4, 'max': 128, 'step': 1}),
70 | 'iteration': ('INT', {'default': 4, 'min': 1, 'max': 8, 'step': 1}),
71 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
72 | }
73 | }
74 |
75 | RETURN_TYPES = ('IMAGE',)
76 | RETURN_NAMES = ('image',)
77 | FUNCTION = 'process'
78 |
79 | CATEGORY = 'BMAB/fill'
80 |
81 | ratio_sel = {
82 | '1:1': (1024, 1024),
83 | '4:5': (960, 1200),
84 | '2:3': (896, 1344),
85 | '9:16': (816, 1456),
86 | '5:4': (1200, 960),
87 | '3:2': (1344, 896),
88 | '16:9': (1456, 816)
89 | }
90 |
91 | def infer(self, image, width, height, overlap_width, num_inference_steps, prompt_input):
92 | source = image
93 | image_ratio = source.width / source.height
94 | output_ratio = width / height
95 |
96 | if output_ratio <= image_ratio:
97 | ratio = width / source.width
98 | else:
99 | ratio = height / source.height
100 |
101 | source = source.resize((math.ceil(source.width * ratio), math.ceil(source.height * ratio)), Image.Resampling.LANCZOS)
102 | background = Image.new('RGB', (width, height), (255, 255, 255))
103 | mask = Image.new('L', (width, height), 255)
104 | mask_draw = ImageDraw.Draw(mask)
105 |
106 | if output_ratio <= image_ratio:
107 | margin = (height - source.height) // 2
108 | background.paste(source, (0, margin))
109 | mask_draw.rectangle((0, margin + overlap_width, source.width, margin + source.height - overlap_width), fill=0)
110 | else:
111 | margin = (width - source.width) // 2
112 | background.paste(source, (margin, 0))
113 | mask_draw.rectangle((margin + overlap_width, 0, margin + source.width - overlap_width, source.height), fill=0)
114 |
115 | cnet_image = background.copy()
116 | cnet_image.paste(0, (0, 0), mask)
117 |
118 | final_prompt = f"{prompt_input} , high quality, 4k"
119 |
120 | if pipe is None:
121 | load()
122 |
123 | (
124 | prompt_embeds,
125 | negative_prompt_embeds,
126 | pooled_prompt_embeds,
127 | negative_pooled_prompt_embeds,
128 | ) = pipe.encode_prompt(final_prompt, "cuda", True)
129 |
130 | image = pipe(
131 | prompt_embeds=prompt_embeds,
132 | negative_prompt_embeds=negative_prompt_embeds,
133 | pooled_prompt_embeds=pooled_prompt_embeds,
134 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
135 | image=cnet_image,
136 | num_inference_steps=num_inference_steps
137 | )
138 |
139 | image = image.convert("RGBA")
140 | cnet_image.paste(image, (0, 0), mask)
141 |
142 | return cnet_image
143 |
144 | def process(self, image, ratio, dilation, step, iteration, prompt, **kwargs):
145 |
146 | r = BMABReframe.ratio_sel.get(ratio, (1024, 1024))
147 |
148 | results = []
149 | for image in utils.get_pils_from_pixels(image):
150 | for v in range(0, iteration):
151 | a = self.infer(image, r[0], r[1], dilation, step, prompt_input=prompt)
152 | results.append(a)
153 | pixels = utils.get_pixels_from_pils(results)
154 |
155 | return (pixels,)
156 |
157 |
158 | class BMABOutpaintByRatio:
159 | resize_methods = ['stretching', 'inpaint', 'inpaint+lama']
160 | resize_alignment = ['bottom', 'top', 'top-right', 'right', 'bottom-right', 'bottom-left', 'left', 'top-left', 'center']
161 |
162 | @classmethod
163 | def INPUT_TYPES(s):
164 | return {
165 | 'required': {
166 | 'image': ('IMAGE',),
167 | 'steps': ('INT', {'default': 8, 'min': 0, 'max': 10000}),
168 | 'alignment': (s.resize_alignment,),
169 | 'ratio': ('FLOAT', {'default': 0.85, 'min': 0.1, 'max': 0.95, 'step': 0.01}),
170 | 'dilation': ('INT', {'default': 32, 'min': 4, 'max': 128, 'step': 1}),
171 | 'iteration': ('INT', {'default': 4, 'min': 1, 'max': 8, 'step': 1}),
172 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
173 | },
174 | 'optional': {
175 | }
176 | }
177 |
178 | RETURN_TYPES = ('IMAGE', )
179 | RETURN_NAMES = ('image', )
180 | FUNCTION = 'process'
181 |
182 | CATEGORY = 'BMAB/fill'
183 |
184 | @staticmethod
185 | def image_alignment(image, left, right, top, bottom, ratio):
186 | left = int(left)
187 | top = int(top)
188 | input_image = image.resize((int(image.width * ratio), int(image.height * ratio)), Image.Resampling.LANCZOS)
189 | background = Image.new('RGB', image.size, (255, 255, 255))
190 | background.paste(input_image, box=(left, top))
191 | return background
192 |
193 | @staticmethod
194 | def mask_alignment(width, height, left, right, top, bottom, ratio, dilation):
195 | left = int(left)
196 | top = int(top)
197 | w, h = math.ceil(width * ratio), math.ceil(height * ratio)
198 | mask = Image.new('L', (width, height), 255)
199 | mask_draw = ImageDraw.Draw(mask)
200 | box = (
201 | 0 if left == 0 else left + dilation,
202 | 0 if top == 0 else top + dilation,
203 | width if (left + w) >= width else (left + w - dilation),
204 | height if (top + h) >= height else (top + h - dilation)
205 | )
206 | mask_draw.rectangle(box, fill=0)
207 | return mask
208 |
209 | def infer(self, image, al, ratio, dilation, num_inference_steps, prompt_input):
210 | if al not in utils.alignment:
211 | return image
212 | w, h = math.ceil(image.width * (1 - ratio)), math.ceil(image.height * (1 - ratio))
213 | background = BMABOutpaintByRatio.image_alignment(image, *utils.alignment[al](w, h), ratio)
214 | mask = BMABOutpaintByRatio.mask_alignment(image.width, image.height, *utils.alignment[al](w, h), ratio, dilation)
215 |
216 | cnet_image = background.copy()
217 | cnet_image.paste(0, (0, 0), mask)
218 |
219 | final_prompt = f"{prompt_input} , high quality, 4k"
220 |
221 | if pipe is None:
222 | load()
223 |
224 | (
225 | prompt_embeds,
226 | negative_prompt_embeds,
227 | pooled_prompt_embeds,
228 | negative_pooled_prompt_embeds,
229 | ) = pipe.encode_prompt(final_prompt, "cuda", True)
230 |
231 | image = pipe(
232 | prompt_embeds=prompt_embeds,
233 | negative_prompt_embeds=negative_prompt_embeds,
234 | pooled_prompt_embeds=pooled_prompt_embeds,
235 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
236 | image=cnet_image,
237 | num_inference_steps=num_inference_steps
238 | )
239 |
240 | return image
241 |
242 | def process(self, image, steps, alignment, ratio, dilation, iteration, prompt):
243 |
244 | results = []
245 | for image in utils.get_pils_from_pixels(image):
246 |
247 | print('Process image resize', ratio)
248 | for r in range(0, iteration):
249 | a = self.infer(image, alignment, ratio, dilation, steps, prompt_input=prompt)
250 | results.append(a)
251 |
252 | pixels = utils.get_pixels_from_pils(results)
253 | return (pixels,)
254 |
255 |
256 | class BMABInpaint:
257 |
258 | @classmethod
259 | def INPUT_TYPES(s):
260 | return {
261 | 'required': {
262 | 'image': ('IMAGE',),
263 | 'mask': ('MASK',),
264 | 'steps': ('INT', {'default': 8, 'min': 0, 'max': 10000}),
265 | 'iteration': ('INT', {'default': 4, 'min': 1, 'max': 8, 'step': 1}),
266 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
267 | },
268 | 'optional': {
269 | 'seed': ('SEED',)
270 | }
271 | }
272 |
273 | RETURN_TYPES = ('IMAGE', )
274 | RETURN_NAMES = ('image', )
275 | FUNCTION = 'process'
276 |
277 | CATEGORY = 'BMAB/fill'
278 |
279 | def infer(self, image, mask, steps, prompt_input):
280 |
281 | source = image
282 | source.paste((255, 255, 255), (0, 0), mask)
283 |
284 | cnet_image = source.copy()
285 | cnet_image.paste(0, (0, 0), mask)
286 |
287 | final_prompt = f"{prompt_input} , high quality, 4k"
288 |
289 | if pipe is None:
290 | load()
291 |
292 | (
293 | prompt_embeds,
294 | negative_prompt_embeds,
295 | pooled_prompt_embeds,
296 | negative_pooled_prompt_embeds,
297 | ) = pipe.encode_prompt(final_prompt, "cuda", True)
298 |
299 | image = pipe(
300 | prompt_embeds=prompt_embeds,
301 | negative_prompt_embeds=negative_prompt_embeds,
302 | pooled_prompt_embeds=pooled_prompt_embeds,
303 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
304 | image=cnet_image,
305 | num_inference_steps=steps
306 | )
307 | return image
308 |
309 | def mask_to_image(self, mask):
310 | result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
311 | return utils.get_pils_from_pixels(result)[0].convert('L')
312 |
313 | def process(self, image, mask, steps, iteration, prompt, seed=None):
314 |
315 | results = []
316 | mask = self.mask_to_image(mask)
317 | for image in utils.get_pils_from_pixels(image):
318 | for r in range(0, iteration):
319 | a = self.infer(image, mask, steps, prompt_input=prompt)
320 | results.append(a)
321 |
322 | pixels = utils.get_pixels_from_pils(results)
323 | return (pixels,)
324 |
325 |
--------------------------------------------------------------------------------
/bmab/nodes/loaders.py:
--------------------------------------------------------------------------------
1 | import folder_paths
2 | from bmab.nodes.binder import BMABLoraBind
3 |
4 |
5 | class BMABLoraLoader:
6 | @classmethod
7 | def INPUT_TYPES(s):
8 | return {
9 | 'required': {
10 | 'lora_name': (folder_paths.get_filename_list('loras'), ),
11 | 'strength_model': ('FLOAT', {'default': 1.0, 'min': -100.0, 'max': 100.0, 'step': 0.01}),
12 | 'strength_clip': ('FLOAT', {'default': 1.0, 'min': -100.0, 'max': 100.0, 'step': 0.01}),
13 | },
14 | 'optional': {
15 | 'lora': ('BMAB lora',),
16 | }
17 | }
18 |
19 | RETURN_TYPES = ('BMAB lora', )
20 | RETURN_NAMES = ('lora', )
21 | FUNCTION = 'load_lora'
22 |
23 | CATEGORY = 'BMAB/loader'
24 |
25 | def load_lora(self, lora_name, strength_model, strength_clip, lora: BMABLoraBind=None):
26 | if lora is None:
27 | lora = BMABLoraBind()
28 | else:
29 | lora = lora.copy()
30 | lora.append(lora_name, strength_model, strength_clip)
31 | return (lora, )
32 |
33 |
--------------------------------------------------------------------------------
/bmab/nodes/toy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 | import bmab
5 |
6 | from bmab import utils
7 | from server import PromptServer
8 |
9 |
10 | question = '''
11 | make detailed prompt for stable diffusion using keyword "{text}" about scene, lighting, face, pose, clothes, background and colors in only 1 sentence. The sentence is describes very detailed. Do not say about human race.
12 | '''
13 |
14 |
15 | class BMABGoogleGemini:
16 |
17 | def __init__(self) -> None:
18 | super().__init__()
19 | self.last_prompt = None
20 | self.last_text = None
21 |
22 | @classmethod
23 | def INPUT_TYPES(s):
24 | return {
25 | 'required': {
26 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
27 | 'text': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
28 | 'api_key': ('STRING', {'multiline': False}),
29 | 'random_seed': ('INT', {'default': 0, 'min': 0, 'max': 65536, 'step': 1}),
30 | },
31 | }
32 |
33 | RETURN_TYPES = ('STRING',)
34 | RETURN_NAMES = ('string', )
35 | FUNCTION = 'prompt'
36 |
37 | CATEGORY = 'BMAB/toy'
38 |
39 | def get_prompt(self, text, api_key):
40 | import google.generativeai as genai
41 | genai.configure(api_key=api_key)
42 | model = genai.GenerativeModel('gemini-pro')
43 | text = text.strip()
44 | response = model.generate_content(question.format(text=text))
45 | try:
46 | self.last_prompt = response.text
47 | print(response.text)
48 | cache_path = os.path.join(os.path.dirname(bmab.__file__), '../resources/cache')
49 | cache_file = os.path.join(cache_path, 'gemini.txt')
50 | with open(cache_file, 'a', encoding='utf8') as f:
51 | f.write(time.strftime('%Y.%m.%d - %H:%M:%S'))
52 | f.write('\n')
53 | f.write(self.last_prompt)
54 | f.write('\n')
55 | except:
56 | print('ERROR reading API response', response)
57 | self.last_text = None
58 | PromptServer.instance.send_sync("stop-iteration", {})
59 |
60 | return self.last_prompt
61 |
62 | def prompt(self, prompt: str, text: str, api_key, random_seed=None, **kwargs):
63 | random_seed = random.randint(0, 65535)
64 | if prompt.find('__prompt__') >= 0:
65 | if self.last_text != text:
66 | random_seed = random.randint(0, 65535)
67 | self.last_text = text
68 | self.get_prompt(text, api_key)
69 | if self.last_prompt is None:
70 | PromptServer.instance.send_sync("stop-iteration", {})
71 | prompt = prompt.replace('__prompt__', self.last_prompt)
72 | result = utils.parse_prompt(prompt, random_seed)
73 | return {"ui": {"string": [str(random_seed), ]}, "result": (result,)}
74 |
--------------------------------------------------------------------------------
/bmab/nodes/upscaler.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from PIL import ImageDraw
3 |
4 | from comfy_extras.chainner_models import model_loading
5 | from comfy import model_management
6 | import torch
7 | import comfy.utils
8 | import folder_paths
9 |
10 | import nodes
11 | from bmab import utils
12 | from bmab.nodes.binder import BMABBind
13 |
14 |
15 | class BMABUpscale:
16 | upscale_methods = ['LANCZOS', 'NEAREST', 'BILINEAR', 'BICUBIC']
17 |
18 | @classmethod
19 | def INPUT_TYPES(s):
20 | return {
21 | 'required': {
22 | 'upscale_method': (BMABUpscale.upscale_methods, ),
23 | 'scale': ('FLOAT', {'default': 2.0, 'min': 0, 'max': 4.0, 'step': 0.001}),
24 | 'width': ('INT', {'default': 512, 'min': 32, 'max': nodes.MAX_RESOLUTION, 'step': 8}),
25 | 'height': ('INT', {'default': 512, 'min': 32, 'max': nodes.MAX_RESOLUTION, 'step': 8}),
26 | },
27 | 'optional': {
28 | 'bind': ('BMAB bind',),
29 | 'image': ('IMAGE',),
30 | },
31 | }
32 |
33 | RETURN_TYPES = ('BMAB bind', 'IMAGE',)
34 | RETURN_NAMES = ('BMAB bind', 'image', )
35 | FUNCTION = 'upscale'
36 |
37 | CATEGORY = 'BMAB/upscale'
38 |
39 | def upscale(self, upscale_method, scale, width, height, bind: BMABBind=None, image=None):
40 | pixels = bind.pixels if image is None else image
41 | pil_upscale_methods = {
42 | 'LANCZOS': Image.Resampling.LANCZOS,
43 | 'BILINEAR': Image.Resampling.BILINEAR,
44 | 'BICUBIC': Image.Resampling.BICUBIC,
45 | 'NEAREST': Image.Resampling.NEAREST,
46 | }
47 | results = []
48 | for bgimg in utils.get_pils_from_pixels(pixels):
49 | if scale != 0:
50 | width, height = int(bgimg.width * scale), int(bgimg.height * scale)
51 | method = pil_upscale_methods.get(upscale_method)
52 | results.append(bgimg.resize((width, height), method))
53 | pixels = utils.get_pixels_from_pils(results)
54 | return BMABBind.result(bind, pixels, )
55 |
56 |
57 | class BMABUpscaleWithModel:
58 | @classmethod
59 | def INPUT_TYPES(s):
60 | return {
61 | "required": {
62 | "model_name": (folder_paths.get_filename_list("upscale_models"),),
63 | 'scale': ('FLOAT', {'default': 2.0, 'min': 0, 'max': 4.0, 'step': 0.001}),
64 | 'width': ('INT', {'default': 512, 'min': 0, 'max': nodes.MAX_RESOLUTION, 'step': 8}),
65 | 'height': ('INT', {'default': 512, 'min': 0, 'max': nodes.MAX_RESOLUTION, 'step': 8}),
66 | },
67 | 'optional': {
68 | 'bind': ('BMAB bind',),
69 | 'image': ('IMAGE',),
70 | },
71 | }
72 |
73 | RETURN_TYPES = ('BMAB bind', "IMAGE",)
74 | RETURN_NAMES = ('BMAB bind', 'image', )
75 | FUNCTION = "upscale"
76 |
77 | CATEGORY = "BMAB/upscale"
78 |
79 | def load_model(self, model_name):
80 | model_path = folder_paths.get_full_path("upscale_models", model_name)
81 | sd = comfy.utils.load_torch_file(model_path, safe_load=True)
82 | if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
83 | sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": ""})
84 | out = model_loading.load_state_dict(sd).eval()
85 | return out
86 |
87 | def upscale_with_model(self, model_name, pixels, progress=True):
88 | upscale_model = self.load_model(model_name)
89 | device = model_management.get_torch_device()
90 |
91 | memory_required = model_management.module_size(upscale_model.model)
92 | memory_required += (512 * 512 * 3) * pixels.element_size() * max(upscale_model.scale, 1.0) * 384.0 # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
93 | memory_required += pixels.nelement() * pixels.element_size()
94 | model_management.free_memory(memory_required, device)
95 |
96 | upscale_model.to(device)
97 | in_img = pixels.movedim(-1, -3).to(device)
98 |
99 | tile = 512
100 | overlap = 32
101 |
102 | oom = True
103 | while oom:
104 | try:
105 | if progress:
106 | steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
107 | pbar = comfy.utils.ProgressBar(steps)
108 | s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
109 | else:
110 | s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale)
111 | oom = False
112 | except model_management.OOM_EXCEPTION as e:
113 | tile //= 2
114 | if tile < 128:
115 | raise e
116 |
117 | upscale_model.to("cpu")
118 | s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
119 | return (s,)
120 |
121 | def upscale(self, model_name, scale, width, height, bind: BMABBind=None, image=None):
122 | pixels = bind.pixels if image is None else image
123 | if scale != 0:
124 | _, h, w, c = pixels.shape
125 | width, height = int(w * scale), int(h * scale)
126 |
127 | s = self.upscale_with_model(model_name, pixels)
128 | pil_images = utils.get_pils_from_pixels(s)
129 | results = [img.resize((width, height), Image.Resampling.LANCZOS) for img in pil_images]
130 | pixels = utils.get_pixels_from_pils(results)
131 |
132 | return BMABBind.result(bind, pixels,)
133 |
--------------------------------------------------------------------------------
/bmab/nodes/utilnode.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 |
5 | import base64
6 | from io import BytesIO
7 | from PIL import Image
8 |
9 | import bmab
10 | from bmab import utils
11 | from bmab.nodes.binder import BMABBind
12 | from bmab import serverext
13 |
14 |
15 | class BMABModelToBind:
16 |
17 | @classmethod
18 | def INPUT_TYPES(s):
19 | return {
20 | 'required': {
21 | 'bind': ('BMAB bind',),
22 | },
23 | 'optional': {
24 | 'model': ('MODEL',),
25 | 'clip': ('CLIP',),
26 | 'vae': ('VAE',),
27 | }
28 | }
29 |
30 | RETURN_TYPES = ('BMAB bind', )
31 | RETURN_NAMES = ('bind', )
32 | FUNCTION = 'process'
33 |
34 | CATEGORY = 'BMAB/utils'
35 |
36 | def process(self, bind: BMABBind, model=None, clip=None, vae=None):
37 | if model is not None:
38 | bind.model = model
39 | if clip is not None:
40 | bind.clip = clip
41 | if vae is not None:
42 | bind.vae = vae
43 | return (bind, )
44 |
45 |
46 | class BMABConditioningToBind:
47 |
48 | @classmethod
49 | def INPUT_TYPES(s):
50 | return {
51 | 'required': {
52 | 'bind': ('BMAB bind',),
53 | },
54 | 'optional': {
55 | 'positive': ('CONDITIONING',),
56 | 'negative': ('CONDITIONING',),
57 | }
58 | }
59 |
60 | RETURN_TYPES = ('BMAB bind', )
61 | RETURN_NAMES = ('bind', )
62 | FUNCTION = 'process'
63 |
64 | CATEGORY = 'BMAB/utils'
65 |
66 | def process(self, bind: BMABBind, positive=None, negative=None):
67 | if positive is not None:
68 | bind.positive = positive
69 | if negative is not None:
70 | bind.negative = negative
71 | return (bind, )
72 |
73 |
74 | class BMABNoiseGenerator:
75 |
76 | @classmethod
77 | def INPUT_TYPES(s):
78 | return {
79 | 'required': {
80 | 'width': ('INT', {'default': 512, 'min': 256, 'max': 2048, 'step': 8}),
81 | 'height': ('INT', {'default': 512, 'min': 256, 'max': 2048, 'step': 8}),
82 | },
83 | 'optional': {
84 | 'bind': ('BMAB bind',),
85 | 'latent': ('LATENT',),
86 | }
87 | }
88 |
89 | RETURN_TYPES = ('IMAGE', )
90 | RETURN_NAMES = ('image', )
91 | FUNCTION = 'generate'
92 |
93 | CATEGORY = 'BMAB/utils'
94 |
95 | @staticmethod
96 | def generate_noise(seed, width, height):
97 | img_1 = np.zeros([height, width, 3], dtype=np.uint8)
98 | # Generate random Gaussian noise
99 | mean = 0
100 | stddev = 180
101 | r, g, b = cv2.split(img_1)
102 | # cv2.setRNGSeed(seed)
103 | cv2.randn(r, mean, stddev)
104 | cv2.randn(g, mean, stddev)
105 | cv2.randn(b, mean, stddev)
106 | img = cv2.merge([r, g, b])
107 | pil_image = Image.fromarray(img, mode='RGB')
108 | return pil_image
109 |
110 | def generate(self, width, height, bind: BMABBind=None, latent=None):
111 | if bind is not None:
112 | width, height = utils.get_shape(bind.latent_image)
113 | if latent is not None:
114 | width, height = utils.get_shape(latent)
115 |
116 | cache_path = os.path.join(os.path.dirname(bmab.__file__), '../resources/cache')
117 | filename = f'noise_{width}_{height}.png'
118 | full_path = os.path.join(cache_path, filename)
119 | if os.path.exists(full_path) and os.path.isfile(full_path):
120 | noise = Image.open(full_path)
121 | return (utils.get_pixels_from_pils([noise]), )
122 |
123 | noise = self.generate_noise(0, width, height)
124 | noise.save(full_path)
125 | return (utils.get_pixels_from_pils([noise]),)
126 |
127 |
128 | class BMABBase64Image:
129 |
130 | def __init__(self) -> None:
131 | super().__init__()
132 |
133 | @classmethod
134 | def INPUT_TYPES(s):
135 | return {
136 | 'required': {
137 | 'encoding': ('STRING', {'multiline': True, 'dynamicPrompts': True}),
138 | },
139 | }
140 |
141 | RETURN_TYPES = ('IMAGE', 'INT', 'INT')
142 | RETURN_NAMES = ('image', 'width', 'height')
143 | FUNCTION = 'process'
144 |
145 | CATEGORY = 'BMAB/utils'
146 |
147 | def process(self, encoding):
148 | results = []
149 | pil = Image.open(BytesIO(base64.b64decode(encoding)))
150 | results.append(pil)
151 | return utils.get_pixels_from_pils(results), pil.width, pil.height
152 |
153 |
154 | class BMABImageStorage:
155 |
156 | def __init__(self) -> None:
157 | super().__init__()
158 |
159 | @classmethod
160 | def INPUT_TYPES(s):
161 | return {
162 | 'required': {
163 | 'images': ('IMAGE',),
164 | 'client_id': ('STRING', {'multiline': False}),
165 | },
166 | }
167 |
168 | RETURN_TYPES = ()
169 | FUNCTION = 'process'
170 |
171 | CATEGORY = 'BMAB/utils'
172 | OUTPUT_NODE = True
173 |
174 | def process(self, images, client_id):
175 | results = []
176 | for image in utils.get_pils_from_pixels(images):
177 | results.append(image)
178 | serverext.memory_image_storage[client_id] = results
179 | return {'ui': {'images': []}}
180 |
181 |
182 | class BMABNormalizeSize:
183 |
184 | def __init__(self) -> None:
185 | super().__init__()
186 |
187 | @classmethod
188 | def INPUT_TYPES(s):
189 | return {
190 | 'required': {
191 | 'width': ('INT', {'default': 512, 'min': 256, 'max': 2048, 'step': 8}),
192 | 'height': ('INT', {'default': 768, 'min': 256, 'max': 2048, 'step': 8}),
193 | 'normalize': ('INT', {'default': 768, 'min': 256, 'max': 2048, 'step': 8}),
194 | },
195 | }
196 |
197 | RETURN_TYPES = ('INT', 'INT')
198 | RETURN_NAMES = ('width', 'height')
199 | FUNCTION = 'process'
200 |
201 | CATEGORY = 'BMAB/utils'
202 | OUTPUT_NODE = True
203 |
204 | def process(self, width, height, normalize):
205 | print(width, height)
206 | if height > width:
207 | ratio = normalize / height
208 | w, h = int(width * ratio), normalize
209 | else:
210 | ratio = normalize / width
211 | w, h = normalize, int(height * ratio)
212 | return w, h
213 |
214 |
215 | class BMABDummy:
216 |
217 | def __init__(self) -> None:
218 | super().__init__()
219 |
220 | @classmethod
221 | def INPUT_TYPES(s):
222 | return {
223 | 'required': {
224 | 'images': ('IMAGE',),
225 | 'seed': ('INT', {'default': 0, 'min': 0, 'max': 0xffffffffffffffff, 'tooltip': 'The random seed used for creating the noise.'}),
226 | },
227 | }
228 |
229 | RETURN_TYPES = ('IMAGE', )
230 | RETURN_NAMES = ('images', )
231 | FUNCTION = 'process'
232 |
233 | CATEGORY = 'BMAB/utils'
234 | OUTPUT_NODE = True
235 |
236 | def process(self, images, seed):
237 | return (images, )
238 |
239 |
--------------------------------------------------------------------------------
/bmab/nodes/watermark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 |
5 | from PIL import Image
6 | from PIL import ImageDraw
7 | from PIL import ImageFont
8 |
9 | from bmab import utils
10 |
11 |
12 | class BMABWatermark:
13 | alignment = {
14 | 'bottom-left': lambda w, h, cx, cy: (0, h - cy),
15 | 'top': lambda w, h, cx, cy: (w / 2 - cx / 2, 0),
16 | 'top-right': lambda w, h, cx, cy: (w - cx, 0),
17 | 'right': lambda w, h, cx, cy: (w - cx, h / 2 - cy / 2),
18 | 'bottom-right': lambda w, h, cx, cy: (w - cx, h - cy),
19 | 'bottom': lambda w, h, cx, cy: (w / 2 - cx / 2, h - cy),
20 | 'left': lambda w, h, cx, cy: (0, h / 2 - cy / 2),
21 | 'top-left': lambda w, h, cx, cy: (0, 0),
22 | 'center': lambda w, h, cx, cy: (w / 2 - cx / 2, h / 2 - cy / 2),
23 | }
24 |
25 |
26 | @classmethod
27 | def INPUT_TYPES(s):
28 | return {
29 | 'required': {
30 | 'font': (s.list_fonts(),),
31 | 'alignment': ([x for x in s.alignment.keys()], ),
32 | 'text_alignment': (['left', 'right', 'center'], ),
33 | 'rotate': ([0, 90, 180, 270], ),
34 | 'color': ('STRING', {'default': '#000000'}),
35 | 'background_color': ('STRING', {'default': '#000000'}),
36 | 'font_size': ('INT', {'default': 12, 'min': 4, 'max': 128}),
37 | 'transparency': ('INT', {'default': 100, 'min': 0, 'max': 100}),
38 | 'background_transparency': ('INT', {'default': 0, 'min': 0, 'max': 100}),
39 | 'margin': ('INT', {'default': 5, 'min': 0, 'max': 100}),
40 | 'text': ('STRING', {'multiline': True}),
41 | },
42 | 'optional': {
43 | 'bind': ('BMAB bind',),
44 | 'image': ('IMAGE',),
45 | }
46 | }
47 |
48 | RETURN_TYPES = ('BMAB bind', 'IMAGE',)
49 | RETURN_NAMES = ('BMAB bind', 'image',)
50 | FUNCTION = 'process'
51 |
52 | CATEGORY = 'BMAB/basic'
53 |
54 | def process_watermark(self, img, font, alignment, text_alignment, rotate, color, background_color, font_size, transparency, background_transparency, margin, text):
55 | background_color = self.color_hex_to_rgb(background_color, int(255 * (background_transparency / 100)))
56 |
57 | if os.path.isfile(text):
58 | cropped = Image.open(text)
59 | else:
60 | font = self.get_font(font, font_size)
61 | color = self.color_hex_to_rgb(color, int(255 * (transparency / 100)))
62 |
63 | # 1st
64 | base = Image.new('RGBA', img.size, background_color)
65 | draw = ImageDraw.Draw(base)
66 | bbox = draw.textbbox((0, 0), text, font=font)
67 | draw.text((0, 0), text, font=font, fill=color, align=text_alignment)
68 | cropped = base.crop(bbox)
69 |
70 | # 2st margin
71 | base = Image.new('RGBA', (cropped.width + margin * 2, cropped.height + margin * 2), background_color)
72 | base.paste(cropped, (margin, margin))
73 |
74 | # 3rd rotate
75 | base = base.rotate(rotate, expand=True)
76 |
77 | # 4th
78 | image = img.convert('RGBA')
79 | image2 = image.copy()
80 | x, y = BMABWatermark.alignment[alignment](image.width, image.height, base.width, base.height)
81 | image2.paste(base, (int(x), int(y)))
82 | return Image.alpha_composite(image, image2)
83 |
84 | def process(self, image=None, bind=None, **kwargs):
85 | pixels = bind.pixels if image is None else image
86 | results = []
87 | for img in utils.get_pils_from_pixels(pixels):
88 | results.append(self.process_watermark(img, **kwargs))
89 | pixels = utils.get_pixels_from_pils(results)
90 | return (bind, pixels, )
91 |
92 | @staticmethod
93 | def color_hex_to_rgb(value, transparency):
94 | value = value.lstrip('#')
95 | lv = len(value)
96 | r, g, b = tuple(int(value[i:i + 2], 16) for i in range(0, lv, 2))
97 | return r, g, b, transparency
98 |
99 | @staticmethod
100 | def list_fonts():
101 | if sys.platform == 'win32':
102 | path = 'C:\\Windows\\Fonts\\*.ttf'
103 | files = glob.glob(path)
104 | return [os.path.basename(f) for f in files]
105 | if sys.platform == 'darwin':
106 | path = '/System/Library/Fonts/*'
107 | files = glob.glob(path)
108 | return [os.path.basename(f) for f in files]
109 | if sys.platform == 'linux':
110 | path = '/usr/share/fonts/*'
111 | files = glob.glob(path)
112 | fonts = [os.path.basename(f) for f in files]
113 | if 'SAGEMAKER_INTERNAL_IMAGE_URI' in os.environ:
114 | path = '/opt/conda/envs/sagemaker-distribution/fonts/*'
115 | files = glob.glob(path)
116 | fonts.extend([os.path.basename(f) for f in files])
117 | return fonts
118 | return ['']
119 |
120 | @staticmethod
121 | def get_font(font, size):
122 | if sys.platform == 'win32':
123 | path = f'C:\\Windows\\Fonts\\{font}'
124 | return ImageFont.truetype(path, size, encoding="unic")
125 | if sys.platform == 'darwin':
126 | path = f'/System/Library/Fonts/{font}'
127 | return ImageFont.truetype(path, size, encoding="unic")
128 | if sys.platform == 'linux':
129 | if 'SAGEMAKER_INTERNAL_IMAGE_URI' in os.environ:
130 | path = f'/opt/conda/envs/sagemaker-distribution/fonts/{font}'
131 | else:
132 | path = f'/usr/share/fonts/{font}'
133 | return ImageFont.truetype(path, size, encoding="unic")
134 |
--------------------------------------------------------------------------------
/bmab/process.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from PIL import Image
4 | from PIL import ImageDraw
5 | from PIL import ImageFilter
6 |
7 | import comfy
8 | import nodes
9 | import folder_paths
10 |
11 | from bmab import utils
12 | from bmab.nodes import BMABBind
13 | from bmab.utils.color import apply_color_correction
14 |
15 |
16 | def process_img2img(bind: BMABBind, image, params):
17 | steps, cfg, sampler_name, scheduler, denoise = params['steps'], params['cfg_scale'], params['sampler_name'], params['scheduler'], params['denoise']
18 | pixels = utils.pil2tensor(image.convert('RGB'))
19 | latent = dict(samples=bind.vae.encode(pixels))
20 | samples = nodes.common_ksampler(bind.model, bind.seed, steps, cfg, sampler_name, scheduler, bind.positive, bind.negative, latent, denoise=denoise)[0]
21 | latent = bind.vae.decode(samples["samples"])
22 | result = utils.tensor2pil(latent)
23 | return result
24 |
25 |
26 | def process_img2img_with_mask(bind: BMABBind, image, params, mask=None, box=None):
27 | width, height, padding, dilation = params['width'], params['height'], params['padding'], params['dilation']
28 |
29 | if box is None:
30 | box = mask.getbbox()
31 | if box is None:
32 | return image
33 |
34 | if mask is None:
35 | mask = Image.new('L', image.size, 0)
36 | dr = ImageDraw.Draw(mask, 'L')
37 | dr.rectangle(box, fill=255)
38 |
39 | x1, y1, x2, y2 = tuple(int(x) for x in box)
40 |
41 | cbx = utils.get_box_with_padding(image, (x1, y1, x2, y2), padding)
42 | cropped = image.crop(cbx)
43 | resized = utils.resize_and_fill(cropped, width, height)
44 | processed = process_img2img(bind, resized, params)
45 | processed = apply_color_correction(resized, processed)
46 |
47 | iratio = width / height
48 | cratio = cropped.width / cropped.height
49 | if iratio < cratio:
50 | ratio = cropped.width / width
51 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio)))
52 | y0 = (processed.height - cropped.height) // 2
53 | processed = processed.crop((0, y0, cropped.width, y0 + cropped.height))
54 | else:
55 | ratio = cropped.height / height
56 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio)))
57 | x0 = (processed.width - cropped.width) // 2
58 | processed = processed.crop((x0, 0, x0 + cropped.width, cropped.height))
59 |
60 | img = image.copy()
61 | img.paste(processed, (cbx[0], cbx[1]))
62 |
63 | pil_mask = utils.dilate_mask(mask, dilation)
64 | blur = ImageFilter.GaussianBlur(dilation)
65 | blur_mask = pil_mask.filter(blur)
66 |
67 | image.paste(img, (0, 0), mask=blur_mask)
68 | return image
69 |
70 |
71 | def load_controlnet(control_net_name):
72 | controlnet_path = folder_paths.get_full_path('controlnet', control_net_name)
73 | controlnet = comfy.controlnet.load_controlnet(controlnet_path)
74 | return controlnet
75 |
76 |
77 | def apply_controlnet(control_net_name, positive, negative, strength, start_percent, end_percent, image):
78 | control_net = load_controlnet(control_net_name)
79 |
80 | control_hint = image.movedim(-1, 1)
81 | cnets = {}
82 |
83 | out = []
84 | for conditioning in [positive, negative]:
85 | c = []
86 | for t in conditioning:
87 | d = t[1].copy()
88 |
89 | prev_cnet = d.get('control', None)
90 | if prev_cnet in cnets:
91 | c_net = cnets[prev_cnet]
92 | else:
93 | c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
94 | c_net.set_previous_controlnet(prev_cnet)
95 | cnets[prev_cnet] = c_net
96 |
97 | d['control'] = c_net
98 | d['control_apply_to_uncond'] = False
99 | n = [t[0], d]
100 | c.append(n)
101 | out.append(c)
102 | return out[0], out[1]
103 |
104 |
105 | def preprocess(image, mask):
106 | mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
107 | mask = mask.movedim(1, -1).expand((-1, -1, -1, 3))
108 | image = image.clone()
109 | image[mask > 0.5] = -1.0
110 | return image
111 |
112 |
113 | def process_img2img_with_controlnet(bind: BMABBind, image, params, controlnet_name, mask=None):
114 | steps, cfg, sampler_name, scheduler, denoise = params['steps'], params['cfg_scale'], params['sampler_name'], params['scheduler'], params['denoise']
115 |
116 | pixels = utils.pil2tensor(image.convert('RGB'))
117 | latent = dict(samples=bind.vae.encode(pixels))
118 |
119 | if mask is not None:
120 | mask_pixels = utils.pil2tensor(mask.convert('RGB'))
121 | cn_pixels = preprocess(pixels, mask_pixels[:, :, :, 0])
122 | else:
123 | cn_pixels = pixels
124 | positive, negative = apply_controlnet(controlnet_name, bind.positive, bind.negative, 1.0, 0.0, 1.0, cn_pixels)
125 |
126 | samples = nodes.common_ksampler(bind.model, bind.seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=denoise)[0]
127 | latent = bind.vae.decode(samples["samples"])
128 | result = utils.tensor2pil(latent)
129 | return result
130 |
131 |
132 | def process_img2img_with_controlnet_mask(bind: BMABBind, controlnet_name, image, params, mask=None, box=None):
133 | width, height, padding, dilation = params['width'], params['height'], params['padding'], params['dilation']
134 |
135 | if box is None:
136 | box = mask.getbbox()
137 | if box is None:
138 | return image
139 |
140 | if mask is None:
141 | mask = Image.new('L', image.size, 0)
142 | dr = ImageDraw.Draw(mask, 'L')
143 | dr.rectangle(box, fill=255)
144 |
145 | x1, y1, x2, y2 = tuple(int(x) for x in box)
146 |
147 | cbx = utils.get_box_with_padding(image, (x1, y1, x2, y2), padding)
148 | cropped = image.crop(cbx)
149 | resized = utils.resize_and_fill(cropped, width, height)
150 | processed = process_img2img_with_controlnet(bind, resized, params, controlnet_name, mask=mask)
151 | processed = apply_color_correction(resized, processed)
152 |
153 | iratio = width / height
154 | cratio = cropped.width / cropped.height
155 | if iratio < cratio:
156 | ratio = cropped.width / width
157 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio)))
158 | y0 = (processed.height - cropped.height) // 2
159 | processed = processed.crop((0, y0, cropped.width, y0 + cropped.height))
160 | else:
161 | ratio = cropped.height / height
162 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio)))
163 | x0 = (processed.width - cropped.width) // 2
164 | processed = processed.crop((x0, 0, x0 + cropped.width, cropped.height))
165 |
166 | img = image.copy()
167 | img.paste(processed, (cbx[0], cbx[1]))
168 |
169 | pil_mask = utils.dilate_mask(mask, dilation)
170 | blur = ImageFilter.GaussianBlur(dilation)
171 | blur_mask = pil_mask.filter(blur)
172 |
173 | image.paste(img, (0, 0), mask=blur_mask)
174 | return image
175 |
176 |
177 |
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/bmab/serverext.py:
--------------------------------------------------------------------------------
1 | from server import PromptServer
2 | from aiohttp import web
3 |
4 | from PIL import Image
5 | from bmab.nodes import fill, upscaler
6 | import base64
7 | from io import BytesIO
8 | from bmab import utils
9 | from comfy import utils as cutils
10 | from bmab.utils import yolo, sam
11 |
12 | memory_image_storage = {}
13 |
14 |
15 | def b64_encoding(image):
16 | buffered = BytesIO()
17 | image.save(buffered, format="PNG")
18 | return base64.b64encode(buffered.getvalue()).decode("utf-8")
19 |
20 |
21 | def b64_decoding(b64):
22 | return Image.open(BytesIO(base64.b64decode(b64)))
23 |
24 |
25 | client_ids = {}
26 |
27 |
28 | @PromptServer.instance.routes.get("/bmab")
29 | async def bmab_register_client(request):
30 | remote_client_id = request.rel_url.query.get("remote_client_id")
31 | remote_name = request.rel_url.query.get("remote_name")
32 |
33 | print(remote_client_id, remote_name)
34 | client_ids[remote_client_id] = {'name': remote_name}
35 | data = {'name': remote_name, 'client_id': remote_client_id}
36 | return web.json_response(data)
37 |
38 |
39 | @PromptServer.instance.routes.get("/bmab/remote")
40 | async def bmab_remote(request):
41 | command = request.rel_url.query.get("command")
42 | name = request.rel_url.query.get("name")
43 | data = {'command': command, 'name': name}
44 |
45 | if command == 'queue':
46 | for sid, v in client_ids.items():
47 | if v.get('name') == name:
48 | await PromptServer.instance.send("bmab_queue", {"status": '', 'sid': sid}, sid)
49 | data['client_id'] = sid
50 |
51 | return web.json_response(data)
52 |
53 |
54 | @PromptServer.instance.routes.post("/bmab/outpaintbyratio")
55 | async def bmab_outpaintbyratio(request):
56 | j = await request.json()
57 | b64img = j.get('image')
58 | if b64img is not None:
59 | prompt = j.get('prompt')
60 | align = j.get('align', 'bottom')
61 | ratio = j.get('ratio', 0.85)
62 | dilation = j.get('dilation', 16)
63 | steps = j.get('steps', 8)
64 |
65 | filler = fill.BMABOutpaintByRatio()
66 | img = filler.infer(b64_decoding(b64img), align, ratio, dilation, steps, prompt)
67 | data = {'image': b64_encoding(img)}
68 | return web.json_response(data)
69 | else:
70 | print('release')
71 | fill.unload()
72 | return web.json_response({})
73 |
74 |
75 | @PromptServer.instance.routes.post("/bmab/reframe")
76 | async def bmab_reframe(request):
77 | j = await request.json()
78 | b64img = j.get('image')
79 | if b64img is not None:
80 | prompt = j.get('prompt')
81 | ratio = j.get('ratio', '1:1')
82 | dilation = j.get('dilation', 16)
83 | steps = j.get('steps', 8)
84 | r = fill.BMABReframe.ratio_sel.get(ratio, (1024, 1024))
85 |
86 | filler = fill.BMABReframe()
87 | img = filler.infer(b64_decoding(b64img), r[0], r[1], dilation, steps, prompt)
88 | data = {'image': b64_encoding(img)}
89 | return web.json_response(data)
90 | else:
91 | print('release')
92 | fill.unload()
93 | return web.json_response({})
94 |
95 |
96 | @PromptServer.instance.routes.post("/bmab/upscale")
97 | async def bmab_upscale(request):
98 | from comfy import utils as cutils
99 |
100 | j = await request.json()
101 | b64img = j.get('image')
102 | model = j.get('model')
103 | scale = j.get('scale', '2')
104 | width = j.get('width', 0)
105 | height = j.get('height', 0)
106 |
107 | hook = cutils.PROGRESS_BAR_HOOK
108 | cutils.PROGRESS_BAR_HOOK = None
109 | try:
110 | up = upscaler.BMABUpscaleWithModel()
111 | img = b64_decoding(b64img)
112 | if scale != 0:
113 | width, height = int(img.width * scale), int(img.height * scale)
114 | pixels = utils.get_pixels_from_pils([img])
115 | s = up.upscale_with_model(model, pixels, progress=False)
116 | utils.torch_gc()
117 | finally:
118 | cutils.PROGRESS_BAR_HOOK = hook
119 |
120 | pil_images = utils.get_pils_from_pixels(s)
121 | result = pil_images[0].resize((width, height), Image.Resampling.LANCZOS)
122 | data = {'image': b64_encoding(result)}
123 | return web.json_response(data)
124 |
125 |
126 | @PromptServer.instance.routes.post("/bmab/inpaint")
127 | async def bmab_inpaint(request):
128 | j = await request.json()
129 | b64img = j.get('image')
130 | if b64img is not None:
131 | b64mask = j.get('mask')
132 | prompt = j.get('prompt')
133 | steps = j.get('steps')
134 |
135 | inpaint = fill.BMABInpaint()
136 | img = b64_decoding(b64img)
137 | msk = b64_decoding(b64mask).convert('L')
138 |
139 | result = inpaint.infer(img, msk, steps, prompt_input=prompt)
140 | data = {'image': b64_encoding(result)}
141 | return web.json_response(data)
142 | else:
143 | print('release')
144 | fill.unload()
145 | return web.json_response({})
146 |
147 |
148 | @PromptServer.instance.routes.post("/bmab/depth")
149 | async def bmab_inpaint(request):
150 | from custom_nodes.comfyui_controlnet_aux.node_wrappers.depth_anything import Depth_Anything_Preprocessor
151 |
152 | j = await request.json()
153 | b64img = j.get('image')
154 | resolution = j.get('resolution')
155 | images = utils.get_pixels_from_pils([b64_decoding(b64img)])
156 | hook = cutils.PROGRESS_BAR_HOOK
157 | cutils.PROGRESS_BAR_HOOK = None
158 | try:
159 | node = Depth_Anything_Preprocessor()
160 | out = node.execute(images, 'depth_anything_vitl14.pth', resolution)
161 | finally:
162 | cutils.PROGRESS_BAR_HOOK = hook
163 | pil_images = utils.get_pils_from_pixels(out[0])
164 | data = {'image': b64_encoding(pil_images[0])}
165 | return web.json_response(data)
166 |
167 |
168 | @PromptServer.instance.routes.post("/bmab/sam")
169 | async def bmab_sam(request):
170 | j = await request.json()
171 | b64img = j.get('image')
172 | model = j.get('model')
173 | confidence = j.get('confidence')
174 |
175 | image = b64_decoding(b64img)
176 | boxes, conf = yolo.predict(image, model, confidence)
177 | for box in boxes:
178 | mask = sam.sam_predict_box(image, box)
179 | data = {'image': b64_encoding(mask.convert('RGB'))}
180 | sam.release()
181 | return web.json_response(data)
182 | return web.json_response({})
183 |
184 |
185 | @PromptServer.instance.routes.post("/bmab/detect")
186 | async def bmab_detect(request):
187 | j = await request.json()
188 | b64img = j.get('image')
189 | model = j.get('model')
190 | confidence = j.get('confidence')
191 |
192 | image = b64_decoding(b64img)
193 | boxes, conf = yolo.predict(image, model, confidence)
194 | data = {'boxes': boxes}
195 | return web.json_response(data)
196 |
197 |
198 | @PromptServer.instance.routes.post("/bmab/free")
199 | async def bmab_free(request):
200 | fill.unload()
201 | return web.json_response({})
202 |
203 |
204 | @PromptServer.instance.routes.get("/bmab/images")
205 | async def bmab_upscale(request):
206 | remote_client_id = request.rel_url.query.get("clientId")
207 | imgs = memory_image_storage.get(remote_client_id, [])
208 | results = [b64_encoding(i) for i in imgs]
209 | data = {'images': results}
210 | return web.json_response(data)
211 |
212 |
213 | from transformers import AutoProcessor, AutoModelForCausalLM
214 | import torch
215 |
216 |
217 | @PromptServer.instance.routes.post("/bmab/caption")
218 | async def bmab_upscale(request):
219 | j = await request.json()
220 | b64img = j.get('image')
221 |
222 | image = b64_decoding(b64img)
223 | caption = run_captioning(image)
224 | ret = {
225 | 'caption': caption
226 | }
227 | return web.json_response(ret)
228 |
229 |
230 | def run_captioning(image):
231 | print(f"run_captioning")
232 | # Load internally to not consume resources for training
233 | device = "cuda" if torch.cuda.is_available() else "cpu"
234 | print(f"device={device}")
235 | torch_dtype = torch.float16
236 | model = AutoModelForCausalLM.from_pretrained(
237 | "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
238 | ).to(device)
239 | processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
240 |
241 | prompt = ""
242 | inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
243 | print(f"inputs {inputs}")
244 |
245 | generated_ids = model.generate(
246 | input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
247 | )
248 | print(f"generated_ids {generated_ids}")
249 |
250 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
251 | print(f"generated_text: {generated_text}")
252 | parsed_answer = processor.post_process_generation(
253 | generated_text, task=prompt, image_size=(image.width, image.height)
254 | )
255 | print(f"parsed_answer = {parsed_answer}")
256 | caption_text = parsed_answer[""].replace("The image shows ", "")
257 |
258 |
259 | model.to("cpu")
260 | del model
261 | del processor
262 | if torch.cuda.is_available():
263 | torch.cuda.empty_cache()
264 |
265 | return caption_text
--------------------------------------------------------------------------------
/bmab/utils/color.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from skimage import exposure
4 | from blendmodes.blend import blendLayers, BlendType
5 | from PIL import Image
6 |
7 |
8 | def apply_color_correction(correction, original_image):
9 | image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
10 | cv2.cvtColor(
11 | np.asarray(original_image),
12 | cv2.COLOR_RGB2LAB
13 | ),
14 | cv2.cvtColor(np.asarray(correction.copy()), cv2.COLOR_RGB2LAB),
15 | channel_axis=2
16 | ), cv2.COLOR_LAB2RGB).astype("uint8"))
17 |
18 | image = blendLayers(image, original_image, BlendType.LUMINOSITY)
19 |
20 | return image.convert('RGB')
21 |
--------------------------------------------------------------------------------
/bmab/utils/grdino.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
4 | from bmab import utils
5 |
6 |
7 | def get_device():
8 | if sys.platform == 'darwin':
9 | # MPS is not good.
10 | return 'cpu'
11 | elif torch.cuda.is_available():
12 | return 'cuda'
13 | return 'cpu'
14 |
15 |
16 | def predict(pilimg, prompt, box_threahold=0.35, text_threshold=0.25, device=get_device()):
17 | model_id = "IDEA-Research/grounding-dino-base"
18 |
19 | processor = AutoProcessor.from_pretrained(model_id)
20 | model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
21 |
22 | inputs = processor(images=pilimg, text=prompt, return_tensors="pt").to(device)
23 | with torch.no_grad():
24 | outputs = model(**inputs)
25 |
26 | results = processor.post_process_grounded_object_detection(
27 | outputs,
28 | inputs.input_ids,
29 | box_threshold=box_threahold,
30 | text_threshold=text_threshold,
31 | target_sizes=[pilimg.size[::-1]]
32 | )
33 | del processor
34 | model.to('cpu')
35 | utils.torch_gc()
36 |
37 | result = results[0]
38 | return result["boxes"], result["scores"], result["labels"]
39 |
--------------------------------------------------------------------------------
/bmab/utils/sam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 |
5 | from PIL import Image
6 | from segment_anything import SamPredictor
7 | from segment_anything import sam_model_registry
8 | from bmab import utils
9 |
10 |
11 | bmab_model_path = os.path.join(os.path.dirname(__file__), '../../models')
12 |
13 | sam_model = None
14 |
15 |
16 | def sam_init(model):
17 | model_type = 'vit_b'
18 | for m in ('vit_b', 'vit_l', 'vit_h'):
19 | if model.find(m) >= 0:
20 | model_type = m
21 | break
22 |
23 | global sam_model
24 | if not sam_model:
25 | utils.lazy_loader(model)
26 | sam_model = sam_model_registry[model_type](checkpoint=f'%s/{model}' % bmab_model_path)
27 | sam_model.to(device=utils.get_device())
28 | sam_model.eval()
29 | return sam_model
30 |
31 |
32 | def sam_predict(pilimg, boxes, model='sam_vit_b_01ec64.pth'):
33 | sam = sam_init(model)
34 |
35 | mask_predictor = SamPredictor(sam)
36 |
37 | numpy_image = np.array(pilimg)
38 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
39 | mask_predictor.set_image(opencv_image)
40 |
41 | result = Image.new('L', pilimg.size, 0)
42 | for box in boxes:
43 | x1, y1, x2, y2 = box
44 |
45 | box = np.array([int(x1), int(y1), int(x2), int(y2)])
46 | masks, scores, logits = mask_predictor.predict(
47 | box=box,
48 | multimask_output=False
49 | )
50 |
51 | mask = Image.fromarray(masks[0])
52 | result.paste(mask, mask=mask)
53 |
54 | return result
55 |
56 |
57 | def sam_predict_box(pilimg, box, model='sam_vit_b_01ec64.pth'):
58 | sam = sam_init(model)
59 |
60 | mask_predictor = SamPredictor(sam)
61 |
62 | numpy_image = np.array(pilimg)
63 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
64 | mask_predictor.set_image(opencv_image)
65 |
66 | x1, y1, x2, y2 = box
67 | box = np.array([int(x1), int(y1), int(x2), int(y2)])
68 |
69 | masks, scores, logits = mask_predictor.predict(
70 | box=box,
71 | multimask_output=False
72 | )
73 |
74 | return Image.fromarray(masks[0])
75 |
76 |
77 | def get_array_predict_box(pilimg, box, model='sam_vit_b_01ec64.pth'):
78 | sam = sam_init(model)
79 | mask_predictor = SamPredictor(sam)
80 | numpy_image = np.array(pilimg)
81 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
82 | mask_predictor.set_image(opencv_image)
83 | x1, y1, x2, y2 = box
84 | box = np.array([int(x1), int(y1), int(x2), int(y2)])
85 | masks, scores, logits = mask_predictor.predict(
86 | box=box,
87 | multimask_output=False
88 | )
89 | return masks[0]
90 |
91 |
92 | def release():
93 | global sam_model
94 | if sam_model is not None:
95 | sam_model.to(device='cpu')
96 | sam_model = None
97 | utils.torch_gc()
98 |
--------------------------------------------------------------------------------
/bmab/utils/yolo.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | from ultralytics import YOLO
4 | from bmab import utils
5 |
6 |
7 | def predict(image: Image, model, confidence):
8 | yolo = utils.lazy_loader(model)
9 | boxes = []
10 | confs = []
11 | try:
12 | model = YOLO(yolo)
13 | pred = model(image, conf=confidence, device=utils.get_device())
14 | boxes = pred[0].boxes.xyxy.cpu().numpy()
15 | boxes = boxes.tolist()
16 | confs = pred[0].boxes.conf.tolist()
17 | except:
18 | pass
19 | del model
20 | utils.torch_gc()
21 | return boxes, confs
22 |
23 |
--------------------------------------------------------------------------------
/models/put_models_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/models/put_models_here
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui_bmab"
3 | description = "BMAB for ComfyUI. BMAB is an custom nodes of ComfyUI and has the function of post-processing the generated image according to settings. If necessary, you can find and redraw people, faces, and hands, or perform functions such as resize, resample, and add noise. You can composite two images or perform the Upscale function."
4 | version = "1.1.1"
5 | license = { text = "GNU Affero General Public License v3.0" }
6 | dependencies = ["lightning", "segment_anything", "omegaconf", "ultralytics", "scikit-image", "huggingface_hub", "transformers>=4.40.0", "blendmodes", "opencv-python", "urllib3", "pillow_avif_plugin"]
7 |
8 | [project.urls]
9 | Repository = "https://github.com/portu-sim/comfyui_bmab"
10 | # Used by Comfy Registry https://comfyregistry.org
11 |
12 | [tool.comfy]
13 | PublisherId = "portu-sim"
14 | DisplayName = "comfyui_bmab"
15 | Icon = ""
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | lightning
2 | segment_anything
3 | omegaconf
4 | ultralytics
5 | scikit-image
6 | huggingface_hub
7 | transformers>=4.40.0
8 | blendmodes
9 | opencv-python
10 | urllib3
11 | pillow
12 | pillow_avif_plugin
13 |
14 |
--------------------------------------------------------------------------------
/resources/cache/_cachefiles_will_be_put_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/resources/cache/_cachefiles_will_be_put_here
--------------------------------------------------------------------------------
/resources/examples/bmab-flux-sample.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 41,
3 | "last_link_id": 65,
4 | "nodes": [
5 | {
6 | "id": 39,
7 | "type": "CheckpointLoaderSimple",
8 | "pos": [
9 | 8,
10 | 617
11 | ],
12 | "size": {
13 | "0": 315,
14 | "1": 98
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "outputs": [
20 | {
21 | "name": "MODEL",
22 | "type": "MODEL",
23 | "links": [
24 | 58
25 | ],
26 | "shape": 3,
27 | "slot_index": 0
28 | },
29 | {
30 | "name": "CLIP",
31 | "type": "CLIP",
32 | "links": [
33 | 59
34 | ],
35 | "shape": 3,
36 | "slot_index": 1
37 | },
38 | {
39 | "name": "VAE",
40 | "type": "VAE",
41 | "links": [
42 | 61
43 | ],
44 | "shape": 3,
45 | "slot_index": 2
46 | }
47 | ],
48 | "properties": {
49 | "Node name for S&R": "CheckpointLoaderSimple"
50 | },
51 | "widgets_values": [
52 | "FLUX1\\flux1-dev-fp8.safetensors"
53 | ]
54 | },
55 | {
56 | "id": 38,
57 | "type": "BMAB Flux Integrator",
58 | "pos": [
59 | 376,
60 | 617
61 | ],
62 | "size": [
63 | 411.24033826936557,
64 | 377.8769301478941
65 | ],
66 | "flags": {},
67 | "order": 3,
68 | "mode": 0,
69 | "inputs": [
70 | {
71 | "name": "model",
72 | "type": "MODEL",
73 | "link": 58
74 | },
75 | {
76 | "name": "clip",
77 | "type": "CLIP",
78 | "link": 59
79 | },
80 | {
81 | "name": "vae",
82 | "type": "VAE",
83 | "link": 61
84 | },
85 | {
86 | "name": "context",
87 | "type": "CONTEXT",
88 | "link": 62
89 | },
90 | {
91 | "name": "seed_in",
92 | "type": "SEED",
93 | "link": null
94 | },
95 | {
96 | "name": "latent",
97 | "type": "LATENT",
98 | "link": 63
99 | },
100 | {
101 | "name": "image",
102 | "type": "IMAGE",
103 | "link": null
104 | }
105 | ],
106 | "outputs": [
107 | {
108 | "name": "BMAB bind",
109 | "type": "BMAB bind",
110 | "links": [
111 | 64
112 | ],
113 | "shape": 3,
114 | "slot_index": 0
115 | }
116 | ],
117 | "properties": {
118 | "Node name for S&R": "BMAB Flux Integrator"
119 | },
120 | "widgets_values": [
121 | 3.5,
122 | "cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open placing a fancy black forest cake with candles on top of a dinner table of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere there are paintings on the walls"
123 | ]
124 | },
125 | {
126 | "id": 41,
127 | "type": "BMAB KSampler",
128 | "pos": [
129 | 841,
130 | 618
131 | ],
132 | "size": {
133 | "0": 315,
134 | "1": 174
135 | },
136 | "flags": {},
137 | "order": 4,
138 | "mode": 0,
139 | "inputs": [
140 | {
141 | "name": "bind",
142 | "type": "BMAB bind",
143 | "link": 64
144 | },
145 | {
146 | "name": "lora",
147 | "type": "BMAB lora",
148 | "link": null
149 | }
150 | ],
151 | "outputs": [
152 | {
153 | "name": "BMAB bind",
154 | "type": "BMAB bind",
155 | "links": null,
156 | "shape": 3
157 | },
158 | {
159 | "name": "image",
160 | "type": "IMAGE",
161 | "links": [
162 | 65
163 | ],
164 | "shape": 3,
165 | "slot_index": 1
166 | }
167 | ],
168 | "properties": {
169 | "Node name for S&R": "BMAB KSampler"
170 | },
171 | "widgets_values": [
172 | 20,
173 | 1,
174 | "euler",
175 | "normal",
176 | 1
177 | ]
178 | },
179 | {
180 | "id": 27,
181 | "type": "EmptySD3LatentImage",
182 | "pos": [
183 | 11,
184 | 991
185 | ],
186 | "size": {
187 | "0": 315,
188 | "1": 106
189 | },
190 | "flags": {},
191 | "order": 1,
192 | "mode": 0,
193 | "outputs": [
194 | {
195 | "name": "LATENT",
196 | "type": "LATENT",
197 | "links": [
198 | 63
199 | ],
200 | "shape": 3,
201 | "slot_index": 0
202 | }
203 | ],
204 | "properties": {
205 | "Node name for S&R": "EmptySD3LatentImage"
206 | },
207 | "widgets_values": [
208 | 1024,
209 | 1024,
210 | 1
211 | ],
212 | "color": "#323",
213 | "bgcolor": "#535"
214 | },
215 | {
216 | "id": 9,
217 | "type": "SaveImage",
218 | "pos": [
219 | 1190,
220 | 628
221 | ],
222 | "size": {
223 | "0": 985.3012084960938,
224 | "1": 1060.3828125
225 | },
226 | "flags": {},
227 | "order": 5,
228 | "mode": 0,
229 | "inputs": [
230 | {
231 | "name": "images",
232 | "type": "IMAGE",
233 | "link": 65
234 | }
235 | ],
236 | "properties": {},
237 | "widgets_values": [
238 | "ComfyUI"
239 | ]
240 | },
241 | {
242 | "id": 40,
243 | "type": "BMAB Context",
244 | "pos": [
245 | 7,
246 | 757
247 | ],
248 | "size": {
249 | "0": 315,
250 | "1": 178
251 | },
252 | "flags": {},
253 | "order": 2,
254 | "mode": 0,
255 | "inputs": [
256 | {
257 | "name": "seed_in",
258 | "type": "SEED",
259 | "link": null
260 | }
261 | ],
262 | "outputs": [
263 | {
264 | "name": "BMAB context",
265 | "type": "CONTEXT",
266 | "links": [
267 | 62
268 | ],
269 | "shape": 3,
270 | "slot_index": 0
271 | }
272 | ],
273 | "properties": {
274 | "Node name for S&R": "BMAB Context"
275 | },
276 | "widgets_values": [
277 | 972054013131368,
278 | "randomize",
279 | 20,
280 | 8.040000000000001,
281 | "dpmpp_sde",
282 | "karras"
283 | ]
284 | }
285 | ],
286 | "links": [
287 | [
288 | 58,
289 | 39,
290 | 0,
291 | 38,
292 | 0,
293 | "MODEL"
294 | ],
295 | [
296 | 59,
297 | 39,
298 | 1,
299 | 38,
300 | 1,
301 | "CLIP"
302 | ],
303 | [
304 | 61,
305 | 39,
306 | 2,
307 | 38,
308 | 2,
309 | "VAE"
310 | ],
311 | [
312 | 62,
313 | 40,
314 | 0,
315 | 38,
316 | 3,
317 | "CONTEXT"
318 | ],
319 | [
320 | 63,
321 | 27,
322 | 0,
323 | 38,
324 | 5,
325 | "LATENT"
326 | ],
327 | [
328 | 64,
329 | 38,
330 | 0,
331 | 41,
332 | 0,
333 | "BMAB bind"
334 | ],
335 | [
336 | 65,
337 | 41,
338 | 1,
339 | 9,
340 | 0,
341 | "IMAGE"
342 | ]
343 | ],
344 | "groups": [],
345 | "config": {},
346 | "extra": {
347 | "ds": {
348 | "scale": 0.7513148009015777,
349 | "offset": [
350 | 466.97698927324814,
351 | -334.0327902578701
352 | ]
353 | }
354 | },
355 | "version": 0.4
356 | }
--------------------------------------------------------------------------------
/resources/examples/example.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 29,
3 | "last_link_id": 46,
4 | "nodes": [
5 | {
6 | "id": 4,
7 | "type": "CheckpointLoaderSimple",
8 | "pos": [
9 | 25,
10 | 217
11 | ],
12 | "size": {
13 | "0": 315,
14 | "1": 98
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "outputs": [
20 | {
21 | "name": "MODEL",
22 | "type": "MODEL",
23 | "links": [
24 | 34
25 | ],
26 | "slot_index": 0
27 | },
28 | {
29 | "name": "CLIP",
30 | "type": "CLIP",
31 | "links": [
32 | 35
33 | ],
34 | "slot_index": 1
35 | },
36 | {
37 | "name": "VAE",
38 | "type": "VAE",
39 | "links": [],
40 | "slot_index": 2
41 | }
42 | ],
43 | "properties": {
44 | "Node name for S&R": "CheckpointLoaderSimple"
45 | },
46 | "widgets_values": [
47 | "portu_429_lora2.fp16.safetensors"
48 | ]
49 | },
50 | {
51 | "id": 18,
52 | "type": "VAELoader",
53 | "pos": [
54 | 20,
55 | 356
56 | ],
57 | "size": {
58 | "0": 315,
59 | "1": 58
60 | },
61 | "flags": {},
62 | "order": 1,
63 | "mode": 0,
64 | "outputs": [
65 | {
66 | "name": "VAE",
67 | "type": "VAE",
68 | "links": [
69 | 38
70 | ],
71 | "shape": 3,
72 | "slot_index": 0
73 | }
74 | ],
75 | "properties": {
76 | "Node name for S&R": "VAELoader"
77 | },
78 | "widgets_values": [
79 | "vae-ft-mse-840000-ema-pruned.ckpt"
80 | ]
81 | },
82 | {
83 | "id": 26,
84 | "type": "BMAB Context",
85 | "pos": [
86 | 24,
87 | 464
88 | ],
89 | "size": {
90 | "0": 315,
91 | "1": 178
92 | },
93 | "flags": {},
94 | "order": 2,
95 | "mode": 0,
96 | "inputs": [
97 | {
98 | "name": "seed_in",
99 | "type": "SEED",
100 | "link": null
101 | }
102 | ],
103 | "outputs": [
104 | {
105 | "name": "BMAB context",
106 | "type": "CONTEXT",
107 | "links": [
108 | 39
109 | ],
110 | "shape": 3,
111 | "slot_index": 0
112 | }
113 | ],
114 | "properties": {
115 | "Node name for S&R": "BMAB Context"
116 | },
117 | "widgets_values": [
118 | 501253181764904,
119 | "randomize",
120 | 20,
121 | 8,
122 | "dpmpp_sde",
123 | "karras"
124 | ]
125 | },
126 | {
127 | "id": 25,
128 | "type": "BMAB Integrator",
129 | "pos": [
130 | 407,
131 | 224
132 | ],
133 | "size": {
134 | "0": 400,
135 | "1": 318
136 | },
137 | "flags": {},
138 | "order": 4,
139 | "mode": 0,
140 | "inputs": [
141 | {
142 | "name": "model",
143 | "type": "MODEL",
144 | "link": 34
145 | },
146 | {
147 | "name": "clip",
148 | "type": "CLIP",
149 | "link": 35
150 | },
151 | {
152 | "name": "vae",
153 | "type": "VAE",
154 | "link": 38
155 | },
156 | {
157 | "name": "context",
158 | "type": "CONTEXT",
159 | "link": 39
160 | },
161 | {
162 | "name": "seed_in",
163 | "type": "SEED",
164 | "link": null
165 | },
166 | {
167 | "name": "latent",
168 | "type": "LATENT",
169 | "link": 40
170 | },
171 | {
172 | "name": "image",
173 | "type": "IMAGE",
174 | "link": null
175 | }
176 | ],
177 | "outputs": [
178 | {
179 | "name": "BMAB bind",
180 | "type": "BMAB bind",
181 | "links": [
182 | 37
183 | ],
184 | "shape": 3,
185 | "slot_index": 0
186 | }
187 | ],
188 | "properties": {
189 | "Node name for S&R": "BMAB Integrator"
190 | },
191 | "widgets_values": [
192 | -2,
193 | "none",
194 | "A1111",
195 | "1girl, standing, full body, street,",
196 | "worst quality, low quality"
197 | ]
198 | },
199 | {
200 | "id": 5,
201 | "type": "EmptyLatentImage",
202 | "pos": [
203 | 26,
204 | 687
205 | ],
206 | "size": {
207 | "0": 315,
208 | "1": 106
209 | },
210 | "flags": {},
211 | "order": 3,
212 | "mode": 0,
213 | "outputs": [
214 | {
215 | "name": "LATENT",
216 | "type": "LATENT",
217 | "links": [
218 | 40
219 | ],
220 | "slot_index": 0
221 | }
222 | ],
223 | "properties": {
224 | "Node name for S&R": "EmptyLatentImage"
225 | },
226 | "widgets_values": [
227 | 512,
228 | 768,
229 | 1
230 | ]
231 | },
232 | {
233 | "id": 11,
234 | "type": "BMAB KSampler",
235 | "pos": [
236 | 840,
237 | 221
238 | ],
239 | "size": {
240 | "0": 315,
241 | "1": 174
242 | },
243 | "flags": {},
244 | "order": 5,
245 | "mode": 0,
246 | "inputs": [
247 | {
248 | "name": "bind",
249 | "type": "BMAB bind",
250 | "link": 37
251 | },
252 | {
253 | "name": "lora",
254 | "type": "BMAB lora",
255 | "link": null
256 | }
257 | ],
258 | "outputs": [
259 | {
260 | "name": "BMAB bind",
261 | "type": "BMAB bind",
262 | "links": [
263 | 45
264 | ],
265 | "shape": 3,
266 | "slot_index": 0
267 | },
268 | {
269 | "name": "image",
270 | "type": "IMAGE",
271 | "links": null,
272 | "shape": 3,
273 | "slot_index": 1
274 | }
275 | ],
276 | "properties": {
277 | "Node name for S&R": "BMAB KSampler"
278 | },
279 | "widgets_values": [
280 | 20,
281 | 8,
282 | "Use same sampler",
283 | "Use same scheduler",
284 | 1
285 | ]
286 | },
287 | {
288 | "id": 17,
289 | "type": "BMAB Save Image",
290 | "pos": [
291 | 1934,
292 | 221
293 | ],
294 | "size": {
295 | "0": 526.6683959960938,
296 | "1": 744.5535278320312
297 | },
298 | "flags": {},
299 | "order": 8,
300 | "mode": 0,
301 | "inputs": [
302 | {
303 | "name": "bind",
304 | "type": "BMAB bind",
305 | "link": 43
306 | },
307 | {
308 | "name": "images",
309 | "type": "IMAGE",
310 | "link": null
311 | }
312 | ],
313 | "properties": {
314 | "Node name for S&R": "BMAB Save Image"
315 | },
316 | "widgets_values": [
317 | "ComfyUI"
318 | ]
319 | },
320 | {
321 | "id": 29,
322 | "type": "BMAB KSamplerHiresFixWithUpscaler",
323 | "pos": [
324 | 1206,
325 | 221
326 | ],
327 | "size": {
328 | "0": 315,
329 | "1": 290
330 | },
331 | "flags": {},
332 | "order": 6,
333 | "mode": 0,
334 | "inputs": [
335 | {
336 | "name": "bind",
337 | "type": "BMAB bind",
338 | "link": 45
339 | },
340 | {
341 | "name": "image",
342 | "type": "IMAGE",
343 | "link": null
344 | },
345 | {
346 | "name": "lora",
347 | "type": "BMAB lora",
348 | "link": null
349 | }
350 | ],
351 | "outputs": [
352 | {
353 | "name": "BMAB bind",
354 | "type": "BMAB bind",
355 | "links": [
356 | 46
357 | ],
358 | "shape": 3,
359 | "slot_index": 0
360 | },
361 | {
362 | "name": "image",
363 | "type": "IMAGE",
364 | "links": null,
365 | "shape": 3
366 | }
367 | ],
368 | "properties": {
369 | "Node name for S&R": "BMAB KSamplerHiresFixWithUpscaler"
370 | },
371 | "widgets_values": [
372 | 20,
373 | 7,
374 | "Use same sampler",
375 | "Use same scheduler",
376 | 0.45,
377 | "LANCZOS",
378 | 2,
379 | 512,
380 | 512
381 | ]
382 | },
383 | {
384 | "id": 27,
385 | "type": "BMAB Face Detailer",
386 | "pos": [
387 | 1568,
388 | 219
389 | ],
390 | "size": {
391 | "0": 315,
392 | "1": 290
393 | },
394 | "flags": {},
395 | "order": 7,
396 | "mode": 0,
397 | "inputs": [
398 | {
399 | "name": "bind",
400 | "type": "BMAB bind",
401 | "link": 46
402 | },
403 | {
404 | "name": "image",
405 | "type": "IMAGE",
406 | "link": null
407 | },
408 | {
409 | "name": "lora",
410 | "type": "BMAB lora",
411 | "link": null
412 | }
413 | ],
414 | "outputs": [
415 | {
416 | "name": "BMAB bind",
417 | "type": "BMAB bind",
418 | "links": [
419 | 43
420 | ],
421 | "shape": 3,
422 | "slot_index": 0
423 | },
424 | {
425 | "name": "image",
426 | "type": "IMAGE",
427 | "links": null,
428 | "shape": 3
429 | }
430 | ],
431 | "properties": {
432 | "Node name for S&R": "BMAB Face Detailer"
433 | },
434 | "widgets_values": [
435 | 20,
436 | 4,
437 | "Use same sampler",
438 | "Use same scheduler",
439 | 0.45,
440 | 32,
441 | 4,
442 | 512,
443 | 512
444 | ]
445 | }
446 | ],
447 | "links": [
448 | [
449 | 34,
450 | 4,
451 | 0,
452 | 25,
453 | 0,
454 | "MODEL"
455 | ],
456 | [
457 | 35,
458 | 4,
459 | 1,
460 | 25,
461 | 1,
462 | "CLIP"
463 | ],
464 | [
465 | 37,
466 | 25,
467 | 0,
468 | 11,
469 | 0,
470 | "BMAB bind"
471 | ],
472 | [
473 | 38,
474 | 18,
475 | 0,
476 | 25,
477 | 2,
478 | "VAE"
479 | ],
480 | [
481 | 39,
482 | 26,
483 | 0,
484 | 25,
485 | 3,
486 | "CONTEXT"
487 | ],
488 | [
489 | 40,
490 | 5,
491 | 0,
492 | 25,
493 | 5,
494 | "LATENT"
495 | ],
496 | [
497 | 43,
498 | 27,
499 | 0,
500 | 17,
501 | 0,
502 | "BMAB bind"
503 | ],
504 | [
505 | 45,
506 | 11,
507 | 0,
508 | 29,
509 | 0,
510 | "BMAB bind"
511 | ],
512 | [
513 | 46,
514 | 29,
515 | 0,
516 | 27,
517 | 0,
518 | "BMAB bind"
519 | ]
520 | ],
521 | "groups": [],
522 | "config": {},
523 | "extra": {
524 | "ds": {
525 | "scale": 0.6830134553650709,
526 | "offset": [
527 | 138.7066514884853,
528 | 67.04635823955786
529 | ]
530 | }
531 | },
532 | "version": 0.4
533 | }
--------------------------------------------------------------------------------
/resources/examples/openpose-hand-detailing-example.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 33,
3 | "last_link_id": 53,
4 | "nodes": [
5 | {
6 | "id": 4,
7 | "type": "CheckpointLoaderSimple",
8 | "pos": [
9 | 25,
10 | 217
11 | ],
12 | "size": {
13 | "0": 315,
14 | "1": 98
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "outputs": [
20 | {
21 | "name": "MODEL",
22 | "type": "MODEL",
23 | "links": [
24 | 34
25 | ],
26 | "slot_index": 0
27 | },
28 | {
29 | "name": "CLIP",
30 | "type": "CLIP",
31 | "links": [
32 | 35
33 | ],
34 | "slot_index": 1
35 | },
36 | {
37 | "name": "VAE",
38 | "type": "VAE",
39 | "links": [],
40 | "slot_index": 2
41 | }
42 | ],
43 | "properties": {
44 | "Node name for S&R": "CheckpointLoaderSimple"
45 | },
46 | "widgets_values": [
47 | "portu_429_lora2.fp16.safetensors"
48 | ]
49 | },
50 | {
51 | "id": 18,
52 | "type": "VAELoader",
53 | "pos": [
54 | 20,
55 | 356
56 | ],
57 | "size": {
58 | "0": 315,
59 | "1": 58
60 | },
61 | "flags": {},
62 | "order": 1,
63 | "mode": 0,
64 | "outputs": [
65 | {
66 | "name": "VAE",
67 | "type": "VAE",
68 | "links": [
69 | 38
70 | ],
71 | "shape": 3,
72 | "slot_index": 0
73 | }
74 | ],
75 | "properties": {
76 | "Node name for S&R": "VAELoader"
77 | },
78 | "widgets_values": [
79 | "vae-ft-mse-840000-ema-pruned.ckpt"
80 | ]
81 | },
82 | {
83 | "id": 26,
84 | "type": "BMAB Context",
85 | "pos": [
86 | 24,
87 | 464
88 | ],
89 | "size": {
90 | "0": 315,
91 | "1": 178
92 | },
93 | "flags": {},
94 | "order": 2,
95 | "mode": 0,
96 | "inputs": [
97 | {
98 | "name": "seed_in",
99 | "type": "SEED",
100 | "link": null
101 | }
102 | ],
103 | "outputs": [
104 | {
105 | "name": "BMAB context",
106 | "type": "CONTEXT",
107 | "links": [
108 | 39
109 | ],
110 | "shape": 3,
111 | "slot_index": 0
112 | }
113 | ],
114 | "properties": {
115 | "Node name for S&R": "BMAB Context"
116 | },
117 | "widgets_values": [
118 | 878953794515282,
119 | "randomize",
120 | 20,
121 | 8,
122 | "dpmpp_sde",
123 | "karras"
124 | ]
125 | },
126 | {
127 | "id": 5,
128 | "type": "EmptyLatentImage",
129 | "pos": [
130 | 26,
131 | 687
132 | ],
133 | "size": {
134 | "0": 315,
135 | "1": 106
136 | },
137 | "flags": {},
138 | "order": 3,
139 | "mode": 0,
140 | "outputs": [
141 | {
142 | "name": "LATENT",
143 | "type": "LATENT",
144 | "links": [
145 | 40
146 | ],
147 | "slot_index": 0
148 | }
149 | ],
150 | "properties": {
151 | "Node name for S&R": "EmptyLatentImage"
152 | },
153 | "widgets_values": [
154 | 512,
155 | 768,
156 | 1
157 | ]
158 | },
159 | {
160 | "id": 30,
161 | "type": "BMAB Load Image",
162 | "pos": [
163 | 407,
164 | 753
165 | ],
166 | "size": [
167 | 469.37541296463996,
168 | 746.2945929323166
169 | ],
170 | "flags": {},
171 | "order": 4,
172 | "mode": 0,
173 | "outputs": [
174 | {
175 | "name": "IMAGE",
176 | "type": "IMAGE",
177 | "links": [
178 | 47
179 | ],
180 | "shape": 3,
181 | "slot_index": 0
182 | },
183 | {
184 | "name": "MASK",
185 | "type": "MASK",
186 | "links": null,
187 | "shape": 3
188 | }
189 | ],
190 | "properties": {
191 | "Node name for S&R": "BMAB Load Image"
192 | },
193 | "widgets_values": [
194 | "00009-3139289071.png",
195 | "image"
196 | ]
197 | },
198 | {
199 | "id": 33,
200 | "type": "PreviewImage",
201 | "pos": [
202 | 1319,
203 | 750
204 | ],
205 | "size": [
206 | 537.1467559333894,
207 | 756.6380952760667
208 | ],
209 | "flags": {},
210 | "order": 8,
211 | "mode": 0,
212 | "inputs": [
213 | {
214 | "name": "images",
215 | "type": "IMAGE",
216 | "link": 52
217 | }
218 | ],
219 | "properties": {
220 | "Node name for S&R": "PreviewImage"
221 | }
222 | },
223 | {
224 | "id": 32,
225 | "type": "PreviewImage",
226 | "pos": [
227 | 1875,
228 | 753
229 | ],
230 | "size": [
231 | 562.8348297615144,
232 | 749.2057612916917
233 | ],
234 | "flags": {},
235 | "order": 7,
236 | "mode": 0,
237 | "inputs": [
238 | {
239 | "name": "images",
240 | "type": "IMAGE",
241 | "link": 51
242 | }
243 | ],
244 | "properties": {
245 | "Node name for S&R": "PreviewImage"
246 | }
247 | },
248 | {
249 | "id": 25,
250 | "type": "BMAB Integrator",
251 | "pos": [
252 | 407,
253 | 224
254 | ],
255 | "size": {
256 | "0": 400,
257 | "1": 318
258 | },
259 | "flags": {},
260 | "order": 5,
261 | "mode": 0,
262 | "inputs": [
263 | {
264 | "name": "model",
265 | "type": "MODEL",
266 | "link": 34
267 | },
268 | {
269 | "name": "clip",
270 | "type": "CLIP",
271 | "link": 35
272 | },
273 | {
274 | "name": "vae",
275 | "type": "VAE",
276 | "link": 38
277 | },
278 | {
279 | "name": "context",
280 | "type": "CONTEXT",
281 | "link": 39
282 | },
283 | {
284 | "name": "seed_in",
285 | "type": "SEED",
286 | "link": null
287 | },
288 | {
289 | "name": "latent",
290 | "type": "LATENT",
291 | "link": 40
292 | },
293 | {
294 | "name": "image",
295 | "type": "IMAGE",
296 | "link": null
297 | }
298 | ],
299 | "outputs": [
300 | {
301 | "name": "BMAB bind",
302 | "type": "BMAB bind",
303 | "links": [
304 | 53
305 | ],
306 | "shape": 3,
307 | "slot_index": 0
308 | }
309 | ],
310 | "properties": {
311 | "Node name for S&R": "BMAB Integrator"
312 | },
313 | "widgets_values": [
314 | -2,
315 | "none",
316 | "A1111",
317 | "1girl, standing, full body, street, (detailed hand:1.4)",
318 | "worst quality, low quality"
319 | ]
320 | },
321 | {
322 | "id": 31,
323 | "type": "BMAB Openpose Hand Detailer",
324 | "pos": [
325 | 968,
326 | 320
327 | ],
328 | "size": {
329 | "0": 315,
330 | "1": 314
331 | },
332 | "flags": {},
333 | "order": 6,
334 | "mode": 0,
335 | "inputs": [
336 | {
337 | "name": "bind",
338 | "type": "BMAB bind",
339 | "link": 53
340 | },
341 | {
342 | "name": "image",
343 | "type": "IMAGE",
344 | "link": 47
345 | },
346 | {
347 | "name": "lora",
348 | "type": "BMAB lora",
349 | "link": null
350 | }
351 | ],
352 | "outputs": [
353 | {
354 | "name": "BMAB bind",
355 | "type": "BMAB bind",
356 | "links": null,
357 | "shape": 3
358 | },
359 | {
360 | "name": "image",
361 | "type": "IMAGE",
362 | "links": [
363 | 51
364 | ],
365 | "shape": 3,
366 | "slot_index": 1
367 | },
368 | {
369 | "name": "annotation",
370 | "type": "IMAGE",
371 | "links": [
372 | 52
373 | ],
374 | "shape": 3,
375 | "slot_index": 2
376 | }
377 | ],
378 | "properties": {
379 | "Node name for S&R": "BMAB Openpose Hand Detailer"
380 | },
381 | "widgets_values": [
382 | 20,
383 | 7,
384 | "Use same sampler",
385 | "Use same scheduler",
386 | 0.45,
387 | 32,
388 | 4,
389 | 1024,
390 | 1024,
391 | "disable"
392 | ]
393 | }
394 | ],
395 | "links": [
396 | [
397 | 34,
398 | 4,
399 | 0,
400 | 25,
401 | 0,
402 | "MODEL"
403 | ],
404 | [
405 | 35,
406 | 4,
407 | 1,
408 | 25,
409 | 1,
410 | "CLIP"
411 | ],
412 | [
413 | 38,
414 | 18,
415 | 0,
416 | 25,
417 | 2,
418 | "VAE"
419 | ],
420 | [
421 | 39,
422 | 26,
423 | 0,
424 | 25,
425 | 3,
426 | "CONTEXT"
427 | ],
428 | [
429 | 40,
430 | 5,
431 | 0,
432 | 25,
433 | 5,
434 | "LATENT"
435 | ],
436 | [
437 | 47,
438 | 30,
439 | 0,
440 | 31,
441 | 1,
442 | "IMAGE"
443 | ],
444 | [
445 | 51,
446 | 31,
447 | 1,
448 | 32,
449 | 0,
450 | "IMAGE"
451 | ],
452 | [
453 | 52,
454 | 31,
455 | 2,
456 | 33,
457 | 0,
458 | "IMAGE"
459 | ],
460 | [
461 | 53,
462 | 25,
463 | 0,
464 | 31,
465 | 0,
466 | "BMAB bind"
467 | ]
468 | ],
469 | "groups": [],
470 | "config": {},
471 | "extra": {
472 | "ds": {
473 | "scale": 0.6830134553650707,
474 | "offset": [
475 | 185.95766803736825,
476 | -28.764659635582873
477 | ]
478 | }
479 | },
480 | "version": 0.4
481 | }
--------------------------------------------------------------------------------
/resources/wildcard/put_wildcard_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/resources/wildcard/put_wildcard_here
--------------------------------------------------------------------------------
/web/gemini.js:
--------------------------------------------------------------------------------
1 | import { app } from "/scripts/app.js";
2 | import { ComfyWidgets } from "/scripts/widgets.js";
3 |
4 | app.registerExtension({
5 | name: "Comfy.BMAB.GoogleGeminiPromptNode",
6 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
7 | if (nodeData.name === "BMAB Google Gemini Prompt") {
8 |
9 | const onExecuted = nodeType.prototype.onExecuted;
10 | nodeType.prototype.onExecuted = function (texts) {
11 | onExecuted?.apply(this, arguments);
12 | let widget_id = this?.widgets.findIndex(
13 | obj => obj.name === 'random_seed'
14 | );
15 | this.widgets[widget_id].value = Number(texts?.string)
16 | app.graph.setDirtyCanvas(true);
17 | };
18 | }
19 | },
20 | });
21 |
--------------------------------------------------------------------------------
/web/loadoutputimage.js:
--------------------------------------------------------------------------------
1 | // this code from ComfyUI_Custom_Nodes_AlekPet
2 |
3 | import { app } from "/scripts/app.js";
4 | import { ComfyWidgets } from "/scripts/widgets.js";
5 | import { api } from "../../scripts/api.js";
6 |
7 | app.registerExtension({
8 | name: "Comfy.BMAB.LoadOutputImage",
9 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
10 | if (nodeData.name === "BMAB Load Output Image") {
11 |
12 |
13 | const onNodeCreated = nodeType.prototype.onNodeCreated;
14 | nodeType.prototype.onNodeCreated = function () {
15 | onNodeCreated?.apply(this, arguments);
16 |
17 | const node = this
18 | function showImage(name) {
19 | const img = new Image();
20 | img.onload = () => {
21 | node.imgs = [img];
22 | app.graph.setDirtyCanvas(true);
23 | };
24 |
25 | const split = name.split('/');
26 | let subdir = ''
27 | let fileName = ''
28 | if (split.length === 1) {
29 | fileName = split[0];
30 | } else {
31 | subdir = split.slice(0, split.length - 1).join('/');
32 | fileName = split[split.length - 1];
33 | }
34 | img.src = api.apiURL(`/view?filename=${fileName}&subfolder=${subdir}&type=output${app.getRandParam()}`);
35 | node.setSizeForImage?.();
36 | }
37 |
38 | const imageWidget = node.widgets.find((w) => w.name === "image");
39 |
40 | const cb = this.callback;
41 | imageWidget.callback = function () {
42 | showImage(imageWidget.value);
43 | app.graph.setDirtyCanvas(true);
44 | if (cb) {
45 | return cb.apply(this, arguments);
46 | }
47 | };
48 |
49 | showImage(imageWidget.value);
50 | app.graph.setDirtyCanvas(true);
51 | };
52 | }
53 | },
54 | });
55 |
--------------------------------------------------------------------------------
/web/previewtext.js:
--------------------------------------------------------------------------------
1 | // this code from ComfyUI_Custom_Nodes_AlekPet
2 |
3 | import { app } from "/scripts/app.js";
4 | import { ComfyWidgets } from "/scripts/widgets.js";
5 |
6 | app.registerExtension({
7 | name: "Comfy.BMAB.PreviewText",
8 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
9 | if (nodeData.name === "BMAB Preview Text") {
10 | const onNodeCreated = nodeType.prototype.onNodeCreated;
11 | nodeType.prototype.onNodeCreated = function () {
12 | const ret = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
13 |
14 | let BMABPreviewText = app.graph._nodes.filter((wi) => wi.type == nodeData.name),
15 | nodeName = `${nodeData.name}_${BMABPreviewText.length}`;
16 | console.log(`Create ${nodeData.name}: ${nodeName}`);
17 | const wi = ComfyWidgets.STRING(this, nodeName,
18 | ["STRING", {default: "", placeholder: "Text message output...", multiline: true,},], app);
19 | wi.widget.inputEl.readOnly = true;
20 | return ret;
21 | };
22 |
23 | // Function set value
24 | const outSet = function (texts) {
25 | if (texts.length > 0) {
26 | let widget_id = this?.widgets.findIndex((w) => w.type == "customtext");
27 | if (Array.isArray(texts))
28 | texts = texts.filter((word) => word.trim() !== "").map((word) => word.trim()).join(" ");
29 | this.widgets[widget_id].value = texts;
30 | app.graph.setDirtyCanvas(true);
31 | }
32 | };
33 |
34 | // onExecuted
35 | const onExecuted = nodeType.prototype.onExecuted;
36 | nodeType.prototype.onExecuted = function (texts) {
37 | onExecuted?.apply(this, arguments);
38 | outSet.call(this, texts?.string);
39 | };
40 |
41 | // onConfigure
42 | const onConfigure = nodeType.prototype.onConfigure;
43 | nodeType.prototype.onConfigure = function (w) {
44 | onConfigure?.apply(this, arguments);
45 | if (w?.widgets_values?.length) {
46 | outSet.call(this, w.widgets_values);
47 | }
48 | };
49 | }
50 | },
51 | });
52 |
--------------------------------------------------------------------------------
/web/remoteaccess.js:
--------------------------------------------------------------------------------
1 | import { app } from "/scripts/app.js";
2 | import { ComfyWidgets } from "/scripts/widgets.js";
3 | import { api } from "../../scripts/api.js";
4 |
5 | app.registerExtension({
6 | name: "Comfy.BMAB.BMABRemoteAccessAndSave",
7 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
8 | if (nodeData.name === "BMAB Remote Access And Save") {
9 |
10 | function register_client_id(name) {
11 | fetch(api.apiURL(`bmab?remote_client_id=${api.clientId}&remote_name=${name}`))
12 | .then(response => response.json())
13 | .then(data => {
14 | console.log(data);
15 | })
16 | .catch(error => {
17 | console.error(error);
18 | });
19 | }
20 |
21 | const onNodeCreated = nodeType.prototype.onNodeCreated;
22 | nodeType.prototype.onNodeCreated = function () {
23 | onNodeCreated?.apply(this, arguments);
24 | const remote_name = this.widgets.find((w) => w.name === "remote_name");
25 | register_client_id(remote_name.value);
26 |
27 | remote_name.callback = function () {
28 | register_client_id(this.value);
29 | }
30 | };
31 |
32 | const onReconnect = nodeType.prototype.onReconnect;
33 | nodeType.prototype.onReconnect = function () {
34 | const remote_name = this.widgets.find((w) => w.name === "remote_name");
35 | register_client_id(remote_name.value);
36 | };
37 |
38 | const onConfigure = nodeType.prototype.onConfigure;
39 | nodeType.prototype.onConfigure = function (w) {
40 | const remote_name = this.widgets.find((w) => w.name === "remote_name");
41 | register_client_id(remote_name.value);
42 | };
43 |
44 | const onBMABQueue = nodeType.prototype.onBMABQueue;
45 | nodeType.prototype.onBMABQueue = function () {
46 | console.log('QUEUE prompt')
47 | app.queuePrompt(0, 1)
48 | };
49 |
50 | api.addEventListener("reconnected", ({ detail }) => {
51 | app.graph._nodes.forEach((node) => {
52 | if (node.onReconnect)
53 | node.onReconnect()
54 | })
55 | });
56 |
57 | api.addEventListener("bmab_queue", ({ detail }) => {
58 | app.graph._nodes.forEach((node) => {
59 | if (node.onBMABQueue)
60 | node.onBMABQueue()
61 | })
62 | });
63 | }
64 | },
65 | });
66 |
--------------------------------------------------------------------------------