├── .gitignore ├── Data └── splits │ ├── Brats21_test.csv │ ├── Brats21_val.csv │ ├── IXI_test.csv │ ├── IXI_train_fold0.csv │ ├── IXI_train_fold1.csv │ ├── IXI_train_fold2.csv │ ├── IXI_train_fold3.csv │ ├── IXI_train_fold4.csv │ ├── IXI_val_fold0.csv │ ├── IXI_val_fold1.csv │ ├── IXI_val_fold2.csv │ ├── IXI_val_fold3.csv │ ├── IXI_val_fold4.csv │ ├── MSLUB_test.csv │ ├── MSLUB_val.csv │ └── avail_t2.csv ├── README.md ├── cDDPM_Model.png ├── configs ├── callbacks │ └── checkpoint.yaml ├── config.yaml ├── datamodule │ └── IXI.yaml ├── experiment │ └── cDDPM │ │ ├── DDPM.yaml │ │ ├── DDPM_cond_spark_2D.yaml │ │ ├── DDPM_patched.yaml │ │ └── Spark_2D_pretrain.yaml ├── logger │ ├── csv.yaml │ └── wandb.yaml ├── mode │ └── default.yaml ├── model │ ├── DDPM_2D.yaml │ ├── DDPM_2D_patched.yaml │ └── Spark_2D.yaml └── trainer │ └── default.yaml ├── environment.yml ├── pc_environment.env ├── preprocessing ├── cut.py ├── extract_masks.py ├── get_mask.py ├── n4filter.py ├── prepare_Brats21.sh ├── prepare_IXI.sh ├── prepare_MSLUB.sh ├── registration.py ├── replace.py ├── resample.py └── sri_atlas │ ├── LICENSE.sri24 │ └── templates │ ├── EPI.nii │ ├── EPI_brain.nii │ ├── PD.nii │ ├── PD_brain.nii │ ├── T1.nii │ ├── T1_brain.nii │ ├── T2.nii │ └── T2_brain.nii ├── requirements.txt ├── run.py └── src ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc └── train.cpython-39.pyc ├── datamodules ├── Datamodules_eval.py ├── Datamodules_train.py ├── __init__.py ├── __pycache__ │ ├── Datamodules_eval.cpython-39.pyc │ ├── Datamodules_train.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ └── create_dataset.cpython-39.pyc └── create_dataset.py ├── models ├── DDPM_2D.py ├── DDPM_2D_patched.py ├── LDM │ ├── __pycache__ │ │ └── util.cpython-39.pyc │ ├── lr_scheduler.py │ ├── models │ │ ├── __pycache__ │ │ │ ├── autoencoder.cpython-38.pyc │ │ │ └── autoencoder.cpython-39.pyc │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── ddim.cpython-38.pyc │ │ │ ├── ddim.cpython-39.pyc │ │ │ ├── ddpm.cpython-38.pyc │ │ │ └── ddpm.cpython-39.pyc │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── ddpm_class.py │ │ │ └── plms.py │ ├── modules │ │ ├── __pycache__ │ │ │ ├── attention.cpython-38.pyc │ │ │ ├── attention.cpython-39.pyc │ │ │ ├── ema.cpython-38.pyc │ │ │ ├── ema.cpython-39.pyc │ │ │ └── x_transformer.cpython-38.pyc │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── model.cpython-38.pyc │ │ │ │ ├── model.cpython-39.pyc │ │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ │ ├── openaimodel.cpython-39.pyc │ │ │ │ ├── util.cpython-38.pyc │ │ │ │ └── util.cpython-39.pyc │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── distributions.cpython-38.pyc │ │ │ │ └── distributions.cpython-39.pyc │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── modules.cpython-38.pyc │ │ │ │ └── modules.cpython-39.pyc │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── contperceptual.cpython-38.pyc │ │ │ │ └── vqperceptual.cpython-38.pyc │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py ├── Spark_2D.py ├── __init__.py ├── __pycache__ │ ├── DDPM_2D.cpython-39.pyc │ ├── Spark_2D.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ └── losses.cpython-39.pyc ├── losses.py └── modules │ ├── DDPM_encoder.py │ ├── OpenAI_Unet.py │ ├── __pycache__ │ ├── DDPM_encoder.cpython-39.pyc │ ├── OpenAI_Unet.cpython-39.pyc │ └── cond_DDPM.cpython-39.pyc │ ├── cond_DDPM.py │ └── spark │ ├── Spark_2D.py │ ├── __pycache__ │ ├── Spark_2D.cpython-39.pyc │ ├── decoder.cpython-39.pyc │ ├── encoder.cpython-39.pyc │ ├── models.cpython-39.pyc │ └── resnet.cpython-39.pyc │ ├── decoder.py │ ├── encoder.py │ ├── models.py │ └── resnet.py ├── train.py └── utils ├── LDM.py ├── __pycache__ ├── generate_noise._extrapolate2-235.py39.1.nbc ├── generate_noise._extrapolate2-235.py39.nbi ├── generate_noise._extrapolate2-250.py39.1.nbc ├── generate_noise._extrapolate2-250.py39.nbi ├── generate_noise._noise2-251.py39.1.nbc ├── generate_noise._noise2-251.py39.nbi ├── generate_noise._noise2-266.py39.1.nbc ├── generate_noise._noise2-266.py39.nbi ├── generate_noise._noise2a-352.py39.1.nbc ├── generate_noise._noise2a-352.py39.nbi ├── generate_noise._noise2a-367.py39.1.nbc ├── generate_noise._noise2a-367.py39.nbi ├── generate_noise.cpython-39.pyc ├── patch_sampling.cpython-39.pyc ├── utils.cpython-39.pyc └── utils_eval.cpython-39.pyc ├── generate_noise.py ├── patch_sampling.py ├── pos_embed.py ├── taming.py ├── utils.py └── utils_eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | ### VisualStudioCode 3 | .vscode/* 4 | !.vscode/settings.json 5 | !.vscode/tasks.json 6 | !.vscode/launch.json 7 | !.vscode/extensions.json 8 | *.code-workspace 9 | **/.vscode 10 | 11 | # Lightning-Hydra-Template 12 | data/ 13 | logs/ 14 | wandb/ 15 | .autoenv 16 | # pc_environment.env 17 | outputlog* 18 | errorlog* -------------------------------------------------------------------------------- /Data/splits/IXI_val_fold0.csv: -------------------------------------------------------------------------------- 1 | ,img_name,"SEX_ID (1=m, 2=f)",HEIGHT,WEIGHT,ETHNIC_ID,MARITAL_ID,OCCUPATION_ID,QUALIFICATION_ID,DATE_AVAILABLE,STUDY_DATE,age,label,img_path,mask_path,seg_path,Agegroup 2 | 321,IXI357-HH-2076_t1.nii.gz,2,155,58,1,2,5,4,1,2006-04-20,63.54,0,/Train/ixi/t1/IXI357-HH-2076_t1.nii.gz,/Train/ixi/mask/IXI357-HH-2076_mask.nii.gz,,3.0 3 | 285,IXI313-HH-2241_t1.nii.gz,1,183,92,1,3,5,5,1,2006-08-03,65.85,0,/Train/ixi/t1/IXI313-HH-2241_t1.nii.gz,/Train/ixi/mask/IXI313-HH-2241_mask.nii.gz,,3.0 4 | 122,IXI141-Guys-0789_t1.nii.gz,2,163,57,1,3,1,4,1,2005-10-14,46.01,0,/Train/ixi/t1/IXI141-Guys-0789_t1.nii.gz,/Train/ixi/mask/IXI141-Guys-0789_mask.nii.gz,,2.0 5 | 107,IXI123-Guys-0774_t1.nii.gz,2,170,60,1,2,2,4,1,2005-09-02,46.58,0,/Train/ixi/t1/IXI123-Guys-0774_t1.nii.gz,/Train/ixi/mask/IXI123-Guys-0774_mask.nii.gz,,2.0 6 | 568,IXI597-IOP-1161_t1.nii.gz,2,149,52,1,1,1,2,1,2006-12-04,43.08,0,/Train/ixi/t1/IXI597-IOP-1161_t1.nii.gz,/Train/ixi/mask/IXI597-IOP-1161_mask.nii.gz,,2.0 7 | 241,IXI266-Guys-0853_t1.nii.gz,2,165,102,1,4,1,4,1,2006-01-20,61.07,0,/Train/ixi/t1/IXI266-Guys-0853_t1.nii.gz,/Train/ixi/mask/IXI266-Guys-0853_mask.nii.gz,,3.0 8 | 408,IXI442-IOP-1041_t1.nii.gz,1,185,84,1,2,5,5,1,2006-07-13,61.37,0,/Train/ixi/t1/IXI442-IOP-1041_t1.nii.gz,/Train/ixi/mask/IXI442-IOP-1041_mask.nii.gz,,3.0 9 | 94,IXI109-Guys-0732_t1.nii.gz,2,164,66,5,2,6,5,1,2005-08-01,36.47,0,/Train/ixi/t1/IXI109-Guys-0732_t1.nii.gz,/Train/ixi/mask/IXI109-Guys-0732_mask.nii.gz,,1.0 10 | 513,IXI541-IOP-1146_t1.nii.gz,2,165,58,1,1,1,5,1,2006-09-25,36.42,0,/Train/ixi/t1/IXI541-IOP-1146_t1.nii.gz,/Train/ixi/mask/IXI541-IOP-1146_mask.nii.gz,,1.0 11 | 180,IXI204-HH-1651_t1.nii.gz,2,175,96,1,2,5,1,1,2005-11-07,74.64,0,/Train/ixi/t1/IXI204-HH-1651_t1.nii.gz,/Train/ixi/mask/IXI204-HH-1651_mask.nii.gz,,3.0 12 | 427,IXI461-Guys-0998_t1.nii.gz,2,160,68,1,2,5,4,1,2006-07-24,60.9,0,/Train/ixi/t1/IXI461-Guys-0998_t1.nii.gz,/Train/ixi/mask/IXI461-Guys-0998_mask.nii.gz,,3.0 13 | 576,IXI606-HH-2601_t1.nii.gz,1,178,86,1,1,8,5,1,2006-10-25,60.54,0,/Train/ixi/t1/IXI606-HH-2601_t1.nii.gz,/Train/ixi/mask/IXI606-HH-2601_mask.nii.gz,,3.0 14 | 303,IXI331-IOP-0892_t1.nii.gz,1,170,59,3,0,0,0,1,2006-03-06,23.49,0,/Train/ixi/t1/IXI331-IOP-0892_t1.nii.gz,/Train/ixi/mask/IXI331-IOP-0892_mask.nii.gz,,1.0 15 | 453,IXI489-Guys-1014_t1.nii.gz,2,155,50,1,2,5,5,1,2006-08-01,69.07,0,/Train/ixi/t1/IXI489-Guys-1014_t1.nii.gz,/Train/ixi/mask/IXI489-Guys-1014_mask.nii.gz,,3.0 16 | 518,IXI547-IOP-1149_t1.nii.gz,1,170,70,2,3,1,5,1,2006-10-04,33.83,0,/Train/ixi/t1/IXI547-IOP-1149_t1.nii.gz,/Train/ixi/mask/IXI547-IOP-1149_mask.nii.gz,,1.0 17 | 215,IXI239-HH-2296_t1.nii.gz,2,160,91,1,4,5,5,1,2006-08-17,62.96,0,/Train/ixi/t1/IXI239-HH-2296_t1.nii.gz,/Train/ixi/mask/IXI239-HH-2296_mask.nii.gz,,3.0 18 | 266,IXI294-IOP-0868_t1.nii.gz,2,164,65,6,1,3,5,1,2006-02-06,27.08,0,/Train/ixi/t1/IXI294-IOP-0868_t1.nii.gz,/Train/ixi/mask/IXI294-IOP-0868_mask.nii.gz,,1.0 19 | 129,IXI148-HH-1453_t1.nii.gz,2,163,73,1,1,3,3,1,2005-09-01,40.56,0,/Train/ixi/t1/IXI148-HH-1453_t1.nii.gz,/Train/ixi/mask/IXI148-HH-1453_mask.nii.gz,,2.0 20 | 165,IXI188-Guys-0798_t1.nii.gz,2,0,114,1,1,1,2,1,2005-10-21,44.02,0,/Train/ixi/t1/IXI188-Guys-0798_t1.nii.gz,/Train/ixi/mask/IXI188-Guys-0798_mask.nii.gz,,2.0 21 | 22,IXI034-HH-1260_t1.nii.gz,2,163,55,1,1,3,5,1,2005-06-23,24.34,0,/Train/ixi/t1/IXI034-HH-1260_t1.nii.gz,/Train/ixi/mask/IXI034-HH-1260_mask.nii.gz,,1.0 22 | 548,IXI574-IOP-1156_t1.nii.gz,2,157,95,1,1,6,2,1,2006-10-30,50.57,0,/Train/ixi/t1/IXI574-IOP-1156_t1.nii.gz,/Train/ixi/mask/IXI574-IOP-1156_mask.nii.gz,,2.0 23 | 444,IXI481-HH-2175_t1.nii.gz,1,173,104,4,2,2,1,1,2006-07-06,65.33,0,/Train/ixi/t1/IXI481-HH-2175_t1.nii.gz,/Train/ixi/mask/IXI481-HH-2175_mask.nii.gz,,3.0 24 | 9,IXI020-Guys-0700_t1.nii.gz,1,178,72,1,2,1,4,1,2005-06-24,39.47,0,/Train/ixi/t1/IXI020-Guys-0700_t1.nii.gz,/Train/ixi/mask/IXI020-Guys-0700_mask.nii.gz,,1.0 25 | 483,IXI516-HH-2297_t1.nii.gz,2,170,86,1,4,1,2,1,2006-08-17,60.41,0,/Train/ixi/t1/IXI516-HH-2297_t1.nii.gz,/Train/ixi/mask/IXI516-HH-2297_mask.nii.gz,,3.0 26 | 21,IXI033-HH-1259_t1.nii.gz,1,174,62,1,1,3,5,1,2005-06-23,24.76,0,/Train/ixi/t1/IXI033-HH-1259_t1.nii.gz,/Train/ixi/mask/IXI033-HH-1259_mask.nii.gz,,1.0 27 | 577,IXI607-Guys-1097_t1.nii.gz,1,0,960,1,5,5,1,1,2006-11-06,83.81,0,/Train/ixi/t1/IXI607-Guys-1097_t1.nii.gz,/Train/ixi/mask/IXI607-Guys-1097_mask.nii.gz,,4.0 28 | 384,IXI417-Guys-0939_t1.nii.gz,1,178,0,1,1,1,1,1,2006-05-08,59.88,0,/Train/ixi/t1/IXI417-Guys-0939_t1.nii.gz,/Train/ixi/mask/IXI417-Guys-0939_mask.nii.gz,,2.0 29 | 31,IXI043-Guys-0714_t1.nii.gz,2,160,58,1,1,3,5,1,2005-07-13,22.65,0,/Train/ixi/t1/IXI043-Guys-0714_t1.nii.gz,/Train/ixi/mask/IXI043-Guys-0714_mask.nii.gz,,1.0 30 | 216,IXI240-Guys-0834_t1.nii.gz,1,173,73,1,2,1,5,1,2005-11-25,39.04,0,/Train/ixi/t1/IXI240-Guys-0834_t1.nii.gz,/Train/ixi/mask/IXI240-Guys-0834_mask.nii.gz,,1.0 31 | 340,IXI375-Guys-0925_t1.nii.gz,2,187,96,2,0,0,0,1,2006-04-25,41.06,0,/Train/ixi/t1/IXI375-Guys-0925_t1.nii.gz,/Train/ixi/mask/IXI375-Guys-0925_mask.nii.gz,,2.0 32 | 39,IXI051-HH-1328_t1.nii.gz,2,168,58,1,3,1,5,1,2005-07-21,25.74,0,/Train/ixi/t1/IXI051-HH-1328_t1.nii.gz,/Train/ixi/mask/IXI051-HH-1328_mask.nii.gz,,1.0 33 | 237,IXI262-HH-1861_t1.nii.gz,2,156,47,5,1,3,3,1,2006-02-09,20.93,0,/Train/ixi/t1/IXI262-HH-1861_t1.nii.gz,/Train/ixi/mask/IXI262-HH-1861_mask.nii.gz,,1.0 34 | 330,IXI365-Guys-0923_t1.nii.gz,2,157,75,1,4,5,1,1,2006-04-10,66.53,0,/Train/ixi/t1/IXI365-Guys-0923_t1.nii.gz,/Train/ixi/mask/IXI365-Guys-0923_mask.nii.gz,,3.0 35 | 124,IXI143-Guys-0785_t1.nii.gz,2,173,67,3,2,1,5,1,2005-09-14,31.96,0,/Train/ixi/t1/IXI143-Guys-0785_t1.nii.gz,/Train/ixi/mask/IXI143-Guys-0785_mask.nii.gz,,1.0 36 | 4,IXI014-HH-1236_t1.nii.gz,2,163,65,1,4,1,5,1,2005-06-09,34.24,0,/Train/ixi/t1/IXI014-HH-1236_t1.nii.gz,/Train/ixi/mask/IXI014-HH-1236_mask.nii.gz,,1.0 37 | 400,IXI434-IOP-1010_t1.nii.gz,2,160,60,1,1,5,5,1,2006-06-14,67.24,0,/Train/ixi/t1/IXI434-IOP-1010_t1.nii.gz,/Train/ixi/mask/IXI434-IOP-1010_mask.nii.gz,,3.0 38 | 145,IXI166-Guys-0846_t1.nii.gz,2,163,68,1,3,1,5,1,2006-01-13,41.2,0,/Train/ixi/t1/IXI166-Guys-0846_t1.nii.gz,/Train/ixi/mask/IXI166-Guys-0846_mask.nii.gz,,2.0 39 | 440,IXI477-IOP-1141_t1.nii.gz,2,157,62,1,2,1,5,1,2006-08-23,46.43,0,/Train/ixi/t1/IXI477-IOP-1141_t1.nii.gz,/Train/ixi/mask/IXI477-IOP-1141_mask.nii.gz,,2.0 40 | 187,IXI211-HH-1568_t1.nii.gz,1,186,94,2,1,1,3,1,2005-10-06,55.12,0,/Train/ixi/t1/IXI211-HH-1568_t1.nii.gz,/Train/ixi/mask/IXI211-HH-1568_mask.nii.gz,,2.0 41 | 588,IXI619-Guys-1099_t1.nii.gz,1,182,70,1,3,3,2,1,2006-11-17,24.88,0,/Train/ixi/t1/IXI619-Guys-1099_t1.nii.gz,/Train/ixi/mask/IXI619-Guys-1099_mask.nii.gz,,1.0 42 | 267,IXI295-HH-1814_t1.nii.gz,2,163,83,3,2,1,2,1,2006-01-26,50.9,0,/Train/ixi/t1/IXI295-HH-1814_t1.nii.gz,/Train/ixi/mask/IXI295-HH-1814_mask.nii.gz,,2.0 43 | 42,IXI054-Guys-0707_t1.nii.gz,2,151,59,1,1,1,5,1,2005-07-08,41.75,0,/Train/ixi/t1/IXI054-Guys-0707_t1.nii.gz,/Train/ixi/mask/IXI054-Guys-0707_mask.nii.gz,,2.0 44 | 470,IXI503-Guys-1021_t1.nii.gz,1,180,79,1,2,5,5,1,2006-08-07,65.1,0,/Train/ixi/t1/IXI503-Guys-1021_t1.nii.gz,/Train/ixi/mask/IXI503-Guys-1021_mask.nii.gz,,3.0 45 | 293,IXI321-Guys-0903_t1.nii.gz,2,168,53,1,2,1,5,1,2006-03-17,54.74,0,/Train/ixi/t1/IXI321-Guys-0903_t1.nii.gz,/Train/ixi/mask/IXI321-Guys-0903_mask.nii.gz,,2.0 46 | -------------------------------------------------------------------------------- /Data/splits/IXI_val_fold1.csv: -------------------------------------------------------------------------------- 1 | ,img_name,"SEX_ID (1=m, 2=f)",HEIGHT,WEIGHT,ETHNIC_ID,MARITAL_ID,OCCUPATION_ID,QUALIFICATION_ID,DATE_AVAILABLE,STUDY_DATE,age,label,img_path,mask_path,seg_path,Agegroup 2 | 187,IXI211-HH-1568_t1.nii.gz,1,186,94,2,1,1,3,1,2005-10-06,55.12,0,/Train/ixi/t1/IXI211-HH-1568_t1.nii.gz,/Train/ixi/mask/IXI211-HH-1568_mask.nii.gz,,2.0 3 | 63,IXI075-Guys-0754_t1.nii.gz,2,157,58,1,1,1,4,1,2005-08-12,35.99,0,/Train/ixi/t1/IXI075-Guys-0754_t1.nii.gz,/Train/ixi/mask/IXI075-Guys-0754_mask.nii.gz,,1.0 4 | 423,IXI456-Guys-1019_t1.nii.gz,2,165,86,1,3,2,2,1,2006-08-11,70.25,0,/Train/ixi/t1/IXI456-Guys-1019_t1.nii.gz,/Train/ixi/mask/IXI456-Guys-1019_mask.nii.gz,,3.0 5 | 365,IXI399-Guys-0966_t1.nii.gz,2,165,70,1,2,5,3,1,2006-06-12,73.54,0,/Train/ixi/t1/IXI399-Guys-0966_t1.nii.gz,/Train/ixi/mask/IXI399-Guys-0966_mask.nii.gz,,3.0 6 | 237,IXI262-HH-1861_t1.nii.gz,2,156,47,5,1,3,3,1,2006-02-09,20.93,0,/Train/ixi/t1/IXI262-HH-1861_t1.nii.gz,/Train/ixi/mask/IXI262-HH-1861_mask.nii.gz,,1.0 7 | 530,IXI555-Guys-1074_t1.nii.gz,2,0,79,1,1,5,1,1,2006-09-18,65.67,0,/Train/ixi/t1/IXI555-Guys-1074_t1.nii.gz,/Train/ixi/mask/IXI555-Guys-1074_mask.nii.gz,,3.0 8 | 447,IXI484-HH-2179_t1.nii.gz,1,170,72,1,2,5,4,1,2006-07-13,66.37,0,/Train/ixi/t1/IXI484-HH-2179_t1.nii.gz,/Train/ixi/mask/IXI484-HH-2179_mask.nii.gz,,3.0 9 | 359,IXI393-Guys-0941_t1.nii.gz,2,175,90,1,2,7,4,1,2006-05-05,60.1,0,/Train/ixi/t1/IXI393-Guys-0941_t1.nii.gz,/Train/ixi/mask/IXI393-Guys-0941_mask.nii.gz,,3.0 10 | 433,IXI468-Guys-0985_t1.nii.gz,1,180,70,1,2,5,5,1,2006-07-17,67.7,0,/Train/ixi/t1/IXI468-Guys-0985_t1.nii.gz,/Train/ixi/mask/IXI468-Guys-0985_mask.nii.gz,,3.0 11 | 385,IXI418-Guys-0956_t1.nii.gz,2,165,64,1,2,2,2,1,2006-06-17,59.74,0,/Train/ixi/t1/IXI418-Guys-0956_t1.nii.gz,/Train/ixi/mask/IXI418-Guys-0956_mask.nii.gz,,2.0 12 | 514,IXI542-IOP-1147_t1.nii.gz,2,154,64,1,1,1,5,1,2006-09-27,44.38,0,/Train/ixi/t1/IXI542-IOP-1147_t1.nii.gz,/Train/ixi/mask/IXI542-IOP-1147_mask.nii.gz,,2.0 13 | 145,IXI166-Guys-0846_t1.nii.gz,2,163,68,1,3,1,5,1,2006-01-13,41.2,0,/Train/ixi/t1/IXI166-Guys-0846_t1.nii.gz,/Train/ixi/mask/IXI166-Guys-0846_mask.nii.gz,,2.0 14 | 453,IXI489-Guys-1014_t1.nii.gz,2,155,50,1,2,5,5,1,2006-08-01,69.07,0,/Train/ixi/t1/IXI489-Guys-1014_t1.nii.gz,/Train/ixi/mask/IXI489-Guys-1014_mask.nii.gz,,3.0 15 | 360,IXI394-Guys-0940_t1.nii.gz,1,184,80,1,2,1,5,1,2006-05-08,53.75,0,/Train/ixi/t1/IXI394-Guys-0940_t1.nii.gz,/Train/ixi/mask/IXI394-Guys-0940_mask.nii.gz,,2.0 16 | 387,IXI420-Guys-1028_t1.nii.gz,2,154,58,1,2,5,4,1,2006-08-14,73.18,0,/Train/ixi/t1/IXI420-Guys-1028_t1.nii.gz,/Train/ixi/mask/IXI420-Guys-1028_mask.nii.gz,,3.0 17 | 402,IXI436-HH-2153_t1.nii.gz,1,190,73,1,1,1,1,1,2006-05-25,28.35,0,/Train/ixi/t1/IXI436-HH-2153_t1.nii.gz,/Train/ixi/mask/IXI436-HH-2153_mask.nii.gz,,1.0 18 | 542,IXI568-HH-2607_t1.nii.gz,2,165,58,1,1,1,2,1,2006-10-26,33.69,0,/Train/ixi/t1/IXI568-HH-2607_t1.nii.gz,/Train/ixi/mask/IXI568-HH-2607_mask.nii.gz,,1.0 19 | 206,IXI230-IOP-0869_t1.nii.gz,2,178,64,1,1,1,3,1,2006-02-06,21.15,0,/Train/ixi/t1/IXI230-IOP-0869_t1.nii.gz,/Train/ixi/mask/IXI230-IOP-0869_mask.nii.gz,,1.0 20 | 321,IXI357-HH-2076_t1.nii.gz,2,155,58,1,2,5,4,1,2006-04-20,63.54,0,/Train/ixi/t1/IXI357-HH-2076_t1.nii.gz,/Train/ixi/mask/IXI357-HH-2076_mask.nii.gz,,3.0 21 | 353,IXI388-IOP-0973_t1.nii.gz,1,185,88,5,2,1,5,1,2006-05-22,33.35,0,/Train/ixi/t1/IXI388-IOP-0973_t1.nii.gz,/Train/ixi/mask/IXI388-IOP-0973_mask.nii.gz,,1.0 22 | 159,IXI181-Guys-0790_t1.nii.gz,1,188,102,1,3,1,5,1,2005-10-14,45.7,0,/Train/ixi/t1/IXI181-Guys-0790_t1.nii.gz,/Train/ixi/mask/IXI181-Guys-0790_mask.nii.gz,,2.0 23 | 152,IXI174-HH-1571_t1.nii.gz,2,158,64,3,2,5,2,1,2005-10-06,63.1,0,/Train/ixi/t1/IXI174-HH-1571_t1.nii.gz,/Train/ixi/mask/IXI174-HH-1571_mask.nii.gz,,3.0 24 | 59,IXI071-Guys-0770_t1.nii.gz,2,165,63,1,2,1,4,1,2005-08-26,37.0,0,/Train/ixi/t1/IXI071-Guys-0770_t1.nii.gz,/Train/ixi/mask/IXI071-Guys-0770_mask.nii.gz,,1.0 25 | 440,IXI477-IOP-1141_t1.nii.gz,2,157,62,1,2,1,5,1,2006-08-23,46.43,0,/Train/ixi/t1/IXI477-IOP-1141_t1.nii.gz,/Train/ixi/mask/IXI477-IOP-1141_mask.nii.gz,,2.0 26 | 245,IXI270-Guys-0847_t1.nii.gz,2,170,70,1,5,5,5,1,2006-01-13,61.02,0,/Train/ixi/t1/IXI270-Guys-0847_t1.nii.gz,/Train/ixi/mask/IXI270-Guys-0847_mask.nii.gz,,3.0 27 | 51,IXI063-Guys-0742_t1.nii.gz,1,178,102,1,1,1,3,1,2005-08-05,41.12,0,/Train/ixi/t1/IXI063-Guys-0742_t1.nii.gz,/Train/ixi/mask/IXI063-Guys-0742_mask.nii.gz,,2.0 28 | 203,IXI226-HH-1618_t1.nii.gz,2,175,75,1,2,1,5,1,2005-10-27,41.9,0,/Train/ixi/t1/IXI226-HH-1618_t1.nii.gz,/Train/ixi/mask/IXI226-HH-1618_mask.nii.gz,,2.0 29 | 110,IXI128-HH-1470_t1.nii.gz,2,151,60,1,1,3,5,1,2005-09-08,27.33,0,/Train/ixi/t1/IXI128-HH-1470_t1.nii.gz,/Train/ixi/mask/IXI128-HH-1470_mask.nii.gz,,1.0 30 | 95,IXI110-Guys-0733_t1.nii.gz,1,180,75,5,2,1,5,1,2005-08-01,37.77,0,/Train/ixi/t1/IXI110-Guys-0733_t1.nii.gz,/Train/ixi/mask/IXI110-Guys-0733_mask.nii.gz,,1.0 31 | 150,IXI172-Guys-0982_t1.nii.gz,1,185,100,1,2,5,5,1,2006-07-14,74.99,0,/Train/ixi/t1/IXI172-Guys-0982_t1.nii.gz,/Train/ixi/mask/IXI172-Guys-0982_mask.nii.gz,,3.0 32 | 39,IXI051-HH-1328_t1.nii.gz,2,168,58,1,3,1,5,1,2005-07-21,25.74,0,/Train/ixi/t1/IXI051-HH-1328_t1.nii.gz,/Train/ixi/mask/IXI051-HH-1328_mask.nii.gz,,1.0 33 | 519,IXI548-IOP-1150_t1.nii.gz,2,153,50,1,2,5,5,1,2006-10-09,66.33,0,/Train/ixi/t1/IXI548-IOP-1150_t1.nii.gz,/Train/ixi/mask/IXI548-IOP-1150_mask.nii.gz,,3.0 34 | 21,IXI033-HH-1259_t1.nii.gz,1,174,62,1,1,3,5,1,2005-06-23,24.76,0,/Train/ixi/t1/IXI033-HH-1259_t1.nii.gz,/Train/ixi/mask/IXI033-HH-1259_mask.nii.gz,,1.0 35 | 268,IXI296-HH-1970_t1.nii.gz,2,152,54,1,5,8,4,1,2006-03-16,63.97,0,/Train/ixi/t1/IXI296-HH-1970_t1.nii.gz,/Train/ixi/mask/IXI296-HH-1970_mask.nii.gz,,3.0 36 | 50,IXI062-Guys-0740_t1.nii.gz,1,185,99,1,2,1,5,1,2005-08-05,36.22,0,/Train/ixi/t1/IXI062-Guys-0740_t1.nii.gz,/Train/ixi/mask/IXI062-Guys-0740_mask.nii.gz,,1.0 37 | 346,IXI381-Guys-1024_t1.nii.gz,2,160,75,1,3,1,4,1,2006-08-11,59.67,0,/Train/ixi/t1/IXI381-Guys-1024_t1.nii.gz,/Train/ixi/mask/IXI381-Guys-1024_mask.nii.gz,,2.0 38 | 247,IXI275-HH-1803_t1.nii.gz,1,190,95,1,1,1,5,1,2006-01-19,24.18,0,/Train/ixi/t1/IXI275-HH-1803_t1.nii.gz,/Train/ixi/mask/IXI275-HH-1803_mask.nii.gz,,1.0 39 | 54,IXI066-Guys-0731_t1.nii.gz,2,153,71,1,2,1,5,1,2005-07-27,46.17,0,/Train/ixi/t1/IXI066-Guys-0731_t1.nii.gz,/Train/ixi/mask/IXI066-Guys-0731_mask.nii.gz,,2.0 40 | 338,IXI373-IOP-0967_t1.nii.gz,2,165,69,3,2,2,2,1,2006-05-03,58.79,0,/Train/ixi/t1/IXI373-IOP-0967_t1.nii.gz,/Train/ixi/mask/IXI373-IOP-0967_mask.nii.gz,,2.0 41 | 430,IXI464-IOP-1029_t1.nii.gz,2,0,0,1,5,5,5,1,2006-07-26,86.32,0,/Train/ixi/t1/IXI464-IOP-1029_t1.nii.gz,/Train/ixi/mask/IXI464-IOP-1029_mask.nii.gz,,4.0 42 | 570,IXI599-HH-2659_t1.nii.gz,1,183,90,1,4,1,2,1,2006-11-02,39.38,0,/Train/ixi/t1/IXI599-HH-2659_t1.nii.gz,/Train/ixi/mask/IXI599-HH-2659_mask.nii.gz,,1.0 43 | 125,IXI144-Guys-0788_t1.nii.gz,2,168,53,1,1,3,5,1,2005-09-14,29.36,0,/Train/ixi/t1/IXI144-Guys-0788_t1.nii.gz,/Train/ixi/mask/IXI144-Guys-0788_mask.nii.gz,,1.0 44 | 280,IXI308-Guys-0884_t1.nii.gz,1,0,76,1,3,1,5,1,2006-03-03,41.2,0,/Train/ixi/t1/IXI308-Guys-0884_t1.nii.gz,/Train/ixi/mask/IXI308-Guys-0884_mask.nii.gz,,2.0 45 | 363,IXI397-Guys-0953_t1.nii.gz,2,170,115,1,2,4,0,1,2006-06-09,51.66,0,/Train/ixi/t1/IXI397-Guys-0953_t1.nii.gz,/Train/ixi/mask/IXI397-Guys-0953_mask.nii.gz,,2.0 46 | -------------------------------------------------------------------------------- /Data/splits/IXI_val_fold2.csv: -------------------------------------------------------------------------------- 1 | ,img_name,"SEX_ID (1=m, 2=f)",HEIGHT,WEIGHT,ETHNIC_ID,MARITAL_ID,OCCUPATION_ID,QUALIFICATION_ID,DATE_AVAILABLE,STUDY_DATE,age,label,img_path,mask_path,seg_path,Agegroup 2 | 60,IXI072-HH-2324_t1.nii.gz,1,161,82,6,2,5,4,1,2006-08-24,68.6,0,/Train/ixi/t1/IXI072-HH-2324_t1.nii.gz,/Train/ixi/mask/IXI072-HH-2324_mask.nii.gz,,3.0 3 | 529,IXI554-Guys-1068_t1.nii.gz,1,0,0,0,0,0,0,1,2006-09-08,70.11,0,/Train/ixi/t1/IXI554-Guys-1068_t1.nii.gz,/Train/ixi/mask/IXI554-Guys-1068_mask.nii.gz,,3.0 4 | 73,IXI087-Guys-0768_t1.nii.gz,2,168,57,1,3,1,5,1,2005-08-26,25.66,0,/Train/ixi/t1/IXI087-Guys-0768_t1.nii.gz,/Train/ixi/mask/IXI087-Guys-0768_mask.nii.gz,,1.0 5 | 7,IXI017-Guys-0698_t1.nii.gz,2,178,72,1,3,1,5,1,2005-06-24,29.09,0,/Train/ixi/t1/IXI017-Guys-0698_t1.nii.gz,/Train/ixi/mask/IXI017-Guys-0698_mask.nii.gz,,1.0 6 | 583,IXI613-HH-2734_t1.nii.gz,1,177,84,1,3,1,5,1,2006-11-16,25.59,0,/Train/ixi/t1/IXI613-HH-2734_t1.nii.gz,/Train/ixi/mask/IXI613-HH-2734_mask.nii.gz,,1.0 7 | 354,IXI389-Guys-0930_t1.nii.gz,2,168,73,1,5,1,1,1,2006-04-28,58.37,0,/Train/ixi/t1/IXI389-Guys-0930_t1.nii.gz,/Train/ixi/mask/IXI389-Guys-0930_mask.nii.gz,,2.0 8 | 406,IXI440-HH-2127_t1.nii.gz,1,165,78,1,2,1,4,1,2006-05-18,48.1,0,/Train/ixi/t1/IXI440-HH-2127_t1.nii.gz,/Train/ixi/mask/IXI440-HH-2127_mask.nii.gz,,2.0 9 | 447,IXI484-HH-2179_t1.nii.gz,1,170,72,1,2,5,4,1,2006-07-13,66.37,0,/Train/ixi/t1/IXI484-HH-2179_t1.nii.gz,/Train/ixi/mask/IXI484-HH-2179_mask.nii.gz,,3.0 10 | 341,IXI376-Guys-0938_t1.nii.gz,2,160,0,1,2,5,5,1,2006-05-08,72.59,0,/Train/ixi/t1/IXI376-Guys-0938_t1.nii.gz,/Train/ixi/mask/IXI376-Guys-0938_mask.nii.gz,,3.0 11 | 115,IXI134-Guys-0780_t1.nii.gz,1,183,78,1,1,6,5,1,2005-09-09,47.35,0,/Train/ixi/t1/IXI134-Guys-0780_t1.nii.gz,/Train/ixi/mask/IXI134-Guys-0780_mask.nii.gz,,2.0 12 | 576,IXI606-HH-2601_t1.nii.gz,1,178,86,1,1,8,5,1,2006-10-25,60.54,0,/Train/ixi/t1/IXI606-HH-2601_t1.nii.gz,/Train/ixi/mask/IXI606-HH-2601_mask.nii.gz,,3.0 13 | 577,IXI607-Guys-1097_t1.nii.gz,1,0,960,1,5,5,1,1,2006-11-06,83.81,0,/Train/ixi/t1/IXI607-Guys-1097_t1.nii.gz,/Train/ixi/mask/IXI607-Guys-1097_mask.nii.gz,,4.0 14 | 375,IXI409-Guys-0960_t1.nii.gz,2,0,64,1,1,5,1,1,2006-06-17,70.95,0,/Train/ixi/t1/IXI409-Guys-0960_t1.nii.gz,/Train/ixi/mask/IXI409-Guys-0960_mask.nii.gz,,3.0 15 | 257,IXI285-Guys-0857_t1.nii.gz,1,176,69,1,1,1,5,1,2006-01-27,40.81,0,/Train/ixi/t1/IXI285-Guys-0857_t1.nii.gz,/Train/ixi/mask/IXI285-Guys-0857_mask.nii.gz,,2.0 16 | 562,IXI591-Guys-1084_t1.nii.gz,2,167,54,1,2,5,3,1,2006-10-30,59.89,0,/Train/ixi/t1/IXI591-Guys-1084_t1.nii.gz,/Train/ixi/mask/IXI591-Guys-1084_mask.nii.gz,,2.0 17 | 94,IXI109-Guys-0732_t1.nii.gz,2,164,66,5,2,6,5,1,2005-08-01,36.47,0,/Train/ixi/t1/IXI109-Guys-0732_t1.nii.gz,/Train/ixi/mask/IXI109-Guys-0732_mask.nii.gz,,1.0 18 | 122,IXI141-Guys-0789_t1.nii.gz,2,163,57,1,3,1,4,1,2005-10-14,46.01,0,/Train/ixi/t1/IXI141-Guys-0789_t1.nii.gz,/Train/ixi/mask/IXI141-Guys-0789_mask.nii.gz,,2.0 19 | 25,IXI037-Guys-0704_t1.nii.gz,2,164,100,1,3,3,5,1,2005-06-30,37.21,0,/Train/ixi/t1/IXI037-Guys-0704_t1.nii.gz,/Train/ixi/mask/IXI037-Guys-0704_mask.nii.gz,,1.0 20 | 613,IXI651-Guys-1118_t1.nii.gz,1,175,61,3,2,8,2,1,2006-12-01,50.4,0,/Train/ixi/t1/IXI651-Guys-1118_t1.nii.gz,/Train/ixi/mask/IXI651-Guys-1118_mask.nii.gz,,2.0 21 | 582,IXI612-HH-2688_t1.nii.gz,1,180,80,1,1,1,5,1,2006-11-08,33.92,0,/Train/ixi/t1/IXI612-HH-2688_t1.nii.gz,/Train/ixi/mask/IXI612-HH-2688_mask.nii.gz,,1.0 22 | 165,IXI188-Guys-0798_t1.nii.gz,2,0,114,1,1,1,2,1,2005-10-21,44.02,0,/Train/ixi/t1/IXI188-Guys-0798_t1.nii.gz,/Train/ixi/mask/IXI188-Guys-0798_mask.nii.gz,,2.0 23 | 36,IXI048-HH-1326_t1.nii.gz,1,194,90,1,2,1,5,1,2005-07-21,50.65,0,/Train/ixi/t1/IXI048-HH-1326_t1.nii.gz,/Train/ixi/mask/IXI048-HH-1326_mask.nii.gz,,2.0 24 | 435,IXI470-IOP-1030_t1.nii.gz,1,187,90,1,1,1,5,1,2006-07-31,35.98,0,/Train/ixi/t1/IXI470-IOP-1030_t1.nii.gz,/Train/ixi/mask/IXI470-IOP-1030_mask.nii.gz,,1.0 25 | 31,IXI043-Guys-0714_t1.nii.gz,2,160,58,1,1,3,5,1,2005-07-13,22.65,0,/Train/ixi/t1/IXI043-Guys-0714_t1.nii.gz,/Train/ixi/mask/IXI043-Guys-0714_mask.nii.gz,,1.0 26 | 187,IXI211-HH-1568_t1.nii.gz,1,186,94,2,1,1,3,1,2005-10-06,55.12,0,/Train/ixi/t1/IXI211-HH-1568_t1.nii.gz,/Train/ixi/mask/IXI211-HH-1568_mask.nii.gz,,2.0 27 | 460,IXI495-Guys-1009_t1.nii.gz,2,163,80,1,4,5,5,1,2006-07-31,60.87,0,/Train/ixi/t1/IXI495-Guys-1009_t1.nii.gz,/Train/ixi/mask/IXI495-Guys-1009_mask.nii.gz,,3.0 28 | 495,IXI528-Guys-1073_t1.nii.gz,1,172,80,1,3,1,5,1,2006-09-18,40.53,0,/Train/ixi/t1/IXI528-Guys-1073_t1.nii.gz,/Train/ixi/mask/IXI528-Guys-1073_mask.nii.gz,,2.0 29 | 599,IXI631-HH-2651_t1.nii.gz,1,185,89,1,3,1,3,1,2006-11-01,41.3,0,/Train/ixi/t1/IXI631-HH-2651_t1.nii.gz,/Train/ixi/mask/IXI631-HH-2651_mask.nii.gz,,2.0 30 | 200,IXI223-Guys-0830_t1.nii.gz,1,183,86,1,2,2,4,1,2005-11-18,63.25,0,/Train/ixi/t1/IXI223-Guys-0830_t1.nii.gz,/Train/ixi/mask/IXI223-Guys-0830_mask.nii.gz,,3.0 31 | 71,IXI085-Guys-0759_t1.nii.gz,2,166,59,6,1,1,5,1,2005-08-19,31.85,0,/Train/ixi/t1/IXI085-Guys-0759_t1.nii.gz,/Train/ixi/mask/IXI085-Guys-0759_mask.nii.gz,,1.0 32 | 24,IXI036-Guys-0736_t1.nii.gz,1,186,75,1,1,3,3,1,2005-08-02,23.54,0,/Train/ixi/t1/IXI036-Guys-0736_t1.nii.gz,/Train/ixi/mask/IXI036-Guys-0736_mask.nii.gz,,1.0 33 | 315,IXI350-Guys-0908_t1.nii.gz,1,0,0,0,0,0,0,1,2006-04-03,62.68,0,/Train/ixi/t1/IXI350-Guys-0908_t1.nii.gz,/Train/ixi/mask/IXI350-Guys-0908_mask.nii.gz,,3.0 34 | 294,IXI322-IOP-0891_t1.nii.gz,2,165,60,1,0,0,0,1,2006-03-06,28.48,0,/Train/ixi/t1/IXI322-IOP-0891_t1.nii.gz,/Train/ixi/mask/IXI322-IOP-0891_mask.nii.gz,,1.0 35 | 419,IXI452-HH-2213_t1.nii.gz,2,157,76,1,2,1,5,1,2006-07-27,63.4,0,/Train/ixi/t1/IXI452-HH-2213_t1.nii.gz,/Train/ixi/mask/IXI452-HH-2213_mask.nii.gz,,3.0 36 | 511,IXI539-Guys-1067_t1.nii.gz,1,0,0,0,0,0,0,1,2006-09-08,78.07,0,/Train/ixi/t1/IXI539-Guys-1067_t1.nii.gz,/Train/ixi/mask/IXI539-Guys-1067_mask.nii.gz,,3.0 37 | 515,IXI543-IOP-1148_t1.nii.gz,1,175,78,1,2,1,5,1,2006-09-04,70.52,0,/Train/ixi/t1/IXI543-IOP-1148_t1.nii.gz,/Train/ixi/mask/IXI543-IOP-1148_mask.nii.gz,,3.0 38 | 21,IXI033-HH-1259_t1.nii.gz,1,174,62,1,1,3,5,1,2005-06-23,24.76,0,/Train/ixi/t1/IXI033-HH-1259_t1.nii.gz,/Train/ixi/mask/IXI033-HH-1259_mask.nii.gz,,1.0 39 | 18,IXI029-Guys-0829_t1.nii.gz,2,155,79,4,5,1,5,1,2005-11-18,59.22,0,/Train/ixi/t1/IXI029-Guys-0829_t1.nii.gz,/Train/ixi/mask/IXI029-Guys-0829_mask.nii.gz,,2.0 40 | 97,IXI112-Guys-0735_t1.nii.gz,1,165,65,5,1,3,5,1,2005-08-01,23.7,0,/Train/ixi/t1/IXI112-Guys-0735_t1.nii.gz,/Train/ixi/mask/IXI112-Guys-0735_mask.nii.gz,,1.0 41 | 84,IXI099-Guys-0748_t1.nii.gz,1,152,76,1,2,1,4,1,2005-08-10,51.75,0,/Train/ixi/t1/IXI099-Guys-0748_t1.nii.gz,/Train/ixi/mask/IXI099-Guys-0748_mask.nii.gz,,2.0 42 | 214,IXI238-IOP-0883_t1.nii.gz,2,150,55,1,1,1,5,1,2006-02-27,27.73,0,/Train/ixi/t1/IXI238-IOP-0883_t1.nii.gz,/Train/ixi/mask/IXI238-IOP-0883_mask.nii.gz,,1.0 43 | 572,IXI601-HH-2700_t1.nii.gz,1,173,79,1,1,1,5,1,2006-11-09,35.52,0,/Train/ixi/t1/IXI601-HH-2700_t1.nii.gz,/Train/ixi/mask/IXI601-HH-2700_mask.nii.gz,,1.0 44 | 429,IXI463-IOP-1043_t1.nii.gz,2,165,76,1,1,5,3,1,2006-07-20,70.14,0,/Train/ixi/t1/IXI463-IOP-1043_t1.nii.gz,/Train/ixi/mask/IXI463-IOP-1043_mask.nii.gz,,3.0 45 | 610,IXI646-HH-2653_t1.nii.gz,2,165,60,6,5,5,3,1,2006-11-01,71.21,0,/Train/ixi/t1/IXI646-HH-2653_t1.nii.gz,/Train/ixi/mask/IXI646-HH-2653_mask.nii.gz,,3.0 46 | -------------------------------------------------------------------------------- /Data/splits/IXI_val_fold3.csv: -------------------------------------------------------------------------------- 1 | ,img_name,"SEX_ID (1=m, 2=f)",HEIGHT,WEIGHT,ETHNIC_ID,MARITAL_ID,OCCUPATION_ID,QUALIFICATION_ID,DATE_AVAILABLE,STUDY_DATE,age,label,img_path,mask_path,seg_path,Agegroup 2 | 457,IXI493-Guys-1007_t1.nii.gz,1,165,63,1,2,5,1,1,2006-07-31,67.77,0,/Train/ixi/t1/IXI493-Guys-1007_t1.nii.gz,/Train/ixi/mask/IXI493-Guys-1007_mask.nii.gz,,3.0 3 | 86,IXI101-Guys-0749_t1.nii.gz,1,180,102,1,2,1,4,1,2005-08-10,45.77,0,/Train/ixi/t1/IXI101-Guys-0749_t1.nii.gz,/Train/ixi/mask/IXI101-Guys-0749_mask.nii.gz,,2.0 4 | 118,IXI137-HH-1472_t1.nii.gz,1,172,73,1,1,1,5,1,2005-09-08,41.23,0,/Train/ixi/t1/IXI137-HH-1472_t1.nii.gz,/Train/ixi/mask/IXI137-HH-1472_mask.nii.gz,,2.0 5 | 250,IXI278-HH-1771_t1.nii.gz,1,184,65,5,1,3,3,1,2006-01-12,20.21,0,/Train/ixi/t1/IXI278-HH-1771_t1.nii.gz,/Train/ixi/mask/IXI278-HH-1771_mask.nii.gz,,1.0 6 | 21,IXI033-HH-1259_t1.nii.gz,1,174,62,1,1,3,5,1,2005-06-23,24.76,0,/Train/ixi/t1/IXI033-HH-1259_t1.nii.gz,/Train/ixi/mask/IXI033-HH-1259_mask.nii.gz,,1.0 7 | 206,IXI230-IOP-0869_t1.nii.gz,2,178,64,1,1,1,3,1,2006-02-06,21.15,0,/Train/ixi/t1/IXI230-IOP-0869_t1.nii.gz,/Train/ixi/mask/IXI230-IOP-0869_mask.nii.gz,,1.0 8 | 231,IXI256-HH-1723_t1.nii.gz,1,190,82,2,1,3,5,1,2005-12-22,27.38,0,/Train/ixi/t1/IXI256-HH-1723_t1.nii.gz,/Train/ixi/mask/IXI256-HH-1723_mask.nii.gz,,1.0 9 | 522,IXI550-Guys-1069_t1.nii.gz,2,0,0,0,0,0,0,1,2006-09-08,59.91,0,/Train/ixi/t1/IXI550-Guys-1069_t1.nii.gz,/Train/ixi/mask/IXI550-Guys-1069_mask.nii.gz,,2.0 10 | 493,IXI526-HH-2392_t1.nii.gz,2,163,95,4,1,1,1,1,2006-09-14,54.61,0,/Train/ixi/t1/IXI526-HH-2392_t1.nii.gz,/Train/ixi/mask/IXI526-HH-2392_mask.nii.gz,,2.0 11 | 251,IXI279-Guys-1044_t1.nii.gz,1,178,81,1,2,1,5,1,2006-01-27,51.88,0,/Train/ixi/t1/IXI279-Guys-1044_t1.nii.gz,/Train/ixi/mask/IXI279-Guys-1044_mask.nii.gz,,2.0 12 | 506,IXI536-Guys-1059_t1.nii.gz,2,167,63,1,5,1,5,1,2006-08-31,58.83,0,/Train/ixi/t1/IXI536-Guys-1059_t1.nii.gz,/Train/ixi/mask/IXI536-Guys-1059_mask.nii.gz,,2.0 13 | 116,IXI135-Guys-0779_t1.nii.gz,1,173,83,3,1,1,5,1,2005-09-09,29.0,0,/Train/ixi/t1/IXI135-Guys-0779_t1.nii.gz,/Train/ixi/mask/IXI135-Guys-0779_mask.nii.gz,,1.0 14 | 467,IXI500-Guys-1017_t1.nii.gz,1,186,83,1,2,5,3,1,2006-08-04,63.18,0,/Train/ixi/t1/IXI500-Guys-1017_t1.nii.gz,/Train/ixi/mask/IXI500-Guys-1017_mask.nii.gz,,3.0 15 | 265,IXI293-IOP-0876_t1.nii.gz,2,170,72,1,1,1,5,1,2006-02-13,26.52,0,/Train/ixi/t1/IXI293-IOP-0876_t1.nii.gz,/Train/ixi/mask/IXI293-IOP-0876_mask.nii.gz,,1.0 16 | 472,IXI505-Guys-1026_t1.nii.gz,2,152,0,1,4,2,1,1,2006-08-14,62.25,0,/Train/ixi/t1/IXI505-Guys-1026_t1.nii.gz,/Train/ixi/mask/IXI505-Guys-1026_mask.nii.gz,,3.0 17 | 306,IXI335-HH-1906_t1.nii.gz,1,180,100,1,2,5,5,1,2006-02-23,69.28,0,/Train/ixi/t1/IXI335-HH-1906_t1.nii.gz,/Train/ixi/mask/IXI335-HH-1906_mask.nii.gz,,3.0 18 | 536,IXI561-IOP-1152_t1.nii.gz,2,157,59,1,1,1,5,1,2006-10-17,31.58,0,/Train/ixi/t1/IXI561-IOP-1152_t1.nii.gz,/Train/ixi/mask/IXI561-IOP-1152_mask.nii.gz,,1.0 19 | 611,IXI648-Guys-1107_t1.nii.gz,1,193,120,1,1,6,4,1,2006-11-27,47.72,0,/Train/ixi/t1/IXI648-Guys-1107_t1.nii.gz,/Train/ixi/mask/IXI648-Guys-1107_mask.nii.gz,,2.0 20 | 499,IXI532-IOP-1145_t1.nii.gz,1,203,78,1,2,1,4,1,2006-09-21,36.55,0,/Train/ixi/t1/IXI532-IOP-1145_t1.nii.gz,/Train/ixi/mask/IXI532-IOP-1145_mask.nii.gz,,1.0 21 | 193,IXI217-HH-1638_t1.nii.gz,2,168,68,1,4,8,2,1,2005-11-03,57.21,0,/Train/ixi/t1/IXI217-HH-1638_t1.nii.gz,/Train/ixi/mask/IXI217-HH-1638_mask.nii.gz,,2.0 22 | 464,IXI498-Guys-1050_t1.nii.gz,2,0,0,0,0,0,0,1,2006-08-25,76.78,0,/Train/ixi/t1/IXI498-Guys-1050_t1.nii.gz,/Train/ixi/mask/IXI498-Guys-1050_mask.nii.gz,,3.0 23 | 517,IXI546-HH-2450_t1.nii.gz,2,165,60,1,2,2,5,1,2006-09-28,60.07,0,/Train/ixi/t1/IXI546-HH-2450_t1.nii.gz,/Train/ixi/mask/IXI546-HH-2450_mask.nii.gz,,3.0 24 | 492,IXI525-HH-2413_t1.nii.gz,1,185,92,1,2,1,5,1,2006-09-21,41.52,0,/Train/ixi/t1/IXI525-HH-2413_t1.nii.gz,/Train/ixi/mask/IXI525-HH-2413_mask.nii.gz,,2.0 25 | 375,IXI409-Guys-0960_t1.nii.gz,2,0,64,1,1,5,1,1,2006-06-17,70.95,0,/Train/ixi/t1/IXI409-Guys-0960_t1.nii.gz,/Train/ixi/mask/IXI409-Guys-0960_mask.nii.gz,,3.0 26 | 511,IXI539-Guys-1067_t1.nii.gz,1,0,0,0,0,0,0,1,2006-09-08,78.07,0,/Train/ixi/t1/IXI539-Guys-1067_t1.nii.gz,/Train/ixi/mask/IXI539-Guys-1067_mask.nii.gz,,3.0 27 | 550,IXI576-Guys-1077_t1.nii.gz,2,162,58,1,4,5,4,1,2006-10-02,67.33,0,/Train/ixi/t1/IXI576-Guys-1077_t1.nii.gz,/Train/ixi/mask/IXI576-Guys-1077_mask.nii.gz,,3.0 28 | 537,IXI562-Guys-1131_t1.nii.gz,2,157,55,1,4,7,5,1,2006-10-09,42.79,0,/Train/ixi/t1/IXI562-Guys-1131_t1.nii.gz,/Train/ixi/mask/IXI562-Guys-1131_mask.nii.gz,,2.0 29 | 475,IXI508-HH-2268_t1.nii.gz,2,160,54,1,5,5,2,1,2006-08-10,60.93,0,/Train/ixi/t1/IXI508-HH-2268_t1.nii.gz,/Train/ixi/mask/IXI508-HH-2268_mask.nii.gz,,3.0 30 | 551,IXI577-HH-2661_t1.nii.gz,2,170,67,1,2,1,1,1,2006-11-02,64.19,0,/Train/ixi/t1/IXI577-HH-2661_t1.nii.gz,/Train/ixi/mask/IXI577-HH-2661_mask.nii.gz,,3.0 31 | 47,IXI059-HH-1284_t1.nii.gz,1,188,90,1,2,1,5,1,2005-07-04,34.14,0,/Train/ixi/t1/IXI059-HH-1284_t1.nii.gz,/Train/ixi/mask/IXI059-HH-1284_mask.nii.gz,,1.0 32 | 232,IXI257-HH-1724_t1.nii.gz,1,175,62,3,1,5,4,1,2005-12-22,69.55,0,/Train/ixi/t1/IXI257-HH-1724_t1.nii.gz,/Train/ixi/mask/IXI257-HH-1724_mask.nii.gz,,3.0 33 | 114,IXI132-HH-1415_t1.nii.gz,1,163,60,1,1,1,5,1,2005-08-18,58.72,0,/Train/ixi/t1/IXI132-HH-1415_t1.nii.gz,/Train/ixi/mask/IXI132-HH-1415_mask.nii.gz,,2.0 34 | 65,IXI077-Guys-0752_t1.nii.gz,2,175,65,1,1,1,5,1,2005-08-12,36.48,0,/Train/ixi/t1/IXI077-Guys-0752_t1.nii.gz,/Train/ixi/mask/IXI077-Guys-0752_mask.nii.gz,,1.0 35 | 201,IXI224-Guys-0823_t1.nii.gz,2,168,89,1,1,1,4,1,2005-11-11,36.92,0,/Train/ixi/t1/IXI224-Guys-0823_t1.nii.gz,/Train/ixi/mask/IXI224-Guys-0823_mask.nii.gz,,1.0 36 | 282,IXI310-IOP-0890_t1.nii.gz,2,150,55,1,4,1,5,1,2006-02-28,62.02,0,/Train/ixi/t1/IXI310-IOP-0890_t1.nii.gz,/Train/ixi/mask/IXI310-IOP-0890_mask.nii.gz,,3.0 37 | 340,IXI375-Guys-0925_t1.nii.gz,2,187,96,2,0,0,0,1,2006-04-25,41.06,0,/Train/ixi/t1/IXI375-Guys-0925_t1.nii.gz,/Train/ixi/mask/IXI375-Guys-0925_mask.nii.gz,,2.0 38 | 75,IXI090-Guys-0800_t1.nii.gz,1,178,90,3,1,3,5,1,2005-10-21,41.82,0,/Train/ixi/t1/IXI090-Guys-0800_t1.nii.gz,/Train/ixi/mask/IXI090-Guys-0800_mask.nii.gz,,2.0 39 | 466,IXI499-Guys-1004_t1.nii.gz,1,0,0,0,0,0,0,1,2006-07-28,82.19,0,/Train/ixi/t1/IXI499-Guys-1004_t1.nii.gz,/Train/ixi/mask/IXI499-Guys-1004_mask.nii.gz,,4.0 40 | 71,IXI085-Guys-0759_t1.nii.gz,2,166,59,6,1,1,5,1,2005-08-19,31.85,0,/Train/ixi/t1/IXI085-Guys-0759_t1.nii.gz,/Train/ixi/mask/IXI085-Guys-0759_mask.nii.gz,,1.0 41 | 66,IXI078-Guys-0751_t1.nii.gz,2,168,67,1,1,1,5,1,2005-08-12,28.51,0,/Train/ixi/t1/IXI078-Guys-0751_t1.nii.gz,/Train/ixi/mask/IXI078-Guys-0751_mask.nii.gz,,1.0 42 | 585,IXI616-Guys-1092_t1.nii.gz,1,162,70,1,2,1,3,1,2006-11-06,55.09,0,/Train/ixi/t1/IXI616-Guys-1092_t1.nii.gz,/Train/ixi/mask/IXI616-Guys-1092_mask.nii.gz,,2.0 43 | 610,IXI646-HH-2653_t1.nii.gz,2,165,60,6,5,5,3,1,2006-11-01,71.21,0,/Train/ixi/t1/IXI646-HH-2653_t1.nii.gz,/Train/ixi/mask/IXI646-HH-2653_mask.nii.gz,,3.0 44 | 92,IXI107-Guys-0761_t1.nii.gz,2,173,57,1,1,1,5,1,2005-08-24,31.9,0,/Train/ixi/t1/IXI107-Guys-0761_t1.nii.gz,/Train/ixi/mask/IXI107-Guys-0761_mask.nii.gz,,1.0 45 | 69,IXI083-HH-1357_t1.nii.gz,1,182,78,1,2,1,5,1,2005-08-04,30.89,0,/Train/ixi/t1/IXI083-HH-1357_t1.nii.gz,/Train/ixi/mask/IXI083-HH-1357_mask.nii.gz,,1.0 46 | -------------------------------------------------------------------------------- /Data/splits/IXI_val_fold4.csv: -------------------------------------------------------------------------------- 1 | ,img_name,"SEX_ID (1=m, 2=f)",HEIGHT,WEIGHT,ETHNIC_ID,MARITAL_ID,OCCUPATION_ID,QUALIFICATION_ID,DATE_AVAILABLE,STUDY_DATE,age,label,img_path,mask_path,seg_path,Agegroup 2 | 614,IXI652-Guys-1116_t1.nii.gz,1,163,80,1,1,1,5,1,2006-12-01,42.99,0,/Train/ixi/t1/IXI652-Guys-1116_t1.nii.gz,/Train/ixi/mask/IXI652-Guys-1116_mask.nii.gz,,2.0 3 | 246,IXI274-HH-2294_t1.nii.gz,2,163,75,4,5,5,4,1,2006-08-17,73.54,0,/Train/ixi/t1/IXI274-HH-2294_t1.nii.gz,/Train/ixi/mask/IXI274-HH-2294_mask.nii.gz,,3.0 4 | 576,IXI606-HH-2601_t1.nii.gz,1,178,86,1,1,8,5,1,2006-10-25,60.54,0,/Train/ixi/t1/IXI606-HH-2601_t1.nii.gz,/Train/ixi/mask/IXI606-HH-2601_mask.nii.gz,,3.0 5 | 150,IXI172-Guys-0982_t1.nii.gz,1,185,100,1,2,5,5,1,2006-07-14,74.99,0,/Train/ixi/t1/IXI172-Guys-0982_t1.nii.gz,/Train/ixi/mask/IXI172-Guys-0982_mask.nii.gz,,3.0 6 | 225,IXI251-Guys-1055_t1.nii.gz,2,167,74,1,5,5,1,1,2006-08-29,80.17,0,/Train/ixi/t1/IXI251-Guys-1055_t1.nii.gz,/Train/ixi/mask/IXI251-Guys-1055_mask.nii.gz,,4.0 7 | 106,IXI122-Guys-0773_t1.nii.gz,2,168,70,1,1,3,5,1,2005-09-02,22.97,0,/Train/ixi/t1/IXI122-Guys-0773_t1.nii.gz,/Train/ixi/mask/IXI122-Guys-0773_mask.nii.gz,,1.0 8 | 359,IXI393-Guys-0941_t1.nii.gz,2,175,90,1,2,7,4,1,2006-05-05,60.1,0,/Train/ixi/t1/IXI393-Guys-0941_t1.nii.gz,/Train/ixi/mask/IXI393-Guys-0941_mask.nii.gz,,3.0 9 | 495,IXI528-Guys-1073_t1.nii.gz,1,172,80,1,3,1,5,1,2006-09-18,40.53,0,/Train/ixi/t1/IXI528-Guys-1073_t1.nii.gz,/Train/ixi/mask/IXI528-Guys-1073_mask.nii.gz,,2.0 10 | 514,IXI542-IOP-1147_t1.nii.gz,2,154,64,1,1,1,5,1,2006-09-27,44.38,0,/Train/ixi/t1/IXI542-IOP-1147_t1.nii.gz,/Train/ixi/mask/IXI542-IOP-1147_mask.nii.gz,,2.0 11 | 422,IXI455-Guys-0981_t1.nii.gz,2,164,63,1,2,5,5,1,2006-07-10,68.48,0,/Train/ixi/t1/IXI455-Guys-0981_t1.nii.gz,/Train/ixi/mask/IXI455-Guys-0981_mask.nii.gz,,3.0 12 | 178,IXI201-HH-1588_t1.nii.gz,2,150,54,4,1,3,1,1,2005-10-13,22.41,0,/Train/ixi/t1/IXI201-HH-1588_t1.nii.gz,/Train/ixi/mask/IXI201-HH-1588_mask.nii.gz,,1.0 13 | 580,IXI610-HH-2649_t1.nii.gz,1,180,75,1,2,5,4,1,2006-11-01,56.68,0,/Train/ixi/t1/IXI610-HH-2649_t1.nii.gz,/Train/ixi/mask/IXI610-HH-2649_mask.nii.gz,,2.0 14 | 521,IXI549-Guys-1046_t1.nii.gz,2,162,54,0,0,0,0,1,2006-08-24,25.02,0,/Train/ixi/t1/IXI549-Guys-1046_t1.nii.gz,/Train/ixi/mask/IXI549-Guys-1046_mask.nii.gz,,1.0 15 | 80,IXI095-HH-1390_t1.nii.gz,2,170,64,1,1,3,5,1,2005-08-11,24.9,0,/Train/ixi/t1/IXI095-HH-1390_t1.nii.gz,/Train/ixi/mask/IXI095-HH-1390_mask.nii.gz,,1.0 16 | 124,IXI143-Guys-0785_t1.nii.gz,2,173,67,3,2,1,5,1,2005-09-14,31.96,0,/Train/ixi/t1/IXI143-Guys-0785_t1.nii.gz,/Train/ixi/mask/IXI143-Guys-0785_mask.nii.gz,,1.0 17 | 289,IXI317-Guys-0896_t1.nii.gz,1,170,73,1,2,2,5,1,2006-03-10,67.79,0,/Train/ixi/t1/IXI317-Guys-0896_t1.nii.gz,/Train/ixi/mask/IXI317-Guys-0896_mask.nii.gz,,3.0 18 | 58,IXI070-Guys-0767_t1.nii.gz,2,173,64,1,1,3,3,1,2005-08-26,20.7,0,/Train/ixi/t1/IXI070-Guys-0767_t1.nii.gz,/Train/ixi/mask/IXI070-Guys-0767_mask.nii.gz,,1.0 19 | 28,IXI040-Guys-0724_t1.nii.gz,2,0,68,1,3,2,5,1,2005-07-22,44.09,0,/Train/ixi/t1/IXI040-Guys-0724_t1.nii.gz,/Train/ixi/mask/IXI040-Guys-0724_mask.nii.gz,,2.0 20 | 469,IXI502-Guys-1020_t1.nii.gz,2,165,65,1,2,5,5,1,2006-08-07,63.88,0,/Train/ixi/t1/IXI502-Guys-1020_t1.nii.gz,/Train/ixi/mask/IXI502-Guys-1020_mask.nii.gz,,3.0 21 | 10,IXI021-Guys-0703_t1.nii.gz,2,165,64,1,1,3,3,1,2005-06-30,21.57,0,/Train/ixi/t1/IXI021-Guys-0703_t1.nii.gz,/Train/ixi/mask/IXI021-Guys-0703_mask.nii.gz,,1.0 22 | 467,IXI500-Guys-1017_t1.nii.gz,1,186,83,1,2,5,3,1,2006-08-04,63.18,0,/Train/ixi/t1/IXI500-Guys-1017_t1.nii.gz,/Train/ixi/mask/IXI500-Guys-1017_mask.nii.gz,,3.0 23 | 532,IXI558-Guys-1079_t1.nii.gz,2,162,55,1,5,8,5,1,2006-10-06,66.86,0,/Train/ixi/t1/IXI558-Guys-1079_t1.nii.gz,/Train/ixi/mask/IXI558-Guys-1079_mask.nii.gz,,3.0 24 | 610,IXI646-HH-2653_t1.nii.gz,2,165,60,6,5,5,3,1,2006-11-01,71.21,0,/Train/ixi/t1/IXI646-HH-2653_t1.nii.gz,/Train/ixi/mask/IXI646-HH-2653_mask.nii.gz,,3.0 25 | 507,IXI536-Guys-1059_t1.nii.gz,2,0,0,0,0,0,0,1,2006-08-31,58.83,0,/Train/ixi/t1/IXI536-Guys-1059_t1.nii.gz,/Train/ixi/mask/IXI536-Guys-1059_mask.nii.gz,,2.0 26 | 531,IXI556-HH-2452_t1.nii.gz,1,186,88,1,1,6,2,1,2006-10-05,55.64,0,/Train/ixi/t1/IXI556-HH-2452_t1.nii.gz,/Train/ixi/mask/IXI556-HH-2452_mask.nii.gz,,2.0 27 | 557,IXI584-Guys-1129_t1.nii.gz,1,182,96,1,3,1,5,1,2006-10-16,41.33,0,/Train/ixi/t1/IXI584-Guys-1129_t1.nii.gz,/Train/ixi/mask/IXI584-Guys-1129_mask.nii.gz,,2.0 28 | 559,IXI586-HH-2451_t1.nii.gz,1,175,78,1,3,1,5,1,2006-10-05,34.37,0,/Train/ixi/t1/IXI586-HH-2451_t1.nii.gz,/Train/ixi/mask/IXI586-HH-2451_mask.nii.gz,,1.0 29 | 207,IXI231-IOP-0866_t1.nii.gz,2,171,92,1,0,0,0,1,2006-02-07,58.99,0,/Train/ixi/t1/IXI231-IOP-0866_t1.nii.gz,/Train/ixi/mask/IXI231-IOP-0866_mask.nii.gz,,2.0 30 | 502,IXI534-Guys-1062_t1.nii.gz,1,175,88,1,2,2,1,1,2006-09-01,69.61,0,/Train/ixi/t1/IXI534-Guys-1062_t1.nii.gz,/Train/ixi/mask/IXI534-Guys-1062_mask.nii.gz,,3.0 31 | 96,IXI111-Guys-0734_t1.nii.gz,1,176,75,3,1,3,5,1,2005-08-01,25.19,0,/Train/ixi/t1/IXI111-Guys-0734_t1.nii.gz,/Train/ixi/mask/IXI111-Guys-0734_mask.nii.gz,,1.0 32 | 111,IXI129-Guys-0775_t1.nii.gz,1,183,74,1,1,3,5,1,2005-09-02,23.04,0,/Train/ixi/t1/IXI129-Guys-0775_t1.nii.gz,/Train/ixi/mask/IXI129-Guys-0775_mask.nii.gz,,1.0 33 | 546,IXI572-HH-2605_t1.nii.gz,2,170,57,2,4,1,5,1,2006-10-26,41.95,0,/Train/ixi/t1/IXI572-HH-2605_t1.nii.gz,/Train/ixi/mask/IXI572-HH-2605_mask.nii.gz,,2.0 34 | 268,IXI296-HH-1970_t1.nii.gz,2,152,54,1,5,8,4,1,2006-03-16,63.97,0,/Train/ixi/t1/IXI296-HH-1970_t1.nii.gz,/Train/ixi/mask/IXI296-HH-1970_mask.nii.gz,,3.0 35 | 32,IXI044-Guys-0712_t1.nii.gz,2,163,69,1,3,3,5,1,2005-07-13,44.85,0,/Train/ixi/t1/IXI044-Guys-0712_t1.nii.gz,/Train/ixi/mask/IXI044-Guys-0712_mask.nii.gz,,2.0 36 | 94,IXI109-Guys-0732_t1.nii.gz,2,164,66,5,2,6,5,1,2005-08-01,36.47,0,/Train/ixi/t1/IXI109-Guys-0732_t1.nii.gz,/Train/ixi/mask/IXI109-Guys-0732_mask.nii.gz,,1.0 37 | 77,IXI092-HH-1436_t1.nii.gz,1,180,80,4,2,1,4,1,2005-08-25,33.24,0,/Train/ixi/t1/IXI092-HH-1436_t1.nii.gz,/Train/ixi/mask/IXI092-HH-1436_mask.nii.gz,,1.0 38 | 31,IXI043-Guys-0714_t1.nii.gz,2,160,58,1,1,3,5,1,2005-07-13,22.65,0,/Train/ixi/t1/IXI043-Guys-0714_t1.nii.gz,/Train/ixi/mask/IXI043-Guys-0714_mask.nii.gz,,1.0 39 | 260,IXI288-Guys-0879_t1.nii.gz,2,163,65,1,2,1,5,1,2006-02-17,78.58,0,/Train/ixi/t1/IXI288-Guys-0879_t1.nii.gz,/Train/ixi/mask/IXI288-Guys-0879_mask.nii.gz,,3.0 40 | 581,IXI611-HH-2650_t1.nii.gz,1,175,76,1,1,1,2,1,2006-11-01,28.87,0,/Train/ixi/t1/IXI611-HH-2650_t1.nii.gz,/Train/ixi/mask/IXI611-HH-2650_mask.nii.gz,,1.0 41 | 565,IXI594-Guys-1089_t1.nii.gz,1,176,99,1,1,5,5,1,2006-11-03,62.0,0,/Train/ixi/t1/IXI594-Guys-1089_t1.nii.gz,/Train/ixi/mask/IXI594-Guys-1089_mask.nii.gz,,3.0 42 | 611,IXI648-Guys-1107_t1.nii.gz,1,193,120,1,1,6,4,1,2006-11-27,47.72,0,/Train/ixi/t1/IXI648-Guys-1107_t1.nii.gz,/Train/ixi/mask/IXI648-Guys-1107_mask.nii.gz,,2.0 43 | 593,IXI625-Guys-1098_t1.nii.gz,1,167,75,1,2,1,4,1,2006-11-17,47.07,0,/Train/ixi/t1/IXI625-Guys-1098_t1.nii.gz,/Train/ixi/mask/IXI625-Guys-1098_mask.nii.gz,,2.0 44 | 346,IXI381-Guys-1024_t1.nii.gz,2,160,75,1,3,1,4,1,2006-08-11,59.67,0,/Train/ixi/t1/IXI381-Guys-1024_t1.nii.gz,/Train/ixi/mask/IXI381-Guys-1024_mask.nii.gz,,2.0 45 | 9,IXI020-Guys-0700_t1.nii.gz,1,178,72,1,2,1,4,1,2005-06-24,39.47,0,/Train/ixi/t1/IXI020-Guys-0700_t1.nii.gz,/Train/ixi/mask/IXI020-Guys-0700_mask.nii.gz,,1.0 46 | -------------------------------------------------------------------------------- /Data/splits/MSLUB_test.csv: -------------------------------------------------------------------------------- 1 | ,img_name,age,sex,ms_type,label,img_path,mask_path,seg_path,Agegroup 2 | 25,patient26,40,F,RR,1,/Test/MSLUB/t1/patient26_t1.nii.gz,/Test/MSLUB/mask/patient26_mask.nii.gz,/Test/MSLUB/seg/patient26_seg.nii.gz,1.0 3 | 12,patient13,26,M,RR,1,/Test/MSLUB/t1/patient13_t1.nii.gz,/Test/MSLUB/mask/patient13_mask.nii.gz,/Test/MSLUB/seg/patient13_seg.nii.gz,1.0 4 | 13,patient14,42,M,RR,1,/Test/MSLUB/t1/patient14_t1.nii.gz,/Test/MSLUB/mask/patient14_mask.nii.gz,/Test/MSLUB/seg/patient14_seg.nii.gz,2.0 5 | 6,patient07,53,F,RR,1,/Test/MSLUB/t1/patient07_t1.nii.gz,/Test/MSLUB/mask/patient07_mask.nii.gz,/Test/MSLUB/seg/patient07_seg.nii.gz,2.0 6 | 7,patient08,41,M,RR,1,/Test/MSLUB/t1/patient08_t1.nii.gz,/Test/MSLUB/mask/patient08_mask.nii.gz,/Test/MSLUB/seg/patient08_seg.nii.gz,2.0 7 | 8,patient09,40,F,RR,1,/Test/MSLUB/t1/patient09_t1.nii.gz,/Test/MSLUB/mask/patient09_mask.nii.gz,/Test/MSLUB/seg/patient09_seg.nii.gz,1.0 8 | 10,patient11,29,M,RR,1,/Test/MSLUB/t1/patient11_t1.nii.gz,/Test/MSLUB/mask/patient11_mask.nii.gz,/Test/MSLUB/seg/patient11_seg.nii.gz,1.0 9 | 20,patient21,33,F,RR,1,/Test/MSLUB/t1/patient21_t1.nii.gz,/Test/MSLUB/mask/patient21_mask.nii.gz,/Test/MSLUB/seg/patient21_seg.nii.gz,1.0 10 | 9,patient10,64,F,RR,1,/Test/MSLUB/t1/patient10_t1.nii.gz,/Test/MSLUB/mask/patient10_mask.nii.gz,/Test/MSLUB/seg/patient10_seg.nii.gz,3.0 11 | 29,patient30,54,F,RR,1,/Test/MSLUB/t1/patient30_t1.nii.gz,/Test/MSLUB/mask/patient30_mask.nii.gz,/Test/MSLUB/seg/patient30_seg.nii.gz,2.0 12 | 5,patient06,37,F,SP,1,/Test/MSLUB/t1/patient06_t1.nii.gz,/Test/MSLUB/mask/patient06_mask.nii.gz,/Test/MSLUB/seg/patient06_seg.nii.gz,1.0 13 | 24,patient25,35,F,RR,1,/Test/MSLUB/t1/patient25_t1.nii.gz,/Test/MSLUB/mask/patient25_mask.nii.gz,/Test/MSLUB/seg/patient25_seg.nii.gz,1.0 14 | 26,patient27,39,F,RR,1,/Test/MSLUB/t1/patient27_t1.nii.gz,/Test/MSLUB/mask/patient27_mask.nii.gz,/Test/MSLUB/seg/patient27_seg.nii.gz,1.0 15 | 23,patient24,43,M,RR,1,/Test/MSLUB/t1/patient24_t1.nii.gz,/Test/MSLUB/mask/patient24_mask.nii.gz,/Test/MSLUB/seg/patient24_seg.nii.gz,2.0 16 | 22,patient23,39,F,RR,1,/Test/MSLUB/t1/patient23_t1.nii.gz,/Test/MSLUB/mask/patient23_mask.nii.gz,/Test/MSLUB/seg/patient23_seg.nii.gz,1.0 17 | 15,patient16,42,F,RR,1,/Test/MSLUB/t1/patient16_t1.nii.gz,/Test/MSLUB/mask/patient16_mask.nii.gz,/Test/MSLUB/seg/patient16_seg.nii.gz,2.0 18 | 1,patient02,33,M,CIS,1,/Test/MSLUB/t1/patient02_t1.nii.gz,/Test/MSLUB/mask/patient02_mask.nii.gz,/Test/MSLUB/seg/patient02_seg.nii.gz,1.0 19 | 11,patient12,39,F,RR,1,/Test/MSLUB/t1/patient12_t1.nii.gz,/Test/MSLUB/mask/patient12_mask.nii.gz,/Test/MSLUB/seg/patient12_seg.nii.gz,1.0 20 | 18,patient19,47,F,RR,1,/Test/MSLUB/t1/patient19_t1.nii.gz,/Test/MSLUB/mask/patient19_mask.nii.gz,/Test/MSLUB/seg/patient19_seg.nii.gz,2.0 21 | 3,patient04,25,M,SP,1,/Test/MSLUB/t1/patient04_t1.nii.gz,/Test/MSLUB/mask/patient04_mask.nii.gz,/Test/MSLUB/seg/patient04_seg.nii.gz,1.0 22 | -------------------------------------------------------------------------------- /Data/splits/MSLUB_val.csv: -------------------------------------------------------------------------------- 1 | ,img_name,age,sex,ms_type,label,img_path,mask_path,seg_path,Agegroup 2 | 28,patient29,26,F,CIS,1,/Test/MSLUB/t1/patient29_t1.nii.gz,/Test/MSLUB/mask/patient29_mask.nii.gz,/Test/MSLUB/seg/patient29_seg.nii.gz,1.0 3 | 14,patient15,57,F,PR,1,/Test/MSLUB/t1/patient15_t1.nii.gz,/Test/MSLUB/mask/patient15_mask.nii.gz,/Test/MSLUB/seg/patient15_seg.nii.gz,2.0 4 | 2,patient03,37,F,,1,/Test/MSLUB/t1/patient03_t1.nii.gz,/Test/MSLUB/mask/patient03_mask.nii.gz,/Test/MSLUB/seg/patient03_seg.nii.gz,1.0 5 | 27,patient28,39,F,RR,1,/Test/MSLUB/t1/patient28_t1.nii.gz,/Test/MSLUB/mask/patient28_mask.nii.gz,/Test/MSLUB/seg/patient28_seg.nii.gz,1.0 6 | 19,patient20,37,F,RR,1,/Test/MSLUB/t1/patient20_t1.nii.gz,/Test/MSLUB/mask/patient20_mask.nii.gz,/Test/MSLUB/seg/patient20_seg.nii.gz,1.0 7 | 16,patient17,27,F,RR,1,/Test/MSLUB/t1/patient17_t1.nii.gz,/Test/MSLUB/mask/patient17_mask.nii.gz,/Test/MSLUB/seg/patient17_seg.nii.gz,1.0 8 | 21,patient22,30,F,RR,1,/Test/MSLUB/t1/patient22_t1.nii.gz,/Test/MSLUB/mask/patient22_mask.nii.gz,/Test/MSLUB/seg/patient22_seg.nii.gz,1.0 9 | 17,patient18,60,F,RR,1,/Test/MSLUB/t1/patient18_t1.nii.gz,/Test/MSLUB/mask/patient18_mask.nii.gz,/Test/MSLUB/seg/patient18_seg.nii.gz,2.0 10 | 0,patient01,31,F,RR,1,/Test/MSLUB/t1/patient01_t1.nii.gz,/Test/MSLUB/mask/patient01_mask.nii.gz,/Test/MSLUB/seg/patient01_seg.nii.gz,1.0 11 | 4,patient05,33,F,RR,1,/Test/MSLUB/t1/patient05_t1.nii.gz,/Test/MSLUB/mask/patient05_mask.nii.gz,/Test/MSLUB/seg/patient05_seg.nii.gz,1.0 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditioned-Diffusion-Models-UAD 2 | Codebase for the paper [Guided Reconstruction with Conditioned Diffusion Models for Unsupervised Anomaly Detection in Brain MRIs](https://www.sciencedirect.com/science/article/pii/S0010482525000101). 3 | 4 | **Abstract**: 5 | Unsupervised anomaly detection in Brain MRIs aims to identify abnormalities as outliers from a healthy training distribution. Reconstruction-based approaches that use generative models to learn to reconstruct healthy brain anatomy are commonly used for this task. Diffusion models are an emerging class of deep generative models that show great potential regarding reconstruction fidelity. However, they face challenges in preserving intensity characteristics in the reconstructed images, limiting their performance in anomaly detection. To address this challenge, we propose to condition the denoising mechanism of diffusion models with additional information about the image to reconstruct coming from a latent representation of the noise-free input image. This conditioning enables high-fidelity reconstruction of healthy brain structures while aligning local intensity characteristics of input-reconstruction pairs. We evaluate our method's reconstruction quality, domain adaptation features and finally segmentation performance on publicly available data sets with various pathologies. Using our proposed conditioning mechanism we can reduce the false-positive predictions and enable a more precise delineation of anomalies which significantly enhances the anomaly detection performance compared to established state-of-the-art approaches to unsupervised anomaly detection in brain MRI. Furthermore, our approach shows promise in domain adaptation across different MRI acquisitions and simulated contrasts, a crucial property of general anomaly detection methods. 6 | ## Model Architecture 7 | 8 | ![Model Architecture](cDDPM_Model.png) 9 | 10 | 11 | ## Data 12 | We use the IXI data set, the BraTS21, MSLUB, ATLAS_v2 and WMH data set for our experiments. 13 | You can download/request the original data sets here: 14 | 15 | * [IXI](https://brain-development.org/ixi-dataset/) 16 | * [BraTS21](http://braintumorsegmentation.org/) 17 | * [MSLUB](https://lit.fe.uni-lj.si/en/research/resources/3D-MR-MS/) 18 | * [ATLAS v2](https://fcon_1000.projects.nitrc.org/indi/retro/atlas.html) 19 | * [WMH](https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/AECRSD ) 20 | 21 | If you’d like to use our preprocessed data, we’ve made preprocessed versions of the datasets available [here](https://1drv.ms/u/c/66229029a9e95461/EVb21X1kmXxCh_xfqMNmzH8B1Rqe_wWDHYzoQuiGj94k3Q?e=wjFP6h) (approx. 37G). 22 | 23 | After downloading, the directory structure of should look like this: 24 | 25 | 26 | ├── Train 27 | │ ├── ixi 28 | │ │ ├── mask 29 | │ │ ├── t2 30 | │ │ └── t1 31 | ├── Test 32 | │ ├── Brats21 33 | │ │ ├── mask 34 | │ │ ├── t2 35 | │ │ └──seg 36 | │ ├── MSLUB 37 | │ │ ├── mask 38 | │ │ ├── t2 39 | │ │ └── seg 40 | │ ├── ATLAS_v2 41 | │ │ ├── mask 42 | │ │ ├── t1 43 | │ │ └── seg 44 | │ └── ... 45 | ├── splits 46 | │ ├── Brats21_test.csv 47 | │ ├── Brats21_val.csv 48 | │ ├── MSLUB_val.csv 49 | │ ├── MSLUB_test.csv 50 | │ ├── IXI_train_fold0.csv 51 | │ ├── IXI_train_fold1.csv 52 | │ └── ... 53 | └── ... 54 | 55 | You should then specify the location of in the pc_environment.env file. Additionally, specify the , where runs will be saved. 56 | 57 | ## Environment Set-up 58 | To download the code type 59 | 60 | git clone git@github.com:FinnBehrendt/conditioned-Diffusion-Models-UAD.git 61 | 62 | In your linux terminal and switch directories via 63 | 64 | cd conditioned-Diffusion-Models-UAD 65 | 66 | To setup the environment with all required packages and libraries, you need to install anaconda first. 67 | 68 | Then, run 69 | 70 | conda env create -f environment.yml -n cddpm-uad 71 | 72 | and subsequently run 73 | 74 | conda activate cddpm-uad 75 | pip install -r requirements.txt 76 | 77 | to install all required packages. 78 | 79 | ## Run Experiments 80 | 81 | To run the training and evaluation of the cDDPM without pretraining, you can simply run 82 | 83 | python run.py experiment=cDDPM/DDPM_cond_spark_2D model.cfg.pretrained_encoder=False 84 | 85 | For better performance, you can pretrain the encoder via masked pretraining (Spark) 86 | 87 | python run.py experiment=cDDPM/Spark_2D_pretrain 88 | 89 | Having pretrained the encoder, you can now run 90 | 91 | python run.py experiment=cDDPM/DDPM_cond_spark_2D encoder_path= 92 | 93 | The will be placed in the . Alternatively, you will find the best checkpoint path printed in the terminal. 94 | 95 | ## Citation 96 | If you make use of our work, you can cite it via 97 | 98 | 99 | @article{BEHRENDT2025109660, 100 | abstract = {The application of supervised models to clinical screening tasks is challenging due to the need for annotated data for each considered pathology. Unsupervised Anomaly Detection (UAD) is an alternative approach that aims to identify any anomaly as an outlier from a healthy training distribution. A prevalent strategy for UAD in brain MRI involves using generative models to learn the reconstruction of healthy brain anatomy for a given input image. As these models should fail to reconstruct unhealthy structures, the reconstruction errors indicate anomalies. However, a significant challenge is to balance the accurate reconstruction of healthy anatomy and the undesired replication of abnormal structures. While diffusion models have shown promising results with detailed and accurate reconstructions, they face challenges in preserving intensity characteristics, resulting in false positives. We propose conditioning the denoising process of diffusion models with additional information derived from a latent representation of the input image. We demonstrate that this conditioning allows for accurate and local adaptation to the general input intensity distribution while avoiding the replication of unhealthy structures. We compare the novel approach to different state-of-the-art methods and for different data sets. Our results show substantial improvements in the segmentation performance, with the Dice score improved by 11.9%, 20.0%, and 44.6%, for the BraTS, ATLAS and MSLUB data sets, respectively, while maintaining competitive performance on the WMH data set. Furthermore, our results indicate effective domain adaptation across different MRI acquisitions and simulated contrasts, an important attribute for general anomaly detection methods. The code for our work is available at https://github.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD.}, 101 | author = {Finn Behrendt and Debayan Bhattacharya and Robin Mieling and Lennart Maack and Julia Kr{\"u}ger and Roland Opfer and Alexander Schlaefer}, 102 | doi = {https://doi.org/10.1016/j.compbiomed.2025.109660}, 103 | issn = {0010-4825}, 104 | journal = {Computers in Biology and Medicine}, 105 | keywords = {Unsupervised anomaly detection, Segmentation, Brain MRI, Diffusion models}, 106 | pages = {109660}, 107 | title = {Guided reconstruction with conditioned diffusion models for unsupervised anomaly detection in brain MRIs}, 108 | url = {https://www.sciencedirect.com/science/article/pii/S0010482525000101}, 109 | volume = {186}, 110 | year = {2025}, 111 | bdsk-url-1 = {https://www.sciencedirect.com/science/article/pii/S0010482525000101}, 112 | bdsk-url-2 = {https://doi.org/10.1016/j.compbiomed.2025.109660}} 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /cDDPM_Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/cDDPM_Model.png -------------------------------------------------------------------------------- /configs/callbacks/checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 3 | monitor: 'val/Loss_comb' 4 | save_top_k: 1 5 | auto_insert_metric_name: False 6 | save_last: True 7 | mode: "min" 8 | dirpath: "checkpoints/" 9 | filename: "epoch-{epoch}_step-{step}_loss-{val/Loss_comb:.2f}" 10 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default.yaml 7 | - model: DDPM_2D.yaml 8 | - datamodule: IXI.yaml 9 | - callbacks: 10 | - checkpoint.yaml # set this to null if you don't want to use callbacks 11 | 12 | 13 | - logger: # set logger here or use command line (e.g. `python run.py logger=wandb`) 14 | - wandb 15 | - csv 16 | - experiment: DDPM.yaml # set experiment here or use command line (e.g. `python run.py experiment=DDPM`) 17 | - mode: default.yaml 18 | 19 | # enable color logging 20 | - override hydra/hydra_logging: colorlog 21 | - override hydra/job_logging: colorlog 22 | 23 | # path to original working directory 24 | # hydra hijacks working directory by changing it to the current log directory, 25 | # so it's useful to have this path as a special variable 26 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 27 | work_dir: ${hydra:runtime.cwd} 28 | 29 | # path to folder with data 30 | data_dir: ${oc.env:DATA_DIR} 31 | log_dir: ${oc.env:LOG_DIR} 32 | name : ${experiment.name} 33 | # use `python run.py debug=true` for easy debugging! 34 | # this will run 1 train, val and test loop with only 1 batch 35 | # equivalent to running `python run.py trainer.fast_dev_run=true` 36 | # (this is placed here just for easier access from command line) 37 | debug: False 38 | 39 | # pretty print config at the start of the run using Rich library 40 | print_config: True 41 | 42 | # disable python warnings if they annoy you 43 | ignore_warnings: False 44 | 45 | # check performance on test set, using the best model achieved during training 46 | # lightning chooses best model based on metric specified in checkpoint callback 47 | test_after_training: True 48 | 49 | onlyEval: False 50 | new_wandb_run: False # if we want to reevaluate to a new wandb run 51 | checkpoint: 'best' # which checkpoints to load 52 | 53 | 54 | load_checkpoint: path_to_ckpt # path to checkpoint to load 55 | -------------------------------------------------------------------------------- /configs/datamodule/IXI.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.Datamodules_train.IXI 2 | 3 | cfg: 4 | name: IXI 5 | path: 6 | pathBase: ${data_dir} 7 | 8 | IXI: 9 | IDs: 10 | train: 11 | - ${data_dir}/Data/splits/IXI_train_fold0.csv 12 | - ${data_dir}/Data/splits/IXI_train_fold1.csv 13 | - ${data_dir}/Data/splits/IXI_train_fold2.csv 14 | - ${data_dir}/Data/splits/IXI_train_fold3.csv 15 | - ${data_dir}/Data/splits/IXI_train_fold4.csv 16 | val: 17 | - ${data_dir}/Data/splits/IXI_val_fold0.csv 18 | - ${data_dir}/Data/splits/IXI_val_fold1.csv 19 | - ${data_dir}/Data/splits/IXI_val_fold2.csv 20 | - ${data_dir}/Data/splits/IXI_val_fold3.csv 21 | - ${data_dir}/Data/splits/IXI_val_fold4.csv 22 | test: ${data_dir}/Data/splits/IXI_test.csv 23 | keep_t2: ${data_dir}/Data/splits/avail_t2.csv 24 | 25 | Brats21: 26 | IDs: 27 | test: ${data_dir}/Data/splits/Brats21_test.csv 28 | val: ${data_dir}/Data/splits/Brats21_val.csv 29 | 30 | MSLUB: 31 | IDs: 32 | test: ${data_dir}/Data/splits/MSLUB_test.csv 33 | val: ${data_dir}/Data/splits/MSLUB_val.csv 34 | 35 | 36 | imageDim: [160,192,160] 37 | rescaleFactor: 2 38 | interRes: [8,8,5] #[HxWxD] 39 | cropMode: 'isotropic' 40 | spatialDims: ${model.cfg.spatialDims} 41 | unisotropic_sampling: True 42 | sample_set: False 43 | 44 | preLoad: True 45 | curvatureFlow: True 46 | percentile: True 47 | pad: True 48 | permute: False 49 | 50 | # Augmentations 51 | randomRotate: False 52 | rotateDegree: 5 53 | horizontalFlip: False 54 | randomBrightness: False 55 | brightnessRange: (0.75,1.25) 56 | randomContrast: False 57 | contrastRange: (0.75,1.25) 58 | 59 | modelpath: ${data_dir}/Data/pretrained_2D_model/ 60 | num_workers: 4 61 | batch_size: 32 62 | lr : 0.0001 63 | droplast: True 64 | 65 | 66 | # Evaluation 67 | mode: t1 68 | resizedEvaluation: True 69 | testsets: # specify which test sets to evaluate! 70 | - Datamodules_eval.Brats21 71 | - Datamodules_eval.MSLUB 72 | - Datamodules_train.IXI 73 | 74 | 75 | -------------------------------------------------------------------------------- /configs/experiment/cDDPM/DDPM.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_full.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # override trainer to null so it's not loaded from main config defaults... 8 | - override /model: DDPM_2D.yaml 9 | - override /datamodule: IXI.yaml 10 | datamodule: 11 | cfg: 12 | rescaleFactor: 2 13 | imageDim: [192,192,100] 14 | mode: t2 15 | aug_intensity: True 16 | model: 17 | cfg: 18 | test_timesteps: 500 19 | dim_mults: [1,2,2] 20 | unet_dim: 128 21 | objective: pred_x0 22 | loss: l1 23 | residualmode: l1 24 | OpenaiUnet: True # use openai unet 25 | conv_resample: True 26 | noisetype: simplex 27 | dropout_unet: 0.0 28 | condition: False 29 | num_folds: 1 30 | logger: 31 | wandb: 32 | project: MIDL23_DDPM 33 | 34 | ckpt_path: best 35 | 36 | trainer: 37 | max_epochs: 1600 38 | precision: 32 39 | name : DDPM_2D 40 | seed: 3141 41 | -------------------------------------------------------------------------------- /configs/experiment/cDDPM/DDPM_cond_spark_2D.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_full.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # override trainer to null so it's not loaded from main config defaults... 8 | - override /model: DDPM_2D.yaml 9 | - override /datamodule: IXI.yaml 10 | datamodule: 11 | cfg: 12 | rescaleFactor: 2 13 | imageDim: [192,192,100] 14 | mode: t2 15 | aug_intensity: True 16 | 17 | model: 18 | cfg: 19 | test_timesteps: 500 20 | dim_mults: [1,2,2] 21 | unet_dim: 128 22 | backbone: Spark_Encoder_2D 23 | version: resnet50 24 | cond_dim: 128 25 | OpenaiUnet: True # use openai unet 26 | spatial_transformer: False # use crossattention for conditional features 27 | condition: True # use conditional features 28 | noisetype: simplex 29 | encoder_path: xxx # path to encoder weights 30 | pretrained_encoder: True 31 | save_to_disc: False 32 | noise_ensemble: True 33 | num_folds: 1 34 | logger: 35 | wandb: 36 | project: cDDPM 37 | 38 | ckpt_path: best 39 | 40 | trainer: 41 | max_epochs: 1200 42 | name : DDPM_cond_2D_spark 43 | seed: 3141 -------------------------------------------------------------------------------- /configs/experiment/cDDPM/DDPM_patched.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_full.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # override trainer to null so it's not loaded from main config defaults... 8 | - override /model: DDPM_2D_patched.yaml 9 | - override /datamodule: IXI.yaml 10 | datamodule: 11 | cfg: 12 | rescaleFactor: 2 13 | imageDim: [192,192,100] 14 | mode: t2 15 | aug_intensity: True 16 | model: 17 | cfg: 18 | test_timesteps: 500 19 | dim_mults: [1,2,2] 20 | unet_dim: 128 21 | objective: pred_x0 22 | loss: l1 23 | residualmode: l1 24 | OpenaiUnet: True # use openai unet 25 | conv_resample: True 26 | noisetype: simplex 27 | dropout_unet: 0.0 28 | patch_size: 48 # size of the patches 29 | grid_boxes: True # sample boxes from a fixed grid 30 | inpaint: True # solve inpainting task -- Loss calculation only for the patched region 31 | condition: False 32 | num_folds: 1 33 | logger: 34 | wandb: 35 | project: MIDL23_DDPM 36 | 37 | ckpt_path: best 38 | 39 | trainer: 40 | max_epochs: 1600 41 | precision: 32 42 | 43 | name : DDPM_2D_patched 44 | seed: 3141 45 | -------------------------------------------------------------------------------- /configs/experiment/cDDPM/Spark_2D_pretrain.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_full.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # override trainer to null so it's not loaded from main config defaults... 8 | - override /model: Spark_2D.yaml 9 | - override /datamodule: IXI.yaml 10 | datamodule: 11 | cfg: 12 | rescaleFactor: 2 13 | imageDim: [192,192,100] 14 | mode: t2 15 | model: 16 | cfg: 17 | backbone: resnet50 18 | loss_on_mask: True 19 | mask_ratio: 0.65 20 | num_folds: 1 21 | logger: 22 | wandb: 23 | project: cDDPM_pretrain 24 | 25 | ckpt_path: best 26 | 27 | trainer: 28 | max_epochs: 1200 29 | name : MAE_2D 30 | seed: 3141 31 | test_after_training: False -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | csv: 3 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 4 | save_dir: "." 5 | name: "csv/" 6 | version: "" 7 | prefix: "" -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: "patched DDPMs MIDL23" 6 | name: ${hydra:job.name} 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | resume: False 11 | log_model: False 12 | prefix: "" 13 | job_type: "" 14 | group: "" 15 | tags: [] 16 | # note: "" -------------------------------------------------------------------------------- /configs/mode/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default running mode 4 | 5 | default_mode: True 6 | 7 | hydra: 8 | # output paths for hydra logs 9 | run: 10 | dir: ${log_dir}/logs/runs/${name}/${model.cfg.name}_${datamodule.cfg.name}_${name}_${hydra.job.override_dirname}_${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${log_dir}/logs/multiruns/${name}/${now:%Y-%m-%d_%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | # you can set here environment variables that are universal for all users 16 | # for system specific variables (like data paths) it's better to use .env file! 17 | job: 18 | name: ${model.cfg.name}_${datamodule.cfg.name}_${name}_${hydra.job.override_dirname} # This determines the 19 | config: 20 | override_dirname: 21 | item_sep: _ 22 | kv_sep: "-" 23 | exclude_keys: 24 | - experiment 25 | - trainer.gpus 26 | - load_checkpoint 27 | - datamodule.cfg.resizedEvaluation 28 | - datamodule.cfg.sample_set 29 | - model.cfg.encoder_path 30 | env_set: 31 | EXAMPLE_VAR: "example_value" 32 | 33 | -------------------------------------------------------------------------------- /configs/model/DDPM_2D.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.DDPM_2D.DDPM_2D 2 | 3 | cfg: 4 | name: DDPM_2D 5 | ## Data 6 | imageDim: ${datamodule.cfg.imageDim} 7 | rescaleFactor: ${datamodule.cfg.rescaleFactor} 8 | interRes: ${datamodule.cfg.interRes} 9 | cropMode: ${datamodule.cfg.cropMode} 10 | spatialDims: 2D 11 | resizedEvaluation: ${datamodule.cfg.resizedEvaluation} 12 | 13 | ## Architecture 14 | unet_dim: 128 15 | dim_mults: [1, 2, 2] 16 | learned_variance: False 17 | learned_sinusoidal_cond: False 18 | 19 | ## Training 20 | loss: 'l1' 21 | lossStrategy: 'mean' 22 | lr: ${datamodule.cfg.lr} 23 | 24 | # LR Scheduling 25 | scheduleLR: False 26 | patienceLR: 10 27 | 28 | # Early Stopping 29 | earlyStopping: False 30 | patienceStopping: 50 31 | 32 | ## Evaluation 33 | saveOutputImages: True 34 | evalSeg: True 35 | 36 | ## General postprocessing 37 | pad: ${datamodule.cfg.pad} 38 | erodeBrainmask: True 39 | medianFiltering: True 40 | threshold: auto # 'auto' for autothresholding, any number for manually setting 41 | mode: ${datamodule.cfg.mode} 42 | -------------------------------------------------------------------------------- /configs/model/DDPM_2D_patched.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.DDPM_2D_patched.DDPM_2D 2 | 3 | cfg: 4 | name: DDPM_patched_2D 5 | ## Data 6 | imageDim: ${datamodule.cfg.imageDim} 7 | rescaleFactor: ${datamodule.cfg.rescaleFactor} 8 | interRes: ${datamodule.cfg.interRes} 9 | cropMode: ${datamodule.cfg.cropMode} 10 | spatialDims: 2D 11 | resizedEvaluation: ${datamodule.cfg.resizedEvaluation} 12 | 13 | ## Architecture 14 | unet_dim: 128 15 | dim_mults: [1, 2, 2] 16 | learned_variance: False 17 | learned_sinusoidal_cond: False 18 | 19 | ## Training 20 | 21 | loss: 'l1' 22 | lossStrategy: 'mean' 23 | lr: ${datamodule.cfg.lr} 24 | 25 | 26 | # LR Scheduling 27 | scheduleLR: False 28 | patienceLR: 10 29 | 30 | # Early Stopping 31 | earlyStopping: False 32 | patienceStopping: 50 33 | 34 | ## Evaluation 35 | saveOutputImages: True 36 | evalSeg: True 37 | 38 | ## General postprocessing 39 | pad: ${datamodule.cfg.pad} 40 | erodeBrainmask: True 41 | medianFiltering: True 42 | threshold: auto # 'auto' for autothresholding, any number for manually setting 43 | 44 | mode: ${datamodule.cfg.mode} 45 | 46 | # patching: 47 | scale_patch: 1 48 | patch_size: 16 49 | patch_stride: 16 50 | objective: pred_x0 51 | 52 | -------------------------------------------------------------------------------- /configs/model/Spark_2D.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.Spark_2D.Spark_2D 2 | 3 | cfg: 4 | name: Spark_2D 5 | ## Data 6 | imageDim: ${datamodule.cfg.imageDim} 7 | 8 | rescaleFactor: ${datamodule.cfg.rescaleFactor} 9 | interRes: ${datamodule.cfg.interRes} 10 | cropMode: ${datamodule.cfg.cropMode} 11 | spatialDims: 2D 12 | resizedEvaluation: ${datamodule.cfg.resizedEvaluation} 13 | 14 | ## Architecture 15 | dropRate: 0.2 16 | unisotropic_sampling: ${datamodule.cfg.unisotropic_sampling} 17 | 18 | mask_ratio: 0.65 19 | uniform: False 20 | pe: False 21 | pix_norm: False 22 | dense_loss: False 23 | loss_l2: True 24 | en_de_norm: 'bn' 25 | en_de_lin: True 26 | sbn: False 27 | pyramid: 4 28 | dp: 0 29 | # decoder 30 | dec_dim: 512 31 | double: True 32 | hea: [0,1] 33 | cmid: 0 34 | 35 | ## Training 36 | lossStrategy: 'mean' 37 | lr: ${datamodule.cfg.lr} 38 | pretrained: False 39 | modelpath: ${datamodule.cfg.modelpath}/mae_pretrain_vit_base.pth 40 | # LR Scheduling 41 | scheduleLR: False 42 | patienceLR: 10 43 | 44 | # Early Stopping 45 | earlyStopping: False 46 | patienceStopping: 50 47 | 48 | ## Evaluation 49 | saveOutputImages: True 50 | evalSeg: True 51 | 52 | ## General postprocessing 53 | pad: ${datamodule.cfg.pad} 54 | erodeBrainmask: True 55 | medianFiltering: True 56 | threshold: auto # 'auto' for autothresholding, any number for manually setting 57 | 58 | 59 | SimCLR_pretraining: False -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: -1 # Specify GPU by CUDA_VISIBLE_DEVICES=0 4 | min_epochs: 1 5 | max_epochs: 800 6 | log_every_n_steps: 5 7 | precision : 16 8 | num_sanity_val_steps : 0 # This does not work with dp, only with ddp 9 | check_val_every_n_epoch : 1 10 | benchmark: True 11 | overfit_batches: False 12 | 13 | -------------------------------------------------------------------------------- /pc_environment.env: -------------------------------------------------------------------------------- 1 | DATA_DIR= 2 | LOG_DIR= -------------------------------------------------------------------------------- /preprocessing/cut.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import numpy as np 3 | import sys 4 | import argparse 5 | import os 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | 9 | def first_nonzero(arr, axis, invalid_val=0): 10 | mask = arr != 0 11 | return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val) 12 | 13 | 14 | def last_nonzero(arr, axis, invalid_val=0): 15 | mask = arr != 0 16 | val = arr.shape[axis] - np.flip(mask, axis=axis).argmax(axis=axis) - 1 17 | return np.where(mask.any(axis=axis), val, invalid_val) 18 | 19 | 20 | def arg_parser(): 21 | parser = argparse.ArgumentParser( 22 | description='Extract mask files from directory') 23 | required = parser.add_argument_group('Required') 24 | required.add_argument('-i', '--img-dir', type=str, required=True, nargs='+', 25 | help='path to directory with images to be processed') 26 | required.add_argument('-m', '--mask-dir', type=str, required=True, 27 | help='mask directory') 28 | required.add_argument('-o', '--output', type=str, required=True, 29 | help='output directory') 30 | required.add_argument('-mode', '--mode', type=str, required=False, default='t1', 31 | help='mode') 32 | return parser 33 | 34 | 35 | def main(args=None): 36 | args = arg_parser().parse_args(args) 37 | try: 38 | Path(args.output).mkdir(parents=True,exist_ok=True) 39 | Path(args.output + '/mask/').mkdir(parents=True,exist_ok=True) 40 | Path(args.output + '/seg/').mkdir(parents=True,exist_ok=True) 41 | Path(args.output + '/' + args.mode).mkdir(parents=True,exist_ok=True) 42 | file_suffix = args.mode 43 | for input_dir in args.img_dir: 44 | if not os.path.isdir(input_dir): 45 | raise ValueError('(-i / --img-dir) argument needs to be a directory of NIfTI images.') 46 | 47 | for i, mask_path in tqdm(enumerate(os.listdir(args.mask_dir))): 48 | try: 49 | if mask_path.endswith('_mask.nii.gz') or mask_path.endswith('_mask.nii'): 50 | if not os.path.isfile(args.output + '/' + file_suffix + '/' + mask_path.replace('mask', file_suffix)) or not os.path.isfile(args.output + '/seg/' + mask_path.replace('mask', 'seg')): 51 | # print('Processing file {} of {}'.format(i+1, len(os.listdir(args.mask_dir)))) 52 | max_dims = [0, 0, 0] 53 | # Load mask 54 | mask_file = nib.load(os.path.join(args.mask_dir, mask_path)) 55 | mask = mask_file.get_fdata() 56 | 57 | # Zero axis 58 | zero_min_indices = first_nonzero(mask, 0, 999999).min() 59 | zero_max_indices = last_nonzero(mask, 0).max() 60 | 61 | # First axis 62 | first_min_indices = first_nonzero(mask, 1, 999999).min() 63 | first_max_indices = last_nonzero(mask, 1).max() 64 | 65 | # Second axis 66 | second_min_indices = first_nonzero(mask, 2, 999999).min() 67 | second_max_indices = last_nonzero(mask, 2).max() 68 | 69 | max_dims = np.maximum(max_dims, [zero_max_indices - zero_min_indices, 70 | first_max_indices - first_min_indices, 71 | second_max_indices - second_min_indices]) 72 | print(max_dims) 73 | # for k, input_dir in enumerate(args.img_dir): 74 | 75 | 76 | # Create path if it does not exist yet 77 | # Path(file_suffix + '-cut').mkdir(exist_ok=True) 78 | 79 | # Construct new file name 80 | file_name = mask_path.replace('mask', file_suffix) 81 | out_name = args.output + '/' + file_suffix + '/' + mask_path.replace('mask', file_suffix) # 82 | mask_name = args.output + '/mask/' + mask_path 83 | seg_name = args.output + '/seg/' + mask_path.replace('mask', 'seg') 84 | # Load volume 85 | vol = nib.load(os.path.join(input_dir, file_name)) 86 | vol_np = vol.get_fdata() 87 | 88 | # Slice 89 | new_vol = vol_np[zero_min_indices:zero_max_indices, first_min_indices:first_max_indices, 90 | second_min_indices:second_max_indices] 91 | 92 | # Create new nifti file and save it 93 | new_vol_file = nib.Nifti1Image(new_vol, affine=vol.affine, header=vol.header) 94 | nib.save(new_vol_file,out_name) 95 | 96 | # Path('mask-cut').mkdir(exist_ok=True) 97 | new_mask_vol = mask[zero_min_indices:zero_max_indices, first_min_indices:first_max_indices, 98 | second_min_indices:second_max_indices] 99 | new_mask_file = nib.Nifti1Image(new_mask_vol, affine=mask_file.affine, header=mask_file.header) 100 | 101 | if os.path.isfile((args.mask_dir+mask_path).replace('mask','seg')) and not os.path.isfile(seg_name): 102 | seg_path = os.path.join(args.mask_dir, mask_path).replace('mask','seg') 103 | seg = nib.load(seg_path) 104 | seg_np = seg.get_fdata() 105 | 106 | # Slice 107 | new_seg = seg_np[zero_min_indices:zero_max_indices, first_min_indices:first_max_indices, 108 | second_min_indices:second_max_indices] 109 | # Create new nifti file and save it 110 | new_seg_file = nib.Nifti1Image(new_seg, affine=vol.affine, header=vol.header) 111 | nib.save(new_seg_file,seg_name) 112 | # print(mask_name) 113 | if not os.path.isfile(mask_name): 114 | 115 | nib.save(new_mask_file, mask_name) 116 | except: 117 | print('error') 118 | print('Maximum dimensions are {}'.format(max_dims)) 119 | return 0 120 | except Exception as e: 121 | print(e) 122 | return 1 123 | 124 | 125 | if __name__ == "__main__": 126 | sys.exit(main(sys.argv[1:])) 127 | -------------------------------------------------------------------------------- /preprocessing/extract_masks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | from pathlib import Path 5 | 6 | 7 | def arg_parser(): 8 | parser = argparse.ArgumentParser( 9 | description='Extract mask files from directory') 10 | required = parser.add_argument_group('Required') 11 | required.add_argument('-i', '--img-dir', type=str, required=True, 12 | help='path to directory with images to be processed') 13 | required.add_argument('-o', '--out-dir', type=str, required=True, 14 | help='output directory for preprocessed files') 15 | return parser 16 | 17 | 18 | def main(args=None): 19 | args = arg_parser().parse_args(args) 20 | try: 21 | if not os.path.isdir(args.img_dir): 22 | raise ValueError('(-i / --img-dir) argument needs to be a directory of NIfTI images.') 23 | Path(args.out_dir).mkdir(exist_ok=True) 24 | for file in os.listdir(args.img_dir): 25 | if file.endswith('_mask.nii.gz') or file.endswith('_mask.nii'): 26 | os.rename(os.path.join(args.img_dir, file), os.path.join(args.out_dir, file)) 27 | return 0 28 | except Exception as e: 29 | print(e) 30 | return 1 31 | 32 | 33 | if __name__ == "__main__": 34 | sys.exit(main(sys.argv[1:])) 35 | 36 | 37 | -------------------------------------------------------------------------------- /preprocessing/get_mask.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import sys 3 | import argparse 4 | import os 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | import ants 8 | def arg_parser(): 9 | parser = argparse.ArgumentParser( 10 | description='create binary mask for 3D volumes in NIfTI format') 11 | parser.add_argument('-i', '--img-dir', type=str, required=True, 12 | help='path to directory with images to be processed') 13 | parser.add_argument('-o', '--out-dir', type=str, required=False, default='tmp', 14 | help='output directory for preprocessed files') 15 | parser.add_argument('-mod', '--modality', type=str, required=True, 16 | help='output directory for preprocessed files') 17 | 18 | return parser 19 | 20 | 21 | def main(args=None): 22 | args = arg_parser().parse_args(args) 23 | try: 24 | if not os.path.isdir(args.img_dir): 25 | raise ValueError('(-i / --img-dir) argument needs to be a directory of NIfTI images.') 26 | Path(args.out_dir).mkdir(parents=True, exist_ok=True) 27 | for i, file in tqdm(enumerate(os.listdir(args.img_dir))): 28 | # print('Processing file {} of {}'.format(i + 1, len(os.listdir(args.img_dir)))) 29 | if file.endswith('.nii.gz') or file.endswith('.nii'): 30 | vol_file = ants.image_read(os.path.join(args.img_dir, file)) 31 | mask = ants.get_mask(vol_file) 32 | ants.image_write(mask, os.path.join(args.out_dir,file.replace(args.modality,'mask'))) 33 | 34 | if args.out_dir == 'tmp': 35 | os.remove(os.path.join(args.img_dir, file)) 36 | os.rename(os.path.join(args.out_dir, file), os.path.join(args.img_dir, file)) 37 | # Delete tmo directory 38 | if args.out_dir == 'tmp': 39 | os.rmdir(args.out_dir) 40 | return 0 41 | except Exception as e: 42 | print(e) 43 | return 1 44 | 45 | 46 | if __name__ == "__main__": 47 | sys.exit(main(sys.argv[1:])) 48 | -------------------------------------------------------------------------------- /preprocessing/n4filter.py: -------------------------------------------------------------------------------- 1 | import ants 2 | import sys 3 | import argparse 4 | import os 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | def arg_parser(): 9 | parser = argparse.ArgumentParser( 10 | description='Reorient 3D volumes in NIfTI format to RAS') 11 | parser.add_argument('-i', '--img-dir', type=str, required=True, 12 | help='path to directory with images to be processed') 13 | parser.add_argument('-o', '--out-dir', type=str, required=False, default='tmp', 14 | help='output directory for preprocessed files') 15 | parser.add_argument('-m', '--mask-dir', type=str, required=False, default=None, 16 | help='mask directory for preprocessed files') 17 | return parser 18 | 19 | 20 | def main(args=None): 21 | args = arg_parser().parse_args(args) 22 | # try: 23 | if not os.path.isdir(args.img_dir): 24 | raise ValueError('(-i / --img-dir) argument needs to be a directory of NIfTI images.') 25 | # Create output dir if it does not exist yet 26 | Path(args.out_dir).mkdir(parents=True,exist_ok=True) 27 | # Define n4 filter options 28 | n4_opts = {'iters': [200, 200, 200, 200], 'tol': 0.0005} 29 | # Iterate through input directory 30 | 31 | for i, file in enumerate(tqdm(os.listdir(args.img_dir))): 32 | if not os.path.isfile(os.path.join(args.out_dir, file)): 33 | # print('Processing file {} of {}'.format(i + 1, len(os.listdir(args.img_dir)))) 34 | if file.endswith('.nii.gz') or file.endswith('.nii'): 35 | # Define file path 36 | file_path = os.path.join(args.img_dir, file) 37 | # Read img and mask 38 | img = ants.image_read(file_path) 39 | if args.mask_dir is not None: 40 | try: 41 | mask = ants.image_read(os.path.join(args.mask_dir, file.replace('_t1ce', '_mask').replace('_t2', '_mask').replace('_t1', '_mask').replace('_flair', '_mask').replace('_FLAIR', '_mask').replace('_dwi', '_mask'))) 42 | # Smooth mask 43 | smoothed_mask = ants.smooth_image(mask, 1) 44 | except: 45 | smoothed_mask = None 46 | print('no mask') 47 | # Perform bias field correction 48 | img = ants.n4_bias_field_correction(img, convergence=n4_opts, weight_mask=smoothed_mask) 49 | else: 50 | img = ants.n4_bias_field_correction(img, convergence=n4_opts) 51 | # Write output img 52 | 53 | ants.image_write(img, os.path.join(args.out_dir,file)) 54 | if args.out_dir == 'tmp': 55 | os.remove(os.path.join(args.img_dir, file)) 56 | os.rename(os.path.join(args.out_dir, file), os.path.join(args.img_dir, file)) 57 | # Delete tmo directory 58 | if args.out_dir == 'tmp': 59 | os.rmdir(args.out_dir) 60 | return 0 61 | # except Exception as e: 62 | # print(e) 63 | # return 1 64 | 65 | 66 | if __name__ == "__main__": 67 | sys.exit(main(sys.argv[1:])) 68 | -------------------------------------------------------------------------------- /preprocessing/prepare_Brats21.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # cli arguments: 3 | # 1. path to data directory 4 | # 2. path to output directory 5 | INPUT_DIR=$1 6 | DATA_DIR=$2 7 | 8 | # make the arguments mandatory and that the data dir is not a relative path 9 | if [ -z "$INPUT_DIR" ] || [ -z "$DATA_DIR" ] 10 | then 11 | echo "Usage: ./prepare_MSLUB.sh " 12 | exit 1 13 | fi 14 | 15 | if [ "$INPUT_DIR" == "." ] || [ "$INPUT_DIR" == ".." ] 16 | then 17 | echo "Please use absolute paths for input_dir" 18 | exit 1 19 | fi 20 | # For BRATS, we already have resampled, skull-stripped data 21 | 22 | mkdir -p $DATA_DIR/v2skullstripped/Brats21/ 23 | mkdir -p $DATA_DIR/v2skullstripped/Brats21/mask 24 | 25 | cp -r $INPUT_DIR/t2 $INPUT_DIR/seg $DATA_DIR/v2skullstripped/Brats21/ 26 | 27 | echo "extract masks" 28 | python get_mask.py -i $DATA_DIR/v2skullstripped/Brats21/t2 -o $DATA_DIR/v2skullstripped/Brats21/t2 -mod t2 29 | python extract_masks.py -i $DATA_DIR/v2skullstripped/Brats21/t2 -o $DATA_DIR/v2skullstripped/Brats21/mask 30 | python replace.py -i $DATA_DIR/v2skullstripped/Brats21/mask -s " _t2" "" 31 | 32 | echo "Register t2" 33 | python registration.py -i $DATA_DIR/v2skullstripped/Brats21/t2 -o $DATA_DIR/v3registered_non_iso/Brats21/t2 --modality=_t2 -trans Affine -templ sri_atlas/templates/T1_brain.nii 34 | 35 | 36 | echo "Cut to brain" 37 | python cut.py -i $DATA_DIR/v3registered_non_iso/Brats21/t2 -m $DATA_DIR/v3registered_non_iso/Brats21/mask/ -o $DATA_DIR/v3registered_non_iso_cut/Brats21/ -mode t2 38 | 39 | echo "Bias Field Correction" 40 | python n4filter.py -i $DATA_DIR/v3registered_non_iso_cut/Brats21/t2 -o $DATA_DIR/v4correctedN4_non_iso_cut/Brats21/t2 -m $DATA_DIR/v4correctedN4_non_iso_cut/Brats21/mask 41 | mkdir $DATA_DIR/v4correctedN4_non_iso_cut/Brats21/mask 42 | cp $DATA_DIR/v3registered_non_iso_cut/Brats21/mask/* $DATA_DIR/v4correctedN4_non_iso_cut/Brats21/mask 43 | mkdir $DATA_DIR/v4correctedN4_non_iso_cut/Brats21/seg 44 | cp $DATA_DIR/v3registered_non_iso_cut/Brats21/seg/* $DATA_DIR/v4correctedN4_non_iso_cut/Brats21/seg 45 | echo "Done" 46 | 47 | # now, you should copy the files in the output directory to the data directory of the project 48 | -------------------------------------------------------------------------------- /preprocessing/prepare_IXI.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # cli arguments: 3 | # 1. path to data directory 4 | # 2. path to output directory 5 | INPUT_DIR=$1 6 | DATA_DIR=$2 7 | 8 | # make the arguments mandatory and that the data dir is not a relative path 9 | if [ -z "$INPUT_DIR" ] || [ -z "$DATA_DIR" ] 10 | then 11 | echo "Usage: ./prepare_IXI.sh " 12 | exit 1 13 | fi 14 | 15 | if [ "$INPUT_DIR" == "." ] || [ "$INPUT_DIR" == ".." ] 16 | then 17 | echo "Please use absolute paths for input_dir" 18 | exit 1 19 | fi 20 | 21 | echo "Resample" 22 | mkdir -p $DATA_DIR/v1resampled/IXI/t2 23 | python resample.py -i $INPUT_DIR/t2 -o $DATA_DIR/v1resampled/IXI/t2 -r 1.0 1.0 1.0 24 | # rename files for standard naming 25 | for file in $DATA_DIR/v1resampled/IXI/t2/* 26 | do 27 | mv "$file" "${file%-T2.nii.gz}_t2.nii.gz" 28 | done 29 | 30 | echo "Generate masks" 31 | CUDA_VISIBLE_DEVICES=0 hd-bet -i $DATA_DIR/v1resampled/IXI/t2 -o $DATA_DIR/v2skullstripped/IXI/t2 32 | python extract_masks.py -i $DATA_DIR/v2skullstripped/IXI/t2 -o $DATA_DIR/v2skullstripped/IXI/mask 33 | python replace.py -i $DATA_DIR/v2skullstripped/IXI/mask -s " _t2" "" 34 | 35 | echo "Register t2" 36 | python registration.py -i $DATA_DIR/v2skullstripped/IXI/t2 -o $DATA_DIR/v3registered_non_iso/IXI/t2 --modality=_t2 -trans Affine -templ sri_atlas/templates/T1_brain.nii 37 | 38 | echo "Cut to brain" 39 | python cut.py -i $DATA_DIR/v3registered_non_iso/IXI/t2 -m $DATA_DIR/v3registered_non_iso/IXI/mask/ -o $DATA_DIR/v3registered_non_iso_cut/IXI/ -mode t2 40 | 41 | echo "Bias Field Correction" 42 | python n4filter.py -i $DATA_DIR/v3registered_non_iso_cut/IXI/t2 -o $DATA_DIR/v4correctedN4_non_iso_cut/IXI/t2 -m $DATA_DIR/v3registered_non_iso_cut/IXI/mask 43 | 44 | mkdir $DATA_DIR/v4correctedN4_non_iso_cut/IXI/mask 45 | cp $DATA_DIR/v3registered_non_iso_cut/IXI/mask/* $DATA_DIR/v4correctedN4_non_iso_cut/IXI/mask 46 | echo "Done" 47 | 48 | # now, you should copy the files in the output directory to the data directory of the project 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /preprocessing/prepare_MSLUB.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # cli arguments: 3 | # 1. path to data directory 4 | # 2. path to output directory 5 | INPUT_DIR=$1 6 | DATA_DIR=$2 7 | 8 | # make the arguments mandatory and that the data dir is not a relative path 9 | if [ -z "$INPUT_DIR" ] || [ -z "$DATA_DIR" ] 10 | then 11 | echo "Usage: ./prepare_MSLUB.sh " 12 | exit 1 13 | fi 14 | 15 | if [ "$INPUT_DIR" == "." ] || [ "$INPUT_DIR" == ".." ] 16 | then 17 | echo "Please use absolute paths for input_dir" 18 | exit 1 19 | fi 20 | 21 | echo "Resample" 22 | mkdir -p $DATA_DIR/v1resampled/MSLUB/t2 23 | python resample.py -i $INPUT_DIR/t2 -o $DATA_DIR/v1resampled/MSLUB/t2 -r 1.0 1.0 1.0 24 | ## rename files for standard naming 25 | for file in $DATA_DIR/v1resampled/MSLUB/t2/* 26 | do 27 | mv "$file" "${file%_T2W.nii.gz}_t2.nii.gz" 28 | done 29 | 30 | echo "Generate masks" 31 | # mkdir -p $DATA_DIR/v2skullstripped/MSLUB/t2 32 | CUDA_VISIBLE_DEVICES=0 hd-bet -i $DATA_DIR/v1resampled/MSLUB/t2 -o $DATA_DIR/v2skullstripped/MSLUB/t2 # --overwrite_existing=0 33 | python extract_masks.py -i $DATA_DIR/v2skullstripped/MSLUB/t2 -o $DATA_DIR/v2skullstripped/MSLUB/mask 34 | python replace.py -i $DATA_DIR/v2skullstripped/MSLUB/mask -s " _t2" "" 35 | 36 | # copy segmentation masks to the data directory 37 | mkdir -p $DATA_DIR/v2skullstripped/MSLUB/seg 38 | cp -r $INPUT_DIR/seg/* $DATA_DIR/v2skullstripped/MSLUB/seg/ 39 | 40 | for file in $DATA_DIR/v2skullstripped/MSLUB/seg/* 41 | do 42 | mv "$file" "${file%consensus_gt.nii.gz}seg.nii.gz" 43 | done 44 | 45 | 46 | echo "Register t2" 47 | python registration.py -i $DATA_DIR/v2skullstripped/MSLUB/t2 -o $DATA_DIR/v3registered_non_iso/MSLUB/t2 --modality=_t2 -trans Affine -templ sri_atlas/templates/T1_brain.nii 48 | 49 | echo "Cut to brain" 50 | python cut.py -i $DATA_DIR/v3registered_non_iso/MSLUB/t2 -m $DATA_DIR/v3registered_non_iso/MSLUB/mask/ -o $DATA_DIR/v3registered_non_iso_cut/MSLUB/ -mode t2 51 | 52 | echo "Bias Field Correction" 53 | python n4filter.py -i $DATA_DIR/v3registered_non_iso_cut/MSLUB/t2 -o $DATA_DIR/v4correctedN4_non_iso_cut/MSLUB/t2 -m $DATA_DIR/v3registered_non_iso_cut/MSLUB/mask 54 | mkdir $DATA_DIR/v4correctedN4_non_iso_cut/MSLUB/mask 55 | cp $DATA_DIR/v3registered_non_iso_cut/MSLUB/mask/* $DATA_DIR/v4correctedN4_non_iso_cut/MSLUB/mask 56 | mkdir $DATA_DIR/v4correctedN4_non_iso_cut/MSLUB/seg 57 | cp $DATA_DIR/v3registered_non_iso_cut/MSLUB/seg/* $DATA_DIR/v4correctedN4_non_iso_cut/MSLUB/seg 58 | echo "Done" 59 | 60 | 61 | # now, you should copy the files in the output directory to the data directory of the project 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /preprocessing/registration.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from tqdm import tqdm 6 | import argparse 7 | import os 8 | import sys 9 | from pathlib import Path 10 | import ants 11 | 12 | def arg_parser(): 13 | parser = argparse.ArgumentParser( 14 | description='Resample 3D volumes in NIfTI format') 15 | parser.add_argument('-i', '--img-dir', type=str, required=True, 16 | help='path to directory with images to be processed') 17 | parser.add_argument('-modal', '--modality', type=str, required=False, default='_t1', 18 | help='t1, t2, FLAIR, ...') 19 | parser.add_argument('-o', '--out-dir', type=str, required=False, default='tmp', 20 | help='output directory for preprocessed files') 21 | parser.add_argument('-r', '--resolution', type=float, required=False, nargs=3, default=[1.0, 1.0, 1.0], 22 | help='target resolution') 23 | parser.add_argument('-or', '--orientation', type=str, required=False, default='RAS', 24 | help='target orientation') 25 | parser.add_argument('-inter', '--interpolation', type=int, required=False, default=4, 26 | help='target orientation') 27 | parser.add_argument('-nomask', '--nomaskandseg', type=int, required=False, default=0, 28 | help='set to one if you can reuse the masks and segmentations of other modalities') 29 | parser.add_argument('-trans', '--transform', type=str, required=False, default='Rigid', 30 | help='specify the transformation') 31 | parser.add_argument('-templ', '--template', type=str, required=True, 32 | help='path to template') 33 | return parser 34 | 35 | def main(args=None): 36 | args = arg_parser().parse_args(args) 37 | src_basepath = args.img_dir # 38 | dest_basepath = args.out_dir # 39 | 40 | 41 | if not os.path.isdir(args.img_dir): 42 | raise ValueError('(-i / --img-dir) argument needs to be a directory of NIfTI images.') 43 | 44 | Path(args.out_dir).mkdir(parents=True,exist_ok=True) 45 | 46 | fixed_im = ants.image_read(args.template) 47 | fixed_im = fixed_im.reorient_image2('RAI') 48 | 49 | if os.path.isdir(os.path.join(Path(args.img_dir).parents[0],'seg')) and args.nomaskandseg !=1: 50 | seg_path = os.path.join(Path(args.img_dir).parents[0],'seg') 51 | seg_out = os.path.join(Path(args.out_dir).parents[0],'seg') 52 | Path(seg_out).mkdir(parents=True,exist_ok=True) 53 | else: 54 | seg_path = None 55 | 56 | if os.path.isdir(os.path.join(Path(args.img_dir).parents[0],'mask')) and args.nomaskandseg !=1: 57 | mask_path = os.path.join(Path(args.img_dir).parents[0],'mask') 58 | mask_out = os.path.join(Path(args.out_dir).parents[0],'mask') 59 | Path(mask_out).mkdir(parents=True,exist_ok=True) 60 | else: 61 | mask_path = None 62 | 63 | for i, file in tqdm(enumerate(os.listdir(args.img_dir))): 64 | if not os.path.isfile(os.path.join(dest_basepath, file)) or not os.path.isfile(os.path.join(mask_out, file.replace(args.modality,'_mask'))) or not os.path.isfile(os.path.join(seg_out, file.replace(args.modality,'_seg'))): 65 | path_img = os.path.join(args.img_dir, file) 66 | moving_im = ants.image_read(path_img) 67 | # register to template 68 | im_tx = ants.registration(fixed=fixed_im, moving=moving_im, type_of_transform = args.transform ) 69 | moved_im = im_tx['warpedmovout'] 70 | ants.image_write(moved_im, os.path.join(dest_basepath, file)) 71 | 72 | if mask_path is not None: 73 | path_mask = os.path.join(mask_path, file.replace(args.modality,'_mask')) 74 | moving_mask = ants.image_read(path_mask) 75 | # transform the mask in the same way as the volume 76 | moved_mask = ants.apply_transforms( fixed=fixed_im, moving=moving_mask, 77 | transformlist=im_tx['fwdtransforms'] ) 78 | ants.image_write(moved_mask, os.path.join(mask_out, file.replace(args.modality,'_mask'))) 79 | 80 | if seg_path is not None: 81 | path_seg = os.path.join(seg_path, file.replace(args.modality,'_seg')) 82 | moving_seg = ants.image_read(path_seg) 83 | moving_seg = moving_seg.reorient_image2('RAI') # reorient to standard orientation 84 | # transform the segmentation in the same way as the volume 85 | moved_seg = ants.apply_transforms( fixed=fixed_im, moving=moving_seg, 86 | transformlist=im_tx['fwdtransforms'] ) 87 | ants.image_write(moved_seg, os.path.join(seg_out, file.replace(args.modality,'_seg'))) 88 | 89 | 90 | 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | sys.exit(main(sys.argv[1:])) 96 | 97 | 98 | -------------------------------------------------------------------------------- /preprocessing/replace.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import sys 3 | import argparse 4 | import os 5 | 6 | 7 | def arg_parser(): 8 | parser = argparse.ArgumentParser( 9 | description='Replace string from 3D volumes file names in NIfTI format') 10 | required = parser.add_argument_group('Required') 11 | required.add_argument('-i', '--img-dir', type=str, required=True, 12 | help='path to directory with images to be processed') 13 | required.add_argument('-s', '--string', type=str, required=True, nargs=2, 14 | help='string to replaced') 15 | return parser 16 | 17 | 18 | def main(args=None): 19 | args = arg_parser().parse_args(args) 20 | try: 21 | if not os.path.isdir(args.img_dir): 22 | raise ValueError('(-i / --img-dir) argument needs to be a directory of NIfTI images.') 23 | for file in os.listdir(args.img_dir): 24 | if file.endswith('.nii.gz') or file.endswith('.nii'): 25 | os.rename(os.path.join(args.img_dir, file), os.path.join(args.img_dir, file).replace(args.string[0].lstrip(), args.string[1].lstrip())) 26 | return 0 27 | except Exception as e: 28 | print(e) 29 | return 1 30 | 31 | 32 | if __name__ == "__main__": 33 | sys.exit(main(sys.argv[1:])) 34 | 35 | 36 | -------------------------------------------------------------------------------- /preprocessing/resample.py: -------------------------------------------------------------------------------- 1 | import ants 2 | import nibabel as nib 3 | 4 | import argparse 5 | import os 6 | import sys 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | 10 | def arg_parser(): 11 | parser = argparse.ArgumentParser( 12 | description='Resample 3D volumes in NIfTI format') 13 | parser.add_argument('-i', '--img-dir', type=str, required=True, 14 | help='path to directory with images to be processed') 15 | parser.add_argument('-o', '--out-dir', type=str, required=False, default='tmp', 16 | help='output directory for preprocessed files') 17 | parser.add_argument('-r', '--resolution', type=float, required=False, nargs=3, default=[1.0, 1.0, 1.0], 18 | help='target resolution') 19 | parser.add_argument('-or', '--orientation', type=str, required=False, default='RAI', 20 | help='target orientation') 21 | parser.add_argument('-inter', '--interpolation', type=int, required=False, default=4, 22 | help='target orientation') 23 | return parser 24 | 25 | 26 | def main(args=None): 27 | args = arg_parser().parse_args(args) 28 | try: 29 | if not os.path.isdir(args.img_dir): 30 | raise ValueError('(-i / --img-dir) argument needs to be a directory of NIfTI images.') 31 | Path(args.out_dir).mkdir(parents=True ,exist_ok=True) 32 | for i, file in tqdm(enumerate(os.listdir(args.img_dir))): 33 | # print('Processing file {} of {}'.format(i + 1, len(os.listdir(args.img_dir)))) 34 | if file.endswith('.nii.gz') or file.endswith('.nii'): 35 | if not os.path.isfile(os.path.join(args.out_dir, file)): 36 | try: 37 | vol = ants.image_read(os.path.join(args.img_dir, file)) 38 | if args.resolution != vol.spacing: 39 | vol = ants.resample_image(vol, args.resolution, False, args.interpolation) 40 | vol = vol.reorient_image2(args.orientation) 41 | ants.image_write(vol, os.path.join(args.out_dir, file)) 42 | if args.out_dir == 'tmp': 43 | os.remove(os.path.join(args.img_dir, file)) 44 | os.rename(os.path.join(args.out_dir, file), os.path.join(args.img_dir, file)) 45 | except Exception as e: 46 | print('An error occurred: {}'.format(str(e))) 47 | print('Could not process file {}. Moving on...'.format(file)) 48 | # Delete tmp directory 49 | if args.out_dir == 'tmp': 50 | os.rmdir(args.out_dir) 51 | return 0 52 | except Exception as e: 53 | print(e) 54 | return 1 55 | 56 | 57 | if __name__ == "__main__": 58 | sys.exit(main(sys.argv[1:])) 59 | -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/EPI.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/EPI.nii -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/EPI_brain.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/EPI_brain.nii -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/PD.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/PD.nii -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/PD_brain.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/PD_brain.nii -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/T1.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/T1.nii -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/T1_brain.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/T1_brain.nii -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/T2.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/T2.nii -------------------------------------------------------------------------------- /preprocessing/sri_atlas/templates/T2_brain.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/preprocessing/sri_atlas/templates/T2_brain.nii -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.7.1 2 | albumentations==1.1.0 3 | ants==0.0.7 4 | antspyx==0.3.5 5 | clip==0.2.0 6 | einops==0.4.1 7 | einops_exts==0.0.3 8 | ema_pytorch==0.0.8 9 | grad_cam==1.3.7 10 | h5py==3.6.0 11 | hydra-core==1.1.0 12 | imageio==2.10.3 13 | kornia==0.6.7 14 | monai==1.0.1 15 | natsort==8.2.0 16 | nibabel==3.2.1 17 | numba==0.54.0 18 | numpy==1.20.3 19 | omegaconf==2.1.2 20 | opencv_python==4.5.4.60 21 | pandas==1.3.4 22 | Pillow==9.4.0 23 | python-dotenv==0.21.1 24 | pytorch_lightning==1.5.1 25 | PyYAML==6.0 26 | requests==2.26.0 27 | rich==13.3.1 28 | rotary_embedding_torch==0.1.5 29 | scikit_image==0.18.3 30 | scikit_learn==1.0.1 31 | scipy==1.7.2 32 | seaborn==0.11.2 33 | SimpleITK==2.2.1 34 | timm==0.6.7 35 | torch==1.10.0 36 | torchio==0.18.73 37 | torchmetrics==0.8.1 38 | torchvision==0.11.1 39 | tqdm==4.62.3 40 | transformers==4.26.1 41 | ttach==0.0.3 42 | video_diffusion_pytorch==0.5.3 43 | wandb==0.12.19 44 | matplotlib==3.5.2 45 | hydra-colorlog 46 | 47 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # template: https://github.com/ashleve/lightning-hydra-template/blob/main/run.py 2 | import dotenv 3 | import hydra 4 | from omegaconf import DictConfig 5 | import os 6 | import sys 7 | import socket 8 | #import multiprocessing as mp 9 | sys.setrecursionlimit(2000) 10 | # load environment variables from `.env` file if it exists 11 | # recursively searches for `.env` in all folders starting from work dir 12 | dir_path = os.path.dirname(os.path.realpath(__file__)) 13 | dotenv.load_dotenv(dir_path+'/pc_environment.env',override=True) 14 | # dotenv.dotenv_values("pc_environment.env") 15 | 16 | 17 | @hydra.main(config_path="configs/", config_name="config.yaml") 18 | def main(config: DictConfig): 19 | # import torch 20 | # torch.multiprocessing.set_sharing_strategy('file_system') 21 | # Imports should be nested inside @hydra.main to optimize tab completion 22 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 23 | from src.train import train 24 | from src.utils import utils 25 | 26 | # A couple of optional utilities: 27 | # - disabling python warnings 28 | # - easier access to debug mode 29 | # - forcing debug friendly configuration 30 | # You can safely get rid of this line if you don't want those 31 | utils.extras(config) 32 | 33 | # Pretty print config using Rich library 34 | if config.get("print_config"): 35 | utils.print_config(config, resolve=True) 36 | 37 | # Train model 38 | return train(config) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/train.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/__pycache__/train.cpython-39.pyc -------------------------------------------------------------------------------- /src/datamodules/Datamodules_eval.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, random_split 2 | from pytorch_lightning import LightningDataModule 3 | from typing import Optional 4 | import pandas as pd 5 | import src.datamodules.create_dataset as create_dataset 6 | 7 | 8 | class Brats21(LightningDataModule): 9 | 10 | def __init__(self, cfg, fold= None): 11 | super(Brats21, self).__init__() 12 | self.cfg = cfg 13 | self.preload = cfg.get('preload',True) 14 | # load data paths and indices 15 | self.imgpath = {} 16 | self.csvpath_val = cfg.path.Brats21.IDs.val 17 | self.csvpath_test = cfg.path.Brats21.IDs.test 18 | self.csv = {} 19 | states = ['val','test'] 20 | 21 | self.csv['val'] = pd.read_csv(self.csvpath_val) 22 | self.csv['test'] = pd.read_csv(self.csvpath_test) 23 | for state in states: 24 | self.csv[state]['settype'] = state 25 | self.csv[state]['setname'] = 'Brats21' 26 | 27 | self.csv[state]['img_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['img_path'] 28 | self.csv[state]['mask_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['mask_path'] 29 | self.csv[state]['seg_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['seg_path'] 30 | 31 | if cfg.mode != 't1': 32 | self.csv[state]['img_path'] = self.csv[state]['img_path'].str.replace('t1',cfg.mode).str.replace('FLAIR.nii.gz',f'{cfg.mode.lower()}.nii.gz') 33 | 34 | def setup(self, stage: Optional[str] = None): 35 | # called on every GPU 36 | if not hasattr(self,'val_eval'): 37 | if self.cfg.sample_set: # for debugging 38 | self.val_eval = create_dataset.Eval(self.csv['val'][0:8], self.cfg) 39 | self.test_eval = create_dataset.Eval(self.csv['test'][0:8], self.cfg) 40 | else : 41 | self.val_eval = create_dataset.Eval(self.csv['val'], self.cfg) 42 | self.test_eval = create_dataset.Eval(self.csv['test'], self.cfg) 43 | 44 | def val_dataloader(self): 45 | return DataLoader(self.val_eval, batch_size=1, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=False) 46 | 47 | def test_dataloader(self): 48 | return DataLoader(self.test_eval, batch_size=1, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=False) 49 | 50 | 51 | 52 | class MSLUB(LightningDataModule): 53 | 54 | def __init__(self, cfg, fold= None): 55 | super(MSLUB, self).__init__() 56 | self.cfg = cfg 57 | self.preload = cfg.get('preload',True) 58 | # load data paths and indices 59 | self.imgpath = {} 60 | self.csvpath_val = cfg.path.MSLUB.IDs.val 61 | self.csvpath_test = cfg.path.MSLUB.IDs.test 62 | self.csv = {} 63 | states = ['val','test'] 64 | 65 | self.csv['val'] = pd.read_csv(self.csvpath_val) 66 | self.csv['test'] = pd.read_csv(self.csvpath_test) 67 | for state in states: 68 | self.csv[state]['settype'] = state 69 | self.csv[state]['setname'] = 'MSLUB' 70 | 71 | self.csv[state]['img_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['img_path'] 72 | self.csv[state]['mask_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['mask_path'] 73 | self.csv[state]['seg_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['seg_path'] 74 | 75 | if cfg.mode != 't1': 76 | self.csv[state]['img_path'] = self.csv[state]['img_path'].str.replace('uniso/t1',f'uniso/{cfg.mode}').str.replace('t1.nii.gz',f'{cfg.mode}.nii.gz') 77 | def setup(self, stage: Optional[str] = None): 78 | # called on every GPU 79 | if not hasattr(self,'val_eval'): 80 | if self.cfg.sample_set: # for debugging 81 | self.val_eval = create_dataset.Eval(self.csv['val'][0:4], self.cfg) 82 | self.test_eval = create_dataset.Eval(self.csv['test'][0:4], self.cfg) 83 | else : 84 | self.val_eval = create_dataset.Eval(self.csv['val'], self.cfg) 85 | self.test_eval = create_dataset.Eval(self.csv['test'], self.cfg) 86 | 87 | def val_dataloader(self): 88 | return DataLoader(self.val_eval, batch_size=1, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=False) 89 | 90 | def test_dataloader(self): 91 | return DataLoader(self.test_eval, batch_size=1, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=False) 92 | -------------------------------------------------------------------------------- /src/datamodules/Datamodules_train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, random_split 2 | from pytorch_lightning import LightningDataModule 3 | import src.datamodules.create_dataset as create_dataset 4 | from typing import Optional 5 | import pandas as pd 6 | 7 | 8 | class IXI(LightningDataModule): 9 | 10 | def __init__(self, cfg, fold = None): 11 | super(IXI, self).__init__() 12 | self.cfg = cfg 13 | self.preload = cfg.get('preload',True) 14 | # load data paths and indices 15 | # IXI 16 | 17 | self.cfg.permute = False # no permutation for IXI 18 | 19 | 20 | self.imgpath = {} 21 | self.csvpath_train = cfg.path.IXI.IDs.train[fold] 22 | self.csvpath_val = cfg.path.IXI.IDs.val[fold] 23 | self.csvpath_test = cfg.path.IXI.IDs.test 24 | self.csv = {} 25 | states = ['train','val','test'] 26 | 27 | self.csv['train'] = pd.read_csv(self.csvpath_train) 28 | self.csv['val'] = pd.read_csv(self.csvpath_val) 29 | self.csv['test'] = pd.read_csv(self.csvpath_test) 30 | if cfg.mode == 't2': 31 | keep_t2 = pd.read_csv(cfg.path.IXI.keep_t2) # only keep t2 images that have a t1 counterpart 32 | 33 | for state in states: 34 | self.csv[state]['settype'] = state 35 | self.csv[state]['setname'] = 'IXI' 36 | 37 | 38 | self.csv[state]['img_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['img_path'] 39 | self.csv[state]['mask_path'] = cfg.path.pathBase + '/Data/' + self.csv[state]['mask_path'] 40 | self.csv[state]['seg_path'] = None 41 | 42 | if cfg.mode == 't2': 43 | self.csv[state] = self.csv[state][self.csv[state].img_name.isin(keep_t2['0'].str.replace('t2','t1'))] 44 | self.csv[state]['img_path'] = self.csv[state]['img_path'].str.replace('t1','t2') 45 | 46 | def setup(self, stage: Optional[str] = None): 47 | # called on every GPU 48 | if not hasattr(self,'train'): 49 | if self.cfg.sample_set: # for debugging 50 | self.train = create_dataset.Train(self.csv['train'][0:50],self.cfg) 51 | self.val = create_dataset.Train(self.csv['val'][0:50],self.cfg) 52 | self.val_eval = create_dataset.Eval(self.csv['val'][0:8],self.cfg) 53 | self.test_eval = create_dataset.Eval(self.csv['test'][0:8],self.cfg) 54 | else: 55 | self.train = create_dataset.Train(self.csv['train'],self.cfg) 56 | self.val = create_dataset.Train(self.csv['val'],self.cfg) 57 | self.val_eval = create_dataset.Eval(self.csv['val'],self.cfg) 58 | self.test_eval = create_dataset.Eval(self.csv['test'],self.cfg) 59 | 60 | def train_dataloader(self): 61 | return DataLoader(self.train, batch_size=self.cfg.batch_size, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=True, drop_last=self.cfg.get('droplast',False)) 62 | 63 | def val_dataloader(self): 64 | return DataLoader(self.val, batch_size=self.cfg.batch_size, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=False) 65 | 66 | def val_eval_dataloader(self): 67 | return DataLoader(self.val_eval, batch_size=1, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=False) 68 | 69 | def test_eval_dataloader(self): 70 | return DataLoader(self.test_eval, batch_size=1, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=False) 71 | 72 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/datamodules/__init__.py -------------------------------------------------------------------------------- /src/datamodules/__pycache__/Datamodules_eval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/datamodules/__pycache__/Datamodules_eval.cpython-39.pyc -------------------------------------------------------------------------------- /src/datamodules/__pycache__/Datamodules_train.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/datamodules/__pycache__/Datamodules_train.cpython-39.pyc -------------------------------------------------------------------------------- /src/datamodules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/datamodules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/datamodules/__pycache__/create_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/datamodules/__pycache__/create_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /src/models/LDM/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/models/__pycache__/autoencoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/__pycache__/autoencoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/diffusion/__init__.py -------------------------------------------------------------------------------- /src/models/LDM/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/models/diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/models/diffusion/__pycache__/ddim.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/diffusion/__pycache__/ddim.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/models/diffusion/__pycache__/ddpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/models/diffusion/__pycache__/ddpm.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/__pycache__/x_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/__pycache__/x_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/diffusionmodules/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/diffusionmodules/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/distributions/__init__.py -------------------------------------------------------------------------------- /src/models/LDM/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/distributions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/distributions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/distributions/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/distributions/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /src/models/LDM/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /src/models/LDM/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/encoders/__init__.py -------------------------------------------------------------------------------- /src/models/LDM/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/encoders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/encoders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/encoders/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/encoders/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from src.models.LDM.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=2, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | 138 | class FrozenCLIPTextEmbedder(nn.Module): 139 | """ 140 | Uses the CLIP transformer encoder for text. 141 | """ 142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 143 | super().__init__() 144 | self.model, _ = clip.load(version, jit=False, device="cpu") 145 | self.device = device 146 | self.max_length = max_length 147 | self.n_repeat = n_repeat 148 | self.normalize = normalize 149 | 150 | def freeze(self): 151 | self.model = self.model.eval() 152 | for param in self.parameters(): 153 | param.requires_grad = False 154 | 155 | def forward(self, text): 156 | tokens = clip.tokenize(text).to(self.device) 157 | z = self.model.encode_text(tokens) 158 | if self.normalize: 159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 160 | return z 161 | 162 | def encode(self, text): 163 | z = self(text) 164 | if z.ndim==2: 165 | z = z[:, None, :] 166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 167 | return z 168 | 169 | 170 | class FrozenClipImageEmbedder(nn.Module): 171 | """ 172 | Uses the CLIP image encoder. 173 | """ 174 | def __init__( 175 | self, 176 | model, 177 | jit=False, 178 | device='cuda' if torch.cuda.is_available() else 'cpu', 179 | antialias=False, 180 | ): 181 | super().__init__() 182 | self.model, _ = clip.load(name=model, device=device, jit=jit) 183 | 184 | self.antialias = antialias 185 | 186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 188 | 189 | def preprocess(self, x): 190 | # normalize to [0,1] 191 | x = kornia.geometry.resize(x, (224, 224), 192 | interpolation='bicubic',align_corners=True, 193 | antialias=self.antialias) 194 | x = (x + 1.) / 2. 195 | # renormalize according to clip 196 | x = kornia.enhance.normalize(x, self.mean, self.std) 197 | return x 198 | 199 | def forward(self, x): 200 | # x is assumed to be in range [-1,1] 201 | return self.model.encode_image(self.preprocess(x)) 202 | 203 | -------------------------------------------------------------------------------- /src/models/LDM/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from src.models.LDM.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from src.models.LDM.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /src/models/LDM/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /src/models/LDM/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from src.models.LDM.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /src/models/LDM/modules/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/losses/__pycache__/contperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/losses/__pycache__/contperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/losses/__pycache__/vqperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/LDM/modules/losses/__pycache__/vqperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/LDM/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /src/models/LDM/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log -------------------------------------------------------------------------------- /src/models/LDM/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | if 'Covid' in config["target"]: 86 | return get_obj_from_str(config["target"])(config) 87 | else: 88 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 89 | 90 | 91 | def get_obj_from_str(string, reload=False): 92 | module, cls = string.rsplit(".", 1) 93 | if reload: 94 | module_imp = importlib.import_module(module) 95 | importlib.reload(module_imp) 96 | return getattr(importlib.import_module(module, package=None), cls) 97 | 98 | 99 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 100 | # create dummy dataset instance 101 | 102 | # run prefetching 103 | if idx_to_fn: 104 | res = func(data, worker_id=idx) 105 | else: 106 | res = func(data) 107 | Q.put([idx, res]) 108 | Q.put("Done") 109 | 110 | 111 | def parallel_data_prefetch( 112 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 113 | ): 114 | # if target_data_type not in ["ndarray", "list"]: 115 | # raise ValueError( 116 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 117 | # ) 118 | if isinstance(data, np.ndarray) and target_data_type == "list": 119 | raise ValueError("list expected but function got ndarray.") 120 | elif isinstance(data, abc.Iterable): 121 | if isinstance(data, dict): 122 | print( 123 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 124 | ) 125 | data = list(data.values()) 126 | if target_data_type == "ndarray": 127 | data = np.asarray(data) 128 | else: 129 | data = list(data) 130 | else: 131 | raise TypeError( 132 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 133 | ) 134 | 135 | if cpu_intensive: 136 | Q = mp.Queue(1000) 137 | proc = mp.Process 138 | else: 139 | Q = Queue(1000) 140 | proc = Thread 141 | # spawn processes 142 | if target_data_type == "ndarray": 143 | arguments = [ 144 | [func, Q, part, i, use_worker_id] 145 | for i, part in enumerate(np.array_split(data, n_proc)) 146 | ] 147 | else: 148 | step = ( 149 | int(len(data) / n_proc + 1) 150 | if len(data) % n_proc != 0 151 | else int(len(data) / n_proc) 152 | ) 153 | arguments = [ 154 | [func, Q, part, i, use_worker_id] 155 | for i, part in enumerate( 156 | [data[i: i + step] for i in range(0, len(data), step)] 157 | ) 158 | ] 159 | processes = [] 160 | for i in range(n_proc): 161 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 162 | processes += [p] 163 | 164 | # start processes 165 | print(f"Start prefetching...") 166 | import time 167 | 168 | start = time.time() 169 | gather_res = [[] for _ in range(n_proc)] 170 | try: 171 | for p in processes: 172 | p.start() 173 | 174 | k = 0 175 | while k < n_proc: 176 | # get result 177 | res = Q.get() 178 | if res == "Done": 179 | k += 1 180 | else: 181 | gather_res[res[0]] = res[1] 182 | 183 | except Exception as e: 184 | print("Exception: ", e) 185 | for p in processes: 186 | p.terminate() 187 | 188 | raise e 189 | finally: 190 | for p in processes: 191 | p.join() 192 | print(f"Prefetching complete. [{time.time() - start} sec.]") 193 | 194 | if target_data_type == 'ndarray': 195 | if not isinstance(gather_res[0], np.ndarray): 196 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 197 | 198 | # order outputs 199 | return np.concatenate(gather_res, axis=0) 200 | elif target_data_type == 'list': 201 | out = [] 202 | for r in gather_res: 203 | out.extend(r) 204 | return out 205 | else: 206 | return gather_res 207 | -------------------------------------------------------------------------------- /src/models/Spark_2D.py: -------------------------------------------------------------------------------- 1 | from src.models.modules.spark.Spark_2D import SparK_2D 2 | from src.models.losses import L1_AE 3 | import torch 4 | from src.utils.utils_eval import _test_step, _test_end, get_eval_dictionary 5 | import numpy as np 6 | from pytorch_lightning.core.lightning import LightningModule 7 | import torch.optim as optim 8 | from typing import Any 9 | import torchio as tio 10 | 11 | 12 | class Spark_2D(LightningModule): 13 | def __init__(self,cfg,prefix=None): 14 | super().__init__() 15 | self.cfg = cfg 16 | # Model 17 | self.model = SparK_2D(cfg) 18 | self.L1 = L1_AE(cfg) 19 | 20 | self.prefix = prefix 21 | self.save_hyperparameters() 22 | 23 | def forward(self, x): 24 | active_ex, reco, loss, latent = self.model(x) 25 | if self.cfg.get('loss_on_mask', False): # loss is calculated only on the masked patches 26 | loss = loss 27 | else: 28 | loss = self.L1({'x_hat':reco},x)['recon_error'] + self.cfg.get('delta_mask',0) * loss 29 | return loss, reco, latent[0].mean([2,3]) 30 | 31 | def training_step(self, batch, batch_idx: int): 32 | # process batch 33 | input = batch['vol'][tio.DATA].squeeze(-1) # add dimension for channel 34 | loss, reco, latent = self(input) # loss, reconstruction, latent 35 | 36 | self.log(f'{self.prefix}train/Loss_comb', loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=input.shape[0],sync_dist=True) 37 | 38 | return {"loss": loss} # , 'latent_space': z} 39 | 40 | 41 | 42 | def validation_step(self, batch: Any, batch_idx: int): 43 | input = batch['vol'][tio.DATA].squeeze(-1) # add dimension for channel 44 | loss, reco, _ = self(input) 45 | # log val metrics 46 | self.log(f'{self.prefix}val/Loss_comb', loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=input.shape[0],sync_dist=True) 47 | return {"loss": loss} 48 | 49 | def on_test_start(self): 50 | self.eval_dict = get_eval_dictionary() 51 | self.new_size = [160,190,160] 52 | self.diffs_list = [] 53 | self.seg_list = [] 54 | if not hasattr(self,'threshold'): 55 | self.threshold = {} 56 | 57 | def test_step(self, batch: Any, batch_idx: int): 58 | self.dataset = batch['Dataset'] 59 | input = batch['vol'][tio.DATA] 60 | data_orig = batch['vol_orig'][tio.DATA] 61 | data_seg = batch['seg_orig'][tio.DATA] if batch['seg_available'] else torch.zeros_like(data_orig) 62 | data_mask = batch['mask_orig'][tio.DATA] 63 | ID = batch['ID'] 64 | self.stage = batch['stage'] 65 | label = batch['label'] 66 | AnomalyScoreReco = [] 67 | 68 | 69 | if self.cfg.get('num_eval_slices', input.size(4)) != input.size(4): 70 | num_slices = self.cfg.get('num_eval_slices', input.size(4)) # number of center slices to evaluate. If not set, the whole Volume is evaluated 71 | start_slice = int((input.size(4) - num_slices) / 2) 72 | 73 | input = input[...,start_slice:start_slice+num_slices] 74 | data_orig = data_orig[...,start_slice:start_slice+num_slices] 75 | data_seg = data_seg[...,start_slice:start_slice+num_slices] 76 | data_mask = data_mask[...,start_slice:start_slice+num_slices] 77 | 78 | final_volume = torch.zeros([input.size(2), input.size(3), input.size(4)], device = self.device) 79 | 80 | # reorder depth to batch dimension 81 | assert input.shape[0] == 1, "Batch size must be 1" 82 | input = input.squeeze(0).permute(3,0,1,2) # [B,C,H,W,D] -> [D,C,H,W] 83 | 84 | # compute reconstruction 85 | loss, output_slice, _ = self(input) 86 | 87 | # calculate loss and Anomalyscores 88 | AnomalyScoreReco.append(loss.item()) 89 | 90 | # reassamble the reconstruction volume 91 | final_volume = output_slice.squeeze().permute(1,2,0) # to HxWxD 92 | 93 | 94 | # average across slices to get volume-based scores 95 | AnomalyScoreReco_vol = np.mean(AnomalyScoreReco) 96 | 97 | 98 | self.eval_dict['AnomalyScoreRegPerVol'].append(0) 99 | 100 | 101 | if not self.cfg.get('use_postprocessed_score', True): 102 | self.eval_dict['AnomalyScoreRecoPerVol'].append(AnomalyScoreReco_vol) 103 | self.eval_dict['AnomalyScoreCombPerVol'].append(0) 104 | self.eval_dict['AnomalyScoreCombiPerVol'].append(0) 105 | self.eval_dict['AnomalyScoreCombPriorPerVol'].append(0) 106 | self.eval_dict['AnomalyScoreCombiPriorPerVol'].append(0) 107 | 108 | final_volume = final_volume.unsqueeze(0) 109 | final_volume = final_volume.unsqueeze(0) 110 | 111 | # calculate metrics 112 | _test_step(self, final_volume, data_orig, data_seg, data_mask, batch_idx, ID, label) # everything that is independent of the model choice 113 | 114 | 115 | def on_test_end(self) : 116 | # calculate metrics 117 | _test_end(self) # everything that is independent of the model choice 118 | 119 | 120 | def configure_optimizers(self): 121 | return optim.AdamW(self.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.get('weight_decay', 0.05), betas=[0.9,0.95]) 122 | 123 | def update_prefix(self, prefix): 124 | self.prefix = prefix -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/__pycache__/DDPM_2D.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/__pycache__/DDPM_2D.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/Spark_2D.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/__pycache__/Spark_2D.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/losses.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/__pycache__/losses.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class L1_AE(torch.nn.Module): 4 | def __init__(self, cfg) : 5 | super().__init__() 6 | self.strat = cfg.lossStrategy 7 | 8 | def forward(self, output_batch, input_batch) : 9 | if isinstance(output_batch, dict): 10 | output_batch = output_batch['x_hat'] 11 | else: 12 | output_batch = output_batch 13 | if self.strat == 'sum' : 14 | L1Loss = nn.L1Loss(reduction = 'sum') 15 | L1 = L1Loss(output_batch, input_batch)/input_batch.shape[0] 16 | elif self.strat == 'mean' : 17 | L1Loss = nn.L1Loss(reduction = 'mean') 18 | L1 = L1Loss(output_batch, input_batch) 19 | loss = {} 20 | loss['combined_loss'] = L1 21 | loss['reg'] = L1 # dummy 22 | loss['recon_error'] = L1 23 | return loss 24 | -------------------------------------------------------------------------------- /src/models/modules/DDPM_encoder.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torchvision 4 | from src.models.modules.spark.Spark_2D import SparK_2D_encoder 5 | 6 | def get_encoder(cfg): 7 | """ 8 | Available backbones (some of them): 9 | Resnet: 10 | resnet18, 11 | resnet34, 12 | resnet50, 13 | resnet101 14 | """ 15 | backbone = cfg.get('backbone','resnet50') 16 | chans = 1 17 | if 'spark' in backbone.lower(): # spark encoder 18 | encoder = SparK_2D_encoder(cfg) 19 | else : # 2D CNN encoder 20 | encoder = timm.create_model(backbone, pretrained=cfg.pretrained_backbone, in_chans=chans, num_classes = cfg.get('cond_dim',256) ) 21 | 22 | out_features = cfg.get('cond_dim',256) # much adaptive.. 23 | 24 | 25 | 26 | return encoder, out_features -------------------------------------------------------------------------------- /src/models/modules/__pycache__/DDPM_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/__pycache__/DDPM_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/__pycache__/OpenAI_Unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/__pycache__/OpenAI_Unet.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/__pycache__/cond_DDPM.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/__pycache__/cond_DDPM.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/spark/__pycache__/Spark_2D.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/spark/__pycache__/Spark_2D.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/spark/__pycache__/decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/spark/__pycache__/decoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/spark/__pycache__/encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/spark/__pycache__/encoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/spark/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/spark/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/spark/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/models/modules/spark/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/modules/spark/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from timm.models.layers import trunc_normal_, DropPath, Mlp 10 | import torch.nn as nn 11 | 12 | def is_pow2n(x): 13 | return x > 0 and (x & (x - 1) == 0) 14 | _BN = None 15 | 16 | 17 | class UNetBlock2x(nn.Module): 18 | def __init__(self, cin, cout, cmid, last_act=True): 19 | super().__init__() 20 | if cmid == 0: 21 | c_mid = cin 22 | elif cmid == 1: 23 | c_mid = (cin + cout) // 2 24 | 25 | self.b = nn.Sequential( 26 | nn.Conv2d(cin, c_mid, 3, 1, 1, bias=False), _BN(c_mid), nn.ReLU6(inplace=True), 27 | nn.Conv2d(c_mid, cout, 3, 1, 1, bias=False), _BN(cout), (nn.ReLU6(inplace=True) if last_act else nn.Identity()), 28 | ) 29 | 30 | def forward(self, x): 31 | return self.b(x) 32 | 33 | 34 | class DecoderConv(nn.Module): 35 | def __init__(self, cin, cout, double, heavy, cmid): 36 | super().__init__() 37 | self.up = nn.ConvTranspose2d(cin, cin, kernel_size=4 if double else 2, stride=2, padding=1 if double else 0, bias=True) 38 | ls = [UNetBlock2x(cin, (cin if i != heavy[1]-1 else cout), cmid=cmid, last_act=i != heavy[1]-1) for i in range(heavy[1])] 39 | self.conv = nn.Sequential(*ls) 40 | 41 | def forward(self, x): 42 | x = self.up(x) 43 | return self.conv(x) 44 | 45 | 46 | class LightDecoder(nn.Module): 47 | def __init__(self, decoder_fea_dim, upsample_ratio, double=False, heavy=None, cmid=0, sbn=False): 48 | global _BN 49 | _BN = nn.SyncBatchNorm if sbn else nn.BatchNorm2d 50 | super().__init__() 51 | self.fea_dim = decoder_fea_dim 52 | if heavy is None: 53 | heavy = [0, 1] 54 | heavy[1] = max(1, heavy[1]) 55 | self.double_bool = double 56 | self.heavy = heavy 57 | self.cmid = cmid 58 | self.sbn = sbn 59 | 60 | assert is_pow2n(upsample_ratio) 61 | n = round(math.log2(upsample_ratio)) 62 | channels = [self.fea_dim // 2**i for i in range(n+1)] 63 | self.dec = nn.ModuleList([ 64 | DecoderConv(cin, cout, double, heavy, cmid) for (cin, cout) in zip(channels[:-1], channels[1:]) 65 | ]) 66 | self.proj = nn.Conv2d(channels[-1], 1, kernel_size=1, stride=1, bias=True) 67 | 68 | self.initialize() 69 | 70 | def forward(self, to_dec): 71 | x = 0 72 | for i, d in enumerate(self.dec): 73 | if i < len(to_dec) and to_dec[i] is not None: 74 | x = x + to_dec[i] 75 | x = self.dec[i](x) 76 | return self.proj(x) 77 | 78 | def num_para(self): 79 | tot = sum(p.numel() for p in self.parameters()) 80 | 81 | para1 = para2 = 0 82 | for m in self.dec.modules(): 83 | if isinstance(m, nn.ConvTranspose2d): 84 | para1 += sum(p.numel() for p in m.parameters()) 85 | elif isinstance(m, nn.Conv2d): 86 | para2 += sum(p.numel() for p in m.parameters()) 87 | return f'#para: {tot/1e6:.2f} (dconv={para1/1e6:.2f}, conv={para2/1e6:.2f}, ot={(tot-para1-para2)/1e6:.2f})' 88 | 89 | def extra_repr(self) -> str: 90 | return f'fea_dim={self.fea_dim}, dbl={self.double_bool}, heavy={self.heavy}, cmid={self.cmid}, sbn={self.sbn}' 91 | 92 | def initialize(self): 93 | for m in self.modules(): 94 | if isinstance(m, nn.Linear): 95 | trunc_normal_(m.weight, std=.02) 96 | if m.bias is not None: 97 | nn.init.constant_(m.bias, 0) 98 | elif isinstance(m, nn.Embedding): 99 | trunc_normal_(m.weight, std=.02) 100 | if m.padding_idx is not None: 101 | m.weight.data[m.padding_idx].zero_() 102 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)): 103 | nn.init.constant_(m.bias, 0) 104 | nn.init.constant_(m.weight, 1.0) 105 | elif isinstance(m, nn.Conv2d): 106 | trunc_normal_(m.weight, std=.02) 107 | if m.bias is not None: 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 111 | if m.bias is not None: 112 | nn.init.constant_(m.bias, 0.) -------------------------------------------------------------------------------- /src/models/modules/spark/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from timm.models.layers import DropPath 10 | 11 | 12 | _cur_active: torch.Tensor = None # B1ff 13 | def _get_active_ex_or_ii(H, returning_active_ex=True): 14 | downsample_raito = H // _cur_active.shape[-1] 15 | active_ex = _cur_active.repeat_interleave(downsample_raito, 2).repeat_interleave(downsample_raito, 3) 16 | return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi 17 | 18 | 19 | def sp_conv_forward(self, x: torch.Tensor): 20 | x = super(type(self), self).forward(x) 21 | x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv 22 | return x 23 | 24 | 25 | def sp_bn_forward(self, x: torch.Tensor): 26 | ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False) 27 | 28 | bhwc = x.permute(0, 2, 3, 1) 29 | nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc` 30 | nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc` 31 | 32 | bchw = torch.zeros_like(bhwc) 33 | bchw[ii] = nc 34 | bchw = bchw.permute(0, 3, 1, 2) 35 | return bchw 36 | 37 | 38 | class SparseConv2d(nn.Conv2d): 39 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details 40 | 41 | 42 | class SparseMaxPooling(nn.MaxPool2d): 43 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details 44 | 45 | 46 | class SparseAvgPooling(nn.AvgPool2d): 47 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details 48 | 49 | 50 | class SparseBatchNorm2d(nn.BatchNorm1d): 51 | forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details 52 | 53 | 54 | class SparseSyncBatchNorm2d(nn.SyncBatchNorm): 55 | forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details 56 | 57 | 58 | class SparseConvNeXtLayerNorm(nn.LayerNorm): 59 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 60 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 61 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 62 | with shape (batch_size, channels, height, width). 63 | """ 64 | 65 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True): 66 | if data_format not in ["channels_last", "channels_first"]: 67 | raise NotImplementedError 68 | super().__init__(normalized_shape, eps, elementwise_affine=True) 69 | self.data_format = data_format 70 | self.sparse = sparse 71 | 72 | def forward(self, x): 73 | if x.ndim == 4: # BHWC 74 | if self.data_format == "channels_last": 75 | if self.sparse: 76 | ii = _get_active_ex_or_ii(H=x.shape[1], returning_active_ex=False) 77 | nc = x[ii] 78 | nc = super(SparseConvNeXtLayerNorm, self).forward(nc) 79 | 80 | x = torch.zeros_like(x) 81 | x[ii] = nc 82 | return x 83 | else: 84 | return super(SparseConvNeXtLayerNorm, self).forward(x) 85 | else: # channels_first 86 | if self.sparse: 87 | ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False) 88 | bhwc = x.permute(0, 2, 3, 1) 89 | nc = bhwc[ii] 90 | nc = super(SparseConvNeXtLayerNorm, self).forward(nc) 91 | 92 | x = torch.zeros_like(bhwc) 93 | x[ii] = nc 94 | return x.permute(0, 3, 1, 2) 95 | else: 96 | u = x.mean(1, keepdim=True) 97 | s = (x - u).pow(2).mean(1, keepdim=True) 98 | x = (x - u) / torch.sqrt(s + self.eps) 99 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 100 | return x 101 | else: # BLC or BC 102 | if self.sparse: 103 | raise NotImplementedError 104 | else: 105 | return super(SparseConvNeXtLayerNorm, self).forward(x) 106 | 107 | def __repr__(self): 108 | return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})' 109 | 110 | 111 | class SparseConvNeXtBlock(nn.Module): 112 | r""" ConvNeXt Block. There are two equivalent implementations: 113 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 114 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 115 | We use (2) as we find it slightly faster in PyTorch 116 | 117 | Args: 118 | dim (int): Number of input channels. 119 | drop_path (float): Stochastic depth rate. Default: 0.0 120 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 121 | """ 122 | 123 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7): 124 | super().__init__() 125 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv 126 | self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse) 127 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 128 | self.act = nn.GELU() 129 | self.pwconv2 = nn.Linear(4 * dim, dim) 130 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 131 | requires_grad=True) if layer_scale_init_value > 0 else None 132 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 133 | self.sparse = sparse 134 | 135 | def forward(self, x): 136 | input = x 137 | x = self.dwconv(x) 138 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 139 | x = self.norm(x) 140 | x = self.pwconv1(x) 141 | x = self.act(x) 142 | x = self.pwconv2(x) 143 | if self.gamma is not None: 144 | x = self.gamma * x 145 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 146 | 147 | if self.sparse: 148 | x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) 149 | 150 | x = input + self.drop_path(x) 151 | return x 152 | 153 | def __repr__(self): 154 | return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})' 155 | 156 | 157 | class SparseEncoder(nn.Module): 158 | def __init__(self, cnn, input_size, downsample_raito, encoder_fea_dim, verbose=False, sbn=False): 159 | super(SparseEncoder, self).__init__() 160 | self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn) 161 | self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim 162 | 163 | @staticmethod 164 | def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False): 165 | oup = m 166 | if isinstance(m, nn.Conv2d): 167 | m: nn.Conv2d 168 | bias = m.bias is not None 169 | oup = SparseConv2d( 170 | m.in_channels, m.out_channels, 171 | kernel_size=m.kernel_size, stride=m.stride, padding=m.padding, 172 | dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode, 173 | ) 174 | oup.weight.data.copy_(m.weight.data) 175 | if bias: 176 | oup.bias.data.copy_(m.bias.data) 177 | elif isinstance(m, nn.MaxPool2d): 178 | m: nn.MaxPool2d 179 | oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode) 180 | elif isinstance(m, nn.AvgPool2d): 181 | m: nn.AvgPool2d 182 | oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override) 183 | elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): 184 | m: nn.BatchNorm2d 185 | oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats) 186 | oup.weight.data.copy_(m.weight.data) 187 | oup.bias.data.copy_(m.bias.data) 188 | oup.running_mean.data.copy_(m.running_mean.data) 189 | oup.running_var.data.copy_(m.running_var.data) 190 | oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data) 191 | if hasattr(m, "qconfig"): 192 | oup.qconfig = m.qconfig 193 | elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm): 194 | m: nn.LayerNorm 195 | oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps) 196 | oup.weight.data.copy_(m.weight.data) 197 | oup.bias.data.copy_(m.bias.data) 198 | elif isinstance(m, (nn.Conv1d,)): 199 | raise NotImplementedError 200 | 201 | for name, child in m.named_children(): 202 | oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn)) 203 | del m 204 | return oup 205 | 206 | def forward(self, x, pyramid): 207 | return self.sp_cnn(x, pyramid=pyramid) -------------------------------------------------------------------------------- /src/models/modules/spark/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from timm import create_model 9 | from timm.loss import SoftTargetCrossEntropy 10 | from timm.models.layers import drop 11 | 12 | import torchvision 13 | from src.models.modules.spark.resnet import ResNet 14 | _import_resnets_for_timm_registration = (ResNet,) 15 | 16 | 17 | 18 | # log more 19 | def _ex_repr(self): 20 | return ', '.join( 21 | f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) 22 | for k, v in vars(self).items() 23 | if not k.startswith('_') and k != 'training' 24 | and not isinstance(v, (torch.nn.Module, torch.Tensor)) 25 | ) 26 | for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): 27 | if hasattr(clz, 'extra_repr'): 28 | clz.extra_repr = _ex_repr 29 | else: 30 | clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' 31 | 32 | 33 | model_alias_to_fullname = { 34 | 'res18': 'resnet18', 35 | 'res34': 'resnet34', 36 | 'res50': 'resnet50', 37 | 'res101': 'resnet101', 38 | 'res152': 'resnet152', 39 | 'res200': 'resnet200', 40 | 'cnxS': 'convnext_small', 41 | 'cnxB': 'convnext_base', 42 | 'cnxL': 'convnext_large', 43 | } 44 | model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()} 45 | 46 | 47 | pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel 48 | 'resnet18': [dict(drop_path_rate=0.05), 11.7, 1.8, 32, 512], 49 | 'resnet34': [dict(drop_path_rate=0.05), 21.8, 3.7, 32, 512], 50 | 'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048], 51 | 'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048], 52 | 'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048], 53 | 'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048], 54 | 'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768], 55 | 'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024], 56 | 'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536], 57 | } 58 | for v in pre_train_d.values(): 59 | v[0]['pretrained'] = False 60 | v[0]['num_classes'] = 0 61 | v[0]['global_pool'] = '' 62 | 63 | 64 | def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): 65 | from src.models.modules.spark.encoder import SparseEncoder 66 | 67 | kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name] 68 | if drop_path_rate != 0: 69 | kwargs['drop_path_rate'] = drop_path_rate 70 | print(f'[sparse_cnn] model kwargs={kwargs}') 71 | cnn = create_model(name,in_chans=1, **kwargs) 72 | if hasattr(cnn, 'global_pool'): 73 | if callable(cnn.global_pool): 74 | cnn.global_pool = torch.nn.Identity() 75 | elif isinstance(cnn.global_pool, str): 76 | cnn.global_pool = '' 77 | 78 | if not isinstance(downsample_raito, int) or not isinstance(fea_dim, int): 79 | with torch.no_grad(): 80 | cnn.eval() 81 | o = cnn(torch.rand(1, 3, input_size, input_size)) 82 | downsample_raito = input_size // o.shape[-1] 83 | fea_dim = o.shape[1] 84 | cnn.train() 85 | print(f'[sparse_cnn] downsample_raito={downsample_raito}, fea_dim={fea_dim}') 86 | 87 | return SparseEncoder(cnn, input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, verbose=verbose, sbn=sbn) 88 | 89 | def build_encoder(name: str, cond_dim:int, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): 90 | 91 | kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name] 92 | if drop_path_rate != 0: 93 | kwargs['drop_path_rate'] = drop_path_rate 94 | if 'global_pool' in kwargs: 95 | kwargs.pop('global_pool') 96 | kwargs['num_classes'] = cond_dim 97 | print(f'[sparse_cnn] model kwargs={kwargs}') 98 | cnn = create_model(name,in_chans=1, **kwargs) 99 | 100 | if not isinstance(downsample_raito, int) or not isinstance(fea_dim, int): 101 | with torch.no_grad(): 102 | cnn.eval() 103 | o = cnn(torch.rand(1, 3, input_size, input_size)) 104 | downsample_raito = input_size // o.shape[-1] 105 | fea_dim = o.shape[1] 106 | cnn.train() 107 | print(f'[sparse_cnn] downsample_raito={downsample_raito}, fea_dim={fea_dim}') 108 | 109 | return cnn 110 | 111 | 112 | -------------------------------------------------------------------------------- /src/models/modules/spark/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch.nn.functional as F 10 | from timm.models.resnet import ResNet 11 | 12 | 13 | def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4 14 | x = self.conv1(x) 15 | x = self.bn1(x) 16 | x = self.act1(x) 17 | x = self.maxpool(x) 18 | 19 | ls = [] 20 | x = self.layer1(x) 21 | if pyramid: ls.append(x) 22 | x = self.layer2(x) 23 | if pyramid: ls.append(x) 24 | x = self.layer3(x) 25 | if pyramid: ls.append(x) 26 | x = self.layer4(x) 27 | if pyramid: ls.append(x) 28 | 29 | if pyramid: 30 | for i in range(len(ls)-pyramid-1, -1, -1): 31 | del ls[i] 32 | return [None] * (4 - pyramid) + ls 33 | else: 34 | return x 35 | 36 | 37 | def forward(self, x, pyramid=0): 38 | if pyramid == 0: 39 | x = self.forward_features(x, pyramid=pyramid) 40 | x = self.global_pool(x) 41 | if self.drop_rate: 42 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 43 | x = self.fc(x) 44 | return x 45 | else: 46 | return self.forward_features(x, pyramid=pyramid) 47 | 48 | 49 | def resnets_get_layer_id_and_scale_exp(self, para_name: str): 50 | # stages: 51 | # 50 : [3, 4, 6, 3] 52 | # 101 : [3, 4, 23, 3] 53 | # 152 : [3, 8, 36, 3] 54 | # 200 : [3, 24, 36, 3] 55 | # eca269d: [3, 30, 48, 8] 56 | 57 | L2, L3 = len(self.layer2), len(self.layer3) 58 | if L2 == 4 and L3 == 6: 59 | blk2, blk3 = 2, 3 60 | elif L2 == 4 and L3 == 23: 61 | blk2, blk3 = 2, 3 62 | elif L2 == 8 and L3 == 36: 63 | blk2, blk3 = 4, 4 64 | elif L2 == 24 and L3 == 36: 65 | blk2, blk3 = 4, 4 66 | elif L2 == 30 and L3 == 48: 67 | blk2, blk3 = 5, 6 68 | else: 69 | raise NotImplementedError 70 | 71 | N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5) 72 | N = 2 + N2 + N3 73 | if para_name.startswith('layer'): # 1, 2, 3, 4, 5 74 | stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1]) 75 | if stage_id == 1: 76 | layer_id = 1 77 | elif stage_id == 2: 78 | layer_id = 2 + block_id // blk2 # 2, 3 79 | elif stage_id == 3: 80 | layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11 81 | else: # == 4 82 | layer_id = N # r50: 6 r101: 12 83 | elif para_name.startswith('fc.'): 84 | layer_id = N+1 # r50: 7 r101: 13 85 | else: 86 | layer_id = 0 87 | 88 | return layer_id, N+1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0 89 | 90 | 91 | ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp 92 | ResNet.forward_features = forward_features 93 | ResNet.forward = forward 94 | 95 | 96 | if __name__ == '__main__': 97 | import torch 98 | from timm.models import create_model 99 | r = create_model('resnet50') 100 | with torch.no_grad(): 101 | print(r(torch.rand(2, 3, 224, 224)).shape) 102 | print(r(torch.rand(2, 3, 224, 224), pyramid=1)) 103 | print(r(torch.rand(2, 3, 224, 224), pyramid=2)) 104 | print(r(torch.rand(2, 3, 224, 224), pyramid=3)) 105 | print(r(torch.rand(2, 3, 224, 224), pyramid=4)) -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._extrapolate2-235.py39.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._extrapolate2-235.py39.1.nbc -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._extrapolate2-235.py39.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._extrapolate2-235.py39.nbi -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._extrapolate2-250.py39.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._extrapolate2-250.py39.1.nbc -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._extrapolate2-250.py39.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._extrapolate2-250.py39.nbi -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2-251.py39.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2-251.py39.1.nbc -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2-251.py39.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2-251.py39.nbi -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2-266.py39.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2-266.py39.1.nbc -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2-266.py39.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2-266.py39.nbi -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2a-352.py39.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2a-352.py39.1.nbc -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2a-352.py39.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2a-352.py39.nbi -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2a-367.py39.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2a-367.py39.1.nbc -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise._noise2a-367.py39.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise._noise2a-367.py39.nbi -------------------------------------------------------------------------------- /src/utils/__pycache__/generate_noise.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/generate_noise.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/patch_sampling.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/patch_sampling.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/utils_eval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FinnBehrendt/Conditioned-Diffusion-Models-UAD/c5904a71ffd76900a0fd05da9d8a571fbeb2a146/src/utils/__pycache__/utils_eval.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/patch_sampling.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | class BoxSampler(): 5 | """ 6 | Sample patches from an image 7 | """ 8 | def __init__(self, cfg): 9 | self.patch_size = cfg.get('patch_size',16) 10 | self.stride = self.patch_size # default stride is patch size 11 | self.overlap = cfg.get('overlap',False) # default is no overlap 12 | 13 | 14 | 15 | def sample_single_box(self, image): 16 | """ 17 | sample a random bounding box from an image 18 | Args: 19 | image (torch.tensor): 2D image of shape [batch, channel, height, width] 20 | Returns: 21 | bounding box (torch.tensor): bounding box of shape [batch, x_min, x_max, y_min, y_max] 22 | """ 23 | # get image size 24 | batch_size, channel, height, width = image.shape 25 | 26 | # checks 27 | if isinstance(self.patch_size, int): 28 | self.patch_size = [self.patch_size, self.patch_size] 29 | if self.patch_size[1] > height or self.patch_size[0] > width: 30 | raise ValueError('Patch size is larger than image size') 31 | # sample random box 32 | x_min = torch.randint(0, width , (batch_size, 1)) 33 | y_min = torch.randint(0, height , (batch_size, 1)) 34 | x_max = x_min + self.patch_size[0] 35 | y_max = y_min + self.patch_size[1] 36 | 37 | # create bounding box 38 | box = torch.stack((x_min,y_min,x_max,y_max),dim=1) 39 | return box 40 | 41 | def sample_grid(self, image): 42 | """ 43 | sample a grid of bounding boxes from an image 44 | Args: 45 | image (torch.tensor): 2D image of shape [batch, channel, height, width] 46 | Returns: 47 | bounding box (torch.tensor): bounding box of shape [batch, num_boxes, x_min, x_max, y_min, y_max] 48 | """ 49 | # get image size 50 | batch_size, channel, height, width = image.shape 51 | 52 | # checks 53 | if isinstance(self.patch_size, int): 54 | self.patch_size = [self.patch_size, self.patch_size] 55 | if self.patch_size[1] > height or self.patch_size[0] > width: 56 | raise ValueError('Patch size is larger than image size') 57 | 58 | # sample random box 59 | x_min = torch.arange(0, width, self.stride).repeat(batch_size,1) 60 | y_min = torch.arange(0, height, self.stride).repeat(batch_size,1) 61 | 62 | if self.overlap : # adjust the grid to equally distribute the patches 63 | n_y = len(y_min[0,:]) 64 | n_x = len(x_min[0,:]) 65 | for i in range(n_y): 66 | y_min[:,i] = (i*((height-self.patch_size[1])/(np.int32(n_y-1)))) 67 | for i in range(n_x): 68 | x_min[:,i] = (i*((width-self.patch_size[0])/(np.int32(n_x-1)))) 69 | 70 | x_max = x_min + self.patch_size[0] 71 | y_max = y_min + self.patch_size[1] 72 | box = [] # list of boxes 73 | for i in range(y_min.shape[1]): 74 | for j in range(x_min.shape[1]): 75 | box.append(torch.stack((x_min[:,j],y_min[:,i],x_max[:,j],y_max[:,i]),dim=1)) 76 | 77 | # create bounding box 78 | box = torch.stack((box),dim=1) 79 | return box 80 | 81 | 82 | def sample_grid_cut(self, image): # get grid without overlap.. 83 | """ 84 | sample a grid of bounding boxes from an image 85 | Args: 86 | image (torch.tensor): 2D image of shape [batch, channel, height, width] 87 | Returns: 88 | bounding box (torch.tensor): bounding box of shape [batch, num_boxes, x_min, x_max, y_min, y_max] 89 | """ 90 | # get image size 91 | batch_size, channel, height, width = image.shape 92 | 93 | # checks 94 | if isinstance(self.patch_size, int): 95 | self.patch_size = [self.patch_size, self.patch_size] 96 | if self.patch_size[1] > height or self.patch_size[0] > width: 97 | raise ValueError('Patch size is larger than image size') 98 | 99 | # sample random box 100 | x_min = torch.arange(0, width, self.stride).repeat(batch_size,1) 101 | y_min = torch.arange(0, height, self.stride).repeat(batch_size,1) 102 | 103 | x_max = x_min + self.patch_size[0] 104 | y_max = y_min + self.patch_size[1] 105 | box = [] # list of boxes 106 | for i in range(y_min.shape[1]): 107 | for j in range(x_min.shape[1]): 108 | box.append(torch.stack((x_min[:,j],y_min[:,i],x_max[:,j],y_max[:,i]),dim=1)) 109 | 110 | # create bounding box 111 | box = torch.stack((box),dim=1) 112 | return box 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /src/utils/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /src/utils/taming.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | Parameters 69 | ---------- 70 | list_or_dict : list or dict 71 | Possibly nested list or dictionary. 72 | key : str 73 | key/to/value, path like string describing all keys necessary to 74 | consider to get to the desired value. List indices can also be 75 | passed here. 76 | splitval : str 77 | String that defines the delimiter between keys of the 78 | different depth levels in `key`. 79 | default : obj 80 | Value returned if :attr:`key` is not found. 81 | expand : bool 82 | Whether to expand callable nodes on the path or not. 83 | Returns 84 | ------- 85 | The desired value or if :attr:`default` is not ``None`` and the 86 | :attr:`key` is not found returns ``default``. 87 | Raises 88 | ------ 89 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 90 | ``None``. 91 | """ 92 | 93 | keys = key.split(splitval) 94 | 95 | success = True 96 | try: 97 | visited = [] 98 | parent = None 99 | last_key = None 100 | for key in keys: 101 | if callable(list_or_dict): 102 | if not expand: 103 | raise KeyNotFoundError( 104 | ValueError( 105 | "Trying to get past callable node with expand=False." 106 | ), 107 | keys=keys, 108 | visited=visited, 109 | ) 110 | list_or_dict = list_or_dict() 111 | parent[last_key] = list_or_dict 112 | 113 | last_key = key 114 | parent = list_or_dict 115 | 116 | try: 117 | if isinstance(list_or_dict, dict): 118 | list_or_dict = list_or_dict[key] 119 | else: 120 | list_or_dict = list_or_dict[int(key)] 121 | except (KeyError, IndexError, ValueError) as e: 122 | raise KeyNotFoundError(e, keys=keys, visited=visited) 123 | 124 | visited += [key] 125 | # final expansion of retrieved value 126 | if expand and callable(list_or_dict): 127 | list_or_dict = list_or_dict() 128 | parent[last_key] = list_or_dict 129 | except KeyNotFoundError as e: 130 | if default is None: 131 | raise e 132 | else: 133 | list_or_dict = default 134 | success = False 135 | 136 | if not pass_success: 137 | return list_or_dict 138 | else: 139 | return list_or_dict, success 140 | 141 | 142 | if __name__ == "__main__": 143 | config = {"keya": "a", 144 | "keyb": "b", 145 | "keyc": 146 | {"cc1": 1, 147 | "cc2": 2, 148 | } 149 | } 150 | from omegaconf import OmegaConf 151 | config = OmegaConf.create(config) 152 | print(config) 153 | retrieve(config, "keya") -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import warnings 4 | from typing import List, Sequence 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import rich.syntax 8 | import rich.tree 9 | from omegaconf import DictConfig, OmegaConf 10 | from pytorch_lightning.utilities import rank_zero_only 11 | import yaml 12 | 13 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 14 | """Initializes multi-GPU-friendly python logger.""" 15 | 16 | logger = logging.getLogger(name) 17 | logger.setLevel(level) 18 | 19 | # this ensures all logging levels get marked with the rank zero decorator 20 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 21 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): 22 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 23 | 24 | return logger 25 | 26 | 27 | def extras(config: DictConfig) -> None: 28 | """A couple of optional utilities, controlled by main config file: 29 | - disabling warnings 30 | - easier access to debug mode 31 | - forcing debug friendly configuration 32 | Modifies DictConfig in place. 33 | Args: 34 | config (DictConfig): Configuration composed by Hydra. 35 | """ 36 | 37 | log = get_logger() 38 | 39 | # enable adding new keys to config 40 | OmegaConf.set_struct(config, False) 41 | 42 | # disable python warnings if 43 | if config.get("ignore_warnings"): 44 | log.info("Disabling python warnings! ") 45 | warnings.filterwarnings("ignore") 46 | 47 | # set if 48 | if config.get("debug"): 49 | log.info("Running in debug mode! ") 50 | config.trainer.fast_dev_run = True 51 | 52 | # force debugger friendly configuration if 53 | if config.trainer.get("fast_dev_run"): 54 | log.info("Forcing debugger friendly configuration! ") 55 | # Debuggers don't like GPUs or multiprocessing 56 | if config.trainer.get("gpus"): 57 | config.trainer.gpus = 0 58 | if config.datamodule.get("pin_memory"): 59 | config.datamodule.pin_memory = False 60 | if config.datamodule.get("num_workers"): 61 | config.datamodule.num_workers = 0 62 | 63 | # disable adding new keys to config 64 | OmegaConf.set_struct(config, True) 65 | 66 | 67 | @rank_zero_only 68 | def print_config( 69 | config: DictConfig, 70 | fields: Sequence[str] = ( 71 | "trainer", 72 | "model", 73 | "datamodule", 74 | "callbacks", 75 | "logger", 76 | "seed", 77 | ), 78 | resolve: bool = True, 79 | ) -> None: 80 | """Prints content of DictConfig using Rich library and its tree structure. 81 | Args: 82 | config (DictConfig): Configuration composed by Hydra. 83 | fields (Sequence[str], optional): Determines which main fields from config will 84 | be printed and in what order. 85 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 86 | """ 87 | 88 | style = "dim" 89 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 90 | 91 | for field in fields: 92 | branch = tree.add(field, style=style, guide_style=style) 93 | 94 | config_section = config.get(field) 95 | branch_content = str(config_section) 96 | if isinstance(config_section, DictConfig): 97 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 98 | 99 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 100 | 101 | rich.print(tree) 102 | 103 | with open("config_tree.txt", "w") as fp: 104 | rich.print(tree, file=fp) 105 | 106 | 107 | def empty(*args, **kwargs): 108 | pass 109 | 110 | 111 | @rank_zero_only 112 | def log_hyperparameters( 113 | config: DictConfig, 114 | model: pl.LightningModule, 115 | datamodule: pl.LightningDataModule, 116 | trainer: pl.Trainer, 117 | callbacks: List[pl.Callback], 118 | logger: List[pl.loggers.LightningLoggerBase], 119 | ) -> None: 120 | """This method controls which parameters from Hydra config are saved by Lightning loggers. 121 | Additionaly saves: 122 | - number of trainable model parameters 123 | """ 124 | 125 | hparams = {} 126 | 127 | # choose which parts of hydra config will be saved to loggers 128 | hparams["trainer"] = config["trainer"] 129 | hparams["model"] = config["model"] 130 | hparams["datamodule"] = config["datamodule"] 131 | 132 | if "seed" in config: 133 | hparams["seed"] = config["seed"] 134 | if "callbacks" in config: 135 | hparams["callbacks"] = config["callbacks"] 136 | 137 | # save number of model parameters 138 | hparams["model/params_total"] = sum(p.numel() for p in model.parameters()) 139 | hparams["model/params_trainable"] = sum( 140 | p.numel() for p in model.parameters() if p.requires_grad 141 | ) 142 | hparams["model/params_not_trainable"] = sum( 143 | p.numel() for p in model.parameters() if not p.requires_grad 144 | ) 145 | hparams['run_id'] = trainer.logger.experiment[0].id 146 | # send hparams to all loggers 147 | trainer.logger.log_hyperparams(hparams) 148 | 149 | # disable logging any more hyperparameters for all loggers 150 | # this is just a trick to prevent trainer from logging hparams of model, 151 | # since we already did that above 152 | trainer.logger.log_hyperparams = empty 153 | 154 | 155 | def finish( 156 | config: DictConfig, 157 | model: pl.LightningModule, 158 | datamodule: pl.LightningDataModule, 159 | trainer: pl.Trainer, 160 | callbacks: List[pl.Callback], 161 | logger: List[pl.loggers.LightningLoggerBase], 162 | ) -> None: 163 | """Makes sure everything closed properly.""" 164 | 165 | # without this sweeps with wandb logger might crash! 166 | for lg in logger: 167 | if isinstance(lg, pl.loggers.wandb.WandbLogger): 168 | import wandb 169 | 170 | wandb.finish() 171 | 172 | def summarize(eval_dict, prefix): # removes list entries from dictionary for faster logging 173 | # for set in list(eval_dict) : 174 | eval_dict_new = {} 175 | for key in list(eval_dict) : 176 | if type(eval_dict[key]) is not list : 177 | eval_dict_new[prefix + '/' + key] = eval_dict[key] 178 | return eval_dict_new 179 | 180 | def get_yaml(path): # read yaml 181 | with open(path, "r") as stream: 182 | try: 183 | file = yaml.safe_load(stream) 184 | except yaml.YAMLError as exc: 185 | print(exc) 186 | return file 187 | 188 | def get_checkpoint(cfg, path): 189 | checkpoint_path = path 190 | checkpoint_to_load = cfg.get("checkpoint",'last') # default to last.ckpt 191 | all_checkpoints = os.listdir(checkpoint_path + '/checkpoints') 192 | hparams = get_yaml(path+'/csv//hparams.yaml') 193 | wandbID = hparams['run_id'] 194 | checkpoints = {} 195 | for fold in range(cfg.get('num_folds',1)): 196 | checkpoints[f'fold-{fold+1}'] = [] # dict to store the checkpoints with their path for different folds 197 | 198 | if checkpoint_to_load == 'last': 199 | matching_checkpoints = [c for c in all_checkpoints if "last" in c] 200 | matching_checkpoints.sort(key = lambda x: x.split('fold-')[1][0:1]) 201 | for fold, cp_name in enumerate(matching_checkpoints): 202 | checkpoints[f'fold-{fold+1}'] = checkpoint_path + '/checkpoints/' + cp_name 203 | elif 'best' in checkpoint_to_load : 204 | matching_checkpoints = [c for c in all_checkpoints if "last" not in c] 205 | matching_checkpoints.sort(key = lambda x: x.split('loss-')[1][0:4]) # sort by loss value -> increasing 206 | for fold in checkpoints: 207 | for cp in matching_checkpoints: 208 | if fold in cp: 209 | checkpoints[fold].append(checkpoint_path + '/checkpoints/' + cp) 210 | if not 'best_k' in checkpoint_to_load: # best_k loads the k best checkpoints 211 | checkpoints[fold] = checkpoints[fold][0] # get only the best (first) checkpoint of that fold 212 | return wandbID, checkpoints 213 | 214 | 215 | def calc_interres(dims,fac,num_pooling,k,p,s): 216 | dims = [int(x/fac) for x in dims] 217 | if len(dims)==2: 218 | w,h = dims 219 | d = None 220 | else: 221 | w,h,d = dims 222 | for i in range(num_pooling): 223 | w = int((w-k+2*p)/s +1) 224 | h = int((h-k+2*p)/s +1) 225 | if d is not None: 226 | d = int((d-k+2*p)/s +1) 227 | return [w,h] if d is None else [w,h,d] --------------------------------------------------------------------------------