├── .ipynb_checkpoints ├── FuseNet _colab-checkpoint.ipynb └── FuseNet-checkpoint.ipynb ├── FuseNet _colab.ipynb ├── FuseNet.ipynb ├── LICENSE ├── README.md ├── input_images ├── GT │ ├── Test_1_GT.bmp │ ├── Test_2_GT.bmp │ ├── Test_3.bmp │ └── Test_4_GT.png └── image │ ├── Test_1.bmp │ ├── Test_2.bmp │ ├── Test_3.bmp │ ├── Test_4.png │ └── Test_5.png ├── model_utils.py ├── requirements.txt └── utils.py /.ipynb_checkpoints/FuseNet _colab-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5bd16b8a", 6 | "metadata": {}, 7 | "source": [ 8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "0bbea43e", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!git clone https://github.com/mindflow-institue/FuseNet.git\n", 19 | "%cd ./FuseNet" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "b2c0cc45", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import argparse\n", 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "import torch.nn.functional as F\n", 33 | "import torch.optim as optim\n", 34 | "import torchvision\n", 35 | "import torchvision.transforms as T\n", 36 | "\n", 37 | "import cv2\n", 38 | "import sys\n", 39 | "import os\n", 40 | "import numpy as np\n", 41 | "import random\n", 42 | "import glob\n", 43 | "from matplotlib import pyplot as plt\n", 44 | "\n", 45 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n", 46 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n", 47 | "\n", 48 | "from einops import rearrange\n", 49 | "from einops.layers.torch import Rearrange" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "62a6f1f4", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "use_cuda = torch.cuda.is_available()\n", 60 | "\n", 61 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n", 62 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n", 63 | " help='number of channels')\n", 64 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n", 65 | " help='number of maximum iterations')\n", 66 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n", 67 | " help='minimum number of labels')\n", 68 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n", 69 | " help='learning rate')\n", 70 | "\n", 71 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n", 72 | " help='input image folder path')\n", 73 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n", 74 | " help='whether to save output ot not')\n", 75 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n", 76 | " help='output folder path')\n", 77 | "\n", 78 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n", 79 | " help='Cross entropy loss weighting factor')\n", 80 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n", 81 | " help='Clip loss weighting factor')\n", 82 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n", 83 | " help='Boundary loss weighting factor')\n", 84 | "\n", 85 | "args = parser.parse_args(args=[])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "eefd34af", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "if args.save_output:\n", 96 | " SAVE_PATH = args.output_path\n", 97 | " os.makedirs(SAVE_PATH, exist_ok=True)\n", 98 | "\n", 99 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "8af58b72", 105 | "metadata": {}, 106 | "source": [ 107 | "# Loading Data" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "4688dcb4", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "IMG_PATH = args.input_path\n", 118 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n", 119 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "2d188db8", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "len(img_data), len(lbl_data)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "d29d88f9", 135 | "metadata": {}, 136 | "source": [ 137 | "# Model" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "e7d3b082", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "class Model(nn.Module):\n", 148 | " \"\"\"\n", 149 | " Args:\n", 150 | " input_dim (int): Dimension of the input data.\n", 151 | " image_embed (int): Dimension of the image embeddings.\n", 152 | " augmented_embed (int): Dimension of the augmented image embeddings.\n", 153 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n", 154 | " temperature (float): Temperature parameter to scale CLIP matrix.\n", 155 | " dropout (float): Dropout rate applied in the projection heads.\n", 156 | " beta (int): Downsampling factor.\n", 157 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n", 158 | " \"\"\"\n", 159 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n", 160 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n", 161 | " super(Model, self).__init__()\n", 162 | " \n", 163 | " input_H, input_W = input_size\n", 164 | " self.H = input_H\n", 165 | " \n", 166 | " self.beta = 16 # Downsampling factor\n", 167 | " self.alpha = 3 # Main path scaling factor\n", 168 | " self.img_enc = Encoder(input_dim, image_embed)\n", 169 | " self.aug_enc = Encoder(input_dim, image_embed)\n", 170 | " \n", 171 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n", 172 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n", 173 | " self.temperature = temperature\n", 174 | " \n", 175 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n", 176 | " value_channels=image_embed, height=input_H, width=input_W)\n", 177 | " \n", 178 | " \n", 179 | " self.patch_size = self.H//8 #32\n", 180 | " self.dim = image_embed\n", 181 | " patch_dim = self.dim * self.patch_size * self.patch_size\n", 182 | " \n", 183 | " self.to_patch_embedding_img = nn.Sequential(\n", 184 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 185 | " nn.Linear(patch_dim, self.dim))\n", 186 | " \n", 187 | " self.to_patch_embedding_aug = nn.Sequential(\n", 188 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 189 | " nn.Linear(patch_dim, self.dim)) \n", 190 | " \n", 191 | " self.bn1 = nn.BatchNorm2d(image_embed)\n", 192 | " self.bn2 = nn.BatchNorm2d(image_embed)\n", 193 | " \n", 194 | " \n", 195 | " def forward(self, x, augmented_x):\n", 196 | "\n", 197 | " # extract feature representations of each modality\n", 198 | " img_f = self.img_enc(x)\n", 199 | " aug_f = self.img_enc(augmented_x) \n", 200 | "\n", 201 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n", 202 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n", 203 | "\n", 204 | " # Getting Image and augmented image Embeddings (with same dimension)\n", 205 | " img_e = self.image_projection(img_f)\n", 206 | " aug_e = self.aug_projection(aug_f)\n", 207 | " \n", 208 | " # Calculating CLIP\n", 209 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 210 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 211 | " \n", 212 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n", 213 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n", 214 | " \n", 215 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n", 216 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n", 217 | " \n", 218 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n", 219 | " img_e_sim = img_e_norm @ img_e_norm.mT\n", 220 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n", 221 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n", 222 | " \n", 223 | " # Cross attention\n", 224 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n", 225 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n", 226 | " \n", 227 | " attn = attn_1 + attn_2\n", 228 | " \n", 229 | " _, edge1 = torch.max(attn, 1)\n", 230 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n", 231 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n", 232 | " _, edge2 = torch.max(attn_up, 1)\n", 233 | " edge = edge1 - edge2\n", 234 | "\n", 235 | " return edge, attn, clip_sim, clip_targets\n" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "id": "e45e692c", 241 | "metadata": {}, 242 | "source": [ 243 | "# Training" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "id": "e4808e95", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "img_size = 256" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "a4a4c435", 260 | "metadata": { 261 | "scrolled": false 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "for img_num, img_file in enumerate(img_data):\n", 266 | " \n", 267 | " ##### Read image #####\n", 268 | " image = read_image(img_file, img_size).to(device)\n", 269 | "\n", 270 | " ##### Laod Model #####\n", 271 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n", 272 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n", 273 | " beta=16, alpha=3).to(device)\n", 274 | " model.train()\n", 275 | "\n", 276 | " ##### Setteings #####\n", 277 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n", 278 | " \n", 279 | " loss_ce = torch.nn.CrossEntropyLoss()\n", 280 | " loss_s = torch.nn.L1Loss()\n", 281 | " \n", 282 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n", 283 | " label_colours = np.random.randint(255, size=(128, 3))\n", 284 | " \n", 285 | " \n", 286 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n", 287 | " aug_img = jitter(image)\n", 288 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n", 289 | " aug_img = aug_img.to(device)\n", 290 | " \n", 291 | " ##### Training #####\n", 292 | " for batch_idx in range(args.maxIter):\n", 293 | "\n", 294 | " optimizer.zero_grad()\n", 295 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n", 296 | " \n", 297 | " ### Output\n", 298 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n", 299 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 300 | " \n", 301 | " _, target = torch.max(output, 1)\n", 302 | " img_target = target.data.cpu().numpy()\n", 303 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 304 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 305 | " \n", 306 | " ### Cross-entropy loss function \n", 307 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n", 308 | " \n", 309 | " ### Boundary Loss\n", 310 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n", 311 | " \n", 312 | " ### CLIP loss \n", 313 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n", 314 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n", 315 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n", 316 | " \n", 317 | " ### Optimization \n", 318 | " loss = loss_ce_value + loss_clip + loss_edge\n", 319 | " loss.backward()\n", 320 | " optimizer.step()\n", 321 | " \n", 322 | " \n", 323 | " nLabels = len(np.unique(img_target))\n", 324 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n", 325 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n", 326 | " '| B:', round(loss_edge.item(), 4))\n", 327 | " \n", 328 | " if nLabels <= args.minLabels and batch_idx>=5:\n", 329 | " print (f\"Number of labels have reached {nLabels}\")\n", 330 | " break\n", 331 | " \n", 332 | "\n", 333 | " ##### Evaluate #####\n", 334 | " edge, output, _, _ = model(image, aug_img)\n", 335 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 336 | " _, target = torch.max(output, 1)\n", 337 | " img_target = target.data.cpu().numpy()\n", 338 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 339 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 340 | " \n", 341 | " \n", 342 | " ##### Visualization #####\n", 343 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n", 344 | " axes[0].imshow(img_eval_output)\n", 345 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n", 346 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n", 347 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n", 348 | " axes[0].set_title('Prediction')\n", 349 | " axes[1].set_title('Input Image')\n", 350 | " axes[2].set_title('Augmented Image')\n", 351 | " axes[3].set_title('Edge SR') \n", 352 | " axes[0].axis('off')\n", 353 | " axes[1].axis('off')\n", 354 | " axes[2].axis('off')\n", 355 | " axes[3].axis('off')\n", 356 | " plt.show()\n", 357 | " \n", 358 | " if args.save_output:\n", 359 | " name = os.path.basename(img_file).split('.')[0]\n", 360 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n", 361 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 362 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 363 | " \n", 364 | " print('-------------------------------', '\\n')" 365 | ] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "Python 3 (ipykernel)", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.11.3" 385 | } 386 | }, 387 | "nbformat": 4, 388 | "nbformat_minor": 5 389 | } 390 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/FuseNet-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5bd16b8a", 6 | "metadata": {}, 7 | "source": [ 8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "b2c0cc45", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import argparse\n", 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "import torch.nn.functional as F\n", 22 | "import torch.optim as optim\n", 23 | "import torchvision\n", 24 | "import torchvision.transforms as T\n", 25 | "\n", 26 | "import cv2\n", 27 | "import sys\n", 28 | "import os\n", 29 | "import numpy as np\n", 30 | "import random\n", 31 | "import glob\n", 32 | "from matplotlib import pyplot as plt\n", 33 | "\n", 34 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n", 35 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n", 36 | "\n", 37 | "from einops import rearrange\n", 38 | "from einops.layers.torch import Rearrange" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "62a6f1f4", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "use_cuda = torch.cuda.is_available()\n", 49 | "\n", 50 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n", 51 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n", 52 | " help='number of channels')\n", 53 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n", 54 | " help='number of maximum iterations')\n", 55 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n", 56 | " help='minimum number of labels')\n", 57 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n", 58 | " help='learning rate')\n", 59 | "\n", 60 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n", 61 | " help='input image folder path')\n", 62 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n", 63 | " help='whether to save output ot not')\n", 64 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n", 65 | " help='output folder path')\n", 66 | "\n", 67 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n", 68 | " help='Cross entropy loss weighting factor')\n", 69 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n", 70 | " help='Clip loss weighting factor')\n", 71 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n", 72 | " help='Boundary loss weighting factor')\n", 73 | "\n", 74 | "args = parser.parse_args(args=[])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "eefd34af", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "if args.save_output:\n", 85 | " SAVE_PATH = args.output_path\n", 86 | " os.makedirs(SAVE_PATH, exist_ok=True)\n", 87 | "\n", 88 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "8af58b72", 94 | "metadata": {}, 95 | "source": [ 96 | "# Loading Data" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "4688dcb4", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "IMG_PATH = args.input_path\n", 107 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n", 108 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "2d188db8", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "len(img_data), len(lbl_data)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "d29d88f9", 124 | "metadata": {}, 125 | "source": [ 126 | "# Model" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "e7d3b082", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "class Model(nn.Module):\n", 137 | " \"\"\"\n", 138 | " Args:\n", 139 | " input_dim (int): Dimension of the input data.\n", 140 | " image_embed (int): Dimension of the image embeddings.\n", 141 | " augmented_embed (int): Dimension of the augmented image embeddings.\n", 142 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n", 143 | " temperature (float): Temperature parameter to scale CLIP matrix.\n", 144 | " dropout (float): Dropout rate applied in the projection heads.\n", 145 | " beta (int): Downsampling factor.\n", 146 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n", 147 | " \"\"\"\n", 148 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n", 149 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n", 150 | " super(Model, self).__init__()\n", 151 | " \n", 152 | " input_H, input_W = input_size\n", 153 | " self.H = input_H\n", 154 | " \n", 155 | " self.beta = 16 # Downsampling factor\n", 156 | " self.alpha = 3 # Main path scaling factor\n", 157 | " self.img_enc = Encoder(input_dim, image_embed)\n", 158 | " self.aug_enc = Encoder(input_dim, image_embed)\n", 159 | " \n", 160 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n", 161 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n", 162 | " self.temperature = temperature\n", 163 | " \n", 164 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n", 165 | " value_channels=image_embed, height=input_H, width=input_W)\n", 166 | " \n", 167 | " \n", 168 | " self.patch_size = self.H//8 #32\n", 169 | " self.dim = image_embed\n", 170 | " patch_dim = self.dim * self.patch_size * self.patch_size\n", 171 | " \n", 172 | " self.to_patch_embedding_img = nn.Sequential(\n", 173 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 174 | " nn.Linear(patch_dim, self.dim))\n", 175 | " \n", 176 | " self.to_patch_embedding_aug = nn.Sequential(\n", 177 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 178 | " nn.Linear(patch_dim, self.dim)) \n", 179 | " \n", 180 | " self.bn1 = nn.BatchNorm2d(image_embed)\n", 181 | " self.bn2 = nn.BatchNorm2d(image_embed)\n", 182 | " \n", 183 | " \n", 184 | " def forward(self, x, augmented_x):\n", 185 | "\n", 186 | " # extract feature representations of each modality\n", 187 | " img_f = self.img_enc(x)\n", 188 | " aug_f = self.img_enc(augmented_x) \n", 189 | "\n", 190 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n", 191 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n", 192 | "\n", 193 | " # Getting Image and augmented image Embeddings (with same dimension)\n", 194 | " img_e = self.image_projection(img_f)\n", 195 | " aug_e = self.aug_projection(aug_f)\n", 196 | " \n", 197 | " # Calculating CLIP\n", 198 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 199 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 200 | " \n", 201 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n", 202 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n", 203 | " \n", 204 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n", 205 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n", 206 | " \n", 207 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n", 208 | " img_e_sim = img_e_norm @ img_e_norm.mT\n", 209 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n", 210 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n", 211 | " \n", 212 | " # Cross attention\n", 213 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n", 214 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n", 215 | " \n", 216 | " attn = attn_1 + attn_2\n", 217 | " \n", 218 | " _, edge1 = torch.max(attn, 1)\n", 219 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n", 220 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n", 221 | " _, edge2 = torch.max(attn_up, 1)\n", 222 | " edge = edge1 - edge2\n", 223 | "\n", 224 | " return edge, attn, clip_sim, clip_targets\n" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "e45e692c", 230 | "metadata": {}, 231 | "source": [ 232 | "# Training" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "e4808e95", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "img_size = 256" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "id": "a4a4c435", 249 | "metadata": { 250 | "scrolled": false 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "for img_num, img_file in enumerate(img_data):\n", 255 | " \n", 256 | " ##### Read image #####\n", 257 | " image = read_image(img_file, img_size).to(device)\n", 258 | "\n", 259 | " ##### Laod Model #####\n", 260 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n", 261 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n", 262 | " beta=16, alpha=3).to(device)\n", 263 | " model.train()\n", 264 | "\n", 265 | " ##### Setteings #####\n", 266 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n", 267 | " \n", 268 | " loss_ce = torch.nn.CrossEntropyLoss()\n", 269 | " loss_s = torch.nn.L1Loss()\n", 270 | " \n", 271 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n", 272 | " label_colours = np.random.randint(255, size=(128, 3))\n", 273 | " \n", 274 | " \n", 275 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n", 276 | " aug_img = jitter(image)\n", 277 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n", 278 | " aug_img = aug_img.to(device)\n", 279 | " \n", 280 | " ##### Training #####\n", 281 | " for batch_idx in range(args.maxIter):\n", 282 | "\n", 283 | " optimizer.zero_grad()\n", 284 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n", 285 | " \n", 286 | " ### Output\n", 287 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n", 288 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 289 | " \n", 290 | " _, target = torch.max(output, 1)\n", 291 | " img_target = target.data.cpu().numpy()\n", 292 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 293 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 294 | " \n", 295 | " ### Cross-entropy loss function \n", 296 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n", 297 | " \n", 298 | " ### Boundary Loss\n", 299 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n", 300 | " \n", 301 | " ### CLIP loss \n", 302 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n", 303 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n", 304 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n", 305 | " \n", 306 | " ### Optimization \n", 307 | " loss = loss_ce_value + loss_clip + loss_edge\n", 308 | " loss.backward()\n", 309 | " optimizer.step()\n", 310 | " \n", 311 | " \n", 312 | " nLabels = len(np.unique(img_target))\n", 313 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n", 314 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n", 315 | " '| B:', round(loss_edge.item(), 4))\n", 316 | " \n", 317 | " if nLabels <= args.minLabels and batch_idx>=5:\n", 318 | " print (f\"Number of labels have reached {nLabels}\")\n", 319 | " break\n", 320 | " \n", 321 | "\n", 322 | " ##### Evaluate #####\n", 323 | " edge, output, _, _ = model(image, aug_img)\n", 324 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 325 | " _, target = torch.max(output, 1)\n", 326 | " img_target = target.data.cpu().numpy()\n", 327 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 328 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 329 | " \n", 330 | " \n", 331 | " ##### Visualization #####\n", 332 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n", 333 | " axes[0].imshow(img_eval_output)\n", 334 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n", 335 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n", 336 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n", 337 | " axes[0].set_title('Prediction')\n", 338 | " axes[1].set_title('Input Image')\n", 339 | " axes[2].set_title('Augmented Image')\n", 340 | " axes[3].set_title('Edge SR') \n", 341 | " axes[0].axis('off')\n", 342 | " axes[1].axis('off')\n", 343 | " axes[2].axis('off')\n", 344 | " axes[3].axis('off')\n", 345 | " plt.show()\n", 346 | " \n", 347 | " if args.save_output:\n", 348 | " name = os.path.basename(img_file).split('.')[0]\n", 349 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n", 350 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 351 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 352 | " \n", 353 | " print('-------------------------------', '\\n')" 354 | ] 355 | } 356 | ], 357 | "metadata": { 358 | "kernelspec": { 359 | "display_name": "Python 3 (ipykernel)", 360 | "language": "python", 361 | "name": "python3" 362 | }, 363 | "language_info": { 364 | "codemirror_mode": { 365 | "name": "ipython", 366 | "version": 3 367 | }, 368 | "file_extension": ".py", 369 | "mimetype": "text/x-python", 370 | "name": "python", 371 | "nbconvert_exporter": "python", 372 | "pygments_lexer": "ipython3", 373 | "version": "3.11.3" 374 | } 375 | }, 376 | "nbformat": 4, 377 | "nbformat_minor": 5 378 | } 379 | -------------------------------------------------------------------------------- /FuseNet _colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5bd16b8a", 6 | "metadata": {}, 7 | "source": [ 8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "0bbea43e", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!git clone https://github.com/mindflow-institue/FuseNet.git\n", 19 | "%cd ./FuseNet" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "b2c0cc45", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import argparse\n", 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "import torch.nn.functional as F\n", 33 | "import torch.optim as optim\n", 34 | "import torchvision\n", 35 | "import torchvision.transforms as T\n", 36 | "\n", 37 | "import cv2\n", 38 | "import sys\n", 39 | "import os\n", 40 | "import numpy as np\n", 41 | "import random\n", 42 | "import glob\n", 43 | "from matplotlib import pyplot as plt\n", 44 | "\n", 45 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n", 46 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n", 47 | "\n", 48 | "from einops import rearrange\n", 49 | "from einops.layers.torch import Rearrange" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "62a6f1f4", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "use_cuda = torch.cuda.is_available()\n", 60 | "\n", 61 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n", 62 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n", 63 | " help='number of channels')\n", 64 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n", 65 | " help='number of maximum iterations')\n", 66 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n", 67 | " help='minimum number of labels')\n", 68 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n", 69 | " help='learning rate')\n", 70 | "\n", 71 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n", 72 | " help='input image folder path')\n", 73 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n", 74 | " help='whether to save output ot not')\n", 75 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n", 76 | " help='output folder path')\n", 77 | "\n", 78 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n", 79 | " help='Cross entropy loss weighting factor')\n", 80 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n", 81 | " help='Clip loss weighting factor')\n", 82 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n", 83 | " help='Boundary loss weighting factor')\n", 84 | "\n", 85 | "args = parser.parse_args(args=[])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "eefd34af", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "if args.save_output:\n", 96 | " SAVE_PATH = args.output_path\n", 97 | " os.makedirs(SAVE_PATH, exist_ok=True)\n", 98 | "\n", 99 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "8af58b72", 105 | "metadata": {}, 106 | "source": [ 107 | "# Loading Data" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "4688dcb4", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "IMG_PATH = args.input_path\n", 118 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n", 119 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "2d188db8", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "len(img_data), len(lbl_data)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "d29d88f9", 135 | "metadata": {}, 136 | "source": [ 137 | "# Model" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "e7d3b082", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "class Model(nn.Module):\n", 148 | " \"\"\"\n", 149 | " Args:\n", 150 | " input_dim (int): Dimension of the input data.\n", 151 | " image_embed (int): Dimension of the image embeddings.\n", 152 | " augmented_embed (int): Dimension of the augmented image embeddings.\n", 153 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n", 154 | " temperature (float): Temperature parameter to scale CLIP matrix.\n", 155 | " dropout (float): Dropout rate applied in the projection heads.\n", 156 | " beta (int): Downsampling factor.\n", 157 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n", 158 | " \"\"\"\n", 159 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n", 160 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n", 161 | " super(Model, self).__init__()\n", 162 | " \n", 163 | " input_H, input_W = input_size\n", 164 | " self.H = input_H\n", 165 | " \n", 166 | " self.beta = 16 # Downsampling factor\n", 167 | " self.alpha = 3 # Main path scaling factor\n", 168 | " self.img_enc = Encoder(input_dim, image_embed)\n", 169 | " self.aug_enc = Encoder(input_dim, image_embed)\n", 170 | " \n", 171 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n", 172 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n", 173 | " self.temperature = temperature\n", 174 | " \n", 175 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n", 176 | " value_channels=image_embed, height=input_H, width=input_W)\n", 177 | " \n", 178 | " \n", 179 | " self.patch_size = self.H//8 #32\n", 180 | " self.dim = image_embed\n", 181 | " patch_dim = self.dim * self.patch_size * self.patch_size\n", 182 | " \n", 183 | " self.to_patch_embedding_img = nn.Sequential(\n", 184 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 185 | " nn.Linear(patch_dim, self.dim))\n", 186 | " \n", 187 | " self.to_patch_embedding_aug = nn.Sequential(\n", 188 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 189 | " nn.Linear(patch_dim, self.dim)) \n", 190 | " \n", 191 | " self.bn1 = nn.BatchNorm2d(image_embed)\n", 192 | " self.bn2 = nn.BatchNorm2d(image_embed)\n", 193 | " \n", 194 | " \n", 195 | " def forward(self, x, augmented_x):\n", 196 | "\n", 197 | " # extract feature representations of each modality\n", 198 | " img_f = self.img_enc(x)\n", 199 | " aug_f = self.img_enc(augmented_x) \n", 200 | "\n", 201 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n", 202 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n", 203 | "\n", 204 | " # Getting Image and augmented image Embeddings (with same dimension)\n", 205 | " img_e = self.image_projection(img_f)\n", 206 | " aug_e = self.aug_projection(aug_f)\n", 207 | " \n", 208 | " # Calculating CLIP\n", 209 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 210 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 211 | " \n", 212 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n", 213 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n", 214 | " \n", 215 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n", 216 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n", 217 | " \n", 218 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n", 219 | " img_e_sim = img_e_norm @ img_e_norm.mT\n", 220 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n", 221 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n", 222 | " \n", 223 | " # Cross attention\n", 224 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n", 225 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n", 226 | " \n", 227 | " attn = attn_1 + attn_2\n", 228 | " \n", 229 | " _, edge1 = torch.max(attn, 1)\n", 230 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n", 231 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n", 232 | " _, edge2 = torch.max(attn_up, 1)\n", 233 | " edge = edge1 - edge2\n", 234 | "\n", 235 | " return edge, attn, clip_sim, clip_targets\n" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "id": "e45e692c", 241 | "metadata": {}, 242 | "source": [ 243 | "# Training" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "id": "e4808e95", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "img_size = 256" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "a4a4c435", 260 | "metadata": { 261 | "scrolled": false 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "for img_num, img_file in enumerate(img_data):\n", 266 | " \n", 267 | " ##### Read image #####\n", 268 | " image = read_image(img_file, img_size).to(device)\n", 269 | "\n", 270 | " ##### Laod Model #####\n", 271 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n", 272 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n", 273 | " beta=16, alpha=3).to(device)\n", 274 | " model.train()\n", 275 | "\n", 276 | " ##### Setteings #####\n", 277 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n", 278 | " \n", 279 | " loss_ce = torch.nn.CrossEntropyLoss()\n", 280 | " loss_s = torch.nn.L1Loss()\n", 281 | " \n", 282 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n", 283 | " label_colours = np.random.randint(255, size=(128, 3))\n", 284 | " \n", 285 | " \n", 286 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n", 287 | " aug_img = jitter(image)\n", 288 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n", 289 | " aug_img = aug_img.to(device)\n", 290 | " \n", 291 | " ##### Training #####\n", 292 | " for batch_idx in range(args.maxIter):\n", 293 | "\n", 294 | " optimizer.zero_grad()\n", 295 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n", 296 | " \n", 297 | " ### Output\n", 298 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n", 299 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 300 | " \n", 301 | " _, target = torch.max(output, 1)\n", 302 | " img_target = target.data.cpu().numpy()\n", 303 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 304 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 305 | " \n", 306 | " ### Cross-entropy loss function \n", 307 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n", 308 | " \n", 309 | " ### Boundary Loss\n", 310 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n", 311 | " \n", 312 | " ### CLIP loss \n", 313 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n", 314 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n", 315 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n", 316 | " \n", 317 | " ### Optimization \n", 318 | " loss = loss_ce_value + loss_clip + loss_edge\n", 319 | " loss.backward()\n", 320 | " optimizer.step()\n", 321 | " \n", 322 | " \n", 323 | " nLabels = len(np.unique(img_target))\n", 324 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n", 325 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n", 326 | " '| B:', round(loss_edge.item(), 4))\n", 327 | " \n", 328 | " if nLabels <= args.minLabels and batch_idx>=5:\n", 329 | " print (f\"Number of labels have reached {nLabels}\")\n", 330 | " break\n", 331 | " \n", 332 | "\n", 333 | " ##### Evaluate #####\n", 334 | " edge, output, _, _ = model(image, aug_img)\n", 335 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 336 | " _, target = torch.max(output, 1)\n", 337 | " img_target = target.data.cpu().numpy()\n", 338 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 339 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 340 | " \n", 341 | " \n", 342 | " ##### Visualization #####\n", 343 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n", 344 | " axes[0].imshow(img_eval_output)\n", 345 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n", 346 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n", 347 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n", 348 | " axes[0].set_title('Prediction')\n", 349 | " axes[1].set_title('Input Image')\n", 350 | " axes[2].set_title('Augmented Image')\n", 351 | " axes[3].set_title('Edge SR') \n", 352 | " axes[0].axis('off')\n", 353 | " axes[1].axis('off')\n", 354 | " axes[2].axis('off')\n", 355 | " axes[3].axis('off')\n", 356 | " plt.show()\n", 357 | " \n", 358 | " if args.save_output:\n", 359 | " name = os.path.basename(img_file).split('.')[0]\n", 360 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n", 361 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 362 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 363 | " \n", 364 | " print('-------------------------------', '\\n')" 365 | ] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "Python 3 (ipykernel)", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.11.3" 385 | } 386 | }, 387 | "nbformat": 4, 388 | "nbformat_minor": 5 389 | } 390 | -------------------------------------------------------------------------------- /FuseNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5bd16b8a", 6 | "metadata": {}, 7 | "source": [ 8 | "# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation
" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "b2c0cc45", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import argparse\n", 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "import torch.nn.functional as F\n", 22 | "import torch.optim as optim\n", 23 | "import torchvision\n", 24 | "import torchvision.transforms as T\n", 25 | "\n", 26 | "import cv2\n", 27 | "import sys\n", 28 | "import os\n", 29 | "import numpy as np\n", 30 | "import random\n", 31 | "import glob\n", 32 | "from matplotlib import pyplot as plt\n", 33 | "\n", 34 | "from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy\n", 35 | "from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock\n", 36 | "\n", 37 | "from einops import rearrange\n", 38 | "from einops.layers.torch import Rearrange" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "62a6f1f4", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "use_cuda = torch.cuda.is_available()\n", 49 | "\n", 50 | "parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')\n", 51 | "parser.add_argument('--nChannel', metavar='N', default=64, type=int, \n", 52 | " help='number of channels')\n", 53 | "parser.add_argument('--maxIter', metavar='T', default=50, type=int, \n", 54 | " help='number of maximum iterations')\n", 55 | "parser.add_argument('--minLabels', metavar='minL', default=3, type=int, \n", 56 | " help='minimum number of labels')\n", 57 | "parser.add_argument('--lr', metavar='LR', default=0.005, type=float, \n", 58 | " help='learning rate')\n", 59 | "\n", 60 | "parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', \n", 61 | " help='input image folder path')\n", 62 | "parser.add_argument('--save_output', metavar='SAVE', default=True, \n", 63 | " help='whether to save output ot not')\n", 64 | "parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', \n", 65 | " help='output folder path')\n", 66 | "\n", 67 | "parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, \n", 68 | " help='Cross entropy loss weighting factor')\n", 69 | "parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, \n", 70 | " help='Clip loss weighting factor')\n", 71 | "parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, \n", 72 | " help='Boundary loss weighting factor')\n", 73 | "\n", 74 | "args = parser.parse_args(args=[])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "eefd34af", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "if args.save_output:\n", 85 | " SAVE_PATH = args.output_path\n", 86 | " os.makedirs(SAVE_PATH, exist_ok=True)\n", 87 | "\n", 88 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "8af58b72", 94 | "metadata": {}, 95 | "source": [ 96 | "# Loading Data" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "4688dcb4", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "IMG_PATH = args.input_path\n", 107 | "img_data = sorted(glob.glob(IMG_PATH + 'image/*'))\n", 108 | "lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "2d188db8", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "len(img_data), len(lbl_data)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "d29d88f9", 124 | "metadata": {}, 125 | "source": [ 126 | "# Model" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "e7d3b082", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "class Model(nn.Module):\n", 137 | " \"\"\"\n", 138 | " Args:\n", 139 | " input_dim (int): Dimension of the input data.\n", 140 | " image_embed (int): Dimension of the image embeddings.\n", 141 | " augmented_embed (int): Dimension of the augmented image embeddings.\n", 142 | " input_size (tuple): Tuple representing the input size of the images (height, width).\n", 143 | " temperature (float): Temperature parameter to scale CLIP matrix.\n", 144 | " dropout (float): Dropout rate applied in the projection heads.\n", 145 | " beta (int): Downsampling factor.\n", 146 | " alpha (int): Scaling factor applied to the main path in the cross-attention block.\n", 147 | " \"\"\"\n", 148 | " def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),\n", 149 | " temperature=5.0, dropout=0.1, beta=16, alpha=3):\n", 150 | " super(Model, self).__init__()\n", 151 | " \n", 152 | " input_H, input_W = input_size\n", 153 | " self.H = input_H\n", 154 | " \n", 155 | " self.beta = 16 # Downsampling factor\n", 156 | " self.alpha = 3 # Main path scaling factor\n", 157 | " self.img_enc = Encoder(input_dim, image_embed)\n", 158 | " self.aug_enc = Encoder(input_dim, image_embed)\n", 159 | " \n", 160 | " self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)\n", 161 | " self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)\n", 162 | " self.temperature = temperature\n", 163 | " \n", 164 | " self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,\n", 165 | " value_channels=image_embed, height=input_H, width=input_W)\n", 166 | " \n", 167 | " \n", 168 | " self.patch_size = self.H//8 #32\n", 169 | " self.dim = image_embed\n", 170 | " patch_dim = self.dim * self.patch_size * self.patch_size\n", 171 | " \n", 172 | " self.to_patch_embedding_img = nn.Sequential(\n", 173 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 174 | " nn.Linear(patch_dim, self.dim))\n", 175 | " \n", 176 | " self.to_patch_embedding_aug = nn.Sequential(\n", 177 | " Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),\n", 178 | " nn.Linear(patch_dim, self.dim)) \n", 179 | " \n", 180 | " self.bn1 = nn.BatchNorm2d(image_embed)\n", 181 | " self.bn2 = nn.BatchNorm2d(image_embed)\n", 182 | " \n", 183 | " \n", 184 | " def forward(self, x, augmented_x):\n", 185 | "\n", 186 | " # extract feature representations of each modality\n", 187 | " img_f = self.img_enc(x)\n", 188 | " aug_f = self.img_enc(augmented_x) \n", 189 | "\n", 190 | " img_f = rearrange(img_f, 'b c h w -> b (h w) c')\n", 191 | " aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')\n", 192 | "\n", 193 | " # Getting Image and augmented image Embeddings (with same dimension)\n", 194 | " img_e = self.image_projection(img_f)\n", 195 | " aug_e = self.aug_projection(aug_f)\n", 196 | " \n", 197 | " # Calculating CLIP\n", 198 | " img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 199 | " aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)\n", 200 | " \n", 201 | " img_e_patch = self.to_patch_embedding_img(img_e_r) \n", 202 | " aug_e_patch = self.to_patch_embedding_aug(aug_e_r) \n", 203 | " \n", 204 | " img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True) \n", 205 | " aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)\n", 206 | " \n", 207 | " clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature\n", 208 | " img_e_sim = img_e_norm @ img_e_norm.mT\n", 209 | " aug_e_sim = aug_e_norm @ aug_e_norm.mT\n", 210 | " clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)\n", 211 | " \n", 212 | " # Cross attention\n", 213 | " attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)\n", 214 | " attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)\n", 215 | " \n", 216 | " attn = attn_1 + attn_2\n", 217 | " \n", 218 | " _, edge1 = torch.max(attn, 1)\n", 219 | " attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)\n", 220 | " attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)\n", 221 | " _, edge2 = torch.max(attn_up, 1)\n", 222 | " edge = edge1 - edge2\n", 223 | "\n", 224 | " return edge, attn, clip_sim, clip_targets\n" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "e45e692c", 230 | "metadata": {}, 231 | "source": [ 232 | "# Training" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "e4808e95", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "img_size = 256" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "id": "a4a4c435", 249 | "metadata": { 250 | "scrolled": false 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "for img_num, img_file in enumerate(img_data):\n", 255 | " \n", 256 | " ##### Read image #####\n", 257 | " image = read_image(img_file, img_size).to(device)\n", 258 | "\n", 259 | " ##### Laod Model #####\n", 260 | " model = Model(input_dim=3, image_embed=64, augmented_embed=64,\n", 261 | " input_size=(img_size, img_size), temperature=5.0, dropout=0.1,\n", 262 | " beta=16, alpha=3).to(device)\n", 263 | " model.train()\n", 264 | "\n", 265 | " ##### Setteings #####\n", 266 | " zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)\n", 267 | " \n", 268 | " loss_ce = torch.nn.CrossEntropyLoss()\n", 269 | " loss_s = torch.nn.L1Loss()\n", 270 | " \n", 271 | " optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)\n", 272 | " label_colours = np.random.randint(255, size=(128, 3))\n", 273 | " \n", 274 | " \n", 275 | " jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])\n", 276 | " aug_img = jitter(image)\n", 277 | " aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)\n", 278 | " aug_img = aug_img.to(device)\n", 279 | " \n", 280 | " ##### Training #####\n", 281 | " for batch_idx in range(args.maxIter):\n", 282 | "\n", 283 | " optimizer.zero_grad()\n", 284 | " edge, output, clip_logits, clip_targets = model(image, aug_img)\n", 285 | " \n", 286 | " ### Output\n", 287 | " output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0] \n", 288 | " output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 289 | " \n", 290 | " _, target = torch.max(output, 1)\n", 291 | " img_target = target.data.cpu().numpy()\n", 292 | " img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 293 | " img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 294 | " \n", 295 | " ### Cross-entropy loss function \n", 296 | " loss_ce_value = args.loss_ce_coef * loss_ce(output, target)\n", 297 | " \n", 298 | " ### Boundary Loss\n", 299 | " loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img) \n", 300 | " \n", 301 | " ### CLIP loss \n", 302 | " aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')\n", 303 | " img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')\n", 304 | " loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)\n", 305 | " \n", 306 | " ### Optimization \n", 307 | " loss = loss_ce_value + loss_clip + loss_edge\n", 308 | " loss.backward()\n", 309 | " optimizer.step()\n", 310 | " \n", 311 | " \n", 312 | " nLabels = len(np.unique(img_target))\n", 313 | " print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),\n", 314 | " '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),\n", 315 | " '| B:', round(loss_edge.item(), 4))\n", 316 | " \n", 317 | " if nLabels <= args.minLabels and batch_idx>=5:\n", 318 | " print (f\"Number of labels have reached {nLabels}\")\n", 319 | " break\n", 320 | " \n", 321 | "\n", 322 | " ##### Evaluate #####\n", 323 | " edge, output, _, _ = model(image, aug_img)\n", 324 | " output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)\n", 325 | " _, target = torch.max(output, 1)\n", 326 | " img_target = target.data.cpu().numpy()\n", 327 | " img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])\n", 328 | " img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)\n", 329 | " \n", 330 | " \n", 331 | " ##### Visualization #####\n", 332 | " fig, axes = plt.subplots(1, 4, figsize=(8, 8))\n", 333 | " axes[0].imshow(img_eval_output)\n", 334 | " axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])\n", 335 | " axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])\n", 336 | " axes[3].imshow(edge[0].cpu().detach().numpy())\n", 337 | " axes[0].set_title('Prediction')\n", 338 | " axes[1].set_title('Input Image')\n", 339 | " axes[2].set_title('Augmented Image')\n", 340 | " axes[3].set_title('Edge SR') \n", 341 | " axes[0].axis('off')\n", 342 | " axes[1].axis('off')\n", 343 | " axes[2].axis('off')\n", 344 | " axes[3].axis('off')\n", 345 | " plt.show()\n", 346 | " \n", 347 | " if args.save_output:\n", 348 | " name = os.path.basename(img_file).split('.')[0]\n", 349 | " cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)\n", 350 | " cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 351 | " cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)\n", 352 | " \n", 353 | " print('-------------------------------', '\\n')" 354 | ] 355 | } 356 | ], 357 | "metadata": { 358 | "kernelspec": { 359 | "display_name": "Python 3 (ipykernel)", 360 | "language": "python", 361 | "name": "python3" 362 | }, 363 | "language_info": { 364 | "codemirror_mode": { 365 | "name": "ipython", 366 | "version": 3 367 | }, 368 | "file_extension": ".py", 369 | "mimetype": "text/x-python", 370 | "name": "python", 371 | "nbconvert_exporter": "python", 372 | "pygments_lexer": "ipython3", 373 | "version": "3.11.3" 374 | } 375 | }, 376 | "nbformat": 4, 377 | "nbformat_minor": 5 378 | } 379 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 X-MindFlow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2311.13069-b31b1b.svg)](https://arxiv.org/abs/2311.13069) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindflow-institue/FuseNet/blob/main/FuseNet_colab.ipynb) 4 | 5 | 6 | Semantic segmentation, a crucial task in computer vision, often relies on labor-intensive and costly annotated datasets for training. In response to this challenge, we introduce FuseNet, a dual-stream framework for self-supervised semantic segmentation that eliminates the need for manual annotation. FuseNet leverages the shared semantic dependencies between the original and augmented images to create a clustering space, effectively assigning pixels to semantically related clusters, and ultimately generating the segmentation map. Additionally, FuseNet incorporates a cross-modal fusion technique that extends the principles of CLIP by replacing textual data with augmented images. This approach enables the model to learn complex visual representations, enhancing robustness against variations similar to CLIP’s text invariance. To further improve edge alignment and spatial consistency between neighboring pixels, we introduce an edge refinement loss. This loss function considers edge information to enhance spatial coherence, facilitating the grouping of nearby pixels with similar visual features. Extensive experiments on skin lesion and lung segmentation datasets demonstrate the effectiveness of our method. 7 | 8 |
9 | 10 | ![FuseNet](https://github.com/xmindflow/FuseNet/assets/61879630/79338d04-51f2-475d-8c16-eb5b6323a2aa) 11 | 12 |
13 | 14 | 15 | 16 | ## Updates 17 | - If you found this paper useful, please consider checking out our previously accepted papers at MIDL and ICCV: 18 | `MS-Former` [[Paper](https://openreview.net/forum?id=pp2raGSU3Wx)] [[GitHub](https://github.com/mindflow-institue/MS-Former)], and `S3-Net` [[Paper](https://openreview.net/forum?id=pp2raGSU3Wx)] [[GitHub](https://github.com/mindflow-institue/MS-Former)] ♥️✌🏻 19 | 20 | - November 22, 2023: First release of the code. 21 | 22 | ## Installation 23 | 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Run Demo 29 | Put your input images in the ```input_images/image``` folder and just simply run the ```FuseNet.ipynb``` notebook ;) 30 | 31 | ## Experiments 32 | 33 |

34 | 35 |

36 | 37 | 38 | ## Citation 39 | If this code helps with your research, please consider citing the following paper: 40 |
41 | 42 | ```bibtex 43 | @article{kazerouni2023fusenet, 44 | title={FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation}, 45 | author={Kazerouni, Amirhossein and Karimijafarbigloo, Sanaz and Azad, Reza and Velichko, Yury and Bagci, Ulas and Merhof, Dorit}, 46 | journal={arXiv preprint arXiv:2311.13069}, 47 | year={2023} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /input_images/GT/Test_1_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/GT/Test_1_GT.bmp -------------------------------------------------------------------------------- /input_images/GT/Test_2_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/GT/Test_2_GT.bmp -------------------------------------------------------------------------------- /input_images/GT/Test_3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/GT/Test_3.bmp -------------------------------------------------------------------------------- /input_images/GT/Test_4_GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/GT/Test_4_GT.png -------------------------------------------------------------------------------- /input_images/image/Test_1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/image/Test_1.bmp -------------------------------------------------------------------------------- /input_images/image/Test_2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/image/Test_2.bmp -------------------------------------------------------------------------------- /input_images/image/Test_3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/image/Test_3.bmp -------------------------------------------------------------------------------- /input_images/image/Test_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/image/Test_4.png -------------------------------------------------------------------------------- /input_images/image/Test_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/FuseNet/e8ec1b4183317f467b0dd18b5ab34c00276c510a/input_images/image/Test_5.png -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, input_dim, nChannel): 9 | super(Encoder, self).__init__() 10 | self.conv1 = nn.Conv2d(input_dim, nChannel, kernel_size=3, stride=1, padding=1) 11 | self.bn1 = nn.BatchNorm2d(nChannel) 12 | 13 | self.conv2 = nn.Conv2d(nChannel, nChannel, kernel_size=1, stride=1, padding=0, groups=nChannel) 14 | self.bn2 = nn.BatchNorm2d(nChannel) 15 | 16 | def forward(self, x): 17 | x = self.conv1(x) 18 | x = F.relu(x) 19 | x = self.bn1(x) 20 | 21 | x = self.conv2(x) 22 | x = self.bn2(x) 23 | return x 24 | 25 | 26 | class ProjectionHead(nn.Module): 27 | def __init__( 28 | self, embedding_dim, projection_dim, dropout=0.1): 29 | super().__init__() 30 | self.projection = nn.Linear(embedding_dim, projection_dim) 31 | self.gelu = nn.GELU() 32 | self.fc = nn.Linear(projection_dim, projection_dim) 33 | self.dropout = nn.Dropout(dropout) 34 | self.layer_norm = nn.LayerNorm(projection_dim) 35 | 36 | def forward(self, x): 37 | projected = self.projection(x) 38 | x = self.gelu(projected) 39 | x = self.fc(x) 40 | x = self.dropout(x) 41 | x = x + projected 42 | x = self.layer_norm(x) 43 | return x 44 | 45 | 46 | class DWConv(nn.Module): 47 | def __init__(self, dim): 48 | super().__init__() 49 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 50 | 51 | def forward(self, x: torch.Tensor, H, W) -> torch.Tensor: 52 | B, N, C = x.shape 53 | tx = x.transpose(1, 2).view(B, C, H, W) 54 | conv_x = self.dwconv(tx) 55 | return conv_x.flatten(2).transpose(1, 2) 56 | 57 | 58 | class MixFFN_skip(nn.Module): 59 | def __init__(self, c1, c2): 60 | super().__init__() 61 | self.fc1 = nn.Linear(c1, c2) 62 | self.dwconv = DWConv(c2) 63 | self.act = nn.GELU() 64 | self.fc2 = nn.Linear(c2, c1) 65 | self.norm1 = nn.LayerNorm(c2) 66 | 67 | def forward(self, x, H, W): 68 | ax = self.act(self.norm1(self.dwconv(self.fc1(x), H, W) + self.fc1(x))) 69 | out = self.fc2(ax) 70 | return out 71 | 72 | 73 | class CrossAttention(nn.Module): 74 | """ 75 | args: 76 | in_channels: (int) : Embedding Dimension. 77 | key_channels: (int) : Key Embedding Dimension, Best: (in_channels). 78 | value_channels: (int) : Value Embedding Dimension, Best: (in_channels or in_channels//2). 79 | input: 80 | x : [B, D, H, W] 81 | output: 82 | Efficient Attention : [B, D, H, W] 83 | 84 | """ 85 | 86 | def __init__(self, in_channels, key_channels, value_channels, height, width,): 87 | super().__init__() 88 | self.in_channels = in_channels 89 | self.key_channels = key_channels 90 | self.value_channels = value_channels 91 | self.H = height 92 | self.W = width 93 | self.reprojection = nn.Conv2d(value_channels, in_channels*2, 1) 94 | self.norm = nn.LayerNorm(2 * in_channels) 95 | 96 | def forward(self, x1, x2): 97 | B, N, D = x1.size() 98 | 99 | # Efficient Attention 100 | keys = F.softmax(x1.transpose(1, 2), dim=2) 101 | queries = F.softmax(x1.transpose(1, 2), dim=1) 102 | values = x2.transpose(1, 2) 103 | context = keys @ values.transpose(1, 2) # dk*dv 104 | attended_value = (context.transpose(1, 2) @ queries).reshape(B, self.value_channels, self.H, self.W) # n*dv 105 | 106 | eff_attention = self.reprojection(attended_value).reshape(B, 2 * D, N).permute(0, 2, 1) 107 | eff_attention = self.norm(eff_attention) 108 | 109 | return eff_attention 110 | 111 | 112 | class CrossAttentionBlock(nn.Module): 113 | """ 114 | Input -> x1:[B, N, D] - N = H*W 115 | x2:[B, N, D] 116 | Output -> y:[B, N, D] 117 | D is half the size of the concatenated input (x1 from a lower level and x2 from the skip connection) 118 | """ 119 | 120 | def __init__(self, in_channels, key_channels, value_channels, height, width): 121 | super().__init__() 122 | self.norm1 = nn.LayerNorm(in_channels) 123 | self.H = height 124 | self.W = width 125 | self.attn = CrossAttention(in_channels, key_channels, value_channels, height, width) 126 | self.norm2 = nn.LayerNorm((in_channels * 2)) 127 | self.mlp = MixFFN_skip((in_channels * 2), int(in_channels * 4)) 128 | 129 | def forward(self, x1, x2): 130 | norm_1 = self.norm1(x1) 131 | norm_2 = self.norm1(x2) 132 | 133 | attn = self.attn(norm_1, norm_2) 134 | residual = torch.cat([x1, x2], dim=-1) 135 | tx = residual + attn 136 | mx = tx + self.mlp(self.norm2(tx), self.H, self.W) 137 | mx = rearrange(mx, 'b (h w) c -> b c h w', h=self.H, w=self.W) 138 | 139 | return mx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | einops 4 | cv2 5 | numpy 6 | tqdm -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torch.nn as nn 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def read_image(img_file, img_size): 10 | im = cv2.imread(img_file) 11 | im = cv2.resize(im, (img_size, img_size), interpolation=cv2.INTER_CUBIC) 12 | data = torch.from_numpy(np.array([im.transpose((2, 0, 1)).astype('float32')/255.])) 13 | 14 | return data 15 | 16 | 17 | def cross_entropy(preds, targets, reduction='none'): 18 | log_softmax = nn.LogSoftmax(dim=-1) 19 | loss = (-targets * log_softmax(preds)).sum(1) 20 | if reduction == "none": 21 | return loss 22 | elif reduction == "mean": 23 | return loss.mean() 24 | 25 | 26 | def create_mask(pred, GT): 27 | 28 | kernel = np.ones((5, 5), np.uint8) 29 | dilated_GT = cv2.dilate(GT, kernel, iterations = 4) 30 | 31 | mult = pred * GT 32 | unique, count = np.unique(mult[mult !=0], return_counts=True) 33 | cls= unique[np.argmax(count)] 34 | 35 | lesion = np.where(pred==cls, 1, 0) * dilated_GT 36 | 37 | return lesion 38 | 39 | 40 | def dice_metric(A, B): 41 | intersect = np.sum(A * B) 42 | fsum = np.sum(A) 43 | ssum = np.sum(B) 44 | dice = (2 * intersect ) / (fsum + ssum) 45 | 46 | return dice 47 | 48 | 49 | def hm_metric(A, B): 50 | intersection = A * B 51 | union = np.logical_or(A, B) 52 | hm_score = (np.sum(union) - np.sum(intersection)) / np.sum(union) 53 | 54 | return hm_score 55 | 56 | 57 | def xor_metric(A, GT): 58 | intersection = A * GT 59 | union = np.logical_or(A, GT) 60 | xor_score = (np.sum(union) - np.sum(intersection)) / np.sum(GT) 61 | 62 | return xor_score --------------------------------------------------------------------------------