├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── denoising_diffusion_pytorch ├── __init__.py ├── diffusion_models │ ├── __init__.py │ ├── components.py │ └── denoising_diffusion_pytorch.py ├── entities │ ├── __init__.py │ └── unet.py └── utils.py ├── images ├── denoising-diffusion.png └── sample.png ├── model.py ├── pytorch-xla-env-setup.py ├── requirements.txt ├── sample.py ├── setup.py ├── setup.sh ├── sweep.yaml ├── test.py ├── util ├── __init__.py ├── configs │ ├── __init__.py │ ├── generate_configs.py │ ├── samples │ │ ├── all-but-no-filters-or-dc.json │ │ ├── all-but-no-filters.json │ │ ├── default-config-copy.json │ │ ├── ica-only.json │ │ ├── raw-high-pass-ica.json │ │ ├── raw-high-pass-no-ica.json │ │ ├── raw-loss-pass-ica.json │ │ ├── raw-low-pass-no-ica.json │ │ ├── raw-only.json │ │ ├── remove-all-noise-ica-no-eog.json │ │ ├── remove-all-noise-no-ica.json │ │ └── requirements.txt │ └── variations.json ├── move.py ├── preprocessing │ ├── EEG_to_Dataset_Pipeline.ipynb │ ├── __init__.py │ ├── pipeline.py │ └── preprocess.py ├── resize.py └── update_configs.py └── visualise.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | .idea/ 135 | /util/configs/samples/generated/ 136 | /datasets/ 137 | /wandb/ 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Denoising Diffusion Probabilistic Model, in PyTorch 4 | 5 | Implementation of Denoising Diffusion Probabilistic Models in PyTorch. It is a new approach to generative modeling that may have the potential to rival GANs. It uses denoising score matching to estimate the gradient of the data distribution, followed by Langevin sampling to sample from the true distribution. 6 | 7 | This implementation was transcribed from the official Tensorflow version here 8 | 9 | Youtube AI Educators - Yannic Kilcher | AI Coffeebreak with Letitia | Outlier 10 | 11 | Annotated code by Research Scientists / Engineers from 🤗 Huggingface 12 | 13 | Update: Turns out none of the technicalities really matters at all | "Cold Diffusion" paper 14 | 15 | 16 | 17 | [![PyPI version](https://badge.fury.io/py/denoising-diffusion-pytorch.svg)](https://badge.fury.io/py/denoising-diffusion-pytorch) 18 | 19 | ## Install 20 | 21 | ```bash 22 | $ pip install denoising_diffusion_pytorch 23 | ``` 24 | 25 | ## Usage 26 | 27 | ```python 28 | import torch 29 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion 30 | 31 | model = Unet( 32 | dim = 64, 33 | dim_mults = (1, 2, 4, 8) 34 | ) 35 | 36 | diffusion = GaussianDiffusion( 37 | model, 38 | image_size = 128, 39 | timesteps = 1000, # number of steps 40 | loss_type = 'l1' # L1 or L2 41 | ) 42 | 43 | training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1 44 | loss = diffusion(training_images) 45 | loss.backward() 46 | # after a lot of training 47 | 48 | sampled_images = diffusion.sample(batch_size = 4) 49 | sampled_images.shape # (4, 3, 128, 128) 50 | ``` 51 | 52 | Or, if you simply want to pass in a folder name and the desired image dimensions, you can use the `Trainer` class to easily train a model. 53 | 54 | ```python 55 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion 56 | from denoising_diffusion_pytorch.utils import Trainer 57 | 58 | model = Unet( 59 | dim=64, 60 | dim_mults=(1, 2, 4, 8) 61 | ).cuda() 62 | 63 | diffusion = GaussianDiffusion( 64 | model, 65 | image_size=128, 66 | timesteps=1000, # number of steps 67 | sampling_timesteps=250, 68 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) 69 | loss_type='l1' # L1 or L2 70 | ).cuda() 71 | 72 | trainer = Trainer( 73 | diffusion, 74 | 'path/to/your/images', 75 | train_batch_size=32, 76 | train_lr=8e-5, 77 | train_num_steps=700000, # total training steps 78 | gradient_accumulate_every=2, # gradient accumulation steps 79 | ema_decay=0.995, # exponential moving average decay 80 | amp=True # turn on mixed precision 81 | ) 82 | 83 | trainer.train() 84 | ``` 85 | 86 | Samples and model checkpoints will be logged to `./results` periodically 87 | 88 | ## Multi-GPU Training 89 | 90 | The `Trainer` class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their `accelerate` CLI 91 | 92 | At the project root directory, where the training script is, run 93 | 94 | ```python 95 | $ accelerate config 96 | ``` 97 | 98 | Then, in the same directory 99 | 100 | ```python 101 | $ accelerate launch train.py 102 | ``` 103 | 104 | ## Citations 105 | 106 | ```bibtex 107 | @inproceedings{NEURIPS2020_4c5bcfec, 108 | author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter}, 109 | booktitle = {Advances in Neural Information Processing Systems}, 110 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin}, 111 | pages = {6840--6851}, 112 | publisher = {Curran Associates, Inc.}, 113 | title = {Denoising Diffusion Probabilistic Models}, 114 | url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf}, 115 | volume = {33}, 116 | year = {2020} 117 | } 118 | ``` 119 | 120 | ```bibtex 121 | @InProceedings{pmlr-v139-nichol21a, 122 | title = {Improved Denoising Diffusion Probabilistic Models}, 123 | author = {Nichol, Alexander Quinn and Dhariwal, Prafulla}, 124 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 125 | pages = {8162--8171}, 126 | year = {2021}, 127 | editor = {Meila, Marina and Zhang, Tong}, 128 | volume = {139}, 129 | series = {Proceedings of Machine Learning Research}, 130 | month = {18--24 Jul}, 131 | publisher = {PMLR}, 132 | pdf = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf}, 133 | url = {https://proceedings.mlr.press/v139/nichol21a.html}, 134 | } 135 | ``` 136 | 137 | ```bibtex 138 | @inproceedings{kingma2021on, 139 | title = {On Density Estimation with Diffusion Models}, 140 | author = {Diederik P Kingma and Tim Salimans and Ben Poole and Jonathan Ho}, 141 | booktitle = {Advances in Neural Information Processing Systems}, 142 | editor = {A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan}, 143 | year = {2021}, 144 | url = {https://openreview.net/forum?id=2LdBqxc1Yv} 145 | } 146 | ``` 147 | 148 | ```bibtex 149 | @article{Choi2022PerceptionPT, 150 | title = {Perception Prioritized Training of Diffusion Models}, 151 | author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon}, 152 | journal = {ArXiv}, 153 | year = {2022}, 154 | volume = {abs/2204.00227} 155 | } 156 | ``` 157 | 158 | ```bibtex 159 | @article{Karras2022ElucidatingTD, 160 | title = {Elucidating the Design Space of Diffusion-Based Generative Models}, 161 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine}, 162 | journal = {ArXiv}, 163 | year = {2022}, 164 | volume = {abs/2206.00364} 165 | } 166 | ``` 167 | 168 | ```bibtex 169 | @article{Song2021DenoisingDI, 170 | title = {Denoising Diffusion Implicit Models}, 171 | author = {Jiaming Song and Chenlin Meng and Stefano Ermon}, 172 | journal = {ArXiv}, 173 | year = {2021}, 174 | volume = {abs/2010.02502} 175 | } 176 | ``` 177 | 178 | ```bibtex 179 | @misc{chen2022analog, 180 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning}, 181 | author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton}, 182 | year = {2022}, 183 | eprint = {2208.04202}, 184 | archivePrefix = {arXiv}, 185 | primaryClass = {cs.CV} 186 | } 187 | ``` 188 | 189 | ```bibtex 190 | @article{Qiao2019WeightS, 191 | title = {Weight Standardization}, 192 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille}, 193 | journal = {ArXiv}, 194 | year = {2019}, 195 | volume = {abs/1903.10520} 196 | } 197 | ``` 198 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from denoising_diffusion_pytorch.diffusion_models.denoising_diffusion_pytorch import GaussianDiffusion 2 | from denoising_diffusion_pytorch.utils import Trainer 3 | from denoising_diffusion_pytorch.entities.unet import Unet 4 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/diffusion_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/denoising_diffusion_pytorch/diffusion_models/__init__.py -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/diffusion_models/components.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | from einops import reduce, rearrange 6 | from torch import nn, einsum 7 | from torch.nn import functional as F 8 | 9 | from denoising_diffusion_pytorch.utils import exists, compute_L2_norm 10 | 11 | 12 | class Residual(nn.Module): 13 | def __init__(self, fn): 14 | super().__init__() 15 | self.fn = fn 16 | 17 | def forward(self, x, *args, **kwargs): 18 | return self.fn(x, *args, **kwargs) + x 19 | 20 | 21 | class WeightStandardizedConv2d(nn.Conv2d): 22 | """ 23 | https://arxiv.org/abs/1903.10520 24 | weight standardization purportedly works synergistically with group normalization 25 | """ 26 | 27 | def forward(self, x): 28 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 29 | 30 | weight = self.weight 31 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') 32 | var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased=False)) 33 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 34 | 35 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 36 | 37 | 38 | class LayerNorm(nn.Module): 39 | def __init__(self, dim): 40 | super().__init__() 41 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 42 | 43 | def forward(self, x): 44 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 45 | var = torch.var(x, dim=1, unbiased=False, keepdim=True) 46 | mean = torch.mean(x, dim=1, keepdim=True) 47 | return (x - mean) * (var + eps).rsqrt() * self.g 48 | 49 | 50 | class PreNorm(nn.Module): 51 | def __init__(self, dim, fn): 52 | super().__init__() 53 | self.fn = fn 54 | self.norm = LayerNorm(dim) 55 | 56 | def forward(self, x): 57 | x = self.norm(x) 58 | return self.fn(x) 59 | 60 | 61 | class SinusoidalPositionalEmbedding(nn.Module): 62 | def __init__(self, dim): 63 | super().__init__() 64 | self.dim = dim 65 | 66 | def forward(self, x): 67 | device = x.device 68 | half_dim = self.dim // 2 69 | emb = math.log(10000) / (half_dim - 1) 70 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 71 | emb = x[:, None] * emb[None, :] 72 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 73 | return emb 74 | 75 | 76 | class LearnedSinusoidalPositionalEmbedding(nn.Module): 77 | """ following @crowsonkb 's lead with learned sinusoidal pos emb """ 78 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 79 | 80 | def __init__(self, dim): 81 | super().__init__() 82 | assert (dim % 2) == 0 83 | half_dim = dim // 2 84 | self.weights = nn.Parameter(torch.randn(half_dim)) 85 | 86 | def forward(self, x): 87 | x = rearrange(x, 'b -> b 1') 88 | frequencies = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 89 | fouriered = torch.cat((frequencies.sin(), frequencies.cos()), dim=-1) 90 | fouriered = torch.cat((x, fouriered), dim=-1) 91 | return fouriered 92 | 93 | 94 | class BaseBlock(nn.Module): 95 | def __init__(self, dim, dim_out, groups=8): 96 | super().__init__() 97 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 98 | self.norm = nn.GroupNorm(groups, dim_out) 99 | self.act = nn.SiLU() 100 | 101 | def forward(self, x, scale_shift=None): 102 | x = self.proj(x) 103 | x = self.norm(x) 104 | 105 | if exists(scale_shift): 106 | scale, shift = scale_shift 107 | x = x * (scale + 1) + shift 108 | 109 | x = self.act(x) 110 | return x 111 | 112 | 113 | class ResnetBlock(nn.Module): 114 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 115 | super().__init__() 116 | self.mlp = nn.Sequential( 117 | nn.SiLU(), 118 | nn.Linear(time_emb_dim, dim_out * 2) 119 | ) if exists(time_emb_dim) else None 120 | 121 | self.block1 = BaseBlock(dim, dim_out, groups=groups) 122 | self.block2 = BaseBlock(dim_out, dim_out, groups=groups) 123 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 124 | 125 | def forward(self, x, time_emb=None): 126 | scale_shift = None 127 | if exists(self.mlp) and exists(time_emb): 128 | time_emb = self.mlp(time_emb) 129 | time_emb = rearrange(time_emb, 'b c -> b c 1 1') 130 | scale_shift = time_emb.chunk(2, dim=1) 131 | 132 | h = self.block1(x, scale_shift=scale_shift) 133 | 134 | h = self.block2(h) 135 | 136 | return h + self.res_conv(x) 137 | 138 | 139 | class LinearAttention(nn.Module): 140 | def __init__(self, dim, heads=4, dim_head=32): 141 | super().__init__() 142 | self.scale = dim_head ** -0.5 143 | self.heads = heads 144 | hidden_dim = dim_head * heads 145 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 146 | 147 | self.to_out = nn.Sequential( 148 | nn.Conv2d(hidden_dim, dim, 1), 149 | LayerNorm(dim) 150 | ) 151 | 152 | def forward(self, x): 153 | b, c, h, w = x.shape 154 | qkv = self.to_qkv(x).chunk(3, dim=1) 155 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) 156 | 157 | q = q.softmax(dim=-2) 158 | k = k.softmax(dim=-1) 159 | 160 | q = q * self.scale 161 | v = v / (h * w) 162 | 163 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 164 | 165 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 166 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) 167 | return self.to_out(out) 168 | 169 | 170 | class Attention(nn.Module): 171 | def __init__(self, dim, heads=4, dim_head=32, scale=10): 172 | super().__init__() 173 | self.scale = scale 174 | self.heads = heads 175 | hidden_dim = dim_head * heads 176 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 177 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 178 | 179 | def forward(self, x): 180 | b, c, h, w = x.shape 181 | qkv = self.to_qkv(x).chunk(3, dim=1) 182 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) 183 | 184 | q, k = map(compute_L2_norm, (q, k)) 185 | 186 | sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale 187 | attn = sim.softmax(dim=-1) 188 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 189 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 190 | # I hate einsums... 191 | return self.to_out(out) 192 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/diffusion_models/denoising_diffusion_pytorch.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from random import random 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import wandb 7 | from einops import reduce 8 | from torch import nn 9 | from tqdm.auto import tqdm 10 | 11 | from denoising_diffusion_pytorch import utils 12 | from denoising_diffusion_pytorch.utils import normalise_to_negative_one_to_one, \ 13 | unnormalise_to_zero_to_one, extract, linear_beta_schedule, cosine_beta_schedule, default 14 | 15 | # Constants 16 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) 17 | 18 | 19 | # TODO add documentation/descriptions to every method and class 20 | # TODO generate model diagram and per-layer parameter count 21 | 22 | 23 | class GaussianDiffusion(nn.Module): 24 | def __init__( 25 | self, 26 | learning_model, 27 | *, 28 | image_size, 29 | timesteps=1000, 30 | sampling_timesteps=None, 31 | loss_type='l1', 32 | training_objective='pred_noise', 33 | beta_schedule='cosine', 34 | p2_loss_weight_gamma=0., 35 | # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. 36 | # is recommended 37 | p2_loss_weight_k=1, 38 | ddim_sampling_eta=1. 39 | ): 40 | """ 41 | This class provides all important logic and behaviour for the base Gaussian Diffusion model. 42 | 43 | :param learning_model: The model used for learning the forwards diffusion process from x_T to x_0. 44 | This is typically a U-Net model, inline with the literature. 45 | :param image_size: The single dimension for the output image. 46 | For example, a value of 32 will produce a 32x32 pixel image output. 47 | :param timesteps: The number of timesteps to be used for the forward and reverse processes of the model. 48 | :param sampling_timesteps: The number of timesteps to be used for sampling. 49 | If this is less than param timesteps, then we are using Improved DDPM. 50 | :param loss_type: The type of loss we will use. This can be either L1 or L2 loss. 51 | :param training_objective: The objective that dictates what the model attempts to learn. 52 | This must be either pred_noise to learn noise, or pred_x0 to learn the truth image. 53 | """ 54 | super().__init__() 55 | loss_type = loss_type.lower() 56 | training_objective = training_objective.lower() 57 | beta_schedule = beta_schedule.lower() 58 | 59 | assert loss_type in ['l1', 'l2'], f'The specified loss type, {loss_type}, must be either L1 or L2.' 60 | assert training_objective in ['pred_noise', 'pred_x0'], \ 61 | 'The given objective must be either pred_noise (predict noise) or pred_x0 (predict image start)' 62 | assert beta_schedule in ['linear', 'cosine'], f'The given beta schedule {beta_schedule} is invalid!' 63 | assert not (type(self) != GaussianDiffusion and learning_model.channels == learning_model.out_dim) 64 | # TODO add an assertion error message 65 | assert sampling_timesteps is None or 0 < sampling_timesteps <= timesteps, \ 66 | 'The given sampling timesteps value is invalid!' 67 | 68 | self.learning_model = learning_model 69 | self.channels = self.learning_model.channels 70 | self.self_condition = self.learning_model.self_condition 71 | self.image_size = image_size 72 | self.objective = training_objective 73 | 74 | betas = linear_beta_schedule(timesteps) if beta_schedule == 'linear' else cosine_beta_schedule(timesteps) 75 | alphas = 1. - betas 76 | alphas_cumprod = torch.cumprod(alphas, axis=0) 77 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) 78 | 79 | timesteps, = betas.shape 80 | self.num_timesteps = int(timesteps) 81 | self.loss_type = loss_type 82 | 83 | # Sampling-related parameters 84 | 85 | self.sampling_timesteps = default(sampling_timesteps, timesteps) 86 | # The default number of sampling timesteps. Reduced for Improved DDPM 87 | 88 | assert self.sampling_timesteps <= timesteps 89 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 90 | self.ddim_sampling_eta = ddim_sampling_eta 91 | 92 | # Helper function to convert function values from float64 to float32 93 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 94 | 95 | register_buffer('betas', betas) 96 | register_buffer('alphas_cumprod', alphas_cumprod) 97 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 98 | 99 | # calculations for diffusion q(x_t | x_{t-1}) and others 100 | 101 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 102 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 103 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 104 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 105 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 106 | 107 | # calculations for posterior q(x_{t-1} | x_t, x_0) 108 | 109 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 110 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 111 | 112 | register_buffer('posterior_variance', posterior_variance) 113 | 114 | # Log when the variance is clipped, as the posterior variance is zero 115 | # at the beginning of the diffusion chain (x_0 to x_T) 116 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) 117 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 118 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 119 | 120 | # Calculate reweighting values for p2 loss 121 | register_buffer('p2_loss_weight', 122 | (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) 123 | 124 | def predict_x0_from_noise(self, x_t, t, noise): 125 | return ( 126 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 127 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 128 | ) 129 | 130 | def predict_noise_from_xt(self, x_t, t, x0): 131 | """ 132 | This method attempts to predict the gaussian noise for the forwards process, from x_T to x_0. 133 | :param x_t: The isotropic gaussian noise sample at the beginning of the forwards process. 134 | :param t: The number of sampling timesteps. 135 | """ 136 | return ((extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / 137 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)) 138 | 139 | def q_posterior(self, x_start, x_t, t): 140 | posterior_mean = ( 141 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 142 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 143 | ) 144 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 145 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 146 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 147 | 148 | def model_predictions(self, x, t, x_self_cond=None): 149 | model_output = self.learning_model(x, t, x_self_cond) 150 | predicted_noise, x_0 = None, None 151 | 152 | if self.objective == 'pred_noise': 153 | predicted_noise = model_output 154 | x_0 = self.predict_x0_from_noise(x, t, model_output) 155 | 156 | elif self.objective == 'pred_x0': 157 | predicted_noise = self.predict_noise_from_xt(x, t, model_output) 158 | x_0 = model_output # The output of the model, x0 159 | 160 | return ModelPrediction(predicted_noise, x_0) 161 | 162 | def p_mean_variance(self, x, t, x_self_cond=None, clip_denoised=True): 163 | """ 164 | This method computes the mean and variance by sampling directly from the model. 165 | """ 166 | preds = self.model_predictions(x, t, x_self_cond) 167 | x_start = preds.pred_x_start 168 | 169 | if clip_denoised: 170 | x_start.clamp_(-1., 1.) 171 | 172 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t) 173 | return model_mean, posterior_variance, posterior_log_variance, x_start 174 | 175 | @torch.no_grad() 176 | def compute_sample_for_timestep(self, x, t: int, x_self_cond=None, clip_denoised=True): 177 | """ 178 | This method takes a single step in the sampling/forwards process. For example, in a forwards process with 250 179 | sampling steps, this method will be called 250 times. 180 | :param x: The current image for the given timestep t. If t=T, then we are at the beginning of the forwards 181 | process, and x will be an isotropic Gaussian noise sample. 182 | :param t: The current sampling timestep that defines x_T, where 0 <= t <= T. 183 | """ 184 | b, *_, device = *x.shape, x.device 185 | # batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) 186 | batched_times = torch.full((x.shape[0],), t, dtype=torch.long) 187 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x=x, t=batched_times, x_self_cond=x_self_cond, 188 | clip_denoised=clip_denoised) 189 | noise = torch.randn_like(x) if t > 0 else 0. 190 | # Reset noise if t == zero, i.e., if we now have the output image of the model. 191 | # This nulls-out the following operation, preserving the output image. 192 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise 193 | return pred_img, x_start 194 | 195 | @torch.no_grad() 196 | def compute_complete_sample(self, shape, device): 197 | """ 198 | This method simply runs a for loop used for computing the series of samples from x_T through to x_0. 199 | The returned value is the final output image of the model. 200 | 201 | A progress bar is provided by the tqdm library throughout. 202 | 203 | :param shape: The shape of the output, not the image output. 204 | Specifically, this is set to the batch size x n_channels , determining how 205 | many samples should be generated in one iteration. 206 | :param device: The device to use for computing the samples on. 207 | """ 208 | img = torch.randn(shape, device=device) 209 | # img = torch.randn(shape) 210 | 211 | x_start = None 212 | 213 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling loop time step', total=self.num_timesteps): 214 | self_cond = x_start if self.self_condition else None 215 | img, x_start = self.compute_sample_for_timestep(img, t, self_cond) 216 | 217 | img = unnormalise_to_zero_to_one(img) 218 | return img 219 | 220 | @torch.no_grad() 221 | def ddim_sample(self, shape, clip_denoised=True): 222 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[ 223 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 224 | 225 | times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1] 226 | times = list(reversed(times.int().tolist())) 227 | time_pairs = list(zip(times[:-1], times[1:])) 228 | 229 | img = torch.randn(shape, device=device) 230 | 231 | # Begin image, xT, sampled as random noise 232 | # TODO need a way to specify a noise sample from the EEG forward process, 233 | # not to randomly generate it 234 | 235 | x_start = None 236 | 237 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step'): 238 | alpha = self.alphas_cumprod[time] 239 | alpha_next = self.alphas_cumprod[time_next] 240 | 241 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long) 242 | 243 | self_cond = x_start if self.self_condition else None 244 | 245 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond) 246 | 247 | if clip_denoised: 248 | x_start.clamp_(-1., 1.) 249 | 250 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 251 | c = ((1 - alpha_next) - sigma ** 2).sqrt() 252 | 253 | noise = torch.randn_like(img) if time_next > 0 else 0. 254 | 255 | img = x_start * alpha_next.sqrt() + \ 256 | c * pred_noise + \ 257 | sigma * noise 258 | 259 | img = utils.unnormalise_to_zero_to_one(img) 260 | return img 261 | 262 | @torch.no_grad() 263 | def sample(self, batch_size): 264 | """ 265 | This method computes a given number of samples from the model in one sampling step. 266 | This method does not execute the many inferences required to draw a sample 267 | but instead determines which sampling strategy to use (standard or reduced sample step count), 268 | image output dimensions and batch sizes. 269 | """ 270 | batch_size = 16 if batch_size is None else batch_size # default value 271 | image_size, channels = self.image_size, self.channels 272 | sampling_function = self.ddim_sample if self.is_ddim_sampling else self.compute_complete_sample 273 | return sampling_function((batch_size, channels, image_size, image_size)) 274 | 275 | @torch.no_grad() 276 | def interpolate(self, x1, x2, t=None, lam=0.5): 277 | b, *_, device = *x1.shape, x1.device 278 | t = default(t, self.num_timesteps - 1) 279 | 280 | assert x1.shape == x2.shape 281 | 282 | t_batched = torch.stack([torch.tensor(t, device=device)] * b) 283 | # t_batched = torch.stack([torch.tensor(t)] * b) 284 | xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) 285 | 286 | img = (1 - lam) * xt1 + lam * xt2 287 | for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t): 288 | img = self.compute_sample_for_timestep(img, torch.full((b,), i, device=device, dtype=torch.long)) 289 | # img = self.compute_sample_for_timestep(img, torch.full((b,), i, dtype=torch.long)) 290 | 291 | return img 292 | 293 | def q_sample(self, x_start, t, noise=None): 294 | noise = default(noise, lambda: torch.randn_like(x_start)) 295 | # Uses the supplied noise value if it exists, 296 | # or generates more Gaussian noise with the same dimensions as x_start 297 | 298 | return ( 299 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 300 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 301 | ) 302 | 303 | @property 304 | def loss_fn(self): 305 | """ 306 | This function does not compute the loss for two given images (truth and predicted), 307 | but instead returns the appropriate loss function in accordance with the stated 308 | desired loss function. 309 | 310 | Given that this method returns a function, then any parameters supplied are 311 | naturally passed into that function, thus deriving the loss value. 312 | """ 313 | if self.loss_type == 'l1': 314 | return F.l1_loss 315 | elif self.loss_type == 'l2': 316 | return F.mse_loss 317 | else: 318 | raise ValueError(f'invalid loss type {self.loss_type}') 319 | 320 | def compute_loss_for_timestep(self, eeg_sample, target_sample, timestep, noise=None): 321 | """ 322 | This method computes the losses for a full pass of a given image through the model. 323 | This effectively performs the full noising then denoising/diffusion process. 324 | After this, the losses between the true image and the generated image (via the diffusion process) 325 | are compared, their losses computed, and returned. 326 | 327 | :param x_start: The given image, x_0, to use for other processes. 328 | :param generated_noise: A sample of noise to be applied to the given x_start value. 329 | :param timestep: The timestep to compute losses for. This can be updated linearly, or sampled randomly. 330 | """ 331 | # b, c, h, w = x_start.shape # Commented out as these values are not used 332 | generated_noise = default(noise, lambda: torch.randn_like(eeg_sample)) 333 | # Generates random normal/Gaussian noise with the same dimensions as the given input 334 | 335 | x = self.q_sample(x_start=eeg_sample, t=timestep, noise=generated_noise) 336 | # Warps the image by the noise we just generated 337 | # in accordance to our beta scheduling choice and current timestep t 338 | 339 | # If you are performing self-conditioning, then 50% of the training iterations 340 | # will predict x_start from the current timestep, t. This will then be used to 341 | # update U-Net's gradients with. This technique increases training time by 25%, 342 | # but appears to significantly lower the FID score of the model 343 | x_self_cond = None # TODO look into using this 344 | if self.self_condition and random() < 0.5: 345 | with torch.no_grad(): 346 | x_self_cond = self.model_predictions(x, timestep).pred_x_start 347 | x_self_cond.detach_() 348 | 349 | # Next we predict the output according to our objective, 350 | # then compute the gradient from that result 351 | model_out = self.learning_model(x, timestep, x_self_cond) # The prediction of our model 352 | 353 | if self.objective == 'pred_noise': 354 | # If we are trying to predict the noise that was just added 355 | 356 | # target = noise 357 | target = default(noise, torch.randn_like(target_sample)) 358 | elif self.objective == 'pred_x0': 359 | # If we are trying to predict the original, true image, x_0 360 | # target = x_start # 361 | target = target_sample 362 | else: 363 | raise ValueError(f'unknown objective {self.objective}') 364 | 365 | loss = self.loss_fn(model_out, target, reduction='none') 366 | 367 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 368 | 369 | loss = loss * extract(self.p2_loss_weight, timestep, loss.shape) 370 | wandb.log({"raw_losses": loss, "averaged_loss": loss.mean().item()}) 371 | # TODO maybe log Inception Score and/or FID. 372 | return loss.mean() 373 | 374 | def forward(self, img, *args, **kwargs): 375 | """ 376 | When calling model(x), this is the method that is called. 377 | This method takes an input image - x_0 - from the dataset and trains the diffusion & U-Net model on it. 378 | 379 | :param img: The image to be used as x_0, which is the starting and ending 380 | image for the noising and diffusion process, respectively. 381 | """ 382 | 383 | # b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 384 | eeg_sample, target_sample = img 385 | 386 | b, c, h, w, device, img_size = *eeg_sample.shape, eeg_sample.device, self.image_size 387 | 388 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 389 | timestep = torch.randint(0, self.num_timesteps, (b,), device=device).long() 390 | # timestep = torch.randint(0, self.num_timesteps, (b,)).long() 391 | 392 | eeg_sample = normalise_to_negative_one_to_one(eeg_sample) 393 | target_sample = normalise_to_negative_one_to_one(target_sample) 394 | return self.compute_loss_for_timestep(eeg_sample, target_sample, timestep, *args, **kwargs) 395 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/entities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/denoising_diffusion_pytorch/entities/__init__.py -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/entities/unet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from denoising_diffusion_pytorch.diffusion_models.components import Residual, PreNorm, SinusoidalPositionalEmbedding, \ 7 | LearnedSinusoidalPositionalEmbedding, ResnetBlock, LinearAttention, Attention 8 | from denoising_diffusion_pytorch.utils import upsample, downsample, default 9 | 10 | 11 | class Unet(nn.Module): 12 | def __init__( 13 | self, 14 | dim, 15 | init_dim=None, 16 | out_dim=None, 17 | dim_mults=(1, 2, 4, 8), 18 | channels=3, 19 | self_condition=False, 20 | resnet_block_groups=8, 21 | learned_variance=False, 22 | # TODO learned_variance seems to determine if we follow Improved DDPM and learn the variance alongside 23 | # the mean? Might be worth setting to True if so. 24 | learned_sinusoidal_cond=False, 25 | learned_sinusoidal_dim=16 # Per Improved DDPM 26 | ): 27 | """ 28 | :param channels: A multiplier for each channel's layer count. 29 | """ 30 | super().__init__() 31 | 32 | # determine dimensions 33 | 34 | self.channels = channels 35 | self.self_condition = self_condition 36 | input_channels = channels * (2 if self_condition else 1) 37 | 38 | init_dim = default(init_dim, dim) 39 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding=3) 40 | 41 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 42 | in_out = list(zip(dims[:-1], dims[1:])) 43 | 44 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 45 | 46 | # time embeddings 47 | 48 | time_dim = dim * 4 49 | 50 | self.learned_sinusoidal_cond = learned_sinusoidal_cond 51 | 52 | if learned_sinusoidal_cond: # TODO determine if we want to use this 53 | sinu_pos_emb = LearnedSinusoidalPositionalEmbedding(learned_sinusoidal_dim) 54 | fourier_dim = learned_sinusoidal_dim + 1 55 | else: 56 | sinu_pos_emb = SinusoidalPositionalEmbedding(dim) 57 | fourier_dim = dim 58 | 59 | self.time_mlp = nn.Sequential( 60 | sinu_pos_emb, 61 | nn.Linear(fourier_dim, time_dim), 62 | nn.GELU(), 63 | nn.Linear(time_dim, time_dim) 64 | ) 65 | 66 | # layers 67 | 68 | self.downs = nn.ModuleList([]) 69 | self.ups = nn.ModuleList([]) 70 | num_resolutions = len(in_out) 71 | 72 | for ind, (dim_in, dim_out) in enumerate(in_out): 73 | is_last = ind >= (num_resolutions - 1) 74 | 75 | self.downs.append(nn.ModuleList([ 76 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 77 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 78 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 79 | downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1) 80 | ])) 81 | 82 | mid_dim = dims[-1] 83 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 84 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 85 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 86 | 87 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 88 | is_last = ind == (len(in_out) - 1) 89 | 90 | self.ups.append(nn.ModuleList([ 91 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 92 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 93 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 94 | upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1) 95 | ])) 96 | 97 | default_out_dim = channels * (1 if not learned_variance else 2) 98 | self.out_dim = default(out_dim, default_out_dim) 99 | 100 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 101 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 102 | 103 | def forward(self, x, time, x_self_cond=None): 104 | if self.self_condition: 105 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 106 | x = torch.cat((x_self_cond, x), dim=1) 107 | 108 | x = self.init_conv(x) 109 | r = x.clone() 110 | 111 | t = self.time_mlp(time) 112 | 113 | h = [] 114 | 115 | for block1, block2, attn, downsample in self.downs: 116 | x = block1(x, t) 117 | h.append(x) 118 | 119 | x = block2(x, t) 120 | x = attn(x) 121 | h.append(x) 122 | 123 | x = downsample(x) 124 | 125 | x = self.mid_block1(x, t) 126 | x = self.mid_attn(x) 127 | x = self.mid_block2(x, t) 128 | 129 | for block1, block2, attn, upsample in self.ups: 130 | x = torch.cat((x, h.pop()), dim=1) 131 | x = block1(x, t) 132 | 133 | x = torch.cat((x, h.pop()), dim=1) 134 | x = block2(x, t) 135 | x = attn(x) 136 | 137 | x = upsample(x) 138 | 139 | x = torch.cat((x, r), dim=1) 140 | 141 | x = self.final_res_block(x, t) 142 | return self.final_conv(x) 143 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | import os 4 | import random 5 | import shutil 6 | from collections import defaultdict 7 | from functools import partial 8 | from pathlib import Path 9 | 10 | import torch 11 | import wandb 12 | from PIL import Image 13 | from accelerate import Accelerator 14 | from ema_pytorch import EMA 15 | from torch import nn 16 | from torch.nn import functional as F 17 | from torch.optim import Adam 18 | from torch.utils.data import Dataset, DataLoader 19 | from torchvision import transforms as T, utils 20 | from tqdm.auto import tqdm 21 | 22 | 23 | def exists(x): 24 | """ 25 | This method checks if the given parameter is not equal to None, i.e., if it has a value. 26 | 27 | :param x: The value to be checked for None-status. 28 | """ 29 | return x is not None 30 | 31 | 32 | class GenericDataset(Dataset): 33 | def __init__( 34 | self, 35 | folder, 36 | image_size: int, 37 | exts: list = None, 38 | augment_horizontal_flip=False, 39 | convert_image_to=None 40 | ): 41 | """ 42 | This class loads images from a given directory and resizes them to be square. 43 | 44 | :param folder: The folder that contains the files for this dataset. 45 | :param int image_size: The dimensions for the given image. All images will be converted to a square with these dimensions. 46 | :param list exts: A list of file extensions that this class should load, such as jpg and png. 47 | :param augment_horizontal_flip: If a horizontal (left-to-right) flip of the image should be performed. 48 | :param convert_image_to: A given lambda function specifying how to convert the input images 49 | for this dataset. This is applied before any other manipulations, such as resizing or 50 | horizontal flipping. 51 | """ 52 | super().__init__() 53 | if exts is None: 54 | exts = ['jpg', 'jpeg', 'png', 'tiff'] 55 | self.folder = folder 56 | self.image_size = image_size 57 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 58 | 59 | lambda_convert_function = partial(convert_image_to, convert_image_to) \ 60 | if exists(convert_image_to) \ 61 | else nn.Identity() # TODO determine what partial(...) does 62 | # nn.Identity simply returns the input. 63 | # So, if convert_image_to_fn is None, 64 | # lambda_convert_function will just return whatever is input to it 65 | 66 | self.transform = T.Compose([ 67 | T.Lambda(lambda_convert_function), 68 | # Execute some lambda of code to convert an image of the dataset by, such as greyscaling. 69 | T.Resize(image_size), 70 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), 71 | T.CenterCrop(image_size), 72 | T.ToTensor() 73 | ]) 74 | 75 | def __len__(self): 76 | """ 77 | Return the number of images in this dataset. 78 | """ 79 | return len(self.paths) 80 | 81 | def __getitem__(self, index): 82 | """ 83 | Return the given image for the given index in this dataset. 84 | """ 85 | path = self.paths[index] 86 | img = Image.open(path) 87 | return self.transform(img) 88 | 89 | 90 | def find_and_move_unsorted(src_dir, ftype: str): 91 | if len(os.listdir(f'{src_dir}/unsorted')) <= 0: 92 | pass 93 | # for path in Path(f'{src_dir}/unsorted').rglob(f'*.{ftype}'): 94 | for path in Path(f'{src_dir}/unsorted').rglob(f'*.*'): 95 | name = str(path).lower().split('/')[-1] 96 | p = None 97 | p = 'penguin' if 'penguin' in name else p 98 | p = 'guitar' if 'guitar' in name else p 99 | p = 'flower' if 'flower' in name else p 100 | 101 | assert p is not None, f'Could not sort file at path: {path}' 102 | print('Moving', name, 'to', src_dir, p) 103 | os.makedirs(f'{src_dir}/{p}', exist_ok=True) 104 | os.rename(path, f'{src_dir}/{p}/{name}') 105 | 106 | shutil.rmtree(f'{src_dir}/unsorted') 107 | os.mkdir(f'{src_dir}/unsorted') 108 | 109 | 110 | class EEGTargetsDataset(Dataset): 111 | def __init__(self, 112 | eeg_directory='./datasets/eeg', 113 | targets_directory='./datasets/targets', 114 | labels: list = None, 115 | shuffle_eeg=True, 116 | shuffle_targets=True, 117 | file_types=None, 118 | unsorted_eeg_policy=None, 119 | unsorted_target_policy=None, 120 | image_size=[32, 32]): 121 | """ 122 | This class loads a dataset of EEG and target image pairs, and then determines the correct class label for each one. 123 | 124 | A typical directory structure is as follows: 125 | datasets/ 126 | ├── eeg 127 | │ ├── flower 128 | │ ├── guitar 129 | │ ├── penguin 130 | │ └── unsorted 131 | └── target 132 | ├── flower 133 | ├── guitar 134 | └── penguin 135 | └── unsorted 136 | 137 | It is expected that the directories at the bottom layer are named after the label they contain. 138 | It is expected that the names match between the `eeg` and `targets` directory. 139 | 140 | Any files and directories in the 'unsorted' directory will be searched recursively 141 | and moved to the appropriate directory. Their appropriate directory will be inferred 142 | by the file's name. 143 | 144 | 145 | :param eeg_directory: The directory containing EEG images for training. These will be loaded recursively. 146 | :param targets_directory: The directory containing target images for training. These will be loaded recursively. 147 | :param labels: The list of labels to be used for training. If left blank, this defaults to guitar, penguin and flower. 148 | :param shuffle_eeg: If the EEG training images should be sampled from randomly. 149 | :param shuffle_targets: If the target training images for each respective EEG image should be sampled from randomly. 150 | :param file_types: The list of file types to load. 151 | """ 152 | if unsorted_target_policy is None: 153 | unsorted_target_policy = ['move', 'delete-src-dirs'] 154 | if unsorted_eeg_policy is None: 155 | unsorted_eeg_policy = ['move', 'delete-src-dirs'] 156 | if labels is None: 157 | labels = ['penguin', 'guitar', 'flower'] 158 | 159 | if file_types is None: 160 | file_types = ['jpg', 'png'] 161 | 162 | self.file_types = file_types 163 | self.labels = labels 164 | self.eeg_directory = eeg_directory 165 | self.targets_directory = targets_directory 166 | self.shuffle_eeg = shuffle_eeg 167 | self.shuffle_targets = shuffle_targets 168 | self.data = defaultdict(lambda: defaultdict(list)) 169 | self.indices = {'eeg': {}, 'target': {}} 170 | self.image_size = image_size 171 | 172 | # TODO load eeg recursively, load targets recursively, generate labels for each, group them together 173 | 174 | os.makedirs(eeg_directory, exist_ok=True) 175 | os.makedirs(targets_directory, exist_ok=True) 176 | 177 | for label in labels: 178 | assert os.path.exists(f'{eeg_directory}/{label}'), \ 179 | f'The EEG directory for `{label}` does not exist.' 180 | assert os.path.exists(f'{targets_directory}/{label}'), \ 181 | f'The targets directory for `{label}` does not exist.' 182 | # 183 | # for ftype in file_types: 184 | # find_and_move_unsorted(eeg_directory, ftype) 185 | # find_and_move_unsorted(targets_directory, ftype) 186 | 187 | for label in labels: 188 | d0 = os.listdir(f'{eeg_directory}/{label}') 189 | d1 = os.listdir(f'{targets_directory}/{label}') 190 | 191 | d0 = [d for d in d0 if not d.startswith('.')] 192 | d1 = [d for d in d1 if not d.startswith('.')] 193 | 194 | if self.shuffle_eeg: 195 | random.shuffle(d0) 196 | 197 | if self.shuffle_targets: 198 | random.shuffle(d1) 199 | 200 | self.data['eeg'][label] = d0 201 | self.data['targets'][label] = d1 202 | 203 | # print(f'Loaded {len(d0)} images for EEG/{label}') 204 | # print(f'Loaded {len(d1)} images for Targets/{label}') 205 | 206 | self.transformTarget = T.Compose([ 207 | T.RandomHorizontalFlip(), 208 | T.Resize(self.image_size), 209 | T.ToTensor() 210 | ]) 211 | 212 | self.transformEEG = T.Compose([ 213 | T.Resize(self.image_size), 214 | T.ToTensor() 215 | ]) 216 | 217 | def __getitem__(self, index): 218 | """ 219 | Retrieve an item in the dataset at the given index. If shuffling is enabled, the given index is ignored. 220 | 221 | The values returned are in a tuple, and in the order of: 222 | 1. EEG image for label `L`. 223 | 2. Target image for label `L`. 224 | 3. Label `L`. 225 | """ 226 | label = random.choice(self.labels) 227 | eeg_sample = random.choice(self.data['eeg'][label]) 228 | target_sample = random.choice(self.data['targets'][label]) 229 | 230 | eeg_sample = Image.open(f'{self.eeg_directory}/{label}/{eeg_sample}') 231 | target_sample = Image.open(f'{self.targets_directory}/{label}/{target_sample}').convert('RGB') 232 | 233 | if target_sample.mode in ('RGBA', 'LA') or (target_sample.mode == 'P' and 'transparency' in target_sample.info): 234 | alpha = target_sample.convert('RGBA').split()[-1] 235 | bg = Image.new("RGBA", target_sample.size, (255, 255, 255) + (255,)) 236 | bg.paste(target_sample, mask=alpha) 237 | target_sample = bg 238 | 239 | eeg_sample = self.transformEEG(eeg_sample) 240 | eeg_sample = eeg_sample.repeat(3, 1, 1) 241 | return eeg_sample, self.transformTarget(target_sample), label 242 | 243 | def __len__(self): 244 | return sum(len(d[label]) for label in self.labels for d in self.data.values()) 245 | 246 | 247 | class Trainer(object): 248 | """ 249 | This class is responsible for the training, sampling and saving loops that 250 | take place when interacting with the model. 251 | """ 252 | 253 | def __init__( 254 | self, 255 | diffusion_model, 256 | training_images_dir, # Folder where the training images exist 257 | # TODO add a secondary folder for the target image-class pairings 258 | # TODO add a means of loading and reading in those image-class pairings 259 | *, 260 | train_batch_size=16, 261 | gradient_accumulate_every=1, 262 | augment_horizontal_flip=True, # Flip image from left-to-right 263 | training_learning_rate=1e-4, 264 | num_training_steps=100000, 265 | ema_update_every=10, 266 | ema_decay=0.995, 267 | adam_betas=(0.9, 0.99), 268 | save_and_sample_every=1000, 269 | num_samples=25, 270 | results_folder='./results', 271 | amp=False, # Used mixed precision during training 272 | fp16=False, # Use Floating-Point 16-bit precision 273 | # TODO might be able to enable fp16 without affecting amp, 274 | # allowing for the model to train on TPUs 275 | split_batches=True, 276 | convert_image_to_ext=None, # A given extension to convert image types to 277 | use_wandb=True 278 | ): 279 | """ 280 | :param split_batches: If the batch of images loaded should be split by 281 | accelerator across all devices, or treated as a per-device batch count. 282 | For example, with a batch size of 32 and 8 devices, split_batches=True 283 | would put 4 items on each device. 284 | """ 285 | super().__init__() 286 | 287 | self.accelerator = Accelerator( 288 | split_batches=split_batches, 289 | mixed_precision='fp16' if fp16 else 'no', 290 | gradient_accumulation_steps=gradient_accumulate_every, 291 | device_placement=True 292 | ) 293 | 294 | self.accelerator.native_amp = amp 295 | 296 | assert has_int_square_root(num_samples), 'The number of samples must have an integer square root' 297 | 298 | self.diffusion_model = diffusion_model 299 | self.num_samples = num_samples 300 | self.save_and_sample_every = save_and_sample_every 301 | self.batch_size = train_batch_size 302 | self.gradient_accumulate_every = gradient_accumulate_every 303 | self.train_num_steps = num_training_steps 304 | self.image_size = diffusion_model.image_size 305 | 306 | self.train_images_dataset = EEGTargetsDataset() 307 | dataloader = DataLoader(self.train_images_dataset, 308 | batch_size=train_batch_size, 309 | shuffle=True, 310 | pin_memory=True, 311 | num_workers=multiprocessing.cpu_count()) 312 | 313 | # dataloader = self.accelerator.prepare(dataloader) 314 | # self.train_images_dataloader = cycle(dataloader) 315 | dataloader = self.accelerator.prepare(dataloader) 316 | self.train_eeg_targets_dataloader = cycle(dataloader) 317 | 318 | # optimizer 319 | 320 | self.optimiser = Adam(diffusion_model.parameters(), lr=training_learning_rate, betas=adam_betas) 321 | 322 | # for logging results in a folder periodically 323 | 324 | if self.accelerator.is_main_process: 325 | self.ema = EMA(diffusion_model, beta=ema_decay, update_every=ema_update_every) 326 | 327 | self.results_folder = Path(f'{results_folder}/{wandb.run.name}-{wandb.run.id}') 328 | self.results_folder.mkdir(exist_ok=True, parents=True) 329 | 330 | # Step counter 331 | self.step = 0 332 | 333 | self.diffusion_model, self.optimiser = self.accelerator.prepare(self.diffusion_model, self.optimiser) 334 | self.diffusion_model.learning_model = self.accelerator.prepare(self.diffusion_model.learning_model) 335 | # wandb.login(key=os.environ['WANDB_API_KEY']) # Uncomment if `wandb login` does not work in the console 336 | 337 | def save(self, milestone): 338 | if not self.accelerator.is_local_main_process: 339 | return 340 | 341 | data = { 342 | 'step': self.step, 343 | 'model': self.accelerator.get_state_dict(self.diffusion_model), 344 | 'opt': self.optimiser.state_dict(), 345 | 'ema': self.ema.state_dict(), 346 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None 347 | } 348 | 349 | torch.save(data, str(self.results_folder / f'model-{milestone}.pt')) 350 | 351 | def load(self, path, milestone): 352 | data = torch.load(str(f'{path}/model-{milestone}.pt')) 353 | 354 | model = self.accelerator.unwrap_model(self.diffusion_model) 355 | model.load_state_dict(data['model']) 356 | 357 | self.step = data['step'] 358 | self.optimiser.load_state_dict(data['opt']) 359 | self.ema.load_state_dict(data['ema']) 360 | 361 | if exists(self.accelerator.scaler) and exists(data['scaler']): 362 | self.accelerator.scaler.load_state_dict(data['scaler']) 363 | 364 | def train(self): 365 | accelerator = self.accelerator 366 | device = accelerator.device 367 | 368 | with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar: 369 | 370 | while self.step < self.train_num_steps: 371 | 372 | total_loss = 0. 373 | 374 | for _ in range(self.gradient_accumulate_every): 375 | eeg_sample, target_sample, _ = next(self.train_eeg_targets_dataloader) 376 | # eeg_sample.to(device) 377 | # target_sample.to(device) 378 | data = (eeg_sample, target_sample) 379 | # eeg_sample, target_sample, label 380 | 381 | with self.accelerator.autocast(): 382 | loss = self.diffusion_model(data) 383 | loss = loss / self.gradient_accumulate_every 384 | total_loss += loss.item() 385 | total_loss += loss 386 | 387 | self.accelerator.backward(loss) 388 | 389 | wandb.log({'total_training_loss': total_loss, 'training_timestep': self.step}) 390 | pbar.set_description(f'loss: {total_loss:.4f}') 391 | 392 | accelerator.wait_for_everyone() 393 | 394 | self.optimiser.step() 395 | self.optimiser.zero_grad() 396 | 397 | accelerator.wait_for_everyone() 398 | 399 | if accelerator.is_main_process: 400 | self.ema.to(device) 401 | self.ema.update() 402 | 403 | if self.step != 0 and self.step % self.save_and_sample_every == 0: 404 | self.ema.ema_model.eval() 405 | 406 | with torch.no_grad(): 407 | milestone = self.step // self.save_and_sample_every 408 | batches = num_to_groups(self.num_samples, self.batch_size) 409 | all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches)) 410 | 411 | all_images = torch.cat(all_images_list, dim=0) 412 | utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), 413 | nrow=int(math.sqrt(self.num_samples))) 414 | self.save(milestone) 415 | 416 | self.step += 1 417 | pbar.update(1) 418 | 419 | accelerator.print('Training complete!') 420 | 421 | 422 | def cycle(dataloader): 423 | while True: 424 | for data in dataloader: 425 | yield data 426 | 427 | 428 | def has_int_square_root(num): 429 | return (math.sqrt(num) ** 2) == num 430 | 431 | 432 | def num_to_groups(num, divisor): 433 | groups = num // divisor 434 | remainder = num % divisor 435 | arr = [divisor] * groups 436 | if remainder > 0: 437 | arr.append(remainder) 438 | return arr 439 | 440 | 441 | def convert_image_to_fn(img_type, image): 442 | if image.mode != img_type: 443 | return image.convert(img_type) 444 | return image 445 | 446 | 447 | def compute_L2_norm(t): 448 | """ 449 | Compute the L2 normalised value for some given value t. 450 | :param t: The value to compute the L2 norm against. 451 | """ 452 | return F.normalize(t, dim=-1) 453 | 454 | 455 | def normalise_to_negative_one_to_one(img): 456 | return img * 2 - 1 457 | 458 | 459 | def unnormalise_to_zero_to_one(t): 460 | return (t + 1) * 0.5 461 | 462 | 463 | def upsample(dim, dim_out=None): 464 | return nn.Sequential( 465 | nn.Upsample(scale_factor=2, mode='nearest'), 466 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1) 467 | ) 468 | 469 | 470 | def downsample(input_channel_dims, output_channel_dims=None): 471 | """ 472 | This method creates a downsampling convolutional layer of the U-Net architecture. 473 | 474 | :param input_channel_dims: The channel dimensions for the input to the layer. 475 | :param output_channel_dims: The channel dimensions for the output of the layer. 476 | """ 477 | return nn.Conv2d( 478 | in_channels=input_channel_dims, 479 | out_channels=default(output_channel_dims, input_channel_dims), 480 | kernel_size=4, 481 | stride=2, 482 | padding=1) 483 | 484 | 485 | def extract(a, t, x_shape): 486 | b, *_ = t.shape 487 | out = a.gather(-1, t) 488 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 489 | 490 | 491 | def linear_beta_schedule(timesteps): 492 | scale = 1000 / timesteps 493 | beta_start = scale * 0.0001 494 | beta_end = scale * 0.02 495 | return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) 496 | 497 | 498 | def cosine_beta_schedule(timesteps, s=0.008): 499 | """ 500 | Cosine beta schedule 501 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 502 | :param s: The parameter that dictates the 'shift' of the beta values. 503 | A lower and higher value will cause beta to begin at a higher and lower value, respectively. 504 | """ 505 | steps = timesteps + 1 506 | x = torch.linspace(0, timesteps, steps, dtype=torch.float64) 507 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 508 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 509 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 510 | return torch.clip(betas, 0, 0.999) 511 | 512 | 513 | def default(val, func): 514 | """ 515 | This method simply checks if the parameter val exists. 516 | If it does not, parameter func is either returned, 517 | or executed if it is itself a function. 518 | 519 | :param val: The value to be checked for its existence. See method exists for more. 520 | :param func: The value or function to be returned or executed (respectively) if val does not exist. 521 | """ 522 | if exists(val): 523 | return val 524 | return func() if callable(func) else func 525 | -------------------------------------------------------------------------------- /images/denoising-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/images/denoising-diffusion.png -------------------------------------------------------------------------------- /images/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/images/sample.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import wandb 4 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 5 | 6 | torch.cuda.empty_cache() 7 | wandb.login() 8 | 9 | # wandb.config.learning_rate = 3e-4 10 | # wandb.config.training_timesteps = 5000 11 | # wandb.config.sampling_timesteps = 250 12 | # wandb.config.image_size = 32 13 | # wandb.config.number_of_samples = 25 14 | # wandb.config.batch_size = 512 15 | # wandb.config.use_amp = False 16 | # wandb.config.use_fp16 = True 17 | # wandb.config.gradient_accumulation_rate = 2 18 | # wandb.config.ema_update_rate = 10 19 | # wandb.config.ema_decay = 0.995 20 | # wandb.config.adam_betas = (0.9, 0.99) 21 | # wandb.config.save_and_sample_rate = 1000 22 | # wandb.config.do_split_batches = False 23 | # wandb.config.timesteps = 1000 24 | # wandb.config.loss_type = 'L2' 25 | # wandb.config.unet_dim = 16 26 | # wandb.config.unet_mults = (1, 2, 4, 8) 27 | # wandb.config.unet_channels = 3 28 | # wandb.config.training_objective = 'pred_x0' 29 | 30 | default_hypers = dict( 31 | learning_rate=3e-4, 32 | training_timesteps=1001, 33 | sampling_timesteps=250, 34 | image_size=32, 35 | number_of_samples=25, 36 | batch_size=256, 37 | use_amp=False, 38 | use_fp16=False, 39 | gradient_accumulation_rate=2, 40 | ema_update_rate=10, 41 | ema_decay=0.995, 42 | adam_betas=(0.9, 0.99), 43 | save_and_sample_rate=1000, 44 | do_split_batches=False, 45 | timesteps=1000, 46 | loss_type='L2', 47 | unet_dim=16, 48 | unet_mults=(1, 2, 4, 8), 49 | unet_channels=3, 50 | training_objective='pred_x0' 51 | ) 52 | 53 | wandb.init(config=default_hypers, project='bath-thesis', entity='jd202') 54 | 55 | # with open('./sweep.yaml') as f: 56 | # sweep_config = yaml.load(f, Loader=SafeLoader) 57 | # 58 | # sweep_id = wandb.sweep(sweep_config, entity='jd202', project='bath-thesis') 59 | 60 | 61 | model = Unet( 62 | dim=wandb.config.unet_dim, 63 | dim_mults=wandb.config.unet_mults, 64 | channels=wandb.config.unet_channels 65 | ) 66 | 67 | diffusion = GaussianDiffusion( 68 | model, 69 | image_size=wandb.config.image_size, 70 | timesteps=wandb.config.timesteps, # number of steps 71 | sampling_timesteps=wandb.config.sampling_timesteps, 72 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) 73 | loss_type=wandb.config.loss_type, # L1 or L2 74 | training_objective=wandb.config.training_objective 75 | ) 76 | 77 | trainer = Trainer( 78 | diffusion, 79 | '/Users/jake/Desktop/scp/cifar', 80 | train_batch_size=wandb.config.batch_size, 81 | training_learning_rate=wandb.config.learning_rate, 82 | num_training_steps=wandb.config.training_timesteps, # total training steps 83 | num_samples=wandb.config.number_of_samples, 84 | gradient_accumulate_every=wandb.config.gradient_accumulation_rate, # gradient accumulation steps 85 | ema_update_every=wandb.config.ema_update_rate, 86 | ema_decay=wandb.config.ema_decay, # exponential moving average decay 87 | amp=wandb.config.use_amp, # turn on mixed precision 88 | fp16=wandb.config.use_fp16, 89 | save_and_sample_every=wandb.config.save_and_sample_rate 90 | ) 91 | 92 | trainer.load('./results/loadins', '17') 93 | 94 | wandb.watch(model) 95 | wandb.watch(diffusion) 96 | 97 | trainer.train() 98 | 99 | wandb.finish() 100 | -------------------------------------------------------------------------------- /pytorch-xla-env-setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Sample usage: 3 | # python env-setup.py --version 1.11 --apt-packages libomp5 4 | import argparse 5 | import collections 6 | from datetime import datetime 7 | import os 8 | import platform 9 | import re 10 | import requests 11 | import subprocess 12 | import threading 13 | import sys 14 | 15 | VersionConfig = collections.namedtuple('VersionConfig', 16 | ['wheels', 'tpu', 'py_version', 'cuda_version']) 17 | DEFAULT_CUDA_VERSION = '11.2' 18 | OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d') 19 | NEW_VERSION = datetime.strptime('20220315', '%Y%m%d') # 1.11 release date 20 | OLDEST_GPU_VERSION = datetime.strptime('20200707', '%Y%m%d') 21 | DIST_BUCKET = 'gs://tpu-pytorch/wheels' 22 | TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' 23 | TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' 24 | TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' 25 | VERSION_REGEX = re.compile(r'^(\d+\.)+\d+$') 26 | 27 | def is_gpu_runtime(): 28 | return int(os.environ.get('COLAB_GPU', 0)) == 1 29 | 30 | 31 | def is_tpu_runtime(): 32 | return 'TPU_NAME' in os.environ 33 | 34 | 35 | def update_tpu_runtime(tpu_name, version): 36 | print(f'Updating TPU runtime to {version.tpu} ...') 37 | 38 | try: 39 | import cloud_tpu_client 40 | except ImportError: 41 | subprocess.call([sys.executable, '-m', 'pip', 'install', 'cloud-tpu-client']) 42 | import cloud_tpu_client 43 | 44 | client = cloud_tpu_client.Client(tpu_name) 45 | client.configure_tpu_version(version.tpu) 46 | print('Done updating TPU runtime') 47 | 48 | 49 | def get_py_version(): 50 | version_tuple = platform.python_version_tuple() 51 | return version_tuple[0] + version_tuple[1] # major_version + minor_version 52 | 53 | 54 | def get_cuda_version(): 55 | if is_gpu_runtime(): 56 | # cuda available, install cuda wheels 57 | return DEFAULT_CUDA_VERSION 58 | 59 | 60 | def get_version(version): 61 | cuda_version = get_cuda_version() 62 | if version == 'nightly': 63 | return VersionConfig( 64 | 'nightly', 'pytorch-nightly', get_py_version(), cuda_version) 65 | 66 | version_date = None 67 | try: 68 | version_date = datetime.strptime(version, '%Y%m%d') 69 | except ValueError: 70 | pass # Not a dated nightly. 71 | 72 | if version_date: 73 | if cuda_version and version_date < OLDEST_GPU_VERSION: 74 | raise ValueError( 75 | f'Oldest nightly version build with CUDA available is {OLDEST_GPU_VERSION}') 76 | elif not cuda_version and version_date < OLDEST_VERSION: 77 | raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}') 78 | return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}', 79 | get_py_version(), cuda_version) 80 | 81 | 82 | if not VERSION_REGEX.match(version): 83 | raise ValueError(f'{version} is an invalid torch_xla version pattern') 84 | return VersionConfig( 85 | version, f'pytorch-{version}', get_py_version(), cuda_version) 86 | 87 | 88 | def install_vm(version, apt_packages, is_root=False): 89 | dist_bucket = DIST_BUCKET 90 | 91 | if version.cuda_version: 92 | # Distributions for GPU runtime 93 | # Note: GPU wheels available from 1.11 94 | dist_bucket = os.path.join( 95 | DIST_BUCKET, 'cuda/{}'.format(version.cuda_version.replace('.', ''))) 96 | else: 97 | # Distributions for TPU runtime 98 | # Note: this redirection is required for 1.11 & nightly releases 99 | # because the current 2 VM wheels are not compatible with colab environment. 100 | if version.wheels == 'nightly': 101 | dist_bucket = os.path.join(DIST_BUCKET, 'colab/') 102 | elif 'nightly+' in version.wheels: 103 | build_date = datetime.strptime( version.wheels.split('+')[1], '%Y%m%d') 104 | if build_date >= NEW_VERSION: 105 | dist_bucket = os.path.join(DIST_BUCKET, 'colab/') 106 | elif VERSION_REGEX.match(version.wheels): 107 | minor = int(version.wheels.split('.')[1]) 108 | if minor >= 11: 109 | dist_bucket = os.path.join(DIST_BUCKET, 'colab/') 110 | else: 111 | raise ValueError(f'{version} is an invalid torch_xla version pattern') 112 | 113 | torch_whl = TORCH_WHEEL_TMPL.format( 114 | whl_version=version.wheels, py_version=version.py_version) 115 | torch_whl_path = os.path.join(dist_bucket, torch_whl) 116 | torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format( 117 | whl_version=version.wheels, py_version=version.py_version) 118 | torch_xla_whl_path = os.path.join(dist_bucket, torch_xla_whl) 119 | torchvision_whl = TORCHVISION_WHEEL_TMPL.format( 120 | whl_version=version.wheels, py_version=version.py_version) 121 | torchvision_whl_path = os.path.join(dist_bucket, torchvision_whl) 122 | apt_cmd = ['apt-get', 'install', '-y'] 123 | apt_cmd.extend(apt_packages) 124 | 125 | if not is_root: 126 | # Colab/Kaggle run as root, but not GCE VMs so we need privilege 127 | apt_cmd.insert(0, 'sudo') 128 | 129 | installation_cmds = [ 130 | [sys.executable, '-m', 'pip', 'uninstall', '-y', 'torch', 'torchvision'], 131 | ['gsutil', 'cp', torch_whl_path, '.'], 132 | ['gsutil', 'cp', torch_xla_whl_path, '.'], 133 | ['gsutil', 'cp', torchvision_whl_path, '.'], 134 | [sys.executable, '-m', 'pip', 'install', torch_whl], 135 | [sys.executable, '-m', 'pip', 'install', torch_xla_whl], 136 | [sys.executable, '-m', 'pip', 'install', torchvision_whl], 137 | apt_cmd, 138 | ] 139 | for cmd in installation_cmds: 140 | subprocess.call(cmd) 141 | 142 | 143 | def run_setup(args): 144 | version = get_version(args.version) 145 | # Update TPU 146 | print('Updating... This may take around 2 minutes.') 147 | 148 | if is_tpu_runtime(): 149 | update = threading.Thread( 150 | target=update_tpu_runtime, args=( 151 | args.tpu, 152 | version, 153 | )) 154 | update.start() 155 | 156 | install_vm(version, args.apt_packages, is_root=not args.tpu) 157 | 158 | if is_tpu_runtime(): 159 | update.join() 160 | 161 | 162 | if __name__ == '__main__': 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument( 165 | '--version', 166 | type=str, 167 | default='20200515', 168 | help='Versions to install (nightly, release version, or YYYYMMDD).', 169 | ) 170 | parser.add_argument( 171 | '--apt-packages', 172 | nargs='+', 173 | default=['libomp5'], 174 | help='List of apt packages to install', 175 | ) 176 | parser.add_argument( 177 | '--tpu', 178 | type=str, 179 | help='[GCP] Name of the TPU (same zone, project as VM running script)', 180 | ) 181 | args = parser.parse_args() 182 | run_setup(args) 183 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0 2 | accelerate==0.13.0.dev0 3 | astunparse==1.6.3 4 | attrs==19.3.0 5 | Automat==0.8.0 6 | blinker==1.4 7 | cachetools==5.2.0 8 | certifi==2022.6.15 9 | chardet==3.0.4 10 | charset-normalizer==2.1.0 11 | Click==7.0 12 | cloud-init==22.2 13 | cloud-tpu-client==0.10 14 | colorama==0.4.3 15 | command-not-found==0.3 16 | commonmark==0.9.1 17 | configobj==5.0.6 18 | constantly==15.1.0 19 | cryptography==2.8 20 | Cython==0.29.14 21 | dbus-python==1.2.16 22 | distlib==0.3.4 23 | distro==1.4.0 24 | distro-info===0.23ubuntu1 25 | docker-pycreds==0.4.0 26 | einops==0.4.1 27 | ema-pytorch==0.0.10 28 | entrypoints==0.3 29 | filelock==3.7.1 30 | flatbuffers==2.0 31 | future==0.18.2 32 | gast==0.4.0 33 | gitdb==4.0.9 34 | GitPython==3.1.27 35 | google-api-core==1.31.6 36 | google-api-python-client==1.8.0 37 | google-auth==2.9.0 38 | google-auth-httplib2==0.1.0 39 | google-auth-oauthlib==0.4.6 40 | google-pasta==0.2.0 41 | googleapis-common-protos==1.56.3 42 | grpcio==1.47.0 43 | h5py==3.7.0 44 | httplib2==0.20.4 45 | hyperlink==19.0.0 46 | idna==3.3 47 | importlib-metadata==4.12.0 48 | incremental==16.10.1 49 | intel-openmp==2022.1.0 50 | Jinja2==2.10.1 51 | jsonpatch==1.22 52 | jsonpointer==2.0 53 | jsonschema==3.2.0 54 | keras==2.9.0 55 | Keras-Applications==1.0.8 56 | Keras-Preprocessing==1.1.2 57 | keyring==18.0.1 58 | language-selector==0.1 59 | launchpadlib==1.10.13 60 | lazr.restfulclient==0.14.2 61 | lazr.uri==1.0.3 62 | libclang==14.0.1 63 | libtpu-nightly==0.1.dev20220518 64 | Markdown==3.3.7 65 | MarkupSafe==1.1.0 66 | mkl==2022.1.0 67 | mkl-include==2022.1.0 68 | mock==4.0.3 69 | more-itertools==4.2.0 70 | netifaces==0.10.4 71 | numpy==1.23.0 72 | oauth2client==4.1.3 73 | oauthlib==3.1.0 74 | opt-einsum==3.3.0 75 | packaging==21.3 76 | pathtools==0.1.2 77 | pexpect==4.6.0 78 | Pillow==9.2.0 79 | platformdirs==2.5.2 80 | promise==2.3 81 | protobuf==3.20.1 82 | psutil==5.9.1 83 | pyasn1==0.4.8 84 | pyasn1-modules==0.2.8 85 | Pygments==2.13.0 86 | PyGObject==3.36.0 87 | PyHamcrest==1.9.0 88 | PyJWT==1.7.1 89 | pymacaroons==0.13.0 90 | PyNaCl==1.3.0 91 | pyOpenSSL==19.0.0 92 | pyparsing==3.0.9 93 | pyrsistent==0.15.5 94 | pyserial==3.4 95 | python-apt==2.0.0+ubuntu0.20.4.7 96 | python-debian===0.1.36ubuntu1 97 | pytz==2022.1 98 | PyYAML==5.4.1 99 | requests==2.28.1 100 | requests-oauthlib==1.3.1 101 | requests-unixsocket==0.2.0 102 | rich==12.5.1 103 | rsa==4.8 104 | SecretStorage==2.3.1 105 | sentry-sdk==1.9.5 106 | service-identity==18.1.0 107 | setproctitle==1.3.2 108 | shortuuid==1.0.9 109 | simplejson==3.16.0 110 | six==1.16.0 111 | smmap==5.0.0 112 | sos==4.3 113 | ssh-import-id==5.10 114 | systemd-python==234 115 | tbb==2021.6.0 116 | tensorboard==2.9.1 117 | tensorboard-data-server==0.6.1 118 | tensorboard-plugin-wit==1.8.1 119 | tensorflow==2.10.0 120 | tensorflow-estimator==2.9.0 121 | tensorflow-io-gcs-filesystem==0.26.0 122 | termcolor==1.1.0 123 | torch==1.12.1 124 | torchvision==0.13.1 125 | tqdm==4.64.0 126 | Twisted==18.9.0 127 | typing-extensions==4.2.0 128 | ubuntu-advantage-tools==27.8 129 | ufw==0.36 130 | unattended-upgrades==0.1 131 | uritemplate==3.0.1 132 | urllib3==1.26.9 133 | virtualenv==20.15.1 134 | wadllib==1.3.3 135 | wandb==0.13.2 136 | Werkzeug==2.1.2 137 | wrapt==1.14.1 138 | zipp==1.0.0 139 | zope.interface==4.7.1 -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torchvision.utils import save_image 5 | 6 | import wandb 7 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, utils 8 | 9 | torch.cuda.empty_cache() 10 | wandb.login() 11 | 12 | default_hypers = dict( 13 | learning_rate=3e-4, 14 | training_timesteps=1001, 15 | sampling_timesteps=250, 16 | image_size=32, 17 | number_of_samples=25, 18 | batch_size=256, 19 | use_amp=False, 20 | use_fp16=False, 21 | gradient_accumulation_rate=2, 22 | ema_update_rate=10, 23 | ema_decay=0.995, 24 | adam_betas=(0.9, 0.99), 25 | save_and_sample_rate=1000, 26 | do_split_batches=False, 27 | timesteps=4000, 28 | loss_type='L2', 29 | unet_dim=128, 30 | unet_mults=(1, 2, 2, 2), 31 | unet_channels=3, 32 | training_objective='pred_x0' 33 | ) 34 | 35 | wandb.init(config=default_hypers, project='bath-thesis', entity='jd202') 36 | 37 | model = Unet( 38 | dim=wandb.config.unet_dim, 39 | dim_mults=wandb.config.unet_mults, 40 | channels=wandb.config.unet_channels 41 | ) 42 | 43 | diffusion = GaussianDiffusion( 44 | model, 45 | image_size=wandb.config.image_size, 46 | timesteps=wandb.config.timesteps, # number of steps 47 | sampling_timesteps=wandb.config.sampling_timesteps, 48 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) 49 | loss_type=wandb.config.loss_type, # L1 or L2 50 | training_objective=wandb.config.training_objective 51 | ) 52 | 53 | trainer = Trainer( 54 | diffusion, 55 | '/Users/jake/Desktop/scp/cifar', 56 | train_batch_size=wandb.config.batch_size, 57 | training_learning_rate=wandb.config.learning_rate, 58 | num_training_steps=wandb.config.training_timesteps, # total training steps 59 | num_samples=wandb.config.number_of_samples, 60 | gradient_accumulate_every=wandb.config.gradient_accumulation_rate, # gradient accumulation steps 61 | ema_update_every=wandb.config.ema_update_rate, 62 | ema_decay=wandb.config.ema_decay, # exponential moving average decay 63 | amp=wandb.config.use_amp, # turn on mixed precision 64 | fp16=wandb.config.use_fp16, 65 | save_and_sample_every=wandb.config.save_and_sample_rate 66 | ) 67 | 68 | trainer.load('./results/loadins', '56') 69 | trainer.ema.ema_model.eval() 70 | with torch.no_grad(): 71 | milestone = 10 // 1 72 | batches = utils.num_to_groups(wandb.config.number_of_samples, wandb.config.batch_size) 73 | all_images_list = list(map(lambda n: trainer.ema.ema_model.sample(batch_size=n), batches)) 74 | 75 | all_images = torch.cat(all_images_list, dim=0) 76 | save_image(all_images, str(f'results/samples/sample-{milestone}.png'), 77 | nrow=int(math.sqrt(wandb.config.number_of_samples))) 78 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='denoising-diffusion-pytorch', 5 | packages=find_packages(), 6 | version='0.27.4', 7 | license='MIT', 8 | description='Denoising Diffusion Probabilistic Models - Pytorch', 9 | author='DevJake', 10 | long_description_content_type='text/markdown', 11 | keywords=[ 12 | 'artificial intelligence', 13 | 'generative models' 14 | ], 15 | install_requires=[ 16 | 'accelerate', 17 | 'einops', 18 | 'ema-pytorch', 19 | 'pillow', 20 | 'torch', 21 | 'torchvision', 22 | 'tqdm' 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | cd ~/ 2 | #python3 -m pip install virtualenv # It is just simpler this way, trust me... 3 | git clone https://github.com/DevJake/EEG-diffusion-pytorch.git diffusion 4 | cd diffusion 5 | #virtualenv venv 6 | #source venv/bin/activate 7 | sudo python3 setup.py install 8 | #pip3 install wandb accelerate einops tqdm ema_pytorch torchvision 9 | #pip3 install -r requirements.txt 10 | #python3 -m wandb login 11 | accelerate config 12 | sudo apt install rclone 13 | mkdir -p ~/.config/rclone 14 | nano ~/.config/rclone/rclone.conf 15 | # Add in your rclone config to connect to the repository storing all EEG and Targets data 16 | mkdir -p datasets/eeg/unsorted datasets/eeg/flower datasets/eeg/penguin datasets/eeg/guitar 17 | mkdir -p datasets/targets/unsorted datasets/targets/flower datasets/targets/penguin datasets/targets/guitar 18 | cd ~/diffusion/datasets/eeg 19 | #rclone copy gc:/bath-thesis-data/data/outputs/preprocessing . -P 20 | rclone copy gc:/bath-thesis-data/data/subjects/preprocessed/combined . -P 21 | find . -name "*.tar.gz" -exec tar -xf {} \; # this will take some time to run... 22 | find . -name "*.tar.gz" -exec rm -v {} \; 23 | #find ./ -type f -exec mv --backup=numbered {} ./ -v \; 24 | cd ~/diffusion/datasets/targets/ 25 | rclone copy gc:/bath-thesis-data/data/classes/32x32.tar . -P 26 | tar -xf 32x32.tar 27 | rm 32x32.tar 28 | mv 32x32/flower-32x32/* flower/ & mv 32x32/guitar-32x32/* guitar/ & mv 32x32/penguin-32x32/* penguin/ 29 | rm 32x32 -r 30 | cd ../.. 31 | accelerate launch model.py 32 | 33 | #curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py 34 | #sudo python3 pytorch-xla-env-setup.py --version 1.12 35 | 36 | # Given the enormous size of model save files, and their frequent saving, 37 | # you can use the following command in a tmux session to have them be backed 38 | # up the the Google Cloud bucket, or another provider of choice. Source directories are not deleted, 39 | # so do not worry about the model crashing from not being able to save! 40 | # while sleep 120; do rclone move ~/diffusion/results/ gc:bath-thesis-data/data/trained_models/ -P; done 41 | -------------------------------------------------------------------------------- /sweep.yaml: -------------------------------------------------------------------------------- 1 | program: /usr/local/bin/accelerate launch model.py 2 | command: 3 | - ${program} 4 | method: bayes 5 | metric: 6 | goal: minimize 7 | name: total_training_loss 8 | parameters: 9 | ema_update_rate: 10 | max: 20 11 | min: 5 12 | distribution: int_uniform 13 | gradient_accumulation_rate: 14 | max: 20 15 | min: 1 16 | distribution: int_uniform 17 | training_timesteps: 18 | max: 2000 19 | min: 500 20 | distribution: int_uniform 21 | training_objective: 22 | values: 23 | - pred_x0 24 | - pred_noise 25 | distribution: categorical 26 | learning_rate: 27 | max: 0.1 28 | min: 0.00001 29 | distribution: uniform 30 | image_size: 31 | max: 64 32 | min: 16 33 | distribution: int_uniform 34 | batch_size: 35 | max: 512 36 | min: 32 37 | distribution: int_uniform 38 | timesteps: 39 | max: 8000 40 | min: 200 41 | distribution: int_uniform 42 | loss_type: 43 | values: 44 | - L2 45 | - L1 46 | distribution: categorical 47 | ema_decay: 48 | max: 1.99 49 | min: 0.4 50 | distribution: uniform 51 | unet_dim: 52 | max: 128 53 | min: 8 54 | distribution: int_uniform -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from denoising_diffusion_pytorch.utils import EEGTargetsDataset, cycle 4 | 5 | dataset = EEGTargetsDataset() 6 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, pin_memory=True, num_workers=0) 7 | dataloader = cycle(dataloader) 8 | data = next(dataloader) 9 | 10 | print(data) 11 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/util/__init__.py -------------------------------------------------------------------------------- /util/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/util/configs/__init__.py -------------------------------------------------------------------------------- /util/configs/generate_configs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import uuid 5 | import util.preprocessing.pipeline as pipe 6 | 7 | if __name__ == '__main__': 8 | samples_dir = './samples' 9 | variations = None 10 | 11 | with open(f'./variations.json', 'r') as f: 12 | variations = json.load(f) 13 | 14 | for w, _ in variations['window_overlap']: 15 | assert len(pipe.get_output_dims_by_factors(w * 1024)) > 0, f'No factor pairs exist for a window width of {w}!' 16 | 17 | combinations = [ 18 | [window_overlap, use_channels, high_pass, low_pass] 19 | for window_overlap in variations['window_overlap'] 20 | for use_channels in variations['use_channels'] 21 | for high_pass in variations['high_pass'] 22 | for low_pass in variations['low_pass'] 23 | ] 24 | 25 | print(f'Calculated {len(combinations)} total variants...') 26 | print(f'Estimated configuration count is {len(combinations) * len(list(glob.iglob(samples_dir + "/**.json")))}') 27 | 28 | for i, sample in enumerate(glob.iglob(f'{samples_dir}/*.json')): 29 | print('Operating on config #', i) 30 | 31 | os.makedirs(f'{samples_dir}/generated/config_{i}', exist_ok=True) 32 | 33 | conf = None 34 | with open(sample, 'r') as f: 35 | conf = json.load(f) 36 | 37 | for wo, use_channels, high_pass, low_pass in combinations: 38 | window_size, window_overlap = wo 39 | 40 | config = conf.copy() 41 | 42 | # config['META.CONFIG_NAME'] += f'_window_size={window_size}' 43 | # config['META.CONFIG_NAME'] += f'_window_overlap={window_overlap}' 44 | # config['META.CONFIG_NAME'] += f'_use_channels={use_channels}' 45 | # config['META.CONFIG_NAME'] += f'_highpass={high_pass}' 46 | # config['META.CONFIG_NAME'] += f'_lowpass={low_pass}' 47 | 48 | unique_id = str(uuid.uuid4()) 49 | 50 | config['META.CONFIG_NAME'] += '--' + unique_id 51 | 52 | config['PREPROCESSING.HIGH_PASS_FILTER.FREQ'] = high_pass 53 | config['PREPROCESSING.LOW_PASS_FILTER.FREQ'] = low_pass 54 | config['RENDER.WINDOW_OVERLAP'] = window_overlap 55 | config['RENDER.WINDOW_WIDTH'] = window_size 56 | config['RENDER.DO_PER_CHANNEL'] = use_channels 57 | 58 | with open(f'{samples_dir}/generated/config_{i}/{config["META.CONFIG_NAME"]}.json', 'w') as f_o: 59 | json.dump(config, f_o, indent=4, sort_keys=True) 60 | -------------------------------------------------------------------------------- /util/configs/samples/all-but-no-filters-or-dc.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "all-but-no-filters-or-dc", 3 | "META.CONFIG_COMMENT": "All settings and default values, no filtering or DC removal", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false, 6 | "PREPROCESSING.DO_MONTAGE": true, 7 | "PREPROCESSING.DO_REMOVE_DC": false, 8 | "PREPROCESSING.DO_REMOVE_EOG": true, 9 | "PREPROCESSING.DO_USE_ICA": true, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/all-but-no-filters.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "all-but-no-filters", 3 | "META.CONFIG_COMMENT": "All settings and default values, no filtering", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false, 6 | "PREPROCESSING.DO_MONTAGE": true, 7 | "PREPROCESSING.DO_REMOVE_DC": true, 8 | "PREPROCESSING.DO_REMOVE_EOG": true, 9 | "PREPROCESSING.DO_USE_ICA": true, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/default-config-copy.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "default-config-copy", 3 | "META.CONFIG_COMMENT": "A copy of the default config with all settings enabled and default values", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true, 6 | "PREPROCESSING.DO_MONTAGE": true, 7 | "PREPROCESSING.DO_REMOVE_DC": true, 8 | "PREPROCESSING.DO_REMOVE_EOG": true, 9 | "PREPROCESSING.DO_USE_ICA": true, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/ica-only.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "ica-only", 3 | "META.CONFIG_COMMENT": "All settings and default values, no filtering, DC removal or EOG noise removal", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false, 6 | "PREPROCESSING.DO_MONTAGE": true, 7 | "PREPROCESSING.DO_REMOVE_DC": false, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": true, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/raw-high-pass-ica.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "raw-high-pass-ica", 3 | "META.CONFIG_COMMENT": "Apply high pass and ICA only", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false, 6 | "PREPROCESSING.DO_MONTAGE": false, 7 | "PREPROCESSING.DO_REMOVE_DC": false, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": true, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/raw-high-pass-no-ica.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "raw-high-pass-no-ica", 3 | "META.CONFIG_COMMENT": "Apply high pass only", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false, 6 | "PREPROCESSING.DO_MONTAGE": false, 7 | "PREPROCESSING.DO_REMOVE_DC": false, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": false, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/raw-loss-pass-ica.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "raw-low-pass-ica", 3 | "META.CONFIG_COMMENT": "Apply low pass filtering and ICA only", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true, 6 | "PREPROCESSING.DO_MONTAGE": false, 7 | "PREPROCESSING.DO_REMOVE_DC": false, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": true, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/raw-low-pass-no-ica.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "raw-low-pass-no-ica", 3 | "META.CONFIG_COMMENT": "Apply low pass only", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true, 6 | "PREPROCESSING.DO_MONTAGE": false, 7 | "PREPROCESSING.DO_REMOVE_DC": false, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": false, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/raw-only.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "raw-only", 3 | "META.CONFIG_COMMENT": "Apply no preprocessing or ICA", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false, 6 | "PREPROCESSING.DO_MONTAGE": false, 7 | "PREPROCESSING.DO_REMOVE_DC": false, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": false, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/remove-all-noise-ica-no-eog.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "remove-all-noise-ica-no-eog", 3 | "META.CONFIG_COMMENT": "Remove all noise, use ICA, do not remove EOG noise", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true, 6 | "PREPROCESSING.DO_MONTAGE": true, 7 | "PREPROCESSING.DO_REMOVE_DC": true, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": true, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/remove-all-noise-no-ica.json: -------------------------------------------------------------------------------- 1 | { 2 | "META.CONFIG_NAME": "remove-all-noise-no-ica", 3 | "META.CONFIG_COMMENT": "Remove all noise, do not use ICA", 4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true, 5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true, 6 | "PREPROCESSING.DO_MONTAGE": true, 7 | "PREPROCESSING.DO_REMOVE_DC": true, 8 | "PREPROCESSING.DO_REMOVE_EOG": false, 9 | "PREPROCESSING.DO_USE_ICA": false, 10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50, 11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1, 12 | "RENDER.DO_PER_CHANNEL": false, 13 | "RENDER.WINDOW_OVERLAP": 0.5, 14 | "RENDER.WINDOW_WIDTH": 1 15 | } -------------------------------------------------------------------------------- /util/configs/samples/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.12.0 2 | appdirs==1.4.4 3 | certifi==2022.6.15 4 | charset-normalizer==2.1.1 5 | click==8.1.3 6 | cycler==0.11.0 7 | decorator==5.1.1 8 | docker-pycreds==0.4.0 9 | einops==0.4.1 10 | ema-pytorch==0.0.10 11 | fonttools==4.37.1 12 | gitdb==4.0.9 13 | GitPython==3.1.27 14 | idna==3.3 15 | Jinja2==3.1.2 16 | kiwisolver==1.4.4 17 | MarkupSafe==2.1.1 18 | matplotlib==3.5.3 19 | mne==1.1.1 20 | numpy==1.23.2 21 | packaging==21.3 22 | pathtools==0.1.2 23 | Pillow==9.2.0 24 | pooch==1.6.0 25 | promise==2.3 26 | protobuf==3.20.1 27 | psutil==5.9.1 28 | pyparsing==3.0.9 29 | python-dateutil==2.8.2 30 | PyYAML==6.0 31 | requests==2.28.1 32 | scipy==1.9.0 33 | sentry-sdk==1.9.5 34 | setproctitle==1.3.2 35 | shortuuid==1.0.9 36 | six==1.16.0 37 | smmap==5.0.0 38 | torch==1.12.1 39 | torchvision==0.13.1 40 | tqdm==4.64.0 41 | typing_extensions==4.3.0 42 | urllib3==1.26.12 43 | wandb==0.13.2 44 | -------------------------------------------------------------------------------- /util/configs/variations.json: -------------------------------------------------------------------------------- 1 | { 2 | "high_pass": [ 3 | 30, 4 | 50 5 | ], 6 | "low_pass": [ 7 | 1, 8 | 0.5, 9 | 0.1 10 | ], 11 | "use_channels": [ 12 | true, 13 | false 14 | ], 15 | "window_overlap": [ 16 | [ 17 | 1, 18 | 0.5 19 | ], 20 | [ 21 | 1.5, 22 | 0.4 23 | ], 24 | [ 25 | 0.5, 26 | 0.2 27 | ], 28 | [ 29 | 2, 30 | 1 31 | ] 32 | ] 33 | } -------------------------------------------------------------------------------- /util/move.py: -------------------------------------------------------------------------------- 1 | labels = ['flower', 'penguin', 'guitar'] 2 | import os 3 | from pathlib import Path 4 | 5 | import tqdm 6 | 7 | cpt = sum([len(files) for r, d, files in os.walk('./')]) 8 | 9 | t = 0 10 | for path in tqdm.tqdm(Path('./').rglob(f'*.jpg'), initial=0, total=cpt): 11 | split = str(path).split('/') 12 | subject, session, name = split[0], split[1], split[-1] 13 | 14 | label = None 15 | label = 'penguin' if 'penguin' in name else label 16 | label = 'guitar' if 'guitar' in name else label 17 | label = 'flower' if 'flower' in name else label 18 | 19 | new_name = f'{label}/{str(t).zfill(8)}_{subject}_{session}_{name}' 20 | os.rename(path, f'./{new_name}') 21 | t += 1 22 | -------------------------------------------------------------------------------- /util/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/util/preprocessing/__init__.py -------------------------------------------------------------------------------- /util/preprocessing/pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """EEG to Dataset Pipeline 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1xR3KqzFsQt5oF1ot8ZzYS49FkIkaZRWV 8 | 9 | This notebook intends to produce the following datasets on a given EEG RAW dataset: 10 | 1. Denoised EEG with no Gaussian noise and sliding window 11 | 2. Denoised EEG with Gaussian noise and sliding window 12 | 3. Noisy EEG with no Gaussian noise and sliding window 13 | 4. Noisy EEG with Gaussian noise and sliding window 14 | 15 | Parameters therefore include: 16 | - Gaussian noise or none 17 | - Amount of Gaussian noise 18 | - Noisy or not 19 | - If denoising, what methods? 20 | - ICA - and number of channels 21 | - Bad channel removal 22 | - ECG, EOG and EEG removal 23 | - Low pass filters 24 | - Sliding window width 25 | - Window overlap amount 26 | 27 | 28 | Output datasets must also be reshaped into consistent image sizes, such as 64x64, 32x32, 64x32, etc. This is dependent on the possible configurations, which is dictated by the sliding window width and overlap amount. This can be calculated by finding the factoring pairs of the flattened EEG signal vector size. 29 | 30 | Each output window should not only be a .jpg image, but should also be either greyscale or black & white. Furthermore, a CSV is needed to contain metadata, such as subject/participant, and what they were imagining/perceiving in the moment. 31 | """ 32 | 33 | # Load the given EEG file 34 | 35 | sub_cap_sizes = { # [subject] = capsize 36 | 2: 'N/A', 37 | 3: 'L', 38 | 4: 'L', 39 | 5: 'M', 40 | 6: 'L', 41 | 7: 'L', 42 | 8: 'M', 43 | 9: 'N/A', 44 | 10: 'M', 45 | 11: 'M', 46 | 12: 'L', 47 | 13: 'L', 48 | 14: 'L', 49 | 15: 'N/A', 50 | 16: 'L', 51 | 17: 'L' 52 | } 53 | 54 | 55 | # subject, session = 10, 1 56 | 57 | def load_eeg(subject: int, session: int): 58 | raw_file = f'./data/subjects/Subject {subject}/Session {session}/sub{subject}_sess{session}.set' 59 | eog_channels = ['VEOGL', 'VEOGU', 'HEOGR', 'HEOGL'] # electroculogram electrodes 60 | raw = mne.io.read_raw_eeglab(raw_file, preload=True, eog=eog_channels) 61 | 62 | raw.info['bads'] += 'CCP1h' if sub_cap_sizes[subject] == 'L' else [] 63 | try: 64 | raw = raw.pick_types(eog=False, eeg=True, exclude='bads') # Effectively drop all EOG channels, leaving just EEG 65 | except RuntimeError: 66 | print('Exception has occurred when ignoring bad channels, selecting without explicit exclusions...') 67 | raw.info['bads'] = [] 68 | raw = raw.pick_types(eeg=True, eog=False, exclude='bads') 69 | raw = raw.interpolate_bads(reset_bads=False) 70 | 71 | return raw 72 | 73 | 74 | from matplotlib import pyplot as plt 75 | 76 | 77 | def generate_sample_image_from_eeg(start_t, end_t, image_height=64, channel=0): 78 | SAMPLE_RATE = 1024 79 | plt.gray() 80 | return plt.imshow(raw.get_data(channel, start_t * SAMPLE_RATE, end_t * SAMPLE_RATE)[0, :].reshape(image_height, -1), 81 | interpolation='nearest') 82 | 83 | 84 | # display(generate_sample_image_from_eeg(500, 510)) 85 | 86 | import mne 87 | import json 88 | import numpy as np 89 | import mne.preprocessing as preprocessing 90 | 91 | 92 | def apply_montage(raw, montage_file_path: str): 93 | with open(montage_file_path, 'r') as f: 94 | raw.set_montage(mne.channels.make_dig_montage(json.load(f), coord_frame='head')) 95 | return raw 96 | 97 | 98 | def compute_ICA(raw, channels: int = 20, random_state: int = 0, reject_dict={'mag': 5e-12, 'grad': 4000e-13}, 99 | max_iter=800, method: str = 'fastica'): 100 | ica = mne.preprocessing.ICA(n_components=channels, random_state=random_state, max_iter=max_iter, method=method) 101 | return ica.fit(raw, reject=reject_dict) 102 | 103 | 104 | def remove_EOG(raw, ica, eog_channel_names=['Fp1', 'Fp2'], ica_z_threshold=1.96): # Requires ICA 105 | eog_epochs = preprocessing.create_eog_epochs(raw, ch_name=['Fp1', 'Fp2']).average() 106 | eog_epochs.apply_baseline(baseline=(None, -0.2)) 107 | 108 | eog_indices, _ = ica.find_bads_eog(eog_epochs, ch_name=eog_channel_names, threshold=ica_z_threshold) 109 | ica.exclude.extend(eog_indices) 110 | 111 | return ica 112 | 113 | 114 | def remove_ECG(raw, ica, ica_z_threshold=1.96): # Requires ICA 115 | ecg_epochs = preprocessing.create_ecg_epochs(raw).average() 116 | ecg_epochs.apply_baseline(baseline=(None, -0.2)) 117 | 118 | ecg_indices, _ = ica.find_bads_ecg(ecg_epochs, threshold=ica_z_threshold) 119 | ica.exclude.extend(ecg_indices) 120 | 121 | return ica 122 | 123 | 124 | def remove_DC(raw, notches=np.arange(50, 251, 50)): 125 | # https://mne.tools/stable/generated/mne.io.Raw.html#mne.io.Raw.notch_filter 126 | return raw.notch_filter(notches, picks='eeg', filter_length='auto', phase='zero-double', fir_design='firwin') 127 | 128 | 129 | def apply_filter(raw, low_freq=None, high_freq=None): 130 | return raw.filter(low_freq, high_freq) 131 | 132 | 133 | def apply_ICA_to_RAW(raw, ica, dimensions_to_keep=124): 134 | return ica.apply(raw, n_pca_components=dimensions_to_keep) 135 | 136 | 137 | # TODO comment and document all code 138 | # TODO resolve all TODOs... 139 | 140 | # raw.plot_psd(fmin = 0,fmax=50, n_fft=2048, spatial_colors=True) 141 | # raw.info 142 | 143 | # raw.get_montage().plot(kind='topomap', show_names=True, sphere='auto') 144 | 145 | def generate_events(raw): 146 | events, event_ids = mne.events_from_annotations(raw, verbose=False) 147 | epochs = mne.Epochs(raw=raw, events=events, event_id=event_ids, preload=True, tmin=0, tmax=4, baseline=None, 148 | event_repeated='merge') 149 | # Gather all epochs within the raw data, assigning a corresponding event ID 150 | 151 | events_list = {(a, b, c): [] for a in ['imagined', 'perceived'] for b in ['flower', 'guitar', 'penguin'] for c in 152 | ['text', 'pictorial', 'sound']} 153 | # Generate a list of possible event combinations. There are 18 in total 154 | 155 | for id in event_ids: 156 | idl = id.lower() 157 | a = 'imagined' if 'imagination' in idl else 'perceived' 158 | b = 'guitar' if 'guitar' in idl else 'flower' if 'flower' in idl else 'penguin' 159 | c = 'text' if '_t_' in idl else 'pictorial' if '_image_' in idl else 'sound' 160 | events_list[(a, b, c)].append(id) 161 | # Map the ID of every real event to a corresponding event_id from the events_list variable 162 | 163 | for i, event in enumerate(events_list.items()): 164 | key, value = event 165 | name = '_'.join(key) 166 | try: 167 | mne.epochs.combine_event_ids(epochs, value, {name: 500 + i}, copy=False) 168 | # Merge all events with the same ID into a new ID, effectively grouping them for easy access 169 | except KeyError: 170 | print('KeyError whilst combining event IDs... skipping this ID') 171 | continue 172 | 173 | return events, event_ids, epochs, events_list 174 | 175 | 176 | def select_specific_epochs(epochs, A: list, B: list, C: list): 177 | select_epochs = [[a, b, c] for a in A for b in B for c in C] 178 | select_epochs = ['_'.join(k) for k in select_epochs] 179 | select_epochs = epochs[select_epochs] 180 | return select_epochs 181 | 182 | 183 | def crop_epochs(epochs, cropping_rules=None): 184 | if cropping_rules == None: 185 | cropping_rules = { 186 | 'imagined': { 187 | 'sound': 4, 188 | 'pictorial': 4, 189 | 'text': 4 190 | }, 191 | 'perceived': { 192 | 'sound': 4, 193 | 'pictorial': 3, 194 | 'text': 3 195 | } 196 | } 197 | 198 | for i, epoch in enumerate(zip(epochs, list(epochs.event_id.keys()))): 199 | epoch, name = epoch 200 | a, _, c = name.split('_') 201 | 202 | epochs[name].crop(tmin=0, tmax=cropping_rules[a][c]) 203 | 204 | return epochs 205 | 206 | 207 | def get_output_dims_by_factors(vector_length: int, orientation='landscape'): 208 | orientation = orientation.lower() 209 | 210 | pairs = np.array( 211 | [[p, vector_length // p] for p in range(2, int(np.sqrt(vector_length) + 1)) if vector_length % p == 0]) 212 | 213 | assert pairs.size > 0, f'No factor pairs exist for the given value: {vector_length}' 214 | assert orientation in ['landscape', 'portrait'], 'The orientation must be one of either landscape or portrait. ' 215 | 216 | squarest = np.sort(pairs[np.argmin(np.abs(pairs[:, 0] - pairs[:, 1]))]) 217 | return squarest if orientation == 'portrait' else squarest[::-1] 218 | 219 | 220 | def calc_total_windows(vector_length: int, window_width: int, overlap: int = 1, sample_rate=1024): 221 | return 1 + int((vector_length - (window_width * sample_rate)) / (overlap * sample_rate)) 222 | 223 | 224 | def split_vector_to_window_indices(vector_length: int, window_width=1, window_overlap=0.5, 225 | sample_rate=1024): # return (a, b) pairs from which to split the vector against 226 | window_width *= sample_rate 227 | window_overlap *= sample_rate 228 | # The window_width and window_overlap are products of the sample_rate, 229 | # so if 1 second is 1024Hz, then a 0.9 window_overlap is 0.9*sample_rate 230 | 231 | assert window_width % 1 == 0, f'The given window width does not produce an integer sample count. window_width={window_width / sample_rate}, sample_rate={sample_rate}' 232 | assert window_overlap % 1 == 0, f'The given window overlap does not produce an integer sample count. window_overlap={window_overlap / sample_rate}, sample_rate={sample_rate}' 233 | 234 | window_width, window_overlap = int(window_width), int(window_overlap) 235 | 236 | assert window_width != window_overlap, 'The window width cannot equal the overlap, as this produces no windows!' 237 | assert window_width > 0, 'Window width must be greater than zero.' 238 | assert window_overlap >= 0, 'Window overlap cannot be negative.' 239 | 240 | if window_overlap == 0: 241 | k = np.arange(start=0, stop=vector_length + (sample_rate if vector_length % sample_rate == 0 else 0)) 242 | return np.array(list(zip(k[::sample_rate], k[sample_rate::sample_rate]))) 243 | 244 | k = np.arange(start=0, stop=vector_length - window_overlap, step=window_width - window_overlap) 245 | k = k[k <= vector_length - window_width] 246 | return np.array([k, k + window_width]).T 247 | 248 | 249 | def generate_eeg_dataset(raw_eeg, channels: list = np.arange(128), per_channel=True, window_width_seconds=1, 250 | window_overlap_seconds=0.5): 251 | assert type(raw_eeg) in [np.ndarray, mne.io.eeglab.eeglab.RawEEGLAB], \ 252 | f'The given raw_eeg must be of type Numpy Array or RawEEGLAB. type={type(raw_eeg)}' 253 | 254 | # Trim channel count to match the maximum of our input data 255 | if type(raw_eeg) is mne.io.eeglab.eeglab.RawEEGLAB and len(channels) > len(raw_eeg.ch_names): 256 | channels = np.arange(len(raw_eeg.ch_names)) 257 | else: 258 | channels = np.arange(raw_eeg.shape[0]) 259 | 260 | window_size = window_width_seconds * 1024 # 1024Hz per second sampling frequency 261 | 262 | output_size = get_output_dims_by_factors(window_size) if per_channel else get_output_dims_by_factors( 263 | window_size * len(channels)) 264 | # Size of output image for each EEG sample, optionally per channel 265 | 266 | data = raw_eeg.squeeze() # Remove the outer-most dimension 267 | if type(raw_eeg) is mne.io.eeglab.eeglab.RawEEGLAB: 268 | data = raw_eeg.get_data(channels, start=0, stop=None) # loads all of the data. # (n_channels, n_data_points) 269 | 270 | splits = split_vector_to_window_indices(data.shape[1]) 271 | 272 | split_start, split_end = splits[:, 0], splits[:, 1] 273 | 274 | output = None # Prefilling our outputs array with zeros 275 | if per_channel: 276 | output = np.zeros((channels.shape[0], splits.shape[0], output_size[0], 277 | output_size[1])) # n_channels, n_windows, reshaped_window 278 | else: 279 | output = np.zeros((splits.shape[0], output_size[0], output_size[1])) 280 | 281 | for i, s in enumerate(zip(split_start, split_end)): 282 | a, b = s 283 | 284 | if per_channel: 285 | for c in channels: 286 | output[c, i] = data[c, a:b].reshape(-1, output_size[0], output_size[1]) 287 | else: 288 | output[i] = data[:, a:b].reshape(output_size[0], output_size[1]) 289 | # Populate the outputs array according to if we're operating per_channel or not 290 | 291 | del data 292 | return output 293 | -------------------------------------------------------------------------------- /util/preprocessing/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import glob 3 | import os 4 | import random 5 | import uuid 6 | 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | from util.preprocessing import pipeline 11 | import json 12 | 13 | 14 | def preprocess(eeg_data_dir='./data/subjects', output_dir='./data/outputs/preprocessing/', 15 | montage_file_name='ANTNeuro_montage', hypers=None): 16 | assert os.path.exists(f'{eeg_data_dir}/{montage_file_name}.json'), \ 17 | 'The specified montage file could not be found! ' \ 18 | 'Please check it is in the EEG data directory root.' 19 | 20 | if hypers is None: 21 | hypers = { 22 | 'RENDER.DO_PER_CHANNEL': True, 23 | 'PREPROCESSING.DO_MONTAGE': True, 24 | 'PREPROCESSING.DO_REMOVE_DC': True, 25 | 'PREPROCESSING.DO_LOW_PASS_FILTER': True, 26 | 'PREPROCESSING.DO_HIGH_PASS_FILTER': True, 27 | 'PREPROCESSING.DO_USE_ICA': True, 28 | 'PREPROCESSING.DO_REMOVE_EOG': True, 29 | 'PREPROCESSING.LOW_PASS_FILTER.FREQ': 0.1, 30 | 'PREPROCESSING.HIGH_PASS_FILTER.FREQ': 50, 31 | 'RENDER.WINDOW_WIDTH': 1, 32 | 'RENDER.WINDOW_OVERLAP': 0.5, 33 | 'META.CONFIG_NAME': 'config-1', 34 | 'META.CONFIG_COMMENT': 'The default config for any EEG preprocessing.', 35 | 'META.RAW.DO_SAVE': False 36 | } 37 | 38 | if 'META.RAW.DO_SAVE' not in hypers: 39 | hypers['META.RAW.DO_SAVE'] = False 40 | 41 | sub_sess_pairs = [] # subject, session 42 | 43 | for subject in range(20): 44 | for session in range(3): 45 | path = f'{eeg_data_dir}/Subject {subject}/Session {session}/sub{subject}_sess{session}' 46 | if not os.path.isfile(f'{path}.fdt'): 47 | continue 48 | sub_sess_pairs.append((subject, session)) 49 | 50 | print(f'Now executing config \'{hypers["META.CONFIG_NAME"]}\' ' 51 | f'against the following subject and session pairs: {sub_sess_pairs}') 52 | pbar_subjects = tqdm(len(sub_sess_pairs), desc='Subjects and Sessions') 53 | k = 0 54 | for subject, session in sub_sess_pairs: 55 | k += 1 56 | print(f'Now preprocessing data for Subject {subject}, session {session}. Progress = {k}/{len(sub_sess_pairs)}') 57 | try: 58 | unique_id = str(uuid.uuid4()) # A unique ID for this run 59 | print('The unique ID for this subject/session pairing is', unique_id) 60 | hypers['META.UUID'] = unique_id 61 | 62 | raw = pipeline.load_eeg(subject, session) 63 | if hypers['PREPROCESSING.DO_MONTAGE']: 64 | raw = pipeline.apply_montage(raw, f'{eeg_data_dir}/{montage_file_name}.json') 65 | 66 | if hypers['PREPROCESSING.DO_REMOVE_DC']: 67 | raw = pipeline.remove_DC(raw) 68 | 69 | if hypers['PREPROCESSING.DO_LOW_PASS_FILTER']: 70 | raw = pipeline.apply_filter(raw, low_freq=hypers['PREPROCESSING.LOW_PASS_FILTER.FREQ'], high_freq=None) 71 | 72 | if hypers['PREPROCESSING.DO_HIGH_PASS_FILTER']: 73 | raw = pipeline.apply_filter(raw, low_freq=None, high_freq=hypers['PREPROCESSING.HIGH_PASS_FILTER.FREQ']) 74 | 75 | ica = None 76 | if hypers['PREPROCESSING.DO_USE_ICA']: 77 | ica = pipeline.compute_ICA(raw) 78 | 79 | if hypers['PREPROCESSING.DO_REMOVE_EOG']: 80 | ica = pipeline.remove_EOG(raw, ica) 81 | # ica = pipeline.remove_ECG(raw, ica) # Sometimes works, sometimes does not, seems to be an issue with MNE 82 | raw = pipeline.apply_ICA_to_RAW(raw, ica) 83 | del ica # It is no longer needed, so we delete it from memory 84 | 85 | _, _, epochs, _ = pipeline.generate_events(raw) 86 | path = f'{eeg_data_dir}/preprocessed/Subject {subject}/Session {session}/{hypers["META.CONFIG_NAME"]}/{unique_id} ' 87 | os.makedirs(path, exist_ok=True) 88 | 89 | if hypers['META.RAW.DO_SAVE']: 90 | raw.save(f'{path}/sub_{subject}_sess_{session}_preprocessed_raw.fif.gz') 91 | 92 | with open(f'{path}/sub_{subject}_sess_{session}_hyperparams.json', 'w') as f: 93 | json.dump(hypers, f, sort_keys=True, indent=4) 94 | 95 | del raw 96 | 97 | A, B, C = ['imagined', 'perceived'], ['guitar', 'penguin', 'flower'], ['text', 'sound', 'pictorial'] 98 | select_epochs = pipeline.select_specific_epochs(epochs, A, B, C) 99 | 100 | cropped_epochs = pipeline.crop_epochs(select_epochs) 101 | del select_epochs 102 | 103 | print('All preprocessing now complete, saving images!') 104 | 105 | pbar_epochs = tqdm(len(cropped_epochs), desc='Epoch progress') 106 | 107 | for i, p in enumerate(zip(cropped_epochs, cropped_epochs.event_id)): 108 | epoch, name = p 109 | images = pipeline.generate_eeg_dataset( 110 | epoch.squeeze(), 111 | per_channel=hypers['RENDER.DO_PER_CHANNEL'], 112 | window_width_seconds=hypers['RENDER.WINDOW_WIDTH'], 113 | window_overlap_seconds=hypers['RENDER.WINDOW_OVERLAP'] 114 | ) # Remove outer dimension as this is just 1, so useless 115 | # pbar_channels = tqdm(images.shape[0], position=1, desc='Channel progress') 116 | 117 | for c, channel in enumerate(images): 118 | dir = f'{output_dir}/subject_{subject}/session_{session}/{hypers["META.CONFIG_NAME"]}/{unique_id}/channel_{c}' 119 | os.makedirs(dir, exist_ok=True) 120 | # pbar_event = tqdm(channel.shape[0], position=2, desc='Event progress') 121 | for e, event in enumerate(channel): 122 | im = Image.fromarray(event, 'L') 123 | im.save(f'{dir}/epoch_{i}_channel_{c}_event_{e}_{name}.jpg') 124 | 125 | # pbar_event.update(1) 126 | 127 | # pbar_channels.update(1) 128 | # pbar_event.close() 129 | 130 | pbar_epochs.update(1) 131 | # pbar_channels.close() 132 | 133 | print(f'Finished saving images for epoch {i}. Name={name}') 134 | 135 | # pbar_epochs.close() 136 | 137 | pbar_subjects.update(1) 138 | print(f'Completed preprocessing for subject {subject}, session {session}') 139 | 140 | del cropped_epochs 141 | except Exception as e: 142 | print(f'Encountered an exception for Subject {subject} session {session}. Stack trace is as follows...') 143 | print(e) 144 | continue 145 | 146 | 147 | def load_and_process_hyperparameters(dir: str, shuffle=True): 148 | print(f'Loading configs from dir: {dir}') 149 | configs = [] 150 | 151 | for config_path in glob.iglob(f'{dir}/**/**/*.json'): 152 | with open(config_path, 'r') as f: 153 | configs.append(json.load(f)) 154 | # print('Loaded config:', json.load(f)) 155 | 156 | print(f'Found {len(configs)} to be processed...') 157 | return random.shuffle(configs) if shuffle else configs 158 | 159 | 160 | if __name__ == '__main__': 161 | print('Here we go...') 162 | num_cpu = '8' 163 | os.environ['OMP_NUM_THREADS'] = num_cpu 164 | ALLOW_DEFAULT_CONFIG = False 165 | 166 | configs = load_and_process_hyperparameters('./data/configurations') 167 | for i, config in enumerate(configs): 168 | if config['META.CONFIG_NAME'] == 'default-config' and not ALLOW_DEFAULT_CONFIG: 169 | continue 170 | print(f'Now processing with the following configuration ({i} of {len(configs)}):') 171 | print(config) 172 | try: 173 | preprocess(hypers=config) 174 | except Exception as e: 175 | print('An exception has been thrown, could not perform preprocessing for the given config...') 176 | print(config) 177 | 178 | print('Every config has now been processed, hurray! Terminating...') 179 | -------------------------------------------------------------------------------- /util/resize.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from PIL import Image, ImageOps, ImageFile 5 | 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | 8 | targets = [ 9 | (32, 32), 10 | (48, 48), 11 | (48, 32), 12 | (32, 16), 13 | (64, 32), 14 | (64, 48), 15 | (64, 64), 16 | (96, 96), 17 | (128, 128), 18 | (256, 256), 19 | (512, 512), 20 | (1024, 1024), 21 | (256, 128), 22 | (512, 128), 23 | (1024, 512), 24 | (1024, 256), 25 | (96, 48) 26 | ] 27 | target_dirs = ['flower', 'guitar', 'penguin'] 28 | 29 | for w, h in targets: 30 | for d in target_dirs: 31 | 32 | os.makedirs(f'{d}-{w}x{h}', exist_ok=True) 33 | for f in glob.iglob(f'./{d}/*.*'): 34 | print(w, h, f) 35 | name = f.split('/')[-1] 36 | img = Image.open(f) 37 | img = ImageOps.fit(img, (w, h), Image.Resampling.LANCZOS) 38 | img.save(f'{d}-{w}x{h}/{name}') 39 | -------------------------------------------------------------------------------- /util/update_configs.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | api = wandb.Api() 4 | run = api.run("jd202/bath-thesis/1fqllayw") 5 | 6 | run.config['learning_rate'] = 3e-4 7 | run.config['training_timesteps'] = 5000 8 | run.config['sampling_timesteps'] = 250 9 | run.config['image_size'] = 32 10 | run.config['number_of_samples'] = 25 11 | run.config['batch_size'] = 512 12 | run.config['use_amp'] = False 13 | run.config['use_fp16'] = True 14 | run.config['gradient_accumulation_rate'] = 2 15 | run.config['ema_update_rate'] = 10 16 | run.config['ema_decay'] = 0.995 17 | run.config['adam_betas'] = (0.9, 0.99) 18 | run.config['save_and_sample_rate'] = 1000 19 | run.config['do_split_batches'] = False 20 | run.config['timesteps'] = 1000 21 | run.config['loss_type'] = 'L2' 22 | run.config['unet_dim'] = 16 23 | run.config['unet_mults'] = (1, 2, 4, 8) 24 | run.config['unet_channels'] = 3 25 | run.config['training_objective'] = 'pred_x0' 26 | 27 | run.update() 28 | -------------------------------------------------------------------------------- /visualise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import wandb 4 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 5 | 6 | torch.cuda.empty_cache() 7 | wandb.login() 8 | 9 | default_hypers = dict( 10 | learning_rate=3e-4, 11 | training_timesteps=1001, 12 | sampling_timesteps=250, 13 | image_size=32, 14 | number_of_samples=25, 15 | batch_size=256, 16 | use_amp=False, 17 | use_fp16=False, 18 | gradient_accumulation_rate=2, 19 | ema_update_rate=10, 20 | ema_decay=0.995, 21 | adam_betas=(0.9, 0.99), 22 | save_and_sample_rate=1000, 23 | do_split_batches=False, 24 | timesteps=4000, 25 | loss_type='L2', 26 | unet_dim=128, 27 | unet_mults=(1, 2, 2, 2), 28 | unet_channels=3, 29 | training_objective='pred_x0' 30 | ) 31 | 32 | wandb.init(config=default_hypers, project='bath-thesis', entity='jd202') 33 | 34 | model = Unet( 35 | dim=wandb.config.unet_dim, 36 | dim_mults=wandb.config.unet_mults, 37 | channels=wandb.config.unet_channels 38 | ) 39 | 40 | diffusion = GaussianDiffusion( 41 | model, 42 | image_size=wandb.config.image_size, 43 | timesteps=wandb.config.timesteps, # number of steps 44 | sampling_timesteps=wandb.config.sampling_timesteps, 45 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) 46 | loss_type=wandb.config.loss_type, # L1 or L2 47 | training_objective=wandb.config.training_objective 48 | ) 49 | 50 | trainer = Trainer( 51 | diffusion, 52 | '/Users/jake/Desktop/scp/cifar', 53 | train_batch_size=wandb.config.batch_size, 54 | training_learning_rate=wandb.config.learning_rate, 55 | num_training_steps=wandb.config.training_timesteps, # total training steps 56 | num_samples=wandb.config.number_of_samples, 57 | gradient_accumulate_every=wandb.config.gradient_accumulation_rate, # gradient accumulation steps 58 | ema_update_every=wandb.config.ema_update_rate, 59 | ema_decay=wandb.config.ema_decay, # exponential moving average decay 60 | amp=wandb.config.use_amp, # turn on mixed precision 61 | fp16=wandb.config.use_fp16, 62 | save_and_sample_every=wandb.config.save_and_sample_rate 63 | ) 64 | 65 | trainer.load('./results/loadins', '56') 66 | trainer.ema.ema_model.eval() 67 | 68 | torch.save(model.state_dict(), './') 69 | --------------------------------------------------------------------------------