├── launch_cmd.bat ├── .gitignore ├── sample ├── copy1.webp ├── copy2.webp ├── remove1.webp └── remove2.webp ├── requirements.txt ├── README.md ├── main.py └── bubble_tool └── bubble_tool.py /launch_cmd.bat: -------------------------------------------------------------------------------- 1 | %windir%\System32\cmd.exe /K "venv\Scripts\activate.bat" 2 | 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /venv 3 | /tmp 4 | /input 5 | /output 6 | .vscode 7 | /data 8 | /runs 9 | -------------------------------------------------------------------------------- /sample/copy1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/copy1.webp -------------------------------------------------------------------------------- /sample/copy2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/copy2.webp -------------------------------------------------------------------------------- /sample/remove1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/remove1.webp -------------------------------------------------------------------------------- /sample/remove2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/remove2.webp -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow 3 | fire 4 | opencv-python 5 | pandas 6 | huggingface_hub 7 | tqdm 8 | ultralytics==8.3.49 9 | simple-lama-inpainting 10 | transformers==4.43.4 11 | einops 12 | shapely 13 | timm 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speech Bubble Remove and Copy Tool 2 | 3 | This is a tool to remove or copy the speech bubbles. 4 | 5 | - Remove all speech bubbles from the image. (source image / cleaned image) 6 | - (As you can see from this example, not all speech bubbles can be captured completely. Or maybe the sample images aren't suitable...) 7 | 8 | 9 | 10 | - Copy all speech bubbles in image A to image B. (source image(A) / base image(B) / base image with speech bubble) 11 | 12 | 13 | 14 | 15 | 16 | ## Installation(for windows) 17 | [Python 3.10](https://www.python.org/) and git client must be installed 18 | 19 | ```sh 20 | git clone https://github.com/s9roll7/speech_bubble_remove_and_copy.git 21 | cd speech_bubble_remove_and_copy 22 | py -3.10 -m venv venv 23 | venv\Scripts\activate.bat 24 | # Please install torch according to your environment.(https://pytorch.org/get-started/locally/) 25 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## How To Use 30 | 31 | run launch_cmd.bat 32 | 33 | ```sh 34 | # remove 35 | python main.py remove SRC_IMAGE_DIR_PATH 36 | 37 | #!! Create base images based on the clean images. 38 | #!! and put it in BASE_IMAGE_DIR with the "same file name". 39 | 40 | # copy 41 | python main.py copy SRC_IMAGE_DIR_PATH BASE_IMAGE_DIR_PATH 42 | 43 | 44 | ``` 45 | or 46 | 47 | ```sh 48 | # remove 49 | # prepare your project directory 50 | # my_proj_1/ 51 | # src/ <--- put source images here 52 | 53 | 54 | # with comic panel extraction 55 | python main.py remove_proj PROJ_DIR_PATH 56 | 57 | or 58 | 59 | # without comic panel extraction 60 | python main.py remove_proj PROJ_DIR_PATH --split=False 61 | 62 | # copy 63 | # prepare your project directory 64 | # my_proj_1/ 65 | # src/ <--- put source images here 66 | # base/ <--- put base images here 67 | 68 | #!! Create base images based on the clean images. 69 | #!! and put it in "my_proj_1/base" with the "same file name". 70 | 71 | python main.py copy_proj PROJ_DIR_PATH 72 | ``` 73 | 74 | ## Advanced Settings 75 | ### Switching models 76 | If you want to include text as well as speech bubbles in the process, you need to use a different model. 77 | Download the "adetailerForTextSpeech" model from civitai, Place it in the following location. 78 | data/models/adetailerForTextSpeech_v20/unwantedV10x.pt 79 | Use the following command. 80 | ```sh 81 | # remove 82 | python main.py remove_proj PROJ_DIR_PATH --model_type=1 83 | 84 | # copy 85 | python main.py copy_proj PROJ_DIR_PATH --model_type=1 86 | ``` 87 | 88 | If you want to use a different model, edit YOLO_SEG_MODEL_LOCATION(in bubble_tool.py) 89 | 90 | 91 | 92 | ## Changelog 93 | ### 2024-12-22 94 | Added comic panel extraction function 95 | Changed lama model 96 | 97 | 98 | ## Related resources 99 | - [simple-lama-inpainting](https://github.com/enesmsahin/simple-lama-inpainting) 100 | - [yolov8m_seg-speech-bubble](https://huggingface.co/kitsumed/yolov8m_seg-speech-bubble) 101 | - [magi](https://github.com/ragavsachdeva/magi) 102 | - [AnimeMangaInpainting](https://huggingface.co/dreMaz/AnimeMangaInpainting) 103 | 104 | 105 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from pathlib import Path 4 | import fire 5 | from datetime import datetime 6 | import shutil 7 | 8 | from bubble_tool.bubble_tool import remove_bubble, copy_bubble, split_panel, combine_panel 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 13 | handler = logging.StreamHandler() 14 | handler.setFormatter(formatter) 15 | logger.addHandler(handler) 16 | 17 | 18 | def get_time_str(): 19 | return datetime.now().strftime("%Y%m%d_%H%M%S") 20 | 21 | def create_proj_dirs(proj_dir_path:Path): 22 | (proj_dir_path / Path("src")).mkdir(parents=True, exist_ok=True) 23 | (proj_dir_path / Path("mask")).mkdir(parents=True, exist_ok=True) 24 | (proj_dir_path / Path("cleaned")).mkdir(parents=True, exist_ok=True) 25 | (proj_dir_path / Path("base")).mkdir(parents=True, exist_ok=True) 26 | (proj_dir_path / Path("base_with_bubble")).mkdir(parents=True, exist_ok=True) 27 | 28 | 29 | 30 | class Command: 31 | def __init__(self): 32 | import os 33 | if os.name == 'nt': 34 | import _locale 35 | if not hasattr(_locale, '_gdl_bak'): 36 | _locale._gdl_bak = _locale._getdefaultlocale 37 | _locale._getdefaultlocale = (lambda *args: (_locale._gdl_bak()[0], 'UTF-8')) 38 | 39 | def remove(self, src_path, model_type:int=0, detection_th:float=0.1): 40 | # src image -> mask + clean image 41 | start_tim = time.time() 42 | 43 | src_path = Path(src_path) 44 | if not src_path.is_dir(): 45 | raise ValueError( f"{src_path} not found" ) 46 | 47 | output_dir = Path("output") / Path("remove") / Path(get_time_str()) 48 | 49 | mask_output_path = output_dir / Path("mask") 50 | mask_output_path.mkdir(parents=True) 51 | clean_img_output_path = output_dir / Path("cleaned_img") 52 | clean_img_output_path.mkdir(parents=True) 53 | 54 | remove_bubble(src_path, mask_output_path, clean_img_output_path, model_type, detection_th) 55 | 56 | logger.info(f"Output : {clean_img_output_path}") 57 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 58 | 59 | 60 | def copy(self, src_path, base_path, mask_path=None, model_type:int=0, blend_method:int=-1, detection_th:float=0.1): 61 | # src image, base_image ,(bubble_mask) -> base_with_bubble image 62 | start_tim = time.time() 63 | 64 | src_path = Path(src_path) 65 | if not src_path.is_dir(): 66 | raise ValueError( f"{src_path} not found" ) 67 | 68 | base_path = Path(base_path) 69 | if not base_path.is_dir(): 70 | raise ValueError( f"{base_path} not found" ) 71 | 72 | output_dir = Path("output") / Path("copy") / Path(get_time_str()) 73 | 74 | with_bubble_img_output_path =output_dir / Path("with_bubble_img") 75 | with_bubble_img_output_path.mkdir(parents=True) 76 | 77 | if mask_path: 78 | mask_path = Path(mask_path) 79 | if not mask_path.is_dir(): 80 | raise ValueError( f"{mask_path} not found" ) 81 | 82 | creat_mask = False 83 | 84 | else: 85 | mask_path = output_dir / Path("mask") 86 | mask_path.mkdir(parents=True) 87 | 88 | creat_mask = True 89 | 90 | copy_bubble(src_path, base_path, mask_path, creat_mask, with_bubble_img_output_path, model_type, blend_method, detection_th) 91 | 92 | logger.info(f"Output : {with_bubble_img_output_path}") 93 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 94 | 95 | 96 | def split(self, src_path, frag_size=1024*2): 97 | # src image -> split images 98 | start_tim = time.time() 99 | 100 | src_path = Path(src_path) 101 | if not src_path.is_dir(): 102 | raise ValueError( f"{src_path} not found" ) 103 | 104 | output_dir = Path("output") / Path("split") / Path(get_time_str()) 105 | output_dir.mkdir(parents=True) 106 | 107 | split_panel(src_path, Path(output_dir), frag_size) 108 | 109 | logger.info(f"Output : {output_dir}") 110 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 111 | 112 | 113 | def combine(self, src_path, frag_path): 114 | # src image + frag images -> combined image 115 | start_tim = time.time() 116 | 117 | src_path = Path(src_path) 118 | if not src_path.is_dir(): 119 | raise ValueError( f"{src_path} not found" ) 120 | 121 | frag_path = Path(frag_path) 122 | if not frag_path.is_dir(): 123 | raise ValueError( f"{frag_path} not found" ) 124 | 125 | output_dir = Path("output") / Path("combine") / Path(get_time_str()) 126 | output_dir.mkdir(parents=True) 127 | 128 | combine_panel(src_path, frag_path, Path(output_dir)) 129 | 130 | logger.info(f"Output : {output_dir}") 131 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 132 | 133 | 134 | def remove_proj(self, proj_dir_path, model_type:int=0, detection_th:float=0.1, split=True, frag_size=1024*2): 135 | # src image -> mask + clean image 136 | start_tim = time.time() 137 | 138 | proj_dir_path = Path(proj_dir_path) 139 | if not proj_dir_path.is_dir(): 140 | raise ValueError( f"{proj_dir_path} not found" ) 141 | 142 | src_path = proj_dir_path / Path("src") 143 | if not src_path.is_dir(): 144 | raise ValueError( f"{src_path} not found" ) 145 | 146 | create_proj_dirs(proj_dir_path) 147 | 148 | time_str = get_time_str() 149 | 150 | mask_output_path = proj_dir_path / Path("mask") / Path(time_str) 151 | mask_output_path.mkdir(parents=True) 152 | 153 | if split == False: 154 | clean_img_output_path = proj_dir_path / Path("cleaned") / Path(time_str) 155 | clean_img_output_path.mkdir(parents=True) 156 | remove_bubble(src_path, mask_output_path, clean_img_output_path, model_type, detection_th) 157 | 158 | logger.info(f"Output : {clean_img_output_path}") 159 | else: 160 | clean_img_output_path = proj_dir_path / Path("pre_split") / Path(time_str) 161 | clean_img_output_path.mkdir(parents=True) 162 | split_output_path = proj_dir_path / Path("cleaned") / Path(time_str) 163 | split_output_path.mkdir(parents=True) 164 | remove_bubble(src_path, mask_output_path, clean_img_output_path, model_type, detection_th) 165 | 166 | split_panel(clean_img_output_path, split_output_path, frag_size) 167 | 168 | split_info_path = split_output_path / Path("split_info.txt") 169 | split_info_path2 = proj_dir_path / Path("base") / Path("split_info.txt") 170 | 171 | shutil.copy(split_info_path, split_info_path2) 172 | 173 | logger.info(f"Output : {split_output_path}") 174 | 175 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 176 | 177 | 178 | def copy_proj(self, proj_dir_path, model_type:int=0, blend_method:int=-1, detection_th:float=0.1): 179 | # src image, base_image ,(bubble_mask) -> base_with_bubble image 180 | start_tim = time.time() 181 | 182 | proj_dir_path = Path(proj_dir_path) 183 | if not proj_dir_path.is_dir(): 184 | raise ValueError( f"{proj_dir_path} not found" ) 185 | 186 | src_path = proj_dir_path / Path("src") 187 | if not src_path.is_dir(): 188 | raise ValueError( f"{src_path} not found" ) 189 | 190 | base_path = proj_dir_path / Path("base") 191 | if not base_path.is_dir(): 192 | raise ValueError( f"{base_path} not found" ) 193 | 194 | create_proj_dirs(proj_dir_path) 195 | 196 | time_str = get_time_str() 197 | 198 | mask_output_path = proj_dir_path / Path("mask") / Path(time_str) 199 | mask_output_path.mkdir(parents=True) 200 | with_bubble_img_output_path = proj_dir_path / Path("base_with_bubble") / Path(time_str) 201 | with_bubble_img_output_path.mkdir(parents=True) 202 | 203 | combine = (base_path / Path("split_info.txt")).is_file() 204 | 205 | if combine == False: 206 | copy_bubble(src_path, base_path, mask_output_path, True, with_bubble_img_output_path, model_type, blend_method, detection_th) 207 | else: 208 | combine_output_path = proj_dir_path / Path("combined") / Path(time_str) 209 | combine_output_path.mkdir(parents=True) 210 | 211 | combine_panel(src_path, base_path, combine_output_path) 212 | copy_bubble(src_path, combine_output_path, mask_output_path, True, with_bubble_img_output_path, model_type, blend_method, detection_th) 213 | 214 | 215 | logger.info(f"Output : {with_bubble_img_output_path}") 216 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 217 | 218 | 219 | fire.Fire(Command) 220 | -------------------------------------------------------------------------------- /bubble_tool/bubble_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from pathlib import Path 4 | import os 5 | 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | import torch 13 | from ultralytics import YOLO 14 | from simple_lama_inpainting import SimpleLama 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | logger.setLevel(logging.INFO) 19 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 20 | handler = logging.StreamHandler() 21 | handler.setFormatter(formatter) 22 | logger.addHandler(handler) 23 | 24 | ############################################################################# 25 | YOLO_SEG_MODEL_LOCATION = [ # model_type 26 | "data/models/yolov8m_seg-speech-bubble/model.pt", # 0 27 | "data/models/adetailerForTextSpeech_v20/unwantedV10x.pt", # 1 (Manual download is required) 28 | ] 29 | 30 | ############################################################################# 31 | 32 | BUBBLE_BORDER_COLOR = (0,0,0) # black 33 | BUBBLE_BORDER_WIDTH = 0 # 0 pixel -> auto-calculation 34 | 35 | ############################################################################# 36 | 37 | 38 | 39 | def get_image_file_list(img_dir_path:Path): 40 | img_list = [p for p in img_dir_path.glob("*") if re.search(r'.*\.(jpg|png|webp)', str(p))] 41 | return sorted(img_list) 42 | 43 | def resize_img(img, size_xy): 44 | if img.shape[0] > size_xy[1]: 45 | return cv2.resize(img, size_xy, interpolation=cv2.INTER_AREA) 46 | else: 47 | return cv2.resize(img, size_xy, interpolation=cv2.INTER_CUBIC) 48 | 49 | def prepare_yolo(): 50 | import os 51 | from pathlib import PurePosixPath 52 | 53 | from huggingface_hub import hf_hub_download 54 | 55 | os.makedirs("data/models/yolov8m_seg-speech-bubble", exist_ok=True) 56 | for hub_file in [ 57 | "model.pt", 58 | ]: 59 | path = Path(hub_file) 60 | 61 | saved_path = "data/models/yolov8m_seg-speech-bubble" / path 62 | 63 | if os.path.exists(saved_path): 64 | continue 65 | 66 | hf_hub_download( 67 | repo_id="kitsumed/yolov8m_seg-speech-bubble", subfolder=PurePosixPath(path.parent), filename=PurePosixPath(path.name), local_dir="data/models/yolov8m_seg-speech-bubble" 68 | ) 69 | 70 | def prepare_lama(): 71 | import os 72 | from pathlib import PurePosixPath 73 | 74 | from huggingface_hub import hf_hub_download 75 | 76 | if os.environ.get("LAMA_MODEL", None) == None: 77 | os.environ["LAMA_MODEL"] = "data/models/AnimeMangaInpainting/model.jit.pt" 78 | else: 79 | return 80 | 81 | os.makedirs("data/models/AnimeMangaInpainting", exist_ok=True) 82 | for hub_file in [ 83 | "model.jit.pt", 84 | ]: 85 | path = Path(hub_file) 86 | 87 | saved_path = "data/models/AnimeMangaInpainting" / path 88 | 89 | if os.path.exists(saved_path): 90 | continue 91 | 92 | hf_hub_download( 93 | repo_id="s9roll74/tracing_dreMaz_AnimeMangaInpainting", subfolder=PurePosixPath(path.parent), filename=PurePosixPath(path.name), local_dir="data/models/AnimeMangaInpainting" 94 | ) 95 | 96 | if False: 97 | from bubble_tool.inpainting_lama_mpe import LamaFourier 98 | def load_lama_mpe(model_path, device, use_mpe: bool = True, large_arch: bool = False) -> LamaFourier: 99 | model = LamaFourier(build_discriminator=False, use_mpe=use_mpe, large_arch=large_arch) 100 | sd = torch.load(model_path, map_location = 'cpu') 101 | model.generator.load_state_dict(sd['gen_state_dict']) 102 | if use_mpe: 103 | model.mpe.load_state_dict(sd['str_state_dict']) 104 | model.eval().to(device) 105 | return model 106 | 107 | model = load_lama_mpe("data/models/AnimeMangaInpainting/lama_large_512px.ckpt", device='cpu', use_mpe=False, large_arch=True) 108 | traced_model = torch.jit.trace(model.generator, ( torch.zeros([1, 3, 512, 512]), torch.zeros([1,1,512, 512]))) 109 | traced_model.save("data/models/AnimeMangaInpainting/model.jit.pt") 110 | 111 | 112 | # https://github.com/ultralytics/ultralytics/issues/3560 113 | def scale_image_torch(masks, im0_shape, ratio_pad=None): 114 | """ 115 | Takes a mask, and resizes it to the original image size 116 | 117 | Args: 118 | masks (torch.Tensor): resized and padded masks/images, [c, h, w]. 119 | im0_shape (tuple): the original image shape 120 | ratio_pad (tuple): the ratio of the padding to the original image. 121 | 122 | Returns: 123 | masks (torch.Tensor): The masks that are being returned. 124 | """ 125 | import torch.nn.functional as F 126 | 127 | if len(masks.shape) < 3: 128 | raise ValueError( 129 | f'"len of masks shape" should be 3, but got {len(masks.shape)}' 130 | ) 131 | im1_shape = masks.shape 132 | if im1_shape[1:] == im0_shape: 133 | return masks 134 | if ratio_pad is None: # calculate from im0_shape 135 | gain = min( 136 | im1_shape[1] / im0_shape[0], im1_shape[2] / im0_shape[1] 137 | ) # gain = old / new 138 | pad = (im1_shape[2] - im0_shape[1] * gain) / 2, ( 139 | im1_shape[1] - im0_shape[0] * gain 140 | ) / 2 # wh padding 141 | else: 142 | gain = ratio_pad[0][0] 143 | pad = ratio_pad[1] 144 | top, left = int(pad[1]), int(pad[0]) # y, x 145 | bottom, right = int(im1_shape[1] - pad[1]), int(im1_shape[2] - pad[0]) 146 | 147 | masks = masks[:, top:bottom, left:right] 148 | if masks.shape[1:] != im0_shape: 149 | masks = F.interpolate( 150 | masks[None], im0_shape, mode="bilinear", align_corners=False 151 | )[0] 152 | 153 | return masks 154 | 155 | 156 | 157 | ############################################################################# 158 | 159 | def detect_bubble(src_list, mask_path:Path, model_type, detection_th, classification): 160 | 161 | if model_type == 0: 162 | prepare_yolo() 163 | 164 | if len(YOLO_SEG_MODEL_LOCATION) > model_type: 165 | model = YOLO(YOLO_SEG_MODEL_LOCATION[model_type]) 166 | else: 167 | raise ValueError(f"unknown {model_type=}") 168 | 169 | model.to("cuda" if torch.cuda.is_available() else "cpu") 170 | 171 | for i, src in tqdm( enumerate(src_list), desc=f"detect_bubble", total=len(src_list)): 172 | 173 | result_path = mask_path / Path(Path(src).with_suffix(".png").name) 174 | 175 | org_size = Image.open(src).size 176 | 177 | with torch.no_grad(): 178 | results = model.predict(src, save=False, verbose=False, conf=detection_th) 179 | 180 | masks = results[0].masks 181 | boxes = results[0].boxes 182 | 183 | result = [None,None,None] 184 | 185 | if masks is not None: 186 | for mask, box in zip(masks,boxes): 187 | cls = int(box.cls) 188 | if cls > 2: 189 | cls = 2 190 | 191 | mask = scale_image_torch(mask.data, (org_size[1],org_size[0])) 192 | if result[cls] is not None: 193 | result[cls] += mask.squeeze() 194 | else: 195 | result[cls] = mask.squeeze() 196 | 197 | for i in range(len(result)): 198 | if result[i] is not None: 199 | result[i] = result[i].cpu().numpy() 200 | result[i] = result[i].astype('uint8') * 255 201 | else: 202 | result[i] = np.zeros((org_size[1],org_size[0]), np.uint8) 203 | 204 | if classification: 205 | result_array = np.dstack(result) 206 | else: 207 | result_array = result[0] | result[1] | result[2] 208 | 209 | Image.fromarray(result_array).save(result_path) 210 | else: 211 | if classification: 212 | result_array = np.zeros((org_size[1],org_size[0],3), np.uint8) 213 | else: 214 | result_array = np.zeros((org_size[1],org_size[0]), np.uint8) 215 | 216 | Image.fromarray( result_array ).save(result_path) 217 | 218 | 219 | if False: 220 | bubble_array = np.array(Image.open(src)) 221 | bubble_array[result==0] = 120 222 | 223 | Image.fromarray(bubble_array).save("bubble_only.png") 224 | 225 | model.to("cpu") 226 | 227 | torch.cuda.empty_cache() 228 | 229 | 230 | 231 | def lama_inpaint(src_list, mask_list, output_path:Path): 232 | 233 | prepare_lama() 234 | 235 | simple_lama = SimpleLama() 236 | 237 | for i, (src, mask) in tqdm( enumerate(zip(src_list,mask_list)), desc=f"lama_inpaint", total=min(len(src_list),len(mask_list))): 238 | 239 | result_path = output_path / Path(Path(src).name) 240 | 241 | image = Image.open(src) 242 | 243 | image = image.convert('RGB') 244 | 245 | org_size = image.size 246 | 247 | mask = np.array(Image.open(mask).convert('L')) 248 | 249 | k = int(org_size[0] * 10 / 480) 250 | mask = cv2.dilate(mask, np.ones((k, k), np.uint8), 3) 251 | 252 | k = int(org_size[0] * 9 / 480) //2 * 2 + 1 253 | mask = cv2.GaussianBlur(mask, ksize=(k,k), sigmaX=0) 254 | 255 | mask[mask >= 125] = 255 256 | mask[mask < 125] = 0 257 | 258 | result = simple_lama(image, mask) 259 | result.save(result_path) 260 | 261 | simple_lama.model.to("cpu") 262 | 263 | torch.cuda.empty_cache() 264 | 265 | 266 | def blend_image_A(org_array, mask_array, dst_array): 267 | 268 | kernel1_size = int(30 / 1000 * dst_array.shape[1]) 269 | kernel1_size = max(kernel1_size , 3) 270 | 271 | kernel2_size = int(5 / 1000 * dst_array.shape[1]) 272 | kernel2_size = max(kernel2_size , 3) 273 | 274 | gaussian_k_size = int(31 / 1000 * dst_array.shape[1]) // 2 * 2 + 1 275 | gaussian_k_size = max(gaussian_k_size , 3) 276 | 277 | bubble_border_width = int(3 / 1000 * dst_array.shape[1]) 278 | bubble_border_width = max(bubble_border_width , 1) 279 | if BUBBLE_BORDER_WIDTH: 280 | bubble_border_width = BUBBLE_BORDER_WIDTH 281 | 282 | kernel = np.ones((kernel1_size,kernel1_size),np.uint8) 283 | kernel2 = np.ones((kernel2_size,kernel2_size),np.uint8) 284 | 285 | mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_OPEN, kernel) 286 | mask_array = cv2.erode(mask_array,kernel2,iterations = 1) 287 | 288 | mask_array = cv2.GaussianBlur(mask_array, ksize=(gaussian_k_size,gaussian_k_size), sigmaX=0) 289 | mask_array[mask_array >= 125] = 255 290 | mask_array[mask_array < 125] = 0 291 | 292 | #Image.fromarray(mask_array.astype(np.uint8)).save("new_mask.png") 293 | 294 | contours, _ = cv2.findContours(mask_array, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 295 | 296 | for j in range(len(contours)): 297 | if cv2.contourArea(contours[j]) > (mask_array.shape[0] * mask_array.shape[1]) * 0.005: 298 | org_array = cv2.drawContours(org_array, contours, j, BUBBLE_BORDER_COLOR, bubble_border_width) 299 | 300 | if mask_array.ndim == 2: 301 | mask_array = mask_array[:, :, np.newaxis] 302 | 303 | mask_array = mask_array / 255 304 | 305 | dst_array = org_array * mask_array + dst_array * (1 - mask_array) 306 | 307 | return dst_array 308 | 309 | def blend_image_B(org_array, mask_array, dst_array): 310 | 311 | gaussian_k_size = int(31 / 1000 * dst_array.shape[1]) // 2 * 2 + 1 312 | 313 | mask_array = cv2.GaussianBlur(mask_array, ksize=(gaussian_k_size,gaussian_k_size), sigmaX=0) 314 | 315 | if mask_array.ndim == 2: 316 | mask_array = mask_array[:, :, np.newaxis] 317 | 318 | mask_array = mask_array / 255 319 | 320 | dst_array = org_array * mask_array + dst_array * (1 - mask_array) 321 | 322 | return dst_array 323 | 324 | def blend_image_C(org_array, mask_array, dst_array): 325 | 326 | org_array[mask_array==0] = (0) 327 | 328 | gray = cv2.cvtColor(org_array, cv2.COLOR_RGB2GRAY) 329 | ret, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY) 330 | 331 | contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 332 | mask = np.zeros((org_array.shape[0],org_array.shape[1]), np.uint8) 333 | mask = cv2.drawContours(mask, contours, -1, (255), -1) 334 | 335 | if mask.ndim == 2: 336 | mask = mask[:, :, np.newaxis] 337 | 338 | mask = mask / 255 339 | 340 | dst_array = org_array * mask + dst_array * (1 - mask) 341 | 342 | return dst_array 343 | 344 | 345 | def blend_image(src_list, mask_list, base_list, output_path:Path, blend_method): 346 | 347 | for i, (src, mask, base) in tqdm( enumerate(zip(src_list,mask_list,base_list)), desc=f"blend_image", total=min(len(src_list),len(mask_list),len(base_list))): 348 | 349 | result_path = output_path / Path(Path(src).name) 350 | 351 | org_array = np.array(Image.open(src)) 352 | mask_array = np.array(Image.open(mask)) 353 | dst_array = np.array(Image.open(base)) 354 | org_array = org_array[:,:,:3] 355 | dst_array = dst_array[:,:,:3] 356 | 357 | if blend_method != -1: 358 | if mask_array.ndim == 3: 359 | mask_array = mask_array[:, :, 0] | mask_array[:, :, 1] | mask_array[:, :, 2] 360 | 361 | # DEBUG 362 | #dst_array[:,:,:] = 125 363 | 364 | org_array = resize_img( org_array, (dst_array.shape[1], dst_array.shape[0]) ) 365 | mask_array = resize_img( mask_array, (dst_array.shape[1], dst_array.shape[0]) ) 366 | 367 | if blend_method == -1: 368 | 369 | dst_array = blend_image_A(org_array.copy(), mask_array[:, :, 0], dst_array) 370 | dst_array = blend_image_B(org_array.copy(), mask_array[:, :, 1], dst_array) 371 | dst_array = blend_image_B(org_array.copy(), mask_array[:, :, 2], dst_array) 372 | 373 | elif blend_method == 0: 374 | dst_array = blend_image_A(org_array, mask_array, dst_array) 375 | elif blend_method == 1: 376 | dst_array = blend_image_B(org_array, mask_array, dst_array) 377 | else: 378 | dst_array = blend_image_C(org_array, mask_array, dst_array) 379 | 380 | Image.fromarray(dst_array.astype(np.uint8)).save(result_path) 381 | 382 | 383 | def remove_bubble(src_path:Path, mask_output_path:Path, clean_img_output_path:Path, model_type, detection_th): 384 | 385 | src_list = get_image_file_list(src_path) 386 | 387 | detect_bubble(src_list, mask_output_path, model_type, detection_th, False) 388 | 389 | mask_list = get_image_file_list(mask_output_path) 390 | 391 | lama_inpaint(src_list, mask_list, clean_img_output_path) 392 | 393 | 394 | 395 | def copy_bubble(src_path:Path, base_path:Path, mask_input_path:Path, create_mask, with_bubble_img_output_path:Path, model_type, blend_method, detection_th): 396 | 397 | base_list = get_image_file_list(base_path) 398 | src_list = get_image_file_list(src_path) 399 | 400 | src_map = { s.stem : s for s in src_list } 401 | 402 | src_list = [] 403 | 404 | for b in base_list: 405 | src_img_path = src_map.get(b.stem, None) 406 | 407 | if src_img_path: 408 | if src_img_path.is_file(): 409 | src_list.append( src_img_path ) 410 | 411 | 412 | if create_mask: 413 | detect_bubble(src_list, mask_input_path, model_type, detection_th, True) 414 | 415 | mask_list = [ (mask_input_path/Path(s.name)).with_suffix(".png") for s in base_list if (mask_input_path/Path(s.name)).with_suffix(".png").is_file() ] 416 | 417 | blend_image(src_list, mask_list, base_list, with_bubble_img_output_path, blend_method) 418 | 419 | 420 | 421 | def split_panel(src_path:Path, split_output_path:Path, output_size): 422 | from transformers import AutoModel 423 | 424 | src_list = get_image_file_list(src_path) 425 | 426 | def read_image_as_np_array(image_path): 427 | with open(image_path, "rb") as file: 428 | image = Image.open(file).convert("L").convert("RGB") 429 | image = np.array(image) 430 | return image 431 | 432 | images = [read_image_as_np_array(image) for image in src_list] 433 | 434 | model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True).cuda() 435 | with torch.no_grad(): 436 | results = model.predict_detections_and_associations(images) 437 | 438 | crop_info = [] 439 | 440 | for i, (result, src_path) in enumerate(zip(results, src_list)) : 441 | with Image.open(src_path) as im: 442 | crop_info.append(f"{src_path.name},{im.size[0]},{im.size[1]}") 443 | for j, panel in enumerate(result["panels"]): 444 | im_crop = im.crop(panel) 445 | scale = output_size / (im_crop.size[0] + im_crop.size[1]) 446 | im_crop = im_crop.resize( (int(im_crop.size[0] * scale), int(im_crop.size[1] * scale)), Image.LANCZOS ) 447 | filename = str(i).zfill(8) + "_" + str(j).zfill(8) + ".png" 448 | im_crop.save( split_output_path / Path(filename) ) 449 | crop_info.append(f"{filename},{panel[0]},{panel[1]},{panel[2]},{panel[3]}") 450 | 451 | crop_info_path = split_output_path / Path("split_info.txt") 452 | crop_info_path.write_text( "\n".join(crop_info), encoding="utf-8") 453 | 454 | model.to("cpu") 455 | 456 | torch.cuda.empty_cache() 457 | 458 | 459 | def combine_panel(src_path:Path, fragment_path:Path, output_dir:Path): 460 | 461 | crop_info_path = fragment_path / Path("split_info.txt") 462 | if crop_info_path.is_file() == False: 463 | raise ValueError(f"{crop_info_path} not found") 464 | 465 | crop_list = crop_info_path.read_text() 466 | crop_list = crop_list.splitlines() 467 | 468 | def create_crop_info(crop_list): 469 | crop_info = {} 470 | cur_src_name = None 471 | cur_item = {} 472 | 473 | for c in crop_list: 474 | c = c.split(",") 475 | if len(c) == 3: 476 | if cur_item: 477 | crop_info[cur_src_name] = cur_item 478 | cur_item = {} 479 | cur_src_name, x, y = c 480 | cur_item["org_size"] = (int(x), int(y)) 481 | cur_item["frags"] = {} 482 | else: 483 | frag_name, x1, y1, x2, y2 = c 484 | x1 = int(float(x1)) 485 | y1 = int(float(y1)) 486 | x2 = int(float(x2)) 487 | y2 = int(float(y2)) 488 | cur_item["frags"][frag_name] = (x1,y1,x2,y2) 489 | 490 | if cur_item: 491 | crop_info[cur_src_name] = cur_item 492 | return crop_info 493 | 494 | 495 | crop_info = create_crop_info(crop_list) 496 | 497 | 498 | for i, src_name in enumerate(crop_info): 499 | src_img_path = src_path / Path(src_name) 500 | if src_img_path.is_file() == False: 501 | continue 502 | src_img = Image.open( src_img_path ) 503 | org_size = crop_info[src_name]["org_size"] 504 | 505 | mod = False 506 | 507 | for frag_name in crop_info[src_name]["frags"]: 508 | frag_img_path = fragment_path / Path(frag_name) 509 | if frag_img_path.is_file() == False: 510 | frag_img_path = frag_img_path.with_suffix(".png") 511 | if frag_img_path.is_file() == False: 512 | continue 513 | frag_img = Image.open(frag_img_path) 514 | x1,y1,x2,y2 = crop_info[src_name]["frags"][frag_name] 515 | 516 | scale_x = (x2-x1) / org_size[0] 517 | scale_y = (y2-y1) / org_size[1] 518 | frag_img = frag_img.resize( (int(src_img.size[0] * scale_x), int(src_img.size[1] * scale_y)), Image.LANCZOS ) 519 | scale_x = src_img.size[0]/org_size[0] 520 | scale_y = src_img.size[1]/org_size[1] 521 | src_img.paste(frag_img, (int(x1 * scale_x),int(y1 * scale_y)) ) 522 | mod = True 523 | 524 | if mod: 525 | src_img.save( output_dir/ Path(src_name) ) 526 | 527 | --------------------------------------------------------------------------------