├── README.md ├── cellseg_time_eval.py ├── classification ├── tmp ├── train_classification.py └── unsup_classification.py ├── classifiers.py ├── compute_metric.py ├── finetune_convnext_stardist.py ├── fintune_on_newdataset ├── classifiers.py ├── compute_metric.py ├── fintune.ipynb ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── convnext.cpython-37.pyc │ │ ├── convnext.cpython-38.pyc │ │ ├── flexible_unet.cpython-37.pyc │ │ ├── flexible_unet.cpython-38.pyc │ │ ├── flexible_unet_convext.cpython-37.pyc │ │ ├── flexible_unet_convext.cpython-38.pyc │ │ ├── flexible_unet_convnext.cpython-37.pyc │ │ ├── swin_unetr.cpython-38.pyc │ │ └── unetr2d.cpython-38.pyc │ ├── convnext.py │ ├── flexible_unet.py │ ├── flexible_unet_convext.py │ └── flexible_unet_convnext.py ├── overlay.py ├── stardist_pkg │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── big.cpython-37.pyc │ │ ├── bioimageio_utils.cpython-37.pyc │ │ ├── matching.cpython-37.pyc │ │ ├── nms.cpython-37.pyc │ │ ├── sample_patches.cpython-37.pyc │ │ ├── utils.cpython-37.pyc │ │ └── version.cpython-37.pyc │ ├── big.py │ ├── bioimageio_utils.py │ ├── geometry │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── geom2d.cpython-37.pyc │ │ │ └── geom3d.cpython-37.pyc │ │ └── geom2d.py │ ├── kernels │ │ ├── stardist2d.cl │ │ └── stardist3d.cl │ ├── matching.py │ ├── models │ │ ├── __init__.py │ │ ├── base.py │ │ └── model2d.py │ ├── nms.py │ ├── rays3d.py │ ├── sample_patches.py │ ├── utils.py │ └── version.py ├── train_classification.py ├── unsup_classification.py ├── utils.py └── utils_modify.py ├── license ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── convnext.cpython-37.pyc │ ├── convnext.cpython-38.pyc │ ├── flexible_unet.cpython-37.pyc │ ├── flexible_unet.cpython-38.pyc │ ├── flexible_unet_convext.cpython-37.pyc │ ├── flexible_unet_convext.cpython-38.pyc │ ├── flexible_unet_convnext.cpython-37.pyc │ ├── swin_unetr.cpython-38.pyc │ ├── tmp │ └── unetr2d.cpython-38.pyc ├── convnext.py ├── flexible_unet.py ├── flexible_unet_convext.py └── flexible_unet_convnext.py ├── overlay.py ├── predict.py ├── predict_unet_convnext.py ├── requirements.txt ├── stardist_pkg ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── big.cpython-37.pyc │ ├── bioimageio_utils.cpython-37.pyc │ ├── matching.cpython-37.pyc │ ├── nms.cpython-37.pyc │ ├── sample_patches.cpython-37.pyc │ ├── utils.cpython-37.pyc │ └── version.cpython-37.pyc ├── big.py ├── bioimageio_utils.py ├── geometry │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── geom2d.cpython-37.pyc │ │ └── geom3d.cpython-37.pyc │ └── geom2d.py ├── kernels │ ├── stardist2d.cl │ └── stardist3d.cl ├── matching.py ├── models │ ├── __init__.py │ ├── base.py │ └── model2d.py ├── nms.py ├── rays3d.py ├── sample_patches.py ├── utils.py └── version.py ├── train_convnext_hover..py ├── train_convnext_stardist.py ├── utils.py └── utils_modify.py /README.md: -------------------------------------------------------------------------------- 1 | # Solution of SRIBD-Med Team for NeurIPS2022-CellSeg Challenge 2 | 3 | BEFORE YOU RAISE AN ISSUE, PLEASE SEND YOUR QUESTIONS TO lhaof\@sribd.cn AND weilou\@link.cuhk.edu.cn 4 | 5 | Institution: Shenzhen Research Institute of Big Data (SRIBD, http://www.sribd.cn/) 6 | Authors: Wei Lou\*, Xinyi Yu\*, Chenyu Liu\*, Xiang Wan, Guanbin Li, Siqi Liu, Haofeng Li\# (http://haofengli.net/) 7 | 8 | This repository provides the solution of team Sribd-med for [NeurIPS-CellSeg](https://neurips22-cellseg.grand-challenge.org/) Challenge. The details of our method are described in our paper [Multi-stream Cell Segmentation with Low-level Cues for Multi-modality Images]. Some parts of the codes are from the baseline codes of the [NeurIPS-CellSeg-Baseline](https://github.com/JunMa11/NeurIPS-CellSeg) repository, 9 | 10 | You can reproduce our method as follows step by step: 11 | 12 | ## Environments and Requirements: 13 | Install requirements by 14 | 15 | ```shell 16 | python -m pip install -r requirements.txt 17 | ``` 18 | 19 | ## Dataset 20 | The competition training and tuning data can be downloaded from https://neurips22-cellseg.grand-challenge.org/dataset/ 21 | Besides, you can download three publiced data from the following link: 22 | Cellpose: https://www.cellpose.org/dataset  23 | Omnipose: http://www.cellpose.org/dataset_omnipose 24 | Sartorius: https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview  25 | 26 | ## Automatic cell classification 27 | You can classify the cells into four classes in this step. 28 | Put all the images (competition + Cellpose + Omnipose + Sartorius) in one folder (data/allimages). 29 | Run classification code: 30 | 31 | ```shell 32 | python classification/unsup_classification.py 33 | ``` 34 | The results can be stored in data/classification_results/ 35 | 36 | ## CNN-base classification model training 37 | Using the classified images in data/classification_results/. Stay connected to the Internet and the code may automatically download the necessary ImageNet-Pretrained weights. A resnet18 is trained: 38 | ```shell 39 | python classification/train_classification.py 40 | ``` 41 | ## Segmentation Training 42 | Pre-training convnext-stardist using all the images (data/allimages). 43 | ```shell 44 | python train_convnext_stardist.py 45 | ``` 46 | For class 0,2,3 finetune on the classified data (Take class1 as a example): 47 | ```shell 48 | python finetune_convnext_stardist.py model_dir=(The pretrained convnext-stardist model) data_dir='data/classification_results/class1' 49 | ``` 50 | For class 1 train the convnext-hover from scratch using classified class 3 data. 51 | ```shell 52 | python train_convnext_hover.py data_dir='data/classification_results/class3' 53 | ``` 54 | 55 | Finally, four segmentation models will be trained. 56 | 57 | ## Trained models 58 | The models can be downloaded from this link: 59 | https://drive.google.com/drive/folders/1MkEOpgmdkg5Yqw6Ng5PoOhtmo9xPPwIj?usp=sharing 60 | 61 | Docker environment: 62 | ```shell 63 | docker push lewislou/sribd-cellseg:tagname 64 | ``` 65 | 66 | ## Inference 67 | The inference process includes classification and segmentation. 68 | ```shell 69 | python predict.py -i input_path -o output_path --model_path './models' 70 | ``` 71 | Colab codes for model inference: https://colab.research.google.com/drive/1Dk6V6vm0IqaIevjAyjUTuR1nZfT6EvCh?usp=sharing 72 | ## Evaluation 73 | Calculate the F-score for evaluation: 74 | ```shell 75 | python compute_metric.py --gt_path path_to_labels --seg_path output_path 76 | ``` 77 | ## Finetue on a new dataset 78 | We provide a jupyter notebook to train our model on a new dataset - Cellpose step by step. 79 | The notebook codes are in the folder fintune_on_newdataset/fintune.ipynb 80 | 81 | ## Results 82 | The tuning set F1 score of our method is 0.8795. The Running time with tolerance of our method on all the 101 cases in the tuning set is 0 (within the time tolerance) in our local workstation. 83 | 84 | ## Acknowledgement 85 | We thank for the contributors of public datasets. We thank for the support from Shenzhen Research Institute of Big Data (SRIBD, http://www.sribd.cn/) 86 | 87 | -------------------------------------------------------------------------------- /cellseg_time_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code was adapted from the MICCAI FLARE Challenge 3 | https://flare22.grand-challenge.org/ 4 | 5 | The testing images will be evaluated one by one. 6 | To compensate for the Docker container startup time, we give a time tolerance for the running time. 7 | https://neurips22-cellseg.grand-challenge.org/metrics/ 8 | """ 9 | 10 | import os 11 | join = os.path.join 12 | import sys 13 | import shutil 14 | import time 15 | import torch 16 | import argparse 17 | from collections import OrderedDict 18 | from skimage import io 19 | import tifffile as tif 20 | import numpy as np 21 | import pandas as pd 22 | 23 | parser = argparse.ArgumentParser('Segmentation efficiency eavluation for docker containers', add_help=False) 24 | parser.add_argument('-i', '--test_img_path', default='./val-imgs-30/', type=str, help='testing data path') 25 | parser.add_argument('-o','--save_path', default='./val_team_seg', type=str, help='segmentation output path') 26 | parser.add_argument('-d','--docker_folder_path', default='./team_docker', type=str, help='team docker path') 27 | args = parser.parse_args() 28 | 29 | test_img_path = args.test_img_path 30 | save_path = args.save_path 31 | docker_path = args.docker_folder_path 32 | 33 | input_temp = './inputs/' 34 | output_temp = './outputs' 35 | os.makedirs(save_path, exist_ok=True) 36 | 37 | dockers = sorted(os.listdir(docker_path)) 38 | test_cases = sorted(os.listdir(test_img_path)) 39 | 40 | for docker in dockers: 41 | try: 42 | # create temp folers for inference one-by-one 43 | if os.path.exists(input_temp): 44 | shutil.rmtree(input_temp) 45 | if os.path.exists(output_temp): 46 | shutil.rmtree(output_temp) 47 | os.makedirs(input_temp) 48 | os.makedirs(output_temp) 49 | # load docker and create a new folder to save segmentation results 50 | teamname = docker.split('.')[0].lower() 51 | print('teamname docker: ', docker) 52 | # os.system('docker image load < {}'.format(join(docker_path, docker))) 53 | team_outpath = join(save_path, teamname) 54 | if os.path.exists(team_outpath): 55 | shutil.rmtree(team_outpath) 56 | os.mkdir(team_outpath) 57 | metric = OrderedDict() 58 | metric['Img Name'] = [] 59 | metric['Real Running Time'] = [] 60 | metric['Rank Running Time'] = [] 61 | # To obtain the running time for each case, we inference the testing case one-by-one 62 | for case in test_cases: 63 | shutil.copy(join(test_img_path, case), input_temp) 64 | if case.endswith('.tif') or case.endswith('.tiff'): 65 | img = tif.imread(join(input_temp, case)) 66 | else: 67 | img = io.imread(join(input_temp, case)) 68 | pix_num = img.shape[0] * img.shape[1] 69 | cmd = 'docker container run --gpus="device=0" -m 28g --name {} --rm -v $PWD/inputs/:/workspace/inputs/ -v $PWD/outputs/:/workspace/outputs/ {}:latest /bin/bash -c "sh predict.sh" '.format(teamname, teamname) 70 | print(teamname, ' docker command:', cmd, '\n', 'testing image name:', case) 71 | start_time = time.time() 72 | os.system(cmd) 73 | real_running_time = time.time() - start_time 74 | print(f"{case} finished! Inference time: {real_running_time}") 75 | # save metrics 76 | metric['Img Name'].append(case) 77 | metric['Real Running Time'].append(real_running_time) 78 | if pix_num <= 1000000: 79 | rank_running_time = np.max([0, real_running_time-10]) 80 | else: 81 | rank_running_time = np.max([0, real_running_time-10*(pix_num/1000000)]) 82 | metric['Rank Running Time'].append(rank_running_time) 83 | os.remove(join(input_temp, case)) 84 | seg_name = case.split('.')[0] + '_label.tiff' 85 | try: 86 | os.rename(join(output_temp, seg_name), join(team_outpath, seg_name)) 87 | except: 88 | print(f"{join(output_temp, seg_name)}, {join(team_outpath, seg_name)}") 89 | print("Wrong segmentation name!!! It should be image_name.split(\'.\')[0] + \'_label.tiff\' ") 90 | metric_df = pd.DataFrame(metric) 91 | metric_df.to_csv(join(team_outpath, teamname + '_running_time.csv'), index=False) 92 | torch.cuda.empty_cache() 93 | # os.system("docker rmi {}:latest".format(teamname)) 94 | shutil.rmtree(input_temp) 95 | shutil.rmtree(output_temp) 96 | except Exception as e: 97 | print(e) 98 | -------------------------------------------------------------------------------- /classification/tmp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /classification/train_classification.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob, time, random, shutil, copy 2 | from tqdm import tqdm 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torchvision 7 | from torchvision import datasets, models, transforms 8 | import torch.utils.data as data 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.optim import lr_scheduler 12 | import torch.nn.functional as F 13 | from torchsummary import summary 14 | from matplotlib import pyplot as plt 15 | from torchvision.models import resnet18, ResNet18_Weights # do not import 16 | from PIL import Image, ImageFile 17 | from skimage import io 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | 20 | # Set the train and validation directory paths 21 | train_directory = 'dataset/train' 22 | valid_directory = 'dataset/val' 23 | 24 | # Batch size 25 | bs = 64 26 | # Number of epochs 27 | num_epochs = 20 28 | # Number of classes 29 | num_classes = 4 30 | # Number of workers 31 | num_cpu = 8 32 | 33 | # Applying transforms to the data 34 | image_transforms = { 35 | 'train': transforms.Compose([ 36 | transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), 37 | transforms.RandomRotation(degrees=15), 38 | transforms.RandomHorizontalFlip(), 39 | transforms.CenterCrop(size=224), 40 | transforms.ToTensor(), 41 | transforms.Normalize([0.485, 0.456, 0.406], 42 | [0.229, 0.224, 0.225]) 43 | ]), 44 | 'valid': transforms.Compose([ 45 | transforms.Resize(size=256), 46 | transforms.CenterCrop(size=224), 47 | transforms.ToTensor(), 48 | transforms.Normalize([0.485, 0.456, 0.406], 49 | [0.229, 0.224, 0.225]) 50 | ]) 51 | } 52 | 53 | # Load data from folders 54 | dataset = { 55 | 'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']), 56 | 'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid']) 57 | } 58 | 59 | # Size of train and validation data 60 | dataset_sizes = { 61 | 'train':len(dataset['train']), 62 | 'valid':len(dataset['valid']) 63 | } 64 | 65 | # Create iterators for data loading 66 | dataloaders = { 67 | 'train':data.DataLoader(dataset['train'], batch_size=bs, shuffle=True, 68 | num_workers=num_cpu, pin_memory=True, drop_last=False), 69 | 'valid':data.DataLoader(dataset['valid'], batch_size=bs, shuffle=False, 70 | num_workers=num_cpu, pin_memory=True, drop_last=False) 71 | } 72 | 73 | # Class names or target labels 74 | class_names = dataset['train'].classes 75 | print("Classes:", class_names) 76 | 77 | # Print the train and validation data sizes 78 | print("Training-set size:",dataset_sizes['train'], 79 | "\nValidation-set size:", dataset_sizes['valid']) 80 | 81 | modelname = 'resnet18' 82 | 83 | # Set default device as gpu, if available 84 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 85 | 86 | weights = ResNet18_Weights.DEFAULT 87 | model = resnet18(weights=None) 88 | num_ftrs = model.fc.in_features 89 | model.fc = nn.Linear(num_ftrs, num_classes) 90 | 91 | 92 | # Transfer the model to GPU 93 | model = model.to(device) 94 | 95 | # Print model summary 96 | print('Model Summary:-\n') 97 | for num, (name, param) in enumerate(model.named_parameters()): 98 | print(num, name, param.requires_grad ) 99 | summary(model, input_size=(3, 224, 224)) 100 | 101 | # Loss function 102 | criterion = nn.CrossEntropyLoss() 103 | 104 | # Optimizer 105 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 106 | 107 | # Learning rate decay 108 | scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) 109 | 110 | since = time.time() 111 | 112 | best_model_wts = copy.deepcopy(model.state_dict()) 113 | best_acc = 0.0 114 | 115 | for epoch in range(1, num_epochs+1): 116 | print('Epoch {}/{}'.format(epoch, num_epochs)) 117 | print('-' * 10) 118 | 119 | # Each epoch has a training and validation phase 120 | for phase in ['train', 'valid']: 121 | if phase == 'train': 122 | model.train() # Set model to training mode 123 | else: 124 | model.eval() # Set model to evaluate mode 125 | 126 | running_loss = 0.0 127 | running_corrects = 0 128 | 129 | # Iterate over data. 130 | n = 0 131 | stream = tqdm(dataloaders[phase]) 132 | for i, (inputs, labels) in enumerate(stream, start=1): 133 | inputs = inputs.to(device) 134 | labels = labels.to(device) 135 | 136 | # zero the parameter gradients 137 | optimizer.zero_grad() 138 | 139 | # forward 140 | # track history if only in train 141 | with torch.set_grad_enabled(phase == 'train'): 142 | outputs = model(inputs) 143 | _, preds = torch.max(outputs, 1) 144 | loss = criterion(outputs, labels) 145 | 146 | # backward + optimize only if in training phase 147 | if phase == 'train': 148 | loss.backward() 149 | optimizer.step() 150 | 151 | # statistics 152 | n += inputs.shape[0] 153 | running_loss += loss.item() * inputs.size(0) 154 | running_corrects += torch.sum(preds == labels.data) 155 | 156 | stream.set_description(f'Batch {i}/{len(dataloaders[phase])} | Loss: {running_loss/n:.4f}, Acc: {running_corrects/n:.4f}') 157 | 158 | if phase == 'train': 159 | scheduler.step() 160 | 161 | epoch_loss = running_loss / dataset_sizes[phase] 162 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 163 | 164 | print('Epoch {}-{} Loss: {:.4f} Acc: {:.4f}'.format( 165 | epoch, phase, epoch_loss, epoch_acc)) 166 | 167 | # deep copy the model 168 | if phase == 'valid' and epoch_acc >= best_acc: 169 | best_acc = epoch_acc 170 | best_model_wts = copy.deepcopy(model.state_dict()) 171 | print('Update best model!') 172 | 173 | time_elapsed = time.time() - since 174 | print('Training complete in {:.0f}m {:.0f}s'.format( 175 | time_elapsed // 60, time_elapsed % 60)) 176 | print('Best val Acc: {:4f}'.format(best_acc)) 177 | 178 | # load best model weights 179 | model.load_state_dict(best_model_wts) 180 | torch.save(model, 'logs/resnet18_4class.pth') 181 | torch.save(model.state_dict(), 'logs/resnet18_4class.tar') 182 | -------------------------------------------------------------------------------- /classification/unsup_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import os 8 | import numpy as np 9 | import shutil 10 | import torch 11 | import torch.nn 12 | import torchvision.models as models 13 | from torch.autograd import Variable 14 | import torch.cuda 15 | import torchvision.transforms as transforms 16 | from PIL import Image 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from sklearn.datasets import make_blobs 20 | from sklearn.cluster import KMeans 21 | from sklearn.metrics import silhouette_score 22 | from sklearn.preprocessing import StandardScaler 23 | from sklearn.metrics import pairwise_distances_argmin_min 24 | from scipy.spatial.distance import pdist, squareform 25 | from skimage import io, segmentation, morphology, exposure 26 | from skimage.color import rgb2hsv 27 | img_to_tensor = transforms.ToTensor() 28 | import random 29 | import tifffile as tif 30 | path = '/data1/partitionA/CUHKSZ/histopath_2022/grand_competition/Train_Labeled/images/' 31 | files = os.listdir(path) 32 | binary_path = '0/' 33 | gray_path = '1/' 34 | colored_path = 'colored/' 35 | os.makedirs(binary_path, exist_ok=True) 36 | os.makedirs(colored_path, exist_ok=True) 37 | os.makedirs(gray_path, exist_ok=True) 38 | for img_name in files: 39 | img_path = path + str(img_name) 40 | if img_name.endswith('.tif') or img_name.endswith('.tiff'): 41 | img_data = tif.imread(img_path) 42 | else: 43 | img_data = io.imread(img_path) 44 | if len(img_data.shape) == 2 or (len(img_data.shape) == 3 and img_data.shape[-1] == 1): 45 | shutil.copyfile(path + img_name, binary_path + img_name) 46 | elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: 47 | shutil.copyfile(path + img_name, colored_path + img_name) 48 | else: 49 | hsv_img = rgb2hsv(img_data) 50 | s = hsv_img[:,:,1] 51 | v = hsv_img[:,:,2] 52 | print(img_name,s.mean(),v.mean()) 53 | if s.mean() > 0.1 or (v.mean()<0.1 or v.mean() > 0.6): 54 | shutil.copyfile(path + img_name, colored_path + img_name) 55 | else: 56 | shutil.copyfile(path + img_name, gray_path + img_name) 57 | 58 | 59 | 60 | # In[3]: 61 | 62 | 63 | ####Phrase 2 clustering by cell size 64 | from skimage import measure 65 | colored_path = 'colored/' 66 | label_path = 'allimages/tif/' 67 | big_path = '2/' 68 | small_path = '3/' 69 | files = os.listdir(colored_path) 70 | os.makedirs(big_path, exist_ok=True) 71 | os.makedirs(small_path, exist_ok=True) 72 | for img_name in files: 73 | label = tif.imread(label_path + img_name.split('.')[0]+'.tif') 74 | props = measure.regionprops(label) 75 | num_pix = [] 76 | for idx in range(len(props)): 77 | num_pix.append(props[idx].area) 78 | max_area = max(num_pix) 79 | print(max_area) 80 | if max_area > 30000: 81 | shutil.copyfile(path + img_name, big_path + img_name) 82 | else: 83 | shutil.copyfile(path + img_name, small_path + img_name) 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /classifiers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d( 11 | in_planes, 12 | out_planes, 13 | kernel_size=3, 14 | stride=stride, 15 | padding=dilation, 16 | groups=groups, 17 | bias=False, 18 | dilation=dilation, 19 | ) 20 | 21 | 22 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion: int = 1 29 | 30 | def __init__( 31 | self, 32 | inplanes: int, 33 | planes: int, 34 | stride: int = 1, 35 | downsample: Optional[nn.Module] = None, 36 | groups: int = 1, 37 | base_width: int = 64, 38 | dilation: int = 1, 39 | norm_layer: Optional[Callable[..., nn.Module]] = None, 40 | ) -> None: 41 | super().__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x: Tensor) -> Tensor: 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | class Bottleneck(nn.Module): 76 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 77 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 78 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 79 | # This variant is also known as ResNet V1.5 and improves accuracy according to 80 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 81 | 82 | expansion: int = 4 83 | 84 | def __init__( 85 | self, 86 | inplanes: int, 87 | planes: int, 88 | stride: int = 1, 89 | downsample: Optional[nn.Module] = None, 90 | groups: int = 1, 91 | base_width: int = 64, 92 | dilation: int = 1, 93 | norm_layer: Optional[Callable[..., nn.Module]] = None, 94 | ) -> None: 95 | super().__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.0)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.downsample = downsample 108 | self.stride = stride 109 | 110 | def forward(self, x: Tensor) -> Tensor: 111 | identity = x 112 | 113 | out = self.conv1(x) 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv3(out) 122 | out = self.bn3(out) 123 | 124 | if self.downsample is not None: 125 | identity = self.downsample(x) 126 | 127 | out += identity 128 | out = self.relu(out) 129 | 130 | return out 131 | 132 | class ResNet(nn.Module): 133 | def __init__( 134 | self, 135 | block: Type[Union[BasicBlock, Bottleneck]], 136 | layers: List[int], 137 | num_classes: int = 1000, 138 | zero_init_residual: bool = False, 139 | groups: int = 1, 140 | width_per_group: int = 64, 141 | replace_stride_with_dilation: Optional[List[bool]] = None, 142 | norm_layer: Optional[Callable[..., nn.Module]] = None, 143 | ) -> None: 144 | super().__init__() 145 | # _log_api_usage_once(self) 146 | if norm_layer is None: 147 | norm_layer = nn.BatchNorm2d 148 | self._norm_layer = norm_layer 149 | 150 | self.inplanes = 64 151 | self.dilation = 1 152 | if replace_stride_with_dilation is None: 153 | # each element in the tuple indicates if we should replace 154 | # the 2x2 stride with a dilated convolution instead 155 | replace_stride_with_dilation = [False, False, False] 156 | if len(replace_stride_with_dilation) != 3: 157 | raise ValueError( 158 | "replace_stride_with_dilation should be None " 159 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 160 | ) 161 | self.groups = groups 162 | self.base_width = width_per_group 163 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 164 | self.bn1 = norm_layer(self.inplanes) 165 | self.relu = nn.ReLU(inplace=True) 166 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 167 | self.layer1 = self._make_layer(block, 64, layers[0]) 168 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 169 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 170 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 171 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 172 | self.fc = nn.Linear(512 * block.expansion, num_classes) 173 | 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 177 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 178 | nn.init.constant_(m.weight, 1) 179 | nn.init.constant_(m.bias, 0) 180 | 181 | # Zero-initialize the last BN in each residual branch, 182 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 183 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 184 | if zero_init_residual: 185 | for m in self.modules(): 186 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 187 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 188 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 189 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 190 | 191 | def _make_layer( 192 | self, 193 | block: Type[Union[BasicBlock, Bottleneck]], 194 | planes: int, 195 | blocks: int, 196 | stride: int = 1, 197 | dilate: bool = False, 198 | ) -> nn.Sequential: 199 | norm_layer = self._norm_layer 200 | downsample = None 201 | previous_dilation = self.dilation 202 | if dilate: 203 | self.dilation *= stride 204 | stride = 1 205 | if stride != 1 or self.inplanes != planes * block.expansion: 206 | downsample = nn.Sequential( 207 | conv1x1(self.inplanes, planes * block.expansion, stride), 208 | norm_layer(planes * block.expansion), 209 | ) 210 | 211 | layers = [] 212 | layers.append( 213 | block( 214 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 215 | ) 216 | ) 217 | self.inplanes = planes * block.expansion 218 | for _ in range(1, blocks): 219 | layers.append( 220 | block( 221 | self.inplanes, 222 | planes, 223 | groups=self.groups, 224 | base_width=self.base_width, 225 | dilation=self.dilation, 226 | norm_layer=norm_layer, 227 | ) 228 | ) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def _forward_impl(self, x: Tensor) -> Tensor: 233 | # See note [TorchScript super()] 234 | x = self.conv1(x) 235 | x = self.bn1(x) 236 | x = self.relu(x) 237 | x = self.maxpool(x) 238 | 239 | x = self.layer1(x) 240 | x = self.layer2(x) 241 | x = self.layer3(x) 242 | x = self.layer4(x) 243 | 244 | x = self.avgpool(x) 245 | x = torch.flatten(x, 1) 246 | x = self.fc(x) 247 | 248 | return x 249 | 250 | def forward(self, x: Tensor) -> Tensor: 251 | return self._forward_impl(x) 252 | 253 | def resnet18(weights=None): 254 | # weights: path 255 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=4) 256 | if weights is not None: 257 | model.load_state_dict(torch.load(weights)) 258 | return model 259 | 260 | def resnet10(): 261 | return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=4) 262 | -------------------------------------------------------------------------------- /compute_metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Mar 31 18:10:52 2022 3 | adapted form https://github.com/stardist/stardist/blob/master/stardist/matching.py 4 | Thanks the authors of Stardist for sharing the great code 5 | 6 | """ 7 | 8 | import argparse 9 | import numpy as np 10 | from numba import jit 11 | from scipy.optimize import linear_sum_assignment 12 | from collections import OrderedDict 13 | import pandas as pd 14 | from skimage import segmentation 15 | import tifffile as tif 16 | import os 17 | join = os.path.join 18 | from tqdm import tqdm 19 | 20 | def _intersection_over_union(masks_true, masks_pred): 21 | """ intersection over union of all mask pairs 22 | 23 | Parameters 24 | ------------ 25 | 26 | masks_true: ND-array, int 27 | ground truth masks, where 0=NO masks; 1,2... are mask labels 28 | masks_pred: ND-array, int 29 | predicted masks, where 0=NO masks; 1,2... are mask labels 30 | """ 31 | overlap = _label_overlap(masks_true, masks_pred) 32 | n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) 33 | n_pixels_true = np.sum(overlap, axis=1, keepdims=True) 34 | iou = overlap / (n_pixels_pred + n_pixels_true - overlap) 35 | iou[np.isnan(iou)] = 0.0 36 | return iou 37 | 38 | @jit(nopython=True) 39 | def _label_overlap(x, y): 40 | """ fast function to get pixel overlaps between masks in x and y 41 | 42 | Parameters 43 | ------------ 44 | 45 | x: ND-array, int 46 | where 0=NO masks; 1,2... are mask labels 47 | y: ND-array, int 48 | where 0=NO masks; 1,2... are mask labels 49 | 50 | Returns 51 | ------------ 52 | 53 | overlap: ND-array, int 54 | matrix of pixel overlaps of size [x.max()+1, y.max()+1] 55 | 56 | """ 57 | x = x.ravel() 58 | y = y.ravel() 59 | 60 | # preallocate a 'contact map' matrix 61 | overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint) 62 | 63 | # loop over the labels in x and add to the corresponding 64 | # overlap entry. If label A in x and label B in y share P 65 | # pixels, then the resulting overlap is P 66 | # len(x)=len(y), the number of pixels in the whole image 67 | for i in range(len(x)): 68 | overlap[x[i],y[i]] += 1 69 | return overlap 70 | 71 | def _true_positive(iou, th): 72 | """ true positive at threshold th 73 | 74 | Parameters 75 | ------------ 76 | 77 | iou: float, ND-array 78 | array of IOU pairs 79 | th: float 80 | threshold on IOU for positive label 81 | 82 | Returns 83 | ------------ 84 | 85 | tp: float 86 | number of true positives at threshold 87 | """ 88 | n_min = min(iou.shape[0], iou.shape[1]) 89 | costs = -(iou >= th).astype(float) - iou / (2*n_min) 90 | true_ind, pred_ind = linear_sum_assignment(costs) 91 | match_ok = iou[true_ind, pred_ind] >= th 92 | tp = match_ok.sum() 93 | return tp 94 | 95 | def eval_tp_fp_fn(masks_true, masks_pred, threshold=0.5): 96 | num_inst_gt = np.max(masks_true) 97 | num_inst_seg = np.max(masks_pred) 98 | if num_inst_seg>0: 99 | iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:] 100 | # for k,th in enumerate(threshold): 101 | tp = _true_positive(iou, threshold) 102 | fp = num_inst_seg - tp 103 | fn = num_inst_gt - tp 104 | else: 105 | print('No segmentation results!') 106 | tp = 0 107 | fp = 0 108 | fn = 0 109 | 110 | return tp, fp, fn 111 | 112 | def remove_boundary_cells(mask): 113 | W, H = mask.shape 114 | bd = np.ones((W, H)) 115 | bd[2:W-2, 2:H-2] = 0 116 | bd_cells = np.unique(mask*bd) 117 | for i in bd_cells[1:]: 118 | mask[mask==i] = 0 119 | new_label,_,_ = segmentation.relabel_sequential(mask) 120 | return new_label 121 | 122 | def main(): 123 | parser = argparse.ArgumentParser('Compute F1 score for cell segmentation results', add_help=False) 124 | # Dataset parameters 125 | parser.add_argument('--gt_path', type=str, help='path to ground truth; file names end with _label.tiff', required=True) 126 | parser.add_argument('--seg_path', type=str, help='path to segmentation results; file names are the same as ground truth', required=True) 127 | parser.add_argument('--save_path', default='./', help='path where to save metrics') 128 | args = parser.parse_args() 129 | 130 | gt_path = args.gt_path 131 | seg_path = args.seg_path 132 | names = sorted(os.listdir(seg_path)) 133 | seg_metric = OrderedDict() 134 | seg_metric['Names'] = [] 135 | seg_metric['F1_Score'] = [] 136 | for name in tqdm(names): 137 | assert name.endswith('_label.tiff'), 'The suffix of label name should be _label.tiff' 138 | 139 | # Load the images for this case 140 | gt = tif.imread(join(gt_path, name)) 141 | seg = tif.imread(join(seg_path, name)) 142 | 143 | # Score the cases 144 | # do not consider cells on the boundaries during evaluation 145 | if np.prod(gt.shape)<25000000: 146 | gt = remove_boundary_cells(gt.astype(np.int32)) 147 | seg = remove_boundary_cells(seg.astype(np.int32)) 148 | tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5) 149 | else: # for large images (>5000x5000), the F1 score is computed by a patch-based way 150 | H, W = gt.shape 151 | roi_size = 2000 152 | 153 | if H % roi_size != 0: 154 | n_H = H // roi_size + 1 155 | new_H = roi_size * n_H 156 | else: 157 | n_H = H // roi_size 158 | new_H = H 159 | 160 | if W % roi_size != 0: 161 | n_W = W // roi_size + 1 162 | new_W = roi_size * n_W 163 | else: 164 | n_W = W // roi_size 165 | new_W = W 166 | 167 | gt_pad = np.zeros((new_H, new_W), dtype=gt.dtype) 168 | seg_pad = np.zeros((new_H, new_W), dtype=gt.dtype) 169 | gt_pad[:H, :W] = gt 170 | seg_pad[:H, :W] = seg 171 | 172 | tp = 0 173 | fp = 0 174 | fn = 0 175 | for i in range(n_H): 176 | for j in range(n_W): 177 | gt_roi = remove_boundary_cells(gt_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)]) 178 | seg_roi = remove_boundary_cells(seg_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)]) 179 | tp_i, fp_i, fn_i = eval_tp_fp_fn(gt_roi, seg_roi, threshold=0.5) 180 | tp += tp_i 181 | fp += fp_i 182 | fn += fn_i 183 | 184 | if tp == 0: 185 | precision = 0 186 | recall = 0 187 | f1 = 0 188 | else: 189 | precision = tp / (tp + fp) 190 | recall = tp / (tp + fn) 191 | f1 = 2*(precision * recall)/ (precision + recall) 192 | seg_metric['Names'].append(name) 193 | seg_metric['F1_Score'].append(np.round(f1, 4)) 194 | 195 | 196 | seg_metric_df = pd.DataFrame(seg_metric) 197 | seg_metric_df.to_csv(join(args.save_path, 'seg_metric.csv'), index=False) 198 | print('mean F1 Score:', np.mean(seg_metric['F1_Score'])) 199 | 200 | if __name__ == '__main__': 201 | main() 202 | -------------------------------------------------------------------------------- /fintune_on_newdataset/classifiers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d( 11 | in_planes, 12 | out_planes, 13 | kernel_size=3, 14 | stride=stride, 15 | padding=dilation, 16 | groups=groups, 17 | bias=False, 18 | dilation=dilation, 19 | ) 20 | 21 | 22 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion: int = 1 29 | 30 | def __init__( 31 | self, 32 | inplanes: int, 33 | planes: int, 34 | stride: int = 1, 35 | downsample: Optional[nn.Module] = None, 36 | groups: int = 1, 37 | base_width: int = 64, 38 | dilation: int = 1, 39 | norm_layer: Optional[Callable[..., nn.Module]] = None, 40 | ) -> None: 41 | super().__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x: Tensor) -> Tensor: 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | class Bottleneck(nn.Module): 76 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 77 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 78 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 79 | # This variant is also known as ResNet V1.5 and improves accuracy according to 80 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 81 | 82 | expansion: int = 4 83 | 84 | def __init__( 85 | self, 86 | inplanes: int, 87 | planes: int, 88 | stride: int = 1, 89 | downsample: Optional[nn.Module] = None, 90 | groups: int = 1, 91 | base_width: int = 64, 92 | dilation: int = 1, 93 | norm_layer: Optional[Callable[..., nn.Module]] = None, 94 | ) -> None: 95 | super().__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.0)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.downsample = downsample 108 | self.stride = stride 109 | 110 | def forward(self, x: Tensor) -> Tensor: 111 | identity = x 112 | 113 | out = self.conv1(x) 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv3(out) 122 | out = self.bn3(out) 123 | 124 | if self.downsample is not None: 125 | identity = self.downsample(x) 126 | 127 | out += identity 128 | out = self.relu(out) 129 | 130 | return out 131 | 132 | class ResNet(nn.Module): 133 | def __init__( 134 | self, 135 | block: Type[Union[BasicBlock, Bottleneck]], 136 | layers: List[int], 137 | num_classes: int = 1000, 138 | zero_init_residual: bool = False, 139 | groups: int = 1, 140 | width_per_group: int = 64, 141 | replace_stride_with_dilation: Optional[List[bool]] = None, 142 | norm_layer: Optional[Callable[..., nn.Module]] = None, 143 | ) -> None: 144 | super().__init__() 145 | # _log_api_usage_once(self) 146 | if norm_layer is None: 147 | norm_layer = nn.BatchNorm2d 148 | self._norm_layer = norm_layer 149 | 150 | self.inplanes = 64 151 | self.dilation = 1 152 | if replace_stride_with_dilation is None: 153 | # each element in the tuple indicates if we should replace 154 | # the 2x2 stride with a dilated convolution instead 155 | replace_stride_with_dilation = [False, False, False] 156 | if len(replace_stride_with_dilation) != 3: 157 | raise ValueError( 158 | "replace_stride_with_dilation should be None " 159 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 160 | ) 161 | self.groups = groups 162 | self.base_width = width_per_group 163 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 164 | self.bn1 = norm_layer(self.inplanes) 165 | self.relu = nn.ReLU(inplace=True) 166 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 167 | self.layer1 = self._make_layer(block, 64, layers[0]) 168 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 169 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 170 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 171 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 172 | self.fc = nn.Linear(512 * block.expansion, num_classes) 173 | 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 177 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 178 | nn.init.constant_(m.weight, 1) 179 | nn.init.constant_(m.bias, 0) 180 | 181 | # Zero-initialize the last BN in each residual branch, 182 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 183 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 184 | if zero_init_residual: 185 | for m in self.modules(): 186 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 187 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 188 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 189 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 190 | 191 | def _make_layer( 192 | self, 193 | block: Type[Union[BasicBlock, Bottleneck]], 194 | planes: int, 195 | blocks: int, 196 | stride: int = 1, 197 | dilate: bool = False, 198 | ) -> nn.Sequential: 199 | norm_layer = self._norm_layer 200 | downsample = None 201 | previous_dilation = self.dilation 202 | if dilate: 203 | self.dilation *= stride 204 | stride = 1 205 | if stride != 1 or self.inplanes != planes * block.expansion: 206 | downsample = nn.Sequential( 207 | conv1x1(self.inplanes, planes * block.expansion, stride), 208 | norm_layer(planes * block.expansion), 209 | ) 210 | 211 | layers = [] 212 | layers.append( 213 | block( 214 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 215 | ) 216 | ) 217 | self.inplanes = planes * block.expansion 218 | for _ in range(1, blocks): 219 | layers.append( 220 | block( 221 | self.inplanes, 222 | planes, 223 | groups=self.groups, 224 | base_width=self.base_width, 225 | dilation=self.dilation, 226 | norm_layer=norm_layer, 227 | ) 228 | ) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def _forward_impl(self, x: Tensor) -> Tensor: 233 | # See note [TorchScript super()] 234 | x = self.conv1(x) 235 | x = self.bn1(x) 236 | x = self.relu(x) 237 | x = self.maxpool(x) 238 | 239 | x = self.layer1(x) 240 | x = self.layer2(x) 241 | x = self.layer3(x) 242 | x = self.layer4(x) 243 | 244 | x = self.avgpool(x) 245 | x = torch.flatten(x, 1) 246 | x = self.fc(x) 247 | 248 | return x 249 | 250 | def forward(self, x: Tensor) -> Tensor: 251 | return self._forward_impl(x) 252 | 253 | def resnet18(weights=None): 254 | # weights: path 255 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=4) 256 | if weights is not None: 257 | model.load_state_dict(torch.load(weights)) 258 | return model 259 | 260 | def resnet10(): 261 | return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=4) 262 | -------------------------------------------------------------------------------- /fintune_on_newdataset/compute_metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Mar 31 18:10:52 2022 3 | adapted form https://github.com/stardist/stardist/blob/master/stardist/matching.py 4 | Thanks the authors of Stardist for sharing the great code 5 | 6 | """ 7 | 8 | import argparse 9 | import numpy as np 10 | from numba import jit 11 | from scipy.optimize import linear_sum_assignment 12 | from collections import OrderedDict 13 | import pandas as pd 14 | from skimage import segmentation 15 | import tifffile as tif 16 | import os 17 | join = os.path.join 18 | from tqdm import tqdm 19 | 20 | def _intersection_over_union(masks_true, masks_pred): 21 | """ intersection over union of all mask pairs 22 | 23 | Parameters 24 | ------------ 25 | 26 | masks_true: ND-array, int 27 | ground truth masks, where 0=NO masks; 1,2... are mask labels 28 | masks_pred: ND-array, int 29 | predicted masks, where 0=NO masks; 1,2... are mask labels 30 | """ 31 | overlap = _label_overlap(masks_true, masks_pred) 32 | n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) 33 | n_pixels_true = np.sum(overlap, axis=1, keepdims=True) 34 | iou = overlap / (n_pixels_pred + n_pixels_true - overlap) 35 | iou[np.isnan(iou)] = 0.0 36 | return iou 37 | 38 | @jit(nopython=True) 39 | def _label_overlap(x, y): 40 | """ fast function to get pixel overlaps between masks in x and y 41 | 42 | Parameters 43 | ------------ 44 | 45 | x: ND-array, int 46 | where 0=NO masks; 1,2... are mask labels 47 | y: ND-array, int 48 | where 0=NO masks; 1,2... are mask labels 49 | 50 | Returns 51 | ------------ 52 | 53 | overlap: ND-array, int 54 | matrix of pixel overlaps of size [x.max()+1, y.max()+1] 55 | 56 | """ 57 | x = x.ravel() 58 | y = y.ravel() 59 | 60 | # preallocate a 'contact map' matrix 61 | overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint) 62 | 63 | # loop over the labels in x and add to the corresponding 64 | # overlap entry. If label A in x and label B in y share P 65 | # pixels, then the resulting overlap is P 66 | # len(x)=len(y), the number of pixels in the whole image 67 | for i in range(len(x)): 68 | overlap[x[i],y[i]] += 1 69 | return overlap 70 | 71 | def _true_positive(iou, th): 72 | """ true positive at threshold th 73 | 74 | Parameters 75 | ------------ 76 | 77 | iou: float, ND-array 78 | array of IOU pairs 79 | th: float 80 | threshold on IOU for positive label 81 | 82 | Returns 83 | ------------ 84 | 85 | tp: float 86 | number of true positives at threshold 87 | """ 88 | n_min = min(iou.shape[0], iou.shape[1]) 89 | costs = -(iou >= th).astype(float) - iou / (2*n_min) 90 | true_ind, pred_ind = linear_sum_assignment(costs) 91 | match_ok = iou[true_ind, pred_ind] >= th 92 | tp = match_ok.sum() 93 | return tp 94 | 95 | def eval_tp_fp_fn(masks_true, masks_pred, threshold=0.5): 96 | num_inst_gt = np.max(masks_true) 97 | num_inst_seg = np.max(masks_pred) 98 | if num_inst_seg>0: 99 | iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:] 100 | # for k,th in enumerate(threshold): 101 | tp = _true_positive(iou, threshold) 102 | fp = num_inst_seg - tp 103 | fn = num_inst_gt - tp 104 | else: 105 | print('No segmentation results!') 106 | tp = 0 107 | fp = 0 108 | fn = 0 109 | 110 | return tp, fp, fn 111 | 112 | def remove_boundary_cells(mask): 113 | W, H = mask.shape 114 | bd = np.ones((W, H)) 115 | bd[2:W-2, 2:H-2] = 0 116 | bd_cells = np.unique(mask*bd) 117 | for i in bd_cells[1:]: 118 | mask[mask==i] = 0 119 | new_label,_,_ = segmentation.relabel_sequential(mask) 120 | return new_label 121 | 122 | def main(): 123 | parser = argparse.ArgumentParser('Compute F1 score for cell segmentation results', add_help=False) 124 | # Dataset parameters 125 | parser.add_argument('--gt_path', type=str, help='path to ground truth; file names end with _label.tiff', required=True) 126 | parser.add_argument('--seg_path', type=str, help='path to segmentation results; file names are the same as ground truth', required=True) 127 | parser.add_argument('--save_path', default='./', help='path where to save metrics') 128 | args = parser.parse_args() 129 | 130 | gt_path = args.gt_path 131 | seg_path = args.seg_path 132 | names = sorted(os.listdir(seg_path)) 133 | seg_metric = OrderedDict() 134 | seg_metric['Names'] = [] 135 | seg_metric['F1_Score'] = [] 136 | for name in tqdm(names): 137 | assert name.endswith('_label.tiff'), 'The suffix of label name should be _label.tiff' 138 | 139 | # Load the images for this case 140 | gt = tif.imread(join(gt_path, name)) 141 | seg = tif.imread(join(seg_path, name)) 142 | 143 | # Score the cases 144 | # do not consider cells on the boundaries during evaluation 145 | if np.prod(gt.shape)<25000000: 146 | gt = remove_boundary_cells(gt.astype(np.int32)) 147 | seg = remove_boundary_cells(seg.astype(np.int32)) 148 | tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5) 149 | else: # for large images (>5000x5000), the F1 score is computed by a patch-based way 150 | H, W = gt.shape 151 | roi_size = 2000 152 | 153 | if H % roi_size != 0: 154 | n_H = H // roi_size + 1 155 | new_H = roi_size * n_H 156 | else: 157 | n_H = H // roi_size 158 | new_H = H 159 | 160 | if W % roi_size != 0: 161 | n_W = W // roi_size + 1 162 | new_W = roi_size * n_W 163 | else: 164 | n_W = W // roi_size 165 | new_W = W 166 | 167 | gt_pad = np.zeros((new_H, new_W), dtype=gt.dtype) 168 | seg_pad = np.zeros((new_H, new_W), dtype=gt.dtype) 169 | gt_pad[:H, :W] = gt 170 | seg_pad[:H, :W] = seg 171 | 172 | tp = 0 173 | fp = 0 174 | fn = 0 175 | for i in range(n_H): 176 | for j in range(n_W): 177 | gt_roi = remove_boundary_cells(gt_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)]) 178 | seg_roi = remove_boundary_cells(seg_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)]) 179 | tp_i, fp_i, fn_i = eval_tp_fp_fn(gt_roi, seg_roi, threshold=0.5) 180 | tp += tp_i 181 | fp += fp_i 182 | fn += fn_i 183 | 184 | if tp == 0: 185 | precision = 0 186 | recall = 0 187 | f1 = 0 188 | else: 189 | precision = tp / (tp + fp) 190 | recall = tp / (tp + fn) 191 | f1 = 2*(precision * recall)/ (precision + recall) 192 | seg_metric['Names'].append(name) 193 | seg_metric['F1_Score'].append(np.round(f1, 4)) 194 | 195 | 196 | seg_metric_df = pd.DataFrame(seg_metric) 197 | seg_metric_df.to_csv(join(args.save_path, 'seg_metric.csv'), index=False) 198 | print('mean F1 Score:', np.mean(seg_metric['F1_Score'])) 199 | 200 | if __name__ == '__main__': 201 | main() 202 | -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Mar 20 14:23:55 2022 5 | 6 | @author: jma 7 | """ 8 | 9 | -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/convnext.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/convnext.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/convnext.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/convnext.cpython-38.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/flexible_unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/flexible_unet.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/flexible_unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/flexible_unet.cpython-38.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/flexible_unet_convext.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/flexible_unet_convext.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/flexible_unet_convext.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/flexible_unet_convext.cpython-38.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/flexible_unet_convnext.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/flexible_unet_convnext.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/swin_unetr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/swin_unetr.cpython-38.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/__pycache__/unetr2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/models/__pycache__/unetr2d.cpython-38.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from functools import partial 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_, DropPath 13 | from timm.models.registry import register_model 14 | from monai.networks.layers.factories import Act, Conv, Pad, Pool 15 | from monai.networks.layers.utils import get_norm_layer 16 | from monai.utils.module import look_up_option 17 | from typing import List, NamedTuple, Optional, Tuple, Type, Union 18 | class Block(nn.Module): 19 | r""" ConvNeXt Block. There are two equivalent implementations: 20 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 21 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 22 | We use (2) as we find it slightly faster in PyTorch 23 | 24 | Args: 25 | dim (int): Number of input channels. 26 | drop_path (float): Stochastic depth rate. Default: 0.0 27 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 28 | """ 29 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 30 | super().__init__() 31 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 32 | self.norm = LayerNorm(dim, eps=1e-6) 33 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 34 | self.act = nn.GELU() 35 | self.pwconv2 = nn.Linear(4 * dim, dim) 36 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 37 | requires_grad=True) if layer_scale_init_value > 0 else None 38 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 39 | 40 | def forward(self, x): 41 | input = x 42 | x = self.dwconv(x) 43 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 44 | x = self.norm(x) 45 | x = self.pwconv1(x) 46 | x = self.act(x) 47 | x = self.pwconv2(x) 48 | if self.gamma is not None: 49 | x = self.gamma * x 50 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 51 | 52 | x = input + self.drop_path(x) 53 | return x 54 | 55 | class ConvNeXt(nn.Module): 56 | r""" ConvNeXt 57 | A PyTorch impl of : `A ConvNet for the 2020s` - 58 | https://arxiv.org/pdf/2201.03545.pdf 59 | 60 | Args: 61 | in_chans (int): Number of input image channels. Default: 3 62 | num_classes (int): Number of classes for classification head. Default: 1000 63 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 64 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 65 | drop_path_rate (float): Stochastic depth rate. Default: 0. 66 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 67 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 68 | """ 69 | def __init__(self, in_chans=3, num_classes=21841, 70 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 71 | layer_scale_init_value=1e-6, head_init_scale=1., out_indices=[0, 1, 2, 3], 72 | ): 73 | super().__init__() 74 | # conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", 2] 75 | # self._conv_stem = conv_type(self.in_channels, self.in_channels, kernel_size=3, stride=stride, bias=False) 76 | # self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size) 77 | 78 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 79 | stem = nn.Sequential( 80 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 81 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 82 | ) 83 | self.downsample_layers.append(stem) 84 | for i in range(3): 85 | downsample_layer = nn.Sequential( 86 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 87 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 88 | ) 89 | self.downsample_layers.append(downsample_layer) 90 | 91 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 92 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 93 | cur = 0 94 | for i in range(4): 95 | stage = nn.Sequential( 96 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 97 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 98 | ) 99 | self.stages.append(stage) 100 | cur += depths[i] 101 | 102 | 103 | self.out_indices = out_indices 104 | 105 | norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first") 106 | for i_layer in range(4): 107 | layer = norm_layer(dims[i_layer]) 108 | layer_name = f'norm{i_layer}' 109 | self.add_module(layer_name, layer) 110 | self.apply(self._init_weights) 111 | 112 | 113 | def _init_weights(self, m): 114 | if isinstance(m, (nn.Conv2d, nn.Linear)): 115 | trunc_normal_(m.weight, std=.02) 116 | nn.init.constant_(m.bias, 0) 117 | 118 | def forward_features(self, x): 119 | outs = [] 120 | 121 | for i in range(4): 122 | x = self.downsample_layers[i](x) 123 | x = self.stages[i](x) 124 | if i in self.out_indices: 125 | norm_layer = getattr(self, f'norm{i}') 126 | x_out = norm_layer(x) 127 | 128 | outs.append(x_out) 129 | 130 | return tuple(outs) 131 | 132 | def forward(self, x): 133 | x = self.forward_features(x) 134 | 135 | return x 136 | 137 | class LayerNorm(nn.Module): 138 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 139 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 140 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 141 | with shape (batch_size, channels, height, width). 142 | """ 143 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 144 | super().__init__() 145 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 146 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 147 | self.eps = eps 148 | self.data_format = data_format 149 | if self.data_format not in ["channels_last", "channels_first"]: 150 | raise NotImplementedError 151 | self.normalized_shape = (normalized_shape, ) 152 | 153 | def forward(self, x): 154 | if self.data_format == "channels_last": 155 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 156 | elif self.data_format == "channels_first": 157 | u = x.mean(1, keepdim=True) 158 | s = (x - u).pow(2).mean(1, keepdim=True) 159 | x = (x - u) / torch.sqrt(s + self.eps) 160 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 161 | return x 162 | 163 | 164 | model_urls = { 165 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 166 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 167 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 168 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 169 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 170 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 171 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 172 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 173 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 174 | } 175 | 176 | @register_model 177 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 178 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 179 | if pretrained: 180 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 181 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 182 | model.load_state_dict(checkpoint["model"]) 183 | return model 184 | 185 | @register_model 186 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 187 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 188 | if pretrained: 189 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 190 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 191 | model.load_state_dict(checkpoint["model"], strict=False) 192 | return model 193 | 194 | @register_model 195 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 197 | if pretrained: 198 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 199 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 200 | model.load_state_dict(checkpoint["model"], strict=False) 201 | return model 202 | 203 | @register_model 204 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 205 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 206 | if pretrained: 207 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 208 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 209 | model.load_state_dict(checkpoint["model"]) 210 | return model 211 | 212 | @register_model 213 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 214 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 215 | if pretrained: 216 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 217 | url = model_urls['convnext_xlarge_22k'] 218 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 219 | model.load_state_dict(checkpoint["model"]) 220 | return model -------------------------------------------------------------------------------- /fintune_on_newdataset/models/flexible_unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import List, Optional, Sequence, Tuple, Union 13 | 14 | import torch 15 | from torch import nn 16 | 17 | from monai.networks.blocks import UpSample 18 | from monai.networks.layers.factories import Conv 19 | from monai.networks.layers.utils import get_act_layer 20 | from monai.networks.nets import EfficientNetBNFeatures 21 | from monai.networks.nets.basic_unet import UpCat 22 | from monai.utils import InterpolateMode 23 | 24 | __all__ = ["FlexibleUNet"] 25 | 26 | encoder_feature_channel = { 27 | "efficientnet-b0": (16, 24, 40, 112, 320), 28 | "efficientnet-b1": (16, 24, 40, 112, 320), 29 | "efficientnet-b2": (16, 24, 48, 120, 352), 30 | "efficientnet-b3": (24, 32, 48, 136, 384), 31 | "efficientnet-b4": (24, 32, 56, 160, 448), 32 | "efficientnet-b5": (24, 40, 64, 176, 512), 33 | "efficientnet-b6": (32, 40, 72, 200, 576), 34 | "efficientnet-b7": (32, 48, 80, 224, 640), 35 | "efficientnet-b8": (32, 56, 88, 248, 704), 36 | "efficientnet-l2": (72, 104, 176, 480, 1376), 37 | } 38 | 39 | 40 | def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple: 41 | """ 42 | Get the encoder output channels by given backbone name. 43 | 44 | Args: 45 | backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7]. 46 | in_channels: channel of input tensor, default to 3. 47 | 48 | Returns: 49 | A tuple of output feature map channels' length . 50 | """ 51 | encoder_channel_tuple = encoder_feature_channel[backbone] 52 | encoder_channel_list = [in_channels] + list(encoder_channel_tuple) 53 | encoder_channel = tuple(encoder_channel_list) 54 | return encoder_channel 55 | 56 | 57 | class UNetDecoder(nn.Module): 58 | """ 59 | UNet Decoder. 60 | This class refers to `segmentation_models.pytorch 61 | `_. 62 | 63 | Args: 64 | spatial_dims: number of spatial dimensions. 65 | encoder_channels: number of output channels for all feature maps in encoder. 66 | `len(encoder_channels)` should be no less than 2. 67 | decoder_channels: number of output channels for all feature maps in decoder. 68 | `len(decoder_channels)` should equal to `len(encoder_channels) - 1`. 69 | act: activation type and arguments. 70 | norm: feature normalization type and arguments. 71 | dropout: dropout ratio. 72 | bias: whether to have a bias term in convolution blocks in this decoder. 73 | upsample: upsampling mode, available options are 74 | ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. 75 | pre_conv: a conv block applied before upsampling. 76 | Only used in the "nontrainable" or "pixelshuffle" mode. 77 | interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} 78 | Only used in the "nontrainable" mode. 79 | align_corners: set the align_corners parameter for upsample. Defaults to True. 80 | Only used in the "nontrainable" mode. 81 | is_pad: whether to pad upsampling features to fit the encoder spatial dims. 82 | 83 | """ 84 | 85 | def __init__( 86 | self, 87 | spatial_dims: int, 88 | encoder_channels: Sequence[int], 89 | decoder_channels: Sequence[int], 90 | act: Union[str, tuple], 91 | norm: Union[str, tuple], 92 | dropout: Union[float, tuple], 93 | bias: bool, 94 | upsample: str, 95 | pre_conv: Optional[str], 96 | interp_mode: str, 97 | align_corners: Optional[bool], 98 | is_pad: bool, 99 | ): 100 | 101 | super().__init__() 102 | if len(encoder_channels) < 2: 103 | raise ValueError("the length of `encoder_channels` should be no less than 2.") 104 | if len(decoder_channels) != len(encoder_channels) - 1: 105 | raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.") 106 | 107 | in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1]) 108 | skip_channels = list(encoder_channels[1:-1][::-1]) + [0] 109 | halves = [True] * (len(skip_channels) - 1) 110 | halves.append(False) 111 | blocks = [] 112 | for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves): 113 | blocks.append( 114 | UpCat( 115 | spatial_dims=spatial_dims, 116 | in_chns=in_chn, 117 | cat_chns=skip_chn, 118 | out_chns=out_chn, 119 | act=act, 120 | norm=norm, 121 | dropout=dropout, 122 | bias=bias, 123 | upsample=upsample, 124 | pre_conv=pre_conv, 125 | interp_mode=interp_mode, 126 | align_corners=align_corners, 127 | halves=halve, 128 | is_pad=is_pad, 129 | ) 130 | ) 131 | self.blocks = nn.ModuleList(blocks) 132 | 133 | def forward(self, features: List[torch.Tensor], skip_connect: int = 4): 134 | skips = features[:-1][::-1] 135 | features = features[1:][::-1] 136 | 137 | x = features[0] 138 | for i, block in enumerate(self.blocks): 139 | if i < skip_connect: 140 | skip = skips[i] 141 | else: 142 | skip = None 143 | x = block(x, skip) 144 | 145 | return x 146 | 147 | 148 | class SegmentationHead(nn.Sequential): 149 | """ 150 | Segmentation head. 151 | This class refers to `segmentation_models.pytorch 152 | `_. 153 | 154 | Args: 155 | spatial_dims: number of spatial dimensions. 156 | in_channels: number of input channels for the block. 157 | out_channels: number of output channels for the block. 158 | kernel_size: kernel size for the conv layer. 159 | act: activation type and arguments. 160 | scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. 161 | 162 | """ 163 | 164 | def __init__( 165 | self, 166 | spatial_dims: int, 167 | in_channels: int, 168 | out_channels: int, 169 | kernel_size: int = 3, 170 | act: Optional[Union[Tuple, str]] = None, 171 | scale_factor: float = 1.0, 172 | ): 173 | 174 | conv_layer = Conv[Conv.CONV, spatial_dims]( 175 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2 176 | ) 177 | up_layer: nn.Module = nn.Identity() 178 | if scale_factor > 1.0: 179 | up_layer = UpSample( 180 | spatial_dims=spatial_dims, 181 | scale_factor=scale_factor, 182 | mode="nontrainable", 183 | pre_conv=None, 184 | interp_mode=InterpolateMode.LINEAR, 185 | ) 186 | if act is not None: 187 | act_layer = get_act_layer(act) 188 | else: 189 | act_layer = nn.Identity() 190 | super().__init__(conv_layer, up_layer, act_layer) 191 | 192 | 193 | class FlexibleUNet(nn.Module): 194 | """ 195 | A flexible implementation of UNet-like encoder-decoder architecture. 196 | """ 197 | 198 | def __init__( 199 | self, 200 | in_channels: int, 201 | out_channels: int, 202 | backbone: str, 203 | pretrained: bool = False, 204 | decoder_channels: Tuple = (256, 128, 64, 32, 16), 205 | spatial_dims: int = 2, 206 | norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}), 207 | act: Union[str, tuple] = ("relu", {"inplace": True}), 208 | dropout: Union[float, tuple] = 0.0, 209 | decoder_bias: bool = False, 210 | upsample: str = "nontrainable", 211 | interp_mode: str = "nearest", 212 | is_pad: bool = True, 213 | ) -> None: 214 | """ 215 | A flexible implement of UNet, in which the backbone/encoder can be replaced with 216 | any efficient network. Currently the input must have a 2 or 3 spatial dimension 217 | and the spatial size of each dimension must be a multiple of 32 if is pad parameter 218 | is False 219 | 220 | Args: 221 | in_channels: number of input channels. 222 | out_channels: number of output channels. 223 | backbone: name of backbones to initialize, only support efficientnet right now, 224 | can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2]. 225 | pretrained: whether to initialize pretrained ImageNet weights, only available 226 | for spatial_dims=2 and batch norm is used, default to False. 227 | decoder_channels: number of output channels for all feature maps in decoder. 228 | `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default 229 | to (256, 128, 64, 32, 16). 230 | spatial_dims: number of spatial dimensions, default to 2. 231 | norm: normalization type and arguments, default to ("batch", {"eps": 1e-3, 232 | "momentum": 0.1}). 233 | act: activation type and arguments, default to ("relu", {"inplace": True}). 234 | dropout: dropout ratio, default to 0.0. 235 | decoder_bias: whether to have a bias term in decoder's convolution blocks. 236 | upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``, 237 | ``"nontrainable"``. 238 | interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} 239 | Only used in the "nontrainable" mode. 240 | is_pad: whether to pad upsampling features to fit features from encoder. Default to True. 241 | If this parameter is set to "True", the spatial dim of network input can be arbitary 242 | size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32. 243 | """ 244 | super().__init__() 245 | 246 | if backbone not in encoder_feature_channel: 247 | raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.") 248 | 249 | if spatial_dims not in (2, 3): 250 | raise ValueError("spatial_dims can only be 2 or 3.") 251 | 252 | adv_prop = "ap" in backbone 253 | 254 | self.backbone = backbone 255 | self.spatial_dims = spatial_dims 256 | model_name = backbone 257 | encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels) 258 | self.encoder = EfficientNetBNFeatures( 259 | model_name=model_name, 260 | pretrained=pretrained, 261 | in_channels=in_channels, 262 | spatial_dims=spatial_dims, 263 | norm=norm, 264 | adv_prop=adv_prop, 265 | ) 266 | self.decoder = UNetDecoder( 267 | spatial_dims=spatial_dims, 268 | encoder_channels=encoder_channels, 269 | decoder_channels=decoder_channels, 270 | act=act, 271 | norm=norm, 272 | dropout=dropout, 273 | bias=decoder_bias, 274 | upsample=upsample, 275 | interp_mode=interp_mode, 276 | pre_conv=None, 277 | align_corners=None, 278 | is_pad=is_pad, 279 | ) 280 | self.dist_head = SegmentationHead( 281 | spatial_dims=spatial_dims, 282 | in_channels=decoder_channels[-1], 283 | out_channels=32, 284 | kernel_size=1, 285 | act='relu', 286 | ) 287 | self.prob_head = SegmentationHead( 288 | spatial_dims=spatial_dims, 289 | in_channels=decoder_channels[-1], 290 | out_channels=1, 291 | kernel_size=1, 292 | act='sigmoid', 293 | ) 294 | 295 | def forward(self, inputs: torch.Tensor): 296 | """ 297 | Do a typical encoder-decoder-header inference. 298 | 299 | Args: 300 | inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, 301 | N is defined by `dimensions`. 302 | 303 | Returns: 304 | A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``. 305 | 306 | """ 307 | x = inputs 308 | enc_out = self.encoder(x) 309 | decoder_out = self.decoder(enc_out) 310 | dist = self.dist_head(decoder_out) 311 | prob = self.prob_head(decoder_out) 312 | return dist,prob 313 | -------------------------------------------------------------------------------- /fintune_on_newdataset/overlay.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | ###overlay 5 | import cv2 6 | import math 7 | import random 8 | import colorsys 9 | import numpy as np 10 | import itertools 11 | import matplotlib.pyplot as plt 12 | from matplotlib import cm 13 | import os 14 | import scipy.io as io 15 | def get_bounding_box(img): 16 | """Get bounding box coordinate information.""" 17 | rows = np.any(img, axis=1) 18 | cols = np.any(img, axis=0) 19 | rmin, rmax = np.where(rows)[0][[0, -1]] 20 | cmin, cmax = np.where(cols)[0][[0, -1]] 21 | # due to python indexing, need to add 1 to max 22 | # else accessing will be 1px in the box, not out 23 | rmax += 1 24 | cmax += 1 25 | return [rmin, rmax, cmin, cmax] 26 | #### 27 | def colorize(ch, vmin, vmax): 28 | """Will clamp value value outside the provided range to vmax and vmin.""" 29 | cmap = plt.get_cmap("jet") 30 | ch = np.squeeze(ch.astype("float32")) 31 | vmin = vmin if vmin is not None else ch.min() 32 | vmax = vmax if vmax is not None else ch.max() 33 | ch[ch > vmax] = vmax # clamp value 34 | ch[ch < vmin] = vmin 35 | ch = (ch - vmin) / (vmax - vmin + 1.0e-16) 36 | # take RGB from RGBA heat map 37 | ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") 38 | return ch_cmap 39 | 40 | 41 | #### 42 | def random_colors(N, bright=True): 43 | """Generate random colors. 44 | 45 | To get visually distinct colors, generate them in HSV space then 46 | convert to RGB. 47 | """ 48 | brightness = 1.0 if bright else 0.7 49 | hsv = [(i / N, 1, brightness) for i in range(N)] 50 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 51 | random.shuffle(colors) 52 | return colors 53 | 54 | 55 | #### 56 | def visualize_instances_map( 57 | input_image, inst_map, type_map=None, type_colour=None, line_thickness=2 58 | ): 59 | """Overlays segmentation results on image as contours. 60 | 61 | Args: 62 | input_image: input image 63 | inst_map: instance mask with unique value for every object 64 | type_map: type mask with unique value for every class 65 | type_colour: a dict of {type : colour} , `type` is from 0-N 66 | and `colour` is a tuple of (R, G, B) 67 | line_thickness: line thickness of contours 68 | 69 | Returns: 70 | overlay: output image with segmentation overlay as contours 71 | """ 72 | overlay = np.copy((input_image).astype(np.uint8)) 73 | 74 | inst_list = list(np.unique(inst_map)) # get list of instances 75 | inst_list.remove(0) # remove background 76 | 77 | inst_rng_colors = random_colors(len(inst_list)) 78 | inst_rng_colors = np.array(inst_rng_colors) * 255 79 | inst_rng_colors = inst_rng_colors.astype(np.uint8) 80 | 81 | for inst_idx, inst_id in enumerate(inst_list): 82 | inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object 83 | y1, y2, x1, x2 = get_bounding_box(inst_map_mask) 84 | y1 = y1 - 2 if y1 - 2 >= 0 else y1 85 | x1 = x1 - 2 if x1 - 2 >= 0 else x1 86 | x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 87 | y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 88 | inst_map_crop = inst_map_mask[y1:y2, x1:x2] 89 | contours_crop = cv2.findContours( 90 | inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 91 | ) 92 | # only has 1 instance per map, no need to check #contour detected by opencv 93 | #print(contours_crop) 94 | contours_crop = np.squeeze( 95 | contours_crop[0][0].astype("int32") 96 | ) # * opencv protocol format may break 97 | 98 | if len(contours_crop.shape) == 1: 99 | contours_crop = contours_crop.reshape(1,-1) 100 | #print(contours_crop.shape) 101 | contours_crop += np.asarray([[x1, y1]]) # index correction 102 | if type_map is not None: 103 | type_map_crop = type_map[y1:y2, x1:x2] 104 | type_id = np.unique(type_map_crop).max() # non-zero 105 | inst_colour = type_colour[type_id] 106 | else: 107 | inst_colour = (inst_rng_colors[inst_idx]).tolist() 108 | cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness) 109 | return overlay 110 | 111 | 112 | # In[ ]: 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | import warnings 4 | def format_warning(message, category, filename, lineno, line=''): 5 | import pathlib 6 | return f"{pathlib.Path(filename).name} ({lineno}): {message}\n" 7 | warnings.formatwarning = format_warning 8 | del warnings 9 | 10 | from .version import __version__ 11 | 12 | # TODO: which functions to expose here? all? 13 | from .nms import non_maximum_suppression 14 | from .utils import edt_prob, fill_label_holes, sample_points, calculate_extents, export_imagej_rois, gputools_available 15 | from .geometry import star_dist, polygons_to_label, relabel_image_stardist, ray_angles, dist_to_coord 16 | from .sample_patches import sample_patches 17 | from .bioimageio_utils import export_bioimageio, import_bioimageio 18 | 19 | def _py_deprecation(ver_python=(3,6), ver_stardist='0.9.0'): 20 | import sys 21 | from distutils.version import LooseVersion 22 | if sys.version_info[:2] == ver_python and LooseVersion(__version__) < LooseVersion(ver_stardist): 23 | print(f"You are using Python {ver_python[0]}.{ver_python[1]}, which will no longer be supported in StarDist {ver_stardist}.\n" 24 | f"→ Please upgrade to Python {ver_python[0]}.{ver_python[1]+1} or later.", file=sys.stderr, flush=True) 25 | _py_deprecation() 26 | del _py_deprecation 27 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/big.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/big.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/matching.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/matching.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/nms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/nms.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/sample_patches.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/sample_patches.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/__pycache__/version.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/__pycache__/version.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | # TODO: rethink naming for 2D/3D functions 4 | 5 | from .geom2d import star_dist, relabel_image_stardist, ray_angles, dist_to_coord, polygons_to_label, polygons_to_label_coord 6 | 7 | from .geom2d import _dist_to_coord_old, _polygons_to_label_old 8 | 9 | #, dist_to_volume, dist_to_centroid 10 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/fintune_on_newdataset/stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/geometry/geom2d.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, absolute_import, division 2 | import numpy as np 3 | import warnings 4 | 5 | from skimage.measure import regionprops 6 | from skimage.draw import polygon 7 | from csbdeep.utils import _raise 8 | 9 | from ..utils import path_absolute, _is_power_of_2, _normalize_grid 10 | from ..matching import _check_label_array 11 | from stardist.lib.stardist2d import c_star_dist 12 | 13 | 14 | 15 | def _ocl_star_dist(lbl, n_rays=32, grid=(1,1)): 16 | from gputools import OCLProgram, OCLArray, OCLImage 17 | (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError()) 18 | n_rays = int(n_rays) 19 | # slicing with grid is done with tuple(slice(0, None, g) for g in grid) 20 | res_shape = tuple((s-1)//g+1 for s, g in zip(lbl.shape, grid)) 21 | 22 | src = OCLImage.from_array(lbl.astype(np.uint16,copy=False)) 23 | dst = OCLArray.empty(res_shape+(n_rays,), dtype=np.float32) 24 | program = OCLProgram(path_absolute("kernels/stardist2d.cl"), build_options=['-D', 'N_RAYS=%d' % n_rays]) 25 | program.run_kernel('star_dist', res_shape[::-1], None, dst.data, src, np.int32(grid[0]),np.int32(grid[1])) 26 | return dst.get() 27 | 28 | 29 | def _cpp_star_dist(lbl, n_rays=32, grid=(1,1)): 30 | (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError()) 31 | return c_star_dist(lbl.astype(np.uint16,copy=False), np.int32(n_rays), np.int32(grid[0]),np.int32(grid[1])) 32 | 33 | 34 | def _py_star_dist(a, n_rays=32, grid=(1,1)): 35 | (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError()) 36 | if grid != (1,1): 37 | raise NotImplementedError(grid) 38 | 39 | n_rays = int(n_rays) 40 | a = a.astype(np.uint16,copy=False) 41 | dst = np.empty(a.shape+(n_rays,),np.float32) 42 | 43 | for i in range(a.shape[0]): 44 | for j in range(a.shape[1]): 45 | value = a[i,j] 46 | if value == 0: 47 | dst[i,j] = 0 48 | else: 49 | st_rays = np.float32((2*np.pi) / n_rays) 50 | for k in range(n_rays): 51 | phi = np.float32(k*st_rays) 52 | dy = np.cos(phi) 53 | dx = np.sin(phi) 54 | x, y = np.float32(0), np.float32(0) 55 | while True: 56 | x += dx 57 | y += dy 58 | ii = int(round(i+x)) 59 | jj = int(round(j+y)) 60 | if (ii < 0 or ii >= a.shape[0] or 61 | jj < 0 or jj >= a.shape[1] or 62 | value != a[ii,jj]): 63 | # small correction as we overshoot the boundary 64 | t_corr = 1-.5/max(np.abs(dx),np.abs(dy)) 65 | x -= t_corr*dx 66 | y -= t_corr*dy 67 | dist = np.sqrt(x**2+y**2) 68 | dst[i,j,k] = dist 69 | break 70 | return dst 71 | 72 | 73 | def star_dist(a, n_rays=32, grid=(1,1), mode='cpp'): 74 | """'a' assumbed to be a label image with integer values that encode object ids. id 0 denotes background.""" 75 | 76 | n_rays >= 3 or _raise(ValueError("need 'n_rays' >= 3")) 77 | 78 | if mode == 'python': 79 | return _py_star_dist(a, n_rays, grid=grid) 80 | elif mode == 'cpp': 81 | return _cpp_star_dist(a, n_rays, grid=grid) 82 | elif mode == 'opencl': 83 | return _ocl_star_dist(a, n_rays, grid=grid) 84 | else: 85 | _raise(ValueError("Unknown mode %s" % mode)) 86 | 87 | 88 | def _dist_to_coord_old(rhos, grid=(1,1)): 89 | """convert from polar to cartesian coordinates for a single image (3-D array) or multiple images (4-D array)""" 90 | 91 | grid = _normalize_grid(grid,2) 92 | is_single_image = rhos.ndim == 3 93 | if is_single_image: 94 | rhos = np.expand_dims(rhos,0) 95 | assert rhos.ndim == 4 96 | 97 | n_images,h,w,n_rays = rhos.shape 98 | coord = np.empty((n_images,h,w,2,n_rays),dtype=rhos.dtype) 99 | 100 | start = np.indices((h,w)) 101 | for i in range(2): 102 | coord[...,i,:] = grid[i] * np.broadcast_to(start[i].reshape(1,h,w,1), (n_images,h,w,n_rays)) 103 | 104 | phis = ray_angles(n_rays).reshape(1,1,1,n_rays) 105 | 106 | coord[...,0,:] += rhos * np.sin(phis) # row coordinate 107 | coord[...,1,:] += rhos * np.cos(phis) # col coordinate 108 | 109 | return coord[0] if is_single_image else coord 110 | 111 | 112 | def _polygons_to_label_old(coord, prob, points, shape=None, thr=-np.inf): 113 | sh = coord.shape[:2] if shape is None else shape 114 | lbl = np.zeros(sh,np.int32) 115 | # sort points with increasing probability 116 | ind = np.argsort([ prob[p[0],p[1]] for p in points ]) 117 | points = points[ind] 118 | 119 | i = 1 120 | for p in points: 121 | if prob[p[0],p[1]] < thr: 122 | continue 123 | rr,cc = polygon(coord[p[0],p[1],0], coord[p[0],p[1],1], sh) 124 | lbl[rr,cc] = i 125 | i += 1 126 | 127 | return lbl 128 | 129 | 130 | def dist_to_coord(dist, points, scale_dist=(1,1)): 131 | """convert from polar to cartesian coordinates for a list of distances and center points 132 | dist.shape = (n_polys, n_rays) 133 | points.shape = (n_polys, 2) 134 | len(scale_dist) = 2 135 | return coord.shape = (n_polys,2,n_rays) 136 | """ 137 | dist = np.asarray(dist) 138 | points = np.asarray(points) 139 | assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points) \ 140 | and points.shape[1]==2 and len(scale_dist)==2 141 | n_rays = dist.shape[1] 142 | phis = ray_angles(n_rays) 143 | coord = (dist[:,np.newaxis]*np.array([np.sin(phis),np.cos(phis)])).astype(np.float32) 144 | coord *= np.asarray(scale_dist).reshape(1,2,1) 145 | coord += points[...,np.newaxis] 146 | return coord 147 | 148 | 149 | def polygons_to_label_coord(coord, shape, labels=None): 150 | """renders polygons to image of given shape 151 | 152 | coord.shape = (n_polys, n_rays) 153 | """ 154 | coord = np.asarray(coord) 155 | if labels is None: labels = np.arange(len(coord)) 156 | 157 | _check_label_array(labels, "labels") 158 | assert coord.ndim==3 and coord.shape[1]==2 and len(coord)==len(labels) 159 | 160 | lbl = np.zeros(shape,np.int32) 161 | 162 | for i,c in zip(labels,coord): 163 | rr,cc = polygon(*c, shape) 164 | lbl[rr,cc] = i+1 165 | 166 | return lbl 167 | 168 | 169 | def polygons_to_label(dist, points, shape, prob=None, thr=-np.inf, scale_dist=(1,1)): 170 | """converts distances and center points to label image 171 | 172 | dist.shape = (n_polys, n_rays) 173 | points.shape = (n_polys, 2) 174 | 175 | label ids will be consecutive and adhere to the order given 176 | """ 177 | dist = np.asarray(dist) 178 | points = np.asarray(points) 179 | prob = np.inf*np.ones(len(points)) if prob is None else np.asarray(prob) 180 | 181 | assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points) 182 | assert len(points)==len(prob) and points.shape[1]==2 and prob.ndim==1 183 | 184 | n_rays = dist.shape[1] 185 | 186 | ind = prob>thr 187 | points = points[ind] 188 | dist = dist[ind] 189 | prob = prob[ind] 190 | 191 | ind = np.argsort(prob, kind='stable') 192 | points = points[ind] 193 | dist = dist[ind] 194 | 195 | coord = dist_to_coord(dist, points, scale_dist=scale_dist) 196 | 197 | return polygons_to_label_coord(coord, shape=shape, labels=ind) 198 | 199 | 200 | def relabel_image_stardist(lbl, n_rays, **kwargs): 201 | """relabel each label region in `lbl` with its star representation""" 202 | _check_label_array(lbl, "lbl") 203 | if not lbl.ndim==2: 204 | raise ValueError("lbl image should be 2 dimensional") 205 | dist = star_dist(lbl, n_rays, **kwargs) 206 | points = np.array(tuple(np.array(r.centroid).astype(int) for r in regionprops(lbl))) 207 | dist = dist[tuple(points.T)] 208 | return polygons_to_label(dist, points, shape=lbl.shape) 209 | 210 | 211 | def ray_angles(n_rays=32): 212 | return np.linspace(0,2*np.pi,n_rays,endpoint=False) 213 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/kernels/stardist2d.cl: -------------------------------------------------------------------------------- 1 | #ifndef M_PI 2 | #define M_PI 3.141592653589793 3 | #endif 4 | 5 | __constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; 6 | 7 | inline float2 pol2cart(const float rho, const float phi) { 8 | const float x = rho * cos(phi); 9 | const float y = rho * sin(phi); 10 | return (float2)(x,y); 11 | } 12 | 13 | __kernel void star_dist(__global float* dst, read_only image2d_t src, const int grid_y, const int grid_x) { 14 | 15 | const int i = get_global_id(0), j = get_global_id(1); 16 | const int Nx = get_global_size(0), Ny = get_global_size(1); 17 | const float2 grid = (float2)(grid_x, grid_y); 18 | 19 | const float2 origin = (float2)(i,j) * grid; 20 | const int value = read_imageui(src,sampler,origin).x; 21 | 22 | if (value == 0) { 23 | // background pixel -> nothing to do, write all zeros 24 | for (int k = 0; k < N_RAYS; k++) { 25 | dst[k + i*N_RAYS + j*N_RAYS*Nx] = 0; 26 | } 27 | } else { 28 | float st_rays = (2*M_PI) / N_RAYS; // step size for ray angles 29 | // for all rays 30 | for (int k = 0; k < N_RAYS; k++) { 31 | const float phi = k*st_rays; // current ray angle phi 32 | const float2 dir = pol2cart(1,phi); // small vector in direction of ray 33 | float2 offset = 0; // offset vector to be added to origin 34 | // find radius that leaves current object 35 | while (1) { 36 | offset += dir; 37 | const int offset_value = read_imageui(src,sampler,round(origin+offset)).x; 38 | if (offset_value != value) { 39 | // small correction as we overshoot the boundary 40 | const float t_corr = .5f/fmax(fabs(dir.x),fabs(dir.y)); 41 | offset += (t_corr-1.f)*dir; 42 | 43 | const float dist = sqrt(offset.x*offset.x + offset.y*offset.y); 44 | dst[k + i*N_RAYS + j*N_RAYS*Nx] = dist; 45 | break; 46 | } 47 | } 48 | } 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/kernels/stardist3d.cl: -------------------------------------------------------------------------------- 1 | #ifndef M_PI 2 | #define M_PI 3.141592653589793 3 | #endif 4 | 5 | __constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; 6 | 7 | inline int round_to_int(float r) { 8 | return (int)rint(r); 9 | } 10 | 11 | 12 | __kernel void stardist3d(read_only image3d_t lbl, __constant float * rays, __global float* dist, const int grid_z, const int grid_y, const int grid_x) { 13 | 14 | const int i = get_global_id(0); 15 | const int j = get_global_id(1); 16 | const int k = get_global_id(2); 17 | 18 | const int Nx = get_global_size(0); 19 | const int Ny = get_global_size(1); 20 | const int Nz = get_global_size(2); 21 | 22 | const float4 grid = (float4)(grid_x, grid_y, grid_z, 1); 23 | const float4 origin = (float4)(i,j,k,0) * grid; 24 | const int value = read_imageui(lbl,sampler,origin).x; 25 | 26 | if (value == 0) { 27 | // background pixel -> nothing to do, write all zeros 28 | for (int m = 0; m < N_RAYS; m++) { 29 | dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = 0; 30 | } 31 | 32 | } 33 | else { 34 | for (int m = 0; m < N_RAYS; m++) { 35 | 36 | const float4 dx = (float4)(rays[3*m+2],rays[3*m+1],rays[3*m],0); 37 | // if ((i==Nx/2)&&(j==Ny/2)&(k==Nz/2)){ 38 | // printf("kernel: %.2f %.2f %.2f \n",dx.x,dx.y,dx.z); 39 | // } 40 | float4 x = (float4)(0,0,0,0); 41 | 42 | // move along ray 43 | while (1) { 44 | x += dx; 45 | // if ((i==10)&&(j==10)&(k==10)){ 46 | // printf("kernel run: %.2f %.2f %.2f value %d \n",x.x,x.y,x.z, read_imageui(lbl,sampler,origin+x).x); 47 | // } 48 | 49 | // to make it equivalent to the cpp version... 50 | const float4 x_int = (float4)(round_to_int(x.x), 51 | round_to_int(x.y), 52 | round_to_int(x.z), 53 | 0); 54 | 55 | if (value != read_imageui(lbl,sampler,origin+x_int).x){ 56 | 57 | dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = length(x_int); 58 | break; 59 | } 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | from .model2d import Config2D, StarDist2D, StarDistData2D 4 | 5 | from csbdeep.utils import backend_channels_last 6 | from csbdeep.utils.tf import keras_import 7 | K = keras_import('backend') 8 | if not backend_channels_last(): 9 | raise NotImplementedError( 10 | "Keras is configured to use the '%s' image data format, which is currently not supported. " 11 | "Please change it to use 'channels_last' instead: " 12 | "https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % K.image_data_format() 13 | ) 14 | del backend_channels_last, K 15 | 16 | from csbdeep.models import register_model, register_aliases, clear_models_and_aliases 17 | # register pre-trained models and aliases (TODO: replace with updatable solution) 18 | clear_models_and_aliases(StarDist2D, StarDist3D) 19 | register_model(StarDist2D, '2D_versatile_fluo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_fluo.zip', '8db40dacb5a1311b8d2c447ad934fb8a') 20 | register_model(StarDist2D, '2D_versatile_he', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_he.zip', 'bf34cb3c0e5b3435971e18d66778a4ec') 21 | register_model(StarDist2D, '2D_paper_dsb2018', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_paper_dsb2018.zip', '6287bf283f85c058ec3e7094b41039b5') 22 | register_model(StarDist2D, '2D_demo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_demo.zip', '31f70402f58c50dd231ec31b4375ea2c') 23 | 24 | register_aliases(StarDist2D, '2D_paper_dsb2018', 'DSB 2018 (from StarDist 2D paper)') 25 | register_aliases(StarDist2D, '2D_versatile_fluo', 'Versatile (fluorescent nuclei)') 26 | register_aliases(StarDist2D, '2D_versatile_he', 'Versatile (H&E nuclei)') 27 | del register_model, register_aliases, clear_models_and_aliases 28 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/rays3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ray factory 3 | 4 | classes that provide vertex and triangle information for rays on spheres 5 | 6 | Example: 7 | 8 | rays = Rays_Tetra(n_level = 4) 9 | 10 | print(rays.vertices) 11 | print(rays.faces) 12 | 13 | """ 14 | from __future__ import print_function, unicode_literals, absolute_import, division 15 | import numpy as np 16 | from scipy.spatial import ConvexHull 17 | import copy 18 | import warnings 19 | 20 | class Rays_Base(object): 21 | def __init__(self, **kwargs): 22 | self.kwargs = kwargs 23 | self._vertices, self._faces = self.setup_vertices_faces() 24 | self._vertices = np.asarray(self._vertices, np.float32) 25 | self._faces = np.asarray(self._faces, int) 26 | self._faces = np.asanyarray(self._faces) 27 | 28 | def setup_vertices_faces(self): 29 | """has to return 30 | 31 | verts , faces 32 | 33 | verts = ( (z_1,y_1,x_1), ... ) 34 | faces ( (0,1,2), (2,3,4), ... ) 35 | 36 | """ 37 | raise NotImplementedError() 38 | 39 | @property 40 | def vertices(self): 41 | """read-only property""" 42 | return self._vertices.copy() 43 | 44 | @property 45 | def faces(self): 46 | """read-only property""" 47 | return self._faces.copy() 48 | 49 | def __getitem__(self, i): 50 | return self.vertices[i] 51 | 52 | def __len__(self): 53 | return len(self._vertices) 54 | 55 | def __repr__(self): 56 | def _conv(x): 57 | if isinstance(x,(tuple, list, np.ndarray)): 58 | return "_".join(_conv(_x) for _x in x) 59 | if isinstance(x,float): 60 | return "%.2f"%x 61 | return str(x) 62 | return "%s_%s" % (self.__class__.__name__, "_".join("%s_%s" % (k, _conv(v)) for k, v in sorted(self.kwargs.items()))) 63 | 64 | def to_json(self): 65 | return { 66 | "name": self.__class__.__name__, 67 | "kwargs": self.kwargs 68 | } 69 | 70 | def dist_loss_weights(self, anisotropy = (1,1,1)): 71 | """returns the anisotropy corrected weights for each ray""" 72 | anisotropy = np.array(anisotropy) 73 | assert anisotropy.shape == (3,) 74 | return np.linalg.norm(self.vertices*anisotropy, axis = -1) 75 | 76 | def volume(self, dist=None): 77 | """volume of the starconvex polyhedron spanned by dist (if None, uses dist=1) 78 | dist can be a nD array, but the last dimension has to be of length n_rays 79 | """ 80 | if dist is None: dist = np.ones_like(self.vertices) 81 | 82 | dist = np.asarray(dist) 83 | 84 | if not dist.shape[-1]==len(self.vertices): 85 | raise ValueError("last dimension of dist should have length len(rays.vertices)") 86 | # all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays) 87 | # self.vertices -> (n_rays,3) 88 | # dist -> (m,n,..., n_rays) 89 | 90 | # dist -> (m,n,..., n_rays, 3) 91 | dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1) 92 | # verts -> (m,n,..., n_rays, 3) 93 | verts = np.broadcast_to(self.vertices, dist.shape) 94 | 95 | # dist, verts -> (n_rays, m,n, ..., 3) 96 | dist = np.moveaxis(dist,-2,0) 97 | verts = np.moveaxis(verts,-2,0) 98 | 99 | # vs -> (n_faces, 3, m, n, ..., 3) 100 | vs = (dist*verts)[self.faces] 101 | # vs -> (n_faces, m, n, ..., 3, 3) 102 | vs = np.moveaxis(vs, 1,-2) 103 | # vs -> (n_faces * m * n, 3, 3) 104 | vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3)) 105 | d = np.linalg.det(list(vs)).reshape((len(self.faces),)+dist.shape[1:-1]) 106 | 107 | return -1./6*np.sum(d, axis = 0) 108 | 109 | def surface(self, dist=None): 110 | """surface area of the starconvex polyhedron spanned by dist (if None, uses dist=1)""" 111 | dist = np.asarray(dist) 112 | 113 | if not dist.shape[-1]==len(self.vertices): 114 | raise ValueError("last dimension of dist should have length len(rays.vertices)") 115 | 116 | # self.vertices -> (n_rays,3) 117 | # dist -> (m,n,..., n_rays) 118 | 119 | # all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays) 120 | 121 | # dist -> (m,n,..., n_rays, 3) 122 | dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1) 123 | # verts -> (m,n,..., n_rays, 3) 124 | verts = np.broadcast_to(self.vertices, dist.shape) 125 | 126 | # dist, verts -> (n_rays, m,n, ..., 3) 127 | dist = np.moveaxis(dist,-2,0) 128 | verts = np.moveaxis(verts,-2,0) 129 | 130 | # vs -> (n_faces, 3, m, n, ..., 3) 131 | vs = (dist*verts)[self.faces] 132 | # vs -> (n_faces, m, n, ..., 3, 3) 133 | vs = np.moveaxis(vs, 1,-2) 134 | # vs -> (n_faces * m * n, 3, 3) 135 | vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3)) 136 | 137 | pa = vs[...,1,:]-vs[...,0,:] 138 | pb = vs[...,2,:]-vs[...,0,:] 139 | 140 | d = .5*np.linalg.norm(np.cross(list(pa), list(pb)), axis = -1) 141 | d = d.reshape((len(self.faces),)+dist.shape[1:-1]) 142 | return np.sum(d, axis = 0) 143 | 144 | 145 | def copy(self, scale=(1,1,1)): 146 | """ returns a copy whose vertices are scaled by given factor""" 147 | scale = np.asarray(scale) 148 | assert scale.shape == (3,) 149 | res = copy.deepcopy(self) 150 | res._vertices *= scale[np.newaxis] 151 | return res 152 | 153 | 154 | 155 | 156 | def rays_from_json(d): 157 | return eval(d["name"])(**d["kwargs"]) 158 | 159 | 160 | ################################################################ 161 | 162 | class Rays_Explicit(Rays_Base): 163 | def __init__(self, vertices0, faces0): 164 | self.vertices0, self.faces0 = vertices0, faces0 165 | super().__init__(vertices0=list(vertices0), faces0=list(faces0)) 166 | 167 | def setup_vertices_faces(self): 168 | return self.vertices0, self.faces0 169 | 170 | 171 | class Rays_Cartesian(Rays_Base): 172 | def __init__(self, n_rays_x=11, n_rays_z=5): 173 | super().__init__(n_rays_x=n_rays_x, n_rays_z=n_rays_z) 174 | 175 | def setup_vertices_faces(self): 176 | """has to return list of ( (z_1,y_1,x_1), ... ) _""" 177 | n_rays_x, n_rays_z = self.kwargs["n_rays_x"], self.kwargs["n_rays_z"] 178 | dphi = np.float32(2. * np.pi / n_rays_x) 179 | dtheta = np.float32(np.pi / n_rays_z) 180 | 181 | verts = [] 182 | for mz in range(n_rays_z): 183 | for mx in range(n_rays_x): 184 | phi = mx * dphi 185 | theta = mz * dtheta 186 | if mz == 0: 187 | theta = 1e-12 188 | if mz == n_rays_z - 1: 189 | theta = np.pi - 1e-12 190 | dx = np.cos(phi) * np.sin(theta) 191 | dy = np.sin(phi) * np.sin(theta) 192 | dz = np.cos(theta) 193 | if mz == 0 or mz == n_rays_z - 1: 194 | dx += 1e-12 195 | dy += 1e-12 196 | verts.append([dz, dy, dx]) 197 | 198 | verts = np.array(verts) 199 | 200 | def _ind(mz, mx): 201 | return mz * n_rays_x + mx 202 | 203 | faces = [] 204 | 205 | for mz in range(n_rays_z - 1): 206 | for mx in range(n_rays_x): 207 | faces.append([_ind(mz, mx), _ind(mz + 1, (mx + 1) % n_rays_x), _ind(mz, (mx + 1) % n_rays_x)]) 208 | faces.append([_ind(mz, mx), _ind(mz + 1, mx), _ind(mz + 1, (mx + 1) % n_rays_x)]) 209 | 210 | faces = np.array(faces) 211 | 212 | return verts, faces 213 | 214 | 215 | class Rays_SubDivide(Rays_Base): 216 | """ 217 | Subdivision polyehdra 218 | 219 | n_level = 1 -> base polyhedra 220 | n_level = 2 -> 1x subdivision 221 | n_level = 3 -> 2x subdivision 222 | ... 223 | """ 224 | 225 | def __init__(self, n_level=4): 226 | super().__init__(n_level=n_level) 227 | 228 | def base_polyhedron(self): 229 | raise NotImplementedError() 230 | 231 | def setup_vertices_faces(self): 232 | n_level = self.kwargs["n_level"] 233 | verts0, faces0 = self.base_polyhedron() 234 | return self._recursive_split(verts0, faces0, n_level) 235 | 236 | def _recursive_split(self, verts, faces, n_level): 237 | if n_level <= 1: 238 | return verts, faces 239 | else: 240 | verts, faces = Rays_SubDivide.split(verts, faces) 241 | return self._recursive_split(verts, faces, n_level - 1) 242 | 243 | @classmethod 244 | def split(self, verts0, faces0): 245 | """split a level""" 246 | 247 | split_edges = dict() 248 | verts = list(verts0[:]) 249 | faces = [] 250 | 251 | def _add(a, b): 252 | """ returns index of middle point and adds vertex if not already added""" 253 | edge = tuple(sorted((a, b))) 254 | if not edge in split_edges: 255 | v = .5 * (verts[a] + verts[b]) 256 | v *= 1. / np.linalg.norm(v) 257 | verts.append(v) 258 | split_edges[edge] = len(verts) - 1 259 | return split_edges[edge] 260 | 261 | for v1, v2, v3 in faces0: 262 | ind1 = _add(v1, v2) 263 | ind2 = _add(v2, v3) 264 | ind3 = _add(v3, v1) 265 | faces.append([v1, ind1, ind3]) 266 | faces.append([v2, ind2, ind1]) 267 | faces.append([v3, ind3, ind2]) 268 | faces.append([ind1, ind2, ind3]) 269 | 270 | return verts, faces 271 | 272 | 273 | class Rays_Tetra(Rays_SubDivide): 274 | """ 275 | Subdivision of a tetrahedron 276 | 277 | n_level = 1 -> normal tetrahedron (4 vertices) 278 | n_level = 2 -> 1x subdivision (10 vertices) 279 | n_level = 3 -> 2x subdivision (34 vertices) 280 | ... 281 | """ 282 | 283 | def base_polyhedron(self): 284 | verts = np.array([ 285 | [np.sqrt(8. / 9), 0., -1. / 3], 286 | [-np.sqrt(2. / 9), np.sqrt(2. / 3), -1. / 3], 287 | [-np.sqrt(2. / 9), -np.sqrt(2. / 3), -1. / 3], 288 | [0., 0., 1.] 289 | ]) 290 | faces = [[0, 1, 2], 291 | [0, 3, 1], 292 | [0, 2, 3], 293 | [1, 3, 2]] 294 | 295 | return verts, faces 296 | 297 | 298 | class Rays_Octo(Rays_SubDivide): 299 | """ 300 | Subdivision of a tetrahedron 301 | 302 | n_level = 1 -> normal Octahedron (6 vertices) 303 | n_level = 2 -> 1x subdivision (18 vertices) 304 | n_level = 3 -> 2x subdivision (66 vertices) 305 | 306 | """ 307 | 308 | def base_polyhedron(self): 309 | verts = np.array([ 310 | [0, 0, 1], 311 | [0, 1, 0], 312 | [0, 0, -1], 313 | [0, -1, 0], 314 | [1, 0, 0], 315 | [-1, 0, 0]]) 316 | 317 | faces = [[0, 1, 4], 318 | [0, 5, 1], 319 | [1, 2, 4], 320 | [1, 5, 2], 321 | [2, 3, 4], 322 | [2, 5, 3], 323 | [3, 0, 4], 324 | [3, 5, 0], 325 | ] 326 | 327 | return verts, faces 328 | 329 | 330 | def reorder_faces(verts, faces): 331 | """reorder faces such that their orientation points outward""" 332 | def _single(face): 333 | return face[::-1] if np.linalg.det(verts[face])>0 else face 334 | return tuple(map(_single, faces)) 335 | 336 | 337 | class Rays_GoldenSpiral(Rays_Base): 338 | def __init__(self, n=70, anisotropy = None): 339 | if n<4: 340 | raise ValueError("At least 4 points have to be given!") 341 | super().__init__(n=n, anisotropy = anisotropy if anisotropy is None else tuple(anisotropy)) 342 | 343 | def setup_vertices_faces(self): 344 | n = self.kwargs["n"] 345 | anisotropy = self.kwargs["anisotropy"] 346 | if anisotropy is None: 347 | anisotropy = np.ones(3) 348 | else: 349 | anisotropy = np.array(anisotropy) 350 | 351 | # the smaller golden angle = 2pi * 0.3819... 352 | g = (3. - np.sqrt(5.)) * np.pi 353 | phi = g * np.arange(n) 354 | # z = np.linspace(-1, 1, n + 2)[1:-1] 355 | # rho = np.sqrt(1. - z ** 2) 356 | # verts = np.stack([rho*np.cos(phi), rho*np.sin(phi),z]).T 357 | # 358 | z = np.linspace(-1, 1, n) 359 | rho = np.sqrt(1. - z ** 2) 360 | verts = np.stack([z, rho * np.sin(phi), rho * np.cos(phi)]).T 361 | 362 | # warnings.warn("ray definition has changed! Old results are invalid!") 363 | 364 | # correct for anisotropy 365 | verts = verts/anisotropy 366 | #verts /= np.linalg.norm(verts, axis=-1, keepdims=True) 367 | 368 | hull = ConvexHull(verts) 369 | faces = reorder_faces(verts,hull.simplices) 370 | 371 | verts /= np.linalg.norm(verts, axis=-1, keepdims=True) 372 | 373 | return verts, faces 374 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/sample_patches.py: -------------------------------------------------------------------------------- 1 | """provides a faster sampling function""" 2 | 3 | import numpy as np 4 | from csbdeep.utils import _raise, choice 5 | 6 | 7 | def sample_patches(datas, patch_size, n_samples, valid_inds=None, verbose=False): 8 | """optimized version of csbdeep.data.sample_patches_from_multiple_stacks 9 | """ 10 | 11 | len(patch_size)==datas[0].ndim or _raise(ValueError()) 12 | 13 | if not all(( a.shape == datas[0].shape for a in datas )): 14 | raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas))) 15 | 16 | if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )): 17 | raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape))) 18 | 19 | if valid_inds is None: 20 | valid_inds = tuple(_s.ravel() for _s in np.meshgrid(*tuple(np.arange(p//2,s-p//2+1) for s,p in zip(datas[0].shape, patch_size)))) 21 | 22 | n_valid = len(valid_inds[0]) 23 | 24 | if n_valid == 0: 25 | raise ValueError("no regions to sample from!") 26 | 27 | idx = choice(range(n_valid), n_samples, replace=(n_valid < n_samples)) 28 | rand_inds = [v[idx] for v in valid_inds] 29 | res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas] 30 | 31 | return res 32 | 33 | 34 | def get_valid_inds(img, patch_size, patch_filter=None): 35 | """ 36 | Returns all indices of an image that 37 | - can be used as center points for sampling patches of a given patch_size, and 38 | - are part of the boolean mask given by the function patch_filter (if provided) 39 | 40 | img: np.ndarray 41 | patch_size: tuple of ints 42 | the width of patches per img dimension, 43 | patch_filter: None or callable 44 | a function with signature patch_filter(img, patch_size) returning a boolean mask 45 | """ 46 | 47 | len(patch_size)==img.ndim or _raise(ValueError()) 48 | 49 | if not all(( 0 < s <= d for s,d in zip(patch_size,img.shape))): 50 | raise ValueError("patch_size %s negative or larger than image shape %s along some dimensions" % (str(patch_size), str(img.shape))) 51 | 52 | if patch_filter is None: 53 | # only cut border indices (which is faster) 54 | patch_mask = np.ones(img.shape,dtype=bool) 55 | valid_inds = tuple(np.arange(p // 2, s - p + p // 2 + 1).astype(np.uint32) for p, s in zip(patch_size, img.shape)) 56 | valid_inds = tuple(s.ravel() for s in np.meshgrid(*valid_inds, indexing='ij')) 57 | else: 58 | patch_mask = patch_filter(img, patch_size) 59 | 60 | # get the valid indices 61 | border_slices = tuple([slice(p // 2, s - p + p // 2 + 1) for p, s in zip(patch_size, img.shape)]) 62 | valid_inds = np.where(patch_mask[border_slices]) 63 | valid_inds = tuple((v + s.start).astype(np.uint32) for s, v in zip(border_slices, valid_inds)) 64 | 65 | return valid_inds 66 | -------------------------------------------------------------------------------- /fintune_on_newdataset/stardist_pkg/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.8.3' 2 | -------------------------------------------------------------------------------- /fintune_on_newdataset/train_classification.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob, time, random, shutil, copy 2 | from tqdm import tqdm 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torchvision 7 | from torchvision import datasets, models, transforms 8 | import torch.utils.data as data 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.optim import lr_scheduler 12 | import torch.nn.functional as F 13 | from torchsummary import summary 14 | from matplotlib import pyplot as plt 15 | from torchvision.models import resnet18, ResNet18_Weights # do not import 16 | from PIL import Image, ImageFile 17 | from skimage import io 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | 20 | # Set the train and validation directory paths 21 | train_directory = 'dataset/train' 22 | valid_directory = 'dataset/val' 23 | 24 | # Batch size 25 | bs = 64 26 | # Number of epochs 27 | num_epochs = 20 28 | # Number of classes 29 | num_classes = 4 30 | # Number of workers 31 | num_cpu = 8 32 | 33 | # Applying transforms to the data 34 | image_transforms = { 35 | 'train': transforms.Compose([ 36 | transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), 37 | transforms.RandomRotation(degrees=15), 38 | transforms.RandomHorizontalFlip(), 39 | transforms.CenterCrop(size=224), 40 | transforms.ToTensor(), 41 | transforms.Normalize([0.485, 0.456, 0.406], 42 | [0.229, 0.224, 0.225]) 43 | ]), 44 | 'valid': transforms.Compose([ 45 | transforms.Resize(size=256), 46 | transforms.CenterCrop(size=224), 47 | transforms.ToTensor(), 48 | transforms.Normalize([0.485, 0.456, 0.406], 49 | [0.229, 0.224, 0.225]) 50 | ]) 51 | } 52 | 53 | # Load data from folders 54 | dataset = { 55 | 'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']), 56 | 'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid']) 57 | } 58 | 59 | # Size of train and validation data 60 | dataset_sizes = { 61 | 'train':len(dataset['train']), 62 | 'valid':len(dataset['valid']) 63 | } 64 | 65 | # Create iterators for data loading 66 | dataloaders = { 67 | 'train':data.DataLoader(dataset['train'], batch_size=bs, shuffle=True, 68 | num_workers=num_cpu, pin_memory=True, drop_last=False), 69 | 'valid':data.DataLoader(dataset['valid'], batch_size=bs, shuffle=False, 70 | num_workers=num_cpu, pin_memory=True, drop_last=False) 71 | } 72 | 73 | # Class names or target labels 74 | class_names = dataset['train'].classes 75 | print("Classes:", class_names) 76 | 77 | # Print the train and validation data sizes 78 | print("Training-set size:",dataset_sizes['train'], 79 | "\nValidation-set size:", dataset_sizes['valid']) 80 | 81 | modelname = 'resnet18' 82 | 83 | # Set default device as gpu, if available 84 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 85 | 86 | weights = ResNet18_Weights.DEFAULT 87 | model = resnet18(weights=None) 88 | num_ftrs = model.fc.in_features 89 | model.fc = nn.Linear(num_ftrs, num_classes) 90 | 91 | 92 | # Transfer the model to GPU 93 | model = model.to(device) 94 | 95 | # Print model summary 96 | print('Model Summary:-\n') 97 | for num, (name, param) in enumerate(model.named_parameters()): 98 | print(num, name, param.requires_grad ) 99 | summary(model, input_size=(3, 224, 224)) 100 | 101 | # Loss function 102 | criterion = nn.CrossEntropyLoss() 103 | 104 | # Optimizer 105 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 106 | 107 | # Learning rate decay 108 | scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) 109 | 110 | since = time.time() 111 | 112 | best_model_wts = copy.deepcopy(model.state_dict()) 113 | best_acc = 0.0 114 | 115 | for epoch in range(1, num_epochs+1): 116 | print('Epoch {}/{}'.format(epoch, num_epochs)) 117 | print('-' * 10) 118 | 119 | # Each epoch has a training and validation phase 120 | for phase in ['train', 'valid']: 121 | if phase == 'train': 122 | model.train() # Set model to training mode 123 | else: 124 | model.eval() # Set model to evaluate mode 125 | 126 | running_loss = 0.0 127 | running_corrects = 0 128 | 129 | # Iterate over data. 130 | n = 0 131 | stream = tqdm(dataloaders[phase]) 132 | for i, (inputs, labels) in enumerate(stream, start=1): 133 | inputs = inputs.to(device) 134 | labels = labels.to(device) 135 | 136 | # zero the parameter gradients 137 | optimizer.zero_grad() 138 | 139 | # forward 140 | # track history if only in train 141 | with torch.set_grad_enabled(phase == 'train'): 142 | outputs = model(inputs) 143 | _, preds = torch.max(outputs, 1) 144 | loss = criterion(outputs, labels) 145 | 146 | # backward + optimize only if in training phase 147 | if phase == 'train': 148 | loss.backward() 149 | optimizer.step() 150 | 151 | # statistics 152 | n += inputs.shape[0] 153 | running_loss += loss.item() * inputs.size(0) 154 | running_corrects += torch.sum(preds == labels.data) 155 | 156 | stream.set_description(f'Batch {i}/{len(dataloaders[phase])} | Loss: {running_loss/n:.4f}, Acc: {running_corrects/n:.4f}') 157 | 158 | if phase == 'train': 159 | scheduler.step() 160 | 161 | epoch_loss = running_loss / dataset_sizes[phase] 162 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 163 | 164 | print('Epoch {}-{} Loss: {:.4f} Acc: {:.4f}'.format( 165 | epoch, phase, epoch_loss, epoch_acc)) 166 | 167 | # deep copy the model 168 | if phase == 'valid' and epoch_acc >= best_acc: 169 | best_acc = epoch_acc 170 | best_model_wts = copy.deepcopy(model.state_dict()) 171 | print('Update best model!') 172 | 173 | time_elapsed = time.time() - since 174 | print('Training complete in {:.0f}m {:.0f}s'.format( 175 | time_elapsed // 60, time_elapsed % 60)) 176 | print('Best val Acc: {:4f}'.format(best_acc)) 177 | 178 | # load best model weights 179 | model.load_state_dict(best_model_wts) 180 | torch.save(model, 'logs/resnet18_4class.pth') 181 | torch.save(model.state_dict(), 'logs/resnet18_4class.tar') 182 | -------------------------------------------------------------------------------- /fintune_on_newdataset/unsup_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import os 8 | import numpy as np 9 | import shutil 10 | import torch 11 | import torch.nn 12 | import torchvision.models as models 13 | from torch.autograd import Variable 14 | import torch.cuda 15 | import torchvision.transforms as transforms 16 | from PIL import Image 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from sklearn.datasets import make_blobs 20 | from sklearn.cluster import KMeans 21 | from sklearn.metrics import silhouette_score 22 | from sklearn.preprocessing import StandardScaler 23 | from sklearn.metrics import pairwise_distances_argmin_min 24 | from scipy.spatial.distance import pdist, squareform 25 | from skimage import io, segmentation, morphology, exposure 26 | from skimage.color import rgb2hsv 27 | img_to_tensor = transforms.ToTensor() 28 | import random 29 | import tifffile as tif 30 | path = '/data1/partitionA/CUHKSZ/histopath_2022/grand_competition/Train_Labeled/images/' 31 | files = os.listdir(path) 32 | binary_path = '0/' 33 | gray_path = '1/' 34 | colored_path = 'colored/' 35 | os.makedirs(binary_path, exist_ok=True) 36 | os.makedirs(colored_path, exist_ok=True) 37 | os.makedirs(gray_path, exist_ok=True) 38 | for img_name in files: 39 | img_path = path + str(img_name) 40 | if img_name.endswith('.tif') or img_name.endswith('.tiff'): 41 | img_data = tif.imread(img_path) 42 | else: 43 | img_data = io.imread(img_path) 44 | if len(img_data.shape) == 2 or (len(img_data.shape) == 3 and img_data.shape[-1] == 1): 45 | shutil.copyfile(path + img_name, binary_path + img_name) 46 | elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: 47 | shutil.copyfile(path + img_name, colored_path + img_name) 48 | else: 49 | hsv_img = rgb2hsv(img_data) 50 | s = hsv_img[:,:,1] 51 | v = hsv_img[:,:,2] 52 | print(img_name,s.mean(),v.mean()) 53 | if s.mean() > 0.1 or (v.mean()<0.1 or v.mean() > 0.6): 54 | shutil.copyfile(path + img_name, colored_path + img_name) 55 | else: 56 | shutil.copyfile(path + img_name, gray_path + img_name) 57 | 58 | 59 | 60 | # In[3]: 61 | 62 | 63 | ####Phrase 2 clustering by cell size 64 | from skimage import measure 65 | colored_path = 'colored/' 66 | label_path = 'allimages/tif/' 67 | big_path = '2/' 68 | small_path = '3/' 69 | files = os.listdir(colored_path) 70 | os.makedirs(big_path, exist_ok=True) 71 | os.makedirs(small_path, exist_ok=True) 72 | for img_name in files: 73 | label = tif.imread(label_path + img_name.split('.')[0]+'.tif') 74 | props = measure.regionprops(label) 75 | num_pix = [] 76 | for idx in range(len(props)): 77 | num_pix.append(props[idx].area) 78 | max_area = max(num_pix) 79 | print(max_area) 80 | if max_area > 30000: 81 | shutil.copyfile(path + img_name, big_path + img_name) 82 | else: 83 | shutil.copyfile(path + img_name, small_path + img_name) 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Mar 20 14:23:55 2022 5 | 6 | @author: jma 7 | """ 8 | 9 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/convnext.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/convnext.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/convnext.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/convnext.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/flexible_unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/flexible_unet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/flexible_unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/flexible_unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/flexible_unet_convext.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/flexible_unet_convext.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/flexible_unet_convext.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/flexible_unet_convext.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/flexible_unet_convnext.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/flexible_unet_convnext.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/swin_unetr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/swin_unetr.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/tmp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__pycache__/unetr2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/models/__pycache__/unetr2d.cpython-38.pyc -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from functools import partial 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_, DropPath 13 | from timm.models.registry import register_model 14 | from monai.networks.layers.factories import Act, Conv, Pad, Pool 15 | from monai.networks.layers.utils import get_norm_layer 16 | from monai.utils.module import look_up_option 17 | from typing import List, NamedTuple, Optional, Tuple, Type, Union 18 | class Block(nn.Module): 19 | r""" ConvNeXt Block. There are two equivalent implementations: 20 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 21 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 22 | We use (2) as we find it slightly faster in PyTorch 23 | 24 | Args: 25 | dim (int): Number of input channels. 26 | drop_path (float): Stochastic depth rate. Default: 0.0 27 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 28 | """ 29 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 30 | super().__init__() 31 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 32 | self.norm = LayerNorm(dim, eps=1e-6) 33 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 34 | self.act = nn.GELU() 35 | self.pwconv2 = nn.Linear(4 * dim, dim) 36 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 37 | requires_grad=True) if layer_scale_init_value > 0 else None 38 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 39 | 40 | def forward(self, x): 41 | input = x 42 | x = self.dwconv(x) 43 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 44 | x = self.norm(x) 45 | x = self.pwconv1(x) 46 | x = self.act(x) 47 | x = self.pwconv2(x) 48 | if self.gamma is not None: 49 | x = self.gamma * x 50 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 51 | 52 | x = input + self.drop_path(x) 53 | return x 54 | 55 | class ConvNeXt(nn.Module): 56 | r""" ConvNeXt 57 | A PyTorch impl of : `A ConvNet for the 2020s` - 58 | https://arxiv.org/pdf/2201.03545.pdf 59 | 60 | Args: 61 | in_chans (int): Number of input image channels. Default: 3 62 | num_classes (int): Number of classes for classification head. Default: 1000 63 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 64 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 65 | drop_path_rate (float): Stochastic depth rate. Default: 0. 66 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 67 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 68 | """ 69 | def __init__(self, in_chans=3, num_classes=21841, 70 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 71 | layer_scale_init_value=1e-6, head_init_scale=1., out_indices=[0, 1, 2, 3], 72 | ): 73 | super().__init__() 74 | # conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", 2] 75 | # self._conv_stem = conv_type(self.in_channels, self.in_channels, kernel_size=3, stride=stride, bias=False) 76 | # self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size) 77 | 78 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 79 | stem = nn.Sequential( 80 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 81 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 82 | ) 83 | self.downsample_layers.append(stem) 84 | for i in range(3): 85 | downsample_layer = nn.Sequential( 86 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 87 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 88 | ) 89 | self.downsample_layers.append(downsample_layer) 90 | 91 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 92 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 93 | cur = 0 94 | for i in range(4): 95 | stage = nn.Sequential( 96 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 97 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 98 | ) 99 | self.stages.append(stage) 100 | cur += depths[i] 101 | 102 | 103 | self.out_indices = out_indices 104 | 105 | norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first") 106 | for i_layer in range(4): 107 | layer = norm_layer(dims[i_layer]) 108 | layer_name = f'norm{i_layer}' 109 | self.add_module(layer_name, layer) 110 | self.apply(self._init_weights) 111 | 112 | 113 | def _init_weights(self, m): 114 | if isinstance(m, (nn.Conv2d, nn.Linear)): 115 | trunc_normal_(m.weight, std=.02) 116 | nn.init.constant_(m.bias, 0) 117 | 118 | def forward_features(self, x): 119 | outs = [] 120 | 121 | for i in range(4): 122 | x = self.downsample_layers[i](x) 123 | x = self.stages[i](x) 124 | if i in self.out_indices: 125 | norm_layer = getattr(self, f'norm{i}') 126 | x_out = norm_layer(x) 127 | 128 | outs.append(x_out) 129 | 130 | return tuple(outs) 131 | 132 | def forward(self, x): 133 | x = self.forward_features(x) 134 | 135 | return x 136 | 137 | class LayerNorm(nn.Module): 138 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 139 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 140 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 141 | with shape (batch_size, channels, height, width). 142 | """ 143 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 144 | super().__init__() 145 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 146 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 147 | self.eps = eps 148 | self.data_format = data_format 149 | if self.data_format not in ["channels_last", "channels_first"]: 150 | raise NotImplementedError 151 | self.normalized_shape = (normalized_shape, ) 152 | 153 | def forward(self, x): 154 | if self.data_format == "channels_last": 155 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 156 | elif self.data_format == "channels_first": 157 | u = x.mean(1, keepdim=True) 158 | s = (x - u).pow(2).mean(1, keepdim=True) 159 | x = (x - u) / torch.sqrt(s + self.eps) 160 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 161 | return x 162 | 163 | 164 | model_urls = { 165 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 166 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 167 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 168 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 169 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 170 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 171 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 172 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 173 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 174 | } 175 | 176 | @register_model 177 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 178 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 179 | if pretrained: 180 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 181 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 182 | model.load_state_dict(checkpoint["model"]) 183 | return model 184 | 185 | @register_model 186 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 187 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 188 | if pretrained: 189 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 190 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 191 | model.load_state_dict(checkpoint["model"], strict=False) 192 | return model 193 | 194 | @register_model 195 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 197 | if pretrained: 198 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 199 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 200 | model.load_state_dict(checkpoint["model"], strict=False) 201 | return model 202 | 203 | @register_model 204 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 205 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 206 | if pretrained: 207 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 208 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 209 | model.load_state_dict(checkpoint["model"]) 210 | return model 211 | 212 | @register_model 213 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 214 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 215 | if pretrained: 216 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 217 | url = model_urls['convnext_xlarge_22k'] 218 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 219 | model.load_state_dict(checkpoint["model"]) 220 | return model -------------------------------------------------------------------------------- /models/flexible_unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import List, Optional, Sequence, Tuple, Union 13 | 14 | import torch 15 | from torch import nn 16 | 17 | from monai.networks.blocks import UpSample 18 | from monai.networks.layers.factories import Conv 19 | from monai.networks.layers.utils import get_act_layer 20 | from monai.networks.nets import EfficientNetBNFeatures 21 | from monai.networks.nets.basic_unet import UpCat 22 | from monai.utils import InterpolateMode 23 | 24 | __all__ = ["FlexibleUNet"] 25 | 26 | encoder_feature_channel = { 27 | "efficientnet-b0": (16, 24, 40, 112, 320), 28 | "efficientnet-b1": (16, 24, 40, 112, 320), 29 | "efficientnet-b2": (16, 24, 48, 120, 352), 30 | "efficientnet-b3": (24, 32, 48, 136, 384), 31 | "efficientnet-b4": (24, 32, 56, 160, 448), 32 | "efficientnet-b5": (24, 40, 64, 176, 512), 33 | "efficientnet-b6": (32, 40, 72, 200, 576), 34 | "efficientnet-b7": (32, 48, 80, 224, 640), 35 | "efficientnet-b8": (32, 56, 88, 248, 704), 36 | "efficientnet-l2": (72, 104, 176, 480, 1376), 37 | } 38 | 39 | 40 | def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple: 41 | """ 42 | Get the encoder output channels by given backbone name. 43 | 44 | Args: 45 | backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7]. 46 | in_channels: channel of input tensor, default to 3. 47 | 48 | Returns: 49 | A tuple of output feature map channels' length . 50 | """ 51 | encoder_channel_tuple = encoder_feature_channel[backbone] 52 | encoder_channel_list = [in_channels] + list(encoder_channel_tuple) 53 | encoder_channel = tuple(encoder_channel_list) 54 | return encoder_channel 55 | 56 | 57 | class UNetDecoder(nn.Module): 58 | """ 59 | UNet Decoder. 60 | This class refers to `segmentation_models.pytorch 61 | `_. 62 | 63 | Args: 64 | spatial_dims: number of spatial dimensions. 65 | encoder_channels: number of output channels for all feature maps in encoder. 66 | `len(encoder_channels)` should be no less than 2. 67 | decoder_channels: number of output channels for all feature maps in decoder. 68 | `len(decoder_channels)` should equal to `len(encoder_channels) - 1`. 69 | act: activation type and arguments. 70 | norm: feature normalization type and arguments. 71 | dropout: dropout ratio. 72 | bias: whether to have a bias term in convolution blocks in this decoder. 73 | upsample: upsampling mode, available options are 74 | ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. 75 | pre_conv: a conv block applied before upsampling. 76 | Only used in the "nontrainable" or "pixelshuffle" mode. 77 | interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} 78 | Only used in the "nontrainable" mode. 79 | align_corners: set the align_corners parameter for upsample. Defaults to True. 80 | Only used in the "nontrainable" mode. 81 | is_pad: whether to pad upsampling features to fit the encoder spatial dims. 82 | 83 | """ 84 | 85 | def __init__( 86 | self, 87 | spatial_dims: int, 88 | encoder_channels: Sequence[int], 89 | decoder_channels: Sequence[int], 90 | act: Union[str, tuple], 91 | norm: Union[str, tuple], 92 | dropout: Union[float, tuple], 93 | bias: bool, 94 | upsample: str, 95 | pre_conv: Optional[str], 96 | interp_mode: str, 97 | align_corners: Optional[bool], 98 | is_pad: bool, 99 | ): 100 | 101 | super().__init__() 102 | if len(encoder_channels) < 2: 103 | raise ValueError("the length of `encoder_channels` should be no less than 2.") 104 | if len(decoder_channels) != len(encoder_channels) - 1: 105 | raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.") 106 | 107 | in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1]) 108 | skip_channels = list(encoder_channels[1:-1][::-1]) + [0] 109 | halves = [True] * (len(skip_channels) - 1) 110 | halves.append(False) 111 | blocks = [] 112 | for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves): 113 | blocks.append( 114 | UpCat( 115 | spatial_dims=spatial_dims, 116 | in_chns=in_chn, 117 | cat_chns=skip_chn, 118 | out_chns=out_chn, 119 | act=act, 120 | norm=norm, 121 | dropout=dropout, 122 | bias=bias, 123 | upsample=upsample, 124 | pre_conv=pre_conv, 125 | interp_mode=interp_mode, 126 | align_corners=align_corners, 127 | halves=halve, 128 | is_pad=is_pad, 129 | ) 130 | ) 131 | self.blocks = nn.ModuleList(blocks) 132 | 133 | def forward(self, features: List[torch.Tensor], skip_connect: int = 4): 134 | skips = features[:-1][::-1] 135 | features = features[1:][::-1] 136 | 137 | x = features[0] 138 | for i, block in enumerate(self.blocks): 139 | if i < skip_connect: 140 | skip = skips[i] 141 | else: 142 | skip = None 143 | x = block(x, skip) 144 | 145 | return x 146 | 147 | 148 | class SegmentationHead(nn.Sequential): 149 | """ 150 | Segmentation head. 151 | This class refers to `segmentation_models.pytorch 152 | `_. 153 | 154 | Args: 155 | spatial_dims: number of spatial dimensions. 156 | in_channels: number of input channels for the block. 157 | out_channels: number of output channels for the block. 158 | kernel_size: kernel size for the conv layer. 159 | act: activation type and arguments. 160 | scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. 161 | 162 | """ 163 | 164 | def __init__( 165 | self, 166 | spatial_dims: int, 167 | in_channels: int, 168 | out_channels: int, 169 | kernel_size: int = 3, 170 | act: Optional[Union[Tuple, str]] = None, 171 | scale_factor: float = 1.0, 172 | ): 173 | 174 | conv_layer = Conv[Conv.CONV, spatial_dims]( 175 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2 176 | ) 177 | up_layer: nn.Module = nn.Identity() 178 | if scale_factor > 1.0: 179 | up_layer = UpSample( 180 | spatial_dims=spatial_dims, 181 | scale_factor=scale_factor, 182 | mode="nontrainable", 183 | pre_conv=None, 184 | interp_mode=InterpolateMode.LINEAR, 185 | ) 186 | if act is not None: 187 | act_layer = get_act_layer(act) 188 | else: 189 | act_layer = nn.Identity() 190 | super().__init__(conv_layer, up_layer, act_layer) 191 | 192 | 193 | class FlexibleUNet(nn.Module): 194 | """ 195 | A flexible implementation of UNet-like encoder-decoder architecture. 196 | """ 197 | 198 | def __init__( 199 | self, 200 | in_channels: int, 201 | out_channels: int, 202 | backbone: str, 203 | pretrained: bool = False, 204 | decoder_channels: Tuple = (256, 128, 64, 32, 16), 205 | spatial_dims: int = 2, 206 | norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}), 207 | act: Union[str, tuple] = ("relu", {"inplace": True}), 208 | dropout: Union[float, tuple] = 0.0, 209 | decoder_bias: bool = False, 210 | upsample: str = "nontrainable", 211 | interp_mode: str = "nearest", 212 | is_pad: bool = True, 213 | ) -> None: 214 | """ 215 | A flexible implement of UNet, in which the backbone/encoder can be replaced with 216 | any efficient network. Currently the input must have a 2 or 3 spatial dimension 217 | and the spatial size of each dimension must be a multiple of 32 if is pad parameter 218 | is False 219 | 220 | Args: 221 | in_channels: number of input channels. 222 | out_channels: number of output channels. 223 | backbone: name of backbones to initialize, only support efficientnet right now, 224 | can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2]. 225 | pretrained: whether to initialize pretrained ImageNet weights, only available 226 | for spatial_dims=2 and batch norm is used, default to False. 227 | decoder_channels: number of output channels for all feature maps in decoder. 228 | `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default 229 | to (256, 128, 64, 32, 16). 230 | spatial_dims: number of spatial dimensions, default to 2. 231 | norm: normalization type and arguments, default to ("batch", {"eps": 1e-3, 232 | "momentum": 0.1}). 233 | act: activation type and arguments, default to ("relu", {"inplace": True}). 234 | dropout: dropout ratio, default to 0.0. 235 | decoder_bias: whether to have a bias term in decoder's convolution blocks. 236 | upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``, 237 | ``"nontrainable"``. 238 | interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} 239 | Only used in the "nontrainable" mode. 240 | is_pad: whether to pad upsampling features to fit features from encoder. Default to True. 241 | If this parameter is set to "True", the spatial dim of network input can be arbitary 242 | size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32. 243 | """ 244 | super().__init__() 245 | 246 | if backbone not in encoder_feature_channel: 247 | raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.") 248 | 249 | if spatial_dims not in (2, 3): 250 | raise ValueError("spatial_dims can only be 2 or 3.") 251 | 252 | adv_prop = "ap" in backbone 253 | 254 | self.backbone = backbone 255 | self.spatial_dims = spatial_dims 256 | model_name = backbone 257 | encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels) 258 | self.encoder = EfficientNetBNFeatures( 259 | model_name=model_name, 260 | pretrained=pretrained, 261 | in_channels=in_channels, 262 | spatial_dims=spatial_dims, 263 | norm=norm, 264 | adv_prop=adv_prop, 265 | ) 266 | self.decoder = UNetDecoder( 267 | spatial_dims=spatial_dims, 268 | encoder_channels=encoder_channels, 269 | decoder_channels=decoder_channels, 270 | act=act, 271 | norm=norm, 272 | dropout=dropout, 273 | bias=decoder_bias, 274 | upsample=upsample, 275 | interp_mode=interp_mode, 276 | pre_conv=None, 277 | align_corners=None, 278 | is_pad=is_pad, 279 | ) 280 | self.dist_head = SegmentationHead( 281 | spatial_dims=spatial_dims, 282 | in_channels=decoder_channels[-1], 283 | out_channels=32, 284 | kernel_size=1, 285 | act='relu', 286 | ) 287 | self.prob_head = SegmentationHead( 288 | spatial_dims=spatial_dims, 289 | in_channels=decoder_channels[-1], 290 | out_channels=1, 291 | kernel_size=1, 292 | act='sigmoid', 293 | ) 294 | 295 | def forward(self, inputs: torch.Tensor): 296 | """ 297 | Do a typical encoder-decoder-header inference. 298 | 299 | Args: 300 | inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, 301 | N is defined by `dimensions`. 302 | 303 | Returns: 304 | A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``. 305 | 306 | """ 307 | x = inputs 308 | enc_out = self.encoder(x) 309 | decoder_out = self.decoder(enc_out) 310 | dist = self.dist_head(decoder_out) 311 | prob = self.prob_head(decoder_out) 312 | return dist,prob 313 | -------------------------------------------------------------------------------- /overlay.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | ###overlay 5 | import cv2 6 | import math 7 | import random 8 | import colorsys 9 | import numpy as np 10 | import itertools 11 | import matplotlib.pyplot as plt 12 | from matplotlib import cm 13 | import os 14 | import scipy.io as io 15 | def get_bounding_box(img): 16 | """Get bounding box coordinate information.""" 17 | rows = np.any(img, axis=1) 18 | cols = np.any(img, axis=0) 19 | rmin, rmax = np.where(rows)[0][[0, -1]] 20 | cmin, cmax = np.where(cols)[0][[0, -1]] 21 | # due to python indexing, need to add 1 to max 22 | # else accessing will be 1px in the box, not out 23 | rmax += 1 24 | cmax += 1 25 | return [rmin, rmax, cmin, cmax] 26 | #### 27 | def colorize(ch, vmin, vmax): 28 | """Will clamp value value outside the provided range to vmax and vmin.""" 29 | cmap = plt.get_cmap("jet") 30 | ch = np.squeeze(ch.astype("float32")) 31 | vmin = vmin if vmin is not None else ch.min() 32 | vmax = vmax if vmax is not None else ch.max() 33 | ch[ch > vmax] = vmax # clamp value 34 | ch[ch < vmin] = vmin 35 | ch = (ch - vmin) / (vmax - vmin + 1.0e-16) 36 | # take RGB from RGBA heat map 37 | ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") 38 | return ch_cmap 39 | 40 | 41 | #### 42 | def random_colors(N, bright=True): 43 | """Generate random colors. 44 | 45 | To get visually distinct colors, generate them in HSV space then 46 | convert to RGB. 47 | """ 48 | brightness = 1.0 if bright else 0.7 49 | hsv = [(i / N, 1, brightness) for i in range(N)] 50 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 51 | random.shuffle(colors) 52 | return colors 53 | 54 | 55 | #### 56 | def visualize_instances_map( 57 | input_image, inst_map, type_map=None, type_colour=None, line_thickness=2 58 | ): 59 | """Overlays segmentation results on image as contours. 60 | 61 | Args: 62 | input_image: input image 63 | inst_map: instance mask with unique value for every object 64 | type_map: type mask with unique value for every class 65 | type_colour: a dict of {type : colour} , `type` is from 0-N 66 | and `colour` is a tuple of (R, G, B) 67 | line_thickness: line thickness of contours 68 | 69 | Returns: 70 | overlay: output image with segmentation overlay as contours 71 | """ 72 | overlay = np.copy((input_image).astype(np.uint8)) 73 | 74 | inst_list = list(np.unique(inst_map)) # get list of instances 75 | inst_list.remove(0) # remove background 76 | 77 | inst_rng_colors = random_colors(len(inst_list)) 78 | inst_rng_colors = np.array(inst_rng_colors) * 255 79 | inst_rng_colors = inst_rng_colors.astype(np.uint8) 80 | 81 | for inst_idx, inst_id in enumerate(inst_list): 82 | inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object 83 | y1, y2, x1, x2 = get_bounding_box(inst_map_mask) 84 | y1 = y1 - 2 if y1 - 2 >= 0 else y1 85 | x1 = x1 - 2 if x1 - 2 >= 0 else x1 86 | x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 87 | y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 88 | inst_map_crop = inst_map_mask[y1:y2, x1:x2] 89 | contours_crop = cv2.findContours( 90 | inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 91 | ) 92 | # only has 1 instance per map, no need to check #contour detected by opencv 93 | #print(contours_crop) 94 | contours_crop = np.squeeze( 95 | contours_crop[0][0].astype("int32") 96 | ) # * opencv protocol format may break 97 | 98 | if len(contours_crop.shape) == 1: 99 | contours_crop = contours_crop.reshape(1,-1) 100 | #print(contours_crop.shape) 101 | contours_crop += np.asarray([[x1, y1]]) # index correction 102 | if type_map is not None: 103 | type_map_crop = type_map[y1:y2, x1:x2] 104 | type_id = np.unique(type_map_crop).max() # non-zero 105 | inst_colour = type_colour[type_id] 106 | else: 107 | inst_colour = (inst_rng_colors[inst_idx]).tolist() 108 | cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness) 109 | return overlay 110 | 111 | 112 | # In[ ]: 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | join = os.path.join 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from collections import OrderedDict 9 | from torchvision import datasets, models, transforms 10 | from classifiers import resnet10, resnet18 11 | 12 | from utils_modify import sliding_window_inference,sliding_window_inference_large,__proc_np_hv 13 | from PIL import Image 14 | import torch.nn.functional as F 15 | from skimage import io, segmentation, morphology, measure, exposure 16 | import tifffile as tif 17 | from models.flexible_unet_convnext import FlexibleUNet_star,FlexibleUNet_hv 18 | #from overlay import visualize_instances_map 19 | 20 | def normalize_channel(img, lower=1, upper=99): 21 | non_zero_vals = img[np.nonzero(img)] 22 | percentiles = np.percentile(non_zero_vals, [lower, upper]) 23 | if percentiles[1] - percentiles[0] > 0.001: 24 | img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8') 25 | else: 26 | img_norm = img 27 | return img_norm.astype(np.uint8) 28 | #torch.cuda.synchronize() 29 | parser = argparse.ArgumentParser('Baseline for Microscopy image segmentation', add_help=False) 30 | # Dataset parameters 31 | parser.add_argument('-i', '--input_path', default='./inputs', type=str, help='training data path; subfolders: images, labels') 32 | parser.add_argument("-o", '--output_path', default='./outputs', type=str, help='output path') 33 | parser.add_argument('--model_path', default='./models', help='path where to save models and segmentation results') 34 | parser.add_argument('--show_overlay', required=False, default=False, action="store_true", help='save segmentation overlay') 35 | 36 | # Model parameters 37 | parser.add_argument('--model_name', default='efficientunet', help='select mode: unet, unetr, swinunetr') 38 | parser.add_argument('--input_size', default=512, type=int, help='segmentation classes') 39 | args = parser.parse_args() 40 | input_path = args.input_path 41 | output_path = args.output_path 42 | model_path = args.model_path 43 | os.makedirs(output_path, exist_ok=True) 44 | #overlay_path = 'overlays/' 45 | #print(input_path) 46 | 47 | img_names = sorted(os.listdir(join(input_path))) 48 | #print(img_names) 49 | 50 | 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | 53 | 54 | preprocess=transforms.Compose([ 55 | transforms.Resize(size=256), 56 | transforms.CenterCrop(size=224), 57 | transforms.ToTensor(), 58 | transforms.Normalize([0.485, 0.456, 0.406], 59 | [0.229, 0.224, 0.225]) 60 | ]) 61 | roi_size = (512, 512) 62 | overlap = 0.5 63 | np_thres, ksize, overall_thres, obj_size_thres = 0.6, 15, 0.4, 100 64 | n_rays = 32 65 | sw_batch_size = 4 66 | num_classes= 4 67 | block_size = 2048 68 | min_overlap = 128 69 | context = 128 70 | with torch.no_grad(): 71 | for img_name in img_names: 72 | #print(img_name) 73 | if img_name.endswith('.tif') or img_name.endswith('.tiff'): 74 | img_data = tif.imread(join(input_path, img_name)) 75 | else: 76 | img_data = io.imread(join(input_path, img_name)) 77 | # normalize image data 78 | if len(img_data.shape) == 2: 79 | img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1) 80 | elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: 81 | img_data = img_data[:,:, :3] 82 | else: 83 | pass 84 | pre_img_data = np.zeros(img_data.shape, dtype=np.uint8) 85 | for i in range(3): 86 | img_channel_i = img_data[:,:,i] 87 | if len(img_channel_i[np.nonzero(img_channel_i)])>0: 88 | pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99) 89 | inputs=preprocess(Image.fromarray(pre_img_data)).unsqueeze(0).to(device) 90 | cls_MODEL = model_path + '/cls/resnet18_4class_all_modified.tar' 91 | model = resnet18().to(device) 92 | model.load_state_dict(torch.load(cls_MODEL)) 93 | model.eval() 94 | outputs = model(inputs) 95 | _, preds = torch.max(outputs, 1) 96 | label=preds[0].cpu().numpy() 97 | #print(label) 98 | test_npy01 = pre_img_data 99 | if label in [0,1,2] or img_data.shape[0] > 4000: 100 | if label == 0: 101 | model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device) 102 | checkpoint = torch.load(model_path+'/0/best_model.pth', map_location=torch.device(device)) 103 | model.load_state_dict(checkpoint['model_state_dict']) 104 | model.eval() 105 | 106 | output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device) 107 | tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label) 108 | 109 | elif label == 1: 110 | model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device) 111 | checkpoint = torch.load(model_path+'/1/best_model.pth', map_location=torch.device(device)) 112 | model.load_state_dict(checkpoint['model_state_dict']) 113 | model.eval() 114 | 115 | output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device) 116 | tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label) 117 | elif label == 2: 118 | model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device) 119 | checkpoint = torch.load(model_path+'/2/best_model.pth', map_location=torch.device(device)) 120 | model.load_state_dict(checkpoint['model_state_dict']) 121 | model.eval() 122 | 123 | output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device) 124 | tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label) 125 | 126 | 127 | else: 128 | model = FlexibleUNet_hv(in_channels=3,out_channels=2+2,backbone='convnext_small',pretrained=False,n_rays=2,prob_out_channels=2,).to(device) 129 | checkpoint = torch.load(model_path+'/3/best_model_converted.pth', map_location=torch.device(device)) 130 | #model.load_state_dict(checkpoint['model_state_dict']) 131 | #od = OrderedDict() 132 | #for k, v in checkpoint['model_state_dict'].items(): 133 | #od[k.replace('module.', '')] = v 134 | model.load_state_dict(checkpoint) 135 | model.to(device) 136 | model.eval() 137 | test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0, 3, 1, 2).type(torch.FloatTensor).to(device) 138 | if isinstance(roi_size, tuple): 139 | roi = roi_size 140 | 141 | output_hv, output_np = sliding_window_inference(test_tensor, roi, sw_batch_size, model, overlap=overlap) 142 | pred_dict = {'np': output_np, 'hv': output_hv} 143 | pred_dict = OrderedDict( 144 | [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] # NHWC 145 | ) 146 | pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] 147 | pred_output = torch.cat(list(pred_dict.values()), -1).cpu().numpy() # NHW3 148 | pred_map = np.squeeze(pred_output) # HW3 149 | pred_inst = __proc_np_hv(pred_map, np_thres, ksize, overall_thres, obj_size_thres) 150 | raw_pred_shape = pred_inst.shape[:2] 151 | output_label = pred_inst 152 | 153 | tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label) 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /predict_unet_convnext.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | join = os.path.join 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import monai 8 | import torch.nn as nn 9 | 10 | from utils import sliding_window_inference 11 | #from baseline.models.unetr2d import UNETR2D 12 | import time 13 | from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label 14 | from stardist import random_label_cmap,ray_angles 15 | from stardist import star_dist,edt_prob 16 | from skimage import io, segmentation, morphology, measure, exposure 17 | import tifffile as tif 18 | import cv2 19 | from overlay import visualize_instances_map 20 | from models.flexible_unet import FlexibleUNet 21 | from models.flexible_unet_convext import FlexibleUNetConvext 22 | def normalize_channel(img, lower=1, upper=99): 23 | non_zero_vals = img[np.nonzero(img)] 24 | percentiles = np.percentile(non_zero_vals, [lower, upper]) 25 | if percentiles[1] - percentiles[0] > 0.001: 26 | img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8') 27 | else: 28 | img_norm = img 29 | return img_norm.astype(np.uint8) 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser('Baseline for Microscopy image segmentation', add_help=False) 33 | # Dataset parameters 34 | #parser.add_argument('-i', '--input_path', default='./inputs', type=str, help='training data path; subfolders: images, labels') 35 | #parser.add_argument("-o", '--output_path', default='./outputs', type=str, help='output path') 36 | parser.add_argument('--model_path', default='./work_dir/swinunetr_3class', help='path where to save models and segmentation results') 37 | parser.add_argument('--show_overlay', required=False, default=False, action="store_true", help='save segmentation overlay') 38 | 39 | # Model parameters 40 | parser.add_argument('--model_name', default='efficientunet', help='select mode: unet, unetr, swinunetr') 41 | parser.add_argument('--num_class', default=3, type=int, help='segmentation classes') 42 | parser.add_argument('--input_size', default=512, type=int, help='segmentation classes') 43 | args = parser.parse_args() 44 | 45 | input_path = '/home/data/TuningSet/' 46 | output_path = '/home/data/output/' 47 | overlay_path = '/home/data/overlay/' 48 | 49 | 50 | img_names = sorted(os.listdir(join(input_path))) 51 | n_rays = 32 52 | 53 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 | 55 | 56 | 57 | if args.model_name.lower() == "efficientunet": 58 | model = FlexibleUNetConvext( 59 | in_channels=3, 60 | out_channels=n_rays+1, 61 | backbone='convnext_small', 62 | pretrained=True, 63 | ).to(device) 64 | 65 | 66 | 67 | sigmoid = nn.Sigmoid() 68 | checkpoint = torch.load('/home/louwei/stardist_convnext/efficientunet_3class/best_model.pth', map_location=torch.device(device)) 69 | model.load_state_dict(checkpoint['model_state_dict']) 70 | #%% 71 | roi_size = (args.input_size, args.input_size) 72 | sw_batch_size = 4 73 | model.eval() 74 | with torch.no_grad(): 75 | for img_name in img_names: 76 | print(img_name) 77 | if img_name.endswith('.tif') or img_name.endswith('.tiff'): 78 | img_data = tif.imread(join(input_path, img_name)) 79 | else: 80 | img_data = io.imread(join(input_path, img_name)) 81 | # normalize image data 82 | if len(img_data.shape) == 2: 83 | img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1) 84 | elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: 85 | img_data = img_data[:,:, :3] 86 | else: 87 | pass 88 | pre_img_data = np.zeros(img_data.shape, dtype=np.uint8) 89 | for i in range(3): 90 | img_channel_i = img_data[:,:,i] 91 | if len(img_channel_i[np.nonzero(img_channel_i)])>0: 92 | pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99) 93 | 94 | t0 = time.time() 95 | #test_npy01 = pre_img_data/np.max(pre_img_data) 96 | test_npy01 = pre_img_data 97 | test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) 98 | output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model) 99 | #test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W) 100 | prob = output_prob[0][0].cpu().numpy() 101 | dist = output_dist[0].cpu().numpy() 102 | 103 | 104 | dist = np.transpose(dist,(1,2,0)) 105 | dist = np.maximum(1e-3, dist) 106 | points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4) 107 | 108 | coord = dist_to_coord(disti,points) 109 | 110 | star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape) 111 | tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), star_label) 112 | overlay = visualize_instances_map(pre_img_data,star_label) 113 | cv2.imwrite(join(overlay_path, img_name.split('.')[0]+'.png'), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) 114 | 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gputools==0.2.13 2 | h5py==3.7.0 3 | huggingface-hub==0.10.1 4 | imagecodecs 5 | imageio==2.22.2 6 | importlib-metadata==5.0.0 7 | kiwisolver==1.4.4 8 | llvmlite==0.39.1 9 | Mako==1.2.3 10 | Markdown==3.4.1 11 | MarkupSafe==2.1.1 12 | matplotlib==3.6.1 13 | mkl-fft==1.3.1 14 | mkl-service==2.4.0 15 | monai==1.0.0 16 | networkx==2.8.7 17 | numba==0.56.3 18 | numexpr 19 | numpy 20 | oauthlib==3.2.2 21 | opencv-python==4.6.0.66 22 | packaging 23 | pandas==1.4.4 24 | Pillow==9.2.0 25 | scikit-image==0.19.3 26 | scipy==1.9.2 27 | stardist==0.8.3 28 | tensorboard==2.10.1 29 | tensorboard-data-server==0.6.1 30 | tensorboard-plugin-wit==1.8.1 31 | tifffile==2022.10.10 32 | timm==0.6.11 33 | torch==1.12.1 34 | torchaudio==0.12.1 35 | torchvision==0.13.1 36 | tqdm==4.64.1 37 | 38 | -------------------------------------------------------------------------------- /stardist_pkg/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | import warnings 4 | def format_warning(message, category, filename, lineno, line=''): 5 | import pathlib 6 | return f"{pathlib.Path(filename).name} ({lineno}): {message}\n" 7 | warnings.formatwarning = format_warning 8 | del warnings 9 | 10 | from .version import __version__ 11 | 12 | # TODO: which functions to expose here? all? 13 | from .nms import non_maximum_suppression 14 | from .utils import edt_prob, fill_label_holes, sample_points, calculate_extents, export_imagej_rois, gputools_available 15 | from .geometry import star_dist, polygons_to_label, relabel_image_stardist, ray_angles, dist_to_coord 16 | from .sample_patches import sample_patches 17 | from .bioimageio_utils import export_bioimageio, import_bioimageio 18 | 19 | def _py_deprecation(ver_python=(3,6), ver_stardist='0.9.0'): 20 | import sys 21 | from distutils.version import LooseVersion 22 | if sys.version_info[:2] == ver_python and LooseVersion(__version__) < LooseVersion(ver_stardist): 23 | print(f"You are using Python {ver_python[0]}.{ver_python[1]}, which will no longer be supported in StarDist {ver_stardist}.\n" 24 | f"→ Please upgrade to Python {ver_python[0]}.{ver_python[1]+1} or later.", file=sys.stderr, flush=True) 25 | _py_deprecation() 26 | del _py_deprecation 27 | -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/big.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/big.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/matching.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/matching.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/nms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/nms.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/sample_patches.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/sample_patches.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/__pycache__/version.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/__pycache__/version.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | # TODO: rethink naming for 2D/3D functions 4 | 5 | from .geom2d import star_dist, relabel_image_stardist, ray_angles, dist_to_coord, polygons_to_label, polygons_to_label_coord 6 | 7 | from .geom2d import _dist_to_coord_old, _polygons_to_label_old 8 | 9 | #, dist_to_volume, dist_to_centroid 10 | -------------------------------------------------------------------------------- /stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhaof/CellSeg/1c7bb981241574e2a268498491aa1ca38b6fcaa8/stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc -------------------------------------------------------------------------------- /stardist_pkg/geometry/geom2d.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, absolute_import, division 2 | import numpy as np 3 | import warnings 4 | 5 | from skimage.measure import regionprops 6 | from skimage.draw import polygon 7 | from csbdeep.utils import _raise 8 | 9 | from ..utils import path_absolute, _is_power_of_2, _normalize_grid 10 | from ..matching import _check_label_array 11 | from stardist.lib.stardist2d import c_star_dist 12 | 13 | 14 | 15 | def _ocl_star_dist(lbl, n_rays=32, grid=(1,1)): 16 | from gputools import OCLProgram, OCLArray, OCLImage 17 | (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError()) 18 | n_rays = int(n_rays) 19 | # slicing with grid is done with tuple(slice(0, None, g) for g in grid) 20 | res_shape = tuple((s-1)//g+1 for s, g in zip(lbl.shape, grid)) 21 | 22 | src = OCLImage.from_array(lbl.astype(np.uint16,copy=False)) 23 | dst = OCLArray.empty(res_shape+(n_rays,), dtype=np.float32) 24 | program = OCLProgram(path_absolute("kernels/stardist2d.cl"), build_options=['-D', 'N_RAYS=%d' % n_rays]) 25 | program.run_kernel('star_dist', res_shape[::-1], None, dst.data, src, np.int32(grid[0]),np.int32(grid[1])) 26 | return dst.get() 27 | 28 | 29 | def _cpp_star_dist(lbl, n_rays=32, grid=(1,1)): 30 | (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError()) 31 | return c_star_dist(lbl.astype(np.uint16,copy=False), np.int32(n_rays), np.int32(grid[0]),np.int32(grid[1])) 32 | 33 | 34 | def _py_star_dist(a, n_rays=32, grid=(1,1)): 35 | (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError()) 36 | if grid != (1,1): 37 | raise NotImplementedError(grid) 38 | 39 | n_rays = int(n_rays) 40 | a = a.astype(np.uint16,copy=False) 41 | dst = np.empty(a.shape+(n_rays,),np.float32) 42 | 43 | for i in range(a.shape[0]): 44 | for j in range(a.shape[1]): 45 | value = a[i,j] 46 | if value == 0: 47 | dst[i,j] = 0 48 | else: 49 | st_rays = np.float32((2*np.pi) / n_rays) 50 | for k in range(n_rays): 51 | phi = np.float32(k*st_rays) 52 | dy = np.cos(phi) 53 | dx = np.sin(phi) 54 | x, y = np.float32(0), np.float32(0) 55 | while True: 56 | x += dx 57 | y += dy 58 | ii = int(round(i+x)) 59 | jj = int(round(j+y)) 60 | if (ii < 0 or ii >= a.shape[0] or 61 | jj < 0 or jj >= a.shape[1] or 62 | value != a[ii,jj]): 63 | # small correction as we overshoot the boundary 64 | t_corr = 1-.5/max(np.abs(dx),np.abs(dy)) 65 | x -= t_corr*dx 66 | y -= t_corr*dy 67 | dist = np.sqrt(x**2+y**2) 68 | dst[i,j,k] = dist 69 | break 70 | return dst 71 | 72 | 73 | def star_dist(a, n_rays=32, grid=(1,1), mode='cpp'): 74 | """'a' assumbed to be a label image with integer values that encode object ids. id 0 denotes background.""" 75 | 76 | n_rays >= 3 or _raise(ValueError("need 'n_rays' >= 3")) 77 | 78 | if mode == 'python': 79 | return _py_star_dist(a, n_rays, grid=grid) 80 | elif mode == 'cpp': 81 | return _cpp_star_dist(a, n_rays, grid=grid) 82 | elif mode == 'opencl': 83 | return _ocl_star_dist(a, n_rays, grid=grid) 84 | else: 85 | _raise(ValueError("Unknown mode %s" % mode)) 86 | 87 | 88 | def _dist_to_coord_old(rhos, grid=(1,1)): 89 | """convert from polar to cartesian coordinates for a single image (3-D array) or multiple images (4-D array)""" 90 | 91 | grid = _normalize_grid(grid,2) 92 | is_single_image = rhos.ndim == 3 93 | if is_single_image: 94 | rhos = np.expand_dims(rhos,0) 95 | assert rhos.ndim == 4 96 | 97 | n_images,h,w,n_rays = rhos.shape 98 | coord = np.empty((n_images,h,w,2,n_rays),dtype=rhos.dtype) 99 | 100 | start = np.indices((h,w)) 101 | for i in range(2): 102 | coord[...,i,:] = grid[i] * np.broadcast_to(start[i].reshape(1,h,w,1), (n_images,h,w,n_rays)) 103 | 104 | phis = ray_angles(n_rays).reshape(1,1,1,n_rays) 105 | 106 | coord[...,0,:] += rhos * np.sin(phis) # row coordinate 107 | coord[...,1,:] += rhos * np.cos(phis) # col coordinate 108 | 109 | return coord[0] if is_single_image else coord 110 | 111 | 112 | def _polygons_to_label_old(coord, prob, points, shape=None, thr=-np.inf): 113 | sh = coord.shape[:2] if shape is None else shape 114 | lbl = np.zeros(sh,np.int32) 115 | # sort points with increasing probability 116 | ind = np.argsort([ prob[p[0],p[1]] for p in points ]) 117 | points = points[ind] 118 | 119 | i = 1 120 | for p in points: 121 | if prob[p[0],p[1]] < thr: 122 | continue 123 | rr,cc = polygon(coord[p[0],p[1],0], coord[p[0],p[1],1], sh) 124 | lbl[rr,cc] = i 125 | i += 1 126 | 127 | return lbl 128 | 129 | 130 | def dist_to_coord(dist, points, scale_dist=(1,1)): 131 | """convert from polar to cartesian coordinates for a list of distances and center points 132 | dist.shape = (n_polys, n_rays) 133 | points.shape = (n_polys, 2) 134 | len(scale_dist) = 2 135 | return coord.shape = (n_polys,2,n_rays) 136 | """ 137 | dist = np.asarray(dist) 138 | points = np.asarray(points) 139 | assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points) \ 140 | and points.shape[1]==2 and len(scale_dist)==2 141 | n_rays = dist.shape[1] 142 | phis = ray_angles(n_rays) 143 | coord = (dist[:,np.newaxis]*np.array([np.sin(phis),np.cos(phis)])).astype(np.float32) 144 | coord *= np.asarray(scale_dist).reshape(1,2,1) 145 | coord += points[...,np.newaxis] 146 | return coord 147 | 148 | 149 | def polygons_to_label_coord(coord, shape, labels=None): 150 | """renders polygons to image of given shape 151 | 152 | coord.shape = (n_polys, n_rays) 153 | """ 154 | coord = np.asarray(coord) 155 | if labels is None: labels = np.arange(len(coord)) 156 | 157 | _check_label_array(labels, "labels") 158 | assert coord.ndim==3 and coord.shape[1]==2 and len(coord)==len(labels) 159 | 160 | lbl = np.zeros(shape,np.int32) 161 | 162 | for i,c in zip(labels,coord): 163 | rr,cc = polygon(*c, shape) 164 | lbl[rr,cc] = i+1 165 | 166 | return lbl 167 | 168 | 169 | def polygons_to_label(dist, points, shape, prob=None, thr=-np.inf, scale_dist=(1,1)): 170 | """converts distances and center points to label image 171 | 172 | dist.shape = (n_polys, n_rays) 173 | points.shape = (n_polys, 2) 174 | 175 | label ids will be consecutive and adhere to the order given 176 | """ 177 | dist = np.asarray(dist) 178 | points = np.asarray(points) 179 | prob = np.inf*np.ones(len(points)) if prob is None else np.asarray(prob) 180 | 181 | assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points) 182 | assert len(points)==len(prob) and points.shape[1]==2 and prob.ndim==1 183 | 184 | n_rays = dist.shape[1] 185 | 186 | ind = prob>thr 187 | points = points[ind] 188 | dist = dist[ind] 189 | prob = prob[ind] 190 | 191 | ind = np.argsort(prob, kind='stable') 192 | points = points[ind] 193 | dist = dist[ind] 194 | 195 | coord = dist_to_coord(dist, points, scale_dist=scale_dist) 196 | 197 | return polygons_to_label_coord(coord, shape=shape, labels=ind) 198 | 199 | 200 | def relabel_image_stardist(lbl, n_rays, **kwargs): 201 | """relabel each label region in `lbl` with its star representation""" 202 | _check_label_array(lbl, "lbl") 203 | if not lbl.ndim==2: 204 | raise ValueError("lbl image should be 2 dimensional") 205 | dist = star_dist(lbl, n_rays, **kwargs) 206 | points = np.array(tuple(np.array(r.centroid).astype(int) for r in regionprops(lbl))) 207 | dist = dist[tuple(points.T)] 208 | return polygons_to_label(dist, points, shape=lbl.shape) 209 | 210 | 211 | def ray_angles(n_rays=32): 212 | return np.linspace(0,2*np.pi,n_rays,endpoint=False) 213 | -------------------------------------------------------------------------------- /stardist_pkg/kernels/stardist2d.cl: -------------------------------------------------------------------------------- 1 | #ifndef M_PI 2 | #define M_PI 3.141592653589793 3 | #endif 4 | 5 | __constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; 6 | 7 | inline float2 pol2cart(const float rho, const float phi) { 8 | const float x = rho * cos(phi); 9 | const float y = rho * sin(phi); 10 | return (float2)(x,y); 11 | } 12 | 13 | __kernel void star_dist(__global float* dst, read_only image2d_t src, const int grid_y, const int grid_x) { 14 | 15 | const int i = get_global_id(0), j = get_global_id(1); 16 | const int Nx = get_global_size(0), Ny = get_global_size(1); 17 | const float2 grid = (float2)(grid_x, grid_y); 18 | 19 | const float2 origin = (float2)(i,j) * grid; 20 | const int value = read_imageui(src,sampler,origin).x; 21 | 22 | if (value == 0) { 23 | // background pixel -> nothing to do, write all zeros 24 | for (int k = 0; k < N_RAYS; k++) { 25 | dst[k + i*N_RAYS + j*N_RAYS*Nx] = 0; 26 | } 27 | } else { 28 | float st_rays = (2*M_PI) / N_RAYS; // step size for ray angles 29 | // for all rays 30 | for (int k = 0; k < N_RAYS; k++) { 31 | const float phi = k*st_rays; // current ray angle phi 32 | const float2 dir = pol2cart(1,phi); // small vector in direction of ray 33 | float2 offset = 0; // offset vector to be added to origin 34 | // find radius that leaves current object 35 | while (1) { 36 | offset += dir; 37 | const int offset_value = read_imageui(src,sampler,round(origin+offset)).x; 38 | if (offset_value != value) { 39 | // small correction as we overshoot the boundary 40 | const float t_corr = .5f/fmax(fabs(dir.x),fabs(dir.y)); 41 | offset += (t_corr-1.f)*dir; 42 | 43 | const float dist = sqrt(offset.x*offset.x + offset.y*offset.y); 44 | dst[k + i*N_RAYS + j*N_RAYS*Nx] = dist; 45 | break; 46 | } 47 | } 48 | } 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /stardist_pkg/kernels/stardist3d.cl: -------------------------------------------------------------------------------- 1 | #ifndef M_PI 2 | #define M_PI 3.141592653589793 3 | #endif 4 | 5 | __constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; 6 | 7 | inline int round_to_int(float r) { 8 | return (int)rint(r); 9 | } 10 | 11 | 12 | __kernel void stardist3d(read_only image3d_t lbl, __constant float * rays, __global float* dist, const int grid_z, const int grid_y, const int grid_x) { 13 | 14 | const int i = get_global_id(0); 15 | const int j = get_global_id(1); 16 | const int k = get_global_id(2); 17 | 18 | const int Nx = get_global_size(0); 19 | const int Ny = get_global_size(1); 20 | const int Nz = get_global_size(2); 21 | 22 | const float4 grid = (float4)(grid_x, grid_y, grid_z, 1); 23 | const float4 origin = (float4)(i,j,k,0) * grid; 24 | const int value = read_imageui(lbl,sampler,origin).x; 25 | 26 | if (value == 0) { 27 | // background pixel -> nothing to do, write all zeros 28 | for (int m = 0; m < N_RAYS; m++) { 29 | dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = 0; 30 | } 31 | 32 | } 33 | else { 34 | for (int m = 0; m < N_RAYS; m++) { 35 | 36 | const float4 dx = (float4)(rays[3*m+2],rays[3*m+1],rays[3*m],0); 37 | // if ((i==Nx/2)&&(j==Ny/2)&(k==Nz/2)){ 38 | // printf("kernel: %.2f %.2f %.2f \n",dx.x,dx.y,dx.z); 39 | // } 40 | float4 x = (float4)(0,0,0,0); 41 | 42 | // move along ray 43 | while (1) { 44 | x += dx; 45 | // if ((i==10)&&(j==10)&(k==10)){ 46 | // printf("kernel run: %.2f %.2f %.2f value %d \n",x.x,x.y,x.z, read_imageui(lbl,sampler,origin+x).x); 47 | // } 48 | 49 | // to make it equivalent to the cpp version... 50 | const float4 x_int = (float4)(round_to_int(x.x), 51 | round_to_int(x.y), 52 | round_to_int(x.z), 53 | 0); 54 | 55 | if (value != read_imageui(lbl,sampler,origin+x_int).x){ 56 | 57 | dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = length(x_int); 58 | break; 59 | } 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /stardist_pkg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | from .model2d import Config2D, StarDist2D, StarDistData2D 4 | 5 | from csbdeep.utils import backend_channels_last 6 | from csbdeep.utils.tf import keras_import 7 | K = keras_import('backend') 8 | if not backend_channels_last(): 9 | raise NotImplementedError( 10 | "Keras is configured to use the '%s' image data format, which is currently not supported. " 11 | "Please change it to use 'channels_last' instead: " 12 | "https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % K.image_data_format() 13 | ) 14 | del backend_channels_last, K 15 | 16 | from csbdeep.models import register_model, register_aliases, clear_models_and_aliases 17 | # register pre-trained models and aliases (TODO: replace with updatable solution) 18 | clear_models_and_aliases(StarDist2D, StarDist3D) 19 | register_model(StarDist2D, '2D_versatile_fluo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_fluo.zip', '8db40dacb5a1311b8d2c447ad934fb8a') 20 | register_model(StarDist2D, '2D_versatile_he', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_he.zip', 'bf34cb3c0e5b3435971e18d66778a4ec') 21 | register_model(StarDist2D, '2D_paper_dsb2018', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_paper_dsb2018.zip', '6287bf283f85c058ec3e7094b41039b5') 22 | register_model(StarDist2D, '2D_demo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_demo.zip', '31f70402f58c50dd231ec31b4375ea2c') 23 | 24 | register_aliases(StarDist2D, '2D_paper_dsb2018', 'DSB 2018 (from StarDist 2D paper)') 25 | register_aliases(StarDist2D, '2D_versatile_fluo', 'Versatile (fluorescent nuclei)') 26 | register_aliases(StarDist2D, '2D_versatile_he', 'Versatile (H&E nuclei)') 27 | del register_model, register_aliases, clear_models_and_aliases 28 | -------------------------------------------------------------------------------- /stardist_pkg/nms.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, absolute_import, division 2 | import numpy as np 3 | from time import time 4 | from .utils import _normalize_grid 5 | 6 | def _ind_prob_thresh(prob, prob_thresh, b=2): 7 | if b is not None and np.isscalar(b): 8 | b = ((b,b),)*prob.ndim 9 | 10 | ind_thresh = prob > prob_thresh 11 | if b is not None: 12 | _ind_thresh = np.zeros_like(ind_thresh) 13 | ss = tuple(slice(_bs[0] if _bs[0]>0 else None, 14 | -_bs[1] if _bs[1]>0 else None) for _bs in b) 15 | _ind_thresh[ss] = True 16 | ind_thresh &= _ind_thresh 17 | return ind_thresh 18 | 19 | 20 | def _non_maximum_suppression_old(coord, prob, grid=(1,1), b=2, nms_thresh=0.5, prob_thresh=0.5, verbose=False, max_bbox_search=True): 21 | """2D coordinates of the polys that survive from a given prediction (prob, coord) 22 | 23 | prob.shape = (Ny,Nx) 24 | coord.shape = (Ny,Nx,2,n_rays) 25 | 26 | b: don't use pixel closer than b pixels to the image boundary 27 | 28 | returns retained points 29 | """ 30 | from .lib.stardist2d import c_non_max_suppression_inds_old 31 | 32 | # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary 33 | 34 | assert prob.ndim == 2 35 | assert coord.ndim == 4 36 | grid = _normalize_grid(grid,2) 37 | 38 | # mask = prob > prob_thresh 39 | # if b is not None and b > 0: 40 | # _mask = np.zeros_like(mask) 41 | # _mask[b:-b,b:-b] = True 42 | # mask &= _mask 43 | 44 | mask = _ind_prob_thresh(prob, prob_thresh, b) 45 | 46 | polygons = coord[mask] 47 | scores = prob[mask] 48 | 49 | # sort scores descendingly 50 | ind = np.argsort(scores)[::-1] 51 | survivors = np.zeros(len(ind), bool) 52 | polygons = polygons[ind] 53 | scores = scores[ind] 54 | 55 | if max_bbox_search: 56 | # map pixel indices to ids of sorted polygons (-1 => polygon at that pixel not a candidate) 57 | mapping = -np.ones(mask.shape,np.int32) 58 | mapping.flat[ np.flatnonzero(mask)[ind] ] = range(len(ind)) 59 | else: 60 | mapping = np.empty((0,0),np.int32) 61 | 62 | if verbose: 63 | t = time() 64 | 65 | survivors[ind] = c_non_max_suppression_inds_old(np.ascontiguousarray(polygons.astype(np.int32)), 66 | mapping, np.float32(nms_thresh), np.int32(max_bbox_search), 67 | np.int32(grid[0]), np.int32(grid[1]),np.int32(verbose)) 68 | 69 | if verbose: 70 | print("keeping %s/%s polygons" % (np.count_nonzero(survivors), len(polygons))) 71 | print("NMS took %.4f s" % (time() - t)) 72 | 73 | points = np.stack([ii[survivors] for ii in np.nonzero(mask)],axis=-1) 74 | return points 75 | 76 | 77 | def non_maximum_suppression(dist, prob, grid=(1,1), b=2, nms_thresh=0.5, prob_thresh=0.5, 78 | use_bbox=True, use_kdtree=True, verbose=False,cut=False): 79 | """Non-Maximum-Supression of 2D polygons 80 | 81 | Retains only polygons whose overlap is smaller than nms_thresh 82 | 83 | dist.shape = (Ny,Nx, n_rays) 84 | prob.shape = (Ny,Nx) 85 | 86 | returns the retained points, probabilities, and distances: 87 | 88 | points, prob, dist = non_maximum_suppression(dist, prob, .... 89 | 90 | """ 91 | 92 | # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary 93 | 94 | assert prob.ndim == 2 and dist.ndim == 3 and prob.shape == dist.shape[:2] 95 | dist = np.asarray(dist) 96 | prob = np.asarray(prob) 97 | n_rays = dist.shape[-1] 98 | 99 | grid = _normalize_grid(grid,2) 100 | 101 | # mask = prob > prob_thresh 102 | # if b is not None and b > 0: 103 | # _mask = np.zeros_like(mask) 104 | # _mask[b:-b,b:-b] = True 105 | # mask &= _mask 106 | 107 | mask = _ind_prob_thresh(prob, prob_thresh, b) 108 | points = np.stack(np.where(mask), axis=1) 109 | 110 | dist = dist[mask] 111 | scores = prob[mask] 112 | 113 | # sort scores descendingly 114 | ind = np.argsort(scores)[::-1] 115 | if cut is True and ind.shape[0] > 20000: 116 | #if cut is True and : 117 | ind = ind[:round(ind.shape[0]*0.5)] 118 | dist = dist[ind] 119 | scores = scores[ind] 120 | points = points[ind] 121 | 122 | points = (points * np.array(grid).reshape((1,2))) 123 | 124 | if verbose: 125 | t = time() 126 | 127 | inds = non_maximum_suppression_inds(dist, points.astype(np.int32, copy=False), scores=scores, 128 | use_bbox=use_bbox, use_kdtree=use_kdtree, 129 | thresh=nms_thresh, verbose=verbose) 130 | 131 | if verbose: 132 | print("keeping %s/%s polygons" % (np.count_nonzero(inds), len(inds))) 133 | print("NMS took %.4f s" % (time() - t)) 134 | 135 | return points[inds], scores[inds], dist[inds] 136 | 137 | 138 | def non_maximum_suppression_sparse(dist, prob, points, b=2, nms_thresh=0.5, 139 | use_bbox=True, use_kdtree = True, verbose=False): 140 | """Non-Maximum-Supression of 2D polygons from a list of dists, probs (scores), and points 141 | 142 | Retains only polyhedra whose overlap is smaller than nms_thresh 143 | 144 | dist.shape = (n_polys, n_rays) 145 | prob.shape = (n_polys,) 146 | points.shape = (n_polys,2) 147 | 148 | returns the retained instances 149 | 150 | (pointsi, probi, disti, indsi) 151 | 152 | with 153 | pointsi = points[indsi] ... 154 | 155 | """ 156 | 157 | # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary 158 | 159 | dist = np.asarray(dist) 160 | prob = np.asarray(prob) 161 | points = np.asarray(points) 162 | n_rays = dist.shape[-1] 163 | 164 | assert dist.ndim == 2 and prob.ndim == 1 and points.ndim == 2 and \ 165 | points.shape[-1]==2 and len(prob) == len(dist) == len(points) 166 | 167 | verbose and print("predicting instances with nms_thresh = {nms_thresh}".format(nms_thresh=nms_thresh), flush=True) 168 | 169 | inds_original = np.arange(len(prob)) 170 | _sorted = np.argsort(prob)[::-1] 171 | probi = prob[_sorted] 172 | disti = dist[_sorted] 173 | pointsi = points[_sorted] 174 | inds_original = inds_original[_sorted] 175 | 176 | if verbose: 177 | print("non-maximum suppression...") 178 | t = time() 179 | 180 | inds = non_maximum_suppression_inds(disti, pointsi, scores=probi, thresh=nms_thresh, use_kdtree = use_kdtree, verbose=verbose) 181 | 182 | if verbose: 183 | print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds))) 184 | print("NMS took %.4f s" % (time() - t)) 185 | 186 | return pointsi[inds], probi[inds], disti[inds], inds_original[inds] 187 | 188 | 189 | def non_maximum_suppression_inds(dist, points, scores, thresh=0.5, use_bbox=True, use_kdtree = True, verbose=1): 190 | """ 191 | Applies non maximum supression to ray-convex polygons given by dists and points 192 | sorted by scores and IoU threshold 193 | 194 | P1 will suppress P2, if IoU(P1,P2) > thresh 195 | 196 | with IoU(P1,P2) = Ainter(P1,P2) / min(A(P1),A(P2)) 197 | 198 | i.e. the smaller thresh, the more polygons will be supressed 199 | 200 | dist.shape = (n_poly, n_rays) 201 | point.shape = (n_poly, 2) 202 | score.shape = (n_poly,) 203 | 204 | returns indices of selected polygons 205 | """ 206 | 207 | from stardist.lib.stardist2d import c_non_max_suppression_inds 208 | 209 | assert dist.ndim == 2 210 | assert points.ndim == 2 211 | 212 | n_poly = dist.shape[0] 213 | 214 | if scores is None: 215 | scores = np.ones(n_poly) 216 | 217 | assert len(scores) == n_poly 218 | assert points.shape[0] == n_poly 219 | 220 | def _prep(x, dtype): 221 | return np.ascontiguousarray(x.astype(dtype, copy=False)) 222 | 223 | inds = c_non_max_suppression_inds(_prep(dist, np.float32), 224 | _prep(points, np.float32), 225 | int(use_kdtree), 226 | int(use_bbox), 227 | int(verbose), 228 | np.float32(thresh)) 229 | 230 | return inds 231 | 232 | 233 | ######### 234 | 235 | 236 | def non_maximum_suppression_3d(dist, prob, rays, grid=(1,1,1), b=2, nms_thresh=0.5, prob_thresh=0.5, use_bbox=True, use_kdtree=True, verbose=False): 237 | """Non-Maximum-Supression of 3D polyhedra 238 | 239 | Retains only polyhedra whose overlap is smaller than nms_thresh 240 | 241 | dist.shape = (Nz,Ny,Nx, n_rays) 242 | prob.shape = (Nz,Ny,Nx) 243 | 244 | returns the retained points, probabilities, and distances: 245 | 246 | points, prob, dist = non_maximum_suppression_3d(dist, prob, .... 247 | """ 248 | 249 | # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary 250 | 251 | dist = np.asarray(dist) 252 | prob = np.asarray(prob) 253 | 254 | assert prob.ndim == 3 and dist.ndim == 4 and dist.shape[-1] == len(rays) and prob.shape == dist.shape[:3] 255 | 256 | grid = _normalize_grid(grid,3) 257 | 258 | verbose and print("predicting instances with prob_thresh = {prob_thresh} and nms_thresh = {nms_thresh}".format(prob_thresh=prob_thresh, nms_thresh=nms_thresh), flush=True) 259 | 260 | # ind_thresh = prob > prob_thresh 261 | # if b is not None and b > 0: 262 | # _ind_thresh = np.zeros_like(ind_thresh) 263 | # _ind_thresh[b:-b,b:-b,b:-b] = True 264 | # ind_thresh &= _ind_thresh 265 | 266 | ind_thresh = _ind_prob_thresh(prob, prob_thresh, b) 267 | points = np.stack(np.where(ind_thresh), axis=1) 268 | verbose and print("found %s candidates"%len(points)) 269 | probi = prob[ind_thresh] 270 | disti = dist[ind_thresh] 271 | 272 | _sorted = np.argsort(probi)[::-1] 273 | probi = probi[_sorted] 274 | disti = disti[_sorted] 275 | points = points[_sorted] 276 | 277 | verbose and print("non-maximum suppression...") 278 | points = (points * np.array(grid).reshape((1,3))) 279 | 280 | inds = non_maximum_suppression_3d_inds(disti, points, rays=rays, scores=probi, thresh=nms_thresh, 281 | use_bbox=use_bbox, use_kdtree = use_kdtree, 282 | verbose=verbose) 283 | 284 | verbose and print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds))) 285 | return points[inds], probi[inds], disti[inds] 286 | 287 | 288 | def non_maximum_suppression_3d_sparse(dist, prob, points, rays, b=2, nms_thresh=0.5, use_kdtree = True, verbose=False): 289 | """Non-Maximum-Supression of 3D polyhedra from a list of dists, probs and points 290 | 291 | Retains only polyhedra whose overlap is smaller than nms_thresh 292 | dist.shape = (n_polys, n_rays) 293 | prob.shape = (n_polys,) 294 | points.shape = (n_polys,3) 295 | 296 | returns the retained instances 297 | 298 | (pointsi, probi, disti, indsi) 299 | 300 | with 301 | pointsi = points[indsi] ... 302 | """ 303 | 304 | # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary 305 | 306 | dist = np.asarray(dist) 307 | prob = np.asarray(prob) 308 | points = np.asarray(points) 309 | 310 | assert dist.ndim == 2 and prob.ndim == 1 and points.ndim == 2 and \ 311 | dist.shape[-1] == len(rays) and points.shape[-1]==3 and len(prob) == len(dist) == len(points) 312 | 313 | verbose and print("predicting instances with nms_thresh = {nms_thresh}".format(nms_thresh=nms_thresh), flush=True) 314 | 315 | inds_original = np.arange(len(prob)) 316 | _sorted = np.argsort(prob)[::-1] 317 | probi = prob[_sorted] 318 | disti = dist[_sorted] 319 | pointsi = points[_sorted] 320 | inds_original = inds_original[_sorted] 321 | 322 | verbose and print("non-maximum suppression...") 323 | 324 | inds = non_maximum_suppression_3d_inds(disti, pointsi, rays=rays, scores=probi, thresh=nms_thresh, use_kdtree = use_kdtree, verbose=verbose) 325 | 326 | verbose and print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds))) 327 | return pointsi[inds], probi[inds], disti[inds], inds_original[inds] 328 | 329 | 330 | def non_maximum_suppression_3d_inds(dist, points, rays, scores, thresh=0.5, use_bbox=True, use_kdtree = True, verbose=1): 331 | """ 332 | Applies non maximum supression to ray-convex polyhedra given by dists and rays 333 | sorted by scores and IoU threshold 334 | 335 | P1 will suppress P2, if IoU(P1,P2) > thresh 336 | 337 | with IoU(P1,P2) = Ainter(P1,P2) / min(A(P1),A(P2)) 338 | 339 | i.e. the smaller thresh, the more polygons will be supressed 340 | 341 | dist.shape = (n_poly, n_rays) 342 | point.shape = (n_poly, 3) 343 | score.shape = (n_poly,) 344 | 345 | returns indices of selected polygons 346 | """ 347 | from .lib.stardist3d import c_non_max_suppression_inds 348 | 349 | assert dist.ndim == 2 350 | assert points.ndim == 2 351 | assert dist.shape[1] == len(rays) 352 | 353 | n_poly = dist.shape[0] 354 | 355 | if scores is None: 356 | scores = np.ones(n_poly) 357 | 358 | assert len(scores) == n_poly 359 | assert points.shape[0] == n_poly 360 | 361 | # sort scores descendingly 362 | ind = np.argsort(scores)[::-1] 363 | survivors = np.ones(n_poly, bool) 364 | dist = dist[ind] 365 | points = points[ind] 366 | scores = scores[ind] 367 | 368 | def _prep(x, dtype): 369 | return np.ascontiguousarray(x.astype(dtype, copy=False)) 370 | 371 | if verbose: 372 | t = time() 373 | 374 | survivors[ind] = c_non_max_suppression_inds(_prep(dist, np.float32), 375 | _prep(points, np.float32), 376 | _prep(rays.vertices, np.float32), 377 | _prep(rays.faces, np.int32), 378 | _prep(scores, np.float32), 379 | int(use_bbox), 380 | int(use_kdtree), 381 | int(verbose), 382 | np.float32(thresh)) 383 | 384 | if verbose: 385 | print("NMS took %.4f s" % (time() - t)) 386 | 387 | return survivors 388 | -------------------------------------------------------------------------------- /stardist_pkg/rays3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ray factory 3 | 4 | classes that provide vertex and triangle information for rays on spheres 5 | 6 | Example: 7 | 8 | rays = Rays_Tetra(n_level = 4) 9 | 10 | print(rays.vertices) 11 | print(rays.faces) 12 | 13 | """ 14 | from __future__ import print_function, unicode_literals, absolute_import, division 15 | import numpy as np 16 | from scipy.spatial import ConvexHull 17 | import copy 18 | import warnings 19 | 20 | class Rays_Base(object): 21 | def __init__(self, **kwargs): 22 | self.kwargs = kwargs 23 | self._vertices, self._faces = self.setup_vertices_faces() 24 | self._vertices = np.asarray(self._vertices, np.float32) 25 | self._faces = np.asarray(self._faces, int) 26 | self._faces = np.asanyarray(self._faces) 27 | 28 | def setup_vertices_faces(self): 29 | """has to return 30 | 31 | verts , faces 32 | 33 | verts = ( (z_1,y_1,x_1), ... ) 34 | faces ( (0,1,2), (2,3,4), ... ) 35 | 36 | """ 37 | raise NotImplementedError() 38 | 39 | @property 40 | def vertices(self): 41 | """read-only property""" 42 | return self._vertices.copy() 43 | 44 | @property 45 | def faces(self): 46 | """read-only property""" 47 | return self._faces.copy() 48 | 49 | def __getitem__(self, i): 50 | return self.vertices[i] 51 | 52 | def __len__(self): 53 | return len(self._vertices) 54 | 55 | def __repr__(self): 56 | def _conv(x): 57 | if isinstance(x,(tuple, list, np.ndarray)): 58 | return "_".join(_conv(_x) for _x in x) 59 | if isinstance(x,float): 60 | return "%.2f"%x 61 | return str(x) 62 | return "%s_%s" % (self.__class__.__name__, "_".join("%s_%s" % (k, _conv(v)) for k, v in sorted(self.kwargs.items()))) 63 | 64 | def to_json(self): 65 | return { 66 | "name": self.__class__.__name__, 67 | "kwargs": self.kwargs 68 | } 69 | 70 | def dist_loss_weights(self, anisotropy = (1,1,1)): 71 | """returns the anisotropy corrected weights for each ray""" 72 | anisotropy = np.array(anisotropy) 73 | assert anisotropy.shape == (3,) 74 | return np.linalg.norm(self.vertices*anisotropy, axis = -1) 75 | 76 | def volume(self, dist=None): 77 | """volume of the starconvex polyhedron spanned by dist (if None, uses dist=1) 78 | dist can be a nD array, but the last dimension has to be of length n_rays 79 | """ 80 | if dist is None: dist = np.ones_like(self.vertices) 81 | 82 | dist = np.asarray(dist) 83 | 84 | if not dist.shape[-1]==len(self.vertices): 85 | raise ValueError("last dimension of dist should have length len(rays.vertices)") 86 | # all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays) 87 | # self.vertices -> (n_rays,3) 88 | # dist -> (m,n,..., n_rays) 89 | 90 | # dist -> (m,n,..., n_rays, 3) 91 | dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1) 92 | # verts -> (m,n,..., n_rays, 3) 93 | verts = np.broadcast_to(self.vertices, dist.shape) 94 | 95 | # dist, verts -> (n_rays, m,n, ..., 3) 96 | dist = np.moveaxis(dist,-2,0) 97 | verts = np.moveaxis(verts,-2,0) 98 | 99 | # vs -> (n_faces, 3, m, n, ..., 3) 100 | vs = (dist*verts)[self.faces] 101 | # vs -> (n_faces, m, n, ..., 3, 3) 102 | vs = np.moveaxis(vs, 1,-2) 103 | # vs -> (n_faces * m * n, 3, 3) 104 | vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3)) 105 | d = np.linalg.det(list(vs)).reshape((len(self.faces),)+dist.shape[1:-1]) 106 | 107 | return -1./6*np.sum(d, axis = 0) 108 | 109 | def surface(self, dist=None): 110 | """surface area of the starconvex polyhedron spanned by dist (if None, uses dist=1)""" 111 | dist = np.asarray(dist) 112 | 113 | if not dist.shape[-1]==len(self.vertices): 114 | raise ValueError("last dimension of dist should have length len(rays.vertices)") 115 | 116 | # self.vertices -> (n_rays,3) 117 | # dist -> (m,n,..., n_rays) 118 | 119 | # all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays) 120 | 121 | # dist -> (m,n,..., n_rays, 3) 122 | dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1) 123 | # verts -> (m,n,..., n_rays, 3) 124 | verts = np.broadcast_to(self.vertices, dist.shape) 125 | 126 | # dist, verts -> (n_rays, m,n, ..., 3) 127 | dist = np.moveaxis(dist,-2,0) 128 | verts = np.moveaxis(verts,-2,0) 129 | 130 | # vs -> (n_faces, 3, m, n, ..., 3) 131 | vs = (dist*verts)[self.faces] 132 | # vs -> (n_faces, m, n, ..., 3, 3) 133 | vs = np.moveaxis(vs, 1,-2) 134 | # vs -> (n_faces * m * n, 3, 3) 135 | vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3)) 136 | 137 | pa = vs[...,1,:]-vs[...,0,:] 138 | pb = vs[...,2,:]-vs[...,0,:] 139 | 140 | d = .5*np.linalg.norm(np.cross(list(pa), list(pb)), axis = -1) 141 | d = d.reshape((len(self.faces),)+dist.shape[1:-1]) 142 | return np.sum(d, axis = 0) 143 | 144 | 145 | def copy(self, scale=(1,1,1)): 146 | """ returns a copy whose vertices are scaled by given factor""" 147 | scale = np.asarray(scale) 148 | assert scale.shape == (3,) 149 | res = copy.deepcopy(self) 150 | res._vertices *= scale[np.newaxis] 151 | return res 152 | 153 | 154 | 155 | 156 | def rays_from_json(d): 157 | return eval(d["name"])(**d["kwargs"]) 158 | 159 | 160 | ################################################################ 161 | 162 | class Rays_Explicit(Rays_Base): 163 | def __init__(self, vertices0, faces0): 164 | self.vertices0, self.faces0 = vertices0, faces0 165 | super().__init__(vertices0=list(vertices0), faces0=list(faces0)) 166 | 167 | def setup_vertices_faces(self): 168 | return self.vertices0, self.faces0 169 | 170 | 171 | class Rays_Cartesian(Rays_Base): 172 | def __init__(self, n_rays_x=11, n_rays_z=5): 173 | super().__init__(n_rays_x=n_rays_x, n_rays_z=n_rays_z) 174 | 175 | def setup_vertices_faces(self): 176 | """has to return list of ( (z_1,y_1,x_1), ... ) _""" 177 | n_rays_x, n_rays_z = self.kwargs["n_rays_x"], self.kwargs["n_rays_z"] 178 | dphi = np.float32(2. * np.pi / n_rays_x) 179 | dtheta = np.float32(np.pi / n_rays_z) 180 | 181 | verts = [] 182 | for mz in range(n_rays_z): 183 | for mx in range(n_rays_x): 184 | phi = mx * dphi 185 | theta = mz * dtheta 186 | if mz == 0: 187 | theta = 1e-12 188 | if mz == n_rays_z - 1: 189 | theta = np.pi - 1e-12 190 | dx = np.cos(phi) * np.sin(theta) 191 | dy = np.sin(phi) * np.sin(theta) 192 | dz = np.cos(theta) 193 | if mz == 0 or mz == n_rays_z - 1: 194 | dx += 1e-12 195 | dy += 1e-12 196 | verts.append([dz, dy, dx]) 197 | 198 | verts = np.array(verts) 199 | 200 | def _ind(mz, mx): 201 | return mz * n_rays_x + mx 202 | 203 | faces = [] 204 | 205 | for mz in range(n_rays_z - 1): 206 | for mx in range(n_rays_x): 207 | faces.append([_ind(mz, mx), _ind(mz + 1, (mx + 1) % n_rays_x), _ind(mz, (mx + 1) % n_rays_x)]) 208 | faces.append([_ind(mz, mx), _ind(mz + 1, mx), _ind(mz + 1, (mx + 1) % n_rays_x)]) 209 | 210 | faces = np.array(faces) 211 | 212 | return verts, faces 213 | 214 | 215 | class Rays_SubDivide(Rays_Base): 216 | """ 217 | Subdivision polyehdra 218 | 219 | n_level = 1 -> base polyhedra 220 | n_level = 2 -> 1x subdivision 221 | n_level = 3 -> 2x subdivision 222 | ... 223 | """ 224 | 225 | def __init__(self, n_level=4): 226 | super().__init__(n_level=n_level) 227 | 228 | def base_polyhedron(self): 229 | raise NotImplementedError() 230 | 231 | def setup_vertices_faces(self): 232 | n_level = self.kwargs["n_level"] 233 | verts0, faces0 = self.base_polyhedron() 234 | return self._recursive_split(verts0, faces0, n_level) 235 | 236 | def _recursive_split(self, verts, faces, n_level): 237 | if n_level <= 1: 238 | return verts, faces 239 | else: 240 | verts, faces = Rays_SubDivide.split(verts, faces) 241 | return self._recursive_split(verts, faces, n_level - 1) 242 | 243 | @classmethod 244 | def split(self, verts0, faces0): 245 | """split a level""" 246 | 247 | split_edges = dict() 248 | verts = list(verts0[:]) 249 | faces = [] 250 | 251 | def _add(a, b): 252 | """ returns index of middle point and adds vertex if not already added""" 253 | edge = tuple(sorted((a, b))) 254 | if not edge in split_edges: 255 | v = .5 * (verts[a] + verts[b]) 256 | v *= 1. / np.linalg.norm(v) 257 | verts.append(v) 258 | split_edges[edge] = len(verts) - 1 259 | return split_edges[edge] 260 | 261 | for v1, v2, v3 in faces0: 262 | ind1 = _add(v1, v2) 263 | ind2 = _add(v2, v3) 264 | ind3 = _add(v3, v1) 265 | faces.append([v1, ind1, ind3]) 266 | faces.append([v2, ind2, ind1]) 267 | faces.append([v3, ind3, ind2]) 268 | faces.append([ind1, ind2, ind3]) 269 | 270 | return verts, faces 271 | 272 | 273 | class Rays_Tetra(Rays_SubDivide): 274 | """ 275 | Subdivision of a tetrahedron 276 | 277 | n_level = 1 -> normal tetrahedron (4 vertices) 278 | n_level = 2 -> 1x subdivision (10 vertices) 279 | n_level = 3 -> 2x subdivision (34 vertices) 280 | ... 281 | """ 282 | 283 | def base_polyhedron(self): 284 | verts = np.array([ 285 | [np.sqrt(8. / 9), 0., -1. / 3], 286 | [-np.sqrt(2. / 9), np.sqrt(2. / 3), -1. / 3], 287 | [-np.sqrt(2. / 9), -np.sqrt(2. / 3), -1. / 3], 288 | [0., 0., 1.] 289 | ]) 290 | faces = [[0, 1, 2], 291 | [0, 3, 1], 292 | [0, 2, 3], 293 | [1, 3, 2]] 294 | 295 | return verts, faces 296 | 297 | 298 | class Rays_Octo(Rays_SubDivide): 299 | """ 300 | Subdivision of a tetrahedron 301 | 302 | n_level = 1 -> normal Octahedron (6 vertices) 303 | n_level = 2 -> 1x subdivision (18 vertices) 304 | n_level = 3 -> 2x subdivision (66 vertices) 305 | 306 | """ 307 | 308 | def base_polyhedron(self): 309 | verts = np.array([ 310 | [0, 0, 1], 311 | [0, 1, 0], 312 | [0, 0, -1], 313 | [0, -1, 0], 314 | [1, 0, 0], 315 | [-1, 0, 0]]) 316 | 317 | faces = [[0, 1, 4], 318 | [0, 5, 1], 319 | [1, 2, 4], 320 | [1, 5, 2], 321 | [2, 3, 4], 322 | [2, 5, 3], 323 | [3, 0, 4], 324 | [3, 5, 0], 325 | ] 326 | 327 | return verts, faces 328 | 329 | 330 | def reorder_faces(verts, faces): 331 | """reorder faces such that their orientation points outward""" 332 | def _single(face): 333 | return face[::-1] if np.linalg.det(verts[face])>0 else face 334 | return tuple(map(_single, faces)) 335 | 336 | 337 | class Rays_GoldenSpiral(Rays_Base): 338 | def __init__(self, n=70, anisotropy = None): 339 | if n<4: 340 | raise ValueError("At least 4 points have to be given!") 341 | super().__init__(n=n, anisotropy = anisotropy if anisotropy is None else tuple(anisotropy)) 342 | 343 | def setup_vertices_faces(self): 344 | n = self.kwargs["n"] 345 | anisotropy = self.kwargs["anisotropy"] 346 | if anisotropy is None: 347 | anisotropy = np.ones(3) 348 | else: 349 | anisotropy = np.array(anisotropy) 350 | 351 | # the smaller golden angle = 2pi * 0.3819... 352 | g = (3. - np.sqrt(5.)) * np.pi 353 | phi = g * np.arange(n) 354 | # z = np.linspace(-1, 1, n + 2)[1:-1] 355 | # rho = np.sqrt(1. - z ** 2) 356 | # verts = np.stack([rho*np.cos(phi), rho*np.sin(phi),z]).T 357 | # 358 | z = np.linspace(-1, 1, n) 359 | rho = np.sqrt(1. - z ** 2) 360 | verts = np.stack([z, rho * np.sin(phi), rho * np.cos(phi)]).T 361 | 362 | # warnings.warn("ray definition has changed! Old results are invalid!") 363 | 364 | # correct for anisotropy 365 | verts = verts/anisotropy 366 | #verts /= np.linalg.norm(verts, axis=-1, keepdims=True) 367 | 368 | hull = ConvexHull(verts) 369 | faces = reorder_faces(verts,hull.simplices) 370 | 371 | verts /= np.linalg.norm(verts, axis=-1, keepdims=True) 372 | 373 | return verts, faces 374 | -------------------------------------------------------------------------------- /stardist_pkg/sample_patches.py: -------------------------------------------------------------------------------- 1 | """provides a faster sampling function""" 2 | 3 | import numpy as np 4 | from csbdeep.utils import _raise, choice 5 | 6 | 7 | def sample_patches(datas, patch_size, n_samples, valid_inds=None, verbose=False): 8 | """optimized version of csbdeep.data.sample_patches_from_multiple_stacks 9 | """ 10 | 11 | len(patch_size)==datas[0].ndim or _raise(ValueError()) 12 | 13 | if not all(( a.shape == datas[0].shape for a in datas )): 14 | raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas))) 15 | 16 | if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )): 17 | raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape))) 18 | 19 | if valid_inds is None: 20 | valid_inds = tuple(_s.ravel() for _s in np.meshgrid(*tuple(np.arange(p//2,s-p//2+1) for s,p in zip(datas[0].shape, patch_size)))) 21 | 22 | n_valid = len(valid_inds[0]) 23 | 24 | if n_valid == 0: 25 | raise ValueError("no regions to sample from!") 26 | 27 | idx = choice(range(n_valid), n_samples, replace=(n_valid < n_samples)) 28 | rand_inds = [v[idx] for v in valid_inds] 29 | res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas] 30 | 31 | return res 32 | 33 | 34 | def get_valid_inds(img, patch_size, patch_filter=None): 35 | """ 36 | Returns all indices of an image that 37 | - can be used as center points for sampling patches of a given patch_size, and 38 | - are part of the boolean mask given by the function patch_filter (if provided) 39 | 40 | img: np.ndarray 41 | patch_size: tuple of ints 42 | the width of patches per img dimension, 43 | patch_filter: None or callable 44 | a function with signature patch_filter(img, patch_size) returning a boolean mask 45 | """ 46 | 47 | len(patch_size)==img.ndim or _raise(ValueError()) 48 | 49 | if not all(( 0 < s <= d for s,d in zip(patch_size,img.shape))): 50 | raise ValueError("patch_size %s negative or larger than image shape %s along some dimensions" % (str(patch_size), str(img.shape))) 51 | 52 | if patch_filter is None: 53 | # only cut border indices (which is faster) 54 | patch_mask = np.ones(img.shape,dtype=bool) 55 | valid_inds = tuple(np.arange(p // 2, s - p + p // 2 + 1).astype(np.uint32) for p, s in zip(patch_size, img.shape)) 56 | valid_inds = tuple(s.ravel() for s in np.meshgrid(*valid_inds, indexing='ij')) 57 | else: 58 | patch_mask = patch_filter(img, patch_size) 59 | 60 | # get the valid indices 61 | border_slices = tuple([slice(p // 2, s - p + p // 2 + 1) for p, s in zip(patch_size, img.shape)]) 62 | valid_inds = np.where(patch_mask[border_slices]) 63 | valid_inds = tuple((v + s.start).astype(np.uint32) for s, v in zip(border_slices, valid_inds)) 64 | 65 | return valid_inds 66 | -------------------------------------------------------------------------------- /stardist_pkg/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.8.3' 2 | --------------------------------------------------------------------------------