├── scores ├── __init__.py └── compute_scores.py ├── algorithms ├── __init__.py ├── utils_dilated_tubules.py ├── utils.py ├── unsupervised_dcf.py └── oneshot_dcf.py ├── generate_annotations ├── __init__.py ├── __pycache__ │ ├── utils.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── delete_annotations.cpython-38.pyc ├── delete_annotations.py ├── extract_images.py └── get_annotations.py ├── .gitattributes ├── images_test ├── lion.jpg ├── flower0.jpg ├── flower1.jpg ├── human1.jpg ├── human2.jpg ├── human3.jpg ├── gt │ ├── human1.png │ ├── human2.png │ └── human3.png └── pineapple.jpg ├── folder_images_paper ├── pineapple.jpg ├── skin_lesions.png ├── tumor_region.png └── real_life_images.png ├── requirements.txt ├── .gitignore ├── config.py ├── README.md ├── models ├── models_architectures.py └── models.py └── utils.py /scores/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generate_annotations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored 2 | -------------------------------------------------------------------------------- /images_test/lion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/lion.jpg -------------------------------------------------------------------------------- /images_test/flower0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/flower0.jpg -------------------------------------------------------------------------------- /images_test/flower1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/flower1.jpg -------------------------------------------------------------------------------- /images_test/human1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/human1.jpg -------------------------------------------------------------------------------- /images_test/human2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/human2.jpg -------------------------------------------------------------------------------- /images_test/human3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/human3.jpg -------------------------------------------------------------------------------- /images_test/gt/human1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/gt/human1.png -------------------------------------------------------------------------------- /images_test/gt/human2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/gt/human2.png -------------------------------------------------------------------------------- /images_test/gt/human3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/gt/human3.png -------------------------------------------------------------------------------- /images_test/pineapple.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/images_test/pineapple.jpg -------------------------------------------------------------------------------- /folder_images_paper/pineapple.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/folder_images_paper/pineapple.jpg -------------------------------------------------------------------------------- /folder_images_paper/skin_lesions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/folder_images_paper/skin_lesions.png -------------------------------------------------------------------------------- /folder_images_paper/tumor_region.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/folder_images_paper/tumor_region.png -------------------------------------------------------------------------------- /folder_images_paper/real_life_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/folder_images_paper/real_life_images.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchstain 4 | torch-contour 5 | scipy 6 | numba 7 | torchvision 8 | tqdm 9 | opencv-python 10 | matplotlib -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Cytomine-python-client/ 2 | cytomine/ 3 | generate_annotations/delete_annotations.py 4 | generate_annotations/get_annotations.py 5 | **/__pycache__/ 6 | -------------------------------------------------------------------------------- /generate_annotations/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/generate_annotations/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /generate_annotations/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/generate_annotations/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generate_annotations/__pycache__/delete_annotations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinehabis/Deep-ContourFlow/HEAD/generate_annotations/__pycache__/delete_annotations.cpython-38.pyc -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | 5 | path_data = os.getenv('PATH_DATA_DILATED_TUBULES') 6 | if path_data != None: 7 | path_dcf = str(Path(__file__).resolve().parent) 8 | path_slides = os.path.join(path_data,'slides') 9 | path_annotations = os.path.join(path_dcf,'generate_annotations') 10 | path_images = os.path.join(path_data,'images') 11 | path_masks = os.path.join(path_data,'masks') 12 | path_scores = os.path.join(path_dcf,'scores') 13 | pathes = [path_slides, path_images, path_masks] 14 | 15 | for path in pathes: 16 | if not os.path.exists(path): 17 | os.makedirs(path) -------------------------------------------------------------------------------- /generate_annotations/delete_annotations.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import sys 6 | from pathlib import Path 7 | 8 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 9 | 10 | import logging 11 | import sys 12 | from shapely.geometry import Point, box 13 | 14 | from cytomine import Cytomine 15 | from cytomine.models import AnnotationCollection 16 | import os 17 | 18 | logging.basicConfig() 19 | logger = logging.getLogger("cytomine.client") 20 | logger.setLevel(logging.INFO) 21 | 22 | 23 | def delete_annotations(id_image, id_project): 24 | pb_key = os.getenv('CYTOMINE_PUBLIC') 25 | pv_key = os.getenv('CYTOMINE_PRIVATE') 26 | host = os.getenv('CYTOMINE_HOST') 27 | 28 | with Cytomine(host=host, public_key=pb_key, private_key=pv_key) as cytomine: 29 | # Get the list of annotations 30 | annotations = AnnotationCollection() 31 | annotations.image = id_image 32 | annotations.project = id_project 33 | annotations.fetch() 34 | for annotation in annotations: 35 | annotation.delete() 36 | return "You deleted all the annnotations" 37 | -------------------------------------------------------------------------------- /generate_annotations/extract_images.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | from config import * 6 | from utils import ( 7 | retrieve_img_contour, 8 | interpolate, 9 | process_coord_get_image, 10 | find_thresh, 11 | row_to_filename, 12 | row_to_coordinates, 13 | ) 14 | from openslide import OpenSlide 15 | from tqdm import tqdm 16 | import numpy as np 17 | import pandas as pd 18 | import cv2 19 | import tifffile 20 | 21 | annotations = pd.read_csv( 22 | os.path.join(path_annotations, "annotations.csv"), index_col=0 23 | ) 24 | n = annotations.shape[0] 25 | annotations = annotations.replace(["dilated_tubule", "fake_tubule"], [1, 0]) 26 | 27 | filenames = np.unique(list(annotations["slide"])) 28 | coordinates_start = {} 29 | 30 | for filename in filenames: 31 | thresh = find_thresh(filename, percentile=90) 32 | slide_path = os.path.join(path_slides, filename) 33 | im = OpenSlide(slide_path) 34 | anns = annotations[annotations["slide"] == filename] 35 | for row in tqdm(anns.iterrows()): 36 | coordinates, term = row_to_coordinates(row[1]) 37 | img, contour_true = process_coord_get_image(coordinates, im=im, margin=100) 38 | mask = cv2.fillPoly( 39 | np.zeros((img.shape[0], img.shape[1])), contour_true[None].astype(int), 1, 0 40 | ).astype(int) 41 | 42 | img, contour_init = retrieve_img_contour(img=img, thresh=thresh, mask=mask) 43 | filename = row_to_filename(row[1]) 44 | if not (os.path.exists(path_masks)): 45 | os.makedirs(path_masks) 46 | if not (os.path.exists(path_images)): 47 | os.makedirs(path_images) 48 | tifffile.imsave(os.path.join(path_masks, filename), mask) 49 | tifffile.imsave(os.path.join(path_images, filename), img) 50 | coordinates_start[filename] = interpolate(contour_init, 100) 51 | 52 | np.save(os.path.join(path_data, "contour_init.npy"), coordinates_start) 53 | -------------------------------------------------------------------------------- /generate_annotations/get_annotations.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import sys 7 | from pathlib import Path 8 | 9 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 10 | from config import * 11 | import logging 12 | 13 | import sys 14 | import pandas as pd 15 | from cytomine import Cytomine 16 | from cytomine.models import AnnotationCollection, ImageInstanceCollection 17 | from cytomine.models import TermCollection 18 | 19 | 20 | 21 | def get_by_id(haystack, needle): 22 | return next((item for item in haystack if item.id == needle), None) 23 | 24 | 25 | pb_key = os.getenv('CYTOMINE_PUBLIC') 26 | pv_key = os.getenv('CYTOMINE_PRIVATE') 27 | host = os.getenv('CYTOMINE_HOST') 28 | 29 | logging.basicConfig() 30 | logger = logging.getLogger("cytomine.client") 31 | logger.setLevel(logging.INFO) 32 | 33 | project_id = os.getenv('PROJECT_ID') 34 | 35 | with Cytomine(host=host, public_key=pb_key, private_key=pv_key) as cytomine: 36 | terms = TermCollection().fetch_with_filter("project", project_id) 37 | terms_dict = {t.id: t.name for t in terms} 38 | image_instances = ImageInstanceCollection().fetch_with_filter("project", project_id) 39 | dic_id_img = {} 40 | for image in image_instances: 41 | dic_id_img[image.id] = image.filename.split("/")[-1] 42 | # We want all annotations in a given project. 43 | annotations = AnnotationCollection() 44 | annotations.project = project_id # Add a filter: only annotations from this project 45 | annotations.showWKT = True # Ask to return WKT location (geometry) in the response 46 | annotations.showMeta = ( 47 | True # Ask to return meta information (id, ...) in the response 48 | ) 49 | annotations.showGIS = ( 50 | True # Ask to return GIS information (perimeter, area, ...) in the response 51 | ) 52 | annotations.showTerm = True 53 | # ... 54 | # => Fetch annotations from the server with the given filters. 55 | annotations.fetch() 56 | 57 | df = pd.DataFrame( 58 | columns=["id", "slide", "project", "term", "area", "perimeter", "location"] 59 | ) 60 | 61 | for annotation in annotations: 62 | if len(annotation.term) > 0: 63 | df = df.append( 64 | { 65 | "id": annotation.id, 66 | "slide": dic_id_img[annotation.image], 67 | "project": annotation.project, 68 | "term": terms_dict[annotation.term[0]], 69 | "area": annotation.area, 70 | "perimeter": annotation.perimeter, 71 | "location": annotation.location, 72 | }, 73 | ignore_index=True, 74 | ) 75 | df.to_csv(os.path.join(path_annotations, "annotations.csv")) 76 | -------------------------------------------------------------------------------- /algorithms/utils_dilated_tubules.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | import cv2 6 | import numpy as np 7 | from scipy.ndimage import binary_closing 8 | from skimage.measure import label 9 | from skimage.morphology import disk 10 | 11 | from config import * 12 | 13 | # from histolab.slide import Slide 14 | 15 | 16 | def row_to_filename(row): 17 | filename = row.slide.split(".")[0] + "_" + str(row.id) + ".tif" 18 | return filename 19 | 20 | 21 | def find_thresh(filename, percentile): 22 | img = Slide(os.path.join(path_slides, filename), processed_path="") 23 | arr = img.resampled_array(scale_factor=4) 24 | gray = np.mean(arr, -1) 25 | ret2, th2 = cv2.threshold( 26 | np.mean(arr, -1).astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU 27 | ) 28 | new_img = binary_closing(1 - th2 / 255, disk(9)).astype(bool) 29 | return np.percentile(gray[new_img], percentile) / 255 30 | 31 | 32 | def row_to_coordinates(row): 33 | class_ = row.term 34 | string = row.location 35 | string = string.replace("POLYGON ", "") 36 | string = string.replace(", ", "),(") 37 | string = string.replace(" ", ",") 38 | coordinates = np.array(eval(string)) 39 | return coordinates, class_ 40 | 41 | 42 | def preprocess_contour(contour_init, img): 43 | img = cv2.fillPoly(np.zeros(img.shape[:-1]), [contour_init.astype(int)], 1) 44 | img = binary_closing(img, disk(5)) 45 | contour_init = np.squeeze( 46 | cv2.findContours(img.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[ 47 | 0 48 | ][0] 49 | ) 50 | return contour_init 51 | 52 | 53 | def process_coord_get_image(coord, im, margin=100): 54 | coord_tmp = coord.copy() 55 | 56 | coord_tmp[:, 1] = im.dimensions[1] - coord_tmp[:, 1] 57 | coord_min = coord_tmp - np.min(coord_tmp, 0) 58 | x_min, y_min = np.min(coord_tmp, 0).astype(int) - margin 59 | x_max, y_max = np.max(coord_tmp, 0).astype(int) + margin 60 | 61 | img = np.array( 62 | im.read_region( 63 | location=[x_min, y_min], level=0, size=[x_max - x_min, y_max - y_min] 64 | ) 65 | )[:, :, :-1] 66 | contour = (coord_min + margin).astype(int) 67 | return img, contour 68 | 69 | 70 | def retrieve_img_contour(img, thresh, mask): 71 | img = img / np.max(img) 72 | mean = np.mean(img, -1) 73 | l, c = np.array(img.shape[:-1]) // 2 74 | 75 | x = label(binary_closing(mean > thresh, disk(5))) 76 | lab = x * mask 77 | uniques, counts = np.unique(lab[lab > 0], return_counts=True) 78 | arg = np.argsort(counts)[-1] 79 | 80 | white = (x == uniques[arg]).astype(int) 81 | shapes = cv2.findContours( 82 | white.astype(np.uint8), 83 | method=cv2.RETR_TREE, 84 | mode=cv2.CHAIN_APPROX_SIMPLE, 85 | )[0] 86 | 87 | return (img * 255).astype(np.uint8), np.squeeze(shapes[0]) 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep ContourFLow 2 | 3 | ![Python](https://img.shields.io/badge/python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54) 4 | [![Mail](https://img.shields.io/badge/Gmail-D14836?style=for-the-badge&logo=gmail&logoColor=white)](mailto:antoine.habis.tlcm@gmail.com) 5 | [![Downloads](https://static.pepy.tech/badge/torch_contour/month)](https://pepy.tech/project/torch_contour) 6 | [![Downloads](https://static.pepy.tech/badge/torch_contour)](https://pepy.tech/project/torch_contour) 7 | [![ArXiv Paper](https://img.shields.io/badge/DOI-10.1038%2Fs41586--020--2649--2-blue)](https://doi.org/10.48550/arXiv.2407.10696) 8 | 9 | To use this repository please first install torch-contour: 10 | 11 | ``` 12 | $pip install -r requirements.txt 13 | ``` 14 | 15 | In this repository you can find the code for both: 16 | 17 | - Unsupervised Deep-ContourFlow 18 | - One shot learning Deep-ContourFlow. 19 | 20 | 21 | 22 | ![Alt text](./folder_images_paper/real_life_images.png "Unsupervised DCF: evolution of the contour on four real-life images when varying the initial contour") 23 | 24 | ![Alt text](./folder_images_paper/tumor_region.png "Unsupervised DCF: evolution of the contour on two histology images.") 25 | 26 | ## Unsupervised Deep ContourFLow: 27 | 28 | To use Unsupervised DCF just add your image in `images_test` and run the algorithm using the notebook: `unsupervised_dcf.ipynb` 29 | 30 | ## One shot learning: 31 | 32 | To use the one shot version of the algorithm please provide a support image with a support mask and a query image in `images_test` and run the algorithm using `one_shot_dcf.ipynb`. 33 | 34 | if you use the the code please cite the following paper: 35 | 36 | ``` 37 | @misc{habis2024deepcontourflowadvancingactive, 38 | title={Deep ContourFlow: Advancing Active Contours with Deep Learning}, 39 | author={Antoine Habis and Vannary Meas-Yedid and Elsa Angelini and Jean-Christophe Olivo-Marin}, 40 | year={2024}, 41 | eprint={2407.10696}, 42 | archivePrefix={arXiv}, 43 | primaryClass={cs.CV}, 44 | url={https://arxiv.org/abs/2407.10696}, 45 | } 46 | ``` 47 | 48 | 82 | -------------------------------------------------------------------------------- /scores/compute_scores.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | from config import * 6 | from algorithms.oneshot_dcf import DCF 7 | from utils import row_to_filename, preprocess_contour 8 | from tqdm import tqdm 9 | import pandas as pd 10 | from skimage.morphology import disk 11 | import tifffile 12 | import numpy as np 13 | import cv2 14 | 15 | annotations = pd.read_csv( 16 | os.path.join(path_annotations, "annotations.csv"), index_col=0 17 | ) 18 | contour_inits = np.load( 19 | os.path.join(path_data, "contour_init.npy"), allow_pickle=True 20 | ).item() 21 | 22 | df = pd.DataFrame( 23 | columns=[ 24 | "slide", 25 | "nb_support", 26 | "DICE(%)", 27 | "IOU(%)", 28 | "gt", 29 | "score", 30 | "nb_query", 31 | ] 32 | ) 33 | try: 34 | df = pd.read_csv(os.path.join(path_scores, "scores.csv"), index_col=0) 35 | except: 36 | pass 37 | 38 | slides_already_processed = list(np.unique(df["slide"])) 39 | 40 | all_filenames = np.unique(list(annotations["slide"])) 41 | filenames_to_process = list(set(all_filenames) - set(slides_already_processed)) 42 | 43 | 44 | def compute_scores_filename(filename, df): 45 | annotations = pd.read_csv( 46 | os.path.join(path_annotations, "annotations.csv"), index_col=0 47 | ) 48 | annotations = annotations.replace(["dilated_tubule", "fake_tubule"], [1, 0]) 49 | 50 | ### We extract only the annotations of a given slide 51 | 52 | annotations = annotations[annotations["slide"] == filename] 53 | 54 | ### We take only the dilated tubule of the slide 55 | 56 | annotations_support = annotations[annotations["term"] == 1] 57 | annotations_support = annotations_support.sample(frac=1).head(10) 58 | 59 | for _, row0 in annotations_support.iterrows(): 60 | dcf = DCF( 61 | nb_points=100, 62 | n_epochs=300, 63 | nb_augment=100, 64 | isolines=np.array([0.0, 1.0]), 65 | learning_rate=5e-2, 66 | clip=1e-1, 67 | sigma=5, 68 | weights=0.9, 69 | exponential_decay=0.999, 70 | thresh=1e-2, 71 | ) 72 | 73 | filename_img_support = row_to_filename(row0) 74 | 75 | print("support_filename", filename_img_support) 76 | 77 | img_support = tifffile.imread(os.path.join(path_images, filename_img_support)) 78 | mask_support = tifffile.imread(os.path.join(path_masks, filename_img_support)) 79 | contour_support = np.squeeze( 80 | cv2.findContours( 81 | mask_support.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 82 | )[0][0] 83 | ) 84 | 85 | dcf.fit(img_support, contour_support, augment=True) 86 | 87 | for _, row1 in annotations.iterrows(): 88 | filename_img = row_to_filename(row1) 89 | img = tifffile.imread(os.path.join(path_images, filename_img)) 90 | term = row1.term 91 | contour_init = contour_inits[filename_img] 92 | 93 | C0 = preprocess_contour(contour_init, img) 94 | shape_fin, score, tots, energies = dcf.predict(img, C0) 95 | 96 | x = np.argmin(tots) 97 | 98 | contour_pred = shape_fin[x] 99 | 100 | img_true = tifffile.imread(os.path.join(path_masks, filename_img)) 101 | img_pred = cv2.fillPoly(np.zeros(img.shape[:-1]), contour_pred[None], 1) 102 | num_DICE = 2 * np.sum(img_true * img_pred) 103 | denom_DICE = np.sum(img_true) + np.sum(img_pred) 104 | num_IOU = np.sum(img_true * img_pred) 105 | denom_IOU = np.sum(np.maximum(img_true, img_pred)) 106 | 107 | DICE = (num_DICE / denom_DICE) * 100 108 | IOU = (num_IOU / denom_IOU) * 100 109 | 110 | df = df.append( 111 | { 112 | "slide": row0.slide, 113 | "nb_support": row0.id, 114 | "DICE(%)": np.round(DICE, 2), 115 | "IOU(%)": np.round(IOU, 2), 116 | "gt": term, 117 | "score": score, 118 | "nb_query": row1.id, 119 | }, 120 | ignore_index=True, 121 | ) 122 | 123 | dice = np.array(list(df["DICE(%)"])) 124 | gt = np.array(list(df["gt"])) 125 | print(str(np.sum(dice * gt) / np.sum(gt))) 126 | df.to_csv(os.path.join(path_scores, "scores.csv")) 127 | 128 | return df 129 | 130 | for filename in tqdm(filenames_to_process): 131 | df = compute_scores_filename(filename, df) 132 | -------------------------------------------------------------------------------- /models/models_architectures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model architectures for DCF. 3 | 4 | This module contains the model architectures supported by DCF: 5 | - VGG16 6 | - ResNet50 7 | - ResNet101 8 | - ResNet_FPN (Feature Pyramid Network) 9 | - ResNet101_FPN 10 | """ 11 | 12 | import torch 13 | import torchvision.models as models 14 | 15 | 16 | class VGG16(torch.nn.Module): 17 | """ 18 | VGG16 architecture for multi-scale feature extraction. 19 | 20 | Extracts features from layers 3, 8, 15, 22, 29 for a multi-scale 21 | representation of image characteristics. 22 | """ 23 | 24 | def __init__(self): 25 | super().__init__() 26 | vgg16 = models.vgg16(weights="DEFAULT") 27 | self.features = vgg16.features.to(torch.float32) 28 | 29 | def forward(self, x): 30 | return self.features(x) 31 | 32 | 33 | class ResNet50(torch.nn.Module): 34 | """ 35 | ResNet50 architecture for multi-scale feature extraction. 36 | 37 | Extracts features from layers layer1, layer2, layer3, layer4 for a multi-scale 38 | representation of image characteristics. 39 | """ 40 | 41 | def __init__(self): 42 | super().__init__() 43 | resnet = models.resnet50(weights="DEFAULT") 44 | self.conv1 = resnet.conv1 45 | self.bn1 = resnet.bn1 46 | self.relu = resnet.relu 47 | self.maxpool = resnet.maxpool 48 | self.layer1 = resnet.layer1 # 1/4 49 | self.layer2 = resnet.layer2 # 1/8 50 | self.layer3 = resnet.layer3 # 1/16 51 | self.layer4 = resnet.layer4 # 1/32 52 | 53 | def forward(self, x): 54 | x = self.conv1(x) 55 | x = self.bn1(x) 56 | x = self.relu(x) 57 | x = self.maxpool(x) 58 | 59 | c1 = self.layer1(x) # 1/4 60 | c2 = self.layer2(c1) # 1/8 61 | c3 = self.layer3(c2) # 1/16 62 | c4 = self.layer4(c3) # 1/32 63 | 64 | return [c1, c2, c3, c4] 65 | 66 | 67 | class ResNet101(torch.nn.Module): 68 | """ 69 | ResNet101 architecture for multi-scale feature extraction. 70 | 71 | Extracts features from layers layer1, layer2, layer3, layer4 for a multi-scale 72 | representation of image characteristics. 73 | """ 74 | 75 | def __init__(self): 76 | super().__init__() 77 | resnet = models.resnet101(weights="DEFAULT") 78 | self.conv1 = resnet.conv1 79 | self.bn1 = resnet.bn1 80 | self.relu = resnet.relu 81 | self.maxpool = resnet.maxpool 82 | self.layer1 = resnet.layer1 # 1/4 83 | self.layer2 = resnet.layer2 # 1/8 84 | self.layer3 = resnet.layer3 # 1/16 85 | self.layer4 = resnet.layer4 # 1/32 86 | 87 | def forward(self, x): 88 | x = self.conv1(x) 89 | x = self.bn1(x) 90 | x = self.relu(x) 91 | x = self.maxpool(x) 92 | 93 | c1 = self.layer1(x) # 1/4 94 | c2 = self.layer2(c1) # 1/8 95 | c3 = self.layer3(c2) # 1/16 96 | c4 = self.layer4(c3) # 1/32 97 | 98 | return [c1, c2, c3, c4] 99 | 100 | 101 | class ResNet_FPN(torch.nn.Module): 102 | """ 103 | ResNet50 architecture with Feature Pyramid Network (FPN). 104 | 105 | Implements a custom Feature Pyramid Network that extracts multi-scale 106 | features directly in the forward pass. Returns a list of features 107 | at different spatial scales. 108 | """ 109 | 110 | def __init__(self, backbone_name: str = "resnet50"): 111 | super().__init__() 112 | 113 | if backbone_name == "resnet50": 114 | backbone = models.resnet50(weights="DEFAULT") 115 | elif backbone_name == "resnet101": 116 | backbone = models.resnet101(weights="DEFAULT") 117 | else: 118 | raise ValueError(f"Backbone {backbone_name} not supported") 119 | 120 | self.backbone = backbone 121 | # Extraire les couches pour FPN 122 | self.layer1 = backbone.layer1 # 1/4 123 | self.layer2 = backbone.layer2 # 1/8 124 | self.layer3 = backbone.layer3 # 1/16 125 | self.layer4 = backbone.layer4 # 1/32 126 | 127 | def forward(self, x): 128 | # Forward pass to extract multi-scale features 129 | x = self.backbone.conv1(x) 130 | x = self.backbone.bn1(x) 131 | x = self.backbone.relu(x) 132 | x = self.backbone.maxpool(x) 133 | 134 | c1 = self.layer1(x) # 1/4 135 | c2 = self.layer2(c1) # 1/8 136 | c3 = self.layer3(c2) # 1/16 137 | c4 = self.layer4(c3) # 1/32 138 | 139 | return [c1, c2, c3, c4] 140 | 141 | 142 | class ResNet101_FPN(torch.nn.Module): 143 | """ 144 | ResNet101 architecture with Feature Pyramid Network (FPN). 145 | 146 | Implements a custom Feature Pyramid Network based on ResNet101. 147 | Returns multi-scale features directly in the forward pass. 148 | """ 149 | 150 | def __init__(self): 151 | super().__init__() 152 | self.fpn = ResNet_FPN("resnet101") 153 | 154 | # Expose layers directly for compatibility with hooks 155 | self.layer1 = self.fpn.layer1 156 | self.layer2 = self.fpn.layer2 157 | self.layer3 = self.fpn.layer3 158 | self.layer4 = self.fpn.layer4 159 | 160 | def forward(self, x): 161 | return self.fpn(x) 162 | -------------------------------------------------------------------------------- /algorithms/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilitaires pour les algorithmes DCF. 3 | 4 | Ce module contient des fonctions utilitaires pour le post-processing 5 | et d'autres opérations communes aux algorithmes DCF. 6 | """ 7 | 8 | import logging 9 | import multiprocessing as mp 10 | from typing import Tuple 11 | 12 | import cv2 13 | import numpy as np 14 | from scipy.ndimage import distance_transform_edt, label 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def process_grabcut_single_helper(args: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: 20 | """ 21 | Helper function for multiprocessing GrabCut processing. 22 | 23 | Args: 24 | args: Tuple containing (img_np, contour) 25 | 26 | Returns: 27 | Processed contour 28 | """ 29 | img_np, contour = args 30 | try: 31 | img_np = np.moveaxis(img_np, 0, -1) # (C, H, W) -> (H, W, C) 32 | img_np = (img_np * 255).astype(np.uint8) 33 | mask = np.zeros((img_np.shape[0], img_np.shape[1]), dtype=np.uint8) 34 | if len(contour.shape) == 2: 35 | contour_for_fill = contour.reshape(-1, 1, 2).astype(int) 36 | else: 37 | contour_for_fill = contour.astype(int) 38 | 39 | cv2.fillPoly(mask, [contour_for_fill], 1) 40 | 41 | distance_map = distance_transform_edt(mask) 42 | distance_map = distance_map / np.max(distance_map) 43 | distance_map_outside = distance_transform_edt(1 - mask) 44 | distance_map_outside = distance_map_outside / np.max(distance_map_outside) 45 | 46 | mask_grabcut = np.full(mask.shape, cv2.GC_PR_BGD, dtype=np.uint8) 47 | mask_grabcut[distance_map > 0.2] = cv2.GC_FGD 48 | mask_grabcut[(distance_map > 0.2) & (distance_map <= 0.8)] = cv2.GC_PR_FGD 49 | mask_grabcut[distance_map_outside > 0.8] = cv2.GC_BGD 50 | 51 | bgdModel = np.zeros((1, 65), np.float64) 52 | fgdModel = np.zeros((1, 65), np.float64) 53 | 54 | cv2.grabCut( 55 | img_np, 56 | mask_grabcut, 57 | None, 58 | bgdModel, 59 | fgdModel, 60 | 5, 61 | cv2.GC_INIT_WITH_MASK, 62 | ) 63 | 64 | result = np.where( 65 | (mask_grabcut == cv2.GC_FGD) | (mask_grabcut == cv2.GC_PR_FGD), 1, 0 66 | ).astype(np.uint8) 67 | 68 | labeled_array, num_features = label(result) 69 | if num_features > 0: 70 | largest_cc = np.argmax(np.bincount(labeled_array.flat)[1:]) + 1 71 | result = (labeled_array == largest_cc).astype(np.uint8) 72 | 73 | contours, _ = cv2.findContours( 74 | result, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 75 | ) 76 | 77 | if contours: 78 | largest_contour = max(contours, key=cv2.contourArea) 79 | return largest_contour.reshape(-1, 2) 80 | else: 81 | return contour 82 | 83 | except Exception as e: 84 | logger.error(f"Error in single GrabCut processing: {e}") 85 | return contour 86 | 87 | 88 | def apply_grabcut_postprocessing_parallel( 89 | img: np.ndarray, final_contours: np.ndarray 90 | ) -> np.ndarray: 91 | """ 92 | Apply GrabCut post-processing with parallel processing for better performance. 93 | 94 | Args: 95 | img: Input image tensor (B, C, H, W) 96 | final_contours: Final contours from DCF (B, K, 2) 97 | 98 | Returns: 99 | Refined contours after GrabCut processing 100 | """ 101 | try: 102 | img_list = [img[i] for i in range(img.shape[0])] 103 | args_list = [(img_list[i], final_contours[i]) for i in range(len(img_list))] 104 | with mp.Pool(processes=min(mp.cpu_count(), len(args_list))) as pool: 105 | results = pool.map(process_grabcut_single_helper, args_list) 106 | 107 | logger.info("GrabCut post-processing completed with parallel processing") 108 | return np.array(results) 109 | 110 | except Exception as e: 111 | logger.error(f"Error in parallel GrabCut post-processing: {e}") 112 | return final_contours # Return original contours if parallel processing fails 113 | 114 | 115 | def apply_grabcut_postprocessing_sequential( 116 | img: np.ndarray, final_contours: np.ndarray 117 | ) -> np.ndarray: 118 | """ 119 | Apply GrabCut post-processing with sequential processing for better stability. 120 | 121 | Args: 122 | img: Input image tensor (B, C, H, W) 123 | final_contours: Final contours from DCF (B, K, 2) 124 | 125 | Returns: 126 | Refined contours after GrabCut processing 127 | """ 128 | try: 129 | results = [] 130 | for i in range(img.shape[0]): 131 | try: 132 | result = process_grabcut_single_helper((img[i], final_contours[i])) 133 | results.append(result) 134 | except Exception as e: 135 | logger.warning(f"Error processing image {i} with GrabCut: {e}") 136 | results.append(final_contours[i]) 137 | 138 | logger.info("GrabCut post-processing completed with sequential processing") 139 | return np.array(results) 140 | 141 | except Exception as e: 142 | logger.error(f"Error in sequential GrabCut post-processing: {e}") 143 | return final_contours # Return original contours if processing fails 144 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module de configuration pour les modèles supportés par DCF. 3 | 4 | Ce module fournit une interface unifiée pour charger et configurer différents modèles 5 | de deep learning (VGG16, ResNet50, ResNet101, ResNet-FPN) pour l'extraction de features. 6 | """ 7 | 8 | import logging 9 | from typing import Any, Dict 10 | 11 | import torch 12 | from torchvision import transforms 13 | 14 | # Import des architectures de modèles 15 | from models.models_architectures import ( 16 | VGG16, 17 | ResNet50, 18 | ResNet101, 19 | ResNet101_FPN, 20 | ResNet_FPN, 21 | ) 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | preprocess = transforms.Compose( 26 | [ 27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 28 | ] 29 | ) 30 | 31 | MODEL_CONFIGS = { 32 | "vgg16": { 33 | "model_fn": lambda: VGG16(), 34 | "layer_indices": [3, 8, 15, 22, 29], 35 | "layer_access": lambda model, idx: model.features[idx] 36 | if hasattr(model.features, str(idx)) 37 | else None, 38 | "preprocess": preprocess, 39 | "description": "VGG16 with features extracted from layers 3, 8, 15, 22, 29", 40 | }, 41 | "resnet50": { 42 | "model_fn": lambda: ResNet50(), 43 | "layer_indices": ["layer1", "layer2", "layer3", "layer4"], 44 | "layer_access": lambda model, idx: getattr(model, idx) 45 | if hasattr(model, idx) 46 | else None, 47 | "preprocess": preprocess, 48 | "description": "ResNet50 with features extracted from layers layer1-4", 49 | }, 50 | "resnet101": { 51 | "model_fn": lambda: ResNet101(), 52 | "layer_indices": ["layer1", "layer2", "layer3", "layer4"], 53 | "layer_access": lambda model, idx: getattr(model, idx) 54 | if hasattr(model, idx) 55 | else None, 56 | "preprocess": preprocess, 57 | "description": "ResNet101 with features extracted from layers layer1-4", 58 | }, 59 | "resnet_fpn": { 60 | "model_fn": lambda: ResNet_FPN("resnet50"), 61 | "layer_indices": ["layer1", "layer2", "layer3", "layer4"], 62 | "layer_access": lambda model, idx: getattr(model, idx) 63 | if hasattr(model, idx) 64 | else None, 65 | "preprocess": preprocess, 66 | "description": "ResNet50 with Feature Pyramid Network (FPN)", 67 | }, 68 | "resnet101_fpn": { 69 | "model_fn": lambda: ResNet101_FPN(), 70 | "layer_indices": ["layer1", "layer2", "layer3", "layer4"], 71 | "layer_access": lambda model, idx: getattr(model, idx) 72 | if hasattr(model, idx) 73 | else None, 74 | "preprocess": preprocess, 75 | "description": "ResNet101 with Feature Pyramid Network (FPN)", 76 | }, 77 | } 78 | 79 | 80 | def create_resnet_fpn_model(backbone_name: str = "resnet50") -> torch.nn.Module: 81 | """ 82 | Creates a ResNet model with Feature Pyramid Network (FPN). 83 | 84 | Args: 85 | backbone_name: Backbone name ('resnet50' or 'resnet101') 86 | 87 | Returns: 88 | ResNet model with FPN 89 | 90 | Raises: 91 | ValueError: If the backbone is not supported 92 | """ 93 | try: 94 | if backbone_name == "resnet50": 95 | return ResNet_FPN("resnet50") 96 | elif backbone_name == "resnet101": 97 | return ResNet_FPN("resnet101") 98 | else: 99 | raise ValueError(f"Backbone {backbone_name} not supported") 100 | 101 | except Exception as e: 102 | logger.error(f"Error creating ResNet-FPN model: {e}") 103 | raise 104 | 105 | 106 | def detect_model_type(model: torch.nn.Module) -> str: 107 | """ 108 | Automatically detects the model type. 109 | 110 | Args: 111 | model: PyTorch model 112 | 113 | Returns: 114 | Detected model type ('vgg16', 'resnet50', 'resnet101', 'resnet_fpn') 115 | """ 116 | model_str = str(model) 117 | model_class_name = model.__class__.__name__ 118 | 119 | if "ResNet" in model_str: 120 | if model_class_name == "ResNet50": 121 | return "resnet50" 122 | elif model_class_name == "ResNet101": 123 | return "resnet101" 124 | elif model_class_name == "ResNet_FPN": 125 | return "resnet_fpn" 126 | elif model_class_name == "ResNet101_FPN": 127 | return "resnet101_fpn" 128 | elif "FPN" in model_str: 129 | return "resnet_fpn" 130 | else: 131 | # Fallback based on class name 132 | if "resnet50" in model_str.lower(): 133 | return "resnet50" 134 | elif "resnet101" in model_str.lower(): 135 | return "resnet101" 136 | else: 137 | return "resnet50" # Default 138 | elif "VGG" in model_str or "Sequential" in model_str or model_class_name == "VGG16": 139 | return "vgg16" 140 | else: 141 | logger.warning(f"Unrecognized model type: {model_str}. Using VGG16 as default.") 142 | return "vgg16" 143 | 144 | 145 | def get_model_config(model_type: str) -> Dict[str, Any]: 146 | """ 147 | Retrieves the configuration of a specific model. 148 | 149 | Args: 150 | model_type: Model type ('vgg16', 'resnet50', 'resnet101', 'resnet_fpn', 'resnet101_fpn') 151 | 152 | Returns: 153 | Model configuration 154 | 155 | Raises: 156 | ValueError: If the model type is not supported 157 | """ 158 | if model_type not in MODEL_CONFIGS: 159 | raise ValueError(f"Unsupported model type: {model_type}") 160 | 161 | return MODEL_CONFIGS[model_type] 162 | 163 | 164 | def list_available_models() -> Dict[str, str]: 165 | """ 166 | Lists all available models with their descriptions. 167 | 168 | Returns: 169 | Dictionary of available models with their descriptions 170 | """ 171 | return {name: config["description"] for name, config in MODEL_CONFIGS.items()} 172 | 173 | 174 | def create_model(model_type: str) -> torch.nn.Module: 175 | """ 176 | Creates a model of the specified type. 177 | 178 | Args: 179 | model_type: Model type to create 180 | 181 | Returns: 182 | Initialized PyTorch model 183 | 184 | Raises: 185 | ValueError: If the model type is not supported 186 | """ 187 | if model_type not in MODEL_CONFIGS: 188 | raise ValueError(f"Unsupported model type: {model_type}") 189 | 190 | try: 191 | return MODEL_CONFIGS[model_type]["model_fn"]() 192 | except Exception as e: 193 | logger.error(f"Error creating model {model_type}: {e}") 194 | raise 195 | 196 | 197 | def get_model_layer_access(model_type: str): 198 | """ 199 | Retrieves the layer access function for a model type. 200 | 201 | Args: 202 | model_type: Model type 203 | 204 | Returns: 205 | Layer access function 206 | """ 207 | if model_type not in MODEL_CONFIGS: 208 | raise ValueError(f"Unsupported model type: {model_type}") 209 | 210 | return MODEL_CONFIGS[model_type]["layer_access"] 211 | 212 | 213 | def get_model_layer_indices(model_type: str) -> list: 214 | """ 215 | Retrieves the layer indices for a model type. 216 | 217 | Args: 218 | model_type: Model type 219 | 220 | Returns: 221 | List of layer indices 222 | """ 223 | if model_type not in MODEL_CONFIGS: 224 | raise ValueError(f"Unsupported model type: {model_type}") 225 | 226 | return MODEL_CONFIGS[model_type]["layer_indices"] 227 | 228 | 229 | def get_model_preprocess(model_type: str): 230 | """ 231 | Retrieves the preprocessing function for a model type. 232 | 233 | Args: 234 | model_type: Model type 235 | 236 | Returns: 237 | Preprocessing function 238 | """ 239 | if model_type not in MODEL_CONFIGS: 240 | raise ValueError(f"Unsupported model type: {model_type}") 241 | 242 | return MODEL_CONFIGS[model_type]["preprocess"] 243 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | from config import * 6 | import numpy as np 7 | import cv2 8 | from torch.nn import Module 9 | import torch 10 | import torch.nn.functional as F 11 | from torch_contour import * 12 | from torch import cdist 13 | from torchvision.transforms.functional import vflip, hflip 14 | 15 | 16 | def piecewise_linear(x, x0, y0, k1, k2): 17 | return np.piecewise( 18 | x, [x < x0], [lambda x: k1 * x + y0 - k1 * x0, lambda x: k2 * x + y0 - k2 * x0] 19 | ) 20 | 21 | 22 | class Contour_to_features(torch.nn.Module): 23 | """ 24 | A PyTorch neural network module designed to convert contour data into feature representations. 25 | This class leverages two sub-modules: Contour_to_mask and Mask_to_features. 26 | """ 27 | 28 | def __init__(self, size: int, activations: dict): 29 | """ 30 | Initializes the Contour_to_features class. 31 | 32 | This method creates instances of two sub-modules: 33 | - Contour_to_mask with a parameter of 200. 34 | - Mask_to_features with no parameters. 35 | 36 | 37 | Parameters: 38 | ----------- 39 | activations: dict 40 | A Dictionary of feature maps (e.g., from a CNN). 41 | The dictionnary contains keys (int) which represent the order of the activations chosen by the user. 42 | Which means that activations[0] returns the 1st feature among the set of features in activations that have been chosen by the user. 43 | >>> Example 1: If the user wants to select each feature extracted by the model at each scale, activations(i) should contain the feature extracted at scale i. 44 | >>> Example 2: If the user wants to select each feature extracted by the model, activations(i) should contain the feature extracted at layer i. 45 | 46 | ctm : Contour_to_mask 47 | An instance of the Contour_to_mask class, initialized with a parameter of 200. 48 | mtf : Mask_to_features 49 | An instance of the Mask_to_features class. 50 | """ 51 | super(Contour_to_features, self).__init__() 52 | self.ctm = Contour_to_mask(size, k=1e4).requires_grad_(False) 53 | self.mtf = Mask_to_features(activations).requires_grad_(False) 54 | 55 | def forward(self, contour): 56 | """ 57 | Defines the forward pass of the Contour_to_features model. 58 | 59 | This method takes in a contour and activations, uses the Contour_to_mask sub-module to 60 | generate a mask from the contour, and then applies the Mask_to_features sub-module to 61 | combine the activations with the mask. 62 | 63 | Parameters: 64 | ----------- 65 | contour : Tensor 66 | The contour data input tensor. 67 | 68 | Returns: 69 | -------- 70 | output_features: List 71 | The output features after combining the activations with the mask. 72 | the output_features is a List. 73 | len(output_features) = len(activations) 74 | output_features[i] has shape (B, C, 1) 75 | 76 | """ 77 | mask = self.ctm(contour) 78 | output_features = self.mtf(mask) 79 | return output_features 80 | 81 | 82 | class Mask_to_features(Module): 83 | """ 84 | A PyTorch neural network module designed to convert mask and activation data into feature representations. 85 | """ 86 | 87 | def __init__(self, activations, eps=1e-5): 88 | """ 89 | Initializes the Mask_to_features class. 90 | 91 | This method sets up the module without any specific parameters. 92 | 93 | Parameters: 94 | ----------- 95 | activations: dict 96 | A Dictionary of feature maps (e.g., from a CNN). 97 | The dictionnary contains keys (int) which represent the order of the activations chosen by the user. 98 | Which means that activations[0] returns the 1st feature among the set of features in activations that have been chosen by the user. 99 | >>> Example 1: If the user wants to select each feature extracted by the model at each scale, activations(i) should contain the feature extracted at scale i. 100 | >>> Example 2: If the user wants to select each feature extracted by the model, activations(i) should contain the feature extracted at layer i. 101 | """ 102 | super(Mask_to_features, self).__init__() 103 | self.activations = activations 104 | self.eps = eps 105 | 106 | def forward(self, mask: torch.Tensor): 107 | """ 108 | Defines the forward pass of the Mask_to_features model. 109 | 110 | This method takes in a dictionary of activations and a mask tensor, resizes the mask to match the 111 | dimensions of each activation layer, and then calculates features inside and outside the mask for each 112 | activation layer. 113 | 114 | Parameters: 115 | ----------- 116 | mask : torch.Tensor 117 | The mask tensor. 118 | 119 | Returns: 120 | -------- 121 | features_inside : list of torch.Tensor 122 | A list containing the feature representations inside the mask for each activation layer. 123 | 124 | features_outside : list of torch.Tensor 125 | A list containing the feature representations outside the mask for each activation layer. 126 | 127 | len(features) = len(activations) 128 | features_outside[i] has shape (B, C_i, 1) and features_inside has shape (B, C_i, 1) 129 | with C_i the number of channels in self.activations[i] 130 | """ 131 | 132 | masks = [ 133 | F.interpolate( 134 | mask, 135 | size=(self.activations[i].shape[-2], self.activations[i].shape[-1]), 136 | mode="bilinear", 137 | ) 138 | for i in range(len(self.activations)) 139 | ] 140 | 141 | features_inside, features_outside = [], [] 142 | 143 | for i in range(len(self.activations)): 144 | features_inside.append( 145 | ( 146 | torch.sum(self.activations[i] * masks[i], dim=(2, 3)) 147 | / (torch.sum(masks[i], (2, 3)) + self.eps) 148 | )[..., None] 149 | ) 150 | 151 | features_outside.append( 152 | ( 153 | torch.sum(self.activations[i] * (1 - masks[i]), dim=(2, 3)) 154 | / (torch.sum((1 - masks[i]), (2, 3)) + self.eps) 155 | )[..., None] 156 | ) 157 | 158 | return features_inside, features_outside 159 | 160 | 161 | def augmentation(tuple_inputs_arrays): 162 | ps = np.random.random(10) 163 | for element in tuple_inputs_arrays: 164 | element = element.reshape((1, element.shape[0], element.shape[1], -1)) 165 | 166 | if ps[0] > 1 / 4 and ps[0] < 1 / 2: 167 | element = torch.rot90(element, dims=(1, 2), k=1) 168 | 169 | if ps[0] > 1 / 2 and ps[0] < 3 / 4: 170 | element = torch.rot90(element, dims=(1, 2), k=2) 171 | 172 | if ps[0] > 3 / 4 and ps[0] < 1: 173 | element = torch.rot90(element, dims=(1, 2), k=3) 174 | 175 | if ps[1] > 0.5: 176 | element = vflip(element) 177 | 178 | if ps[2] > 0.5: 179 | element = hflip(element) 180 | 181 | return tuple_inputs_arrays 182 | 183 | 184 | ##### Change the doc 185 | class Contour_to_isoline_features(torch.nn.Module): 186 | """ 187 | A PyTorch neural network module designed to convert contour data into feature representations. 188 | This class leverages two sub-modules: Contour_to_mask and Mask_to_features. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | size: int, 194 | activations: dict, 195 | isolines: torch.Tensor, 196 | halfway_value: float, 197 | compute_features_mask=False, 198 | ): 199 | """ 200 | Initializes the Contour_to_features class. 201 | 202 | This method creates instances of two sub-modules: 203 | - Contour_to_mask 204 | - Mask_to_features 205 | 206 | Parameters: 207 | ----------- 208 | size: int 209 | the size of image containing the normalized distance map generated in order to retrieve the isolines_features 210 | activations: dict 211 | A Dictionary of feature maps (e.g., from a CNN). 212 | The dictionnary contains keys (int) which represent the order of the activations chosen by the user. 213 | Which means that activations[0] returns the 1st feature among the set of features in activations that have been chosen by the user. 214 | >>> Example 1: If the user wants to select each feature extracted by the model at each scale, activations(i) should contain the feature extracted at scale i. 215 | >>> Example 2: If the user wants to select each feature extracted by the model, activations(i) should contain the feature extracted at layer i. 216 | ctd : Contour_to_mask 217 | An instance of the Contour_to_distance_map class. 218 | dtf : Distance_map_to_features 219 | An instance of the istance_map_to_features class. 220 | compute_features_mask : bool 221 | whether to compute the average features at each scale inside the mask or not 222 | 223 | """ 224 | super(Contour_to_isoline_features, self).__init__() 225 | self.ctd = Contour_to_distance_map(size).requires_grad_(False) 226 | self.dtf = Distance_map_to_isoline_features( 227 | activations, isolines, halfway_value 228 | ).requires_grad_(False) 229 | self.compute_features_mask = compute_features_mask 230 | 231 | def forward(self, contour): 232 | """ 233 | Defines the forward pass of the Contour_to_features model. 234 | 235 | This method takes in a contour and activations, uses the Contour_to_mask sub-module to 236 | generate a mask from the contour, and then applies the Mask_to_features sub-module to 237 | combine the activations with the mask. 238 | 239 | Parameters: 240 | ----------- 241 | contour : Tensor 242 | The contour data input tensor. 243 | Returns: 244 | -------- 245 | output_features: tuple of list of tensors 246 | The output features after combining the activations with the mask. 247 | if self.compute_features_mask = True, the output_features will be a tuple. 248 | output_features[0] correspond to the list of the features at each scale and each isoline. 249 | output_features[1] correspond to the list of the features inside the mask at each scale. 250 | if self.compute_features_mask = False 251 | output_features correspond to the list of the features at each scale and each isoline. 252 | 253 | 254 | """ 255 | self.dtf.compute_features_mask = self.compute_features_mask 256 | dmap, mask = self.ctd(contour, True) 257 | output_features = self.dtf(dmap, mask) 258 | return output_features 259 | 260 | 261 | class Distance_map_to_isoline_features(Module): 262 | def __init__( 263 | self, 264 | activations: dict, 265 | isolines: torch.Tensor, 266 | halfway_value: float = 0.5, 267 | compute_features_mask=False, 268 | ): 269 | """ 270 | Initializes the Isoline_to_features class. 271 | 272 | Parameters: 273 | ----------- 274 | activations: dict 275 | A Dictionary of feature maps (e.g., from a CNN). 276 | The dictionnary contains keys (int) which represent the order of the activations chosen by the user. 277 | Which means that activations[0] returns the 1st feature among the set of features in activations that have been chosen by the user. 278 | >>> Example 1: If the user wants to select each feature extracted by the model at each scale, activations(i) should contain the feature extracted at scale i. 279 | >>> Example 2: If the user wants to select each feature extracted by the model, activations(i) should contain the feature extracted at layer i. 280 | 281 | isolines: torch.Tensor 282 | A tensor representing isoline values. 283 | the isolines values must in [0, 1] 284 | Example: torch.tensor([0.0, 0.5, 0.8]) 285 | 286 | halfway_value: float 287 | halfway_value is the value that must be reached in the middle of two consecutive isolines (represented as gaussians) when summing them together. 288 | >>> For example if isolines = [0,1] 289 | >>> and halfway value = 0.8 290 | >>> then the variances of the gaussian centered on 0 and the gaussian centered on 1 should be set so that thety sum up to 0.8 at 0.5. 291 | 292 | 293 | """ 294 | super(Distance_map_to_isoline_features, self).__init__() 295 | 296 | self.isolines = isolines # Store the isoline tensor. 297 | self.vars = self.mean_to_var( 298 | self.isolines, halfway_value 299 | ) # Store the variance tensor. 300 | self.activations = activations 301 | self.compute_features_mask = compute_features_mask 302 | 303 | def mean_to_var(self, isolines, halfway_value): 304 | """ 305 | 306 | This function takes a list of isolines values (which correspond to the mean values of the gaussians) 307 | and computes the variances of each gaussian so that two consecutive gaussians sum to halfway_value at halfway the means. 308 | 309 | Parameters: 310 | ----------- 311 | isolines: torch.Tensor 312 | A tensor representing isoline values. 313 | halfway_value: float 314 | The value that must be reached in the middle of two consecutive isolines (represented as gaussians) when summing them together. 315 | 316 | Returns: 317 | -------- 318 | variances: torch.Tensor 319 | The variances of each gaussian so that two consecutive isolines sum to halfway_value at halfway. 320 | len(variances) = len(isolines) 321 | """ 322 | 323 | mat = cdist(isolines[:, None], isolines[:, None]) ** 2 324 | mat = torch.where(mat == 0, torch.tensor(float("inf")), mat) 325 | variances = -torch.min(mat, 0).values / (8 * np.log(halfway_value)) 326 | return variances 327 | 328 | def forward(self, distance_map: torch.Tensor, mask: torch.Tensor): 329 | """ 330 | Forward pass of the Isoline_to_features module. Generates features from activations and isolines. 331 | 332 | Parameters: 333 | ----------- 334 | 335 | distance_map: torch.Tensor 336 | A tensor with shape (B, 1, H, W) 337 | The tensor represents a batch of distance maps. 338 | 339 | mask: torch.Tensor: 340 | A tensor with shape (B, 1, H, W) 341 | the mask of each contour in the batch. 342 | 343 | compute_features_mask: (bool, optional) 344 | If True, compute additional aggregated features inside the masks for each features in activations. 345 | 346 | Returns: 347 | -------- 348 | 349 | features_isolines: list 350 | A list of aggregated features at each isoline for each feature in activations. 351 | 352 | features_mask:list 353 | A list of aggregated features inside the mask for each features in activations (if compute_features_mask is True). 354 | """ 355 | 356 | # Number of scales in the activations dictionary 357 | nb_scales = len(self.activations) 358 | 359 | # Apply Gaussian-like weighting to isolines based on distance_map and variance 360 | isolines = mask * torch.exp( 361 | -((self.isolines[None, :, None, None] - distance_map) ** (2)) 362 | / (self.vars[None, :, None, None]) 363 | ) 364 | 365 | # Resize the isolines to match each activation scale, using bilinear interpolation 366 | isolines_scales = [ 367 | F.interpolate( 368 | isolines, 369 | size=( 370 | self.activations[i].shape[-2], 371 | self.activations[i].shape[-1], 372 | ), # Match activations' spatial size 373 | mode="bilinear", 374 | ) 375 | for i in range(nb_scales) 376 | ] 377 | 378 | # If compute_features_mask is True, resize the mask for each scale 379 | if self.compute_features_mask: 380 | masks = [ 381 | F.interpolate( 382 | mask, 383 | size=( 384 | self.activations[i].shape[-2], 385 | self.activations[i].shape[-1], 386 | ), # Match activations' spatial size 387 | mode="bilinear", 388 | ) 389 | for i in range(nb_scales) 390 | ] 391 | 392 | # Initialize lists for features and features_mask (if applicable) 393 | features_isolines, features_mask = [], [] 394 | 395 | # Loop through each scale and compute features 396 | for i in range(nb_scales): 397 | 398 | # Compute feature aggregation at scale 'i' by summing over the spatial dimensions, 399 | # weighted by the isolines, and normalizing by the sum of isolines 400 | f_s_i = (self.activations[i][:, :, None] * isolines_scales[i][:, None]).sum( 401 | dim=[-2, -1] 402 | ) / isolines_scales[i].sum(dim=[-2, -1])[:, None] 403 | features_isolines.append(f_s_i) 404 | 405 | # If compute_features_mask is True, compute and store features based on masks 406 | if self.compute_features_mask: 407 | features_mask.append( 408 | torch.sum( 409 | self.activations[i] * masks[i], dim=(-2, -1) 410 | ) # Compute masked feature aggregation 411 | / torch.sum(masks[i], dim=(-2, -1)) # Normalize by mask's sum 412 | ) 413 | 414 | # Return the features and features_mask (if computed) 415 | return features_isolines, features_mask 416 | 417 | 418 | def define_contour_init(n, center, axes, angle=0): 419 | # major, minor axes 420 | start_angle = 0 421 | end_angle = 360 422 | color = 1 423 | thickness = -1 424 | 425 | # Draw a filled ellipse on the input image 426 | mask = cv2.ellipse( 427 | np.zeros((n, n)), 428 | center, 429 | axes, 430 | angle, 431 | start_angle, 432 | end_angle, 433 | color, 434 | thickness, 435 | ).astype(np.uint8) 436 | contour = np.squeeze( 437 | cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0][0] 438 | ) 439 | return contour, mask 440 | -------------------------------------------------------------------------------- /algorithms/unsupervised_dcf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from pathlib import Path 4 | from typing import Tuple 5 | 6 | from algorithms.utils import apply_grabcut_postprocessing_parallel 7 | from models.models import ( 8 | VGG16, 9 | create_model, 10 | detect_model_type, 11 | get_model_layer_access, 12 | get_model_layer_indices, 13 | get_model_preprocess, 14 | ) 15 | from utils import Contour_to_features, piecewise_linear 16 | 17 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 18 | 19 | import numpy as np 20 | import torch 21 | from scipy import optimize 22 | from torch.nn.utils import clip_grad_norm_ 23 | from torch.optim import Adam 24 | from torch.optim.lr_scheduler import ReduceLROnPlateau 25 | from torch_contour import CleanContours, Smoothing, area 26 | from tqdm import tqdm 27 | 28 | logging.basicConfig(level=logging.INFO) 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class DCF: 33 | """ 34 | Implementation of the unsupervised Deep Contour Flow (DCF) algorithm. 35 | 36 | This class implements the unsupervised version of DCF that moves the contour 37 | over time to push as far away as possible the features inside and outside the contour. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | n_epochs: int = 50, # Réduit de 100 à 50 43 | model=VGG16, # Peut être un torch.nn.Module ou un string 44 | learning_rate: float = 1e-1, # Augmenté pour convergence plus rapide 45 | clip: float = 5e-2, # Réduit pour stabilité 46 | area_force: float = 0.0, 47 | sigma: float = 1, 48 | early_stopping_patience: int = 5, # Réduit pour arrêt plus rapide 49 | early_stopping_threshold: float = 1e-6, 50 | use_mixed_precision: bool = True, # Activé par défaut 51 | do_apply_grabcut: bool = False, 52 | max_batch_size: int = 4, # Nouveau paramètre pour le traitement par batch 53 | ): 54 | """ 55 | Initialize the DCF algorithm with the specified parameters. 56 | 57 | Args: 58 | n_epochs: Maximum number of training epochs 59 | model: Pre-trained model for extracting activations 60 | learning_rate: Learning rate for optimization 61 | clip: Gradient clipping value 62 | area_force: Weight of the contour area constraint 63 | sigma: Standard deviation of the Gaussian smoothing operator 64 | early_stopping_patience: Number of epochs before early stopping 65 | early_stopping_threshold: Minimum improvement threshold for early stopping 66 | use_mixed_precision: Use mixed precision for GPU acceleration 67 | do_apply_grabcut: Apply GrabCut post-processing 68 | max_batch_size: Maximum batch size for processing 69 | 70 | Raises: 71 | ValueError: If parameters are invalid 72 | """ 73 | 74 | self._validate_parameters( 75 | n_epochs, 76 | learning_rate, 77 | clip, 78 | area_force, 79 | sigma, 80 | early_stopping_patience, 81 | early_stopping_threshold, 82 | ) 83 | 84 | self.n_epochs = n_epochs 85 | # Initialiser le modèle 86 | self.model = self._initialize_model(model) 87 | self.model_type = detect_model_type(self.model) 88 | self.learning_rate = learning_rate 89 | self.clip = clip 90 | self.lambda_area = area_force 91 | self.device = None 92 | self.max_batch_size = max_batch_size 93 | 94 | self.early_stopping_patience = early_stopping_patience 95 | self.early_stopping_threshold = early_stopping_threshold 96 | self.use_mixed_precision = use_mixed_precision 97 | self.do_apply_grabcut = do_apply_grabcut 98 | # 1. Optimisations GPU 99 | self._setup_gpu_optimizations() 100 | 101 | self._initialize_components(sigma) 102 | 103 | if self.use_mixed_precision: 104 | if not torch.cuda.is_available(): 105 | logger.warning( 106 | "Mixed precision requested but CUDA not available. Disabling." 107 | ) 108 | self.use_mixed_precision = False 109 | else: 110 | self.scaler = torch.cuda.amp.GradScaler() 111 | 112 | logger.info(f"DCF initialized with {n_epochs} epochs, lr={learning_rate}") 113 | 114 | def _initialize_model(self, model) -> torch.nn.Module: 115 | """ 116 | Initialise le modèle en fonction du paramètre fourni. 117 | 118 | Args: 119 | model: PyTorch model, model class, or string specifying the model type 120 | 121 | Returns: 122 | Modèle initialisé 123 | """ 124 | if isinstance(model, str): 125 | return create_model(model) 126 | elif isinstance(model, type) and issubclass(model, torch.nn.Module): 127 | # If it's a model class, create an instance 128 | return model() 129 | else: 130 | # If it's already a model instance 131 | return model 132 | 133 | def _setup_gpu_optimizations(self): 134 | """Configure GPU optimizations for better performance.""" 135 | if torch.cuda.is_available(): 136 | torch.backends.cudnn.benchmark = True 137 | torch.backends.cudnn.deterministic = False 138 | torch.backends.cuda.matmul.allow_tf32 = True 139 | torch.backends.cudnn.allow_tf32 = True 140 | logger.info("GPU optimizations enabled") 141 | 142 | def _cleanup_gpu_memory(self): 143 | """Clean up GPU memory cache.""" 144 | if torch.cuda.is_available(): 145 | torch.cuda.empty_cache() 146 | 147 | def _validate_parameters( 148 | self, 149 | n_epochs: int, 150 | learning_rate: float, 151 | clip: float, 152 | area_force: float, 153 | sigma: float, 154 | early_stopping_patience: int, 155 | early_stopping_threshold: float, 156 | ) -> None: 157 | """Validate input parameters.""" 158 | if n_epochs <= 0: 159 | raise ValueError("n_epochs must be positive") 160 | if learning_rate <= 0: 161 | raise ValueError("learning_rate must be positive") 162 | if clip <= 0: 163 | raise ValueError("clip must be positive") 164 | if sigma <= 0: 165 | raise ValueError("sigma must be positive") 166 | if early_stopping_patience < 0: 167 | raise ValueError("early_stopping_patience must be non-negative") 168 | if early_stopping_threshold < 0: 169 | raise ValueError("early_stopping_threshold must be non-negative") 170 | 171 | def _initialize_components(self, sigma: float) -> None: 172 | """Initialize algorithm components.""" 173 | try: 174 | self.activations = {} 175 | self.shapes = {} 176 | self.spatial_dims = {} # Pour stocker les dimensions spatiales des activations 177 | 178 | self._setup_activation_hooks() 179 | 180 | self.smooth = Smoothing(sigma) 181 | self.cleaner = CleanContours() 182 | 183 | except Exception as e: 184 | logger.error(f"Error initializing components: {e}") 185 | raise 186 | 187 | def _setup_activation_hooks(self) -> None: 188 | """Configure hooks for extracting model activations.""" 189 | try: 190 | # For models that return features directly, no hooks needed 191 | if self.model_type in [ 192 | "resnet_fpn", 193 | "resnet50", 194 | "resnet101", 195 | "resnet101_fpn", 196 | ]: 197 | logger.info( 198 | f"{self.model_type} detected: no hooks needed, activations will be captured in forward pass" 199 | ) 200 | return 201 | 202 | # Determine layer indices to use based on model type 203 | layer_indices = get_model_layer_indices(self.model_type) 204 | layer_access = get_model_layer_access(self.model_type) 205 | 206 | for i, layer_idx in enumerate(layer_indices): 207 | layer_model = layer_access(self.model, layer_idx) 208 | 209 | if layer_model is not None: 210 | layer_model.register_forward_hook(self.get_activations(i)) 211 | else: 212 | logger.warning( 213 | f"Layer {layer_idx} not found in model or not accessible." 214 | ) 215 | except Exception as e: 216 | logger.error(f"Error configuring hooks: {e}") 217 | raise 218 | 219 | def get_activations(self, name: int): 220 | """ 221 | Create a hook to capture activations from a specific layer. 222 | 223 | Args: 224 | name: Layer name/index 225 | 226 | Returns: 227 | Hook function to capture activations 228 | """ 229 | 230 | def hook(model, input, output): 231 | try: 232 | device = input[0].device 233 | self.activations[name] = output.to(device) 234 | # Capturer les dimensions spatiales (H, W) de la couche 235 | self.spatial_dims[name] = output.shape[ 236 | 2 237 | ] # Prendre H (ou W, ils sont égaux) 238 | except Exception as e: 239 | logger.error(f"Error capturing activations: {e}") 240 | raise 241 | 242 | return hook 243 | 244 | def multiscale_loss( 245 | self, features: Tuple[list, list], eps: float = 1e-6 246 | ) -> torch.Tensor: 247 | """ 248 | Compute a multiscale loss based on features inside and outside the mask. 249 | Optimized version with vectorized operations where possible. 250 | 251 | Args: 252 | features: Tuple containing (features_inside, features_outside) for each scale 253 | eps: Small value to avoid division by zero 254 | 255 | Returns: 256 | Computed multiscale loss 257 | """ 258 | try: 259 | features_inside, features_outside = features 260 | nb_scales = len(features_inside) 261 | batch_size = features_inside[0].shape[0] 262 | energies = torch.zeros((nb_scales, batch_size), device=self.device) 263 | 264 | scale_contributions = torch.zeros(nb_scales, device=self.device) 265 | 266 | for j in range(nb_scales): 267 | diff = features_inside[j] - features_outside[j] 268 | norm_diff = torch.linalg.vector_norm(diff, 2, dim=-2)[..., 0] # (B,) 269 | norm_activations = torch.linalg.vector_norm( 270 | self.activations[j], 2, dim=(1, 2, 3) 271 | ) # (B,) 272 | 273 | norm_mse = -norm_diff / (norm_activations + eps) # (B,) 274 | energies[j] = norm_mse 275 | 276 | scale_contributions[j] = norm_diff.mean() / ( 277 | norm_activations.mean() + eps 278 | ) 279 | 280 | # ---- Dynamic weight computation ---- 281 | inv_contrib = 1.0 / (scale_contributions + eps) 282 | dynamic_weights = inv_contrib / inv_contrib.sum() 283 | dynamic_weights = dynamic_weights.view(nb_scales, 1) 284 | return torch.sum(energies * dynamic_weights, dim=0) 285 | except Exception as e: 286 | logger.error(f"Error computing multiscale loss: {e}") 287 | raise 288 | 289 | def predict( 290 | self, img: torch.Tensor, contour_init: torch.Tensor 291 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 292 | """ 293 | Predict contour for a given image and initial contour. 294 | 295 | Args: 296 | img: Input image tensor of shape (B, C, H, W) 297 | contour_init: Initial contour tensor of shape (B, 1, K, 2) 298 | 299 | Returns: 300 | Tuple containing: 301 | - contour_history: Contour history during prediction 302 | - loss_history: Loss values history 303 | - final_contours: Optimized final contours 304 | 305 | Raises: 306 | ValueError: If input tensors are invalid 307 | RuntimeError: If an error occurs during optimization 308 | """ 309 | try: 310 | self._validate_inputs(img, contour_init) 311 | 312 | self.device = contour_init.device 313 | self.img_dim = torch.tensor(img.shape[-2:], device=self.device) 314 | 315 | # Prepare data 316 | loss_history = np.zeros((contour_init.shape[0], self.n_epochs)) 317 | contour_history = [] 318 | 319 | self._setup_model_and_activations(img) 320 | 321 | contour, optimizer, lr_scheduler = self._setup_optimization(contour_init) 322 | 323 | self._setup_processing_components(img) 324 | 325 | contour_history, loss_history = self._run_optimization_loop( 326 | contour, optimizer, lr_scheduler, loss_history, contour_history 327 | ) 328 | 329 | final_contours = self._compute_final_contours(contour_history, loss_history) 330 | 331 | # Apply GrabCut if requested 332 | if self.do_apply_grabcut: 333 | logger.info("Applying GrabCut post-processing...") 334 | img_np = img.cpu().numpy() 335 | final_contours = apply_grabcut_postprocessing_parallel( 336 | img_np, final_contours 337 | ) 338 | 339 | logger.info("Prediction completed successfully") 340 | return ( 341 | np.roll(contour_history, axis=-1, shift=-1), 342 | loss_history, 343 | final_contours, 344 | ) 345 | 346 | except Exception as e: 347 | logger.error(f"Error during prediction: {e}") 348 | raise RuntimeError(f"Prediction failed: {e}") 349 | 350 | def _process_single_batch( 351 | self, imgs: torch.Tensor, contours_init: torch.Tensor 352 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 353 | """Process a single batch of images and contours.""" 354 | return self.predict(imgs, contours_init) 355 | 356 | def _merge_batch_results( 357 | self, results: list 358 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 359 | """Merge results from multiple batches.""" 360 | contour_histories = [] 361 | loss_histories = [] 362 | final_contours = [] 363 | 364 | for contour_history, loss_history, final_contour in results: 365 | contour_histories.append(contour_history) 366 | loss_histories.append(loss_history) 367 | final_contours.append(final_contour) 368 | 369 | return ( 370 | np.concatenate(contour_histories, axis=0), 371 | np.concatenate(loss_histories, axis=0), 372 | np.concatenate(final_contours, axis=0), 373 | ) 374 | 375 | def _validate_inputs(self, img: torch.Tensor, contour_init: torch.Tensor) -> None: 376 | """Validate input tensors.""" 377 | if img.dtype != torch.float32: 378 | raise ValueError("Image must be of type float32") 379 | if img.dim() != 4: 380 | raise ValueError("Image must have 4 dimensions (B, C, H, W)") 381 | if contour_init.dim() != 4: 382 | raise ValueError("Initial contour must have 4 dimensions (B, 1, K, 2)") 383 | if img.shape[0] != contour_init.shape[0]: 384 | raise ValueError("Image and contour batch sizes must match") 385 | 386 | def _setup_model_and_activations(self, img: torch.Tensor) -> None: 387 | """Configure model and extract activations.""" 388 | try: 389 | # Move model to appropriate device 390 | if str(self.device) == "cuda:0": 391 | self.model = self.model.cuda() 392 | elif str(self.device) == "mps:0": 393 | self.model = self.model.to(torch.device("mps")) 394 | 395 | # Get preprocessing function for the model type 396 | preprocess_fn = get_model_preprocess(self.model_type) 397 | 398 | # Extract activations based on model type 399 | with torch.no_grad(): 400 | if self.model_type in [ 401 | "resnet_fpn", 402 | "resnet50", 403 | "resnet101", 404 | "resnet101_fpn", 405 | ]: 406 | # For these models, forward returns multi-scale features directly 407 | activations = self.model(preprocess_fn(img)) 408 | for i, activation in enumerate(activations): 409 | self.activations[i] = activation.to(self.device) 410 | # Capturer les dimensions spatiales pour ces modèles aussi 411 | self.spatial_dims[i] = activation.shape[ 412 | 2 413 | ] # Prendre H (ou W, ils sont égaux) 414 | else: 415 | # For other models (VGG), use normal forward pass 416 | _ = self.model(preprocess_fn(img)) 417 | 418 | except Exception as e: 419 | logger.error(f"Error configuring model: {e}") 420 | raise 421 | 422 | def _setup_optimization( 423 | self, contour_init: torch.Tensor 424 | ) -> Tuple[torch.Tensor, Adam, ReduceLROnPlateau]: 425 | """Configure optimization with improved learning rate scheduling.""" 426 | try: 427 | contour = torch.roll(contour_init, dims=-1, shifts=1) 428 | contour = contour.contiguous() 429 | contour.requires_grad = True 430 | 431 | optimizer = Adam( 432 | [contour], lr=self.learning_rate, eps=1e-8, betas=(0.9, 0.999) 433 | ) 434 | lr_scheduler = ReduceLROnPlateau( 435 | optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-6 436 | ) 437 | 438 | return contour, optimizer, lr_scheduler 439 | 440 | except Exception as e: 441 | logger.error(f"Error configuring optimization: {e}") 442 | raise 443 | 444 | def _setup_processing_components(self, img: torch.Tensor) -> None: 445 | """Configure processing components.""" 446 | try: 447 | self.ctf = Contour_to_features(img.shape[-1] // (2**2), self.activations) 448 | 449 | # Calculer les poids dynamiquement à partir des dimensions spatiales réelles 450 | except Exception as e: 451 | logger.error(f"Error configuring processing components: {e}") 452 | raise 453 | 454 | def _run_optimization_loop( 455 | self, 456 | contour: torch.Tensor, 457 | optimizer: Adam, 458 | lr_scheduler: ReduceLROnPlateau, 459 | loss_history: np.ndarray, 460 | contour_history: list, 461 | ) -> Tuple[list, np.ndarray]: 462 | """Execute main optimization loop with performance monitoring.""" 463 | try: 464 | best_loss = float("inf") 465 | patience_counter = 0 466 | 467 | logger.info("Starting contour evolution...") 468 | 469 | for i in tqdm(range(self.n_epochs), desc="Optimizing contour"): 470 | optimizer.zero_grad() 471 | 472 | loss, batch_loss = self._compute_loss(contour) 473 | 474 | self._backward_and_update(loss, contour, optimizer) 475 | lr_scheduler.step(loss) # Use loss for scheduler step 476 | 477 | contour = self._smooth_contour(contour) 478 | 479 | contour_cleaned = self._save_history( 480 | contour, batch_loss, loss_history, contour_history, i 481 | ) 482 | 483 | optimizer.param_groups[0]["params"][0] = contour_cleaned 484 | contour = contour_cleaned 485 | 486 | if self._check_early_stopping(batch_loss, best_loss, patience_counter): 487 | logger.info(f"Early stopping at epoch {i + 1}") 488 | break 489 | 490 | best_loss, patience_counter = self._update_early_stopping_vars( 491 | batch_loss, best_loss, patience_counter 492 | ) 493 | 494 | return contour_history, loss_history 495 | 496 | except Exception as e: 497 | logger.error(f"Error during optimization loop: {e}") 498 | raise 499 | 500 | def _compute_loss(self, contour: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 501 | """Compute loss for current step.""" 502 | try: 503 | if self.use_mixed_precision and str(self.device) == "cuda:0": 504 | with torch.cuda.amp.autocast(): 505 | features = self.ctf(contour) 506 | batch_loss = ( 507 | self.multiscale_loss(features) 508 | + self.lambda_area * area(contour)[:, 0] 509 | ) 510 | loss = self.img_dim[0] * torch.mean(batch_loss) 511 | else: 512 | features = self.ctf(contour) 513 | batch_loss = ( 514 | self.multiscale_loss(features) 515 | + self.lambda_area * area(contour)[:, 0] 516 | ) 517 | loss = self.img_dim[0] * torch.mean(batch_loss) 518 | 519 | return loss, batch_loss 520 | 521 | except Exception as e: 522 | logger.error(f"Error computing loss: {e}") 523 | raise 524 | 525 | def _backward_and_update( 526 | self, loss: torch.Tensor, contour: torch.Tensor, optimizer: Adam 527 | ) -> None: 528 | """Perform backward pass and parameter update.""" 529 | try: 530 | if self.use_mixed_precision and str(self.device) == "cuda:0": 531 | self.scaler.scale(loss).backward(inputs=contour) 532 | self.scaler.unscale_(optimizer) 533 | clip_grad_norm_(contour, self.clip) 534 | self.scaler.step(optimizer) 535 | self.scaler.update() 536 | else: 537 | loss.backward(inputs=contour) 538 | clip_grad_norm_(contour, self.clip) 539 | optimizer.step() 540 | 541 | except Exception as e: 542 | logger.error(f"Error during backward pass: {e}") 543 | raise 544 | 545 | def _smooth_contour(self, contour: torch.Tensor) -> torch.Tensor: 546 | """Apply smoothing to contour.""" 547 | try: 548 | contour_input = torch.clone(contour) 549 | return ( 550 | self.smooth((contour - contour_input).to(torch.float32)) + contour_input 551 | ) 552 | except Exception as e: 553 | logger.error(f"Error smoothing contour: {e}") 554 | raise 555 | 556 | def _save_history( 557 | self, 558 | contour: torch.Tensor, 559 | batch_loss: torch.Tensor, 560 | loss_history: np.ndarray, 561 | contour_history: list, 562 | epoch: int, 563 | ) -> torch.Tensor: 564 | """Save optimization history and return cleaned contour with optimized GPU memory management.""" 565 | try: 566 | with torch.no_grad(): 567 | # Force contiguous copy to avoid negative strides 568 | batch_loss_contiguous = batch_loss.contiguous() 569 | loss_history[:, epoch] = batch_loss_contiguous.cpu().detach().numpy() 570 | 571 | img_dims = self.img_dim.cpu().numpy() 572 | img_dims_xy = img_dims[::-1].copy() 573 | contour_scaled = ( 574 | contour 575 | * torch.tensor( 576 | img_dims_xy, device=self.device, dtype=torch.float32 577 | )[None, None, None] 578 | ) 579 | 580 | contour_scaled_contiguous = contour_scaled.contiguous() 581 | contour_history.append( 582 | contour_scaled_contiguous.cpu().detach().numpy().astype(np.int32) 583 | ) 584 | 585 | # Clean up GPU memory 586 | contour_np = contour.cpu().detach().numpy() 587 | contour_cleaned_np = self.cleaner.clean_contours_and_interpolate( 588 | contour_np 589 | ) 590 | contour_cleaned = torch.clip(torch.from_numpy(contour_cleaned_np), 0, 1) 591 | 592 | if str(self.device) == "cuda:0": 593 | contour_cleaned = contour_cleaned.to(torch.float32).cuda() 594 | elif str(self.device) == "mps:0": 595 | contour_cleaned = contour_cleaned.to(torch.float32).to( 596 | torch.device("mps") 597 | ) 598 | 599 | contour_cleaned.grad = None 600 | contour_cleaned.requires_grad = True 601 | 602 | # Clean up GPU memory 603 | self._cleanup_gpu_memory() 604 | 605 | return contour_cleaned 606 | 607 | except Exception as e: 608 | logger.error(f"Error saving history: {e}") 609 | raise 610 | 611 | def _check_early_stopping( 612 | self, batch_loss: torch.Tensor, best_loss: float, patience_counter: int 613 | ) -> bool: 614 | """Check if early stopping should be triggered with improved criteria.""" 615 | current_loss = torch.mean(batch_loss).item() 616 | 617 | loss_improvement = current_loss < best_loss - self.early_stopping_threshold 618 | if loss_improvement: 619 | best_loss = current_loss 620 | patience_counter = 0 621 | else: 622 | patience_counter += 1 623 | 624 | return patience_counter >= self.early_stopping_patience 625 | 626 | def _update_early_stopping_vars( 627 | self, batch_loss: torch.Tensor, best_loss: float, patience_counter: int 628 | ) -> Tuple[float, int]: 629 | """Update early stopping variables with improved logic.""" 630 | current_loss = torch.mean(batch_loss).item() 631 | if current_loss < best_loss - self.early_stopping_threshold: 632 | best_loss = current_loss 633 | patience_counter = 0 634 | else: 635 | patience_counter += 1 636 | return best_loss, patience_counter 637 | 638 | def _compute_final_contours( 639 | self, contour_history: list, loss_history: np.ndarray 640 | ) -> np.ndarray: 641 | """Compute optimized final contours.""" 642 | try: 643 | contour_history_array = np.roll( 644 | np.stack(contour_history), axis=-1, shift=-1 645 | )[:, :, 0] 646 | 647 | final_contours = np.zeros( 648 | ( 649 | loss_history.shape[0], 650 | contour_history_array.shape[-2], 651 | contour_history_array.shape[-1], 652 | ) 653 | ) 654 | 655 | for i, loss in enumerate(loss_history): 656 | try: 657 | # Remove NaN values from loss history 658 | valid_loss = loss[~np.isnan(loss)] 659 | if len(valid_loss) < 2: 660 | logger.warning( 661 | f"Not enough valid loss values for sample {i}, using last contour" 662 | ) 663 | final_contours[i] = contour_history_array[-1, i] 664 | continue 665 | 666 | p, _ = optimize.curve_fit( 667 | piecewise_linear, 668 | np.arange(len(valid_loss)), 669 | valid_loss, 670 | bounds=( 671 | np.array([0, -np.inf, -np.inf, -np.inf]), 672 | np.array([len(valid_loss), np.inf, np.inf, np.inf]), 673 | ), 674 | ) 675 | 676 | index_stop = int(p[0]) - 10 677 | index_stop = max(0, min(index_stop, len(contour_history_array) - 1)) 678 | final_contours[i] = contour_history_array[index_stop, i] 679 | 680 | except Exception as e: 681 | logger.warning(f"Error computing final contour for sample {i}: {e}") 682 | final_contours[i] = contour_history_array[-1, i] 683 | 684 | logger.info("Contour stopped") 685 | return final_contours 686 | 687 | except Exception as e: 688 | logger.error(f"Error computing final contours: {e}") 689 | raise 690 | -------------------------------------------------------------------------------- /algorithms/oneshot_dcf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import warnings 4 | from pathlib import Path 5 | from typing import List, Optional, Tuple 6 | 7 | import cv2 8 | from scipy.ndimage import distance_transform_edt, label 9 | 10 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 11 | import numpy as np 12 | import torch 13 | import torchstain 14 | from torch_contour import CleanContours, Contour_to_distance_map, Smoothing, area 15 | from torchvision import models, transforms 16 | from tqdm import tqdm 17 | 18 | from utils import ( 19 | Contour_to_features, 20 | Contour_to_isoline_features, 21 | Distance_map_to_isoline_features, 22 | Mask_to_features, 23 | augmentation, 24 | ) 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | logger = logging.getLogger(__name__) 28 | 29 | warnings.filterwarnings("ignore", category=UserWarning) 30 | 31 | preprocess = transforms.Compose( 32 | [ 33 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 34 | ] 35 | ) 36 | 37 | try: 38 | vgg16 = models.vgg16(weights="DEFAULT") 39 | VGG16 = vgg16.features.to(torch.float32) 40 | except Exception as e: 41 | logger.error(f"Error loading VGG16 model: {e}") 42 | raise 43 | 44 | 45 | class DCF: 46 | def __init__( 47 | self, 48 | n_epochs: int = 100, 49 | nb_augment: int = 100, 50 | model: torch.nn.Module = VGG16, 51 | sigma: float = 7, 52 | learning_rate: float = 5e-2, 53 | clip: float = 1e-1, 54 | exponential_decay: float = 0.998, 55 | thresh: float = 1e-2, 56 | isolines: Optional[List[float]] = None, 57 | isoline_weights: Optional[List[float]] = None, 58 | lambda_area: float = 1e-4, 59 | early_stopping_patience: int = 10, 60 | early_stopping_threshold: float = 1e-6, 61 | use_mixed_precision: bool = False, 62 | device: Optional[str] = None, 63 | do_apply_grabcut: bool = False, 64 | ): 65 | """ 66 | This class implements the one shot version of DCF. It contains a fit and a predict step. 67 | The fit step aims at capturing the features of the support image in the support contour. 68 | The predict step aims at evolving an initial contour so that the features match as much as possible to the ones of the support 69 | 70 | Parameters: 71 | ----------- 72 | n_epochs : int 73 | The maximum number of gradient descent during the predict step. 74 | 75 | nb_augment : int 76 | The number of augmentations applied to the support image during the fitting step. 77 | Note that if you want to apply your own augmentations please go to utils.py and modify the augmentation method. 78 | 79 | model: torch.nn.Module 80 | The pretrained model from which the activations will be extracted. 81 | This model can be any model as long as the activations have shape (B,C,H,W). 82 | Note that for each model you choose to work with you will have to specify which activations of the model you want to use. 83 | For example if you are interested in the activations from the 3rd, 8th, 15th, 22th, 29th layers of the model VGG16 please write 84 | 85 | >>> self.model[3].register_forward_hook(self.get_activations("0")) 86 | >>> self.model[8].register_forward_hook(self.get_activations("1")) 87 | >>> self.model[15].register_forward_hook(self.get_activations("2")) 88 | >>> self.model[22].register_forward_hook(self.get_activations("3")) 89 | >>> self.model[29].register_forward_hook(self.get_activations("4")) 90 | 91 | You don't have to use 5 layers but we do in the paper. 92 | 93 | sigma: float 94 | The standard deviation of the gaussian smoothing operator. 95 | 96 | learning_rate: float 97 | The value of the gradient step. 98 | 99 | clip: float 100 | The value to set in order to clip the norm of the gradient of the contour so that it doesn't move too far. 101 | 102 | exponential_decay: float 103 | The exponential decay of the learning_rate. 104 | 105 | thresh: float 106 | If the maximum of the norm of the gradient of the contour over each node does not exceed thresh, we stop the contour evolution. 107 | 108 | isolines: List[float] 109 | Values in the list must be in the range [0,1] 110 | If provided, DCF will use the isolines centered on the values inside the list and use them to move the contour over time. 111 | If None, DCF won't use any isoline and will move the contour using the aggregation of the features inside the mask corresponding to the contour. 112 | 113 | isoline_weights: List[float] 114 | The corresponding weights values w_i for each isoline in isolines when computing the loss. 115 | 116 | lambda_area: float 117 | Weight for the area constraint in the loss function. 118 | 119 | early_stopping_patience: int 120 | Number of epochs to wait before stopping if loss doesn't improve. 121 | 122 | early_stopping_threshold: float 123 | Minimum improvement threshold for early stopping. 124 | 125 | use_mixed_precision: bool 126 | Whether to use mixed precision training for faster computation. 127 | 128 | device: Optional[str] 129 | Device to use for computation. If None, will be automatically detected. 130 | """ 131 | 132 | super(DCF, self).__init__() 133 | 134 | self._validate_parameters( 135 | n_epochs, 136 | nb_augment, 137 | sigma, 138 | learning_rate, 139 | clip, 140 | exponential_decay, 141 | thresh, 142 | lambda_area, 143 | early_stopping_patience, 144 | early_stopping_threshold, 145 | ) 146 | 147 | self.n_epochs = n_epochs 148 | self.nb_augment = nb_augment 149 | self.model = model 150 | self.learning_rate = learning_rate 151 | self.clip = clip 152 | self.ed = exponential_decay 153 | self.thresh = thresh 154 | self.lambda_area = lambda_area 155 | self.early_stopping_patience = early_stopping_patience 156 | self.early_stopping_threshold = early_stopping_threshold 157 | self.use_mixed_precision = use_mixed_precision 158 | self.do_apply_grabcut = do_apply_grabcut 159 | 160 | # Device setup 161 | if device is None: 162 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 163 | else: 164 | self.device = torch.device(device) 165 | 166 | # Mixed precision setup 167 | if self.use_mixed_precision and self.device.type == "cuda": 168 | self.scaler = torch.cuda.amp.GradScaler() 169 | else: 170 | self.scaler = None 171 | 172 | # Enable optimizations 173 | if self.device.type == "cuda": 174 | torch.backends.cudnn.benchmark = True 175 | torch.backends.cudnn.deterministic = False 176 | 177 | self._initialize_components(sigma, isolines, isoline_weights) 178 | 179 | logger.info( 180 | f"DCF initialized with {n_epochs} epochs, lr={learning_rate}, device={self.device}" 181 | ) 182 | 183 | def _validate_parameters( 184 | self, 185 | n_epochs: int, 186 | nb_augment: int, 187 | sigma: float, 188 | learning_rate: float, 189 | clip: float, 190 | exponential_decay: float, 191 | thresh: float, 192 | lambda_area: float, 193 | early_stopping_patience: int, 194 | early_stopping_threshold: float, 195 | ) -> None: 196 | """Validate input parameters.""" 197 | if n_epochs <= 0: 198 | raise ValueError("n_epochs must be positive") 199 | if nb_augment <= 0: 200 | raise ValueError("nb_augment must be positive") 201 | if sigma <= 0: 202 | raise ValueError("sigma must be positive") 203 | if learning_rate <= 0: 204 | raise ValueError("learning_rate must be positive") 205 | if clip <= 0: 206 | raise ValueError("clip must be positive") 207 | if not 0 < exponential_decay < 1: 208 | raise ValueError("exponential_decay must be between 0 and 1") 209 | if thresh <= 0: 210 | raise ValueError("thresh must be positive") 211 | if lambda_area < 0: 212 | raise ValueError("lambda_area must be non-negative") 213 | if early_stopping_patience <= 0: 214 | raise ValueError("early_stopping_patience must be positive") 215 | if early_stopping_threshold <= 0: 216 | raise ValueError("early_stopping_threshold must be positive") 217 | 218 | def _initialize_components(self, sigma: float, isolines, isoline_weights) -> None: 219 | """Initialize algorithm components.""" 220 | try: 221 | self.activations = {} 222 | self.isolines = isolines 223 | self.isoline_weights = isoline_weights 224 | 225 | self._setup_activation_hooks() 226 | 227 | self.smooth = Smoothing(sigma) 228 | self.cleaner = CleanContours() 229 | 230 | # Initialize normalizer 231 | self.normalizer = torchstain.normalizers.MacenkoNormalizer(backend="torch") 232 | self.normalizer.HERef = np.array( 233 | [ 234 | [0.47262014, 0.17700575], 235 | [0.79697804, 0.84033483], 236 | [0.37610664, 0.51235373], 237 | ] 238 | ) 239 | self.normalizer.maxCRef = np.array([1.43072807, 0.98501085]) 240 | 241 | except Exception as e: 242 | logger.error(f"Error initializing components: {e}") 243 | raise 244 | 245 | def _setup_activation_hooks(self) -> None: 246 | """Configure hooks for extracting model activations.""" 247 | try: 248 | # VGG16 layers for multi-scale feature extraction 249 | layer_indices = [3, 8, 15, 22, 29] 250 | for i, layer_idx in enumerate(layer_indices): 251 | if hasattr(self.model, str(layer_idx)): 252 | self.model[layer_idx].register_forward_hook(self.get_activations(i)) 253 | else: 254 | logger.warning(f"Layer {layer_idx} not found in model") 255 | except Exception as e: 256 | logger.error(f"Error configuring hooks: {e}") 257 | raise 258 | 259 | def get_activations(self, name: int): 260 | """ 261 | Returns a hook function that stores the activations (outputs) of a layer in a dictionary 262 | under the given name. This hook is designed to be registered on a specific layer in a model, 263 | allowing you to capture its output (activations) during the forward pass. 264 | 265 | Parameters: 266 | ----------- 267 | name : int 268 | An integer that identifies the name/key under which the activations of the layer 269 | should be stored in `self.activations`. 270 | 271 | Returns: 272 | -------- 273 | hook : function 274 | A hook function that takes the model, input, and output as arguments. It captures 275 | the output (activations) of the layer and stores it in the `self.activations` 276 | dictionary, ensuring the tensor is moved to the correct device. 277 | """ 278 | 279 | def hook(model, input, output): 280 | """Hook function that captures the activations of the layer and stores them. 281 | 282 | Parameters: 283 | ----------- 284 | model : torch.nn.Module 285 | The layer from which activations are being captured. 286 | input : torch.Tensor 287 | Input to the layer. It's used here to get the device information. 288 | output : torch.Tensor 289 | The output (activations) of the layer, which will be stored in `self.activations`. 290 | 291 | Returns: 292 | -------- 293 | None 294 | """ 295 | try: 296 | device = input[0].device 297 | self.activations[name] = output.to(device) 298 | except Exception as e: 299 | logger.error(f"Error capturing activations: {e}") 300 | raise 301 | 302 | return hook 303 | 304 | def multi_scale_multi_isoline_loss( 305 | self, features_isolines: List[torch.Tensor] 306 | ) -> Tuple[torch.Tensor, torch.Tensor]: 307 | """ 308 | Computes the multi-scale multi-isoline loss between the query features and support features across multiple 309 | activation scales and isolines. This loss measures the difference between the isoline features at different 310 | scales and computes a weighted mean loss. 311 | 312 | Parameters: 313 | ----------- 314 | features_isolines : List[torch.Tensor] 315 | A list of feature isoline tensors for the query, where each tensor corresponds to a layer's 316 | feature isolines. Each tensor has shape `(B, N, C_i)`, where: 317 | - `B`: Number of samples in the batch. 318 | - `N`: Number of isolines per sample. 319 | - `C_i`: Feature dimension for the respective layer. 320 | 321 | Returns: 322 | -------- 323 | loss_batch : torch.Tensor 324 | A 1D tensor of shape `(B,)` representing the total loss per sample, averaged across scales. 325 | 326 | loss_scales_isos_batch : torch.Tensor 327 | A 2D tensor of shape `(B, N)` representing the isoline-wise loss per sample, 328 | averaged across scales. 329 | """ 330 | try: 331 | batch_size = features_isolines[0].shape[0] 332 | num_activations = len(self.activations) 333 | 334 | loss_scales = torch.zeros((batch_size, num_activations), device=self.device) 335 | loss_scales_isos_batch = torch.zeros( 336 | (batch_size, num_activations, self.nb_iso), device=self.device 337 | ) 338 | 339 | if str(self.device) == "cuda:0": 340 | loss_scales, self.isoline_weights = ( 341 | loss_scales.cuda(), 342 | self.isoline_weights.cuda(), 343 | ) 344 | 345 | for j in range(num_activations): 346 | difference_features = ( 347 | features_isolines[j] - self.features_isolines_support[j] 348 | ) 349 | lsi = torch.sqrt(torch.norm(difference_features, dim=-2)) 350 | loss_scales_isos_batch[:, j] = lsi 351 | loss_scales[:, j] = torch.mean(self.isoline_weights * lsi, dim=-1) 352 | 353 | loss_batch = torch.mean(loss_scales, dim=-1) 354 | return loss_batch, loss_scales_isos_batch 355 | 356 | except Exception as e: 357 | logger.error(f"Error computing multi-scale multi-isoline loss: {e}") 358 | raise 359 | 360 | def fit(self, img_support: torch.Tensor, polygon_support: torch.Tensor) -> None: 361 | # Ensure input tensors are float32 362 | if img_support.dtype != torch.float32: 363 | img_support = img_support.float() 364 | if polygon_support.dtype != torch.float32: 365 | polygon_support = polygon_support.float() 366 | """ 367 | Fit the DCF model to the support image and contour. 368 | 369 | Parameters: 370 | ----------- 371 | img_support : torch.Tensor 372 | Support image tensor of shape (B, C, H, W) 373 | polygon_support : torch.Tensor 374 | Support contour tensor of shape (B, K, 2) 375 | """ 376 | try: 377 | if img_support.dtype != torch.float32: 378 | raise ValueError("tensor must be of type float32") 379 | 380 | size = img_support.shape[-1] 381 | ctd = Contour_to_distance_map(size=size) 382 | distance_map_support, mask_support = ctd(polygon_support, return_mask=True) 383 | 384 | with torch.no_grad(): 385 | logger.info("Fitting DCF one shot...") 386 | 387 | self._move_model_to_device() 388 | 389 | if str(self.device) == "cuda:0": 390 | self.model = self.model.cuda() 391 | if self.isolines != None: 392 | self.isolines = torch.tensor( 393 | self.isolines, dtype=torch.float32 394 | ).cuda() 395 | self.isoline_weights = torch.tensor( 396 | self.isoline_weights, dtype=torch.float32 397 | ).cuda() 398 | 399 | for i in tqdm(range(self.nb_augment), desc="Augmenting support"): 400 | img_augmented, mask_augmented, distance_map_support_augmented = ( 401 | augmentation((img_support, mask_support, distance_map_support)) 402 | ) 403 | 404 | _ = self.model(preprocess(img_augmented)) 405 | 406 | if self.isolines is None: 407 | class_feature_extractor = Mask_to_features( 408 | self.activations 409 | ).requires_grad_(False) 410 | self.nb_iso = 1 411 | self.isoline_weights = torch.tensor(1.0, dtype=torch.float32) 412 | self.ctf = Contour_to_features(size // (2**2), self.activations) 413 | input_ = (mask_augmented,) 414 | else: 415 | self.nb_iso = self.isolines.shape[0] 416 | self.ctf = Contour_to_isoline_features( 417 | size // (2**2), 418 | self.activations, 419 | halfway_value=0.5, 420 | isolines=self.isolines, 421 | ) 422 | class_feature_extractor = Distance_map_to_isoline_features( 423 | self.activations, halfway_value=0.5, isolines=self.isolines 424 | ) 425 | input_ = (distance_map_support_augmented, mask_augmented) 426 | 427 | class_feature_extractor.compute_features_mask = True 428 | tmp, tmp_mask = class_feature_extractor(*input_) 429 | 430 | if i == 0: 431 | self.features_isolines_support = tmp 432 | self.features_mask_support = tmp_mask 433 | else: 434 | for j, (iso, mask) in enumerate(zip(tmp, tmp_mask)): 435 | self.features_isolines_support[j] += iso 436 | self.features_mask_support[j] += mask 437 | 438 | self.features_isolines_support = [ 439 | u / self.nb_augment for u in self.features_isolines_support 440 | ] 441 | self.features_anchor_mask = [ 442 | u / self.nb_augment for u in self.features_mask_support 443 | ] 444 | 445 | self.weights = torch.tensor( 446 | [1 / (2) ** i for i in range(len(self.activations))], 447 | dtype=torch.float32, 448 | device=self.device, 449 | ) 450 | self.weights = self.weights / torch.sum(self.weights) 451 | 452 | logger.info("DCF fitting completed successfully") 453 | 454 | except Exception as e: 455 | logger.error(f"Error during fitting: {e}") 456 | raise 457 | 458 | def similarity_score(self, features_mask_query: List[torch.Tensor]) -> torch.Tensor: 459 | """ 460 | Computes a similarity score between query feature masks and support feature masks 461 | using a weighted cosine similarity across multiple activation layers. 462 | 463 | Parameters: 464 | ----------- 465 | features_mask_query : List[torch.Tensor] 466 | A list of query feature mask tensors, where each tensor corresponds to a layer's 467 | feature representation. Each tensor has shape (B, C), where: 468 | - `B`: Batch size (number of query samples). 469 | - `C`: Feature dimension for the respective layer. 470 | 471 | Returns: 472 | -------- 473 | torch.Tensor: 474 | A 1D tensor of shape (B,) representing the weighted similarity score for each 475 | sample in the batch across all layers. 476 | """ 477 | try: 478 | b = features_mask_query[0].shape[0] 479 | num_activations = len(self.activations) 480 | score = torch.zeros((num_activations, b), device=self.device) 481 | 482 | for i in range(num_activations): 483 | support_features = torch.squeeze(self.features_mask_support[i]) 484 | query_features = torch.squeeze(features_mask_query[i]).to(torch.float32) 485 | 486 | numerator = torch.sum(support_features * query_features, dim=-1) 487 | denominator = torch.linalg.norm( 488 | support_features, dim=-1 489 | ) * torch.linalg.norm(query_features, dim=-1) 490 | cos = self.weights[i] * numerator / (denominator + 1e-8) 491 | score[i] = torch.flatten(cos) 492 | 493 | return torch.mean(score, dim=0) 494 | 495 | except Exception as e: 496 | logger.error(f"Error computing similarity score: {e}") 497 | raise 498 | 499 | def predict( 500 | self, imgs_query: torch.Tensor, contours_query: torch.Tensor 501 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 502 | # Ensure input tensors are float32 503 | if imgs_query.dtype != torch.float32: 504 | imgs_query = imgs_query.float() 505 | if contours_query.dtype != torch.float32: 506 | contours_query = contours_query.float() 507 | """ 508 | Predicts contours for the query images using gradient descent-based optimization. 509 | This function refines the contours through several epochs, computes the loss and loss_scales_isos 510 | at each step, and returns the best contours based on the minimum loss. 511 | 512 | Parameters: 513 | ----------- 514 | imgs_query : torch.Tensor 515 | A batch of input query images of shape `(B, C, H, W)`. 516 | contours_query : torch.Tensor 517 | Initial contour points for the images of shape `(B, 1, K, 2)`. 518 | 519 | where - B is the batch size, 520 | - K is the number of nodes in each contour 521 | - C = 3 the number of channel 522 | - H the Height of the images 523 | - W the Width of the images 524 | 525 | Returns: 526 | -------- 527 | epochs_contours_query[argmin] : np.ndarray 528 | The predicted contours of shape `(B, K, 2)` that minimize the loss. 529 | scores : torch.Tensor 530 | Similarity scores for each contour, indicating how well the final query contours match 531 | the learned features of the support contour, with shape `(B,)`. 532 | losses : np.ndarray 533 | Losses recorded over epochs, shape `(N_epoch, B)`. 534 | loss_scales_isos : np.ndarray 535 | Energy values recorded over epochs, representing isoline-wise losses across layers, 536 | shape `(n_epochs, B, num_activations, num_isolines)`. 537 | """ 538 | try: 539 | if not hasattr(self, "ctf") or not hasattr( 540 | self, "features_isolines_support" 541 | ): 542 | raise RuntimeError( 543 | "Model must be fitted before prediction. Call fit() first." 544 | ) 545 | 546 | if not hasattr(self, "nb_iso"): 547 | raise RuntimeError( 548 | "nb_iso not defined. Model must be fitted before prediction." 549 | ) 550 | 551 | b = contours_query.shape[0] 552 | self.nb_points = contours_query.shape[-2] 553 | batch_size, _, h, w = imgs_query.shape 554 | losses = np.zeros((self.n_epochs, batch_size)) 555 | epochs_contours_query = np.zeros( 556 | (self.n_epochs, batch_size, contours_query.shape[-2], 2) 557 | ) 558 | loss_scales_isos = np.zeros( 559 | (self.n_epochs, batch_size, len(self.activations), self.nb_iso) 560 | ) 561 | 562 | scale = torch.tensor([512.0, 512.0], dtype=torch.float32) / torch.tensor( 563 | [h, w], dtype=torch.float32 564 | ) 565 | 566 | self.img_dim = torch.tensor( 567 | imgs_query.shape[-2:], dtype=torch.float32, device=self.device 568 | ) 569 | 570 | contours_query_array = contours_query.cpu().detach().numpy() 571 | contours_query_array = self.cleaner.clean_contours_and_interpolate( 572 | contours_query_array 573 | ).clip(0, 1) 574 | contours_query = ( 575 | torch.from_numpy(np.roll(contours_query_array, axis=1, shift=1)) 576 | .float() 577 | .to(self.device) 578 | ) 579 | contours_query.requires_grad = True 580 | 581 | self._move_model_to_device() 582 | scale = scale.to(self.device) 583 | 584 | _ = self.model(preprocess(imgs_query)) 585 | 586 | logger.info("Contour is evolving please wait a few moments...") 587 | 588 | best_loss = float("inf") 589 | patience_counter = 0 590 | 591 | for i in tqdm(range(self.n_epochs), desc="Evolving contour"): 592 | if self.use_mixed_precision and self.scaler is not None: 593 | with torch.cuda.amp.autocast(): 594 | features_isoline_query, _ = self.ctf(contours_query) 595 | loss_batch, loss_scales_isos_batch = ( 596 | self.multi_scale_multi_isoline_loss(features_isoline_query) 597 | ) 598 | loss_all = b * torch.mean( 599 | loss_batch + self.lambda_area * area(contours_query)[:, 0] 600 | ) 601 | 602 | self.scaler.scale(loss_all).backward(inputs=contours_query) 603 | else: 604 | features_isoline_query, _ = self.ctf(contours_query) 605 | loss_batch, loss_scales_isos_batch = ( 606 | self.multi_scale_multi_isoline_loss(features_isoline_query) 607 | ) 608 | loss_all = b * torch.mean( 609 | loss_batch + self.lambda_area * area(contours_query)[:, 0] 610 | ) 611 | loss_all.backward(inputs=contours_query) 612 | 613 | current_loss = loss_all.item() 614 | losses[i] = current_loss 615 | epochs_contours_query[i] = contours_query.detach().cpu().numpy() 616 | loss_scales_isos[i] = loss_scales_isos_batch.detach().cpu().numpy() 617 | 618 | if current_loss < best_loss - self.early_stopping_threshold: 619 | best_loss = current_loss 620 | patience_counter = 0 621 | else: 622 | patience_counter += 1 623 | 624 | if patience_counter >= self.early_stopping_patience: 625 | logger.info(f"Early stopping at epoch {i+1}") 626 | break 627 | 628 | norm_grad = torch.unsqueeze(torch.norm(contours_query.grad, dim=-1), -1) 629 | clipped_norm = torch.clip(norm_grad, 0, self.clip) 630 | stop = (torch.amax(norm_grad[:, 0], dim=-2) < self.thresh)[-1] 631 | 632 | if not torch.all(stop): 633 | with torch.no_grad(): 634 | gradient_direction = ( 635 | contours_query.grad * clipped_norm / (norm_grad + 1e-8) 636 | ) 637 | gradient_direction = self.smooth(gradient_direction) 638 | contours_query = ( 639 | contours_query 640 | - scale 641 | * self.learning_rate 642 | * (self.ed**i) 643 | * gradient_direction 644 | ) 645 | interpolated_contour = self.cleaner.clean_contours_and_interpolate( 646 | contours_query.detach().cpu().numpy() 647 | ) 648 | contours_query = torch.clip( 649 | torch.from_numpy(interpolated_contour).float().to(self.device), 650 | 0, 651 | 1, 652 | ) 653 | 654 | contours_query.grad = None 655 | contours_query.requires_grad = True 656 | 657 | else: 658 | logger.info("The algorithm stopped earlier") 659 | break 660 | 661 | # Calculate score after gradient descent 662 | self.ctf.compute_features_mask = True 663 | _, features_mask_query = self.ctf(contours_query) 664 | 665 | scores = self.similarity_score(features_mask_query).cpu().detach().numpy() 666 | 667 | losses[losses == 0] = 1e10 668 | argmin = np.argmin(losses, axis=0) 669 | 670 | best_contours = epochs_contours_query[argmin] 671 | 672 | img_dims = np.array(self.img_dim.cpu().numpy()) 673 | img_dims_xy = img_dims[::-1] 674 | 675 | best_contours_scaled = best_contours * img_dims_xy[None, None, None] 676 | best_contours_final = best_contours_scaled.astype(np.int32) 677 | 678 | # Apply GrabCut if requested 679 | if self.do_apply_grabcut: 680 | logger.info("Applying GrabCut post-processing...") 681 | best_contours_final = self._apply_grabcut_postprocessing( 682 | imgs_query, best_contours_final 683 | ) 684 | 685 | logger.info("Prediction completed successfully") 686 | return best_contours_final, scores, losses, loss_scales_isos 687 | 688 | except Exception as e: 689 | logger.error(f"Error during prediction: {e}") 690 | raise RuntimeError(f"Prediction failed: {e}") 691 | 692 | def _get_model_device(self) -> torch.device: 693 | """Get the device of the model in a robust way.""" 694 | try: 695 | if list(self.model.parameters()): 696 | return next(self.model.parameters()).device 697 | elif list(self.model.buffers()): 698 | return next(self.model.buffers()).device 699 | else: 700 | return torch.device("cpu") 701 | except Exception: 702 | return torch.device("cpu") 703 | 704 | def _move_model_to_device(self) -> None: 705 | """Move the model to the target device if needed.""" 706 | try: 707 | current_device = self._get_model_device() 708 | if current_device != self.device: 709 | self.model = self.model.to(self.device) 710 | except Exception as e: 711 | logger.warning(f"Could not move model to device {self.device}: {e}") 712 | 713 | def _apply_grabcut_postprocessing( 714 | self, img: torch.Tensor, final_contours: np.ndarray 715 | ) -> np.ndarray: 716 | """ 717 | Apply GrabCut post-processing to refine the final contours. 718 | 719 | Args: 720 | img: Input image tensor (B, C, H, W) 721 | final_contours: Final contours from DCF (B, K, 2) - already in image coordinates 722 | 723 | Returns: 724 | Refined contours after GrabCut processing 725 | """ 726 | try: 727 | refined_contours = [] 728 | 729 | for i in range(img.shape[0]): 730 | # Convert tensor to numpy 731 | img_np = img[i].cpu().numpy() 732 | img_np = np.moveaxis(img_np, 0, -1) # (C, H, W) -> (H, W, C) 733 | img_np = (img_np * 255).astype(np.uint8) 734 | 735 | # Get contour for this batch 736 | contour = final_contours[i] 737 | 738 | # In oneshot_dcf, contours are already in image coordinates 739 | # Just ensure they are within bounds 740 | h, w = img_np.shape[:2] 741 | contour = np.clip(contour, 0, [w - 1, h - 1]).astype(np.int32) 742 | 743 | # Create mask from contour 744 | mask = np.zeros((h, w), dtype=np.uint8) 745 | 746 | # Fix contour format for cv2.fillPoly 747 | if len(contour.shape) == 2: 748 | contour_for_fill = contour.reshape(-1, 1, 2) 749 | else: 750 | contour_for_fill = contour 751 | 752 | cv2.fillPoly(mask, [contour_for_fill], 1) 753 | 754 | # Debug: check if mask is valid 755 | logger.info( 756 | f"Sample {i}: contour shape={contour.shape}, range=[{contour.min()}, {contour.max()}], mask_sum={np.sum(mask)}" 757 | ) 758 | 759 | if np.sum(mask) == 0: 760 | logger.warning(f"Empty mask for sample {i}, skipping GrabCut") 761 | refined_contours.append(contour) 762 | continue 763 | 764 | # Apply GrabCut 765 | distance_map = distance_transform_edt(mask) 766 | distance_map = distance_map / np.max(distance_map) 767 | distance_map_outside = distance_transform_edt(1 - mask) 768 | distance_map_outside = distance_map_outside / np.max( 769 | distance_map_outside 770 | ) 771 | 772 | mask_grabcut = np.full(mask.shape, cv2.GC_PR_BGD, dtype=np.uint8) 773 | mask_grabcut[distance_map > 0.8] = cv2.GC_FGD 774 | mask_grabcut[(distance_map > 0.5) & (distance_map <= 0.8)] = ( 775 | cv2.GC_PR_FGD 776 | ) 777 | mask_grabcut[distance_map_outside > 0.8] = cv2.GC_BGD 778 | 779 | bgdModel = np.zeros((1, 65), np.float64) 780 | fgdModel = np.zeros((1, 65), np.float64) 781 | 782 | cv2.grabCut( 783 | img_np, 784 | mask_grabcut, 785 | None, 786 | bgdModel, 787 | fgdModel, 788 | 5, 789 | cv2.GC_INIT_WITH_MASK, 790 | ) 791 | 792 | result = np.where( 793 | (mask_grabcut == cv2.GC_FGD) | (mask_grabcut == cv2.GC_PR_FGD), 1, 0 794 | ).astype(np.uint8) 795 | 796 | # Get largest connected component 797 | labeled_array, num_features = label(result) 798 | if num_features > 0: 799 | largest_cc = np.argmax(np.bincount(labeled_array.flat)[1:]) + 1 800 | result = (labeled_array == largest_cc).astype(np.uint8) 801 | 802 | # Find contours from refined mask 803 | contours, _ = cv2.findContours( 804 | result, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 805 | ) 806 | 807 | if contours: 808 | # Get the largest contour 809 | largest_contour = max(contours, key=cv2.contourArea) 810 | refined_contours.append(largest_contour.reshape(-1, 2)) 811 | else: 812 | # Fallback to original contour 813 | refined_contours.append(contour) 814 | 815 | logger.info("GrabCut post-processing completed") 816 | return np.array(refined_contours) 817 | 818 | except Exception as e: 819 | logger.error(f"Error in GrabCut post-processing: {e}") 820 | return final_contours # Return original contours if GrabCut fails 821 | --------------------------------------------------------------------------------