├── __init__.py ├── data ├── .gitkeep └── README.md ├── configs ├── sweep │ ├── .gitkeep │ ├── num_clusters.yaml │ ├── aff.yaml │ ├── README.md │ └── defaults.yaml ├── wandb │ ├── .gitkeep │ └── README.md ├── dataset │ ├── .gitkeep │ └── README.md ├── hydra │ └── defaults.yaml ├── model │ └── dino_vits8.yaml ├── loader │ └── defaults.yaml ├── eval │ └── defaults.yaml ├── vis │ ├── allTrue.yaml │ ├── selected.yaml │ └── defaults.yaml ├── precomputed │ ├── carotid_mini_vis_polyaxon.yaml │ ├── defaults.yaml │ ├── carotid_mutinfo.yaml │ ├── carotid_mini_vis_local.yaml │ ├── mutinfo_train_mini.yaml │ ├── liver2_mini.yaml │ └── mutinfo_val_carotid.yaml ├── multi_region_segmentation │ └── defaults.yaml ├── pipeline_steps │ ├── allTrue.yaml │ ├── defaults.yaml │ └── allFalse.yaml ├── crf │ └── defaults.yaml ├── bbox │ └── defaults.yaml ├── spectral_clustering │ └── defaults.yaml └── defaults.yaml ├── evaluation ├── __init__.py ├── demo_dataset │ ├── images │ │ └── ultrasonix_test001.jpg │ ├── predictions │ │ └── ultrasonix_test001.png │ └── ground_truth │ │ └── ultrasonix_test001.png ├── dataset.py └── eval_demo.ipynb ├── self-training ├── __init__.py ├── models │ ├── __init__.py │ ├── dino.py │ ├── simclrTripletLightningModule.py │ ├── simclrLightningModule.py │ └── dinoLightningModule.py ├── tasks │ ├── __init__.py │ ├── configs │ │ ├── hydra │ │ │ └── defaults.yaml │ │ ├── experiment │ │ │ └── defaults.yaml │ │ ├── loader │ │ │ ├── defaults.yaml │ │ │ ├── patch64.yaml │ │ │ ├── patch70.yaml │ │ │ ├── patchv2.yaml │ │ │ ├── patch.yaml │ │ │ └── triplet_patch.yaml │ │ ├── dataset │ │ │ ├── carotid-mini.yaml │ │ │ ├── us_mixed.yaml │ │ │ ├── liver2_mini.yaml │ │ │ ├── liver2_medium.yaml │ │ │ ├── liver_reduced.yaml │ │ │ ├── carotid_mutinfo.yaml │ │ │ ├── imagenet-4-classes.yaml │ │ │ └── liver_similar_folders.yaml │ │ ├── train │ │ │ ├── defaults.yaml │ │ │ ├── dino_vits16.yaml │ │ │ ├── dino_vits8.yaml │ │ │ └── triplet.yaml │ │ ├── simclr.yaml │ │ ├── defaults.yaml │ │ ├── saliency_maps.yaml │ │ └── features.yaml │ └── features.py ├── custom_utils │ ├── __init__.py │ └── check_dimensions.py ├── datasets │ ├── __init__.py │ ├── augmentations.py │ ├── dataset_utils.py │ ├── samplers.py │ ├── datasets.py │ └── dino_transform.py ├── requirements.txt ├── poly.yaml ├── testpf.yaml ├── pfile_group_triplet.yaml ├── pfile_group.yaml └── train-mini-example.py ├── deep-spectral-segmentation ├── __init__.py ├── vis │ ├── __init__.py │ └── vis_utils.py ├── extract │ ├── __init__.py │ └── MutualInformation.py ├── pipeline │ ├── __init__.py │ ├── pipeline.sh │ └── extract_features.bat ├── dino2_models │ ├── __init__.py │ └── readme.md ├── semantic-segmentation │ ├── __init__.py │ ├── config │ │ ├── eval.yaml │ │ ├── base.yaml │ │ ├── train.yaml │ │ └── base_customdata.yaml │ ├── model │ │ ├── __init__.py │ │ └── model.py │ ├── eval_utils.py │ ├── README.md │ ├── dataset │ │ ├── custom_dataset.py │ │ ├── __init__.py │ │ └── voc.py │ ├── eval.py │ └── util.py └── requirements.txt ├── .gitattributes ├── req_pip.txt ├── pipeline.png ├── .gitmodules ├── req_conda.txt ├── run_pipeline.sh ├── data_preprocessing ├── denoise │ ├── readme.md │ └── denoise.sh ├── normalize.py ├── remap_labels.py ├── preprocess_data.py ├── crop_black_borders.py └── custom_normalization.py ├── sota └── SLIC │ ├── generate_slic.py │ └── generate_fz.py ├── .gitignore ├── spectral-clustering └── spectralnet_per_image.py └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/sweep/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/wandb/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/dataset/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /self-training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /self-training/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /self-training/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /self-training/custom_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /self-training/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/vis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored 2 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/extract/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/hydra/defaults.yaml: -------------------------------------------------------------------------------- 1 | job: 2 | chdir: True -------------------------------------------------------------------------------- /deep-spectral-segmentation/dino2_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/model/dino_vits8.yaml: -------------------------------------------------------------------------------- 1 | name: dino_vits8 2 | checkpoint: "" -------------------------------------------------------------------------------- /req_pip.txt: -------------------------------------------------------------------------------- 1 | SimpleCRF 2 | spectralnet 3 | sewar 4 | torchsummary 5 | -------------------------------------------------------------------------------- /self-training/tasks/configs/hydra/defaults.yaml: -------------------------------------------------------------------------------- 1 | job: 2 | chdir: True -------------------------------------------------------------------------------- /configs/loader/defaults.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | num_workers: 0 3 | mode: full -------------------------------------------------------------------------------- /self-training/tasks/configs/experiment/defaults.yaml: -------------------------------------------------------------------------------- 1 | name: train_dinoLightningModule -------------------------------------------------------------------------------- /self-training/tasks/configs/loader/defaults.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | num_workers: 0 3 | mode: full -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexaatm/UnsupervisedSegmentor4Ultrasound/HEAD/pipeline.png -------------------------------------------------------------------------------- /self-training/tasks/configs/loader/patch64.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | num_workers: 0 3 | mode: patch 4 | patch_size: 64 -------------------------------------------------------------------------------- /self-training/tasks/configs/loader/patch70.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | num_workers: 0 3 | mode: patch 4 | patch_size: 70 -------------------------------------------------------------------------------- /self-training/tasks/configs/loader/patchv2.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | num_workers: 0 3 | mode: patch 4 | patch_size: 100 -------------------------------------------------------------------------------- /self-training/tasks/configs/loader/patch.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | num_workers: 0 3 | mode: patch 4 | patch_mode: random 5 | patch_size: 64 -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/carotid-mini.yaml: -------------------------------------------------------------------------------- 1 | name: carotid-mini 2 | path: ../data/carotid-mini/images 3 | input_size: 256 4 | triplet_mode: seq -------------------------------------------------------------------------------- /configs/eval/defaults.yaml: -------------------------------------------------------------------------------- 1 | vis_dir: './eval/vis' 2 | vis_rand_k: 10 3 | eval_per_image: True 4 | eval_per_dataset: False 5 | iou_thresh: 0.0 6 | void_label: 0 -------------------------------------------------------------------------------- /configs/vis/allTrue.yaml: -------------------------------------------------------------------------------- 1 | eigen: True 2 | crf_segmaps: True 3 | dino_attn_maps: True 4 | multiregion_segmaps: True 5 | segmaps: True 6 | crf_multi_region: True -------------------------------------------------------------------------------- /self-training/tasks/configs/loader/triplet_patch.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | num_workers: 0 3 | mode: patch 4 | patch_size: 64 5 | min_shift: 0.1 6 | max_shift: 0.3 -------------------------------------------------------------------------------- /configs/vis/selected.yaml: -------------------------------------------------------------------------------- 1 | eigen: True 2 | crf_segmaps: True 3 | dino_attn_maps: False 4 | multiregion_segmaps: True 5 | segmaps: True 6 | crf_multi_region: True -------------------------------------------------------------------------------- /configs/vis/defaults.yaml: -------------------------------------------------------------------------------- 1 | eigen: False 2 | crf_segmaps: False 3 | dino_attn_maps: False 4 | multiregion_segmaps: False 5 | segmaps: False 6 | crf_multi_region: False -------------------------------------------------------------------------------- /evaluation/demo_dataset/images/ultrasonix_test001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexaatm/UnsupervisedSegmentor4Ultrasound/HEAD/evaluation/demo_dataset/images/ultrasonix_test001.jpg -------------------------------------------------------------------------------- /configs/precomputed/carotid_mini_vis_polyaxon.yaml: -------------------------------------------------------------------------------- 1 | mode: precomputed 2 | features: "" 3 | eig: "" 4 | multi_region_segmentation: "" 5 | bboxes: "" 6 | segmaps: "" 7 | crf_segmaps: "" 8 | -------------------------------------------------------------------------------- /evaluation/demo_dataset/predictions/ultrasonix_test001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexaatm/UnsupervisedSegmentor4Ultrasound/HEAD/evaluation/demo_dataset/predictions/ultrasonix_test001.png -------------------------------------------------------------------------------- /evaluation/demo_dataset/ground_truth/ultrasonix_test001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexaatm/UnsupervisedSegmentor4Ultrasound/HEAD/evaluation/demo_dataset/ground_truth/ultrasonix_test001.png -------------------------------------------------------------------------------- /configs/multi_region_segmentation/defaults.yaml: -------------------------------------------------------------------------------- 1 | adaptive: False 2 | non_adaptive_num_segments: 4 3 | infer_bg_index: True 4 | clustering1: kmeans_eigen 5 | num_eigenvectors: 1_000_000 6 | multiprocessing: 0 -------------------------------------------------------------------------------- /self-training/tasks/configs/train/defaults.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | backbone: 3 | pretrained_weights: False 4 | optimizer: Adam 5 | lr: 1e-3 6 | epochs: 10 7 | seed: 1 8 | weight_decay: 0 9 | fraction_layers_to_freeze: 0 -------------------------------------------------------------------------------- /self-training/tasks/configs/simclr.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: liver2_mini 3 | - train: defaults 4 | - wandb: defaults 5 | - hydra: defaults 6 | - experiment: defaults 7 | - loader: defaults 8 | -------------------------------------------------------------------------------- /self-training/tasks/configs/train/dino_vits16.yaml: -------------------------------------------------------------------------------- 1 | backbone: dino_vits16 2 | pretrained_weights: False 3 | optimizer: Adam 4 | lr: 1e-3 5 | epochs: 10 6 | seed: 1 7 | weight_decay: 0 8 | fraction_layers_to_freeze: 0.0 -------------------------------------------------------------------------------- /self-training/tasks/configs/train/dino_vits8.yaml: -------------------------------------------------------------------------------- 1 | backbone: dino_vits8 2 | pretrained_weights: False 3 | optimizer: Adam 4 | lr: 1e-4 5 | epochs: 10 6 | seed: 1 7 | weight_decay: 1e-5 8 | fraction_layers_to_freeze: 0.0 -------------------------------------------------------------------------------- /self-training/tasks/configs/train/triplet.yaml: -------------------------------------------------------------------------------- 1 | backbone: resnet 2 | pretrained_weights: False 3 | optimizer: Adam 4 | lr: 1e-3 5 | epochs: 10 6 | seed: 1 7 | triplet_loss_margin: 2 8 | fraction_layers_to_freeze: 0.0 -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/config/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base_customdata 4 | - _self_ 5 | 6 | name: 'eval' 7 | job_type: 'eval' 8 | 9 | segments_dir: "" 10 | -------------------------------------------------------------------------------- /configs/pipeline_steps/allTrue.yaml: -------------------------------------------------------------------------------- 1 | dino_features: True 2 | eigen: True 3 | segments: True 4 | bbox: True 5 | bbox_features: True 6 | clusters: True 7 | sem_segm: True 8 | crf_segm: True 9 | crf_multi_region: True 10 | eval : True -------------------------------------------------------------------------------- /configs/pipeline_steps/defaults.yaml: -------------------------------------------------------------------------------- 1 | dino_features: True 2 | eigen: True 3 | segments: True 4 | bbox: True 5 | bbox_features: True 6 | clusters: True 7 | sem_segm: True 8 | crf_segm: True 9 | crf_multi_region: True 10 | eval : False -------------------------------------------------------------------------------- /self-training/requirements.txt: -------------------------------------------------------------------------------- 1 | lightly 2 | torch 3 | torchvision 4 | wandb 5 | hydra-core 6 | omegaconf 7 | numpy 8 | Pillow 9 | matplotlib 10 | tqdm 11 | accelerate 12 | torchsummary 13 | torchinfo 14 | opencv-python-headless 15 | -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/us_mixed.yaml: -------------------------------------------------------------------------------- 1 | name: us_mixed 2 | path: ../data/us_mixed/train 3 | val_path: ../data/us_mixed/val 4 | rel_train_path: us_mixed/train 5 | rel_val_path: us_mixed/val 6 | input_size: 256 7 | triplet_mode: class -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "CutLER"] 2 | path = CutLER 3 | url = git@github.com:alexaatm/CutLER.git 4 | [submodule "data_preprocessing/denoise/MPRNet"] 5 | path = data_preprocessing/denoise/MPRNet 6 | url = git@github.com:swz30/MPRNet.git 7 | -------------------------------------------------------------------------------- /configs/pipeline_steps/allFalse.yaml: -------------------------------------------------------------------------------- 1 | dino_features: False 2 | eigen: False 3 | segments: False 4 | bbox: False 5 | bbox_features: False 6 | clusters: False 7 | sem_segm: False 8 | crf_segm: False 9 | crf_multi_region: False 10 | eval : False -------------------------------------------------------------------------------- /configs/precomputed/defaults.yaml: -------------------------------------------------------------------------------- 1 | mode: from_scratch 2 | features: "" 3 | eig: "" 4 | multi_region_segmentation: "" 5 | bboxes: "" 6 | bbox_features: "" 7 | bbox_clusters: "" 8 | segmaps: "" 9 | crf_segmaps: "" 10 | crf_multi_region: "" 11 | -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/liver2_mini.yaml: -------------------------------------------------------------------------------- 1 | name: liver2_mini 2 | path: ../data/liver2_mini/train 3 | val_path: ../data/liver2_mini/val 4 | rel_train_path: liver2_mini/train 5 | rel_val_path: liver2_mini/val 6 | input_size: 256 7 | triplet_mode: seq -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/liver2_medium.yaml: -------------------------------------------------------------------------------- 1 | name: liver2_medium 2 | path: ../data/liver2_medium/train 3 | val_path: ../data/liver2_medium/val 4 | rel_train_path: liver2_medium/train 5 | rel_val_path: liver2_medium/val 6 | input_size: 256 7 | triplet_mode: seq -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/liver_reduced.yaml: -------------------------------------------------------------------------------- 1 | name: liver_reduced 2 | path: ../data/liver_reduced/train 3 | val_path: ../data/liver_reduced/val 4 | rel_train_path: liver_reduced/train 5 | rel_val_path: liver_reduced/val 6 | input_size: 256 7 | triplet_mode: seq -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/carotid_mutinfo.yaml: -------------------------------------------------------------------------------- 1 | name: carotid_mutinfo 2 | path: ../data/carotid_mutinfo/train 3 | val_path: ../data/carotid_mutinfo/val 4 | rel_train_path: carotid_mutinfo/train 5 | rel_val_path: carotid_mutinfo/val 6 | input_size: 256 7 | triplet_mode: seq -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/imagenet-4-classes.yaml: -------------------------------------------------------------------------------- 1 | name: imagenet-4-classes 2 | path: ../data/imagenet-4-classes/train 3 | val_path: ../data/imagenet-4-classes/val 4 | rel_train_path: imagenet-4-classes/train 5 | rel_val_path: imagenet-4-classes/val 6 | input_size: 256 7 | triplet_mode: class -------------------------------------------------------------------------------- /deep-spectral-segmentation/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | fire 3 | opencv-python-headless 4 | pillow 5 | scikit-image 6 | scipy 7 | torch 8 | torchvision 9 | tqdm 10 | pymatting 11 | wandb 12 | omegaconf 13 | hydra-core 14 | SimpleCRF 15 | scikit-learn 16 | # pip install markupsafe==2.0.1 --force -------------------------------------------------------------------------------- /self-training/tasks/configs/dataset/liver_similar_folders.yaml: -------------------------------------------------------------------------------- 1 | name: liver_similar_folders 2 | path: ../data/liver_similar_folders/train 3 | val_path: ../data/liver_similar_folders/val 4 | rel_train_path: liver_similar_folders/train 5 | rel_val_path: liver_similar_folders/val 6 | input_size: 256 7 | triplet_mode: seq -------------------------------------------------------------------------------- /configs/precomputed/carotid_mutinfo.yaml: -------------------------------------------------------------------------------- 1 | mode: precomputed 2 | features: /outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110825/carotid_mutinfo/features/dino_vits8 3 | eig: "" 4 | multi_region_segmentation: "" 5 | bboxes: "" 6 | bbox_features: "" 7 | bbox_clusters: "" 8 | segmaps: "" 9 | crf_segmaps: "" 10 | -------------------------------------------------------------------------------- /self-training/tasks/configs/defaults.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: liver2_mini 3 | - train: defaults 4 | - wandb: defaults 5 | - hydra: defaults 6 | - experiment: defaults 7 | - loader: defaults 8 | - _self_ 9 | 10 | hydra: 11 | run: 12 | dir: ./outputs/models/${now:%Y-%m-%d}-${experiment.name}/${now:%H-%M-%S}-${dataset.name} -------------------------------------------------------------------------------- /configs/wandb/README.md: -------------------------------------------------------------------------------- 1 | Create a defaults.yaml file with your wandb preferences, here is an example of its content: 2 | 3 | ``` 4 | setup: 5 | project: yourproject 6 | entity: yourusername 7 | mode: offline 8 | key: 000000000000000000000000000000000000000 9 | tag: 'test1' 10 | watch: 11 | log: all 12 | log_freq: 1 13 | mode: local 14 | ``` -------------------------------------------------------------------------------- /self-training/tasks/configs/saliency_maps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: defaults 3 | - loader: defaults 4 | - hydra: defaults 5 | - _self_ 6 | 7 | hydra: 8 | run: 9 | dir: ./outputs/attention_maps/${model_name}/${dataset.name}/${now:%Y-%m-%d}-${now:%H-%M-%S}-${vis} 10 | model_name: simclr 11 | model_checkpoint: "" 12 | loader.batch_size: 1 13 | vis: saliency_maps_v1 14 | just_backbone: False -------------------------------------------------------------------------------- /configs/crf/defaults.yaml: -------------------------------------------------------------------------------- 1 | num_classes: 3 2 | downsample_factor: null 3 | multiprocessing: 0 4 | # CRF parameters 5 | w1: 10 # weight of bilateral term # default: 10.0, 6 | alpha: 80 # spatial std # default: 80, 7 | beta: 13 # rgb std # default: 13, 8 | w2: 3 # weight of spatial term # default: 3.0, 9 | gamma: 3 # spatial std # default: 3, 10 | it: 5.0 # iteration # default: 5.0, -------------------------------------------------------------------------------- /configs/dataset/README.md: -------------------------------------------------------------------------------- 1 | Create a config .yaml file for your dataset, following the structure below. E.g. thyroid.yaml: 2 | 3 | ``` 4 | name: thyroid_test9 5 | dataset_root: THYROID 6 | dataset_type: folders 7 | images_root: images 8 | list: lists/images.txt 9 | gt_dir: labels 10 | pred_dir: null 11 | n_classes: 2 12 | features_dir: null 13 | preprocessed_dir: null 14 | derained_dir: null 15 | eigenseg_dir: null 16 | ``` -------------------------------------------------------------------------------- /self-training/tasks/configs/features.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: defaults 3 | - loader: defaults 4 | - wandb: defaults 5 | - hydra: defaults 6 | - experiment: defaults 7 | - _self_ 8 | 9 | hydra: 10 | run: 11 | dir: ./outputs/features/${now:%Y-%m-%d}-${experiment.name}/${now:%H-%M-%S}-${dataset.name} 12 | output_dir: features 13 | model_name: simclr 14 | model_checkpoint: null 15 | loader.batch_size: 1 -------------------------------------------------------------------------------- /req_conda.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | fire 3 | hydra-core 4 | lightly 5 | matplotlib 6 | numpy 7 | omegaconf 8 | opencv-python-headless 9 | pillow 10 | pymatting 11 | scikit-image 12 | scikit-learn 13 | scipy 14 | torchinfo 15 | torchvision 16 | tqdm 17 | wandb 18 | torchmetrics 19 | torchinfo 20 | torchvision 21 | fvcore 22 | iopath 23 | xformers 24 | submitit 25 | albumentations 26 | kornia 27 | pandas 28 | streamlit 29 | 30 | 31 | -------------------------------------------------------------------------------- /configs/bbox/defaults.yaml: -------------------------------------------------------------------------------- 1 | # bbox extraction: 2 | num_erode: 2 3 | num_dilate: 5 4 | skip_bg_index: True 5 | downsample_factor: null 6 | # bbox features 7 | C_pos: 0.0 8 | C_mask: 0.0 9 | apply_mask: False 10 | feat_comb_method: "sum" 11 | # bbox clustering: 12 | num_clusters: 20 13 | seed: 1 14 | pca_dim: 0 15 | clustering: kmeans 16 | should_use_siamese: null 17 | should_use_ae: null 18 | is_sparse_graph: null 19 | spectral_n_nbg: null -------------------------------------------------------------------------------- /configs/sweep/num_clusters.yaml: -------------------------------------------------------------------------------- 1 | name: num_clusters 2 | seg_for_eval: ['crf_multi_region'] 3 | method: grid 4 | count: 5 | simple: True 6 | sweep_id: null 7 | config: 8 | # generic 9 | segments_num: [15] 10 | clusters_num: [6,9,12,15] 11 | 12 | # postprocessing (CRF) 13 | crf: 14 | num_classes: [10] 15 | w1: [15] 16 | alpha: [7] 17 | beta: [10] 18 | w2: [5] 19 | gamma: [5] 20 | it: [10] -------------------------------------------------------------------------------- /run_pipeline.sh: -------------------------------------------------------------------------------- 1 | cd deep-spectral-segmentation 2 | 3 | export WANDB_API_KEY= 4 | export WANDB_CONFIG_DIR=/tmp/ 5 | export WANDB_CACHE_DIR=/tmp/ 6 | export WANDB_AGENT_MAX_INITIAL_FAILURE=20 7 | export WANDB__SERVICE_WAIT=600 8 | export XFORMERS_DISABLED=True 9 | 10 | python -m pipeline.pipeline_sweep_subfolders \ 11 | vis=selected \ 12 | pipeline_steps=defaults \ 13 | dataset=thyroid \ 14 | wandb.tag=test \ 15 | sweep=defaults -------------------------------------------------------------------------------- /configs/precomputed/carotid_mini_vis_local.yaml: -------------------------------------------------------------------------------- 1 | mode: precomputed 2 | features: "C:/Users/Tmenova/personal/tum/thesis/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/carotid-mini/2023-06-14--11-57-28/features/dino_vits8" 3 | eig: "C:/Users/Tmenova/personal/tum/thesis/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/carotid-mini/2023-06-14--11-57-28/eig/laplacian" 4 | multi_region_segmentation: "" 5 | bboxes: "" 6 | segmaps: "" 7 | crf_segmaps: "" 8 | -------------------------------------------------------------------------------- /configs/precomputed/mutinfo_train_mini.yaml: -------------------------------------------------------------------------------- 1 | mode: precomputed 2 | features: "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/mutinfo_train_mini/2023-08-12--15-04-22/features/dino_vits8" 3 | eig: "thesis-codebase/deep-spectral-segmentation/outputs/pipeline/mutinfo_train_mini/2023-08-12--15-04-22/eig/laplacian" 4 | multi_region_segmentation: "" 5 | bboxes: "" 6 | bbox_features: "" 7 | bbox_clusters: "" 8 | segmaps: "" 9 | crf_segmaps: "" 10 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/pipeline/pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #! Example parameters for the semantic segmentation experiments 4 | DATASET="liver2_mini" 5 | MODEL="dino_vits8" 6 | MATRIX="laplacian" 7 | DOWNSAMPLE=16 8 | N_SEG=15 9 | N_ERODE=2 10 | N_DILATE=5 11 | 12 | python ../extract/extract.py extract_features \ 13 | --images_list "./data/${DATASET}/lists/images.txt" \ 14 | --images_root "./data/${DATASET}/images" \ 15 | --output_dir "./data/${DATASET}/features/${MODEL}" \ 16 | --model_name "${MODEL}" \ 17 | --batch_size 1 -------------------------------------------------------------------------------- /configs/sweep/aff.yaml: -------------------------------------------------------------------------------- 1 | name: aff 2 | seg_for_eval: ['crf_multi_region'] 3 | method: grid 4 | count: 5 | simple: True 6 | sweep_id: null 7 | config: 8 | # generic 9 | segments_num: [15] 10 | clusters_num: [15] 11 | 12 | # affinities 13 | spectral_clustering: 14 | C_dino: [1.0] 15 | C_ssd: [0.0, 1.0] 16 | C_mi: [0.0, 1.0] 17 | C_pos: [0.0] 18 | 19 | # postprocessing (CRF) 20 | crf: 21 | w1: [15] 22 | alpha: [7] 23 | beta: [10] 24 | w2: [5] 25 | gamma: [5] 26 | it: [10] -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | This is a folder for your data. 2 | It should be structured the following way: 3 | 4 | ``` 5 | data/ 6 | --DATSET1 7 | ----subfolder1/ 8 | ------images 9 | ------lists 10 | ------labels 11 | ... 12 | ----subfolderN/ 13 | ------images 14 | ------lists 15 | ------labels 16 | 17 | --DATSET2 18 | ----subfolder1/ 19 | ------images 20 | ------lists 21 | ------labels 22 | ----subfolder2/ 23 | ------images 24 | ------lists 25 | ------labels 26 | ``` 27 | 28 | Note: names of datasets and subfolders can by any. Important is to create a corresponding config file (see config/dataset). -------------------------------------------------------------------------------- /configs/spectral_clustering/defaults.yaml: -------------------------------------------------------------------------------- 1 | which_matrix: laplacian 2 | which_color_matrix: knn 3 | which_features: k 4 | normalize: True 5 | threshold_at_zero: True 6 | lapnorm: True 7 | K: 20 8 | image_downsample_factor: null 9 | image_color_lambda: 0.0 10 | multiprocessing: 0 11 | C_ssd_knn: 0.0 12 | C_dino: 1.0 13 | max_knn_neigbors: 80 14 | C_var_knn: 0.0 15 | C_pos_knn: 0.0 16 | C_ssd: 0.0 17 | C_ncc: 0.0 18 | C_lncc: 0.0 19 | C_ssim: 0.0 20 | C_mi: 0.0 21 | C_sam: 0.0 22 | patch_size: 8 23 | aff_sigma: 0.01 24 | distance_weight1: null 25 | distance_weight2: null 26 | use_transform: False 27 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/pipeline/extract_features.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Example parameters for the semantic segmentation experiments 4 | set "DATASET=liver2_mini" 5 | set "MODEL=dino_vits8" 6 | set "MATRIX=laplacian" 7 | set "DOWNSAMPLE=16" 8 | set "N_SEG=15" 9 | set "N_ERODE=2" 10 | set "N_DILATE=5" 11 | set "DATA_ROOT=../../data" 12 | 13 | 14 | python ..\extract\extract.py extract_features ^ 15 | --images_list "%DATA_ROOT%/%DATASET%/lists/images.txt" ^ 16 | --images_root "%DATA_ROOT%/%DATASET%/images" ^ 17 | --output_dir "%DATA_ROOT%/%DATASET%/features/%MODEL%" ^ 18 | --model_name "%MODEL%" ^ 19 | --batch_size 1 -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | from .model import get_deeplab_resnet, get_deeplab_vit 7 | 8 | 9 | def get_model(name: str, num_classes: int): 10 | if 'resnet' in name: 11 | model = get_deeplab_resnet(num_classes=(num_classes + 1)) # add 1 for bg 12 | elif 'vit' in name: 13 | model = get_deeplab_vit(backbone_name=name, num_classes=(num_classes + 1)) # add 1 for bg 14 | else: 15 | raise NotImplementedError() 16 | return model 17 | 18 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/dino2_models/readme.md: -------------------------------------------------------------------------------- 1 | Clone here the repository for dinov2: https://github.com/facebookresearch/dinov2.git 2 | in order to modify the needed layers for extracting selfattention based on https://github.com/facebookresearch/dinov2/commit/df7265ce09efa7553a537606565217e42cefea32 3 | 4 | Alternatively, clone directly the modified version: https://github.com/3cology/dinov2_with_attention_extraction.git 5 | 6 | Note: need to make the directory into modules (put __init__.py file in each folder). 7 | Note2: in dino2/models/vision_transformer.py, need to change imports to "from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block 8 | 9 | 10 | -------------------------------------------------------------------------------- /configs/sweep/README.md: -------------------------------------------------------------------------------- 1 | Create a sweep config, if you wish to experiment with multiple settings and their combinations (list them using commas). You can use any configs from the configs folder and specify the values, using the following structure: 2 | 3 | ``` 4 | name: defaults 5 | seg_for_eval: ['multi_region'] 6 | method: grid 7 | count: 1 8 | simple: True 9 | config: 10 | # generic 11 | segments_num: [15] 12 | clusters_num: [6] 13 | 14 | # preprocessing 15 | norm: ["imagenet"] 16 | inv: [False] 17 | hist_eq: [True] 18 | gauss_blur: [True] 19 | 20 | # affinities 21 | spectral_clustering: 22 | C_dino: [1.0] 23 | C_ssd: [0.0] 24 | C_mi: [0.0] 25 | 26 | # postprocessing (CRF) 27 | crf: 28 | num_classes: [6] 29 | w1: [15] 30 | alpha: [7] 31 | beta: [10] 32 | w2: [5] 33 | gamma: [5] 34 | it: [10] 35 | 36 | ``` 37 | -------------------------------------------------------------------------------- /configs/sweep/defaults.yaml: -------------------------------------------------------------------------------- 1 | name: defaults 2 | seg_for_eval: ['crf_multi_region'] 3 | method: grid 4 | count: 5 | simple: True 6 | sweep_id: null 7 | config: 8 | # generic 9 | segments_num: [15] 10 | clusters_num: [15] 11 | multi_region_segmentation: 12 | non_adaptive_num_segments: [15] #same as segments_num 13 | bbox: 14 | num_clusters: [15] #same as clusters_num 15 | 16 | # preprocessing 17 | norm: ["imagenet"] 18 | inv: [False] 19 | hist_eq: [True] 20 | gauss_blur: [True] 21 | 22 | # affinities 23 | spectral_clustering: 24 | K: [15] #same as segments_num 25 | C_dino: [1.0] 26 | C_ssd: [0.0] 27 | C_mi: [0.0] 28 | C_pos: [0.0] 29 | 30 | # postprocessing (CRF) 31 | crf: 32 | num_classes: [15] #same as clusters_num 33 | w1: [15] 34 | alpha: [7] 35 | beta: [10] 36 | w2: [5] 37 | gamma: [5] 38 | it: [10] -------------------------------------------------------------------------------- /data_preprocessing/denoise/readme.md: -------------------------------------------------------------------------------- 1 | Clone here the repository for MPRNet: https://github.com/swz30/MPRNet.git 2 | 3 | Place the denoise.sh utility bash script inside of the cloned MPRNet folder, to run denoising on several input folders. 4 | 5 | To use pretrained models, download and place them in pretrained_models folder of each task (Denoising, Deraining): 6 | 7 | To test the pre-trained models of [Deblurring](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing), [Deraining](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing), [Denoising](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing) on your own images, run 8 | ``` 9 | python demo.py --task Task_Name --input_dir path_to_images --result_dir save_images_here 10 | ``` 11 | Here is an example to perform Deblurring: 12 | ``` 13 | python demo.py --task Deblurring --input_dir ./samples/input/ --result_dir ./samples/output/ 14 | ``` 15 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/config/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | run: 4 | dir: ./outputs/${name}/${now:%Y-%m-%d--%H-%M-%S} 5 | 6 | name: "debug" 7 | seed: 1 8 | job_type: 'train' 9 | fp16: False 10 | cpu: False 11 | wandb: False 12 | wandb_kwargs: 13 | project: deep-spectral-segmentation 14 | 15 | data: 16 | num_classes: 20 17 | dataset: pascal 18 | train_kwargs: 19 | root: ${oc.env:HOME}/machine-learning-datasets/semantic-segmentation/PASCAL_VOC/VOC2012 20 | year: "2012" 21 | image_set: train 22 | download: False 23 | val_kwargs: 24 | root: ${oc.env:HOME}/machine-learning-datasets/semantic-segmentation/PASCAL_VOC/VOC2012 25 | year: "2012" 26 | image_set: "val" 27 | download: False 28 | loader: 29 | batch_size: 144 30 | num_workers: 8 31 | pin_memory: False 32 | transform: 33 | resize_size: 256 34 | crop_size: 224 35 | img_mean: [0.485, 0.456, 0.406] 36 | img_std: [0.229, 0.224, 0.225] 37 | 38 | segments_dir: "" 39 | 40 | logging: 41 | print_freq: 50 -------------------------------------------------------------------------------- /configs/precomputed/liver2_mini.yaml: -------------------------------------------------------------------------------- 1 | mode: precomputed 2 | features: "/outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110511/liver2_mini/features/dino_vits8" 3 | eig: "/outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110512/liver2_mini/eig/laplacian" 4 | multi_region_segmentation: "/outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110512/liver2_mini/multi_region_segmentation/laplacian" 5 | bboxes: "/outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110512/liver2_mini/multi_region_bboxes/laplacian/bboxes.pth" 6 | bbox_features: "/outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110687/liver2_mini/multi_region_bboxes/laplacian/bbox_features.pth" 7 | bbox_clusters: " /outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110690/liver2_mini/multi_region_bboxes/laplacian/bbox_clusters.pth" 8 | segmaps: "/outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110691/liver2_mini/semantic_segmentations/patches/laplacian/segmaps" 9 | crf_segmaps: "/outputs/oleksandra_tmenova/SegmPipelineUS/experiments/110727/liver2_mini/semantic_segmentations/patches/laplacian/crf_segmaps" -------------------------------------------------------------------------------- /configs/precomputed/mutinfo_val_carotid.yaml: -------------------------------------------------------------------------------- 1 | mode: precomputed 2 | features: "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/mutinfo_val_carotid/2023-08-13--19-37-19/features/dino_vits8" 3 | eig: "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/mutinfo_val_carotid/2023-08-13--19-37-19/eig/laplacian" 4 | multi_region_segmentation: "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/mutinfo_val_carotid/2023-08-13--19-37-19/multi_region_segmentation/laplacian" 5 | bboxes: "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/mutinfo_val_carotid/2023-08-13--19-37-19/multi_region_bboxes/laplacian/bboxes.pth" 6 | bbox_features: "" 7 | bbox_clusters: "" 8 | segmaps: "" 9 | crf_segmaps: "" 10 | crf_multi_region: "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/pipeline/mutinfo_val_carotid/2023-08-13--20-47-28/semantic_segmentations/laplacian/crf_multi_region" -------------------------------------------------------------------------------- /configs/defaults.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: liver2_mini 3 | - wandb: defaults 4 | - hydra: defaults 5 | - loader: defaults 6 | - model: dino_vits8 7 | - spectral_clustering: defaults 8 | - multi_region_segmentation: defaults 9 | - bbox: defaults 10 | - precomputed: defaults 11 | - crf: defaults 12 | - vis: defaults 13 | - pipeline_steps: defaults 14 | - eval: defaults 15 | - sweep: defaults 16 | - _self_ 17 | 18 | hydra: 19 | run: 20 | dir: ./results/${dataset.name}/exp_${wandb.tag}/${now:%Y-%m-%d}/${now:%H-%M-%S} 21 | 22 | custom_path_to_save_data: "" 23 | 24 | only_vis: False 25 | only_eval: False 26 | 27 | # data preprocessing -> TODO: move to a spearate config dir 28 | preprocessed_data: null 29 | norm: imagenet 30 | inv: False 31 | hist_eq: False 32 | gauss_blur: False 33 | gauss_teta: 0.05 34 | 35 | 36 | segments_num: 15 37 | clusters_num: 15 38 | 39 | spectral_clustering: 40 | K: ${segments_num} 41 | 42 | multi_region_segmentation: 43 | non_adaptive_num_segments: ${segments_num} 44 | 45 | bbox: 46 | num_clusters: ${clusters_num} 47 | 48 | crf: 49 | num_classes: ${clusters_num} 50 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/config/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base_customdata 4 | - _self_ 5 | 6 | job_type: 'train' 7 | eval_every: 1 # eval every this many epochs 8 | checkpoint_every: 10 # checkpoint every this many epochs 9 | 10 | unfrozen_backbone_layers: 1 # -1 to train all, 0 to freeze entirely, > 0 to specify 11 | model: 12 | name: resnet50 13 | num_classes: ${data.num_classes} 14 | 15 | # Please change these 16 | segments_dir: "" 17 | matching: "" 18 | 19 | checkpoint: 20 | resume: null 21 | resume_training: True 22 | resume_optimizer_only: False 23 | 24 | # Exponential moving average of model parameters 25 | ema: 26 | use_ema: False 27 | decay: 0.999 28 | update_every: 10 29 | 30 | # Training steps/epochs 31 | max_train_steps: 5000 32 | max_train_epochs: null 33 | 34 | # Optimization 35 | lr: 0.005 36 | gradient_accumulation_steps: 1 37 | optimizer: 38 | scale_learning_rate_with_batch_size: False 39 | clip_grad_norm: null 40 | 41 | # Timm optimizer 42 | kind: 'timm' 43 | kwargs: 44 | opt: 'adamw' 45 | weight_decay: 1e-8 46 | 47 | # Learning rate scheduling 48 | scheduler: 49 | 50 | # Transformers scheduler 51 | kind: 'transformers' 52 | stepwise: True 53 | kwargs: 54 | name: linear 55 | num_warmup_steps: 0 56 | num_training_steps: ${max_train_steps} 57 | -------------------------------------------------------------------------------- /data_preprocessing/denoise/denoise.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --job-name=test 4 | #SBATCH --output=test-%A.out # Standard output of the script (Can be absolute or relative path). %A adds the job id to the file name so you can launch the same script multiple times and get different logging files 5 | #SBATCH --error=test-%A.err # Standard error of the script 6 | #SBATCH --gres=gpu:0 # Number of GPUs if needed 7 | #SBATCH --cpus-per-task=8 # Number of CPUs (Don't use more than 24 per GPU) 8 | #SBATCH --mem=30G # Memory in GB (Don't use more than 126G per GPU) BEFORE: 36G 9 | #SBATCH -w unimatrix2 10 | 11 | # activate corresponding environment 12 | # source ../../../../.venv-p311/bin/activate 13 | source ../../../../.venv-p311-cu116/bin/activate 14 | 15 | 16 | pwd 17 | 18 | echo "test: nvidia-smi" 19 | nvidia-smi 20 | echo "test: torch version" 21 | python -c "import torch; print(torch.__version__)" 22 | 23 | 24 | datasets=( 25 | # "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/THYROID/val/val18" 26 | "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/test_heart/val/val0" 27 | ) 28 | 29 | for dataset in "${datasets[@]}"; do 30 | echo "Preprocessing for dataset: $dataset" 31 | 32 | python demo.py \ 33 | --task Deraining \ 34 | --input_dir "$dataset"/images \ 35 | --result_dir "$dataset"/mpr_derained 36 | 37 | done 38 | 39 | 40 | -------------------------------------------------------------------------------- /self-training/poly.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | kind: experiment 3 | framework: pytorch 4 | tags: 5 | - dinov1_vits16 6 | build: 7 | image: pytorchlightning/pytorch_lightning:base-xla-py3.7-torch1.12 8 | build_steps: 9 | - pip install -r requirements.txt 10 | - pip install markupsafe==2.0.1 --force 11 | - pip3 install torch_xla[tpuvm] 12 | - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 13 | 14 | environment: 15 | resources: 16 | cpu: 17 | requests: 12 18 | limits: 48 19 | memory: 20 | requests: 24576 21 | limits: 64000 22 | gpu: 23 | requests: 1 24 | limits: 1 25 | 26 | params: 27 | batch_size: 48 28 | exp_name: train_dinoLightningModule 29 | backbone: dino_vits16 30 | epochs: 300 31 | loader_num_workers: 8 32 | dataset_input_size: 256 33 | train_lr: 1e-3 34 | pretrained_weights: False 35 | dataset: liver_reduced 36 | loader: patch64 37 | 38 | run: 39 | cmd: python -m tasks.train \ 40 | wandb.setup.project=self-train \ 41 | wandb=server \ 42 | dataset={{ dataset }} \ 43 | experiment.name={{ exp_name }} \ 44 | train.backbone={{ backbone }} \ 45 | train.epochs={{ epochs }} \ 46 | train.pretrained_weights={{ pretrained_weights }} \ 47 | loader.batch_size={{ batch_size }} \ 48 | loader.num_workers={{ loader_num_workers }} \ 49 | dataset.input_size={{ dataset_input_size }} \ 50 | loader={{ loader }} \ 51 | train.lr={{ train_lr }} 52 | -------------------------------------------------------------------------------- /self-training/models/dino.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from lightly.models.modules import DINOProjectionHead 4 | from lightly.models.utils import deactivate_requires_grad 5 | import torchvision 6 | from torch import nn 7 | 8 | 9 | class DINO(torch.nn.Module): 10 | def __init__(self, backbone, input_dim): 11 | super().__init__() 12 | self.student_backbone = backbone 13 | self.student_head = DINOProjectionHead( 14 | input_dim, 512, 64, 2048, freeze_last_layer=1 15 | ) 16 | self.teacher_backbone = copy.deepcopy(backbone) 17 | self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048) 18 | deactivate_requires_grad(self.teacher_backbone) 19 | deactivate_requires_grad(self.teacher_head) 20 | 21 | def forward(self, x): 22 | y = self.student_backbone(x).flatten(start_dim=1) 23 | z = self.student_head(y) 24 | return z 25 | 26 | def forward_teacher(self, x): 27 | y = self.teacher_backbone(x).flatten(start_dim=1) 28 | z = self.teacher_head(y) 29 | return z 30 | 31 | def get_dino_backbone(dino_model_name: str): 32 | backbone = torch.hub.load('facebookresearch/dino:main', dino_model_name, pretrained=False) 33 | input_dim = backbone.embed_dim 34 | return (backbone, input_dim) 35 | 36 | def get_resnet_backbone(): 37 | resnet = torchvision.models.resnet18() 38 | backbone = nn.Sequential(*list(resnet.children())[:-1]) 39 | input_dim = 512 40 | return (backbone, input_dim) 41 | 42 | -------------------------------------------------------------------------------- /self-training/testpf.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | kind: experiment 3 | framework: pytorch 4 | tags: 5 | - TEST 6 | build: 7 | image: pytorchlightning/pytorch_lightning:base-xla-py3.7-torch1.12 8 | build_steps: 9 | - pip install -r requirements.txt 10 | - pip install markupsafe==2.0.1 --force 11 | - pip3 install torch_xla[tpuvm] 12 | - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 13 | 14 | environment: 15 | resources: 16 | cpu: 17 | requests: 12 18 | limits: 48 19 | memory: 20 | requests: 24576 21 | limits: 64000 22 | gpu: 23 | requests: 1 24 | limits: 1 25 | limits: 26 | nvidia.com/gpu: "2" 27 | 28 | params: 29 | batch_size: 48 30 | exp_name: train_dinoLightningModule 31 | backbone: dino_vits16 32 | epochs: 2 33 | loader_num_workers: 8 34 | dataset_input_size: 256 35 | train_lr: 1e-3 36 | pretrained_weights: False 37 | dataset: liver2_mini 38 | loader: patch64 39 | 40 | run: 41 | cmd: python -m tasks.train \ 42 | wandb.setup.project=self-train \ 43 | wandb=server \ 44 | dataset={{ dataset }} \ 45 | experiment.name={{ exp_name }} \ 46 | train.backbone={{ backbone }} \ 47 | train.epochs={{ epochs }} \ 48 | train.pretrained_weights={{ pretrained_weights }} \ 49 | loader.batch_size={{ batch_size }} \ 50 | loader.num_workers={{ loader_num_workers }} \ 51 | dataset.input_size={{ dataset_input_size }} \ 52 | loader={{ loader }} \ 53 | train.lr={{ train_lr }} 54 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/config/base_customdata.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | run: 4 | dir: ./outputs/${name}/${now:%Y-%m-%d--%H-%M-%S} 5 | 6 | name: "debug" 7 | seed: 1 8 | job_type: 'train' 9 | fp16: False 10 | cpu: False 11 | wandb: True 12 | wandb_kwargs: 13 | project: dsp-selftraining 14 | 15 | data: 16 | num_classes: 6 17 | dataset: liver_mixed_val_mini 18 | train_dataset: 19 | root_dir: /home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/LIVER_MIXED/val_mini 20 | gt_dir: null 21 | pred_dir: 22 | image_dir: /home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/LIVER_MIXED/val_mini/images 23 | val_dataset: 24 | root_dir: /home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/LIVER_MIXED/val_mini 25 | gt_dir: /home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/LIVER_MIXED/val_mini/labels 26 | pred_dir: 27 | image_dir: /home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/LIVER_MIXED/val_mini/images 28 | loader: 29 | batch_size: 144 30 | num_workers: 8 31 | pin_memory: False 32 | transform: 33 | resize_size: 256 34 | crop_size: 224 35 | img_mean: [0.485, 0.456, 0.406] 36 | img_std: [0.229, 0.224, 0.225] 37 | 38 | segments_dir: /home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/liver_mixed_val_mini/exp_baseline/2023-11-27/15-26-14/baseline/seg15_clust6_norm-imagenet_prepr-None_dino1_clusterkmeans_time2023-11-27_15-26-22/semantic_segmentations/laplacian/crf_segmaps 39 | 40 | logging: 41 | print_freq: 50 -------------------------------------------------------------------------------- /data_preprocessing/normalize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import argparse 5 | 6 | def normalize_images_in_folder(input_folder, output_folder): 7 | # Create output folder if it doesn't exist 8 | if not os.path.exists(output_folder): 9 | os.makedirs(output_folder) 10 | 11 | # Loop through images in input folder 12 | for filename in os.listdir(input_folder): 13 | if filename.endswith(".png"): 14 | # Load PNG image 15 | image_path = os.path.join(input_folder, filename) 16 | image = Image.open(image_path) 17 | 18 | # Convert image to numerical array 19 | image_array = np.array(image) 20 | 21 | # Find max value 22 | max_value = np.max(image_array) 23 | 24 | # Normalize values to range 0-255 25 | normalized_array = (image_array / max_value) * 255 26 | 27 | # Round normalized values and convert to uint8 28 | normalized_array = np.round(normalized_array).astype(np.uint8) 29 | 30 | # Convert array back to image 31 | normalized_image = Image.fromarray(normalized_array) 32 | 33 | # Save normalized image to output folder 34 | output_path = os.path.join(output_folder, filename) 35 | normalized_image.save(output_path) 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="Normalize images in a folder.") 39 | parser.add_argument("input_folder", help="Path to the input folder containing PNG images.") 40 | parser.add_argument("output_folder", help="Path to the output folder for saving normalized images.") 41 | args = parser.parse_args() 42 | 43 | normalize_images_in_folder(args.input_folder, args.output_folder) 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /self-training/pfile_group_triplet.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | kind: group 3 | framework: pytorch 4 | tags: 5 | - triplet 6 | - group_exp 7 | build: 8 | image: pytorchlightning/pytorch_lightning:base-xla-py3.7-torch1.12 9 | build_steps: 10 | - pip install -r requirements.txt 11 | - pip install markupsafe==2.0.1 --force 12 | - pip3 install torch_xla[tpuvm] 13 | - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 14 | 15 | environment: 16 | resources: 17 | cpu: 18 | requests: 12 19 | limits: 48 20 | memory: 21 | requests: 24576 22 | limits: 64000 23 | gpu: 24 | requests: 1 25 | limits: 1 26 | 27 | params: 28 | batch_size: 48 29 | exp_name: train_triplet 30 | epochs: 300 31 | loader_num_workers: 8 32 | dataset_input_size: 256 33 | train_lr: 1e-3 34 | loader: triplet_patch 35 | train: triplet 36 | train_triplet_loss_margin: 10 37 | 38 | hptuning: 39 | concurrency: 1 40 | matrix: 41 | pretrained_weights: 42 | values: [True] 43 | dataset: 44 | values: [liver_reduced,carotid_mutinfo] 45 | backbone: 46 | values: [resnet,dino_vits16] 47 | 48 | run: 49 | cmd: python -m tasks.train \ 50 | wandb.setup.project=self-train \ 51 | wandb=server \ 52 | dataset={{ dataset }} \ 53 | experiment.name={{ exp_name }} \ 54 | train={{ train }} 55 | train.backbone={{ backbone }} \ 56 | train.epochs={{ epochs }} \ 57 | train.pretrained_weights={{ pretrained_weights }} \ 58 | loader.batch_size={{ batch_size }} \ 59 | loader.num_workers={{ loader_num_workers }} \ 60 | dataset.input_size={{ dataset_input_size }} \ 61 | loader={{ loader }} \ 62 | train.lr={{ train_lr }} \ 63 | train.triplet_loss_margin={{ train_triplet_loss_margin }} 64 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from joblib import Parallel 3 | from joblib.parallel import delayed 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | 7 | def hungarian_match(flat_preds, flat_targets, preds_k, targets_k, metric='acc', n_jobs=16): 8 | assert (preds_k == targets_k) # one to one 9 | num_k = preds_k 10 | 11 | # perform hungarian matching 12 | print('Using iou as metric') 13 | results = Parallel(n_jobs=n_jobs, backend='multiprocessing')(delayed(get_iou)( 14 | flat_preds, flat_targets, c1, c2) for c2 in range(num_k) for c1 in range(num_k)) 15 | results = np.array(results) 16 | results = results.reshape((num_k, num_k)).T 17 | match = linear_sum_assignment(flat_targets.shape[0] - results) 18 | match = np.array(list(zip(*match))) 19 | res = [] 20 | for out_c, gt_c in match: 21 | res.append((out_c, gt_c)) 22 | 23 | return res 24 | 25 | 26 | def majority_vote(flat_preds, flat_targets, preds_k, targets_k, n_jobs=16): 27 | iou_mat = Parallel(n_jobs=n_jobs, backend='multiprocessing')(delayed(get_iou)( 28 | flat_preds, flat_targets, c1, c2) for c2 in range(targets_k) for c1 in range(preds_k)) 29 | iou_mat = np.array(iou_mat) 30 | results = iou_mat.reshape((targets_k, preds_k)).T 31 | results = np.argmax(results, axis=1) 32 | match = np.array(list(zip(range(preds_k), results))) 33 | return match 34 | 35 | 36 | def get_iou(flat_preds, flat_targets, c1, c2): 37 | tp = 0 38 | fn = 0 39 | fp = 0 40 | tmp_all_gt = (flat_preds == c1) 41 | tmp_pred = (flat_targets == c2) 42 | tp += np.sum(tmp_all_gt & tmp_pred) 43 | fp += np.sum(~tmp_all_gt & tmp_pred) 44 | fn += np.sum(tmp_all_gt & ~tmp_pred) 45 | jac = float(tp) / max(float(tp + fp + fn), 1e-8) 46 | return jac 47 | -------------------------------------------------------------------------------- /sota/SLIC/generate_slic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import argparse 5 | from skimage import io 6 | import matplotlib.pyplot as plt 7 | 8 | # from skimage.data import lena 9 | from skimage.segmentation import felzenszwalb, slic, quickshift 10 | from skimage.segmentation import mark_boundaries 11 | from skimage.util import img_as_float 12 | 13 | def generate_slic(input_folder, output_folder): 14 | # Create output folder if it doesn't exist 15 | if not os.path.exists(output_folder): 16 | os.makedirs(output_folder) 17 | 18 | # Loop through images in input folder 19 | for filename in os.listdir(input_folder): 20 | if filename.endswith(".png"): 21 | # Load PNG image 22 | image_path = os.path.join(input_folder, filename) 23 | img = io.imread(image_path, as_gray=False, plugin='matplotlib') 24 | img = img_as_float(img) 25 | segmap = slic(img, n_segments=15, compactness=0.1, min_size_factor=0.01, sigma=0.1, channel_axis=None) 26 | 27 | # Round normalized values and convert to uint8 28 | segmap = np.round(segmap).astype(np.uint8) 29 | 30 | # Convert array back to image 31 | segmap = Image.fromarray(segmap) 32 | 33 | # Save normalized image to output folder 34 | output_path = os.path.join(output_folder, filename) 35 | segmap.save(output_path) 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="Generate slic segmentation in a folder.") 39 | parser.add_argument("input_folder", help="Path to the input folder containing PNG images.") 40 | parser.add_argument("output_folder", help="Path to the output folder for saving normalized images.") 41 | args = parser.parse_args() 42 | 43 | generate_slic(args.input_folder, args.output_folder) 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /sota/SLIC/generate_fz.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import argparse 5 | from skimage import io 6 | import matplotlib.pyplot as plt 7 | 8 | # from skimage.data import lena 9 | from skimage.segmentation import felzenszwalb 10 | from skimage.util import img_as_float 11 | 12 | def generate_fz(input_folder, output_folder): 13 | # Create output folder if it doesn't exist 14 | if not os.path.exists(output_folder): 15 | os.makedirs(output_folder) 16 | 17 | # Loop through images in input folder 18 | for filename in os.listdir(input_folder): 19 | if filename.endswith(".png"): 20 | # Load PNG image 21 | image_path = os.path.join(input_folder, filename) 22 | img = io.imread(image_path, as_gray=False, plugin='matplotlib') 23 | img = img_as_float(img) 24 | # segmap = slic(img, n_segments=15, compactness=0.1, min_size_factor=0.01, sigma=0.1, channel_axis=None) 25 | segmap = felzenszwalb(img, sigma=0.1, min_size=1000) 26 | 27 | # Round normalized values and convert to uint8 28 | segmap = np.round(segmap).astype(np.uint8) 29 | 30 | # Convert array back to image 31 | segmap = Image.fromarray(segmap) 32 | 33 | # Save normalized image to output folder 34 | output_path = os.path.join(output_folder, filename) 35 | segmap.save(output_path) 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="Generate felzenszwalb segmentation in a folder.") 39 | parser.add_argument("input_folder", help="Path to the input folder containing PNG images.") 40 | parser.add_argument("output_folder", help="Path to the output folder for saving normalized images.") 41 | args = parser.parse_args() 42 | 43 | generate_fz(args.input_folder, args.output_folder) 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /self-training/pfile_group.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | kind: group 3 | framework: pytorch 4 | tags: 5 | - dinov1_vits8 6 | - LIVER_REDUCED 7 | - 10 epochs 8 | - group_exp 9 | build: 10 | image: pytorchlightning/pytorch_lightning:base-xla-py3.7-torch1.12 11 | build_steps: 12 | - pip install -r requirements.txt 13 | - pip install markupsafe==2.0.1 --force 14 | - pip3 install torch_xla[tpuvm] 15 | - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 16 | 17 | environment: 18 | resources: 19 | cpu: 20 | requests: 12 21 | limits: 48 22 | memory: 23 | requests: 32768 24 | limits: 64000 25 | gpu: 26 | requests: 1 27 | limits: 1 28 | # node_selector: 29 | # gpuMemory: "24" 30 | 31 | 32 | params: 33 | batch_size: 8 34 | exp_name: train_dinoLightningModule 35 | backbone: dino_vits8 36 | epochs: 10 37 | loader_num_workers: 8 38 | dataset_input_size: 256 39 | train_lr: 1e-5 40 | train_weight_decay: 1e-5 41 | 42 | hptuning: 43 | concurrency: 1 44 | matrix: 45 | pretrained_weights: 46 | values: [False, True] 47 | dataset: 48 | values: [liver_reduced] 49 | fraction_layers_to_freeze: 50 | values: [0.0] 51 | loader: 52 | values: [defaults,patch64] 53 | 54 | 55 | run: 56 | cmd: python -m tasks.train \ 57 | wandb.setup.project=SSL-fine-tuning \ 58 | wandb=server \ 59 | dataset={{ dataset }} \ 60 | experiment.name={{ exp_name }} \ 61 | train.backbone={{ backbone }} \ 62 | train.epochs={{ epochs }} \ 63 | train.pretrained_weights={{ pretrained_weights }} \ 64 | loader.batch_size={{ batch_size }} \ 65 | loader.num_workers={{ loader_num_workers }} \ 66 | dataset.input_size={{ dataset_input_size }} \ 67 | loader={{ loader }} \ 68 | train.lr={{ train_lr }} \ 69 | train.weight_decay={{ train_weight_decay }} \ 70 | train.fraction_layers_to_freeze={{ fraction_layers_to_freeze }} 71 | 72 | -------------------------------------------------------------------------------- /self-training/datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | class HistogramNormalize: 7 | """Performs histogram normalization on numpy array and returns 8-bit image. 8 | 9 | Code was taken from lightly, but adpated to work with PIL image as input: 10 | https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_custom_augmentations.html 11 | who adapted it from Facebook: 12 | https://github.com/facebookresearch/CovidPrognosis 13 | 14 | """ 15 | 16 | def __init__(self, number_bins: int = 256): 17 | self.number_bins = number_bins 18 | 19 | def __call__(self, image: np.array) -> Image: 20 | if not isinstance(image, np.ndarray): 21 | image = np.array(image) 22 | # Get the image histogram. 23 | image_histogram, bins = np.histogram( 24 | image.flatten(), self.number_bins, density=True 25 | ) 26 | cdf = image_histogram.cumsum() # cumulative distribution function 27 | cdf = 255 * cdf / cdf[-1] # normalize 28 | 29 | # Use linear interpolation of cdf to find new pixel values. 30 | image_equalized = np.interp(image.flatten(), bins[:-1], cdf) 31 | pil_image = Image.fromarray(np.uint8(image_equalized.reshape(image.shape))) 32 | # return Image.fromarray(image_equalized.reshape(image.shape)) 33 | return pil_image 34 | 35 | class GaussianNoise: 36 | """Applies random Gaussian noise to a tensor. 37 | 38 | Code was taken from lightly tutirials: 39 | https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_custom_augmentations.html 40 | 41 | The intensity of the noise is dependent on the mean of the pixel values. 42 | See https://arxiv.org/pdf/2101.04909.pdf for more information. 43 | 44 | """ 45 | 46 | def __call__(self, sample: torch.Tensor) -> torch.Tensor: 47 | mu = sample.mean() 48 | snr = np.random.randint(low=4, high=8) 49 | sigma = mu / snr 50 | noise = torch.normal(torch.zeros(sample.shape), sigma) 51 | return sample + noise -------------------------------------------------------------------------------- /data_preprocessing/remap_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import argparse 5 | 6 | def remap_lables(old_label_image, old_labels, new_labels): 7 | old_labels_unique = np.unique(old_label_image) 8 | print('old_current', old_labels_unique) 9 | print('old_assumed', old_labels) 10 | # for old_current, old_assumed in zip(old_labels_unique,np.array(old_labels) ): 11 | # assert(old_current==old_assumed) 12 | new_label_image = np.zeros_like(old_label_image) 13 | for old_label_i, target_label_i in zip(old_labels, new_labels): 14 | new_label_image[old_label_image == int(old_label_i)] = int(target_label_i) 15 | return new_label_image 16 | 17 | def remap_labels_in_folder(input_folder, output_folder): 18 | # Create output folder if it doesn't exist 19 | if not os.path.exists(output_folder): 20 | os.makedirs(output_folder) 21 | 22 | # Loop through images in input folder 23 | for filename in os.listdir(input_folder): 24 | if filename.endswith(".png"): 25 | # Load PNG image 26 | image_path = os.path.join(input_folder, filename) 27 | im = np.array(Image.open(image_path)) 28 | 29 | # get current and new labels 30 | current_labels = np.unique(im).tolist() 31 | new_labels = [i for i in range(len(current_labels))] 32 | 33 | # remap 34 | new_im = remap_lables(im, current_labels, new_labels) 35 | 36 | # save 37 | new_segm_image = Image.fromarray(new_im) 38 | output_path = os.path.join(output_folder, filename) 39 | new_segm_image.save(str(output_path)) 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser(description="Remap labels in a folder.") 43 | parser.add_argument("input_folder", help="Path to the input folder containing PNG images.") 44 | parser.add_argument("output_folder", help="Path to the output folder for saving normalized images.") 45 | args = parser.parse_args() 46 | 47 | remap_labels_in_folder(args.input_folder, args.output_folder) 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /self-training/custom_utils/check_dimensions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PIL import Image 3 | import os 4 | 5 | def check_image_dimensions(path, expected_size): 6 | # get the list of files in the directory 7 | files = os.listdir(path) 8 | 9 | # initialize a list to store filenames that do not have the expected dimensions 10 | non_matching_files = [] 11 | 12 | # loop through each file in the directory 13 | for file in files: 14 | # check if the file is an image file (i.e., has a file extension of .jpg, .png, etc.) 15 | if file.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')): 16 | # open the image file using the Pillow library 17 | with Image.open(os.path.join(path, file)) as img: 18 | # get the width and height of the image 19 | width, height = img.size 20 | # check if the dimensions of the image match the expected dimensions 21 | if (width, height) != expected_size: 22 | non_matching_files.append(file) 23 | print(f'File {file} has size: ({width},{height})') 24 | 25 | # return the list of filenames that do not have the expected dimensions 26 | return non_matching_files 27 | 28 | def main(): 29 | # check if the correct number of arguments were provided 30 | if len(sys.argv) != 4: 31 | print("Usage: python script.py path height width") 32 | return 33 | 34 | # get the path, height, and width from the command line arguments 35 | path = sys.argv[1] 36 | height = int(sys.argv[2]) 37 | width = int(sys.argv[3]) 38 | 39 | # check if the specified directory exists 40 | if not os.path.exists(path): 41 | print(f"The directory {path} does not exist") 42 | return 43 | 44 | # check the dimensions of the images in the directory 45 | non_matching_files = check_image_dimensions(path, (width, height)) 46 | 47 | # print the filenames of the images that do not have the expected dimensions 48 | if non_matching_files: 49 | print("The following files do not have the expected dimensions:") 50 | for file in non_matching_files: 51 | print(file) 52 | else: 53 | print("All image files in the directory have the expected dimensions") 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /self-training/models/simclrTripletLightningModule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torchvision 4 | from torch import nn 5 | from lightly.models.modules import SimCLRProjectionHead 6 | 7 | class SimCLRTriplet(pl.LightningModule): 8 | def __init__(self, backbone, hidden_dim, max_epochs=1, lr = 0.001, optimizer="Adam", triplet_loss_margin=1): 9 | super().__init__() 10 | 11 | self.optimizer_choice=optimizer 12 | self.max_epochs = max_epochs 13 | self.lr=lr 14 | self.backbone=backbone 15 | self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 512) 16 | # https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html 17 | self.criterion = torch.nn.TripletMarginLoss(margin=triplet_loss_margin) 18 | 19 | def forward(self, anchor, pos, neg): 20 | x_anchor = self.backbone(anchor).flatten(start_dim=1) 21 | z_anchor = self.projection_head(x_anchor) 22 | 23 | x_pos = self.backbone(pos).flatten(start_dim=1) 24 | z_pos = self.projection_head(x_pos) 25 | 26 | x_neg = self.backbone(neg).flatten(start_dim=1) 27 | z_neg = self.projection_head(x_neg) 28 | 29 | return z_anchor, z_pos, z_neg 30 | 31 | def _common_step(self, batch, mode='train'): 32 | (anchor, _, _), (pos, _, _,), (neg, _, _) = batch #unpack the batch, lower dash_ stands for targets and fnames 33 | z_anchor, z_pos, z_neg = self.forward(anchor, pos, neg) 34 | loss = self.criterion(z_anchor, z_pos, z_neg) 35 | self.log(f'{mode}_loss', loss) 36 | return loss 37 | 38 | def training_step(self, batch, batch_idx): 39 | return self._common_step(batch, mode='train') 40 | 41 | def validation_step(self, batch, batch_idx): 42 | return self._common_step(batch, mode='val') 43 | 44 | def configure_optimizers(self): 45 | if self.optimizer_choice == "sdg": 46 | optim = torch.optim.SGD( 47 | self.parameters(), lr=self.lr, momentum=0.9, weight_decay=5e-4 48 | ) 49 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 50 | optim, self.max_epochs 51 | ) 52 | return [optim], [scheduler] 53 | elif self.optimizer_choice == "Adam": 54 | optim = torch.optim.Adam(self.parameters(), lr=self.lr) 55 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 56 | optim, self.max_epochs 57 | ) 58 | return [optim], [scheduler] 59 | else: 60 | raise NotImplementedError() -------------------------------------------------------------------------------- /self-training/models/simclrLightningModule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torchvision 4 | from torch import nn 5 | 6 | from lightly.loss import NTXentLoss 7 | from lightly.models.modules import SimCLRProjectionHead 8 | 9 | # Note: The model and training settings do not follow the reference settings 10 | # from the paper. The settings are chosen such that the example can easily be 11 | # run on a small dataset with a single GPU. 12 | 13 | class SimCLR(pl.LightningModule): 14 | def __init__(self, backbone, hidden_dim, max_epochs=1, lr = 0.001, optimizer="Adam"): 15 | super().__init__() 16 | 17 | self.optimizer_choice=optimizer 18 | self.max_epochs = max_epochs 19 | self.lr=lr 20 | self.backbone=backbone 21 | self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 512) 22 | self.criterion = NTXentLoss() 23 | 24 | def forward(self, x): 25 | x = self.backbone(x).flatten(start_dim=1) 26 | z = self.projection_head(x) 27 | return z 28 | 29 | def _common_step(self, batch, mode='train'): 30 | (x0, x1), _, _ = batch 31 | z0 = self.forward(x0) 32 | z1 = self.forward(x1) 33 | loss = self.criterion(z0, z1) 34 | self.log(f'{mode}_loss', loss) 35 | return loss 36 | 37 | def training_step(self, batch, batch_idx): 38 | return self._common_step(batch, mode='train') 39 | 40 | def validation_step(self, batch, batch_idx): 41 | return self._common_step(batch, mode='val') 42 | 43 | def configure_optimizers(self): 44 | if self.optimizer_choice == "sdg": 45 | optim = torch.optim.SGD( 46 | self.parameters(), lr=self.lr, momentum=0.9, weight_decay=5e-4 47 | ) 48 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 49 | optim, self.max_epochs 50 | ) 51 | return [optim], [scheduler] 52 | elif self.optimizer_choice == "Adam": 53 | optim = torch.optim.Adam(self.parameters(), lr=self.lr) 54 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 55 | optim, self.max_epochs 56 | ) 57 | return [optim], [scheduler] 58 | else: 59 | raise NotImplementedError() 60 | 61 | def get_resnet_backbone(pretrained_weights = False): 62 | if pretrained_weights: 63 | resnet = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) 64 | else: 65 | resnet = torchvision.models.resnet18() 66 | backbone = nn.Sequential(*list(resnet.children())[:-1]) 67 | hidden_dim = resnet.fc.in_features 68 | return (backbone, hidden_dim) -------------------------------------------------------------------------------- /self-training/train-mini-example.py: -------------------------------------------------------------------------------- 1 | from models import dino 2 | import torch 3 | from lightly.data import DINOCollateFunction, LightlyDataset 4 | from lightly.loss import DINOLoss 5 | from lightly.models.utils import update_momentum 6 | from lightly.utils.scheduler import cosine_schedule 7 | 8 | import hydra 9 | from omegaconf import DictConfig, OmegaConf 10 | import os 11 | import logging 12 | 13 | # A logger for this file 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | @hydra.main(version_base=None, config_path="../configs", config_name="config") 18 | def train(cfg: DictConfig) -> None: 19 | log.info(OmegaConf.to_yaml(cfg)) 20 | log.info("Current working directory : {}".format(os.getcwd())) 21 | 22 | # model 23 | backbone, input_dim = dino.get_dino_backbone("dino_vits16") 24 | model = dino.DINO(backbone, input_dim) 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | model.to(device) 27 | 28 | 29 | # data 30 | # dataset = LightlyDataset("../data/liver2_mini/train") 31 | dataset = LightlyDataset(cfg.dataset.path) 32 | collate_fn = DINOCollateFunction() 33 | 34 | dataloader = torch.utils.data.DataLoader( 35 | dataset, 36 | batch_size=1, #make smaller for dino backbone, was 64 for resnet 37 | collate_fn=collate_fn, 38 | shuffle=True, 39 | drop_last=True, 40 | num_workers=0, 41 | ) 42 | 43 | print(len(dataloader)) 44 | 45 | criterion = DINOLoss( 46 | output_dim=2048, 47 | warmup_teacher_temp_epochs=5, 48 | ) 49 | # move loss to correct device because it also contains parameters 50 | criterion = criterion.to(device) 51 | 52 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 53 | 54 | epochs = 10 55 | 56 | log.info("Starting Training") 57 | for epoch in range(epochs): 58 | total_loss = 0 59 | momentum_val = cosine_schedule(epoch, epochs, 0.996, 1) 60 | for views, _, _ in dataloader: 61 | update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val) 62 | update_momentum(model.student_head, model.teacher_head, m=momentum_val) 63 | views = [view.to(device) for view in views] 64 | global_views = views[:2] 65 | teacher_out = [model.forward_teacher(view) for view in global_views] 66 | student_out = [model.forward(view) for view in views] 67 | loss = criterion(teacher_out, student_out, epoch=epoch) 68 | total_loss += loss.detach() 69 | loss.backward() 70 | # We only cancel gradients of student head. 71 | model.student_head.cancel_last_layer_gradients(current_epoch=epoch) 72 | optimizer.step() 73 | optimizer.zero_grad() 74 | 75 | avg_loss = total_loss / len(dataloader) 76 | log.info(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") 77 | 78 | 79 | if __name__ == "__main__": 80 | train() -------------------------------------------------------------------------------- /data_preprocessing/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | from PIL import Image 5 | import kornia as K 6 | import torch 7 | 8 | from torch.nn import Module 9 | from torch import Tensor 10 | from typing import Any, Callable, Iterable, Optional, Tuple, Union 11 | from kornia.enhance import equalize_clahe 12 | 13 | import argparse 14 | 15 | 16 | class EqualizeClahe(Module): 17 | def __init__(self, 18 | clip_limit: float = 40.0, 19 | grid_size: Tuple[int, int] = (8, 8), 20 | slow_and_differentiable: bool = False 21 | ) -> None: 22 | super().__init__() 23 | self.clip_limit = clip_limit 24 | self.grid_size = grid_size 25 | self.slow_and_differentiable = slow_and_differentiable 26 | 27 | def __repr__(self) -> str: 28 | return ( 29 | f"{self.__class__.__name__}" 30 | f"(clip_limit={self.clip_limit}, " 31 | f"grid_size={self.grid_size}, " 32 | f"slow_and_differentiable={self.slow_and_differentiable})" 33 | ) 34 | 35 | def forward(self, input: Tensor) -> Tensor: 36 | # ref: https://kornia.readthedocs.io/en/latest/_modules/kornia/enhance/equalization.html#equalize_clahe 37 | return equalize_clahe(input, self.clip_limit, self.grid_size, self.slow_and_differentiable) 38 | 39 | def preprocess_dataset(image_folder, processed_image_folder, preprocessing_pipeline): 40 | # preprocessing_pipeline: of type torch.transforms or K.augmentation.container.ImageSequential 41 | 42 | if not os.path.exists(processed_image_folder): 43 | os.makedirs(processed_image_folder) 44 | 45 | images = sorted(os.listdir(image_folder)) 46 | 47 | for im_name in tqdm(images): 48 | # read file 49 | im_file = os.path.join(image_folder, im_name) 50 | x_rgb: torch.Tensor = K.io.load_image(im_file, K.io.ImageLoadType.RGB32)[None, ...] # BxCxHxW 51 | 52 | # process 53 | x = preprocessing_pipeline(x_rgb) 54 | x_numpy = K.utils.image.tensor_to_image(x) 55 | processed_image = Image.fromarray(np.uint8(x_numpy* 255)) 56 | 57 | # save 58 | new_im_file = os.path.join(processed_image_folder, im_name) 59 | 60 | processed_image.save(str(new_im_file)) 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser(description='Preprocessing of ultrasound data (Offline)') 64 | parser.add_argument('--dataset_folder', type=str, 65 | default='/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/CAROTID_MIXED/val_mini', 66 | help='Path to the root dataset folder containing folder "images"') 67 | args = parser.parse_args() 68 | 69 | preprocessing_pipeline = K.augmentation.container.ImageSequential( 70 | K.filters.GaussianBlur2d((3,3), (5.,5.)), 71 | EqualizeClahe(grid_size = (2,2)), 72 | # K.filters.MedianBlur((5,5)), 73 | ) 74 | 75 | folders_to_process = [ 76 | # TODO: account for passing a folder with folders, or a list of folder paths 77 | args.dataset_folder, 78 | ] 79 | 80 | for root_folder in folders_to_process: 81 | images_folder = os.path.join(root_folder, 'images') 82 | assert(os.path.exists(images_folder)) 83 | preprocessed_folder = os.path.join(root_folder, 'preprocessed') 84 | preprocess_dataset(images_folder,preprocessed_folder,preprocessing_pipeline) 85 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/README.md: -------------------------------------------------------------------------------- 1 | ## Semantic Segmentation 2 | 3 | We begin by extracting features and eigenvectors from our images. For instructions on this process, follow the steps in "Extraction" in the main `README`. 4 | 5 | Next, we obtain coarse (i.e. patch-level) semantic segmentations. This process involves (1) extracting segments from the eigenvectors, (2) taking a bounding box around them, (3) extracting features for these boxes, (4) clustering these features, (5) obtaining coarse semantic segmentations. 6 | 7 | For example, you can run the following in the `extract` directory. 8 | 9 | ```bash 10 | # Example parameters for the semantic segmentation experiments 11 | DATASET="VOC2012" 12 | MODEL="dino_vits16" 13 | MATRIX="laplacian" 14 | DOWNSAMPLE=16 15 | N_SEG=15 16 | N_ERODE=2 17 | N_DILATE=5 18 | 19 | # Extract segments 20 | python extract.py extract_multi_region_segmentations \ 21 | --non_adaptive_num_segments ${N_SEG} \ 22 | --features_dir "./data/${DATASET}/features/${MODEL}" \ 23 | --eigs_dir "./data/${DATASET}/eigs/${MATRIX}" \ 24 | --output_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" 25 | 26 | # Extract bounding boxes 27 | python extract.py extract_bboxes \ 28 | --features_dir "./data/${DATASET}/features/${MODEL}" \ 29 | --segmentations_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" \ 30 | --num_erode ${N_ERODE} \ 31 | --num_dilate ${N_DILATE} \ 32 | --downsample_factor ${DOWNSAMPLE} \ 33 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bboxes.pth" 34 | 35 | # Extract bounding box features 36 | python extract.py extract_bbox_features \ 37 | --model_name ${MODEL} \ 38 | --images_root "./data/${DATASET}/images" \ 39 | --bbox_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bboxes.pth" \ 40 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_features.pth" 41 | 42 | # Extract clusters 43 | python extract.py extract_bbox_clusters \ 44 | --bbox_features_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_features.pth" \ 45 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_clusters.pth" 46 | 47 | # Create semantic segmentations 48 | python extract.py extract_semantic_segmentations \ 49 | --segmentations_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" \ 50 | --bbox_clusters_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_clusters.pth" \ 51 | --output_dir "./data/${DATASET}/semantic_segmentations/patches/${MATRIX}/segmaps" 52 | ``` 53 | 54 | At this point, you can evaluate the segmentations using `eval.py` in this directory. For example: 55 | ```bash 56 | python eval.py segments_dir="/output_dir/from/above" 57 | ``` 58 | 59 | Optionally, you can also perform self-training using `train.py`. You can specify the correct matching using `matching="\"[(0, 0), ... (19, 6), (20, 7)]\""`. This matching may be obtained by first evaluating using `python eval.py`. For example: 60 | ```bash 61 | python train.py lr=2e-4 data.loader.batch_size=96 segments_dir="/path/to/segmaps" matching="\"[(0, 0), ... (19, 6), (20, 7)]\"" 62 | ``` 63 | 64 | Please note that the unsupervised semantic segmentation results have very high variance; some runs are much better than others. This variance is primarily due to the random seeds of the K-means clustering steps above, and it is secondarily due to randomness in the self-training stage. Also please note that this code has been heavily re-factored for its public release. Although we try to ensure that there are no bugs, it is nevertheless possible that there is a bug we have overlooked. 65 | -------------------------------------------------------------------------------- /evaluation/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | from pathlib import Path 7 | import cv2 8 | from tqdm import tqdm, trange 9 | from torchvision import transforms 10 | 11 | 12 | class EvalDataset(Dataset): 13 | def __init__(self, root_dir, gt_dir = "", pred_dir = "", transform: object = None, image_dir = ""): 14 | self.root_dir = root_dir 15 | self.image_dir = image_dir if image_dir!="" else os.path.join(root_dir, 'images') 16 | self.gt_dir = gt_dir if gt_dir!="" else os.path.join(root_dir, 'ground_truth') 17 | self.pred_dir = pred_dir if pred_dir!="" else os.path.join(root_dir, 'predictions') 18 | self.image_list = os.listdir(self.image_dir) 19 | self.transform = transform 20 | 21 | print("root:", self.root_dir) 22 | print("image_dir:", self.image_dir) 23 | print("gt_dir:", self.gt_dir) 24 | print("pred_dir:", self.pred_dir) 25 | 26 | def __len__(self): 27 | return len(self.image_list) 28 | 29 | def __getitem__(self, idx): 30 | img_name_full = self.image_list[idx] 31 | img_name = img_name_full[:-4] 32 | img_path = os.path.join(self.image_dir, img_name_full) 33 | gt_path = os.path.join(self.gt_dir, img_name + ".png") 34 | pred_path = os.path.join(self.pred_dir, img_name + ".png") 35 | 36 | # assert(os.path.isfile(gt_path)) 37 | 38 | if self.transform is not None: 39 | image = self.transform(Image.open(img_path).convert("RGB")) 40 | # print(f"EvalDataset: Image after transforms: {image.size()}") 41 | image = image.detach() 42 | if image.is_cuda: 43 | image = image.cpu().numpy() 44 | else: 45 | image = image.numpy() 46 | image = image.transpose(1, 2, 0) 47 | # TODO: consider denormalizing image to ensure correct plotting (see extract_crf step of pipeline) 48 | # print(f"EvalDataset: Image after numpy: {image.shape}") 49 | ground_truth = np.array(Image.open(gt_path).convert('L')) 50 | prediction = np.array(Image.open(pred_path).convert('L')) 51 | else: 52 | image = np.array(Image.open(img_path).convert("RGB")) 53 | ground_truth = np.array(Image.open(gt_path).convert('L')) 54 | prediction = np.array(Image.open(pred_path).convert('L')) 55 | 56 | # Resize masks of image size not matching 57 | prediction = self._resize_mask(prediction, image) 58 | ground_truth = self._resize_mask(ground_truth, image) 59 | 60 | metadata = {'id': Path(img_path).stem, 'path': img_path, 'shape': tuple(image.shape[:2])} 61 | 62 | return (image, ground_truth, prediction, metadata) 63 | 64 | def _resize_mask(self, mask, image): 65 | # Check if sizes correspond 66 | H_im, W_im = image.shape[:2] 67 | H_mask, W_mask = mask.shape 68 | 69 | if (H_mask!= H_im or W_mask!=W_im): 70 | mask = cv2.resize(mask, dsize=(W_im, H_im), interpolation=cv2.INTER_NEAREST) # (W, H) for cv2 71 | return mask 72 | 73 | # if __name__ == '__main__': 74 | # root_dir = "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/US_MIXED/val" 75 | # gt_dir = "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/US_MIXED/val/lables" 76 | # pred_dir = "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/US_MIXED/val/predictions/maskcut_init_lr0.001_us_mixed_val_thresh0.0" 77 | 78 | # d = EvalDataset(root_dir,gt_dir,pred_dir) 79 | 80 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/dataset/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | from pathlib import Path 7 | import cv2 8 | from tqdm import tqdm, trange 9 | from torchvision import transforms 10 | 11 | class CustomDataset(Dataset): 12 | def __init__(self, root_dir, gt_dir = "", pred_dir = "", transform: object = None, image_dir = "", label_map = None): 13 | self.root_dir = root_dir 14 | self.image_dir = image_dir if image_dir!="" else os.path.join(root_dir, 'images') 15 | self.pred_dir = pred_dir if pred_dir!="" else os.path.join(root_dir, 'predictions') 16 | self.image_list = os.listdir(self.image_dir) 17 | self.transform = transform 18 | self._prepare_label_map(label_map) 19 | 20 | # Use pseudolabels as Ground Truth is no GT directory given (e.g. for self training) 21 | if gt_dir is not None: 22 | self.gt_dir = gt_dir if gt_dir!="" else os.path.join(root_dir, 'ground_truth') 23 | else: 24 | self.gt_dir = self.pred_dir #HACK, cause gt is ignored during training 25 | 26 | print("root:", self.root_dir) 27 | print("image_dir:", self.image_dir) 28 | print("gt_dir:", self.gt_dir) 29 | print("pred_dir:", self.pred_dir) 30 | print("label_map: ", label_map) 31 | 32 | def __len__(self): 33 | return len(self.image_list) 34 | 35 | def __getitem__(self, idx): 36 | img_name_full = self.image_list[idx] 37 | img_name = img_name_full[:-4] 38 | img_path = os.path.join(self.image_dir, img_name_full) 39 | gt_path = os.path.join(self.gt_dir, img_name + ".png") 40 | pred_path = os.path.join(self.pred_dir, img_name + ".png") 41 | 42 | # assert(os.path.isfile(gt_path)) 43 | 44 | # Load data 45 | image = np.array(Image.open(img_path).convert("RGB")) 46 | ground_truth = np.array(Image.open(gt_path).convert('L')) 47 | prediction = np.array(Image.open(pred_path).convert('L')) 48 | metadata = {'id': Path(img_path).stem, 'path': img_path, 'shape': tuple(image.shape[:2])} 49 | 50 | # Resize masks of image size not matching 51 | prediction = self._resize_mask(prediction, image) 52 | ground_truth = self._resize_mask(ground_truth, image) 53 | 54 | # Remap labelmap if matching provided 55 | if self.label_map_fn is not None: 56 | prediction = self.label_map_fn(prediction) 57 | 58 | # Tranform and unpack 59 | if self.transform is not None: 60 | if type(self.transform) is tuple: 61 | for t in self.transform: 62 | data = t(image=image, mask1=ground_truth, mask2=prediction) 63 | image, ground_truth, prediction = data['image'], data['mask1'], data['mask2'] 64 | 65 | if torch.is_tensor(ground_truth): 66 | ground_truth = ground_truth.long() 67 | if torch.is_tensor(prediction): 68 | prediction = prediction.long() 69 | 70 | return image, ground_truth, prediction, metadata 71 | 72 | def _resize_mask(self, mask, image): 73 | # Check if sizes correspond 74 | H_im, W_im = image.shape[:2] 75 | H_mask, W_mask = mask.shape 76 | 77 | if (H_mask!= H_im or W_mask!=W_im): 78 | mask = cv2.resize(mask, dsize=(H_im, W_im), interpolation=cv2.INTER_NEAREST) # (H, W) 79 | return mask 80 | 81 | def _prepare_label_map(self, label_map): 82 | if label_map is not None: 83 | self.label_map_fn = np.vectorize(label_map.__getitem__) 84 | else: 85 | self.label_map_fn = None 86 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import albumentations.pytorch as AP 3 | import cv2 4 | from torch.utils.data._utils.collate import default_collate 5 | 6 | from .voc import VOCSegmentationWithPseudolabels 7 | 8 | from .custom_dataset import CustomDataset 9 | 10 | 11 | def get_transforms(resize_size, crop_size, img_mean, img_std): 12 | 13 | # Multiple training transforms for contrastive learning 14 | train_joint_transform = A.Compose([ 15 | A.SmallestMaxSize(resize_size, interpolation=cv2.INTER_CUBIC), 16 | A.RandomCrop(crop_size, crop_size), 17 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 18 | train_geometric_transform = A.ReplayCompose([ 19 | A.RandomResizedCrop(crop_size, crop_size, interpolation=cv2.INTER_CUBIC), 20 | A.HorizontalFlip(), 21 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 22 | train_separate_transform = A.Compose([ 23 | A.ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), 24 | A.ToGray(p=0.2), A.GaussianBlur(p=0.1), # A.Solarize(p=0.1) 25 | A.Normalize(mean=img_mean, std=img_std), AP.ToTensorV2(), 26 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 27 | 28 | # Validation transform -- no resizing! 29 | val_transform = A.Compose([ 30 | # A.Resize(resize_size, resize_size, interpolation=cv2.INTER_CUBIC), A.CenterCrop(crop_size, crop_size), 31 | A.Normalize(mean=img_mean, std=img_std), AP.ToTensorV2() 32 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 33 | 34 | train_transforms_tuple = (train_joint_transform, train_geometric_transform, train_separate_transform) 35 | return train_transforms_tuple, val_transform 36 | 37 | 38 | def collate_fn(batch): 39 | everything_but_metadata = [t[:-1] for t in batch] 40 | metadata = [t[-1] for t in batch] 41 | return (*default_collate(everything_but_metadata), metadata) 42 | 43 | 44 | def get_datasets_voc(cfg): 45 | 46 | # Get transforms 47 | train_transforms_tuple, val_transform = get_transforms(**cfg.data.transform) 48 | 49 | # Get the label map 50 | if cfg.matching: 51 | matching = dict(eval(str(cfg.matching))) 52 | print(f'Using matching: {matching}') 53 | else: 54 | matching = None 55 | 56 | # Training dataset 57 | dataset_train = VOCSegmentationWithPseudolabels( 58 | **cfg.data.train_kwargs, 59 | segments_dir=cfg.segments_dir, 60 | transforms_tuple=train_transforms_tuple, 61 | label_map=matching 62 | ) 63 | 64 | # Validation dataset 65 | dataset_val = VOCSegmentationWithPseudolabels( 66 | **cfg.data.val_kwargs, 67 | segments_dir=cfg.segments_dir, 68 | transform=val_transform, 69 | label_map=matching 70 | ) 71 | 72 | return dataset_train, dataset_val, collate_fn 73 | 74 | def get_datasets(cfg): 75 | 76 | # Get transforms 77 | train_transforms_tuple, val_transform = get_transforms(**cfg.data.transform) 78 | 79 | # Get the label map 80 | if cfg.matching: 81 | matching = dict(eval(str(cfg.matching))) 82 | print(f'Using matching: {matching}') 83 | else: 84 | matching = None 85 | 86 | # Training dataset 87 | dataset_train = CustomDataset( 88 | root_dir = cfg.data.train_dataset.root_dir, 89 | gt_dir = cfg.data.train_dataset.gt_dir, 90 | pred_dir = cfg.segments_dir, 91 | image_dir = cfg.data.train_dataset.image_dir, 92 | transform=train_transforms_tuple, 93 | label_map=matching 94 | ) 95 | 96 | # Validation dataset 97 | dataset_val = CustomDataset( 98 | root_dir = cfg.data.val_dataset.root_dir, 99 | gt_dir = cfg.data.val_dataset.gt_dir, 100 | pred_dir = cfg.segments_dir, 101 | image_dir = cfg.data.val_dataset.image_dir, 102 | transform=val_transform, 103 | label_map=matching 104 | ) 105 | 106 | return dataset_train, dataset_val, collate_fn 107 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/extract/MutualInformation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mutual Information calculation between 2 images. 3 | Pytorch implementation, credits go to the respective authors: 4 | https://github.com/connorlee77/pytorch-mutual-information/blob/master/MutualInformation.py 5 | """ 6 | 7 | import os 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | import skimage.io 14 | import matplotlib.pyplot as plt 15 | 16 | from PIL import Image 17 | from torchvision import transforms 18 | 19 | from sklearn.metrics import normalized_mutual_info_score 20 | 21 | 22 | class MutualInformation(nn.Module): 23 | 24 | def __init__(self, sigma=0.1, num_bins=256, normalize=True): 25 | super(MutualInformation, self).__init__() 26 | 27 | self.sigma = sigma 28 | self.num_bins = num_bins 29 | self.normalize = normalize 30 | self.epsilon = 1e-10 31 | 32 | self.bins = nn.Parameter(torch.linspace(0, 255, num_bins).float(), requires_grad=False) 33 | 34 | 35 | def marginalPdf(self, values): 36 | 37 | residuals = values - self.bins.unsqueeze(0).unsqueeze(0) 38 | kernel_values = torch.exp(-0.5*(residuals / self.sigma).pow(2)) 39 | 40 | pdf = torch.mean(kernel_values, dim=1) 41 | normalization = torch.sum(pdf, dim=1).unsqueeze(1) + self.epsilon 42 | pdf = pdf / normalization 43 | 44 | return pdf, kernel_values 45 | 46 | 47 | def jointPdf(self, kernel_values1, kernel_values2): 48 | 49 | joint_kernel_values = torch.matmul(kernel_values1.transpose(1, 2), kernel_values2) 50 | normalization = torch.sum(joint_kernel_values, dim=(1,2)).view(-1, 1, 1) + self.epsilon 51 | pdf = joint_kernel_values / normalization 52 | 53 | return pdf 54 | 55 | 56 | def getMutualInformation(self, input1, input2): 57 | ''' 58 | input1: B, C, H, W 59 | input2: B, C, H, W 60 | 61 | return: scalar 62 | ''' 63 | 64 | # Torch tensors for images between (0, 1) 65 | input1 = input1*255 66 | input2 = input2*255 67 | 68 | B, C, H, W = input1.shape 69 | assert((input1.shape == input2.shape)) 70 | 71 | x1 = input1.view(B, H*W, C) 72 | x2 = input2.view(B, H*W, C) 73 | 74 | pdf_x1, kernel_values1 = self.marginalPdf(x1) 75 | pdf_x2, kernel_values2 = self.marginalPdf(x2) 76 | pdf_x1x2 = self.jointPdf(kernel_values1, kernel_values2) 77 | 78 | H_x1 = -torch.sum(pdf_x1*torch.log2(pdf_x1 + self.epsilon), dim=1) 79 | H_x2 = -torch.sum(pdf_x2*torch.log2(pdf_x2 + self.epsilon), dim=1) 80 | H_x1x2 = -torch.sum(pdf_x1x2*torch.log2(pdf_x1x2 + self.epsilon), dim=(1,2)) 81 | 82 | mutual_information = H_x1 + H_x2 - H_x1x2 83 | 84 | if self.normalize: 85 | mutual_information = 2*mutual_information/(H_x1+H_x2) 86 | 87 | return mutual_information 88 | 89 | 90 | def forward(self, input1, input2): 91 | ''' 92 | input1: B, C, H, W 93 | input2: B, C, H, W 94 | 95 | return: scalar 96 | ''' 97 | return self.getMutualInformation(input1, input2) 98 | 99 | 100 | 101 | if __name__ == '__main__': 102 | 103 | device = 'cuda:0' 104 | 105 | ### Create test cases ### 106 | img1 = Image.open('grad.jpg').convert('L') 107 | img2 = img1.rotate(10) 108 | 109 | arr1 = np.array(img1) 110 | arr2 = np.array(img2) 111 | 112 | mi_true_1 = normalized_mutual_info_score(arr1.ravel(), arr2.ravel()) 113 | mi_true_2 = normalized_mutual_info_score(arr2.ravel(), arr2.ravel()) 114 | 115 | img1 = transforms.ToTensor() (img1).unsqueeze(dim=0).to(device) 116 | img2 = transforms.ToTensor() (img2).unsqueeze(dim=0).to(device) 117 | 118 | # Pair of different images, pair of same images 119 | input1 = torch.cat([img1, img2]) 120 | input2 = torch.cat([img2, img2]) 121 | 122 | MI = MutualInformation(num_bins=256, sigma=0.1, normalize=True).to(device) 123 | mi_test = MI(input1, input2) 124 | 125 | mi_test_1 = mi_test[0].cpu().numpy() 126 | mi_test_2 = mi_test[1].cpu().numpy() 127 | 128 | print('Image Pair 1 | sklearn MI: {}, this MI: {}'.format(mi_true_1, mi_test_1)) 129 | print('Image Pair 2 | sklearn MI: {}, this MI: {}'.format(mi_true_2, mi_test_2)) 130 | 131 | assert(np.abs(mi_test_1 - mi_true_1) < 0.05) 132 | assert(np.abs(mi_test_2 - mi_true_2) < 0.05) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # sashena 132 | thesis_env/ 133 | data1/ 134 | data_old/ 135 | self-training/outputs/ 136 | deep-spectral-segmentation/outputs/ 137 | deep-spectral-segmentation/results/ 138 | .vscode/ 139 | dino/ 140 | notebooks/ 141 | self-training/wandb/ 142 | **/__pycache__/ 143 | dino/__pycache__/ 144 | self-training/.polyaxon/ 145 | self-training/tasks/configs/wandb/server.yaml 146 | self-training/multirun/ 147 | ImageNet-Datasets-Downloader/ 148 | UltrasondConfienceMap/ 149 | myenv/ 150 | .pai_simulator/ 151 | self-training/Dockerfile 152 | .gitignore 153 | test.yaml 154 | settings.yaml 155 | package-lock.json 156 | GEANT-OV-RSA-CA-4.crt 157 | USERTrust-RSA-Certification-Authority.crt 158 | Ubuntu.appx 159 | self-training/tasks/configs/wandb/defaults.yaml 160 | slurm_scripts/*.out 161 | slurm_scripts/*.err 162 | slurm_scripts/logs 163 | evaluation/__pycache__/ 164 | evaluation/outputs/ 165 | evaluation/demo_* 166 | cutler_env/ 167 | evaluation/eval_analysis/artifacts 168 | slurm_scripts/single_eval.sh 169 | slurm_scripts/pipelinewitheval.sh 170 | slurm_scripts/deep_spectral_pipeline.sh 171 | slurm_scripts/cutler_inference.sh 172 | slurm_scripts/*.out 173 | slurm_scripts/*.err 174 | slurm_scripts/*/*.out 175 | slurm_scripts/*/*.err 176 | slurm_scripts/ 177 | final_runs 178 | data_preprocessing/denoise/MPRNet/*.err 179 | data_preprocessing/denoise/MPRNet/*.out 180 | data_preprocessing/denoise/*.err 181 | data_preprocessing/denoise/*.out 182 | data_preprocessing/denoise/MPRNet 183 | data_preprocessing/denoise/MPRNet/Denoising/pretrained_models/model_denoising.pth 184 | data_preprocessing/denoise/MPRNet/Deraining/pretrained_models/model_deraining.pth 185 | spectral-clustering/*.png 186 | spectral-clustering/ann_index.ann 187 | self-training/tasks/configs/wandb/defaults.yaml 188 | slurm/ 189 | sota/*.err 190 | sota/*.out 191 | deep-spectral-segmentation/dino2_models/dinov2_with_attention_extraction/ 192 | data/THYROID 193 | configs/wandb/server.yaml 194 | configs/wandb/defaults.yaml 195 | configs/dataset/thyroid.yaml -------------------------------------------------------------------------------- /self-training/models/dinoLightningModule.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torchvision 6 | from torch import nn 7 | 8 | from lightly.loss import DINOLoss 9 | from lightly.models.modules import DINOProjectionHead 10 | from lightly.models.utils import deactivate_requires_grad, update_momentum 11 | from lightly.utils.scheduler import cosine_schedule 12 | 13 | 14 | # Note: The model and training settings do not follow the reference settings 15 | # from the paper. The settings are chosen such that the example can easily be 16 | # run on a small dataset with a single GPU. 17 | 18 | 19 | class DINO(pl.LightningModule): 20 | def __init__(self, backbone, input_dim, max_epochs=1, optimizer="Adam", lr = 0.001, weight_decay=0): 21 | super().__init__() 22 | 23 | self.max_epochs=max_epochs 24 | self.optimizer_choice=optimizer 25 | self.lr=lr 26 | self.student_backbone = backbone 27 | self.student_head = DINOProjectionHead( 28 | input_dim, 512, 64, 2048, freeze_last_layer=1 29 | ) 30 | self.teacher_backbone = copy.deepcopy(backbone) 31 | self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048) 32 | deactivate_requires_grad(self.teacher_backbone) 33 | deactivate_requires_grad(self.teacher_head) 34 | 35 | self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5) 36 | self.weight_decay=weight_decay 37 | 38 | def forward(self, x): 39 | y = self.student_backbone(x).flatten(start_dim=1) 40 | z = self.student_head(y) 41 | return z 42 | 43 | def forward_teacher(self, x): 44 | y = self.teacher_backbone(x).flatten(start_dim=1) 45 | z = self.teacher_head(y) 46 | return z 47 | 48 | def training_step(self, batch, batch_idx): 49 | return self._common_step(batch, mode='train') 50 | 51 | def validation_step(self, batch, batch_idx): 52 | self._common_step(batch, mode='val') 53 | 54 | def _common_step(self, batch, mode='train'): 55 | momentum = cosine_schedule(self.current_epoch, self.max_epochs, 0.996, 1) 56 | update_momentum(self.student_backbone, self.teacher_backbone, m=momentum) 57 | update_momentum(self.student_head, self.teacher_head, m=momentum) 58 | views, a, b = batch 59 | views = [view.to(self.device) for view in views] 60 | global_views = views[:2] 61 | teacher_out = [self.forward_teacher(view) for view in global_views] 62 | student_out = [self.forward(view) for view in views] 63 | loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch) 64 | 65 | self.log(f'{mode}_loss', loss) 66 | return loss 67 | 68 | 69 | def on_after_backward(self): 70 | self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch) 71 | 72 | def configure_optimizers(self): 73 | if self.optimizer_choice=="Adam": 74 | optim = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 75 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 76 | optim, self.max_epochs 77 | ) 78 | return [optim], [scheduler] 79 | else: 80 | raise NotImplementedError() 81 | 82 | def get_dino_backbone(dino_model_name: str, pretrained_weights = False): 83 | if "dinov2" in dino_model_name: 84 | # eg for dinov2 models like dinov2_vits14 85 | if pretrained_weights: 86 | backbone = torch.hub.load('facebookresearch/dinov2:main', dino_model_name, pretrained=True) 87 | else: 88 | backbone = torch.hub.load('facebookresearch/dinov2:main', dino_model_name, pretrained=False) 89 | else: 90 | if pretrained_weights: 91 | backbone = torch.hub.load('facebookresearch/dino:main', dino_model_name, pretrained=True) 92 | else: 93 | backbone = torch.hub.load('facebookresearch/dino:main', dino_model_name, pretrained=False) 94 | input_dim = backbone.embed_dim 95 | return (backbone, input_dim) 96 | 97 | def get_resnet_backbone(pretrained_weights = False): 98 | # TODO: change to resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') 99 | # unify with the function above 100 | if pretrained_weights: 101 | resnet = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) 102 | else: 103 | resnet = torchvision.models.resnet18() 104 | backbone = nn.Sequential(*list(resnet.children())[:-1]) 105 | input_dim = 512 106 | return (backbone, input_dim) -------------------------------------------------------------------------------- /data_preprocessing/crop_black_borders.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | import argparse 7 | 8 | def crop_image(image, top, left, bottom, right): 9 | width, height = image.size 10 | left = max(0, left) 11 | top = max(0, top) 12 | right = min(width, width - right) 13 | bottom = min(height, height - bottom) 14 | 15 | cropped_image = image.crop((left, top, right, bottom)) 16 | return cropped_image 17 | 18 | def find_distance_to_non_zero(image_array): 19 | top, left, bottom, right = 0, 0, 0, 0 20 | 21 | # Find distance from top 22 | for i in range(image_array.shape[0]): 23 | if np.any(image_array[i]): 24 | top = i 25 | break 26 | 27 | # Find distance from bottom 28 | for i in range(image_array.shape[0] - 1, -1, -1): 29 | if np.any(image_array[i]): 30 | bottom = image_array.shape[0] - i - 1 31 | break 32 | 33 | # Find distance from left 34 | for i in range(image_array.shape[1]): 35 | if np.any(image_array[:, i]): 36 | left = i 37 | break 38 | 39 | # Find distance from right 40 | for i in range(image_array.shape[1] - 1, -1, -1): 41 | if np.any(image_array[:, i]): 42 | right = image_array.shape[1] - i - 1 43 | break 44 | 45 | return top, left, bottom, right 46 | 47 | def preprocess_and_find_distances(image_folder, images): 48 | distances = {'top': [], 'left': [], 'bottom': [], 'right': []} 49 | 50 | for im_name in tqdm(images): 51 | # read file 52 | im_file = os.path.join(image_folder, im_name) 53 | image = Image.open(im_file) 54 | image_array = np.array(image) 55 | 56 | # find distances 57 | top, left, bottom, right = find_distance_to_non_zero(image_array) 58 | 59 | distances['top'].append(top) 60 | distances['left'].append(left) 61 | distances['bottom'].append(bottom) 62 | distances['right'].append(right) 63 | 64 | min_top = min(distances['top']) 65 | min_left = min(distances['left']) 66 | min_bottom = min(distances['bottom']) 67 | min_right = min(distances['right']) 68 | 69 | return min_top, min_left, min_bottom, min_right 70 | 71 | def crop_and_save_images(image_folder, processed_image_folder, images, distances): 72 | if not os.path.exists(processed_image_folder): 73 | os.makedirs(processed_image_folder) 74 | 75 | min_top, min_left, min_bottom, min_right = distances 76 | 77 | for im_name in tqdm(images): 78 | # read file 79 | im_file = os.path.join(image_folder, im_name) 80 | image = Image.open(im_file) 81 | 82 | # crop image 83 | cropped_image = crop_image(image, min_top, min_left, min_bottom, min_right) 84 | 85 | # save cropped image 86 | new_im_file = os.path.join(processed_image_folder, im_name) 87 | cropped_image.save(new_im_file) 88 | 89 | # print(f"Processed {im_name}, Cropped and Saved.") 90 | 91 | def split_images_into_sequences(images): 92 | sequences = {} 93 | for im_name in images: 94 | prefix = im_name.split('_')[0] + '_' 95 | if prefix not in sequences: 96 | sequences[prefix] = [] 97 | sequences[prefix].append(im_name) 98 | return sequences 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser(description='Cutting black borders of ultrasound data (Offline)') 102 | parser.add_argument('--image_folder', type=str, 103 | help='Path to the root dataset folder containing folder "images"') 104 | parser.add_argument('--additional_folders', nargs='+', default=[], 105 | help='List of additional folders to be cropped the same way as the main folder') 106 | 107 | args = parser.parse_args() 108 | 109 | processed_image_folder = args.image_folder + "_cropped" 110 | 111 | images = sorted(os.listdir(args.image_folder)) 112 | 113 | # Split all images into sequences based on the prefix 114 | image_sequences = split_images_into_sequences(images) 115 | 116 | # Process each sequence separately 117 | for sequence_prefix, sequence_images in image_sequences.items(): 118 | print(f'Processing image folder for sequence {sequence_prefix}') 119 | distances = preprocess_and_find_distances(args.image_folder, sequence_images) 120 | crop_and_save_images(args.image_folder, processed_image_folder, sequence_images, distances) 121 | 122 | # Crop images in additional folders 123 | if args.additional_folders: 124 | print(f'Processing additional folders for sequence {sequence_prefix}') 125 | 126 | for folder in args.additional_folders: 127 | processed_image_folder = folder + "_cropped" 128 | crop_and_save_images(folder, processed_image_folder, sequence_images, distances) 129 | 130 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torchvision.models._utils import IntermediateLayerGetter 5 | from torchvision.models.segmentation.deeplabv3 import (ASPP, DeepLabHead, DeepLabV3) 6 | 7 | 8 | def get_deeplab_resnet(num_classes: int, name: str = 'deeplabv3plus', output_stride: int = 8): 9 | 10 | if output_stride == 8: 11 | replace_stride_with_dilation = [False, True, True] 12 | aspp_dilate = [12, 24, 36] 13 | elif output_stride == 16: 14 | replace_stride_with_dilation = [False, False, True] 15 | aspp_dilate = [6, 12, 18] 16 | else: 17 | raise NotImplementedError() 18 | 19 | backbone = torch.hub.load( 20 | 'facebookresearch/dino:main', 21 | 'dino_resnet50', 22 | replace_stride_with_dilation=replace_stride_with_dilation 23 | ) 24 | 25 | inplanes = 2048 26 | low_level_planes = 256 27 | 28 | if name == 'deeplabv3plus': 29 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} 30 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 31 | DeepLab = DeepLabV3Plus 32 | elif name == 'deeplabv3': 33 | return_layers = {'layer4': 'out'} 34 | DeepLab = DeepLabV3 35 | classifier = DeepLabHead(inplanes, num_classes, aspp_dilate) 36 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 37 | 38 | model = DeepLab(backbone, classifier) 39 | return model 40 | 41 | 42 | def get_deeplab_vit(num_classes: int, backbone_name: str = 'vits16', name: str = 'deeplabv3plus'): 43 | 44 | # Backbone 45 | backbone = torch.hub.load('facebookresearch/dino:main', f'dino_{backbone_name}') 46 | 47 | # Classifier 48 | aspp_dilate = [12, 24, 36] 49 | inplanes = low_level_planes = backbone.embed_dim 50 | if name == 'deeplabv3plus': 51 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 52 | DeepLab = DeepLabV3Plus 53 | elif name == 'deeplabv3': 54 | DeepLab = DeepLabV3 55 | classifier = DeepLabHead(inplanes, num_classes, aspp_dilate) 56 | 57 | # Wrap 58 | backbone = VisionTransformerWrapper(backbone) 59 | model = DeepLab(backbone, classifier) 60 | return model 61 | 62 | 63 | class VisionTransformerWrapper(nn.Module): 64 | def __init__(self, backbone): 65 | super().__init__() 66 | self.backbone = backbone 67 | 68 | def forward(self, x): 69 | # Forward 70 | output = self.backbone.get_intermediate_layers(x, n=5) 71 | # Reshaping 72 | assert (len(output) == 5), f'{output.shape=}' 73 | H_patch = x.shape[-2] // self.backbone.patch_embed.patch_size 74 | W_patch = x.shape[-1] // self.backbone.patch_embed.patch_size 75 | out_ll = output[0][:, 1:, :].transpose(-2, -1).unflatten(-1, (H_patch, W_patch)) 76 | out = output[-1][:, 1:, :].transpose(-2, -1).unflatten(-1, (H_patch, W_patch)) 77 | return {'low_level': out_ll, 'out': out} 78 | 79 | 80 | class DeepLabHeadV3Plus(nn.Module): 81 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): 82 | super(DeepLabHeadV3Plus, self).__init__() 83 | self.project = nn.Sequential( 84 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 85 | nn.BatchNorm2d(48), 86 | nn.ReLU(inplace=True), 87 | ) 88 | 89 | self.aspp = ASPP(in_channels, aspp_dilate) 90 | 91 | self.classifier = nn.Sequential( 92 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 93 | nn.BatchNorm2d(256), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(256, num_classes, 1) 96 | ) 97 | self._init_weight() 98 | 99 | def forward(self, feature): 100 | low_level_feature = self.project(feature['low_level']) 101 | output_feature = self.aspp(feature['out']) 102 | output_feature = F.interpolate( 103 | output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) 104 | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1)) 105 | 106 | def _init_weight(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight) 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | 115 | class DeepLabV3Plus(nn.Module): 116 | def __init__(self, backbone, classifier): 117 | super().__init__() 118 | self.backbone = backbone 119 | self.classifier = classifier 120 | 121 | def forward(self, x): 122 | input_shape = x.shape[-2:] 123 | features = self.backbone(x) 124 | x = self.classifier(features) 125 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 126 | return x 127 | -------------------------------------------------------------------------------- /self-training/tasks/features.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | import logging 5 | from pathlib import Path 6 | from custom_utils import utils 7 | import torch 8 | from lightly.data import LightlyDataset 9 | import os 10 | from accelerate import Accelerator 11 | from tqdm import tqdm 12 | from torchvision import transforms 13 | 14 | # A logger for this file 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | def extract_features(cfg: DictConfig) -> None: 19 | # adapted from https://github.com/lukemelas/deep-spectral-segmentation/tree/main/semantic-segmentation 20 | 21 | 22 | # Output 23 | utils.make_output_dir(cfg.output_dir) 24 | 25 | # Models 26 | model_name = cfg.model_name.lower() 27 | model_path = os.path.join(hydra.utils.get_original_cwd(), cfg.model_checkpoint) 28 | print("model path: ", model_path) 29 | model, params = utils.get_model_from_path(model_name, model_path) 30 | val_transform = utils.get_transform(model_name) 31 | model = model.eval() 32 | 33 | # Add resize to the transforms 34 | if 'carotid' in cfg.dataset.name: 35 | # resize to acquare images (val set has varied sizes...) 36 | resize = transforms.Resize((cfg.dataset.input_size,cfg.dataset.input_size)) 37 | else: 38 | resize = transforms.Resize(cfg.dataset.input_size) 39 | transform = transforms.Compose([resize, val_transform]) 40 | 41 | # Add dino_spceific params - hook and numheads 42 | if 'dino' in model_name: 43 | # hook 44 | which_block = -1 45 | feat_out = {} 46 | def hook_fn_forward_qkv(module, input, output): 47 | feat_out["qkv"] = output 48 | model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) 49 | num_heads = params[0] 50 | patch_size = params[1] 51 | 52 | # Dataset 53 | dataset = LightlyDataset( 54 | input_dir = os.path.join(hydra.utils.get_original_cwd(),cfg.dataset.path), 55 | transform=transform) 56 | dataloader = torch.utils.data.DataLoader( 57 | dataset, 58 | batch_size=cfg.loader.batch_size, 59 | num_workers=cfg.loader.num_workers) 60 | log.info(f'Dataset size: {len(dataset)}') 61 | log.info(f'Dataloader size: {len(dataloader)}') 62 | 63 | # Prepare accelerator 64 | cpu = True 65 | if torch.cuda.is_available(): 66 | cpu = False 67 | accelerator = Accelerator(cpu) 68 | model = model.to(accelerator.device) 69 | 70 | # Process 71 | pbar = tqdm(dataloader, desc='Processing') 72 | for i, (samples, targets, fnames) in enumerate(pbar): 73 | output_dict = {} 74 | 75 | # Check if file already exists 76 | id = Path(fnames[0]).stem 77 | output_file = Path(cfg.output_dir) / f'{id}.pth' 78 | if output_file.is_file(): 79 | pbar.write(f'Skipping existing file {str(output_file)}') 80 | continue 81 | 82 | B, C, H, W = samples.shape 83 | # print(f'samples shape: {samples.shape}') 84 | 85 | # Forward and collect features into output dict 86 | if 'dino' in model_name: 87 | # reshape image 88 | P = patch_size 89 | H_patch, W_patch = H // P, W // P 90 | H_pad, W_pad = H_patch * P, W_patch * P 91 | T = H_patch * W_patch + 1 # number of tokens, add 1 for [CLS] 92 | samples = samples[:, :, :H_pad, :W_pad] 93 | samples = samples.to(accelerator.device) 94 | 95 | # extarct features 96 | model.get_intermediate_layers(samples)[0].squeeze(0) 97 | output_qkv = feat_out["qkv"].reshape(B, T, 3, num_heads, -1 // num_heads).permute(2, 0, 3, 1, 4) 98 | output_dict['k'] = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :] 99 | output_dict['patch_size'] = patch_size 100 | output_dict['shape'] = (B, C, H, W) 101 | 102 | elif 'simclr' in model_name: 103 | samples = samples.to(accelerator.device) 104 | output_dict['simclr'] = model(samples).flatten(start_dim=1) 105 | output_dict['shape'] = (output_dict['simclr'].shape) 106 | else: 107 | raise ValueError(model_name) 108 | 109 | # Metadata 110 | output_dict['file'] = fnames[0] 111 | output_dict['id'] = id 112 | output_dict['model_name'] = model_name 113 | 114 | output_dict = {k: (v.detach().cpu() if torch.is_tensor(v) else v) for k, v in output_dict.items()} 115 | 116 | # Save 117 | accelerator.save(output_dict, str(output_file)) 118 | accelerator.wait_for_everyone() 119 | 120 | log.info(f'Saved features to {cfg.output_dir}') 121 | 122 | @hydra.main(version_base=None, config_path="./configs", config_name="features") 123 | def run_experiment(cfg: DictConfig) -> None: 124 | log.info(OmegaConf.to_yaml(cfg)) 125 | log.info("Current working directory : {}".format(os.getcwd())) 126 | 127 | if cfg.experiment.name == "extract_features": 128 | log.info(f"Experiment chosen: {cfg.experiment.name}") 129 | extract_features(cfg) 130 | else: 131 | raise ValueError(f'No experiment called: {cfg.experiment.name}') 132 | 133 | 134 | 135 | if __name__ == "__main__": 136 | run_experiment() -------------------------------------------------------------------------------- /data_preprocessing/custom_normalization.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Optional, Tuple, Any 3 | 4 | import cv2 5 | import torch 6 | from skimage.morphology import binary_dilation, binary_erosion 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms 9 | 10 | import argparse 11 | import os 12 | 13 | class ImagesDataset(Dataset): 14 | """A very simple dataset for loading images.""" 15 | 16 | def __init__(self, filenames: str, images_root: Optional[str] = None, transform: Optional[Callable] = None, 17 | prepare_filenames: bool = True) -> None: 18 | self.root = None if images_root is None else Path(images_root) 19 | self.filenames = sorted(list(set(filenames))) if prepare_filenames else filenames 20 | self.transform = transform 21 | 22 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 23 | path = self.filenames[index] 24 | full_path = Path(path) if self.root is None else self.root / path 25 | assert full_path.is_file(), f'Not a file: {full_path}' 26 | image = cv2.imread(str(full_path)) 27 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 28 | if self.transform is not None: 29 | image = self.transform(image) 30 | return image, path, index 31 | 32 | def __len__(self) -> int: 33 | return len(self.filenames) 34 | 35 | class OnlineMeanStd: 36 | """ A class for calculating mean and std of a given dataset 37 | ref: https://github.com/Nikronic/CoarseNet/blob/master/utils/preprocess.py#L142-L200 38 | """ 39 | def __init__(self): 40 | pass 41 | 42 | def __call__(self, dataset, batch_size, method='strong'): 43 | """ 44 | Calculate mean and std of a dataset in lazy mode (online) 45 | On mode strong, batch size will be discarded because we use batch_size=1 to minimize leaps. 46 | 47 | :param dataset: Dataset object corresponding to your dataset 48 | :param batch_size: higher size, more accurate approximation 49 | :param method: weak: fast but less accurate, strong: slow but very accurate - recommended = strong 50 | :return: A tuple of (mean, std) with size of (3,) 51 | """ 52 | 53 | if method == 'weak': 54 | loader = DataLoader(dataset=dataset, 55 | batch_size=batch_size, 56 | shuffle=False, 57 | num_workers=0, 58 | pin_memory=0) 59 | mean = 0. 60 | std = 0. 61 | nb_samples = 0. 62 | for item in loader: 63 | data, files, indices = item 64 | batch_samples = data.size(0) 65 | data = data.view(batch_samples, data.size(1), -1) 66 | mean += data.mean(2).sum(0) 67 | std += data.std(2).sum(0) 68 | nb_samples += batch_samples 69 | 70 | mean /= nb_samples 71 | std /= nb_samples 72 | 73 | return mean, std 74 | 75 | elif method == 'strong': 76 | loader = DataLoader(dataset=dataset, 77 | batch_size=1, 78 | shuffle=False, 79 | num_workers=0, 80 | pin_memory=0) 81 | cnt = 0 82 | fst_moment = torch.empty(3) 83 | snd_moment = torch.empty(3) 84 | 85 | for item in loader: 86 | data, files, indices = item 87 | b, c, h, w = data.shape 88 | nb_pixels = b * h * w 89 | sum_ = torch.sum(data, dim=[0, 2, 3]) 90 | sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3]) 91 | fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels) 92 | snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels) 93 | 94 | cnt += nb_pixels 95 | 96 | return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2) 97 | 98 | 99 | def find_custom_norm_mean_std(images_list, images_root, batch_size): 100 | filenames = Path(images_list).read_text().splitlines() 101 | dataset_raw = ImagesDataset(filenames=filenames, images_root=images_root, transform=transforms.ToTensor()) 102 | meanStdCalculator = OnlineMeanStd() 103 | mean, std = meanStdCalculator(dataset_raw, batch_size=batch_size, method='strong') 104 | return mean, std 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description='Find a custom mean and std of a given dataset') 108 | parser.add_argument('--images_list', type=str, 109 | help='Path to the txt file containing filenames of the dataset (for compatibility with deep spectral dataset format)') 110 | parser.add_argument('--images_root', type=str, 111 | help='Path to the images folder of the dataset') 112 | parser.add_argument('--batch_size', type=str, 113 | default=100, 114 | help='A batchsize for finding mean and std of a datset in batches (the higher the better)') 115 | 116 | args = parser.parse_args() 117 | 118 | # train dataset for US mixed 119 | # images_root="/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/US_MIXED/train/images" 120 | # images_list="/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/US_MIXED/train/lists/images.txt" 121 | 122 | mean, std = find_custom_norm_mean_std(args.images_list, args.images_root, args.batch_size) 123 | print(f"Dataset_root = {args.images_root}, batch_size = {args.batch_size}") 124 | print(f"Mean = {mean}, std = {std}") 125 | 126 | 127 | -------------------------------------------------------------------------------- /spectral-clustering/spectralnet_per_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from spectralnet import SpectralNet 4 | from sklearn.cluster import KMeans, MiniBatchKMeans 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | 8 | 9 | def get_image_sizes(data_dict: dict, downsample_factor = None): 10 | P = data_dict['patch_size'] if downsample_factor is None else downsample_factor 11 | B, C, H, W = data_dict['shape'] 12 | assert B == 1, 'assumption violated :(' 13 | H_patch, W_patch = H // P, W // P 14 | H_pad, W_pad = H_patch * P, W_patch * P 15 | return (B, C, H, W, P, H_patch, W_patch, H_pad, W_pad) 16 | 17 | def get_border_fraction(segmap: np.array): 18 | num_border_pixels = 2 * (segmap.shape[0] + segmap.shape[1]) 19 | counts_map = {idx: 0 for idx in np.unique(segmap)} 20 | np.zeros(len(np.unique(segmap))) 21 | for border in [segmap[:, 0], segmap[:, -1], segmap[0, :], segmap[-1, :]]: 22 | unique, counts = np.unique(border, return_counts=True) 23 | for idx, count in zip(unique.tolist(), counts.tolist()): 24 | counts_map[idx] += count 25 | # normlized_counts_map = {idx: count / num_border_pixels for idx, count in counts_map.items()} 26 | indices = np.array(list(counts_map.keys())) 27 | normlized_counts = np.array(list(counts_map.values())) / num_border_pixels 28 | return indices, normlized_counts 29 | 30 | def get_segmap(clusters, data_dict): 31 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = get_image_sizes(data_dict) 32 | # Reshape 33 | infer_bg_index = True 34 | if clusters.size == H_patch * W_patch: # TODO: better solution might be to pass in patch index 35 | segmap = clusters.reshape(H_patch, W_patch) 36 | elif clusters.size == H_patch * W_patch * 4: 37 | segmap = clusters.reshape(H_patch * 2, W_patch * 2) 38 | else: 39 | raise ValueError() 40 | 41 | # TODO: Improve this step in the pipeline. 42 | # Background detection: we assume that the segment with the most border pixels is the 43 | # background region. We will always make this region equal 0. 44 | if infer_bg_index: 45 | indices, normlized_counts = get_border_fraction(segmap) 46 | bg_index = indices[np.argmax(normlized_counts)].item() 47 | bg_region = (segmap == bg_index) 48 | zero_region = (segmap == 0) 49 | segmap[bg_region] = 0 50 | segmap[zero_region] = bg_index 51 | 52 | return segmap 53 | 54 | feature_path = "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/liver_mixed_val_mini/exp_clustering_sweep/2023-11-07/13-13-13/seg8_clust6_norm-imagenet_prepr-False_dino1_clusterkmeans/features/dino_vits8/Patient-12-david-01_7.pth" 55 | eigs_path = "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/deep-spectral-segmentation/outputs/liver_mixed_val_mini/exp_clustering_sweep/2023-11-07/13-13-13/seg8_clust6_norm-imagenet_prepr-False_dino1_clusterkmeans/eig/laplacian/Patient-12-david-01_7.pth" 56 | image_path = "/home/guests/oleksandra_tmenova/test/project/thesis-codebase/data/LIVER_MIXED/val_mini/images/Patient-12-david-01_7.png" 57 | 58 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 59 | device = 'cpu' 60 | data_dict = torch.load(feature_path, map_location=device) # map_location='cpu' 61 | data_dict.update(torch.load(eigs_path), map_location=device) 62 | 63 | feats = data_dict['k'].squeeze() 64 | feats2 = data_dict['k'].squeeze() 65 | feats_numpy = feats.cpu().detach().numpy() 66 | n_clusters = 6 67 | 68 | print(f'feats are on device: {feats.device}') 69 | print(f'feats2 are on device: {feats2.device}') 70 | 71 | 72 | # kmeans baseline 73 | kmeans = KMeans(n_clusters=n_clusters, random_state=1) 74 | clusters1 = kmeans.fit_predict(feats_numpy) 75 | 76 | # eigenvectors (deep spectral approach) 77 | # num_eigenvectors = 1000000 78 | # eigenvectors = data_dict['eigenvectors'][1:1+num_eigenvectors].numpy() # take non-constant eigenvectors 79 | # clusters2 = kmeans.fit_predict(eigenvectors.T) 80 | 81 | # spectral_net 82 | spectralnet = SpectralNet(n_clusters=n_clusters, 83 | should_use_siamese=True, 84 | should_use_ae = True) 85 | spectralnet.fit(feats) 86 | clusters3 = spectralnet.predict(feats) 87 | 88 | print(f'feats are on device: {feats.device}') 89 | print(f'feats2 are on device: {feats2.device}') 90 | 91 | spectralnet2 = SpectralNet(n_clusters=n_clusters, 92 | should_use_siamese=True, 93 | should_use_ae = True, 94 | is_sparse_graph=True, 95 | spectral_n_nbg=2) 96 | 97 | feats2 = feats2.to('cuda') 98 | print(f'feats are on device: {feats.device}') 99 | print(f'feats2 are on device: {feats2.device}') 100 | spectralnet2.fit(feats2) 101 | clusters4 = spectralnet2.predict(feats2) 102 | 103 | # get segmentation maps 104 | segmap_kmeans = get_segmap(clusters1, data_dict) 105 | # segmap_eigen = get_segmap(clusters2, data_dict) 106 | segmap_spectralnet = get_segmap(clusters3, data_dict) 107 | segmap_spectralnet2 = get_segmap(clusters4, data_dict) 108 | 109 | 110 | # create a plot 111 | image = np.array(Image.open(image_path)) 112 | 113 | fig, axs = plt.subplots(1, 4, figsize=(8, 8)) 114 | 115 | axs[0].imshow(image) 116 | axs[0].set_title("image") 117 | 118 | axs[1].imshow(segmap_kmeans) 119 | axs[1].set_title("dino_kmeans") 120 | 121 | axs[2].imshow(segmap_spectralnet) 122 | axs[2].set_title("spectralnet_nonsparse_nn30") 123 | 124 | axs[3].imshow(segmap_spectralnet2) 125 | axs[3].set_title("spectralnet_sparse_nn2") 126 | 127 | plt.tight_layout() 128 | fig.savefig("/home/guests/oleksandra_tmenova/test/project/thesis-codebase/spectral-clustering/per_image_clustering_6segments.png") 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Iterable, Optional 4 | 5 | import hydra 6 | import numpy as np 7 | import torch 8 | import wandb 9 | from accelerate import Accelerator 10 | from matplotlib.cm import get_cmap 11 | from omegaconf import DictConfig, OmegaConf 12 | from PIL import Image 13 | from skimage.color import label2rgb 14 | from tqdm import tqdm, trange 15 | 16 | import eval_utils 17 | import util as utils 18 | from dataset.voc import VOCSegmentationWithPseudolabels 19 | 20 | 21 | @hydra.main(config_path='config', config_name='eval') 22 | def main(cfg: DictConfig): 23 | 24 | # Accelerator 25 | accelerator = Accelerator(fp16=cfg.fp16, cpu=cfg.cpu) 26 | 27 | # Logging 28 | utils.setup_distributed_print(accelerator.is_local_main_process) 29 | if cfg.wandb and accelerator.is_local_main_process: 30 | wandb.init(name=cfg.name, job_type=cfg.job_type, config=OmegaConf.to_container(cfg), save_code=True, **cfg.wandb_kwargs) 31 | cfg = DictConfig(wandb.config.as_dict()) # get the config back from wandb for hyperparameter sweeps 32 | 33 | # Configuration 34 | print(OmegaConf.to_yaml(cfg)) 35 | print(f'Current working directory: {os.getcwd()}') 36 | 37 | # Set random seed 38 | utils.set_seed(cfg.seed) 39 | 40 | # Create dataset with segments/pseudolabels 41 | dataset_val = VOCSegmentationWithPseudolabels( 42 | **cfg.data.val_kwargs, 43 | segments_dir=cfg.segments_dir, 44 | transform=None, # no transform to evaluate at original resolution 45 | ) 46 | 47 | # Evaluate 48 | eval_stats, match = evaluate(cfg=cfg, dataset_val=dataset_val, n_clusters=cfg.get('n_clusters', None)) 49 | print(eval_stats) 50 | if cfg.wandb and accelerator.is_local_main_process: 51 | wandb.summary['mIoU'] = eval_stats['mIoU'] 52 | 53 | # Visualize 54 | visualize(cfg=cfg, dataset_val=dataset_val) 55 | 56 | 57 | def visualize( 58 | *, 59 | cfg: DictConfig, 60 | dataset_val: Iterable, 61 | vis_dir: str = './vis'): 62 | 63 | # Visualize 64 | num_vis = 40 65 | vis_dir = Path(vis_dir) 66 | colors = get_cmap('tab20', cfg.data.num_classes + 1).colors[:,:3] 67 | pbar = tqdm(dataset_val, total=num_vis, desc='Saving visualizations: ') 68 | for i, (image, target, mask, metadata) in enumerate(pbar): 69 | if i >= num_vis: break 70 | image = np.array(image) 71 | target = np.array(target) 72 | target[target == 255] = 0 # set the "unknown" regions to background for visualization 73 | # Overlay mask on image 74 | image_pred_overlay = label2rgb(label=mask, image=image, colors=colors[np.unique(target)[1:]], bg_label=0, alpha=0.45) 75 | image_target_overlay = label2rgb(label=target, image=image, colors=colors[np.unique(target)[1:]], bg_label=0, alpha=0.45) 76 | # Save 77 | image_id = metadata["id"] 78 | path_pred = vis_dir / 'pred' / f'{image_id}-pred.png' 79 | path_target = vis_dir / 'target' / f'{image_id}-target.png' 80 | path_pred.parent.mkdir(exist_ok=True, parents=True) 81 | path_target.parent.mkdir(exist_ok=True, parents=True) 82 | Image.fromarray((image_pred_overlay * 255).astype(np.uint8)).save(str(path_pred)) 83 | Image.fromarray((image_target_overlay * 255).astype(np.uint8)).save(str(path_target)) 84 | print(f'Saved visualizations to {vis_dir.absolute()}') 85 | 86 | 87 | def evaluate( 88 | *, 89 | cfg: DictConfig, 90 | dataset_val: Iterable, 91 | n_clusters: Optional[int] = None): 92 | 93 | # Add background class 94 | n_classes = cfg.data.num_classes + 1 95 | if n_clusters is None: 96 | n_clusters = n_classes 97 | 98 | # Iterate 99 | tp = [0] * n_classes 100 | fp = [0] * n_classes 101 | fn = [0] * n_classes 102 | 103 | # Load all pixel embeddings 104 | all_preds = np.zeros((len(dataset_val) * 500 * 500), dtype=np.float32) 105 | all_gt = np.zeros((len(dataset_val) * 500 * 500), dtype=np.float32) 106 | offset_ = 0 107 | 108 | # Add all pixels to our arrays 109 | _alread_warned = 0 110 | for i in trange(len(dataset_val), desc='Concatenating all predictions'): 111 | image, target, mask, metadata = dataset_val[i] 112 | # Check where ground-truth is valid and append valid pixels to the array 113 | valid = (target != 255) 114 | n_valid = np.sum(valid) 115 | all_gt[offset_:offset_+n_valid] = target[valid] 116 | # Append the predicted targets in the array 117 | all_preds[offset_:offset_+n_valid, ] = mask[valid] 118 | all_gt[offset_:offset_+n_valid, ] = target[valid] 119 | # Update offset_ 120 | offset_ += n_valid 121 | 122 | # Truncate to the actual number of pixels 123 | all_preds = all_preds[:offset_, ] 124 | all_gt = all_gt[:offset_, ] 125 | 126 | # Do hungarian matching 127 | num_elems = offset_ 128 | if n_clusters == n_classes: 129 | print('Using hungarian algorithm for matching') 130 | match = eval_utils.hungarian_match(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes, metric='iou') 131 | else: 132 | print('Using majority voting for matching') 133 | match = eval_utils.majority_vote(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes) 134 | print(f'Optimal matching: {match}') 135 | 136 | # Remap predictions 137 | reordered_preds = np.zeros(num_elems, dtype=all_preds.dtype) 138 | for pred_i, target_i in match: 139 | reordered_preds[all_preds == int(pred_i)] = int(target_i) 140 | 141 | # TP, FP, and FN evaluation 142 | for i_part in range(0, n_classes): 143 | tmp_all_gt = (all_gt == i_part) 144 | tmp_pred = (reordered_preds == i_part) 145 | tp[i_part] += np.sum(tmp_all_gt & tmp_pred) 146 | fp[i_part] += np.sum(~tmp_all_gt & tmp_pred) 147 | fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred) 148 | 149 | # Calculate Jaccard index 150 | jac = [0] * n_classes 151 | for i_part in range(0, n_classes): 152 | jac[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8) 153 | 154 | # Print results 155 | eval_result = dict() 156 | eval_result['jaccards_all_categs'] = jac 157 | eval_result['mIoU'] = np.mean(jac) 158 | print('Evaluation of semantic segmentation ') 159 | print('mIoU is %.2f' % (100*eval_result['mIoU'])) 160 | return eval_result, match 161 | 162 | 163 | if __name__ == '__main__': 164 | torch.set_grad_enabled(False) 165 | main() 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Deep Spectral Methods for Unsupervised Ultrasound Image Interpretation 3 | 4 | We integrate key concepts from unsupervised deep spectral methods, which combine spectral graph theory with deep learning methods. We utilize self-supervised transformer features for spectral clustering to generate meaningful segments based on ultrasound-specific metrics and shape and positional priors, ensuring semantic consistency across the dataset. 5 | 6 | ![alt text](https://github.com/alexaatm/UnsupervisedSegmentor4Ultrasound/blob/main/pipeline.png?raw=true) 7 | 8 | 9 | ## Installation 10 | 11 | We recommend using a conda environment following steps 12 | 13 | ```bash 14 | conda env create -f conda.yaml 15 | conda activate dss 16 | ``` 17 | Alternatively, you can install the packages yourself: conda packages are listed in ``req_conda.txt`` file, and other packages not pressent in conda (pip) are listed in ``req_pip.txt``. 18 | 19 | ```bash 20 | conda create --name dss 21 | conda activate dss 22 | conda install --name dss --yes --file req_conda.txt -c fastai -c defaults -c conda-forge 23 | conda install --name dss --yes pip 24 | install -r req_pip.txt 25 | ``` 26 | 27 | This installation worked with Python 3.11. 28 | 29 | 30 | ## Organization 31 | 32 | To run this on your data, you need to prepare the dataset itself (2D images), config for the dataset, and config for the run (the desired parameters). 33 | 34 | ### Data 35 | 36 | Put data in the `data` folder of main repo. `lists` folder has a file `images.txt` with the list of files. `images` and `labels` are `.png` files with mathcing names. The folder should follow the structure: 37 | 38 | 39 | ```markdown 40 | data/ 41 | --DATSET1 42 | ----subfolder1/ 43 | ------images 44 | ------lists 45 | ------labels 46 | ... 47 | ----subfolderN/ 48 | ------images 49 | ------lists 50 | ------labels 51 | ``` 52 | 53 | See `data/README.me` for more details. 54 | 55 | ### Data config 56 | 57 | Add a dataset `.yaml` config to `configs/dataset`. It should follow the structure (example): 58 | 59 | ```markdown 60 | name: carotid 61 | dataset_root: carotid/val 62 | dataset_type: folders 63 | images_root: images 64 | list: lists/images.txt 65 | gt_dir: labels 66 | pred_dir: "" 67 | n_classes: 2 68 | features_dir: "" 69 | preprocessed_dir: "" 70 | derained_dir: "" 71 | eigenseg_dir: "" 72 | ``` 73 | See `configs/dataset/README.me` for another example. 74 | 75 | ### Run config 76 | 77 | There are two ways to set parameters of the pipeline: through the hydra config when running the python script (see Next Section), or through a wandb sweep config (described here). 78 | 79 | To use sweep config, add a new (or modify an existing) yaml file in the `configs/sweep` folder. It has few parameters of the sweep itself (name, type, count), and custom parameters (the config dictionary itself, the steps to evaluate - relevant only if evaluation is on). See [wandb sweeps](https://docs.wandb.ai/guides/sweeps) for more details. You should put here parameters you want to modify (otherwise they have theh value form their default configs). Here is a simple sweep config to explore multiple cluster numbers for the Step II of the pipeline (where eigensegments get merged into semantic clusters). 80 | 81 | 82 | ```markdown 83 | name: num_clusters 84 | seg_for_eval: ['crf_multi_region'] 85 | method: grid 86 | count: 87 | simple: True 88 | sweep_id: null 89 | config: 90 | # generic 91 | segments_num: [15] 92 | clusters_num: [6,9,12,15] 93 | ``` 94 | See other examples of sweep configs in `configs/sweep` folder. 95 | 96 | Note: Sweep config yaml file allows a more succinct way to set multiple values for parameters in order to run the code on multiple values and to let wandb take care of sweeps (such sweeps are nice for evaluation purposes and are easier to control through wanbd compared to hydra sweeps where each run is independant). If you just set one value, then it will be equivalent to a single run. 97 | 98 | It is nice to use swepe config files, because it makes it easier to track what configurations have been tried (as opposed to changing parameters through hydra configs when calling the python script). 99 | 100 | ### wandb config 101 | 102 | You can set the name of the project and the wandb authentication key, to track the pipeline progress in wand. Especially good for evaluation, since all he results and metrics are also logged there. 103 | ## Usage/Examples 104 | 105 | In the root directory, there is a bash file with an example call for the pipeline. 106 | 107 | ```bash 108 | cd deep-spectral-segmentation 109 | 110 | export WANDB_API_KEY= 111 | export WANDB_CONFIG_DIR=/tmp/ 112 | export WANDB_CACHE_DIR=/tmp/ 113 | export WANDB_AGENT_MAX_INITIAL_FAILURE=20 114 | export WANDB__SERVICE_WAIT=600 115 | export XFORMERS_DISABLED=True 116 | 117 | python -m pipeline.pipeline_sweep_subfolders \ 118 | vis=selected \ 119 | pipeline_steps=defaults \ 120 | dataset=thyroid \ 121 | wandb.tag=test \ 122 | sweep=defaults 123 | ``` 124 | 125 | Configs are set through the hydra .yaml configs. For example: 126 | - `selected.yaml` config in `configs/vis` indicates which steps should be visualized 127 | - `defaults.yaml` config in `configs/pipeline_steps` indicates that all steps, except for the evaluation should be completed. Other options include `allTrue` (to generate pseudomasks and to evaluate) or `allFalse` (e.g. when you use precomputed paths and only want to add plots). 128 | - `thyroid.yaml` config in `configs/dataset` described which dataset to run the pipeline on. 129 | - `defaults.yaml` config in `configs/sweep` shows which parameters you want to modify. Note: you could also set all parameters from here (important: if sweep config is used, it will overwrite the same parameters when wandb sweep starts! The pipeline code without sweeps needs some updates and can be used in the future too). 130 | - `test` value of the `wandb.tag` config can be any string value of your choice. 131 | 132 | See other configs in `configs/defaults.yaml`. E.g.you can also set custom paths to choose where to save the results using `custom_path_to_save_data` parameter. 133 | 134 | 135 | 136 | 137 | ## Acknowledgements 138 | 139 | Note: This repo is based on the fork https://github.com/alexaatm/deep-spectral-segmentation, which is our modified version of the original work of https://github.com/lukemelas/deep-spectral-segmentation, credits go to the respective authors who laid the whole foundation of what we were building on! In our fork, we added additional features, such as: new affinities based on image similarity metrics useful for ultrasound (SSD, MI and many more), new priors for the clustering step (position and shape of segments, support of dinov2 models and custom dino models, preprocessing useful for ultrasound (gaussian blurring, histogram equalisation etc.), and configurable pipeline code for tracking runs in wandb. 140 | 141 | In current repo, the abovementioned fork was merged into the main code, as we wanted to clean it from the unused parts of the original code. 142 | 143 | - [Our Modified Version of Deep Spectral Segmentation](https://github.com/alexaatm/deep-spectral-segmentation) 144 | - [Original Deep Spectral Methods Work](https://github.com/lukemelas/deep-spectral-segmentation) 145 | -------------------------------------------------------------------------------- /evaluation/eval_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/guests/oleksandra_tmenova/test/project/.venv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import os\n", 19 | "from pathlib import Path\n", 20 | "from typing import Iterable, Optional\n", 21 | "\n", 22 | "import hydra\n", 23 | "import numpy as np\n", 24 | "import torch\n", 25 | "import wandb\n", 26 | "from accelerate import Accelerator\n", 27 | "from matplotlib.cm import get_cmap\n", 28 | "from omegaconf import DictConfig, OmegaConf\n", 29 | "from PIL import Image\n", 30 | "from skimage.color import label2rgb\n", 31 | "from tqdm import tqdm, trange" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 21, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# evaluation utilities\n", 41 | "import eval_utils\n", 42 | "# for reading a dataset with groundth truth and labels\n", 43 | "from dataset import EvalDataset\n", 44 | "\n", 45 | "\n", 46 | "root_dir = 'demo_dataset'\n", 47 | "custom_dataset = EvalDataset(root_dir)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 22, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# Add background class\n", 57 | "n_classes = 6\n", 58 | "n_clusters = 6\n", 59 | "\n", 60 | "# Iterate\n", 61 | "tp = [0] * n_classes\n", 62 | "fp = [0] * n_classes\n", 63 | "fn = [0] * n_classes\n", 64 | "\n", 65 | "# Load all pixel embeddings\n", 66 | "all_preds = np.zeros((len(custom_dataset) * 500 * 500), dtype=np.float32)\n", 67 | "all_gt = np.zeros((len(custom_dataset) * 500 * 500), dtype=np.float32)\n", 68 | "offset_ = 0" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 23, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stderr", 78 | "output_type": "stream", 79 | "text": [ 80 | "Concatenating all predictions: 100%|██████████| 1/1 [00:00<00:00, 108.21it/s]\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "for i in trange(len(custom_dataset), desc='Concatenating all predictions'):\n", 86 | " image, target, mask = custom_dataset[i]\n", 87 | " # Check where ground-truth is valid and append valid pixels to the array\n", 88 | " valid = (target != 255)\n", 89 | " n_valid = np.sum(valid)\n", 90 | " all_gt[offset_:offset_+n_valid] = target[valid]\n", 91 | " # Append the predicted targets in the array\n", 92 | " all_preds[offset_:offset_+n_valid, ] = mask[valid]\n", 93 | " all_gt[offset_:offset_+n_valid, ] = target[valid]\n", 94 | " # Update offset_\n", 95 | " offset_ += n_valid" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 24, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "Using hungarian algorithm for matching\n", 108 | "Using iou as metric\n", 109 | "Optimal matching: [(0, 5), (1, 1), (2, 0), (3, 2), (4, 3), (5, 4)]\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "# Truncate to the actual number of pixels\n", 115 | "all_preds = all_preds[:offset_, ]\n", 116 | "all_gt = all_gt[:offset_, ]\n", 117 | "\n", 118 | "# Do hungarian matching\n", 119 | "num_elems = offset_\n", 120 | "if n_clusters == n_classes:\n", 121 | " print('Using hungarian algorithm for matching')\n", 122 | " match = eval_utils.hungarian_match(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes, metric='iou')\n", 123 | "else:\n", 124 | " print('Using majority voting for matching')\n", 125 | " match = eval_utils.majority_vote(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes)\n", 126 | "print(f'Optimal matching: {match}')" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 25, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "Evaluation of semantic segmentation \n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "# Remap predictions\n", 144 | "reordered_preds = np.zeros(num_elems, dtype=all_preds.dtype)\n", 145 | "for pred_i, target_i in match:\n", 146 | " reordered_preds[all_preds == int(pred_i)] = int(target_i)\n", 147 | "\n", 148 | "# TP, FP, and FN evaluation\n", 149 | "for i_part in range(0, n_classes):\n", 150 | " tmp_all_gt = (all_gt == i_part)\n", 151 | " tmp_pred = (reordered_preds == i_part)\n", 152 | " tp[i_part] += np.sum(tmp_all_gt & tmp_pred)\n", 153 | " fp[i_part] += np.sum(~tmp_all_gt & tmp_pred)\n", 154 | " fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred)\n", 155 | "\n", 156 | "# Calculate Jaccard index\n", 157 | "jac = [0] * n_classes\n", 158 | "for i_part in range(0, n_classes):\n", 159 | " jac[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8)\n", 160 | "\n", 161 | "# Print results\n", 162 | "eval_result = dict()\n", 163 | "eval_result['jaccards_all_categs'] = jac\n", 164 | "eval_result['mIoU'] = np.mean(jac)\n", 165 | "print('Evaluation of semantic segmentation ')" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 26, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "{'jaccards_all_categs': [0.0,\n", 177 | " 0.7693266832917706,\n", 178 | " 0.0,\n", 179 | " 0.0,\n", 180 | " 0.0,\n", 181 | " 0.46726946546892595],\n", 182 | " 'mIoU': 0.20609935812678273}" 183 | ] 184 | }, 185 | "execution_count": 26, 186 | "metadata": {}, 187 | "output_type": "execute_result" 188 | } 189 | ], 190 | "source": [ 191 | "eval_result" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 27, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "[(0, 5), (1, 1), (2, 0), (3, 2), (4, 3), (5, 4)]" 203 | ] 204 | }, 205 | "execution_count": 27, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | } 209 | ], 210 | "source": [ 211 | "match" 212 | ] 213 | } 214 | ], 215 | "metadata": { 216 | "kernelspec": { 217 | "display_name": "Python-3.8.12", 218 | "language": "python", 219 | "name": "python3" 220 | }, 221 | "language_info": { 222 | "codemirror_mode": { 223 | "name": "ipython", 224 | "version": 3 225 | }, 226 | "file_extension": ".py", 227 | "mimetype": "text/x-python", 228 | "name": "python", 229 | "nbconvert_exporter": "python", 230 | "pygments_lexer": "ipython3", 231 | "version": "3.8.12" 232 | }, 233 | "orig_nbformat": 4 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 2 237 | } 238 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/dataset/voc.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import Any, Callable, Dict, List, Optional, Tuple 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torchvision.datasets.voc import (DATASET_YEAR_DICT, VisionDataset, os, verify_str_arg) 10 | 11 | 12 | def _resize_pseudolabel(pseudolabel, img): 13 | if ( 14 | (pseudolabel.shape[0] == img.shape[0] // 16) or 15 | (pseudolabel.shape[0] == img.shape[0] // 8) or 16 | (pseudolabel.shape[0] == 2 * (img.shape[0] // 16)) 17 | ): 18 | return cv2.resize(pseudolabel, dsize=img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 19 | return pseudolabel 20 | 21 | 22 | class VOCSegmentationWithPseudolabelsBase(VisionDataset): 23 | 24 | _SPLITS_DIR = "Segmentation" 25 | _TARGET_DIR = "SegmentationClass" 26 | _TARGET_FILE_EXT = ".png" 27 | 28 | def __init__( 29 | self, 30 | root: str, 31 | year: str = "2012", 32 | image_set: str = "train", 33 | download: bool = False, 34 | transform: Optional[Callable] = None, 35 | target_transform: Optional[Callable] = None, 36 | transforms: Optional[Callable] = None, 37 | ): 38 | super().__init__(root, transforms, transform, target_transform) 39 | if year == "2007-test": 40 | if image_set == "test": 41 | warnings.warn( 42 | "Acessing the test image set of the year 2007 with year='2007-test' is deprecated. " 43 | "Please use the combination year='2007' and image_set='test' instead." 44 | ) 45 | year = "2007" 46 | else: 47 | raise ValueError( 48 | "In the test image set of the year 2007 only image_set='test' is allowed. " 49 | "For all other image sets use year='2007' instead." 50 | ) 51 | self.year = year 52 | 53 | valid_image_sets = ["train", "trainval", "val"] 54 | if year == "2007": 55 | valid_image_sets.append("test") 56 | self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets) 57 | 58 | key = "2007-test" if year == "2007" and image_set == "test" else year 59 | dataset_year_dict = DATASET_YEAR_DICT[key] 60 | 61 | self.url = dataset_year_dict["url"] 62 | self.filename = dataset_year_dict["filename"] 63 | self.md5 = dataset_year_dict["md5"] 64 | 65 | base_dir = dataset_year_dict["base_dir"] 66 | voc_root = os.path.join(self.root, base_dir) 67 | 68 | if download: 69 | from torchvision.datasets.voc import download_and_extract_archive 70 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) 71 | 72 | if not os.path.isdir(voc_root): 73 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 74 | 75 | splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR) 76 | split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") 77 | 78 | if self.image_set == 'train': # everything except val 79 | image_dir = os.path.join(voc_root, "JPEGImages") 80 | with open(os.path.join(splits_dir, "val.txt"), "r") as f: 81 | val_file_stems = set([stem.strip() for stem in f.readlines()]) 82 | all_image_paths = [p for p in Path(image_dir).iterdir()] 83 | train_image_paths = [str(p) for p in all_image_paths if p.stem not in val_file_stems] 84 | self.images = sorted(train_image_paths) 85 | # For the targets, we will just replicate the same target however many times 86 | target_dir = os.path.join(voc_root, self._TARGET_DIR) 87 | self.targets = [str(next(Path(target_dir).iterdir()))] * len(self.images) 88 | 89 | else: 90 | 91 | with open(os.path.join(split_f), "r") as f: 92 | file_names = [x.strip() for x in f.readlines()] 93 | 94 | image_dir = os.path.join(voc_root, "JPEGImages") 95 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] 96 | 97 | target_dir = os.path.join(voc_root, self._TARGET_DIR) 98 | self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names] 99 | 100 | assert len(self.images) == len(self.targets), ( len(self.images), len(self.targets)) 101 | 102 | @property 103 | def masks(self) -> List[str]: 104 | return self.targets 105 | 106 | def _prepare_label_map(self, label_map): 107 | if label_map is not None: 108 | self.label_map_fn = np.vectorize(label_map.__getitem__) 109 | else: 110 | self.label_map_fn = None 111 | 112 | def _prepare_segments_dir(self, segments_dir): 113 | self.segments_dir = segments_dir 114 | # Get segment and image files, which are assumed to be in correspondence 115 | all_segment_files = sorted(map(str, Path(segments_dir).iterdir())) 116 | all_img_files = sorted(Path(self.images[0]).parent.iterdir()) 117 | assert len(all_img_files) == len(all_segment_files), (len(all_img_files), len(all_segment_files)) 118 | # Create mapping because I named the segment files badly (sequentially instead of by image id) 119 | all_img_stems = [p.stem for p in all_img_files] 120 | valid_img_stems = set([Path(p).stem for p in self.images]) # in our split (e.g. 'val') 121 | segment_files = [] 122 | for i in range(len(all_img_stems)): 123 | if all_img_stems[i] in valid_img_stems: 124 | segment_files.append(all_segment_files[i]) 125 | self.segments = segment_files 126 | assert len(self.segments) == len(self.images), f'{len(self.segments)=} and {len(self.images)=}' 127 | print('Loaded segments and images') 128 | print(f'First image filepath: {self.images[0]}') 129 | print(f'First segmap filepath: {self.segments[0]}') 130 | print(f'Last image filepath: {self.images[-1]}') 131 | print(f'Last segmap filepath: {self.segments[-1]}') 132 | 133 | def _load(self, index: int): 134 | # Load image 135 | img = np.array(Image.open(self.images[index]).convert("RGB")) 136 | target = np.array(Image.open(self.masks[index])) 137 | metadata = {'id': Path(self.images[index]).stem, 'path': self.images[index], 'shape': tuple(img.shape[:2])} 138 | # New: load segmap and accompanying metedata 139 | pseudolabel = np.array(Image.open(self.segments[index])) 140 | pseudolabel = _resize_pseudolabel(pseudolabel, img) # HACK HACK HACK 141 | if self.label_map_fn is not None: 142 | pseudolabel = self.label_map_fn(pseudolabel) 143 | return (img, target, pseudolabel, metadata) 144 | 145 | def __len__(self) -> int: 146 | return len(self.images) 147 | 148 | 149 | class VOCSegmentationWithPseudolabels(VOCSegmentationWithPseudolabelsBase): 150 | 151 | def __init__(self, *args, segments_dir, transform = None, label_map = None, **kwargs): 152 | super().__init__(*args, **kwargs) 153 | self._prepare_segments_dir(segments_dir) 154 | self.transform = transform 155 | self._prepare_label_map(label_map) 156 | 157 | def __getitem__(self, index: int): 158 | img, target, pseudolabel, metadata = self._load(index) 159 | if self.transform is not None: 160 | # Transform 161 | data = self.transform(image=img, mask1=target, mask2=pseudolabel) 162 | # Unpack 163 | img, target, pseudolabel = data['image'], data['mask1'], data['mask2'] 164 | if torch.is_tensor(target): 165 | target = target.long() 166 | if torch.is_tensor(pseudolabel): 167 | pseudolabel = pseudolabel.long() 168 | return img, target, pseudolabel, metadata 169 | -------------------------------------------------------------------------------- /self-training/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import sys 5 | import os 6 | 7 | def calc_hist(img): 8 | # Calculate the histogram of a grayscale image 9 | hist = cv2.calcHist([img], [0], None, [256], [0, 256]) 10 | return hist.flatten() 11 | 12 | def hist_diff(hist1, hist2): 13 | # Calculate the absolute difference between two histograms 14 | diff = np.sum(np.abs(hist1 - hist2)) 15 | return diff 16 | 17 | def detect_shots_from_video(filename): 18 | # Open the video file 19 | cap = cv2.VideoCapture(filename) 20 | 21 | # Check if the video file was opened successfully 22 | if not cap.isOpened(): 23 | print("Error opening video file") 24 | return 25 | 26 | # Get the video properties 27 | fps = cap.get(cv2.CAP_PROP_FPS) 28 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 29 | 30 | # Initialize variables for shot detection 31 | prev_frame = None 32 | prev_hist = None 33 | threshold = 40000 # Adjust this parameter to adjust the sensitivity of the shot detector 34 | 35 | # array for tracking changes 36 | changes = [] 37 | 38 | # Loop over the frames of the video 39 | for i in range(frame_count): 40 | # Read the next frame 41 | ret, frame = cap.read() 42 | if not ret: 43 | break 44 | 45 | # Convert the frame to grayscale 46 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 47 | 48 | # Calculate the histogram of the current frame 49 | curr_hist = calc_hist(gray) 50 | 51 | if prev_frame is not None: 52 | # Calculate the absolute difference between the current and previous histograms 53 | diff = hist_diff(curr_hist, prev_hist) 54 | 55 | # Check if the absolute difference exceeds the threshold 56 | if diff > threshold: 57 | # A shot has been detected 58 | print("Shot detected at time {:.2f} seconds".format(i / fps)) 59 | 60 | changes.append(1) 61 | 62 | else: 63 | changes.append(0) 64 | 65 | # Update the previous frame and histogram 66 | prev_frame = gray 67 | prev_hist = curr_hist 68 | 69 | # Release the video capture object and close the windows 70 | cap.release() 71 | 72 | # plot the array using matplotlib 73 | plt.plot(changes) 74 | 75 | # Add labels and title 76 | plt.xlabel("Frames") 77 | plt.ylabel("Change") 78 | plt.title("Change detection") 79 | 80 | plt.show() 81 | 82 | def detect_shots_from_list(image_list): 83 | """ 84 | Input: image_list of images of type PIL 85 | Output: indices where change happens 86 | """ 87 | image_num = len(image_list) 88 | 89 | # Initialize variables for shot detection 90 | prev_frame = None 91 | prev_hist = None 92 | threshold = 140000 # Adjust this parameter to adjust the sensitivity of the shot detector 93 | 94 | # array for tracking changes 95 | changes = [] 96 | ind_of_change = [] 97 | diffs = [] 98 | 99 | # Loop over the images 100 | for i in range(image_num): 101 | # Read the next frame 102 | image = image_list[i] 103 | 104 | # Convert the image from PIL to grayscale OpenCv image 105 | gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) 106 | 107 | # Calculate the histogram of the current frame 108 | curr_hist = calc_hist(gray) 109 | 110 | if prev_frame is not None: 111 | # Calculate the absolute difference between the current and previous histograms 112 | diff = hist_diff(curr_hist, prev_hist) 113 | diffs.append(diff) 114 | 115 | # Check if the absolute difference exceeds the threshold 116 | if diff > threshold: 117 | # A shot has been detected 118 | print("Shot detected at index {:.2f}".format(i)) 119 | changes.append(1) 120 | ind_of_change.append(i) 121 | 122 | else: 123 | changes.append(0) 124 | 125 | else: 126 | # for the 0th frame there is no prev frame 127 | ind_of_change.append(i) # mark change 128 | changes.append(0) 129 | diffs.append(0) 130 | 131 | # Update the previous frame and histogram 132 | prev_frame = gray 133 | prev_hist = curr_hist 134 | 135 | # plot the array using matplotlib 136 | fig, ax = plt.subplots(figsize = (20, 5)) 137 | ax.plot(range(image_num), changes) 138 | 139 | # normalize the differences 140 | # diffs = (diffs - np.min(diffs)) / (np.max(diffs) - np.min(diffs)) 141 | ax.plot(range(image_num), diffs) 142 | 143 | # Add labels and title 144 | plt.xlabel("Frames") 145 | plt.ylabel("Change") 146 | plt.title("Change detection") 147 | 148 | # add a vertical line at an index where a change occurs 149 | for ind in ind_of_change: 150 | ax.axvline(x=ind, color='r') 151 | ax.text(ind + 0.1, 0, f'i={ind}', rotation=90) 152 | 153 | plt.show() 154 | 155 | return ind_of_change 156 | 157 | def detect_shots_from_list_label(image_list): 158 | """ 159 | Input: image_list of images of type PIL and its label: (PIL Image, 0) 160 | Output: a modified tuple where the labels of each image are set based on belonging to a set of consequent frames with no change 161 | """ 162 | labeled_images_list = [] 163 | image_num = len(image_list) 164 | 165 | # Initialize variables for shot detection 166 | prev_frame = None 167 | prev_hist = None 168 | threshold = 140000 # Adjust this parameter to adjust the sensitivity of the shot detector 169 | prev_class = 0 170 | 171 | # array for tracking changes 172 | changes = [] 173 | ind_of_change = [] 174 | diffs = [] 175 | class_labels = [prev_class] 176 | 177 | # Loop over the images 178 | for i in range(image_num): 179 | # Read the next frame 180 | image = image_list[i][0] 181 | 182 | # Convert the image from PIL to grayscale OpenCv image 183 | gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) 184 | 185 | # Calculate the histogram of the current frame 186 | curr_hist = calc_hist(gray) 187 | 188 | if prev_frame is not None: 189 | # Calculate the absolute difference between the current and previous histograms 190 | diff = hist_diff(curr_hist, prev_hist) 191 | diffs.append(diff) 192 | print(f"diff={diff}") 193 | 194 | # Check if the absolute difference exceeds the threshold 195 | if diff > threshold: 196 | # A shot has been detected 197 | print("Shot detected at index {:.2f}".format(i)) 198 | changes.append(1) 199 | ind_of_change.append(i) 200 | curr_class = prev_class + 1 201 | labeled_images_list.append((image, curr_class)) 202 | # Add a class label to the list of class labels 203 | class_labels.append(curr_class) 204 | 205 | else: 206 | changes.append(0) 207 | labeled_images_list.append((image, prev_class)) 208 | 209 | else: 210 | # for the 0th frame there is no prev frame 211 | ind_of_change.append(i) # mark change 212 | changes.append(0) 213 | diffs.append(0) 214 | curr_class = 0 215 | labeled_images_list.append((image, curr_class)) 216 | 217 | 218 | # Update the previous frame and histogram 219 | prev_frame = gray 220 | prev_hist = curr_hist 221 | prev_class = curr_class 222 | 223 | print(f'Num of changes: {len(ind_of_change)}') 224 | print(f'Num of classes: {curr_class+1}') 225 | 226 | 227 | # plot the array using matplotlib 228 | fig, ax = plt.subplots(figsize = (20, 5)) 229 | ax.plot(range(image_num), changes) 230 | 231 | # normalize the differences 232 | diffs = (diffs - np.min(diffs)) / (np.max(diffs) - np.min(diffs)) 233 | ax.plot(range(image_num), diffs) 234 | 235 | # Add labels and title 236 | plt.xlabel("Frames") 237 | plt.ylabel("Change") 238 | plt.title("Change detection") 239 | 240 | # add a vertical line at an index where a change occurs 241 | for ind in ind_of_change: 242 | ax.axvline(x=ind, color='r') 243 | ax.text(ind + 0.1, 0, f'i={ind}', rotation=90) 244 | 245 | plt.show() 246 | 247 | # for i, (im, label) in enumerate(labeled_images_list): 248 | # print(f'index {i}, class {label}') 249 | 250 | return labeled_images_list, class_labels 251 | 252 | def get_unique_classes_list(image_list): 253 | image_labels_np = np.array([label for _, label in image_list]) 254 | print(image_labels_np) 255 | classes = np.unique(image_labels_np) 256 | return classes.tolist() 257 | 258 | if __name__ == "__main__": 259 | path = sys.argv[1] 260 | 261 | if os.path.isfile(path): 262 | detect_shots_from_video(path) -------------------------------------------------------------------------------- /self-training/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.data import Sampler 4 | from torchvision.transforms import transforms 5 | from datasets import datasets 6 | from PIL import Image 7 | import numpy as np 8 | 9 | class PatchSampler(Sampler): 10 | def __init__(self, dataset, patch_size, patch_mode='random', shuffle=True): 11 | self.dataset = dataset 12 | self.patch_size = patch_size 13 | self.patch_mode = patch_mode 14 | self.shuffle=shuffle 15 | 16 | # generate patches 17 | if self.patch_mode=='grid': 18 | self.grid_sampler_init() 19 | elif self.patch_mode=='random': 20 | self.random_sampler_init() 21 | else: 22 | raise NotImplementedError() 23 | 24 | if self.shuffle: 25 | random.shuffle(self.indices) 26 | 27 | def grid_sampler_init(self): 28 | dataset_indices = [] 29 | for idx in range(len(self.dataset)): 30 | image = self.dataset[idx][0] 31 | w, h = image.size[0], image.size[1] 32 | 33 | # find the crop of the image s.t. that is evenly divisible by the patch size 34 | crop_height = h - h % self.patch_size 35 | crop_width = w - w % self.patch_size 36 | num_patches_h = (crop_height - 1) // self.patch_size + 1 37 | num_patches_w = (crop_width - 1) // self.patch_size + 1 38 | 39 | # get all patches from image with index idx 40 | image_indices = [(idx, w*self.patch_size, h*self.patch_size, self.patch_size) for w in range(num_patches_w) for h in range(num_patches_h)] 41 | dataset_indices.append(image_indices) 42 | # concatenate all indices per all images into one list of indices 43 | self.indices = [i for image_indices in dataset_indices for i in image_indices] 44 | 45 | def random_sampler_init(self): 46 | dataset_indices = [] 47 | for idx in range(len(self.dataset)): 48 | image = self.dataset[idx][0] 49 | w, h = image.size[0], image.size[1] 50 | 51 | num_patches = int(h * w / (self.patch_size**2)) 52 | 53 | # get random patches for a given image 54 | image_indices = [(idx, \ 55 | torch.randint(low = 0, high = w - self.patch_size + 1, size=(1,)).item(), \ 56 | torch.randint(low = 0, high = h - self.patch_size + 1, size=(1,)).item() , \ 57 | self.patch_size) for patch in range(num_patches)] 58 | 59 | dataset_indices.append(image_indices) 60 | # concatenate all indices per all images into one list of indices 61 | self.indices = [i for image_indices in dataset_indices for i in image_indices] 62 | 63 | def __iter__(self): 64 | return iter(self.indices) 65 | 66 | def __len__(self): 67 | return len(self.indices) 68 | 69 | class RandomPatchSampler(Sampler): 70 | def __init__(self, dataset, patch_size, shuffle=True): 71 | self.dataset = dataset 72 | self.patch_size = patch_size 73 | self.shuffle=shuffle 74 | 75 | def __iter__(self): 76 | indices = list(range(len(self.dataset))) 77 | if self.shuffle: 78 | random.shuffle(indices) 79 | for idx in indices: 80 | image = self.dataset[idx][0] 81 | w, h = image.size[0], image.size[1] 82 | 83 | num_patches = int(h * w / (self.patch_size**2)) 84 | 85 | # get random patches for a given image 86 | for patch in range(num_patches): 87 | yield (idx, \ 88 | torch.randint(low = 0, high = w - self.patch_size + 1, size=(1,)).item(), \ 89 | torch.randint(low = 0, high = h - self.patch_size + 1, size=(1,)).item() , \ 90 | self.patch_size) 91 | 92 | def __len__(self): 93 | image = self.dataset[0][0] 94 | w, h = image.size[0], image.size[1] 95 | num_patches_per_image = int(h * w / (self.patch_size**2)) 96 | return len(self.dataset) * num_patches_per_image 97 | 98 | 99 | class TripletPatchSampler(Sampler): 100 | def __init__(self, dataset, patch_size, shuffle=True, min_shift=0.1, max_shift=0.3): 101 | self.dataset = dataset 102 | self.patch_size = patch_size 103 | self.shuffle=shuffle 104 | self.max_shift=max_shift 105 | self.min_shift=min_shift 106 | 107 | 108 | 109 | def __iter__(self): 110 | indices = list(range(len(self.dataset))) 111 | if self.shuffle: 112 | random.shuffle(indices) 113 | for idx in indices: 114 | image = self.dataset[idx][0] 115 | w, h = image.size[0], image.size[1] 116 | 117 | num_patches = int(h * w / (self.patch_size**2)) 118 | 119 | # get random patches for a given image 120 | for i in range(num_patches): 121 | # sample anchor patch 122 | anchor_patch = (torch.randint(low = 0, high = w - self.patch_size + 1, size=(1,)).item(), \ 123 | torch.randint(low = 0, high = h - self.patch_size + 1, size=(1,)).item()) 124 | 125 | # sample positive patch by shifting range from the anchor 126 | w_shift = torch.randint(low = int(self.min_shift * self.patch_size), high = int(self.max_shift * self.patch_size), size=(1,)).item() 127 | h_shift = torch.randint(low = int(self.min_shift * self.patch_size), high = int(self.max_shift * self.patch_size), size=(1,)).item() 128 | pos_patch = (torch.randint(low = max(anchor_patch[0] - w_shift, 0), 129 | high = min(anchor_patch[0] + self.patch_size//2 + w_shift, w - self.patch_size + 1), 130 | size=(1,)).item(), \ 131 | torch.randint(low = max(anchor_patch[1] - h_shift, 0), 132 | high = min(anchor_patch[1] + self.patch_size//2 + h_shift, h - self.patch_size + 1), 133 | size=(1,)).item()) 134 | 135 | # sample negative randomly (TODO: add non random negative sampling) 136 | neg_patch = (torch.randint(low = 0, high = w - self.patch_size + 1, size=(1,)).item(), \ 137 | torch.randint(low = 0, high = h - self.patch_size + 1, size=(1,)).item()) 138 | 139 | yield (idx, anchor_patch, pos_patch, neg_patch, self.patch_size) 140 | 141 | 142 | def __len__(self): 143 | image = self.dataset[0][0] 144 | w, h = image.size[0], image.size[1] 145 | num_patches_per_image = int(h * w / (self.patch_size**2)) 146 | return len(self.dataset) * num_patches_per_image 147 | 148 | 149 | if __name__ == "__main__": 150 | test_path="../data/liver2_mini/train" 151 | 152 | # PATCH DATASET 153 | dataset=datasets.PatchDataset(root=test_path) 154 | 155 | sampler = RandomPatchSampler(dataset=dataset, patch_size=16) 156 | 157 | transform = transforms.Compose([transforms.ToTensor()]) 158 | collate_fn = datasets.BaseCollateFunction(transform) 159 | 160 | # TODO: figure out how to combine sampler with shuffle, like if shufle is True, pick a random image to start with? 161 | dataloader = torch.utils.data.DataLoader( 162 | dataset, 163 | sampler=sampler, 164 | batch_size=1, 165 | collate_fn=collate_fn, 166 | shuffle=False, 167 | drop_last=True, 168 | num_workers=1, 169 | ) 170 | 171 | print(f'Dataset size:{len(dataset)}') 172 | print(f'Sampler size:{len(sampler)}') 173 | print(f'Dataloader size:{len(dataloader)}') 174 | # print(f'Dataloader shape:{dataloader.shape}') 175 | 176 | for batch in dataloader: 177 | im, label, _ = batch 178 | print(f'im[0]:{im[0].shape}, label[0]={label[0]}') 179 | break 180 | 181 | 182 | 183 | 184 | # TRIPLET PATCH DATASET 185 | tripelt_dataset=datasets.TripletPatchDataset(root=test_path) 186 | 187 | triplet_sampler = TripletPatchSampler(dataset=tripelt_dataset, patch_size=16) 188 | 189 | transform = transforms.Compose([transforms.ToTensor()]) 190 | tripelt_collate_fn = datasets.TripletBaseCollateFunction(transform) 191 | 192 | # TODO: figure out how to combine sampler with shuffle, like if shufle is True, pick a random image to start with? 193 | triplet_dataloader = torch.utils.data.DataLoader( 194 | tripelt_dataset, 195 | sampler=triplet_sampler, 196 | batch_size=1, 197 | collate_fn=tripelt_collate_fn, 198 | shuffle=False, 199 | drop_last=True, 200 | num_workers=1, 201 | ) 202 | 203 | print(f'Dataset size:{len(tripelt_dataset)}') 204 | print(f'Sampler size:{len(triplet_sampler)}') 205 | print(f'Dataloader size:{len(triplet_dataloader)}') 206 | 207 | for batch in triplet_dataloader: 208 | (anchor, _, _), (pos, _, _,), (neg, _, _) = batch 209 | print("anchor=", anchor.shape, ", pos=", pos.shape, ", neg=", neg.shape) 210 | break 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /self-training/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | from lightly.data import LightlyDataset, BaseCollateFunction 2 | from PIL import Image 3 | import numpy as np 4 | import torchvision.transforms as T 5 | from typing import List, Tuple 6 | import torch 7 | import cv2 8 | from datasets import dataset_utils 9 | from datasets import samplers 10 | 11 | class TripletDataset(LightlyDataset): 12 | def __init__( 13 | self, 14 | root: str, 15 | transform: object = None, 16 | mode: str ='seq', 17 | ): 18 | super(TripletDataset, self).__init__(root, transform) 19 | self.mode=mode 20 | if self.mode == 'seq': 21 | self.relabeled_list, self.classes_list = dataset_utils.detect_shots_from_list_label(self.dataset) 22 | self.set_dataset(self.relabeled_list) 23 | elif self.mode == 'class': 24 | self.classes_list = dataset_utils.get_unique_classes_list(self.dataset) 25 | print(f"Unique class labels found: {self.classes_list}") 26 | 27 | 28 | def __getitem__(self, index): 29 | if self.mode == 'random': 30 | a_index, p_index, n_index = self.get_random_triplet(index) 31 | elif self.mode == 'seq': 32 | a_index, p_index, n_index = self.get_triplet_by_class(self.relabeled_list, self.classes_list) 33 | elif self.mode == 'class': 34 | a_index, p_index, n_index = self.get_triplet_by_class(self.dataset, self.classes_list) 35 | else: 36 | print(f"No sampling mode called {self.mode}") 37 | raise NotImplementedError() 38 | 39 | # get filenames 40 | a_fname = self.index_to_filename(self.dataset, a_index) 41 | p_fname = self.index_to_filename(self.dataset, p_index) 42 | n_fname = self.index_to_filename(self.dataset, n_index) 43 | 44 | # get samples (image) and targets (label) 45 | a_sample, a_target = self.dataset.__getitem__(a_index) 46 | p_sample, p_target = self.dataset.__getitem__(p_index) 47 | n_sample, n_target = self.dataset.__getitem__(n_index) 48 | 49 | # Return the triplet of images 50 | return ((a_sample, a_target, a_fname), (p_sample, p_target, p_fname), (n_sample, n_target, n_fname)) 51 | 52 | # TODO: add other approaches for triplet sampling 53 | def get_random_triplet(self, index): 54 | """ 55 | Returns a triplet. Anchor, pos, neg are not the same. 56 | Returns: 57 | triplet (tuple of str): A tuple of 3 indices randomly selected from the dataset indices. 58 | """ 59 | # TODO: choose anchor index also randomly from the whole dataset 60 | anchor_index = index 61 | positive_index = np.random.choice(self.__len__()) 62 | while positive_index==anchor_index: 63 | positive_index = np.random.choice(self.__len__()) 64 | 65 | negative_index = np.random.choice(self.__len__()) 66 | while negative_index==anchor_index or negative_index==positive_index: 67 | negative_index = np.random.choice(self.__len__()) 68 | 69 | return (anchor_index, positive_index, negative_index) 70 | 71 | def get_triplet_by_class(self, labeled_list, classes_list): 72 | """ 73 | Returns a triplet. Anchor, pos, neg are not the same based 74 | on a class given by analizing changes in frame sequence. 75 | Returns: 76 | triplet (tuple of str): A tuple of 3 indices selected from the dataset classes. 77 | 78 | Ref. for sampling based on classes: https://github.com/andreasveit/triplet-network-pytorch/blob/master/triplet_mnist_loader.py 79 | """ 80 | image_labels_np = np.array([label for _, label in labeled_list]) 81 | 82 | # pick a class randomly 83 | class_idx = np.random.choice(classes_list) 84 | # print(f'class : {class_idx}') 85 | anchor_index = np.random.choice(np.where(image_labels_np==class_idx)[0]) 86 | positive_index = np.random.choice(np.where(image_labels_np==class_idx)[0]) 87 | if (len(np.where(image_labels_np==class_idx)[0])>1): 88 | # print(f'class {class_idx} has {np.where(image_labels_np==class_idx)[0]} samples') 89 | while positive_index==anchor_index: 90 | positive_index = np.random.choice(np.where(image_labels_np==class_idx)[0]) 91 | negative_index = np.random.choice(np.where(image_labels_np!=class_idx)[0]) 92 | return (anchor_index, positive_index, negative_index) 93 | 94 | def get_dataset(self): 95 | return self.dataset 96 | 97 | def set_dataset(self, ims_labels): 98 | # TODO: check if corresponds to the needed format: list of tuples (PIL IMAGE, int label) 99 | self.dataset = ims_labels 100 | 101 | class TripletBaseCollateFunction(BaseCollateFunction): 102 | def __init__(self, transform: T.Compose, pos_transform: T.Compose): 103 | super(TripletBaseCollateFunction, self).__init__(transform) 104 | self.transform = transform 105 | if pos_transform==None: 106 | self.pos_transform=T.Compose([]) 107 | else: 108 | self.pos_transform = pos_transform 109 | 110 | def forward(self, batch: List[Tuple[ \ 111 | Tuple[Image.Image, int, str], \ 112 | Tuple[Image.Image, int, str], \ 113 | Tuple[Image.Image, int, str]]]) \ 114 | -> Tuple[ \ 115 | Tuple[torch.Tensor, torch.Tensor,torch.Tensor], \ 116 | Tuple[torch.Tensor, torch.Tensor,torch.Tensor], \ 117 | Tuple[torch.Tensor, torch.Tensor,torch.Tensor]]: 118 | """Turns a batch of triplet tuples into a tuple of batches. 119 | Args: 120 | batch: 121 | A batch of 3 tuples, each of tuple of images, labels, and filenames. 122 | Returns: 123 | A tuple of (anchors, labels, and filenames), (positives, labels, and filenames), (negatives, labels, and filenames)). 124 | The images consist of batches corresponding to transformations of the input images. 125 | Reference to basic collate function: https://github.com/lightly-ai/lightly/blob/master/lightly/data/collate.py 126 | """ 127 | 128 | # lists of samples 129 | # anchors is 0th item in a tuple (a,p,n), anchor sample is 0th item in a tuple (sample, target, fname) 130 | a_samples = torch.stack([self.transform(item[0][0]) for item in batch]) 131 | p_samples = torch.stack([self.pos_transform(self.transform(item[1][0])) for item in batch]) 132 | n_samples = torch.stack([self.transform(item[2][0]) for item in batch]) 133 | 134 | # lists of labels (targets) 135 | a_targets = torch.LongTensor([item[0][1] for item in batch]) 136 | p_targets = torch.LongTensor([item[1][1] for item in batch]) 137 | n_targets= torch.LongTensor([item[2][1] for item in batch]) 138 | 139 | # lists of filenames 140 | a_fnames = [item[0][2] for item in batch] 141 | p_fnames = [item[1][2] for item in batch] 142 | n_fnames = [item[2][2] for item in batch] 143 | 144 | return (a_samples, a_targets, a_fnames), (p_samples, p_targets, p_fnames), (n_samples, n_targets, n_fnames) 145 | 146 | class PatchDataset(LightlyDataset): 147 | def __init__( 148 | self, 149 | root: str, 150 | transform: object = None, 151 | ): 152 | super(PatchDataset, self).__init__(root, transform) 153 | 154 | def __getitem__(self, index): 155 | # print(f'index={index}') 156 | if isinstance(index, tuple): 157 | idx, i, j, patch_size = index 158 | # get filename 159 | fname = self.index_to_filename(self.dataset, idx) 160 | 161 | # get samples (image) and targets (label) 162 | sample, target = self.dataset.__getitem__(idx) 163 | 164 | # get a specified patch 165 | # patch = sample[..., i:i+patch_size, j:j+patch_size] 166 | # TODO: check if you need to switch H and W for PIL Image -> i for width, j for height, PIL image has (W, H) 167 | patch = sample.crop((i, j, i+patch_size, j+patch_size)) 168 | # patch.show() 169 | 170 | return (patch, target, f'patch_{i}_{j}_{fname}') 171 | 172 | else: 173 | # just return a full image 174 | # get filename 175 | fname = self.index_to_filename(self.dataset, index) 176 | 177 | # get samples (image) and targets (label) 178 | sample, target = self.dataset.__getitem__(index) 179 | 180 | return (sample, target, fname) 181 | 182 | class TripletPatchDataset(LightlyDataset): 183 | def __init__( 184 | self, 185 | root: str, 186 | transform: object = None, 187 | ): 188 | super(TripletPatchDataset, self).__init__(root, transform) 189 | 190 | def __getitem__(self, index): 191 | # print(f'index={index}') 192 | if isinstance(index, tuple): 193 | idx, a, p, n, patch_size = index 194 | # get filename 195 | fname = self.index_to_filename(self.dataset, idx) 196 | 197 | # get samples (image) and targets (label) 198 | sample, target = self.dataset.__getitem__(idx) 199 | 200 | # get specified patches for a triplet 201 | a_sample = sample.crop((a[0], a[1], a[0]+patch_size, a[1]+patch_size)) 202 | p_sample = sample.crop((p[0], p[1], p[0]+patch_size, p[1]+patch_size)) 203 | n_sample = sample.crop((n[0], n[1], n[0]+patch_size, n[1]+patch_size)) 204 | 205 | # patch.show() 206 | 207 | return ((a_sample, target, f'patch_{a[0]}_{a[1]}_{fname}'), \ 208 | (p_sample, target, f'patch_{p[0]}_{p[1]}_{fname}'), \ 209 | (n_sample, target, f'patch_{n[0]}_{n[1]}_{fname}')) 210 | else: 211 | # just return a full image 212 | # get filename 213 | fname = self.index_to_filename(self.dataset, index) 214 | 215 | # get samples (image) and targets (label) 216 | sample, target = self.dataset.__getitem__(index) 217 | 218 | return (sample, target, fname) 219 | 220 | 221 | if __name__ == "__main__": 222 | # test_path="../data/liver_reduced/train" 223 | # test_path="../data/liver_similar" 224 | test_path="../data/imagenet-4-classes/train" 225 | 226 | 227 | dataset=TripletDataset(root=test_path, mode='class') 228 | print(len(dataset)) 229 | triplet = dataset[0] 230 | print("anchor=", triplet[0], ", pos=", triplet[1], ", neg=", triplet[2]) 231 | 232 | images = dataset.get_dataset() 233 | print(len(images)) 234 | sample = images[0] 235 | print(f'Sample= {sample}') 236 | 237 | 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /self-training/datasets/dino_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | 3 | import PIL 4 | import torchvision.transforms as T 5 | from PIL.Image import Image 6 | from torch import Tensor 7 | 8 | from lightly.transforms.gaussian_blur import GaussianBlur 9 | from lightly.transforms.multi_view_transform import MultiViewTransform 10 | from lightly.transforms.rotation import random_rotation_transform 11 | from lightly.transforms.solarize import RandomSolarization 12 | from lightly.transforms.utils import IMAGENET_NORMALIZE 13 | 14 | from datasets import augmentations 15 | 16 | 17 | class DINOTransform(MultiViewTransform): 18 | """Implements the global and local view augmentations for DINO [0]. 19 | 20 | Code adapted from: 21 | https://github.com/lightly-ai/lightly/blob/master/lightly/transforms/dino_transform.py 22 | 23 | Input to this transform: 24 | PIL Image or Tensor. 25 | 26 | Output of this transform: 27 | List of Tensor of length 2 * global + n_local_views. (8 by default) 28 | 29 | Applies the following augmentations by default: 30 | - Random resized crop 31 | - Random horizontal flip 32 | - Color jitter 33 | - Random gray scale 34 | - Gaussian blur 35 | - Random solarization 36 | - ImageNet normalization 37 | 38 | This class generates two global and a user defined number of local views 39 | for each image in a batch. The code is adapted from [1]. 40 | 41 | - [0]: DINO, 2021, https://arxiv.org/abs/2104.14294 42 | - [1]: https://github.com/facebookresearch/dino 43 | 44 | Attributes: 45 | global_crop_size: 46 | Crop size of the global views. 47 | global_crop_scale: 48 | Tuple of min and max scales relative to global_crop_size. 49 | local_crop_size: 50 | Crop size of the local views. 51 | local_crop_scale: 52 | Tuple of min and max scales relative to local_crop_size. 53 | n_local_views: 54 | Number of generated local views. 55 | hf_prob: 56 | Probability that horizontal flip is applied. 57 | vf_prob: 58 | Probability that vertical flip is applied. 59 | rr_prob: 60 | Probability that random rotation is applied. 61 | rr_degrees: 62 | Range of degrees to select from for random rotation. If rr_degrees is None, 63 | images are rotated by 90 degrees. If rr_degrees is a (min, max) tuple, 64 | images are rotated by a random angle in [min, max]. If rr_degrees is a 65 | single number, images are rotated by a random angle in 66 | [-rr_degrees, +rr_degrees]. All rotations are counter-clockwise. 67 | cj_prob: 68 | Probability that color jitter is applied. 69 | cj_strength: 70 | Strength of the color jitter. `cj_bright`, `cj_contrast`, `cj_sat`, and 71 | `cj_hue` are multiplied by this value. 72 | cj_bright: 73 | How much to jitter brightness. 74 | cj_contrast: 75 | How much to jitter constrast. 76 | cj_sat: 77 | How much to jitter saturation. 78 | cj_hue: 79 | How much to jitter hue. 80 | random_gray_scale: 81 | Probability of conversion to grayscale. 82 | gaussian_blur: 83 | Tuple of probabilities to apply gaussian blur on the different 84 | views. The input is ordered as follows: 85 | (global_view_0, global_view_1, local_views) 86 | kernel_size: 87 | Will be deprecated in favor of `sigmas` argument. If set, the old behavior applies and `sigmas` is ignored. 88 | Used to calculate sigma of gaussian blur with kernel_size * input_size. 89 | kernel_scale: 90 | Old argument. Value is deprecated in favor of sigmas. If set, the old behavior applies and `sigmas` is ignored. 91 | Used to scale the `kernel_size` of a factor of `kernel_scale` 92 | sigmas: 93 | Tuple of min and max value from which the std of the gaussian kernel is sampled. 94 | Is ignored if `kernel_size` is set. 95 | solarization: 96 | Probability to apply solarization on the second global view. 97 | normalize: 98 | Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize. 99 | 100 | """ 101 | 102 | def __init__( 103 | self, 104 | global_crop_size: int = 224, 105 | global_crop_scale: Tuple[float, float] = (0.4, 1.0), 106 | local_crop_size: int = 96, 107 | local_crop_scale: Tuple[float, float] = (0.05, 0.4), 108 | n_local_views: int = 6, 109 | hf_prob: float = 0.5, 110 | vf_prob: float = 0, 111 | rr_prob: float = 0, 112 | rr_degrees: Union[None, float, Tuple[float, float]] = None, 113 | cj_prob: float = 0.8, 114 | cj_strength: float = 0.5, 115 | cj_bright: float = 0.8, 116 | cj_contrast: float = 0.8, 117 | cj_sat: float = 0.4, 118 | cj_hue: float = 0.2, 119 | random_gray_scale: float = 0.2, 120 | gaussian_blur: Tuple[float, float, float] = (1.0, 0.1, 0.5), 121 | kernel_size: Optional[float] = None, 122 | kernel_scale: Optional[float] = None, 123 | sigmas: Tuple[float, float] = (0.1, 2), 124 | solarization_prob: float = 0.2, 125 | normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, 126 | gauss_noise_prob: float = 0.0, 127 | inv_prob: float = 0.0, 128 | hist_norm_prob: float = 0.0 129 | ): 130 | # first global crop 131 | global_transform_0 = DINOViewTransform( 132 | crop_size=global_crop_size, 133 | crop_scale=global_crop_scale, 134 | hf_prob=hf_prob, 135 | vf_prob=vf_prob, 136 | rr_prob=rr_prob, 137 | rr_degrees=rr_degrees, 138 | cj_prob=cj_prob, 139 | cj_strength=cj_strength, 140 | cj_bright=cj_bright, 141 | cj_contrast=cj_contrast, 142 | cj_hue=cj_hue, 143 | cj_sat=cj_sat, 144 | random_gray_scale=random_gray_scale, 145 | gaussian_blur=gaussian_blur[0], 146 | kernel_size=kernel_size, 147 | kernel_scale=kernel_scale, 148 | sigmas=sigmas, 149 | solarization_prob=0, 150 | normalize=normalize, 151 | gauss_noise_prob = 0, 152 | inv_prob = 0, 153 | ) 154 | 155 | # second global crop 156 | global_transform_1 = DINOViewTransform( 157 | crop_size=global_crop_size, 158 | crop_scale=global_crop_scale, 159 | hf_prob=hf_prob, 160 | vf_prob=vf_prob, 161 | rr_prob=rr_prob, 162 | rr_degrees=rr_degrees, 163 | cj_prob=cj_prob, 164 | cj_bright=cj_bright, 165 | cj_contrast=cj_contrast, 166 | cj_hue=cj_hue, 167 | cj_sat=cj_sat, 168 | random_gray_scale=random_gray_scale, 169 | gaussian_blur=gaussian_blur[1], 170 | kernel_size=kernel_size, 171 | kernel_scale=kernel_scale, 172 | sigmas=sigmas, 173 | solarization_prob=solarization_prob, 174 | normalize=normalize, 175 | gauss_noise_prob = gauss_noise_prob, 176 | inv_prob = inv_prob, 177 | hist_norm_prob = hist_norm_prob, 178 | ) 179 | 180 | # transformation for the local small crops 181 | local_transform = DINOViewTransform( 182 | crop_size=local_crop_size, 183 | crop_scale=local_crop_scale, 184 | hf_prob=hf_prob, 185 | vf_prob=vf_prob, 186 | rr_prob=rr_prob, 187 | rr_degrees=rr_degrees, 188 | cj_prob=cj_prob, 189 | cj_strength=cj_strength, 190 | cj_bright=cj_bright, 191 | cj_contrast=cj_contrast, 192 | cj_hue=cj_hue, 193 | cj_sat=cj_sat, 194 | random_gray_scale=random_gray_scale, 195 | gaussian_blur=gaussian_blur[2], 196 | kernel_size=kernel_size, 197 | kernel_scale=kernel_scale, 198 | sigmas=sigmas, 199 | solarization_prob=0, 200 | normalize=normalize, 201 | gauss_noise_prob = gauss_noise_prob, 202 | inv_prob = inv_prob, 203 | hist_norm_prob = hist_norm_prob, 204 | ) 205 | local_transforms = [local_transform] * n_local_views 206 | transforms = [global_transform_0, global_transform_1] 207 | transforms.extend(local_transforms) 208 | super().__init__(transforms) 209 | 210 | 211 | class DINOViewTransform: 212 | def __init__( 213 | self, 214 | crop_size: int = 224, 215 | crop_scale: Tuple[float, float] = (0.4, 1.0), 216 | hf_prob: float = 0.5, 217 | vf_prob: float = 0, 218 | rr_prob: float = 0, 219 | rr_degrees: Union[None, float, Tuple[float, float]] = None, 220 | cj_prob: float = 0.8, 221 | cj_strength: float = 0.5, 222 | cj_bright: float = 0.8, 223 | cj_contrast: float = 0.8, 224 | cj_sat: float = 0.4, 225 | cj_hue: float = 0.2, 226 | random_gray_scale: float = 0.2, 227 | gaussian_blur: float = 1.0, 228 | kernel_size: Optional[float] = None, 229 | kernel_scale: Optional[float] = None, 230 | sigmas: Tuple[float, float] = (0.1, 2), 231 | solarization_prob: float = 0.2, 232 | normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, 233 | gauss_noise_prob: float = 0.0, 234 | inv_prob: float = 0.0, 235 | hist_norm_prob: float = 0.0 236 | ): 237 | transform = [ 238 | T.RandomResizedCrop( 239 | size=crop_size, 240 | scale=crop_scale, 241 | interpolation=PIL.Image.BICUBIC, 242 | ), 243 | T.RandomHorizontalFlip(p=hf_prob), 244 | T.RandomVerticalFlip(p=vf_prob), 245 | random_rotation_transform(rr_prob=rr_prob, rr_degrees=rr_degrees), 246 | T.RandomApply([augmentations.HistogramNormalize()], p = hist_norm_prob), 247 | T.RandomApply( 248 | [ 249 | T.ColorJitter( 250 | brightness=cj_strength * cj_bright, 251 | contrast=cj_strength * cj_contrast, 252 | saturation=cj_strength * cj_sat, 253 | hue=cj_strength * cj_hue, 254 | ) 255 | ], 256 | p=cj_prob, 257 | ), 258 | T.RandomGrayscale(p=random_gray_scale), 259 | GaussianBlur( 260 | kernel_size=kernel_size, 261 | scale=kernel_scale, 262 | sigmas=sigmas, 263 | prob=gaussian_blur, 264 | ), 265 | RandomSolarization(prob=solarization_prob), 266 | T.RandomInvert(p=inv_prob), 267 | T.ToTensor(), 268 | 269 | ] 270 | 271 | transform += [T.RandomApply([augmentations.GaussianNoise()], p = gauss_noise_prob)] 272 | 273 | if normalize: 274 | transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])] 275 | 276 | self.transform = T.Compose(transform) 277 | 278 | def __call__(self, image: Union[Tensor, Image]) -> Tensor: 279 | """ 280 | Applies the transforms to the input image. 281 | 282 | Args: 283 | image: 284 | The input image to apply the transforms to. 285 | 286 | Returns: 287 | The transformed image. 288 | 289 | """ 290 | transformed: Tensor = self.transform(image) 291 | return transformed 292 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/vis/vis_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | from matplotlib.colors import ListedColormap 4 | from matplotlib.cm import get_cmap 5 | import torch 6 | from skimage.color import label2rgb 7 | import cv2 8 | from pathlib import Path 9 | import numpy as np 10 | from PIL import Image 11 | from torchvision.utils import draw_bounding_boxes 12 | from extract import extract_utils as utils 13 | from torchvision import transforms 14 | from torch import nn 15 | 16 | torch.cuda.empty_cache() 17 | from torch.cuda.amp import autocast 18 | 19 | import argparse 20 | 21 | def plot_segmentation( 22 | images_list: str, 23 | images_root: str, 24 | segmentations_dir: str, 25 | bbox_file: str = None, 26 | output_dir: str = "./output_plots/segm" 27 | ): 28 | utils.make_output_dir(output_dir, check_if_empty=False) 29 | 30 | # Inputs 31 | image_paths = [] 32 | segmap_paths = [] 33 | images_root = Path(images_root) 34 | segmentations_dir = Path(segmentations_dir) 35 | for image_file in Path(images_list).read_text().splitlines(): 36 | segmap_file = f'{Path(image_file).stem}.png' 37 | image_paths.append(images_root / image_file) 38 | segmap_paths.append(segmentations_dir / segmap_file) 39 | print(f'Found {len(image_paths)} image and segmap paths') 40 | 41 | # Load optional bounding boxes 42 | if bbox_file is not None: 43 | bboxes_list = torch.load(bbox_file) 44 | 45 | # Colors 46 | colors = get_cmap('tab20', 21).colors[:, :3] 47 | 48 | # Load 49 | for i, (image_path, segmap_path) in enumerate(zip(image_paths, segmap_paths)): 50 | image_id = image_path.stem 51 | 52 | # Load 53 | image = np.array(Image.open(image_path).convert('RGB')) 54 | segmap = np.array(Image.open(segmap_path)) 55 | 56 | # Convert binary 57 | if set(np.unique(segmap).tolist()) == {0, 255}: 58 | segmap[segmap == 255] = 1 59 | 60 | # Resize 61 | segmap_fullres = cv2.resize(segmap, dsize=image.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 62 | 63 | # Only view images with a specific class 64 | # which_index = 1 65 | # if which_index not in np.unique(segmap): 66 | # continue 67 | 68 | fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 6)) 69 | 70 | # Load optional bounding boxes 71 | bboxes = None 72 | if bbox_file is not None: 73 | bboxes = torch.tensor(bboxes_list[i]['bboxes_original_resolution']) 74 | assert bboxes_list[i]['id'] == image_id, f"{bboxes_list[i]['id']} but {image_id}" 75 | image_torch = torch.from_numpy(image).permute(2, 0, 1) 76 | image_with_boxes_torch = draw_bounding_boxes(image_torch, bboxes) 77 | image_with_boxes = image_with_boxes_torch.permute(1, 2, 0).numpy() 78 | 79 | axes[0].imshow(image_with_boxes) 80 | axes[0].set_title('Image with Bounding Boxes') 81 | axes[0].axis('off') 82 | else: 83 | axes[0].imshow(image) 84 | axes[0].set_title('Image') 85 | axes[0].axis('off') 86 | 87 | 88 | # Color 89 | segmap_label_indices, segmap_label_counts = np.unique(segmap, return_counts=True) 90 | blank_segmap_overlay = label2rgb(label=segmap_fullres, image=np.full_like(image, 128), 91 | colors=colors[segmap_label_indices[segmap_label_indices != 0]], bg_label=0, alpha=1.0) 92 | image_segmap_overlay = label2rgb(label=segmap_fullres, image=image, 93 | colors=colors[segmap_label_indices[segmap_label_indices != 0]], bg_label=0, alpha=0.45) 94 | segmap_caption = dict(zip(segmap_label_indices.tolist(), (segmap_label_counts).tolist())) 95 | 96 | # Visualization of blank segmap overlay 97 | axes[1].imshow(blank_segmap_overlay) 98 | axes[1].set_title('Blank Segmentation Overlay') 99 | axes[1].axis('off') 100 | 101 | # Visualization of colored image 102 | axes[2].imshow(image_segmap_overlay) 103 | axes[2].set_title('Image with Segmentation Overlay') 104 | axes[2].axis('off') 105 | 106 | plt.tight_layout() 107 | 108 | # Save the plot 109 | output_filename = os.path.join(output_dir, f"{image_id}.png") 110 | plt.savefig(output_filename) 111 | # plt.close(fig) 112 | 113 | print(f"Plots saved in the output directory: {output_dir}") 114 | 115 | def plot_eigenvectors( 116 | images_list: str, 117 | images_root: str, 118 | eigenvevtors_dir: str, 119 | features_dir: str, 120 | output_dir: str = "./output_plots/eigen" 121 | ): 122 | utils.make_output_dir(output_dir, check_if_empty=False) 123 | 124 | # Inputs 125 | image_paths = [] 126 | eigen_paths = [] 127 | feat_paths = [] 128 | images_root = Path(images_root) 129 | eigenvevtors_dir = Path(eigenvevtors_dir) 130 | features_dir = Path(features_dir) 131 | for image_file in Path(images_list).read_text().splitlines(): 132 | file = f'{Path(image_file).stem}.pth' 133 | image_paths.append(images_root / image_file) 134 | eigen_paths.append(eigenvevtors_dir / file) 135 | feat_paths.append(features_dir / file) 136 | print(f'Found {len(image_paths)} image and eigen paths') 137 | 138 | # Load 139 | for i, (image_path, feat_path, eigen_path) in enumerate(zip(image_paths, feat_paths, eigen_paths)): 140 | image_id = image_path.stem 141 | print(image_id) 142 | 143 | # Load data dictionary 144 | image = np.array(Image.open(image_path).convert('RGB')) 145 | data_dict = torch.load(feat_path, map_location='cpu') 146 | data_dict.update(torch.load(eigen_path, map_location='cpu')) 147 | eigenvec_num = len(data_dict['eigenvectors']) 148 | eigenvectors = data_dict['eigenvectors'][:eigenvec_num].numpy() 149 | # print(eigenvectors.shape) 150 | 151 | # Reshape eigenvevtors 152 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict) 153 | eigenvectors_img=eigenvectors.reshape(eigenvec_num, H_patch, W_patch) 154 | 155 | # Plot 156 | fig, axes = plt.subplots(nrows=2, ncols=eigenvec_num//2+1, figsize=(15, 5)) 157 | for i, eigv_ax_pair in enumerate(zip(axes.flatten(),eigenvectors_img)): 158 | a, eigv = eigv_ax_pair 159 | a.imshow(eigv) 160 | a.title.set_text("eigv "+str(i)) 161 | 162 | for a in axes.flatten(): 163 | a.axis('off') 164 | 165 | plt.tight_layout() 166 | 167 | # Save the plot 168 | output_filename = os.path.join(output_dir, f"{image_id}_eigenvectors.png") 169 | plt.savefig(output_filename) 170 | 171 | # Close plot 172 | plt.close() 173 | 174 | print(f"Plots saved in the output directory: {output_dir}") 175 | 176 | def plot_dino_attn_maps( 177 | images_list: str, 178 | images_root: str, 179 | model_name: str, 180 | model_checkpoint: str = "", 181 | output_dir: str = "./output_plots/dino_attn_maps" 182 | ): 183 | utils.make_output_dir(output_dir, check_if_empty=False) 184 | 185 | 186 | # Inputs 187 | image_paths = [] 188 | images_root = Path(images_root) 189 | for image_file in Path(images_list).read_text().splitlines(): 190 | image_paths.append(images_root / image_file) 191 | print(f'Found {len(image_paths)} image paths') 192 | 193 | 194 | # Get the model 195 | if model_checkpoint=="" or model_checkpoint==None: 196 | model, val_transform, patch_size, nh = utils.get_model(model_name) 197 | else: 198 | model, val_transform, patch_size, nh = utils.get_model_from_checkpoint(model_name, model_checkpoint, just_backbone=True) 199 | 200 | # disable grad 201 | for p in model.parameters(): 202 | p.requires_grad = False 203 | 204 | # put model to cuda device to 205 | model = model.to('cuda') 206 | 207 | # Define transforms 208 | # TODO: check with val_transform 209 | transform = transforms.Compose([ 210 | transforms.ToTensor(), 211 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 212 | ]) 213 | 214 | # Load 215 | for i, image_path in enumerate(image_paths): 216 | image_id = image_path.stem 217 | print(image_id) 218 | 219 | # Load image - sample for processing, input_img for plotting the original image 220 | sample = np.array(Image.open(image_path).convert('RGB')) 221 | input_img = sample 222 | # Convert PIL Image to NumPy array and transpose dimensions 223 | input_img = np.array(input_img).transpose((2, 0, 1)) # Transpose to (channels, height, width) 224 | 225 | # Apply transform 226 | sample = transform(sample) 227 | # print(f'sample.shape={sample.shape}') 228 | 229 | # Plot 230 | w = sample.shape[1] - sample.shape[1] % patch_size 231 | h = sample.shape[2] - sample.shape[2] % patch_size 232 | sample = sample[:, :w, :h].unsqueeze(0) 233 | w_featmap = sample.shape[-2] // patch_size 234 | h_featmap = sample.shape[-1] // patch_size 235 | 236 | # move image to device 237 | sample = sample.to('cuda') 238 | 239 | # get self-attention 240 | with torch.cuda.amp.autocast(): 241 | with torch.no_grad(): 242 | torch.cuda.empty_cache() 243 | attentions = model.get_last_selfattention(sample) 244 | 245 | # we keep only the output patch attention 246 | if 'dinov2' in model_name: 247 | # in dinov2, attentions return tensor with 3 dimensions, if xformers is enabled (make sure export XFORMERS_DISABLED=True) 248 | # If xformers is disabled, the commented code below is not needed 249 | # attentions = torch.unsqueeze(attentions, 1) 250 | # attentions.fill_(nh) 251 | # print(f'attentions.shape={attentions.shape}') 252 | if 'reg' in model_name: 253 | attentions = attentions[0, :, 0, 1+4:].reshape(nh, -1) 254 | else: 255 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 256 | else: 257 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 258 | 259 | # we keep only a certain percentage of the mass 260 | val, idx = torch.sort(attentions) 261 | val /= torch.sum(val, dim=1, keepdim=True) 262 | cumval = torch.cumsum(val, dim=1) 263 | 264 | threshold = 0.6 # We visualize masks obtained by thresholding the self-attention maps to keep xx% of the mass. 265 | th_attn = cumval > (1 - threshold) 266 | idx2 = torch.argsort(idx) 267 | for head in range(nh): 268 | th_attn[head] = th_attn[head][idx2[head]] 269 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 270 | 271 | # interpolate 272 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().detach().numpy() 273 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 274 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().detach().numpy() 275 | attentions_mean = np.mean(attentions, axis=0) 276 | 277 | fig = plt.figure(figsize=(6, 6), dpi=200) 278 | ax = fig.add_subplot(3, 3, 1) 279 | ax.set_title("Input") 280 | ax.imshow(np.transpose(input_img, (1, 2, 0))) 281 | ax.axis("off") 282 | 283 | # visualize self-attention of each head 284 | for i in range(6): 285 | ax = fig.add_subplot(3, 3, i + 4) 286 | ax.set_title("Head " + str(i + 1)) 287 | ax.imshow(attentions[i]) 288 | ax.axis("off") 289 | 290 | ax = fig.add_subplot(3, 3, 2) 291 | ax.set_title("Head Mean") 292 | ax.imshow(attentions_mean) 293 | ax.axis("off") 294 | 295 | fig.tight_layout() 296 | 297 | # Save the plot 298 | output_filename = os.path.join(output_dir, f"{image_id}_{model_name}_attn_maps.png") 299 | fig.savefig(output_filename) 300 | 301 | # Close plot 302 | plt.close() 303 | 304 | print(f"Plots saved in the output directory: {output_dir}") 305 | 306 | 307 | if __name__ == "__main__": 308 | parser = argparse.ArgumentParser(description='Plot DINO Attention Maps') 309 | parser.add_argument('--images_list', type=str, required=True, help='Path to the file containing the list of image filenames') 310 | parser.add_argument('--images_root', type=str, required=True, help='Root directory of the images') 311 | parser.add_argument('--model_checkpoint', type=str, required=False, help='Path to the DINO model checkpoint') 312 | parser.add_argument('--model_name', type=str, required=True, help='Name of the DINO model') 313 | parser.add_argument('--output_dir', type=str, default='./output_plots/dino_attn_maps', help='Output directory for saving plots') 314 | 315 | args = parser.parse_args() 316 | 317 | # Print GPU memory summary 318 | print(torch.cuda.memory_summary()) 319 | 320 | # Call the function with command-line arguments 321 | plot_dino_attn_maps(images_list=args.images_list, 322 | images_root=args.images_root, 323 | model_checkpoint=args.model_checkpoint, 324 | model_name=args.model_name, 325 | output_dir=args.output_dir) 326 | 327 | -------------------------------------------------------------------------------- /deep-spectral-segmentation/semantic-segmentation/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers, mostly from torchvision 3 | """ 4 | import time 5 | import datetime 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torchvision 11 | from dataclasses import dataclass 12 | from collections import defaultdict, deque 13 | from typing import Callable, Optional 14 | from PIL import Image 15 | from accelerate import Accelerator 16 | from omegaconf import DictConfig 17 | 18 | 19 | @dataclass 20 | class TrainState: 21 | epoch: int = 0 22 | step: int = 0 23 | best_val: Optional[float] = None 24 | 25 | 26 | def get_optimizer(cfg: DictConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer: 27 | # Determine the learning rate 28 | if cfg.optimizer.scale_learning_rate_with_batch_size: 29 | lr = accelerator.state.num_processes * cfg.data.loader.batch_size * cfg.optimizer.base_lr 30 | print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format( 31 | ws=accelerator.state.num_processes, bs=cfg.data.loader.batch_size, blr=cfg.lr, lr=lr)) 32 | else: # scale base learning rate by batch size 33 | lr = cfg.lr 34 | print('lr = {lr} (absolute learning rate)'.format(lr=lr)) 35 | # Construct optimizer 36 | if cfg.optimizer.kind == 'torch': 37 | parameters = [p for p in model.parameters() if p.requires_grad] 38 | optimizer = getattr(torch.optim, cfg.optimizer.cls)(parameters, lr=lr, **cfg.optimizer.kwargs) 39 | elif cfg.optimizer.kind == 'timm': 40 | from timm.optim import create_optimizer_v2 41 | optimizer = create_optimizer_v2(model, lr=lr, **cfg.optimizer.kwargs) 42 | elif cfg.optimizer.kind == 'transformers': 43 | import transformers 44 | parameters = [p for p in model.parameters() if p.requires_grad] 45 | optimizer = getattr(transformers, cfg.optimizer.name)(parameters, lr=lr, **cfg.optimizer.kwargs) 46 | else: 47 | raise NotImplementedError(f'invalid optimizer config: {cfg.optimizer}') 48 | return optimizer 49 | 50 | 51 | def get_scheduler(cfg: DictConfig, optimizer: torch.optim.Optimizer) -> Callable: 52 | if cfg.scheduler.kind == 'torch': 53 | Sch = getattr(torch.optim.lr_scheduler, cfg.scheduler.cls) 54 | scheduler = Sch(optimizer=optimizer, **cfg.scheduler.kwargs) 55 | if cfg.scheduler.warmup: 56 | from warmup_scheduler import GradualWarmupScheduler 57 | scheduler = GradualWarmupScheduler( # wrap scheduler with warmup 58 | optimizer, multiplier=1, total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler) 59 | elif cfg.scheduler.kind == 'timm': 60 | from timm.scheduler import create_scheduler 61 | scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs) 62 | elif cfg.scheduler.kind == 'transformers': 63 | from transformers import get_scheduler 64 | scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs) 65 | else: 66 | raise NotImplementedError(f'invalid scheduler config: {cfg.scheduler}') 67 | return scheduler 68 | 69 | 70 | @torch.no_grad() 71 | def accuracy(output, target, topk=(1,)): 72 | """Computes the accuracy over the k top predictions for the specified values of k""" 73 | # reshape 74 | target = target.reshape(-1) 75 | output = output.reshape(target.size(0), -1) 76 | 77 | maxk = max(topk) 78 | batch_size = target.size(0) 79 | 80 | _, pred = output.topk(maxk, 1, True, True) 81 | pred = pred.t() 82 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 83 | 84 | res = [] 85 | for k in topk: 86 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 87 | res.append(correct_k.mul_(100.0 / batch_size)) 88 | return res 89 | 90 | 91 | class SmoothedValue(object): 92 | """Track a series of values and provide access to smoothed values over a 93 | window or the global series average. 94 | """ 95 | 96 | def __init__(self, window_size=20, fmt=None): 97 | if fmt is None: 98 | fmt = "{median:.4f} ({global_avg:.4f})" 99 | self.deque = deque(maxlen=window_size) 100 | self.total = 0.0 101 | self.count = 0 102 | self.fmt = fmt 103 | 104 | def update(self, value, n=1): 105 | self.deque.append(value) 106 | self.count += n 107 | self.total += value * n 108 | 109 | def synchronize_between_processes(self, device='cuda'): 110 | """ 111 | Warning: does not synchronize the deque! 112 | """ 113 | if not using_distributed(): 114 | return 115 | print(f"device={device}") 116 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device) 117 | dist.barrier() 118 | dist.all_reduce(t) 119 | t = t.tolist() 120 | self.count = int(t[0]) 121 | self.total = t[1] 122 | 123 | @property 124 | def median(self): 125 | d = torch.tensor(list(self.deque)) 126 | return d.median().item() 127 | 128 | @property 129 | def avg(self): 130 | d = torch.tensor(list(self.deque), dtype=torch.float32) 131 | return d.mean().item() 132 | 133 | @property 134 | def global_avg(self): 135 | return self.total / self.count 136 | 137 | @property 138 | def max(self): 139 | return max(self.deque) 140 | 141 | @property 142 | def value(self): 143 | return self.deque[-1] 144 | 145 | def __str__(self): 146 | return self.fmt.format( 147 | median=self.median, 148 | avg=self.avg, 149 | global_avg=self.global_avg, 150 | max=self.max, 151 | value=self.value) 152 | 153 | 154 | class MetricLogger(object): 155 | def __init__(self, delimiter="\t"): 156 | self.meters = defaultdict(SmoothedValue) 157 | self.delimiter = delimiter 158 | 159 | def update(self, **kwargs): 160 | n = kwargs.pop('n', 1) 161 | for k, v in kwargs.items(): 162 | if isinstance(v, torch.Tensor): 163 | v = v.item() 164 | assert isinstance(v, (float, int)) 165 | self.meters[k].update(v, n=n) 166 | 167 | def __getattr__(self, attr): 168 | if attr in self.meters: 169 | return self.meters[attr] 170 | if attr in self.__dict__: 171 | return self.__dict__[attr] 172 | raise AttributeError("'{}' object has no attribute '{}'".format( 173 | type(self).__name__, attr)) 174 | 175 | def __str__(self): 176 | loss_str = [] 177 | for name, meter in self.meters.items(): 178 | loss_str.append( 179 | "{}: {}".format(name, str(meter)) 180 | ) 181 | return self.delimiter.join(loss_str) 182 | 183 | def synchronize_between_processes(self, device='cuda'): 184 | for meter in self.meters.values(): 185 | meter.synchronize_between_processes(device=device) 186 | 187 | def add_meter(self, name, meter): 188 | self.meters[name] = meter 189 | 190 | def log_every(self, iterable, print_freq, header=None): 191 | i = 0 192 | if not header: 193 | header = '' 194 | start_time = time.time() 195 | end = time.time() 196 | iter_time = SmoothedValue(fmt='{avg:.4f}') 197 | data_time = SmoothedValue(fmt='{avg:.4f}') 198 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 199 | log_msg = [ 200 | header, 201 | '[{0' + space_fmt + '}/{1}]', 202 | 'eta: {eta}', 203 | '{meters}', 204 | 'time: {time}', 205 | 'data: {data}' 206 | ] 207 | if torch.cuda.is_available(): 208 | log_msg.append('max mem: {memory:.0f}') 209 | log_msg = self.delimiter.join(log_msg) 210 | MB = 1024.0 * 1024.0 211 | for obj in iterable: 212 | data_time.update(time.time() - end) 213 | yield obj 214 | iter_time.update(time.time() - end) 215 | if i % print_freq == 0 or i == len(iterable) - 1: 216 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 217 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 218 | if torch.cuda.is_available(): 219 | print(log_msg.format( 220 | i, len(iterable), eta=eta_string, 221 | meters=str(self), 222 | time=str(iter_time), data=str(data_time), 223 | memory=torch.cuda.max_memory_allocated() / MB)) 224 | else: 225 | print(log_msg.format( 226 | i, len(iterable), eta=eta_string, 227 | meters=str(self), 228 | time=str(iter_time), data=str(data_time))) 229 | i += 1 230 | end = time.time() 231 | total_time = time.time() - start_time 232 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 233 | print('{} Total time: {} ({:.4f} s / it)'.format( 234 | header, total_time_str, total_time / len(iterable))) 235 | 236 | 237 | class NormalizeInverse(torchvision.transforms.Normalize): 238 | """ 239 | Undoes the normalization and returns the reconstructed images in the input domain. 240 | """ 241 | 242 | def __init__(self, mean, std): 243 | mean = torch.as_tensor(mean) 244 | std = torch.as_tensor(std) 245 | std_inv = 1 / (std + 1e-7) 246 | mean_inv = -mean * std_inv 247 | super().__init__(mean=mean_inv, std=std_inv) 248 | 249 | def __call__(self, tensor): 250 | return super().__call__(tensor.clone()) 251 | 252 | 253 | def set_requires_grad(module, requires_grad=True): 254 | for p in module.parameters(): 255 | p.requires_grad = requires_grad 256 | 257 | 258 | def resume_from_checkpoint(cfg, model, optimizer=None, scheduler=None, model_ema=None): 259 | 260 | # Resume model state dict 261 | checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu') 262 | if 'model' in checkpoint: 263 | state_dict, key = checkpoint['model'], 'model' 264 | else: 265 | state_dict, key = checkpoint, 'N/A' 266 | if any(k.startswith('module.') for k in state_dict.keys()): 267 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 268 | print('Removed "module." from checkpoint state dict') 269 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 270 | print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}') 271 | if len(missing_keys): 272 | print(f' - Missing_keys: {missing_keys}') 273 | if len(unexpected_keys): 274 | print(f' - Unexpected_keys: {unexpected_keys}') 275 | # Resume model ema 276 | if cfg.ema.use_ema: 277 | if checkpoint['model_ema']: 278 | model_ema.load_state_dict(checkpoint['model_ema']) 279 | print('Loaded model ema from checkpoint') 280 | else: 281 | model_ema.load_state_dict(model.parameters()) 282 | print('No model ema in checkpoint; loaded current parameters into model') 283 | else: 284 | if 'model_ema' in checkpoint: 285 | print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)') 286 | else: 287 | print('Not using model ema, and no model_ema found in checkpoint.') 288 | 289 | # Resume optimization state 290 | if cfg.checkpoint.resume_training and 'train' in cfg.job_type: 291 | if 'steps' in checkpoint: 292 | checkpoint['step'] = checkpoint['steps'] 293 | assert {'optimizer', 'scheduler', 'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys())) 294 | optimizer.load_state_dict(checkpoint['optimizer']) 295 | scheduler.load_state_dict(checkpoint['scheduler']) 296 | epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val'] 297 | train_state = TrainState(epoch=epoch, step=step, best_val=best_val) 298 | print(f'Loaded optimizer/scheduler at epoch {epoch} from checkpoint') 299 | elif cfg.checkpoint.resume_optimizer_only: 300 | assert 'optimizer' in set(checkpoint.keys()) 301 | optimizer.load_state_dict(checkpoint['optimizer']) 302 | print(f'Loaded optimizer from checkpoint, but did not load scheduler/epoch') 303 | else: 304 | train_state = TrainState() 305 | print('Did not resume training (i.e. optimizer/scheduler/epoch)') 306 | 307 | return train_state 308 | 309 | 310 | def setup_distributed_print(is_master): 311 | """ 312 | This function disables printing when not in master process 313 | """ 314 | import builtins as __builtin__ 315 | builtin_print = __builtin__.print 316 | 317 | def print(*args, **kwargs): 318 | force = kwargs.pop('force', False) 319 | if is_master or force: 320 | builtin_print(*args, **kwargs) 321 | 322 | __builtin__.print = print 323 | 324 | 325 | def using_distributed(): 326 | return dist.is_available() and dist.is_initialized() 327 | 328 | 329 | def get_rank(): 330 | return dist.get_rank() if using_distributed() else 0 331 | 332 | 333 | def set_seed(seed): 334 | rank = get_rank() 335 | seed = seed + rank 336 | torch.manual_seed(seed) 337 | torch.cuda.manual_seed(seed) 338 | np.random.seed(seed) 339 | random.seed(seed) 340 | torch.backends.cudnn.enabled = True 341 | torch.backends.cudnn.benchmark = True 342 | if using_distributed(): 343 | print(f'Seeding node {rank} with seed {seed}', force=True) 344 | else: 345 | print(f'Seeding node {rank} with seed {seed}') 346 | 347 | 348 | def tensor_to_pil(image: torch.Tensor): 349 | assert len(image.shape) and image.shape[0] == 3, f"{image.shape=}" 350 | image = (image.float() * 0.5 + 0.5).clamp(0, 1).detach().cpu().requires_grad_(False) 351 | ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 352 | return Image.fromarray(ndarr) 353 | 354 | 355 | def albumentations_to_torch(transform): 356 | def _transform(img, target): 357 | augmented = transform(image=img, mask=target) 358 | return augmented['image'], augmented['mask'] 359 | return _transform 360 | --------------------------------------------------------------------------------