├── launch_cmd.bat
├── .gitignore
├── sample
├── copy1.webp
├── copy2.webp
├── remove1.webp
└── remove2.webp
├── requirements.txt
├── README.md
├── main.py
└── bubble_tool
└── bubble_tool.py
/launch_cmd.bat:
--------------------------------------------------------------------------------
1 | %windir%\System32\cmd.exe /K "venv\Scripts\activate.bat"
2 |
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | /venv
3 | /tmp
4 | /input
5 | /output
6 | .vscode
7 | /data
8 | /runs
9 |
--------------------------------------------------------------------------------
/sample/copy1.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/copy1.webp
--------------------------------------------------------------------------------
/sample/copy2.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/copy2.webp
--------------------------------------------------------------------------------
/sample/remove1.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/remove1.webp
--------------------------------------------------------------------------------
/sample/remove2.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/speech_bubble_remove_and_copy/master/sample/remove2.webp
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pillow
3 | fire
4 | opencv-python
5 | pandas
6 | huggingface_hub
7 | tqdm
8 | ultralytics==8.3.49
9 | simple-lama-inpainting
10 | transformers==4.43.4
11 | einops
12 | shapely
13 | timm
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Speech Bubble Remove and Copy Tool
2 |
3 | This is a tool to remove or copy the speech bubbles.
4 |
5 | - Remove all speech bubbles from the image. (source image / cleaned image)
6 | - (As you can see from this example, not all speech bubbles can be captured completely. Or maybe the sample images aren't suitable...)
7 |
8 |
9 |
10 | - Copy all speech bubbles in image A to image B. (source image(A) / base image(B) / base image with speech bubble)
11 |
12 |
13 |
14 |
15 |
16 | ## Installation(for windows)
17 | [Python 3.10](https://www.python.org/) and git client must be installed
18 |
19 | ```sh
20 | git clone https://github.com/s9roll7/speech_bubble_remove_and_copy.git
21 | cd speech_bubble_remove_and_copy
22 | py -3.10 -m venv venv
23 | venv\Scripts\activate.bat
24 | # Please install torch according to your environment.(https://pytorch.org/get-started/locally/)
25 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
26 | pip install -r requirements.txt
27 | ```
28 |
29 | ## How To Use
30 |
31 | run launch_cmd.bat
32 |
33 | ```sh
34 | # remove
35 | python main.py remove SRC_IMAGE_DIR_PATH
36 |
37 | #!! Create base images based on the clean images.
38 | #!! and put it in BASE_IMAGE_DIR with the "same file name".
39 |
40 | # copy
41 | python main.py copy SRC_IMAGE_DIR_PATH BASE_IMAGE_DIR_PATH
42 |
43 |
44 | ```
45 | or
46 |
47 | ```sh
48 | # remove
49 | # prepare your project directory
50 | # my_proj_1/
51 | # src/ <--- put source images here
52 |
53 |
54 | # with comic panel extraction
55 | python main.py remove_proj PROJ_DIR_PATH
56 |
57 | or
58 |
59 | # without comic panel extraction
60 | python main.py remove_proj PROJ_DIR_PATH --split=False
61 |
62 | # copy
63 | # prepare your project directory
64 | # my_proj_1/
65 | # src/ <--- put source images here
66 | # base/ <--- put base images here
67 |
68 | #!! Create base images based on the clean images.
69 | #!! and put it in "my_proj_1/base" with the "same file name".
70 |
71 | python main.py copy_proj PROJ_DIR_PATH
72 | ```
73 |
74 | ## Advanced Settings
75 | ### Switching models
76 | If you want to include text as well as speech bubbles in the process, you need to use a different model.
77 | Download the "adetailerForTextSpeech" model from civitai, Place it in the following location.
78 | data/models/adetailerForTextSpeech_v20/unwantedV10x.pt
79 | Use the following command.
80 | ```sh
81 | # remove
82 | python main.py remove_proj PROJ_DIR_PATH --model_type=1
83 |
84 | # copy
85 | python main.py copy_proj PROJ_DIR_PATH --model_type=1
86 | ```
87 |
88 | If you want to use a different model, edit YOLO_SEG_MODEL_LOCATION(in bubble_tool.py)
89 |
90 |
91 |
92 | ## Changelog
93 | ### 2024-12-22
94 | Added comic panel extraction function
95 | Changed lama model
96 |
97 |
98 | ## Related resources
99 | - [simple-lama-inpainting](https://github.com/enesmsahin/simple-lama-inpainting)
100 | - [yolov8m_seg-speech-bubble](https://huggingface.co/kitsumed/yolov8m_seg-speech-bubble)
101 | - [magi](https://github.com/ragavsachdeva/magi)
102 | - [AnimeMangaInpainting](https://huggingface.co/dreMaz/AnimeMangaInpainting)
103 |
104 |
105 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from pathlib import Path
4 | import fire
5 | from datetime import datetime
6 | import shutil
7 |
8 | from bubble_tool.bubble_tool import remove_bubble, copy_bubble, split_panel, combine_panel
9 |
10 | logger = logging.getLogger(__name__)
11 | logger.setLevel(logging.INFO)
12 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s")
13 | handler = logging.StreamHandler()
14 | handler.setFormatter(formatter)
15 | logger.addHandler(handler)
16 |
17 |
18 | def get_time_str():
19 | return datetime.now().strftime("%Y%m%d_%H%M%S")
20 |
21 | def create_proj_dirs(proj_dir_path:Path):
22 | (proj_dir_path / Path("src")).mkdir(parents=True, exist_ok=True)
23 | (proj_dir_path / Path("mask")).mkdir(parents=True, exist_ok=True)
24 | (proj_dir_path / Path("cleaned")).mkdir(parents=True, exist_ok=True)
25 | (proj_dir_path / Path("base")).mkdir(parents=True, exist_ok=True)
26 | (proj_dir_path / Path("base_with_bubble")).mkdir(parents=True, exist_ok=True)
27 |
28 |
29 |
30 | class Command:
31 | def __init__(self):
32 | import os
33 | if os.name == 'nt':
34 | import _locale
35 | if not hasattr(_locale, '_gdl_bak'):
36 | _locale._gdl_bak = _locale._getdefaultlocale
37 | _locale._getdefaultlocale = (lambda *args: (_locale._gdl_bak()[0], 'UTF-8'))
38 |
39 | def remove(self, src_path, model_type:int=0, detection_th:float=0.1):
40 | # src image -> mask + clean image
41 | start_tim = time.time()
42 |
43 | src_path = Path(src_path)
44 | if not src_path.is_dir():
45 | raise ValueError( f"{src_path} not found" )
46 |
47 | output_dir = Path("output") / Path("remove") / Path(get_time_str())
48 |
49 | mask_output_path = output_dir / Path("mask")
50 | mask_output_path.mkdir(parents=True)
51 | clean_img_output_path = output_dir / Path("cleaned_img")
52 | clean_img_output_path.mkdir(parents=True)
53 |
54 | remove_bubble(src_path, mask_output_path, clean_img_output_path, model_type, detection_th)
55 |
56 | logger.info(f"Output : {clean_img_output_path}")
57 | logger.info(f"Total Elapsed time : {time.time() - start_tim}")
58 |
59 |
60 | def copy(self, src_path, base_path, mask_path=None, model_type:int=0, blend_method:int=-1, detection_th:float=0.1):
61 | # src image, base_image ,(bubble_mask) -> base_with_bubble image
62 | start_tim = time.time()
63 |
64 | src_path = Path(src_path)
65 | if not src_path.is_dir():
66 | raise ValueError( f"{src_path} not found" )
67 |
68 | base_path = Path(base_path)
69 | if not base_path.is_dir():
70 | raise ValueError( f"{base_path} not found" )
71 |
72 | output_dir = Path("output") / Path("copy") / Path(get_time_str())
73 |
74 | with_bubble_img_output_path =output_dir / Path("with_bubble_img")
75 | with_bubble_img_output_path.mkdir(parents=True)
76 |
77 | if mask_path:
78 | mask_path = Path(mask_path)
79 | if not mask_path.is_dir():
80 | raise ValueError( f"{mask_path} not found" )
81 |
82 | creat_mask = False
83 |
84 | else:
85 | mask_path = output_dir / Path("mask")
86 | mask_path.mkdir(parents=True)
87 |
88 | creat_mask = True
89 |
90 | copy_bubble(src_path, base_path, mask_path, creat_mask, with_bubble_img_output_path, model_type, blend_method, detection_th)
91 |
92 | logger.info(f"Output : {with_bubble_img_output_path}")
93 | logger.info(f"Total Elapsed time : {time.time() - start_tim}")
94 |
95 |
96 | def split(self, src_path, frag_size=1024*2):
97 | # src image -> split images
98 | start_tim = time.time()
99 |
100 | src_path = Path(src_path)
101 | if not src_path.is_dir():
102 | raise ValueError( f"{src_path} not found" )
103 |
104 | output_dir = Path("output") / Path("split") / Path(get_time_str())
105 | output_dir.mkdir(parents=True)
106 |
107 | split_panel(src_path, Path(output_dir), frag_size)
108 |
109 | logger.info(f"Output : {output_dir}")
110 | logger.info(f"Total Elapsed time : {time.time() - start_tim}")
111 |
112 |
113 | def combine(self, src_path, frag_path):
114 | # src image + frag images -> combined image
115 | start_tim = time.time()
116 |
117 | src_path = Path(src_path)
118 | if not src_path.is_dir():
119 | raise ValueError( f"{src_path} not found" )
120 |
121 | frag_path = Path(frag_path)
122 | if not frag_path.is_dir():
123 | raise ValueError( f"{frag_path} not found" )
124 |
125 | output_dir = Path("output") / Path("combine") / Path(get_time_str())
126 | output_dir.mkdir(parents=True)
127 |
128 | combine_panel(src_path, frag_path, Path(output_dir))
129 |
130 | logger.info(f"Output : {output_dir}")
131 | logger.info(f"Total Elapsed time : {time.time() - start_tim}")
132 |
133 |
134 | def remove_proj(self, proj_dir_path, model_type:int=0, detection_th:float=0.1, split=True, frag_size=1024*2):
135 | # src image -> mask + clean image
136 | start_tim = time.time()
137 |
138 | proj_dir_path = Path(proj_dir_path)
139 | if not proj_dir_path.is_dir():
140 | raise ValueError( f"{proj_dir_path} not found" )
141 |
142 | src_path = proj_dir_path / Path("src")
143 | if not src_path.is_dir():
144 | raise ValueError( f"{src_path} not found" )
145 |
146 | create_proj_dirs(proj_dir_path)
147 |
148 | time_str = get_time_str()
149 |
150 | mask_output_path = proj_dir_path / Path("mask") / Path(time_str)
151 | mask_output_path.mkdir(parents=True)
152 |
153 | if split == False:
154 | clean_img_output_path = proj_dir_path / Path("cleaned") / Path(time_str)
155 | clean_img_output_path.mkdir(parents=True)
156 | remove_bubble(src_path, mask_output_path, clean_img_output_path, model_type, detection_th)
157 |
158 | logger.info(f"Output : {clean_img_output_path}")
159 | else:
160 | clean_img_output_path = proj_dir_path / Path("pre_split") / Path(time_str)
161 | clean_img_output_path.mkdir(parents=True)
162 | split_output_path = proj_dir_path / Path("cleaned") / Path(time_str)
163 | split_output_path.mkdir(parents=True)
164 | remove_bubble(src_path, mask_output_path, clean_img_output_path, model_type, detection_th)
165 |
166 | split_panel(clean_img_output_path, split_output_path, frag_size)
167 |
168 | split_info_path = split_output_path / Path("split_info.txt")
169 | split_info_path2 = proj_dir_path / Path("base") / Path("split_info.txt")
170 |
171 | shutil.copy(split_info_path, split_info_path2)
172 |
173 | logger.info(f"Output : {split_output_path}")
174 |
175 | logger.info(f"Total Elapsed time : {time.time() - start_tim}")
176 |
177 |
178 | def copy_proj(self, proj_dir_path, model_type:int=0, blend_method:int=-1, detection_th:float=0.1):
179 | # src image, base_image ,(bubble_mask) -> base_with_bubble image
180 | start_tim = time.time()
181 |
182 | proj_dir_path = Path(proj_dir_path)
183 | if not proj_dir_path.is_dir():
184 | raise ValueError( f"{proj_dir_path} not found" )
185 |
186 | src_path = proj_dir_path / Path("src")
187 | if not src_path.is_dir():
188 | raise ValueError( f"{src_path} not found" )
189 |
190 | base_path = proj_dir_path / Path("base")
191 | if not base_path.is_dir():
192 | raise ValueError( f"{base_path} not found" )
193 |
194 | create_proj_dirs(proj_dir_path)
195 |
196 | time_str = get_time_str()
197 |
198 | mask_output_path = proj_dir_path / Path("mask") / Path(time_str)
199 | mask_output_path.mkdir(parents=True)
200 | with_bubble_img_output_path = proj_dir_path / Path("base_with_bubble") / Path(time_str)
201 | with_bubble_img_output_path.mkdir(parents=True)
202 |
203 | combine = (base_path / Path("split_info.txt")).is_file()
204 |
205 | if combine == False:
206 | copy_bubble(src_path, base_path, mask_output_path, True, with_bubble_img_output_path, model_type, blend_method, detection_th)
207 | else:
208 | combine_output_path = proj_dir_path / Path("combined") / Path(time_str)
209 | combine_output_path.mkdir(parents=True)
210 |
211 | combine_panel(src_path, base_path, combine_output_path)
212 | copy_bubble(src_path, combine_output_path, mask_output_path, True, with_bubble_img_output_path, model_type, blend_method, detection_th)
213 |
214 |
215 | logger.info(f"Output : {with_bubble_img_output_path}")
216 | logger.info(f"Total Elapsed time : {time.time() - start_tim}")
217 |
218 |
219 | fire.Fire(Command)
220 |
--------------------------------------------------------------------------------
/bubble_tool/bubble_tool.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from pathlib import Path
4 | import os
5 |
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | import torch
13 | from ultralytics import YOLO
14 | from simple_lama_inpainting import SimpleLama
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 | logger.setLevel(logging.INFO)
19 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s")
20 | handler = logging.StreamHandler()
21 | handler.setFormatter(formatter)
22 | logger.addHandler(handler)
23 |
24 | #############################################################################
25 | YOLO_SEG_MODEL_LOCATION = [ # model_type
26 | "data/models/yolov8m_seg-speech-bubble/model.pt", # 0
27 | "data/models/adetailerForTextSpeech_v20/unwantedV10x.pt", # 1 (Manual download is required)
28 | ]
29 |
30 | #############################################################################
31 |
32 | BUBBLE_BORDER_COLOR = (0,0,0) # black
33 | BUBBLE_BORDER_WIDTH = 0 # 0 pixel -> auto-calculation
34 |
35 | #############################################################################
36 |
37 |
38 |
39 | def get_image_file_list(img_dir_path:Path):
40 | img_list = [p for p in img_dir_path.glob("*") if re.search(r'.*\.(jpg|png|webp)', str(p))]
41 | return sorted(img_list)
42 |
43 | def resize_img(img, size_xy):
44 | if img.shape[0] > size_xy[1]:
45 | return cv2.resize(img, size_xy, interpolation=cv2.INTER_AREA)
46 | else:
47 | return cv2.resize(img, size_xy, interpolation=cv2.INTER_CUBIC)
48 |
49 | def prepare_yolo():
50 | import os
51 | from pathlib import PurePosixPath
52 |
53 | from huggingface_hub import hf_hub_download
54 |
55 | os.makedirs("data/models/yolov8m_seg-speech-bubble", exist_ok=True)
56 | for hub_file in [
57 | "model.pt",
58 | ]:
59 | path = Path(hub_file)
60 |
61 | saved_path = "data/models/yolov8m_seg-speech-bubble" / path
62 |
63 | if os.path.exists(saved_path):
64 | continue
65 |
66 | hf_hub_download(
67 | repo_id="kitsumed/yolov8m_seg-speech-bubble", subfolder=PurePosixPath(path.parent), filename=PurePosixPath(path.name), local_dir="data/models/yolov8m_seg-speech-bubble"
68 | )
69 |
70 | def prepare_lama():
71 | import os
72 | from pathlib import PurePosixPath
73 |
74 | from huggingface_hub import hf_hub_download
75 |
76 | if os.environ.get("LAMA_MODEL", None) == None:
77 | os.environ["LAMA_MODEL"] = "data/models/AnimeMangaInpainting/model.jit.pt"
78 | else:
79 | return
80 |
81 | os.makedirs("data/models/AnimeMangaInpainting", exist_ok=True)
82 | for hub_file in [
83 | "model.jit.pt",
84 | ]:
85 | path = Path(hub_file)
86 |
87 | saved_path = "data/models/AnimeMangaInpainting" / path
88 |
89 | if os.path.exists(saved_path):
90 | continue
91 |
92 | hf_hub_download(
93 | repo_id="s9roll74/tracing_dreMaz_AnimeMangaInpainting", subfolder=PurePosixPath(path.parent), filename=PurePosixPath(path.name), local_dir="data/models/AnimeMangaInpainting"
94 | )
95 |
96 | if False:
97 | from bubble_tool.inpainting_lama_mpe import LamaFourier
98 | def load_lama_mpe(model_path, device, use_mpe: bool = True, large_arch: bool = False) -> LamaFourier:
99 | model = LamaFourier(build_discriminator=False, use_mpe=use_mpe, large_arch=large_arch)
100 | sd = torch.load(model_path, map_location = 'cpu')
101 | model.generator.load_state_dict(sd['gen_state_dict'])
102 | if use_mpe:
103 | model.mpe.load_state_dict(sd['str_state_dict'])
104 | model.eval().to(device)
105 | return model
106 |
107 | model = load_lama_mpe("data/models/AnimeMangaInpainting/lama_large_512px.ckpt", device='cpu', use_mpe=False, large_arch=True)
108 | traced_model = torch.jit.trace(model.generator, ( torch.zeros([1, 3, 512, 512]), torch.zeros([1,1,512, 512])))
109 | traced_model.save("data/models/AnimeMangaInpainting/model.jit.pt")
110 |
111 |
112 | # https://github.com/ultralytics/ultralytics/issues/3560
113 | def scale_image_torch(masks, im0_shape, ratio_pad=None):
114 | """
115 | Takes a mask, and resizes it to the original image size
116 |
117 | Args:
118 | masks (torch.Tensor): resized and padded masks/images, [c, h, w].
119 | im0_shape (tuple): the original image shape
120 | ratio_pad (tuple): the ratio of the padding to the original image.
121 |
122 | Returns:
123 | masks (torch.Tensor): The masks that are being returned.
124 | """
125 | import torch.nn.functional as F
126 |
127 | if len(masks.shape) < 3:
128 | raise ValueError(
129 | f'"len of masks shape" should be 3, but got {len(masks.shape)}'
130 | )
131 | im1_shape = masks.shape
132 | if im1_shape[1:] == im0_shape:
133 | return masks
134 | if ratio_pad is None: # calculate from im0_shape
135 | gain = min(
136 | im1_shape[1] / im0_shape[0], im1_shape[2] / im0_shape[1]
137 | ) # gain = old / new
138 | pad = (im1_shape[2] - im0_shape[1] * gain) / 2, (
139 | im1_shape[1] - im0_shape[0] * gain
140 | ) / 2 # wh padding
141 | else:
142 | gain = ratio_pad[0][0]
143 | pad = ratio_pad[1]
144 | top, left = int(pad[1]), int(pad[0]) # y, x
145 | bottom, right = int(im1_shape[1] - pad[1]), int(im1_shape[2] - pad[0])
146 |
147 | masks = masks[:, top:bottom, left:right]
148 | if masks.shape[1:] != im0_shape:
149 | masks = F.interpolate(
150 | masks[None], im0_shape, mode="bilinear", align_corners=False
151 | )[0]
152 |
153 | return masks
154 |
155 |
156 |
157 | #############################################################################
158 |
159 | def detect_bubble(src_list, mask_path:Path, model_type, detection_th, classification):
160 |
161 | if model_type == 0:
162 | prepare_yolo()
163 |
164 | if len(YOLO_SEG_MODEL_LOCATION) > model_type:
165 | model = YOLO(YOLO_SEG_MODEL_LOCATION[model_type])
166 | else:
167 | raise ValueError(f"unknown {model_type=}")
168 |
169 | model.to("cuda" if torch.cuda.is_available() else "cpu")
170 |
171 | for i, src in tqdm( enumerate(src_list), desc=f"detect_bubble", total=len(src_list)):
172 |
173 | result_path = mask_path / Path(Path(src).with_suffix(".png").name)
174 |
175 | org_size = Image.open(src).size
176 |
177 | with torch.no_grad():
178 | results = model.predict(src, save=False, verbose=False, conf=detection_th)
179 |
180 | masks = results[0].masks
181 | boxes = results[0].boxes
182 |
183 | result = [None,None,None]
184 |
185 | if masks is not None:
186 | for mask, box in zip(masks,boxes):
187 | cls = int(box.cls)
188 | if cls > 2:
189 | cls = 2
190 |
191 | mask = scale_image_torch(mask.data, (org_size[1],org_size[0]))
192 | if result[cls] is not None:
193 | result[cls] += mask.squeeze()
194 | else:
195 | result[cls] = mask.squeeze()
196 |
197 | for i in range(len(result)):
198 | if result[i] is not None:
199 | result[i] = result[i].cpu().numpy()
200 | result[i] = result[i].astype('uint8') * 255
201 | else:
202 | result[i] = np.zeros((org_size[1],org_size[0]), np.uint8)
203 |
204 | if classification:
205 | result_array = np.dstack(result)
206 | else:
207 | result_array = result[0] | result[1] | result[2]
208 |
209 | Image.fromarray(result_array).save(result_path)
210 | else:
211 | if classification:
212 | result_array = np.zeros((org_size[1],org_size[0],3), np.uint8)
213 | else:
214 | result_array = np.zeros((org_size[1],org_size[0]), np.uint8)
215 |
216 | Image.fromarray( result_array ).save(result_path)
217 |
218 |
219 | if False:
220 | bubble_array = np.array(Image.open(src))
221 | bubble_array[result==0] = 120
222 |
223 | Image.fromarray(bubble_array).save("bubble_only.png")
224 |
225 | model.to("cpu")
226 |
227 | torch.cuda.empty_cache()
228 |
229 |
230 |
231 | def lama_inpaint(src_list, mask_list, output_path:Path):
232 |
233 | prepare_lama()
234 |
235 | simple_lama = SimpleLama()
236 |
237 | for i, (src, mask) in tqdm( enumerate(zip(src_list,mask_list)), desc=f"lama_inpaint", total=min(len(src_list),len(mask_list))):
238 |
239 | result_path = output_path / Path(Path(src).name)
240 |
241 | image = Image.open(src)
242 |
243 | image = image.convert('RGB')
244 |
245 | org_size = image.size
246 |
247 | mask = np.array(Image.open(mask).convert('L'))
248 |
249 | k = int(org_size[0] * 10 / 480)
250 | mask = cv2.dilate(mask, np.ones((k, k), np.uint8), 3)
251 |
252 | k = int(org_size[0] * 9 / 480) //2 * 2 + 1
253 | mask = cv2.GaussianBlur(mask, ksize=(k,k), sigmaX=0)
254 |
255 | mask[mask >= 125] = 255
256 | mask[mask < 125] = 0
257 |
258 | result = simple_lama(image, mask)
259 | result.save(result_path)
260 |
261 | simple_lama.model.to("cpu")
262 |
263 | torch.cuda.empty_cache()
264 |
265 |
266 | def blend_image_A(org_array, mask_array, dst_array):
267 |
268 | kernel1_size = int(30 / 1000 * dst_array.shape[1])
269 | kernel1_size = max(kernel1_size , 3)
270 |
271 | kernel2_size = int(5 / 1000 * dst_array.shape[1])
272 | kernel2_size = max(kernel2_size , 3)
273 |
274 | gaussian_k_size = int(31 / 1000 * dst_array.shape[1]) // 2 * 2 + 1
275 | gaussian_k_size = max(gaussian_k_size , 3)
276 |
277 | bubble_border_width = int(3 / 1000 * dst_array.shape[1])
278 | bubble_border_width = max(bubble_border_width , 1)
279 | if BUBBLE_BORDER_WIDTH:
280 | bubble_border_width = BUBBLE_BORDER_WIDTH
281 |
282 | kernel = np.ones((kernel1_size,kernel1_size),np.uint8)
283 | kernel2 = np.ones((kernel2_size,kernel2_size),np.uint8)
284 |
285 | mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_OPEN, kernel)
286 | mask_array = cv2.erode(mask_array,kernel2,iterations = 1)
287 |
288 | mask_array = cv2.GaussianBlur(mask_array, ksize=(gaussian_k_size,gaussian_k_size), sigmaX=0)
289 | mask_array[mask_array >= 125] = 255
290 | mask_array[mask_array < 125] = 0
291 |
292 | #Image.fromarray(mask_array.astype(np.uint8)).save("new_mask.png")
293 |
294 | contours, _ = cv2.findContours(mask_array, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
295 |
296 | for j in range(len(contours)):
297 | if cv2.contourArea(contours[j]) > (mask_array.shape[0] * mask_array.shape[1]) * 0.005:
298 | org_array = cv2.drawContours(org_array, contours, j, BUBBLE_BORDER_COLOR, bubble_border_width)
299 |
300 | if mask_array.ndim == 2:
301 | mask_array = mask_array[:, :, np.newaxis]
302 |
303 | mask_array = mask_array / 255
304 |
305 | dst_array = org_array * mask_array + dst_array * (1 - mask_array)
306 |
307 | return dst_array
308 |
309 | def blend_image_B(org_array, mask_array, dst_array):
310 |
311 | gaussian_k_size = int(31 / 1000 * dst_array.shape[1]) // 2 * 2 + 1
312 |
313 | mask_array = cv2.GaussianBlur(mask_array, ksize=(gaussian_k_size,gaussian_k_size), sigmaX=0)
314 |
315 | if mask_array.ndim == 2:
316 | mask_array = mask_array[:, :, np.newaxis]
317 |
318 | mask_array = mask_array / 255
319 |
320 | dst_array = org_array * mask_array + dst_array * (1 - mask_array)
321 |
322 | return dst_array
323 |
324 | def blend_image_C(org_array, mask_array, dst_array):
325 |
326 | org_array[mask_array==0] = (0)
327 |
328 | gray = cv2.cvtColor(org_array, cv2.COLOR_RGB2GRAY)
329 | ret, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
330 |
331 | contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
332 | mask = np.zeros((org_array.shape[0],org_array.shape[1]), np.uint8)
333 | mask = cv2.drawContours(mask, contours, -1, (255), -1)
334 |
335 | if mask.ndim == 2:
336 | mask = mask[:, :, np.newaxis]
337 |
338 | mask = mask / 255
339 |
340 | dst_array = org_array * mask + dst_array * (1 - mask)
341 |
342 | return dst_array
343 |
344 |
345 | def blend_image(src_list, mask_list, base_list, output_path:Path, blend_method):
346 |
347 | for i, (src, mask, base) in tqdm( enumerate(zip(src_list,mask_list,base_list)), desc=f"blend_image", total=min(len(src_list),len(mask_list),len(base_list))):
348 |
349 | result_path = output_path / Path(Path(src).name)
350 |
351 | org_array = np.array(Image.open(src))
352 | mask_array = np.array(Image.open(mask))
353 | dst_array = np.array(Image.open(base))
354 | org_array = org_array[:,:,:3]
355 | dst_array = dst_array[:,:,:3]
356 |
357 | if blend_method != -1:
358 | if mask_array.ndim == 3:
359 | mask_array = mask_array[:, :, 0] | mask_array[:, :, 1] | mask_array[:, :, 2]
360 |
361 | # DEBUG
362 | #dst_array[:,:,:] = 125
363 |
364 | org_array = resize_img( org_array, (dst_array.shape[1], dst_array.shape[0]) )
365 | mask_array = resize_img( mask_array, (dst_array.shape[1], dst_array.shape[0]) )
366 |
367 | if blend_method == -1:
368 |
369 | dst_array = blend_image_A(org_array.copy(), mask_array[:, :, 0], dst_array)
370 | dst_array = blend_image_B(org_array.copy(), mask_array[:, :, 1], dst_array)
371 | dst_array = blend_image_B(org_array.copy(), mask_array[:, :, 2], dst_array)
372 |
373 | elif blend_method == 0:
374 | dst_array = blend_image_A(org_array, mask_array, dst_array)
375 | elif blend_method == 1:
376 | dst_array = blend_image_B(org_array, mask_array, dst_array)
377 | else:
378 | dst_array = blend_image_C(org_array, mask_array, dst_array)
379 |
380 | Image.fromarray(dst_array.astype(np.uint8)).save(result_path)
381 |
382 |
383 | def remove_bubble(src_path:Path, mask_output_path:Path, clean_img_output_path:Path, model_type, detection_th):
384 |
385 | src_list = get_image_file_list(src_path)
386 |
387 | detect_bubble(src_list, mask_output_path, model_type, detection_th, False)
388 |
389 | mask_list = get_image_file_list(mask_output_path)
390 |
391 | lama_inpaint(src_list, mask_list, clean_img_output_path)
392 |
393 |
394 |
395 | def copy_bubble(src_path:Path, base_path:Path, mask_input_path:Path, create_mask, with_bubble_img_output_path:Path, model_type, blend_method, detection_th):
396 |
397 | base_list = get_image_file_list(base_path)
398 | src_list = get_image_file_list(src_path)
399 |
400 | src_map = { s.stem : s for s in src_list }
401 |
402 | src_list = []
403 |
404 | for b in base_list:
405 | src_img_path = src_map.get(b.stem, None)
406 |
407 | if src_img_path:
408 | if src_img_path.is_file():
409 | src_list.append( src_img_path )
410 |
411 |
412 | if create_mask:
413 | detect_bubble(src_list, mask_input_path, model_type, detection_th, True)
414 |
415 | mask_list = [ (mask_input_path/Path(s.name)).with_suffix(".png") for s in base_list if (mask_input_path/Path(s.name)).with_suffix(".png").is_file() ]
416 |
417 | blend_image(src_list, mask_list, base_list, with_bubble_img_output_path, blend_method)
418 |
419 |
420 |
421 | def split_panel(src_path:Path, split_output_path:Path, output_size):
422 | from transformers import AutoModel
423 |
424 | src_list = get_image_file_list(src_path)
425 |
426 | def read_image_as_np_array(image_path):
427 | with open(image_path, "rb") as file:
428 | image = Image.open(file).convert("L").convert("RGB")
429 | image = np.array(image)
430 | return image
431 |
432 | images = [read_image_as_np_array(image) for image in src_list]
433 |
434 | model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True).cuda()
435 | with torch.no_grad():
436 | results = model.predict_detections_and_associations(images)
437 |
438 | crop_info = []
439 |
440 | for i, (result, src_path) in enumerate(zip(results, src_list)) :
441 | with Image.open(src_path) as im:
442 | crop_info.append(f"{src_path.name},{im.size[0]},{im.size[1]}")
443 | for j, panel in enumerate(result["panels"]):
444 | im_crop = im.crop(panel)
445 | scale = output_size / (im_crop.size[0] + im_crop.size[1])
446 | im_crop = im_crop.resize( (int(im_crop.size[0] * scale), int(im_crop.size[1] * scale)), Image.LANCZOS )
447 | filename = str(i).zfill(8) + "_" + str(j).zfill(8) + ".png"
448 | im_crop.save( split_output_path / Path(filename) )
449 | crop_info.append(f"{filename},{panel[0]},{panel[1]},{panel[2]},{panel[3]}")
450 |
451 | crop_info_path = split_output_path / Path("split_info.txt")
452 | crop_info_path.write_text( "\n".join(crop_info), encoding="utf-8")
453 |
454 | model.to("cpu")
455 |
456 | torch.cuda.empty_cache()
457 |
458 |
459 | def combine_panel(src_path:Path, fragment_path:Path, output_dir:Path):
460 |
461 | crop_info_path = fragment_path / Path("split_info.txt")
462 | if crop_info_path.is_file() == False:
463 | raise ValueError(f"{crop_info_path} not found")
464 |
465 | crop_list = crop_info_path.read_text()
466 | crop_list = crop_list.splitlines()
467 |
468 | def create_crop_info(crop_list):
469 | crop_info = {}
470 | cur_src_name = None
471 | cur_item = {}
472 |
473 | for c in crop_list:
474 | c = c.split(",")
475 | if len(c) == 3:
476 | if cur_item:
477 | crop_info[cur_src_name] = cur_item
478 | cur_item = {}
479 | cur_src_name, x, y = c
480 | cur_item["org_size"] = (int(x), int(y))
481 | cur_item["frags"] = {}
482 | else:
483 | frag_name, x1, y1, x2, y2 = c
484 | x1 = int(float(x1))
485 | y1 = int(float(y1))
486 | x2 = int(float(x2))
487 | y2 = int(float(y2))
488 | cur_item["frags"][frag_name] = (x1,y1,x2,y2)
489 |
490 | if cur_item:
491 | crop_info[cur_src_name] = cur_item
492 | return crop_info
493 |
494 |
495 | crop_info = create_crop_info(crop_list)
496 |
497 |
498 | for i, src_name in enumerate(crop_info):
499 | src_img_path = src_path / Path(src_name)
500 | if src_img_path.is_file() == False:
501 | continue
502 | src_img = Image.open( src_img_path )
503 | org_size = crop_info[src_name]["org_size"]
504 |
505 | mod = False
506 |
507 | for frag_name in crop_info[src_name]["frags"]:
508 | frag_img_path = fragment_path / Path(frag_name)
509 | if frag_img_path.is_file() == False:
510 | frag_img_path = frag_img_path.with_suffix(".png")
511 | if frag_img_path.is_file() == False:
512 | continue
513 | frag_img = Image.open(frag_img_path)
514 | x1,y1,x2,y2 = crop_info[src_name]["frags"][frag_name]
515 |
516 | scale_x = (x2-x1) / org_size[0]
517 | scale_y = (y2-y1) / org_size[1]
518 | frag_img = frag_img.resize( (int(src_img.size[0] * scale_x), int(src_img.size[1] * scale_y)), Image.LANCZOS )
519 | scale_x = src_img.size[0]/org_size[0]
520 | scale_y = src_img.size[1]/org_size[1]
521 | src_img.paste(frag_img, (int(x1 * scale_x),int(y1 * scale_y)) )
522 | mod = True
523 |
524 | if mod:
525 | src_img.save( output_dir/ Path(src_name) )
526 |
527 |
--------------------------------------------------------------------------------