├── .gitignore
├── README.md
├── echosyn
├── common
│ ├── __init__.py
│ ├── datasets.py
│ └── privacy_utils.py
├── lidm
│ ├── README.md
│ ├── configs
│ │ ├── dynamic.yaml
│ │ ├── ped_a4c.yaml
│ │ └── ped_psax.yaml
│ ├── sample.py
│ └── train.py
├── lvdm
│ ├── README.md
│ ├── configs
│ │ └── default.yaml
│ ├── sample.py
│ └── train.py
├── privacy
│ ├── README.md
│ ├── apply.py
│ ├── configs
│ │ ├── config_dynamic.json
│ │ ├── config_ped_a4c.json
│ │ └── config_ped_psax.json
│ ├── shared.py
│ └── train.py
└── vae
│ ├── README.md
│ └── usencoder_kl_16x16x4.yaml
├── external
└── README.md
├── requirements.txt
├── ressources
├── models.jpg
├── mosaic.gif
├── mosaic_slim.gif
├── real.gif
└── synt.gif
├── scripts
├── complete_pediatrics_filelist.py
├── convert_vae_pt_to_diffusers.py
├── copy_privacy_compliant_images.sh
├── create_reference_dataset.py
├── encode_video_dataset.py
├── extract_frames_from_videos.sh
├── update_split_filelist.py
└── vae_reconstruct_image_folder.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | external/*
163 | !external/README.md
164 | models/
165 | data/
166 | datasets/
167 | experiments/
168 | wandb/
169 | samples/
170 | .vscode/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EchoNet-Synthetic
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## Deep-Dive Podcast (AI generated)
11 |
12 |
13 | https://github.com/user-attachments/assets/06d0811b-f2fc-442f-bde4-0885d8dfd319
14 |
15 | *Note: Turn on audio !*
16 |
17 | ## Introduction
18 |
19 | This repository contains the code and model weights for the paper *[EchoNet-Synthetic: Privacy-preserving Video Generation for Safe Medical Data Sharing](https://arxiv.org/abs/2406.00808)*. Hadrien Reynaud, Qingjie Meng, Mischa Dombrowski, Arijit Ghosh, Alberto Gomez, Paul Leeson and Bernhard Kainz. MICCAI 2024.
20 |
21 | EchoNet-Synthetic presents a protocol to generate surrogate privacy-compliant datasets that are as valuable as their original counterparts to train downstream models (e.g. regression models).
22 |
23 | In this repository, we present the code we use for the experiments in the paper. We provide the code to train the models, generate the synthetic data, and evaluate the quality of the synthetic data.
24 | We also provide all the pre-trained models and release the synthetic datasets we generated.
25 |
26 | 📜 Read the Paper [on arXiv](https://arxiv.org/abs/2406.00808)
27 | 📕 [MICCAI 2024 Proceedings](https://link.springer.com/chapter/10.1007/978-3-031-72104-5_28)
28 | 🤗 Try our interactive demo [on HuggingFace](https://huggingface.co/spaces/HReynaud/EchoNet-Synthetic), it contains all the generative pipeline inference code and weights !
29 |
30 | 
31 |
32 | *Exemple of synthetic videos generated with EchoNet-Synthetic. First Video is real, others are generated.*
33 |
34 | ## Table of contents
35 | 1. [Environment setup](#environment-setup)
36 | 2. [Data preparation](#data-preparation)
37 | 3. [The models](#the-models)
38 | 4. [Generating EchoNet-Synthetic](#generating-echonet-synthetic)
39 | 5. [Evaluation](#evaluation)
40 | 6. [Results](#results)
41 | 7. [Citation](#citation)
42 |
43 | ## Environment setup
44 |
46 |
47 | First, we need to set up the environment. We use the following command to create a new conda environment with the required dependencies.
48 |
49 | ```bash
50 | conda create -y -n echosyn python=3.11
51 | conda activate echosyn
52 | pip install -e .
53 | ```
54 | *Note: the exact version of each package can be found in requirements.txt if necessary*
55 |
56 | This repository lets you train three models:
57 | - the Latent Image Diffusion Model (LIDM)
58 | - the Re-Indentification models for privacy checks
59 | - the Latent Video Diffusion Model (LVDM)
60 |
61 | We rely on external libraries to:
62 | - train the Variational Auto-Encoder (VAE) (Stable-Diffusion and Taming-Transformers)
63 | - evaluate the generated images and videos (StyleGAN-V)
64 | - evaluate the synthetic data on the Ejection Fraction downstream task (EchoNet-Dynamic)
65 |
66 | How to install the external libraries is explained in the [External libraries](external/README.md) section.
67 |
68 |
69 |
70 |
71 | ## Data preparation
72 |
74 |
75 | ### ➡ Original datasets
76 | Download the EchoNet-Dynamic dataset from [here](https://echonet.github.io/dynamic/) and the EchoNet-Pediatric dataset from [here](https://echonet.github.io/pediatric/). The datasets are available for free upon request. Once downloaded, extract the content of the archive in the `datasets` folder. For simplicity and consistency, we structure them like so:
77 | ```
78 | datasets
79 | ├── EchoNet-Dynamic
80 | │ ├── Videos
81 | │ ├── FileList.csv
82 | │ └── VolumeTracings.csv
83 | └── EchoNet-Pediatric
84 | ├── A4C
85 | │ ├── Videos
86 | │ ├── FileList.csv
87 | │ └── VolumeTracings.csv
88 | └── PSAX
89 | ├── Videos
90 | ├── FileList.csv
91 | └── VolumeTracings.csv
92 | ```
93 |
94 | To harmonize the datasets, we add some information to the `FileList.csv` files of the EchoNet-Pediatric dataset, namely FrameHeight, FrameWidth, FPS, NumberOfFrames. We also arbitrarily set the splits from the 10-fold indices to a simple TRAIN/VAL/TEST split. These updates ares applied with the following command:
95 |
96 | ```bash
97 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/A4C
98 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/PSAX
99 | ```
100 |
101 | This is crucial for the other scripts to work properly.
102 |
103 | ### ➡ Image datasets for VAE training
104 | See the [VAE training](echosyn/vae/README.md) section To see how to train the VAE.
105 |
106 | The VAE needs images only and to keep things simple, we format our video datasets into image datasets.
107 | We do that with:
108 | ```bash
109 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Dynamic/Videos data/vae_train_images/images/
110 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/A4C/Videos data/vae_train_images/images/
111 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/PSAX/Videos data/vae_train_images/images/
112 | ```
113 |
114 | Note that this will merge all the images in the same folder.
115 |
116 | Then, we need to create train.txt file and a val.txt file containing the path to these images.
117 | ```bash
118 | find $(cd data/vae_train_images/images && pwd) -type f | shuf > tmp.txt
119 | head -n -1000 tmp.txt > data/vae_train_images/train.txt
120 | tail -n 1000 tmp.txt > data/vae_train_images/val.txt
121 | rm tmp.txt
122 | ```
123 |
124 | ### ➡ Latent Video datasets for LIDM / Privacy / LVDM training
125 |
126 | The LIDM, Re-Identification model and LVDM are trained on pre-encoded latent representations of the videos. To encode the videos, we use the image VAE. You can either retrain the VAE or download it from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/vae). Once you have the VAE, you can encode the videos with the following command:
127 |
128 | ```bash
129 | # For the EchoNet-Dynamic dataset
130 | python scripts/encode_video_dataset.py \
131 | --model models/vae \
132 | --input datasets/EchoNet-Dynamic \
133 | --output data/latents/dynamic \
134 | --gray_scale
135 | ```
136 | ```bash
137 | # For the EchoNet-Pediatric datasets
138 | python scripts/encode_video_dataset.py \
139 | --model models/vae \
140 | --input datasets/EchoNet-Pediatric/A4C \
141 | --output data/latents/ped_a4c \
142 | --gray_scale
143 |
144 | python scripts/encode_video_dataset.py \
145 | --model models/vae \
146 | --input datasets/EchoNet-Pediatric/PSAX \
147 | --output data/latents/ped_psax \
148 | --gray_scale
149 | ```
150 |
151 | ### ➡ Validation datasets
152 |
153 | To quantitatively evaluate the quality of the generated images and videos, we use the StyleGAN-V repo.
154 | We cover the evaluation process in the [Evaluation](#evaluation) section.
155 | To enable this evaluation, we need to prepare the validation datasets. We do that with the following command:
156 |
157 | ```bash
158 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Dynamic --output data/reference/dynamic --frames 128
159 | ```
160 |
161 | ```bash
162 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/A4C --output data/reference/ped_a4c --frames 16
163 | ```
164 |
165 | ```bash
166 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/PSAX --output data/reference/ped_psax --frames 16
167 | ```
168 |
169 | Note that the Pediatric datasets do not support 128 frames, preventing the computation of FVD_128, because there are not enough videos lasting more 4 seconds or more. We therefore only extract 16 frames per video for these datasets.
170 |
171 |
172 |
173 | ## The Models
174 |
176 |
177 | 
178 |
179 | *Our pipeline, using all our models: LIDM, Re-Identification (Privacy), LVDM and VAE*
180 |
181 |
182 | ### The VAE
183 |
184 | You can download the pretrained VAE from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/vae) or train it yourself by following the instructions in the [VAE training](echosyn/vae/README.md) section.
185 |
186 | ### The LIDM
187 |
188 | You can download the pretrained LIDMs from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/lidm_dynamic) or train them yourself by following the instructions in the [LIDM training](echosyn/lidm/README.md) section.
189 |
190 | ### The Re-Identification models
191 |
192 | You can download the pretrained Re-Identification models from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/reidentification_dynamic) or train them yourself by following the instructions in the [Re-Identification training](echosyn/privacy/README.md) section.
193 |
194 | ### The LVDM
195 |
196 | You can download the pretrained LVDM from [here](https://huggingface.co/HReynaud/EchoNet-Synthetic/tree/main/lvdm) or train it yourself by following the instructions in the [LVDM training](echosyn/lvdm/README.md) section.
197 |
198 | ### Structure
199 |
200 | The models should be structured as follows:
201 | ```
202 | models
203 | ├── lidm_dynamic
204 | ├── lidm_ped_a4c
205 | ├── lidm_ped_psax
206 | ├── lvdm
207 | ├── regression_dynamic
208 | ├── regression_ped_a4c
209 | ├── regression_ped_psax
210 | ├── reidentification_dynamic
211 | ├── reidentification_ped_a4c
212 | ├── reidentification_ped_psax
213 | └── vae
214 | ```
215 |
216 |
217 |
218 | ## Generating EchoNet-Synthetic
219 |
221 |
222 | Now that we have all the necessary models, we can generate the synthetic datasets. The process is the same for all three datasets and involves the following steps:
223 | - Generate a collection of latent heart images with the LIDMs (usually 2x the amount of videos we are targetting)
224 | - Apply the privacy check, which will filter out some of the latent images
225 | - Generate the videos with the LVDM, and decode them with the VAE
226 |
227 | #### Dynamic dataset
228 | For the dynamic dataset, we aim for 10,030 videos, so we start by generating 20,000 latent images:
229 | ```bash
230 | python echosyn/lidm/sample.py \
231 | --config models/lidm_dynamic/config.yaml \
232 | --unet models/lidm_dynamic \
233 | --vae models/vae \
234 | --output samples/synthetic/lidm_dynamic \
235 | --num_samples 20000 \
236 | --batch_size 128 \
237 | --num_steps 64 \
238 | --save_latent \
239 | --seed 0
240 | ```
241 |
242 | Then, we filter the latent images with the re-identification model:
243 | ```bash
244 | python echosyn/privacy/apply.py \
245 | --model models/reidentification_dynamic \
246 | --synthetic samples/synthetic/lidm_dynamic/latents \
247 | --reference data/latents/dynamic \
248 | --output samples/synthetic/lidm_dynamic/privatised_latents
249 | ```
250 |
251 | We generate the synthetic videos with the LVDM:
252 | ```bash
253 | python echosyn/lvdm/sample.py \
254 | --config models/lvdm/config.yaml \
255 | --unet models/lvdm \
256 | --vae models/vae \
257 | --conditioning samples/synthetic/lidm_dynamic/privatised_latents \
258 | --output samples/synthetic/lvdm_dynamic \
259 | --num_samples 10030 \
260 | --batch_size 8 \
261 | --num_steps 64 \
262 | --min_lvef 10 \
263 | --max_lvef 90 \
264 | --save_as avi \
265 | --frames 192
266 | ```
267 |
268 | Finally, update the content of the FileList.csv to respect the right amount of videos in each split, and the correct naming convention.
269 | ```bash
270 | mkdir -p datasets/EchoNet-Synthetic/dynamic
271 | mv samples/synthetic/lvdm_dynamic/avi datasets/EchoNet-Synthetic/dynamic/Videos
272 | cp samples/synthetic/lvdm_dynamic/FileList.csv datasets/EchoNet-Synthetic/dynamic/FileList.csv
273 | python scripts/update_split_filelist.py --csv datasets/EchoNet-Synthetic/dynamic/FileList.csv --train 7465 --val 1288 --test 1277
274 | ```
275 |
276 | You can use the exact same process for the Pediatric datasets, just make sure to use the right models.
277 | The A4C dataset should have 3284 videos, with 2580 train, 336 validation and 368 test videos.
278 | 3559 448 519
279 | The PSAX dataset should have 4526 videos, with 3559 train, 448 validation and 519 test videos.
280 | They also have a different number of frames per video, which is 128 insted of 192.
281 |
282 |
283 |
284 | ## Evaluation
285 |
287 |
288 | As the final step, we evaluate the quality of EchoNet-Synthetic videos by training a Ejection Fraction regression model on the synthetic data and evaluating it on the real data.
289 | To do so, we use the EchoNet-Dynamic repository.
290 | You will need to clone a slightly modifier version of the repository from [here](https://github.com/HReynaud/echonet), and put it in the `external` folder.
291 |
292 | Once that is done, you will need to switch environment to the echonet one.
293 | To create the echonet environment, you can use the following command:
294 | ```bash
295 | cd external/echonet
296 | conda create -y -n echonet python=3.10
297 | conda activate echonet
298 | pip install -e .
299 | ```
300 |
301 | Then, you should double check that everything is working by reproducing the results of the EchoNet-Dynamic repository on the real EchoNet-Dynamic dataset. This should take ~12 hours on a single GPU.
302 | ```bash
303 | echonet video --data_dir ../../datasets/EchoNet-Dynamic \
304 | --output ../../experiments/regression_echonet_dynamic \
305 | --pretrained \
306 | --run_test \
307 | --num_epochs 45
308 | ```
309 |
310 | Once you have confirmed that everything is working, you can train the same model on the synthetic data, while evaluating on the real data.
311 | ```bash
312 | echonet video --data_dir ../../datasets/EchoNet-Synthetic/dynamic \
313 | --data_dir_real ../../datasets/EchoNet-Dynamic \
314 | --output ../../experiments/regression_echosyn_dynamic \
315 | --pretrained \
316 | --run_test \
317 | --num_epochs 45 \
318 | --period 1
319 | ```
320 |
321 | This will automatically compute the final metrics after training, and store them in the `experiments/regression_echosyn_dynamic/log.csv` folder.
322 |
323 | The process is the same for the Pediatric datasets, although they start from a pretrained model, trained on the real EchoNet-Dynamic data, which would look like this:
324 | ```bash
325 | echonet video --data_dir ../../datasets/EchoNet-Pediatric/A4C \
326 | --data_dir_real ../../datasets/EchoNet-Pediatric/A4C \
327 | --output ../../experiments/regression_echosyn_peda4c \
328 | --weights ../../experiments/regression_real/best.pt \
329 | --run_test \
330 | --num_epochs 45 \
331 | --period 1
332 | ```
333 |
334 |
335 | Here are two tricks that can improve the quality and alignment of the synthetic data with the real data:
336 | - Re-label the synthetic data with the regression model trained on the real data.
337 | - Use the Ejection Fraction scores from the real dataset to create the synthetic dataset, not uniformaly distributed ones.
338 |
339 |
340 |
341 | ## Results
342 |
344 |
345 | Here is a side by side comparison between a real video and a synthetic video generated with EchoNet-Synthetic. We use a frame from the real video and its corresponding ejected fraction score to "reproduce" the real video.
346 |
347 |
348 |
349 | Real Video |
350 | Reproduction |
351 |
352 |
353 |  |
354 |  |
355 |
356 |
357 |
358 | Here we show a collection of synthetic videos generated with EchoNet-Synthetic. We can see that the quality of the videos is excellent, and show a variety of ejection fraction scores.
359 |
360 | 
361 |
362 |
363 |
364 | # Acknowledgements
365 | This work was supported by Ultromics Ltd. and the UKRI Centre for Doctoral Training in Artificial Intelligence for Healthcare (EP/S023283/1).
366 | HPC resources were provided by the Erlangen National High Performance Computing Center (NHR@FAU) of the Friedrich-Alexander-Universität Erlangen-Nürnberg (FAU) under the NHR projects b143dc and b180dc. NHR funding is provided by federal and Bavarian state authorities. NHR@FAU hardware is partially funded by the German Research Foundation (DFG) – 440719683.
367 |
368 |
369 | ## Citation
370 |
371 | ```
372 | @inproceedings{reynaud2024echonet,
373 | title={Echonet-synthetic: Privacy-preserving video generation for safe medical data sharing},
374 | author={Reynaud, Hadrien and Meng, Qingjie and Dombrowski, Mischa and Ghosh, Arijit and Day, Thomas and Gomez, Alberto and Leeson, Paul and Kainz, Bernhard},
375 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
376 | pages={285--295},
377 | year={2024},
378 | organization={Springer}
379 | }
380 | ```
381 |
--------------------------------------------------------------------------------
/echosyn/common/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import importlib
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from einops import rearrange
11 | from omegaconf import OmegaConf
12 | import imageio
13 |
14 | import diffusers
15 |
16 | def padf(tensor, mult=3):
17 | """
18 | Pads a tensor along the last dimension to make its size a multiple of 2^mult.
19 |
20 | Args:
21 | tensor (torch.Tensor): The tensor to pad.
22 | mult (int, optional): The power of 2 that the tensor's size should be a multiple of. Defaults to 3.
23 |
24 | Returns:
25 | torch.Tensor: The padded tensor.
26 | int: The amount of padding applied.
27 | """
28 | pad = 2**mult - (tensor.shape[-1] % 2**mult)
29 | pad = pad//2
30 | tensor = F.pad(tensor, (pad, pad, pad, pad, 0, 0), mode='replicate')
31 | return tensor, pad
32 |
33 | def unpadf(tensor, pad=1):
34 | """
35 | Removes padding from a tensor along the last two dimensions.
36 |
37 | Args:
38 | tensor (torch.Tensor): The tensor to unpad.
39 | pad (int, optional): The amount of padding to remove. Defaults to 1.
40 |
41 | Returns:
42 | torch.Tensor: The unpadded tensor.
43 | """
44 | return tensor[..., pad:-pad, pad:-pad]
45 |
46 | def pad_reshape(tensor, mult=3):
47 | """
48 | Pads a tensor along the last dimension to make its size a multiple of 2^mult and reshapes it.
49 |
50 | Args:
51 | tensor (torch.Tensor): The tensor to pad and reshape.
52 | mult (int, optional): The power of 2 that the tensor's size should be a multiple of. Defaults to 3.
53 |
54 | Returns:
55 | torch.Tensor: The padded and reshaped tensor.
56 | int: The amount of padding applied.
57 | """
58 | tensor, pad = padf(tensor, mult=mult)
59 | tensor = rearrange(tensor, "b c t h w -> b t c h w")
60 | return tensor, pad
61 |
62 | def unpad_reshape(tensor, pad=1):
63 | """
64 | Reshapes a tensor and removes padding from it along the last two dimensions.
65 |
66 | Args:
67 | tensor (torch.Tensor): The tensor to reshape and unpad.
68 | pad (int, optional): The amount of padding to remove. Defaults to 1.
69 |
70 | Returns:
71 | torch.Tensor: The reshaped and unpadded tensor.
72 | """
73 | tensor = rearrange(tensor, "b t c h w -> b c t h w")
74 | tensor = unpadf(tensor, pad=pad)
75 | return tensor
76 |
77 | def instantiate_from_config(config, scope: list[str], return_klass_kwargs=False, **kwargs):
78 | """
79 | Instantiate a class from a config dictionary.
80 |
81 | Args:
82 | config (dict): The config dictionary.
83 | scope (list[str]): The scope of the class to instantiate.
84 | return_klass_kwargs (bool, optional): Whether to return the class and its kwargs. Defaults to False.
85 | **kwargs: Additional keyword arguments to pass to the class constructor.
86 |
87 | Returns:
88 | object: The instantiated class.
89 | (optional) type: The class that was instantiated.
90 | (optional) dict: The kwargs that were passed to the class constructor.
91 | """
92 | okwargs = OmegaConf.to_container(config, resolve=True)
93 | klass_name = okwargs.pop("_class_name")
94 | klass = None
95 |
96 | for module_name in scope:
97 | try:
98 | module = importlib.import_module(module_name)
99 | except ImportError:
100 | continue # Try next module
101 |
102 | klass = getattr(module, klass_name, None)
103 | if klass is not None:
104 | break # Stop when we find a matching class
105 |
106 | assert klass is not None, f"Could not find class {klass_name} in the specified scope"
107 | instance = klass(**okwargs, **kwargs)
108 |
109 | if return_klass_kwargs:
110 | return instance, klass, okwargs
111 | return instance
112 |
113 | def load_model(path):
114 | """
115 | Loads a model from a checkpoint.
116 |
117 | Args:
118 | path (str): The path to the checkpoint.
119 |
120 | Returns:
121 | object: The loaded model.
122 | """
123 | # find config.json
124 | json_path = os.path.join(path, "config.json")
125 | assert os.path.exists(json_path), f"Could not find config.json at {json_path}"
126 | with open(json_path, "r") as f:
127 | config = json.load(f)
128 |
129 | # instantiate class
130 | klass_name = config["_class_name"]
131 | klass = getattr(diffusers, klass_name, None)
132 | if klass is None:
133 | klass = globals().get(klass_name, None)
134 | assert klass is not None, f"Could not find class {klass_name} in diffusers or global scope."
135 | assert getattr(klass, "from_pretrained", None) is not None, f"Class {klass_name} does not support 'from_pretrained'."
136 |
137 | # load checkpoint
138 | model = klass.from_pretrained(path)
139 |
140 | return model
141 |
142 | def save_as_mp4(tensor, filename, fps=30):
143 | """
144 | Saves a 4D tensor (nFrames, height, width, channels) as an MP4 video.
145 |
146 | Parameters:
147 | - tensor: 4D torch.Tensor. Tensor containing the video frames.
148 | - filename: str. The output filename for the video.
149 | - fps: int. Frames per second for the output video.
150 |
151 | Returns:
152 | - None
153 | """
154 | import imageio
155 | # Make sure the tensor is on the CPU and is a numpy array
156 | np_video = tensor.cpu().numpy()
157 |
158 | # Ensure the tensor dtype is uint8
159 | if np_video.dtype != np.uint8:
160 | raise ValueError("The tensor has to be of type uint8")
161 |
162 | # Write the frames to a video file
163 | with imageio.get_writer(filename, fps=fps, ) as writer:
164 | for i in range(np_video.shape[0]):
165 | writer.append_data(np_video[i])
166 |
167 | def save_as_avi(tensor, filename, fps=30):
168 | """
169 | Saves a 4D tensor (nFrames, height, width, channels) as an AVI video with reduced compression.
170 |
171 | Parameters:
172 | - tensor: 4D torch.Tensor. Tensor containing the video frames.
173 | - filename: str. The output filename for the video.
174 | - fps: int. Frames per second for the output video.
175 |
176 | Returns:
177 | - None
178 | """
179 | # Make sure the tensor is on the CPU and is a numpy array
180 | np_video = tensor.cpu().numpy()
181 |
182 | # Ensure the tensor dtype is uint8
183 | if np_video.dtype != np.uint8:
184 | raise ValueError("The tensor has to be of type uint8")
185 |
186 | # Define codec for reduced compression
187 | codec = "mjpeg" # MJPEG codec for AVI files
188 | quality = 10 # High quality (lower values mean higher quality, but larger file sizes)
189 | # pixel_format = "yuvj420p"
190 | # Write the frames to a video file
191 | with imageio.get_writer(filename, fps=fps, codec=codec, quality=quality) as writer:
192 | for frame in np_video:
193 | writer.append_data(frame)
194 |
195 | def save_as_gif(tensor, filename, fps=30):
196 | """
197 | Saves a 4D tensor (nFrames, height, width, channels) as a GIF.
198 |
199 | Parameters:
200 | - tensor: 4D torch.Tensor. Tensor containing the video frames.
201 | - filename: str. The output filename for the GIF.
202 | - fps: int. Frames per second for the output GIF.
203 |
204 | Returns:
205 | - None
206 | """
207 | import imageio
208 | # Make sure the tensor is on the CPU and is a numpy array
209 | np_video = tensor.cpu().numpy()
210 |
211 | # Ensure the tensor dtype is uint8
212 | if np_video.dtype != np.uint8:
213 | raise ValueError("The tensor has to be of type uint8")
214 |
215 | # Write the frames to a GIF file
216 | imageio.mimsave(filename, np_video, fps=fps)
217 |
218 | def save_as_img(tensor, filename, ext="jpg"):
219 | """
220 | Saves a 4D tensor (nFrames, height, width, channels) as a series of JPG images.
221 | OR
222 | Saves a 3D tensor (height, width, channels) as a single image.
223 |
224 | Parameters:
225 | - tensor: 4D torch.Tensor. Tensor containing the video frames.
226 | - filename: str. The output filename for the JPG images.
227 |
228 | Returns:
229 | - None
230 | """
231 | import imageio
232 | # Make sure the tensor is on the CPU and is a numpy array
233 | np_video = tensor.cpu().numpy()
234 |
235 | # Ensure the tensor dtype is uint8
236 | if np_video.dtype != np.uint8:
237 | raise ValueError("The tensor has to be of type uint8")
238 |
239 | # Write the frames to a series of JPG files
240 | if len(np_video.shape) == 3:
241 | imageio.imwrite(filename, np_video, quality=100)
242 | else:
243 | os.makedirs(filename, exist_ok=True)
244 | for i in range(np_video.shape[0]):
245 | imageio.imwrite(os.path.join(filename, f"{i:04d}.{ext}"), np_video[i], quality=100)
246 |
247 | def loadvideo(filename: str, return_fps=False):
248 | """
249 | Loads a video file into a tensor of frames.
250 |
251 | Args:
252 | filename (str): The path to the video file.
253 | return_fps (bool, optional): Whether to return the frames per second of the video. Defaults to False.
254 |
255 | Raises:
256 | FileNotFoundError: If the video file does not exist.
257 |
258 | Returns:
259 | torch.Tensor: A tensor of the video's frames, with shape (frames, 3, height, width).
260 | (optional) float: The frames per second of the video. Only returned if return_fps is True.
261 | """
262 |
263 | if not os.path.exists(filename):
264 | raise FileNotFoundError(filename)
265 | capture = cv2.VideoCapture(filename) # type: ignore
266 |
267 | fps = capture.get(cv2.CAP_PROP_FPS) # type: ignore
268 |
269 | frames = []
270 |
271 | while True: # load all frames
272 | ret, frame = capture.read()
273 | if not ret:
274 | break # Reached end of video
275 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
276 | frame = torch.from_numpy(frame)
277 |
278 | frames.append(frame)
279 | capture.release()
280 |
281 | frames = torch.stack(frames, dim=0) # (frames, 3, height, width)
282 |
283 | if return_fps:
284 | return frames, fps
285 | return frames
286 |
287 | def parse_formats(s):
288 | # Split the input string by comma and strip spaces
289 | formats = [format.strip().lower() for format in s.split(',')]
290 | # Define the allowed choices
291 | allowed_formats = ["avi", "mp4", "gif", "jpg", "png", "pt"]
292 | # Check if all elements in formats are in allowed_formats
293 | for format in formats:
294 | if format not in allowed_formats:
295 | raise argparse.ArgumentTypeError(f"{format} is not a valid format. Choose from {', '.join(allowed_formats)}.")
296 | return formats
297 |
298 |
299 |
300 |
301 |
302 |
303 |
--------------------------------------------------------------------------------
/echosyn/common/datasets.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import pandas as pd
4 | from glob import glob
5 |
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader, ConcatDataset, Sampler
8 | from torchvision import transforms as T
9 | from torchvision.transforms.functional import center_crop
10 |
11 | from PIL import Image
12 |
13 | import decord
14 | decord.bridge.set_bridge('torch')
15 |
16 |
17 | # support video and image + additional info (lvef, view, etc)
18 | class Dynamic(Dataset):
19 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"], datafolder="Videos", ext=".avi") -> None:
20 | super().__init__()
21 | # the config here is only the config for this dataset, ie config.dataset.dynamic
22 |
23 | if type(split) == str:
24 | split = [split]
25 | assert [s in ["TRAIN", "VAL", "TEST"] for s in split], "Splits must be a list of TRAIN, VAL, TEST"
26 |
27 | # assert type(config.target_fps) == int or config.target_fps in ["original", "random", "exponential"], "target_fps must be an integer, 'original', 'random' or 'exponential'"
28 | self.target_fps = config.target_fps
29 | # self.duration_seconds = config.target_duration
30 | self.resolution = config.target_resolution
31 | self.outputs = config.outputs
32 | if type(self.outputs) == str:
33 | self.outputs = [self.outputs]
34 | assert [o in ["video", "image", "lvef"] for o in self.outputs], "Outputs must be a list of video, image, lvef"
35 |
36 | # self.duration_frames = int(self.target_fps * self.duration_seconds)
37 | self.duration_frames = config.target_nframes
38 | self.duration_seconds = self.duration_frames / self.target_fps if type(self.target_fps) == int else None
39 |
40 | # LOAD DATA
41 | assert hasattr(config, "root"), "No root folder specified in config"
42 | assert os.path.exists(os.path.join(config.root, datafolder)), f"Data folder {os.path.join(config.root, datafolder)} does not exist"
43 | assert os.path.exists(os.path.join(config.root, "FileList.csv")), f"FileList.csv does not exist in {config.root}"
44 | self.metadata = pd.read_csv(os.path.join(config.root, "FileList.csv"))
45 | self.metadata = self.metadata[self.metadata["Split"].isin(split)] # filter by split
46 | self.len_before_filter = len(self.metadata)
47 | # add duration column
48 | self.metadata["Duration"] = self.metadata["NumberOfFrames"] / self.metadata["FPS"] # won't work for pediatrics
49 | # filter by duration
50 | if self.duration_seconds is not None:
51 | self.metadata = self.metadata[self.metadata["Duration"] > self.duration_seconds]
52 |
53 | # check if videos are reachable
54 | self.metadata["VideoPath"] = self.metadata["FileName"].apply(lambda x: os.path.join(config.root, datafolder, x) if x.endswith(ext) else os.path.join(config.root, datafolder, x.split('.')[0] + ext))
55 | self.metadata["VideoExists"] = self.metadata["VideoPath"].apply(lambda x: os.path.exists(x))
56 | self.metadata = self.metadata[self.metadata["VideoExists"]]
57 | self.metadata.reset_index(inplace=True, drop=True)
58 | if len(self.metadata) == 0:
59 | raise ValueError(f"No data found in folder {os.path.join(config.root, datafolder)}")
60 |
61 | self.transform = lambda x: x
62 | if hasattr(config, "transforms"):
63 | transforms = []
64 | for transform in config.transforms:
65 | tklass = getattr(T, transform.name)
66 | tobj = tklass(**transform.params)
67 | transforms.append(tobj)
68 | self.transform = T.Compose(transforms)
69 |
70 | def __len__(self):
71 | return len(self.metadata)
72 |
73 | def __getitem__(self, idx, return_row=False):
74 | row = self.metadata.iloc[idx]
75 | output = {
76 | 'filename': row['FileName'],
77 | 'still': False,
78 | }
79 |
80 | if "image" in self.outputs or "video" in self.outputs:
81 | reader = decord.VideoReader(row["VideoPath"], ctx=decord.cpu(), width=self.resolution, height=self.resolution)
82 | og_fps = reader.get_avg_fps()
83 | og_frame_count = len(reader)
84 |
85 | if "video" in self.outputs:
86 | # Generate indices to resample
87 | # Generate a random starting point to cover all frames
88 | if self.target_fps == "original":
89 | target_fps = og_fps
90 | elif self.target_fps == "random":
91 | target_fps = np.random.randint(16, 120)
92 | elif self.target_fps == "half":
93 | target_fps = int(og_fps//2)
94 | elif self.target_fps == "exponential":
95 | rnd, offset = np.random.randint(0, 100), 11
96 | target_fps = int(np.exp(rnd/offset) + offset) # min: 12, max: ~8000
97 | else:
98 | target_fps = self.target_fps
99 | new_frame_count = np.floor(target_fps / og_fps * og_frame_count).astype(int)
100 | resample_indices_a = np.linspace(0, og_frame_count-1, new_frame_count, endpoint=False).round().astype(int)
101 | start_idx = np.random.choice(np.arange(0, resample_indices_a[1])) if len(resample_indices_a) > 1 and resample_indices_a[1] > 1 else 0
102 | resample_indices_a = resample_indices_a + start_idx
103 |
104 | # Sample a random chunk to cover the requested duration
105 | start_idx = np.random.choice(np.arange(0, len(resample_indices_a) - self.duration_frames)) if len(resample_indices_a) > self.duration_frames else 0
106 | end_idx = start_idx + self.duration_frames
107 | resample_indices_b = resample_indices_a[start_idx:end_idx]
108 | resample_indices_b = resample_indices_b[resample_indices_b C x T x H x W
121 | output["video"] = self.transform(video)
122 | output["fps"] = target_fps
123 | output["padding"] = p_index
124 | if self.target_fps == "exponential":
125 | resample_indices_b[1:] = (resample_indices_b[1:] - resample_indices_b[:-1] >= 1).cumsum(0)
126 | resample_indices_b[0] = 0
127 | output["indices"] = np.concatenate((resample_indices_b, np.repeat(resample_indices_b[-1], self.duration_frames-len(resample_indices_b))))
128 |
129 | if "lvef" in self.outputs:
130 | lvef = row["EF"] / 100.0 # normalize to [0, 1]
131 | output["lvef"] = torch.tensor(lvef, dtype=torch.float32)
132 |
133 | if "image" in self.outputs:
134 | image = reader.get_batch(np.random.randint(0, og_frame_count, 1))[0] # H x W x C, uint8
135 | image = image.float() / 128.0 - 1
136 | image = image.permute(2, 0, 1) # H x W x C -> C x H x W
137 | output["image"] = self.transform(image)
138 |
139 | if return_row:
140 | return output, row
141 |
142 | return output
143 |
144 |
145 | class Pediatric(Dynamic):
146 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None:
147 | super().__init__(config, split)
148 |
149 | # View
150 | self.view = config.get("views", "ALL") # A4C, PSAX, ALL
151 | if self.view == "ALL":
152 | pass
153 | else:
154 | self.metadata = self.metadata[self.metadata["View"] == self.view]
155 | self.metadata.reset_index(inplace=True, drop=True)
156 | if len(self.metadata) == 0:
157 | raise ValueError(f"No videos found for view {self.view}")
158 |
159 | def __getitem__(self, idx):
160 | output, row = super().__getitem__(idx, return_row=True)
161 | if "view" in self.outputs:
162 | output["view"] = row["View"]
163 |
164 | return output
165 |
166 |
167 | class Latent(Dynamic):
168 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None:
169 | self.config = config
170 |
171 | super().__init__(config, split, datafolder="Latents", ext=".pt")
172 |
173 | self.view = config.get("views", "ALL") # A4C, PSAX, ALL
174 | if self.view == "ALL":
175 | pass
176 | else:
177 | self.metadata = self.metadata[self.metadata["View"] == self.view]
178 | self.metadata.reset_index(inplace=True, drop=True)
179 | if len(self.metadata) == 0:
180 | raise ValueError(f"No videos found for view {self.view}")
181 |
182 | def __getitem__(self, idx, return_row=False):
183 | row = self.metadata.iloc[idx]
184 | output = {
185 | 'filename': row['FileName'],
186 | }
187 |
188 | if "image" in self.outputs or "video" in self.outputs:
189 | latent_file = row["VideoPath"]
190 | latent_video_tensor = torch.load(latent_file) # T x C x H x W
191 | og_fps = row["FPS"]
192 | og_frame_count = len(latent_video_tensor)
193 |
194 | if "video" in self.outputs:
195 | if self.target_fps == "original":
196 | target_fps = og_fps
197 | elif self.target_fps == "random":
198 | target_fps = np.random.randint(8, 50)
199 | else:
200 | target_fps = self.target_fps
201 |
202 | new_frame_count = np.floor(target_fps / og_fps * og_frame_count).astype(int)
203 | resample_indices = np.linspace(0, og_frame_count, new_frame_count, endpoint=False).round().astype(int)
204 | start_idx = np.random.choice(np.arange(0, resample_indices[1])) if len(resample_indices) > 1 and resample_indices[1] > 1 else 0
205 | resample_indices = resample_indices + start_idx
206 |
207 | # Sample a random chunk to cover the requested duration
208 | start_idx = np.random.choice(np.arange(0, len(resample_indices) - self.duration_frames)) if len(resample_indices) > self.duration_frames else 0
209 | end_idx = start_idx + self.duration_frames
210 | resample_indices = resample_indices[start_idx:end_idx]
211 | resample_indices = resample_indices[resample_indices C x T x H x W
224 | output["video"] = self.transform(latent_video_sample)
225 | output["fps"] = target_fps
226 | output["padding"] = p_index
227 |
228 | if "lvef" in self.outputs:
229 | lvef = row["EF"] / 100.0
230 | output["lvef"] = torch.tensor(lvef, dtype=torch.float32)
231 |
232 | if "image" in self.outputs:
233 | latent_image_tensor = latent_video_tensor[np.random.randint(0, og_frame_count, 1)][0] # C x H x W
234 | output["image"] = self.transform(latent_image_tensor)
235 |
236 | if return_row:
237 | return output, row
238 |
239 | return output
240 |
241 |
242 | class RandomVideo(Dataset):
243 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None:
244 | super().__init__()
245 |
246 | self.config = config
247 | self.root = config.root
248 |
249 | self.target_nframes = config.target_nframes
250 | self.target_resolution = config.target_resolution
251 |
252 | self.outputs = config.outputs
253 | assert len(self.outputs) > 0, "Outputs must not be empty"
254 | assert all([o in ["video", "image"] for o in self.outputs]), "Outputs can only be video or image (or both) for RandomVideo"
255 |
256 |
257 | assert os.path.exists(self.root), f"Root folder {self.root} does not exist"
258 | assert os.path.isdir(self.root), f"Root folder {self.root} is not a directory"
259 | self.all_frames = os.listdir(self.root)
260 |
261 | assert len(self.all_frames) > 0, f"No frames found in {self.root}"
262 |
263 | self.still_image_p = config.get("still_image_p", 0) # probability of returning a still image instead of a video
264 |
265 | def __len__(self):
266 | return len(self.all_frames)
267 |
268 | def __getitem__(self, idx):
269 |
270 | output = {
271 | 'filename': 'Fake',
272 | }
273 |
274 | if "image" in self.outputs:
275 | path = os.path.join(self.root, self.all_frames[idx])
276 | image = Image.open(path) # H x W x C, uint8
277 | image = np.array(image) # H x W x C, uint8
278 | image = torch.from_numpy(image)
279 | image = image.permute(2, 0, 1).float() # C x H x W
280 | image = image / 128.0 - 1 # [-1, 1]
281 | output["image"] = image
282 |
283 | if "video" in self.outputs:
284 | if self.still_image_p > torch.rand(1,).item():
285 | random_indices = np.random.randint(0, len(self.all_frames), 1)
286 | path = os.path.join(self.root, self.all_frames[random_indices[0]])
287 | image = Image.open(path) # H x W x C, uint8
288 | image = np.array(image) # H x W x C, uint8
289 | image = torch.from_numpy(image)
290 | image = image.permute(2, 0, 1).float() # C x H x W
291 | image = image / 128.0 - 1
292 | image = image[:,None,:,:].repeat(1, self.target_nframes, 1, 1)
293 | output["video"] = image
294 | output["still"] = True
295 | else:
296 | random_indices = np.random.randint(0, len(self.all_frames), self.target_nframes)
297 | paths = [os.path.join(self.root, self.all_frames[ridx]) for ridx in random_indices]
298 | images = [Image.open(path) for path in paths]
299 | images = [np.array(image) for image in images]
300 | images = np.stack(images, axis=0) # T x H x W x C
301 | images = torch.from_numpy(images) # T x H x W x C
302 | images = images.permute(3, 0, 1, 2).float() # C x T x H x W
303 | images = images / 128.0 - 1 # [-1, 1]
304 | output["video"] = images
305 | output["still"] = False
306 | output["fps"] = 0
307 | output["padding"] = 0
308 |
309 | return output
310 |
311 |
312 | class FrameFolder(Dataset):
313 | """ config:
314 | - name: FrameFolder
315 | active: true
316 | params:
317 | video_folder: path/to/video_folders
318 | meta_path: path/to/FileList.csv
319 | outputs: ['video', 'image', 'lvef']
320 | """
321 | def __init__(self, config, split=["TRAIN", "VAL", "TEST"]) -> None:
322 | super().__init__()
323 |
324 | self.config = config
325 | self.video_folder = config.video_folder
326 | self.meta_path = config.meta_path
327 |
328 |
329 | self.target_nframes = config.target_nframes
330 | self.target_resolution = config.target_resolution
331 |
332 | self.metadata = pd.read_csv(self.meta_path)
333 | self.metadata = self.metadata[self.metadata["Split"].isin(split)] # filter by split
334 |
335 | # check if videos are reachable
336 | self.metadata["VideoPath"] = self.metadata["FileName"].apply(lambda x: os.path.join(config.video_folder, x.split('.')[0]))
337 | self.metadata["VideoExists"] = self.metadata["VideoPath"].apply(lambda x: ( os.path.isdir(x) and len(os.listdir(x)) > 0 ))
338 | self.metadata = self.metadata[self.metadata["VideoExists"]]
339 |
340 |
341 | def __len__(self):
342 | return len(self.metadata)
343 |
344 | def __getitem__(self, idx):
345 |
346 | row = self.metadata.iloc[idx]
347 |
348 | output = {
349 | 'filename': row['FileName'],
350 | }
351 |
352 | if "image" in self.outputs:
353 | fpath = os.path.join(row['VideoPath'])
354 | rand_item = np.random.choice(os.listdir(fpath))
355 | image = Image.open(os.path.join(fpath, rand_item)) # H x W x C, uint8
356 | image = np.array(image) # H x W x C, uint8
357 | image = torch.from_numpy(image)
358 | image = image.permute(2, 0, 1).float() # C x H x W
359 | image = image / 128.0 - 1 # [-1, 1]
360 | output["image"] = image
361 |
362 | if "video" in self.outputs:
363 | fpath = os.path.join(row['VideoPath'])
364 | fc = self.target_nframes
365 | all_frames_names = sorted(os.listdir(fpath))
366 | if len(all_frames_names) > fc:
367 | start_idx = np.random.randint(0, len(all_frames_names) - fc)
368 | end_idx = start_idx + fc
369 | else:
370 | start_idx = 0
371 | end_idx = -1
372 | all_frames_names = all_frames_names[start_idx:end_idx]
373 | all_frames_path = [os.path.join(fpath, f) for f in all_frames_names]
374 | all_frames = [Image.open(f) for f in all_frames_path]
375 | all_frames = [np.array(f) for f in all_frames]
376 | all_frames = np.stack(all_frames, axis=0) # T x H x W x C
377 | all_frames = torch.from_numpy(all_frames) # T x H x W x C
378 | all_frames = all_frames.permute(3, 0, 1, 2).float() # C x T x H x W
379 | all_frames = all_frames / 128.0 - 1 # [-1, 1]
380 |
381 | if len(all_frames) < fc:
382 | padding_element = torch.zeros_like(all_frames[0])
383 | padding = torch.stack([padding_element] * (fc - len(all_frames)))
384 | all_frames = torch.cat((all_frames, padding), dim=0)
385 | assert len(all_frames) == fc, f"Video length is {len(all_frames)} but should be {fc}"
386 |
387 | output["video"] = all_frames
388 |
389 | if "lvef" in self.outputs:
390 | lvef = row["EF"] / 100.0
391 | output["lvef"] = torch.tensor(lvef, dtype=torch.float32)
392 |
393 | return output
394 |
395 |
396 | def instantiate_dataset(configs, split=["TRAIN", "VAL", "TEST"]):
397 | # config = config.copy()
398 | # assert config.get("datasets", False), "No 'datasets' key found in config"
399 |
400 | # Check if number of frames and resolution are the same for all datasets
401 | target_nframes = None
402 | target_resolution = None
403 | reference_name = None
404 | for dataset_config in configs:
405 | if dataset_config.get("active", True):
406 | if reference_name is None:
407 | reference_name = dataset_config.name
408 | if target_nframes is None:
409 | target_nframes = dataset_config.params.target_nframes
410 | else:
411 | newd = dataset_config.params.target_nframes
412 | assert newd == target_nframes, f"All datasets must ouput the same number of frames, got {reference_name}: {target_nframes} frames and {dataset_config.name}: {newd} frames."
413 | if target_resolution is None:
414 | target_resolution = dataset_config.params.target_resolution
415 | else:
416 | assert dataset_config.params.target_resolution == target_resolution, f"All datasets must have the same target_resolution, got {reference_name}: {target_resolution} and {dataset_config.name}: {dataset_config.params.target_resolution}."
417 |
418 | datasets = []
419 | for dataset_config in configs:
420 | if dataset_config.get("active", True):
421 | datasets.append(globals()[dataset_config.name](dataset_config.params, split=split))
422 |
423 | if len(datasets) == 1:
424 | return datasets[0]
425 | else:
426 | return ConcatDataset(datasets)
427 |
428 |
429 | class RFBalancer(Dataset): # Real - Fake Balancer
430 | """
431 | Balances the dataset by sampling from each dataset with equal probability.
432 |
433 | """
434 | def __init__(self, real_dataset=None, fake_dataset=None, transform=None) -> None:
435 | super().__init__()
436 |
437 | # self.datasets = [fake_dataset, real_dataset]
438 | self.datasets = []
439 | if fake_dataset is not None:
440 | self.datasets.append(fake_dataset)
441 | if real_dataset is not None:
442 | self.datasets.append(real_dataset)
443 |
444 | if len(self.datasets) == 0:
445 | raise ValueError("At least one dataset must be provided")
446 |
447 | if len(self.datasets) > 1:
448 | self.ds_idx = (np.random.rand(1,) < 0.5)[0] # pick the first dataset to start with
449 | else:
450 | self.ds_idx = 0
451 |
452 | self.ds_current = [0] * len(self.datasets)
453 |
454 | self.transforms = transform
455 |
456 | def __len__(self):
457 | return np.sum([len(ds) for ds in self.datasets])
458 |
459 | def _get_index_for_ds(self, idx):
460 | ds_idx = 0
461 | while True:
462 | if idx < len(self.datasets[ds_idx]):
463 | break
464 | else:
465 | idx -= len(self.datasets[ds_idx])
466 | ds_idx = (ds_idx + 1) % len(self.datasets)
467 | return ds_idx, idx
468 |
469 |
470 | def __getitem__(self, idx):
471 | ds_idx, idx = self._get_index_for_ds(idx)
472 | output = self.datasets[ds_idx][idx] # get item from dataset
473 | output["real"] = float(ds_idx) # add real/fake label
474 |
475 | if self.transforms is not None and "video" in output:
476 | output["video"] = self.transforms(output["video"])
477 | if self.transforms is not None and "image" in output:
478 | output["image"] = self.transforms(output["image"])
479 |
480 | return output
481 |
482 |
483 | class SimaseUSVideoDataset(Dataset):
484 | def __init__(self,
485 | phase='training',
486 | transform=None,
487 | latents_csv='./',
488 | training_latents_base_path="./",
489 | in_memory=True,
490 | generator_seed=None):
491 | self.phase = phase
492 | self.training_latents_base_path = training_latents_base_path
493 |
494 | self.in_memory = in_memory
495 | self.videos = []
496 |
497 | PHASE_TO_SPLIT = {"training": "TRAIN", "validation": "VAL", "testing": "TEST"}
498 | self.df = pd.read_csv(latents_csv)
499 | self.df = self.df[self.df["Split"] == PHASE_TO_SPLIT[self.phase]].reset_index(drop=True)
500 | self.transform = transform
501 |
502 | if generator_seed is None:
503 | self.generator = np.random.default_rng()
504 | #unseeded
505 | else:
506 | self.generator_seed = generator_seed
507 | print(f"Set {self.phase} dataset seed to {self.generator_seed}")
508 |
509 | if self.in_memory:
510 | self.load_videos()
511 |
512 | def __len__(self):
513 | return len(self.df)
514 |
515 | def __getitem__(self, index):
516 | vid_a = self.get_vid(index)
517 | if self.transform is not None:
518 | vid_a = self.transform(vid_a)
519 | return vid_a
520 |
521 | def reset_generator(self):
522 | self.generator = np.random.default_rng(self.generator_seed)
523 |
524 | def get_vid(self, index, from_disk=False):
525 | if self.in_memory and not from_disk:
526 | return self.videos[index]
527 | else:
528 | path = self.df.iloc[index]["FileName"].split('.')[0] + ".pt"
529 | path = os.path.join(self.training_latents_base_path, path)
530 | return torch.load(path)
531 |
532 | def load_videos(self):
533 | self.videos = []
534 | print("Preloading videos")
535 | for i in range(len(self)):
536 | self.videos.append(self.get_vid(i, from_disk=True))
537 |
538 |
539 | class SiameseUSDataset(Dataset):
540 | def __init__(self,
541 | phase='training',
542 | transform=None,
543 | latents_csv='./',
544 | training_latents_base_path="./",
545 | in_memory=True,
546 | generator_seed=None):
547 | self.phase = phase
548 | self.training_latents_base_path = training_latents_base_path
549 |
550 | self.in_memory = in_memory
551 | self.videos = []
552 |
553 | PHASE_TO_SPLIT = {"training": "TRAIN", "validation": "VAL", "testing": "TEST"}
554 | self.df = pd.read_csv(latents_csv)
555 | self.df = self.df[self.df["Split"] == PHASE_TO_SPLIT[self.phase]].reset_index(drop=True)
556 |
557 | self.transform = transform
558 |
559 | if generator_seed is None:
560 | self.generator = np.random.default_rng()
561 | #unseeded
562 | else:
563 | self.generator_seed = generator_seed
564 | print(f"Set {self.phase} dataset seed to {self.generator_seed}")
565 |
566 | if self.in_memory:
567 | self.load_videos()
568 |
569 | def __len__(self):
570 | return len(self.df)
571 |
572 | def __getitem__(self, index):
573 |
574 | vid_a = torch.clone(self.get_vid(index))
575 | if self.generator.uniform() < 0.5:
576 | vid_b = torch.clone(self.get_vid((index + self.generator.integers(low=1, high=len(self))) % len(self))) # random different vid
577 | y = 0.0
578 | else:
579 | vid_b = torch.clone(vid_a)
580 | y = 1.0
581 |
582 | if self.transform is not None:
583 | vid_a = self.transform(vid_a)
584 | vid_b = self.transform(vid_b)
585 |
586 | frame_a = self.generator.integers(len(vid_a))
587 | frame_b = (frame_a + self.generator.integers(low=1, high=len(vid_b))) % len(vid_b)
588 | #print(f"Dataloader: framea {frame_a} - frame_b {frame_b} - y: {y}")
589 | return vid_a[frame_a], vid_b[frame_b], y
590 |
591 | def reset_generator(self):
592 | self.generator = np.random.default_rng(self.generator_seed)
593 |
594 | def get_vid(self, index, from_disk=False):
595 | if self.in_memory and not from_disk:
596 | return self.videos[index]
597 | else:
598 | path = self.df.iloc[index]["FileName"].split('.')[0] + ".pt"
599 | path = os.path.join(self.training_latents_base_path, path)
600 | return torch.load(path)
601 |
602 | def load_videos(self):
603 | self.videos = []
604 | print("Preloading videos")
605 | for i in range(len(self)):
606 | self.videos.append(self.get_vid(i, from_disk=True))
607 |
608 |
609 | class ImageSet(Dataset):
610 | def __init__(self, root, ext=".jpg"):
611 | self.root = root
612 | self.all_images = glob(os.path.join(root, "*.jpg"))
613 |
614 | def __len__(self):
615 | return len(self.all_images)
616 |
617 | def __getitem__(self, idx):
618 | image = Image.open(self.all_images[idx])
619 | image = np.array(image) / 128.0 - 1 # [0, 255] -> [-1, 1]
620 | image = image.transpose(2, 0, 1) # H x W x C -> C x H x W
621 | return image
622 |
623 | class TensorSet(Dataset):
624 | def __init__(self, root):
625 | self.root = root
626 | self.all_tensors = glob(os.path.join(root, "*.pt"))
627 |
628 | def __len__(self):
629 | return len(self.all_tensors)
630 |
631 | def __getitem__(self, idx):
632 | tensor = torch.load(self.all_tensors[idx], map_location='cpu')
633 | return tensor
634 |
635 | if __name__ == "__main__":
636 | pass
637 |
638 |
--------------------------------------------------------------------------------
/echosyn/common/privacy_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils import data
4 | from sklearn import metrics
5 | import math
6 | import random
7 | import itertools
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | from echosyn.common.datasets import SiameseUSDataset
11 |
12 |
13 | """
14 | This file provides the most important functions that are used in our experiments. These functions are called in
15 | AgentSiameseNetwork.py which provides the actual training/validation loop and the code for evaluation.
16 | """
17 |
18 |
19 | # Function to get the data loader.
20 | def get_data_loaders(phase='training', data_handling='balanced', n_channels=3, n_samples=100000, transform=None,
21 | image_path='./', batch_size=32, shuffle=True, num_workers=16, pin_memory=True, save_path=None, generator_seed=None):
22 |
23 | #dataset = SiameseDataset(phase=phase, data_handling=data_handling, n_channels=n_channels, n_samples=n_samples,
24 | # transform=transform, image_path=image_path, save_path=save_path)
25 | latents_csv = image_path
26 | training_latents_base_path = os.path.join(os.path.dirname(latents_csv), "Latents")
27 | dataset = SiameseUSDataset(phase=phase, transform=transform, latents_csv=latents_csv, training_latents_base_path=training_latents_base_path, in_memory=False, generator_seed=generator_seed)
28 | dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1,
29 | pin_memory=pin_memory)
30 |
31 | return dataloader
32 |
33 |
34 | # This function represents the training loop for the standard case where we have two input images and one output node.
35 | def train(net, training_loader, n_samples, batch_size, criterion, optimizer, epoch, n_epochs):
36 | net.train()
37 | running_loss = 0.0
38 |
39 | print('Training----->')
40 | for i, batch in enumerate(training_loader):
41 | inputs1, inputs2, labels = batch
42 | inputs1, inputs2, labels = inputs1.cuda(), inputs2.cuda(), labels.cuda()
43 |
44 | # zero the parameter gradients
45 | optimizer.zero_grad()
46 |
47 | # forward + backward + optimize
48 | outputs = net(inputs1, inputs2)
49 | outputs = outputs.squeeze()
50 | labels = labels.type_as(outputs)
51 | loss = criterion(outputs, labels)
52 | loss.backward()
53 | optimizer.step()
54 |
55 | running_loss += loss.item()
56 |
57 | print('Epoch [%d/%d], Iteration [%d/%d], Loss: %.4f' % (epoch + 1, n_epochs, i + 1,
58 | math.ceil(n_samples / batch_size),
59 | loss.item()))
60 |
61 | # Compute the average loss per epoch
62 | training_loss = running_loss / math.ceil(n_samples / batch_size)
63 | return training_loss
64 |
65 |
66 | # This function represents the validation loop for the standard case where we have two input images and one output node.
67 | def validate(net, validation_loader, n_samples, batch_size, criterion, epoch, n_epochs):
68 | net.eval()
69 | running_loss = 0
70 |
71 | print('Validating----->')
72 | with torch.no_grad():
73 | for i, batch in enumerate(validation_loader):
74 | inputs1, inputs2, labels = batch
75 | inputs1, inputs2, labels = inputs1.cuda(), inputs2.cuda(), labels.cuda()
76 |
77 | # forward
78 | outputs = net(inputs1, inputs2)
79 | outputs = outputs.squeeze()
80 | labels = labels.type_as(outputs)
81 | loss = criterion(outputs, labels)
82 |
83 | running_loss += loss.item()
84 |
85 | print('Epoch [%d/%d], Iteration [%d/%d], Loss: %.4f' % (epoch + 1, n_epochs, i + 1,
86 | math.ceil(n_samples / batch_size),
87 | loss.item()))
88 |
89 | # Compute the average loss per epoch
90 | validation_loss = running_loss / math.ceil(n_samples / batch_size)
91 | return validation_loss
92 |
93 |
94 | # This function represents the test loop for the standard case where we have two input images and one output node.
95 | # This function returns the true labels and the predicted values.
96 | def test(net, test_loader):
97 | net.eval()
98 | y_true = None
99 | y_pred = None
100 |
101 | print('Testing----->')
102 | with torch.no_grad():
103 | for i, batch in enumerate(test_loader):
104 | inputs1, inputs2, labels = batch
105 |
106 | if y_true is None:
107 | y_true = labels
108 | else:
109 | y_true = torch.cat((y_true, labels), 0)
110 |
111 | inputs1, inputs2, labels = inputs1.cuda(), inputs2.cuda(), labels.cuda()
112 | outputs = net(inputs1, inputs2)
113 | outputs = torch.sigmoid(outputs)
114 |
115 | if y_pred is None:
116 | y_pred = outputs.cpu()
117 | else:
118 | y_pred = torch.cat((y_pred, outputs.cpu()), 0)
119 |
120 | y_pred = y_pred.squeeze()
121 | return y_true, y_pred
122 |
123 |
124 | # This function computes some standard evaluation metrics given the true labels and the predicted values.
125 | def get_evaluation_metrics(y_true, y_pred):
126 | accuracy = metrics.accuracy_score(y_true, y_pred)
127 | f1_score = metrics.f1_score(y_true, y_pred)
128 | precision = metrics.precision_score(y_true, y_pred)
129 | recall = metrics.recall_score(y_true, y_pred)
130 | report = metrics.classification_report(y_true, y_pred)
131 | confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
132 | return accuracy, f1_score, precision, recall, report, confusion_matrix
133 |
134 |
135 | # This function is used to apply a threshold to the predicted values before computing some evaluation metrics.
136 | def apply_threshold(input_tensor, threshold):
137 | output = np.where(input_tensor > threshold, torch.ones(len(input_tensor)), torch.zeros(len(input_tensor)))
138 | return output
139 |
140 |
141 | # This function implements bootstrapping, in order to get the mean AUC value and the 95% confidence interval.
142 | def bootstrap(n_bootstraps, y_true, y_pred, path, experiment_description):
143 | y_true = np.array(y_true)
144 | y_pred = np.array(y_pred)
145 |
146 | bootstrapped_scores = []
147 |
148 | f = open(path + experiment_description + '_AUC_bootstrapped.txt', "w+")
149 | f.write('AUC_bootstrapped\n')
150 |
151 | for i in range(n_bootstraps):
152 | indices = np.random.randint(0, len(y_pred) - 1, len(y_pred))
153 | auc = metrics.roc_auc_score(y_true[indices], y_pred[indices])
154 | bootstrapped_scores.append(auc)
155 | f.write(str(auc) + '\n')
156 | f.close()
157 | sorted_scores = np.array(bootstrapped_scores)
158 | sorted_scores.sort()
159 |
160 | auc_mean = np.mean(sorted_scores)
161 | confidence_lower = sorted_scores[int(0.025 * len(sorted_scores))]
162 | confidence_upper = sorted_scores[int(0.975 * len(sorted_scores))]
163 |
164 | f = open(path + experiment_description + '_AUC_confidence.txt', "w+")
165 | f.write('AUC_mean: %s\n' % auc_mean)
166 | f.write('Confidence interval for the AUC score: ' + str(confidence_lower) + ' - ' + str(confidence_upper))
167 | f.close()
168 | return auc_mean, confidence_lower, confidence_upper
169 |
170 |
171 | # This is a function that plots the loss curves.
172 | def plot_loss_curves(loss_dict, path, experiment_description):
173 | plt.figure()
174 | plt.plot(range(1, len(loss_dict['training']) + 1), loss_dict['training'], label='Training Loss')
175 | plt.plot(range(1, len(loss_dict['validation']) + 1), loss_dict['validation'], label='Validation Loss')
176 | plt.xlabel('Epoch')
177 | plt.ylabel('Loss')
178 | plt.title('Training and validation loss curves')
179 | plt.legend()
180 | plt.savefig(path + experiment_description + '_loss_curves.png')
181 |
182 |
183 | # This is a function that plots the ROC curve.
184 | def plot_roc_curve(fp_rates, tp_rates, path, experiment_description):
185 | plt.figure()
186 | plt.plot(fp_rates, tp_rates, label='ROC Curve')
187 | plt.xlabel('False positive rate')
188 | plt.ylabel('True positive rate')
189 | plt.title('ROC Curve')
190 | plt.legend()
191 | plt.savefig(path + experiment_description + '_ROC_curve.png')
192 |
193 |
194 | # This is a function that saves the evaluation metrics to a file.
195 | def save_results_to_file(auc, accuracy, f1_score, precision, recall, report, confusion_matrix, path,
196 | experiment_description):
197 | f = open(path + experiment_description + '_results.txt', "w+")
198 | f.write('AUC: %s\n' % auc)
199 | f.write('Accuracy: %s\n' % accuracy)
200 | f.write('F1-Score: %s\n' % f1_score)
201 | f.write('Precision: %s\n' % precision)
202 | f.write('Recall: %s\n' % recall)
203 | f.write('Classification report: %s\n' % report)
204 | f.write('Confusion matrix: %s\n' % confusion_matrix)
205 | f.close()
206 |
207 |
208 | # This function saves the ROC metrics to a file.
209 | def save_roc_metrics_to_file(fp_rates, tp_rates, thresholds, path, experiment_description):
210 | f = open(path + experiment_description + '_ROC_metrics.txt', "w+")
211 | f.write('FP_rate\tTP_rate\tThreshold\n')
212 | for i in range(len(fp_rates)):
213 | f.write(str(fp_rates[i]) + '\t' + str(tp_rates[i]) + '\t' + str(thresholds[i]) + '\n')
214 | f.close()
215 |
216 |
217 | # This function saves the training and validation loss values to a file.
218 | def save_loss_curves(loss_dict, path, experiment_description):
219 | f = open(path + experiment_description + '_loss_values.txt', "w+")
220 | f.write('TrainingLoss\tValidationLoss\n')
221 | for i in range(len(loss_dict['training'])):
222 | f.write(str(loss_dict['training'][i]) + '\t' + str(loss_dict['validation'][i]) + '\n')
223 | f.close()
224 |
225 |
226 | # This function is used to save a checkpoint. This enables us to resume an experiment at a later point if wanted.
227 | def save_checkpoint(epoch, model, optimizer, loss_dict, best_loss, num_bad_epochs, filename='checkpoint.pth'):
228 | state = {
229 | 'epoch': epoch + 1,
230 | 'state_dict': model.state_dict(),
231 | 'optimizer': optimizer.state_dict(),
232 | 'loss_dict': loss_dict,
233 | 'best_loss': best_loss,
234 | 'num_bad_epochs': num_bad_epochs
235 | }
236 | torch.save(state, filename)
237 |
238 |
239 | # This function saves the true labels, the predicted vales, and the thresholded values to a file.
240 | def save_labels_predictions(y_true, y_pred, y_pred_thresh, path, experiment_description):
241 | f = open(path + experiment_description + '_labels_predictions.txt', "w+")
242 | f.write('Label\tPrediction\tPredictionThresholded\n')
243 | for i in range(len(y_true)):
244 | f.write(str(y_true[i]) + '\t' + str(y_pred[i]) + '\t' + str(y_pred_thresh[i]) + '\n')
245 | f.close()
246 |
247 |
248 | # This function was utilized to construct the positive and negative pairs needed for training and testing our
249 | # verification network. This function takes filenames as an input argument and returns both the list of tuples and the
250 | # list of corresponding labels. The constructed pairs were later saved to a .txt file. Now available in the folder
251 | # './image_pairs/' (pairs_training.txt, pairs_validation.txt, pairs_testing.txt).
252 | # Example:
253 | #
254 | # train_val_filenames = np.loadtxt('train_val_list.txt', dtype=str)
255 | # test_filenames = np.loadtxt('test_list.txt', dtype=str)
256 | # tuples_train, labels_train = get_tuples_labels(train_val_filenames[:75708])
257 | # tuples_val, labels_val = get_tuples_labels(train_val_filenames[75708:])
258 | # tuples_test, labels_test = get_tuples_labels(test_filenames)
259 | #
260 | def get_tuples_labels(filenames):
261 | tuples_list = []
262 | labels_list = []
263 | patients = []
264 |
265 | for i, element in enumerate(filenames):
266 | element = element[:-8]
267 | patients.append(element)
268 |
269 | unique_patients, counts_patients = np.unique(patients, return_counts=True)
270 | patients_dict = dict(zip(unique_patients, counts_patients))
271 |
272 | start_idx = 0
273 | patients_files_dict = {}
274 |
275 | for key in patients_dict:
276 | patients_files_dict[key] = filenames[start_idx:start_idx + patients_dict[key]]
277 | start_idx += patients_dict[key]
278 |
279 | if len(patients_files_dict[key]) > 1:
280 | # samples = list(itertools.product(patients_files_dict[key], patients_files_dict[key]))
281 | samples = list(itertools.combinations(patients_files_dict[key], 2))
282 | tuples_list.append(samples)
283 | labels_list.append(np.ones(len(samples)).tolist())
284 |
285 | tuples_list = list(itertools.chain.from_iterable(tuples_list))
286 | labels_list = list(itertools.chain.from_iterable(labels_list))
287 |
288 | N = len(tuples_list)
289 | i = 0
290 |
291 | while i < N:
292 | file1 = random.choice(filenames)
293 | file2 = random.choice(filenames)
294 |
295 | if file1[:-8] != file2[:-8]:
296 | sample = (file1, file2)
297 | tuples_list.append(sample)
298 | labels_list.append(0.0)
299 | i += 1
300 |
301 | return tuples_list, labels_list
302 |
--------------------------------------------------------------------------------
/echosyn/lidm/README.md:
--------------------------------------------------------------------------------
1 | # Latent Image Diffusion Model
2 |
3 | The Latent Image Diffusion Model (LIDM) is the first step of our generative pipeline. It generates a latent representation of a heart, which is then passed to the Latent Video Diffusion Model (LVDM) to generate a video of the heart beating.
4 |
5 | Training the LIDM is straightforward and does not require a lot of resouces. In the paper, the LIDMs are trained for ~24h on a single A100 GPU. The batch size can be adjusted to fit smaller GPUs with no noticeable loss of quality.
6 |
7 | ## 1. Activate the environment
8 |
9 | First, activate the echosyn environment.
10 |
11 | ```bash
12 | conda activate echosyn
13 | ```
14 |
15 | ## 2. Data preparation
16 | Follow the instruction in the [Data preparation](../../README.md#data-preparation) to prepare the data for training. Here, you need the VAE-encoded videos.
17 |
18 | ## 3. Train the LIDM
19 | Once the environment is set up and the data is ready, you can train the LIDMs with the following commands:
20 |
21 | ```bash
22 | python echosyn/lidm/train.py --config echosyn/lidm/configs/dynamic.yaml
23 | python echosyn/lidm/train.py --config echosyn/lidm/configs/ped_a4c.yaml
24 | python echosyn/lidm/train.py --config echosyn/lidm/configs/ped_psax.yaml
25 | ```
26 |
27 | ## 4. Sample from the LIDM
28 |
29 | Once the LIDMs are trained, you can sample from them with the following command:
30 |
31 | ```bash
32 | # For the Dynamic dataset
33 | python echosyn/lidm/sample.py \
34 | --config echosyn/lidm/configs/dynamic.yaml \
35 | --unet experiments/lidm_dynamic/checkpoint-500000/unet_ema \
36 | --vae models/vae \
37 | --output samples/lidm_dynamic \
38 | --num_samples 50000 \
39 | --batch_size 128 \
40 | --num_steps 64 \
41 | --save_latent \
42 | --seed 0
43 | ```
44 |
45 | ```bash
46 | # For the Pediatric A4C dataset
47 | python echosyn/lidm/sample.py \
48 | --config echosyn/lidm/configs/ped_a4c.yaml \
49 | --unet experiments/lidm_ped_a4c/checkpoint-500000/unet_ema \
50 | --vae models/vae \
51 | --output samples/lidm_ped_a4c \
52 | --num_samples 50000 \
53 | --batch_size 128 \
54 | --num_steps 64 \
55 | --save_latent \
56 | --seed 0
57 | ```
58 |
59 | ```bash
60 | # For the Pediatric PSAX dataset
61 | python echosyn/lidm/sample.py \
62 | --config echosyn/lidm/configs/ped_psax.yaml \
63 | --unet experiments/lidm_ped_psax/checkpoint-500000/unet_ema \
64 | --vae models/vae \
65 | --output samples/lidm_ped_psax \
66 | --num_samples 50000 \
67 | --batch_size 128 \
68 | --num_steps 64 \
69 | --save_latent \
70 | --seed 0
71 | ```
72 |
73 | ## 5. Evaluate the LIDM
74 |
75 | To evaluate the LIDMs, we use the FID and IS scores.
76 | To do so, we need to generate 50,000 samples, using the command above.
77 | Note that the privacy step will reject some of these samples.
78 | It is therefore better to generate 100,000, so we can calculate the FID and IS scores again, on the privacy-compliant samples.
79 | The samples are compared to the real samples, which are generated in the [Data preparation](../../README.md#data-preparation) step.
80 |
81 | Then, to evaluate the samples, run the following commands:
82 |
83 | ```bash
84 | cd external/stylegan-v
85 |
86 | # For the Dynamic dataset
87 | python src/scripts/calc_metrics_for_dataset.py \
88 | --real_data_path ../../data/reference/dynamic \
89 | --fake_data_path ../../samples/lidm_dynamic/images \
90 | --mirror 0 --gpus 1 --resolution 112 \
91 | --metrics fid50k_full,is50k >> "../../samples/lidm_dynamic/metrics.txt"
92 |
93 | # For the Pediatric A4C dataset
94 | python src/scripts/calc_metrics_for_dataset.py \
95 | --real_data_path data/reference/ped_a4c \
96 | --fake_data_path samples/lidm_ped_a4c/images \
97 | --mirror 0 --gpus 1 --resolution 112 \
98 | --metrics fid50k_full,is50k >> "samples/lidm_ped_a4c/metrics.txt"
99 |
100 | # For the Pediatric PSAX dataset
101 | python src/scripts/calc_metrics_for_dataset.py \
102 | --real_data_path data/reference/ped_psax \
103 | --fake_data_path samples/lidm_ped_psax/images \
104 | --mirror 0 --gpus 1 --resolution 112 \
105 | --metrics fid50k_full,is50k >> "samples/lidm_ped_psax/metrics.txt
106 | ```
107 |
108 | ## 6. Save the LIDMs for later use
109 | Once you are satisfied with the performance of the LIDMs, you can save them for later use with the following commands:
110 |
111 | ```bash
112 | mkdir -p models/lidm_dynamic; cp -r experiments/lidm_dynamic/checkpoint-500000/unet_ema/* models/lidm_dynamic/; cp experiments/lidm_dynamic/config.yaml models/lidm_dynamic/
113 | mkdir -p models/lidm_ped_a4c; cp -r experiments/lidm_ped_a4c/checkpoint-500000/unet_ema/* models/lidm_ped_a4c/; cp experiments/lidm_ped_a4c/config.yaml models/lidm_ped_a4c/
114 | mkdir -p models/lidm_ped_psax; cp -r experiments/lidm_ped_psax/checkpoint-500000/unet_ema/* models/lidm_ped_psax/; cp experiments/lidm_ped_psax/config.yaml models/lidm_ped_psax/
115 | ```
116 |
117 | This will save the selected ema version of the model, ready to be loaded in any other script as a standalone model.
--------------------------------------------------------------------------------
/echosyn/lidm/configs/dynamic.yaml:
--------------------------------------------------------------------------------
1 | wandb_group: "lidm"
2 | output_dir: experiments/lidm_dynamic
3 |
4 | pretrained_model_name_or_path: null
5 | vae_path: models/vae
6 |
7 | globals:
8 | target_fps: 32
9 | target_nframes: 64
10 | outputs: ["image"]
11 |
12 | datasets:
13 | - name: Latent
14 | active: true
15 | params:
16 | root: data/latents/dynamic
17 | target_fps: ${globals.target_fps}
18 | target_nframes: ${globals.target_nframes}
19 | target_resolution: 14 # emb resolution
20 | outputs: ${globals.outputs}
21 |
22 | unet:
23 | _class_name: UNet2DModel
24 | sample_size: 14 # actual size is 16
25 | in_channels: 4
26 | out_channels: 4
27 | center_input_sample: false
28 | time_embedding_type: positional
29 | freq_shift: 0
30 | flip_sin_to_cos: true
31 | down_block_types:
32 | - AttnDownBlock2D
33 | - AttnDownBlock2D
34 | - AttnDownBlock2D
35 | - DownBlock2D
36 | up_block_types:
37 | - UpBlock2D
38 | - AttnUpBlock2D
39 | - AttnUpBlock2D
40 | - AttnUpBlock2D
41 | block_out_channels:
42 | - 128
43 | - 256
44 | - 256
45 | - 512
46 | layers_per_block: 2
47 | mid_block_scale_factor: 1
48 | downsample_padding: 1
49 | downsample_type: resnet
50 | upsample_type: resnet
51 | dropout: 0.0
52 | act_fn: silu
53 | attention_head_dim: 8
54 | norm_num_groups: 32
55 | attn_norm_num_groups: null
56 | norm_eps: 1e-05
57 | resnet_time_scale_shift: "default"
58 | class_embed_type: null
59 | num_class_embeds: null
60 |
61 | noise_scheduler:
62 | _class_name: DDPMScheduler
63 | num_train_timesteps: 1000
64 | beta_start: 0.0001
65 | beta_end: 0.02
66 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2
67 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range
68 | clip_sample: true
69 | clip_sample_range: 4.0 # default 1
70 | prediction_type: v_prediction # epsilon, sample, v_prediction
71 | thresholding: false # do not touch
72 | dynamic_thresholding_ratio: 0.995 # unused
73 | sample_max_value: 1.0 # unused
74 | timestep_spacing: "leading" #
75 | steps_offset: 0 # unused
76 |
77 | train_batch_size: 256
78 | dataloader_num_workers: 16
79 | max_train_steps: 500000
80 |
81 | learning_rate: 3e-4
82 | lr_warmup_steps: 500
83 | scale_lr: false
84 | lr_scheduler: constant
85 | use_8bit_adam: false
86 | gradient_accumulation_steps: 1
87 |
88 | noise_offset: 0.0
89 |
90 | gradient_checkpointing: false
91 | use_ema: true
92 | enable_xformers_memory_efficient_attention: false
93 | allow_tf32: true
94 |
95 | adam_beta1: 0.9
96 | adam_beta2: 0.999
97 | adam_weight_decay: 1e-2
98 | adam_epsilon: 1e-08
99 | max_grad_norm: 1.0
100 |
101 | logging_dir: logs
102 | mixed_precision: "fp16" # "no", "fp16", "bf16"
103 |
104 | validation_timesteps: 128
105 | validation_fps: ${globals.target_fps}
106 | validation_frames: ${globals.target_nframes}
107 | validation_count: 4 # defines the number of samples
108 | validation_guidance: 1.0
109 | validation_steps: 2500
110 |
111 | report_to: wandb
112 | checkpointing_steps: 100 #10000 # ~3/hour
113 | checkpoints_total_limit: 100 # no limit
114 | resume_from_checkpoint: latest
115 | tracker_project_name: echosyn
116 |
117 | seed: 42
118 |
119 |
--------------------------------------------------------------------------------
/echosyn/lidm/configs/ped_a4c.yaml:
--------------------------------------------------------------------------------
1 | wandb_group: "lidm"
2 | output_dir: experiments/lidm_ped_a4c
3 |
4 | pretrained_model_name_or_path: null
5 | vae_path: models/vae
6 |
7 | globals:
8 | target_fps: "original"
9 | target_nframes: 32
10 | outputs: ["image"]
11 |
12 | datasets:
13 | - name: Latent
14 | active: true
15 | params:
16 | root: data/latents/ped_a4c
17 | target_fps: ${globals.target_fps}
18 | target_nframes: ${globals.target_nframes}
19 | target_resolution: 14 # emb resolution
20 | outputs: ${globals.outputs}
21 |
22 | unet:
23 | _class_name: UNet2DModel
24 | sample_size: 14 # actual size is 16
25 | in_channels: 4
26 | out_channels: 4
27 | center_input_sample: false
28 | time_embedding_type: positional
29 | freq_shift: 0
30 | flip_sin_to_cos: true
31 | down_block_types:
32 | - AttnDownBlock2D
33 | - AttnDownBlock2D
34 | - AttnDownBlock2D
35 | - DownBlock2D
36 | up_block_types:
37 | - UpBlock2D
38 | - AttnUpBlock2D
39 | - AttnUpBlock2D
40 | - AttnUpBlock2D
41 | block_out_channels:
42 | - 128
43 | - 256
44 | - 256
45 | - 512
46 | layers_per_block: 2
47 | mid_block_scale_factor: 1
48 | downsample_padding: 1
49 | downsample_type: resnet
50 | upsample_type: resnet
51 | dropout: 0.0
52 | act_fn: silu
53 | attention_head_dim: 8
54 | norm_num_groups: 32
55 | attn_norm_num_groups: null
56 | norm_eps: 1e-05
57 | resnet_time_scale_shift: "default"
58 | class_embed_type: null
59 | num_class_embeds: null
60 |
61 | noise_scheduler:
62 | _class_name: DDPMScheduler
63 | num_train_timesteps: 1000
64 | beta_start: 0.0001
65 | beta_end: 0.02
66 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2
67 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range
68 | clip_sample: true
69 | clip_sample_range: 4.0 # default 1
70 | prediction_type: v_prediction # epsilon, sample, v_prediction
71 | thresholding: false # do not touch
72 | dynamic_thresholding_ratio: 0.995 # unused
73 | sample_max_value: 1.0 # unused
74 | timestep_spacing: "leading" #
75 | steps_offset: 0 # unused
76 |
77 | train_batch_size: 256
78 | dataloader_num_workers: 16
79 | max_train_steps: 500000
80 |
81 | learning_rate: 3e-4
82 | lr_warmup_steps: 500
83 | scale_lr: false
84 | lr_scheduler: constant
85 | use_8bit_adam: false
86 | gradient_accumulation_steps: 1
87 |
88 | noise_offset: 0.0
89 |
90 | gradient_checkpointing: false
91 | use_ema: true
92 | enable_xformers_memory_efficient_attention: false
93 | allow_tf32: true
94 |
95 | adam_beta1: 0.9
96 | adam_beta2: 0.999
97 | adam_weight_decay: 1e-2
98 | adam_epsilon: 1e-08
99 | max_grad_norm: 1.0
100 |
101 | logging_dir: logs
102 | mixed_precision: "fp16" # "no", "fp16", "bf16"
103 |
104 | validation_timesteps: 128
105 | validation_fps: ${globals.target_fps}
106 | validation_frames: ${globals.target_nframes}
107 | validation_count: 4 # defines the number of samples
108 | validation_guidance: 1.0
109 | validation_steps: 2500
110 |
111 | report_to: wandb
112 | checkpointing_steps: 10000 # ~3/hour
113 | checkpoints_total_limit: 100 # no limit
114 | resume_from_checkpoint: latest
115 | tracker_project_name: echosyn
116 |
117 | seed: 42
118 |
119 |
--------------------------------------------------------------------------------
/echosyn/lidm/configs/ped_psax.yaml:
--------------------------------------------------------------------------------
1 | wandb_group: "lidm"
2 | output_dir: experiments/lidm_ped_psax
3 |
4 | pretrained_model_name_or_path: null
5 | vae_path: models/vae
6 |
7 | globals:
8 | target_fps: "original"
9 | target_nframes: 32
10 | outputs: ["image"]
11 |
12 | datasets:
13 | - name: Latent
14 | active: true
15 | params:
16 | root: data/latents/ped_psax
17 | target_fps: ${globals.target_fps}
18 | target_nframes: ${globals.target_nframes}
19 | target_resolution: 14 # emb resolution
20 | outputs: ${globals.outputs}
21 |
22 | unet:
23 | _class_name: UNet2DModel
24 | sample_size: 14 # actual size is 16
25 | in_channels: 4
26 | out_channels: 4
27 | center_input_sample: false
28 | time_embedding_type: positional
29 | freq_shift: 0
30 | flip_sin_to_cos: true
31 | down_block_types:
32 | - AttnDownBlock2D
33 | - AttnDownBlock2D
34 | - AttnDownBlock2D
35 | - DownBlock2D
36 | up_block_types:
37 | - UpBlock2D
38 | - AttnUpBlock2D
39 | - AttnUpBlock2D
40 | - AttnUpBlock2D
41 | block_out_channels:
42 | - 128
43 | - 256
44 | - 256
45 | - 512
46 | layers_per_block: 2
47 | mid_block_scale_factor: 1
48 | downsample_padding: 1
49 | downsample_type: resnet
50 | upsample_type: resnet
51 | dropout: 0.0
52 | act_fn: silu
53 | attention_head_dim: 8
54 | norm_num_groups: 32
55 | attn_norm_num_groups: null
56 | norm_eps: 1e-05
57 | resnet_time_scale_shift: "default"
58 | class_embed_type: null
59 | num_class_embeds: null
60 |
61 | noise_scheduler:
62 | _class_name: DDPMScheduler
63 | num_train_timesteps: 1000
64 | beta_start: 0.0001
65 | beta_end: 0.02
66 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2
67 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range
68 | clip_sample: true
69 | clip_sample_range: 4.0 # default 1
70 | prediction_type: v_prediction # epsilon, sample, v_prediction
71 | thresholding: false # do not touch
72 | dynamic_thresholding_ratio: 0.995 # unused
73 | sample_max_value: 1.0 # unused
74 | timestep_spacing: "leading" #
75 | steps_offset: 0 # unused
76 |
77 | train_batch_size: 256
78 | dataloader_num_workers: 16
79 | max_train_steps: 500000
80 |
81 | learning_rate: 3e-4
82 | lr_warmup_steps: 500
83 | scale_lr: false
84 | lr_scheduler: constant
85 | use_8bit_adam: false
86 | gradient_accumulation_steps: 1
87 |
88 | noise_offset: 0.0
89 |
90 | gradient_checkpointing: false
91 | use_ema: true
92 | enable_xformers_memory_efficient_attention: false
93 | allow_tf32: true
94 |
95 | adam_beta1: 0.9
96 | adam_beta2: 0.999
97 | adam_weight_decay: 1e-2
98 | adam_epsilon: 1e-08
99 | max_grad_norm: 1.0
100 |
101 | logging_dir: logs
102 | mixed_precision: "fp16" # "no", "fp16", "bf16"
103 |
104 | validation_timesteps: 128
105 | validation_fps: ${globals.target_fps}
106 | validation_frames: ${globals.target_nframes}
107 | validation_count: 4 # defines the number of samples
108 | validation_guidance: 1.0
109 | validation_steps: 2500
110 |
111 | report_to: wandb
112 | checkpointing_steps: 10000 # ~3/hour
113 | checkpoints_total_limit: 100 # no limit
114 | resume_from_checkpoint: latest
115 | tracker_project_name: echosyn
116 |
117 | seed: 42
118 |
119 |
--------------------------------------------------------------------------------
/echosyn/lidm/sample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import math
4 | import os
5 | import shutil
6 | import json
7 | from glob import glob
8 | from einops import rearrange
9 | from omegaconf import OmegaConf
10 | import numpy as np
11 | from tqdm import tqdm
12 | from packaging import version
13 | from functools import partial
14 | from PIL import Image
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.utils.checkpoint
20 | from torch.utils.data import DataLoader, Dataset
21 | from torchvision import transforms
22 |
23 | import diffusers
24 | from diffusers import AutoencoderKL, DDPMScheduler, UNet3DConditionModel, UNetSpatioTemporalConditionModel, DDIMScheduler
25 |
26 | from echosyn.common.datasets import instantiate_dataset
27 | from echosyn.common import (
28 | padf, unpadf,
29 | load_model
30 | )
31 |
32 | """
33 | python echosyn/lidm/sample.py \
34 | --config echosyn/lidm/configs/dynamic.yaml \
35 | --unet experiments/lidm_dynamic/checkpoint-500000/unet_ema \
36 | --vae models/vae \
37 | --output samples/dynamic \
38 | --num_samples 50000 \
39 | --batch_size 128 \
40 | --num_steps 64 \
41 | --save_latent \
42 | --seed 0
43 | """
44 |
45 | if __name__ == "__main__":
46 | # 1 - Parse command line arguments
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument("--config", type=str, default=None, help="Path to config file.")
49 | parser.add_argument("--unet", type=str, default=None, help="Path unet checkpoint.")
50 | parser.add_argument("--vae", type=str, default=None, help="Path vae checkpoint.")
51 | parser.add_argument("--output", type=str, default='.', help="Output directory.")
52 | parser.add_argument("--num_samples", type=int, default=8, help="Number of samples to generate.")
53 | parser.add_argument("--batch_size", type=int, default=8, help="Batch size.")
54 | parser.add_argument("--num_steps", type=int, default=128, help="Number of steps.")
55 | parser.add_argument("--seed", type=int, default=None, help="Random seed.")
56 | parser.add_argument("--save_latent", action="store_true", help="Save latents.")
57 | parser.add_argument("--ddim", action="store_true", help="Save video.")
58 | args = parser.parse_args()
59 |
60 | config = OmegaConf.load(args.config)
61 |
62 | # 2 - Load models
63 | unet = load_model(args.unet)
64 | vae = load_model(args.vae)
65 |
66 | # 3 - Load scheduler
67 | scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler)
68 | scheduler_klass_name = scheduler_kwargs.pop("_class_name")
69 | if args.ddim:
70 | print("Using DDIMScheduler")
71 | scheduler_klass_name = "DDIMScheduler"
72 | scheduler_kwargs.pop("variance_type")
73 | scheduler_klass = getattr(diffusers, scheduler_klass_name, None)
74 | assert scheduler_klass is not None, f"Could not find scheduler class {scheduler_klass_name}"
75 | scheduler = scheduler_klass(**scheduler_kwargs)
76 | scheduler.set_timesteps(args.num_steps)
77 | timesteps = scheduler.timesteps
78 |
79 | # 5 - Setup
80 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81 | dtype = torch.float32
82 | generator = torch.Generator(device=device).manual_seed(config.seed) if config.seed is not None else None
83 | unet = unet.to(device, dtype)
84 | vae = vae.to(device, torch.float32)
85 | unet.eval()
86 | vae.eval()
87 |
88 | format_input = padf
89 | format_output = unpadf
90 |
91 | B, C, H, W = args.batch_size, config.unet.out_channels, config.unet.sample_size, config.unet.sample_size
92 |
93 | forward_kwargs = {
94 | "timestep": -1,
95 | }
96 |
97 | sample_index = 0
98 |
99 | os.makedirs(args.output, exist_ok=True)
100 | os.makedirs(os.path.join(args.output, "images"), exist_ok=True)
101 | if args.save_latent:
102 | os.makedirs(os.path.join(args.output, "latents"), exist_ok=True)
103 | finished = False
104 |
105 | # 6 - Generate samples
106 | with torch.no_grad():
107 | for _ in tqdm(range(int(np.ceil(args.num_samples/args.batch_size)))):
108 | if finished:
109 | break
110 |
111 | latents = torch.randn((B, C, H, W), device=device, dtype=dtype, generator=generator)
112 |
113 | with torch.autocast("cuda"):
114 | for t in timesteps:
115 | forward_kwargs["timestep"] = t
116 | latent_model_input = latents
117 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
118 | latent_model_input, padding = format_input(latent_model_input, mult=3)
119 | noise_pred = unet(latent_model_input, **forward_kwargs).sample
120 | noise_pred = format_output(noise_pred, pad=padding)
121 | latents = scheduler.step(noise_pred, t, latents).prev_sample
122 |
123 | if args.save_latent:
124 | latents_clean = latents.clone()
125 |
126 | # VAE decode
127 | rep = [1, 3, 1, 1]
128 | latents = latents / vae.config.scaling_factor
129 | images = vae.decode(latents.float()).sample
130 | images = (images + 1) * 128 # [-1, 1] -> [0, 256]
131 |
132 | # grayscale
133 | images = images.mean(1).unsqueeze(1).repeat(*rep)
134 |
135 | images = images.clamp(0, 255).to(torch.uint8).cpu()
136 | images = rearrange(images, 'b c h w -> b h w c')
137 |
138 | # 7 - Save samples
139 | images = images.numpy()
140 | for j in range(B):
141 |
142 | Image.fromarray(images[j]).save(os.path.join(args.output, "images", f"sample_{sample_index:06d}.jpg"))
143 | if args.save_latent:
144 | torch.save(latents_clean[j].clone(), os.path.join(args.output, "latents", f"sample_{sample_index:06d}.pt"))
145 |
146 | sample_index += 1
147 | if sample_index >= args.num_samples:
148 | finished = True
149 | break
150 |
151 | print(f"Finished generating {sample_index} samples.")
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/echosyn/lidm/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import math
4 | import os
5 | import shutil
6 | from einops import rearrange
7 | from omegaconf import OmegaConf
8 | import numpy as np
9 | from tqdm.auto import tqdm
10 | from packaging import version
11 | from functools import partial
12 | from copy import deepcopy
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.utils.checkpoint
18 | from torch.utils.data import DataLoader, Dataset
19 | from torchvision import transforms
20 |
21 | import accelerate
22 | from accelerate import Accelerator
23 | from accelerate.logging import get_logger
24 | from accelerate.state import AcceleratorState
25 | from accelerate.utils import ProjectConfiguration, set_seed
26 |
27 | import diffusers
28 | from diffusers import AutoencoderKL, DDPMScheduler, UNet3DConditionModel, UNetSpatioTemporalConditionModel
29 | from diffusers.optimization import get_scheduler
30 | from diffusers.training_utils import EMAModel, compute_snr
31 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
32 | from diffusers.utils.import_utils import is_xformers_available
33 |
34 | from echosyn.common.datasets import instantiate_dataset
35 | from echosyn.common import (
36 | padf, unpadf,
37 | instantiate_from_config
38 | )
39 |
40 | if is_wandb_available():
41 | import wandb
42 |
43 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
44 | check_min_version("0.22.0.dev0")
45 |
46 | logger = get_logger(__name__, log_level="INFO")
47 |
48 | def log_validation(
49 | config,
50 | unet,
51 | vae,
52 | scheduler,
53 | accelerator,
54 | weight_dtype,
55 | epoch,
56 | val_dataset
57 | ):
58 | logger.info("Running validation... ")
59 |
60 | val_unet = accelerator.unwrap_model(unet)
61 | val_vae = vae.to(accelerator.device, dtype=torch.float32)
62 | scheduler.set_timesteps(config.validation_timesteps)
63 | timesteps = scheduler.timesteps
64 |
65 | if config.enable_xformers_memory_efficient_attention:
66 | val_unet.enable_xformers_memory_efficient_attention()
67 |
68 | if config.seed is None:
69 | generator = None
70 | else:
71 | generator = torch.Generator(device=accelerator.device).manual_seed(config.seed)
72 |
73 | indices = np.random.choice(len(val_dataset), size=config.validation_count, replace=False)
74 | ref_elements = [val_dataset[i] for i in indices]
75 | ref_frames = [e['image'] for e in ref_elements]
76 | ref_frames = torch.stack(ref_frames, dim=0) # B x C x H x W
77 | ref_frames = ref_frames.to(accelerator.device, dtype=weight_dtype)
78 |
79 | format_input = padf
80 | format_output = unpadf
81 |
82 | logger.info("Sampling... ")
83 | with torch.no_grad(), torch.autocast("cuda"):
84 | # prepare model inputs
85 | B, C, H, W = config.validation_count, 4, config.unet.sample_size, config.unet.sample_size
86 | latents = torch.randn((B, C, H, W), device=accelerator.device, dtype=weight_dtype, generator=generator)
87 |
88 | forward_kwargs = {
89 | "timestep": timesteps,
90 | }
91 |
92 | # reverse diffusionn loop
93 | for t in timesteps:
94 | latent_model_input = latents
95 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
96 | latent_model_input, padding = format_input(latent_model_input, mult=3)
97 | noise_pred = unet(latent_model_input, t).sample
98 | noise_pred = format_output(noise_pred, pad=padding)
99 | latents = scheduler.step(noise_pred, t, latents).prev_sample
100 |
101 | # VAE decoding
102 | with torch.no_grad(): # no autocast
103 | latents = latents / val_vae.config.scaling_factor
104 | images = val_vae.decode(latents.float()).sample
105 | images = (images + 1) * 128 # [-1, 1] -> [0, 256]
106 | images = images.clamp(0, 255).to(torch.uint8).cpu()
107 |
108 | ref_frames = ref_frames / val_vae.config.scaling_factor
109 | ref_frames = val_vae.decode(ref_frames.float()).sample
110 | ref_frames = (ref_frames + 1) * 128 # [-1, 1] -> [0, 256]
111 | ref_frames = ref_frames.clamp(0, 255).to(torch.uint8).cpu()
112 |
113 | images = torch.cat([ref_frames, images], dim=2) # B x C x (2 H) x W // vertical concat
114 |
115 | # reshape for wandb
116 | images = rearrange(images, "b c h w -> h (b w) c") # prepare for wandb
117 | images = images.numpy()
118 |
119 | logger.info("Done sampling... ")
120 |
121 | for tracker in accelerator.trackers:
122 | if tracker.name == "wandb":
123 | tracker.log({"validation": wandb.Image(images)})
124 | logger.info("Samples sent to wandb.")
125 | else:
126 | logger.warn(f"image logging not implemented for {tracker.name}")
127 |
128 | del val_unet
129 | del val_vae
130 | torch.cuda.empty_cache()
131 |
132 | return images
133 |
134 | def parse_args():
135 | parser = argparse.ArgumentParser(description="")
136 | parser.add_argument("--config", type=str, help="Path to the config file.")
137 | args = parser.parse_args()
138 | return args
139 |
140 | def main():
141 | args = parse_args()
142 | config = OmegaConf.load(args.config)
143 |
144 | # Setup accelerator
145 | logging_dir = os.path.join(config.output_dir, config.logging_dir)
146 |
147 | accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
148 |
149 | accelerator = Accelerator(
150 | gradient_accumulation_steps=config.gradient_accumulation_steps,
151 | mixed_precision=config.mixed_precision,
152 | log_with=config.report_to,
153 | project_config=accelerator_project_config,
154 | )
155 |
156 | # Make one log on every process with the configuration for debugging.
157 | logging.basicConfig(
158 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
159 | datefmt="%m/%d/%Y %H:%M:%S",
160 | level=logging.INFO,
161 | )
162 | logger.info(accelerator.state, main_process_only=False)
163 | if accelerator.is_local_main_process:
164 | diffusers.utils.logging.set_verbosity_info()
165 | else:
166 | diffusers.utils.logging.set_verbosity_error()
167 |
168 | # If passed along, set the training seed now.
169 | if config.seed is not None:
170 | set_seed(config.seed)
171 |
172 | # Handle the repository creation
173 | if accelerator.is_main_process:
174 | if config.output_dir is not None:
175 | os.makedirs(config.output_dir, exist_ok=True)
176 |
177 | noise_scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler, resolve=True)
178 | noise_scheduler_klass_name = noise_scheduler_kwargs.pop("_class_name")
179 | noise_scheduler_klass = globals().get(noise_scheduler_klass_name, None)
180 | assert noise_scheduler_klass is not None, f"Could not find class {noise_scheduler_klass_name}"
181 | noise_scheduler = noise_scheduler_klass(**noise_scheduler_kwargs)
182 |
183 | vae = AutoencoderKL.from_pretrained(config.vae_path).cpu()
184 |
185 | # Create the video unet
186 | unet, unet_klass, unet_kwargs = instantiate_from_config(config.unet, ["diffusers"], return_klass_kwargs=True)
187 | noise_scheduler = instantiate_from_config(config.noise_scheduler, ["diffusers"])
188 |
189 | format_input = padf
190 | format_output = unpadf
191 |
192 | # Freeze vae and text_encoder and set unet to trainable
193 | vae.requires_grad_(False)
194 | unet.train()
195 |
196 | # Create EMA for the unet.
197 | if config.use_ema:
198 | ema_unet = unet_klass(**unet_kwargs)
199 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=unet_klass, model_config=ema_unet.config)
200 |
201 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
202 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
203 | def save_model_hook(models, weights, output_dir):
204 | if accelerator.is_main_process:
205 | if config.use_ema:
206 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
207 |
208 | for i, model in enumerate(models):
209 | model.save_pretrained(os.path.join(output_dir, "unet"))
210 |
211 | # make sure to pop weight so that corresponding model is not saved again
212 | weights.pop()
213 |
214 | def load_model_hook(models, input_dir):
215 | if config.use_ema:
216 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), unet_klass)
217 | ema_unet.load_state_dict(load_model.state_dict())
218 | ema_unet.to(accelerator.device)
219 | del load_model
220 |
221 | for i in range(len(models)):
222 | # pop models so that they are not loaded again
223 | model = models.pop()
224 |
225 | # load diffusers style into model
226 | load_model = unet_klass.from_pretrained(input_dir, subfolder="unet")
227 | model.register_to_config(**load_model.config)
228 |
229 | model.load_state_dict(load_model.state_dict())
230 | del load_model
231 |
232 | accelerator.register_save_state_pre_hook(save_model_hook)
233 | accelerator.register_load_state_pre_hook(load_model_hook)
234 |
235 | if config.gradient_checkpointing:
236 | unet.enable_gradient_checkpointing()
237 |
238 | # Enable TF32 for faster training on Ampere GPUs,
239 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
240 | if config.allow_tf32:
241 | torch.backends.cuda.matmul.allow_tf32 = True
242 |
243 | optimizer = torch.optim.AdamW(
244 | unet.parameters(),
245 | lr=config.learning_rate,
246 | betas=(config.adam_beta1, config.adam_beta2),
247 | weight_decay=config.adam_weight_decay,
248 | eps=config.adam_epsilon,
249 | )
250 |
251 | train_dataset = instantiate_dataset(config.datasets, split=["TRAIN"])
252 | val_dataset = instantiate_dataset(config.datasets, split=["VAL"])
253 |
254 | # DataLoaders creation:
255 | train_dataloader = torch.utils.data.DataLoader(
256 | train_dataset,
257 | shuffle=True,
258 | batch_size=config.train_batch_size,
259 | num_workers=config.dataloader_num_workers,
260 | )
261 |
262 | # Scheduler and math around the number of training steps.
263 | overrode_max_train_steps = False
264 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
265 | if config.max_train_steps is None:
266 | config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
267 | overrode_max_train_steps = True
268 |
269 | lr_scheduler = get_scheduler(
270 | config.lr_scheduler,
271 | optimizer=optimizer,
272 | num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
273 | num_training_steps=config.max_train_steps * accelerator.num_processes,
274 | )
275 |
276 | # Prepare everything with our `accelerator`.
277 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
278 | unet, optimizer, train_dataloader, lr_scheduler
279 | )
280 |
281 | if config.use_ema:
282 | ema_unet.to(accelerator.device)
283 |
284 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
285 | # as these weights are only used for inference, keeping weights in full precision is not required.
286 | weight_dtype = torch.float32
287 | if accelerator.mixed_precision == "fp16":
288 | weight_dtype = torch.float16
289 | config.mixed_precision = accelerator.mixed_precision
290 | elif accelerator.mixed_precision == "bf16":
291 | weight_dtype = torch.bfloat16
292 | config.mixed_precision = accelerator.mixed_precision
293 |
294 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
295 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
296 | if overrode_max_train_steps:
297 | config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
298 | # Afterwards we recalculate our number of training epochs
299 | config.num_train_epochs = math.ceil(config.max_train_steps / num_update_steps_per_epoch)
300 |
301 | # We need to initialize the trackers we use, and also store our configuration.
302 | # The trackers initializes automatically on the main process.
303 | if accelerator.is_main_process:
304 | tracker_config = OmegaConf.to_container(config, resolve=True)
305 | accelerator.init_trackers(
306 | config.tracker_project_name,
307 | tracker_config,
308 | init_kwargs={
309 | "wandb": {
310 | "group": config.wandb_group
311 | },
312 | },
313 | )
314 |
315 | # Train!
316 | total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
317 | model_num_params = sum(p.numel() for p in unet.parameters())
318 | model_trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
319 |
320 | logger.info("***** Running training *****")
321 | logger.info(f" Num examples = {len(train_dataset)}")
322 | logger.info(f" Num Epochs = {config.num_train_epochs}")
323 | logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
324 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
325 | logger.info(f" Gradient Accumulation steps = {config.gradient_accumulation_steps}")
326 | logger.info(f" Total optimization steps = {config.max_train_steps}")
327 | logger.info(f" U-Net: Total params = {model_num_params} \t Trainable params = {model_trainable_params} ({model_trainable_params/model_num_params*100:.2f}%)")
328 | global_step = 0
329 | first_epoch = 0
330 |
331 | # Potentially load in the weights and states from a previous save
332 | if config.resume_from_checkpoint:
333 | if config.resume_from_checkpoint != "latest":
334 | path = os.path.basename(config.resume_from_checkpoint)
335 | else:
336 | # Get the most recent checkpoint
337 | dirs = os.listdir(config.output_dir)
338 | dirs = [d for d in dirs if d.startswith("checkpoint")]
339 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
340 | path = dirs[-1] if len(dirs) > 0 else None
341 |
342 | if path is None:
343 | accelerator.print(
344 | f"Checkpoint '{config.resume_from_checkpoint}' does not exist. Starting a new training run."
345 | )
346 | config.resume_from_checkpoint = None
347 | initial_global_step = 0
348 | else:
349 | accelerator.print(f"Resuming from checkpoint {path}")
350 | accelerator.load_state(os.path.join(config.output_dir, path))
351 | global_step = int(path.split("-")[1])
352 |
353 | initial_global_step = global_step
354 | first_epoch = global_step // num_update_steps_per_epoch
355 |
356 | else:
357 | initial_global_step = 0
358 |
359 | progress_bar = tqdm(
360 | range(0, config.max_train_steps),
361 | initial=initial_global_step,
362 | desc="Steps",
363 | # Only show the progress bar once on each machine.
364 | # disable=not accelerator.is_local_main_process,
365 | disable=not accelerator.is_main_process,
366 | )
367 |
368 | for epoch in range(first_epoch, config.num_train_epochs):
369 | train_loss = 0.0
370 | prediction_mean = 0.0
371 | prediction_std = 0.0
372 | target_mean = 0.0
373 | target_std = 0.0
374 | mean_losses = 0.0
375 | for step, batch in enumerate(train_dataloader):
376 | with accelerator.accumulate(unet):
377 |
378 | latents = batch['image'] # B x C x H x W
379 |
380 | B, C, H, W = latents.shape
381 |
382 | # Sample a random timestep for each video
383 | timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps), (B,), device=latents.device).long()
384 |
385 | # Sample noise that we'll add to the latents
386 | noise = torch.randn_like(latents)
387 | if config.noise_offset > 0:
388 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
389 | noise += config.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
390 |
391 | if config.get('input_perturbation', 0) > 0.0:
392 | noisy_latents = noise_scheduler.add_noise(latents, noise + config.input_perturbation*torch.rand(1,).item() * torch.randn_like(noise), timesteps)
393 | else:
394 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
395 |
396 | # Predict the noise residual and compute loss
397 | noisy_latents, padding = format_input(noisy_latents, mult=3)
398 | model_pred = unet(sample=noisy_latents, timestep=timesteps).sample
399 | model_pred = format_output(model_pred, pad=padding)
400 |
401 | if noise_scheduler.config.prediction_type == "epsilon":
402 | target = noise
403 | elif noise_scheduler.config.prediction_type == "v_prediction":
404 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
405 | else:
406 | assert noise_scheduler.config.prediction_type == "sample", f"Unknown prediction type {noise_scheduler.config.prediction_type}"
407 | target = latents
408 |
409 | # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
410 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
411 | loss = loss.mean()
412 | mean_loss = loss.item()
413 |
414 | # Gather the losses across all processes for logging (if we use distributed training).
415 | avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
416 | train_loss += avg_loss.item() / config.gradient_accumulation_steps
417 | mean_losses += mean_loss / config.gradient_accumulation_steps
418 | prediction_mean += model_pred.mean().item() / config.gradient_accumulation_steps
419 | prediction_std += model_pred.std().item() / config.gradient_accumulation_steps
420 | target_mean += target.mean().item() / config.gradient_accumulation_steps
421 | target_std += target.std().item() / config.gradient_accumulation_steps
422 |
423 | # Backpropagate
424 | accelerator.backward(loss)
425 | if accelerator.sync_gradients:
426 | accelerator.clip_grad_norm_(unet.parameters(), config.max_grad_norm)
427 | optimizer.step()
428 | lr_scheduler.step()
429 | optimizer.zero_grad()
430 |
431 | # Checks if the accelerator has performed an optimization step behind the scenes
432 | if accelerator.sync_gradients:
433 | if config.use_ema:
434 | ema_unet.step(unet.parameters())
435 | progress_bar.update(1)
436 | global_step += 1
437 | accelerator.log({
438 | "train_loss": train_loss,
439 | "prediction_mean": prediction_mean,
440 | "prediction_std": prediction_std,
441 | "target_mean": target_mean,
442 | "target_std": target_std,
443 | "mean_losses": mean_losses,
444 | }, step=global_step)
445 | train_loss = 0.0
446 | prediction_mean = 0.0
447 | prediction_std = 0.0
448 | target_mean = 0.0
449 | target_std = 0.0
450 | mean_losses = 0.0
451 |
452 | if global_step % config.checkpointing_steps == 0:
453 | if accelerator.is_main_process:
454 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
455 | if config.checkpoints_total_limit is not None:
456 | checkpoints = os.listdir(config.output_dir)
457 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
458 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
459 |
460 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
461 | if len(checkpoints) >= config.checkpoints_total_limit:
462 | num_to_remove = len(checkpoints) - config.checkpoints_total_limit + 1
463 | removing_checkpoints = checkpoints[0:num_to_remove]
464 |
465 | logger.info(
466 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
467 | )
468 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
469 |
470 | for removing_checkpoint in removing_checkpoints:
471 | removing_checkpoint = os.path.join(config.output_dir, removing_checkpoint)
472 | shutil.rmtree(removing_checkpoint)
473 |
474 | save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
475 | accelerator.save_state(save_path)
476 | OmegaConf.save(config, os.path.join(config.output_dir, "config.yaml"))
477 | logger.info(f"Saved state to {save_path}")
478 |
479 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
480 | progress_bar.set_postfix(**logs)
481 |
482 | if global_step >= config.max_train_steps:
483 | break
484 |
485 | if accelerator.is_main_process:
486 | if global_step % config.validation_steps == 0:
487 | if config.use_ema:
488 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
489 | ema_unet.store(unet.parameters())
490 | ema_unet.copy_to(unet.parameters())
491 |
492 | log_validation(
493 | config,
494 | unet,
495 | vae,
496 | deepcopy(noise_scheduler),
497 | accelerator,
498 | weight_dtype,
499 | epoch,
500 | val_dataset,
501 | )
502 |
503 | if config.use_ema:
504 | # Switch back to the original UNet parameters.
505 | ema_unet.restore(unet.parameters())
506 |
507 | accelerator.end_training()
508 |
509 |
510 | if __name__ == "__main__":
511 | main()
--------------------------------------------------------------------------------
/echosyn/lvdm/README.md:
--------------------------------------------------------------------------------
1 | # Latent Video Diffusion Model
2 |
3 | The Latent Video Diffusion Model (LVDM) is responsible for animating the latent representation of a heart generated by the Latent Image Diffusion Model (LIDM). The LVDM is trained on the VAE-encoded videos. We condition it on an encoded frame and an ejection fraction score, and train it to reconstruct the video corresponding to that frame and ejection fraction.
4 |
5 | During inference, it can animate any heart, real or synthetic, by conditioning on the latent representation of the heart and the desired ejection fraction.
6 |
7 | ## 1. Activate the environment
8 |
9 | First, activate the echosyn environment.
10 |
11 | ```bash
12 | conda activate echosyn
13 | ```
14 |
15 | ## 2. Data preparation
16 | Follow the instruction in the [Data preparation](../../README.md#data-preparation) to prepare the data for training. Here, you need the VAE-encoded videos.
17 |
18 | ## 3. Train the LVDM
19 |
20 | Once the environment is set up and the data is ready, you can train the LVDM with the following command:
21 |
22 | ```bash
23 | python echosyn/lvdm/train.py --config echosyn/lvdm/configs/default.yaml
24 | ```
25 |
26 | or this one, for multi-gpu training:
27 |
28 | ```bash
29 | accelerate launch \
30 | --num_processes 8 \
31 | --multi_gpu \
32 | --mixed_precision fp16 \
33 | echosyn/lvdm/train.py \
34 | --config echosyn/lvdm/configs/default.yaml
35 | ```
36 |
37 | Note that we train a single model for all datasets. The model is conditioned on the frames, so the same model can generate videos for any dataset.
38 |
39 | ## 4. Sample from the LVDM
40 |
41 | Once the LVDM is trained, you can sample from it with the following command:
42 |
43 | ```bash
44 | python echosyn/lvdm/sample.py \
45 | --config echosyn/lvdm/configs/default.yaml \
46 | --unet experiments/lvdm/checkpoint-500000/unet_ema \
47 | --vae models/vae \
48 | --conditioning samples/lidm_dynamic/privacy_compliant_latents \
49 | --output samples/lvdm_dynamic \
50 | --num_samples 2048 \
51 | --batch_size 8 \
52 | --num_steps 64 \
53 | --min_lvef 10 \
54 | --max_lvef 90 \
55 | --save_as avi,jpg \
56 | --frames 192
57 | ```
58 |
59 | This will generate 2048 videos of 192 frames, conditioned on the latent synthetic and privacy-compliant representation of the heart and uniformaly sampled ejection fraction scores. The videos will be saved in the `samples/lvdm_dynamic/avi` directory.
60 |
61 | ## 5. Evaluate the LVDM
62 |
63 | To evaluate the LIDMs, we use the FID, FVD16, FVD128 and IS scores.
64 | To do so, we need to generate 2048 videos with 192 frames for the dynamic dataset or 128 frames for the two pediatric datasets.
65 | The outputs MUST be in jpg format.
66 | The samples are compared to the real samples, which are generated in the [Data preparation](../../README.md#data-preparation) step.
67 |
68 | Then, to evaluate the synthetic videos, run the following commands:
69 |
70 | ```bash
71 | cd external/stylegan-v
72 |
73 | python src/scripts/calc_metrics_for_dataset.py \
74 | --real_data_path ../../data/reference/dynamic \
75 | --fake_data_path ../../samples/lvdm_dynamic/jpg \
76 | --mirror 0 --gpus 1 --resolution 112 \
77 | --metrics fvd2048_16f,fvd2048_128f,fid50k_full,is50k >> "../../samples/lvdm_dynamic/metrics.txt"
78 |
79 | python src/scripts/calc_metrics_for_dataset.py \
80 | --real_data_path ../../data/reference/ped_a4c \
81 | --fake_data_path ../../samples/lvdm_ped_a4c/jpg \
82 | --mirror 0 --gpus 1 --resolution 112 \
83 | --metrics fvd2048_16f,fvd2048_128f,fid50k_full,is50k >> "../../samples/lvdm_ped_a4c/metrics.txt"
84 |
85 | python src/scripts/calc_metrics_for_dataset.py \
86 | --real_data_path ../../data/reference/ped_psax \
87 | --fake_data_path ../../samples/lvdm_ped_psax/jpg \
88 | --mirror 0 --gpus 1 --resolution 112 \
89 | --metrics fvd2048_16f,fvd2048_128f,fid50k_full,is50k >> "../../samples/lvdm_ped_psax/metrics.txt"
90 | ```
91 |
92 | ## 6. Save the LVDM for later use
93 |
94 | Once the LVDM is trained, you can save it for later use with the following command:
95 |
96 | ```bash
97 | mkdir -p models/lvdm; cp -r experiments/lvdm/checkpoint-500000/unet_ema/* models/lvdm/; cp experiments/lvdm/config.yaml models/lvdm/
98 | ```
99 |
100 | This will save the selected ema version of the model, ready to be loaded in any other script as a standalone model.
--------------------------------------------------------------------------------
/echosyn/lvdm/configs/default.yaml:
--------------------------------------------------------------------------------
1 | wandb_group: "lvdm"
2 | output_dir: experiments/lvdm # experiments/lvdm_vpred_64
3 |
4 | pretrained_model_name_or_path: null
5 | vae_path: models/vae
6 |
7 | globals:
8 | target_fps: 32
9 | target_nframes: 64
10 | outputs: ["video", "lvef", "image"]
11 |
12 |
13 | datasets:
14 | - name: Latent
15 | active: true
16 | params:
17 | root: data/latents/dynamic
18 | target_fps: ${globals.target_fps}
19 | target_nframes: ${globals.target_nframes}
20 | target_resolution: 14 # emb resolution
21 | outputs: ${globals.outputs}
22 |
23 | - name: Latent
24 | active: true
25 | params:
26 | root: data/latents/ped_a4c
27 | target_fps: ${globals.target_fps}
28 | target_nframes: ${globals.target_nframes}
29 | target_resolution: 14 # emb resolution
30 | outputs: ${globals.outputs}
31 |
32 | - name: Latent
33 | active: true
34 | params:
35 | root: data/latents/ped_psax
36 | target_fps: ${globals.target_fps}
37 | target_nframes: ${globals.target_nframes}
38 | target_resolution: 14 # emb resolution
39 | outputs: ${globals.outputs}
40 |
41 | unet:
42 | _class_name: UNetSpatioTemporalConditionModel
43 | addition_time_embed_dim: 1
44 | block_out_channels:
45 | - 128
46 | - 256
47 | - 256
48 | - 512
49 | cross_attention_dim: 1
50 | down_block_types:
51 | - CrossAttnDownBlockSpatioTemporal
52 | - CrossAttnDownBlockSpatioTemporal
53 | - CrossAttnDownBlockSpatioTemporal
54 | - DownBlockSpatioTemporal
55 | in_channels: 8
56 | layers_per_block: 2
57 | num_attention_heads:
58 | - 8
59 | - 16
60 | - 16
61 | - 32
62 | num_frames: ${globals.target_nframes}
63 | out_channels: 4
64 | projection_class_embeddings_input_dim: 1
65 | sample_size: 14
66 | transformer_layers_per_block: 1
67 | up_block_types:
68 | - UpBlockSpatioTemporal
69 | - CrossAttnUpBlockSpatioTemporal
70 | - CrossAttnUpBlockSpatioTemporal
71 | - CrossAttnUpBlockSpatioTemporal
72 |
73 | noise_scheduler:
74 | _class_name: DDPMScheduler
75 | num_train_timesteps: 1000
76 | beta_start: 0.0001
77 | beta_end: 0.02
78 | beta_schedule: linear # linear, scaled_linear, or squaredcos_cap_v2
79 | variance_type: fixed_small # fixed_small, fixed_small_log, fixed_large, fixed_large_log, learned or learned_range
80 | clip_sample: true
81 | clip_sample_range: 4.0 # default 1
82 | prediction_type: v_prediction # epsilon, sample, v_prediction
83 | thresholding: false # do not touch
84 | dynamic_thresholding_ratio: 0.995 # unused
85 | sample_max_value: 1.0 # unused
86 | timestep_spacing: "leading" #
87 | steps_offset: 0 # unused
88 |
89 | train_batch_size: 16
90 | dataloader_num_workers: 16
91 | max_train_steps: 500000
92 |
93 | learning_rate: 1e-4
94 | lr_warmup_steps: 500
95 | scale_lr: false
96 | lr_scheduler: constant
97 | use_8bit_adam: false
98 | gradient_accumulation_steps: 1
99 |
100 | noise_offset: 0.1
101 | drop_conditionning: 0.1 # 10 % of the time, the LVEF conditionning is dropped
102 |
103 | gradient_checkpointing: false
104 | use_ema: true
105 | enable_xformers_memory_efficient_attention: false
106 | allow_tf32: true
107 |
108 | adam_beta1: 0.9
109 | adam_beta2: 0.999
110 | adam_weight_decay: 1e-2
111 | adam_epsilon: 1e-08
112 | max_grad_norm: 1.0
113 |
114 | logging_dir: logs
115 | mixed_precision: "fp16" # "no", "fp16", "bf16"
116 |
117 | validation_timesteps: 128
118 | validation_fps: ${globals.target_fps}
119 | validation_frames: ${globals.target_nframes}
120 | validation_lvefs: [0.0, 0.4, 0.7, 1.0] # defines the number of samples
121 | validation_guidance: 1.0
122 | validation_steps: 1500
123 |
124 | report_to: wandb
125 | checkpointing_steps: 10000 # ~3/hour
126 | checkpoints_total_limit: 100 # no limit
127 | resume_from_checkpoint: latest
128 | tracker_project_name: echosyn
129 |
130 | seed: 42
131 |
132 |
--------------------------------------------------------------------------------
/echosyn/lvdm/sample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import math
4 | import os
5 | import shutil
6 | import json
7 | from glob import glob
8 | from einops import rearrange
9 | from omegaconf import OmegaConf
10 | import numpy as np
11 | from tqdm import tqdm
12 | from packaging import version
13 | from functools import partial
14 | from PIL import Image
15 | import pandas as pd
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import torch.utils.checkpoint
21 | from torch.utils.data import DataLoader, Dataset
22 | from torchvision import transforms
23 |
24 | import diffusers
25 | from diffusers import AutoencoderKL, DDPMScheduler, UNet3DConditionModel, UNetSpatioTemporalConditionModel
26 |
27 | from echosyn.common.datasets import TensorSet, ImageSet
28 | from echosyn.common import (
29 | pad_reshape, unpad_reshape, padf, unpadf,
30 | load_model, save_as_mp4, save_as_gif, save_as_img, save_as_avi,
31 | parse_formats,
32 | )
33 |
34 | """
35 | python echosyn/lvdm/sample.py \
36 | --config echosyn/lvdm/configs/default.yaml \
37 | --unet experiments/lvdm/checkpoint-500000/unet_ema \
38 | --vae model/vae \
39 | --conditioning samples/dynamic/privacy_compliant_latents \
40 | --output samples/dynamic/privacy_compliant_samples \
41 | --num_samples 2048 \
42 | --batch_size 8 \
43 | --num_steps 64 \
44 | --min_lvef 10 \
45 | --max_lvef 90 \
46 | --save_as avi \
47 | --frames 192
48 | """
49 |
50 | if __name__ == "__main__":
51 | # 1 - Parse command line arguments
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument("--config", type=str, default=None, help="Path to config file.")
54 | parser.add_argument("--unet", type=str, default=None, help="Path unet checkpoint.")
55 | parser.add_argument("--vae", type=str, default=None, help="Path vae checkpoint.")
56 | parser.add_argument("--conditioning", type=str, default=None, help="Path to the folder containing the conditionning latents.")
57 | parser.add_argument("--output", type=str, default='.', help="Output directory.")
58 | parser.add_argument("--num_samples", type=int, default=8, help="Number of samples to generate.")
59 | parser.add_argument("--batch_size", type=int, default=8, help="Batch size.")
60 | parser.add_argument("--num_steps", type=int, default=64, help="Number of steps.")
61 | parser.add_argument("--min_lvef", type=int, default=10, help="Minimum LVEF.")
62 | parser.add_argument("--max_lvef", type=int, default=90, help="Maximum LVEF.")
63 | parser.add_argument("--save_as", type=parse_formats, default=None, help="Save formats separated by commas (e.g., avi,jpg). Available: avi, mp4, gif, jpg, png, pt")
64 | parser.add_argument("--frames", type=int, default=192, help="Number of frames to generate. Must be a multiple of 32")
65 | parser.add_argument("--seed", type=int, default=None, help="Random seed.")
66 |
67 | args = parser.parse_args()
68 |
69 | config = OmegaConf.load(args.config)
70 |
71 | # 2 - Load models
72 | unet = load_model(args.unet)
73 | vae = load_model(args.vae)
74 |
75 | # 3 - Load scheduler
76 | scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler)
77 | scheduler_klass_name = scheduler_kwargs.pop("_class_name")
78 | scheduler_klass = getattr(diffusers, scheduler_klass_name, None)
79 | assert scheduler_klass is not None, f"Could not find scheduler class {scheduler_klass_name}"
80 | scheduler = scheduler_klass(**scheduler_kwargs)
81 | scheduler.set_timesteps(args.num_steps)
82 | timesteps = scheduler.timesteps
83 |
84 | # 4 - Load dataset
85 | ## detect type of conditioning:
86 | file_ext = os.listdir(args.conditioning)[0].split(".")[-1].lower()
87 | assert file_ext in ["pt", "jpg", "png"], f"Conditioning files must be either .pt, .jpg or .png, not {file_ext}"
88 | if file_ext == "pt":
89 | dataset = TensorSet(args.conditioning)
90 | else:
91 | dataset = ImageSet(args.conditioning, ext=file_ext)
92 | assert len(dataset) > 0, f"No files found in {args.conditioning} with extension {file_ext}"
93 |
94 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True)
95 |
96 | # 5 - Setup
97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
98 | dtype = torch.float32
99 | generator = torch.Generator(device=device).manual_seed(config.seed) if config.seed is not None else None
100 | unet = unet.to(device, dtype)
101 | vae = vae.to(device, torch.float32)
102 | unet.eval()
103 | vae.eval()
104 |
105 | format_input = pad_reshape if config.unet._class_name == "UNetSpatioTemporalConditionModel" else padf
106 | format_output = unpad_reshape if config.unet._class_name == "UNetSpatioTemporalConditionModel" else unpadf
107 |
108 | B, C, T, H, W = args.batch_size, config.unet.out_channels, config.unet.num_frames, config.unet.sample_size, config.unet.sample_size
109 | fps = config.globals.target_fps
110 | # Stitching parameters
111 | args.frames = int(np.ceil(args.frames/32) * 32)
112 | if args.frames > T:
113 | OT = T//2 # overlap 64//2
114 | TR = (args.frames - T) / 32 # total frames (192 - 64) / 32 = 4
115 | TR = int(TR + 1) # total repetitions
116 | NT = (T-OT) * TR + OT # = args.frame
117 | else:
118 | OT = 0
119 | TR = 1
120 | NT = T
121 |
122 | forward_kwargs = {
123 | "timestep": -1,
124 | }
125 |
126 | if config.unet._class_name == "UNetSpatioTemporalConditionModel":
127 | dummy_added_time_ids = torch.zeros((B*TR, config.unet.addition_time_embed_dim), device=device, dtype=dtype)
128 | forward_kwargs["added_time_ids"] = dummy_added_time_ids
129 |
130 | sample_index = 0
131 |
132 | filelist = []
133 |
134 | os.makedirs(args.output, exist_ok=True)
135 | for ext in args.save_as:
136 | os.makedirs(os.path.join(args.output, ext), exist_ok=True)
137 | finished = False
138 |
139 | pbar = tqdm(total=args.num_samples)
140 |
141 | # 6 - Generate samples
142 | with torch.no_grad():
143 | while not finished:
144 | for cond in dataloader:
145 | if finished:
146 | break
147 |
148 | # Prepare latent noise
149 | latents = torch.randn((B, C, NT, H, W), device=device, dtype=dtype, generator=generator)
150 |
151 | # Prepare conditioning - lvef
152 | lvefs = torch.randint(args.min_lvef, args.max_lvef+1, (B,), device=device, dtype=dtype, generator=generator)
153 | lvefs = lvefs / 100.0
154 | lvefs = lvefs[:, None, None]
155 | lvefs = lvefs.repeat_interleave(TR, dim=0)
156 | forward_kwargs["encoder_hidden_states"] = lvefs
157 | # Prepare conditioning - reference frames
158 | latent_cond_images = cond.to(device, torch.float32)
159 | if file_ext != "pt":
160 | # project image to latent space
161 | latent_cond_images = vae.encode(latent_cond_images).latent_dist.sample()
162 | latent_cond_images = latent_cond_images * vae.config.scaling_factor
163 | latent_cond_images = latent_cond_images[:,:,None,:,:].repeat(1,1,NT,1,1) # B x C x T x H x W
164 |
165 | # Denoise the latent
166 | with torch.autocast("cuda"):
167 | for t in timesteps:
168 | forward_kwargs["timestep"] = t
169 | latent_model_input = scheduler.scale_model_input(latents, timestep=t)
170 | latent_model_input = torch.cat((latent_model_input, latent_cond_images), dim=1) # B x 2C x T x H x W
171 | latent_model_input, padding = format_input(latent_model_input, mult=3) # B x T x 2C x H+P x W+P
172 |
173 | # Stitching
174 | inputs = torch.cat([latent_model_input[:,r*(T-OT):r*(T-OT)+T] for r in range(TR)], dim=0) # B*TR x T x 2C x H+P x W+P
175 | noise_pred = unet(inputs, **forward_kwargs).sample
176 | outputs = torch.chunk(noise_pred, TR, dim=0) # TR x B x T x C x H x W
177 | noise_predictions = []
178 | for r in range(TR):
179 | noise_predictions.append(outputs[r] if r == 0 else outputs[r][:,OT:])
180 | noise_pred = torch.cat(noise_predictions, dim=1) # B x NT x C x H x W
181 |
182 | noise_pred = unpad_reshape(noise_pred, pad=padding)
183 | latents = scheduler.step(noise_pred, t, latents).prev_sample
184 |
185 | # VAE decode
186 | latents = rearrange(latents, "b c t h w -> (b t) c h w").cpu()
187 | latents = latents / vae.config.scaling_factor
188 |
189 | # Decode in chunks to save memory
190 | chunked_latents = torch.split(latents, args.batch_size, dim=0)
191 | decoded_chunks = []
192 | for chunk in chunked_latents:
193 | decoded_chunks.append(vae.decode(chunk.float().cuda()).sample.cpu())
194 | video = torch.cat(decoded_chunks, dim=0) # (B*T) x H x W x C
195 |
196 | # format output
197 | video = rearrange(video, "(b t) c h w -> b t h w c", b=B)
198 | video = (video + 1) * 128
199 | video = video.clamp(0, 255).to(torch.uint8)
200 |
201 | print(video.shape, video.dtype, video.min(), video.max())
202 | file_lvefs = lvefs.squeeze()[::TR].mul(100).to(torch.int).tolist()
203 | # save samples
204 | for j in range(B):
205 | # FileName,EF,ESV,EDV,FrameHeight,FrameWidth,FPS,NumberOfFrames,Split
206 | filelist.append([f"sample_{sample_index:06d}", file_lvefs[j], 0, 0, video.shape[1], video.shape[2], fps, video.shape[0], "TRAIN"])
207 | if "mp4" in args.save_as:
208 | save_as_mp4(video[j], os.path.join(args.output, "mp4", f"sample_{sample_index:06d}.mp4"))
209 | if "avi" in args.save_as:
210 | save_as_avi(video[j], os.path.join(args.output, "avi", f"sample_{sample_index:06d}.avi"))
211 | if "gif" in args.save_as:
212 | save_as_gif(video[j], os.path.join(args.output, "gif", f"sample_{sample_index:06d}.gif"))
213 | if "jpg" in args.save_as:
214 | save_as_img(video[j], os.path.join(args.output, "jpg", f"sample_{sample_index:06d}"), ext="jpg")
215 | if "png" in args.save_as:
216 | save_as_img(video[j], os.path.join(args.output, "png", f"sample_{sample_index:06d}"), ext="png")
217 | if "pt" in args.save_as:
218 | torch.save(video[j].clone(), os.path.join(args.output, "pt", f"sample_{sample_index:06d}.pt"))
219 | sample_index += 1
220 | pbar.update(1)
221 | if sample_index >= args.num_samples:
222 | finished = True
223 | break
224 |
225 | df = pd.DataFrame(filelist, columns=["FileName", "EF", "ESV", "EDV", "FrameHeight", "FrameWidth", "FPS", "NumberOfFrames", "Split"])
226 | df.to_csv(os.path.join(args.output, "FileList.csv"), index=False)
227 | print(f"Generated {sample_index} samples.")
228 |
229 |
230 |
231 |
232 |
--------------------------------------------------------------------------------
/echosyn/privacy/README.md:
--------------------------------------------------------------------------------
1 | # Re-Identification Model
2 |
3 | For this work, we use the [Latent Image Diffusion Model (LIDM)](../lidm/README.md) to generate synthetic echocardiography images. To enforce privacy, as a post-hoc step, we train a re-identification model to project real and generated images into a common latent space. This allows us to compute the similarity between two images, and by extension, to detect any synthetic images that are too similar to real ones.
4 |
5 | The re-identification models are trained on the VAE-encoded real images. We train one re-identification model per LIDM. The training takes a few hours on a single A100 GPU.
6 |
7 | ## 1. Activate the environment
8 |
9 | First, activate the echosyn environment.
10 |
11 | ```bash
12 | conda activate echosyn
13 | ```
14 |
15 | ## 2. Data preparation
16 | Follow the instruction in the [Data preparation](../../README.md#data-preparation) to prepare the data for training. Here, you need the VAE-encoded videos.
17 |
18 | ## 3. Train the Re-Identification models
19 | Once the environment is set up and the data is ready, you can train the Re-Identification models with the following commands:
20 |
21 | ```bash
22 | python echosyn/privacy/train.py --config=echosyn/privacy/configs/config_dynamic.json
23 | python echosyn/privacy/train.py --config=echosyn/privacy/configs/config_ped_a4c.json
24 | python echosyn/privacy/train.py --config=echosyn/privacy/configs/config_ped_psax.json
25 | ```
26 |
27 | ## 4. Filter the synthetic images
28 | After training the re-identification models, you can filter the synthetic images, generated with the LIDMs with the following commands:
29 |
30 | ```bash
31 | python echosyn/privacy/apply.py \
32 | --model experiments/reidentification_dynamic \
33 | --synthetic samples/lidm_dynamic/latents \
34 | --reference data/latents/dynamic \
35 | --output samples/lidm_dynamic/privatised_latents
36 | ```
37 |
38 | ```bash
39 | python echosyn/privacy/apply.py \
40 | --model experiments/reidentification_ped_a4c \
41 | --synthetic samples/lidm_ped_a4d/latents \
42 | --reference data/latents/ped_a4c \
43 | --output samples/lidm_ped_a4d/privatised_latents
44 | ```
45 |
46 | ```bash
47 | python echosyn/privacy/apply.py \
48 | --model experiments/reidentification_ped_psax \
49 | --synthetic samples/lidm_ped_psax/latents \
50 | --reference data/latents/ped_psax \
51 | --output samples/lidm_ped_psax/privatised_latents
52 | ```
53 |
54 | This script will filter out all latents that are too similar to real latents (encoded images).
55 | The filtered latents are saved in the directory specified in `output`.
56 | The similarity threshold is automatically determined by the `apply.py` script.
57 | The latents are all that's required to generate the privacy-compliant synthetic videos, because the Latent Video Diffusion Model (LVDM) is conditioned on the latents, not the images themselves.
58 | To obtain the privacy-compliant images, you can use the provided script like so:
59 |
60 | ```bash
61 | ./scripts/copy_privacy_compliant_images.sh samples/dynamic/images samples/dynamic/privacy_compliant_latents samples/dynamic/privacy_compliant_images
62 | ./scripts/copy_privacy_compliant_images.sh samples/ped_a4c/images samples/ped_a4c/privacy_compliant_latents samples/ped_a4c/privacy_compliant_images
63 | ./scripts/copy_privacy_compliant_images.sh samples/ped_psax/images samples/ped_psax/privacy_compliant_latents samples/ped_psax/privacy_compliant_images
64 | ```
65 |
66 | ## 5. Evaluate the remaining images
67 |
68 | To evaluate the remaining images, we use the same process as for the LIDM, with the following commands:
69 |
70 | ```bash
71 |
72 | cd external/stylegan-v
73 |
74 | # For the Dynamic dataset
75 | python src/scripts/calc_metrics_for_dataset.py \
76 | --real_data_path ../../data/reference/dynamic \
77 | --fake_data_path ../../samples/lidm_dynamic/privacy_compliant_images \
78 | --mirror 0 --gpus 1 --resolution 112 \
79 | --metrics fid50k_full,is50k >> "../../samples/lidm_dynamic/privacy_compliant_metrics.txt"
80 |
81 | # For the Pediatric A4C dataset
82 | python src/scripts/calc_metrics_for_dataset.py \
83 | --real_data_path data/reference/ped_a4c \
84 | --fake_data_path samples/lidm_ped_a4c/privacy_compliant_images \
85 | --mirror 0 --gpus 1 --resolution 112 \
86 | --metrics fid50k_full,is50k >> "samples/lidm_ped_a4c/privacy_compliant_metrics.txt"
87 |
88 | # For the Pediatric PSAX dataset
89 | python src/scripts/calc_metrics_for_dataset.py \
90 | --real_data_path data/reference/ped_psax \
91 | --fake_data_path samples/lidm_ped_psax/privacy_compliant_images \
92 | --mirror 0 --gpus 1 --resolution 112 \
93 | --metrics fid50k_full,is50k >> "samples/lidm_ped_psax/privacy_compliant_metrics.txt
94 | ```
95 |
96 | It is important that at least 50,000 samples remain in the filtered folders, to ensure that the FID and IS scores are reliable.
97 |
98 |
99 | ## 6. Save the Re-Identification models for later use
100 |
101 | The re-identification models can be saved for later use with the following command:
102 |
103 | ```bash
104 | mkdir -p models/reidentification_dynamic; cp experiments/reidentification_dynamic/reidentification_dynamic_best_network.pth models/reidentification_dynamic/; cp experiments/reidentification_dynamic/config.json models/reidentification_dynamic/
105 | ```
106 | ```bash
107 | mkdir -p models/reidentification_ped_a4c; cp experiments/reidentification_ped_a4c/reidentification_ped_a4c_best_network.pth models/reidentification_ped_a4c/; cp experiments/reidentification_ped_a4c/config.json models/reidentification_ped_a4c/
108 | ```
109 | ```bash
110 | mkdir -p models/reidentification_ped_psax; cp experiments/reidentification_ped_psax/reidentification_ped_psax_best_network.pth models/reidentification_ped_psax/; cp experiments/reidentification_ped_psax/config.json models/reidentification_ped_psax/
111 | ```
112 |
--------------------------------------------------------------------------------
/echosyn/privacy/apply.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import random
5 | import shutil
6 |
7 | import torch
8 | import pandas as pd
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | from PIL import Image
13 |
14 | import matplotlib.pyplot as plt
15 | from functools import partial
16 | from echosyn.common.datasets import SimaseUSVideoDataset
17 | from echosyn.privacy.shared import SiameseNetwork
18 |
19 |
20 | """
21 | This script is used to apply the privacy filter to a set of synthetic latents.
22 | Example usage:
23 | python echosyn/privacy/apply.py \
24 | --model experiments/reidentification_dynamic \
25 | --synthetic samples/dynamic/latents \
26 | --reference data/latents/dynamic \
27 | --output samples/dynamic/privatised_latents
28 | """
29 |
30 | def first_frame(vid):
31 | return vid[0:1]
32 |
33 | def subsample(vid, every_nth_frame):
34 | frames = np.arange(0, len(vid), step=every_nth_frame)
35 | return vid[frames]
36 |
37 | if __name__ == "__main__":
38 |
39 | parser = argparse.ArgumentParser(description='Privacy Filter')
40 | parser.add_argument('--model', type=str, help='Path to the model folder.')
41 | parser.add_argument('--synthetic', type=str, help='Path to the synthetic latents folder.')
42 | parser.add_argument('--reference', type=str, help='Path to the real latents folder.')
43 | parser.add_argument('--output', type=str, help='Path to the output folder.')
44 | parser.add_argument('--cutoff_precentile', type=float, default=95, help='Cutoff percentile for privacy threshold.')
45 |
46 | args = parser.parse_args()
47 |
48 | # Load real and synthetic latents
49 | training_latents_csv = os.path.join(args.reference, "FileList.csv")
50 | training_latents_basepath = os.path.join(args.reference, "Latents")
51 | normalization =lambda x: (x - x.min())/(x.max() - x.min()) * 2 - 1 # should be -1 to 1 due to way we trained the model
52 | ds_train = SimaseUSVideoDataset(phase="training", transform=normalization, latents_csv=training_latents_csv, training_latents_base_path=training_latents_basepath, in_memory=True)
53 | ds_test = SimaseUSVideoDataset(phase="testing", transform=normalization, latents_csv=training_latents_csv, training_latents_base_path=training_latents_basepath, in_memory=True)
54 |
55 | synthetic_images_paths = os.listdir(args.synthetic)
56 | synthetic_images = [torch.load(os.path.join(args.synthetic, x)) for x in synthetic_images_paths]
57 |
58 | # convert images to 1 x C x H x W to be consistent in case we want to check videos
59 | for i in range(len(synthetic_images)):
60 | if len(synthetic_images[i].size()) == 3:
61 | synthetic_images[i] = synthetic_images[i].unsqueeze(dim=0)
62 |
63 | synthetic_images = normalization(torch.cat(synthetic_images))
64 | print(f"Number of synthetic images found: {len(synthetic_images)}")
65 |
66 | # Prepare the images
67 | train_vid_to_img = first_frame
68 | test_vid_to_img = first_frame
69 | train_images = torch.cat([train_vid_to_img(x) for x in tqdm(ds_train)])
70 | test_images = torch.cat([test_vid_to_img(x) for x in tqdm(ds_test)])
71 | print(f"Number of real train frames: {len(train_images)}")
72 | print(f"Number of real test frames: {len(test_images)}")
73 |
74 |
75 | # Load Model
76 | with open(os.path.join(args.model, "config.json")) as config:
77 | config = config.read()
78 |
79 | config = json.loads(config)
80 | net = SiameseNetwork(network=config['siamese_architecture'], in_channels=config['n_channels'], n_features=config['n_features'])
81 | net.load_state_dict(torch.load(os.path.join(args.model, os.path.basename(args.model) + "_best_network.pth")))
82 |
83 |
84 | print("Sanity Check")
85 | print(f"Train - Min: {train_images.min()} - Max: {train_images.max()} - shape: {train_images.size()}")
86 | print(f"Test - Min: {test_images.min()} - Max: {test_images.max()} - shape: {test_images.size()}")
87 | print(f"Train - Min: {synthetic_images.min()} - Max: {synthetic_images.max()} - shape: {synthetic_images.size()}")
88 |
89 | # Compute the embeddings
90 | net.eval()
91 | net = net.cuda()
92 | bs = 256
93 | latents_train = []
94 | latents_test = []
95 | latents_synth = []
96 | with torch.no_grad():
97 | for i in tqdm(np.arange(0, len(train_images), bs), "Computing Train Embeddings"):
98 | batch = train_images[i:i+bs].cuda()
99 | latents_train.append(net.forward_once(batch))
100 |
101 | for i in tqdm(np.arange(0, len(test_images), bs), "Computing Test Embeddings"):
102 | batch = test_images[i:i+bs].cuda()
103 | latents_test.append(net.forward_once(batch))
104 |
105 | for i in tqdm(np.arange(0, len(synthetic_images), bs), "Computing Synthetic Embeddings"):
106 | batch = synthetic_images[i:i+bs].cuda()
107 | latents_synth.append(net.forward_once(batch))
108 |
109 | latents_train = torch.cat(latents_train)
110 | latents_test = torch.cat(latents_test)
111 | latents_synth = torch.cat(latents_synth)
112 |
113 |
114 | # Automatically determine the privacy threshold
115 | train_val_corr = torch.corrcoef(torch.cat([latents_train, latents_test])).cpu()
116 | print(train_val_corr.size())
117 |
118 | closest_train = []
119 | for i in range(len(train_images)):
120 | val_matches = train_val_corr[i, len(train_images):]
121 | closest_train.append(val_matches.max().cpu())
122 |
123 | tau = np.percentile(torch.stack(closest_train).numpy(), args.cutoff_precentile)
124 | print(f"Privacy threshold tau: {tau}")
125 |
126 | # Compute the closest matches between synthetic and real images
127 | closest_test = []
128 | batch_size = 10000
129 | latents_synth.cpu()
130 | closest_loc = []
131 | for l in np.arange(0, len(latents_synth), batch_size):
132 | synth = latents_synth[l:l + batch_size].cuda()
133 | train_synth_corr = torch.corrcoef(torch.cat([latents_train, synth]))
134 | for i in range(len(synth)):
135 | synth_matches = train_synth_corr[len(train_images)+i, :len(train_images)]
136 | closest_test.append(synth_matches.max().cpu())
137 | closest_loc.append(synth_matches.argmax().cpu())
138 | synth = synth.cpu()
139 |
140 | # Plot the results
141 | fig, ax = plt.subplots()
142 | nt, bins, patches = ax.hist(closest_train, 100, density=True, label="Train-Val", alpha=0.5, color="blue")
143 | ns, bins, patches = ax.hist(closest_test, 100, density=True, label="Train-Synth", alpha=.5, color="orange")
144 | ax.axvline(tau, 0, max(max(nt), max(ns)), color="black")
145 | ax.set_xlabel('Correlation')
146 | ax.set_ylabel('Frequency')
147 | ax.set_title('Histogram of highest correlation matches')
148 | ax.text(tau, max(max(nt), max(ns)), f'$tau$ = {tau:0.5f} ', ha='right', va='bottom', rotation=0, color='black')
149 | plt.legend()
150 | fig.tight_layout()
151 | plt.savefig(os.path.join(args.model, "privacy_results.png"))
152 |
153 | closest_test = torch.stack(closest_test)
154 | is_public = closest_test < tau
155 | print(f"Number of synthetic images: {len(closest_test)}")
156 | print(f"Number of non private synthetic images: {sum(is_public)} - memorization rate: {1- is_public.float().mean()}")
157 |
158 | # Save the privacy-compliant latents
159 | private_latents_output_path = os.path.abspath(args.output)
160 | os.makedirs(private_latents_output_path, exist_ok=True)
161 |
162 | for path, is_pub in zip(synthetic_images_paths, is_public):
163 | src = os.path.join(args.synthetic, path)
164 | tgt = os.path.join(private_latents_output_path, path)
165 | if is_pub:
166 | shutil.copy(src, tgt)
167 | else:
168 | print(f"Skipping images because it appears memorized: {tgt} ")
169 |
170 | results = {"real_path":[], "tau":[], "synth_path":[], "img_idx": []}
171 | for clos_idx, tau_val, path in zip(closest_loc, closest_test, synthetic_images_paths):
172 | results["real_path"].append(ds_train.df.iloc[int(clos_idx)]["FileName"])
173 | results["tau"].append(float(tau_val))
174 | results["synth_path"].append(path)
175 | results["img_idx"].append(int(clos_idx))
176 | res = pd.DataFrame(results)
177 | res.to_csv(os.path.join(os.path.dirname(private_latents_output_path), "privacy_scores.csv"), index=False)
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/echosyn/privacy/configs/config_dynamic.json:
--------------------------------------------------------------------------------
1 | {
2 | "experiment_description": "reidentification_dynamic",
3 | "resumption": false,
4 | "resumption_count": null,
5 | "previous_experiment": null,
6 |
7 | "image_path": "data/latents/dynamic/FileList.csv",
8 |
9 | "siamese_architecture": "ResNet-50",
10 | "data_handling": "balanced",
11 |
12 | "num_workers": 16,
13 | "pin_memory": true,
14 |
15 | "n_channels": 4,
16 | "n_features": 128,
17 | "image_size": 14,
18 | "loss": "BCEWithLogitsLoss",
19 | "optimizer": "Adam",
20 | "learning_rate": 0.0001,
21 | "batch_size": 128,
22 | "max_epochs": 1000,
23 | "early_stopping": 50,
24 | "transform": "pmone",
25 |
26 | "n_samples_train": 7465,
27 | "n_samples_val": 1288,
28 | "n_samples_test": 1277
29 | }
30 |
--------------------------------------------------------------------------------
/echosyn/privacy/configs/config_ped_a4c.json:
--------------------------------------------------------------------------------
1 | {
2 | "experiment_description": "reidentification_ped_a4c",
3 | "resumption": false,
4 | "resumption_count": null,
5 | "previous_experiment": null,
6 |
7 | "image_path": "data/latents/ped_a4c/FileList.csv",
8 |
9 | "siamese_architecture": "ResNet-50",
10 | "data_handling": "balanced",
11 |
12 | "num_workers": 16,
13 | "pin_memory": true,
14 |
15 | "n_channels": 4,
16 | "n_features": 128,
17 | "image_size": 14,
18 | "loss": "BCEWithLogitsLoss",
19 | "optimizer": "Adam",
20 | "learning_rate": 0.0001,
21 | "batch_size": 128,
22 | "max_epochs": 1000,
23 | "early_stopping": 50,
24 | "transform": "pmone",
25 |
26 | "n_samples_train": 2580,
27 | "n_samples_val": 368,
28 | "n_samples_test": 336
29 | }
30 |
--------------------------------------------------------------------------------
/echosyn/privacy/configs/config_ped_psax.json:
--------------------------------------------------------------------------------
1 | {
2 | "experiment_description": "reidentification_ped_psax",
3 | "resumption": false,
4 | "resumption_count": null,
5 | "previous_experiment": null,
6 |
7 | "image_path": "data/latents/ped_psax/FileList.csv",
8 |
9 | "siamese_architecture": "ResNet-50",
10 | "data_handling": "balanced",
11 |
12 | "num_workers": 16,
13 | "pin_memory": true,
14 |
15 | "n_channels": 4,
16 | "n_features": 128,
17 | "image_size": 14,
18 | "loss": "BCEWithLogitsLoss",
19 | "optimizer": "Adam",
20 | "learning_rate": 0.0001,
21 | "batch_size": 128,
22 | "max_epochs": 1000,
23 | "early_stopping": 50,
24 | "transform": "pmone",
25 |
26 | "n_samples_train": 3559,
27 | "n_samples_val": 519,
28 | "n_samples_test": 448
29 | }
30 |
--------------------------------------------------------------------------------
/echosyn/privacy/shared.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 |
5 |
6 | class SiameseNetwork(nn.Module):
7 | def __init__(self, network='ResNet-50', in_channels=3, n_features=128):
8 | super(SiameseNetwork, self).__init__()
9 | self.network = network
10 | self.in_channels = in_channels
11 | self.n_features = n_features
12 |
13 | if self.network == 'ResNet-50':
14 | # Model: Use ResNet-50 architecture
15 | self.model = models.resnet50(pretrained=True)
16 | # Adjust the input layer: either 1 or 3 input channels
17 | if self.in_channels == 1:
18 | self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
19 | if self.in_channels == 4:
20 | self.model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
21 | elif self.in_channels == 3:
22 | pass
23 | else:
24 | raise Exception(
25 | 'Invalid argument: ' + self.in_channels + '\nChoose either in_channels=1 or in_channels=3')
26 | # Adjust the ResNet classification layer to produce feature vectors of a specific size
27 | self.model.fc = nn.Linear(in_features=2048, out_features=self.n_features, bias=True)
28 |
29 | else:
30 | raise Exception('Invalid argument: ' + self.network +
31 | '\nChoose ResNet-50! Other architectures are not yet implemented in this framework.')
32 |
33 | self.fc_end = nn.Linear(self.n_features, 1)
34 |
35 | def forward_once(self, x):
36 |
37 | # Forward function for one branch to get the n_features-dim feature vector before merging
38 | output = self.model(x)
39 | output = torch.sigmoid(output)
40 | return output
41 |
42 | def forward(self, input1, input2):
43 |
44 | # Forward
45 | output1 = self.forward_once(input1)
46 | output2 = self.forward_once(input2)
47 |
48 | # Compute the absolute difference between the n_features-dim feature vectors and pass it to the last FC-Layer
49 | difference = torch.abs(output1 - output2)
50 | output = self.fc_end(difference)
51 |
52 | return output
53 |
54 |
55 |
--------------------------------------------------------------------------------
/echosyn/privacy/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import copy
5 | import time
6 |
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torchvision.transforms as transforms
13 |
14 | from echosyn.common import privacy_utils as Utils
15 | from echosyn.privacy.shared import SiameseNetwork
16 |
17 | class AgentSiameseNetwork:
18 | def __init__(self, config):
19 | self.config = config
20 |
21 | # set path used to save experiment-related files and results
22 | self.SAVINGS_PATH = './experiments/' + self.config['experiment_description'] + '/'
23 | self.IMAGE_PATH = self.config['image_path']
24 |
25 | # save configuration as config.json in the created folder
26 | with open(self.SAVINGS_PATH + 'config.json', 'w') as outfile:
27 | json.dump(self.config, outfile, indent='\t')
28 | outfile.close()
29 |
30 | # enable benchmark mode in cuDNN
31 | torch.backends.cudnn.benchmark = True
32 |
33 | # set all the important variables
34 | self.network = self.config['siamese_architecture']
35 | self.data_handling = self.config['data_handling']
36 |
37 | self.num_workers = self.config['num_workers']
38 | self.pin_memory = self.config['pin_memory']
39 |
40 | self.n_channels = self.config['n_channels']
41 | self.n_features = self.config['n_features']
42 | self.image_size = self.config['image_size']
43 | self.loss_method = self.config['loss']
44 | self.optimizer_method = self.config['optimizer']
45 | self.learning_rate = self.config['learning_rate']
46 | self.batch_size = self.config['batch_size']
47 | self.max_epochs = self.config['max_epochs']
48 | self.early_stopping = self.config['early_stopping']
49 | self.transform = self.config['transform']
50 |
51 | self.n_samples_train = self.config['n_samples_train']
52 | self.n_samples_val = self.config['n_samples_val']
53 | self.n_samples_test = self.config['n_samples_test']
54 |
55 | self.start_epoch = 0
56 |
57 | self.es = EarlyStopping(patience=self.early_stopping)
58 | self.best_loss = 100000
59 | self.loss_dict = {'training': [],
60 | 'validation': []}
61 |
62 | if self.data_handling == 'balanced':
63 | self.balanced = True
64 | self.randomized = False
65 | elif self.data_handling == 'randomized':
66 | self.balanced = True
67 | self.randomized = True
68 |
69 | # define the suffix needed for loading the checkpoint (in case you want to resume a previous experiment)
70 | if self.config['resumption'] is True:
71 | if self.config['resumption_count'] == 1:
72 | self.load_suffix = ''
73 | elif self.config['resumption_count'] == 2:
74 | self.load_suffix = '_resume'
75 | elif self.config['resumption_count'] > 2:
76 | self.load_suffix = '_resume' + str(self.config['resumption_count'] - 1)
77 |
78 | # define the suffix needed for saving the checkpoint (the checkpoint is saved at the end of each epoch)
79 | if self.config['resumption'] is False:
80 | self.save_suffix = ''
81 | elif self.config['resumption'] is True:
82 | if self.config['resumption_count'] == 1:
83 | self.save_suffix = '_resume'
84 | elif self.config['resumption_count'] > 1:
85 | self.save_suffix = '_resume' + str(self.config['resumption_count'])
86 |
87 | # Define the siamese neural network architecture
88 | self.net = SiameseNetwork(network=self.network, in_channels=self.n_channels, n_features=self.n_features).cuda()
89 | self.best_net = SiameseNetwork(network=self.network, in_channels=self.n_channels,
90 | n_features=self.n_features).cuda()
91 |
92 | # Choose loss function
93 | if self.loss_method == 'BCEWithLogitsLoss':
94 | self.loss = nn.BCEWithLogitsLoss().cuda()
95 | else:
96 | raise Exception('Invalid argument: ' + self.loss_method +
97 | '\nChoose BCEWithLogitsLoss! Other loss functions are not yet implemented!')
98 |
99 | # Set the optimizer function
100 | if self.optimizer_method == 'Adam':
101 | self.optimizer = optim.Adam(self.net.parameters(), lr=self.learning_rate)
102 | else:
103 | raise Exception('Invalid argument: ' + self.optimizer_method +
104 | '\nChoose Adam! Other optimizer functions are not yet implemented!')
105 |
106 | # load state dicts and other information in case a previous experiment will be continued
107 | if self.config['resumption'] is True:
108 | self.checkpoint = torch.load('./archive/' + self.config['previous_experiment'] + '/' + self.config[
109 | 'previous_experiment'] + '_checkpoint' + self.load_suffix + '.pth')
110 | self.best_net.load_state_dict(torch.load(
111 | './archive/' + self.config['previous_experiment'] + '/' + self.config[
112 | 'previous_experiment'] + '_best_network' + self.load_suffix + '.pth'))
113 | self.net.load_state_dict(self.checkpoint['state_dict'])
114 | self.optimizer.load_state_dict(self.checkpoint['optimizer'])
115 | self.best_loss = self.checkpoint['best_loss']
116 | self.loss_dict = self.checkpoint['loss_dict']
117 | self.es.best = self.checkpoint['best_loss']
118 | self.es.num_bad_epochs = self.checkpoint['num_bad_epochs']
119 | self.start_epoch = self.checkpoint['epoch']
120 |
121 | # Initialize transformations
122 | if self.transform == 'image_net':
123 | self.transform_train = transforms.Compose([
124 | transforms.Resize((self.image_size, self.image_size)),
125 | transforms.ToTensor(),
126 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127 | ])
128 | self.transform_val_test = transforms.Compose([
129 | transforms.Resize((self.image_size, self.image_size)),
130 | transforms.ToTensor(),
131 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
132 | ])
133 | elif self.transform == "default":
134 | self.transform_train = transforms.Compose([
135 | transforms.Resize((self.image_size, self.image_size)),
136 | transforms.ToTensor()
137 | ])
138 | self.transform_val_test = transforms.Compose([
139 | transforms.Resize((self.image_size, self.image_size)),
140 | transforms.ToTensor()
141 | ])
142 | elif self.transform == "pmone":
143 | self.transform_train = lambda x: (x - x.min())/(x.max() - x.min()) * 2 - 1
144 | self.transform_val_test = lambda x: (x - x.min())/(x.max() - x.min()) * 2 - 1
145 |
146 | elif self.transform == "none":
147 | self.transform_train = lambda x: x
148 | self.transform_val_test = lambda x: x
149 |
150 |
151 | self.training_loader = Utils.get_data_loaders(phase='training', data_handling=self.data_handling,
152 | n_channels=self.n_channels,
153 | n_samples=self.n_samples_train,
154 | transform=self.transform_train, image_path=self.IMAGE_PATH,
155 | batch_size=self.batch_size, shuffle=True,
156 | num_workers=self.num_workers, pin_memory=self.pin_memory,
157 | save_path=None)
158 | self.validation_loader = Utils.get_data_loaders(phase='validation', data_handling=self.data_handling,
159 | n_channels=self.n_channels,
160 | n_samples=self.n_samples_val,
161 | transform=self.transform_val_test,
162 | image_path=self.IMAGE_PATH, batch_size=self.batch_size,
163 | shuffle=False, num_workers=self.num_workers,
164 | pin_memory=self.pin_memory, save_path=None, generator_seed=0)
165 | self.test_loader = Utils.get_data_loaders(phase='testing', data_handling='balanced',
166 | n_channels=self.n_channels, n_samples=self.n_samples_test,
167 | transform=self.transform_val_test, image_path=self.IMAGE_PATH,
168 | batch_size=self.batch_size, shuffle=False,
169 | num_workers=self.num_workers, pin_memory=self.pin_memory,
170 | save_path=None, generator_seed=0)
171 |
172 | def training_validation(self):
173 | # Training and validation loop!
174 | for epoch in range(self.start_epoch, self.max_epochs):
175 | start_time = time.time()
176 |
177 | training_loss = Utils.train(self.net, self.training_loader, self.n_samples_train, self.batch_size,
178 | self.loss, self.optimizer, epoch, self.max_epochs)
179 |
180 | self.validation_loader.dataset.reset_generator() # make sure this is deterministic for more accurate loss
181 | validation_loss = Utils.validate(self.net, self.validation_loader, self.n_samples_val, self.batch_size,
182 | self.loss, epoch, self.max_epochs)
183 |
184 | self.loss_dict['training'].append(training_loss)
185 | self.loss_dict['validation'].append(validation_loss)
186 | end_time = time.time()
187 | print('Time elapsed for epoch ' + str(epoch + 1) + ': ' + str(
188 | round((end_time - start_time) / 60, 2)) + ' minutes')
189 |
190 | if validation_loss < self.best_loss:
191 | self.best_loss = validation_loss
192 | self.best_net = copy.deepcopy(self.net)
193 |
194 | torch.save(self.best_net.state_dict(), self.SAVINGS_PATH + self.config[
195 | 'experiment_description'] + '_best_network' + self.save_suffix + '.pth')
196 |
197 | Utils.save_loss_curves(self.loss_dict, self.SAVINGS_PATH, self.config['experiment_description'])
198 | Utils.plot_loss_curves(self.loss_dict, self.SAVINGS_PATH, self.config['experiment_description'])
199 |
200 | es_check = self.es.step(validation_loss)
201 | Utils.save_checkpoint(epoch, self.net, self.optimizer, self.loss_dict, self.best_loss,
202 | self.es.num_bad_epochs, self.SAVINGS_PATH + self.config[
203 | 'experiment_description'] + '_checkpoint' + self.save_suffix + '.pth')
204 |
205 | if es_check:
206 | break
207 |
208 | print('Finished Training!')
209 |
210 | def testing_evaluation(self):
211 | # Testing phase!
212 | self.test_loader.dataset.reset_generator()
213 | y_true, y_pred = Utils.test(self.best_net, self.test_loader)
214 | y_true, y_pred = [y_true.numpy(), y_pred.numpy()]
215 |
216 | # Compute the evaluation metrics!
217 | fp_rates, tp_rates, thresholds = metrics.roc_curve(y_true, y_pred)
218 | auc = metrics.roc_auc_score(y_true, y_pred)
219 | y_pred_thresh = Utils.apply_threshold(y_pred, 0.5)
220 | accuracy, f1_score, precision, recall, report, confusion_matrix = Utils.get_evaluation_metrics(y_true,
221 | y_pred_thresh)
222 | auc_mean, confidence_lower, confidence_upper = Utils.bootstrap(10000,
223 | y_true,
224 | y_pred,
225 | self.SAVINGS_PATH,
226 | self.config['experiment_description'])
227 |
228 | # Plot ROC curve!
229 | Utils.plot_roc_curve(fp_rates, tp_rates, self.SAVINGS_PATH, self.config['experiment_description'])
230 |
231 | # Save all the results to files!
232 | Utils.save_labels_predictions(y_true, y_pred, y_pred_thresh, self.SAVINGS_PATH,
233 | self.config['experiment_description'])
234 |
235 | Utils.save_results_to_file(auc, accuracy, f1_score, precision, recall, report, confusion_matrix,
236 | self.SAVINGS_PATH, self.config['experiment_description'])
237 |
238 | Utils.save_roc_metrics_to_file(fp_rates, tp_rates, thresholds, self.SAVINGS_PATH,
239 | self.config['experiment_description'])
240 |
241 | # Print the evaluation metrics!
242 | print('EVALUATION METRICS:')
243 | print('AUC: ' + str(auc))
244 | print('Accuracy: ' + str(accuracy))
245 | print('F1-Score: ' + str(f1_score))
246 | print('Precision: ' + str(precision))
247 | print('Recall: ' + str(recall))
248 | print('Report: ' + str(report))
249 | print('Confusion matrix: ' + str(confusion_matrix))
250 |
251 | print('BOOTSTRAPPING: ')
252 | print('AUC Mean: ' + str(auc_mean))
253 | print('Confidence interval for the AUC score: ' + str(confidence_lower) + ' - ' + str(confidence_upper))
254 |
255 | def run(self):
256 | # Call training/validation and testing loop successively
257 | self.training_validation()
258 | self.testing_evaluation()
259 |
260 | class EarlyStopping(object):
261 | def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
262 | self.mode = mode
263 | self.min_delta = min_delta
264 | self.patience = patience
265 | self.best = None
266 | self.num_bad_epochs = 0
267 | self.is_better = None
268 | self._init_is_better(mode, min_delta, percentage)
269 |
270 | if patience == 0:
271 | self.is_better = lambda a, b: True
272 | self.step = lambda a: False
273 |
274 | def step(self, metrics):
275 | if self.best is None:
276 | self.best = metrics
277 | return False
278 |
279 | if np.isnan(metrics):
280 | return True
281 |
282 | if self.is_better(metrics, self.best):
283 | self.num_bad_epochs = 0
284 | self.best = metrics
285 | else:
286 | self.num_bad_epochs += 1
287 |
288 | if self.num_bad_epochs >= self.patience:
289 | return True
290 |
291 | return False
292 |
293 | def _init_is_better(self, mode, min_delta, percentage):
294 | if mode not in {'min', 'max'}:
295 | raise ValueError('mode ' + mode + ' is unknown!')
296 | if not percentage:
297 | if mode == 'min':
298 | self.is_better = lambda a, best: a < best - min_delta
299 | if mode == 'max':
300 | self.is_better = lambda a, best: a > best + min_delta
301 | else:
302 | if mode == 'min':
303 | self.is_better = lambda a, best: a < best - (
304 | best * min_delta / 100)
305 | if mode == 'max':
306 | self.is_better = lambda a, best: a > best + (
307 | best * min_delta / 100)
308 |
309 |
310 | if __name__ == '__main__':
311 | # define an argument parser
312 | parser = argparse.ArgumentParser('Patient Verification')
313 | parser.add_argument('--config_path', default='./', help='the path where the config files are stored')
314 | parser.add_argument('--config', default='config.json', help='the hyper-parameter configuration and experiment settings')
315 | args = parser.parse_args()
316 | print('Arguments:\n' + '--config_path: ' + args.config_path + '\n--config: ' + args.config)
317 |
318 | # read config
319 | with open(args.config_path + args.config, 'r') as config:
320 | config = config.read()
321 |
322 | # parse config
323 | config = json.loads(config)
324 |
325 | # create folder to save experiment-related files
326 | os.makedirs(os.path.join('./experiments/' , config['experiment_description']), exist_ok=True)
327 | SAVINGS_PATH = './experiments/' + config['experiment_description'] + '/'
328 |
329 | # call agent and run experiment
330 | experiment = AgentSiameseNetwork(config)
331 | experiment.run()
--------------------------------------------------------------------------------
/echosyn/vae/README.md:
--------------------------------------------------------------------------------
1 | # Variational Autoencoder (fun fact: it's also a GAN)
2 |
3 | Training the VAE is not a necessity as an open source VAE such as [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) should work well enough for this project.
4 |
5 | However, if you would like to train the VAE yourself, follow these steps.
6 |
7 | ## 1. Clone the necessary repositories
8 | We use the Stable-Diffusion repo to train the VAE model, but the Taming-Transformers repo is a necessary dependency. It is advised to follow these steps exactly to avoid errors later on.
9 |
10 | ```bash
11 | cd external
12 | git clone https://github.com/CompVis/stable-diffusion
13 |
14 | cd stable-diffusion
15 | conda env create -f environment.yaml
16 | conda activate ldm
17 | ```
18 |
19 | This will download the Stable-Diffusion repository and create the necessary environment to run it.
20 |
21 | ```bash
22 | git clone https://github.com/CompVis/taming-transformers
23 | cd taming-transformers
24 | pip install -e .
25 | ```
26 |
27 | This will download the Taming-Transformers repository and install it as a package.
28 |
29 | ## 2. Prepare the data
30 | The VAE needs images only and to keep things simple, we format our video datasets into image datasets.
31 | We do that with:
32 | ```bash
33 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Dynamic/Videos data/vae_train_images/images/
34 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/A4C/Videos data/vae_train_images/images/
35 | bash scripts/extract_frames_from_videos.sh datasets/EchoNet-Pediatric/PSAX/Videos data/vae_train_images/images/
36 | ```
37 |
38 | Note that this will merge all the images in the same folder.
39 |
40 | Then, we need to create train.txt file and a val.txt file containing the path to these images.
41 | ```bash
42 | find $(cd data/vae_train_images/images && pwd) -type f | shuf > tmp.txt
43 | head -n -1000 tmp.txt > data/vae_train_images/train.txt
44 | tail -n 1000 tmp.txt > data/vae_train_images/val.txt
45 | rm tmp.txt
46 | ```
47 |
48 | That's it for the dataset.
49 |
50 |
51 | ## 3. Train the VAE
52 | Now that the data is ready, we can train the VAE. We use the following command to train the VAE on the images we just extracted.
53 |
54 | ```bash
55 | cd external/stable-diffusion
56 | export DATADIR=$(cd ../../data/vae_train_images && pwd)
57 | python main.py \
58 | --base ../../echosyn/vae/usencoder_kl_16x16x4.yaml \
59 | -t True \
60 | --gpus 0,1,2,3,4,5,6,7 \
61 | --logdir experiments/vae \
62 | ```
63 |
64 | This will train the VAE on the images we extracted and save the model in the experiments/vae folder.
65 | For the paper, we train the VAE on 8xA100 GPUs for 5 days.
66 | *Note: if you use a single gpu, you need to leave a comma in the --gpus argument, like so: ```--gpus 0,```*
67 | To resume training from a checkpoint, use the ```--resume``` flag and replace ```EXPERIMENT_NAME``` with the correct experiment name.
68 |
69 | ```bash
70 | python main.py \
71 | --base ../../echosyn/vae/usencoder_kl_16x16x4.yaml \
72 | -t True \
73 | --gpus 0,1,2,3,4,5,6,7 \
74 | --logdir experiments/vae \
75 | --resume experiments/vae/EXPERIMENT_NAME
76 | ```
77 |
78 | ## 4. Export the VAE to Diffusers 🧨
79 | Now that the VAE is trained, we can export it to a Diffuser AutoencoderKL model. This is done with the following command, replacing ```EXPERIMENT_NAME``` with the correct experiment name.
80 |
81 | ```bash
82 | python scripts/convert_vae_pt_to_diffusers.py
83 | --vae_pt_path experiments/EXPERIMENT_NAME/checkpoints/last.ckpt
84 | --dump_path models/vae/
85 | ```
86 | The script is taken as-is from the [Diffusers library](https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py).
87 |
88 | ## 5. Test the model in Diffusers 🧨
89 | That's it! You now have a VAE model trained on your own data and exported to a Diffuser model.
90 |
91 | To test the model, use the echosyn environment and do:
92 | ```python
93 | from PIL import Image
94 | import torch
95 | import numpy as np
96 | from diffusers import AutoencoderKL
97 | model = AutoencoderKL.from_pretrained("models/vae")
98 | model.eval()
99 |
100 | # Use the model to encode and decode images
101 | img = Image.open("data/vae_train_images/images/0X10A5FC19152B50A5_00001.jpg")
102 |
103 | img = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).to(torch.float32) / 128 - 1
104 |
105 | with torch.no_grad():
106 | lat = model.encode(img).latent_dist.sample()
107 | rec = model.decode(lat).sample
108 | # Display the original and reconstructed images
109 | img = img.squeeze(0).permute(1,2,0)
110 | rec = rec.squeeze(0).permute(1,2,0)
111 | display(Image.fromarray(((img + 1) * 128).to(torch.uint8).numpy()))
112 | ```
113 |
114 | Note that model might not work in mixed precision mode, which would cause nan values. If this happens, force the model dtype to torch.float32
115 |
116 | ## 6. Evaluate the VAE
117 |
118 | To evaluate the VAE, we reconstruct the extracted image datasets and compare the original images to the reconstructed images. We use the following command to recontruct (encode-decode) the images from the dynamic dataset.
119 |
120 | ```bash
121 | python scripts/vae_reconstruct_image_folder.py \
122 | -model models/vae \
123 | -input data/reference/dynamic \
124 | -output samples/reconstructed/dynamic \
125 | -batch_size 32
126 | ```
127 |
128 | This will save the reconstructed images in the data/reconstructed/dynamic folder. To evaluate the VAE, we use the following command to compute the FID score between the original images and the reconstructed images, with the help of the StyleGAN-V repository.
129 |
130 | ```bash
131 | cd external/stylegan-v
132 |
133 | python src/scripts/calc_metrics_for_dataset.py \
134 | --real_data_path ../../data/reference/dynamic \
135 | --fake_data_path ../../samples/reconstructed/dynamic \
136 | --mirror 0 --gpus 1 --resolution 112 \
137 | --metrics fid50k_full,is50k >> "../../samples/reconstructed/dynamic.txt"
138 | ```
139 |
140 |
141 |
--------------------------------------------------------------------------------
/echosyn/vae/usencoder_kl_16x16x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2e-6 # ~5e-4 after scaloing
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 112
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 32
30 | num_workers: 16
31 | train:
32 | target: taming.data.custom.CustomTrain
33 | params:
34 | training_images_list_file: ${oc.env:DATADIR}/train.txt
35 | size: 112
36 | validation:
37 | target: taming.data.custom.CustomTest
38 | params:
39 | test_images_list_file: ${oc.env:DATADIR}/val.txt
40 | size: 112
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 | max_epochs: 1000
--------------------------------------------------------------------------------
/external/README.md:
--------------------------------------------------------------------------------
1 | # This folder contains the external libraries that are used in the project.
2 |
3 | External libraries can be cloned from the following repositories, into the `external` folder:
4 |
5 | 1. Stable-Diffusion: ```git clone https://github.com/CompVis/stable-diffusion```
6 |
Necessary to train the VAE model.
7 |
8 |
9 | 2. Taming-Transformers: ```git clone https://github.com/CompVis/taming-transformers```
10 |
Necessary to train the VAE model.
11 |
12 |
13 | 3. StyleGAN-V: ```git clone git@github.com:HReynaud/stylegan-v.git```
14 |
Used to compute all generative metrics (FID, FVD, IS).
15 |
16 |
17 | 4. EchoNet-Dynamic: ```git clone git@github.com:HReynaud/echonet.git```
18 |
Used to evaluate the synthetic data on the Ejection Fraction downstream task.
19 |
20 |
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.2.1
2 | torchvision==0.17.1
3 | einops==0.7.0
4 | decord==0.6.0
5 | diffusers==0.27.2
6 | packaging==24.0
7 | omegaconf==2.3.0
8 | transformers==4.39.0
9 | accelerate==0.28.0
10 | pandas==2.2.1
11 | wandb==0.16.4
12 | opencv-python==4.9.0.80
13 | moviepy==1.0.3
14 | imageio==2.34.0
15 | scipy==1.12.0
16 | scikit-learn==1.4.1.post1
17 | matplotlib-3.8.3
--------------------------------------------------------------------------------
/ressources/models.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/models.jpg
--------------------------------------------------------------------------------
/ressources/mosaic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/mosaic.gif
--------------------------------------------------------------------------------
/ressources/mosaic_slim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/mosaic_slim.gif
--------------------------------------------------------------------------------
/ressources/real.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/real.gif
--------------------------------------------------------------------------------
/ressources/synt.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HReynaud/EchoNet-Synthetic/f21198d96dcdc38b18e4f0bb4e2998067c72e2b8/ressources/synt.gif
--------------------------------------------------------------------------------
/scripts/complete_pediatrics_filelist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import os
4 | import pandas as pd
5 | from tqdm import tqdm
6 |
7 | """
8 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/A4C
9 | python scripts/complete_pediatrics_filelist.py --dataset datasets/EchoNet-Pediatric/PSAX
10 | """
11 |
12 | def get_video_metadata(video_path):
13 | cap = cv2.VideoCapture(video_path)
14 | fps = cap.get(cv2.CAP_PROP_FPS)
15 | nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
16 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
17 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
18 | cap.release()
19 | return fps, nframes, width, height
20 |
21 | if __name__ == "__main__":
22 |
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory')
25 | args = parser.parse_args()
26 |
27 | csv_path = os.path.join(args.dataset, 'FileList.csv')
28 | assert os.path.exists(csv_path), f"Could not find FileList.csv at {csv_path}"
29 |
30 | metadata = pd.read_csv(csv_path)
31 | metadata.to_csv(os.path.join(args.dataset, 'FileList_ORIGINAL.csv'), index=False) # backup
32 |
33 | metadata['FileName'] = metadata['FileName'].apply(lambda x: x.split('.')[0]) # remove extension
34 | metadata['Fold'] = metadata['Split'] # Copy kfold indices to the Fold column
35 | metadata['Split'] = metadata['Fold'].apply(lambda x: 'TRAIN' if x in range(8) else 'VAL' if x == 8 else 'TEST')
36 |
37 | # Add columns:
38 | # df.loc[df['FileName'] == fname, ['FileName', 'FrameHeight','FrameWidth','FPS','NumberOfFrames']] = [fname, 112, 112, fps, len(video)]
39 |
40 | for i, row in tqdm(metadata.iterrows(), total=len(metadata)):
41 | video_name = row['FileName'] + '.avi'
42 | video_path = os.path.join(args.dataset, 'Videos', video_name)
43 |
44 | fps, nframes, width, height = get_video_metadata(video_path)
45 |
46 | metadata.loc[i, ['FrameHeight','FrameWidth','FPS','NumberOfFrames']] = [height, width, fps, nframes]
47 |
48 | metadata.to_csv(csv_path, index=False)
49 | print("Updated metadata saved to ", csv_path)
50 |
51 |
52 |
--------------------------------------------------------------------------------
/scripts/convert_vae_pt_to_diffusers.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import io
3 |
4 | import requests
5 | import torch
6 | from omegaconf import OmegaConf
7 |
8 | from diffusers import AutoencoderKL
9 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
10 | assign_to_checkpoint,
11 | conv_attn_to_linear,
12 | create_vae_diffusers_config,
13 | renew_vae_attention_paths,
14 | renew_vae_resnet_paths,
15 | )
16 |
17 |
18 | def custom_convert_ldm_vae_checkpoint(checkpoint, config):
19 | vae_state_dict = checkpoint
20 |
21 | new_checkpoint = {}
22 |
23 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
24 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
25 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
26 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
27 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
28 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
29 |
30 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
31 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
32 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
33 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
34 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
35 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
36 |
37 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
38 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
39 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
40 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
41 |
42 | # Retrieves the keys for the encoder down blocks only
43 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
44 | down_blocks = {
45 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
46 | }
47 |
48 | # Retrieves the keys for the decoder up blocks only
49 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
50 | up_blocks = {
51 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
52 | }
53 |
54 | for i in range(num_down_blocks):
55 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
56 |
57 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
58 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
59 | f"encoder.down.{i}.downsample.conv.weight"
60 | )
61 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
62 | f"encoder.down.{i}.downsample.conv.bias"
63 | )
64 |
65 | paths = renew_vae_resnet_paths(resnets)
66 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
67 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
68 |
69 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
70 | num_mid_res_blocks = 2
71 | for i in range(1, num_mid_res_blocks + 1):
72 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
73 |
74 | paths = renew_vae_resnet_paths(resnets)
75 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
76 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
77 |
78 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
79 | paths = renew_vae_attention_paths(mid_attentions)
80 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
81 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
82 | conv_attn_to_linear(new_checkpoint)
83 |
84 | for i in range(num_up_blocks):
85 | block_id = num_up_blocks - 1 - i
86 | resnets = [
87 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
88 | ]
89 |
90 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
91 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
92 | f"decoder.up.{block_id}.upsample.conv.weight"
93 | ]
94 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
95 | f"decoder.up.{block_id}.upsample.conv.bias"
96 | ]
97 |
98 | paths = renew_vae_resnet_paths(resnets)
99 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
100 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
101 |
102 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
103 | num_mid_res_blocks = 2
104 | for i in range(1, num_mid_res_blocks + 1):
105 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
106 |
107 | paths = renew_vae_resnet_paths(resnets)
108 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
109 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
110 |
111 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
112 | paths = renew_vae_attention_paths(mid_attentions)
113 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
114 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
115 | conv_attn_to_linear(new_checkpoint)
116 | return new_checkpoint
117 |
118 |
119 | def vae_pt_to_vae_diffuser(
120 | checkpoint_path: str,
121 | output_path: str,
122 | ):
123 | # Only support V1
124 | r = requests.get(
125 | "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
126 | )
127 | io_obj = io.BytesIO(r.content)
128 |
129 | original_config = yaml.safe_load(io_obj)
130 | image_size = 512
131 | device = "cuda" if torch.cuda.is_available() else "cpu"
132 | if checkpoint_path.endswith("safetensors"):
133 | from safetensors import safe_open
134 |
135 | checkpoint = {}
136 | with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
137 | for key in f.keys():
138 | checkpoint[key] = f.get_tensor(key)
139 | else:
140 | checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
141 |
142 | # Convert the VAE model.
143 | vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
144 | converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
145 |
146 | vae = AutoencoderKL(**vae_config)
147 | vae.load_state_dict(converted_vae_checkpoint)
148 | vae.save_pretrained(output_path)
149 |
150 |
151 | if __name__ == "__main__":
152 | parser = argparse.ArgumentParser()
153 |
154 | parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
155 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
156 |
157 | args = parser.parse_args()
158 |
159 | vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
160 |
--------------------------------------------------------------------------------
/scripts/copy_privacy_compliant_images.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check for the correct number of arguments
4 | if [ "$#" -ne 3 ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 | # Assign arguments to variables
10 | source_image_folder="$1"
11 | source_pt_folder="$2"
12 | destination_folder="$3"
13 |
14 | # Create the destination folder if it doesn't exist
15 | mkdir -p "$destination_folder"
16 |
17 | # Enable nullglob to ensure wildcard patterns that match no files expand to nothing
18 | shopt -s nullglob
19 |
20 | # Initialize counters
21 | copied_count=0
22 | total_files=$(ls -1q "$source_image_folder" | wc -l)
23 | matched_files=()
24 | pt_files_count=$(ls -1q "$source_pt_folder"/*.pt | wc -l)
25 | processed_files=0
26 |
27 | # Loop through .pt files in the source_pt_folder
28 | for pt_file in "$source_pt_folder"/*.pt; do
29 | # Extract the basename without the extension
30 | base_name=$(basename "$pt_file" .pt)
31 |
32 | # Construct the wildcard pattern for the source files
33 | pattern="$source_image_folder"/"$base_name".*
34 |
35 | # Attempt to copy the matching file(s) to the destination folder
36 | if cp $pattern "$destination_folder" 2>/dev/null; then
37 | matched_files+=("$pattern")
38 | ((copied_count++))
39 | fi
40 |
41 | # Update processed files count and calculate progress
42 | ((processed_files++))
43 | progress=$((processed_files * 100 / pt_files_count))
44 | echo -ne "Progress: $progress% \r"
45 | done
46 |
47 | # Disable nullglob to revert back to default behavior
48 | shopt -u nullglob
49 |
50 | # Calculate ignored files
51 | ignored_count=$((total_files - ${#matched_files[@]}))
52 |
53 | # Final newline after progress
54 | echo ""
55 |
56 | # Display summary
57 | echo "Operation completed."
58 | echo "Files copied: $copied_count"
59 | echo "Files ignored: $ignored_count"
60 |
--------------------------------------------------------------------------------
/scripts/create_reference_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import decord
3 | import numpy as np
4 | import os
5 | import pandas as pd
6 | import torch
7 | from tqdm import tqdm
8 | from echosyn.common import save_as_img
9 |
10 | decord.bridge.set_bridge('torch')
11 |
12 | """
13 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Dynamic --output data/reference/dynamic --frames 128
14 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/A4C --output data/reference/ped_a4c --frames 16
15 | python scripts/create_reference_dataset.py --dataset datasets/EchoNet-Pediatric/PSAX --output data/reference/ped_psax --frames 16
16 | """
17 |
18 | if __name__ == "__main__":
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory')
22 | parser.add_argument('--output', type=str, required=True)
23 | parser.add_argument('--frames', type=int, default=128)
24 | parser.add_argument('--fps', type=float, default=32, help='Target FPS of the video before frame extraction. Should match the diffusion model FPS.')
25 |
26 | args = parser.parse_args()
27 |
28 | csv_path = os.path.join(args.dataset, 'FileList.csv')
29 | video_path = os.path.join(args.dataset, 'Videos')
30 |
31 | assert os.path.exists(csv_path), f"Could not find FileList.csv at {csv_path}"
32 | assert os.path.exists(video_path), f"Could not find Videos directory at {video_path}"
33 |
34 | metadata = pd.read_csv(csv_path)
35 | metadata = metadata.sample(frac=1, random_state=42) # shuffle
36 | metadata.reset_index(drop=True, inplace=True)
37 |
38 | extracted_videos = 0
39 |
40 | target_count = {16: 3125, 128: 2048}[args.frames]
41 |
42 | threshold_duration = args.frames / args.fps # 128 -> 4 seconds, 16 -> 0.5 seconds
43 |
44 | for row in tqdm(metadata.iterrows(), total=len(metadata)):
45 | row = row[1]
46 |
47 | video_name = row['FileName'] if row['FileName'].endswith('.avi') else row['FileName'] + '.avi'
48 | video_path = os.path.join(args.dataset, 'Videos', video_name)
49 |
50 | nframes = row['NumberOfFrames']
51 | fps = row['FPS']
52 | duration = nframes / fps
53 |
54 | if duration < threshold_duration:
55 | # skip videos which are too short (less than 4 seconds / 128 frames)
56 | continue
57 |
58 | new_frame_count = np.floor(args.fps / fps * nframes).astype(int)
59 | resample_indices = np.linspace(0, nframes, new_frame_count, endpoint=False).round().astype(int)
60 |
61 | assert len(resample_indices) >= args.frames
62 | resample_indices = resample_indices[:args.frames]
63 |
64 | reader = decord.VideoReader(video_path, ctx=decord.cpu(), width=112, height=112)
65 | video = reader.get_batch(resample_indices) # T x H x W x C, uint8 tensor
66 | video = video.float().mean(axis=-1).clamp(0, 255).to(torch.uint8) # T x H x W
67 | video = video.unsqueeze(-1).repeat(1, 1, 1, 3) # T x H x W x 3
68 |
69 | folder_name = video_name[:-4] # remove .avi
70 |
71 | if row['Split'] == 'TRAIN':
72 | # path_all = os.path.join(args.output, "train", folder_name)
73 | path_all = os.path.join(args.output, folder_name)
74 | os.makedirs(path_all, exist_ok=True)
75 | save_as_img(video, path_all)
76 | extracted_videos += 1
77 |
78 | if extracted_videos >= target_count:
79 | print("Reached target count, stopping.")
80 | break
81 |
82 | print(f"Saved {extracted_videos} videos to {args.output}.")
83 | if extracted_videos < target_count:
84 | print(f"WARNING: only saved {extracted_videos} videos, which is less than the target count of {target_count}.")
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/scripts/encode_video_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 | from glob import glob
5 | from einops import rearrange
6 | import numpy as np
7 | from tqdm.auto import tqdm
8 | import cv2, json
9 | import shutil
10 | import pandas as pd
11 |
12 | import torch
13 |
14 | from diffusers import AutoencoderKL
15 | from echosyn.common import loadvideo, load_model
16 |
17 | """
18 | usage example:
19 |
20 | python scripts/encode_video_dataset.py \
21 | -m models/vae \
22 | -i datasets/EchoNet-Dynamic \
23 | -o data/latents/dynamic \
24 | -g
25 | """
26 |
27 | class VideoDataset(torch.utils.data.Dataset):
28 | def __init__(self, folder):
29 | self.folder = folder
30 | self.videos = sorted(glob(os.path.join(folder, "*.avi")))
31 |
32 | def __len__(self):
33 | return len(self.videos)
34 |
35 | def __getitem__(self, idx):
36 | video, fps = loadvideo(self.videos[idx], return_fps=True)
37 | return video, self.videos[idx], fps
38 |
39 | if __name__ == "__main__":
40 | # Parse arguments
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("-m", "--model", type=str, required=True, help="Path to model folder")
43 | parser.add_argument("-i", "--input", type=str, required=True, help="Path to input folder")
44 | parser.add_argument("-o", "--output", type=str, required=True, help="Path to output folder")
45 | parser.add_argument("-g", "--gray_scale", action="store_true", help="Convert to gray scale", default=False)
46 | parser.add_argument("-f", "--force_overwrite", action="store_true", help="Overwrite existing latents", default=False)
47 | args = parser.parse_args()
48 |
49 | # Prepare
50 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51 | os.makedirs(args.output, exist_ok=True)
52 | video_in_folder = os.path.abspath(os.path.join(args.input, "Videos"))
53 | video_out_folder = os.path.abspath(os.path.join(args.output, "Latents"))
54 |
55 | df = pd.read_csv(os.path.join(args.input, "FileList.csv"))
56 | needs_df_update = False
57 | if df['Split'].dtype == int:
58 | print("Updating Split column to string")
59 | needs_df_update = True
60 | df['Fold'] = df['Split']
61 |
62 | def split_set(row):
63 | if row['Fold'] in range(8):
64 | return 'TRAIN'
65 | elif row['Fold'] == 8:
66 | return 'VAL'
67 | else:
68 | return 'TEST'
69 |
70 | df['Split'] = df.apply(split_set, axis=1)
71 | df['FileName'] = df['FileName'].apply(lambda x: x.split('.')[0])
72 |
73 | if not needs_df_update:
74 | df.to_csv(os.path.join(args.output, "FileList.csv"), index=False)
75 |
76 | print("Loading videos from ", video_in_folder)
77 | print("Saving latents to ", video_out_folder)
78 |
79 | # Load VAE
80 | vae = load_model(args.model)
81 | vae = vae.to(device)
82 | vae.eval()
83 |
84 | # Load Dataset
85 | ds = VideoDataset(video_in_folder)
86 | dl = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=False, num_workers=8)
87 | print(f"Found {len(ds)} videos")
88 |
89 | batch_size = 32 # number of frames to encode simultaneously
90 |
91 | # for vpath in tqdm(videos):
92 | for video, vpath, fps in tqdm(dl):
93 |
94 | video = video[0]
95 | vpath = vpath[0]
96 | fps = fps[0]
97 |
98 | # output path
99 | opath = vpath.replace(video_in_folder, "")[1:] # retrieve relative path to input folder, similar to basename but keeps the folders
100 | opath = opath.replace(".avi", f".pt") # change extension
101 | opath = os.path.join(video_out_folder, opath) # add output folder
102 |
103 | # check if already exists
104 | if os.path.exists(opath) and not args.force_overwrite:
105 | print(f"Skipping {vpath} as {opath} already exists")
106 | continue
107 |
108 | # load video
109 | # video = loadvideo(vpath) # B H W C
110 | video = rearrange(video, "t h w c-> t c h w") # B C H W
111 | video = video.to(device)
112 | video = video.float() / 128.0 -1 # normalize to [-1, 1]
113 | if args.gray_scale:
114 | video = video.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) # B C H W
115 |
116 | # encode video
117 | all_latents = []
118 | for i in range(0, len(video), batch_size):
119 | batch = video[i:i+batch_size]
120 | with torch.no_grad():
121 | latents = vae.encode(batch).latent_dist.sample()
122 | latents = latents * vae.config.scaling_factor
123 | latents = latents.detach().cpu()
124 | all_latents.append(latents)
125 |
126 | all_latents = torch.cat(all_latents, dim=0)
127 |
128 | # save
129 | os.makedirs(os.path.dirname(opath), exist_ok=True)
130 | torch.save(all_latents, opath)
131 |
132 | if needs_df_update:
133 | fname = os.path.basename(opath).split('.')[0]
134 | df.loc[df['FileName'] == fname, ['FileName', 'FrameHeight','FrameWidth','FPS','NumberOfFrames']] = [fname, 112, 112, fps, len(video)]
135 |
136 | if needs_df_update:
137 | df.to_csv(os.path.join(args.output, "FileList.csv"), index=False)
138 | print("Done")
139 |
--------------------------------------------------------------------------------
/scripts/extract_frames_from_videos.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This scripts takes a folder of videos (.avi) and extract every 5th frame from each video
4 | # and saves it as a grayscale JPG image. The output images are saved in a folder specified
5 | # by the user.
6 | # The script uses ffmpeg and xarg to reduce processing time.
7 |
8 |
9 | usage() {
10 | echo "Usage: $0 "
11 | echo " : Path to the folder containing AVI videos."
12 | echo " : Path to the output folder where images will be saved."
13 | }
14 |
15 | # Check if exactly two arguments are given
16 | if [ $# -ne 2 ]; then
17 | echo "Error: Incorrect number of arguments."
18 | usage
19 | exit 1
20 | fi
21 |
22 | input_folder="$1"
23 | output_folder="$2"
24 |
25 | export output_folder
26 | # Create the output folder if it doesn't exist
27 | mkdir -p "$output_folder"
28 |
29 | # Function to extract frames and save as grayscale JPGs
30 | extract_frames() {
31 | local video_file=$1
32 | local video_name=$(basename -- "$video_file")
33 | local base_name="${video_name%.*}"
34 |
35 | # Create a specific directory for each video's frames
36 | local video_output_folder="$output_folder"
37 |
38 | # Output path for the frames
39 | local output_path="$video_output_folder/${base_name}_%05d.jpg"
40 |
41 | # Extract every 5th frame and save as a grayscale JPG
42 | ffmpeg -loglevel error -i "$video_file" -vf "select=not(mod(n\,5)),format=gray,format=yuv420p" \
43 | -vsync vfr "$output_path"
44 |
45 | echo -n "."
46 | }
47 |
48 | export -f extract_frames
49 |
50 | echo "Extracting frames from AVI videos in $input_folder and saving in $output_folder"
51 | echo ""
52 | echo "One dot = one video processed:"
53 | echo ""
54 |
55 | # Process each AVI video file
56 | find "$input_folder" -name "*.avi" -print0 | xargs -0 -I {} -P 32 bash -c 'extract_frames "$@"' _ {}
57 |
--------------------------------------------------------------------------------
/scripts/update_split_filelist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import argparse
4 |
5 | if __name__ == "__main__":
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--csv', type=str, required=True, help='Path to the csv file')
9 | parser.add_argument('--train', type=int, required=True, help='Number of training samples')
10 | parser.add_argument('--val', type=int, required=True, help='Number of validation samples')
11 | parser.add_argument('--test', type=int, required=True, help='Number of testing samples')
12 | args = parser.parse_args()
13 |
14 | df = pd.read_csv(args.csv)
15 | assert 'Split' in df.columns, "FileName column not found in the csv file"
16 |
17 | if (args.train + args.val + args.test) > len(df):
18 | print(f"Warning: There are not enough samples in the csv file ({len(df)}) to split")
19 | print(f"into {args.train} training, {args.val} validation, and {args.test} testing samples ({args.train + args.val + args.test}).")
20 | print(f"Please adjust the number of samples or provide a csv file with more samples.")
21 | exit(1)
22 | elif (args.train + args.val + args.test) < len(df):
23 | print(f"Warning: There are more samples in the csv file ({len(df)}) than needed to split")
24 | print(f"into {args.train} training, {args.val} validation, and {args.test} testing samples ({args.train + args.val + args.test}).")
25 | print(f"Any extra samples will be be put into the \"EXTRA\" split.")
26 |
27 | df.loc[:args.train, 'Split'] = 'TRAIN'
28 | df.loc[args.train:args.train+args.val, 'Split'] = 'VAL'
29 | df.loc[args.train+args.val:args.train+args.val+args.test, 'Split'] = 'TEST'
30 | df.loc[args.train+args.val+args.test:, 'Split'] = 'EXTRA'
31 |
32 | df.to_csv(args.csv, index=False)
33 | print(f"Split the csv file into {args.train} training, {args.val} validation, and {args.test} testing samples.")
34 |
--------------------------------------------------------------------------------
/scripts/vae_reconstruct_image_folder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import numpy as np
6 | from PIL import Image
7 | from glob import glob
8 | from tqdm import tqdm
9 | from einops import rearrange
10 |
11 | from echosyn.common import load_model, save_as_img
12 |
13 |
14 | class ImageLoader(Dataset):
15 | def __init__(self, all_paths):
16 | self.image_paths = all_paths
17 |
18 | def __len__(self):
19 | return len(self.image_paths)
20 |
21 | def __getitem__(self, idx):
22 | path = self.image_paths[idx]
23 | image = Image.open(path)
24 |
25 | image = np.array(image)
26 | image = image / 128.0 - 1 # [-1, 1]
27 | image = rearrange(image, 'h w c -> c h w')
28 |
29 | return image, self.image_paths[idx]
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser(description="Encode and decode images using a trained VAE model.")
33 | parser.add_argument("-m", "--model", type=str, help="Path to the trained VAE model.")
34 | parser.add_argument("-i", "--input", type=str, help="Path to the input folder.")
35 | parser.add_argument("-o", "--output", type=str, help="Path to the output folder.")
36 | parser.add_argument("-b", "--batch_size", type=int, default=128, help="Batch size to use for encoding and decoding.")
37 |
38 | args = parser.parse_args()
39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40 |
41 | # Load the model
42 | model = load_model(args.model)
43 | model.eval()
44 | model.to(device, torch.float32)
45 |
46 | # Load the videos
47 | folder_input = os.path.abspath(args.input)
48 | folder_output = os.path.abspath(args.output)
49 | all_images = glob(os.path.join(folder_input, "**", "*.jpg"), recursive=True)
50 | print(f"Found {len(all_images)} images in {folder_input}")
51 |
52 | # prepare output folder
53 | os.makedirs(folder_output, exist_ok=True)
54 |
55 | # dataset
56 | dataset = ImageLoader(all_images)
57 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=max(args.batch_size, 32))
58 |
59 | for batch in tqdm(dataloader):
60 | images, paths = batch
61 | images = images.to(device, torch.float32)
62 |
63 | # Encode the video
64 | with torch.no_grad():
65 | reconstructed_images = model(images).sample
66 |
67 | reconstructed_images = rearrange(reconstructed_images, 'b c h w -> b h w c')
68 | reconstructed_images = (reconstructed_images + 1) * 128.0
69 | reconstructed_images = reconstructed_images.clamp(0, 255).cpu().to(torch.uint8)
70 |
71 | # Save the reconstructed images
72 | for i, path in enumerate(paths):
73 | new_path = path.replace(folder_input, folder_output)
74 | os.makedirs(os.path.dirname(new_path), exist_ok=True)
75 | save_as_img(reconstructed_images[i], new_path)
76 |
77 | print(f"All reconstructed images saved to {folder_output}")
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='echosyn',
5 | version='1.0',
6 | packages=find_packages(),
7 | install_requires=[
8 | 'torch',
9 | 'torchvision',
10 | 'einops',
11 | 'decord',
12 | 'diffusers',
13 | 'packaging',
14 | 'omegaconf',
15 | 'transformers',
16 | 'accelerate',
17 | 'pandas',
18 | 'wandb',
19 | 'opencv-python',
20 | 'moviepy',
21 | 'imageio',
22 | 'scipy',
23 | 'scikit-learn',
24 | 'matplotlib',
25 | ]
26 | )
--------------------------------------------------------------------------------