If you like my work, please consider showing your support on Patreon. Thank you! ❤
')
43 |
44 | return [mask_prompt,negative_mask_prompt, mask_precision, mask_padding, brush_mask_mode, mask_output, plug]
45 |
46 | def run(self, p, mask_prompt, negative_mask_prompt, mask_precision, mask_padding, brush_mask_mode, mask_output, plug):
47 | def download_file(filename, url):
48 | with open(filename, 'wb') as fout:
49 | response = requests.get(url, stream=True)
50 | response.raise_for_status()
51 | # Write response data to file
52 | for block in response.iter_content(4096):
53 | fout.write(block)
54 | def pil_to_cv2(img):
55 | return (cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR))
56 | def gray_to_pil(img):
57 | return (Image.fromarray(cv2.cvtColor(img,cv2.COLOR_GRAY2RGBA)))
58 |
59 | def center_crop(img,new_width,new_height):
60 | width, height = img.size # Get dimensions
61 |
62 | left = (width - new_width)/2
63 | top = (height - new_height)/2
64 | right = (width + new_width)/2
65 | bottom = (height + new_height)/2
66 |
67 | # Crop the center of the image
68 | return(img.crop((left, top, right, bottom)))
69 |
70 | def overlay_mask_part(img_a,img_b,mode):
71 | if (mode == 0):
72 | img_a = ImageChops.darker(img_a, img_b)
73 | else: img_a = ImageChops.lighter(img_a, img_b)
74 | return(img_a)
75 |
76 | def process_mask_parts(these_preds,these_prompt_parts,mode,final_img = None):
77 | for i in range(these_prompt_parts):
78 | filename = f"mask_{mode}_{i}.png"
79 | plt.imsave(filename,torch.sigmoid(these_preds[i][0]))
80 |
81 | # TODO: Figure out how to convert the plot above to numpy instead of re-loading image
82 | img = cv2.imread(filename)
83 | gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
84 | (thresh, bw_image) = cv2.threshold(gray_image, mask_precision, 255, cv2.THRESH_BINARY)
85 |
86 | if (mode == 0): bw_image = numpy.invert(bw_image)
87 |
88 | if (debug):
89 | print(f"bw_image: {bw_image}")
90 | print(f"final_img: {final_img}")
91 |
92 | # overlay mask parts
93 | bw_image = gray_to_pil(bw_image)
94 | if (i > 0 or final_img is not None):
95 | bw_image = overlay_mask_part(bw_image,final_img,mode)
96 |
97 | # For debugging only:
98 | if (debug): bw_image.save(f"processed_{filename}")
99 |
100 | final_img = bw_image
101 |
102 | return(final_img)
103 |
104 | def get_mask():
105 | # load model
106 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
107 | model.eval();
108 | model_dir = "./repositories/clipseg/weights"
109 | os.makedirs(model_dir, exist_ok=True)
110 | d64_file = f"{model_dir}/rd64-uni.pth"
111 | d16_file = f"{model_dir}/rd16-uni.pth"
112 | delimiter_string = "|"
113 |
114 | # Download model weights if we don't have them yet
115 | if not os.path.exists(d64_file):
116 | print("Downloading clipseg model weights...")
117 | download_file(d64_file,"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files=rd64-uni.pth")
118 | download_file(d16_file,"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files=rd16-uni.pth")
119 | # Mirror:
120 | # https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth
121 | # https://github.com/timojl/clipseg/raw/master/weights/rd16-uni.pth
122 |
123 | # non-strict, because we only stored decoder weights (not CLIP weights)
124 | model.load_state_dict(torch.load(d64_file, map_location=torch.device('cuda')), strict=False);
125 |
126 | transform = transforms.Compose([
127 | transforms.ToTensor(),
128 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
129 | transforms.Resize((512, 512)),
130 | ])
131 | img = transform(p.init_images[0]).unsqueeze(0)
132 |
133 | prompts = mask_prompt.split(delimiter_string)
134 | prompt_parts = len(prompts)
135 | negative_prompts = negative_mask_prompt.split(delimiter_string)
136 | negative_prompt_parts = len(negative_prompts)
137 |
138 | # predict
139 | with torch.no_grad():
140 | preds = model(img.repeat(prompt_parts,1,1,1), prompts)[0]
141 | negative_preds = model(img.repeat(negative_prompt_parts,1,1,1), negative_prompts)[0]
142 |
143 | #tests
144 | if (debug):
145 | print("Check initial mask vars before processing...")
146 | print(f"p.image_mask: {p.image_mask}")
147 | print(f"p.latent_mask: {p.latent_mask}")
148 | print(f"p.mask_for_overlay: {p.mask_for_overlay}")
149 |
150 | if (brush_mask_mode == 1 and p.image_mask is not None):
151 | final_img = p.image_mask.convert("RGBA")
152 | else: final_img = None
153 |
154 | # process masking
155 | final_img = process_mask_parts(preds,prompt_parts,1,final_img)
156 |
157 | # process negative masking
158 | if (brush_mask_mode == 2 and p.image_mask is not None):
159 | p.image_mask = ImageOps.invert(p.image_mask)
160 | p.image_mask = p.image_mask.convert("RGBA")
161 | final_img = overlay_mask_part(final_img,p.image_mask,0)
162 | if (negative_mask_prompt): final_img = process_mask_parts(negative_preds,negative_prompt_parts,0,final_img)
163 |
164 | # Increase mask size with padding
165 | if (mask_padding > 0):
166 | aspect_ratio = p.init_images[0].width / p.init_images[0].height
167 | new_width = p.init_images[0].width+mask_padding*2
168 | new_height = round(new_width / aspect_ratio)
169 | final_img = final_img.resize((new_width,new_height))
170 | final_img = center_crop(final_img,p.init_images[0].width,p.init_images[0].height)
171 |
172 | return (final_img)
173 |
174 |
175 | # Set up processor parameters correctly
176 | p.mode = 1
177 | p.mask_mode = 1
178 | p.image_mask = get_mask().resize((p.init_images[0].width,p.init_images[0].height))
179 | p.mask_for_overlay = p.image_mask
180 | p.latent_mask = None # fixes inpainting full resolution
181 |
182 |
183 | processed = processing.process_images(p)
184 |
185 | if (mask_output):
186 | processed.images.append(p.image_mask)
187 |
188 | return processed
--------------------------------------------------------------------------------