├── .ipynb_checkpoints
├── FuseNet _colab-checkpoint.ipynb
└── FuseNet-checkpoint.ipynb
├── FuseNet _colab.ipynb
├── FuseNet.ipynb
├── LICENSE
├── README.md
├── input_images
├── GT
│ ├── Test_1_GT.bmp
│ ├── Test_2_GT.bmp
│ ├── Test_3.bmp
│ └── Test_4_GT.png
└── image
│ ├── Test_1.bmp
│ ├── Test_2.bmp
│ ├── Test_3.bmp
│ ├── Test_4.png
│ └── Test_5.png
├── model_utils.py
├── requirements.txt
└── utils.py
/.ipynb_checkpoints/FuseNet _colab-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5bd16b8a",
6 | "metadata": {},
7 | "source": [
8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "0bbea43e",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "!git clone https://github.com/mindflow-institue/FuseNet.git\n",
19 | "%cd ./FuseNet"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "id": "b2c0cc45",
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "import argparse\n",
30 | "import torch\n",
31 | "import torch.nn as nn\n",
32 | "import torch.nn.functional as F\n",
33 | "import torch.optim as optim\n",
34 | "import torchvision\n",
35 | "import torchvision.transforms as T\n",
36 | "\n",
37 | "import cv2\n",
38 | "import sys\n",
39 | "import os\n",
40 | "import numpy as np\n",
41 | "import random\n",
42 | "import glob\n",
43 | "from matplotlib import pyplot as plt\n",
44 | "\n",
45 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n",
46 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n",
47 | "\n",
48 | "from einops import rearrange\n",
49 | "from einops.layers.torch import Rearrange"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "id": "62a6f1f4",
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "use_cuda = torch.cuda.is_available()\n",
60 | "\n",
61 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n",
62 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n",
63 | " help='number of channels')\n",
64 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n",
65 | " help='number of maximum iterations')\n",
66 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n",
67 | " help='minimum number of labels')\n",
68 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n",
69 | " help='learning rate')\n",
70 | "\n",
71 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n",
72 | " help='input image folder path')\n",
73 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n",
74 | " help='whether to save output ot not')\n",
75 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n",
76 | " help='output folder path')\n",
77 | "\n",
78 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n",
79 | " help='Cross entropy loss weighting factor')\n",
80 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n",
81 | " help='Clip loss weighting factor')\n",
82 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n",
83 | " help='Boundary loss weighting factor')\n",
84 | "\n",
85 | "args = parser.parse_args(args=[])"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "id": "eefd34af",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "if args.save_output:\n",
96 | " SAVE_PATH = args.output_path\n",
97 | " os.makedirs(SAVE_PATH, exist_ok=True)\n",
98 | "\n",
99 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "id": "8af58b72",
105 | "metadata": {},
106 | "source": [
107 | "# Loading Data"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "4688dcb4",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "IMG_PATH = args.input_path\n",
118 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n",
119 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "id": "2d188db8",
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "len(img_data), len(lbl_data)"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "id": "d29d88f9",
135 | "metadata": {},
136 | "source": [
137 | "# Model"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "id": "e7d3b082",
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "class Model(nn.Module):\n",
148 | " \"\"\"\n",
149 | " Args:\n",
150 | " input_dim (int): Dimension of the input data.\n",
151 | " image_embed (int): Dimension of the image embeddings.\n",
152 | " augmented_embed (int): Dimension of the augmented image embeddings.\n",
153 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n",
154 | " temperature (float): Temperature parameter to scale CLIP matrix.\n",
155 | " dropout (float): Dropout rate applied in the projection heads.\n",
156 | " beta (int): Downsampling factor.\n",
157 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n",
158 | " \"\"\"\n",
159 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n",
160 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n",
161 | " super(Model, self).__init__()\n",
162 | " \n",
163 | " input_H, input_W = input_size\n",
164 | " self.H = input_H\n",
165 | " \n",
166 | " self.beta = 16 # Downsampling factor\n",
167 | " self.alpha = 3 # Main path scaling factor\n",
168 | " self.img_enc = Encoder(input_dim, image_embed)\n",
169 | " self.aug_enc = Encoder(input_dim, image_embed)\n",
170 | " \n",
171 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n",
172 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n",
173 | " self.temperature = temperature\n",
174 | " \n",
175 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n",
176 | " value_channels=image_embed, height=input_H, width=input_W)\n",
177 | " \n",
178 | " \n",
179 | " self.patch_size = self.H//8 #32\n",
180 | " self.dim = image_embed\n",
181 | " patch_dim = self.dim * self.patch_size * self.patch_size\n",
182 | " \n",
183 | " self.to_patch_embedding_img = nn.Sequential(\n",
184 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
185 | " nn.Linear(patch_dim, self.dim))\n",
186 | " \n",
187 | " self.to_patch_embedding_aug = nn.Sequential(\n",
188 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
189 | " nn.Linear(patch_dim, self.dim)) \n",
190 | " \n",
191 | " self.bn1 = nn.BatchNorm2d(image_embed)\n",
192 | " self.bn2 = nn.BatchNorm2d(image_embed)\n",
193 | " \n",
194 | " \n",
195 | " def forward(self, x, augmented_x):\n",
196 | "\n",
197 | " # extract feature representations of each modality\n",
198 | " img_f = self.img_enc(x)\n",
199 | " aug_f = self.img_enc(augmented_x) \n",
200 | "\n",
201 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n",
202 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n",
203 | "\n",
204 | " # Getting Image and augmented image Embeddings (with same dimension)\n",
205 | " img_e = self.image_projection(img_f)\n",
206 | " aug_e = self.aug_projection(aug_f)\n",
207 | " \n",
208 | " # Calculating CLIP\n",
209 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
210 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
211 | " \n",
212 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n",
213 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n",
214 | " \n",
215 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n",
216 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n",
217 | " \n",
218 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n",
219 | " img_e_sim = img_e_norm @ img_e_norm.mT\n",
220 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n",
221 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n",
222 | " \n",
223 | " # Cross attention\n",
224 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n",
225 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n",
226 | " \n",
227 | " attn = attn_1 + attn_2\n",
228 | " \n",
229 | " _, edge1 = torch.max(attn, 1)\n",
230 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n",
231 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n",
232 | " _, edge2 = torch.max(attn_up, 1)\n",
233 | " edge = edge1 - edge2\n",
234 | "\n",
235 | " return edge, attn, clip_sim, clip_targets\n"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "id": "e45e692c",
241 | "metadata": {},
242 | "source": [
243 | "# Training"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "id": "e4808e95",
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "img_size = 256"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "id": "a4a4c435",
260 | "metadata": {
261 | "scrolled": false
262 | },
263 | "outputs": [],
264 | "source": [
265 | "for img_num, img_file in enumerate(img_data):\n",
266 | " \n",
267 | " ##### Read image #####\n",
268 | " image = read_image(img_file, img_size).to(device)\n",
269 | "\n",
270 | " ##### Laod Model #####\n",
271 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n",
272 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n",
273 | " beta=16, alpha=3).to(device)\n",
274 | " model.train()\n",
275 | "\n",
276 | " ##### Setteings #####\n",
277 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n",
278 | " \n",
279 | " loss_ce = torch.nn.CrossEntropyLoss()\n",
280 | " loss_s = torch.nn.L1Loss()\n",
281 | " \n",
282 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n",
283 | " label_colours = np.random.randint(255, size=(128, 3))\n",
284 | " \n",
285 | " \n",
286 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n",
287 | " aug_img = jitter(image)\n",
288 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n",
289 | " aug_img = aug_img.to(device)\n",
290 | " \n",
291 | " ##### Training #####\n",
292 | " for batch_idx in range(args.maxIter):\n",
293 | "\n",
294 | " optimizer.zero_grad()\n",
295 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n",
296 | " \n",
297 | " ### Output\n",
298 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n",
299 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
300 | " \n",
301 | " _, target = torch.max(output, 1)\n",
302 | " img_target = target.data.cpu().numpy()\n",
303 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
304 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
305 | " \n",
306 | " ### Cross-entropy loss function \n",
307 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n",
308 | " \n",
309 | " ### Boundary Loss\n",
310 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n",
311 | " \n",
312 | " ### CLIP loss \n",
313 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n",
314 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n",
315 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n",
316 | " \n",
317 | " ### Optimization \n",
318 | " loss = loss_ce_value + loss_clip + loss_edge\n",
319 | " loss.backward()\n",
320 | " optimizer.step()\n",
321 | " \n",
322 | " \n",
323 | " nLabels = len(np.unique(img_target))\n",
324 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n",
325 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n",
326 | " '| B:', round(loss_edge.item(), 4))\n",
327 | " \n",
328 | " if nLabels <= args.minLabels and batch_idx>=5:\n",
329 | " print (f\"Number of labels have reached {nLabels}\")\n",
330 | " break\n",
331 | " \n",
332 | "\n",
333 | " ##### Evaluate #####\n",
334 | " edge, output, _, _ = model(image, aug_img)\n",
335 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
336 | " _, target = torch.max(output, 1)\n",
337 | " img_target = target.data.cpu().numpy()\n",
338 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
339 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
340 | " \n",
341 | " \n",
342 | " ##### Visualization #####\n",
343 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n",
344 | " axes[0].imshow(img_eval_output)\n",
345 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n",
346 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n",
347 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n",
348 | " axes[0].set_title('Prediction')\n",
349 | " axes[1].set_title('Input Image')\n",
350 | " axes[2].set_title('Augmented Image')\n",
351 | " axes[3].set_title('Edge SR') \n",
352 | " axes[0].axis('off')\n",
353 | " axes[1].axis('off')\n",
354 | " axes[2].axis('off')\n",
355 | " axes[3].axis('off')\n",
356 | " plt.show()\n",
357 | " \n",
358 | " if args.save_output:\n",
359 | " name = os.path.basename(img_file).split('.')[0]\n",
360 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n",
361 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
362 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
363 | " \n",
364 | " print('-------------------------------', '\\n')"
365 | ]
366 | }
367 | ],
368 | "metadata": {
369 | "kernelspec": {
370 | "display_name": "Python 3 (ipykernel)",
371 | "language": "python",
372 | "name": "python3"
373 | },
374 | "language_info": {
375 | "codemirror_mode": {
376 | "name": "ipython",
377 | "version": 3
378 | },
379 | "file_extension": ".py",
380 | "mimetype": "text/x-python",
381 | "name": "python",
382 | "nbconvert_exporter": "python",
383 | "pygments_lexer": "ipython3",
384 | "version": "3.11.3"
385 | }
386 | },
387 | "nbformat": 4,
388 | "nbformat_minor": 5
389 | }
390 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/FuseNet-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5bd16b8a",
6 | "metadata": {},
7 | "source": [
8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "b2c0cc45",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import argparse\n",
19 | "import torch\n",
20 | "import torch.nn as nn\n",
21 | "import torch.nn.functional as F\n",
22 | "import torch.optim as optim\n",
23 | "import torchvision\n",
24 | "import torchvision.transforms as T\n",
25 | "\n",
26 | "import cv2\n",
27 | "import sys\n",
28 | "import os\n",
29 | "import numpy as np\n",
30 | "import random\n",
31 | "import glob\n",
32 | "from matplotlib import pyplot as plt\n",
33 | "\n",
34 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n",
35 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n",
36 | "\n",
37 | "from einops import rearrange\n",
38 | "from einops.layers.torch import Rearrange"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "62a6f1f4",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "use_cuda = torch.cuda.is_available()\n",
49 | "\n",
50 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n",
51 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n",
52 | " help='number of channels')\n",
53 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n",
54 | " help='number of maximum iterations')\n",
55 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n",
56 | " help='minimum number of labels')\n",
57 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n",
58 | " help='learning rate')\n",
59 | "\n",
60 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n",
61 | " help='input image folder path')\n",
62 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n",
63 | " help='whether to save output ot not')\n",
64 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n",
65 | " help='output folder path')\n",
66 | "\n",
67 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n",
68 | " help='Cross entropy loss weighting factor')\n",
69 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n",
70 | " help='Clip loss weighting factor')\n",
71 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n",
72 | " help='Boundary loss weighting factor')\n",
73 | "\n",
74 | "args = parser.parse_args(args=[])"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "eefd34af",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "if args.save_output:\n",
85 | " SAVE_PATH = args.output_path\n",
86 | " os.makedirs(SAVE_PATH, exist_ok=True)\n",
87 | "\n",
88 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "8af58b72",
94 | "metadata": {},
95 | "source": [
96 | "# Loading Data"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "4688dcb4",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "IMG_PATH = args.input_path\n",
107 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n",
108 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "id": "2d188db8",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "len(img_data), len(lbl_data)"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "id": "d29d88f9",
124 | "metadata": {},
125 | "source": [
126 | "# Model"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "id": "e7d3b082",
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "class Model(nn.Module):\n",
137 | " \"\"\"\n",
138 | " Args:\n",
139 | " input_dim (int): Dimension of the input data.\n",
140 | " image_embed (int): Dimension of the image embeddings.\n",
141 | " augmented_embed (int): Dimension of the augmented image embeddings.\n",
142 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n",
143 | " temperature (float): Temperature parameter to scale CLIP matrix.\n",
144 | " dropout (float): Dropout rate applied in the projection heads.\n",
145 | " beta (int): Downsampling factor.\n",
146 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n",
147 | " \"\"\"\n",
148 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n",
149 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n",
150 | " super(Model, self).__init__()\n",
151 | " \n",
152 | " input_H, input_W = input_size\n",
153 | " self.H = input_H\n",
154 | " \n",
155 | " self.beta = 16 # Downsampling factor\n",
156 | " self.alpha = 3 # Main path scaling factor\n",
157 | " self.img_enc = Encoder(input_dim, image_embed)\n",
158 | " self.aug_enc = Encoder(input_dim, image_embed)\n",
159 | " \n",
160 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n",
161 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n",
162 | " self.temperature = temperature\n",
163 | " \n",
164 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n",
165 | " value_channels=image_embed, height=input_H, width=input_W)\n",
166 | " \n",
167 | " \n",
168 | " self.patch_size = self.H//8 #32\n",
169 | " self.dim = image_embed\n",
170 | " patch_dim = self.dim * self.patch_size * self.patch_size\n",
171 | " \n",
172 | " self.to_patch_embedding_img = nn.Sequential(\n",
173 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
174 | " nn.Linear(patch_dim, self.dim))\n",
175 | " \n",
176 | " self.to_patch_embedding_aug = nn.Sequential(\n",
177 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
178 | " nn.Linear(patch_dim, self.dim)) \n",
179 | " \n",
180 | " self.bn1 = nn.BatchNorm2d(image_embed)\n",
181 | " self.bn2 = nn.BatchNorm2d(image_embed)\n",
182 | " \n",
183 | " \n",
184 | " def forward(self, x, augmented_x):\n",
185 | "\n",
186 | " # extract feature representations of each modality\n",
187 | " img_f = self.img_enc(x)\n",
188 | " aug_f = self.img_enc(augmented_x) \n",
189 | "\n",
190 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n",
191 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n",
192 | "\n",
193 | " # Getting Image and augmented image Embeddings (with same dimension)\n",
194 | " img_e = self.image_projection(img_f)\n",
195 | " aug_e = self.aug_projection(aug_f)\n",
196 | " \n",
197 | " # Calculating CLIP\n",
198 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
199 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
200 | " \n",
201 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n",
202 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n",
203 | " \n",
204 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n",
205 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n",
206 | " \n",
207 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n",
208 | " img_e_sim = img_e_norm @ img_e_norm.mT\n",
209 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n",
210 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n",
211 | " \n",
212 | " # Cross attention\n",
213 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n",
214 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n",
215 | " \n",
216 | " attn = attn_1 + attn_2\n",
217 | " \n",
218 | " _, edge1 = torch.max(attn, 1)\n",
219 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n",
220 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n",
221 | " _, edge2 = torch.max(attn_up, 1)\n",
222 | " edge = edge1 - edge2\n",
223 | "\n",
224 | " return edge, attn, clip_sim, clip_targets\n"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "id": "e45e692c",
230 | "metadata": {},
231 | "source": [
232 | "# Training"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "e4808e95",
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "img_size = 256"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": null,
248 | "id": "a4a4c435",
249 | "metadata": {
250 | "scrolled": false
251 | },
252 | "outputs": [],
253 | "source": [
254 | "for img_num, img_file in enumerate(img_data):\n",
255 | " \n",
256 | " ##### Read image #####\n",
257 | " image = read_image(img_file, img_size).to(device)\n",
258 | "\n",
259 | " ##### Laod Model #####\n",
260 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n",
261 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n",
262 | " beta=16, alpha=3).to(device)\n",
263 | " model.train()\n",
264 | "\n",
265 | " ##### Setteings #####\n",
266 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n",
267 | " \n",
268 | " loss_ce = torch.nn.CrossEntropyLoss()\n",
269 | " loss_s = torch.nn.L1Loss()\n",
270 | " \n",
271 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n",
272 | " label_colours = np.random.randint(255, size=(128, 3))\n",
273 | " \n",
274 | " \n",
275 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n",
276 | " aug_img = jitter(image)\n",
277 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n",
278 | " aug_img = aug_img.to(device)\n",
279 | " \n",
280 | " ##### Training #####\n",
281 | " for batch_idx in range(args.maxIter):\n",
282 | "\n",
283 | " optimizer.zero_grad()\n",
284 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n",
285 | " \n",
286 | " ### Output\n",
287 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n",
288 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
289 | " \n",
290 | " _, target = torch.max(output, 1)\n",
291 | " img_target = target.data.cpu().numpy()\n",
292 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
293 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
294 | " \n",
295 | " ### Cross-entropy loss function \n",
296 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n",
297 | " \n",
298 | " ### Boundary Loss\n",
299 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n",
300 | " \n",
301 | " ### CLIP loss \n",
302 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n",
303 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n",
304 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n",
305 | " \n",
306 | " ### Optimization \n",
307 | " loss = loss_ce_value + loss_clip + loss_edge\n",
308 | " loss.backward()\n",
309 | " optimizer.step()\n",
310 | " \n",
311 | " \n",
312 | " nLabels = len(np.unique(img_target))\n",
313 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n",
314 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n",
315 | " '| B:', round(loss_edge.item(), 4))\n",
316 | " \n",
317 | " if nLabels <= args.minLabels and batch_idx>=5:\n",
318 | " print (f\"Number of labels have reached {nLabels}\")\n",
319 | " break\n",
320 | " \n",
321 | "\n",
322 | " ##### Evaluate #####\n",
323 | " edge, output, _, _ = model(image, aug_img)\n",
324 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
325 | " _, target = torch.max(output, 1)\n",
326 | " img_target = target.data.cpu().numpy()\n",
327 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
328 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
329 | " \n",
330 | " \n",
331 | " ##### Visualization #####\n",
332 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n",
333 | " axes[0].imshow(img_eval_output)\n",
334 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n",
335 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n",
336 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n",
337 | " axes[0].set_title('Prediction')\n",
338 | " axes[1].set_title('Input Image')\n",
339 | " axes[2].set_title('Augmented Image')\n",
340 | " axes[3].set_title('Edge SR') \n",
341 | " axes[0].axis('off')\n",
342 | " axes[1].axis('off')\n",
343 | " axes[2].axis('off')\n",
344 | " axes[3].axis('off')\n",
345 | " plt.show()\n",
346 | " \n",
347 | " if args.save_output:\n",
348 | " name = os.path.basename(img_file).split('.')[0]\n",
349 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n",
350 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
351 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
352 | " \n",
353 | " print('-------------------------------', '\\n')"
354 | ]
355 | }
356 | ],
357 | "metadata": {
358 | "kernelspec": {
359 | "display_name": "Python 3 (ipykernel)",
360 | "language": "python",
361 | "name": "python3"
362 | },
363 | "language_info": {
364 | "codemirror_mode": {
365 | "name": "ipython",
366 | "version": 3
367 | },
368 | "file_extension": ".py",
369 | "mimetype": "text/x-python",
370 | "name": "python",
371 | "nbconvert_exporter": "python",
372 | "pygments_lexer": "ipython3",
373 | "version": "3.11.3"
374 | }
375 | },
376 | "nbformat": 4,
377 | "nbformat_minor": 5
378 | }
379 |
--------------------------------------------------------------------------------
/FuseNet _colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5bd16b8a",
6 | "metadata": {},
7 | "source": [
8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "0bbea43e",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "!git clone https://github.com/mindflow-institue/FuseNet.git\n",
19 | "%cd ./FuseNet"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "id": "b2c0cc45",
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "import argparse\n",
30 | "import torch\n",
31 | "import torch.nn as nn\n",
32 | "import torch.nn.functional as F\n",
33 | "import torch.optim as optim\n",
34 | "import torchvision\n",
35 | "import torchvision.transforms as T\n",
36 | "\n",
37 | "import cv2\n",
38 | "import sys\n",
39 | "import os\n",
40 | "import numpy as np\n",
41 | "import random\n",
42 | "import glob\n",
43 | "from matplotlib import pyplot as plt\n",
44 | "\n",
45 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n",
46 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n",
47 | "\n",
48 | "from einops import rearrange\n",
49 | "from einops.layers.torch import Rearrange"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "id": "62a6f1f4",
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "use_cuda = torch.cuda.is_available()\n",
60 | "\n",
61 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n",
62 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n",
63 | " help='number of channels')\n",
64 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n",
65 | " help='number of maximum iterations')\n",
66 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n",
67 | " help='minimum number of labels')\n",
68 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n",
69 | " help='learning rate')\n",
70 | "\n",
71 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n",
72 | " help='input image folder path')\n",
73 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n",
74 | " help='whether to save output ot not')\n",
75 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n",
76 | " help='output folder path')\n",
77 | "\n",
78 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n",
79 | " help='Cross entropy loss weighting factor')\n",
80 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n",
81 | " help='Clip loss weighting factor')\n",
82 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n",
83 | " help='Boundary loss weighting factor')\n",
84 | "\n",
85 | "args = parser.parse_args(args=[])"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "id": "eefd34af",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "if args.save_output:\n",
96 | " SAVE_PATH = args.output_path\n",
97 | " os.makedirs(SAVE_PATH, exist_ok=True)\n",
98 | "\n",
99 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "id": "8af58b72",
105 | "metadata": {},
106 | "source": [
107 | "# Loading Data"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "4688dcb4",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "IMG_PATH = args.input_path\n",
118 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n",
119 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "id": "2d188db8",
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "len(img_data), len(lbl_data)"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "id": "d29d88f9",
135 | "metadata": {},
136 | "source": [
137 | "# Model"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "id": "e7d3b082",
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "class Model(nn.Module):\n",
148 | " \"\"\"\n",
149 | " Args:\n",
150 | " input_dim (int): Dimension of the input data.\n",
151 | " image_embed (int): Dimension of the image embeddings.\n",
152 | " augmented_embed (int): Dimension of the augmented image embeddings.\n",
153 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n",
154 | " temperature (float): Temperature parameter to scale CLIP matrix.\n",
155 | " dropout (float): Dropout rate applied in the projection heads.\n",
156 | " beta (int): Downsampling factor.\n",
157 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n",
158 | " \"\"\"\n",
159 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n",
160 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n",
161 | " super(Model, self).__init__()\n",
162 | " \n",
163 | " input_H, input_W = input_size\n",
164 | " self.H = input_H\n",
165 | " \n",
166 | " self.beta = 16 # Downsampling factor\n",
167 | " self.alpha = 3 # Main path scaling factor\n",
168 | " self.img_enc = Encoder(input_dim, image_embed)\n",
169 | " self.aug_enc = Encoder(input_dim, image_embed)\n",
170 | " \n",
171 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n",
172 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n",
173 | " self.temperature = temperature\n",
174 | " \n",
175 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n",
176 | " value_channels=image_embed, height=input_H, width=input_W)\n",
177 | " \n",
178 | " \n",
179 | " self.patch_size = self.H//8 #32\n",
180 | " self.dim = image_embed\n",
181 | " patch_dim = self.dim * self.patch_size * self.patch_size\n",
182 | " \n",
183 | " self.to_patch_embedding_img = nn.Sequential(\n",
184 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
185 | " nn.Linear(patch_dim, self.dim))\n",
186 | " \n",
187 | " self.to_patch_embedding_aug = nn.Sequential(\n",
188 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
189 | " nn.Linear(patch_dim, self.dim)) \n",
190 | " \n",
191 | " self.bn1 = nn.BatchNorm2d(image_embed)\n",
192 | " self.bn2 = nn.BatchNorm2d(image_embed)\n",
193 | " \n",
194 | " \n",
195 | " def forward(self, x, augmented_x):\n",
196 | "\n",
197 | " # extract feature representations of each modality\n",
198 | " img_f = self.img_enc(x)\n",
199 | " aug_f = self.img_enc(augmented_x) \n",
200 | "\n",
201 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n",
202 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n",
203 | "\n",
204 | " # Getting Image and augmented image Embeddings (with same dimension)\n",
205 | " img_e = self.image_projection(img_f)\n",
206 | " aug_e = self.aug_projection(aug_f)\n",
207 | " \n",
208 | " # Calculating CLIP\n",
209 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
210 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
211 | " \n",
212 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n",
213 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n",
214 | " \n",
215 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n",
216 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n",
217 | " \n",
218 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n",
219 | " img_e_sim = img_e_norm @ img_e_norm.mT\n",
220 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n",
221 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n",
222 | " \n",
223 | " # Cross attention\n",
224 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n",
225 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n",
226 | " \n",
227 | " attn = attn_1 + attn_2\n",
228 | " \n",
229 | " _, edge1 = torch.max(attn, 1)\n",
230 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n",
231 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n",
232 | " _, edge2 = torch.max(attn_up, 1)\n",
233 | " edge = edge1 - edge2\n",
234 | "\n",
235 | " return edge, attn, clip_sim, clip_targets\n"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "id": "e45e692c",
241 | "metadata": {},
242 | "source": [
243 | "# Training"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "id": "e4808e95",
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "img_size = 256"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "id": "a4a4c435",
260 | "metadata": {
261 | "scrolled": false
262 | },
263 | "outputs": [],
264 | "source": [
265 | "for img_num, img_file in enumerate(img_data):\n",
266 | " \n",
267 | " ##### Read image #####\n",
268 | " image = read_image(img_file, img_size).to(device)\n",
269 | "\n",
270 | " ##### Laod Model #####\n",
271 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n",
272 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n",
273 | " beta=16, alpha=3).to(device)\n",
274 | " model.train()\n",
275 | "\n",
276 | " ##### Setteings #####\n",
277 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n",
278 | " \n",
279 | " loss_ce = torch.nn.CrossEntropyLoss()\n",
280 | " loss_s = torch.nn.L1Loss()\n",
281 | " \n",
282 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n",
283 | " label_colours = np.random.randint(255, size=(128, 3))\n",
284 | " \n",
285 | " \n",
286 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n",
287 | " aug_img = jitter(image)\n",
288 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n",
289 | " aug_img = aug_img.to(device)\n",
290 | " \n",
291 | " ##### Training #####\n",
292 | " for batch_idx in range(args.maxIter):\n",
293 | "\n",
294 | " optimizer.zero_grad()\n",
295 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n",
296 | " \n",
297 | " ### Output\n",
298 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n",
299 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
300 | " \n",
301 | " _, target = torch.max(output, 1)\n",
302 | " img_target = target.data.cpu().numpy()\n",
303 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
304 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
305 | " \n",
306 | " ### Cross-entropy loss function \n",
307 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n",
308 | " \n",
309 | " ### Boundary Loss\n",
310 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n",
311 | " \n",
312 | " ### CLIP loss \n",
313 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n",
314 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n",
315 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n",
316 | " \n",
317 | " ### Optimization \n",
318 | " loss = loss_ce_value + loss_clip + loss_edge\n",
319 | " loss.backward()\n",
320 | " optimizer.step()\n",
321 | " \n",
322 | " \n",
323 | " nLabels = len(np.unique(img_target))\n",
324 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n",
325 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n",
326 | " '| B:', round(loss_edge.item(), 4))\n",
327 | " \n",
328 | " if nLabels <= args.minLabels and batch_idx>=5:\n",
329 | " print (f\"Number of labels have reached {nLabels}\")\n",
330 | " break\n",
331 | " \n",
332 | "\n",
333 | " ##### Evaluate #####\n",
334 | " edge, output, _, _ = model(image, aug_img)\n",
335 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
336 | " _, target = torch.max(output, 1)\n",
337 | " img_target = target.data.cpu().numpy()\n",
338 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
339 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
340 | " \n",
341 | " \n",
342 | " ##### Visualization #####\n",
343 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n",
344 | " axes[0].imshow(img_eval_output)\n",
345 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n",
346 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n",
347 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n",
348 | " axes[0].set_title('Prediction')\n",
349 | " axes[1].set_title('Input Image')\n",
350 | " axes[2].set_title('Augmented Image')\n",
351 | " axes[3].set_title('Edge SR') \n",
352 | " axes[0].axis('off')\n",
353 | " axes[1].axis('off')\n",
354 | " axes[2].axis('off')\n",
355 | " axes[3].axis('off')\n",
356 | " plt.show()\n",
357 | " \n",
358 | " if args.save_output:\n",
359 | " name = os.path.basename(img_file).split('.')[0]\n",
360 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n",
361 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
362 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
363 | " \n",
364 | " print('-------------------------------', '\\n')"
365 | ]
366 | }
367 | ],
368 | "metadata": {
369 | "kernelspec": {
370 | "display_name": "Python 3 (ipykernel)",
371 | "language": "python",
372 | "name": "python3"
373 | },
374 | "language_info": {
375 | "codemirror_mode": {
376 | "name": "ipython",
377 | "version": 3
378 | },
379 | "file_extension": ".py",
380 | "mimetype": "text/x-python",
381 | "name": "python",
382 | "nbconvert_exporter": "python",
383 | "pygments_lexer": "ipython3",
384 | "version": "3.11.3"
385 | }
386 | },
387 | "nbformat": 4,
388 | "nbformat_minor": 5
389 | }
390 |
--------------------------------------------------------------------------------
/FuseNet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5bd16b8a",
6 | "metadata": {},
7 | "source": [
8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "b2c0cc45",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import argparse\n",
19 | "import torch\n",
20 | "import torch.nn as nn\n",
21 | "import torch.nn.functional as F\n",
22 | "import torch.optim as optim\n",
23 | "import torchvision\n",
24 | "import torchvision.transforms as T\n",
25 | "\n",
26 | "import cv2\n",
27 | "import sys\n",
28 | "import os\n",
29 | "import numpy as np\n",
30 | "import random\n",
31 | "import glob\n",
32 | "from matplotlib import pyplot as plt\n",
33 | "\n",
34 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n",
35 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n",
36 | "\n",
37 | "from einops import rearrange\n",
38 | "from einops.layers.torch import Rearrange"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "62a6f1f4",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "use_cuda = torch.cuda.is_available()\n",
49 | "\n",
50 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n",
51 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n",
52 | " help='number of channels')\n",
53 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n",
54 | " help='number of maximum iterations')\n",
55 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n",
56 | " help='minimum number of labels')\n",
57 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n",
58 | " help='learning rate')\n",
59 | "\n",
60 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n",
61 | " help='input image folder path')\n",
62 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n",
63 | " help='whether to save output ot not')\n",
64 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n",
65 | " help='output folder path')\n",
66 | "\n",
67 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n",
68 | " help='Cross entropy loss weighting factor')\n",
69 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n",
70 | " help='Clip loss weighting factor')\n",
71 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n",
72 | " help='Boundary loss weighting factor')\n",
73 | "\n",
74 | "args = parser.parse_args(args=[])"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "eefd34af",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "if args.save_output:\n",
85 | " SAVE_PATH = args.output_path\n",
86 | " os.makedirs(SAVE_PATH, exist_ok=True)\n",
87 | "\n",
88 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "8af58b72",
94 | "metadata": {},
95 | "source": [
96 | "# Loading Data"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "4688dcb4",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "IMG_PATH = args.input_path\n",
107 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n",
108 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "id": "2d188db8",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "len(img_data), len(lbl_data)"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "id": "d29d88f9",
124 | "metadata": {},
125 | "source": [
126 | "# Model"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "id": "e7d3b082",
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "class Model(nn.Module):\n",
137 | " \"\"\"\n",
138 | " Args:\n",
139 | " input_dim (int): Dimension of the input data.\n",
140 | " image_embed (int): Dimension of the image embeddings.\n",
141 | " augmented_embed (int): Dimension of the augmented image embeddings.\n",
142 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n",
143 | " temperature (float): Temperature parameter to scale CLIP matrix.\n",
144 | " dropout (float): Dropout rate applied in the projection heads.\n",
145 | " beta (int): Downsampling factor.\n",
146 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n",
147 | " \"\"\"\n",
148 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n",
149 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n",
150 | " super(Model, self).__init__()\n",
151 | " \n",
152 | " input_H, input_W = input_size\n",
153 | " self.H = input_H\n",
154 | " \n",
155 | " self.beta = 16 # Downsampling factor\n",
156 | " self.alpha = 3 # Main path scaling factor\n",
157 | " self.img_enc = Encoder(input_dim, image_embed)\n",
158 | " self.aug_enc = Encoder(input_dim, image_embed)\n",
159 | " \n",
160 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n",
161 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n",
162 | " self.temperature = temperature\n",
163 | " \n",
164 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n",
165 | " value_channels=image_embed, height=input_H, width=input_W)\n",
166 | " \n",
167 | " \n",
168 | " self.patch_size = self.H//8 #32\n",
169 | " self.dim = image_embed\n",
170 | " patch_dim = self.dim * self.patch_size * self.patch_size\n",
171 | " \n",
172 | " self.to_patch_embedding_img = nn.Sequential(\n",
173 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
174 | " nn.Linear(patch_dim, self.dim))\n",
175 | " \n",
176 | " self.to_patch_embedding_aug = nn.Sequential(\n",
177 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n",
178 | " nn.Linear(patch_dim, self.dim)) \n",
179 | " \n",
180 | " self.bn1 = nn.BatchNorm2d(image_embed)\n",
181 | " self.bn2 = nn.BatchNorm2d(image_embed)\n",
182 | " \n",
183 | " \n",
184 | " def forward(self, x, augmented_x):\n",
185 | "\n",
186 | " # extract feature representations of each modality\n",
187 | " img_f = self.img_enc(x)\n",
188 | " aug_f = self.img_enc(augmented_x) \n",
189 | "\n",
190 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n",
191 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n",
192 | "\n",
193 | " # Getting Image and augmented image Embeddings (with same dimension)\n",
194 | " img_e = self.image_projection(img_f)\n",
195 | " aug_e = self.aug_projection(aug_f)\n",
196 | " \n",
197 | " # Calculating CLIP\n",
198 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
199 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n",
200 | " \n",
201 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n",
202 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n",
203 | " \n",
204 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n",
205 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n",
206 | " \n",
207 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n",
208 | " img_e_sim = img_e_norm @ img_e_norm.mT\n",
209 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n",
210 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n",
211 | " \n",
212 | " # Cross attention\n",
213 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n",
214 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n",
215 | " \n",
216 | " attn = attn_1 + attn_2\n",
217 | " \n",
218 | " _, edge1 = torch.max(attn, 1)\n",
219 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n",
220 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n",
221 | " _, edge2 = torch.max(attn_up, 1)\n",
222 | " edge = edge1 - edge2\n",
223 | "\n",
224 | " return edge, attn, clip_sim, clip_targets\n"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "id": "e45e692c",
230 | "metadata": {},
231 | "source": [
232 | "# Training"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "e4808e95",
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "img_size = 256"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": null,
248 | "id": "a4a4c435",
249 | "metadata": {
250 | "scrolled": false
251 | },
252 | "outputs": [],
253 | "source": [
254 | "for img_num, img_file in enumerate(img_data):\n",
255 | " \n",
256 | " ##### Read image #####\n",
257 | " image = read_image(img_file, img_size).to(device)\n",
258 | "\n",
259 | " ##### Laod Model #####\n",
260 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n",
261 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n",
262 | " beta=16, alpha=3).to(device)\n",
263 | " model.train()\n",
264 | "\n",
265 | " ##### Setteings #####\n",
266 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n",
267 | " \n",
268 | " loss_ce = torch.nn.CrossEntropyLoss()\n",
269 | " loss_s = torch.nn.L1Loss()\n",
270 | " \n",
271 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n",
272 | " label_colours = np.random.randint(255, size=(128, 3))\n",
273 | " \n",
274 | " \n",
275 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n",
276 | " aug_img = jitter(image)\n",
277 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n",
278 | " aug_img = aug_img.to(device)\n",
279 | " \n",
280 | " ##### Training #####\n",
281 | " for batch_idx in range(args.maxIter):\n",
282 | "\n",
283 | " optimizer.zero_grad()\n",
284 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n",
285 | " \n",
286 | " ### Output\n",
287 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n",
288 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
289 | " \n",
290 | " _, target = torch.max(output, 1)\n",
291 | " img_target = target.data.cpu().numpy()\n",
292 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
293 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
294 | " \n",
295 | " ### Cross-entropy loss function \n",
296 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n",
297 | " \n",
298 | " ### Boundary Loss\n",
299 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n",
300 | " \n",
301 | " ### CLIP loss \n",
302 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n",
303 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n",
304 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n",
305 | " \n",
306 | " ### Optimization \n",
307 | " loss = loss_ce_value + loss_clip + loss_edge\n",
308 | " loss.backward()\n",
309 | " optimizer.step()\n",
310 | " \n",
311 | " \n",
312 | " nLabels = len(np.unique(img_target))\n",
313 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n",
314 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n",
315 | " '| B:', round(loss_edge.item(), 4))\n",
316 | " \n",
317 | " if nLabels <= args.minLabels and batch_idx>=5:\n",
318 | " print (f\"Number of labels have reached {nLabels}\")\n",
319 | " break\n",
320 | " \n",
321 | "\n",
322 | " ##### Evaluate #####\n",
323 | " edge, output, _, _ = model(image, aug_img)\n",
324 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n",
325 | " _, target = torch.max(output, 1)\n",
326 | " img_target = target.data.cpu().numpy()\n",
327 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n",
328 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n",
329 | " \n",
330 | " \n",
331 | " ##### Visualization #####\n",
332 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n",
333 | " axes[0].imshow(img_eval_output)\n",
334 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n",
335 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n",
336 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n",
337 | " axes[0].set_title('Prediction')\n",
338 | " axes[1].set_title('Input Image')\n",
339 | " axes[2].set_title('Augmented Image')\n",
340 | " axes[3].set_title('Edge SR') \n",
341 | " axes[0].axis('off')\n",
342 | " axes[1].axis('off')\n",
343 | " axes[2].axis('off')\n",
344 | " axes[3].axis('off')\n",
345 | " plt.show()\n",
346 | " \n",
347 | " if args.save_output:\n",
348 | " name = os.path.basename(img_file).split('.')[0]\n",
349 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n",
350 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
351 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n",
352 | " \n",
353 | " print('-------------------------------', '\\n')"
354 | ]
355 | }
356 | ],
357 | "metadata": {
358 | "kernelspec": {
359 | "display_name": "Python 3 (ipykernel)",
360 | "language": "python",
361 | "name": "python3"
362 | },
363 | "language_info": {
364 | "codemirror_mode": {
365 | "name": "ipython",
366 | "version": 3
367 | },
368 | "file_extension": ".py",
369 | "mimetype": "text/x-python",
370 | "name": "python",
371 | "nbconvert_exporter": "python",
372 | "pygments_lexer": "ipython3",
373 | "version": "3.11.3"
374 | }
375 | },
376 | "nbformat": 4,
377 | "nbformat_minor": 5
378 | }
379 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 X-MindFlow
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
2 |
3 | [](https://arxiv.org/abs/2311.13069) [](https://colab.research.google.com/github/mindflow-institue/FuseNet/blob/main/FuseNet_colab.ipynb)
4 |
5 |
6 | Semantic segmentation, a crucial task in computer vision, often relies on labor-intensive and costly annotated datasets for training. In response to this challenge, we introduce FuseNet, a dual-stream framework for self-supervised semantic segmentation that eliminates the need for manual annotation. FuseNet leverages the shared semantic dependencies between the original and augmented images to create a clustering space, effectively assigning pixels to semantically related clusters, and ultimately generating the segmentation map. Additionally, FuseNet incorporates a cross-modal fusion technique that extends the principles of CLIP by replacing textual data with augmented images. This approach enables the model to learn complex visual representations, enhancing robustness against variations similar to CLIP’s text invariance. To further improve edge alignment and spatial consistency between neighboring pixels, we introduce an edge refinement loss. This loss function considers edge information to enhance spatial coherence, facilitating the grouping of nearby pixels with similar visual features. Extensive experiments on skin lesion and lung segmentation datasets demonstrate the effectiveness of our method.
7 |
8 |
9 |
10 | 
11 |
12 |
13 |
14 |
15 |
16 | ## Updates
17 | - If you found this paper useful, please consider checking out our previously accepted papers at MIDL and ICCV:
18 | `MS-Former` [[Paper](https://openreview.net/forum?id=pp2raGSU3Wx)] [[GitHub](https://github.com/mindflow-institue/MS-Former)], and `S3-Net` [[Paper](https://openreview.net/forum?id=pp2raGSU3Wx)] [[GitHub](https://github.com/mindflow-institue/MS-Former)] ♥️✌🏻
19 |
20 | - November 22, 2023: First release of the code.
21 |
22 | ## Installation
23 |
24 | ```bash
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## Run Demo
29 | Put your input images in the ```input_images/image``` folder and just simply run the ```FuseNet.ipynb``` notebook ;)
30 |
31 | ## Experiments
32 |
33 |
34 |
35 |