├── .github └── workflows │ └── publish.yml ├── LICENSE.txt ├── README.md ├── __init__.py ├── bmab ├── __init__.py ├── external │ ├── __init__.py │ ├── advanced_clip │ │ └── __init__.py │ ├── fill │ │ ├── __init__.py │ │ ├── controlnet_union.py │ │ └── pipeline_fill_sd_xl.py │ ├── lama │ │ ├── __init__.py │ │ ├── config.yaml │ │ └── saicinpainting │ │ │ ├── __init__.py │ │ │ ├── training │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ └── masks.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── adversarial.py │ │ │ │ ├── constants.py │ │ │ │ ├── distance_weighting.py │ │ │ │ ├── feature_matching.py │ │ │ │ ├── perceptual.py │ │ │ │ ├── segmentation.py │ │ │ │ └── style_loss.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── depthwise_sep_conv.py │ │ │ │ ├── fake_fakes.py │ │ │ │ ├── ffc.py │ │ │ │ ├── multidilated_conv.py │ │ │ │ ├── multiscale.py │ │ │ │ ├── pix2pixhd.py │ │ │ │ ├── spatial_transform.py │ │ │ │ └── squeeze_excitation.py │ │ │ ├── trainers │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ └── default.py │ │ │ └── visualizers │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── colors.py │ │ │ │ ├── directory.py │ │ │ │ └── noop.py │ │ │ └── utils.py │ └── rmbg14 │ │ ├── MyConfig.py │ │ ├── __init__.py │ │ ├── briarmbg.py │ │ └── utilities.py ├── nodes │ ├── __init__.py │ ├── a1111api.py │ ├── basic.py │ ├── binder.py │ ├── cnloader.py │ ├── detailers.py │ ├── fill.py │ ├── imaging.py │ ├── loaders.py │ ├── resize.py │ ├── sampler.py │ ├── toy.py │ ├── upscaler.py │ ├── utilnode.py │ └── watermark.py ├── process.py ├── serverext.py └── utils │ ├── __init__.py │ ├── color.py │ ├── colorname.py │ ├── grdino.py │ ├── sam.py │ └── yolo.py ├── models └── put_models_here ├── pyproject.toml ├── requirements.txt ├── resources ├── cache │ └── _cachefiles_will_be_put_here ├── examples │ ├── bmab-flux-sample.json │ ├── example.json │ ├── hand-detail-example.json │ ├── ic-light-example.json │ └── openpose-hand-detailing-example.json └── wildcard │ └── put_wildcard_here └── web ├── gemini.js ├── loadoutputimage.js ├── previewtext.js └── remoteaccess.js /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # comfyui_bmab 2 | 3 | BMAB is an custom nodes of ComfyUI and has the function of post-processing the generated image according to settings. 4 | If necessary, you can find and redraw people, faces, and hands, or perform functions such as resize, resample, and add noise. 5 | You can composite two images or perform the Upscale function. 6 | 7 | 8 | 9 | You can download sample.json. 10 | 11 | https://github.com/portu-sim/comfyui_bmab/blob/main/resources/examples/example.json 12 | 13 | # Flux 14 | 15 | BMAB now supports Flux 1. 16 | 17 | 18 | https://github.com/portu-sim/comfyui_bmab/blob/main/resources/examples/bmab-flux-sample.json 19 | 20 | 21 | 22 | # New Nodes 23 | 24 | ## BMAB Inpaint 25 | 26 | image 27 | 28 | ## BMAB Reframe 29 | 30 | image 31 | 32 | ## BMAB Outpaint By Ratio 33 | 34 | image 35 | 36 | 37 | 38 | ### Gallery 39 | 40 | [instagram](https://www.instagram.com/portu.sim/) 41 | [facebook](https://www.facebook.com/portusimkr) 42 | 43 | ### Hand Detailing Sample 44 | 45 | BMAB detects and enlarges the upper body of a person and performs Openpose at high resolution to fix incorrectly drawn hands. 46 | 47 | 48 | 49 | 50 | # Installation 51 | 52 | You can install comfyui_bmab using ComfyUI-Manager easily. 53 | You will need to install a total of three custom nodes. 54 | 55 | * comfyui_bmab 56 | * comfyui_controlnet_aux 57 | * https://github.com/Fannovel16/comfyui_controlnet_aux.git 58 | * Fannovel16, Thanks for excellent code. 59 | * ComfyUI_IPAdapter_plus 60 | * https://github.com/cubiq/ComfyUI_IPAdapter_plus.git 61 | * cubiq, Thanks for excellent code. 62 | 63 | 64 | ### Grounding DINO Installation 65 | 66 | Transfomer v4.40.0 has Grounding DINO implementation. 67 | https://github.com/huggingface/transformers/releases/tag/v4.40.0 68 | Now BMAB use transformer for detecting object. 69 | No installation required. 70 | 71 | ## Install Manually 72 | 73 | I can't describe about your python environment. 74 | I will write the installation instructions assuming you have some knowledge of Python. 75 | 76 | 77 | ### Windows portable User 78 | 79 | ```commandline 80 | cd ComfyUI/custom_nodes 81 | git clone https://github.com/portu-sim/comfyui_bmab.git 82 | cd comfyui_bmab 83 | python_embeded\python.exe -m pip install -r requirements.txt 84 | cd .. 85 | ``` 86 | 87 | You will need to install two additional custom nodes required by comfyui_bmab. 88 | 89 | ```commandline 90 | cd ComfyUI/custom_nodes 91 | git clone https://github.com/Fannovel16/comfyui_controlnet_aux.git 92 | cd comfyui_controlnet_aux 93 | python_embeded\python.exe -r pip install requirements.txt 94 | cd .. 95 | git clone https://github.com/cubiq/ComfyUI_IPAdapter_plus.git 96 | cd ComfyUI_IPAdapter_plus 97 | python_embeded\python.exe -m pip install -r requirements.txt 98 | cd .. 99 | ``` 100 | 101 | ### Other python environment 102 | 103 | ```commandline 104 | cd ComfyUI/custom_nodes 105 | git clone https://github.com/portu-sim/comfyui_bmab.git 106 | cd comfyui_bmab 107 | pip install -r requirements.txt 108 | cd .. 109 | ``` 110 | 111 | You will need to install two additional custom nodes required by comfyui_bmab. 112 | 113 | ```commandline 114 | cd ComfyUI/custom_nodes 115 | git clone https://github.com/Fannovel16/comfyui_controlnet_aux.git 116 | cd comfyui_controlnet_aux 117 | pip install -r requirements.txt 118 | cd .. 119 | git clone https://github.com/cubiq/ComfyUI_IPAdapter_plus.git 120 | cd ComfyUI_IPAdapter_plus 121 | pip install -r requirements.txt 122 | cd .. 123 | ``` 124 | 125 | 126 | Run ComfyUI 127 | 128 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pillow_avif 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__))) 6 | from bmab import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 7 | 8 | try: 9 | import testnodes 10 | print('Register test nodes.') 11 | testnodes.register(NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS) 12 | except Exception as e: 13 | print('Not found test nodes.') 14 | print(e) 15 | 16 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 17 | 18 | WEB_DIRECTORY = f'./web' 19 | 20 | -------------------------------------------------------------------------------- /bmab/__init__.py: -------------------------------------------------------------------------------- 1 | from bmab import nodes, serverext 2 | 3 | 4 | NODE_CLASS_MAPPINGS = { 5 | # Basic 6 | 'BMAB Basic': nodes.BMABBasic, 7 | 'BMAB Edge': nodes.BMABEdge, 8 | 'BMAB Text': nodes.BMABText, 9 | 'BMAB Preview Text': nodes.BMABPreviewText, 10 | 11 | # Resize 12 | 'BMAB Resize By Person': nodes.BMABResizeByPerson, 13 | 'BMAB Resize By Ratio': nodes.BMABResizeByRatio, 14 | 'BMAB Resize and Fill': nodes.BMABResizeAndFill, 15 | 'BMAB Crop': nodes.BMABCrop, 16 | 'BMAB Zoom Out': nodes.BMABZoomOut, 17 | 'BMAB Square': nodes.BMABSquare, 18 | 19 | # Sampler 20 | 'BMAB Integrator': nodes.BMABIntegrator, 21 | 'BMAB ToBind': nodes.BMABToBind, 22 | 'BMAB Flux Integrator': nodes.BMABFluxIntegrator, 23 | 'BMAB Extractor': nodes.BMABExtractor, 24 | 'BMAB SeedGenerator': nodes.BMABSeedGenerator, 25 | 'BMAB KSampler': nodes.BMABKSampler, 26 | 'BMAB KSamplerHiresFix': nodes.BMABKSamplerHiresFix, 27 | 'BMAB KSamplerHiresFixWithUpscaler': nodes.BMABKSamplerHiresFixWithUpscaler, 28 | 'BMAB Context': nodes.BMABContextNode, 29 | 'BMAB Import Integrator': nodes.BMABImportIntegrator, 30 | 'BMAB KSamplerKohyaDeepShrink': nodes.BMABKSamplerKohyaDeepShrink, 31 | 'BMAB Clip Text Encoder SDXL': nodes.BMABClipTextEncoderSDXL, 32 | 33 | # Detailer 34 | 'BMAB Face Detailer': nodes.BMABFaceDetailer, 35 | 'BMAB Person Detailer': nodes.BMABPersonDetailer, 36 | 'BMAB Simple Hand Detailer': nodes.BMABSimpleHandDetailer, 37 | 'BMAB Subframe Hand Detailer': nodes.BMABSubframeHandDetailer, 38 | 'BMAB Openpose Hand Detailer': nodes.BMABOpenposeHandDetailer, 39 | 'BMAB Detail Anything': nodes.BMABDetailAnything, 40 | 41 | # Control Net 42 | 'BMAB ControlNet': nodes.BMABControlNet, 43 | 'BMAB Flux ControlNet': nodes.BMABFluxControlNet, 44 | 'BMAB ControlNet Openpose': nodes.BMABControlNetOpenpose, 45 | 'BMAB ControlNet IPAdapter': nodes.BMABControlNetIPAdapter, 46 | 47 | # Imaging 48 | 'BMAB Detection Crop': nodes.BMABDetectionCrop, 49 | 'BMAB Remove Background': nodes.BMABRemoveBackground, 50 | 'BMAB Alpha Composit': nodes.BMABAlphaComposit, 51 | 'BMAB Blend': nodes.BMABBlend, 52 | 'BMAB Detect And Mask': nodes.BMABDetectAndMask, 53 | 'BMAB Lama Inpaint': nodes.BMABLamaInpaint, 54 | 'BMAB Detector': nodes.BMABDetector, 55 | 'BMAB Segment Anything': nodes.BMABSegmentAnything, 56 | 'BMAB Masks To Images': nodes.BMABMasksToImages, 57 | 'BMAB Load Image': nodes.BMABLoadImage, 58 | 'BMAB Load Output Image': nodes.BMABLoadOutputImage, 59 | 'BMAB Black And White': nodes.BMABBlackAndWhite, 60 | 'BMAB Detect And Paste': nodes.BMABDetectAndPaste, 61 | 62 | # SD-WebUI API 63 | 'BMAB SD-WebUI API Server': nodes.BMABApiServer, 64 | 'BMAB SD-WebUI API T2I': nodes.BMABApiSDWebUIT2I, 65 | 'BMAB SD-WebUI API I2I': nodes.BMABApiSDWebUII2I, 66 | 'BMAB SD-WebUI API T2I Hires.Fix': nodes.BMABApiSDWebUIT2IHiresFix, 67 | 'BMAB SD-WebUI API BMAB Extension': nodes.BMABApiSDWebUIBMABExtension, 68 | 'BMAB SD-WebUI API ControlNet': nodes.BMABApiSDWebUIControlNet, 69 | 70 | # UTIL Nodes 71 | 'BMAB Model To Bind': nodes.BMABModelToBind, 72 | 'BMAB Conditioning To Bind': nodes.BMABConditioningToBind, 73 | 'BMAB Noise Generator': nodes.BMABNoiseGenerator, 74 | 'BMAB Base64 Image': nodes.BMABBase64Image, 75 | 'BMAB Image Storage': nodes.BMABImageStorage, 76 | 'BMAB Normalize Size': nodes.BMABNormalizeSize, 77 | 'BMAB Dummy': nodes.BMABDummy, 78 | 79 | # Watermark 80 | 'BMAB Watermark': nodes.BMABWatermark, 81 | 82 | 'BMAB Upscaler': nodes.BMABUpscale, 83 | 'BMAB Save Image': nodes.BMABSaveImage, 84 | 'BMAB Remote Access And Save': nodes.BMABRemoteAccessAndSave, 85 | 'BMAB Upscale With Model': nodes.BMABUpscaleWithModel, 86 | 'BMAB LoRA Loader': nodes.BMABLoraLoader, 87 | 'BMAB Prompt': nodes.BMABPrompt, 88 | 'BMAB Google Gemini Prompt': nodes.BMABGoogleGemini, 89 | 90 | # Fill 91 | 'BMAB Reframe': nodes.BMABReframe, 92 | 'BMAB Outpaint By Ratio': nodes.BMABOutpaintByRatio, 93 | 'BMAB Inpaint': nodes.BMABInpaint 94 | } 95 | 96 | NODE_DISPLAY_NAME_MAPPINGS = { 97 | # Preview 98 | 'BMAB Basic': 'BMAB Basic', 99 | 'BMAB Edge': 'BMAB Edge', 100 | 'BMAB Text': 'BMAB Text', 101 | 'BMAB Preview Text': 'BMAB Preview Text', 102 | 103 | # Resize 104 | 'BMAB Resize By Person': 'BMAB Resize By Person', 105 | 'BMAB Resize By Ratio': 'BMAB Resize By Ratio', 106 | 'BMAB Resize and Fill': 'BMAB Resize And Fill', 107 | 'BMAB Crop': 'BMAB Crop', 108 | 'BMAB Zoom Out': 'BMAB Zoom Out', 109 | 'BMAB Square': 'BMAB Square', 110 | 111 | # Sampler 112 | 'BMAB Integrator': 'BMAB Integrator', 113 | 'BMAB ToBind': 'BMAB ToBind', 114 | 'BMAB Flux Integrator': 'BMAB Flux Integrator', 115 | 'BMAB KSampler': 'BMAB KSampler', 116 | 'BMAB KSamplerHiresFix': 'BMAB KSampler Hires. Fix', 117 | 'BMAB KSamplerHiresFixWithUpscaler': 'BMAB KSampler Hires. Fix With Upscaler', 118 | 'BMAB Extractor': 'BMAB Extractor', 119 | 'BMAB SeedGenerator': 'BMAB Seed Generator', 120 | 'BMAB Context': 'BMAB Context', 121 | 'BMAB Import Integrator': 'BMAB Import Integrator', 122 | 'BMAB KSamplerKohyaDeepShrink': 'BMAB KSampler with Kohya Deep Shrink', 123 | 'BMAB Clip Text Encoder SDXL': 'BMAB Clip Text Encoder SDXL', 124 | 125 | # Detailer 126 | 'BMAB Face Detailer': 'BMAB Face Detailer', 127 | 'BMAB Person Detailer': 'BMAB Person Detailer', 128 | 'BMAB Simple Hand Detailer': 'BMAB Simple Hand Detailer', 129 | 'BMAB Subframe Hand Detailer': 'BMAB Subframe Hand Detailer', 130 | 'BMAB Openpose Hand Detailer': 'BMAB Openpose Hand Detailer', 131 | 'BMAB Detail Anything': 'BMAB Detail Anything', 132 | 133 | # Control Net 134 | 'BMAB ControlNet': 'BMAB ControlNet', 135 | 'BMAB Flux ControlNet': 'BMAB Flux ControlNet', 136 | 'BMAB ControlNet Openpose': 'BMAB ControlNet Openpose', 137 | 'BMAB ControlNet IPAdapter': 'BMAB ControlNet IPAdapter', 138 | 139 | # Imaging 140 | 'BMAB Detection Crop': 'BMAB Detection Crop', 141 | 'BMAB Remove Background': 'BMAB Remove Background', 142 | 'BMAB Alpha Composit': 'BMAB Alpha Composit', 143 | 'BMAB Blend': 'BMAB Blend', 144 | 'BMAB Detect And Mask': 'BMAB Detect And Mask', 145 | 'BMAB Lama Inpaint': 'BMAB Lama Inpaint', 146 | 'BMAB Detector': 'BMAB Detector', 147 | 'BMAB Segment Anything': 'BMAB Segment Anything', 148 | 'BMAB Masks To Images': 'BMAB Masks To Images', 149 | 'BMAB Load Image': 'BMAB Load Image', 150 | 'BMAB Load Output Image': 'BMAB Load Output Image', 151 | 'BMAB Black And White': 'BMAB Black And White', 152 | 'BMAB Detect And Paste': 'BMAB Detect And Paste', 153 | 154 | # SD-WebUI API 155 | 'BMAB SD-WebUI API Server': 'BMAB SD-WebUI API Server', 156 | 'BMAB SD-WebUI API T2I': 'BMAB SD-WebUI API T2I', 157 | 'BMAB SD-WebUI API I2I': 'BMAB SD-WebUI API I2I', 158 | 'BMAB SD-WebUI API T2I Hires.Fix': 'BMAB SD-WebUI API T2I Hires.Fix', 159 | 'BMAB SD-WebUI API BMAB Extension': 'BMAB SD-WebUI API BMAB Extension', 160 | 'BMAB SD-WebUI API ControlNet': 'BMAB SD-WebUI API ControlNet', 161 | 162 | # UTIL Nodes 163 | 'BMAB Model To Bind': 'BMAB Model To Bind', 164 | 'BMAB Conditioning To Bind': 'BMAB Conditioning To Bind', 165 | 'BMAB Noise Generator': 'BMAB Noise Generator', 166 | 'BMAB Base64 Image': 'BMAB Base64 Image', 167 | 'BMAB Image Storage': 'BMAB Image Storage', 168 | 'BMAB Normalize Size': 'BMAB Normalize Size', 169 | 'BMAB Dummy': 'BMAB Dummy', 170 | 171 | # Watermark 172 | 'BMAB Watermark': 'BMAB Watermark', 173 | 174 | 'BMAB DinoSam': 'BMAB DinoSam', 175 | 'BMAB Upscaler': 'BMAB Upscaler', 176 | 'BMAB Control Net': 'BMAB ControlNet', 177 | 'BMAB Save Image': 'BMAB Save Image', 178 | 'BMAB Remote Access And Save': 'BMAB Remote Access And Save', 179 | 'BMAB Upscale With Model': 'BMAB Upscale With Model', 180 | 'BMAB LoRA Loader': 'BMAB Lora Loader', 181 | 'BMAB Prompt': 'BMAB Prompt', 182 | 'BMAB Google Gemini Prompt': 'BMAB Google Gemini API', 183 | 184 | # Fill 185 | 'BMAB Reframe': 'BMAB Reframe', 186 | 'BMAB Outpaint By Ratio': 'BMAB Outpaint By Ratio', 187 | 'BMAB Inpaint': 'BMAB Inpaint', 188 | } 189 | 190 | -------------------------------------------------------------------------------- /bmab/external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/__init__.py -------------------------------------------------------------------------------- /bmab/external/fill/__init__.py: -------------------------------------------------------------------------------- 1 | # from https://huggingface.co/OzzyGT 2 | # from https://huggingface.co/fffiloni 3 | -------------------------------------------------------------------------------- /bmab/external/lama/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/advimman/lama 2 | # https://github.com/Mikubill/sd-webui-controlnet 3 | 4 | import os 5 | import gc 6 | import sys 7 | import cv2 8 | import yaml 9 | import torch 10 | import numpy as np 11 | from PIL import Image 12 | from omegaconf import OmegaConf 13 | from einops import rearrange 14 | 15 | from bmab.external.lama.saicinpainting.training.trainers import load_checkpoint 16 | from bmab import utils 17 | 18 | 19 | def lama_inpainting(image, mask, device='gpu'): 20 | mask_input = mask 21 | lama = LamaInpainting() 22 | if device == 'cpu': 23 | lama.device = 'cpu' 24 | 25 | if device == 'mps': 26 | image = lama(image, mask_input) 27 | elif lama.device.startswith('cuda') and image.width != image.height: 28 | width, height = image.size 29 | mx = max(image.width, image.height) 30 | resized = Image.new('RGB', (mx, mx), 0) 31 | mask = Image.new('L', (mx, mx), 0) 32 | if height < width: 33 | y0 = (mx - height) // 2 34 | resized.paste(image, (0, y0)) 35 | mask.paste(mask_input, (0, y0)) 36 | l = lama(resized, mask) 37 | image = l.crop((0, y0, width, y0 + height)) 38 | else: 39 | x0 = (mx - width) // 2 40 | resized.paste(image, (x0, 0)) 41 | mask.paste(mask_input, (x0, 0)) 42 | l = lama(resized, mask) 43 | image = l.crop((x0, 0, x0 + width, height)) 44 | else: 45 | image = lama(image, mask_input) 46 | del lama 47 | gc.collect() 48 | if torch.cuda.is_available(): 49 | torch.cuda.empty_cache() 50 | torch.cuda.ipc_collect() 51 | return image 52 | 53 | 54 | class LamaInpainting: 55 | 56 | def __init__(self): 57 | if sys.platform == 'darwin': 58 | self.device = 'mps' 59 | elif torch.cuda.is_available(): 60 | self.device = 'cuda' 61 | else: 62 | self.device = 'cpu' 63 | self.model = None 64 | 65 | @staticmethod 66 | def load_image(pilimg, mode='RGB'): 67 | img = np.array(pilimg.convert(mode)) 68 | if img.ndim == 3: 69 | print('transpose') 70 | img = np.transpose(img, (2, 0, 1)) 71 | out_img = img.astype('float32') / 255 72 | return out_img 73 | 74 | def load_model(self): 75 | config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.yaml') 76 | cfg = yaml.safe_load(open(config_path, 'rt')) 77 | cfg = OmegaConf.create(cfg) 78 | cfg.training_model.predict_only = True 79 | cfg.visualizer.kind = 'noop' 80 | lamapth = utils.lazy_loader('ControlNetLama.pth') 81 | self.model = load_checkpoint(cfg, lamapth, strict=False, map_location='cpu') 82 | self.model = self.model.to(self.device) 83 | self.model.eval() 84 | 85 | def unload_model(self): 86 | if self.model is not None: 87 | self.model.cpu() 88 | 89 | def __call__(self, image, mask): 90 | if self.model is None: 91 | self.load_model() 92 | self.model.to(self.device) 93 | 94 | opencv_image = cv2.cvtColor(np.array(image.convert('RGB')), cv2.COLOR_RGB2BGR)[:, :, 0:3] 95 | opencv_mask = cv2.cvtColor(np.array(mask.convert('RGB')), cv2.COLOR_RGB2BGR)[:, :, 0:1] 96 | color = np.ascontiguousarray(opencv_image).astype(np.float32) / 255.0 97 | mask = np.ascontiguousarray(opencv_mask).astype(np.float32) / 255.0 98 | 99 | with torch.no_grad(): 100 | color = torch.from_numpy(color).float().to(self.device) 101 | mask = torch.from_numpy(mask).float().to(self.device) 102 | mask = (mask > 0.5).float() 103 | color = color * (1 - mask) 104 | image_feed = torch.cat([color, mask], dim=2) 105 | image_feed = rearrange(image_feed, 'h w c -> 1 c h w') 106 | result = self.model(image_feed)[0] 107 | result = rearrange(result, 'c h w -> h w c') 108 | result = result * mask + color * (1 - mask) 109 | result *= 255.0 110 | 111 | img = result.detach().cpu().numpy().clip(0, 255).astype(np.uint8) 112 | color_coverted = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 113 | pil_image = Image.fromarray(color_coverted) 114 | return pil_image 115 | -------------------------------------------------------------------------------- /bmab/external/lama/config.yaml: -------------------------------------------------------------------------------- 1 | run_title: b18_ffc075_batch8x15 2 | training_model: 3 | kind: default 4 | visualize_each_iters: 1000 5 | concat_mask: true 6 | store_discr_outputs_for_vis: true 7 | losses: 8 | l1: 9 | weight_missing: 0 10 | weight_known: 10 11 | perceptual: 12 | weight: 0 13 | adversarial: 14 | kind: r1 15 | weight: 10 16 | gp_coef: 0.001 17 | mask_as_fake_target: true 18 | allow_scale_mask: true 19 | feature_matching: 20 | weight: 100 21 | resnet_pl: 22 | weight: 30 23 | weights_path: ${env:TORCH_HOME} 24 | 25 | optimizers: 26 | generator: 27 | kind: adam 28 | lr: 0.001 29 | discriminator: 30 | kind: adam 31 | lr: 0.0001 32 | visualizer: 33 | key_order: 34 | - image 35 | - predicted_image 36 | - discr_output_fake 37 | - discr_output_real 38 | - inpainted 39 | rescale_keys: 40 | - discr_output_fake 41 | - discr_output_real 42 | kind: directory 43 | outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples 44 | location: 45 | data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large 46 | out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments 47 | tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs 48 | data: 49 | batch_size: 15 50 | val_batch_size: 2 51 | num_workers: 3 52 | train: 53 | indir: ${location.data_root_dir}/train 54 | out_size: 256 55 | mask_gen_kwargs: 56 | irregular_proba: 1 57 | irregular_kwargs: 58 | max_angle: 4 59 | max_len: 200 60 | max_width: 100 61 | max_times: 5 62 | min_times: 1 63 | box_proba: 1 64 | box_kwargs: 65 | margin: 10 66 | bbox_min_size: 30 67 | bbox_max_size: 150 68 | max_times: 3 69 | min_times: 1 70 | segm_proba: 0 71 | segm_kwargs: 72 | confidence_threshold: 0.5 73 | max_object_area: 0.5 74 | min_mask_area: 0.07 75 | downsample_levels: 6 76 | num_variants_per_mask: 1 77 | rigidness_mode: 1 78 | max_foreground_coverage: 0.3 79 | max_foreground_intersection: 0.7 80 | max_mask_intersection: 0.1 81 | max_hidden_area: 0.1 82 | max_scale_change: 0.25 83 | horizontal_flip: true 84 | max_vertical_shift: 0.2 85 | position_shuffle: true 86 | transform_variant: distortions 87 | dataloader_kwargs: 88 | batch_size: ${data.batch_size} 89 | shuffle: true 90 | num_workers: ${data.num_workers} 91 | val: 92 | indir: ${location.data_root_dir}/val 93 | img_suffix: .png 94 | dataloader_kwargs: 95 | batch_size: ${data.val_batch_size} 96 | shuffle: false 97 | num_workers: ${data.num_workers} 98 | visual_test: 99 | indir: ${location.data_root_dir}/korean_test 100 | img_suffix: _input.png 101 | pad_out_to_modulo: 32 102 | dataloader_kwargs: 103 | batch_size: 1 104 | shuffle: false 105 | num_workers: ${data.num_workers} 106 | generator: 107 | kind: ffc_resnet 108 | input_nc: 4 109 | output_nc: 3 110 | ngf: 64 111 | n_downsampling: 3 112 | n_blocks: 18 113 | add_out_act: sigmoid 114 | init_conv_kwargs: 115 | ratio_gin: 0 116 | ratio_gout: 0 117 | enable_lfu: false 118 | downsample_conv_kwargs: 119 | ratio_gin: ${generator.init_conv_kwargs.ratio_gout} 120 | ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin} 121 | enable_lfu: false 122 | resnet_conv_kwargs: 123 | ratio_gin: 0.75 124 | ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin} 125 | enable_lfu: false 126 | discriminator: 127 | kind: pix2pixhd_nlayer 128 | input_nc: 3 129 | ndf: 64 130 | n_layers: 4 131 | evaluator: 132 | kind: default 133 | inpainted_key: inpainted 134 | integral_kind: ssim_fid100_f1 135 | trainer: 136 | kwargs: 137 | gpus: -1 138 | accelerator: ddp 139 | max_epochs: 200 140 | gradient_clip_val: 1 141 | log_gpu_memory: None 142 | limit_train_batches: 25000 143 | val_check_interval: ${trainer.kwargs.limit_train_batches} 144 | log_every_n_steps: 1000 145 | precision: 32 146 | terminate_on_nan: false 147 | check_val_every_n_epoch: 1 148 | num_sanity_val_steps: 8 149 | limit_val_batches: 1000 150 | replace_sampler_ddp: false 151 | checkpoint_kwargs: 152 | verbose: true 153 | save_top_k: 5 154 | save_last: true 155 | period: 1 156 | monitor: val_ssim_fid100_f1_total_mean 157 | mode: max 158 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/__init__.py -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/training/__init__.py -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/training/data/__init__.py -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/lama/saicinpainting/training/losses/__init__.py -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/adversarial.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BaseAdversarialLoss: 9 | def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 10 | generator: nn.Module, discriminator: nn.Module): 11 | """ 12 | Prepare for generator step 13 | :param real_batch: Tensor, a batch of real samples 14 | :param fake_batch: Tensor, a batch of samples produced by generator 15 | :param generator: 16 | :param discriminator: 17 | :return: None 18 | """ 19 | 20 | def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 21 | generator: nn.Module, discriminator: nn.Module): 22 | """ 23 | Prepare for discriminator step 24 | :param real_batch: Tensor, a batch of real samples 25 | :param fake_batch: Tensor, a batch of samples produced by generator 26 | :param generator: 27 | :param discriminator: 28 | :return: None 29 | """ 30 | 31 | def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 32 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, 33 | mask: Optional[torch.Tensor] = None) \ 34 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 35 | """ 36 | Calculate generator loss 37 | :param real_batch: Tensor, a batch of real samples 38 | :param fake_batch: Tensor, a batch of samples produced by generator 39 | :param discr_real_pred: Tensor, discriminator output for real_batch 40 | :param discr_fake_pred: Tensor, discriminator output for fake_batch 41 | :param mask: Tensor, actual mask, which was at input of generator when making fake_batch 42 | :return: total generator loss along with some values that might be interesting to log 43 | """ 44 | raise NotImplemented() 45 | 46 | def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 47 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, 48 | mask: Optional[torch.Tensor] = None) \ 49 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 50 | """ 51 | Calculate discriminator loss and call .backward() on it 52 | :param real_batch: Tensor, a batch of real samples 53 | :param fake_batch: Tensor, a batch of samples produced by generator 54 | :param discr_real_pred: Tensor, discriminator output for real_batch 55 | :param discr_fake_pred: Tensor, discriminator output for fake_batch 56 | :param mask: Tensor, actual mask, which was at input of generator when making fake_batch 57 | :return: total discriminator loss along with some values that might be interesting to log 58 | """ 59 | raise NotImplemented() 60 | 61 | def interpolate_mask(self, mask, shape): 62 | assert mask is not None 63 | assert self.allow_scale_mask or shape == mask.shape[-2:] 64 | if shape != mask.shape[-2:] and self.allow_scale_mask: 65 | if self.mask_scale_mode == 'maxpool': 66 | mask = F.adaptive_max_pool2d(mask, shape) 67 | else: 68 | mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode) 69 | return mask 70 | 71 | def make_r1_gp(discr_real_pred, real_batch): 72 | if torch.is_grad_enabled(): 73 | grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0] 74 | grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean() 75 | else: 76 | grad_penalty = 0 77 | real_batch.requires_grad = False 78 | 79 | return grad_penalty 80 | 81 | class NonSaturatingWithR1(BaseAdversarialLoss): 82 | def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False, 83 | mask_scale_mode='nearest', extra_mask_weight_for_gen=0, 84 | use_unmasked_for_gen=True, use_unmasked_for_discr=True): 85 | self.gp_coef = gp_coef 86 | self.weight = weight 87 | # use for discr => use for gen; 88 | # otherwise we teach only the discr to pay attention to very small difference 89 | assert use_unmasked_for_gen or (not use_unmasked_for_discr) 90 | # mask as target => use unmasked for discr: 91 | # if we don't care about unmasked regions at all 92 | # then it doesn't matter if the value of mask_as_fake_target is true or false 93 | assert use_unmasked_for_discr or (not mask_as_fake_target) 94 | self.use_unmasked_for_gen = use_unmasked_for_gen 95 | self.use_unmasked_for_discr = use_unmasked_for_discr 96 | self.mask_as_fake_target = mask_as_fake_target 97 | self.allow_scale_mask = allow_scale_mask 98 | self.mask_scale_mode = mask_scale_mode 99 | self.extra_mask_weight_for_gen = extra_mask_weight_for_gen 100 | 101 | def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 102 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, 103 | mask=None) \ 104 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 105 | fake_loss = F.softplus(-discr_fake_pred) 106 | if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ 107 | not self.use_unmasked_for_gen: # == if masked region should be treated differently 108 | mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) 109 | if not self.use_unmasked_for_gen: 110 | fake_loss = fake_loss * mask 111 | else: 112 | pixel_weights = 1 + mask * self.extra_mask_weight_for_gen 113 | fake_loss = fake_loss * pixel_weights 114 | 115 | return fake_loss.mean() * self.weight, dict() 116 | 117 | def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 118 | generator: nn.Module, discriminator: nn.Module): 119 | real_batch.requires_grad = True 120 | 121 | def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 122 | discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, 123 | mask=None) \ 124 | -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 125 | 126 | real_loss = F.softplus(-discr_real_pred) 127 | grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef 128 | fake_loss = F.softplus(discr_fake_pred) 129 | 130 | if not self.use_unmasked_for_discr or self.mask_as_fake_target: 131 | # == if masked region should be treated differently 132 | mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) 133 | # use_unmasked_for_discr=False only makes sense for fakes; 134 | # for reals there is no difference beetween two regions 135 | fake_loss = fake_loss * mask 136 | if self.mask_as_fake_target: 137 | fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred) 138 | 139 | sum_discr_loss = real_loss + grad_penalty + fake_loss 140 | metrics = dict(discr_real_out=discr_real_pred.mean(), 141 | discr_fake_out=discr_fake_pred.mean(), 142 | discr_real_gp=grad_penalty) 143 | return sum_discr_loss.mean(), metrics 144 | 145 | class BCELoss(BaseAdversarialLoss): 146 | def __init__(self, weight): 147 | self.weight = weight 148 | self.bce_loss = nn.BCEWithLogitsLoss() 149 | 150 | def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 151 | real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device) 152 | fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight 153 | return fake_loss, dict() 154 | 155 | def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, 156 | generator: nn.Module, discriminator: nn.Module): 157 | real_batch.requires_grad = True 158 | 159 | def discriminator_loss(self, 160 | mask: torch.Tensor, 161 | discr_real_pred: torch.Tensor, 162 | discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 163 | 164 | real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device) 165 | sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2 166 | metrics = dict(discr_real_out=discr_real_pred.mean(), 167 | discr_fake_out=discr_fake_pred.mean(), 168 | discr_real_gp=0) 169 | return sum_discr_loss, metrics 170 | 171 | 172 | def make_discrim_loss(kind, **kwargs): 173 | if kind == 'r1': 174 | return NonSaturatingWithR1(**kwargs) 175 | elif kind == 'bce': 176 | return BCELoss(**kwargs) 177 | raise ValueError(f'Unknown adversarial loss kind {kind}') 178 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/constants.py: -------------------------------------------------------------------------------- 1 | weights = {"ade20k": 2 | [6.34517766497462, 3 | 9.328358208955224, 4 | 11.389521640091116, 5 | 16.10305958132045, 6 | 20.833333333333332, 7 | 22.22222222222222, 8 | 25.125628140703515, 9 | 43.29004329004329, 10 | 50.5050505050505, 11 | 54.6448087431694, 12 | 55.24861878453038, 13 | 60.24096385542168, 14 | 62.5, 15 | 66.2251655629139, 16 | 84.74576271186442, 17 | 90.90909090909092, 18 | 91.74311926605505, 19 | 96.15384615384616, 20 | 96.15384615384616, 21 | 97.08737864077669, 22 | 102.04081632653062, 23 | 135.13513513513513, 24 | 149.2537313432836, 25 | 153.84615384615384, 26 | 163.93442622950818, 27 | 166.66666666666666, 28 | 188.67924528301887, 29 | 192.30769230769232, 30 | 217.3913043478261, 31 | 227.27272727272725, 32 | 227.27272727272725, 33 | 227.27272727272725, 34 | 303.03030303030306, 35 | 322.5806451612903, 36 | 333.3333333333333, 37 | 370.3703703703703, 38 | 384.61538461538464, 39 | 416.6666666666667, 40 | 416.6666666666667, 41 | 434.7826086956522, 42 | 434.7826086956522, 43 | 454.5454545454545, 44 | 454.5454545454545, 45 | 500.0, 46 | 526.3157894736842, 47 | 526.3157894736842, 48 | 555.5555555555555, 49 | 555.5555555555555, 50 | 555.5555555555555, 51 | 555.5555555555555, 52 | 555.5555555555555, 53 | 555.5555555555555, 54 | 555.5555555555555, 55 | 588.2352941176471, 56 | 588.2352941176471, 57 | 588.2352941176471, 58 | 588.2352941176471, 59 | 588.2352941176471, 60 | 666.6666666666666, 61 | 666.6666666666666, 62 | 666.6666666666666, 63 | 666.6666666666666, 64 | 714.2857142857143, 65 | 714.2857142857143, 66 | 714.2857142857143, 67 | 714.2857142857143, 68 | 714.2857142857143, 69 | 769.2307692307693, 70 | 769.2307692307693, 71 | 769.2307692307693, 72 | 833.3333333333334, 73 | 833.3333333333334, 74 | 833.3333333333334, 75 | 833.3333333333334, 76 | 909.090909090909, 77 | 1000.0, 78 | 1111.111111111111, 79 | 1111.111111111111, 80 | 1111.111111111111, 81 | 1111.111111111111, 82 | 1111.111111111111, 83 | 1250.0, 84 | 1250.0, 85 | 1250.0, 86 | 1250.0, 87 | 1250.0, 88 | 1428.5714285714287, 89 | 1428.5714285714287, 90 | 1428.5714285714287, 91 | 1428.5714285714287, 92 | 1428.5714285714287, 93 | 1428.5714285714287, 94 | 1428.5714285714287, 95 | 1666.6666666666667, 96 | 1666.6666666666667, 97 | 1666.6666666666667, 98 | 1666.6666666666667, 99 | 1666.6666666666667, 100 | 1666.6666666666667, 101 | 1666.6666666666667, 102 | 1666.6666666666667, 103 | 1666.6666666666667, 104 | 1666.6666666666667, 105 | 1666.6666666666667, 106 | 2000.0, 107 | 2000.0, 108 | 2000.0, 109 | 2000.0, 110 | 2000.0, 111 | 2000.0, 112 | 2000.0, 113 | 2000.0, 114 | 2000.0, 115 | 2000.0, 116 | 2000.0, 117 | 2000.0, 118 | 2000.0, 119 | 2000.0, 120 | 2000.0, 121 | 2000.0, 122 | 2000.0, 123 | 2500.0, 124 | 2500.0, 125 | 2500.0, 126 | 2500.0, 127 | 2500.0, 128 | 2500.0, 129 | 2500.0, 130 | 2500.0, 131 | 2500.0, 132 | 2500.0, 133 | 2500.0, 134 | 2500.0, 135 | 2500.0, 136 | 3333.3333333333335, 137 | 3333.3333333333335, 138 | 3333.3333333333335, 139 | 3333.3333333333335, 140 | 3333.3333333333335, 141 | 3333.3333333333335, 142 | 3333.3333333333335, 143 | 3333.3333333333335, 144 | 3333.3333333333335, 145 | 3333.3333333333335, 146 | 3333.3333333333335, 147 | 3333.3333333333335, 148 | 3333.3333333333335, 149 | 5000.0, 150 | 5000.0, 151 | 5000.0] 152 | } -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/distance_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | from bmab.external.lama.saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN 7 | 8 | 9 | def dummy_distance_weighter(real_img, pred_img, mask): 10 | return mask 11 | 12 | 13 | def get_gauss_kernel(kernel_size, width_factor=1): 14 | coords = torch.stack(torch.meshgrid(torch.arange(kernel_size), 15 | torch.arange(kernel_size)), 16 | dim=0).float() 17 | diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor) 18 | diff /= diff.sum() 19 | return diff 20 | 21 | 22 | class BlurMask(nn.Module): 23 | def __init__(self, kernel_size=5, width_factor=1): 24 | super().__init__() 25 | self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False) 26 | self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor)) 27 | 28 | def forward(self, real_img, pred_img, mask): 29 | with torch.no_grad(): 30 | result = self.filter(mask) * mask 31 | return result 32 | 33 | 34 | class EmulatedEDTMask(nn.Module): 35 | def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1): 36 | super().__init__() 37 | self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate', 38 | bias=False) 39 | self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float)) 40 | self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False) 41 | self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor)) 42 | 43 | def forward(self, real_img, pred_img, mask): 44 | with torch.no_grad(): 45 | known_mask = 1 - mask 46 | dilated_known_mask = (self.dilate_filter(known_mask) > 1).float() 47 | result = self.blur_filter(1 - dilated_known_mask) * mask 48 | return result 49 | 50 | 51 | class PropagatePerceptualSim(nn.Module): 52 | def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3): 53 | super().__init__() 54 | vgg = torchvision.models.vgg19(pretrained=True).features 55 | vgg_avg_pooling = [] 56 | 57 | for weights in vgg.parameters(): 58 | weights.requires_grad = False 59 | 60 | cur_level_i = 0 61 | for module in vgg.modules(): 62 | if module.__class__.__name__ == 'Sequential': 63 | continue 64 | elif module.__class__.__name__ == 'MaxPool2d': 65 | vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) 66 | else: 67 | vgg_avg_pooling.append(module) 68 | if module.__class__.__name__ == 'ReLU': 69 | cur_level_i += 1 70 | if cur_level_i == level: 71 | break 72 | 73 | self.features = nn.Sequential(*vgg_avg_pooling) 74 | 75 | self.max_iters = max_iters 76 | self.temperature = temperature 77 | self.do_erode = erode_mask_size > 0 78 | if self.do_erode: 79 | self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False) 80 | self.erode_mask.weight.data.fill_(1) 81 | 82 | def forward(self, real_img, pred_img, mask): 83 | with torch.no_grad(): 84 | real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img) 85 | real_feats = self.features(real_img) 86 | 87 | vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True) 88 | / self.temperature) 89 | horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True) 90 | / self.temperature) 91 | 92 | mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False) 93 | if self.do_erode: 94 | mask_scaled = (self.erode_mask(mask_scaled) > 1).float() 95 | 96 | cur_knowness = 1 - mask_scaled 97 | 98 | for iter_i in range(self.max_iters): 99 | new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate') 100 | new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate') 101 | 102 | new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate') 103 | new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate') 104 | 105 | new_knowness = torch.stack([new_top_knowness, new_bottom_knowness, 106 | new_left_knowness, new_right_knowness], 107 | dim=0).max(0).values 108 | 109 | cur_knowness = torch.max(cur_knowness, new_knowness) 110 | 111 | cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear') 112 | result = torch.min(mask, 1 - cur_knowness) 113 | 114 | return result 115 | 116 | 117 | def make_mask_distance_weighter(kind='none', **kwargs): 118 | if kind == 'none': 119 | return dummy_distance_weighter 120 | if kind == 'blur': 121 | return BlurMask(**kwargs) 122 | if kind == 'edt': 123 | return EmulatedEDTMask(**kwargs) 124 | if kind == 'pps': 125 | return PropagatePerceptualSim(**kwargs) 126 | raise ValueError(f'Unknown mask distance weighter kind {kind}') 127 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/feature_matching.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def masked_l2_loss(pred, target, mask, weight_known, weight_missing): 8 | per_pixel_l2 = F.mse_loss(pred, target, reduction='none') 9 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 10 | return (pixel_weights * per_pixel_l2).mean() 11 | 12 | 13 | def masked_l1_loss(pred, target, mask, weight_known, weight_missing): 14 | per_pixel_l1 = F.l1_loss(pred, target, reduction='none') 15 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 16 | return (pixel_weights * per_pixel_l1).mean() 17 | 18 | 19 | def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): 20 | if mask is None: 21 | res = torch.stack([F.mse_loss(fake_feat, target_feat) 22 | for fake_feat, target_feat in zip(fake_features, target_features)]).mean() 23 | else: 24 | res = 0 25 | norm = 0 26 | for fake_feat, target_feat in zip(fake_features, target_features): 27 | cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) 28 | error_weights = 1 - cur_mask 29 | cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() 30 | res = res + cur_val 31 | norm += 1 32 | res = res / norm 33 | return res 34 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | # from models.ade20k import ModelBuilder 7 | from bmab.external.lama.saicinpainting.utils import check_and_warn_input_range 8 | 9 | 10 | IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] 11 | IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] 12 | 13 | 14 | class PerceptualLoss(nn.Module): 15 | def __init__(self, normalize_inputs=True): 16 | super(PerceptualLoss, self).__init__() 17 | 18 | self.normalize_inputs = normalize_inputs 19 | self.mean_ = IMAGENET_MEAN 20 | self.std_ = IMAGENET_STD 21 | 22 | vgg = torchvision.models.vgg19(pretrained=True).features 23 | vgg_avg_pooling = [] 24 | 25 | for weights in vgg.parameters(): 26 | weights.requires_grad = False 27 | 28 | for module in vgg.modules(): 29 | if module.__class__.__name__ == 'Sequential': 30 | continue 31 | elif module.__class__.__name__ == 'MaxPool2d': 32 | vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) 33 | else: 34 | vgg_avg_pooling.append(module) 35 | 36 | self.vgg = nn.Sequential(*vgg_avg_pooling) 37 | 38 | def do_normalize_inputs(self, x): 39 | return (x - self.mean_.to(x.device)) / self.std_.to(x.device) 40 | 41 | def partial_losses(self, input, target, mask=None): 42 | check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses') 43 | 44 | # we expect input and target to be in [0, 1] range 45 | losses = [] 46 | 47 | if self.normalize_inputs: 48 | features_input = self.do_normalize_inputs(input) 49 | features_target = self.do_normalize_inputs(target) 50 | else: 51 | features_input = input 52 | features_target = target 53 | 54 | for layer in self.vgg[:30]: 55 | 56 | features_input = layer(features_input) 57 | features_target = layer(features_target) 58 | 59 | if layer.__class__.__name__ == 'ReLU': 60 | loss = F.mse_loss(features_input, features_target, reduction='none') 61 | 62 | if mask is not None: 63 | cur_mask = F.interpolate(mask, size=features_input.shape[-2:], 64 | mode='bilinear', align_corners=False) 65 | loss = loss * (1 - cur_mask) 66 | 67 | loss = loss.mean(dim=tuple(range(1, len(loss.shape)))) 68 | losses.append(loss) 69 | 70 | return losses 71 | 72 | def forward(self, input, target, mask=None): 73 | losses = self.partial_losses(input, target, mask=mask) 74 | return torch.stack(losses).sum(dim=0) 75 | 76 | def get_global_features(self, input): 77 | check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features') 78 | 79 | if self.normalize_inputs: 80 | features_input = self.do_normalize_inputs(input) 81 | else: 82 | features_input = input 83 | 84 | features_input = self.vgg(features_input) 85 | return features_input 86 | 87 | 88 | class ResNetPL(nn.Module): 89 | def __init__(self, weight=1, 90 | weights_path=None, arch_encoder='resnet50dilated', segmentation=True): 91 | super().__init__() 92 | self.impl = ModelBuilder.get_encoder(weights_path=weights_path, 93 | arch_encoder=arch_encoder, 94 | arch_decoder='ppm_deepsup', 95 | fc_dim=2048, 96 | segmentation=segmentation) 97 | self.impl.eval() 98 | for w in self.impl.parameters(): 99 | w.requires_grad_(False) 100 | 101 | self.weight = weight 102 | 103 | def forward(self, pred, target): 104 | pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) 105 | target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) 106 | 107 | pred_feats = self.impl(pred, return_feature_maps=True) 108 | target_feats = self.impl(target, return_feature_maps=True) 109 | 110 | result = torch.stack([F.mse_loss(cur_pred, cur_target) 111 | for cur_pred, cur_target 112 | in zip(pred_feats, target_feats)]).sum() * self.weight 113 | return result 114 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .constants import weights as constant_weights 6 | 7 | 8 | class CrossEntropy2d(nn.Module): 9 | def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs): 10 | """ 11 | weight (Tensor, optional): a manual rescaling weight given to each class. 12 | If given, has to be a Tensor of size "nclasses" 13 | """ 14 | super(CrossEntropy2d, self).__init__() 15 | self.reduction = reduction 16 | self.ignore_label = ignore_label 17 | self.weights = weights 18 | if self.weights is not None: 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | self.weights = torch.FloatTensor(constant_weights[weights]).to(device) 21 | 22 | def forward(self, predict, target): 23 | """ 24 | Args: 25 | predict:(n, c, h, w) 26 | target:(n, 1, h, w) 27 | """ 28 | target = target.long() 29 | assert not target.requires_grad 30 | assert predict.dim() == 4, "{0}".format(predict.size()) 31 | assert target.dim() == 4, "{0}".format(target.size()) 32 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 33 | assert target.size(1) == 1, "{0}".format(target.size(1)) 34 | assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2)) 35 | assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3)) 36 | target = target.squeeze(1) 37 | n, c, h, w = predict.size() 38 | target_mask = (target >= 0) * (target != self.ignore_label) 39 | target = target[target_mask] 40 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 41 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 42 | loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction) 43 | return loss 44 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/losses/style_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class PerceptualLoss(nn.Module): 7 | r""" 8 | Perceptual loss, VGG-based 9 | https://arxiv.org/abs/1603.08155 10 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 11 | """ 12 | 13 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 14 | super(PerceptualLoss, self).__init__() 15 | self.add_module('vgg', VGG19()) 16 | self.criterion = torch.nn.L1Loss() 17 | self.weights = weights 18 | 19 | def __call__(self, x, y): 20 | # Compute features 21 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 22 | 23 | content_loss = 0.0 24 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 25 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 26 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 27 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 28 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 29 | 30 | 31 | return content_loss 32 | 33 | 34 | class VGG19(torch.nn.Module): 35 | def __init__(self): 36 | super(VGG19, self).__init__() 37 | features = models.vgg19(pretrained=True).features 38 | self.relu1_1 = torch.nn.Sequential() 39 | self.relu1_2 = torch.nn.Sequential() 40 | 41 | self.relu2_1 = torch.nn.Sequential() 42 | self.relu2_2 = torch.nn.Sequential() 43 | 44 | self.relu3_1 = torch.nn.Sequential() 45 | self.relu3_2 = torch.nn.Sequential() 46 | self.relu3_3 = torch.nn.Sequential() 47 | self.relu3_4 = torch.nn.Sequential() 48 | 49 | self.relu4_1 = torch.nn.Sequential() 50 | self.relu4_2 = torch.nn.Sequential() 51 | self.relu4_3 = torch.nn.Sequential() 52 | self.relu4_4 = torch.nn.Sequential() 53 | 54 | self.relu5_1 = torch.nn.Sequential() 55 | self.relu5_2 = torch.nn.Sequential() 56 | self.relu5_3 = torch.nn.Sequential() 57 | self.relu5_4 = torch.nn.Sequential() 58 | 59 | for x in range(2): 60 | self.relu1_1.add_module(str(x), features[x]) 61 | 62 | for x in range(2, 4): 63 | self.relu1_2.add_module(str(x), features[x]) 64 | 65 | for x in range(4, 7): 66 | self.relu2_1.add_module(str(x), features[x]) 67 | 68 | for x in range(7, 9): 69 | self.relu2_2.add_module(str(x), features[x]) 70 | 71 | for x in range(9, 12): 72 | self.relu3_1.add_module(str(x), features[x]) 73 | 74 | for x in range(12, 14): 75 | self.relu3_2.add_module(str(x), features[x]) 76 | 77 | for x in range(14, 16): 78 | self.relu3_2.add_module(str(x), features[x]) 79 | 80 | for x in range(16, 18): 81 | self.relu3_4.add_module(str(x), features[x]) 82 | 83 | for x in range(18, 21): 84 | self.relu4_1.add_module(str(x), features[x]) 85 | 86 | for x in range(21, 23): 87 | self.relu4_2.add_module(str(x), features[x]) 88 | 89 | for x in range(23, 25): 90 | self.relu4_3.add_module(str(x), features[x]) 91 | 92 | for x in range(25, 27): 93 | self.relu4_4.add_module(str(x), features[x]) 94 | 95 | for x in range(27, 30): 96 | self.relu5_1.add_module(str(x), features[x]) 97 | 98 | for x in range(30, 32): 99 | self.relu5_2.add_module(str(x), features[x]) 100 | 101 | for x in range(32, 34): 102 | self.relu5_3.add_module(str(x), features[x]) 103 | 104 | for x in range(34, 36): 105 | self.relu5_4.add_module(str(x), features[x]) 106 | 107 | # don't need the gradients, just want the features 108 | for param in self.parameters(): 109 | param.requires_grad = False 110 | 111 | def forward(self, x): 112 | relu1_1 = self.relu1_1(x) 113 | relu1_2 = self.relu1_2(relu1_1) 114 | 115 | relu2_1 = self.relu2_1(relu1_2) 116 | relu2_2 = self.relu2_2(relu2_1) 117 | 118 | relu3_1 = self.relu3_1(relu2_2) 119 | relu3_2 = self.relu3_2(relu3_1) 120 | relu3_3 = self.relu3_3(relu3_2) 121 | relu3_4 = self.relu3_4(relu3_3) 122 | 123 | relu4_1 = self.relu4_1(relu3_4) 124 | relu4_2 = self.relu4_2(relu4_1) 125 | relu4_3 = self.relu4_3(relu4_2) 126 | relu4_4 = self.relu4_4(relu4_3) 127 | 128 | relu5_1 = self.relu5_1(relu4_4) 129 | relu5_2 = self.relu5_2(relu5_1) 130 | relu5_3 = self.relu5_3(relu5_2) 131 | relu5_4 = self.relu5_4(relu5_3) 132 | 133 | out = { 134 | 'relu1_1': relu1_1, 135 | 'relu1_2': relu1_2, 136 | 137 | 'relu2_1': relu2_1, 138 | 'relu2_2': relu2_2, 139 | 140 | 'relu3_1': relu3_1, 141 | 'relu3_2': relu3_2, 142 | 'relu3_3': relu3_3, 143 | 'relu3_4': relu3_4, 144 | 145 | 'relu4_1': relu4_1, 146 | 'relu4_2': relu4_2, 147 | 'relu4_3': relu4_3, 148 | 'relu4_4': relu4_4, 149 | 150 | 'relu5_1': relu5_1, 151 | 'relu5_2': relu5_2, 152 | 'relu5_3': relu5_3, 153 | 'relu5_4': relu5_4, 154 | } 155 | return out 156 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from bmab.external.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator 4 | from bmab.external.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \ 5 | NLayerDiscriminator, MultidilatedNLayerDiscriminator 6 | 7 | 8 | def make_generator(config, kind, **kwargs): 9 | logging.info(f'Make generator {kind}') 10 | 11 | if kind == 'pix2pixhd_multidilated': 12 | return MultiDilatedGlobalGenerator(**kwargs) 13 | 14 | if kind == 'pix2pixhd_global': 15 | return GlobalGenerator(**kwargs) 16 | 17 | if kind == 'ffc_resnet': 18 | return FFCResNetGenerator(**kwargs) 19 | 20 | raise ValueError(f'Unknown generator kind {kind}') 21 | 22 | 23 | def make_discriminator(kind, **kwargs): 24 | logging.info(f'Make discriminator {kind}') 25 | 26 | if kind == 'pix2pixhd_nlayer_multidilated': 27 | return MultidilatedNLayerDiscriminator(**kwargs) 28 | 29 | if kind == 'pix2pixhd_nlayer': 30 | return NLayerDiscriminator(**kwargs) 31 | 32 | raise ValueError(f'Unknown discriminator kind {kind}') 33 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Tuple, List 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from bmab.external.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 8 | from bmab.external.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv 9 | 10 | 11 | class BaseDiscriminator(nn.Module): 12 | @abc.abstractmethod 13 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 14 | """ 15 | Predict scores and get intermediate activations. Useful for feature matching loss 16 | :return tuple (scores, list of intermediate activations) 17 | """ 18 | raise NotImplemented() 19 | 20 | 21 | def get_conv_block_ctor(kind='default'): 22 | if not isinstance(kind, str): 23 | return kind 24 | if kind == 'default': 25 | return nn.Conv2d 26 | if kind == 'depthwise': 27 | return DepthWiseSeperableConv 28 | if kind == 'multidilated': 29 | return MultidilatedConv 30 | raise ValueError(f'Unknown convolutional block kind {kind}') 31 | 32 | 33 | def get_norm_layer(kind='bn'): 34 | if not isinstance(kind, str): 35 | return kind 36 | if kind == 'bn': 37 | return nn.BatchNorm2d 38 | if kind == 'in': 39 | return nn.InstanceNorm2d 40 | raise ValueError(f'Unknown norm block kind {kind}') 41 | 42 | 43 | def get_activation(kind='tanh'): 44 | if kind == 'tanh': 45 | return nn.Tanh() 46 | if kind == 'sigmoid': 47 | return nn.Sigmoid() 48 | if kind is False: 49 | return nn.Identity() 50 | raise ValueError(f'Unknown activation kind {kind}') 51 | 52 | 53 | class SimpleMultiStepGenerator(nn.Module): 54 | def __init__(self, steps: List[nn.Module]): 55 | super().__init__() 56 | self.steps = nn.ModuleList(steps) 57 | 58 | def forward(self, x): 59 | cur_in = x 60 | outs = [] 61 | for step in self.steps: 62 | cur_out = step(cur_in) 63 | outs.append(cur_out) 64 | cur_in = torch.cat((cur_in, cur_out), dim=1) 65 | return torch.cat(outs[::-1], dim=1) 66 | 67 | def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): 68 | if kind == 'convtranspose': 69 | return [nn.ConvTranspose2d(min(max_features, ngf * mult), 70 | min(max_features, int(ngf * mult / 2)), 71 | kernel_size=3, stride=2, padding=1, output_padding=1), 72 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 73 | elif kind == 'bilinear': 74 | return [nn.Upsample(scale_factor=2, mode='bilinear'), 75 | DepthWiseSeperableConv(min(max_features, ngf * mult), 76 | min(max_features, int(ngf * mult / 2)), 77 | kernel_size=3, stride=1, padding=1), 78 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 79 | else: 80 | raise Exception(f"Invalid deconv kind: {kind}") -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/depthwise_sep_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DepthWiseSeperableConv(nn.Module): 5 | def __init__(self, in_dim, out_dim, *args, **kwargs): 6 | super().__init__() 7 | if 'groups' in kwargs: 8 | # ignoring groups for Depthwise Sep Conv 9 | del kwargs['groups'] 10 | 11 | self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) 12 | self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) 13 | 14 | def forward(self, x): 15 | out = self.depthwise(x) 16 | out = self.pointwise(out) 17 | return out -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/fake_fakes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import SamplePadding 3 | from kornia.augmentation import RandomAffine, CenterCrop 4 | 5 | 6 | class FakeFakesGenerator: 7 | def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2): 8 | self.grad_aug = RandomAffine(degrees=360, 9 | translate=0.2, 10 | padding_mode=SamplePadding.REFLECTION, 11 | keepdim=False, 12 | p=1) 13 | self.img_aug = RandomAffine(degrees=img_aug_degree, 14 | translate=img_aug_translate, 15 | padding_mode=SamplePadding.REFLECTION, 16 | keepdim=True, 17 | p=1) 18 | self.aug_proba = aug_proba 19 | 20 | def __call__(self, input_images, masks): 21 | blend_masks = self._fill_masks_with_gradient(masks) 22 | blend_target = self._make_blend_target(input_images) 23 | result = input_images * (1 - blend_masks) + blend_target * blend_masks 24 | return result, blend_masks 25 | 26 | def _make_blend_target(self, input_images): 27 | batch_size = input_images.shape[0] 28 | permuted = input_images[torch.randperm(batch_size)] 29 | augmented = self.img_aug(input_images) 30 | is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float() 31 | result = augmented * is_aug + permuted * (1 - is_aug) 32 | return result 33 | 34 | def _fill_masks_with_gradient(self, masks): 35 | batch_size, _, height, width = masks.shape 36 | grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \ 37 | .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2) 38 | grad = self.grad_aug(grad) 39 | grad = CenterCrop((height, width))(grad) 40 | grad *= masks 41 | 42 | grad_for_min = grad + (1 - masks) * 10 43 | grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None] 44 | grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6 45 | grad.clamp_(min=0, max=1) 46 | 47 | return grad 48 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/multidilated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from bmab.external.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 5 | 6 | class MultidilatedConv(nn.Module): 7 | def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True, 8 | shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): 9 | super().__init__() 10 | convs = [] 11 | self.equal_dim = equal_dim 12 | assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode 13 | if comb_mode in ('cat_out', 'cat_both'): 14 | self.cat_out = True 15 | if equal_dim: 16 | assert out_dim % dilation_num == 0 17 | out_dims = [out_dim // dilation_num] * dilation_num 18 | self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) 19 | else: 20 | out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 21 | out_dims.append(out_dim - sum(out_dims)) 22 | index = [] 23 | starts = [0] + out_dims[:-1] 24 | lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] 25 | for i in range(out_dims[-1]): 26 | for j in range(dilation_num): 27 | index += list(range(starts[j], starts[j] + lengths[j])) 28 | starts[j] += lengths[j] 29 | self.index = index 30 | assert(len(index) == out_dim) 31 | self.out_dims = out_dims 32 | else: 33 | self.cat_out = False 34 | self.out_dims = [out_dim] * dilation_num 35 | 36 | if comb_mode in ('cat_in', 'cat_both'): 37 | if equal_dim: 38 | assert in_dim % dilation_num == 0 39 | in_dims = [in_dim // dilation_num] * dilation_num 40 | else: 41 | in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 42 | in_dims.append(in_dim - sum(in_dims)) 43 | self.in_dims = in_dims 44 | self.cat_in = True 45 | else: 46 | self.cat_in = False 47 | self.in_dims = [in_dim] * dilation_num 48 | 49 | conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d 50 | dilation = min_dilation 51 | for i in range(dilation_num): 52 | if isinstance(padding, int): 53 | cur_padding = padding * dilation 54 | else: 55 | cur_padding = padding[i] 56 | convs.append(conv_type( 57 | self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs 58 | )) 59 | if i > 0 and shared_weights: 60 | convs[-1].weight = convs[0].weight 61 | convs[-1].bias = convs[0].bias 62 | dilation *= 2 63 | self.convs = nn.ModuleList(convs) 64 | 65 | self.shuffle_in_channels = shuffle_in_channels 66 | if self.shuffle_in_channels: 67 | # shuffle list as shuffling of tensors is nondeterministic 68 | in_channels_permute = list(range(in_dim)) 69 | random.shuffle(in_channels_permute) 70 | # save as buffer so it is saved and loaded with checkpoint 71 | self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) 72 | 73 | def forward(self, x): 74 | if self.shuffle_in_channels: 75 | x = x[:, self.in_channels_permute] 76 | 77 | outs = [] 78 | if self.cat_in: 79 | if self.equal_dim: 80 | x = x.chunk(len(self.convs), dim=1) 81 | else: 82 | new_x = [] 83 | start = 0 84 | for dim in self.in_dims: 85 | new_x.append(x[:, start:start+dim]) 86 | start += dim 87 | x = new_x 88 | for i, conv in enumerate(self.convs): 89 | if self.cat_in: 90 | input = x[i] 91 | else: 92 | input = x 93 | outs.append(conv(input)) 94 | if self.cat_out: 95 | out = torch.cat(outs, dim=1)[:, self.index] 96 | else: 97 | out = sum(outs) 98 | return out 99 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/multiscale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from bmab.external.lama.saicinpainting.training.modules.base import get_conv_block_ctor, get_activation 8 | from bmab.external.lama.saicinpainting.training.modules.pix2pixhd import ResnetBlock 9 | 10 | 11 | class ResNetHead(nn.Module): 12 | def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 13 | padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)): 14 | assert (n_blocks >= 0) 15 | super(ResNetHead, self).__init__() 16 | 17 | conv_layer = get_conv_block_ctor(conv_kind) 18 | 19 | model = [nn.ReflectionPad2d(3), 20 | conv_layer(input_nc, ngf, kernel_size=7, padding=0), 21 | norm_layer(ngf), 22 | activation] 23 | 24 | ### downsample 25 | for i in range(n_downsampling): 26 | mult = 2 ** i 27 | model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 28 | norm_layer(ngf * mult * 2), 29 | activation] 30 | 31 | mult = 2 ** n_downsampling 32 | 33 | ### resnet blocks 34 | for i in range(n_blocks): 35 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, 36 | conv_kind=conv_kind)] 37 | 38 | self.model = nn.Sequential(*model) 39 | 40 | def forward(self, input): 41 | return self.model(input) 42 | 43 | 44 | class ResNetTail(nn.Module): 45 | def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 46 | padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), 47 | up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0, 48 | add_in_proj=None): 49 | assert (n_blocks >= 0) 50 | super(ResNetTail, self).__init__() 51 | 52 | mult = 2 ** n_downsampling 53 | 54 | model = [] 55 | 56 | if add_in_proj is not None: 57 | model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1)) 58 | 59 | ### resnet blocks 60 | for i in range(n_blocks): 61 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, 62 | conv_kind=conv_kind)] 63 | 64 | ### upsample 65 | for i in range(n_downsampling): 66 | mult = 2 ** (n_downsampling - i) 67 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, 68 | output_padding=1), 69 | up_norm_layer(int(ngf * mult / 2)), 70 | up_activation] 71 | self.model = nn.Sequential(*model) 72 | 73 | out_layers = [] 74 | for _ in range(out_extra_layers_n): 75 | out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0), 76 | up_norm_layer(ngf), 77 | up_activation] 78 | out_layers += [nn.ReflectionPad2d(3), 79 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 80 | 81 | if add_out_act: 82 | out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act)) 83 | 84 | self.out_proj = nn.Sequential(*out_layers) 85 | 86 | def forward(self, input, return_last_act=False): 87 | features = self.model(input) 88 | out = self.out_proj(features) 89 | if return_last_act: 90 | return out, features 91 | else: 92 | return out 93 | 94 | 95 | class MultiscaleResNet(nn.Module): 96 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3, 97 | norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), 98 | up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0, 99 | out_cumulative=False, return_only_hr=False): 100 | super().__init__() 101 | 102 | self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling, 103 | n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type, 104 | conv_kind=conv_kind, activation=activation) 105 | for i in range(n_scales)]) 106 | tail_in_feats = ngf * (2 ** n_downsampling) + ngf 107 | self.tails = nn.ModuleList([ResNetTail(output_nc, 108 | ngf=ngf, n_downsampling=n_downsampling, 109 | n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type, 110 | conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer, 111 | up_activation=up_activation, add_out_act=add_out_act, 112 | out_extra_layers_n=out_extra_layers_n, 113 | add_in_proj=None if (i == n_scales - 1) else tail_in_feats) 114 | for i in range(n_scales)]) 115 | 116 | self.out_cumulative = out_cumulative 117 | self.return_only_hr = return_only_hr 118 | 119 | @property 120 | def num_scales(self): 121 | return len(self.heads) 122 | 123 | def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \ 124 | -> Union[torch.Tensor, List[torch.Tensor]]: 125 | """ 126 | :param ms_inputs: List of inputs of different resolutions from HR to LR 127 | :param smallest_scales_num: int or None, number of smallest scales to take at input 128 | :return: Depending on return_only_hr: 129 | True: Only the most HR output 130 | False: List of outputs of different resolutions from HR to LR 131 | """ 132 | if smallest_scales_num is None: 133 | assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num) 134 | smallest_scales_num = len(self.heads) 135 | else: 136 | assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num) 137 | 138 | cur_heads = self.heads[-smallest_scales_num:] 139 | ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)] 140 | 141 | all_outputs = [] 142 | prev_tail_features = None 143 | for i in range(len(ms_features)): 144 | scale_i = -i - 1 145 | 146 | cur_tail_input = ms_features[-i - 1] 147 | if prev_tail_features is not None: 148 | if prev_tail_features.shape != cur_tail_input.shape: 149 | prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:], 150 | mode='bilinear', align_corners=False) 151 | cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1) 152 | 153 | cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True) 154 | 155 | prev_tail_features = cur_tail_feats 156 | all_outputs.append(cur_out) 157 | 158 | if self.out_cumulative: 159 | all_outputs_cum = [all_outputs[0]] 160 | for i in range(1, len(ms_features)): 161 | cur_out = all_outputs[i] 162 | cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:], 163 | mode='bilinear', align_corners=False) 164 | all_outputs_cum.append(cur_out_cum) 165 | all_outputs = all_outputs_cum 166 | 167 | if self.return_only_hr: 168 | return all_outputs[-1] 169 | else: 170 | return all_outputs[::-1] 171 | 172 | 173 | class MultiscaleDiscriminatorSimple(nn.Module): 174 | def __init__(self, ms_impl): 175 | super().__init__() 176 | self.ms_impl = nn.ModuleList(ms_impl) 177 | 178 | @property 179 | def num_scales(self): 180 | return len(self.ms_impl) 181 | 182 | def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \ 183 | -> List[Tuple[torch.Tensor, List[torch.Tensor]]]: 184 | """ 185 | :param ms_inputs: List of inputs of different resolutions from HR to LR 186 | :param smallest_scales_num: int or None, number of smallest scales to take at input 187 | :return: List of pairs (prediction, features) for different resolutions from HR to LR 188 | """ 189 | if smallest_scales_num is None: 190 | assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num) 191 | smallest_scales_num = len(self.heads) 192 | else: 193 | assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \ 194 | (len(self.ms_impl), len(ms_inputs), smallest_scales_num) 195 | 196 | return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)] 197 | 198 | 199 | class SingleToMultiScaleInputMixin: 200 | def forward(self, x: torch.Tensor) -> List: 201 | orig_height, orig_width = x.shape[2:] 202 | factors = [2 ** i for i in range(self.num_scales)] 203 | ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False) 204 | for f in factors] 205 | return super().forward(ms_inputs) 206 | 207 | 208 | class GeneratorMultiToSingleOutputMixin: 209 | def forward(self, x): 210 | return super().forward(x)[0] 211 | 212 | 213 | class DiscriminatorMultiToSingleOutputMixin: 214 | def forward(self, x): 215 | out_feat_tuples = super().forward(x) 216 | return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist] 217 | 218 | 219 | class DiscriminatorMultiToSingleOutputStackedMixin: 220 | def __init__(self, *args, return_feats_only_levels=None, **kwargs): 221 | super().__init__(*args, **kwargs) 222 | self.return_feats_only_levels = return_feats_only_levels 223 | 224 | def forward(self, x): 225 | out_feat_tuples = super().forward(x) 226 | outs = [out for out, _ in out_feat_tuples] 227 | scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:], 228 | mode='bilinear', align_corners=False) 229 | for cur_out in outs[1:]] 230 | out = torch.cat(scaled_outs, dim=1) 231 | if self.return_feats_only_levels is not None: 232 | feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels] 233 | else: 234 | feat_lists = [flist for _, flist in out_feat_tuples] 235 | feats = [f for flist in feat_lists for f in flist] 236 | return out, feats 237 | 238 | 239 | class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple): 240 | pass 241 | 242 | 243 | class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet): 244 | pass 245 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/spatial_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from kornia.geometry.transform import rotate 5 | 6 | 7 | class LearnableSpatialTransformWrapper(nn.Module): 8 | def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True): 9 | super().__init__() 10 | self.impl = impl 11 | self.angle = torch.rand(1) * angle_init_range 12 | if train_angle: 13 | self.angle = nn.Parameter(self.angle, requires_grad=True) 14 | self.pad_coef = pad_coef 15 | 16 | def forward(self, x): 17 | if torch.is_tensor(x): 18 | return self.inverse_transform(self.impl(self.transform(x)), x) 19 | elif isinstance(x, tuple): 20 | x_trans = tuple(self.transform(elem) for elem in x) 21 | y_trans = self.impl(x_trans) 22 | return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)) 23 | else: 24 | raise ValueError(f'Unexpected input type {type(x)}') 25 | 26 | def transform(self, x): 27 | height, width = x.shape[2:] 28 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 29 | x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') 30 | x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) 31 | return x_padded_rotated 32 | 33 | def inverse_transform(self, y_padded_rotated, orig_x): 34 | height, width = orig_x.shape[2:] 35 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 36 | 37 | y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) 38 | y_height, y_width = y_padded.shape[2:] 39 | y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w] 40 | return y 41 | 42 | 43 | if __name__ == '__main__': 44 | layer = LearnableSpatialTransformWrapper(nn.Identity()) 45 | x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float() 46 | y = layer(x) 47 | assert x.shape == y.shape 48 | assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1]) 49 | print('all ok') 50 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/modules/squeeze_excitation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | res = x * y.expand_as(x) 20 | return res 21 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from bmab.external.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule 4 | 5 | 6 | def get_training_model_class(kind): 7 | if kind == 'default': 8 | return DefaultInpaintingTrainingModule 9 | 10 | raise ValueError(f'Unknown trainer module {kind}') 11 | 12 | 13 | def make_training_model(config): 14 | kind = config.training_model.kind 15 | kwargs = dict(config.training_model) 16 | kwargs.pop('kind') 17 | kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' 18 | 19 | logging.info(f'Make training model {kind}') 20 | 21 | cls = get_training_model_class(kind) 22 | return cls(config, **kwargs) 23 | 24 | 25 | def load_checkpoint(train_config, path, map_location='cuda', strict=True): 26 | model = make_training_model(train_config).generator 27 | state = torch.load(path, map_location=map_location) 28 | model.load_state_dict(state, strict=strict) 29 | return model 30 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/trainers/default.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from omegaconf import OmegaConf 6 | 7 | # from bmab.external.lama.saicinpainting.training.data.datasets import make_constant_area_crop_params 8 | from bmab.external.lama.saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter 9 | from bmab.external.lama.saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss 10 | # from bmab.external.lama.saicinpainting.training.modules.fake_fakes import FakeFakesGenerator 11 | from bmab.external.lama.saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise 12 | from bmab.external.lama.saicinpainting.utils import add_prefix_to_keys, get_ramp 13 | 14 | LOGGER = logging.getLogger(__name__) 15 | 16 | 17 | def make_constant_area_crop_batch(batch, **kwargs): 18 | crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2], 19 | img_width=batch['image'].shape[3], 20 | **kwargs) 21 | batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width] 22 | batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width] 23 | return batch 24 | 25 | 26 | class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): 27 | def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image', 28 | add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None, 29 | distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False, 30 | fake_fakes_proba=0, fake_fakes_generator_kwargs=None, 31 | **kwargs): 32 | super().__init__(*args, **kwargs) 33 | self.concat_mask = concat_mask 34 | self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None 35 | self.image_to_discriminator = image_to_discriminator 36 | self.add_noise_kwargs = add_noise_kwargs 37 | self.noise_fill_hole = noise_fill_hole 38 | self.const_area_crop_kwargs = const_area_crop_kwargs 39 | self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \ 40 | if distance_weighter_kwargs is not None else None 41 | self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr 42 | 43 | self.fake_fakes_proba = fake_fakes_proba 44 | if self.fake_fakes_proba > 1e-3: 45 | self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {})) 46 | 47 | def forward(self, batch): 48 | if self.training and self.rescale_size_getter is not None: 49 | cur_size = self.rescale_size_getter(self.global_step) 50 | batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False) 51 | batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest') 52 | 53 | if self.training and self.const_area_crop_kwargs is not None: 54 | batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs) 55 | 56 | img = batch['image'] 57 | mask = batch['mask'] 58 | 59 | masked_img = img * (1 - mask) 60 | 61 | if self.add_noise_kwargs is not None: 62 | noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs) 63 | if self.noise_fill_hole: 64 | masked_img = masked_img + mask * noise[:, :masked_img.shape[1]] 65 | masked_img = torch.cat([masked_img, noise], dim=1) 66 | 67 | if self.concat_mask: 68 | masked_img = torch.cat([masked_img, mask], dim=1) 69 | 70 | batch['predicted_image'] = self.generator(masked_img) 71 | batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image'] 72 | 73 | if self.fake_fakes_proba > 1e-3: 74 | if self.training and torch.rand(1).item() < self.fake_fakes_proba: 75 | batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask) 76 | batch['use_fake_fakes'] = True 77 | else: 78 | batch['fake_fakes'] = torch.zeros_like(img) 79 | batch['fake_fakes_masks'] = torch.zeros_like(mask) 80 | batch['use_fake_fakes'] = False 81 | 82 | batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \ 83 | if self.refine_mask_for_losses is not None and self.training \ 84 | else mask 85 | 86 | return batch 87 | 88 | def generator_loss(self, batch): 89 | img = batch['image'] 90 | predicted_img = batch[self.image_to_discriminator] 91 | original_mask = batch['mask'] 92 | supervised_mask = batch['mask_for_losses'] 93 | 94 | # L1 95 | l1_value = masked_l1_loss(predicted_img, img, supervised_mask, 96 | self.config.losses.l1.weight_known, 97 | self.config.losses.l1.weight_missing) 98 | 99 | total_loss = l1_value 100 | metrics = dict(gen_l1=l1_value) 101 | 102 | # vgg-based perceptual loss 103 | if self.config.losses.perceptual.weight > 0: 104 | pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight 105 | total_loss = total_loss + pl_value 106 | metrics['gen_pl'] = pl_value 107 | 108 | # discriminator 109 | # adversarial_loss calls backward by itself 110 | mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask 111 | self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img, 112 | generator=self.generator, discriminator=self.discriminator) 113 | discr_real_pred, discr_real_features = self.discriminator(img) 114 | discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) 115 | adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img, 116 | fake_batch=predicted_img, 117 | discr_real_pred=discr_real_pred, 118 | discr_fake_pred=discr_fake_pred, 119 | mask=mask_for_discr) 120 | total_loss = total_loss + adv_gen_loss 121 | metrics['gen_adv'] = adv_gen_loss 122 | metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) 123 | 124 | # feature matching 125 | if self.config.losses.feature_matching.weight > 0: 126 | need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False) 127 | mask_for_fm = supervised_mask if need_mask_in_fm else None 128 | fm_value = feature_matching_loss(discr_fake_features, discr_real_features, 129 | mask=mask_for_fm) * self.config.losses.feature_matching.weight 130 | total_loss = total_loss + fm_value 131 | metrics['gen_fm'] = fm_value 132 | 133 | if self.loss_resnet_pl is not None: 134 | resnet_pl_value = self.loss_resnet_pl(predicted_img, img) 135 | total_loss = total_loss + resnet_pl_value 136 | metrics['gen_resnet_pl'] = resnet_pl_value 137 | 138 | return total_loss, metrics 139 | 140 | def discriminator_loss(self, batch): 141 | total_loss = 0 142 | metrics = {} 143 | 144 | predicted_img = batch[self.image_to_discriminator].detach() 145 | self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img, 146 | generator=self.generator, discriminator=self.discriminator) 147 | discr_real_pred, discr_real_features = self.discriminator(batch['image']) 148 | discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) 149 | adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'], 150 | fake_batch=predicted_img, 151 | discr_real_pred=discr_real_pred, 152 | discr_fake_pred=discr_fake_pred, 153 | mask=batch['mask']) 154 | total_loss = total_loss + adv_discr_loss 155 | metrics['discr_adv'] = adv_discr_loss 156 | metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) 157 | 158 | 159 | if batch.get('use_fake_fakes', False): 160 | fake_fakes = batch['fake_fakes'] 161 | self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes, 162 | generator=self.generator, discriminator=self.discriminator) 163 | discr_fake_fakes_pred, _ = self.discriminator(fake_fakes) 164 | fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss( 165 | real_batch=batch['image'], 166 | fake_batch=fake_fakes, 167 | discr_real_pred=discr_real_pred, 168 | discr_fake_pred=discr_fake_fakes_pred, 169 | mask=batch['mask'] 170 | ) 171 | total_loss = total_loss + fake_fakes_adv_discr_loss 172 | metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss 173 | metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_')) 174 | 175 | return total_loss, metrics 176 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from bmab.external.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer 4 | from bmab.external.lama.saicinpainting.training.visualizers.noop import NoopVisualizer 5 | 6 | 7 | def make_visualizer(kind, **kwargs): 8 | logging.info(f'Make visualizer {kind}') 9 | 10 | if kind == 'directory': 11 | return DirectoryVisualizer(**kwargs) 12 | if kind == 'noop': 13 | return NoopVisualizer() 14 | 15 | raise ValueError(f'Unknown visualizer kind {kind}') 16 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/visualizers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | import torch 6 | from skimage import color 7 | from skimage.segmentation import mark_boundaries 8 | 9 | from . import colors 10 | 11 | COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation 12 | 13 | 14 | class BaseVisualizer: 15 | @abc.abstractmethod 16 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 17 | """ 18 | Take a batch, make an image from it and visualize 19 | """ 20 | raise NotImplementedError() 21 | 22 | 23 | def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str], 24 | last_without_mask=True, rescale_keys=None, mask_only_first=None, 25 | black_mask=False) -> np.ndarray: 26 | mask = images_dict['mask'] > 0.5 27 | result = [] 28 | for i, k in enumerate(keys): 29 | img = images_dict[k] 30 | img = np.transpose(img, (1, 2, 0)) 31 | 32 | if rescale_keys is not None and k in rescale_keys: 33 | img = img - img.min() 34 | img /= img.max() + 1e-5 35 | if len(img.shape) == 2: 36 | img = np.expand_dims(img, 2) 37 | 38 | if img.shape[2] == 1: 39 | img = np.repeat(img, 3, axis=2) 40 | elif (img.shape[2] > 3): 41 | img_classes = img.argmax(2) 42 | img = color.label2rgb(img_classes, colors=COLORS) 43 | 44 | if mask_only_first: 45 | need_mark_boundaries = i == 0 46 | else: 47 | need_mark_boundaries = i < len(keys) - 1 or not last_without_mask 48 | 49 | if need_mark_boundaries: 50 | if black_mask: 51 | img = img * (1 - mask[0][..., None]) 52 | img = mark_boundaries(img, 53 | mask[0], 54 | color=(1., 0., 0.), 55 | outline_color=(1., 1., 1.), 56 | mode='thick') 57 | result.append(img) 58 | return np.concatenate(result, axis=1) 59 | 60 | 61 | def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, 62 | last_without_mask=True, rescale_keys=None) -> np.ndarray: 63 | batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() 64 | if k in keys or k == 'mask'} 65 | 66 | batch_size = next(iter(batch.values())).shape[0] 67 | items_to_vis = min(batch_size, max_items) 68 | result = [] 69 | for i in range(items_to_vis): 70 | cur_dct = {k: tens[i] for k, tens in batch.items()} 71 | result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, 72 | rescale_keys=rescale_keys)) 73 | return np.concatenate(result, axis=0) 74 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/visualizers/colors.py: -------------------------------------------------------------------------------- 1 | import random 2 | import colorsys 3 | 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | from matplotlib.colors import LinearSegmentedColormap 9 | 10 | 11 | def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False): 12 | # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib 13 | """ 14 | Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks 15 | :param nlabels: Number of labels (size of colormap) 16 | :param type: 'bright' for strong colors, 'soft' for pastel colors 17 | :param first_color_black: Option to use first color as black, True or False 18 | :param last_color_black: Option to use last color as black, True or False 19 | :param verbose: Prints the number of labels and shows the colormap. True or False 20 | :return: colormap for matplotlib 21 | """ 22 | if type not in ('bright', 'soft'): 23 | print ('Please choose "bright" or "soft" for type') 24 | return 25 | 26 | if verbose: 27 | print('Number of labels: ' + str(nlabels)) 28 | 29 | # Generate color map for bright colors, based on hsv 30 | if type == 'bright': 31 | randHSVcolors = [(np.random.uniform(low=0.0, high=1), 32 | np.random.uniform(low=0.2, high=1), 33 | np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] 34 | 35 | # Convert HSV list to RGB 36 | randRGBcolors = [] 37 | for HSVcolor in randHSVcolors: 38 | randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) 39 | 40 | if first_color_black: 41 | randRGBcolors[0] = [0, 0, 0] 42 | 43 | if last_color_black: 44 | randRGBcolors[-1] = [0, 0, 0] 45 | 46 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 47 | 48 | # Generate soft pastel colors, by limiting the RGB spectrum 49 | if type == 'soft': 50 | low = 0.6 51 | high = 0.95 52 | randRGBcolors = [(np.random.uniform(low=low, high=high), 53 | np.random.uniform(low=low, high=high), 54 | np.random.uniform(low=low, high=high)) for i in range(nlabels)] 55 | 56 | if first_color_black: 57 | randRGBcolors[0] = [0, 0, 0] 58 | 59 | if last_color_black: 60 | randRGBcolors[-1] = [0, 0, 0] 61 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 62 | 63 | # Display colorbar 64 | if verbose: 65 | from matplotlib import colors, colorbar 66 | from matplotlib import pyplot as plt 67 | fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) 68 | 69 | bounds = np.linspace(0, nlabels, nlabels + 1) 70 | norm = colors.BoundaryNorm(bounds, nlabels) 71 | 72 | cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, 73 | boundaries=bounds, format='%1i', orientation=u'horizontal') 74 | 75 | return randRGBcolors, random_colormap 76 | 77 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/visualizers/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from bmab.external.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch 7 | from bmab.external.lama.saicinpainting.utils import check_and_warn_input_range 8 | 9 | 10 | class DirectoryVisualizer(BaseVisualizer): 11 | DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ') 12 | 13 | def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10, 14 | last_without_mask=True, rescale_keys=None): 15 | self.outdir = outdir 16 | os.makedirs(self.outdir, exist_ok=True) 17 | self.key_order = key_order 18 | self.max_items_in_batch = max_items_in_batch 19 | self.last_without_mask = last_without_mask 20 | self.rescale_keys = rescale_keys 21 | 22 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 23 | check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image') 24 | vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch, 25 | last_without_mask=self.last_without_mask, 26 | rescale_keys=self.rescale_keys) 27 | 28 | vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8') 29 | 30 | curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}') 31 | os.makedirs(curoutdir, exist_ok=True) 32 | rank_suffix = f'_r{rank}' if rank is not None else '' 33 | out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg') 34 | 35 | vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) 36 | cv2.imwrite(out_fname, vis_img) 37 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/training/visualizers/noop.py: -------------------------------------------------------------------------------- 1 | from bmab.external.lama.saicinpainting.training.visualizers.base import BaseVisualizer 2 | 3 | 4 | class NoopVisualizer(BaseVisualizer): 5 | def __init__(self, *args, **kwargs): 6 | pass 7 | 8 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 9 | pass 10 | -------------------------------------------------------------------------------- /bmab/external/lama/saicinpainting/utils.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import functools 3 | import logging 4 | import numbers 5 | import os 6 | import signal 7 | import sys 8 | import traceback 9 | import warnings 10 | 11 | import torch 12 | from pytorch_lightning import seed_everything 13 | 14 | LOGGER = logging.getLogger(__name__) 15 | 16 | 17 | def check_and_warn_input_range(tensor, min_value, max_value, name): 18 | actual_min = tensor.min() 19 | actual_max = tensor.max() 20 | if actual_min < min_value or actual_max > max_value: 21 | warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}") 22 | 23 | 24 | def sum_dict_with_prefix(target, cur_dict, prefix, default=0): 25 | for k, v in cur_dict.items(): 26 | target_key = prefix + k 27 | target[target_key] = target.get(target_key, default) + v 28 | 29 | 30 | def average_dicts(dict_list): 31 | result = {} 32 | norm = 1e-3 33 | for dct in dict_list: 34 | sum_dict_with_prefix(result, dct, '') 35 | norm += 1 36 | for k in list(result): 37 | result[k] /= norm 38 | return result 39 | 40 | 41 | def add_prefix_to_keys(dct, prefix): 42 | return {prefix + k: v for k, v in dct.items()} 43 | 44 | 45 | def set_requires_grad(module, value): 46 | for param in module.parameters(): 47 | param.requires_grad = value 48 | 49 | 50 | def flatten_dict(dct): 51 | result = {} 52 | for k, v in dct.items(): 53 | if isinstance(k, tuple): 54 | k = '_'.join(k) 55 | if isinstance(v, dict): 56 | for sub_k, sub_v in flatten_dict(v).items(): 57 | result[f'{k}_{sub_k}'] = sub_v 58 | else: 59 | result[k] = v 60 | return result 61 | 62 | 63 | class LinearRamp: 64 | def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): 65 | self.start_value = start_value 66 | self.end_value = end_value 67 | self.start_iter = start_iter 68 | self.end_iter = end_iter 69 | 70 | def __call__(self, i): 71 | if i < self.start_iter: 72 | return self.start_value 73 | if i >= self.end_iter: 74 | return self.end_value 75 | part = (i - self.start_iter) / (self.end_iter - self.start_iter) 76 | return self.start_value * (1 - part) + self.end_value * part 77 | 78 | 79 | class LadderRamp: 80 | def __init__(self, start_iters, values): 81 | self.start_iters = start_iters 82 | self.values = values 83 | assert len(values) == len(start_iters) + 1, (len(values), len(start_iters)) 84 | 85 | def __call__(self, i): 86 | segment_i = bisect.bisect_right(self.start_iters, i) 87 | return self.values[segment_i] 88 | 89 | 90 | def get_ramp(kind='ladder', **kwargs): 91 | if kind == 'linear': 92 | return LinearRamp(**kwargs) 93 | if kind == 'ladder': 94 | return LadderRamp(**kwargs) 95 | raise ValueError(f'Unexpected ramp kind: {kind}') 96 | 97 | 98 | def print_traceback_handler(sig, frame): 99 | LOGGER.warning(f'Received signal {sig}') 100 | bt = ''.join(traceback.format_stack()) 101 | LOGGER.warning(f'Requested stack trace:\n{bt}') 102 | 103 | 104 | def register_debug_signal_handlers(sig=None, handler=print_traceback_handler): 105 | LOGGER.warning(f'Setting signal {sig} handler {handler}') 106 | signal.signal(sig, handler) 107 | 108 | 109 | def handle_deterministic_config(config): 110 | seed = dict(config).get('seed', None) 111 | if seed is None: 112 | return False 113 | 114 | seed_everything(seed) 115 | return True 116 | 117 | 118 | def get_shape(t): 119 | if torch.is_tensor(t): 120 | return tuple(t.shape) 121 | elif isinstance(t, dict): 122 | return {n: get_shape(q) for n, q in t.items()} 123 | elif isinstance(t, (list, tuple)): 124 | return [get_shape(q) for q in t] 125 | elif isinstance(t, numbers.Number): 126 | return type(t) 127 | else: 128 | raise ValueError('unexpected type {}'.format(type(t))) 129 | 130 | 131 | def get_has_ddp_rank(): 132 | master_port = os.environ.get('MASTER_PORT', None) 133 | node_rank = os.environ.get('NODE_RANK', None) 134 | local_rank = os.environ.get('LOCAL_RANK', None) 135 | world_size = os.environ.get('WORLD_SIZE', None) 136 | has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None 137 | return has_rank 138 | 139 | 140 | def handle_ddp_subprocess(): 141 | def main_decorator(main_func): 142 | @functools.wraps(main_func) 143 | def new_main(*args, **kwargs): 144 | # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE 145 | parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) 146 | has_parent = parent_cwd is not None 147 | has_rank = get_has_ddp_rank() 148 | assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' 149 | 150 | if has_parent: 151 | # we are in the worker 152 | sys.argv.extend([ 153 | f'hydra.run.dir={parent_cwd}', 154 | # 'hydra/hydra_logging=disabled', 155 | # 'hydra/job_logging=disabled' 156 | ]) 157 | # do nothing if this is a top-level process 158 | # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization 159 | 160 | main_func(*args, **kwargs) 161 | return new_main 162 | return main_decorator 163 | 164 | 165 | def handle_ddp_parent_process(): 166 | parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) 167 | has_parent = parent_cwd is not None 168 | has_rank = get_has_ddp_rank() 169 | assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' 170 | 171 | if parent_cwd is None: 172 | os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd() 173 | 174 | return has_parent 175 | -------------------------------------------------------------------------------- /bmab/external/rmbg14/MyConfig.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | from typing import List 3 | 4 | class RMBGConfig(PretrainedConfig): 5 | model_type = "SegformerForSemanticSegmentation" 6 | def __init__( 7 | self, 8 | in_ch=3, 9 | out_ch=1, 10 | **kwargs): 11 | self.in_ch = in_ch 12 | self.out_ch = out_ch 13 | super().__init__(**kwargs) 14 | -------------------------------------------------------------------------------- /bmab/external/rmbg14/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/bmab/external/rmbg14/__init__.py -------------------------------------------------------------------------------- /bmab/external/rmbg14/utilities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision.transforms.functional import normalize 4 | import numpy as np 5 | 6 | def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor: 7 | if len(im.shape) < 3: 8 | im = im[:, :, np.newaxis] 9 | # orig_im_size=im.shape[0:2] 10 | im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1) 11 | im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8) 12 | image = torch.divide(im_tensor,255.0) 13 | image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0]) 14 | return image 15 | 16 | 17 | def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray: 18 | result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0) 19 | ma = torch.max(result) 20 | mi = torch.min(result) 21 | result = (result-mi)/(ma-mi) 22 | im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8) 23 | im_array = np.squeeze(im_array) 24 | return im_array 25 | -------------------------------------------------------------------------------- /bmab/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BMABBasic, BMABBind, BMABSaveImage, BMABText, BMABPreviewText, BMABRemoteAccessAndSave 2 | from .binder import BMABBind, BMABLoraBind 3 | from .cnloader import BMABControlNet, BMABControlNetOpenpose, BMABControlNetIPAdapter, BMABFluxControlNet 4 | from .detailers import BMABFaceDetailer, BMABPersonDetailer, BMABSimpleHandDetailer, BMABSubframeHandDetailer 5 | from .detailers import BMABOpenposeHandDetailer, BMABDetailAnything 6 | from .imaging import BMABDetectionCrop, BMABRemoveBackground, BMABAlphaComposit, BMABBlend 7 | from .imaging import BMABDetectAndMask, BMABLamaInpaint, BMABDetector, BMABSegmentAnything, BMABMasksToImages 8 | from .imaging import BMABLoadImage, BMABEdge, BMABLoadOutputImage, BMABBlackAndWhite, BMABDetectAndPaste 9 | from .loaders import BMABLoraLoader 10 | from .resize import BMABResizeByPerson, BMABResizeByRatio, BMABResizeAndFill, BMABCrop, BMABZoomOut, BMABSquare 11 | from .sampler import BMABKSampler, BMABKSamplerHiresFix, BMABPrompt, BMABIntegrator, BMABSeedGenerator, BMABExtractor 12 | from .sampler import BMABContextNode, BMABKSamplerHiresFixWithUpscaler, BMABImportIntegrator, BMABKSamplerKohyaDeepShrink 13 | from .sampler import BMABClipTextEncoderSDXL, BMABFluxIntegrator, BMABToBind 14 | from .upscaler import BMABUpscale, BMABUpscaleWithModel 15 | from .toy import BMABGoogleGemini 16 | from .a1111api import BMABApiServer, BMABApiSDWebUIT2I, BMABApiSDWebUIT2IHiresFix, BMABApiSDWebUIControlNet 17 | from .a1111api import BMABApiSDWebUIBMABExtension, BMABApiSDWebUII2I 18 | from .utilnode import BMABModelToBind, BMABConditioningToBind, BMABNoiseGenerator 19 | from .watermark import BMABWatermark 20 | from .fill import BMABInpaint, BMABOutpaintByRatio, BMABReframe 21 | from .utilnode import BMABBase64Image, BMABImageStorage, BMABNormalizeSize, BMABDummy 22 | -------------------------------------------------------------------------------- /bmab/nodes/basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import glob 5 | import time 6 | import numpy as np 7 | import folder_paths 8 | from comfy.cli_args import args 9 | from PIL.PngImagePlugin import PngInfo 10 | 11 | from PIL import Image, ImageEnhance, ImageOps 12 | from bmab import utils 13 | from bmab.nodes.binder import BMABBind 14 | 15 | 16 | def calc_color_temperature(temp): 17 | white = (255.0, 254.11008387561782, 250.0419083427406) 18 | 19 | temperature = temp / 100 20 | 21 | if temperature <= 66: 22 | red = 255.0 23 | else: 24 | red = float(temperature - 60) 25 | red = 329.698727446 * math.pow(red, -0.1332047592) 26 | if red < 0: 27 | red = 0 28 | if red > 255: 29 | red = 255 30 | 31 | if temperature <= 66: 32 | green = temperature 33 | green = 99.4708025861 * math.log(green) - 161.1195681661 34 | else: 35 | green = float(temperature - 60) 36 | green = 288.1221695283 * math.pow(green, -0.0755148492) 37 | if green < 0: 38 | green = 0 39 | if green > 255: 40 | green = 255 41 | 42 | if temperature >= 66: 43 | blue = 255.0 44 | else: 45 | if temperature <= 19: 46 | blue = 0.0 47 | else: 48 | blue = float(temperature - 10) 49 | blue = 138.5177312231 * math.log(blue) - 305.0447927307 50 | if blue < 0: 51 | blue = 0 52 | if blue > 255: 53 | blue = 255 54 | 55 | return red / white[0], green / white[1], blue / white[2] 56 | 57 | 58 | class BMABBasic: 59 | @classmethod 60 | def INPUT_TYPES(s): 61 | return { 62 | 'required': { 63 | 'contrast': ('FLOAT', {'default': 1.0, 'min': 0, 'max': 2, 'step': 0.05}), 64 | 'brightness': ('FLOAT', {'default': 1.0, 'min': 0, 'max': 2, 'step': 0.05}), 65 | 'sharpeness': ('FLOAT', {'default': 1.0, 'min': -5.0, 'max': 5.0, 'step': 0.1}), 66 | 'color_saturation': ('FLOAT', {'default': 1.0, 'min': 0.0, 'max': 2.0, 'step': 0.01}), 67 | 'color_temperature': ('INT', {'default': 0, 'min': -2000, 'max': 2000, 'step': 1}), 68 | 'noise_alpha': ('FLOAT', {'default': 0, 'min': 0.0, 'max': 1.0, 'step': 0.05}), 69 | }, 70 | 'optional': { 71 | 'bind': ('BMAB bind',), 72 | 'image': ('IMAGE',), 73 | }, 74 | 'hidden': {'unique_id': 'UNIQUE_ID'} 75 | } 76 | 77 | RETURN_TYPES = ('BMAB bind', 'IMAGE',) 78 | RETURN_NAMES = ('BMAB bind', 'image',) 79 | FUNCTION = 'process' 80 | 81 | CATEGORY = 'BMAB/basic' 82 | 83 | def process(self, contrast, brightness, sharpeness, color_saturation, color_temperature, noise_alpha, unique_id, bind: BMABBind = None, image=None): 84 | if bind is None: 85 | pixels = image 86 | else: 87 | pixels = bind.pixels if image is None else image 88 | 89 | results = [] 90 | for bgimg in utils.get_pils_from_pixels(pixels): 91 | if contrast != 1: 92 | enhancer = ImageEnhance.Contrast(bgimg) 93 | bgimg = enhancer.enhance(contrast) 94 | 95 | if brightness != 1: 96 | enhancer = ImageEnhance.Brightness(bgimg) 97 | bgimg = enhancer.enhance(brightness) 98 | 99 | if sharpeness != 1: 100 | enhancer = ImageEnhance.Sharpness(bgimg) 101 | bgimg = enhancer.enhance(sharpeness) 102 | 103 | if color_saturation != 1: 104 | enhancer = ImageEnhance.Color(bgimg) 105 | bgimg = enhancer.enhance(color_saturation) 106 | 107 | if color_temperature != 0: 108 | temp = calc_color_temperature(6500 + color_temperature) 109 | az = [] 110 | for d in bgimg.getdata(): 111 | az.append((int(d[0] * temp[0]), int(d[1] * temp[1]), int(d[2] * temp[2]))) 112 | bgimg = Image.new('RGB', bgimg.size) 113 | bgimg.putdata(az) 114 | 115 | if noise_alpha != 0: 116 | img_noise = utils.generate_noise(0, bgimg.width, bgimg.height) 117 | bgimg = Image.blend(bgimg, img_noise, alpha=noise_alpha) 118 | 119 | results.append(bgimg) 120 | 121 | pixels = utils.get_pixels_from_pils(results) 122 | return BMABBind.result(bind, pixels, ) 123 | 124 | 125 | class BMABSaveImage: 126 | def __init__(self): 127 | self.output_dir = folder_paths.get_output_directory() 128 | self.type = 'output' 129 | self.compress_level = 4 130 | 131 | @classmethod 132 | def INPUT_TYPES(s): 133 | return { 134 | 'required': { 135 | 'filename_prefix': ('STRING', {'default': 'bmab'}), 136 | 'format': (['png', 'jpg'], ), 137 | 'use_date': (['disable', 'enable'], ), 138 | }, 139 | 'hidden': { 140 | 'prompt': 'PROMPT', 'extra_pnginfo': 'EXTRA_PNGINFO' 141 | }, 142 | 'optional': { 143 | 'bind': ('BMAB bind',), 144 | 'images': ('IMAGE',), 145 | } 146 | } 147 | 148 | RETURN_TYPES = () 149 | FUNCTION = 'save_images' 150 | 151 | OUTPUT_NODE = True 152 | 153 | CATEGORY = 'BMAB/basic' 154 | 155 | @staticmethod 156 | def get_file_sequence(prefix, subdir): 157 | output_dir = os.path.normpath(os.path.join(folder_paths.get_output_directory(), subdir)) 158 | find_path = os.path.join(output_dir, f'{prefix}*') 159 | sequence = 0 160 | for f in glob.glob(find_path): 161 | filename = os.path.basename(f) 162 | split_name = filename[len(prefix)+1:].replace('.', '_').split('_') 163 | try: 164 | file_sequence = int(split_name[0]) 165 | except: 166 | continue 167 | if file_sequence > sequence: 168 | sequence = file_sequence 169 | return sequence + 1 170 | 171 | @staticmethod 172 | def get_sub_directory(use_date): 173 | if not use_date: 174 | return '' 175 | 176 | dd = time.strftime('%Y-%m-%d', time.localtime(time.time())) 177 | full_output_folder = os.path.join(folder_paths.output_directory, dd) 178 | print(full_output_folder) 179 | if not os.path.exists(full_output_folder): 180 | os.mkdir(full_output_folder) 181 | return dd 182 | 183 | def save_images(self, filename_prefix='bmab', format='png', use_date='disable', prompt=None, extra_pnginfo=None, bind: BMABBind = None, images=None): 184 | if images is None: 185 | images = bind.pixels 186 | output_dir = folder_paths.get_output_directory() 187 | results = list() 188 | use_date = use_date == 'enable' 189 | 190 | subdir = self.get_sub_directory(use_date) 191 | prefix_split = filename_prefix.split('/') 192 | if len(prefix_split) != 1: 193 | filename_prefix = prefix_split[-1] 194 | subdir = os.path.join(subdir, '/'.join(prefix_split[:-1])) 195 | sequence = self.get_file_sequence(filename_prefix, subdir) 196 | 197 | for (batch_number, image) in enumerate(images): 198 | 199 | i = 255. * image.cpu().numpy() 200 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) 201 | metadata = None 202 | if not args.disable_metadata: 203 | metadata = PngInfo() 204 | if prompt is not None: 205 | metadata.add_text('prompt', json.dumps(prompt)) 206 | if extra_pnginfo is not None: 207 | for x in extra_pnginfo: 208 | metadata.add_text(x, json.dumps(extra_pnginfo[x])) 209 | 210 | if batch_number > 0: 211 | filename = f'{filename_prefix}_{sequence:05}_{batch_number}' 212 | else: 213 | filename = f'{filename_prefix}_{sequence:05}' 214 | 215 | if bind is not None: 216 | file = f'{filename}_{bind.seed}.{format}' 217 | else: 218 | file = f'{filename}.{format}' 219 | 220 | if use_date: 221 | output_dir = os.path.join(output_dir, subdir) 222 | 223 | if not os.path.exists(output_dir): 224 | os.mkdir(output_dir) 225 | 226 | if format == 'png': 227 | img.save(os.path.join(output_dir, file), pnginfo=metadata, compress_level=self.compress_level) 228 | else: 229 | img.save(os.path.join(output_dir, file)) 230 | 231 | results.append({ 232 | 'filename': file, 233 | 'subfolder': subdir, 234 | 'type': self.type 235 | }) 236 | 237 | sequence += 1 238 | 239 | return {'ui': {'images': results}} 240 | 241 | 242 | class BMABText: 243 | @classmethod 244 | def INPUT_TYPES(s): 245 | return { 246 | 'required': { 247 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}), 248 | }, 249 | 'optional': { 250 | 'text': ('STRING', {"forceInput": True}), 251 | } 252 | } 253 | 254 | RETURN_TYPES = ('STRING',) 255 | RETURN_NAMES = ('string',) 256 | FUNCTION = 'export' 257 | 258 | CATEGORY = 'BMAB/basic' 259 | 260 | def export(self, prompt, text=None): 261 | if text is not None: 262 | prompt = prompt.replace('__prompt__', text) 263 | result = utils.parse_prompt(prompt, 0) 264 | return (result,) 265 | 266 | 267 | class BMABPreviewText: 268 | def __init__(self): 269 | pass 270 | 271 | @classmethod 272 | def INPUT_TYPES(s): 273 | return { 274 | "required": { 275 | "text": ("STRING", {"forceInput": True}), 276 | }, 277 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 278 | } 279 | 280 | RETURN_TYPES = ("STRING",) 281 | OUTPUT_NODE = True 282 | FUNCTION = "preview_text" 283 | 284 | CATEGORY = "BMAB/basic" 285 | 286 | def preview_text(self, text, prompt=None, extra_pnginfo=None): 287 | return {"ui": {"string": [text, ]}, "result": (text,)} 288 | 289 | 290 | class BMABRemoteAccessAndSave(BMABSaveImage): 291 | 292 | @classmethod 293 | def INPUT_TYPES(s): 294 | return { 295 | 'required': { 296 | 'filename_prefix': ('STRING', {'default': 'bmab'}), 297 | 'format': (['png', 'jpg'], ), 298 | 'use_date': (['disable', 'enable'], ), 299 | 'remote_name': ('STRING', {'multiline': False}), 300 | }, 301 | 'hidden': { 302 | 'prompt': 'PROMPT', 'extra_pnginfo': 'EXTRA_PNGINFO' 303 | }, 304 | 'optional': { 305 | 'bind': ('BMAB bind',), 306 | 'images': ('IMAGE',), 307 | } 308 | } 309 | 310 | RETURN_TYPES = () 311 | FUNCTION = 'remote_save_images' 312 | 313 | OUTPUT_NODE = True 314 | 315 | CATEGORY = 'BMAB/basic' 316 | 317 | def remote_save_images(self, filename_prefix='bmab', format='png', use_date='disable', remote_name=None, prompt=None, extra_pnginfo=None, bind: BMABBind = None, images=None): 318 | return self.save_images(filename_prefix, format, use_date, prompt, extra_pnginfo, bind, images) 319 | 320 | 321 | -------------------------------------------------------------------------------- /bmab/nodes/binder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import copy 5 | 6 | 7 | class BMABContext: 8 | 9 | def __init__(self, *args) -> None: 10 | super().__init__() 11 | self.seed, self.sampler, self.scheduler, self.cfg_scale, self.steps = args 12 | 13 | def get(self): 14 | return self.seed, self.sampler, self.scheduler, self.cfg_scale, self.steps 15 | 16 | def update(self, steps, cfg_scale, sampler, scheduler): 17 | if steps == 0: 18 | steps = self.steps 19 | if cfg_scale == 0: 20 | cfg_scale = self.cfg_scale 21 | if sampler != 'Use same sampler': 22 | sampler = self.sampler 23 | if scheduler != 'Use same scheduler': 24 | scheduler = self.scheduler 25 | return steps, cfg_scale, sampler, scheduler 26 | 27 | 28 | class BMABBind: 29 | 30 | def __init__(self, *args) -> None: 31 | super().__init__() 32 | 33 | self.model, self.clip, self.vae, self.prompt, self.negative_prompt, self.positive, self.negative, self.latent_image, self.context, self.pixels, self.seed = args 34 | 35 | def copy(self): 36 | return copy.copy(self) 37 | 38 | @staticmethod 39 | def result(bind, pixels, *args): 40 | if bind is None: 41 | return (None, pixels, *args) 42 | else: 43 | bind.pixels = pixels 44 | return (bind, bind.pixels, *args) 45 | 46 | def get(self): 47 | return self.model, self.clip, self.vae, self.prompt, self.negative_prompt, self.positive, self.negative, self.latent_image, self.context, self.pixels, self.seed 48 | 49 | 50 | class BMABLoraBind: 51 | def __init__(self, *args) -> None: 52 | super().__init__() 53 | self.loras = [] 54 | 55 | def append(self, *args): 56 | self.loras.append(args) 57 | 58 | def copy(self): 59 | return copy.deepcopy(self) 60 | -------------------------------------------------------------------------------- /bmab/nodes/fill.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from diffusers import AutoencoderKL, TCDScheduler 4 | from diffusers.models.model_loading_utils import load_state_dict 5 | from huggingface_hub import hf_hub_download 6 | 7 | from bmab.external.fill.controlnet_union import ControlNetModel_Union 8 | from bmab.external.fill.pipeline_fill_sd_xl import StableDiffusionXLFillPipeline 9 | 10 | from PIL import Image, ImageDraw, ImageFilter 11 | 12 | from bmab import utils 13 | 14 | 15 | pipe = None 16 | 17 | 18 | def load(): 19 | global pipe 20 | 21 | config_file = hf_hub_download( 22 | "xinsir/controlnet-union-sdxl-1.0", 23 | filename="config_promax.json", 24 | ) 25 | 26 | config = ControlNetModel_Union.load_config(config_file) 27 | controlnet_model = ControlNetModel_Union.from_config(config) 28 | model_file = hf_hub_download( 29 | "xinsir/controlnet-union-sdxl-1.0", 30 | filename="diffusion_pytorch_model_promax.safetensors", 31 | ) 32 | state_dict = load_state_dict(model_file) 33 | model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model( 34 | controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0" 35 | ) 36 | model.to(device="cuda", dtype=torch.float16) 37 | 38 | vae = AutoencoderKL.from_pretrained( 39 | "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 40 | ).to("cuda") 41 | 42 | pipe = StableDiffusionXLFillPipeline.from_pretrained( 43 | "SG161222/RealVisXL_V5.0_Lightning", 44 | torch_dtype=torch.float16, 45 | vae=vae, 46 | controlnet=model, 47 | variant="fp16", 48 | ).to("cuda") 49 | 50 | pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config) 51 | 52 | 53 | def unload(): 54 | global pipe 55 | if pipe is not None: 56 | pipe = None 57 | utils.torch_gc() 58 | 59 | 60 | class BMABReframe: 61 | 62 | @classmethod 63 | def INPUT_TYPES(s): 64 | return { 65 | 'required': { 66 | 'image': ('IMAGE',), 67 | 'ratio': (['1:1', '4:5', '2:3', '9:16', '5:4', '3:2', '16:9'],), 68 | 'dilation': ('INT', {'default': 32, 'min': 4, 'max': 128, 'step': 1}), 69 | 'step': ('INT', {'default': 8, 'min': 4, 'max': 128, 'step': 1}), 70 | 'iteration': ('INT', {'default': 4, 'min': 1, 'max': 8, 'step': 1}), 71 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}), 72 | } 73 | } 74 | 75 | RETURN_TYPES = ('IMAGE',) 76 | RETURN_NAMES = ('image',) 77 | FUNCTION = 'process' 78 | 79 | CATEGORY = 'BMAB/fill' 80 | 81 | ratio_sel = { 82 | '1:1': (1024, 1024), 83 | '4:5': (960, 1200), 84 | '2:3': (896, 1344), 85 | '9:16': (816, 1456), 86 | '5:4': (1200, 960), 87 | '3:2': (1344, 896), 88 | '16:9': (1456, 816) 89 | } 90 | 91 | def infer(self, image, width, height, overlap_width, num_inference_steps, prompt_input): 92 | source = image 93 | image_ratio = source.width / source.height 94 | output_ratio = width / height 95 | 96 | if output_ratio <= image_ratio: 97 | ratio = width / source.width 98 | else: 99 | ratio = height / source.height 100 | 101 | source = source.resize((math.ceil(source.width * ratio), math.ceil(source.height * ratio)), Image.Resampling.LANCZOS) 102 | background = Image.new('RGB', (width, height), (255, 255, 255)) 103 | mask = Image.new('L', (width, height), 255) 104 | mask_draw = ImageDraw.Draw(mask) 105 | 106 | if output_ratio <= image_ratio: 107 | margin = (height - source.height) // 2 108 | background.paste(source, (0, margin)) 109 | mask_draw.rectangle((0, margin + overlap_width, source.width, margin + source.height - overlap_width), fill=0) 110 | else: 111 | margin = (width - source.width) // 2 112 | background.paste(source, (margin, 0)) 113 | mask_draw.rectangle((margin + overlap_width, 0, margin + source.width - overlap_width, source.height), fill=0) 114 | 115 | cnet_image = background.copy() 116 | cnet_image.paste(0, (0, 0), mask) 117 | 118 | final_prompt = f"{prompt_input} , high quality, 4k" 119 | 120 | if pipe is None: 121 | load() 122 | 123 | ( 124 | prompt_embeds, 125 | negative_prompt_embeds, 126 | pooled_prompt_embeds, 127 | negative_pooled_prompt_embeds, 128 | ) = pipe.encode_prompt(final_prompt, "cuda", True) 129 | 130 | image = pipe( 131 | prompt_embeds=prompt_embeds, 132 | negative_prompt_embeds=negative_prompt_embeds, 133 | pooled_prompt_embeds=pooled_prompt_embeds, 134 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 135 | image=cnet_image, 136 | num_inference_steps=num_inference_steps 137 | ) 138 | 139 | image = image.convert("RGBA") 140 | cnet_image.paste(image, (0, 0), mask) 141 | 142 | return cnet_image 143 | 144 | def process(self, image, ratio, dilation, step, iteration, prompt, **kwargs): 145 | 146 | r = BMABReframe.ratio_sel.get(ratio, (1024, 1024)) 147 | 148 | results = [] 149 | for image in utils.get_pils_from_pixels(image): 150 | for v in range(0, iteration): 151 | a = self.infer(image, r[0], r[1], dilation, step, prompt_input=prompt) 152 | results.append(a) 153 | pixels = utils.get_pixels_from_pils(results) 154 | 155 | return (pixels,) 156 | 157 | 158 | class BMABOutpaintByRatio: 159 | resize_methods = ['stretching', 'inpaint', 'inpaint+lama'] 160 | resize_alignment = ['bottom', 'top', 'top-right', 'right', 'bottom-right', 'bottom-left', 'left', 'top-left', 'center'] 161 | 162 | @classmethod 163 | def INPUT_TYPES(s): 164 | return { 165 | 'required': { 166 | 'image': ('IMAGE',), 167 | 'steps': ('INT', {'default': 8, 'min': 0, 'max': 10000}), 168 | 'alignment': (s.resize_alignment,), 169 | 'ratio': ('FLOAT', {'default': 0.85, 'min': 0.1, 'max': 0.95, 'step': 0.01}), 170 | 'dilation': ('INT', {'default': 32, 'min': 4, 'max': 128, 'step': 1}), 171 | 'iteration': ('INT', {'default': 4, 'min': 1, 'max': 8, 'step': 1}), 172 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}), 173 | }, 174 | 'optional': { 175 | } 176 | } 177 | 178 | RETURN_TYPES = ('IMAGE', ) 179 | RETURN_NAMES = ('image', ) 180 | FUNCTION = 'process' 181 | 182 | CATEGORY = 'BMAB/fill' 183 | 184 | @staticmethod 185 | def image_alignment(image, left, right, top, bottom, ratio): 186 | left = int(left) 187 | top = int(top) 188 | input_image = image.resize((int(image.width * ratio), int(image.height * ratio)), Image.Resampling.LANCZOS) 189 | background = Image.new('RGB', image.size, (255, 255, 255)) 190 | background.paste(input_image, box=(left, top)) 191 | return background 192 | 193 | @staticmethod 194 | def mask_alignment(width, height, left, right, top, bottom, ratio, dilation): 195 | left = int(left) 196 | top = int(top) 197 | w, h = math.ceil(width * ratio), math.ceil(height * ratio) 198 | mask = Image.new('L', (width, height), 255) 199 | mask_draw = ImageDraw.Draw(mask) 200 | box = ( 201 | 0 if left == 0 else left + dilation, 202 | 0 if top == 0 else top + dilation, 203 | width if (left + w) >= width else (left + w - dilation), 204 | height if (top + h) >= height else (top + h - dilation) 205 | ) 206 | mask_draw.rectangle(box, fill=0) 207 | return mask 208 | 209 | def infer(self, image, al, ratio, dilation, num_inference_steps, prompt_input): 210 | if al not in utils.alignment: 211 | return image 212 | w, h = math.ceil(image.width * (1 - ratio)), math.ceil(image.height * (1 - ratio)) 213 | background = BMABOutpaintByRatio.image_alignment(image, *utils.alignment[al](w, h), ratio) 214 | mask = BMABOutpaintByRatio.mask_alignment(image.width, image.height, *utils.alignment[al](w, h), ratio, dilation) 215 | 216 | cnet_image = background.copy() 217 | cnet_image.paste(0, (0, 0), mask) 218 | 219 | final_prompt = f"{prompt_input} , high quality, 4k" 220 | 221 | if pipe is None: 222 | load() 223 | 224 | ( 225 | prompt_embeds, 226 | negative_prompt_embeds, 227 | pooled_prompt_embeds, 228 | negative_pooled_prompt_embeds, 229 | ) = pipe.encode_prompt(final_prompt, "cuda", True) 230 | 231 | image = pipe( 232 | prompt_embeds=prompt_embeds, 233 | negative_prompt_embeds=negative_prompt_embeds, 234 | pooled_prompt_embeds=pooled_prompt_embeds, 235 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 236 | image=cnet_image, 237 | num_inference_steps=num_inference_steps 238 | ) 239 | 240 | return image 241 | 242 | def process(self, image, steps, alignment, ratio, dilation, iteration, prompt): 243 | 244 | results = [] 245 | for image in utils.get_pils_from_pixels(image): 246 | 247 | print('Process image resize', ratio) 248 | for r in range(0, iteration): 249 | a = self.infer(image, alignment, ratio, dilation, steps, prompt_input=prompt) 250 | results.append(a) 251 | 252 | pixels = utils.get_pixels_from_pils(results) 253 | return (pixels,) 254 | 255 | 256 | class BMABInpaint: 257 | 258 | @classmethod 259 | def INPUT_TYPES(s): 260 | return { 261 | 'required': { 262 | 'image': ('IMAGE',), 263 | 'mask': ('MASK',), 264 | 'steps': ('INT', {'default': 8, 'min': 0, 'max': 10000}), 265 | 'iteration': ('INT', {'default': 4, 'min': 1, 'max': 8, 'step': 1}), 266 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}), 267 | }, 268 | 'optional': { 269 | 'seed': ('SEED',) 270 | } 271 | } 272 | 273 | RETURN_TYPES = ('IMAGE', ) 274 | RETURN_NAMES = ('image', ) 275 | FUNCTION = 'process' 276 | 277 | CATEGORY = 'BMAB/fill' 278 | 279 | def infer(self, image, mask, steps, prompt_input): 280 | 281 | source = image 282 | source.paste((255, 255, 255), (0, 0), mask) 283 | 284 | cnet_image = source.copy() 285 | cnet_image.paste(0, (0, 0), mask) 286 | 287 | final_prompt = f"{prompt_input} , high quality, 4k" 288 | 289 | if pipe is None: 290 | load() 291 | 292 | ( 293 | prompt_embeds, 294 | negative_prompt_embeds, 295 | pooled_prompt_embeds, 296 | negative_pooled_prompt_embeds, 297 | ) = pipe.encode_prompt(final_prompt, "cuda", True) 298 | 299 | image = pipe( 300 | prompt_embeds=prompt_embeds, 301 | negative_prompt_embeds=negative_prompt_embeds, 302 | pooled_prompt_embeds=pooled_prompt_embeds, 303 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 304 | image=cnet_image, 305 | num_inference_steps=steps 306 | ) 307 | return image 308 | 309 | def mask_to_image(self, mask): 310 | result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) 311 | return utils.get_pils_from_pixels(result)[0].convert('L') 312 | 313 | def process(self, image, mask, steps, iteration, prompt, seed=None): 314 | 315 | results = [] 316 | mask = self.mask_to_image(mask) 317 | for image in utils.get_pils_from_pixels(image): 318 | for r in range(0, iteration): 319 | a = self.infer(image, mask, steps, prompt_input=prompt) 320 | results.append(a) 321 | 322 | pixels = utils.get_pixels_from_pils(results) 323 | return (pixels,) 324 | 325 | -------------------------------------------------------------------------------- /bmab/nodes/loaders.py: -------------------------------------------------------------------------------- 1 | import folder_paths 2 | from bmab.nodes.binder import BMABLoraBind 3 | 4 | 5 | class BMABLoraLoader: 6 | @classmethod 7 | def INPUT_TYPES(s): 8 | return { 9 | 'required': { 10 | 'lora_name': (folder_paths.get_filename_list('loras'), ), 11 | 'strength_model': ('FLOAT', {'default': 1.0, 'min': -100.0, 'max': 100.0, 'step': 0.01}), 12 | 'strength_clip': ('FLOAT', {'default': 1.0, 'min': -100.0, 'max': 100.0, 'step': 0.01}), 13 | }, 14 | 'optional': { 15 | 'lora': ('BMAB lora',), 16 | } 17 | } 18 | 19 | RETURN_TYPES = ('BMAB lora', ) 20 | RETURN_NAMES = ('lora', ) 21 | FUNCTION = 'load_lora' 22 | 23 | CATEGORY = 'BMAB/loader' 24 | 25 | def load_lora(self, lora_name, strength_model, strength_clip, lora: BMABLoraBind=None): 26 | if lora is None: 27 | lora = BMABLoraBind() 28 | else: 29 | lora = lora.copy() 30 | lora.append(lora_name, strength_model, strength_clip) 31 | return (lora, ) 32 | 33 | -------------------------------------------------------------------------------- /bmab/nodes/toy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import bmab 5 | 6 | from bmab import utils 7 | from server import PromptServer 8 | 9 | 10 | question = ''' 11 | make detailed prompt for stable diffusion using keyword "{text}" about scene, lighting, face, pose, clothes, background and colors in only 1 sentence. The sentence is describes very detailed. Do not say about human race. 12 | ''' 13 | 14 | 15 | class BMABGoogleGemini: 16 | 17 | def __init__(self) -> None: 18 | super().__init__() 19 | self.last_prompt = None 20 | self.last_text = None 21 | 22 | @classmethod 23 | def INPUT_TYPES(s): 24 | return { 25 | 'required': { 26 | 'prompt': ('STRING', {'multiline': True, 'dynamicPrompts': True}), 27 | 'text': ('STRING', {'multiline': True, 'dynamicPrompts': True}), 28 | 'api_key': ('STRING', {'multiline': False}), 29 | 'random_seed': ('INT', {'default': 0, 'min': 0, 'max': 65536, 'step': 1}), 30 | }, 31 | } 32 | 33 | RETURN_TYPES = ('STRING',) 34 | RETURN_NAMES = ('string', ) 35 | FUNCTION = 'prompt' 36 | 37 | CATEGORY = 'BMAB/toy' 38 | 39 | def get_prompt(self, text, api_key): 40 | import google.generativeai as genai 41 | genai.configure(api_key=api_key) 42 | model = genai.GenerativeModel('gemini-pro') 43 | text = text.strip() 44 | response = model.generate_content(question.format(text=text)) 45 | try: 46 | self.last_prompt = response.text 47 | print(response.text) 48 | cache_path = os.path.join(os.path.dirname(bmab.__file__), '../resources/cache') 49 | cache_file = os.path.join(cache_path, 'gemini.txt') 50 | with open(cache_file, 'a', encoding='utf8') as f: 51 | f.write(time.strftime('%Y.%m.%d - %H:%M:%S')) 52 | f.write('\n') 53 | f.write(self.last_prompt) 54 | f.write('\n') 55 | except: 56 | print('ERROR reading API response', response) 57 | self.last_text = None 58 | PromptServer.instance.send_sync("stop-iteration", {}) 59 | 60 | return self.last_prompt 61 | 62 | def prompt(self, prompt: str, text: str, api_key, random_seed=None, **kwargs): 63 | random_seed = random.randint(0, 65535) 64 | if prompt.find('__prompt__') >= 0: 65 | if self.last_text != text: 66 | random_seed = random.randint(0, 65535) 67 | self.last_text = text 68 | self.get_prompt(text, api_key) 69 | if self.last_prompt is None: 70 | PromptServer.instance.send_sync("stop-iteration", {}) 71 | prompt = prompt.replace('__prompt__', self.last_prompt) 72 | result = utils.parse_prompt(prompt, random_seed) 73 | return {"ui": {"string": [str(random_seed), ]}, "result": (result,)} 74 | -------------------------------------------------------------------------------- /bmab/nodes/upscaler.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from PIL import ImageDraw 3 | 4 | from comfy_extras.chainner_models import model_loading 5 | from comfy import model_management 6 | import torch 7 | import comfy.utils 8 | import folder_paths 9 | 10 | import nodes 11 | from bmab import utils 12 | from bmab.nodes.binder import BMABBind 13 | 14 | 15 | class BMABUpscale: 16 | upscale_methods = ['LANCZOS', 'NEAREST', 'BILINEAR', 'BICUBIC'] 17 | 18 | @classmethod 19 | def INPUT_TYPES(s): 20 | return { 21 | 'required': { 22 | 'upscale_method': (BMABUpscale.upscale_methods, ), 23 | 'scale': ('FLOAT', {'default': 2.0, 'min': 0, 'max': 4.0, 'step': 0.001}), 24 | 'width': ('INT', {'default': 512, 'min': 32, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 25 | 'height': ('INT', {'default': 512, 'min': 32, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 26 | }, 27 | 'optional': { 28 | 'bind': ('BMAB bind',), 29 | 'image': ('IMAGE',), 30 | }, 31 | } 32 | 33 | RETURN_TYPES = ('BMAB bind', 'IMAGE',) 34 | RETURN_NAMES = ('BMAB bind', 'image', ) 35 | FUNCTION = 'upscale' 36 | 37 | CATEGORY = 'BMAB/upscale' 38 | 39 | def upscale(self, upscale_method, scale, width, height, bind: BMABBind=None, image=None): 40 | pixels = bind.pixels if image is None else image 41 | pil_upscale_methods = { 42 | 'LANCZOS': Image.Resampling.LANCZOS, 43 | 'BILINEAR': Image.Resampling.BILINEAR, 44 | 'BICUBIC': Image.Resampling.BICUBIC, 45 | 'NEAREST': Image.Resampling.NEAREST, 46 | } 47 | results = [] 48 | for bgimg in utils.get_pils_from_pixels(pixels): 49 | if scale != 0: 50 | width, height = int(bgimg.width * scale), int(bgimg.height * scale) 51 | method = pil_upscale_methods.get(upscale_method) 52 | results.append(bgimg.resize((width, height), method)) 53 | pixels = utils.get_pixels_from_pils(results) 54 | return BMABBind.result(bind, pixels, ) 55 | 56 | 57 | class BMABUpscaleWithModel: 58 | @classmethod 59 | def INPUT_TYPES(s): 60 | return { 61 | "required": { 62 | "model_name": (folder_paths.get_filename_list("upscale_models"),), 63 | 'scale': ('FLOAT', {'default': 2.0, 'min': 0, 'max': 4.0, 'step': 0.001}), 64 | 'width': ('INT', {'default': 512, 'min': 0, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 65 | 'height': ('INT', {'default': 512, 'min': 0, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 66 | }, 67 | 'optional': { 68 | 'bind': ('BMAB bind',), 69 | 'image': ('IMAGE',), 70 | }, 71 | } 72 | 73 | RETURN_TYPES = ('BMAB bind', "IMAGE",) 74 | RETURN_NAMES = ('BMAB bind', 'image', ) 75 | FUNCTION = "upscale" 76 | 77 | CATEGORY = "BMAB/upscale" 78 | 79 | def load_model(self, model_name): 80 | model_path = folder_paths.get_full_path("upscale_models", model_name) 81 | sd = comfy.utils.load_torch_file(model_path, safe_load=True) 82 | if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: 83 | sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": ""}) 84 | out = model_loading.load_state_dict(sd).eval() 85 | return out 86 | 87 | def upscale_with_model(self, model_name, pixels, progress=True): 88 | upscale_model = self.load_model(model_name) 89 | device = model_management.get_torch_device() 90 | 91 | memory_required = model_management.module_size(upscale_model.model) 92 | memory_required += (512 * 512 * 3) * pixels.element_size() * max(upscale_model.scale, 1.0) * 384.0 # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate 93 | memory_required += pixels.nelement() * pixels.element_size() 94 | model_management.free_memory(memory_required, device) 95 | 96 | upscale_model.to(device) 97 | in_img = pixels.movedim(-1, -3).to(device) 98 | 99 | tile = 512 100 | overlap = 32 101 | 102 | oom = True 103 | while oom: 104 | try: 105 | if progress: 106 | steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) 107 | pbar = comfy.utils.ProgressBar(steps) 108 | s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) 109 | else: 110 | s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale) 111 | oom = False 112 | except model_management.OOM_EXCEPTION as e: 113 | tile //= 2 114 | if tile < 128: 115 | raise e 116 | 117 | upscale_model.to("cpu") 118 | s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) 119 | return (s,) 120 | 121 | def upscale(self, model_name, scale, width, height, bind: BMABBind=None, image=None): 122 | pixels = bind.pixels if image is None else image 123 | if scale != 0: 124 | _, h, w, c = pixels.shape 125 | width, height = int(w * scale), int(h * scale) 126 | 127 | s = self.upscale_with_model(model_name, pixels) 128 | pil_images = utils.get_pils_from_pixels(s) 129 | results = [img.resize((width, height), Image.Resampling.LANCZOS) for img in pil_images] 130 | pixels = utils.get_pixels_from_pils(results) 131 | 132 | return BMABBind.result(bind, pixels,) 133 | -------------------------------------------------------------------------------- /bmab/nodes/utilnode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | import base64 6 | from io import BytesIO 7 | from PIL import Image 8 | 9 | import bmab 10 | from bmab import utils 11 | from bmab.nodes.binder import BMABBind 12 | from bmab import serverext 13 | 14 | 15 | class BMABModelToBind: 16 | 17 | @classmethod 18 | def INPUT_TYPES(s): 19 | return { 20 | 'required': { 21 | 'bind': ('BMAB bind',), 22 | }, 23 | 'optional': { 24 | 'model': ('MODEL',), 25 | 'clip': ('CLIP',), 26 | 'vae': ('VAE',), 27 | } 28 | } 29 | 30 | RETURN_TYPES = ('BMAB bind', ) 31 | RETURN_NAMES = ('bind', ) 32 | FUNCTION = 'process' 33 | 34 | CATEGORY = 'BMAB/utils' 35 | 36 | def process(self, bind: BMABBind, model=None, clip=None, vae=None): 37 | if model is not None: 38 | bind.model = model 39 | if clip is not None: 40 | bind.clip = clip 41 | if vae is not None: 42 | bind.vae = vae 43 | return (bind, ) 44 | 45 | 46 | class BMABConditioningToBind: 47 | 48 | @classmethod 49 | def INPUT_TYPES(s): 50 | return { 51 | 'required': { 52 | 'bind': ('BMAB bind',), 53 | }, 54 | 'optional': { 55 | 'positive': ('CONDITIONING',), 56 | 'negative': ('CONDITIONING',), 57 | } 58 | } 59 | 60 | RETURN_TYPES = ('BMAB bind', ) 61 | RETURN_NAMES = ('bind', ) 62 | FUNCTION = 'process' 63 | 64 | CATEGORY = 'BMAB/utils' 65 | 66 | def process(self, bind: BMABBind, positive=None, negative=None): 67 | if positive is not None: 68 | bind.positive = positive 69 | if negative is not None: 70 | bind.negative = negative 71 | return (bind, ) 72 | 73 | 74 | class BMABNoiseGenerator: 75 | 76 | @classmethod 77 | def INPUT_TYPES(s): 78 | return { 79 | 'required': { 80 | 'width': ('INT', {'default': 512, 'min': 256, 'max': 2048, 'step': 8}), 81 | 'height': ('INT', {'default': 512, 'min': 256, 'max': 2048, 'step': 8}), 82 | }, 83 | 'optional': { 84 | 'bind': ('BMAB bind',), 85 | 'latent': ('LATENT',), 86 | } 87 | } 88 | 89 | RETURN_TYPES = ('IMAGE', ) 90 | RETURN_NAMES = ('image', ) 91 | FUNCTION = 'generate' 92 | 93 | CATEGORY = 'BMAB/utils' 94 | 95 | @staticmethod 96 | def generate_noise(seed, width, height): 97 | img_1 = np.zeros([height, width, 3], dtype=np.uint8) 98 | # Generate random Gaussian noise 99 | mean = 0 100 | stddev = 180 101 | r, g, b = cv2.split(img_1) 102 | # cv2.setRNGSeed(seed) 103 | cv2.randn(r, mean, stddev) 104 | cv2.randn(g, mean, stddev) 105 | cv2.randn(b, mean, stddev) 106 | img = cv2.merge([r, g, b]) 107 | pil_image = Image.fromarray(img, mode='RGB') 108 | return pil_image 109 | 110 | def generate(self, width, height, bind: BMABBind=None, latent=None): 111 | if bind is not None: 112 | width, height = utils.get_shape(bind.latent_image) 113 | if latent is not None: 114 | width, height = utils.get_shape(latent) 115 | 116 | cache_path = os.path.join(os.path.dirname(bmab.__file__), '../resources/cache') 117 | filename = f'noise_{width}_{height}.png' 118 | full_path = os.path.join(cache_path, filename) 119 | if os.path.exists(full_path) and os.path.isfile(full_path): 120 | noise = Image.open(full_path) 121 | return (utils.get_pixels_from_pils([noise]), ) 122 | 123 | noise = self.generate_noise(0, width, height) 124 | noise.save(full_path) 125 | return (utils.get_pixels_from_pils([noise]),) 126 | 127 | 128 | class BMABBase64Image: 129 | 130 | def __init__(self) -> None: 131 | super().__init__() 132 | 133 | @classmethod 134 | def INPUT_TYPES(s): 135 | return { 136 | 'required': { 137 | 'encoding': ('STRING', {'multiline': True, 'dynamicPrompts': True}), 138 | }, 139 | } 140 | 141 | RETURN_TYPES = ('IMAGE', 'INT', 'INT') 142 | RETURN_NAMES = ('image', 'width', 'height') 143 | FUNCTION = 'process' 144 | 145 | CATEGORY = 'BMAB/utils' 146 | 147 | def process(self, encoding): 148 | results = [] 149 | pil = Image.open(BytesIO(base64.b64decode(encoding))) 150 | results.append(pil) 151 | return utils.get_pixels_from_pils(results), pil.width, pil.height 152 | 153 | 154 | class BMABImageStorage: 155 | 156 | def __init__(self) -> None: 157 | super().__init__() 158 | 159 | @classmethod 160 | def INPUT_TYPES(s): 161 | return { 162 | 'required': { 163 | 'images': ('IMAGE',), 164 | 'client_id': ('STRING', {'multiline': False}), 165 | }, 166 | } 167 | 168 | RETURN_TYPES = () 169 | FUNCTION = 'process' 170 | 171 | CATEGORY = 'BMAB/utils' 172 | OUTPUT_NODE = True 173 | 174 | def process(self, images, client_id): 175 | results = [] 176 | for image in utils.get_pils_from_pixels(images): 177 | results.append(image) 178 | serverext.memory_image_storage[client_id] = results 179 | return {'ui': {'images': []}} 180 | 181 | 182 | class BMABNormalizeSize: 183 | 184 | def __init__(self) -> None: 185 | super().__init__() 186 | 187 | @classmethod 188 | def INPUT_TYPES(s): 189 | return { 190 | 'required': { 191 | 'width': ('INT', {'default': 512, 'min': 256, 'max': 2048, 'step': 8}), 192 | 'height': ('INT', {'default': 768, 'min': 256, 'max': 2048, 'step': 8}), 193 | 'normalize': ('INT', {'default': 768, 'min': 256, 'max': 2048, 'step': 8}), 194 | }, 195 | } 196 | 197 | RETURN_TYPES = ('INT', 'INT') 198 | RETURN_NAMES = ('width', 'height') 199 | FUNCTION = 'process' 200 | 201 | CATEGORY = 'BMAB/utils' 202 | OUTPUT_NODE = True 203 | 204 | def process(self, width, height, normalize): 205 | print(width, height) 206 | if height > width: 207 | ratio = normalize / height 208 | w, h = int(width * ratio), normalize 209 | else: 210 | ratio = normalize / width 211 | w, h = normalize, int(height * ratio) 212 | return w, h 213 | 214 | 215 | class BMABDummy: 216 | 217 | def __init__(self) -> None: 218 | super().__init__() 219 | 220 | @classmethod 221 | def INPUT_TYPES(s): 222 | return { 223 | 'required': { 224 | 'images': ('IMAGE',), 225 | 'seed': ('INT', {'default': 0, 'min': 0, 'max': 0xffffffffffffffff, 'tooltip': 'The random seed used for creating the noise.'}), 226 | }, 227 | } 228 | 229 | RETURN_TYPES = ('IMAGE', ) 230 | RETURN_NAMES = ('images', ) 231 | FUNCTION = 'process' 232 | 233 | CATEGORY = 'BMAB/utils' 234 | OUTPUT_NODE = True 235 | 236 | def process(self, images, seed): 237 | return (images, ) 238 | 239 | -------------------------------------------------------------------------------- /bmab/nodes/watermark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | 5 | from PIL import Image 6 | from PIL import ImageDraw 7 | from PIL import ImageFont 8 | 9 | from bmab import utils 10 | 11 | 12 | class BMABWatermark: 13 | alignment = { 14 | 'bottom-left': lambda w, h, cx, cy: (0, h - cy), 15 | 'top': lambda w, h, cx, cy: (w / 2 - cx / 2, 0), 16 | 'top-right': lambda w, h, cx, cy: (w - cx, 0), 17 | 'right': lambda w, h, cx, cy: (w - cx, h / 2 - cy / 2), 18 | 'bottom-right': lambda w, h, cx, cy: (w - cx, h - cy), 19 | 'bottom': lambda w, h, cx, cy: (w / 2 - cx / 2, h - cy), 20 | 'left': lambda w, h, cx, cy: (0, h / 2 - cy / 2), 21 | 'top-left': lambda w, h, cx, cy: (0, 0), 22 | 'center': lambda w, h, cx, cy: (w / 2 - cx / 2, h / 2 - cy / 2), 23 | } 24 | 25 | 26 | @classmethod 27 | def INPUT_TYPES(s): 28 | return { 29 | 'required': { 30 | 'font': (s.list_fonts(),), 31 | 'alignment': ([x for x in s.alignment.keys()], ), 32 | 'text_alignment': (['left', 'right', 'center'], ), 33 | 'rotate': ([0, 90, 180, 270], ), 34 | 'color': ('STRING', {'default': '#000000'}), 35 | 'background_color': ('STRING', {'default': '#000000'}), 36 | 'font_size': ('INT', {'default': 12, 'min': 4, 'max': 128}), 37 | 'transparency': ('INT', {'default': 100, 'min': 0, 'max': 100}), 38 | 'background_transparency': ('INT', {'default': 0, 'min': 0, 'max': 100}), 39 | 'margin': ('INT', {'default': 5, 'min': 0, 'max': 100}), 40 | 'text': ('STRING', {'multiline': True}), 41 | }, 42 | 'optional': { 43 | 'bind': ('BMAB bind',), 44 | 'image': ('IMAGE',), 45 | } 46 | } 47 | 48 | RETURN_TYPES = ('BMAB bind', 'IMAGE',) 49 | RETURN_NAMES = ('BMAB bind', 'image',) 50 | FUNCTION = 'process' 51 | 52 | CATEGORY = 'BMAB/basic' 53 | 54 | def process_watermark(self, img, font, alignment, text_alignment, rotate, color, background_color, font_size, transparency, background_transparency, margin, text): 55 | background_color = self.color_hex_to_rgb(background_color, int(255 * (background_transparency / 100))) 56 | 57 | if os.path.isfile(text): 58 | cropped = Image.open(text) 59 | else: 60 | font = self.get_font(font, font_size) 61 | color = self.color_hex_to_rgb(color, int(255 * (transparency / 100))) 62 | 63 | # 1st 64 | base = Image.new('RGBA', img.size, background_color) 65 | draw = ImageDraw.Draw(base) 66 | bbox = draw.textbbox((0, 0), text, font=font) 67 | draw.text((0, 0), text, font=font, fill=color, align=text_alignment) 68 | cropped = base.crop(bbox) 69 | 70 | # 2st margin 71 | base = Image.new('RGBA', (cropped.width + margin * 2, cropped.height + margin * 2), background_color) 72 | base.paste(cropped, (margin, margin)) 73 | 74 | # 3rd rotate 75 | base = base.rotate(rotate, expand=True) 76 | 77 | # 4th 78 | image = img.convert('RGBA') 79 | image2 = image.copy() 80 | x, y = BMABWatermark.alignment[alignment](image.width, image.height, base.width, base.height) 81 | image2.paste(base, (int(x), int(y))) 82 | return Image.alpha_composite(image, image2) 83 | 84 | def process(self, image=None, bind=None, **kwargs): 85 | pixels = bind.pixels if image is None else image 86 | results = [] 87 | for img in utils.get_pils_from_pixels(pixels): 88 | results.append(self.process_watermark(img, **kwargs)) 89 | pixels = utils.get_pixels_from_pils(results) 90 | return (bind, pixels, ) 91 | 92 | @staticmethod 93 | def color_hex_to_rgb(value, transparency): 94 | value = value.lstrip('#') 95 | lv = len(value) 96 | r, g, b = tuple(int(value[i:i + 2], 16) for i in range(0, lv, 2)) 97 | return r, g, b, transparency 98 | 99 | @staticmethod 100 | def list_fonts(): 101 | if sys.platform == 'win32': 102 | path = 'C:\\Windows\\Fonts\\*.ttf' 103 | files = glob.glob(path) 104 | return [os.path.basename(f) for f in files] 105 | if sys.platform == 'darwin': 106 | path = '/System/Library/Fonts/*' 107 | files = glob.glob(path) 108 | return [os.path.basename(f) for f in files] 109 | if sys.platform == 'linux': 110 | path = '/usr/share/fonts/*' 111 | files = glob.glob(path) 112 | fonts = [os.path.basename(f) for f in files] 113 | if 'SAGEMAKER_INTERNAL_IMAGE_URI' in os.environ: 114 | path = '/opt/conda/envs/sagemaker-distribution/fonts/*' 115 | files = glob.glob(path) 116 | fonts.extend([os.path.basename(f) for f in files]) 117 | return fonts 118 | return [''] 119 | 120 | @staticmethod 121 | def get_font(font, size): 122 | if sys.platform == 'win32': 123 | path = f'C:\\Windows\\Fonts\\{font}' 124 | return ImageFont.truetype(path, size, encoding="unic") 125 | if sys.platform == 'darwin': 126 | path = f'/System/Library/Fonts/{font}' 127 | return ImageFont.truetype(path, size, encoding="unic") 128 | if sys.platform == 'linux': 129 | if 'SAGEMAKER_INTERNAL_IMAGE_URI' in os.environ: 130 | path = f'/opt/conda/envs/sagemaker-distribution/fonts/{font}' 131 | else: 132 | path = f'/usr/share/fonts/{font}' 133 | return ImageFont.truetype(path, size, encoding="unic") 134 | -------------------------------------------------------------------------------- /bmab/process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from PIL import Image 4 | from PIL import ImageDraw 5 | from PIL import ImageFilter 6 | 7 | import comfy 8 | import nodes 9 | import folder_paths 10 | 11 | from bmab import utils 12 | from bmab.nodes import BMABBind 13 | from bmab.utils.color import apply_color_correction 14 | 15 | 16 | def process_img2img(bind: BMABBind, image, params): 17 | steps, cfg, sampler_name, scheduler, denoise = params['steps'], params['cfg_scale'], params['sampler_name'], params['scheduler'], params['denoise'] 18 | pixels = utils.pil2tensor(image.convert('RGB')) 19 | latent = dict(samples=bind.vae.encode(pixels)) 20 | samples = nodes.common_ksampler(bind.model, bind.seed, steps, cfg, sampler_name, scheduler, bind.positive, bind.negative, latent, denoise=denoise)[0] 21 | latent = bind.vae.decode(samples["samples"]) 22 | result = utils.tensor2pil(latent) 23 | return result 24 | 25 | 26 | def process_img2img_with_mask(bind: BMABBind, image, params, mask=None, box=None): 27 | width, height, padding, dilation = params['width'], params['height'], params['padding'], params['dilation'] 28 | 29 | if box is None: 30 | box = mask.getbbox() 31 | if box is None: 32 | return image 33 | 34 | if mask is None: 35 | mask = Image.new('L', image.size, 0) 36 | dr = ImageDraw.Draw(mask, 'L') 37 | dr.rectangle(box, fill=255) 38 | 39 | x1, y1, x2, y2 = tuple(int(x) for x in box) 40 | 41 | cbx = utils.get_box_with_padding(image, (x1, y1, x2, y2), padding) 42 | cropped = image.crop(cbx) 43 | resized = utils.resize_and_fill(cropped, width, height) 44 | processed = process_img2img(bind, resized, params) 45 | processed = apply_color_correction(resized, processed) 46 | 47 | iratio = width / height 48 | cratio = cropped.width / cropped.height 49 | if iratio < cratio: 50 | ratio = cropped.width / width 51 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio))) 52 | y0 = (processed.height - cropped.height) // 2 53 | processed = processed.crop((0, y0, cropped.width, y0 + cropped.height)) 54 | else: 55 | ratio = cropped.height / height 56 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio))) 57 | x0 = (processed.width - cropped.width) // 2 58 | processed = processed.crop((x0, 0, x0 + cropped.width, cropped.height)) 59 | 60 | img = image.copy() 61 | img.paste(processed, (cbx[0], cbx[1])) 62 | 63 | pil_mask = utils.dilate_mask(mask, dilation) 64 | blur = ImageFilter.GaussianBlur(dilation) 65 | blur_mask = pil_mask.filter(blur) 66 | 67 | image.paste(img, (0, 0), mask=blur_mask) 68 | return image 69 | 70 | 71 | def load_controlnet(control_net_name): 72 | controlnet_path = folder_paths.get_full_path('controlnet', control_net_name) 73 | controlnet = comfy.controlnet.load_controlnet(controlnet_path) 74 | return controlnet 75 | 76 | 77 | def apply_controlnet(control_net_name, positive, negative, strength, start_percent, end_percent, image): 78 | control_net = load_controlnet(control_net_name) 79 | 80 | control_hint = image.movedim(-1, 1) 81 | cnets = {} 82 | 83 | out = [] 84 | for conditioning in [positive, negative]: 85 | c = [] 86 | for t in conditioning: 87 | d = t[1].copy() 88 | 89 | prev_cnet = d.get('control', None) 90 | if prev_cnet in cnets: 91 | c_net = cnets[prev_cnet] 92 | else: 93 | c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent)) 94 | c_net.set_previous_controlnet(prev_cnet) 95 | cnets[prev_cnet] = c_net 96 | 97 | d['control'] = c_net 98 | d['control_apply_to_uncond'] = False 99 | n = [t[0], d] 100 | c.append(n) 101 | out.append(c) 102 | return out[0], out[1] 103 | 104 | 105 | def preprocess(image, mask): 106 | mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear") 107 | mask = mask.movedim(1, -1).expand((-1, -1, -1, 3)) 108 | image = image.clone() 109 | image[mask > 0.5] = -1.0 110 | return image 111 | 112 | 113 | def process_img2img_with_controlnet(bind: BMABBind, image, params, controlnet_name, mask=None): 114 | steps, cfg, sampler_name, scheduler, denoise = params['steps'], params['cfg_scale'], params['sampler_name'], params['scheduler'], params['denoise'] 115 | 116 | pixels = utils.pil2tensor(image.convert('RGB')) 117 | latent = dict(samples=bind.vae.encode(pixels)) 118 | 119 | if mask is not None: 120 | mask_pixels = utils.pil2tensor(mask.convert('RGB')) 121 | cn_pixels = preprocess(pixels, mask_pixels[:, :, :, 0]) 122 | else: 123 | cn_pixels = pixels 124 | positive, negative = apply_controlnet(controlnet_name, bind.positive, bind.negative, 1.0, 0.0, 1.0, cn_pixels) 125 | 126 | samples = nodes.common_ksampler(bind.model, bind.seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=denoise)[0] 127 | latent = bind.vae.decode(samples["samples"]) 128 | result = utils.tensor2pil(latent) 129 | return result 130 | 131 | 132 | def process_img2img_with_controlnet_mask(bind: BMABBind, controlnet_name, image, params, mask=None, box=None): 133 | width, height, padding, dilation = params['width'], params['height'], params['padding'], params['dilation'] 134 | 135 | if box is None: 136 | box = mask.getbbox() 137 | if box is None: 138 | return image 139 | 140 | if mask is None: 141 | mask = Image.new('L', image.size, 0) 142 | dr = ImageDraw.Draw(mask, 'L') 143 | dr.rectangle(box, fill=255) 144 | 145 | x1, y1, x2, y2 = tuple(int(x) for x in box) 146 | 147 | cbx = utils.get_box_with_padding(image, (x1, y1, x2, y2), padding) 148 | cropped = image.crop(cbx) 149 | resized = utils.resize_and_fill(cropped, width, height) 150 | processed = process_img2img_with_controlnet(bind, resized, params, controlnet_name, mask=mask) 151 | processed = apply_color_correction(resized, processed) 152 | 153 | iratio = width / height 154 | cratio = cropped.width / cropped.height 155 | if iratio < cratio: 156 | ratio = cropped.width / width 157 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio))) 158 | y0 = (processed.height - cropped.height) // 2 159 | processed = processed.crop((0, y0, cropped.width, y0 + cropped.height)) 160 | else: 161 | ratio = cropped.height / height 162 | processed = processed.resize((int(processed.width * ratio), int(processed.height * ratio))) 163 | x0 = (processed.width - cropped.width) // 2 164 | processed = processed.crop((x0, 0, x0 + cropped.width, cropped.height)) 165 | 166 | img = image.copy() 167 | img.paste(processed, (cbx[0], cbx[1])) 168 | 169 | pil_mask = utils.dilate_mask(mask, dilation) 170 | blur = ImageFilter.GaussianBlur(dilation) 171 | blur_mask = pil_mask.filter(blur) 172 | 173 | image.paste(img, (0, 0), mask=blur_mask) 174 | return image 175 | 176 | 177 | 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /bmab/serverext.py: -------------------------------------------------------------------------------- 1 | from server import PromptServer 2 | from aiohttp import web 3 | 4 | from PIL import Image 5 | from bmab.nodes import fill, upscaler 6 | import base64 7 | from io import BytesIO 8 | from bmab import utils 9 | from comfy import utils as cutils 10 | from bmab.utils import yolo, sam 11 | 12 | memory_image_storage = {} 13 | 14 | 15 | def b64_encoding(image): 16 | buffered = BytesIO() 17 | image.save(buffered, format="PNG") 18 | return base64.b64encode(buffered.getvalue()).decode("utf-8") 19 | 20 | 21 | def b64_decoding(b64): 22 | return Image.open(BytesIO(base64.b64decode(b64))) 23 | 24 | 25 | client_ids = {} 26 | 27 | 28 | @PromptServer.instance.routes.get("/bmab") 29 | async def bmab_register_client(request): 30 | remote_client_id = request.rel_url.query.get("remote_client_id") 31 | remote_name = request.rel_url.query.get("remote_name") 32 | 33 | print(remote_client_id, remote_name) 34 | client_ids[remote_client_id] = {'name': remote_name} 35 | data = {'name': remote_name, 'client_id': remote_client_id} 36 | return web.json_response(data) 37 | 38 | 39 | @PromptServer.instance.routes.get("/bmab/remote") 40 | async def bmab_remote(request): 41 | command = request.rel_url.query.get("command") 42 | name = request.rel_url.query.get("name") 43 | data = {'command': command, 'name': name} 44 | 45 | if command == 'queue': 46 | for sid, v in client_ids.items(): 47 | if v.get('name') == name: 48 | await PromptServer.instance.send("bmab_queue", {"status": '', 'sid': sid}, sid) 49 | data['client_id'] = sid 50 | 51 | return web.json_response(data) 52 | 53 | 54 | @PromptServer.instance.routes.post("/bmab/outpaintbyratio") 55 | async def bmab_outpaintbyratio(request): 56 | j = await request.json() 57 | b64img = j.get('image') 58 | if b64img is not None: 59 | prompt = j.get('prompt') 60 | align = j.get('align', 'bottom') 61 | ratio = j.get('ratio', 0.85) 62 | dilation = j.get('dilation', 16) 63 | steps = j.get('steps', 8) 64 | 65 | filler = fill.BMABOutpaintByRatio() 66 | img = filler.infer(b64_decoding(b64img), align, ratio, dilation, steps, prompt) 67 | data = {'image': b64_encoding(img)} 68 | return web.json_response(data) 69 | else: 70 | print('release') 71 | fill.unload() 72 | return web.json_response({}) 73 | 74 | 75 | @PromptServer.instance.routes.post("/bmab/reframe") 76 | async def bmab_reframe(request): 77 | j = await request.json() 78 | b64img = j.get('image') 79 | if b64img is not None: 80 | prompt = j.get('prompt') 81 | ratio = j.get('ratio', '1:1') 82 | dilation = j.get('dilation', 16) 83 | steps = j.get('steps', 8) 84 | r = fill.BMABReframe.ratio_sel.get(ratio, (1024, 1024)) 85 | 86 | filler = fill.BMABReframe() 87 | img = filler.infer(b64_decoding(b64img), r[0], r[1], dilation, steps, prompt) 88 | data = {'image': b64_encoding(img)} 89 | return web.json_response(data) 90 | else: 91 | print('release') 92 | fill.unload() 93 | return web.json_response({}) 94 | 95 | 96 | @PromptServer.instance.routes.post("/bmab/upscale") 97 | async def bmab_upscale(request): 98 | from comfy import utils as cutils 99 | 100 | j = await request.json() 101 | b64img = j.get('image') 102 | model = j.get('model') 103 | scale = j.get('scale', '2') 104 | width = j.get('width', 0) 105 | height = j.get('height', 0) 106 | 107 | hook = cutils.PROGRESS_BAR_HOOK 108 | cutils.PROGRESS_BAR_HOOK = None 109 | try: 110 | up = upscaler.BMABUpscaleWithModel() 111 | img = b64_decoding(b64img) 112 | if scale != 0: 113 | width, height = int(img.width * scale), int(img.height * scale) 114 | pixels = utils.get_pixels_from_pils([img]) 115 | s = up.upscale_with_model(model, pixels, progress=False) 116 | utils.torch_gc() 117 | finally: 118 | cutils.PROGRESS_BAR_HOOK = hook 119 | 120 | pil_images = utils.get_pils_from_pixels(s) 121 | result = pil_images[0].resize((width, height), Image.Resampling.LANCZOS) 122 | data = {'image': b64_encoding(result)} 123 | return web.json_response(data) 124 | 125 | 126 | @PromptServer.instance.routes.post("/bmab/inpaint") 127 | async def bmab_inpaint(request): 128 | j = await request.json() 129 | b64img = j.get('image') 130 | if b64img is not None: 131 | b64mask = j.get('mask') 132 | prompt = j.get('prompt') 133 | steps = j.get('steps') 134 | 135 | inpaint = fill.BMABInpaint() 136 | img = b64_decoding(b64img) 137 | msk = b64_decoding(b64mask).convert('L') 138 | 139 | result = inpaint.infer(img, msk, steps, prompt_input=prompt) 140 | data = {'image': b64_encoding(result)} 141 | return web.json_response(data) 142 | else: 143 | print('release') 144 | fill.unload() 145 | return web.json_response({}) 146 | 147 | 148 | @PromptServer.instance.routes.post("/bmab/depth") 149 | async def bmab_inpaint(request): 150 | from custom_nodes.comfyui_controlnet_aux.node_wrappers.depth_anything import Depth_Anything_Preprocessor 151 | 152 | j = await request.json() 153 | b64img = j.get('image') 154 | resolution = j.get('resolution') 155 | images = utils.get_pixels_from_pils([b64_decoding(b64img)]) 156 | hook = cutils.PROGRESS_BAR_HOOK 157 | cutils.PROGRESS_BAR_HOOK = None 158 | try: 159 | node = Depth_Anything_Preprocessor() 160 | out = node.execute(images, 'depth_anything_vitl14.pth', resolution) 161 | finally: 162 | cutils.PROGRESS_BAR_HOOK = hook 163 | pil_images = utils.get_pils_from_pixels(out[0]) 164 | data = {'image': b64_encoding(pil_images[0])} 165 | return web.json_response(data) 166 | 167 | 168 | @PromptServer.instance.routes.post("/bmab/sam") 169 | async def bmab_sam(request): 170 | j = await request.json() 171 | b64img = j.get('image') 172 | model = j.get('model') 173 | confidence = j.get('confidence') 174 | 175 | image = b64_decoding(b64img) 176 | boxes, conf = yolo.predict(image, model, confidence) 177 | for box in boxes: 178 | mask = sam.sam_predict_box(image, box) 179 | data = {'image': b64_encoding(mask.convert('RGB'))} 180 | sam.release() 181 | return web.json_response(data) 182 | return web.json_response({}) 183 | 184 | 185 | @PromptServer.instance.routes.post("/bmab/detect") 186 | async def bmab_detect(request): 187 | j = await request.json() 188 | b64img = j.get('image') 189 | model = j.get('model') 190 | confidence = j.get('confidence') 191 | 192 | image = b64_decoding(b64img) 193 | boxes, conf = yolo.predict(image, model, confidence) 194 | data = {'boxes': boxes} 195 | return web.json_response(data) 196 | 197 | 198 | @PromptServer.instance.routes.post("/bmab/free") 199 | async def bmab_free(request): 200 | fill.unload() 201 | return web.json_response({}) 202 | 203 | 204 | @PromptServer.instance.routes.get("/bmab/images") 205 | async def bmab_upscale(request): 206 | remote_client_id = request.rel_url.query.get("clientId") 207 | imgs = memory_image_storage.get(remote_client_id, []) 208 | results = [b64_encoding(i) for i in imgs] 209 | data = {'images': results} 210 | return web.json_response(data) 211 | 212 | 213 | from transformers import AutoProcessor, AutoModelForCausalLM 214 | import torch 215 | 216 | 217 | @PromptServer.instance.routes.post("/bmab/caption") 218 | async def bmab_upscale(request): 219 | j = await request.json() 220 | b64img = j.get('image') 221 | 222 | image = b64_decoding(b64img) 223 | caption = run_captioning(image) 224 | ret = { 225 | 'caption': caption 226 | } 227 | return web.json_response(ret) 228 | 229 | 230 | def run_captioning(image): 231 | print(f"run_captioning") 232 | # Load internally to not consume resources for training 233 | device = "cuda" if torch.cuda.is_available() else "cpu" 234 | print(f"device={device}") 235 | torch_dtype = torch.float16 236 | model = AutoModelForCausalLM.from_pretrained( 237 | "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True 238 | ).to(device) 239 | processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True) 240 | 241 | prompt = "" 242 | inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) 243 | print(f"inputs {inputs}") 244 | 245 | generated_ids = model.generate( 246 | input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 247 | ) 248 | print(f"generated_ids {generated_ids}") 249 | 250 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 251 | print(f"generated_text: {generated_text}") 252 | parsed_answer = processor.post_process_generation( 253 | generated_text, task=prompt, image_size=(image.width, image.height) 254 | ) 255 | print(f"parsed_answer = {parsed_answer}") 256 | caption_text = parsed_answer[""].replace("The image shows ", "") 257 | 258 | 259 | model.to("cpu") 260 | del model 261 | del processor 262 | if torch.cuda.is_available(): 263 | torch.cuda.empty_cache() 264 | 265 | return caption_text -------------------------------------------------------------------------------- /bmab/utils/color.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from skimage import exposure 4 | from blendmodes.blend import blendLayers, BlendType 5 | from PIL import Image 6 | 7 | 8 | def apply_color_correction(correction, original_image): 9 | image = Image.fromarray(cv2.cvtColor(exposure.match_histograms( 10 | cv2.cvtColor( 11 | np.asarray(original_image), 12 | cv2.COLOR_RGB2LAB 13 | ), 14 | cv2.cvtColor(np.asarray(correction.copy()), cv2.COLOR_RGB2LAB), 15 | channel_axis=2 16 | ), cv2.COLOR_LAB2RGB).astype("uint8")) 17 | 18 | image = blendLayers(image, original_image, BlendType.LUMINOSITY) 19 | 20 | return image.convert('RGB') 21 | -------------------------------------------------------------------------------- /bmab/utils/grdino.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 4 | from bmab import utils 5 | 6 | 7 | def get_device(): 8 | if sys.platform == 'darwin': 9 | # MPS is not good. 10 | return 'cpu' 11 | elif torch.cuda.is_available(): 12 | return 'cuda' 13 | return 'cpu' 14 | 15 | 16 | def predict(pilimg, prompt, box_threahold=0.35, text_threshold=0.25, device=get_device()): 17 | model_id = "IDEA-Research/grounding-dino-base" 18 | 19 | processor = AutoProcessor.from_pretrained(model_id) 20 | model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) 21 | 22 | inputs = processor(images=pilimg, text=prompt, return_tensors="pt").to(device) 23 | with torch.no_grad(): 24 | outputs = model(**inputs) 25 | 26 | results = processor.post_process_grounded_object_detection( 27 | outputs, 28 | inputs.input_ids, 29 | box_threshold=box_threahold, 30 | text_threshold=text_threshold, 31 | target_sizes=[pilimg.size[::-1]] 32 | ) 33 | del processor 34 | model.to('cpu') 35 | utils.torch_gc() 36 | 37 | result = results[0] 38 | return result["boxes"], result["scores"], result["labels"] 39 | -------------------------------------------------------------------------------- /bmab/utils/sam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from segment_anything import SamPredictor 7 | from segment_anything import sam_model_registry 8 | from bmab import utils 9 | 10 | 11 | bmab_model_path = os.path.join(os.path.dirname(__file__), '../../models') 12 | 13 | sam_model = None 14 | 15 | 16 | def sam_init(model): 17 | model_type = 'vit_b' 18 | for m in ('vit_b', 'vit_l', 'vit_h'): 19 | if model.find(m) >= 0: 20 | model_type = m 21 | break 22 | 23 | global sam_model 24 | if not sam_model: 25 | utils.lazy_loader(model) 26 | sam_model = sam_model_registry[model_type](checkpoint=f'%s/{model}' % bmab_model_path) 27 | sam_model.to(device=utils.get_device()) 28 | sam_model.eval() 29 | return sam_model 30 | 31 | 32 | def sam_predict(pilimg, boxes, model='sam_vit_b_01ec64.pth'): 33 | sam = sam_init(model) 34 | 35 | mask_predictor = SamPredictor(sam) 36 | 37 | numpy_image = np.array(pilimg) 38 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) 39 | mask_predictor.set_image(opencv_image) 40 | 41 | result = Image.new('L', pilimg.size, 0) 42 | for box in boxes: 43 | x1, y1, x2, y2 = box 44 | 45 | box = np.array([int(x1), int(y1), int(x2), int(y2)]) 46 | masks, scores, logits = mask_predictor.predict( 47 | box=box, 48 | multimask_output=False 49 | ) 50 | 51 | mask = Image.fromarray(masks[0]) 52 | result.paste(mask, mask=mask) 53 | 54 | return result 55 | 56 | 57 | def sam_predict_box(pilimg, box, model='sam_vit_b_01ec64.pth'): 58 | sam = sam_init(model) 59 | 60 | mask_predictor = SamPredictor(sam) 61 | 62 | numpy_image = np.array(pilimg) 63 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) 64 | mask_predictor.set_image(opencv_image) 65 | 66 | x1, y1, x2, y2 = box 67 | box = np.array([int(x1), int(y1), int(x2), int(y2)]) 68 | 69 | masks, scores, logits = mask_predictor.predict( 70 | box=box, 71 | multimask_output=False 72 | ) 73 | 74 | return Image.fromarray(masks[0]) 75 | 76 | 77 | def get_array_predict_box(pilimg, box, model='sam_vit_b_01ec64.pth'): 78 | sam = sam_init(model) 79 | mask_predictor = SamPredictor(sam) 80 | numpy_image = np.array(pilimg) 81 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) 82 | mask_predictor.set_image(opencv_image) 83 | x1, y1, x2, y2 = box 84 | box = np.array([int(x1), int(y1), int(x2), int(y2)]) 85 | masks, scores, logits = mask_predictor.predict( 86 | box=box, 87 | multimask_output=False 88 | ) 89 | return masks[0] 90 | 91 | 92 | def release(): 93 | global sam_model 94 | if sam_model is not None: 95 | sam_model.to(device='cpu') 96 | sam_model = None 97 | utils.torch_gc() 98 | -------------------------------------------------------------------------------- /bmab/utils/yolo.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from ultralytics import YOLO 4 | from bmab import utils 5 | 6 | 7 | def predict(image: Image, model, confidence): 8 | yolo = utils.lazy_loader(model) 9 | boxes = [] 10 | confs = [] 11 | try: 12 | model = YOLO(yolo) 13 | pred = model(image, conf=confidence, device=utils.get_device()) 14 | boxes = pred[0].boxes.xyxy.cpu().numpy() 15 | boxes = boxes.tolist() 16 | confs = pred[0].boxes.conf.tolist() 17 | except: 18 | pass 19 | del model 20 | utils.torch_gc() 21 | return boxes, confs 22 | 23 | -------------------------------------------------------------------------------- /models/put_models_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/models/put_models_here -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_bmab" 3 | description = "BMAB for ComfyUI. BMAB is an custom nodes of ComfyUI and has the function of post-processing the generated image according to settings. If necessary, you can find and redraw people, faces, and hands, or perform functions such as resize, resample, and add noise. You can composite two images or perform the Upscale function." 4 | version = "1.1.1" 5 | license = { text = "GNU Affero General Public License v3.0" } 6 | dependencies = ["lightning", "segment_anything", "omegaconf", "ultralytics", "scikit-image", "huggingface_hub", "transformers>=4.40.0", "blendmodes", "opencv-python", "urllib3", "pillow_avif_plugin"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/portu-sim/comfyui_bmab" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "portu-sim" 14 | DisplayName = "comfyui_bmab" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lightning 2 | segment_anything 3 | omegaconf 4 | ultralytics 5 | scikit-image 6 | huggingface_hub 7 | transformers>=4.40.0 8 | blendmodes 9 | opencv-python 10 | urllib3 11 | pillow 12 | pillow_avif_plugin 13 | 14 | -------------------------------------------------------------------------------- /resources/cache/_cachefiles_will_be_put_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/resources/cache/_cachefiles_will_be_put_here -------------------------------------------------------------------------------- /resources/examples/bmab-flux-sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 41, 3 | "last_link_id": 65, 4 | "nodes": [ 5 | { 6 | "id": 39, 7 | "type": "CheckpointLoaderSimple", 8 | "pos": [ 9 | 8, 10 | 617 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 98 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "MODEL", 22 | "type": "MODEL", 23 | "links": [ 24 | 58 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "CLIP", 31 | "type": "CLIP", 32 | "links": [ 33 | 59 34 | ], 35 | "shape": 3, 36 | "slot_index": 1 37 | }, 38 | { 39 | "name": "VAE", 40 | "type": "VAE", 41 | "links": [ 42 | 61 43 | ], 44 | "shape": 3, 45 | "slot_index": 2 46 | } 47 | ], 48 | "properties": { 49 | "Node name for S&R": "CheckpointLoaderSimple" 50 | }, 51 | "widgets_values": [ 52 | "FLUX1\\flux1-dev-fp8.safetensors" 53 | ] 54 | }, 55 | { 56 | "id": 38, 57 | "type": "BMAB Flux Integrator", 58 | "pos": [ 59 | 376, 60 | 617 61 | ], 62 | "size": [ 63 | 411.24033826936557, 64 | 377.8769301478941 65 | ], 66 | "flags": {}, 67 | "order": 3, 68 | "mode": 0, 69 | "inputs": [ 70 | { 71 | "name": "model", 72 | "type": "MODEL", 73 | "link": 58 74 | }, 75 | { 76 | "name": "clip", 77 | "type": "CLIP", 78 | "link": 59 79 | }, 80 | { 81 | "name": "vae", 82 | "type": "VAE", 83 | "link": 61 84 | }, 85 | { 86 | "name": "context", 87 | "type": "CONTEXT", 88 | "link": 62 89 | }, 90 | { 91 | "name": "seed_in", 92 | "type": "SEED", 93 | "link": null 94 | }, 95 | { 96 | "name": "latent", 97 | "type": "LATENT", 98 | "link": 63 99 | }, 100 | { 101 | "name": "image", 102 | "type": "IMAGE", 103 | "link": null 104 | } 105 | ], 106 | "outputs": [ 107 | { 108 | "name": "BMAB bind", 109 | "type": "BMAB bind", 110 | "links": [ 111 | 64 112 | ], 113 | "shape": 3, 114 | "slot_index": 0 115 | } 116 | ], 117 | "properties": { 118 | "Node name for S&R": "BMAB Flux Integrator" 119 | }, 120 | "widgets_values": [ 121 | 3.5, 122 | "cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open placing a fancy black forest cake with candles on top of a dinner table of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere there are paintings on the walls" 123 | ] 124 | }, 125 | { 126 | "id": 41, 127 | "type": "BMAB KSampler", 128 | "pos": [ 129 | 841, 130 | 618 131 | ], 132 | "size": { 133 | "0": 315, 134 | "1": 174 135 | }, 136 | "flags": {}, 137 | "order": 4, 138 | "mode": 0, 139 | "inputs": [ 140 | { 141 | "name": "bind", 142 | "type": "BMAB bind", 143 | "link": 64 144 | }, 145 | { 146 | "name": "lora", 147 | "type": "BMAB lora", 148 | "link": null 149 | } 150 | ], 151 | "outputs": [ 152 | { 153 | "name": "BMAB bind", 154 | "type": "BMAB bind", 155 | "links": null, 156 | "shape": 3 157 | }, 158 | { 159 | "name": "image", 160 | "type": "IMAGE", 161 | "links": [ 162 | 65 163 | ], 164 | "shape": 3, 165 | "slot_index": 1 166 | } 167 | ], 168 | "properties": { 169 | "Node name for S&R": "BMAB KSampler" 170 | }, 171 | "widgets_values": [ 172 | 20, 173 | 1, 174 | "euler", 175 | "normal", 176 | 1 177 | ] 178 | }, 179 | { 180 | "id": 27, 181 | "type": "EmptySD3LatentImage", 182 | "pos": [ 183 | 11, 184 | 991 185 | ], 186 | "size": { 187 | "0": 315, 188 | "1": 106 189 | }, 190 | "flags": {}, 191 | "order": 1, 192 | "mode": 0, 193 | "outputs": [ 194 | { 195 | "name": "LATENT", 196 | "type": "LATENT", 197 | "links": [ 198 | 63 199 | ], 200 | "shape": 3, 201 | "slot_index": 0 202 | } 203 | ], 204 | "properties": { 205 | "Node name for S&R": "EmptySD3LatentImage" 206 | }, 207 | "widgets_values": [ 208 | 1024, 209 | 1024, 210 | 1 211 | ], 212 | "color": "#323", 213 | "bgcolor": "#535" 214 | }, 215 | { 216 | "id": 9, 217 | "type": "SaveImage", 218 | "pos": [ 219 | 1190, 220 | 628 221 | ], 222 | "size": { 223 | "0": 985.3012084960938, 224 | "1": 1060.3828125 225 | }, 226 | "flags": {}, 227 | "order": 5, 228 | "mode": 0, 229 | "inputs": [ 230 | { 231 | "name": "images", 232 | "type": "IMAGE", 233 | "link": 65 234 | } 235 | ], 236 | "properties": {}, 237 | "widgets_values": [ 238 | "ComfyUI" 239 | ] 240 | }, 241 | { 242 | "id": 40, 243 | "type": "BMAB Context", 244 | "pos": [ 245 | 7, 246 | 757 247 | ], 248 | "size": { 249 | "0": 315, 250 | "1": 178 251 | }, 252 | "flags": {}, 253 | "order": 2, 254 | "mode": 0, 255 | "inputs": [ 256 | { 257 | "name": "seed_in", 258 | "type": "SEED", 259 | "link": null 260 | } 261 | ], 262 | "outputs": [ 263 | { 264 | "name": "BMAB context", 265 | "type": "CONTEXT", 266 | "links": [ 267 | 62 268 | ], 269 | "shape": 3, 270 | "slot_index": 0 271 | } 272 | ], 273 | "properties": { 274 | "Node name for S&R": "BMAB Context" 275 | }, 276 | "widgets_values": [ 277 | 972054013131368, 278 | "randomize", 279 | 20, 280 | 8.040000000000001, 281 | "dpmpp_sde", 282 | "karras" 283 | ] 284 | } 285 | ], 286 | "links": [ 287 | [ 288 | 58, 289 | 39, 290 | 0, 291 | 38, 292 | 0, 293 | "MODEL" 294 | ], 295 | [ 296 | 59, 297 | 39, 298 | 1, 299 | 38, 300 | 1, 301 | "CLIP" 302 | ], 303 | [ 304 | 61, 305 | 39, 306 | 2, 307 | 38, 308 | 2, 309 | "VAE" 310 | ], 311 | [ 312 | 62, 313 | 40, 314 | 0, 315 | 38, 316 | 3, 317 | "CONTEXT" 318 | ], 319 | [ 320 | 63, 321 | 27, 322 | 0, 323 | 38, 324 | 5, 325 | "LATENT" 326 | ], 327 | [ 328 | 64, 329 | 38, 330 | 0, 331 | 41, 332 | 0, 333 | "BMAB bind" 334 | ], 335 | [ 336 | 65, 337 | 41, 338 | 1, 339 | 9, 340 | 0, 341 | "IMAGE" 342 | ] 343 | ], 344 | "groups": [], 345 | "config": {}, 346 | "extra": { 347 | "ds": { 348 | "scale": 0.7513148009015777, 349 | "offset": [ 350 | 466.97698927324814, 351 | -334.0327902578701 352 | ] 353 | } 354 | }, 355 | "version": 0.4 356 | } -------------------------------------------------------------------------------- /resources/examples/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 29, 3 | "last_link_id": 46, 4 | "nodes": [ 5 | { 6 | "id": 4, 7 | "type": "CheckpointLoaderSimple", 8 | "pos": [ 9 | 25, 10 | 217 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 98 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "MODEL", 22 | "type": "MODEL", 23 | "links": [ 24 | 34 25 | ], 26 | "slot_index": 0 27 | }, 28 | { 29 | "name": "CLIP", 30 | "type": "CLIP", 31 | "links": [ 32 | 35 33 | ], 34 | "slot_index": 1 35 | }, 36 | { 37 | "name": "VAE", 38 | "type": "VAE", 39 | "links": [], 40 | "slot_index": 2 41 | } 42 | ], 43 | "properties": { 44 | "Node name for S&R": "CheckpointLoaderSimple" 45 | }, 46 | "widgets_values": [ 47 | "portu_429_lora2.fp16.safetensors" 48 | ] 49 | }, 50 | { 51 | "id": 18, 52 | "type": "VAELoader", 53 | "pos": [ 54 | 20, 55 | 356 56 | ], 57 | "size": { 58 | "0": 315, 59 | "1": 58 60 | }, 61 | "flags": {}, 62 | "order": 1, 63 | "mode": 0, 64 | "outputs": [ 65 | { 66 | "name": "VAE", 67 | "type": "VAE", 68 | "links": [ 69 | 38 70 | ], 71 | "shape": 3, 72 | "slot_index": 0 73 | } 74 | ], 75 | "properties": { 76 | "Node name for S&R": "VAELoader" 77 | }, 78 | "widgets_values": [ 79 | "vae-ft-mse-840000-ema-pruned.ckpt" 80 | ] 81 | }, 82 | { 83 | "id": 26, 84 | "type": "BMAB Context", 85 | "pos": [ 86 | 24, 87 | 464 88 | ], 89 | "size": { 90 | "0": 315, 91 | "1": 178 92 | }, 93 | "flags": {}, 94 | "order": 2, 95 | "mode": 0, 96 | "inputs": [ 97 | { 98 | "name": "seed_in", 99 | "type": "SEED", 100 | "link": null 101 | } 102 | ], 103 | "outputs": [ 104 | { 105 | "name": "BMAB context", 106 | "type": "CONTEXT", 107 | "links": [ 108 | 39 109 | ], 110 | "shape": 3, 111 | "slot_index": 0 112 | } 113 | ], 114 | "properties": { 115 | "Node name for S&R": "BMAB Context" 116 | }, 117 | "widgets_values": [ 118 | 501253181764904, 119 | "randomize", 120 | 20, 121 | 8, 122 | "dpmpp_sde", 123 | "karras" 124 | ] 125 | }, 126 | { 127 | "id": 25, 128 | "type": "BMAB Integrator", 129 | "pos": [ 130 | 407, 131 | 224 132 | ], 133 | "size": { 134 | "0": 400, 135 | "1": 318 136 | }, 137 | "flags": {}, 138 | "order": 4, 139 | "mode": 0, 140 | "inputs": [ 141 | { 142 | "name": "model", 143 | "type": "MODEL", 144 | "link": 34 145 | }, 146 | { 147 | "name": "clip", 148 | "type": "CLIP", 149 | "link": 35 150 | }, 151 | { 152 | "name": "vae", 153 | "type": "VAE", 154 | "link": 38 155 | }, 156 | { 157 | "name": "context", 158 | "type": "CONTEXT", 159 | "link": 39 160 | }, 161 | { 162 | "name": "seed_in", 163 | "type": "SEED", 164 | "link": null 165 | }, 166 | { 167 | "name": "latent", 168 | "type": "LATENT", 169 | "link": 40 170 | }, 171 | { 172 | "name": "image", 173 | "type": "IMAGE", 174 | "link": null 175 | } 176 | ], 177 | "outputs": [ 178 | { 179 | "name": "BMAB bind", 180 | "type": "BMAB bind", 181 | "links": [ 182 | 37 183 | ], 184 | "shape": 3, 185 | "slot_index": 0 186 | } 187 | ], 188 | "properties": { 189 | "Node name for S&R": "BMAB Integrator" 190 | }, 191 | "widgets_values": [ 192 | -2, 193 | "none", 194 | "A1111", 195 | "1girl, standing, full body, street,", 196 | "worst quality, low quality" 197 | ] 198 | }, 199 | { 200 | "id": 5, 201 | "type": "EmptyLatentImage", 202 | "pos": [ 203 | 26, 204 | 687 205 | ], 206 | "size": { 207 | "0": 315, 208 | "1": 106 209 | }, 210 | "flags": {}, 211 | "order": 3, 212 | "mode": 0, 213 | "outputs": [ 214 | { 215 | "name": "LATENT", 216 | "type": "LATENT", 217 | "links": [ 218 | 40 219 | ], 220 | "slot_index": 0 221 | } 222 | ], 223 | "properties": { 224 | "Node name for S&R": "EmptyLatentImage" 225 | }, 226 | "widgets_values": [ 227 | 512, 228 | 768, 229 | 1 230 | ] 231 | }, 232 | { 233 | "id": 11, 234 | "type": "BMAB KSampler", 235 | "pos": [ 236 | 840, 237 | 221 238 | ], 239 | "size": { 240 | "0": 315, 241 | "1": 174 242 | }, 243 | "flags": {}, 244 | "order": 5, 245 | "mode": 0, 246 | "inputs": [ 247 | { 248 | "name": "bind", 249 | "type": "BMAB bind", 250 | "link": 37 251 | }, 252 | { 253 | "name": "lora", 254 | "type": "BMAB lora", 255 | "link": null 256 | } 257 | ], 258 | "outputs": [ 259 | { 260 | "name": "BMAB bind", 261 | "type": "BMAB bind", 262 | "links": [ 263 | 45 264 | ], 265 | "shape": 3, 266 | "slot_index": 0 267 | }, 268 | { 269 | "name": "image", 270 | "type": "IMAGE", 271 | "links": null, 272 | "shape": 3, 273 | "slot_index": 1 274 | } 275 | ], 276 | "properties": { 277 | "Node name for S&R": "BMAB KSampler" 278 | }, 279 | "widgets_values": [ 280 | 20, 281 | 8, 282 | "Use same sampler", 283 | "Use same scheduler", 284 | 1 285 | ] 286 | }, 287 | { 288 | "id": 17, 289 | "type": "BMAB Save Image", 290 | "pos": [ 291 | 1934, 292 | 221 293 | ], 294 | "size": { 295 | "0": 526.6683959960938, 296 | "1": 744.5535278320312 297 | }, 298 | "flags": {}, 299 | "order": 8, 300 | "mode": 0, 301 | "inputs": [ 302 | { 303 | "name": "bind", 304 | "type": "BMAB bind", 305 | "link": 43 306 | }, 307 | { 308 | "name": "images", 309 | "type": "IMAGE", 310 | "link": null 311 | } 312 | ], 313 | "properties": { 314 | "Node name for S&R": "BMAB Save Image" 315 | }, 316 | "widgets_values": [ 317 | "ComfyUI" 318 | ] 319 | }, 320 | { 321 | "id": 29, 322 | "type": "BMAB KSamplerHiresFixWithUpscaler", 323 | "pos": [ 324 | 1206, 325 | 221 326 | ], 327 | "size": { 328 | "0": 315, 329 | "1": 290 330 | }, 331 | "flags": {}, 332 | "order": 6, 333 | "mode": 0, 334 | "inputs": [ 335 | { 336 | "name": "bind", 337 | "type": "BMAB bind", 338 | "link": 45 339 | }, 340 | { 341 | "name": "image", 342 | "type": "IMAGE", 343 | "link": null 344 | }, 345 | { 346 | "name": "lora", 347 | "type": "BMAB lora", 348 | "link": null 349 | } 350 | ], 351 | "outputs": [ 352 | { 353 | "name": "BMAB bind", 354 | "type": "BMAB bind", 355 | "links": [ 356 | 46 357 | ], 358 | "shape": 3, 359 | "slot_index": 0 360 | }, 361 | { 362 | "name": "image", 363 | "type": "IMAGE", 364 | "links": null, 365 | "shape": 3 366 | } 367 | ], 368 | "properties": { 369 | "Node name for S&R": "BMAB KSamplerHiresFixWithUpscaler" 370 | }, 371 | "widgets_values": [ 372 | 20, 373 | 7, 374 | "Use same sampler", 375 | "Use same scheduler", 376 | 0.45, 377 | "LANCZOS", 378 | 2, 379 | 512, 380 | 512 381 | ] 382 | }, 383 | { 384 | "id": 27, 385 | "type": "BMAB Face Detailer", 386 | "pos": [ 387 | 1568, 388 | 219 389 | ], 390 | "size": { 391 | "0": 315, 392 | "1": 290 393 | }, 394 | "flags": {}, 395 | "order": 7, 396 | "mode": 0, 397 | "inputs": [ 398 | { 399 | "name": "bind", 400 | "type": "BMAB bind", 401 | "link": 46 402 | }, 403 | { 404 | "name": "image", 405 | "type": "IMAGE", 406 | "link": null 407 | }, 408 | { 409 | "name": "lora", 410 | "type": "BMAB lora", 411 | "link": null 412 | } 413 | ], 414 | "outputs": [ 415 | { 416 | "name": "BMAB bind", 417 | "type": "BMAB bind", 418 | "links": [ 419 | 43 420 | ], 421 | "shape": 3, 422 | "slot_index": 0 423 | }, 424 | { 425 | "name": "image", 426 | "type": "IMAGE", 427 | "links": null, 428 | "shape": 3 429 | } 430 | ], 431 | "properties": { 432 | "Node name for S&R": "BMAB Face Detailer" 433 | }, 434 | "widgets_values": [ 435 | 20, 436 | 4, 437 | "Use same sampler", 438 | "Use same scheduler", 439 | 0.45, 440 | 32, 441 | 4, 442 | 512, 443 | 512 444 | ] 445 | } 446 | ], 447 | "links": [ 448 | [ 449 | 34, 450 | 4, 451 | 0, 452 | 25, 453 | 0, 454 | "MODEL" 455 | ], 456 | [ 457 | 35, 458 | 4, 459 | 1, 460 | 25, 461 | 1, 462 | "CLIP" 463 | ], 464 | [ 465 | 37, 466 | 25, 467 | 0, 468 | 11, 469 | 0, 470 | "BMAB bind" 471 | ], 472 | [ 473 | 38, 474 | 18, 475 | 0, 476 | 25, 477 | 2, 478 | "VAE" 479 | ], 480 | [ 481 | 39, 482 | 26, 483 | 0, 484 | 25, 485 | 3, 486 | "CONTEXT" 487 | ], 488 | [ 489 | 40, 490 | 5, 491 | 0, 492 | 25, 493 | 5, 494 | "LATENT" 495 | ], 496 | [ 497 | 43, 498 | 27, 499 | 0, 500 | 17, 501 | 0, 502 | "BMAB bind" 503 | ], 504 | [ 505 | 45, 506 | 11, 507 | 0, 508 | 29, 509 | 0, 510 | "BMAB bind" 511 | ], 512 | [ 513 | 46, 514 | 29, 515 | 0, 516 | 27, 517 | 0, 518 | "BMAB bind" 519 | ] 520 | ], 521 | "groups": [], 522 | "config": {}, 523 | "extra": { 524 | "ds": { 525 | "scale": 0.6830134553650709, 526 | "offset": [ 527 | 138.7066514884853, 528 | 67.04635823955786 529 | ] 530 | } 531 | }, 532 | "version": 0.4 533 | } -------------------------------------------------------------------------------- /resources/examples/openpose-hand-detailing-example.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 33, 3 | "last_link_id": 53, 4 | "nodes": [ 5 | { 6 | "id": 4, 7 | "type": "CheckpointLoaderSimple", 8 | "pos": [ 9 | 25, 10 | 217 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 98 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "MODEL", 22 | "type": "MODEL", 23 | "links": [ 24 | 34 25 | ], 26 | "slot_index": 0 27 | }, 28 | { 29 | "name": "CLIP", 30 | "type": "CLIP", 31 | "links": [ 32 | 35 33 | ], 34 | "slot_index": 1 35 | }, 36 | { 37 | "name": "VAE", 38 | "type": "VAE", 39 | "links": [], 40 | "slot_index": 2 41 | } 42 | ], 43 | "properties": { 44 | "Node name for S&R": "CheckpointLoaderSimple" 45 | }, 46 | "widgets_values": [ 47 | "portu_429_lora2.fp16.safetensors" 48 | ] 49 | }, 50 | { 51 | "id": 18, 52 | "type": "VAELoader", 53 | "pos": [ 54 | 20, 55 | 356 56 | ], 57 | "size": { 58 | "0": 315, 59 | "1": 58 60 | }, 61 | "flags": {}, 62 | "order": 1, 63 | "mode": 0, 64 | "outputs": [ 65 | { 66 | "name": "VAE", 67 | "type": "VAE", 68 | "links": [ 69 | 38 70 | ], 71 | "shape": 3, 72 | "slot_index": 0 73 | } 74 | ], 75 | "properties": { 76 | "Node name for S&R": "VAELoader" 77 | }, 78 | "widgets_values": [ 79 | "vae-ft-mse-840000-ema-pruned.ckpt" 80 | ] 81 | }, 82 | { 83 | "id": 26, 84 | "type": "BMAB Context", 85 | "pos": [ 86 | 24, 87 | 464 88 | ], 89 | "size": { 90 | "0": 315, 91 | "1": 178 92 | }, 93 | "flags": {}, 94 | "order": 2, 95 | "mode": 0, 96 | "inputs": [ 97 | { 98 | "name": "seed_in", 99 | "type": "SEED", 100 | "link": null 101 | } 102 | ], 103 | "outputs": [ 104 | { 105 | "name": "BMAB context", 106 | "type": "CONTEXT", 107 | "links": [ 108 | 39 109 | ], 110 | "shape": 3, 111 | "slot_index": 0 112 | } 113 | ], 114 | "properties": { 115 | "Node name for S&R": "BMAB Context" 116 | }, 117 | "widgets_values": [ 118 | 878953794515282, 119 | "randomize", 120 | 20, 121 | 8, 122 | "dpmpp_sde", 123 | "karras" 124 | ] 125 | }, 126 | { 127 | "id": 5, 128 | "type": "EmptyLatentImage", 129 | "pos": [ 130 | 26, 131 | 687 132 | ], 133 | "size": { 134 | "0": 315, 135 | "1": 106 136 | }, 137 | "flags": {}, 138 | "order": 3, 139 | "mode": 0, 140 | "outputs": [ 141 | { 142 | "name": "LATENT", 143 | "type": "LATENT", 144 | "links": [ 145 | 40 146 | ], 147 | "slot_index": 0 148 | } 149 | ], 150 | "properties": { 151 | "Node name for S&R": "EmptyLatentImage" 152 | }, 153 | "widgets_values": [ 154 | 512, 155 | 768, 156 | 1 157 | ] 158 | }, 159 | { 160 | "id": 30, 161 | "type": "BMAB Load Image", 162 | "pos": [ 163 | 407, 164 | 753 165 | ], 166 | "size": [ 167 | 469.37541296463996, 168 | 746.2945929323166 169 | ], 170 | "flags": {}, 171 | "order": 4, 172 | "mode": 0, 173 | "outputs": [ 174 | { 175 | "name": "IMAGE", 176 | "type": "IMAGE", 177 | "links": [ 178 | 47 179 | ], 180 | "shape": 3, 181 | "slot_index": 0 182 | }, 183 | { 184 | "name": "MASK", 185 | "type": "MASK", 186 | "links": null, 187 | "shape": 3 188 | } 189 | ], 190 | "properties": { 191 | "Node name for S&R": "BMAB Load Image" 192 | }, 193 | "widgets_values": [ 194 | "00009-3139289071.png", 195 | "image" 196 | ] 197 | }, 198 | { 199 | "id": 33, 200 | "type": "PreviewImage", 201 | "pos": [ 202 | 1319, 203 | 750 204 | ], 205 | "size": [ 206 | 537.1467559333894, 207 | 756.6380952760667 208 | ], 209 | "flags": {}, 210 | "order": 8, 211 | "mode": 0, 212 | "inputs": [ 213 | { 214 | "name": "images", 215 | "type": "IMAGE", 216 | "link": 52 217 | } 218 | ], 219 | "properties": { 220 | "Node name for S&R": "PreviewImage" 221 | } 222 | }, 223 | { 224 | "id": 32, 225 | "type": "PreviewImage", 226 | "pos": [ 227 | 1875, 228 | 753 229 | ], 230 | "size": [ 231 | 562.8348297615144, 232 | 749.2057612916917 233 | ], 234 | "flags": {}, 235 | "order": 7, 236 | "mode": 0, 237 | "inputs": [ 238 | { 239 | "name": "images", 240 | "type": "IMAGE", 241 | "link": 51 242 | } 243 | ], 244 | "properties": { 245 | "Node name for S&R": "PreviewImage" 246 | } 247 | }, 248 | { 249 | "id": 25, 250 | "type": "BMAB Integrator", 251 | "pos": [ 252 | 407, 253 | 224 254 | ], 255 | "size": { 256 | "0": 400, 257 | "1": 318 258 | }, 259 | "flags": {}, 260 | "order": 5, 261 | "mode": 0, 262 | "inputs": [ 263 | { 264 | "name": "model", 265 | "type": "MODEL", 266 | "link": 34 267 | }, 268 | { 269 | "name": "clip", 270 | "type": "CLIP", 271 | "link": 35 272 | }, 273 | { 274 | "name": "vae", 275 | "type": "VAE", 276 | "link": 38 277 | }, 278 | { 279 | "name": "context", 280 | "type": "CONTEXT", 281 | "link": 39 282 | }, 283 | { 284 | "name": "seed_in", 285 | "type": "SEED", 286 | "link": null 287 | }, 288 | { 289 | "name": "latent", 290 | "type": "LATENT", 291 | "link": 40 292 | }, 293 | { 294 | "name": "image", 295 | "type": "IMAGE", 296 | "link": null 297 | } 298 | ], 299 | "outputs": [ 300 | { 301 | "name": "BMAB bind", 302 | "type": "BMAB bind", 303 | "links": [ 304 | 53 305 | ], 306 | "shape": 3, 307 | "slot_index": 0 308 | } 309 | ], 310 | "properties": { 311 | "Node name for S&R": "BMAB Integrator" 312 | }, 313 | "widgets_values": [ 314 | -2, 315 | "none", 316 | "A1111", 317 | "1girl, standing, full body, street, (detailed hand:1.4)", 318 | "worst quality, low quality" 319 | ] 320 | }, 321 | { 322 | "id": 31, 323 | "type": "BMAB Openpose Hand Detailer", 324 | "pos": [ 325 | 968, 326 | 320 327 | ], 328 | "size": { 329 | "0": 315, 330 | "1": 314 331 | }, 332 | "flags": {}, 333 | "order": 6, 334 | "mode": 0, 335 | "inputs": [ 336 | { 337 | "name": "bind", 338 | "type": "BMAB bind", 339 | "link": 53 340 | }, 341 | { 342 | "name": "image", 343 | "type": "IMAGE", 344 | "link": 47 345 | }, 346 | { 347 | "name": "lora", 348 | "type": "BMAB lora", 349 | "link": null 350 | } 351 | ], 352 | "outputs": [ 353 | { 354 | "name": "BMAB bind", 355 | "type": "BMAB bind", 356 | "links": null, 357 | "shape": 3 358 | }, 359 | { 360 | "name": "image", 361 | "type": "IMAGE", 362 | "links": [ 363 | 51 364 | ], 365 | "shape": 3, 366 | "slot_index": 1 367 | }, 368 | { 369 | "name": "annotation", 370 | "type": "IMAGE", 371 | "links": [ 372 | 52 373 | ], 374 | "shape": 3, 375 | "slot_index": 2 376 | } 377 | ], 378 | "properties": { 379 | "Node name for S&R": "BMAB Openpose Hand Detailer" 380 | }, 381 | "widgets_values": [ 382 | 20, 383 | 7, 384 | "Use same sampler", 385 | "Use same scheduler", 386 | 0.45, 387 | 32, 388 | 4, 389 | 1024, 390 | 1024, 391 | "disable" 392 | ] 393 | } 394 | ], 395 | "links": [ 396 | [ 397 | 34, 398 | 4, 399 | 0, 400 | 25, 401 | 0, 402 | "MODEL" 403 | ], 404 | [ 405 | 35, 406 | 4, 407 | 1, 408 | 25, 409 | 1, 410 | "CLIP" 411 | ], 412 | [ 413 | 38, 414 | 18, 415 | 0, 416 | 25, 417 | 2, 418 | "VAE" 419 | ], 420 | [ 421 | 39, 422 | 26, 423 | 0, 424 | 25, 425 | 3, 426 | "CONTEXT" 427 | ], 428 | [ 429 | 40, 430 | 5, 431 | 0, 432 | 25, 433 | 5, 434 | "LATENT" 435 | ], 436 | [ 437 | 47, 438 | 30, 439 | 0, 440 | 31, 441 | 1, 442 | "IMAGE" 443 | ], 444 | [ 445 | 51, 446 | 31, 447 | 1, 448 | 32, 449 | 0, 450 | "IMAGE" 451 | ], 452 | [ 453 | 52, 454 | 31, 455 | 2, 456 | 33, 457 | 0, 458 | "IMAGE" 459 | ], 460 | [ 461 | 53, 462 | 25, 463 | 0, 464 | 31, 465 | 0, 466 | "BMAB bind" 467 | ] 468 | ], 469 | "groups": [], 470 | "config": {}, 471 | "extra": { 472 | "ds": { 473 | "scale": 0.6830134553650707, 474 | "offset": [ 475 | 185.95766803736825, 476 | -28.764659635582873 477 | ] 478 | } 479 | }, 480 | "version": 0.4 481 | } -------------------------------------------------------------------------------- /resources/wildcard/put_wildcard_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portu-sim/comfyui_bmab/646c90e2be3ab4d377e8e6774d0f919e6154d44a/resources/wildcard/put_wildcard_here -------------------------------------------------------------------------------- /web/gemini.js: -------------------------------------------------------------------------------- 1 | import { app } from "/scripts/app.js"; 2 | import { ComfyWidgets } from "/scripts/widgets.js"; 3 | 4 | app.registerExtension({ 5 | name: "Comfy.BMAB.GoogleGeminiPromptNode", 6 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 7 | if (nodeData.name === "BMAB Google Gemini Prompt") { 8 | 9 | const onExecuted = nodeType.prototype.onExecuted; 10 | nodeType.prototype.onExecuted = function (texts) { 11 | onExecuted?.apply(this, arguments); 12 | let widget_id = this?.widgets.findIndex( 13 | obj => obj.name === 'random_seed' 14 | ); 15 | this.widgets[widget_id].value = Number(texts?.string) 16 | app.graph.setDirtyCanvas(true); 17 | }; 18 | } 19 | }, 20 | }); 21 | -------------------------------------------------------------------------------- /web/loadoutputimage.js: -------------------------------------------------------------------------------- 1 | // this code from ComfyUI_Custom_Nodes_AlekPet 2 | 3 | import { app } from "/scripts/app.js"; 4 | import { ComfyWidgets } from "/scripts/widgets.js"; 5 | import { api } from "../../scripts/api.js"; 6 | 7 | app.registerExtension({ 8 | name: "Comfy.BMAB.LoadOutputImage", 9 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 10 | if (nodeData.name === "BMAB Load Output Image") { 11 | 12 | 13 | const onNodeCreated = nodeType.prototype.onNodeCreated; 14 | nodeType.prototype.onNodeCreated = function () { 15 | onNodeCreated?.apply(this, arguments); 16 | 17 | const node = this 18 | function showImage(name) { 19 | const img = new Image(); 20 | img.onload = () => { 21 | node.imgs = [img]; 22 | app.graph.setDirtyCanvas(true); 23 | }; 24 | 25 | const split = name.split('/'); 26 | let subdir = '' 27 | let fileName = '' 28 | if (split.length === 1) { 29 | fileName = split[0]; 30 | } else { 31 | subdir = split.slice(0, split.length - 1).join('/'); 32 | fileName = split[split.length - 1]; 33 | } 34 | img.src = api.apiURL(`/view?filename=${fileName}&subfolder=${subdir}&type=output${app.getRandParam()}`); 35 | node.setSizeForImage?.(); 36 | } 37 | 38 | const imageWidget = node.widgets.find((w) => w.name === "image"); 39 | 40 | const cb = this.callback; 41 | imageWidget.callback = function () { 42 | showImage(imageWidget.value); 43 | app.graph.setDirtyCanvas(true); 44 | if (cb) { 45 | return cb.apply(this, arguments); 46 | } 47 | }; 48 | 49 | showImage(imageWidget.value); 50 | app.graph.setDirtyCanvas(true); 51 | }; 52 | } 53 | }, 54 | }); 55 | -------------------------------------------------------------------------------- /web/previewtext.js: -------------------------------------------------------------------------------- 1 | // this code from ComfyUI_Custom_Nodes_AlekPet 2 | 3 | import { app } from "/scripts/app.js"; 4 | import { ComfyWidgets } from "/scripts/widgets.js"; 5 | 6 | app.registerExtension({ 7 | name: "Comfy.BMAB.PreviewText", 8 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 9 | if (nodeData.name === "BMAB Preview Text") { 10 | const onNodeCreated = nodeType.prototype.onNodeCreated; 11 | nodeType.prototype.onNodeCreated = function () { 12 | const ret = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; 13 | 14 | let BMABPreviewText = app.graph._nodes.filter((wi) => wi.type == nodeData.name), 15 | nodeName = `${nodeData.name}_${BMABPreviewText.length}`; 16 | console.log(`Create ${nodeData.name}: ${nodeName}`); 17 | const wi = ComfyWidgets.STRING(this, nodeName, 18 | ["STRING", {default: "", placeholder: "Text message output...", multiline: true,},], app); 19 | wi.widget.inputEl.readOnly = true; 20 | return ret; 21 | }; 22 | 23 | // Function set value 24 | const outSet = function (texts) { 25 | if (texts.length > 0) { 26 | let widget_id = this?.widgets.findIndex((w) => w.type == "customtext"); 27 | if (Array.isArray(texts)) 28 | texts = texts.filter((word) => word.trim() !== "").map((word) => word.trim()).join(" "); 29 | this.widgets[widget_id].value = texts; 30 | app.graph.setDirtyCanvas(true); 31 | } 32 | }; 33 | 34 | // onExecuted 35 | const onExecuted = nodeType.prototype.onExecuted; 36 | nodeType.prototype.onExecuted = function (texts) { 37 | onExecuted?.apply(this, arguments); 38 | outSet.call(this, texts?.string); 39 | }; 40 | 41 | // onConfigure 42 | const onConfigure = nodeType.prototype.onConfigure; 43 | nodeType.prototype.onConfigure = function (w) { 44 | onConfigure?.apply(this, arguments); 45 | if (w?.widgets_values?.length) { 46 | outSet.call(this, w.widgets_values); 47 | } 48 | }; 49 | } 50 | }, 51 | }); 52 | -------------------------------------------------------------------------------- /web/remoteaccess.js: -------------------------------------------------------------------------------- 1 | import { app } from "/scripts/app.js"; 2 | import { ComfyWidgets } from "/scripts/widgets.js"; 3 | import { api } from "../../scripts/api.js"; 4 | 5 | app.registerExtension({ 6 | name: "Comfy.BMAB.BMABRemoteAccessAndSave", 7 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 8 | if (nodeData.name === "BMAB Remote Access And Save") { 9 | 10 | function register_client_id(name) { 11 | fetch(api.apiURL(`bmab?remote_client_id=${api.clientId}&remote_name=${name}`)) 12 | .then(response => response.json()) 13 | .then(data => { 14 | console.log(data); 15 | }) 16 | .catch(error => { 17 | console.error(error); 18 | }); 19 | } 20 | 21 | const onNodeCreated = nodeType.prototype.onNodeCreated; 22 | nodeType.prototype.onNodeCreated = function () { 23 | onNodeCreated?.apply(this, arguments); 24 | const remote_name = this.widgets.find((w) => w.name === "remote_name"); 25 | register_client_id(remote_name.value); 26 | 27 | remote_name.callback = function () { 28 | register_client_id(this.value); 29 | } 30 | }; 31 | 32 | const onReconnect = nodeType.prototype.onReconnect; 33 | nodeType.prototype.onReconnect = function () { 34 | const remote_name = this.widgets.find((w) => w.name === "remote_name"); 35 | register_client_id(remote_name.value); 36 | }; 37 | 38 | const onConfigure = nodeType.prototype.onConfigure; 39 | nodeType.prototype.onConfigure = function (w) { 40 | const remote_name = this.widgets.find((w) => w.name === "remote_name"); 41 | register_client_id(remote_name.value); 42 | }; 43 | 44 | const onBMABQueue = nodeType.prototype.onBMABQueue; 45 | nodeType.prototype.onBMABQueue = function () { 46 | console.log('QUEUE prompt') 47 | app.queuePrompt(0, 1) 48 | }; 49 | 50 | api.addEventListener("reconnected", ({ detail }) => { 51 | app.graph._nodes.forEach((node) => { 52 | if (node.onReconnect) 53 | node.onReconnect() 54 | }) 55 | }); 56 | 57 | api.addEventListener("bmab_queue", ({ detail }) => { 58 | app.graph._nodes.forEach((node) => { 59 | if (node.onBMABQueue) 60 | node.onBMABQueue() 61 | }) 62 | }); 63 | } 64 | }, 65 | }); 66 | --------------------------------------------------------------------------------