├── .gitignore ├── README.md ├── echosyn ├── common │ ├── __init__.py │ ├── datasets.py │ └── privacy_utils.py ├── lidm │ ├── README.md │ ├── configs │ │ ├── dynamic.yaml │ │ ├── ped_a4c.yaml │ │ └── ped_psax.yaml │ ├── sample.py │ └── train.py ├── lvdm │ ├── README.md │ ├── configs │ │ └── default.yaml │ ├── sample.py │ └── train.py ├── privacy │ ├── README.md │ ├── apply.py │ ├── configs │ │ ├── config_dynamic.json │ │ ├── config_ped_a4c.json │ │ └── config_ped_psax.json │ ├── shared.py │ └── train.py └── vae │ ├── README.md │ └── usencoder_kl_16x16x4.yaml ├── external └── README.md ├── requirements.txt ├── ressources ├── models.jpg ├── mosaic.gif ├── mosaic_slim.gif ├── real.gif └── synt.gif ├── scripts ├── complete_pediatrics_filelist.py ├── convert_vae_pt_to_diffusers.py ├── copy_privacy_compliant_images.sh ├── create_reference_dataset.py ├── encode_video_dataset.py ├── extract_frames_from_videos.sh ├── update_split_filelist.py └── vae_reconstruct_image_folder.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | external/* 163 | !external/README.md 164 | models/ 165 | data/ 166 | datasets/ 167 | experiments/ 168 | wandb/ 169 | samples/ 170 | .vscode/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EchoNet-Synthetic 2 |

3 | 4 | 5 | 6 | 7 | Star this repo! 8 |

9 | 10 | ## Deep-Dive Podcast (AI generated) 11 | 12 | 13 | https://github.com/user-attachments/assets/06d0811b-f2fc-442f-bde4-0885d8dfd319 14 | 15 | *Note: Turn on audio !* 16 | 17 | ## Introduction 18 | 19 | This repository contains the code and model weights for the paper *[EchoNet-Synthetic: Privacy-preserving Video Generation for Safe Medical Data Sharing](https://arxiv.org/abs/2406.00808)*. Hadrien Reynaud, Qingjie Meng, Mischa Dombrowski, Arijit Ghosh, Alberto Gomez, Paul Leeson and Bernhard Kainz. MICCAI 2024. 20 | 21 | EchoNet-Synthetic presents a protocol to generate surrogate privacy-compliant datasets that are as valuable as their original counterparts to train downstream models (e.g. regression models). 22 | 23 | In this repository, we present the code we use for the experiments in the paper. We provide the code to train the models, generate the synthetic data, and evaluate the quality of the synthetic data. 24 | We also provide all the pre-trained models and release the synthetic datasets we generated. 25 | 26 | 📜 Read the Paper [on arXiv](https://arxiv.org/abs/2406.00808)
27 | 📕 [MICCAI 2024 Proceedings](https://link.springer.com/chapter/10.1007/978-3-031-72104-5_28)
28 | 🤗 Try our interactive demo [on HuggingFace](https://huggingface.co/spaces/HReynaud/EchoNet-Synthetic), it contains all the generative pipeline inference code and weights ! 29 | 30 | ![Slim GIF Demo](ressources/mosaic_slim.gif) 31 | 32 | *Exemple of synthetic videos generated with EchoNet-Synthetic. First Video is real, others are generated.* 33 | 34 | ## Table of contents 35 | 1. [Environment setup](#environment-setup) 36 | 2. [Data preparation](#data-preparation) 37 | 3. [The models](#the-models) 38 | 4. [Generating EchoNet-Synthetic](#generating-echonet-synthetic) 39 | 5. [Evaluation](#evaluation) 40 | 6. [Results](#results) 41 | 7. [Citation](#citation) 42 | 43 | ## Environment setup 44 | 46 | 47 | First, we need to set up the environment. We use the following command to create a new conda environment with the required dependencies. 48 | 49 | ```bash 50 | conda create -y -n echosyn python=3.11 51 | conda activate echosyn 52 | pip install -e . 53 | ``` 54 | *Note: the exact version of each package can be found in requirements.txt if necessary* 55 | 56 | This repository lets you train three models: 57 | - the Latent Image Diffusion Model (LIDM) 58 | - the Re-Indentification models for privacy checks 59 | - the Latent Video Diffusion Model (LVDM) 60 | 61 | We rely on external libraries to: 62 | - train the Variational Auto-Encoder (VAE) (Stable-Diffusion and Taming-Transformers) 63 | - evaluate the generated images and videos (StyleGAN-V) 64 | - evaluate the synthetic data on the Ejection Fraction downstream task (EchoNet-Dynamic) 65 | 66 | How to install the external libraries is explained in the [External libraries](external/README.md) section. 67 | 68 | 69 | 70 | 71 | ## Data preparation 72 | 74 | 75 | ### ➡ Original datasets 76 | Download the EchoNet-Dynamic dataset from [here](https://echonet.github.io/dynamic/) and the EchoNet-Pediatric dataset from [here](https://echonet.github.io/pediatric/). The datasets are available for free upon request. Once downloaded, extract the content of the archive in the `datasets` folder. For simplicity and consistency, we structure them like so: 77 | ``` 78 | datasets 79 | ├── EchoNet-Dynamic 80 | │ ├── Videos 81 | │ ├── FileList.csv 82 | │ └── VolumeTracings.csv 83 | └── EchoNet-Pediatric 84 | ├── A4C 85 | │ ├── Videos 86 | │ ├── FileList.csv 87 | │ └── VolumeTracings.csv 88 | └── PSAX 89 | ├── Videos 90 | ├── FileList.csv 91 | └── VolumeTracings.csv 92 | ``` 93 | 94 | To harmonize the datasets, we add some information to the `FileList.csv` files of the EchoNet-Pediatric dataset, namely FrameHeight, FrameWidth, FPS, NumberOfFrames. We also arbitrarily set the splits from the 10-fold indices to a simple TRAIN/VAL/TEST split. These updates ares applied with the following command: 95 | 96 | ```bash 97 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/A4C 98 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/PSAX 99 | ``` 100 | 101 | This is crucial for the other scripts to work properly. 102 | 103 | ### ➡ Image datasets for VAE training 104 | See the [VAE training](echosyn/vae/README.md) section To see how to train the VAE. 105 | 106 | The VAE needs images only and to keep things simple, we format our video datasets into image datasets. 107 | We do that with: 108 | ```bash 109 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Dynamic/Videos data/vae_train_images/images/ 110 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/A4C/Videos data/vae_train_images/images/ 111 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/PSAX/Videos data/vae_train_images/images/ 112 | ``` 113 | 114 | Note that this will merge all the images in the same folder. 115 | 116 | Then, we need to create train.txt file and a val.txt file containing the path to these images. 117 | ```bash 118 | find $(cd data/vae_train_images/images && pwd) -type f | shuf > tmp.txt 119 | head -n -1000 tmp.txt > data/vae_train_images/train.txt 120 | tail -n 1000 tmp.txt > data/vae_train_images/val.txt 121 | rm tmp.txt 122 | ``` 123 | 124 | ### ➡ Latent Video datasets for LIDM / Privacy / LVDM training 125 | 126 | The LIDM, Re-Identification model and LVDM are trained on pre-encoded latent representations of the videos. To encode the videos, we use the image VAE. You can either retrain the VAE or download it from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/vae). Once you have the VAE, you can encode the videos with the following command: 127 | 128 | ```bash 129 | # For the EchoNet-Dynamic dataset 130 | python scripts/encode_video_dataset.py \ 131 | --model models/vae \ 132 | --input datasets/EchoNet-Dynamic \ 133 | --output data/latents/dynamic \ 134 | --gray_scale 135 | ``` 136 | ```bash 137 | # For the EchoNet-Pediatric datasets 138 | python scripts/encode_video_dataset.py \ 139 | --model models/vae \ 140 | --input datasets/EchoNet-Pediatric/A4C \ 141 | --output data/latents/ped_a4c \ 142 | --gray_scale 143 | 144 | python scripts/encode_video_dataset.py \ 145 | --model models/vae \ 146 | --input datasets/EchoNet-Pediatric/PSAX \ 147 | --output data/latents/ped_psax \ 148 | --gray_scale 149 | ``` 150 | 151 | ### ➡ Validation datasets 152 | 153 | To quantitatively evaluate the quality of the generated images and videos, we use the StyleGAN-V repo. 154 | We cover the evaluation process in the [Evaluation](#evaluation) section. 155 | To enable this evaluation, we need to prepare the validation datasets. We do that with the following command: 156 | 157 | ```bash 158 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Dynamic --output data/reference/dynamic --frames 128 159 | ``` 160 | 161 | ```bash 162 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/A4C --output data/reference/ped_a4c --frames 16 163 | ``` 164 | 165 | ```bash 166 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/PSAX --output data/reference/ped_psax --frames 16 167 | ``` 168 | 169 | Note that the Pediatric datasets do not support 128 frames, preventing the computation of FVD_128, because there are not enough videos lasting more 4 seconds or more. We therefore only extract 16 frames per video for these datasets. 170 | 171 | 172 | 173 | ## The Models 174 | 176 | 177 | ![Models](ressources/models.jpg) 178 | 179 | *Our pipeline, using all our models: LIDM, Re-Identification (Privacy), LVDM and VAE* 180 | 181 | 182 | ### The VAE 183 | 184 | You can download the pretrained VAE from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/vae) or train it yourself by following the instructions in the [VAE training](echosyn/vae/README.md) section. 185 | 186 | ### The LIDM 187 | 188 | You can download the pretrained LIDMs from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/lidm_dynamic) or train them yourself by following the instructions in the [LIDM training](echosyn/lidm/README.md) section. 189 | 190 | ### The Re-Identification models 191 | 192 | You can download the pretrained Re-Identification models from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/reidentification_dynamic) or train them yourself by following the instructions in the [Re-Identification training](echosyn/privacy/README.md) section. 193 | 194 | ### The LVDM 195 | 196 | You can download the pretrained LVDM from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/lvdm) or train it yourself by following the instructions in the [LVDM training](echosyn/lvdm/README.md) section. 197 | 198 | ### Structure 199 | 200 | The models should be structured as follows: 201 | ``` 202 | models 203 | ├── lidm_dynamic 204 | ├── lidm_ped_a4c 205 | ├── lidm_ped_psax 206 | ├── lvdm 207 | ├── regression_dynamic 208 | ├── regression_ped_a4c 209 | ├── regression_ped_psax 210 | ├── reidentification_dynamic 211 | ├── reidentification_ped_a4c 212 | ├── reidentification_ped_psax 213 | └── vae 214 | ``` 215 | 216 | 217 | 218 | ## Generating EchoNet-Synthetic 219 | 221 | 222 | Now that we have all the necessary models, we can generate the synthetic datasets. The process is the same for all three datasets and involves the following steps: 223 | - Generate a collection of latent heart images with the LIDMs (usually 2x the amount of videos we are targetting) 224 | - Apply the privacy check, which will filter out some of the latent images 225 | - Generate the videos with the LVDM, and decode them with the VAE 226 | 227 | #### Dynamic dataset 228 | For the dynamic dataset, we aim for 10,030 videos, so we start by generating 20,000 latent images: 229 | ```bash 230 | python echosyn/lidm/sample.py \ 231 | --config models/lidm_dynamic/config.yaml \ 232 | --unet models/lidm_dynamic \ 233 | --vae models/vae \ 234 | --output samples/synthetic/lidm_dynamic \ 235 | --num_samples 20000 \ 236 | --batch_size 128 \ 237 | --num_steps 64 \ 238 | --save_latent \ 239 | --seed 0 240 | ``` 241 | 242 | Then, we filter the latent images with the re-identification model: 243 | ```bash 244 | python echosyn/privacy/apply.py \ 245 | --model models/reidentification_dynamic \ 246 | --synthetic samples/synthetic/lidm_dynamic/latents \ 247 | --reference data/latents/dynamic \ 248 | --output samples/synthetic/lidm_dynamic/privatised_latents 249 | ``` 250 | 251 | We generate the synthetic videos with the LVDM: 252 | ```bash 253 | python echosyn/lvdm/sample.py \ 254 | --config models/lvdm/config.yaml \ 255 | --unet models/lvdm \ 256 | --vae models/vae \ 257 | --conditioning samples/synthetic/lidm_dynamic/privatised_latents \ 258 | --output samples/synthetic/lvdm_dynamic \ 259 | --num_samples 10030 \ 260 | --batch_size 8 \ 261 | --num_steps 64 \ 262 | --min_lvef 10 \ 263 | --max_lvef 90 \ 264 | --save_as avi \ 265 | --frames 192 266 | ``` 267 | 268 | Finally, update the content of the FileList.csv to respect the right amount of videos in each split, and the correct naming convention. 269 | ```bash 270 | mkdir -p datasets/EchoNet-Synthetic/dynamic 271 | mv samples/synthetic/lvdm_dynamic/avi datasets/EchoNet-Synthetic/dynamic/Videos 272 | cp samples/synthetic/lvdm_dynamic/FileList.csv datasets/EchoNet-Synthetic/dynamic/FileList.csv 273 | python scripts/update_split_filelist.py --csv datasets/EchoNet-Synthetic/dynamic/FileList.csv --train 7465 --val 1288 --test 1277 274 | ``` 275 | 276 | You can use the exact same process for the Pediatric datasets, just make sure to use the right models. 277 | The A4C dataset should have 3284 videos, with 2580 train, 336 validation and 368 test videos. 278 | 3559 448 519 279 | The PSAX dataset should have 4526 videos, with 3559 train, 448 validation and 519 test videos. 280 | They also have a different number of frames per video, which is 128 insted of 192. 281 | 282 | 283 | 284 | ## Evaluation 285 | 287 | 288 | As the final step, we evaluate the quality of EchoNet-Synthetic videos by training a Ejection Fraction regression model on the synthetic data and evaluating it on the real data. 289 | To do so, we use the EchoNet-Dynamic repository. 290 | You will need to clone a slightly modifier version of the repository from [here](https://github.com/HReynaud/echonet), and put it in the `external` folder. 291 | 292 | Once that is done, you will need to switch environment to the echonet one. 293 | To create the echonet environment, you can use the following command: 294 | ```bash 295 | cd external/echonet 296 | conda create -y -n echonet python=3.10 297 | conda activate echonet 298 | pip install -e . 299 | ``` 300 | 301 | Then, you should double check that everything is working by reproducing the results of the EchoNet-Dynamic repository on the real EchoNet-Dynamic dataset. This should take ~12 hours on a single GPU. 302 | ```bash 303 | echonet video --data_dir ../../datasets/EchoNet-Dynamic \ 304 | --output ../../experiments/regression_echonet_dynamic \ 305 | --pretrained \ 306 | --run_test \ 307 | --num_epochs 45 308 | ``` 309 | 310 | Once you have confirmed that everything is working, you can train the same model on the synthetic data, while evaluating on the real data. 311 | ```bash 312 | echonet video --data_dir ../../datasets/EchoNet-Synthetic/dynamic \ 313 | --data_dir_real ../../datasets/EchoNet-Dynamic \ 314 | --output ../../experiments/regression_echosyn_dynamic \ 315 | --pretrained \ 316 | --run_test \ 317 | --num_epochs 45 \ 318 | --period 1 319 | ``` 320 | 321 | This will automatically compute the final metrics after training, and store them in the `experiments/regression_echosyn_dynamic/log.csv` folder. 322 | 323 | The process is the same for the Pediatric datasets, although they start from a pretrained model, trained on the real EchoNet-Dynamic data, which would look like this: 324 | ```bash 325 | echonet video --data_dir ../../datasets/EchoNet-Pediatric/A4C \ 326 | --data_dir_real ../../datasets/EchoNet-Pediatric/A4C \ 327 | --output ../../experiments/regression_echosyn_peda4c \ 328 | --weights ../../experiments/regression_real/best.pt \ 329 | --run_test \ 330 | --num_epochs 45 \ 331 | --period 1 332 | ``` 333 | 334 | 335 | Here are two tricks that can improve the quality and alignment of the synthetic data with the real data: 336 | - Re-label the synthetic data with the regression model trained on the real data. 337 | - Use the Ejection Fraction scores from the real dataset to create the synthetic dataset, not uniformaly distributed ones. 338 | 339 | 340 | 341 | ## Results 342 | 344 | 345 |

Here is a side by side comparison between a real video and a synthetic video generated with EchoNet-Synthetic. We use a frame from the real video and its corresponding ejected fraction score to "reproduce" the real video.

346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 |
Real VideoReproduction
Real VideoReproduction
357 | 358 |

Here we show a collection of synthetic videos generated with EchoNet-Synthetic. We can see that the quality of the videos is excellent, and show a variety of ejection fraction scores.

359 | 360 | ![Mosaic](ressources/mosaic.gif) 361 | 362 | 363 | 364 | # Acknowledgements 365 | This work was supported by Ultromics Ltd. and the UKRI Centre for Doctoral Training in Artificial Intelligence for Healthcare (EP/S023283/1). 366 | HPC resources were provided by the Erlangen National High Performance Computing Center (NHR@FAU) of the Friedrich-Alexander-Universität Erlangen-Nürnberg (FAU) under the NHR projects b143dc and b180dc. NHR funding is provided by federal and Bavarian state authorities. NHR@FAU hardware is partially funded by the German Research Foundation (DFG) – 440719683. 367 | 368 | 369 | ## Citation 370 | 371 | ``` 372 | @inproceedings{reynaud2024echonet, 373 | title={Echonet-synthetic: Privacy-preserving video generation for safe medical data sharing}, 374 | author={Reynaud, Hadrien and Meng, Qingjie and Dombrowski, Mischa and Ghosh, Arijit and Day, Thomas and Gomez, Alberto and Leeson, Paul and Kainz, Bernhard}, 375 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 376 | pages={285--295}, 377 | year={2024}, 378 | organization={Springer} 379 | } 380 | ``` 381 | -------------------------------------------------------------------------------- /echosyn/common/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import importlib 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from einops import rearrange 11 | from omegaconf import OmegaConf 12 | import imageio 13 | 14 | import diffusers 15 | 16 | def padf(tensor, mult=3): 17 | """ 18 | Pads a tensor along the last dimension to make its size a multiple of 2^mult. 19 | 20 | Args: 21 | tensor (torch.Tensor): The tensor to pad. 22 | mult (int, optional): The power of 2 that the tensor's size should be a multiple of. Defaults to 3. 23 | 24 | Returns: 25 | torch.Tensor: The padded tensor. 26 | int: The amount of padding applied. 27 | """ 28 | pad = 2**mult - (tensor.shape[-1] % 2**mult) 29 | pad = pad//2 30 | tensor = F.pad(tensor, (pad, pad, pad, pad, 0, 0), mode='replicate') 31 | return tensor, pad 32 | 33 | def unpadf(tensor, pad=1): 34 | """ 35 | Removes padding from a tensor along the last two dimensions. 36 | 37 | Args: 38 | tensor (torch.Tensor): The tensor to unpad. 39 | pad (int, optional): The amount of padding to remove. Defaults to 1. 40 | 41 | Returns: 42 | torch.Tensor: The unpadded tensor. 43 | """ 44 | return tensor[..., pad:-pad, pad:-pad] 45 | 46 | def pad_reshape(tensor, mult=3): 47 | """ 48 | Pads a tensor along the last dimension to make its size a multiple of 2^mult and reshapes it. 49 | 50 | Args: 51 | tensor (torch.Tensor): The tensor to pad and reshape. 52 | mult (int, optional): The power of 2 that the tensor's size should be a multiple of. Defaults to 3. 53 | 54 | Returns: 55 | torch.Tensor: The padded and reshaped tensor. 56 | int: The amount of padding applied. 57 | """ 58 | tensor, pad = padf(tensor, mult=mult) 59 | tensor = rearrange(tensor, "b c t h w -> b t c h w") 60 | return tensor, pad 61 | 62 | def unpad_reshape(tensor, pad=1): 63 | """ 64 | Reshapes a tensor and removes padding from it along the last two dimensions. 65 | 66 | Args: 67 | tensor (torch.Tensor): The tensor to reshape and unpad. 68 | pad (int, optional): The amount of padding to remove. Defaults to 1. 69 | 70 | Returns: 71 | torch.Tensor: The reshaped and unpadded tensor. 72 | """ 73 | tensor = rearrange(tensor, "b t c h w -> b c t h w") 74 | tensor = unpadf(tensor, pad=pad) 75 | return tensor 76 | 77 | def instantiate_from_config(config, scope: list[str], return_klass_kwargs=False, **kwargs): 78 | """ 79 | Instantiate a class from a config dictionary. 80 | 81 | Args: 82 | config (dict): The config dictionary. 83 | scope (list[str]): The scope of the class to instantiate. 84 | return_klass_kwargs (bool, optional): Whether to return the class and its kwargs. Defaults to False. 85 | **kwargs: Additional keyword arguments to pass to the class constructor. 86 | 87 | Returns: 88 | object: The instantiated class. 89 | (optional) type: The class that was instantiated. 90 | (optional) dict: The kwargs that were passed to the class constructor. 91 | """ 92 | okwargs = OmegaConf.to_container(config, resolve=True) 93 | klass_name = okwargs.pop("_class_name") 94 | klass = None 95 | 96 | for module_name in scope: 97 | try: 98 | module = importlib.import_module(module_name) 99 | except ImportError: 100 | continue # Try next module 101 | 102 | klass = getattr(module, klass_name, None) 103 | if klass is not None: 104 | break # Stop when we find a matching class 105 | 106 | assert klass is not None, f"Could not find class {klass_name} in the specified scope" 107 | instance = klass(**okwargs, **kwargs) 108 | 109 | if return_klass_kwargs: 110 | return instance, klass, okwargs 111 | return instance 112 | 113 | def load_model(path): 114 | """ 115 | Loads a model from a checkpoint. 116 | 117 | Args: 118 | path (str): The path to the checkpoint. 119 | 120 | Returns: 121 | object: The loaded model. 122 | """ 123 | # find config.json 124 | json_path = os.path.join(path, "config.json") 125 | assert os.path.exists(json_path), f"Could not find config.json at {json_path}" 126 | with open(json_path, "r") as f: 127 | config = json.load(f) 128 | 129 | # instantiate class 130 | klass_name = config["_class_name"] 131 | klass = getattr(diffusers, klass_name, None) 132 | if klass is None: 133 | klass = globals().get(klass_name, None) 134 | assert klass is not None, f"Could not find class {klass_name} in diffusers or global scope." 135 | assert getattr(klass, "from_pretrained", None) is not None, f"Class {klass_name} does not support 'from_pretrained'." 136 | 137 | # load checkpoint 138 | model = klass.from_pretrained(path) 139 | 140 | return model 141 | 142 | def save_as_mp4(tensor, filename, fps=30): 143 | """ 144 | Saves a 4D tensor (nFrames, height, width, channels) as an MP4 video. 145 | 146 | Parameters: 147 | - tensor: 4D torch.Tensor. Tensor containing the video frames. 148 | - filename: str. The output filename for the video. 149 | - fps: int. Frames per second for the output video. 150 | 151 | Returns: 152 | - None 153 | """ 154 | import imageio 155 | # Make sure the tensor is on the CPU and is a numpy array 156 | np_video = tensor.cpu().numpy() 157 | 158 | # Ensure the tensor dtype is uint8 159 | if np_video.dtype != np.uint8: 160 | raise ValueError("The tensor has to be of type uint8") 161 | 162 | # Write the frames to a video file 163 | with imageio.get_writer(filename, fps=fps, ) as writer: 164 | for i in range(np_video.shape[0]): 165 | writer.append_data(np_video[i]) 166 | 167 | def save_as_avi(tensor, filename, fps=30): 168 | """ 169 | Saves a 4D tensor (nFrames, height, width, channels) as an AVI video with reduced compression. 170 | 171 | Parameters: 172 | - tensor: 4D torch.Tensor. Tensor containing the video frames. 173 | - filename: str. The output filename for the video. 174 | - fps: int. Frames per second for the output video. 175 | 176 | Returns: 177 | - None 178 | """ 179 | # Make sure the tensor is on the CPU and is a numpy array 180 | np_video = tensor.cpu().numpy() 181 | 182 | # Ensure the tensor dtype is uint8 183 | if np_video.dtype != np.uint8: 184 | raise ValueError("The tensor has to be of type uint8") 185 | 186 | # Define codec for reduced compression 187 | codec = "mjpeg" # MJPEG codec for AVI files 188 | quality = 10 # High quality (lower values mean higher quality, but larger file sizes) 189 | # pixel_format = "yuvj420p" 190 | # Write the frames to a video file 191 | with imageio.get_writer(filename, fps=fps, codec=codec, quality=quality) as writer: 192 | for frame in np_video: 193 | writer.append_data(frame) 194 | 195 | def save_as_gif(tensor, filename, fps=30): 196 | """ 197 | Saves a 4D tensor (nFrames, height, width, channels) as a GIF. 198 | 199 | Parameters: 200 | - tensor: 4D torch.Tensor. Tensor containing the video frames. 201 | - filename: str. The output filename for the GIF. 202 | - fps: int. Frames per second for the output GIF. 203 | 204 | Returns: 205 | - None 206 | """ 207 | import imageio 208 | # Make sure the tensor is on the CPU and is a numpy array 209 | np_video = tensor.cpu().numpy() 210 | 211 | # Ensure the tensor dtype is uint8 212 | if np_video.dtype != np.uint8: 213 | raise ValueError("The tensor has to be of type uint8") 214 | 215 | # Write the frames to a GIF file 216 | imageio.mimsave(filename, np_video, fps=fps) 217 | 218 | def save_as_img(tensor, filename, ext="jpg"): 219 | """ 220 | Saves a 4D tensor (nFrames, height, width, channels) as a series of JPG images. 221 | OR 222 | Saves a 3D tensor (height, width, channels) as a single image. 223 | 224 | Parameters: 225 | - tensor: 4D torch.Tensor. Tensor containing the video frames. 226 | - filename: str. The output filename for the JPG images. 227 | 228 | Returns: 229 | - None 230 | """ 231 | import imageio 232 | # Make sure the tensor is on the CPU and is a numpy array 233 | np_video = tensor.cpu().numpy() 234 | 235 | # Ensure the tensor dtype is uint8 236 | if np_video.dtype != np.uint8: 237 | raise ValueError("The tensor has to be of type uint8") 238 | 239 | # Write the frames to a series of JPG files 240 | if len(np_video.shape) == 3: 241 | imageio.imwrite(filename, np_video, quality=100) 242 | else: 243 | os.makedirs(filename, exist_ok=True) 244 | for i in range(np_video.shape[0]): 245 | imageio.imwrite(os.path.join(filename, f"{i:04d}.{ext}"), np_video[i], quality=100) 246 | 247 | def loadvideo(filename: str, return_fps=False): 248 | """ 249 | Loads a video file into a tensor of frames. 250 | 251 | Args: 252 | filename (str): The path to the video file. 253 | return_fps (bool, optional): Whether to return the frames per second of the video. Defaults to False. 254 | 255 | Raises: 256 | FileNotFoundError: If the video file does not exist. 257 | 258 | Returns: 259 | torch.Tensor: A tensor of the video's frames, with shape (frames, 3, height, width). 260 | (optional) float: The frames per second of the video. Only returned if return_fps is True. 261 | """ 262 | 263 | if not os.path.exists(filename): 264 | raise FileNotFoundError(filename) 265 | capture = cv2.VideoCapture(filename) # type: ignore 266 | 267 | fps = capture.get(cv2.CAP_PROP_FPS) # type: ignore 268 | 269 | frames = [] 270 | 271 | while True: # load all frames 272 | ret, frame = capture.read() 273 | if not ret: 274 | break # Reached end of video 275 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 276 | frame = torch.from_numpy(frame) 277 | 278 | frames.append(frame) 279 | capture.release() 280 | 281 | frames = torch.stack(frames, dim=0) # (frames, 3, height, width) 282 | 283 | if return_fps: 284 | return frames, fps 285 | return frames 286 | 287 | def parse_formats(s): 288 | # Split the input string by comma and strip spaces 289 | formats = [format.strip().lower() for format in s.split(',')] 290 | # Define the allowed choices 291 | allowed_formats = ["avi", "mp4", "gif", "jpg", "png", "pt"] 292 | # Check if all elements in formats are in allowed_formats 293 | for format in formats: 294 | if format not in allowed_formats: 295 | raise argparse.ArgumentTypeError(f"{format} is not a valid format. Choose from {', '.join(allowed_formats)}.") 296 | return formats 297 | 298 | 299 | 300 | 301 | 302 | 303 | -------------------------------------------------------------------------------- /echosyn/common/datasets.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import pandas as pd 4 | from glob import glob 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader, ConcatDataset, Sampler 8 | from torchvision import transforms as T 9 | from torchvision.transforms.functional import center_crop 10 | 11 | from PIL import Image 12 | 13 | import decord 14 | decord.bridge.set_bridge('torch') 15 | 16 | 17 | # support video and image + additional info (lvef, view, etc) 18 | class Dynamic(Dataset): 19 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"], datafolder="Videos", ext=".avi") -> None: 20 | super().__init__() 21 | # the config here is only the config for this dataset, ie config.dataset.dynamic 22 | 23 | if type(split) == str: 24 | split = [split] 25 | assert [s in ["TRAIN", "VAL", "TEST"] for s in split], "Splits must be a list of TRAIN, VAL, TEST" 26 | 27 | # assert type(config.target_fps) == int or config.target_fps in ["original", "random", "exponential"], "target_fps must be an integer, 'original', 'random' or 'exponential'" 28 | self.target_fps = config.target_fps 29 | # self.duration_seconds = config.target_duration 30 | self.resolution = config.target_resolution 31 | self.outputs = config.outputs 32 | if type(self.outputs) == str: 33 | self.outputs = [self.outputs] 34 | assert [o in ["video", "image", "lvef"] for o in self.outputs], "Outputs must be a list of video, image, lvef" 35 | 36 | # self.duration_frames = int(self.target_fps * self.duration_seconds) 37 | self.duration_frames = config.target_nframes 38 | self.duration_seconds = self.duration_frames / self.target_fps if type(self.target_fps) == int else None 39 | 40 | # LOAD DATA 41 | assert hasattr(config, "root"), "No root folder specified in config" 42 | assert os.path.exists(os.path.join(config.root, datafolder)), f"Data folder {os.path.join(config.root, datafolder)} does not exist" 43 | assert os.path.exists(os.path.join(config.root, "FileList.csv")), f"FileList.csv does not exist in {config.root}" 44 | self.metadata = pd.read_csv(os.path.join(config.root, "FileList.csv")) 45 | self.metadata = self.metadata[self.metadata["Split"].isin(split)] # filter by split 46 | self.len_before_filter = len(self.metadata) 47 | # add duration column 48 | self.metadata["Duration"] = self.metadata["NumberOfFrames"] / self.metadata["FPS"] # won't work for pediatrics 49 | # filter by duration 50 | if self.duration_seconds is not None: 51 | self.metadata = self.metadata[self.metadata["Duration"] > self.duration_seconds] 52 | 53 | # check if videos are reachable 54 | self.metadata["VideoPath"] = self.metadata["FileName"].apply(lambda x: os.path.join(config.root, datafolder, x) if x.endswith(ext) else os.path.join(config.root, datafolder, x.split('.')[0] + ext)) 55 | self.metadata["VideoExists"] = self.metadata["VideoPath"].apply(lambda x: os.path.exists(x)) 56 | self.metadata = self.metadata[self.metadata["VideoExists"]] 57 | self.metadata.reset_index(inplace=True, drop=True) 58 | if len(self.metadata) == 0: 59 | raise ValueError(f"No data found in folder {os.path.join(config.root, datafolder)}") 60 | 61 | self.transform = lambda x: x 62 | if hasattr(config, "transforms"): 63 | transforms = [] 64 | for transform in config.transforms: 65 | tklass = getattr(T, transform.name) 66 | tobj = tklass(**transform.params) 67 | transforms.append(tobj) 68 | self.transform = T.Compose(transforms) 69 | 70 | def __len__(self): 71 | return len(self.metadata) 72 | 73 | def __getitem__(self, idx, return_row=False): 74 | row = self.metadata.iloc[idx] 75 | output = { 76 | 'filename': row['FileName'], 77 | 'still': False, 78 | } 79 | 80 | if "image" in self.outputs or "video" in self.outputs: 81 | reader = decord.VideoReader(row["VideoPath"], ctx=decord.cpu(), width=self.resolution, height=self.resolution) 82 | og_fps = reader.get_avg_fps() 83 | og_frame_count = len(reader) 84 | 85 | if "video" in self.outputs: 86 | # Generate indices to resample 87 | # Generate a random starting point to cover all frames 88 | if self.target_fps == "original": 89 | target_fps = og_fps 90 | elif self.target_fps == "random": 91 | target_fps = np.random.randint(16, 120) 92 | elif self.target_fps == "half": 93 | target_fps = int(og_fps//2) 94 | elif self.target_fps == "exponential": 95 | rnd, offset = np.random.randint(0, 100), 11 96 | target_fps = int(np.exp(rnd/offset) + offset) # min: 12, max: ~8000 97 | else: 98 | target_fps = self.target_fps 99 | new_frame_count = np.floor(target_fps / og_fps * og_frame_count).astype(int) 100 | resample_indices_a = np.linspace(0, og_frame_count-1, new_frame_count, endpoint=False).round().astype(int) 101 | start_idx = np.random.choice(np.arange(0, resample_indices_a[1])) if len(resample_indices_a) > 1 and resample_indices_a[1] > 1 else 0 102 | resample_indices_a = resample_indices_a + start_idx 103 | 104 | # Sample a random chunk to cover the requested duration 105 | start_idx = np.random.choice(np.arange(0, len(resample_indices_a) - self.duration_frames)) if len(resample_indices_a) > self.duration_frames else 0 106 | end_idx = start_idx + self.duration_frames 107 | resample_indices_b = resample_indices_a[start_idx:end_idx] 108 | resample_indices_b = resample_indices_b[resample_indices_b C x T x H x W 121 | output["video"] = self.transform(video) 122 | output["fps"] = target_fps 123 | output["padding"] = p_index 124 | if self.target_fps == "exponential": 125 | resample_indices_b[1:] = (resample_indices_b[1:] - resample_indices_b[:-1] >= 1).cumsum(0) 126 | resample_indices_b[0] = 0 127 | output["indices"] = np.concatenate((resample_indices_b, np.repeat(resample_indices_b[-1], self.duration_frames-len(resample_indices_b)))) 128 | 129 | if "lvef" in self.outputs: 130 | lvef = row["EF"] / 100.0 # normalize to [0, 1] 131 | output["lvef"] = torch.tensor(lvef, dtype=torch.float32) 132 | 133 | if "image" in self.outputs: 134 | image = reader.get_batch(np.random.randint(0, og_frame_count, 1))[0] # H x W x C, uint8 135 | image = image.float() / 128.0 - 1 136 | image = image.permute(2, 0, 1) # H x W x C -> C x H x W 137 | output["image"] = self.transform(image) 138 | 139 | if return_row: 140 | return output, row 141 | 142 | return output 143 | 144 | 145 | class Pediatric(Dynamic): 146 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None: 147 | super().__init__(config, split) 148 | 149 | # View 150 | self.view = config.get("views", "ALL") # A4C, PSAX, ALL 151 | if self.view == "ALL": 152 | pass 153 | else: 154 | self.metadata = self.metadata[self.metadata["View"] == self.view] 155 | self.metadata.reset_index(inplace=True, drop=True) 156 | if len(self.metadata) == 0: 157 | raise ValueError(f"No videos found for view {self.view}") 158 | 159 | def __getitem__(self, idx): 160 | output, row = super().__getitem__(idx, return_row=True) 161 | if "view" in self.outputs: 162 | output["view"] = row["View"] 163 | 164 | return output 165 | 166 | 167 | class Latent(Dynamic): 168 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None: 169 | self.config = config 170 | 171 | super().__init__(config, split, datafolder="Latents", ext=".pt") 172 | 173 | self.view = config.get("views", "ALL") # A4C, PSAX, ALL 174 | if self.view == "ALL": 175 | pass 176 | else: 177 | self.metadata = self.metadata[self.metadata["View"] == self.view] 178 | self.metadata.reset_index(inplace=True, drop=True) 179 | if len(self.metadata) == 0: 180 | raise ValueError(f"No videos found for view {self.view}") 181 | 182 | def __getitem__(self, idx, return_row=False): 183 | row = self.metadata.iloc[idx] 184 | output = { 185 | 'filename': row['FileName'], 186 | } 187 | 188 | if "image" in self.outputs or "video" in self.outputs: 189 | latent_file = row["VideoPath"] 190 | latent_video_tensor = torch.load(latent_file) # T x C x H x W 191 | og_fps = row["FPS"] 192 | og_frame_count = len(latent_video_tensor) 193 | 194 | if "video" in self.outputs: 195 | if self.target_fps == "original": 196 | target_fps = og_fps 197 | elif self.target_fps == "random": 198 | target_fps = np.random.randint(8, 50) 199 | else: 200 | target_fps = self.target_fps 201 | 202 | new_frame_count = np.floor(target_fps / og_fps * og_frame_count).astype(int) 203 | resample_indices = np.linspace(0, og_frame_count, new_frame_count, endpoint=False).round().astype(int) 204 | start_idx = np.random.choice(np.arange(0, resample_indices[1])) if len(resample_indices) > 1 and resample_indices[1] > 1 else 0 205 | resample_indices = resample_indices + start_idx 206 | 207 | # Sample a random chunk to cover the requested duration 208 | start_idx = np.random.choice(np.arange(0, len(resample_indices) - self.duration_frames)) if len(resample_indices) > self.duration_frames else 0 209 | end_idx = start_idx + self.duration_frames 210 | resample_indices = resample_indices[start_idx:end_idx] 211 | resample_indices = resample_indices[resample_indices C x T x H x W 224 | output["video"] = self.transform(latent_video_sample) 225 | output["fps"] = target_fps 226 | output["padding"] = p_index 227 | 228 | if "lvef" in self.outputs: 229 | lvef = row["EF"] / 100.0 230 | output["lvef"] = torch.tensor(lvef, dtype=torch.float32) 231 | 232 | if "image" in self.outputs: 233 | latent_image_tensor = latent_video_tensor[np.random.randint(0, og_frame_count, 1)][0] # C x H x W 234 | output["image"] = self.transform(latent_image_tensor) 235 | 236 | if return_row: 237 | return output, row 238 | 239 | return output 240 | 241 | 242 | class RandomVideo(Dataset): 243 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None: 244 | super().__init__() 245 | 246 | self.config = config 247 | self.root = config.root 248 | 249 | self.target_nframes = config.target_nframes 250 | self.target_resolution = config.target_resolution 251 | 252 | self.outputs = config.outputs 253 | assert len(self.outputs) > 0, "Outputs must not be empty" 254 | assert all([o in ["video", "image"] for o in self.outputs]), "Outputs can only be video or image (or both) for RandomVideo" 255 | 256 | 257 | assert os.path.exists(self.root), f"Root folder {self.root} does not exist" 258 | assert os.path.isdir(self.root), f"Root folder {self.root} is not a directory" 259 | self.all_frames = os.listdir(self.root) 260 | 261 | assert len(self.all_frames) > 0, f"No frames found in {self.root}" 262 | 263 | self.still_image_p = config.get("still_image_p", 0) # probability of returning a still image instead of a video 264 | 265 | def __len__(self): 266 | return len(self.all_frames) 267 | 268 | def __getitem__(self, idx): 269 | 270 | output = { 271 | 'filename': 'Fake', 272 | } 273 | 274 | if "image" in self.outputs: 275 | path = os.path.join(self.root, self.all_frames[idx]) 276 | image = Image.open(path) # H x W x C, uint8 277 | image = np.array(image) # H x W x C, uint8 278 | image = torch.from_numpy(image) 279 | image = image.permute(2, 0, 1).float() # C x H x W 280 | image = image / 128.0 - 1 # [-1, 1] 281 | output["image"] = image 282 | 283 | if "video" in self.outputs: 284 | if self.still_image_p > torch.rand(1,).item(): 285 | random_indices = np.random.randint(0, len(self.all_frames), 1) 286 | path = os.path.join(self.root, self.all_frames[random_indices[0]]) 287 | image = Image.open(path) # H x W x C, uint8 288 | image = np.array(image) # H x W x C, uint8 289 | image = torch.from_numpy(image) 290 | image = image.permute(2, 0, 1).float() # C x H x W 291 | image = image / 128.0 - 1 292 | image = image[:,None,:,:].repeat(1, self.target_nframes, 1, 1) 293 | output["video"] = image 294 | output["still"] = True 295 | else: 296 | random_indices = np.random.randint(0, len(self.all_frames), self.target_nframes) 297 | paths = [os.path.join(self.root, self.all_frames[ridx]) for ridx in random_indices] 298 | images = [Image.open(path) for path in paths] 299 | images = [np.array(image) for image in images] 300 | images = np.stack(images, axis=0) # T x H x W x C 301 | images = torch.from_numpy(images) # T x H x W x C 302 | images = images.permute(3, 0, 1, 2).float() # C x T x H x W 303 | images = images / 128.0 - 1 # [-1, 1] 304 | output["video"] = images 305 | output["still"] = False 306 | output["fps"] = 0 307 | output["padding"] = 0 308 | 309 | return output 310 | 311 | 312 | class FrameFolder(Dataset): 313 | """ config: 314 | - name: FrameFolder 315 | active: true 316 | params: 317 | video_folder: path/to/video_folders 318 | meta_path: path/to/FileList.csv 319 | outputs: ['video', 'image', 'lvef'] 320 | """ 321 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None: 322 | super().__init__() 323 | 324 | self.config = config 325 | self.video_folder = config.video_folder 326 | self.meta_path = config.meta_path 327 | 328 | 329 | self.target_nframes = config.target_nframes 330 | self.target_resolution = config.target_resolution 331 | 332 | self.metadata = pd.read_csv(self.meta_path) 333 | self.metadata = self.metadata[self.metadata["Split"].isin(split)] # filter by split 334 | 335 | # check if videos are reachable 336 | self.metadata["VideoPath"] = self.metadata["FileName"].apply(lambda x: os.path.join(config.video_folder, x.split('.')[0])) 337 | self.metadata["VideoExists"] = self.metadata["VideoPath"].apply(lambda x: ( os.path.isdir(x) and len(os.listdir(x)) > 0 )) 338 | self.metadata = self.metadata[self.metadata["VideoExists"]] 339 | 340 | 341 | def __len__(self): 342 | return len(self.metadata) 343 | 344 | def __getitem__(self, idx): 345 | 346 | row = self.metadata.iloc[idx] 347 | 348 | output = { 349 | 'filename': row['FileName'], 350 | } 351 | 352 | if "image" in self.outputs: 353 | fpath = os.path.join(row['VideoPath']) 354 | rand_item = np.random.choice(os.listdir(fpath)) 355 | image = Image.open(os.path.join(fpath, rand_item)) # H x W x C, uint8 356 | image = np.array(image) # H x W x C, uint8 357 | image = torch.from_numpy(image) 358 | image = image.permute(2, 0, 1).float() # C x H x W 359 | image = image / 128.0 - 1 # [-1, 1] 360 | output["image"] = image 361 | 362 | if "video" in self.outputs: 363 | fpath = os.path.join(row['VideoPath']) 364 | fc = self.target_nframes 365 | all_frames_names = sorted(os.listdir(fpath)) 366 | if len(all_frames_names) > fc: 367 | start_idx = np.random.randint(0, len(all_frames_names) - fc) 368 | end_idx = start_idx + fc 369 | else: 370 | start_idx = 0 371 | end_idx = -1 372 | all_frames_names = all_frames_names[start_idx:end_idx] 373 | all_frames_path = [os.path.join(fpath, f) for f in all_frames_names] 374 | all_frames = [Image.open(f) for f in all_frames_path] 375 | all_frames = [np.array(f) for f in all_frames] 376 | all_frames = np.stack(all_frames, axis=0) # T x H x W x C 377 | all_frames = torch.from_numpy(all_frames) # T x H x W x C 378 | all_frames = all_frames.permute(3, 0, 1, 2).float() # C x T x H x W 379 | all_frames = all_frames / 128.0 - 1 # [-1, 1] 380 | 381 | if len(all_frames) < fc: 382 | padding_element = torch.zeros_like(all_frames[0]) 383 | padding = torch.stack([padding_element] * (fc - len(all_frames))) 384 | all_frames = torch.cat((all_frames, padding), dim=0) 385 | assert len(all_frames) == fc, f"Video length is {len(all_frames)} but should be {fc}" 386 | 387 | output["video"] = all_frames 388 | 389 | if "lvef" in self.outputs: 390 | lvef = row["EF"] / 100.0 391 | output["lvef"] = torch.tensor(lvef, dtype=torch.float32) 392 | 393 | return output 394 | 395 | 396 | def instantiate_dataset(configs, split=["TRAIN", "VAL", "TEST"]): 397 | # config = config.copy() 398 | # assert config.get("datasets", False), "No 'datasets' key found in config" 399 | 400 | # Check if number of frames and resolution are the same for all datasets 401 | target_nframes = None 402 | target_resolution = None 403 | reference_name = None 404 | for dataset_config in configs: 405 | if dataset_config.get("active", True): 406 | if reference_name is None: 407 | reference_name = dataset_config.name 408 | if target_nframes is None: 409 | target_nframes = dataset_config.params.target_nframes 410 | else: 411 | newd = dataset_config.params.target_nframes 412 | assert newd == target_nframes, f"All datasets must ouput the same number of frames, got {reference_name}: {target_nframes} frames and {dataset_config.name}: {newd} frames." 413 | if target_resolution is None: 414 | target_resolution = dataset_config.params.target_resolution 415 | else: 416 | assert dataset_config.params.target_resolution == target_resolution, f"All datasets must have the same target_resolution, got {reference_name}: {target_resolution} and {dataset_config.name}: {dataset_config.params.target_resolution}." 417 | 418 | datasets = [] 419 | for dataset_config in configs: 420 | if dataset_config.get("active", True): 421 | datasets.append(globals()[dataset_config.name](dataset_config.params, split=split)) 422 | 423 | if len(datasets) == 1: 424 | return datasets[0] 425 | else: 426 | return ConcatDataset(datasets) 427 | 428 | 429 | class RFBalancer(Dataset): # Real - Fake Balancer 430 | """ 431 | Balances the dataset by sampling from each dataset with equal probability. 432 | 433 | """ 434 | def __init__(self, real_dataset=None, fake_dataset=None, transform=None) -> None: 435 | super().__init__() 436 | 437 | # self.datasets = [fake_dataset, real_dataset] 438 | self.datasets = [] 439 | if fake_dataset is not None: 440 | self.datasets.append(fake_dataset) 441 | if real_dataset is not None: 442 | self.datasets.append(real_dataset) 443 | 444 | if len(self.datasets) == 0: 445 | raise ValueError("At least one dataset must be provided") 446 | 447 | if len(self.datasets) > 1: 448 | self.ds_idx = (np.random.rand(1,) < 0.5)[0] # pick the first dataset to start with 449 | else: 450 | self.ds_idx = 0 451 | 452 | self.ds_current = [0] * len(self.datasets) 453 | 454 | self.transforms = transform 455 | 456 | def __len__(self): 457 | return np.sum([len(ds) for ds in self.datasets]) 458 | 459 | def _get_index_for_ds(self, idx): 460 | ds_idx = 0 461 | while True: 462 | if idx < len(self.datasets[ds_idx]): 463 | break 464 | else: 465 | idx -= len(self.datasets[ds_idx]) 466 | ds_idx = (ds_idx + 1) % len(self.datasets) 467 | return ds_idx, idx 468 | 469 | 470 | def __getitem__(self, idx): 471 | ds_idx, idx = self._get_index_for_ds(idx) 472 | output = self.datasets[ds_idx][idx] # get item from dataset 473 | output["real"] = float(ds_idx) # add real/fake label 474 | 475 | if self.transforms is not None and "video" in output: 476 | output["video"] = self.transforms(output["video"]) 477 | if self.transforms is not None and "image" in output: 478 | output["image"] = self.transforms(output["image"]) 479 | 480 | return output 481 | 482 | 483 | class SimaseUSVideoDataset(Dataset): 484 | def __init__(self, 485 | phase='training', 486 | transform=None, 487 | latents_csv='./', 488 | training_latents_base_path="./", 489 | in_memory=True, 490 | generator_seed=None): 491 | self.phase = phase 492 | self.training_latents_base_path = training_latents_base_path 493 | 494 | self.in_memory = in_memory 495 | self.videos = [] 496 | 497 | PHASE_TO_SPLIT = {"training": "TRAIN", "validation": "VAL", "testing": "TEST"} 498 | self.df = pd.read_csv(latents_csv) 499 | self.df = self.df[self.df["Split"] == PHASE_TO_SPLIT[self.phase]].reset_index(drop=True) 500 | self.transform = transform 501 | 502 | if generator_seed is None: 503 | self.generator = np.random.default_rng() 504 | #unseeded 505 | else: 506 | self.generator_seed = generator_seed 507 | print(f"Set {self.phase} dataset seed to {self.generator_seed}") 508 | 509 | if self.in_memory: 510 | self.load_videos() 511 | 512 | def __len__(self): 513 | return len(self.df) 514 | 515 | def __getitem__(self, index): 516 | vid_a = self.get_vid(index) 517 | if self.transform is not None: 518 | vid_a = self.transform(vid_a) 519 | return vid_a 520 | 521 | def reset_generator(self): 522 | self.generator = np.random.default_rng(self.generator_seed) 523 | 524 | def get_vid(self, index, from_disk=False): 525 | if self.in_memory and not from_disk: 526 | return self.videos[index] 527 | else: 528 | path = self.df.iloc[index]["FileName"].split('.')[0] + ".pt" 529 | path = os.path.join(self.training_latents_base_path, path) 530 | return torch.load(path) 531 | 532 | def load_videos(self): 533 | self.videos = [] 534 | print("Preloading videos") 535 | for i in range(len(self)): 536 | self.videos.append(self.get_vid(i, from_disk=True)) 537 | 538 | 539 | class SiameseUSDataset(Dataset): 540 | def __init__(self, 541 | phase='training', 542 | transform=None, 543 | latents_csv='./', 544 | training_latents_base_path="./", 545 | in_memory=True, 546 | generator_seed=None): 547 | self.phase = phase 548 | self.training_latents_base_path = training_latents_base_path 549 | 550 | self.in_memory = in_memory 551 | self.videos = [] 552 | 553 | PHASE_TO_SPLIT = {"training": "TRAIN", "validation": "VAL", "testing": "TEST"} 554 | self.df = pd.read_csv(latents_csv) 555 | self.df = self.df[self.df["Split"] == PHASE_TO_SPLIT[self.phase]].reset_index(drop=True) 556 | 557 | self.transform = transform 558 | 559 | if generator_seed is None: 560 | self.generator = np.random.default_rng() 561 | #unseeded 562 | else: 563 | self.generator_seed = generator_seed 564 | print(f"Set {self.phase} dataset seed to {self.generator_seed}") 565 | 566 | if self.in_memory: 567 | self.load_videos() 568 | 569 | def __len__(self): 570 | return len(self.df) 571 | 572 | def __getitem__(self, index): 573 | 574 | vid_a = torch.clone(self.get_vid(index)) 575 | if self.generator.uniform() < 0.5: 576 | vid_b = torch.clone(self.get_vid((index + self.generator.integers(low=1, high=len(self))) % len(self))) # random different vid 577 | y = 0.0 578 | else: 579 | vid_b = torch.clone(vid_a) 580 | y = 1.0 581 | 582 | if self.transform is not None: 583 | vid_a = self.transform(vid_a) 584 | vid_b = self.transform(vid_b) 585 | 586 | frame_a = self.generator.integers(len(vid_a)) 587 | frame_b = (frame_a + self.generator.integers(low=1, high=len(vid_b))) % len(vid_b) 588 | #print(f"Dataloader: framea {frame_a} - frame_b {frame_b} - y: {y}") 589 | return vid_a[frame_a], vid_b[frame_b], y 590 | 591 | def reset_generator(self): 592 | self.generator = np.random.default_rng(self.generator_seed) 593 | 594 | def get_vid(self, index, from_disk=False): 595 | if self.in_memory and not from_disk: 596 | return self.videos[index] 597 | else: 598 | path = self.df.iloc[index]["FileName"].split('.')[0] + ".pt" 599 | path = os.path.join(self.training_latents_base_path, path) 600 | return torch.load(path) 601 | 602 | def load_videos(self): 603 | self.videos = [] 604 | print("Preloading videos") 605 | for i in range(len(self)): 606 | self.videos.append(self.get_vid(i, from_disk=True)) 607 | 608 | 609 | class ImageSet(Dataset): 610 | def __init__(self, root, ext=".jpg"): 611 | self.root = root 612 | self.all_images = glob(os.path.join(root, "*.jpg")) 613 | 614 | def __len__(self): 615 | return len(self.all_images) 616 | 617 | def __getitem__(self, idx): 618 | image = Image.open(self.all_images[idx]) 619 | image = np.array(image) / 128.0 - 1 # [0, 255] -> [-1, 1] 620 | image = image.transpose(2, 0, 1) # H x W x C -> C x H x W 621 | return image 622 | 623 | class TensorSet(Dataset): 624 | def __init__(self, root): 625 | self.root = root 626 | self.all_tensors = glob(os.path.join(root, "*.pt")) 627 | 628 | def __len__(self): 629 | return len(self.all_tensors) 630 | 631 | def __getitem__(self, idx): 632 | tensor = torch.load(self.all_tensors[idx], map_location='cpu') 633 | return tensor 634 | 635 | if __name__ == "__main__": 636 | pass 637 | 638 | -------------------------------------------------------------------------------- /echosyn/common/privacy_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils import data 4 | from sklearn import metrics 5 | import math 6 | import random 7 | import itertools 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from echosyn.common.datasets import SiameseUSDataset 11 | 12 | 13 | """ 14 | This file provides the most important functions that are used in our experiments. These functions are called in 15 | AgentSiameseNetwork.py which provides the actual training/validation loop and the code for evaluation. 16 | """ 17 | 18 | 19 | # Function to get the data loader. 20 | def get_data_loaders(phase='training', data_handling='balanced', n_channels=3, n_samples=100000, transform=None, 21 | image_path='./', batch_size=32, shuffle=True, num_workers=16, pin_memory=True, save_path=None, generator_seed=None): 22 | 23 | #dataset = SiameseDataset(phase=phase, data_handling=data_handling, n_channels=n_channels, n_samples=n_samples, 24 | # transform=transform, image_path=image_path, save_path=save_path) 25 | latents_csv = image_path 26 | training_latents_base_path = os.path.join(os.path.dirname(latents_csv), "Latents") 27 | dataset = SiameseUSDataset(phase=phase, transform=transform, latents_csv=latents_csv, training_latents_base_path=training_latents_base_path, in_memory=False, generator_seed=generator_seed) 28 | dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, 29 | pin_memory=pin_memory) 30 | 31 | return dataloader 32 | 33 | 34 | # This function represents the training loop for the standard case where we have two input images and one output node. 35 | def train(net, training_loader, n_samples, batch_size, criterion, optimizer, epoch, n_epochs): 36 | net.train() 37 | running_loss = 0.0 38 | 39 | print('Training----->') 40 | for i, batch in enumerate(training_loader): 41 | inputs1, inputs2, labels = batch 42 | inputs1, inputs2, labels = inputs1.cuda(), inputs2.cuda(), labels.cuda() 43 | 44 | # zero the parameter gradients 45 | optimizer.zero_grad() 46 | 47 | # forward + backward + optimize 48 | outputs = net(inputs1, inputs2) 49 | outputs = outputs.squeeze() 50 | labels = labels.type_as(outputs) 51 | loss = criterion(outputs, labels) 52 | loss.backward() 53 | optimizer.step() 54 | 55 | running_loss += loss.item() 56 | 57 | print('Epoch [%d/%d], Iteration [%d/%d], Loss: %.4f' % (epoch + 1, n_epochs, i + 1, 58 | math.ceil(n_samples / batch_size), 59 | loss.item())) 60 | 61 | # Compute the average loss per epoch 62 | training_loss = running_loss / math.ceil(n_samples / batch_size) 63 | return training_loss 64 | 65 | 66 | # This function represents the validation loop for the standard case where we have two input images and one output node. 67 | def validate(net, validation_loader, n_samples, batch_size, criterion, epoch, n_epochs): 68 | net.eval() 69 | running_loss = 0 70 | 71 | print('Validating----->') 72 | with torch.no_grad(): 73 | for i, batch in enumerate(validation_loader): 74 | inputs1, inputs2, labels = batch 75 | inputs1, inputs2, labels = inputs1.cuda(), inputs2.cuda(), labels.cuda() 76 | 77 | # forward 78 | outputs = net(inputs1, inputs2) 79 | outputs = outputs.squeeze() 80 | labels = labels.type_as(outputs) 81 | loss = criterion(outputs, labels) 82 | 83 | running_loss += loss.item() 84 | 85 | print('Epoch [%d/%d], Iteration [%d/%d], Loss: %.4f' % (epoch + 1, n_epochs, i + 1, 86 | math.ceil(n_samples / batch_size), 87 | loss.item())) 88 | 89 | # Compute the average loss per epoch 90 | validation_loss = running_loss / math.ceil(n_samples / batch_size) 91 | return validation_loss 92 | 93 | 94 | # This function represents the test loop for the standard case where we have two input images and one output node. 95 | # This function returns the true labels and the predicted values. 96 | def test(net, test_loader): 97 | net.eval() 98 | y_true = None 99 | y_pred = None 100 | 101 | print('Testing----->') 102 | with torch.no_grad(): 103 | for i, batch in enumerate(test_loader): 104 | inputs1, inputs2, labels = batch 105 | 106 | if y_true is None: 107 | y_true = labels 108 | else: 109 | y_true = torch.cat((y_true, labels), 0) 110 | 111 | inputs1, inputs2, labels = inputs1.cuda(), inputs2.cuda(), labels.cuda() 112 | outputs = net(inputs1, inputs2) 113 | outputs = torch.sigmoid(outputs) 114 | 115 | if y_pred is None: 116 | y_pred = outputs.cpu() 117 | else: 118 | y_pred = torch.cat((y_pred, outputs.cpu()), 0) 119 | 120 | y_pred = y_pred.squeeze() 121 | return y_true, y_pred 122 | 123 | 124 | # This function computes some standard evaluation metrics given the true labels and the predicted values. 125 | def get_evaluation_metrics(y_true, y_pred): 126 | accuracy = metrics.accuracy_score(y_true, y_pred) 127 | f1_score = metrics.f1_score(y_true, y_pred) 128 | precision = metrics.precision_score(y_true, y_pred) 129 | recall = metrics.recall_score(y_true, y_pred) 130 | report = metrics.classification_report(y_true, y_pred) 131 | confusion_matrix = metrics.confusion_matrix(y_true, y_pred) 132 | return accuracy, f1_score, precision, recall, report, confusion_matrix 133 | 134 | 135 | # This function is used to apply a threshold to the predicted values before computing some evaluation metrics. 136 | def apply_threshold(input_tensor, threshold): 137 | output = np.where(input_tensor > threshold, torch.ones(len(input_tensor)), torch.zeros(len(input_tensor))) 138 | return output 139 | 140 | 141 | # This function implements bootstrapping, in order to get the mean AUC value and the 95% confidence interval. 142 | def bootstrap(n_bootstraps, y_true, y_pred, path, experiment_description): 143 | y_true = np.array(y_true) 144 | y_pred = np.array(y_pred) 145 | 146 | bootstrapped_scores = [] 147 | 148 | f = open(path + experiment_description + '_AUC_bootstrapped.txt', "w+") 149 | f.write('AUC_bootstrapped\n') 150 | 151 | for i in range(n_bootstraps): 152 | indices = np.random.randint(0, len(y_pred) - 1, len(y_pred)) 153 | auc = metrics.roc_auc_score(y_true[indices], y_pred[indices]) 154 | bootstrapped_scores.append(auc) 155 | f.write(str(auc) + '\n') 156 | f.close() 157 | sorted_scores = np.array(bootstrapped_scores) 158 | sorted_scores.sort() 159 | 160 | auc_mean = np.mean(sorted_scores) 161 | confidence_lower = sorted_scores[int(0.025 * len(sorted_scores))] 162 | confidence_upper = sorted_scores[int(0.975 * len(sorted_scores))] 163 | 164 | f = open(path + experiment_description + '_AUC_confidence.txt', "w+") 165 | f.write('AUC_mean: %s\n' % auc_mean) 166 | f.write('Confidence interval for the AUC score: ' + str(confidence_lower) + ' - ' + str(confidence_upper)) 167 | f.close() 168 | return auc_mean, confidence_lower, confidence_upper 169 | 170 | 171 | # This is a function that plots the loss curves. 172 | def plot_loss_curves(loss_dict, path, experiment_description): 173 | plt.figure() 174 | plt.plot(range(1, len(loss_dict['training']) + 1), loss_dict['training'], label='Training Loss') 175 | plt.plot(range(1, len(loss_dict['validation']) + 1), loss_dict['validation'], label='Validation Loss') 176 | plt.xlabel('Epoch') 177 | plt.ylabel('Loss') 178 | plt.title('Training and validation loss curves') 179 | plt.legend() 180 | plt.savefig(path + experiment_description + '_loss_curves.png') 181 | 182 | 183 | # This is a function that plots the ROC curve. 184 | def plot_roc_curve(fp_rates, tp_rates, path, experiment_description): 185 | plt.figure() 186 | plt.plot(fp_rates, tp_rates, label='ROC Curve') 187 | plt.xlabel('False positive rate') 188 | plt.ylabel('True positive rate') 189 | plt.title('ROC Curve') 190 | plt.legend() 191 | plt.savefig(path + experiment_description + '_ROC_curve.png') 192 | 193 | 194 | # This is a function that saves the evaluation metrics to a file. 195 | def save_results_to_file(auc, accuracy, f1_score, precision, recall, report, confusion_matrix, path, 196 | experiment_description): 197 | f = open(path + experiment_description + '_results.txt', "w+") 198 | f.write('AUC: %s\n' % auc) 199 | f.write('Accuracy: %s\n' % accuracy) 200 | f.write('F1-Score: %s\n' % f1_score) 201 | f.write('Precision: %s\n' % precision) 202 | f.write('Recall: %s\n' % recall) 203 | f.write('Classification report: %s\n' % report) 204 | f.write('Confusion matrix: %s\n' % confusion_matrix) 205 | f.close() 206 | 207 | 208 | # This function saves the ROC metrics to a file. 209 | def save_roc_metrics_to_file(fp_rates, tp_rates, thresholds, path, experiment_description): 210 | f = open(path + experiment_description + '_ROC_metrics.txt', "w+") 211 | f.write('FP_rate\tTP_rate\tThreshold\n') 212 | for i in range(len(fp_rates)): 213 | f.write(str(fp_rates[i]) + '\t' + str(tp_rates[i]) + '\t' + str(thresholds[i]) + '\n') 214 | f.close() 215 | 216 | 217 | # This function saves the training and validation loss values to a file. 218 | def save_loss_curves(loss_dict, path, experiment_description): 219 | f = open(path + experiment_description + '_loss_values.txt', "w+") 220 | f.write('TrainingLoss\tValidationLoss\n') 221 | for i in range(len(loss_dict['training'])): 222 | f.write(str(loss_dict['training'][i]) + '\t' + str(loss_dict['validation'][i]) + '\n') 223 | f.close() 224 | 225 | 226 | # This function is used to save a checkpoint. This enables us to resume an experiment at a later point if wanted. 227 | def save_checkpoint(epoch, model, optimizer, loss_dict, best_loss, num_bad_epochs, filename='checkpoint.pth'): 228 | state = { 229 | 'epoch': epoch + 1, 230 | 'state_dict': model.state_dict(), 231 | 'optimizer': optimizer.state_dict(), 232 | 'loss_dict': loss_dict, 233 | 'best_loss': best_loss, 234 | 'num_bad_epochs': num_bad_epochs 235 | } 236 | torch.save(state, filename) 237 | 238 | 239 | # This function saves the true labels, the predicted vales, and the thresholded values to a file. 240 | def save_labels_predictions(y_true, y_pred, y_pred_thresh, path, experiment_description): 241 | f = open(path + experiment_description + '_labels_predictions.txt', "w+") 242 | f.write('Label\tPrediction\tPredictionThresholded\n') 243 | for i in range(len(y_true)): 244 | f.write(str(y_true[i]) + '\t' + str(y_pred[i]) + '\t' + str(y_pred_thresh[i]) + '\n') 245 | f.close() 246 | 247 | 248 | # This function was utilized to construct the positive and negative pairs needed for training and testing our 249 | # verification network. This function takes filenames as an input argument and returns both the list of tuples and the 250 | # list of corresponding labels. The constructed pairs were later saved to a .txt file. Now available in the folder 251 | # './image_pairs/' (pairs_training.txt, pairs_validation.txt, pairs_testing.txt). 252 | # Example: 253 | # 254 | # train_val_filenames = np.loadtxt('train_val_list.txt', dtype=str) 255 | # test_filenames = np.loadtxt('test_list.txt', dtype=str) 256 | # tuples_train, labels_train = get_tuples_labels(train_val_filenames[:75708]) 257 | # tuples_val, labels_val = get_tuples_labels(train_val_filenames[75708:]) 258 | # tuples_test, labels_test = get_tuples_labels(test_filenames) 259 | # 260 | def get_tuples_labels(filenames): 261 | tuples_list = [] 262 | labels_list = [] 263 | patients = [] 264 | 265 | for i, element in enumerate(filenames): 266 | element = element[:-8] 267 | patients.append(element) 268 | 269 | unique_patients, counts_patients = np.unique(patients, return_counts=True) 270 | patients_dict = dict(zip(unique_patients, counts_patients)) 271 | 272 | start_idx = 0 273 | patients_files_dict = {} 274 | 275 | for key in patients_dict: 276 | patients_files_dict[key] = filenames[start_idx:start_idx + patients_dict[key]] 277 | start_idx += patients_dict[key] 278 | 279 | if len(patients_files_dict[key]) > 1: 280 | # samples = list(itertools.product(patients_files_dict[key], patients_files_dict[key])) 281 | samples = list(itertools.combinations(patients_files_dict[key], 2)) 282 | tuples_list.append(samples) 283 | labels_list.append(np.ones(len(samples)).tolist()) 284 | 285 | tuples_list = list(itertools.chain.from_iterable(tuples_list)) 286 | labels_list = list(itertools.chain.from_iterable(labels_list)) 287 | 288 | N = len(tuples_list) 289 | i = 0 290 | 291 | while i < N: 292 | file1 = random.choice(filenames) 293 | file2 = random.choice(filenames) 294 | 295 | if file1[:-8] != file2[:-8]: 296 | sample = (file1, file2) 297 | tuples_list.append(sample) 298 | labels_list.append(0.0) 299 | i += 1 300 | 301 | return tuples_list, labels_list 302 | -------------------------------------------------------------------------------- /echosyn/lidm/README.md: -------------------------------------------------------------------------------- 1 | # Latent Image Diffusion Model 2 | 3 | The Latent Image Diffusion Model (LIDM) is the first step of our generative pipeline. It generates a latent representation of a heart, which is then passed to the Latent Video Diffusion Model (LVDM) to generate a video of the heart beating. 4 | 5 | Training the LIDM is straightforward and does not require a lot of resouces. In the paper, the LIDMs are trained for ~24h on a single A100 GPU. The batch size can be adjusted to fit smaller GPUs with no noticeable loss of quality. 6 | 7 | ## 1. Activate the environment 8 | 9 | First, activate the echosyn environment. 10 | 11 | ```bash 12 | conda activate echosyn 13 | ``` 14 | 15 | ## 2. Data preparation 16 | Follow the instruction in the [Data preparation](../../README.md#data-preparation) to prepare the data for training. Here, you need the VAE-encoded videos. 17 | 18 | ## 3. Train the LIDM 19 | Once the environment is set up and the data is ready, you can train the LIDMs with the following commands: 20 | 21 | ```bash 22 | python echosyn/lidm/train.py --config echosyn/lidm/configs/dynamic.yaml 23 | python echosyn/lidm/train.py --config echosyn/lidm/configs/ped_a4c.yaml 24 | python echosyn/lidm/train.py --config echosyn/lidm/configs/ped_psax.yaml 25 | ``` 26 | 27 | ## 4. Sample from the LIDM 28 | 29 | Once the LIDMs are trained, you can sample from them with the following command: 30 | 31 | ```bash 32 | # For the Dynamic dataset 33 | python echosyn/lidm/sample.py \ 34 | --config echosyn/lidm/configs/dynamic.yaml \ 35 | --unet experiments/lidm_dynamic/checkpoint-500000/unet_ema \ 36 | --vae models/vae \ 37 | --output samples/lidm_dynamic \ 38 | --num_samples 50000 \ 39 | --batch_size 128 \ 40 | --num_steps 64 \ 41 | --save_latent \ 42 | --seed 0 43 | ``` 44 | 45 | ```bash 46 | # For the Pediatric A4C dataset 47 | python echosyn/lidm/sample.py \ 48 | --config echosyn/lidm/configs/ped_a4c.yaml \ 49 | --unet experiments/lidm_ped_a4c/checkpoint-500000/unet_ema \ 50 | --vae models/vae \ 51 | --output samples/lidm_ped_a4c \ 52 | --num_samples 50000 \ 53 | --batch_size 128 \ 54 | --num_steps 64 \ 55 | --save_latent \ 56 | --seed 0 57 | ``` 58 | 59 | ```bash 60 | # For the Pediatric PSAX dataset 61 | python echosyn/lidm/sample.py \ 62 | --config echosyn/lidm/configs/ped_psax.yaml \ 63 | --unet experiments/lidm_ped_psax/checkpoint-500000/unet_ema \ 64 | --vae models/vae \ 65 | --output samples/lidm_ped_psax \ 66 | --num_samples 50000 \ 67 | --batch_size 128 \ 68 | --num_steps 64 \ 69 | --save_latent \ 70 | --seed 0 71 | ``` 72 | 73 | ## 5. Evaluate the LIDM 74 | 75 | To evaluate the LIDMs, we use the FID and IS scores. 76 | To do so, we need to generate 50,000 samples, using the command above. 77 | Note that the privacy step will reject some of these samples. 78 | It is therefore better to generate 100,000, so we can calculate the FID and IS scores again, on the privacy-compliant samples. 79 | The samples are compared to the real samples, which are generated in the [Data preparation](../../README.md#data-preparation) step. 80 | 81 | Then, to evaluate the samples, run the following commands: 82 | 83 | ```bash 84 | cd external/stylegan-v 85 | 86 | # For the Dynamic dataset 87 | python src/scripts/calc_metrics_for_dataset.py \ 88 | --real_data_path ../../data/reference/dynamic \ 89 | --fake_data_path ../../samples/lidm_dynamic/images \ 90 | --mirror 0 --gpus 1 --resolution 112 \ 91 | --metrics fid50k_full,is50k >> "../../samples/lidm_dynamic/metrics.txt" 92 | 93 | # For the Pediatric A4C dataset 94 | python src/scripts/calc_metrics_for_dataset.py \ 95 | --real_data_path data/reference/ped_a4c \ 96 | --fake_data_path samples/lidm_ped_a4c/images \ 97 | --mirror 0 --gpus 1 --resolution 112 \ 98 | --metrics fid50k_full,is50k >> "samples/lidm_ped_a4c/metrics.txt" 99 | 100 | # For the Pediatric PSAX dataset 101 | python src/scripts/calc_metrics_for_dataset.py \ 102 | --real_data_path data/reference/ped_psax \ 103 | --fake_data_path samples/lidm_ped_psax/images \ 104 | --mirror 0 --gpus 1 --resolution 112 \ 105 | --metrics fid50k_full,is50k >> "samples/lidm_ped_psax/metrics.txt 106 | ``` 107 | 108 | ## 6. Save the LIDMs for later use 109 | Once you are satisfied with the performance of the LIDMs, you can save them for later use with the following commands: 110 | 111 | ```bash 112 | mkdir -p models/lidm_dynamic; cp -r experiments/lidm_dynamic/checkpoint-500000/unet_ema/* models/lidm_dynamic/; cp experiments/lidm_dynamic/config.yaml models/lidm_dynamic/ 113 | mkdir -p models/lidm_ped_a4c; cp -r experiments/lidm_ped_a4c/checkpoint-500000/unet_ema/* models/lidm_ped_a4c/; cp experiments/lidm_ped_a4c/config.yaml models/lidm_ped_a4c/ 114 | mkdir -p models/lidm_ped_psax; cp -r experiments/lidm_ped_psax/checkpoint-500000/unet_ema/* models/lidm_ped_psax/; cp experiments/lidm_ped_psax/config.yaml models/lidm_ped_psax/ 115 | ``` 116 | 117 | This will save the selected ema version of the model, ready to be loaded in any other script as a standalone model. -------------------------------------------------------------------------------- /echosyn/lidm/configs/dynamic.yaml: -------------------------------------------------------------------------------- 1 | wandb_group: "lidm" 2 | output_dir: experiments/lidm_dynamic 3 | 4 | pretrained_model_name_or_path: null 5 | vae_path: models/vae 6 | 7 | globals: 8 | target_fps: 32 9 | target_nframes: 64 10 | outputs: ["image"] 11 | 12 | datasets: 13 | - name: Latent 14 | active: true 15 | params: 16 | root: data/latents/dynamic 17 | target_fps: ${globals.target_fps} 18 | target_nframes: ${globals.target_nframes} 19 | target_resolution: 14 # emb resolution 20 | outputs: ${globals.outputs} 21 | 22 | unet: 23 | _class_name: UNet2DModel 24 | sample_size: 14 # actual size is 16 25 | in_channels: 4 26 | out_channels: 4 27 | center_input_sample: false 28 | time_embedding_type: positional 29 | freq_shift: 0 30 | flip_sin_to_cos: true 31 | down_block_types: 32 | - AttnDownBlock2D 33 | - AttnDownBlock2D 34 | - AttnDownBlock2D 35 | - DownBlock2D 36 | up_block_types: 37 | - UpBlock2D 38 | - AttnUpBlock2D 39 | - AttnUpBlock2D 40 | - AttnUpBlock2D 41 | block_out_channels: 42 | - 128 43 | - 256 44 | - 256 45 | - 512 46 | layers_per_block: 2 47 | mid_block_scale_factor: 1 48 | downsample_padding: 1 49 | downsample_type: resnet 50 | upsample_type: resnet 51 | dropout: 0.0 52 | act_fn: silu 53 | attention_head_dim: 8 54 | norm_num_groups: 32 55 | attn_norm_num_groups: null 56 | norm_eps: 1e-05 57 | resnet_time_scale_shift: "default" 58 | class_embed_type: null 59 | num_class_embeds: null 60 | 61 | noise_scheduler: 62 | _class_name: DDPMScheduler 63 | num_train_timesteps: 1000 64 | beta_start: 0.0001 65 | beta_end: 0.02 66 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2 67 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range 68 | clip_sample: true 69 | clip_sample_range: 4.0 # default 1 70 | prediction_type: v_prediction # epsilon, sample, v_prediction 71 | thresholding: false # do not touch 72 | dynamic_thresholding_ratio: 0.995 # unused 73 | sample_max_value: 1.0 # unused 74 | timestep_spacing: "leading" # 75 | steps_offset: 0 # unused 76 | 77 | train_batch_size: 256 78 | dataloader_num_workers: 16 79 | max_train_steps: 500000 80 | 81 | learning_rate: 3e-4 82 | lr_warmup_steps: 500 83 | scale_lr: false 84 | lr_scheduler: constant 85 | use_8bit_adam: false 86 | gradient_accumulation_steps: 1 87 | 88 | noise_offset: 0.0 89 | 90 | gradient_checkpointing: false 91 | use_ema: true 92 | enable_xformers_memory_efficient_attention: false 93 | allow_tf32: true 94 | 95 | adam_beta1: 0.9 96 | adam_beta2: 0.999 97 | adam_weight_decay: 1e-2 98 | adam_epsilon: 1e-08 99 | max_grad_norm: 1.0 100 | 101 | logging_dir: logs 102 | mixed_precision: "fp16" # "no", "fp16", "bf16" 103 | 104 | validation_timesteps: 128 105 | validation_fps: ${globals.target_fps} 106 | validation_frames: ${globals.target_nframes} 107 | validation_count: 4 # defines the number of samples 108 | validation_guidance: 1.0 109 | validation_steps: 2500 110 | 111 | report_to: wandb 112 | checkpointing_steps: 100 #10000 # ~3/hour 113 | checkpoints_total_limit: 100 # no limit 114 | resume_from_checkpoint: latest 115 | tracker_project_name: echosyn 116 | 117 | seed: 42 118 | 119 | -------------------------------------------------------------------------------- /echosyn/lidm/configs/ped_a4c.yaml: -------------------------------------------------------------------------------- 1 | wandb_group: "lidm" 2 | output_dir: experiments/lidm_ped_a4c 3 | 4 | pretrained_model_name_or_path: null 5 | vae_path: models/vae 6 | 7 | globals: 8 | target_fps: "original" 9 | target_nframes: 32 10 | outputs: ["image"] 11 | 12 | datasets: 13 | - name: Latent 14 | active: true 15 | params: 16 | root: data/latents/ped_a4c 17 | target_fps: ${globals.target_fps} 18 | target_nframes: ${globals.target_nframes} 19 | target_resolution: 14 # emb resolution 20 | outputs: ${globals.outputs} 21 | 22 | unet: 23 | _class_name: UNet2DModel 24 | sample_size: 14 # actual size is 16 25 | in_channels: 4 26 | out_channels: 4 27 | center_input_sample: false 28 | time_embedding_type: positional 29 | freq_shift: 0 30 | flip_sin_to_cos: true 31 | down_block_types: 32 | - AttnDownBlock2D 33 | - AttnDownBlock2D 34 | - AttnDownBlock2D 35 | - DownBlock2D 36 | up_block_types: 37 | - UpBlock2D 38 | - AttnUpBlock2D 39 | - AttnUpBlock2D 40 | - AttnUpBlock2D 41 | block_out_channels: 42 | - 128 43 | - 256 44 | - 256 45 | - 512 46 | layers_per_block: 2 47 | mid_block_scale_factor: 1 48 | downsample_padding: 1 49 | downsample_type: resnet 50 | upsample_type: resnet 51 | dropout: 0.0 52 | act_fn: silu 53 | attention_head_dim: 8 54 | norm_num_groups: 32 55 | attn_norm_num_groups: null 56 | norm_eps: 1e-05 57 | resnet_time_scale_shift: "default" 58 | class_embed_type: null 59 | num_class_embeds: null 60 | 61 | noise_scheduler: 62 | _class_name: DDPMScheduler 63 | num_train_timesteps: 1000 64 | beta_start: 0.0001 65 | beta_end: 0.02 66 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2 67 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range 68 | clip_sample: true 69 | clip_sample_range: 4.0 # default 1 70 | prediction_type: v_prediction # epsilon, sample, v_prediction 71 | thresholding: false # do not touch 72 | dynamic_thresholding_ratio: 0.995 # unused 73 | sample_max_value: 1.0 # unused 74 | timestep_spacing: "leading" # 75 | steps_offset: 0 # unused 76 | 77 | train_batch_size: 256 78 | dataloader_num_workers: 16 79 | max_train_steps: 500000 80 | 81 | learning_rate: 3e-4 82 | lr_warmup_steps: 500 83 | scale_lr: false 84 | lr_scheduler: constant 85 | use_8bit_adam: false 86 | gradient_accumulation_steps: 1 87 | 88 | noise_offset: 0.0 89 | 90 | gradient_checkpointing: false 91 | use_ema: true 92 | enable_xformers_memory_efficient_attention: false 93 | allow_tf32: true 94 | 95 | adam_beta1: 0.9 96 | adam_beta2: 0.999 97 | adam_weight_decay: 1e-2 98 | adam_epsilon: 1e-08 99 | max_grad_norm: 1.0 100 | 101 | logging_dir: logs 102 | mixed_precision: "fp16" # "no", "fp16", "bf16" 103 | 104 | validation_timesteps: 128 105 | validation_fps: ${globals.target_fps} 106 | validation_frames: ${globals.target_nframes} 107 | validation_count: 4 # defines the number of samples 108 | validation_guidance: 1.0 109 | validation_steps: 2500 110 | 111 | report_to: wandb 112 | checkpointing_steps: 10000 # ~3/hour 113 | checkpoints_total_limit: 100 # no limit 114 | resume_from_checkpoint: latest 115 | tracker_project_name: echosyn 116 | 117 | seed: 42 118 | 119 | -------------------------------------------------------------------------------- /echosyn/lidm/configs/ped_psax.yaml: -------------------------------------------------------------------------------- 1 | wandb_group: "lidm" 2 | output_dir: experiments/lidm_ped_psax 3 | 4 | pretrained_model_name_or_path: null 5 | vae_path: models/vae 6 | 7 | globals: 8 | target_fps: "original" 9 | target_nframes: 32 10 | outputs: ["image"] 11 | 12 | datasets: 13 | - name: Latent 14 | active: true 15 | params: 16 | root: data/latents/ped_psax 17 | target_fps: ${globals.target_fps} 18 | target_nframes: ${globals.target_nframes} 19 | target_resolution: 14 # emb resolution 20 | outputs: ${globals.outputs} 21 | 22 | unet: 23 | _class_name: UNet2DModel 24 | sample_size: 14 # actual size is 16 25 | in_channels: 4 26 | out_channels: 4 27 | center_input_sample: false 28 | time_embedding_type: positional 29 | freq_shift: 0 30 | flip_sin_to_cos: true 31 | down_block_types: 32 | - AttnDownBlock2D 33 | - AttnDownBlock2D 34 | - AttnDownBlock2D 35 | - DownBlock2D 36 | up_block_types: 37 | - UpBlock2D 38 | - AttnUpBlock2D 39 | - AttnUpBlock2D 40 | - AttnUpBlock2D 41 | block_out_channels: 42 | - 128 43 | - 256 44 | - 256 45 | - 512 46 | layers_per_block: 2 47 | mid_block_scale_factor: 1 48 | downsample_padding: 1 49 | downsample_type: resnet 50 | upsample_type: resnet 51 | dropout: 0.0 52 | act_fn: silu 53 | attention_head_dim: 8 54 | norm_num_groups: 32 55 | attn_norm_num_groups: null 56 | norm_eps: 1e-05 57 | resnet_time_scale_shift: "default" 58 | class_embed_type: null 59 | num_class_embeds: null 60 | 61 | noise_scheduler: 62 | _class_name: DDPMScheduler 63 | num_train_timesteps: 1000 64 | beta_start: 0.0001 65 | beta_end: 0.02 66 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2 67 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range 68 | clip_sample: true 69 | clip_sample_range: 4.0 # default 1 70 | prediction_type: v_prediction # epsilon, sample, v_prediction 71 | thresholding: false # do not touch 72 | dynamic_thresholding_ratio: 0.995 # unused 73 | sample_max_value: 1.0 # unused 74 | timestep_spacing: "leading" # 75 | steps_offset: 0 # unused 76 | 77 | train_batch_size: 256 78 | dataloader_num_workers: 16 79 | max_train_steps: 500000 80 | 81 | learning_rate: 3e-4 82 | lr_warmup_steps: 500 83 | scale_lr: false 84 | lr_scheduler: constant 85 | use_8bit_adam: false 86 | gradient_accumulation_steps: 1 87 | 88 | noise_offset: 0.0 89 | 90 | gradient_checkpointing: false 91 | use_ema: true 92 | enable_xformers_memory_efficient_attention: false 93 | allow_tf32: true 94 | 95 | adam_beta1: 0.9 96 | adam_beta2: 0.999 97 | adam_weight_decay: 1e-2 98 | adam_epsilon: 1e-08 99 | max_grad_norm: 1.0 100 | 101 | logging_dir: logs 102 | mixed_precision: "fp16" # "no", "fp16", "bf16" 103 | 104 | validation_timesteps: 128 105 | validation_fps: ${globals.target_fps} 106 | validation_frames: ${globals.target_nframes} 107 | validation_count: 4 # defines the number of samples 108 | validation_guidance: 1.0 109 | validation_steps: 2500 110 | 111 | report_to: wandb 112 | checkpointing_steps: 10000 # ~3/hour 113 | checkpoints_total_limit: 100 # no limit 114 | resume_from_checkpoint: latest 115 | tracker_project_name: echosyn 116 | 117 | seed: 42 118 | 119 | -------------------------------------------------------------------------------- /echosyn/lidm/sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import shutil 6 | import json 7 | from glob import glob 8 | from einops import rearrange 9 | from omegaconf import OmegaConf 10 | import numpy as np 11 | from tqdm import tqdm 12 | from packaging import version 13 | from functools import partial 14 | from PIL import Image 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.utils.checkpoint 20 | from torch.utils.data import DataLoader, Dataset 21 | from torchvision import transforms 22 | 23 | import diffusers 24 | from diffusers import AutoencoderKL, DDPMScheduler, UNet3DConditionModel, UNetSpatioTemporalConditionModel, DDIMScheduler 25 | 26 | from echosyn.common.datasets import instantiate_dataset 27 | from echosyn.common import ( 28 | padf, unpadf, 29 | load_model 30 | ) 31 | 32 | """ 33 | python echosyn/lidm/sample.py \ 34 | --config echosyn/lidm/configs/dynamic.yaml \ 35 | --unet experiments/lidm_dynamic/checkpoint-500000/unet_ema \ 36 | --vae models/vae \ 37 | --output samples/dynamic \ 38 | --num_samples 50000 \ 39 | --batch_size 128 \ 40 | --num_steps 64 \ 41 | --save_latent \ 42 | --seed 0 43 | """ 44 | 45 | if __name__ == "__main__": 46 | # 1 - Parse command line arguments 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--config", type=str, default=None, help="Path to config file.") 49 | parser.add_argument("--unet", type=str, default=None, help="Path unet checkpoint.") 50 | parser.add_argument("--vae", type=str, default=None, help="Path vae checkpoint.") 51 | parser.add_argument("--output", type=str, default='.', help="Output directory.") 52 | parser.add_argument("--num_samples", type=int, default=8, help="Number of samples to generate.") 53 | parser.add_argument("--batch_size", type=int, default=8, help="Batch size.") 54 | parser.add_argument("--num_steps", type=int, default=128, help="Number of steps.") 55 | parser.add_argument("--seed", type=int, default=None, help="Random seed.") 56 | parser.add_argument("--save_latent", action="store_true", help="Save latents.") 57 | parser.add_argument("--ddim", action="store_true", help="Save video.") 58 | args = parser.parse_args() 59 | 60 | config = OmegaConf.load(args.config) 61 | 62 | # 2 - Load models 63 | unet = load_model(args.unet) 64 | vae = load_model(args.vae) 65 | 66 | # 3 - Load scheduler 67 | scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler) 68 | scheduler_klass_name = scheduler_kwargs.pop("_class_name") 69 | if args.ddim: 70 | print("Using DDIMScheduler") 71 | scheduler_klass_name = "DDIMScheduler" 72 | scheduler_kwargs.pop("variance_type") 73 | scheduler_klass = getattr(diffusers, scheduler_klass_name, None) 74 | assert scheduler_klass is not None, f"Could not find scheduler class {scheduler_klass_name}" 75 | scheduler = scheduler_klass(**scheduler_kwargs) 76 | scheduler.set_timesteps(args.num_steps) 77 | timesteps = scheduler.timesteps 78 | 79 | # 5 - Setup 80 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 81 | dtype = torch.float32 82 | generator = torch.Generator(device=device).manual_seed(config.seed) if config.seed is not None else None 83 | unet = unet.to(device, dtype) 84 | vae = vae.to(device, torch.float32) 85 | unet.eval() 86 | vae.eval() 87 | 88 | format_input = padf 89 | format_output = unpadf 90 | 91 | B, C, H, W = args.batch_size, config.unet.out_channels, config.unet.sample_size, config.unet.sample_size 92 | 93 | forward_kwargs = { 94 | "timestep": -1, 95 | } 96 | 97 | sample_index = 0 98 | 99 | os.makedirs(args.output, exist_ok=True) 100 | os.makedirs(os.path.join(args.output, "images"), exist_ok=True) 101 | if args.save_latent: 102 | os.makedirs(os.path.join(args.output, "latents"), exist_ok=True) 103 | finished = False 104 | 105 | # 6 - Generate samples 106 | with torch.no_grad(): 107 | for _ in tqdm(range(int(np.ceil(args.num_samples/args.batch_size)))): 108 | if finished: 109 | break 110 | 111 | latents = torch.randn((B, C, H, W), device=device, dtype=dtype, generator=generator) 112 | 113 | with torch.autocast("cuda"): 114 | for t in timesteps: 115 | forward_kwargs["timestep"] = t 116 | latent_model_input = latents 117 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) 118 | latent_model_input, padding = format_input(latent_model_input, mult=3) 119 | noise_pred = unet(latent_model_input, **forward_kwargs).sample 120 | noise_pred = format_output(noise_pred, pad=padding) 121 | latents = scheduler.step(noise_pred, t, latents).prev_sample 122 | 123 | if args.save_latent: 124 | latents_clean = latents.clone() 125 | 126 | # VAE decode 127 | rep = [1, 3, 1, 1] 128 | latents = latents / vae.config.scaling_factor 129 | images = vae.decode(latents.float()).sample 130 | images = (images + 1) * 128 # [-1, 1] -> [0, 256] 131 | 132 | # grayscale 133 | images = images.mean(1).unsqueeze(1).repeat(*rep) 134 | 135 | images = images.clamp(0, 255).to(torch.uint8).cpu() 136 | images = rearrange(images, 'b c h w -> b h w c') 137 | 138 | # 7 - Save samples 139 | images = images.numpy() 140 | for j in range(B): 141 | 142 | Image.fromarray(images[j]).save(os.path.join(args.output, "images", f"sample_{sample_index:06d}.jpg")) 143 | if args.save_latent: 144 | torch.save(latents_clean[j].clone(), os.path.join(args.output, "latents", f"sample_{sample_index:06d}.pt")) 145 | 146 | sample_index += 1 147 | if sample_index >= args.num_samples: 148 | finished = True 149 | break 150 | 151 | print(f"Finished generating {sample_index} samples.") 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /echosyn/lidm/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import shutil 6 | from einops import rearrange 7 | from omegaconf import OmegaConf 8 | import numpy as np 9 | from tqdm.auto import tqdm 10 | from packaging import version 11 | from functools import partial 12 | from copy import deepcopy 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.utils.checkpoint 18 | from torch.utils.data import DataLoader, Dataset 19 | from torchvision import transforms 20 | 21 | import accelerate 22 | from accelerate import Accelerator 23 | from accelerate.logging import get_logger 24 | from accelerate.state import AcceleratorState 25 | from accelerate.utils import ProjectConfiguration, set_seed 26 | 27 | import diffusers 28 | from diffusers import AutoencoderKL, DDPMScheduler, UNet3DConditionModel, UNetSpatioTemporalConditionModel 29 | from diffusers.optimization import get_scheduler 30 | from diffusers.training_utils import EMAModel, compute_snr 31 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid 32 | from diffusers.utils.import_utils import is_xformers_available 33 | 34 | from echosyn.common.datasets import instantiate_dataset 35 | from echosyn.common import ( 36 | padf, unpadf, 37 | instantiate_from_config 38 | ) 39 | 40 | if is_wandb_available(): 41 | import wandb 42 | 43 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 44 | check_min_version("0.22.0.dev0") 45 | 46 | logger = get_logger(__name__, log_level="INFO") 47 | 48 | def log_validation( 49 | config, 50 | unet, 51 | vae, 52 | scheduler, 53 | accelerator, 54 | weight_dtype, 55 | epoch, 56 | val_dataset 57 | ): 58 | logger.info("Running validation... ") 59 | 60 | val_unet = accelerator.unwrap_model(unet) 61 | val_vae = vae.to(accelerator.device, dtype=torch.float32) 62 | scheduler.set_timesteps(config.validation_timesteps) 63 | timesteps = scheduler.timesteps 64 | 65 | if config.enable_xformers_memory_efficient_attention: 66 | val_unet.enable_xformers_memory_efficient_attention() 67 | 68 | if config.seed is None: 69 | generator = None 70 | else: 71 | generator = torch.Generator(device=accelerator.device).manual_seed(config.seed) 72 | 73 | indices = np.random.choice(len(val_dataset), size=config.validation_count, replace=False) 74 | ref_elements = [val_dataset[i] for i in indices] 75 | ref_frames = [e['image'] for e in ref_elements] 76 | ref_frames = torch.stack(ref_frames, dim=0) # B x C x H x W 77 | ref_frames = ref_frames.to(accelerator.device, dtype=weight_dtype) 78 | 79 | format_input = padf 80 | format_output = unpadf 81 | 82 | logger.info("Sampling... ") 83 | with torch.no_grad(), torch.autocast("cuda"): 84 | # prepare model inputs 85 | B, C, H, W = config.validation_count, 4, config.unet.sample_size, config.unet.sample_size 86 | latents = torch.randn((B, C, H, W), device=accelerator.device, dtype=weight_dtype, generator=generator) 87 | 88 | forward_kwargs = { 89 | "timestep": timesteps, 90 | } 91 | 92 | # reverse diffusionn loop 93 | for t in timesteps: 94 | latent_model_input = latents 95 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) 96 | latent_model_input, padding = format_input(latent_model_input, mult=3) 97 | noise_pred = unet(latent_model_input, t).sample 98 | noise_pred = format_output(noise_pred, pad=padding) 99 | latents = scheduler.step(noise_pred, t, latents).prev_sample 100 | 101 | # VAE decoding 102 | with torch.no_grad(): # no autocast 103 | latents = latents / val_vae.config.scaling_factor 104 | images = val_vae.decode(latents.float()).sample 105 | images = (images + 1) * 128 # [-1, 1] -> [0, 256] 106 | images = images.clamp(0, 255).to(torch.uint8).cpu() 107 | 108 | ref_frames = ref_frames / val_vae.config.scaling_factor 109 | ref_frames = val_vae.decode(ref_frames.float()).sample 110 | ref_frames = (ref_frames + 1) * 128 # [-1, 1] -> [0, 256] 111 | ref_frames = ref_frames.clamp(0, 255).to(torch.uint8).cpu() 112 | 113 | images = torch.cat([ref_frames, images], dim=2) # B x C x (2 H) x W // vertical concat 114 | 115 | # reshape for wandb 116 | images = rearrange(images, "b c h w -> h (b w) c") # prepare for wandb 117 | images = images.numpy() 118 | 119 | logger.info("Done sampling... ") 120 | 121 | for tracker in accelerator.trackers: 122 | if tracker.name == "wandb": 123 | tracker.log({"validation": wandb.Image(images)}) 124 | logger.info("Samples sent to wandb.") 125 | else: 126 | logger.warn(f"image logging not implemented for {tracker.name}") 127 | 128 | del val_unet 129 | del val_vae 130 | torch.cuda.empty_cache() 131 | 132 | return images 133 | 134 | def parse_args(): 135 | parser = argparse.ArgumentParser(description="") 136 | parser.add_argument("--config", type=str, help="Path to the config file.") 137 | args = parser.parse_args() 138 | return args 139 | 140 | def main(): 141 | args = parse_args() 142 | config = OmegaConf.load(args.config) 143 | 144 | # Setup accelerator 145 | logging_dir = os.path.join(config.output_dir, config.logging_dir) 146 | 147 | accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir) 148 | 149 | accelerator = Accelerator( 150 | gradient_accumulation_steps=config.gradient_accumulation_steps, 151 | mixed_precision=config.mixed_precision, 152 | log_with=config.report_to, 153 | project_config=accelerator_project_config, 154 | ) 155 | 156 | # Make one log on every process with the configuration for debugging. 157 | logging.basicConfig( 158 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 159 | datefmt="%m/%d/%Y %H:%M:%S", 160 | level=logging.INFO, 161 | ) 162 | logger.info(accelerator.state, main_process_only=False) 163 | if accelerator.is_local_main_process: 164 | diffusers.utils.logging.set_verbosity_info() 165 | else: 166 | diffusers.utils.logging.set_verbosity_error() 167 | 168 | # If passed along, set the training seed now. 169 | if config.seed is not None: 170 | set_seed(config.seed) 171 | 172 | # Handle the repository creation 173 | if accelerator.is_main_process: 174 | if config.output_dir is not None: 175 | os.makedirs(config.output_dir, exist_ok=True) 176 | 177 | noise_scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler, resolve=True) 178 | noise_scheduler_klass_name = noise_scheduler_kwargs.pop("_class_name") 179 | noise_scheduler_klass = globals().get(noise_scheduler_klass_name, None) 180 | assert noise_scheduler_klass is not None, f"Could not find class {noise_scheduler_klass_name}" 181 | noise_scheduler = noise_scheduler_klass(**noise_scheduler_kwargs) 182 | 183 | vae = AutoencoderKL.from_pretrained(config.vae_path).cpu() 184 | 185 | # Create the video unet 186 | unet, unet_klass, unet_kwargs = instantiate_from_config(config.unet, ["diffusers"], return_klass_kwargs=True) 187 | noise_scheduler = instantiate_from_config(config.noise_scheduler, ["diffusers"]) 188 | 189 | format_input = padf 190 | format_output = unpadf 191 | 192 | # Freeze vae and text_encoder and set unet to trainable 193 | vae.requires_grad_(False) 194 | unet.train() 195 | 196 | # Create EMA for the unet. 197 | if config.use_ema: 198 | ema_unet = unet_klass(**unet_kwargs) 199 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=unet_klass, model_config=ema_unet.config) 200 | 201 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 202 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 203 | def save_model_hook(models, weights, output_dir): 204 | if accelerator.is_main_process: 205 | if config.use_ema: 206 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 207 | 208 | for i, model in enumerate(models): 209 | model.save_pretrained(os.path.join(output_dir, "unet")) 210 | 211 | # make sure to pop weight so that corresponding model is not saved again 212 | weights.pop() 213 | 214 | def load_model_hook(models, input_dir): 215 | if config.use_ema: 216 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), unet_klass) 217 | ema_unet.load_state_dict(load_model.state_dict()) 218 | ema_unet.to(accelerator.device) 219 | del load_model 220 | 221 | for i in range(len(models)): 222 | # pop models so that they are not loaded again 223 | model = models.pop() 224 | 225 | # load diffusers style into model 226 | load_model = unet_klass.from_pretrained(input_dir, subfolder="unet") 227 | model.register_to_config(**load_model.config) 228 | 229 | model.load_state_dict(load_model.state_dict()) 230 | del load_model 231 | 232 | accelerator.register_save_state_pre_hook(save_model_hook) 233 | accelerator.register_load_state_pre_hook(load_model_hook) 234 | 235 | if config.gradient_checkpointing: 236 | unet.enable_gradient_checkpointing() 237 | 238 | # Enable TF32 for faster training on Ampere GPUs, 239 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 240 | if config.allow_tf32: 241 | torch.backends.cuda.matmul.allow_tf32 = True 242 | 243 | optimizer = torch.optim.AdamW( 244 | unet.parameters(), 245 | lr=config.learning_rate, 246 | betas=(config.adam_beta1, config.adam_beta2), 247 | weight_decay=config.adam_weight_decay, 248 | eps=config.adam_epsilon, 249 | ) 250 | 251 | train_dataset = instantiate_dataset(config.datasets, split=["TRAIN"]) 252 | val_dataset = instantiate_dataset(config.datasets, split=["VAL"]) 253 | 254 | # DataLoaders creation: 255 | train_dataloader = torch.utils.data.DataLoader( 256 | train_dataset, 257 | shuffle=True, 258 | batch_size=config.train_batch_size, 259 | num_workers=config.dataloader_num_workers, 260 | ) 261 | 262 | # Scheduler and math around the number of training steps. 263 | overrode_max_train_steps = False 264 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps) 265 | if config.max_train_steps is None: 266 | config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch 267 | overrode_max_train_steps = True 268 | 269 | lr_scheduler = get_scheduler( 270 | config.lr_scheduler, 271 | optimizer=optimizer, 272 | num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, 273 | num_training_steps=config.max_train_steps * accelerator.num_processes, 274 | ) 275 | 276 | # Prepare everything with our `accelerator`. 277 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 278 | unet, optimizer, train_dataloader, lr_scheduler 279 | ) 280 | 281 | if config.use_ema: 282 | ema_unet.to(accelerator.device) 283 | 284 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 285 | # as these weights are only used for inference, keeping weights in full precision is not required. 286 | weight_dtype = torch.float32 287 | if accelerator.mixed_precision == "fp16": 288 | weight_dtype = torch.float16 289 | config.mixed_precision = accelerator.mixed_precision 290 | elif accelerator.mixed_precision == "bf16": 291 | weight_dtype = torch.bfloat16 292 | config.mixed_precision = accelerator.mixed_precision 293 | 294 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 295 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps) 296 | if overrode_max_train_steps: 297 | config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch 298 | # Afterwards we recalculate our number of training epochs 299 | config.num_train_epochs = math.ceil(config.max_train_steps / num_update_steps_per_epoch) 300 | 301 | # We need to initialize the trackers we use, and also store our configuration. 302 | # The trackers initializes automatically on the main process. 303 | if accelerator.is_main_process: 304 | tracker_config = OmegaConf.to_container(config, resolve=True) 305 | accelerator.init_trackers( 306 | config.tracker_project_name, 307 | tracker_config, 308 | init_kwargs={ 309 | "wandb": { 310 | "group": config.wandb_group 311 | }, 312 | }, 313 | ) 314 | 315 | # Train! 316 | total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps 317 | model_num_params = sum(p.numel() for p in unet.parameters()) 318 | model_trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad) 319 | 320 | logger.info("***** Running training *****") 321 | logger.info(f" Num examples = {len(train_dataset)}") 322 | logger.info(f" Num Epochs = {config.num_train_epochs}") 323 | logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") 324 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 325 | logger.info(f" Gradient Accumulation steps = {config.gradient_accumulation_steps}") 326 | logger.info(f" Total optimization steps = {config.max_train_steps}") 327 | logger.info(f" U-Net: Total params = {model_num_params} \t Trainable params = {model_trainable_params} ({model_trainable_params/model_num_params*100:.2f}%)") 328 | global_step = 0 329 | first_epoch = 0 330 | 331 | # Potentially load in the weights and states from a previous save 332 | if config.resume_from_checkpoint: 333 | if config.resume_from_checkpoint != "latest": 334 | path = os.path.basename(config.resume_from_checkpoint) 335 | else: 336 | # Get the most recent checkpoint 337 | dirs = os.listdir(config.output_dir) 338 | dirs = [d for d in dirs if d.startswith("checkpoint")] 339 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 340 | path = dirs[-1] if len(dirs) > 0 else None 341 | 342 | if path is None: 343 | accelerator.print( 344 | f"Checkpoint '{config.resume_from_checkpoint}' does not exist. Starting a new training run." 345 | ) 346 | config.resume_from_checkpoint = None 347 | initial_global_step = 0 348 | else: 349 | accelerator.print(f"Resuming from checkpoint {path}") 350 | accelerator.load_state(os.path.join(config.output_dir, path)) 351 | global_step = int(path.split("-")[1]) 352 | 353 | initial_global_step = global_step 354 | first_epoch = global_step // num_update_steps_per_epoch 355 | 356 | else: 357 | initial_global_step = 0 358 | 359 | progress_bar = tqdm( 360 | range(0, config.max_train_steps), 361 | initial=initial_global_step, 362 | desc="Steps", 363 | # Only show the progress bar once on each machine. 364 | # disable=not accelerator.is_local_main_process, 365 | disable=not accelerator.is_main_process, 366 | ) 367 | 368 | for epoch in range(first_epoch, config.num_train_epochs): 369 | train_loss = 0.0 370 | prediction_mean = 0.0 371 | prediction_std = 0.0 372 | target_mean = 0.0 373 | target_std = 0.0 374 | mean_losses = 0.0 375 | for step, batch in enumerate(train_dataloader): 376 | with accelerator.accumulate(unet): 377 | 378 | latents = batch['image'] # B x C x H x W 379 | 380 | B, C, H, W = latents.shape 381 | 382 | # Sample a random timestep for each video 383 | timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps), (B,), device=latents.device).long() 384 | 385 | # Sample noise that we'll add to the latents 386 | noise = torch.randn_like(latents) 387 | if config.noise_offset > 0: 388 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 389 | noise += config.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device) 390 | 391 | if config.get('input_perturbation', 0) > 0.0: 392 | noisy_latents = noise_scheduler.add_noise(latents, noise + config.input_perturbation*torch.rand(1,).item() * torch.randn_like(noise), timesteps) 393 | else: 394 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 395 | 396 | # Predict the noise residual and compute loss 397 | noisy_latents, padding = format_input(noisy_latents, mult=3) 398 | model_pred = unet(sample=noisy_latents, timestep=timesteps).sample 399 | model_pred = format_output(model_pred, pad=padding) 400 | 401 | if noise_scheduler.config.prediction_type == "epsilon": 402 | target = noise 403 | elif noise_scheduler.config.prediction_type == "v_prediction": 404 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 405 | else: 406 | assert noise_scheduler.config.prediction_type == "sample", f"Unknown prediction type {noise_scheduler.config.prediction_type}" 407 | target = latents 408 | 409 | # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 410 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 411 | loss = loss.mean() 412 | mean_loss = loss.item() 413 | 414 | # Gather the losses across all processes for logging (if we use distributed training). 415 | avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() 416 | train_loss += avg_loss.item() / config.gradient_accumulation_steps 417 | mean_losses += mean_loss / config.gradient_accumulation_steps 418 | prediction_mean += model_pred.mean().item() / config.gradient_accumulation_steps 419 | prediction_std += model_pred.std().item() / config.gradient_accumulation_steps 420 | target_mean += target.mean().item() / config.gradient_accumulation_steps 421 | target_std += target.std().item() / config.gradient_accumulation_steps 422 | 423 | # Backpropagate 424 | accelerator.backward(loss) 425 | if accelerator.sync_gradients: 426 | accelerator.clip_grad_norm_(unet.parameters(), config.max_grad_norm) 427 | optimizer.step() 428 | lr_scheduler.step() 429 | optimizer.zero_grad() 430 | 431 | # Checks if the accelerator has performed an optimization step behind the scenes 432 | if accelerator.sync_gradients: 433 | if config.use_ema: 434 | ema_unet.step(unet.parameters()) 435 | progress_bar.update(1) 436 | global_step += 1 437 | accelerator.log({ 438 | "train_loss": train_loss, 439 | "prediction_mean": prediction_mean, 440 | "prediction_std": prediction_std, 441 | "target_mean": target_mean, 442 | "target_std": target_std, 443 | "mean_losses": mean_losses, 444 | }, step=global_step) 445 | train_loss = 0.0 446 | prediction_mean = 0.0 447 | prediction_std = 0.0 448 | target_mean = 0.0 449 | target_std = 0.0 450 | mean_losses = 0.0 451 | 452 | if global_step % config.checkpointing_steps == 0: 453 | if accelerator.is_main_process: 454 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 455 | if config.checkpoints_total_limit is not None: 456 | checkpoints = os.listdir(config.output_dir) 457 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 458 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 459 | 460 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 461 | if len(checkpoints) >= config.checkpoints_total_limit: 462 | num_to_remove = len(checkpoints) - config.checkpoints_total_limit + 1 463 | removing_checkpoints = checkpoints[0:num_to_remove] 464 | 465 | logger.info( 466 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 467 | ) 468 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 469 | 470 | for removing_checkpoint in removing_checkpoints: 471 | removing_checkpoint = os.path.join(config.output_dir, removing_checkpoint) 472 | shutil.rmtree(removing_checkpoint) 473 | 474 | save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}") 475 | accelerator.save_state(save_path) 476 | OmegaConf.save(config, os.path.join(config.output_dir, "config.yaml")) 477 | logger.info(f"Saved state to {save_path}") 478 | 479 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 480 | progress_bar.set_postfix(**logs) 481 | 482 | if global_step >= config.max_train_steps: 483 | break 484 | 485 | if accelerator.is_main_process: 486 | if global_step % config.validation_steps == 0: 487 | if config.use_ema: 488 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 489 | ema_unet.store(unet.parameters()) 490 | ema_unet.copy_to(unet.parameters()) 491 | 492 | log_validation( 493 | config, 494 | unet, 495 | vae, 496 | deepcopy(noise_scheduler), 497 | accelerator, 498 | weight_dtype, 499 | epoch, 500 | val_dataset, 501 | ) 502 | 503 | if config.use_ema: 504 | # Switch back to the original UNet parameters. 505 | ema_unet.restore(unet.parameters()) 506 | 507 | accelerator.end_training() 508 | 509 | 510 | if __name__ == "__main__": 511 | main() -------------------------------------------------------------------------------- /echosyn/lvdm/README.md: -------------------------------------------------------------------------------- 1 | # Latent Video Diffusion Model 2 | 3 | The Latent Video Diffusion Model (LVDM) is responsible for animating the latent representation of a heart generated by the Latent Image Diffusion Model (LIDM). The LVDM is trained on the VAE-encoded videos. We condition it on an encoded frame and an ejection fraction score, and train it to reconstruct the video corresponding to that frame and ejection fraction. 4 | 5 | During inference, it can animate any heart, real or synthetic, by conditioning on the latent representation of the heart and the desired ejection fraction. 6 | 7 | ## 1. Activate the environment 8 | 9 | First, activate the echosyn environment. 10 | 11 | ```bash 12 | conda activate echosyn 13 | ``` 14 | 15 | ## 2. Data preparation 16 | Follow the instruction in the [Data preparation](../../README.md#data-preparation) to prepare the data for training. Here, you need the VAE-encoded videos. 17 | 18 | ## 3. Train the LVDM 19 | 20 | Once the environment is set up and the data is ready, you can train the LVDM with the following command: 21 | 22 | ```bash 23 | python echosyn/lvdm/train.py --config echosyn/lvdm/configs/default.yaml 24 | ``` 25 | 26 | or this one, for multi-gpu training: 27 | 28 | ```bash 29 | accelerate launch \ 30 | --num_processes 8 \ 31 | --multi_gpu \ 32 | --mixed_precision fp16 \ 33 | echosyn/lvdm/train.py \ 34 | --config echosyn/lvdm/configs/default.yaml 35 | ``` 36 | 37 | Note that we train a single model for all datasets. The model is conditioned on the frames, so the same model can generate videos for any dataset. 38 | 39 | ## 4. Sample from the LVDM 40 | 41 | Once the LVDM is trained, you can sample from it with the following command: 42 | 43 | ```bash 44 | python echosyn/lvdm/sample.py \ 45 | --config echosyn/lvdm/configs/default.yaml \ 46 | --unet experiments/lvdm/checkpoint-500000/unet_ema \ 47 | --vae models/vae \ 48 | --conditioning samples/lidm_dynamic/privacy_compliant_latents \ 49 | --output samples/lvdm_dynamic \ 50 | --num_samples 2048 \ 51 | --batch_size 8 \ 52 | --num_steps 64 \ 53 | --min_lvef 10 \ 54 | --max_lvef 90 \ 55 | --save_as avi,jpg \ 56 | --frames 192 57 | ``` 58 | 59 | This will generate 2048 videos of 192 frames, conditioned on the latent synthetic and privacy-compliant representation of the heart and uniformaly sampled ejection fraction scores. The videos will be saved in the `samples/lvdm_dynamic/avi` directory. 60 | 61 | ## 5. Evaluate the LVDM 62 | 63 | To evaluate the LIDMs, we use the FID, FVD16, FVD128 and IS scores. 64 | To do so, we need to generate 2048 videos with 192 frames for the dynamic dataset or 128 frames for the two pediatric datasets. 65 | The outputs MUST be in jpg format. 66 | The samples are compared to the real samples, which are generated in the [Data preparation](../../README.md#data-preparation) step. 67 | 68 | Then, to evaluate the synthetic videos, run the following commands: 69 | 70 | ```bash 71 | cd external/stylegan-v 72 | 73 | python src/scripts/calc_metrics_for_dataset.py \ 74 | --real_data_path ../../data/reference/dynamic \ 75 | --fake_data_path ../../samples/lvdm_dynamic/jpg \ 76 | --mirror 0 --gpus 1 --resolution 112 \ 77 | --metrics fvd2048_16f,fvd2048_128f,fid50k_full,is50k >> "../../samples/lvdm_dynamic/metrics.txt" 78 | 79 | python src/scripts/calc_metrics_for_dataset.py \ 80 | --real_data_path ../../data/reference/ped_a4c \ 81 | --fake_data_path ../../samples/lvdm_ped_a4c/jpg \ 82 | --mirror 0 --gpus 1 --resolution 112 \ 83 | --metrics fvd2048_16f,fvd2048_128f,fid50k_full,is50k >> "../../samples/lvdm_ped_a4c/metrics.txt" 84 | 85 | python src/scripts/calc_metrics_for_dataset.py \ 86 | --real_data_path ../../data/reference/ped_psax \ 87 | --fake_data_path ../../samples/lvdm_ped_psax/jpg \ 88 | --mirror 0 --gpus 1 --resolution 112 \ 89 | --metrics fvd2048_16f,fvd2048_128f,fid50k_full,is50k >> "../../samples/lvdm_ped_psax/metrics.txt" 90 | ``` 91 | 92 | ## 6. Save the LVDM for later use 93 | 94 | Once the LVDM is trained, you can save it for later use with the following command: 95 | 96 | ```bash 97 | mkdir -p models/lvdm; cp -r experiments/lvdm/checkpoint-500000/unet_ema/* models/lvdm/; cp experiments/lvdm/config.yaml models/lvdm/ 98 | ``` 99 | 100 | This will save the selected ema version of the model, ready to be loaded in any other script as a standalone model. -------------------------------------------------------------------------------- /echosyn/lvdm/configs/default.yaml: -------------------------------------------------------------------------------- 1 | wandb_group: "lvdm" 2 | output_dir: experiments/lvdm # experiments/lvdm_vpred_64 3 | 4 | pretrained_model_name_or_path: null 5 | vae_path: models/vae 6 | 7 | globals: 8 | target_fps: 32 9 | target_nframes: 64 10 | outputs: ["video", "lvef", "image"] 11 | 12 | 13 | datasets: 14 | - name: Latent 15 | active: true 16 | params: 17 | root: data/latents/dynamic 18 | target_fps: ${globals.target_fps} 19 | target_nframes: ${globals.target_nframes} 20 | target_resolution: 14 # emb resolution 21 | outputs: ${globals.outputs} 22 | 23 | - name: Latent 24 | active: true 25 | params: 26 | root: data/latents/ped_a4c 27 | target_fps: ${globals.target_fps} 28 | target_nframes: ${globals.target_nframes} 29 | target_resolution: 14 # emb resolution 30 | outputs: ${globals.outputs} 31 | 32 | - name: Latent 33 | active: true 34 | params: 35 | root: data/latents/ped_psax 36 | target_fps: ${globals.target_fps} 37 | target_nframes: ${globals.target_nframes} 38 | target_resolution: 14 # emb resolution 39 | outputs: ${globals.outputs} 40 | 41 | unet: 42 | _class_name: UNetSpatioTemporalConditionModel 43 | addition_time_embed_dim: 1 44 | block_out_channels: 45 | - 128 46 | - 256 47 | - 256 48 | - 512 49 | cross_attention_dim: 1 50 | down_block_types: 51 | - CrossAttnDownBlockSpatioTemporal 52 | - CrossAttnDownBlockSpatioTemporal 53 | - CrossAttnDownBlockSpatioTemporal 54 | - DownBlockSpatioTemporal 55 | in_channels: 8 56 | layers_per_block: 2 57 | num_attention_heads: 58 | - 8 59 | - 16 60 | - 16 61 | - 32 62 | num_frames: ${globals.target_nframes} 63 | out_channels: 4 64 | projection_class_embeddings_input_dim: 1 65 | sample_size: 14 66 | transformer_layers_per_block: 1 67 | up_block_types: 68 | - UpBlockSpatioTemporal 69 | - CrossAttnUpBlockSpatioTemporal 70 | - CrossAttnUpBlockSpatioTemporal 71 | - CrossAttnUpBlockSpatioTemporal 72 | 73 | noise_scheduler: 74 | _class_name: DDPMScheduler 75 | num_train_timesteps: 1000 76 | beta_start: 0.0001 77 | beta_end: 0.02 78 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2 79 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range 80 | clip_sample: true 81 | clip_sample_range: 4.0 # default 1 82 | prediction_type: v_prediction # epsilon, sample, v_prediction 83 | thresholding: false # do not touch 84 | dynamic_thresholding_ratio: 0.995 # unused 85 | sample_max_value: 1.0 # unused 86 | timestep_spacing: "leading" # 87 | steps_offset: 0 # unused 88 | 89 | train_batch_size: 16 90 | dataloader_num_workers: 16 91 | max_train_steps: 500000 92 | 93 | learning_rate: 1e-4 94 | lr_warmup_steps: 500 95 | scale_lr: false 96 | lr_scheduler: constant 97 | use_8bit_adam: false 98 | gradient_accumulation_steps: 1 99 | 100 | noise_offset: 0.1 101 | drop_conditionning: 0.1 # 10 % of the time, the LVEF conditionning is dropped 102 | 103 | gradient_checkpointing: false 104 | use_ema: true 105 | enable_xformers_memory_efficient_attention: false 106 | allow_tf32: true 107 | 108 | adam_beta1: 0.9 109 | adam_beta2: 0.999 110 | adam_weight_decay: 1e-2 111 | adam_epsilon: 1e-08 112 | max_grad_norm: 1.0 113 | 114 | logging_dir: logs 115 | mixed_precision: "fp16" # "no", "fp16", "bf16" 116 | 117 | validation_timesteps: 128 118 | validation_fps: ${globals.target_fps} 119 | validation_frames: ${globals.target_nframes} 120 | validation_lvefs: [0.0, 0.4, 0.7, 1.0] # defines the number of samples 121 | validation_guidance: 1.0 122 | validation_steps: 1500 123 | 124 | report_to: wandb 125 | checkpointing_steps: 10000 # ~3/hour 126 | checkpoints_total_limit: 100 # no limit 127 | resume_from_checkpoint: latest 128 | tracker_project_name: echosyn 129 | 130 | seed: 42 131 | 132 | -------------------------------------------------------------------------------- /echosyn/lvdm/sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import shutil 6 | import json 7 | from glob import glob 8 | from einops import rearrange 9 | from omegaconf import OmegaConf 10 | import numpy as np 11 | from tqdm import tqdm 12 | from packaging import version 13 | from functools import partial 14 | from PIL import Image 15 | import pandas as pd 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.utils.checkpoint 21 | from torch.utils.data import DataLoader, Dataset 22 | from torchvision import transforms 23 | 24 | import diffusers 25 | from diffusers import AutoencoderKL, DDPMScheduler, UNet3DConditionModel, UNetSpatioTemporalConditionModel 26 | 27 | from echosyn.common.datasets import TensorSet, ImageSet 28 | from echosyn.common import ( 29 | pad_reshape, unpad_reshape, padf, unpadf, 30 | load_model, save_as_mp4, save_as_gif, save_as_img, save_as_avi, 31 | parse_formats, 32 | ) 33 | 34 | """ 35 | python echosyn/lvdm/sample.py \ 36 | --config echosyn/lvdm/configs/default.yaml \ 37 | --unet experiments/lvdm/checkpoint-500000/unet_ema \ 38 | --vae model/vae \ 39 | --conditioning samples/dynamic/privacy_compliant_latents \ 40 | --output samples/dynamic/privacy_compliant_samples \ 41 | --num_samples 2048 \ 42 | --batch_size 8 \ 43 | --num_steps 64 \ 44 | --min_lvef 10 \ 45 | --max_lvef 90 \ 46 | --save_as avi \ 47 | --frames 192 48 | """ 49 | 50 | if __name__ == "__main__": 51 | # 1 - Parse command line arguments 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--config", type=str, default=None, help="Path to config file.") 54 | parser.add_argument("--unet", type=str, default=None, help="Path unet checkpoint.") 55 | parser.add_argument("--vae", type=str, default=None, help="Path vae checkpoint.") 56 | parser.add_argument("--conditioning", type=str, default=None, help="Path to the folder containing the conditionning latents.") 57 | parser.add_argument("--output", type=str, default='.', help="Output directory.") 58 | parser.add_argument("--num_samples", type=int, default=8, help="Number of samples to generate.") 59 | parser.add_argument("--batch_size", type=int, default=8, help="Batch size.") 60 | parser.add_argument("--num_steps", type=int, default=64, help="Number of steps.") 61 | parser.add_argument("--min_lvef", type=int, default=10, help="Minimum LVEF.") 62 | parser.add_argument("--max_lvef", type=int, default=90, help="Maximum LVEF.") 63 | parser.add_argument("--save_as", type=parse_formats, default=None, help="Save formats separated by commas (e.g., avi,jpg). Available: avi, mp4, gif, jpg, png, pt") 64 | parser.add_argument("--frames", type=int, default=192, help="Number of frames to generate. Must be a multiple of 32") 65 | parser.add_argument("--seed", type=int, default=None, help="Random seed.") 66 | 67 | args = parser.parse_args() 68 | 69 | config = OmegaConf.load(args.config) 70 | 71 | # 2 - Load models 72 | unet = load_model(args.unet) 73 | vae = load_model(args.vae) 74 | 75 | # 3 - Load scheduler 76 | scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler) 77 | scheduler_klass_name = scheduler_kwargs.pop("_class_name") 78 | scheduler_klass = getattr(diffusers, scheduler_klass_name, None) 79 | assert scheduler_klass is not None, f"Could not find scheduler class {scheduler_klass_name}" 80 | scheduler = scheduler_klass(**scheduler_kwargs) 81 | scheduler.set_timesteps(args.num_steps) 82 | timesteps = scheduler.timesteps 83 | 84 | # 4 - Load dataset 85 | ## detect type of conditioning: 86 | file_ext = os.listdir(args.conditioning)[0].split(".")[-1].lower() 87 | assert file_ext in ["pt", "jpg", "png"], f"Conditioning files must be either .pt, .jpg or .png, not {file_ext}" 88 | if file_ext == "pt": 89 | dataset = TensorSet(args.conditioning) 90 | else: 91 | dataset = ImageSet(args.conditioning, ext=file_ext) 92 | assert len(dataset) > 0, f"No files found in {args.conditioning} with extension {file_ext}" 93 | 94 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True) 95 | 96 | # 5 - Setup 97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 98 | dtype = torch.float32 99 | generator = torch.Generator(device=device).manual_seed(config.seed) if config.seed is not None else None 100 | unet = unet.to(device, dtype) 101 | vae = vae.to(device, torch.float32) 102 | unet.eval() 103 | vae.eval() 104 | 105 | format_input = pad_reshape if config.unet._class_name == "UNetSpatioTemporalConditionModel" else padf 106 | format_output = unpad_reshape if config.unet._class_name == "UNetSpatioTemporalConditionModel" else unpadf 107 | 108 | B, C, T, H, W = args.batch_size, config.unet.out_channels, config.unet.num_frames, config.unet.sample_size, config.unet.sample_size 109 | fps = config.globals.target_fps 110 | # Stitching parameters 111 | args.frames = int(np.ceil(args.frames/32) * 32) 112 | if args.frames > T: 113 | OT = T//2 # overlap 64//2 114 | TR = (args.frames - T) / 32 # total frames (192 - 64) / 32 = 4 115 | TR = int(TR + 1) # total repetitions 116 | NT = (T-OT) * TR + OT # = args.frame 117 | else: 118 | OT = 0 119 | TR = 1 120 | NT = T 121 | 122 | forward_kwargs = { 123 | "timestep": -1, 124 | } 125 | 126 | if config.unet._class_name == "UNetSpatioTemporalConditionModel": 127 | dummy_added_time_ids = torch.zeros((B*TR, config.unet.addition_time_embed_dim), device=device, dtype=dtype) 128 | forward_kwargs["added_time_ids"] = dummy_added_time_ids 129 | 130 | sample_index = 0 131 | 132 | filelist = [] 133 | 134 | os.makedirs(args.output, exist_ok=True) 135 | for ext in args.save_as: 136 | os.makedirs(os.path.join(args.output, ext), exist_ok=True) 137 | finished = False 138 | 139 | pbar = tqdm(total=args.num_samples) 140 | 141 | # 6 - Generate samples 142 | with torch.no_grad(): 143 | while not finished: 144 | for cond in dataloader: 145 | if finished: 146 | break 147 | 148 | # Prepare latent noise 149 | latents = torch.randn((B, C, NT, H, W), device=device, dtype=dtype, generator=generator) 150 | 151 | # Prepare conditioning - lvef 152 | lvefs = torch.randint(args.min_lvef, args.max_lvef+1, (B,), device=device, dtype=dtype, generator=generator) 153 | lvefs = lvefs / 100.0 154 | lvefs = lvefs[:, None, None] 155 | lvefs = lvefs.repeat_interleave(TR, dim=0) 156 | forward_kwargs["encoder_hidden_states"] = lvefs 157 | # Prepare conditioning - reference frames 158 | latent_cond_images = cond.to(device, torch.float32) 159 | if file_ext != "pt": 160 | # project image to latent space 161 | latent_cond_images = vae.encode(latent_cond_images).latent_dist.sample() 162 | latent_cond_images = latent_cond_images * vae.config.scaling_factor 163 | latent_cond_images = latent_cond_images[:,:,None,:,:].repeat(1,1,NT,1,1) # B x C x T x H x W 164 | 165 | # Denoise the latent 166 | with torch.autocast("cuda"): 167 | for t in timesteps: 168 | forward_kwargs["timestep"] = t 169 | latent_model_input = scheduler.scale_model_input(latents, timestep=t) 170 | latent_model_input = torch.cat((latent_model_input, latent_cond_images), dim=1) # B x 2C x T x H x W 171 | latent_model_input, padding = format_input(latent_model_input, mult=3) # B x T x 2C x H+P x W+P 172 | 173 | # Stitching 174 | inputs = torch.cat([latent_model_input[:,r*(T-OT):r*(T-OT)+T] for r in range(TR)], dim=0) # B*TR x T x 2C x H+P x W+P 175 | noise_pred = unet(inputs, **forward_kwargs).sample 176 | outputs = torch.chunk(noise_pred, TR, dim=0) # TR x B x T x C x H x W 177 | noise_predictions = [] 178 | for r in range(TR): 179 | noise_predictions.append(outputs[r] if r == 0 else outputs[r][:,OT:]) 180 | noise_pred = torch.cat(noise_predictions, dim=1) # B x NT x C x H x W 181 | 182 | noise_pred = unpad_reshape(noise_pred, pad=padding) 183 | latents = scheduler.step(noise_pred, t, latents).prev_sample 184 | 185 | # VAE decode 186 | latents = rearrange(latents, "b c t h w -> (b t) c h w").cpu() 187 | latents = latents / vae.config.scaling_factor 188 | 189 | # Decode in chunks to save memory 190 | chunked_latents = torch.split(latents, args.batch_size, dim=0) 191 | decoded_chunks = [] 192 | for chunk in chunked_latents: 193 | decoded_chunks.append(vae.decode(chunk.float().cuda()).sample.cpu()) 194 | video = torch.cat(decoded_chunks, dim=0) # (B*T) x H x W x C 195 | 196 | # format output 197 | video = rearrange(video, "(b t) c h w -> b t h w c", b=B) 198 | video = (video + 1) * 128 199 | video = video.clamp(0, 255).to(torch.uint8) 200 | 201 | print(video.shape, video.dtype, video.min(), video.max()) 202 | file_lvefs = lvefs.squeeze()[::TR].mul(100).to(torch.int).tolist() 203 | # save samples 204 | for j in range(B): 205 | # FileName,EF,ESV,EDV,FrameHeight,FrameWidth,FPS,NumberOfFrames,Split 206 | filelist.append([f"sample_{sample_index:06d}", file_lvefs[j], 0, 0, video.shape[1], video.shape[2], fps, video.shape[0], "TRAIN"]) 207 | if "mp4" in args.save_as: 208 | save_as_mp4(video[j], os.path.join(args.output, "mp4", f"sample_{sample_index:06d}.mp4")) 209 | if "avi" in args.save_as: 210 | save_as_avi(video[j], os.path.join(args.output, "avi", f"sample_{sample_index:06d}.avi")) 211 | if "gif" in args.save_as: 212 | save_as_gif(video[j], os.path.join(args.output, "gif", f"sample_{sample_index:06d}.gif")) 213 | if "jpg" in args.save_as: 214 | save_as_img(video[j], os.path.join(args.output, "jpg", f"sample_{sample_index:06d}"), ext="jpg") 215 | if "png" in args.save_as: 216 | save_as_img(video[j], os.path.join(args.output, "png", f"sample_{sample_index:06d}"), ext="png") 217 | if "pt" in args.save_as: 218 | torch.save(video[j].clone(), os.path.join(args.output, "pt", f"sample_{sample_index:06d}.pt")) 219 | sample_index += 1 220 | pbar.update(1) 221 | if sample_index >= args.num_samples: 222 | finished = True 223 | break 224 | 225 | df = pd.DataFrame(filelist, columns=["FileName", "EF", "ESV", "EDV", "FrameHeight", "FrameWidth", "FPS", "NumberOfFrames", "Split"]) 226 | df.to_csv(os.path.join(args.output, "FileList.csv"), index=False) 227 | print(f"Generated {sample_index} samples.") 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /echosyn/privacy/README.md: -------------------------------------------------------------------------------- 1 | # Re-Identification Model 2 | 3 | For this work, we use the [Latent Image Diffusion Model (LIDM)](../lidm/README.md) to generate synthetic echocardiography images. To enforce privacy, as a post-hoc step, we train a re-identification model to project real and generated images into a common latent space. This allows us to compute the similarity between two images, and by extension, to detect any synthetic images that are too similar to real ones. 4 | 5 | The re-identification models are trained on the VAE-encoded real images. We train one re-identification model per LIDM. The training takes a few hours on a single A100 GPU. 6 | 7 | ## 1. Activate the environment 8 | 9 | First, activate the echosyn environment. 10 | 11 | ```bash 12 | conda activate echosyn 13 | ``` 14 | 15 | ## 2. Data preparation 16 | Follow the instruction in the [Data preparation](../../README.md#data-preparation) to prepare the data for training. Here, you need the VAE-encoded videos. 17 | 18 | ## 3. Train the Re-Identification models 19 | Once the environment is set up and the data is ready, you can train the Re-Identification models with the following commands: 20 | 21 | ```bash 22 | python echosyn/privacy/train.py --config=echosyn/privacy/configs/config_dynamic.json 23 | python echosyn/privacy/train.py --config=echosyn/privacy/configs/config_ped_a4c.json 24 | python echosyn/privacy/train.py --config=echosyn/privacy/configs/config_ped_psax.json 25 | ``` 26 | 27 | ## 4. Filter the synthetic images 28 | After training the re-identification models, you can filter the synthetic images, generated with the LIDMs with the following commands: 29 | 30 | ```bash 31 | python echosyn/privacy/apply.py \ 32 | --model experiments/reidentification_dynamic \ 33 | --synthetic samples/lidm_dynamic/latents \ 34 | --reference data/latents/dynamic \ 35 | --output samples/lidm_dynamic/privatised_latents 36 | ``` 37 | 38 | ```bash 39 | python echosyn/privacy/apply.py \ 40 | --model experiments/reidentification_ped_a4c \ 41 | --synthetic samples/lidm_ped_a4d/latents \ 42 | --reference data/latents/ped_a4c \ 43 | --output samples/lidm_ped_a4d/privatised_latents 44 | ``` 45 | 46 | ```bash 47 | python echosyn/privacy/apply.py \ 48 | --model experiments/reidentification_ped_psax \ 49 | --synthetic samples/lidm_ped_psax/latents \ 50 | --reference data/latents/ped_psax \ 51 | --output samples/lidm_ped_psax/privatised_latents 52 | ``` 53 | 54 | This script will filter out all latents that are too similar to real latents (encoded images). 55 | The filtered latents are saved in the directory specified in `output`. 56 | The similarity threshold is automatically determined by the `apply.py` script. 57 | The latents are all that's required to generate the privacy-compliant synthetic videos, because the Latent Video Diffusion Model (LVDM) is conditioned on the latents, not the images themselves. 58 | To obtain the privacy-compliant images, you can use the provided script like so: 59 | 60 | ```bash 61 | ./scripts/copy_privacy_compliant_images.sh samples/dynamic/images samples/dynamic/privacy_compliant_latents samples/dynamic/privacy_compliant_images 62 | ./scripts/copy_privacy_compliant_images.sh samples/ped_a4c/images samples/ped_a4c/privacy_compliant_latents samples/ped_a4c/privacy_compliant_images 63 | ./scripts/copy_privacy_compliant_images.sh samples/ped_psax/images samples/ped_psax/privacy_compliant_latents samples/ped_psax/privacy_compliant_images 64 | ``` 65 | 66 | ## 5. Evaluate the remaining images 67 | 68 | To evaluate the remaining images, we use the same process as for the LIDM, with the following commands: 69 | 70 | ```bash 71 | 72 | cd external/stylegan-v 73 | 74 | # For the Dynamic dataset 75 | python src/scripts/calc_metrics_for_dataset.py \ 76 | --real_data_path ../../data/reference/dynamic \ 77 | --fake_data_path ../../samples/lidm_dynamic/privacy_compliant_images \ 78 | --mirror 0 --gpus 1 --resolution 112 \ 79 | --metrics fid50k_full,is50k >> "../../samples/lidm_dynamic/privacy_compliant_metrics.txt" 80 | 81 | # For the Pediatric A4C dataset 82 | python src/scripts/calc_metrics_for_dataset.py \ 83 | --real_data_path data/reference/ped_a4c \ 84 | --fake_data_path samples/lidm_ped_a4c/privacy_compliant_images \ 85 | --mirror 0 --gpus 1 --resolution 112 \ 86 | --metrics fid50k_full,is50k >> "samples/lidm_ped_a4c/privacy_compliant_metrics.txt" 87 | 88 | # For the Pediatric PSAX dataset 89 | python src/scripts/calc_metrics_for_dataset.py \ 90 | --real_data_path data/reference/ped_psax \ 91 | --fake_data_path samples/lidm_ped_psax/privacy_compliant_images \ 92 | --mirror 0 --gpus 1 --resolution 112 \ 93 | --metrics fid50k_full,is50k >> "samples/lidm_ped_psax/privacy_compliant_metrics.txt 94 | ``` 95 | 96 | It is important that at least 50,000 samples remain in the filtered folders, to ensure that the FID and IS scores are reliable. 97 | 98 | 99 | ## 6. Save the Re-Identification models for later use 100 | 101 | The re-identification models can be saved for later use with the following command: 102 | 103 | ```bash 104 | mkdir -p models/reidentification_dynamic; cp experiments/reidentification_dynamic/reidentification_dynamic_best_network.pth models/reidentification_dynamic/; cp experiments/reidentification_dynamic/config.json models/reidentification_dynamic/ 105 | ``` 106 | ```bash 107 | mkdir -p models/reidentification_ped_a4c; cp experiments/reidentification_ped_a4c/reidentification_ped_a4c_best_network.pth models/reidentification_ped_a4c/; cp experiments/reidentification_ped_a4c/config.json models/reidentification_ped_a4c/ 108 | ``` 109 | ```bash 110 | mkdir -p models/reidentification_ped_psax; cp experiments/reidentification_ped_psax/reidentification_ped_psax_best_network.pth models/reidentification_ped_psax/; cp experiments/reidentification_ped_psax/config.json models/reidentification_ped_psax/ 111 | ``` 112 | -------------------------------------------------------------------------------- /echosyn/privacy/apply.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import random 5 | import shutil 6 | 7 | import torch 8 | import pandas as pd 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from PIL import Image 13 | 14 | import matplotlib.pyplot as plt 15 | from functools import partial 16 | from echosyn.common.datasets import SimaseUSVideoDataset 17 | from echosyn.privacy.shared import SiameseNetwork 18 | 19 | 20 | """ 21 | This script is used to apply the privacy filter to a set of synthetic latents. 22 | Example usage: 23 | python echosyn/privacy/apply.py \ 24 | --model experiments/reidentification_dynamic \ 25 | --synthetic samples/dynamic/latents \ 26 | --reference data/latents/dynamic \ 27 | --output samples/dynamic/privatised_latents 28 | """ 29 | 30 | def first_frame(vid): 31 | return vid[0:1] 32 | 33 | def subsample(vid, every_nth_frame): 34 | frames = np.arange(0, len(vid), step=every_nth_frame) 35 | return vid[frames] 36 | 37 | if __name__ == "__main__": 38 | 39 | parser = argparse.ArgumentParser(description='Privacy Filter') 40 | parser.add_argument('--model', type=str, help='Path to the model folder.') 41 | parser.add_argument('--synthetic', type=str, help='Path to the synthetic latents folder.') 42 | parser.add_argument('--reference', type=str, help='Path to the real latents folder.') 43 | parser.add_argument('--output', type=str, help='Path to the output folder.') 44 | parser.add_argument('--cutoff_precentile', type=float, default=95, help='Cutoff percentile for privacy threshold.') 45 | 46 | args = parser.parse_args() 47 | 48 | # Load real and synthetic latents 49 | training_latents_csv = os.path.join(args.reference, "FileList.csv") 50 | training_latents_basepath = os.path.join(args.reference, "Latents") 51 | normalization =lambda x: (x - x.min())/(x.max() - x.min()) * 2 - 1 # should be -1 to 1 due to way we trained the model 52 | ds_train = SimaseUSVideoDataset(phase="training", transform=normalization, latents_csv=training_latents_csv, training_latents_base_path=training_latents_basepath, in_memory=True) 53 | ds_test = SimaseUSVideoDataset(phase="testing", transform=normalization, latents_csv=training_latents_csv, training_latents_base_path=training_latents_basepath, in_memory=True) 54 | 55 | synthetic_images_paths = os.listdir(args.synthetic) 56 | synthetic_images = [torch.load(os.path.join(args.synthetic, x)) for x in synthetic_images_paths] 57 | 58 | # convert images to 1 x C x H x W to be consistent in case we want to check videos 59 | for i in range(len(synthetic_images)): 60 | if len(synthetic_images[i].size()) == 3: 61 | synthetic_images[i] = synthetic_images[i].unsqueeze(dim=0) 62 | 63 | synthetic_images = normalization(torch.cat(synthetic_images)) 64 | print(f"Number of synthetic images found: {len(synthetic_images)}") 65 | 66 | # Prepare the images 67 | train_vid_to_img = first_frame 68 | test_vid_to_img = first_frame 69 | train_images = torch.cat([train_vid_to_img(x) for x in tqdm(ds_train)]) 70 | test_images = torch.cat([test_vid_to_img(x) for x in tqdm(ds_test)]) 71 | print(f"Number of real train frames: {len(train_images)}") 72 | print(f"Number of real test frames: {len(test_images)}") 73 | 74 | 75 | # Load Model 76 | with open(os.path.join(args.model, "config.json")) as config: 77 | config = config.read() 78 | 79 | config = json.loads(config) 80 | net = SiameseNetwork(network=config['siamese_architecture'], in_channels=config['n_channels'], n_features=config['n_features']) 81 | net.load_state_dict(torch.load(os.path.join(args.model, os.path.basename(args.model) + "_best_network.pth"))) 82 | 83 | 84 | print("Sanity Check") 85 | print(f"Train - Min: {train_images.min()} - Max: {train_images.max()} - shape: {train_images.size()}") 86 | print(f"Test - Min: {test_images.min()} - Max: {test_images.max()} - shape: {test_images.size()}") 87 | print(f"Train - Min: {synthetic_images.min()} - Max: {synthetic_images.max()} - shape: {synthetic_images.size()}") 88 | 89 | # Compute the embeddings 90 | net.eval() 91 | net = net.cuda() 92 | bs = 256 93 | latents_train = [] 94 | latents_test = [] 95 | latents_synth = [] 96 | with torch.no_grad(): 97 | for i in tqdm(np.arange(0, len(train_images), bs), "Computing Train Embeddings"): 98 | batch = train_images[i:i+bs].cuda() 99 | latents_train.append(net.forward_once(batch)) 100 | 101 | for i in tqdm(np.arange(0, len(test_images), bs), "Computing Test Embeddings"): 102 | batch = test_images[i:i+bs].cuda() 103 | latents_test.append(net.forward_once(batch)) 104 | 105 | for i in tqdm(np.arange(0, len(synthetic_images), bs), "Computing Synthetic Embeddings"): 106 | batch = synthetic_images[i:i+bs].cuda() 107 | latents_synth.append(net.forward_once(batch)) 108 | 109 | latents_train = torch.cat(latents_train) 110 | latents_test = torch.cat(latents_test) 111 | latents_synth = torch.cat(latents_synth) 112 | 113 | 114 | # Automatically determine the privacy threshold 115 | train_val_corr = torch.corrcoef(torch.cat([latents_train, latents_test])).cpu() 116 | print(train_val_corr.size()) 117 | 118 | closest_train = [] 119 | for i in range(len(train_images)): 120 | val_matches = train_val_corr[i, len(train_images):] 121 | closest_train.append(val_matches.max().cpu()) 122 | 123 | tau = np.percentile(torch.stack(closest_train).numpy(), args.cutoff_precentile) 124 | print(f"Privacy threshold tau: {tau}") 125 | 126 | # Compute the closest matches between synthetic and real images 127 | closest_test = [] 128 | batch_size = 10000 129 | latents_synth.cpu() 130 | closest_loc = [] 131 | for l in np.arange(0, len(latents_synth), batch_size): 132 | synth = latents_synth[l:l + batch_size].cuda() 133 | train_synth_corr = torch.corrcoef(torch.cat([latents_train, synth])) 134 | for i in range(len(synth)): 135 | synth_matches = train_synth_corr[len(train_images)+i, :len(train_images)] 136 | closest_test.append(synth_matches.max().cpu()) 137 | closest_loc.append(synth_matches.argmax().cpu()) 138 | synth = synth.cpu() 139 | 140 | # Plot the results 141 | fig, ax = plt.subplots() 142 | nt, bins, patches = ax.hist(closest_train, 100, density=True, label="Train-Val", alpha=0.5, color="blue") 143 | ns, bins, patches = ax.hist(closest_test, 100, density=True, label="Train-Synth", alpha=.5, color="orange") 144 | ax.axvline(tau, 0, max(max(nt), max(ns)), color="black") 145 | ax.set_xlabel('Correlation') 146 | ax.set_ylabel('Frequency') 147 | ax.set_title('Histogram of highest correlation matches') 148 | ax.text(tau, max(max(nt), max(ns)), f'$tau$ = {tau:0.5f} ', ha='right', va='bottom', rotation=0, color='black') 149 | plt.legend() 150 | fig.tight_layout() 151 | plt.savefig(os.path.join(args.model, "privacy_results.png")) 152 | 153 | closest_test = torch.stack(closest_test) 154 | is_public = closest_test < tau 155 | print(f"Number of synthetic images: {len(closest_test)}") 156 | print(f"Number of non private synthetic images: {sum(is_public)} - memorization rate: {1- is_public.float().mean()}") 157 | 158 | # Save the privacy-compliant latents 159 | private_latents_output_path = os.path.abspath(args.output) 160 | os.makedirs(private_latents_output_path, exist_ok=True) 161 | 162 | for path, is_pub in zip(synthetic_images_paths, is_public): 163 | src = os.path.join(args.synthetic, path) 164 | tgt = os.path.join(private_latents_output_path, path) 165 | if is_pub: 166 | shutil.copy(src, tgt) 167 | else: 168 | print(f"Skipping images because it appears memorized: {tgt} ") 169 | 170 | results = {"real_path":[], "tau":[], "synth_path":[], "img_idx": []} 171 | for clos_idx, tau_val, path in zip(closest_loc, closest_test, synthetic_images_paths): 172 | results["real_path"].append(ds_train.df.iloc[int(clos_idx)]["FileName"]) 173 | results["tau"].append(float(tau_val)) 174 | results["synth_path"].append(path) 175 | results["img_idx"].append(int(clos_idx)) 176 | res = pd.DataFrame(results) 177 | res.to_csv(os.path.join(os.path.dirname(private_latents_output_path), "privacy_scores.csv"), index=False) 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /echosyn/privacy/configs/config_dynamic.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_description": "reidentification_dynamic", 3 | "resumption": false, 4 | "resumption_count": null, 5 | "previous_experiment": null, 6 | 7 | "image_path": "data/latents/dynamic/FileList.csv", 8 | 9 | "siamese_architecture": "ResNet-50", 10 | "data_handling": "balanced", 11 | 12 | "num_workers": 16, 13 | "pin_memory": true, 14 | 15 | "n_channels": 4, 16 | "n_features": 128, 17 | "image_size": 14, 18 | "loss": "BCEWithLogitsLoss", 19 | "optimizer": "Adam", 20 | "learning_rate": 0.0001, 21 | "batch_size": 128, 22 | "max_epochs": 1000, 23 | "early_stopping": 50, 24 | "transform": "pmone", 25 | 26 | "n_samples_train": 7465, 27 | "n_samples_val": 1288, 28 | "n_samples_test": 1277 29 | } 30 | -------------------------------------------------------------------------------- /echosyn/privacy/configs/config_ped_a4c.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_description": "reidentification_ped_a4c", 3 | "resumption": false, 4 | "resumption_count": null, 5 | "previous_experiment": null, 6 | 7 | "image_path": "data/latents/ped_a4c/FileList.csv", 8 | 9 | "siamese_architecture": "ResNet-50", 10 | "data_handling": "balanced", 11 | 12 | "num_workers": 16, 13 | "pin_memory": true, 14 | 15 | "n_channels": 4, 16 | "n_features": 128, 17 | "image_size": 14, 18 | "loss": "BCEWithLogitsLoss", 19 | "optimizer": "Adam", 20 | "learning_rate": 0.0001, 21 | "batch_size": 128, 22 | "max_epochs": 1000, 23 | "early_stopping": 50, 24 | "transform": "pmone", 25 | 26 | "n_samples_train": 2580, 27 | "n_samples_val": 368, 28 | "n_samples_test": 336 29 | } 30 | -------------------------------------------------------------------------------- /echosyn/privacy/configs/config_ped_psax.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_description": "reidentification_ped_psax", 3 | "resumption": false, 4 | "resumption_count": null, 5 | "previous_experiment": null, 6 | 7 | "image_path": "data/latents/ped_psax/FileList.csv", 8 | 9 | "siamese_architecture": "ResNet-50", 10 | "data_handling": "balanced", 11 | 12 | "num_workers": 16, 13 | "pin_memory": true, 14 | 15 | "n_channels": 4, 16 | "n_features": 128, 17 | "image_size": 14, 18 | "loss": "BCEWithLogitsLoss", 19 | "optimizer": "Adam", 20 | "learning_rate": 0.0001, 21 | "batch_size": 128, 22 | "max_epochs": 1000, 23 | "early_stopping": 50, 24 | "transform": "pmone", 25 | 26 | "n_samples_train": 3559, 27 | "n_samples_val": 519, 28 | "n_samples_test": 448 29 | } 30 | -------------------------------------------------------------------------------- /echosyn/privacy/shared.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class SiameseNetwork(nn.Module): 7 | def __init__(self, network='ResNet-50', in_channels=3, n_features=128): 8 | super(SiameseNetwork, self).__init__() 9 | self.network = network 10 | self.in_channels = in_channels 11 | self.n_features = n_features 12 | 13 | if self.network == 'ResNet-50': 14 | # Model: Use ResNet-50 architecture 15 | self.model = models.resnet50(pretrained=True) 16 | # Adjust the input layer: either 1 or 3 input channels 17 | if self.in_channels == 1: 18 | self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 19 | if self.in_channels == 4: 20 | self.model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 21 | elif self.in_channels == 3: 22 | pass 23 | else: 24 | raise Exception( 25 | 'Invalid argument: ' + self.in_channels + '\nChoose either in_channels=1 or in_channels=3') 26 | # Adjust the ResNet classification layer to produce feature vectors of a specific size 27 | self.model.fc = nn.Linear(in_features=2048, out_features=self.n_features, bias=True) 28 | 29 | else: 30 | raise Exception('Invalid argument: ' + self.network + 31 | '\nChoose ResNet-50! Other architectures are not yet implemented in this framework.') 32 | 33 | self.fc_end = nn.Linear(self.n_features, 1) 34 | 35 | def forward_once(self, x): 36 | 37 | # Forward function for one branch to get the n_features-dim feature vector before merging 38 | output = self.model(x) 39 | output = torch.sigmoid(output) 40 | return output 41 | 42 | def forward(self, input1, input2): 43 | 44 | # Forward 45 | output1 = self.forward_once(input1) 46 | output2 = self.forward_once(input2) 47 | 48 | # Compute the absolute difference between the n_features-dim feature vectors and pass it to the last FC-Layer 49 | difference = torch.abs(output1 - output2) 50 | output = self.fc_end(difference) 51 | 52 | return output 53 | 54 | 55 | -------------------------------------------------------------------------------- /echosyn/privacy/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import copy 5 | import time 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torchvision.transforms as transforms 13 | 14 | from echosyn.common import privacy_utils as Utils 15 | from echosyn.privacy.shared import SiameseNetwork 16 | 17 | class AgentSiameseNetwork: 18 | def __init__(self, config): 19 | self.config = config 20 | 21 | # set path used to save experiment-related files and results 22 | self.SAVINGS_PATH = './experiments/' + self.config['experiment_description'] + '/' 23 | self.IMAGE_PATH = self.config['image_path'] 24 | 25 | # save configuration as config.json in the created folder 26 | with open(self.SAVINGS_PATH + 'config.json', 'w') as outfile: 27 | json.dump(self.config, outfile, indent='\t') 28 | outfile.close() 29 | 30 | # enable benchmark mode in cuDNN 31 | torch.backends.cudnn.benchmark = True 32 | 33 | # set all the important variables 34 | self.network = self.config['siamese_architecture'] 35 | self.data_handling = self.config['data_handling'] 36 | 37 | self.num_workers = self.config['num_workers'] 38 | self.pin_memory = self.config['pin_memory'] 39 | 40 | self.n_channels = self.config['n_channels'] 41 | self.n_features = self.config['n_features'] 42 | self.image_size = self.config['image_size'] 43 | self.loss_method = self.config['loss'] 44 | self.optimizer_method = self.config['optimizer'] 45 | self.learning_rate = self.config['learning_rate'] 46 | self.batch_size = self.config['batch_size'] 47 | self.max_epochs = self.config['max_epochs'] 48 | self.early_stopping = self.config['early_stopping'] 49 | self.transform = self.config['transform'] 50 | 51 | self.n_samples_train = self.config['n_samples_train'] 52 | self.n_samples_val = self.config['n_samples_val'] 53 | self.n_samples_test = self.config['n_samples_test'] 54 | 55 | self.start_epoch = 0 56 | 57 | self.es = EarlyStopping(patience=self.early_stopping) 58 | self.best_loss = 100000 59 | self.loss_dict = {'training': [], 60 | 'validation': []} 61 | 62 | if self.data_handling == 'balanced': 63 | self.balanced = True 64 | self.randomized = False 65 | elif self.data_handling == 'randomized': 66 | self.balanced = True 67 | self.randomized = True 68 | 69 | # define the suffix needed for loading the checkpoint (in case you want to resume a previous experiment) 70 | if self.config['resumption'] is True: 71 | if self.config['resumption_count'] == 1: 72 | self.load_suffix = '' 73 | elif self.config['resumption_count'] == 2: 74 | self.load_suffix = '_resume' 75 | elif self.config['resumption_count'] > 2: 76 | self.load_suffix = '_resume' + str(self.config['resumption_count'] - 1) 77 | 78 | # define the suffix needed for saving the checkpoint (the checkpoint is saved at the end of each epoch) 79 | if self.config['resumption'] is False: 80 | self.save_suffix = '' 81 | elif self.config['resumption'] is True: 82 | if self.config['resumption_count'] == 1: 83 | self.save_suffix = '_resume' 84 | elif self.config['resumption_count'] > 1: 85 | self.save_suffix = '_resume' + str(self.config['resumption_count']) 86 | 87 | # Define the siamese neural network architecture 88 | self.net = SiameseNetwork(network=self.network, in_channels=self.n_channels, n_features=self.n_features).cuda() 89 | self.best_net = SiameseNetwork(network=self.network, in_channels=self.n_channels, 90 | n_features=self.n_features).cuda() 91 | 92 | # Choose loss function 93 | if self.loss_method == 'BCEWithLogitsLoss': 94 | self.loss = nn.BCEWithLogitsLoss().cuda() 95 | else: 96 | raise Exception('Invalid argument: ' + self.loss_method + 97 | '\nChoose BCEWithLogitsLoss! Other loss functions are not yet implemented!') 98 | 99 | # Set the optimizer function 100 | if self.optimizer_method == 'Adam': 101 | self.optimizer = optim.Adam(self.net.parameters(), lr=self.learning_rate) 102 | else: 103 | raise Exception('Invalid argument: ' + self.optimizer_method + 104 | '\nChoose Adam! Other optimizer functions are not yet implemented!') 105 | 106 | # load state dicts and other information in case a previous experiment will be continued 107 | if self.config['resumption'] is True: 108 | self.checkpoint = torch.load('./archive/' + self.config['previous_experiment'] + '/' + self.config[ 109 | 'previous_experiment'] + '_checkpoint' + self.load_suffix + '.pth') 110 | self.best_net.load_state_dict(torch.load( 111 | './archive/' + self.config['previous_experiment'] + '/' + self.config[ 112 | 'previous_experiment'] + '_best_network' + self.load_suffix + '.pth')) 113 | self.net.load_state_dict(self.checkpoint['state_dict']) 114 | self.optimizer.load_state_dict(self.checkpoint['optimizer']) 115 | self.best_loss = self.checkpoint['best_loss'] 116 | self.loss_dict = self.checkpoint['loss_dict'] 117 | self.es.best = self.checkpoint['best_loss'] 118 | self.es.num_bad_epochs = self.checkpoint['num_bad_epochs'] 119 | self.start_epoch = self.checkpoint['epoch'] 120 | 121 | # Initialize transformations 122 | if self.transform == 'image_net': 123 | self.transform_train = transforms.Compose([ 124 | transforms.Resize((self.image_size, self.image_size)), 125 | transforms.ToTensor(), 126 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 127 | ]) 128 | self.transform_val_test = transforms.Compose([ 129 | transforms.Resize((self.image_size, self.image_size)), 130 | transforms.ToTensor(), 131 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 132 | ]) 133 | elif self.transform == "default": 134 | self.transform_train = transforms.Compose([ 135 | transforms.Resize((self.image_size, self.image_size)), 136 | transforms.ToTensor() 137 | ]) 138 | self.transform_val_test = transforms.Compose([ 139 | transforms.Resize((self.image_size, self.image_size)), 140 | transforms.ToTensor() 141 | ]) 142 | elif self.transform == "pmone": 143 | self.transform_train = lambda x: (x - x.min())/(x.max() - x.min()) * 2 - 1 144 | self.transform_val_test = lambda x: (x - x.min())/(x.max() - x.min()) * 2 - 1 145 | 146 | elif self.transform == "none": 147 | self.transform_train = lambda x: x 148 | self.transform_val_test = lambda x: x 149 | 150 | 151 | self.training_loader = Utils.get_data_loaders(phase='training', data_handling=self.data_handling, 152 | n_channels=self.n_channels, 153 | n_samples=self.n_samples_train, 154 | transform=self.transform_train, image_path=self.IMAGE_PATH, 155 | batch_size=self.batch_size, shuffle=True, 156 | num_workers=self.num_workers, pin_memory=self.pin_memory, 157 | save_path=None) 158 | self.validation_loader = Utils.get_data_loaders(phase='validation', data_handling=self.data_handling, 159 | n_channels=self.n_channels, 160 | n_samples=self.n_samples_val, 161 | transform=self.transform_val_test, 162 | image_path=self.IMAGE_PATH, batch_size=self.batch_size, 163 | shuffle=False, num_workers=self.num_workers, 164 | pin_memory=self.pin_memory, save_path=None, generator_seed=0) 165 | self.test_loader = Utils.get_data_loaders(phase='testing', data_handling='balanced', 166 | n_channels=self.n_channels, n_samples=self.n_samples_test, 167 | transform=self.transform_val_test, image_path=self.IMAGE_PATH, 168 | batch_size=self.batch_size, shuffle=False, 169 | num_workers=self.num_workers, pin_memory=self.pin_memory, 170 | save_path=None, generator_seed=0) 171 | 172 | def training_validation(self): 173 | # Training and validation loop! 174 | for epoch in range(self.start_epoch, self.max_epochs): 175 | start_time = time.time() 176 | 177 | training_loss = Utils.train(self.net, self.training_loader, self.n_samples_train, self.batch_size, 178 | self.loss, self.optimizer, epoch, self.max_epochs) 179 | 180 | self.validation_loader.dataset.reset_generator() # make sure this is deterministic for more accurate loss 181 | validation_loss = Utils.validate(self.net, self.validation_loader, self.n_samples_val, self.batch_size, 182 | self.loss, epoch, self.max_epochs) 183 | 184 | self.loss_dict['training'].append(training_loss) 185 | self.loss_dict['validation'].append(validation_loss) 186 | end_time = time.time() 187 | print('Time elapsed for epoch ' + str(epoch + 1) + ': ' + str( 188 | round((end_time - start_time) / 60, 2)) + ' minutes') 189 | 190 | if validation_loss < self.best_loss: 191 | self.best_loss = validation_loss 192 | self.best_net = copy.deepcopy(self.net) 193 | 194 | torch.save(self.best_net.state_dict(), self.SAVINGS_PATH + self.config[ 195 | 'experiment_description'] + '_best_network' + self.save_suffix + '.pth') 196 | 197 | Utils.save_loss_curves(self.loss_dict, self.SAVINGS_PATH, self.config['experiment_description']) 198 | Utils.plot_loss_curves(self.loss_dict, self.SAVINGS_PATH, self.config['experiment_description']) 199 | 200 | es_check = self.es.step(validation_loss) 201 | Utils.save_checkpoint(epoch, self.net, self.optimizer, self.loss_dict, self.best_loss, 202 | self.es.num_bad_epochs, self.SAVINGS_PATH + self.config[ 203 | 'experiment_description'] + '_checkpoint' + self.save_suffix + '.pth') 204 | 205 | if es_check: 206 | break 207 | 208 | print('Finished Training!') 209 | 210 | def testing_evaluation(self): 211 | # Testing phase! 212 | self.test_loader.dataset.reset_generator() 213 | y_true, y_pred = Utils.test(self.best_net, self.test_loader) 214 | y_true, y_pred = [y_true.numpy(), y_pred.numpy()] 215 | 216 | # Compute the evaluation metrics! 217 | fp_rates, tp_rates, thresholds = metrics.roc_curve(y_true, y_pred) 218 | auc = metrics.roc_auc_score(y_true, y_pred) 219 | y_pred_thresh = Utils.apply_threshold(y_pred, 0.5) 220 | accuracy, f1_score, precision, recall, report, confusion_matrix = Utils.get_evaluation_metrics(y_true, 221 | y_pred_thresh) 222 | auc_mean, confidence_lower, confidence_upper = Utils.bootstrap(10000, 223 | y_true, 224 | y_pred, 225 | self.SAVINGS_PATH, 226 | self.config['experiment_description']) 227 | 228 | # Plot ROC curve! 229 | Utils.plot_roc_curve(fp_rates, tp_rates, self.SAVINGS_PATH, self.config['experiment_description']) 230 | 231 | # Save all the results to files! 232 | Utils.save_labels_predictions(y_true, y_pred, y_pred_thresh, self.SAVINGS_PATH, 233 | self.config['experiment_description']) 234 | 235 | Utils.save_results_to_file(auc, accuracy, f1_score, precision, recall, report, confusion_matrix, 236 | self.SAVINGS_PATH, self.config['experiment_description']) 237 | 238 | Utils.save_roc_metrics_to_file(fp_rates, tp_rates, thresholds, self.SAVINGS_PATH, 239 | self.config['experiment_description']) 240 | 241 | # Print the evaluation metrics! 242 | print('EVALUATION METRICS:') 243 | print('AUC: ' + str(auc)) 244 | print('Accuracy: ' + str(accuracy)) 245 | print('F1-Score: ' + str(f1_score)) 246 | print('Precision: ' + str(precision)) 247 | print('Recall: ' + str(recall)) 248 | print('Report: ' + str(report)) 249 | print('Confusion matrix: ' + str(confusion_matrix)) 250 | 251 | print('BOOTSTRAPPING: ') 252 | print('AUC Mean: ' + str(auc_mean)) 253 | print('Confidence interval for the AUC score: ' + str(confidence_lower) + ' - ' + str(confidence_upper)) 254 | 255 | def run(self): 256 | # Call training/validation and testing loop successively 257 | self.training_validation() 258 | self.testing_evaluation() 259 | 260 | class EarlyStopping(object): 261 | def __init__(self, mode='min', min_delta=0, patience=10, percentage=False): 262 | self.mode = mode 263 | self.min_delta = min_delta 264 | self.patience = patience 265 | self.best = None 266 | self.num_bad_epochs = 0 267 | self.is_better = None 268 | self._init_is_better(mode, min_delta, percentage) 269 | 270 | if patience == 0: 271 | self.is_better = lambda a, b: True 272 | self.step = lambda a: False 273 | 274 | def step(self, metrics): 275 | if self.best is None: 276 | self.best = metrics 277 | return False 278 | 279 | if np.isnan(metrics): 280 | return True 281 | 282 | if self.is_better(metrics, self.best): 283 | self.num_bad_epochs = 0 284 | self.best = metrics 285 | else: 286 | self.num_bad_epochs += 1 287 | 288 | if self.num_bad_epochs >= self.patience: 289 | return True 290 | 291 | return False 292 | 293 | def _init_is_better(self, mode, min_delta, percentage): 294 | if mode not in {'min', 'max'}: 295 | raise ValueError('mode ' + mode + ' is unknown!') 296 | if not percentage: 297 | if mode == 'min': 298 | self.is_better = lambda a, best: a < best - min_delta 299 | if mode == 'max': 300 | self.is_better = lambda a, best: a > best + min_delta 301 | else: 302 | if mode == 'min': 303 | self.is_better = lambda a, best: a < best - ( 304 | best * min_delta / 100) 305 | if mode == 'max': 306 | self.is_better = lambda a, best: a > best + ( 307 | best * min_delta / 100) 308 | 309 | 310 | if __name__ == '__main__': 311 | # define an argument parser 312 | parser = argparse.ArgumentParser('Patient Verification') 313 | parser.add_argument('--config_path', default='./', help='the path where the config files are stored') 314 | parser.add_argument('--config', default='config.json', help='the hyper-parameter configuration and experiment settings') 315 | args = parser.parse_args() 316 | print('Arguments:\n' + '--config_path: ' + args.config_path + '\n--config: ' + args.config) 317 | 318 | # read config 319 | with open(args.config_path + args.config, 'r') as config: 320 | config = config.read() 321 | 322 | # parse config 323 | config = json.loads(config) 324 | 325 | # create folder to save experiment-related files 326 | os.makedirs(os.path.join('./experiments/' , config['experiment_description']), exist_ok=True) 327 | SAVINGS_PATH = './experiments/' + config['experiment_description'] + '/' 328 | 329 | # call agent and run experiment 330 | experiment = AgentSiameseNetwork(config) 331 | experiment.run() -------------------------------------------------------------------------------- /echosyn/vae/README.md: -------------------------------------------------------------------------------- 1 | # Variational Autoencoder (fun fact: it's also a GAN) 2 | 3 | Training the VAE is not a necessity as an open source VAE such as [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) should work well enough for this project. 4 | 5 | However, if you would like to train the VAE yourself, follow these steps. 6 | 7 | ## 1. Clone the necessary repositories 8 | We use the Stable-Diffusion repo to train the VAE model, but the Taming-Transformers repo is a necessary dependency. It is advised to follow these steps exactly to avoid errors later on. 9 | 10 | ```bash 11 | cd external 12 | git clone https://github.com/CompVis/stable-diffusion 13 | 14 | cd stable-diffusion 15 | conda env create -f environment.yaml 16 | conda activate ldm 17 | ``` 18 | 19 | This will download the Stable-Diffusion repository and create the necessary environment to run it. 20 | 21 | ```bash 22 | git clone https://github.com/CompVis/taming-transformers 23 | cd taming-transformers 24 | pip install -e . 25 | ``` 26 | 27 | This will download the Taming-Transformers repository and install it as a package. 28 | 29 | ## 2. Prepare the data 30 | The VAE needs images only and to keep things simple, we format our video datasets into image datasets. 31 | We do that with: 32 | ```bash 33 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Dynamic/Videos data/vae_train_images/images/ 34 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/A4C/Videos data/vae_train_images/images/ 35 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/PSAX/Videos data/vae_train_images/images/ 36 | ``` 37 | 38 | Note that this will merge all the images in the same folder. 39 | 40 | Then, we need to create train.txt file and a val.txt file containing the path to these images. 41 | ```bash 42 | find $(cd data/vae_train_images/images && pwd) -type f | shuf > tmp.txt 43 | head -n -1000 tmp.txt > data/vae_train_images/train.txt 44 | tail -n 1000 tmp.txt > data/vae_train_images/val.txt 45 | rm tmp.txt 46 | ``` 47 | 48 | That's it for the dataset. 49 | 50 | 51 | ## 3. Train the VAE 52 | Now that the data is ready, we can train the VAE. We use the following command to train the VAE on the images we just extracted. 53 | 54 | ```bash 55 | cd external/stable-diffusion 56 | export DATADIR=$(cd ../../data/vae_train_images && pwd) 57 | python main.py \ 58 | --base ../../echosyn/vae/usencoder_kl_16x16x4.yaml \ 59 | -t True \ 60 | --gpus 0,1,2,3,4,5,6,7 \ 61 | --logdir experiments/vae \ 62 | ``` 63 | 64 | This will train the VAE on the images we extracted and save the model in the experiments/vae folder.
65 | For the paper, we train the VAE on 8xA100 GPUs for 5 days.
66 | *Note: if you use a single gpu, you need to leave a comma in the --gpus argument, like so: ```--gpus 0,```*
67 | To resume training from a checkpoint, use the ```--resume``` flag and replace ```EXPERIMENT_NAME``` with the correct experiment name. 68 | 69 | ```bash 70 | python main.py \ 71 | --base ../../echosyn/vae/usencoder_kl_16x16x4.yaml \ 72 | -t True \ 73 | --gpus 0,1,2,3,4,5,6,7 \ 74 | --logdir experiments/vae \ 75 | --resume experiments/vae/EXPERIMENT_NAME 76 | ``` 77 | 78 | ## 4. Export the VAE to Diffusers 🧨 79 | Now that the VAE is trained, we can export it to a Diffuser AutoencoderKL model. This is done with the following command, replacing ```EXPERIMENT_NAME``` with the correct experiment name. 80 | 81 | ```bash 82 | python scripts/convert_vae_pt_to_diffusers.py 83 | --vae_pt_path experiments/EXPERIMENT_NAME/checkpoints/last.ckpt 84 | --dump_path models/vae/ 85 | ``` 86 | The script is taken as-is from the [Diffusers library](https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py). 87 | 88 | ## 5. Test the model in Diffusers 🧨 89 | That's it! You now have a VAE model trained on your own data and exported to a Diffuser model. 90 | 91 | To test the model, use the echosyn environment and do: 92 | ```python 93 | from PIL import Image 94 | import torch 95 | import numpy as np 96 | from diffusers import AutoencoderKL 97 | model = AutoencoderKL.from_pretrained("models/vae") 98 | model.eval() 99 | 100 | # Use the model to encode and decode images 101 | img = Image.open("data/vae_train_images/images/0X10A5FC19152B50A5_00001.jpg") 102 | 103 | img = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).to(torch.float32) / 128 - 1 104 | 105 | with torch.no_grad(): 106 | lat = model.encode(img).latent_dist.sample() 107 | rec = model.decode(lat).sample 108 | # Display the original and reconstructed images 109 | img = img.squeeze(0).permute(1,2,0) 110 | rec = rec.squeeze(0).permute(1,2,0) 111 | display(Image.fromarray(((img + 1) * 128).to(torch.uint8).numpy())) 112 | ``` 113 | 114 | Note that model might not work in mixed precision mode, which would cause nan values. If this happens, force the model dtype to torch.float32 115 | 116 | ## 6. Evaluate the VAE 117 | 118 | To evaluate the VAE, we reconstruct the extracted image datasets and compare the original images to the reconstructed images. We use the following command to recontruct (encode-decode) the images from the dynamic dataset. 119 | 120 | ```bash 121 | python scripts/vae_reconstruct_image_folder.py \ 122 | -model models/vae \ 123 | -input data/reference/dynamic \ 124 | -output samples/reconstructed/dynamic \ 125 | -batch_size 32 126 | ``` 127 | 128 | This will save the reconstructed images in the data/reconstructed/dynamic folder. To evaluate the VAE, we use the following command to compute the FID score between the original images and the reconstructed images, with the help of the StyleGAN-V repository. 129 | 130 | ```bash 131 | cd external/stylegan-v 132 | 133 | python src/scripts/calc_metrics_for_dataset.py \ 134 | --real_data_path ../../data/reference/dynamic \ 135 | --fake_data_path ../../samples/reconstructed/dynamic \ 136 | --mirror 0 --gpus 1 --resolution 112 \ 137 | --metrics fid50k_full,is50k >> "../../samples/reconstructed/dynamic.txt" 138 | ``` 139 | 140 | 141 | -------------------------------------------------------------------------------- /echosyn/vae/usencoder_kl_16x16x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2e-6 # ~5e-4 after scaloing 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 112 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 32 30 | num_workers: 16 31 | train: 32 | target: taming.data.custom.CustomTrain 33 | params: 34 | training_images_list_file: ${oc.env:DATADIR}/train.txt 35 | size: 112 36 | validation: 37 | target: taming.data.custom.CustomTest 38 | params: 39 | test_images_list_file: ${oc.env:DATADIR}/val.txt 40 | size: 112 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | max_epochs: 1000 -------------------------------------------------------------------------------- /external/README.md: -------------------------------------------------------------------------------- 1 | # This folder contains the external libraries that are used in the project. 2 | 3 | External libraries can be cloned from the following repositories, into the `external` folder: 4 |
5 | 1. Stable-Diffusion: ```git clone https://github.com/CompVis/stable-diffusion``` 6 |
Necessary to train the VAE model. 7 | 8 | 9 | 2. Taming-Transformers: ```git clone https://github.com/CompVis/taming-transformers``` 10 |
Necessary to train the VAE model. 11 | 12 | 13 | 3. StyleGAN-V: ```git clone git@github.com:HReynaud/stylegan-v.git``` 14 |
Used to compute all generative metrics (FID, FVD, IS). 15 | 16 | 17 | 4. EchoNet-Dynamic: ```git clone git@github.com:HReynaud/echonet.git``` 18 |
Used to evaluate the synthetic data on the Ejection Fraction downstream task. 19 | 20 | 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.2.1 2 | torchvision==0.17.1 3 | einops==0.7.0 4 | decord==0.6.0 5 | diffusers==0.27.2 6 | packaging==24.0 7 | omegaconf==2.3.0 8 | transformers==4.39.0 9 | accelerate==0.28.0 10 | pandas==2.2.1 11 | wandb==0.16.4 12 | opencv-python==4.9.0.80 13 | moviepy==1.0.3 14 | imageio==2.34.0 15 | scipy==1.12.0 16 | scikit-learn==1.4.1.post1 17 | matplotlib-3.8.3 -------------------------------------------------------------------------------- /ressources/models.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/models.jpg -------------------------------------------------------------------------------- /ressources/mosaic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/mosaic.gif -------------------------------------------------------------------------------- /ressources/mosaic_slim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/mosaic_slim.gif -------------------------------------------------------------------------------- /ressources/real.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/real.gif -------------------------------------------------------------------------------- /ressources/synt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/synt.gif -------------------------------------------------------------------------------- /scripts/complete_pediatrics_filelist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import os 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | """ 8 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/A4C 9 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/PSAX 10 | """ 11 | 12 | def get_video_metadata(video_path): 13 | cap = cv2.VideoCapture(video_path) 14 | fps = cap.get(cv2.CAP_PROP_FPS) 15 | nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 16 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 17 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 18 | cap.release() 19 | return fps, nframes, width, height 20 | 21 | if __name__ == "__main__": 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory') 25 | args = parser.parse_args() 26 | 27 | csv_path = os.path.join(args.dataset, 'FileList.csv') 28 | assert os.path.exists(csv_path), f"Could not find FileList.csv at {csv_path}" 29 | 30 | metadata = pd.read_csv(csv_path) 31 | metadata.to_csv(os.path.join(args.dataset, 'FileList_ORIGINAL.csv'), index=False) # backup 32 | 33 | metadata['FileName'] = metadata['FileName'].apply(lambda x: x.split('.')[0]) # remove extension 34 | metadata['Fold'] = metadata['Split'] # Copy kfold indices to the Fold column 35 | metadata['Split'] = metadata['Fold'].apply(lambda x: 'TRAIN' if x in range(8) else 'VAL' if x == 8 else 'TEST') 36 | 37 | # Add columns: 38 | # df.loc[df['FileName'] == fname, ['FileName', 'FrameHeight','FrameWidth','FPS','NumberOfFrames']] = [fname, 112, 112, fps, len(video)] 39 | 40 | for i, row in tqdm(metadata.iterrows(), total=len(metadata)): 41 | video_name = row['FileName'] + '.avi' 42 | video_path = os.path.join(args.dataset, 'Videos', video_name) 43 | 44 | fps, nframes, width, height = get_video_metadata(video_path) 45 | 46 | metadata.loc[i, ['FrameHeight','FrameWidth','FPS','NumberOfFrames']] = [height, width, fps, nframes] 47 | 48 | metadata.to_csv(csv_path, index=False) 49 | print("Updated metadata saved to ", csv_path) 50 | 51 | 52 | -------------------------------------------------------------------------------- /scripts/convert_vae_pt_to_diffusers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | 4 | import requests 5 | import torch 6 | from omegaconf import OmegaConf 7 | 8 | from diffusers import AutoencoderKL 9 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( 10 | assign_to_checkpoint, 11 | conv_attn_to_linear, 12 | create_vae_diffusers_config, 13 | renew_vae_attention_paths, 14 | renew_vae_resnet_paths, 15 | ) 16 | 17 | 18 | def custom_convert_ldm_vae_checkpoint(checkpoint, config): 19 | vae_state_dict = checkpoint 20 | 21 | new_checkpoint = {} 22 | 23 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 24 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 25 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 26 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 27 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 28 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 29 | 30 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 31 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 32 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 33 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 34 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 35 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 36 | 37 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 38 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 39 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 40 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 41 | 42 | # Retrieves the keys for the encoder down blocks only 43 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 44 | down_blocks = { 45 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 46 | } 47 | 48 | # Retrieves the keys for the decoder up blocks only 49 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 50 | up_blocks = { 51 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 52 | } 53 | 54 | for i in range(num_down_blocks): 55 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 56 | 57 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 58 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 59 | f"encoder.down.{i}.downsample.conv.weight" 60 | ) 61 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 62 | f"encoder.down.{i}.downsample.conv.bias" 63 | ) 64 | 65 | paths = renew_vae_resnet_paths(resnets) 66 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 67 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 68 | 69 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 70 | num_mid_res_blocks = 2 71 | for i in range(1, num_mid_res_blocks + 1): 72 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 73 | 74 | paths = renew_vae_resnet_paths(resnets) 75 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 76 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 77 | 78 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 79 | paths = renew_vae_attention_paths(mid_attentions) 80 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 81 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 82 | conv_attn_to_linear(new_checkpoint) 83 | 84 | for i in range(num_up_blocks): 85 | block_id = num_up_blocks - 1 - i 86 | resnets = [ 87 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 88 | ] 89 | 90 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 91 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 92 | f"decoder.up.{block_id}.upsample.conv.weight" 93 | ] 94 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 95 | f"decoder.up.{block_id}.upsample.conv.bias" 96 | ] 97 | 98 | paths = renew_vae_resnet_paths(resnets) 99 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 100 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 101 | 102 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 103 | num_mid_res_blocks = 2 104 | for i in range(1, num_mid_res_blocks + 1): 105 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 106 | 107 | paths = renew_vae_resnet_paths(resnets) 108 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 109 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 110 | 111 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 112 | paths = renew_vae_attention_paths(mid_attentions) 113 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 114 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 115 | conv_attn_to_linear(new_checkpoint) 116 | return new_checkpoint 117 | 118 | 119 | def vae_pt_to_vae_diffuser( 120 | checkpoint_path: str, 121 | output_path: str, 122 | ): 123 | # Only support V1 124 | r = requests.get( 125 | "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" 126 | ) 127 | io_obj = io.BytesIO(r.content) 128 | 129 | original_config = yaml.safe_load(io_obj) 130 | image_size = 512 131 | device = "cuda" if torch.cuda.is_available() else "cpu" 132 | if checkpoint_path.endswith("safetensors"): 133 | from safetensors import safe_open 134 | 135 | checkpoint = {} 136 | with safe_open(checkpoint_path, framework="pt", device="cpu") as f: 137 | for key in f.keys(): 138 | checkpoint[key] = f.get_tensor(key) 139 | else: 140 | checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"] 141 | 142 | # Convert the VAE model. 143 | vae_config = create_vae_diffusers_config(original_config, image_size=image_size) 144 | converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config) 145 | 146 | vae = AutoencoderKL(**vae_config) 147 | vae.load_state_dict(converted_vae_checkpoint) 148 | vae.save_pretrained(output_path) 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | 154 | parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") 155 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") 156 | 157 | args = parser.parse_args() 158 | 159 | vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path) 160 | -------------------------------------------------------------------------------- /scripts/copy_privacy_compliant_images.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check for the correct number of arguments 4 | if [ "$#" -ne 3 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | # Assign arguments to variables 10 | source_image_folder="$1" 11 | source_pt_folder="$2" 12 | destination_folder="$3" 13 | 14 | # Create the destination folder if it doesn't exist 15 | mkdir -p "$destination_folder" 16 | 17 | # Enable nullglob to ensure wildcard patterns that match no files expand to nothing 18 | shopt -s nullglob 19 | 20 | # Initialize counters 21 | copied_count=0 22 | total_files=$(ls -1q "$source_image_folder" | wc -l) 23 | matched_files=() 24 | pt_files_count=$(ls -1q "$source_pt_folder"/*.pt | wc -l) 25 | processed_files=0 26 | 27 | # Loop through .pt files in the source_pt_folder 28 | for pt_file in "$source_pt_folder"/*.pt; do 29 | # Extract the basename without the extension 30 | base_name=$(basename "$pt_file" .pt) 31 | 32 | # Construct the wildcard pattern for the source files 33 | pattern="$source_image_folder"/"$base_name".* 34 | 35 | # Attempt to copy the matching file(s) to the destination folder 36 | if cp $pattern "$destination_folder" 2>/dev/null; then 37 | matched_files+=("$pattern") 38 | ((copied_count++)) 39 | fi 40 | 41 | # Update processed files count and calculate progress 42 | ((processed_files++)) 43 | progress=$((processed_files * 100 / pt_files_count)) 44 | echo -ne "Progress: $progress% \r" 45 | done 46 | 47 | # Disable nullglob to revert back to default behavior 48 | shopt -u nullglob 49 | 50 | # Calculate ignored files 51 | ignored_count=$((total_files - ${#matched_files[@]})) 52 | 53 | # Final newline after progress 54 | echo "" 55 | 56 | # Display summary 57 | echo "Operation completed." 58 | echo "Files copied: $copied_count" 59 | echo "Files ignored: $ignored_count" 60 | -------------------------------------------------------------------------------- /scripts/create_reference_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import decord 3 | import numpy as np 4 | import os 5 | import pandas as pd 6 | import torch 7 | from tqdm import tqdm 8 | from echosyn.common import save_as_img 9 | 10 | decord.bridge.set_bridge('torch') 11 | 12 | """ 13 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Dynamic --output data/reference/dynamic --frames 128 14 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/A4C --output data/reference/ped_a4c --frames 16 15 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/PSAX --output data/reference/ped_psax --frames 16 16 | """ 17 | 18 | if __name__ == "__main__": 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory') 22 | parser.add_argument('--output', type=str, required=True) 23 | parser.add_argument('--frames', type=int, default=128) 24 | parser.add_argument('--fps', type=float, default=32, help='Target FPS of the video before frame extraction. Should match the diffusion model FPS.') 25 | 26 | args = parser.parse_args() 27 | 28 | csv_path = os.path.join(args.dataset, 'FileList.csv') 29 | video_path = os.path.join(args.dataset, 'Videos') 30 | 31 | assert os.path.exists(csv_path), f"Could not find FileList.csv at {csv_path}" 32 | assert os.path.exists(video_path), f"Could not find Videos directory at {video_path}" 33 | 34 | metadata = pd.read_csv(csv_path) 35 | metadata = metadata.sample(frac=1, random_state=42) # shuffle 36 | metadata.reset_index(drop=True, inplace=True) 37 | 38 | extracted_videos = 0 39 | 40 | target_count = {16: 3125, 128: 2048}[args.frames] 41 | 42 | threshold_duration = args.frames / args.fps # 128 -> 4 seconds, 16 -> 0.5 seconds 43 | 44 | for row in tqdm(metadata.iterrows(), total=len(metadata)): 45 | row = row[1] 46 | 47 | video_name = row['FileName'] if row['FileName'].endswith('.avi') else row['FileName'] + '.avi' 48 | video_path = os.path.join(args.dataset, 'Videos', video_name) 49 | 50 | nframes = row['NumberOfFrames'] 51 | fps = row['FPS'] 52 | duration = nframes / fps 53 | 54 | if duration < threshold_duration: 55 | # skip videos which are too short (less than 4 seconds / 128 frames) 56 | continue 57 | 58 | new_frame_count = np.floor(args.fps / fps * nframes).astype(int) 59 | resample_indices = np.linspace(0, nframes, new_frame_count, endpoint=False).round().astype(int) 60 | 61 | assert len(resample_indices) >= args.frames 62 | resample_indices = resample_indices[:args.frames] 63 | 64 | reader = decord.VideoReader(video_path, ctx=decord.cpu(), width=112, height=112) 65 | video = reader.get_batch(resample_indices) # T x H x W x C, uint8 tensor 66 | video = video.float().mean(axis=-1).clamp(0, 255).to(torch.uint8) # T x H x W 67 | video = video.unsqueeze(-1).repeat(1, 1, 1, 3) # T x H x W x 3 68 | 69 | folder_name = video_name[:-4] # remove .avi 70 | 71 | if row['Split'] == 'TRAIN': 72 | # path_all = os.path.join(args.output, "train", folder_name) 73 | path_all = os.path.join(args.output, folder_name) 74 | os.makedirs(path_all, exist_ok=True) 75 | save_as_img(video, path_all) 76 | extracted_videos += 1 77 | 78 | if extracted_videos >= target_count: 79 | print("Reached target count, stopping.") 80 | break 81 | 82 | print(f"Saved {extracted_videos} videos to {args.output}.") 83 | if extracted_videos < target_count: 84 | print(f"WARNING: only saved {extracted_videos} videos, which is less than the target count of {target_count}.") 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /scripts/encode_video_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | from glob import glob 5 | from einops import rearrange 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | import cv2, json 9 | import shutil 10 | import pandas as pd 11 | 12 | import torch 13 | 14 | from diffusers import AutoencoderKL 15 | from echosyn.common import loadvideo, load_model 16 | 17 | """ 18 | usage example: 19 | 20 | python scripts/encode_video_dataset.py \ 21 | -m models/vae \ 22 | -i datasets/EchoNet-Dynamic \ 23 | -o data/latents/dynamic \ 24 | -g 25 | """ 26 | 27 | class VideoDataset(torch.utils.data.Dataset): 28 | def __init__(self, folder): 29 | self.folder = folder 30 | self.videos = sorted(glob(os.path.join(folder, "*.avi"))) 31 | 32 | def __len__(self): 33 | return len(self.videos) 34 | 35 | def __getitem__(self, idx): 36 | video, fps = loadvideo(self.videos[idx], return_fps=True) 37 | return video, self.videos[idx], fps 38 | 39 | if __name__ == "__main__": 40 | # Parse arguments 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("-m", "--model", type=str, required=True, help="Path to model folder") 43 | parser.add_argument("-i", "--input", type=str, required=True, help="Path to input folder") 44 | parser.add_argument("-o", "--output", type=str, required=True, help="Path to output folder") 45 | parser.add_argument("-g", "--gray_scale", action="store_true", help="Convert to gray scale", default=False) 46 | parser.add_argument("-f", "--force_overwrite", action="store_true", help="Overwrite existing latents", default=False) 47 | args = parser.parse_args() 48 | 49 | # Prepare 50 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 51 | os.makedirs(args.output, exist_ok=True) 52 | video_in_folder = os.path.abspath(os.path.join(args.input, "Videos")) 53 | video_out_folder = os.path.abspath(os.path.join(args.output, "Latents")) 54 | 55 | df = pd.read_csv(os.path.join(args.input, "FileList.csv")) 56 | needs_df_update = False 57 | if df['Split'].dtype == int: 58 | print("Updating Split column to string") 59 | needs_df_update = True 60 | df['Fold'] = df['Split'] 61 | 62 | def split_set(row): 63 | if row['Fold'] in range(8): 64 | return 'TRAIN' 65 | elif row['Fold'] == 8: 66 | return 'VAL' 67 | else: 68 | return 'TEST' 69 | 70 | df['Split'] = df.apply(split_set, axis=1) 71 | df['FileName'] = df['FileName'].apply(lambda x: x.split('.')[0]) 72 | 73 | if not needs_df_update: 74 | df.to_csv(os.path.join(args.output, "FileList.csv"), index=False) 75 | 76 | print("Loading videos from ", video_in_folder) 77 | print("Saving latents to ", video_out_folder) 78 | 79 | # Load VAE 80 | vae = load_model(args.model) 81 | vae = vae.to(device) 82 | vae.eval() 83 | 84 | # Load Dataset 85 | ds = VideoDataset(video_in_folder) 86 | dl = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=False, num_workers=8) 87 | print(f"Found {len(ds)} videos") 88 | 89 | batch_size = 32 # number of frames to encode simultaneously 90 | 91 | # for vpath in tqdm(videos): 92 | for video, vpath, fps in tqdm(dl): 93 | 94 | video = video[0] 95 | vpath = vpath[0] 96 | fps = fps[0] 97 | 98 | # output path 99 | opath = vpath.replace(video_in_folder, "")[1:] # retrieve relative path to input folder, similar to basename but keeps the folders 100 | opath = opath.replace(".avi", f".pt") # change extension 101 | opath = os.path.join(video_out_folder, opath) # add output folder 102 | 103 | # check if already exists 104 | if os.path.exists(opath) and not args.force_overwrite: 105 | print(f"Skipping {vpath} as {opath} already exists") 106 | continue 107 | 108 | # load video 109 | # video = loadvideo(vpath) # B H W C 110 | video = rearrange(video, "t h w c-> t c h w") # B C H W 111 | video = video.to(device) 112 | video = video.float() / 128.0 -1 # normalize to [-1, 1] 113 | if args.gray_scale: 114 | video = video.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) # B C H W 115 | 116 | # encode video 117 | all_latents = [] 118 | for i in range(0, len(video), batch_size): 119 | batch = video[i:i+batch_size] 120 | with torch.no_grad(): 121 | latents = vae.encode(batch).latent_dist.sample() 122 | latents = latents * vae.config.scaling_factor 123 | latents = latents.detach().cpu() 124 | all_latents.append(latents) 125 | 126 | all_latents = torch.cat(all_latents, dim=0) 127 | 128 | # save 129 | os.makedirs(os.path.dirname(opath), exist_ok=True) 130 | torch.save(all_latents, opath) 131 | 132 | if needs_df_update: 133 | fname = os.path.basename(opath).split('.')[0] 134 | df.loc[df['FileName'] == fname, ['FileName', 'FrameHeight','FrameWidth','FPS','NumberOfFrames']] = [fname, 112, 112, fps, len(video)] 135 | 136 | if needs_df_update: 137 | df.to_csv(os.path.join(args.output, "FileList.csv"), index=False) 138 | print("Done") 139 | -------------------------------------------------------------------------------- /scripts/extract_frames_from_videos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This scripts takes a folder of videos (.avi) and extract every 5th frame from each video 4 | # and saves it as a grayscale JPG image. The output images are saved in a folder specified 5 | # by the user. 6 | # The script uses ffmpeg and xarg to reduce processing time. 7 | 8 | 9 | usage() { 10 | echo "Usage: $0 " 11 | echo " : Path to the folder containing AVI videos." 12 | echo " : Path to the output folder where images will be saved." 13 | } 14 | 15 | # Check if exactly two arguments are given 16 | if [ $# -ne 2 ]; then 17 | echo "Error: Incorrect number of arguments." 18 | usage 19 | exit 1 20 | fi 21 | 22 | input_folder="$1" 23 | output_folder="$2" 24 | 25 | export output_folder 26 | # Create the output folder if it doesn't exist 27 | mkdir -p "$output_folder" 28 | 29 | # Function to extract frames and save as grayscale JPGs 30 | extract_frames() { 31 | local video_file=$1 32 | local video_name=$(basename -- "$video_file") 33 | local base_name="${video_name%.*}" 34 | 35 | # Create a specific directory for each video's frames 36 | local video_output_folder="$output_folder" 37 | 38 | # Output path for the frames 39 | local output_path="$video_output_folder/${base_name}_%05d.jpg" 40 | 41 | # Extract every 5th frame and save as a grayscale JPG 42 | ffmpeg -loglevel error -i "$video_file" -vf "select=not(mod(n\,5)),format=gray,format=yuv420p" \ 43 | -vsync vfr "$output_path" 44 | 45 | echo -n "." 46 | } 47 | 48 | export -f extract_frames 49 | 50 | echo "Extracting frames from AVI videos in $input_folder and saving in $output_folder" 51 | echo "" 52 | echo "One dot = one video processed:" 53 | echo "" 54 | 55 | # Process each AVI video file 56 | find "$input_folder" -name "*.avi" -print0 | xargs -0 -I {} -P 32 bash -c 'extract_frames "$@"' _ {} 57 | -------------------------------------------------------------------------------- /scripts/update_split_filelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import argparse 4 | 5 | if __name__ == "__main__": 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--csv', type=str, required=True, help='Path to the csv file') 9 | parser.add_argument('--train', type=int, required=True, help='Number of training samples') 10 | parser.add_argument('--val', type=int, required=True, help='Number of validation samples') 11 | parser.add_argument('--test', type=int, required=True, help='Number of testing samples') 12 | args = parser.parse_args() 13 | 14 | df = pd.read_csv(args.csv) 15 | assert 'Split' in df.columns, "FileName column not found in the csv file" 16 | 17 | if (args.train + args.val + args.test) > len(df): 18 | print(f"Warning: There are not enough samples in the csv file ({len(df)}) to split") 19 | print(f"into {args.train} training, {args.val} validation, and {args.test} testing samples ({args.train + args.val + args.test}).") 20 | print(f"Please adjust the number of samples or provide a csv file with more samples.") 21 | exit(1) 22 | elif (args.train + args.val + args.test) < len(df): 23 | print(f"Warning: There are more samples in the csv file ({len(df)}) than needed to split") 24 | print(f"into {args.train} training, {args.val} validation, and {args.test} testing samples ({args.train + args.val + args.test}).") 25 | print(f"Any extra samples will be be put into the \"EXTRA\" split.") 26 | 27 | df.loc[:args.train, 'Split'] = 'TRAIN' 28 | df.loc[args.train:args.train+args.val, 'Split'] = 'VAL' 29 | df.loc[args.train+args.val:args.train+args.val+args.test, 'Split'] = 'TEST' 30 | df.loc[args.train+args.val+args.test:, 'Split'] = 'EXTRA' 31 | 32 | df.to_csv(args.csv, index=False) 33 | print(f"Split the csv file into {args.train} training, {args.val} validation, and {args.test} testing samples.") 34 | -------------------------------------------------------------------------------- /scripts/vae_reconstruct_image_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import numpy as np 6 | from PIL import Image 7 | from glob import glob 8 | from tqdm import tqdm 9 | from einops import rearrange 10 | 11 | from echosyn.common import load_model, save_as_img 12 | 13 | 14 | class ImageLoader(Dataset): 15 | def __init__(self, all_paths): 16 | self.image_paths = all_paths 17 | 18 | def __len__(self): 19 | return len(self.image_paths) 20 | 21 | def __getitem__(self, idx): 22 | path = self.image_paths[idx] 23 | image = Image.open(path) 24 | 25 | image = np.array(image) 26 | image = image / 128.0 - 1 # [-1, 1] 27 | image = rearrange(image, 'h w c -> c h w') 28 | 29 | return image, self.image_paths[idx] 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description="Encode and decode images using a trained VAE model.") 33 | parser.add_argument("-m", "--model", type=str, help="Path to the trained VAE model.") 34 | parser.add_argument("-i", "--input", type=str, help="Path to the input folder.") 35 | parser.add_argument("-o", "--output", type=str, help="Path to the output folder.") 36 | parser.add_argument("-b", "--batch_size", type=int, default=128, help="Batch size to use for encoding and decoding.") 37 | 38 | args = parser.parse_args() 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | # Load the model 42 | model = load_model(args.model) 43 | model.eval() 44 | model.to(device, torch.float32) 45 | 46 | # Load the videos 47 | folder_input = os.path.abspath(args.input) 48 | folder_output = os.path.abspath(args.output) 49 | all_images = glob(os.path.join(folder_input, "**", "*.jpg"), recursive=True) 50 | print(f"Found {len(all_images)} images in {folder_input}") 51 | 52 | # prepare output folder 53 | os.makedirs(folder_output, exist_ok=True) 54 | 55 | # dataset 56 | dataset = ImageLoader(all_images) 57 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=max(args.batch_size, 32)) 58 | 59 | for batch in tqdm(dataloader): 60 | images, paths = batch 61 | images = images.to(device, torch.float32) 62 | 63 | # Encode the video 64 | with torch.no_grad(): 65 | reconstructed_images = model(images).sample 66 | 67 | reconstructed_images = rearrange(reconstructed_images, 'b c h w -> b h w c') 68 | reconstructed_images = (reconstructed_images + 1) * 128.0 69 | reconstructed_images = reconstructed_images.clamp(0, 255).cpu().to(torch.uint8) 70 | 71 | # Save the reconstructed images 72 | for i, path in enumerate(paths): 73 | new_path = path.replace(folder_input, folder_output) 74 | os.makedirs(os.path.dirname(new_path), exist_ok=True) 75 | save_as_img(reconstructed_images[i], new_path) 76 | 77 | print(f"All reconstructed images saved to {folder_output}") -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='echosyn', 5 | version='1.0', 6 | packages=find_packages(), 7 | install_requires=[ 8 | 'torch', 9 | 'torchvision', 10 | 'einops', 11 | 'decord', 12 | 'diffusers', 13 | 'packaging', 14 | 'omegaconf', 15 | 'transformers', 16 | 'accelerate', 17 | 'pandas', 18 | 'wandb', 19 | 'opencv-python', 20 | 'moviepy', 21 | 'imageio', 22 | 'scipy', 23 | 'scikit-learn', 24 | 'matplotlib', 25 | ] 26 | ) --------------------------------------------------------------------------------