├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── denoising_diffusion_pytorch
├── __init__.py
├── diffusion_models
│ ├── __init__.py
│ ├── components.py
│ └── denoising_diffusion_pytorch.py
├── entities
│ ├── __init__.py
│ └── unet.py
└── utils.py
├── images
├── denoising-diffusion.png
└── sample.png
├── model.py
├── pytorch-xla-env-setup.py
├── requirements.txt
├── sample.py
├── setup.py
├── setup.sh
├── sweep.yaml
├── test.py
├── util
├── __init__.py
├── configs
│ ├── __init__.py
│ ├── generate_configs.py
│ ├── samples
│ │ ├── all-but-no-filters-or-dc.json
│ │ ├── all-but-no-filters.json
│ │ ├── default-config-copy.json
│ │ ├── ica-only.json
│ │ ├── raw-high-pass-ica.json
│ │ ├── raw-high-pass-no-ica.json
│ │ ├── raw-loss-pass-ica.json
│ │ ├── raw-low-pass-no-ica.json
│ │ ├── raw-only.json
│ │ ├── remove-all-noise-ica-no-eog.json
│ │ ├── remove-all-noise-no-ica.json
│ │ └── requirements.txt
│ └── variations.json
├── move.py
├── preprocessing
│ ├── EEG_to_Dataset_Pipeline.ipynb
│ ├── __init__.py
│ ├── pipeline.py
│ └── preprocess.py
├── resize.py
└── update_configs.py
└── visualise.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generation results
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | .idea/
135 | /util/configs/samples/generated/
136 | /datasets/
137 | /wandb/
138 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Denoising Diffusion Probabilistic Model, in PyTorch
4 |
5 | Implementation of Denoising Diffusion Probabilistic Models in PyTorch. It is a new approach to generative modeling that may have the potential to rival GANs. It uses denoising score matching to estimate the gradient of the data distribution, followed by Langevin sampling to sample from the true distribution.
6 |
7 | This implementation was transcribed from the official Tensorflow version here
8 |
9 | Youtube AI Educators - Yannic Kilcher | AI Coffeebreak with Letitia | Outlier
10 |
11 | Annotated code by Research Scientists / Engineers from 🤗 Huggingface
12 |
13 | Update: Turns out none of the technicalities really matters at all | "Cold Diffusion" paper
14 |
15 | 
16 |
17 | [](https://badge.fury.io/py/denoising-diffusion-pytorch)
18 |
19 | ## Install
20 |
21 | ```bash
22 | $ pip install denoising_diffusion_pytorch
23 | ```
24 |
25 | ## Usage
26 |
27 | ```python
28 | import torch
29 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion
30 |
31 | model = Unet(
32 | dim = 64,
33 | dim_mults = (1, 2, 4, 8)
34 | )
35 |
36 | diffusion = GaussianDiffusion(
37 | model,
38 | image_size = 128,
39 | timesteps = 1000, # number of steps
40 | loss_type = 'l1' # L1 or L2
41 | )
42 |
43 | training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1
44 | loss = diffusion(training_images)
45 | loss.backward()
46 | # after a lot of training
47 |
48 | sampled_images = diffusion.sample(batch_size = 4)
49 | sampled_images.shape # (4, 3, 128, 128)
50 | ```
51 |
52 | Or, if you simply want to pass in a folder name and the desired image dimensions, you can use the `Trainer` class to easily train a model.
53 |
54 | ```python
55 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion
56 | from denoising_diffusion_pytorch.utils import Trainer
57 |
58 | model = Unet(
59 | dim=64,
60 | dim_mults=(1, 2, 4, 8)
61 | ).cuda()
62 |
63 | diffusion = GaussianDiffusion(
64 | model,
65 | image_size=128,
66 | timesteps=1000, # number of steps
67 | sampling_timesteps=250,
68 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
69 | loss_type='l1' # L1 or L2
70 | ).cuda()
71 |
72 | trainer = Trainer(
73 | diffusion,
74 | 'path/to/your/images',
75 | train_batch_size=32,
76 | train_lr=8e-5,
77 | train_num_steps=700000, # total training steps
78 | gradient_accumulate_every=2, # gradient accumulation steps
79 | ema_decay=0.995, # exponential moving average decay
80 | amp=True # turn on mixed precision
81 | )
82 |
83 | trainer.train()
84 | ```
85 |
86 | Samples and model checkpoints will be logged to `./results` periodically
87 |
88 | ## Multi-GPU Training
89 |
90 | The `Trainer` class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their `accelerate` CLI
91 |
92 | At the project root directory, where the training script is, run
93 |
94 | ```python
95 | $ accelerate config
96 | ```
97 |
98 | Then, in the same directory
99 |
100 | ```python
101 | $ accelerate launch train.py
102 | ```
103 |
104 | ## Citations
105 |
106 | ```bibtex
107 | @inproceedings{NEURIPS2020_4c5bcfec,
108 | author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
109 | booktitle = {Advances in Neural Information Processing Systems},
110 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
111 | pages = {6840--6851},
112 | publisher = {Curran Associates, Inc.},
113 | title = {Denoising Diffusion Probabilistic Models},
114 | url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
115 | volume = {33},
116 | year = {2020}
117 | }
118 | ```
119 |
120 | ```bibtex
121 | @InProceedings{pmlr-v139-nichol21a,
122 | title = {Improved Denoising Diffusion Probabilistic Models},
123 | author = {Nichol, Alexander Quinn and Dhariwal, Prafulla},
124 | booktitle = {Proceedings of the 38th International Conference on Machine Learning},
125 | pages = {8162--8171},
126 | year = {2021},
127 | editor = {Meila, Marina and Zhang, Tong},
128 | volume = {139},
129 | series = {Proceedings of Machine Learning Research},
130 | month = {18--24 Jul},
131 | publisher = {PMLR},
132 | pdf = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf},
133 | url = {https://proceedings.mlr.press/v139/nichol21a.html},
134 | }
135 | ```
136 |
137 | ```bibtex
138 | @inproceedings{kingma2021on,
139 | title = {On Density Estimation with Diffusion Models},
140 | author = {Diederik P Kingma and Tim Salimans and Ben Poole and Jonathan Ho},
141 | booktitle = {Advances in Neural Information Processing Systems},
142 | editor = {A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan},
143 | year = {2021},
144 | url = {https://openreview.net/forum?id=2LdBqxc1Yv}
145 | }
146 | ```
147 |
148 | ```bibtex
149 | @article{Choi2022PerceptionPT,
150 | title = {Perception Prioritized Training of Diffusion Models},
151 | author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
152 | journal = {ArXiv},
153 | year = {2022},
154 | volume = {abs/2204.00227}
155 | }
156 | ```
157 |
158 | ```bibtex
159 | @article{Karras2022ElucidatingTD,
160 | title = {Elucidating the Design Space of Diffusion-Based Generative Models},
161 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
162 | journal = {ArXiv},
163 | year = {2022},
164 | volume = {abs/2206.00364}
165 | }
166 | ```
167 |
168 | ```bibtex
169 | @article{Song2021DenoisingDI,
170 | title = {Denoising Diffusion Implicit Models},
171 | author = {Jiaming Song and Chenlin Meng and Stefano Ermon},
172 | journal = {ArXiv},
173 | year = {2021},
174 | volume = {abs/2010.02502}
175 | }
176 | ```
177 |
178 | ```bibtex
179 | @misc{chen2022analog,
180 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
181 | author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
182 | year = {2022},
183 | eprint = {2208.04202},
184 | archivePrefix = {arXiv},
185 | primaryClass = {cs.CV}
186 | }
187 | ```
188 |
189 | ```bibtex
190 | @article{Qiao2019WeightS,
191 | title = {Weight Standardization},
192 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
193 | journal = {ArXiv},
194 | year = {2019},
195 | volume = {abs/1903.10520}
196 | }
197 | ```
198 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from denoising_diffusion_pytorch.diffusion_models.denoising_diffusion_pytorch import GaussianDiffusion
2 | from denoising_diffusion_pytorch.utils import Trainer
3 | from denoising_diffusion_pytorch.entities.unet import Unet
4 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/diffusion_models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/denoising_diffusion_pytorch/diffusion_models/__init__.py
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/diffusion_models/components.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | from einops import reduce, rearrange
6 | from torch import nn, einsum
7 | from torch.nn import functional as F
8 |
9 | from denoising_diffusion_pytorch.utils import exists, compute_L2_norm
10 |
11 |
12 | class Residual(nn.Module):
13 | def __init__(self, fn):
14 | super().__init__()
15 | self.fn = fn
16 |
17 | def forward(self, x, *args, **kwargs):
18 | return self.fn(x, *args, **kwargs) + x
19 |
20 |
21 | class WeightStandardizedConv2d(nn.Conv2d):
22 | """
23 | https://arxiv.org/abs/1903.10520
24 | weight standardization purportedly works synergistically with group normalization
25 | """
26 |
27 | def forward(self, x):
28 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
29 |
30 | weight = self.weight
31 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
32 | var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased=False))
33 | normalized_weight = (weight - mean) * (var + eps).rsqrt()
34 |
35 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
36 |
37 |
38 | class LayerNorm(nn.Module):
39 | def __init__(self, dim):
40 | super().__init__()
41 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
42 |
43 | def forward(self, x):
44 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
45 | var = torch.var(x, dim=1, unbiased=False, keepdim=True)
46 | mean = torch.mean(x, dim=1, keepdim=True)
47 | return (x - mean) * (var + eps).rsqrt() * self.g
48 |
49 |
50 | class PreNorm(nn.Module):
51 | def __init__(self, dim, fn):
52 | super().__init__()
53 | self.fn = fn
54 | self.norm = LayerNorm(dim)
55 |
56 | def forward(self, x):
57 | x = self.norm(x)
58 | return self.fn(x)
59 |
60 |
61 | class SinusoidalPositionalEmbedding(nn.Module):
62 | def __init__(self, dim):
63 | super().__init__()
64 | self.dim = dim
65 |
66 | def forward(self, x):
67 | device = x.device
68 | half_dim = self.dim // 2
69 | emb = math.log(10000) / (half_dim - 1)
70 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
71 | emb = x[:, None] * emb[None, :]
72 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
73 | return emb
74 |
75 |
76 | class LearnedSinusoidalPositionalEmbedding(nn.Module):
77 | """ following @crowsonkb 's lead with learned sinusoidal pos emb """
78 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
79 |
80 | def __init__(self, dim):
81 | super().__init__()
82 | assert (dim % 2) == 0
83 | half_dim = dim // 2
84 | self.weights = nn.Parameter(torch.randn(half_dim))
85 |
86 | def forward(self, x):
87 | x = rearrange(x, 'b -> b 1')
88 | frequencies = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
89 | fouriered = torch.cat((frequencies.sin(), frequencies.cos()), dim=-1)
90 | fouriered = torch.cat((x, fouriered), dim=-1)
91 | return fouriered
92 |
93 |
94 | class BaseBlock(nn.Module):
95 | def __init__(self, dim, dim_out, groups=8):
96 | super().__init__()
97 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
98 | self.norm = nn.GroupNorm(groups, dim_out)
99 | self.act = nn.SiLU()
100 |
101 | def forward(self, x, scale_shift=None):
102 | x = self.proj(x)
103 | x = self.norm(x)
104 |
105 | if exists(scale_shift):
106 | scale, shift = scale_shift
107 | x = x * (scale + 1) + shift
108 |
109 | x = self.act(x)
110 | return x
111 |
112 |
113 | class ResnetBlock(nn.Module):
114 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
115 | super().__init__()
116 | self.mlp = nn.Sequential(
117 | nn.SiLU(),
118 | nn.Linear(time_emb_dim, dim_out * 2)
119 | ) if exists(time_emb_dim) else None
120 |
121 | self.block1 = BaseBlock(dim, dim_out, groups=groups)
122 | self.block2 = BaseBlock(dim_out, dim_out, groups=groups)
123 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
124 |
125 | def forward(self, x, time_emb=None):
126 | scale_shift = None
127 | if exists(self.mlp) and exists(time_emb):
128 | time_emb = self.mlp(time_emb)
129 | time_emb = rearrange(time_emb, 'b c -> b c 1 1')
130 | scale_shift = time_emb.chunk(2, dim=1)
131 |
132 | h = self.block1(x, scale_shift=scale_shift)
133 |
134 | h = self.block2(h)
135 |
136 | return h + self.res_conv(x)
137 |
138 |
139 | class LinearAttention(nn.Module):
140 | def __init__(self, dim, heads=4, dim_head=32):
141 | super().__init__()
142 | self.scale = dim_head ** -0.5
143 | self.heads = heads
144 | hidden_dim = dim_head * heads
145 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
146 |
147 | self.to_out = nn.Sequential(
148 | nn.Conv2d(hidden_dim, dim, 1),
149 | LayerNorm(dim)
150 | )
151 |
152 | def forward(self, x):
153 | b, c, h, w = x.shape
154 | qkv = self.to_qkv(x).chunk(3, dim=1)
155 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
156 |
157 | q = q.softmax(dim=-2)
158 | k = k.softmax(dim=-1)
159 |
160 | q = q * self.scale
161 | v = v / (h * w)
162 |
163 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
164 |
165 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
166 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
167 | return self.to_out(out)
168 |
169 |
170 | class Attention(nn.Module):
171 | def __init__(self, dim, heads=4, dim_head=32, scale=10):
172 | super().__init__()
173 | self.scale = scale
174 | self.heads = heads
175 | hidden_dim = dim_head * heads
176 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
177 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
178 |
179 | def forward(self, x):
180 | b, c, h, w = x.shape
181 | qkv = self.to_qkv(x).chunk(3, dim=1)
182 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
183 |
184 | q, k = map(compute_L2_norm, (q, k))
185 |
186 | sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
187 | attn = sim.softmax(dim=-1)
188 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
189 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
190 | # I hate einsums...
191 | return self.to_out(out)
192 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/diffusion_models/denoising_diffusion_pytorch.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from random import random
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import wandb
7 | from einops import reduce
8 | from torch import nn
9 | from tqdm.auto import tqdm
10 |
11 | from denoising_diffusion_pytorch import utils
12 | from denoising_diffusion_pytorch.utils import normalise_to_negative_one_to_one, \
13 | unnormalise_to_zero_to_one, extract, linear_beta_schedule, cosine_beta_schedule, default
14 |
15 | # Constants
16 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
17 |
18 |
19 | # TODO add documentation/descriptions to every method and class
20 | # TODO generate model diagram and per-layer parameter count
21 |
22 |
23 | class GaussianDiffusion(nn.Module):
24 | def __init__(
25 | self,
26 | learning_model,
27 | *,
28 | image_size,
29 | timesteps=1000,
30 | sampling_timesteps=None,
31 | loss_type='l1',
32 | training_objective='pred_noise',
33 | beta_schedule='cosine',
34 | p2_loss_weight_gamma=0.,
35 | # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1.
36 | # is recommended
37 | p2_loss_weight_k=1,
38 | ddim_sampling_eta=1.
39 | ):
40 | """
41 | This class provides all important logic and behaviour for the base Gaussian Diffusion model.
42 |
43 | :param learning_model: The model used for learning the forwards diffusion process from x_T to x_0.
44 | This is typically a U-Net model, inline with the literature.
45 | :param image_size: The single dimension for the output image.
46 | For example, a value of 32 will produce a 32x32 pixel image output.
47 | :param timesteps: The number of timesteps to be used for the forward and reverse processes of the model.
48 | :param sampling_timesteps: The number of timesteps to be used for sampling.
49 | If this is less than param timesteps, then we are using Improved DDPM.
50 | :param loss_type: The type of loss we will use. This can be either L1 or L2 loss.
51 | :param training_objective: The objective that dictates what the model attempts to learn.
52 | This must be either pred_noise to learn noise, or pred_x0 to learn the truth image.
53 | """
54 | super().__init__()
55 | loss_type = loss_type.lower()
56 | training_objective = training_objective.lower()
57 | beta_schedule = beta_schedule.lower()
58 |
59 | assert loss_type in ['l1', 'l2'], f'The specified loss type, {loss_type}, must be either L1 or L2.'
60 | assert training_objective in ['pred_noise', 'pred_x0'], \
61 | 'The given objective must be either pred_noise (predict noise) or pred_x0 (predict image start)'
62 | assert beta_schedule in ['linear', 'cosine'], f'The given beta schedule {beta_schedule} is invalid!'
63 | assert not (type(self) != GaussianDiffusion and learning_model.channels == learning_model.out_dim)
64 | # TODO add an assertion error message
65 | assert sampling_timesteps is None or 0 < sampling_timesteps <= timesteps, \
66 | 'The given sampling timesteps value is invalid!'
67 |
68 | self.learning_model = learning_model
69 | self.channels = self.learning_model.channels
70 | self.self_condition = self.learning_model.self_condition
71 | self.image_size = image_size
72 | self.objective = training_objective
73 |
74 | betas = linear_beta_schedule(timesteps) if beta_schedule == 'linear' else cosine_beta_schedule(timesteps)
75 | alphas = 1. - betas
76 | alphas_cumprod = torch.cumprod(alphas, axis=0)
77 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
78 |
79 | timesteps, = betas.shape
80 | self.num_timesteps = int(timesteps)
81 | self.loss_type = loss_type
82 |
83 | # Sampling-related parameters
84 |
85 | self.sampling_timesteps = default(sampling_timesteps, timesteps)
86 | # The default number of sampling timesteps. Reduced for Improved DDPM
87 |
88 | assert self.sampling_timesteps <= timesteps
89 | self.is_ddim_sampling = self.sampling_timesteps < timesteps
90 | self.ddim_sampling_eta = ddim_sampling_eta
91 |
92 | # Helper function to convert function values from float64 to float32
93 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
94 |
95 | register_buffer('betas', betas)
96 | register_buffer('alphas_cumprod', alphas_cumprod)
97 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
98 |
99 | # calculations for diffusion q(x_t | x_{t-1}) and others
100 |
101 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
102 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
103 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
104 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
105 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
106 |
107 | # calculations for posterior q(x_{t-1} | x_t, x_0)
108 |
109 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
110 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
111 |
112 | register_buffer('posterior_variance', posterior_variance)
113 |
114 | # Log when the variance is clipped, as the posterior variance is zero
115 | # at the beginning of the diffusion chain (x_0 to x_T)
116 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
117 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
118 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
119 |
120 | # Calculate reweighting values for p2 loss
121 | register_buffer('p2_loss_weight',
122 | (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
123 |
124 | def predict_x0_from_noise(self, x_t, t, noise):
125 | return (
126 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
127 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
128 | )
129 |
130 | def predict_noise_from_xt(self, x_t, t, x0):
131 | """
132 | This method attempts to predict the gaussian noise for the forwards process, from x_T to x_0.
133 | :param x_t: The isotropic gaussian noise sample at the beginning of the forwards process.
134 | :param t: The number of sampling timesteps.
135 | """
136 | return ((extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) /
137 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape))
138 |
139 | def q_posterior(self, x_start, x_t, t):
140 | posterior_mean = (
141 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
142 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
143 | )
144 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
145 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
146 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
147 |
148 | def model_predictions(self, x, t, x_self_cond=None):
149 | model_output = self.learning_model(x, t, x_self_cond)
150 | predicted_noise, x_0 = None, None
151 |
152 | if self.objective == 'pred_noise':
153 | predicted_noise = model_output
154 | x_0 = self.predict_x0_from_noise(x, t, model_output)
155 |
156 | elif self.objective == 'pred_x0':
157 | predicted_noise = self.predict_noise_from_xt(x, t, model_output)
158 | x_0 = model_output # The output of the model, x0
159 |
160 | return ModelPrediction(predicted_noise, x_0)
161 |
162 | def p_mean_variance(self, x, t, x_self_cond=None, clip_denoised=True):
163 | """
164 | This method computes the mean and variance by sampling directly from the model.
165 | """
166 | preds = self.model_predictions(x, t, x_self_cond)
167 | x_start = preds.pred_x_start
168 |
169 | if clip_denoised:
170 | x_start.clamp_(-1., 1.)
171 |
172 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
173 | return model_mean, posterior_variance, posterior_log_variance, x_start
174 |
175 | @torch.no_grad()
176 | def compute_sample_for_timestep(self, x, t: int, x_self_cond=None, clip_denoised=True):
177 | """
178 | This method takes a single step in the sampling/forwards process. For example, in a forwards process with 250
179 | sampling steps, this method will be called 250 times.
180 | :param x: The current image for the given timestep t. If t=T, then we are at the beginning of the forwards
181 | process, and x will be an isotropic Gaussian noise sample.
182 | :param t: The current sampling timestep that defines x_T, where 0 <= t <= T.
183 | """
184 | b, *_, device = *x.shape, x.device
185 | # batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long)
186 | batched_times = torch.full((x.shape[0],), t, dtype=torch.long)
187 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x=x, t=batched_times, x_self_cond=x_self_cond,
188 | clip_denoised=clip_denoised)
189 | noise = torch.randn_like(x) if t > 0 else 0.
190 | # Reset noise if t == zero, i.e., if we now have the output image of the model.
191 | # This nulls-out the following operation, preserving the output image.
192 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
193 | return pred_img, x_start
194 |
195 | @torch.no_grad()
196 | def compute_complete_sample(self, shape, device):
197 | """
198 | This method simply runs a for loop used for computing the series of samples from x_T through to x_0.
199 | The returned value is the final output image of the model.
200 |
201 | A progress bar is provided by the tqdm library throughout.
202 |
203 | :param shape: The shape of the output, not the image output.
204 | Specifically, this is set to the batch size x n_channels , determining how
205 | many samples should be generated in one iteration.
206 | :param device: The device to use for computing the samples on.
207 | """
208 | img = torch.randn(shape, device=device)
209 | # img = torch.randn(shape)
210 |
211 | x_start = None
212 |
213 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling loop time step', total=self.num_timesteps):
214 | self_cond = x_start if self.self_condition else None
215 | img, x_start = self.compute_sample_for_timestep(img, t, self_cond)
216 |
217 | img = unnormalise_to_zero_to_one(img)
218 | return img
219 |
220 | @torch.no_grad()
221 | def ddim_sample(self, shape, clip_denoised=True):
222 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[
223 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
224 |
225 | times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1]
226 | times = list(reversed(times.int().tolist()))
227 | time_pairs = list(zip(times[:-1], times[1:]))
228 |
229 | img = torch.randn(shape, device=device)
230 |
231 | # Begin image, xT, sampled as random noise
232 | # TODO need a way to specify a noise sample from the EEG forward process,
233 | # not to randomly generate it
234 |
235 | x_start = None
236 |
237 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):
238 | alpha = self.alphas_cumprod[time]
239 | alpha_next = self.alphas_cumprod[time_next]
240 |
241 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
242 |
243 | self_cond = x_start if self.self_condition else None
244 |
245 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond)
246 |
247 | if clip_denoised:
248 | x_start.clamp_(-1., 1.)
249 |
250 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
251 | c = ((1 - alpha_next) - sigma ** 2).sqrt()
252 |
253 | noise = torch.randn_like(img) if time_next > 0 else 0.
254 |
255 | img = x_start * alpha_next.sqrt() + \
256 | c * pred_noise + \
257 | sigma * noise
258 |
259 | img = utils.unnormalise_to_zero_to_one(img)
260 | return img
261 |
262 | @torch.no_grad()
263 | def sample(self, batch_size):
264 | """
265 | This method computes a given number of samples from the model in one sampling step.
266 | This method does not execute the many inferences required to draw a sample
267 | but instead determines which sampling strategy to use (standard or reduced sample step count),
268 | image output dimensions and batch sizes.
269 | """
270 | batch_size = 16 if batch_size is None else batch_size # default value
271 | image_size, channels = self.image_size, self.channels
272 | sampling_function = self.ddim_sample if self.is_ddim_sampling else self.compute_complete_sample
273 | return sampling_function((batch_size, channels, image_size, image_size))
274 |
275 | @torch.no_grad()
276 | def interpolate(self, x1, x2, t=None, lam=0.5):
277 | b, *_, device = *x1.shape, x1.device
278 | t = default(t, self.num_timesteps - 1)
279 |
280 | assert x1.shape == x2.shape
281 |
282 | t_batched = torch.stack([torch.tensor(t, device=device)] * b)
283 | # t_batched = torch.stack([torch.tensor(t)] * b)
284 | xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
285 |
286 | img = (1 - lam) * xt1 + lam * xt2
287 | for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
288 | img = self.compute_sample_for_timestep(img, torch.full((b,), i, device=device, dtype=torch.long))
289 | # img = self.compute_sample_for_timestep(img, torch.full((b,), i, dtype=torch.long))
290 |
291 | return img
292 |
293 | def q_sample(self, x_start, t, noise=None):
294 | noise = default(noise, lambda: torch.randn_like(x_start))
295 | # Uses the supplied noise value if it exists,
296 | # or generates more Gaussian noise with the same dimensions as x_start
297 |
298 | return (
299 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
300 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
301 | )
302 |
303 | @property
304 | def loss_fn(self):
305 | """
306 | This function does not compute the loss for two given images (truth and predicted),
307 | but instead returns the appropriate loss function in accordance with the stated
308 | desired loss function.
309 |
310 | Given that this method returns a function, then any parameters supplied are
311 | naturally passed into that function, thus deriving the loss value.
312 | """
313 | if self.loss_type == 'l1':
314 | return F.l1_loss
315 | elif self.loss_type == 'l2':
316 | return F.mse_loss
317 | else:
318 | raise ValueError(f'invalid loss type {self.loss_type}')
319 |
320 | def compute_loss_for_timestep(self, eeg_sample, target_sample, timestep, noise=None):
321 | """
322 | This method computes the losses for a full pass of a given image through the model.
323 | This effectively performs the full noising then denoising/diffusion process.
324 | After this, the losses between the true image and the generated image (via the diffusion process)
325 | are compared, their losses computed, and returned.
326 |
327 | :param x_start: The given image, x_0, to use for other processes.
328 | :param generated_noise: A sample of noise to be applied to the given x_start value.
329 | :param timestep: The timestep to compute losses for. This can be updated linearly, or sampled randomly.
330 | """
331 | # b, c, h, w = x_start.shape # Commented out as these values are not used
332 | generated_noise = default(noise, lambda: torch.randn_like(eeg_sample))
333 | # Generates random normal/Gaussian noise with the same dimensions as the given input
334 |
335 | x = self.q_sample(x_start=eeg_sample, t=timestep, noise=generated_noise)
336 | # Warps the image by the noise we just generated
337 | # in accordance to our beta scheduling choice and current timestep t
338 |
339 | # If you are performing self-conditioning, then 50% of the training iterations
340 | # will predict x_start from the current timestep, t. This will then be used to
341 | # update U-Net's gradients with. This technique increases training time by 25%,
342 | # but appears to significantly lower the FID score of the model
343 | x_self_cond = None # TODO look into using this
344 | if self.self_condition and random() < 0.5:
345 | with torch.no_grad():
346 | x_self_cond = self.model_predictions(x, timestep).pred_x_start
347 | x_self_cond.detach_()
348 |
349 | # Next we predict the output according to our objective,
350 | # then compute the gradient from that result
351 | model_out = self.learning_model(x, timestep, x_self_cond) # The prediction of our model
352 |
353 | if self.objective == 'pred_noise':
354 | # If we are trying to predict the noise that was just added
355 |
356 | # target = noise
357 | target = default(noise, torch.randn_like(target_sample))
358 | elif self.objective == 'pred_x0':
359 | # If we are trying to predict the original, true image, x_0
360 | # target = x_start #
361 | target = target_sample
362 | else:
363 | raise ValueError(f'unknown objective {self.objective}')
364 |
365 | loss = self.loss_fn(model_out, target, reduction='none')
366 |
367 | loss = reduce(loss, 'b ... -> b (...)', 'mean')
368 |
369 | loss = loss * extract(self.p2_loss_weight, timestep, loss.shape)
370 | wandb.log({"raw_losses": loss, "averaged_loss": loss.mean().item()})
371 | # TODO maybe log Inception Score and/or FID.
372 | return loss.mean()
373 |
374 | def forward(self, img, *args, **kwargs):
375 | """
376 | When calling model(x), this is the method that is called.
377 | This method takes an input image - x_0 - from the dataset and trains the diffusion & U-Net model on it.
378 |
379 | :param img: The image to be used as x_0, which is the starting and ending
380 | image for the noising and diffusion process, respectively.
381 | """
382 |
383 | # b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
384 | eeg_sample, target_sample = img
385 |
386 | b, c, h, w, device, img_size = *eeg_sample.shape, eeg_sample.device, self.image_size
387 |
388 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
389 | timestep = torch.randint(0, self.num_timesteps, (b,), device=device).long()
390 | # timestep = torch.randint(0, self.num_timesteps, (b,)).long()
391 |
392 | eeg_sample = normalise_to_negative_one_to_one(eeg_sample)
393 | target_sample = normalise_to_negative_one_to_one(target_sample)
394 | return self.compute_loss_for_timestep(eeg_sample, target_sample, timestep, *args, **kwargs)
395 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/entities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/denoising_diffusion_pytorch/entities/__init__.py
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/entities/unet.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from denoising_diffusion_pytorch.diffusion_models.components import Residual, PreNorm, SinusoidalPositionalEmbedding, \
7 | LearnedSinusoidalPositionalEmbedding, ResnetBlock, LinearAttention, Attention
8 | from denoising_diffusion_pytorch.utils import upsample, downsample, default
9 |
10 |
11 | class Unet(nn.Module):
12 | def __init__(
13 | self,
14 | dim,
15 | init_dim=None,
16 | out_dim=None,
17 | dim_mults=(1, 2, 4, 8),
18 | channels=3,
19 | self_condition=False,
20 | resnet_block_groups=8,
21 | learned_variance=False,
22 | # TODO learned_variance seems to determine if we follow Improved DDPM and learn the variance alongside
23 | # the mean? Might be worth setting to True if so.
24 | learned_sinusoidal_cond=False,
25 | learned_sinusoidal_dim=16 # Per Improved DDPM
26 | ):
27 | """
28 | :param channels: A multiplier for each channel's layer count.
29 | """
30 | super().__init__()
31 |
32 | # determine dimensions
33 |
34 | self.channels = channels
35 | self.self_condition = self_condition
36 | input_channels = channels * (2 if self_condition else 1)
37 |
38 | init_dim = default(init_dim, dim)
39 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding=3)
40 |
41 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
42 | in_out = list(zip(dims[:-1], dims[1:]))
43 |
44 | block_klass = partial(ResnetBlock, groups=resnet_block_groups)
45 |
46 | # time embeddings
47 |
48 | time_dim = dim * 4
49 |
50 | self.learned_sinusoidal_cond = learned_sinusoidal_cond
51 |
52 | if learned_sinusoidal_cond: # TODO determine if we want to use this
53 | sinu_pos_emb = LearnedSinusoidalPositionalEmbedding(learned_sinusoidal_dim)
54 | fourier_dim = learned_sinusoidal_dim + 1
55 | else:
56 | sinu_pos_emb = SinusoidalPositionalEmbedding(dim)
57 | fourier_dim = dim
58 |
59 | self.time_mlp = nn.Sequential(
60 | sinu_pos_emb,
61 | nn.Linear(fourier_dim, time_dim),
62 | nn.GELU(),
63 | nn.Linear(time_dim, time_dim)
64 | )
65 |
66 | # layers
67 |
68 | self.downs = nn.ModuleList([])
69 | self.ups = nn.ModuleList([])
70 | num_resolutions = len(in_out)
71 |
72 | for ind, (dim_in, dim_out) in enumerate(in_out):
73 | is_last = ind >= (num_resolutions - 1)
74 |
75 | self.downs.append(nn.ModuleList([
76 | block_klass(dim_in, dim_in, time_emb_dim=time_dim),
77 | block_klass(dim_in, dim_in, time_emb_dim=time_dim),
78 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
79 | downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1)
80 | ]))
81 |
82 | mid_dim = dims[-1]
83 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
84 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
85 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
86 |
87 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
88 | is_last = ind == (len(in_out) - 1)
89 |
90 | self.ups.append(nn.ModuleList([
91 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
92 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
93 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
94 | upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1)
95 | ]))
96 |
97 | default_out_dim = channels * (1 if not learned_variance else 2)
98 | self.out_dim = default(out_dim, default_out_dim)
99 |
100 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
101 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
102 |
103 | def forward(self, x, time, x_self_cond=None):
104 | if self.self_condition:
105 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
106 | x = torch.cat((x_self_cond, x), dim=1)
107 |
108 | x = self.init_conv(x)
109 | r = x.clone()
110 |
111 | t = self.time_mlp(time)
112 |
113 | h = []
114 |
115 | for block1, block2, attn, downsample in self.downs:
116 | x = block1(x, t)
117 | h.append(x)
118 |
119 | x = block2(x, t)
120 | x = attn(x)
121 | h.append(x)
122 |
123 | x = downsample(x)
124 |
125 | x = self.mid_block1(x, t)
126 | x = self.mid_attn(x)
127 | x = self.mid_block2(x, t)
128 |
129 | for block1, block2, attn, upsample in self.ups:
130 | x = torch.cat((x, h.pop()), dim=1)
131 | x = block1(x, t)
132 |
133 | x = torch.cat((x, h.pop()), dim=1)
134 | x = block2(x, t)
135 | x = attn(x)
136 |
137 | x = upsample(x)
138 |
139 | x = torch.cat((x, r), dim=1)
140 |
141 | x = self.final_res_block(x, t)
142 | return self.final_conv(x)
143 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import multiprocessing
3 | import os
4 | import random
5 | import shutil
6 | from collections import defaultdict
7 | from functools import partial
8 | from pathlib import Path
9 |
10 | import torch
11 | import wandb
12 | from PIL import Image
13 | from accelerate import Accelerator
14 | from ema_pytorch import EMA
15 | from torch import nn
16 | from torch.nn import functional as F
17 | from torch.optim import Adam
18 | from torch.utils.data import Dataset, DataLoader
19 | from torchvision import transforms as T, utils
20 | from tqdm.auto import tqdm
21 |
22 |
23 | def exists(x):
24 | """
25 | This method checks if the given parameter is not equal to None, i.e., if it has a value.
26 |
27 | :param x: The value to be checked for None-status.
28 | """
29 | return x is not None
30 |
31 |
32 | class GenericDataset(Dataset):
33 | def __init__(
34 | self,
35 | folder,
36 | image_size: int,
37 | exts: list = None,
38 | augment_horizontal_flip=False,
39 | convert_image_to=None
40 | ):
41 | """
42 | This class loads images from a given directory and resizes them to be square.
43 |
44 | :param folder: The folder that contains the files for this dataset.
45 | :param int image_size: The dimensions for the given image. All images will be converted to a square with these dimensions.
46 | :param list exts: A list of file extensions that this class should load, such as jpg and png.
47 | :param augment_horizontal_flip: If a horizontal (left-to-right) flip of the image should be performed.
48 | :param convert_image_to: A given lambda function specifying how to convert the input images
49 | for this dataset. This is applied before any other manipulations, such as resizing or
50 | horizontal flipping.
51 | """
52 | super().__init__()
53 | if exts is None:
54 | exts = ['jpg', 'jpeg', 'png', 'tiff']
55 | self.folder = folder
56 | self.image_size = image_size
57 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
58 |
59 | lambda_convert_function = partial(convert_image_to, convert_image_to) \
60 | if exists(convert_image_to) \
61 | else nn.Identity() # TODO determine what partial(...) does
62 | # nn.Identity simply returns the input.
63 | # So, if convert_image_to_fn is None,
64 | # lambda_convert_function will just return whatever is input to it
65 |
66 | self.transform = T.Compose([
67 | T.Lambda(lambda_convert_function),
68 | # Execute some lambda of code to convert an image of the dataset by, such as greyscaling.
69 | T.Resize(image_size),
70 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
71 | T.CenterCrop(image_size),
72 | T.ToTensor()
73 | ])
74 |
75 | def __len__(self):
76 | """
77 | Return the number of images in this dataset.
78 | """
79 | return len(self.paths)
80 |
81 | def __getitem__(self, index):
82 | """
83 | Return the given image for the given index in this dataset.
84 | """
85 | path = self.paths[index]
86 | img = Image.open(path)
87 | return self.transform(img)
88 |
89 |
90 | def find_and_move_unsorted(src_dir, ftype: str):
91 | if len(os.listdir(f'{src_dir}/unsorted')) <= 0:
92 | pass
93 | # for path in Path(f'{src_dir}/unsorted').rglob(f'*.{ftype}'):
94 | for path in Path(f'{src_dir}/unsorted').rglob(f'*.*'):
95 | name = str(path).lower().split('/')[-1]
96 | p = None
97 | p = 'penguin' if 'penguin' in name else p
98 | p = 'guitar' if 'guitar' in name else p
99 | p = 'flower' if 'flower' in name else p
100 |
101 | assert p is not None, f'Could not sort file at path: {path}'
102 | print('Moving', name, 'to', src_dir, p)
103 | os.makedirs(f'{src_dir}/{p}', exist_ok=True)
104 | os.rename(path, f'{src_dir}/{p}/{name}')
105 |
106 | shutil.rmtree(f'{src_dir}/unsorted')
107 | os.mkdir(f'{src_dir}/unsorted')
108 |
109 |
110 | class EEGTargetsDataset(Dataset):
111 | def __init__(self,
112 | eeg_directory='./datasets/eeg',
113 | targets_directory='./datasets/targets',
114 | labels: list = None,
115 | shuffle_eeg=True,
116 | shuffle_targets=True,
117 | file_types=None,
118 | unsorted_eeg_policy=None,
119 | unsorted_target_policy=None,
120 | image_size=[32, 32]):
121 | """
122 | This class loads a dataset of EEG and target image pairs, and then determines the correct class label for each one.
123 |
124 | A typical directory structure is as follows:
125 | datasets/
126 | ├── eeg
127 | │ ├── flower
128 | │ ├── guitar
129 | │ ├── penguin
130 | │ └── unsorted
131 | └── target
132 | ├── flower
133 | ├── guitar
134 | └── penguin
135 | └── unsorted
136 |
137 | It is expected that the directories at the bottom layer are named after the label they contain.
138 | It is expected that the names match between the `eeg` and `targets` directory.
139 |
140 | Any files and directories in the 'unsorted' directory will be searched recursively
141 | and moved to the appropriate directory. Their appropriate directory will be inferred
142 | by the file's name.
143 |
144 |
145 | :param eeg_directory: The directory containing EEG images for training. These will be loaded recursively.
146 | :param targets_directory: The directory containing target images for training. These will be loaded recursively.
147 | :param labels: The list of labels to be used for training. If left blank, this defaults to guitar, penguin and flower.
148 | :param shuffle_eeg: If the EEG training images should be sampled from randomly.
149 | :param shuffle_targets: If the target training images for each respective EEG image should be sampled from randomly.
150 | :param file_types: The list of file types to load.
151 | """
152 | if unsorted_target_policy is None:
153 | unsorted_target_policy = ['move', 'delete-src-dirs']
154 | if unsorted_eeg_policy is None:
155 | unsorted_eeg_policy = ['move', 'delete-src-dirs']
156 | if labels is None:
157 | labels = ['penguin', 'guitar', 'flower']
158 |
159 | if file_types is None:
160 | file_types = ['jpg', 'png']
161 |
162 | self.file_types = file_types
163 | self.labels = labels
164 | self.eeg_directory = eeg_directory
165 | self.targets_directory = targets_directory
166 | self.shuffle_eeg = shuffle_eeg
167 | self.shuffle_targets = shuffle_targets
168 | self.data = defaultdict(lambda: defaultdict(list))
169 | self.indices = {'eeg': {}, 'target': {}}
170 | self.image_size = image_size
171 |
172 | # TODO load eeg recursively, load targets recursively, generate labels for each, group them together
173 |
174 | os.makedirs(eeg_directory, exist_ok=True)
175 | os.makedirs(targets_directory, exist_ok=True)
176 |
177 | for label in labels:
178 | assert os.path.exists(f'{eeg_directory}/{label}'), \
179 | f'The EEG directory for `{label}` does not exist.'
180 | assert os.path.exists(f'{targets_directory}/{label}'), \
181 | f'The targets directory for `{label}` does not exist.'
182 | #
183 | # for ftype in file_types:
184 | # find_and_move_unsorted(eeg_directory, ftype)
185 | # find_and_move_unsorted(targets_directory, ftype)
186 |
187 | for label in labels:
188 | d0 = os.listdir(f'{eeg_directory}/{label}')
189 | d1 = os.listdir(f'{targets_directory}/{label}')
190 |
191 | d0 = [d for d in d0 if not d.startswith('.')]
192 | d1 = [d for d in d1 if not d.startswith('.')]
193 |
194 | if self.shuffle_eeg:
195 | random.shuffle(d0)
196 |
197 | if self.shuffle_targets:
198 | random.shuffle(d1)
199 |
200 | self.data['eeg'][label] = d0
201 | self.data['targets'][label] = d1
202 |
203 | # print(f'Loaded {len(d0)} images for EEG/{label}')
204 | # print(f'Loaded {len(d1)} images for Targets/{label}')
205 |
206 | self.transformTarget = T.Compose([
207 | T.RandomHorizontalFlip(),
208 | T.Resize(self.image_size),
209 | T.ToTensor()
210 | ])
211 |
212 | self.transformEEG = T.Compose([
213 | T.Resize(self.image_size),
214 | T.ToTensor()
215 | ])
216 |
217 | def __getitem__(self, index):
218 | """
219 | Retrieve an item in the dataset at the given index. If shuffling is enabled, the given index is ignored.
220 |
221 | The values returned are in a tuple, and in the order of:
222 | 1. EEG image for label `L`.
223 | 2. Target image for label `L`.
224 | 3. Label `L`.
225 | """
226 | label = random.choice(self.labels)
227 | eeg_sample = random.choice(self.data['eeg'][label])
228 | target_sample = random.choice(self.data['targets'][label])
229 |
230 | eeg_sample = Image.open(f'{self.eeg_directory}/{label}/{eeg_sample}')
231 | target_sample = Image.open(f'{self.targets_directory}/{label}/{target_sample}').convert('RGB')
232 |
233 | if target_sample.mode in ('RGBA', 'LA') or (target_sample.mode == 'P' and 'transparency' in target_sample.info):
234 | alpha = target_sample.convert('RGBA').split()[-1]
235 | bg = Image.new("RGBA", target_sample.size, (255, 255, 255) + (255,))
236 | bg.paste(target_sample, mask=alpha)
237 | target_sample = bg
238 |
239 | eeg_sample = self.transformEEG(eeg_sample)
240 | eeg_sample = eeg_sample.repeat(3, 1, 1)
241 | return eeg_sample, self.transformTarget(target_sample), label
242 |
243 | def __len__(self):
244 | return sum(len(d[label]) for label in self.labels for d in self.data.values())
245 |
246 |
247 | class Trainer(object):
248 | """
249 | This class is responsible for the training, sampling and saving loops that
250 | take place when interacting with the model.
251 | """
252 |
253 | def __init__(
254 | self,
255 | diffusion_model,
256 | training_images_dir, # Folder where the training images exist
257 | # TODO add a secondary folder for the target image-class pairings
258 | # TODO add a means of loading and reading in those image-class pairings
259 | *,
260 | train_batch_size=16,
261 | gradient_accumulate_every=1,
262 | augment_horizontal_flip=True, # Flip image from left-to-right
263 | training_learning_rate=1e-4,
264 | num_training_steps=100000,
265 | ema_update_every=10,
266 | ema_decay=0.995,
267 | adam_betas=(0.9, 0.99),
268 | save_and_sample_every=1000,
269 | num_samples=25,
270 | results_folder='./results',
271 | amp=False, # Used mixed precision during training
272 | fp16=False, # Use Floating-Point 16-bit precision
273 | # TODO might be able to enable fp16 without affecting amp,
274 | # allowing for the model to train on TPUs
275 | split_batches=True,
276 | convert_image_to_ext=None, # A given extension to convert image types to
277 | use_wandb=True
278 | ):
279 | """
280 | :param split_batches: If the batch of images loaded should be split by
281 | accelerator across all devices, or treated as a per-device batch count.
282 | For example, with a batch size of 32 and 8 devices, split_batches=True
283 | would put 4 items on each device.
284 | """
285 | super().__init__()
286 |
287 | self.accelerator = Accelerator(
288 | split_batches=split_batches,
289 | mixed_precision='fp16' if fp16 else 'no',
290 | gradient_accumulation_steps=gradient_accumulate_every,
291 | device_placement=True
292 | )
293 |
294 | self.accelerator.native_amp = amp
295 |
296 | assert has_int_square_root(num_samples), 'The number of samples must have an integer square root'
297 |
298 | self.diffusion_model = diffusion_model
299 | self.num_samples = num_samples
300 | self.save_and_sample_every = save_and_sample_every
301 | self.batch_size = train_batch_size
302 | self.gradient_accumulate_every = gradient_accumulate_every
303 | self.train_num_steps = num_training_steps
304 | self.image_size = diffusion_model.image_size
305 |
306 | self.train_images_dataset = EEGTargetsDataset()
307 | dataloader = DataLoader(self.train_images_dataset,
308 | batch_size=train_batch_size,
309 | shuffle=True,
310 | pin_memory=True,
311 | num_workers=multiprocessing.cpu_count())
312 |
313 | # dataloader = self.accelerator.prepare(dataloader)
314 | # self.train_images_dataloader = cycle(dataloader)
315 | dataloader = self.accelerator.prepare(dataloader)
316 | self.train_eeg_targets_dataloader = cycle(dataloader)
317 |
318 | # optimizer
319 |
320 | self.optimiser = Adam(diffusion_model.parameters(), lr=training_learning_rate, betas=adam_betas)
321 |
322 | # for logging results in a folder periodically
323 |
324 | if self.accelerator.is_main_process:
325 | self.ema = EMA(diffusion_model, beta=ema_decay, update_every=ema_update_every)
326 |
327 | self.results_folder = Path(f'{results_folder}/{wandb.run.name}-{wandb.run.id}')
328 | self.results_folder.mkdir(exist_ok=True, parents=True)
329 |
330 | # Step counter
331 | self.step = 0
332 |
333 | self.diffusion_model, self.optimiser = self.accelerator.prepare(self.diffusion_model, self.optimiser)
334 | self.diffusion_model.learning_model = self.accelerator.prepare(self.diffusion_model.learning_model)
335 | # wandb.login(key=os.environ['WANDB_API_KEY']) # Uncomment if `wandb login` does not work in the console
336 |
337 | def save(self, milestone):
338 | if not self.accelerator.is_local_main_process:
339 | return
340 |
341 | data = {
342 | 'step': self.step,
343 | 'model': self.accelerator.get_state_dict(self.diffusion_model),
344 | 'opt': self.optimiser.state_dict(),
345 | 'ema': self.ema.state_dict(),
346 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
347 | }
348 |
349 | torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
350 |
351 | def load(self, path, milestone):
352 | data = torch.load(str(f'{path}/model-{milestone}.pt'))
353 |
354 | model = self.accelerator.unwrap_model(self.diffusion_model)
355 | model.load_state_dict(data['model'])
356 |
357 | self.step = data['step']
358 | self.optimiser.load_state_dict(data['opt'])
359 | self.ema.load_state_dict(data['ema'])
360 |
361 | if exists(self.accelerator.scaler) and exists(data['scaler']):
362 | self.accelerator.scaler.load_state_dict(data['scaler'])
363 |
364 | def train(self):
365 | accelerator = self.accelerator
366 | device = accelerator.device
367 |
368 | with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:
369 |
370 | while self.step < self.train_num_steps:
371 |
372 | total_loss = 0.
373 |
374 | for _ in range(self.gradient_accumulate_every):
375 | eeg_sample, target_sample, _ = next(self.train_eeg_targets_dataloader)
376 | # eeg_sample.to(device)
377 | # target_sample.to(device)
378 | data = (eeg_sample, target_sample)
379 | # eeg_sample, target_sample, label
380 |
381 | with self.accelerator.autocast():
382 | loss = self.diffusion_model(data)
383 | loss = loss / self.gradient_accumulate_every
384 | total_loss += loss.item()
385 | total_loss += loss
386 |
387 | self.accelerator.backward(loss)
388 |
389 | wandb.log({'total_training_loss': total_loss, 'training_timestep': self.step})
390 | pbar.set_description(f'loss: {total_loss:.4f}')
391 |
392 | accelerator.wait_for_everyone()
393 |
394 | self.optimiser.step()
395 | self.optimiser.zero_grad()
396 |
397 | accelerator.wait_for_everyone()
398 |
399 | if accelerator.is_main_process:
400 | self.ema.to(device)
401 | self.ema.update()
402 |
403 | if self.step != 0 and self.step % self.save_and_sample_every == 0:
404 | self.ema.ema_model.eval()
405 |
406 | with torch.no_grad():
407 | milestone = self.step // self.save_and_sample_every
408 | batches = num_to_groups(self.num_samples, self.batch_size)
409 | all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
410 |
411 | all_images = torch.cat(all_images_list, dim=0)
412 | utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'),
413 | nrow=int(math.sqrt(self.num_samples)))
414 | self.save(milestone)
415 |
416 | self.step += 1
417 | pbar.update(1)
418 |
419 | accelerator.print('Training complete!')
420 |
421 |
422 | def cycle(dataloader):
423 | while True:
424 | for data in dataloader:
425 | yield data
426 |
427 |
428 | def has_int_square_root(num):
429 | return (math.sqrt(num) ** 2) == num
430 |
431 |
432 | def num_to_groups(num, divisor):
433 | groups = num // divisor
434 | remainder = num % divisor
435 | arr = [divisor] * groups
436 | if remainder > 0:
437 | arr.append(remainder)
438 | return arr
439 |
440 |
441 | def convert_image_to_fn(img_type, image):
442 | if image.mode != img_type:
443 | return image.convert(img_type)
444 | return image
445 |
446 |
447 | def compute_L2_norm(t):
448 | """
449 | Compute the L2 normalised value for some given value t.
450 | :param t: The value to compute the L2 norm against.
451 | """
452 | return F.normalize(t, dim=-1)
453 |
454 |
455 | def normalise_to_negative_one_to_one(img):
456 | return img * 2 - 1
457 |
458 |
459 | def unnormalise_to_zero_to_one(t):
460 | return (t + 1) * 0.5
461 |
462 |
463 | def upsample(dim, dim_out=None):
464 | return nn.Sequential(
465 | nn.Upsample(scale_factor=2, mode='nearest'),
466 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1)
467 | )
468 |
469 |
470 | def downsample(input_channel_dims, output_channel_dims=None):
471 | """
472 | This method creates a downsampling convolutional layer of the U-Net architecture.
473 |
474 | :param input_channel_dims: The channel dimensions for the input to the layer.
475 | :param output_channel_dims: The channel dimensions for the output of the layer.
476 | """
477 | return nn.Conv2d(
478 | in_channels=input_channel_dims,
479 | out_channels=default(output_channel_dims, input_channel_dims),
480 | kernel_size=4,
481 | stride=2,
482 | padding=1)
483 |
484 |
485 | def extract(a, t, x_shape):
486 | b, *_ = t.shape
487 | out = a.gather(-1, t)
488 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
489 |
490 |
491 | def linear_beta_schedule(timesteps):
492 | scale = 1000 / timesteps
493 | beta_start = scale * 0.0001
494 | beta_end = scale * 0.02
495 | return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
496 |
497 |
498 | def cosine_beta_schedule(timesteps, s=0.008):
499 | """
500 | Cosine beta schedule
501 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
502 | :param s: The parameter that dictates the 'shift' of the beta values.
503 | A lower and higher value will cause beta to begin at a higher and lower value, respectively.
504 | """
505 | steps = timesteps + 1
506 | x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
507 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
508 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
509 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
510 | return torch.clip(betas, 0, 0.999)
511 |
512 |
513 | def default(val, func):
514 | """
515 | This method simply checks if the parameter val exists.
516 | If it does not, parameter func is either returned,
517 | or executed if it is itself a function.
518 |
519 | :param val: The value to be checked for its existence. See method exists for more.
520 | :param func: The value or function to be returned or executed (respectively) if val does not exist.
521 | """
522 | if exists(val):
523 | return val
524 | return func() if callable(func) else func
525 |
--------------------------------------------------------------------------------
/images/denoising-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/images/denoising-diffusion.png
--------------------------------------------------------------------------------
/images/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/images/sample.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import wandb
4 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
5 |
6 | torch.cuda.empty_cache()
7 | wandb.login()
8 |
9 | # wandb.config.learning_rate = 3e-4
10 | # wandb.config.training_timesteps = 5000
11 | # wandb.config.sampling_timesteps = 250
12 | # wandb.config.image_size = 32
13 | # wandb.config.number_of_samples = 25
14 | # wandb.config.batch_size = 512
15 | # wandb.config.use_amp = False
16 | # wandb.config.use_fp16 = True
17 | # wandb.config.gradient_accumulation_rate = 2
18 | # wandb.config.ema_update_rate = 10
19 | # wandb.config.ema_decay = 0.995
20 | # wandb.config.adam_betas = (0.9, 0.99)
21 | # wandb.config.save_and_sample_rate = 1000
22 | # wandb.config.do_split_batches = False
23 | # wandb.config.timesteps = 1000
24 | # wandb.config.loss_type = 'L2'
25 | # wandb.config.unet_dim = 16
26 | # wandb.config.unet_mults = (1, 2, 4, 8)
27 | # wandb.config.unet_channels = 3
28 | # wandb.config.training_objective = 'pred_x0'
29 |
30 | default_hypers = dict(
31 | learning_rate=3e-4,
32 | training_timesteps=1001,
33 | sampling_timesteps=250,
34 | image_size=32,
35 | number_of_samples=25,
36 | batch_size=256,
37 | use_amp=False,
38 | use_fp16=False,
39 | gradient_accumulation_rate=2,
40 | ema_update_rate=10,
41 | ema_decay=0.995,
42 | adam_betas=(0.9, 0.99),
43 | save_and_sample_rate=1000,
44 | do_split_batches=False,
45 | timesteps=1000,
46 | loss_type='L2',
47 | unet_dim=16,
48 | unet_mults=(1, 2, 4, 8),
49 | unet_channels=3,
50 | training_objective='pred_x0'
51 | )
52 |
53 | wandb.init(config=default_hypers, project='bath-thesis', entity='jd202')
54 |
55 | # with open('./sweep.yaml') as f:
56 | # sweep_config = yaml.load(f, Loader=SafeLoader)
57 | #
58 | # sweep_id = wandb.sweep(sweep_config, entity='jd202', project='bath-thesis')
59 |
60 |
61 | model = Unet(
62 | dim=wandb.config.unet_dim,
63 | dim_mults=wandb.config.unet_mults,
64 | channels=wandb.config.unet_channels
65 | )
66 |
67 | diffusion = GaussianDiffusion(
68 | model,
69 | image_size=wandb.config.image_size,
70 | timesteps=wandb.config.timesteps, # number of steps
71 | sampling_timesteps=wandb.config.sampling_timesteps,
72 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
73 | loss_type=wandb.config.loss_type, # L1 or L2
74 | training_objective=wandb.config.training_objective
75 | )
76 |
77 | trainer = Trainer(
78 | diffusion,
79 | '/Users/jake/Desktop/scp/cifar',
80 | train_batch_size=wandb.config.batch_size,
81 | training_learning_rate=wandb.config.learning_rate,
82 | num_training_steps=wandb.config.training_timesteps, # total training steps
83 | num_samples=wandb.config.number_of_samples,
84 | gradient_accumulate_every=wandb.config.gradient_accumulation_rate, # gradient accumulation steps
85 | ema_update_every=wandb.config.ema_update_rate,
86 | ema_decay=wandb.config.ema_decay, # exponential moving average decay
87 | amp=wandb.config.use_amp, # turn on mixed precision
88 | fp16=wandb.config.use_fp16,
89 | save_and_sample_every=wandb.config.save_and_sample_rate
90 | )
91 |
92 | trainer.load('./results/loadins', '17')
93 |
94 | wandb.watch(model)
95 | wandb.watch(diffusion)
96 |
97 | trainer.train()
98 |
99 | wandb.finish()
100 |
--------------------------------------------------------------------------------
/pytorch-xla-env-setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Sample usage:
3 | # python env-setup.py --version 1.11 --apt-packages libomp5
4 | import argparse
5 | import collections
6 | from datetime import datetime
7 | import os
8 | import platform
9 | import re
10 | import requests
11 | import subprocess
12 | import threading
13 | import sys
14 |
15 | VersionConfig = collections.namedtuple('VersionConfig',
16 | ['wheels', 'tpu', 'py_version', 'cuda_version'])
17 | DEFAULT_CUDA_VERSION = '11.2'
18 | OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d')
19 | NEW_VERSION = datetime.strptime('20220315', '%Y%m%d') # 1.11 release date
20 | OLDEST_GPU_VERSION = datetime.strptime('20200707', '%Y%m%d')
21 | DIST_BUCKET = 'gs://tpu-pytorch/wheels'
22 | TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
23 | TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
24 | TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
25 | VERSION_REGEX = re.compile(r'^(\d+\.)+\d+$')
26 |
27 | def is_gpu_runtime():
28 | return int(os.environ.get('COLAB_GPU', 0)) == 1
29 |
30 |
31 | def is_tpu_runtime():
32 | return 'TPU_NAME' in os.environ
33 |
34 |
35 | def update_tpu_runtime(tpu_name, version):
36 | print(f'Updating TPU runtime to {version.tpu} ...')
37 |
38 | try:
39 | import cloud_tpu_client
40 | except ImportError:
41 | subprocess.call([sys.executable, '-m', 'pip', 'install', 'cloud-tpu-client'])
42 | import cloud_tpu_client
43 |
44 | client = cloud_tpu_client.Client(tpu_name)
45 | client.configure_tpu_version(version.tpu)
46 | print('Done updating TPU runtime')
47 |
48 |
49 | def get_py_version():
50 | version_tuple = platform.python_version_tuple()
51 | return version_tuple[0] + version_tuple[1] # major_version + minor_version
52 |
53 |
54 | def get_cuda_version():
55 | if is_gpu_runtime():
56 | # cuda available, install cuda wheels
57 | return DEFAULT_CUDA_VERSION
58 |
59 |
60 | def get_version(version):
61 | cuda_version = get_cuda_version()
62 | if version == 'nightly':
63 | return VersionConfig(
64 | 'nightly', 'pytorch-nightly', get_py_version(), cuda_version)
65 |
66 | version_date = None
67 | try:
68 | version_date = datetime.strptime(version, '%Y%m%d')
69 | except ValueError:
70 | pass # Not a dated nightly.
71 |
72 | if version_date:
73 | if cuda_version and version_date < OLDEST_GPU_VERSION:
74 | raise ValueError(
75 | f'Oldest nightly version build with CUDA available is {OLDEST_GPU_VERSION}')
76 | elif not cuda_version and version_date < OLDEST_VERSION:
77 | raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}')
78 | return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}',
79 | get_py_version(), cuda_version)
80 |
81 |
82 | if not VERSION_REGEX.match(version):
83 | raise ValueError(f'{version} is an invalid torch_xla version pattern')
84 | return VersionConfig(
85 | version, f'pytorch-{version}', get_py_version(), cuda_version)
86 |
87 |
88 | def install_vm(version, apt_packages, is_root=False):
89 | dist_bucket = DIST_BUCKET
90 |
91 | if version.cuda_version:
92 | # Distributions for GPU runtime
93 | # Note: GPU wheels available from 1.11
94 | dist_bucket = os.path.join(
95 | DIST_BUCKET, 'cuda/{}'.format(version.cuda_version.replace('.', '')))
96 | else:
97 | # Distributions for TPU runtime
98 | # Note: this redirection is required for 1.11 & nightly releases
99 | # because the current 2 VM wheels are not compatible with colab environment.
100 | if version.wheels == 'nightly':
101 | dist_bucket = os.path.join(DIST_BUCKET, 'colab/')
102 | elif 'nightly+' in version.wheels:
103 | build_date = datetime.strptime( version.wheels.split('+')[1], '%Y%m%d')
104 | if build_date >= NEW_VERSION:
105 | dist_bucket = os.path.join(DIST_BUCKET, 'colab/')
106 | elif VERSION_REGEX.match(version.wheels):
107 | minor = int(version.wheels.split('.')[1])
108 | if minor >= 11:
109 | dist_bucket = os.path.join(DIST_BUCKET, 'colab/')
110 | else:
111 | raise ValueError(f'{version} is an invalid torch_xla version pattern')
112 |
113 | torch_whl = TORCH_WHEEL_TMPL.format(
114 | whl_version=version.wheels, py_version=version.py_version)
115 | torch_whl_path = os.path.join(dist_bucket, torch_whl)
116 | torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format(
117 | whl_version=version.wheels, py_version=version.py_version)
118 | torch_xla_whl_path = os.path.join(dist_bucket, torch_xla_whl)
119 | torchvision_whl = TORCHVISION_WHEEL_TMPL.format(
120 | whl_version=version.wheels, py_version=version.py_version)
121 | torchvision_whl_path = os.path.join(dist_bucket, torchvision_whl)
122 | apt_cmd = ['apt-get', 'install', '-y']
123 | apt_cmd.extend(apt_packages)
124 |
125 | if not is_root:
126 | # Colab/Kaggle run as root, but not GCE VMs so we need privilege
127 | apt_cmd.insert(0, 'sudo')
128 |
129 | installation_cmds = [
130 | [sys.executable, '-m', 'pip', 'uninstall', '-y', 'torch', 'torchvision'],
131 | ['gsutil', 'cp', torch_whl_path, '.'],
132 | ['gsutil', 'cp', torch_xla_whl_path, '.'],
133 | ['gsutil', 'cp', torchvision_whl_path, '.'],
134 | [sys.executable, '-m', 'pip', 'install', torch_whl],
135 | [sys.executable, '-m', 'pip', 'install', torch_xla_whl],
136 | [sys.executable, '-m', 'pip', 'install', torchvision_whl],
137 | apt_cmd,
138 | ]
139 | for cmd in installation_cmds:
140 | subprocess.call(cmd)
141 |
142 |
143 | def run_setup(args):
144 | version = get_version(args.version)
145 | # Update TPU
146 | print('Updating... This may take around 2 minutes.')
147 |
148 | if is_tpu_runtime():
149 | update = threading.Thread(
150 | target=update_tpu_runtime, args=(
151 | args.tpu,
152 | version,
153 | ))
154 | update.start()
155 |
156 | install_vm(version, args.apt_packages, is_root=not args.tpu)
157 |
158 | if is_tpu_runtime():
159 | update.join()
160 |
161 |
162 | if __name__ == '__main__':
163 | parser = argparse.ArgumentParser()
164 | parser.add_argument(
165 | '--version',
166 | type=str,
167 | default='20200515',
168 | help='Versions to install (nightly, release version, or YYYYMMDD).',
169 | )
170 | parser.add_argument(
171 | '--apt-packages',
172 | nargs='+',
173 | default=['libomp5'],
174 | help='List of apt packages to install',
175 | )
176 | parser.add_argument(
177 | '--tpu',
178 | type=str,
179 | help='[GCP] Name of the TPU (same zone, project as VM running script)',
180 | )
181 | args = parser.parse_args()
182 | run_setup(args)
183 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.1.0
2 | accelerate==0.13.0.dev0
3 | astunparse==1.6.3
4 | attrs==19.3.0
5 | Automat==0.8.0
6 | blinker==1.4
7 | cachetools==5.2.0
8 | certifi==2022.6.15
9 | chardet==3.0.4
10 | charset-normalizer==2.1.0
11 | Click==7.0
12 | cloud-init==22.2
13 | cloud-tpu-client==0.10
14 | colorama==0.4.3
15 | command-not-found==0.3
16 | commonmark==0.9.1
17 | configobj==5.0.6
18 | constantly==15.1.0
19 | cryptography==2.8
20 | Cython==0.29.14
21 | dbus-python==1.2.16
22 | distlib==0.3.4
23 | distro==1.4.0
24 | distro-info===0.23ubuntu1
25 | docker-pycreds==0.4.0
26 | einops==0.4.1
27 | ema-pytorch==0.0.10
28 | entrypoints==0.3
29 | filelock==3.7.1
30 | flatbuffers==2.0
31 | future==0.18.2
32 | gast==0.4.0
33 | gitdb==4.0.9
34 | GitPython==3.1.27
35 | google-api-core==1.31.6
36 | google-api-python-client==1.8.0
37 | google-auth==2.9.0
38 | google-auth-httplib2==0.1.0
39 | google-auth-oauthlib==0.4.6
40 | google-pasta==0.2.0
41 | googleapis-common-protos==1.56.3
42 | grpcio==1.47.0
43 | h5py==3.7.0
44 | httplib2==0.20.4
45 | hyperlink==19.0.0
46 | idna==3.3
47 | importlib-metadata==4.12.0
48 | incremental==16.10.1
49 | intel-openmp==2022.1.0
50 | Jinja2==2.10.1
51 | jsonpatch==1.22
52 | jsonpointer==2.0
53 | jsonschema==3.2.0
54 | keras==2.9.0
55 | Keras-Applications==1.0.8
56 | Keras-Preprocessing==1.1.2
57 | keyring==18.0.1
58 | language-selector==0.1
59 | launchpadlib==1.10.13
60 | lazr.restfulclient==0.14.2
61 | lazr.uri==1.0.3
62 | libclang==14.0.1
63 | libtpu-nightly==0.1.dev20220518
64 | Markdown==3.3.7
65 | MarkupSafe==1.1.0
66 | mkl==2022.1.0
67 | mkl-include==2022.1.0
68 | mock==4.0.3
69 | more-itertools==4.2.0
70 | netifaces==0.10.4
71 | numpy==1.23.0
72 | oauth2client==4.1.3
73 | oauthlib==3.1.0
74 | opt-einsum==3.3.0
75 | packaging==21.3
76 | pathtools==0.1.2
77 | pexpect==4.6.0
78 | Pillow==9.2.0
79 | platformdirs==2.5.2
80 | promise==2.3
81 | protobuf==3.20.1
82 | psutil==5.9.1
83 | pyasn1==0.4.8
84 | pyasn1-modules==0.2.8
85 | Pygments==2.13.0
86 | PyGObject==3.36.0
87 | PyHamcrest==1.9.0
88 | PyJWT==1.7.1
89 | pymacaroons==0.13.0
90 | PyNaCl==1.3.0
91 | pyOpenSSL==19.0.0
92 | pyparsing==3.0.9
93 | pyrsistent==0.15.5
94 | pyserial==3.4
95 | python-apt==2.0.0+ubuntu0.20.4.7
96 | python-debian===0.1.36ubuntu1
97 | pytz==2022.1
98 | PyYAML==5.4.1
99 | requests==2.28.1
100 | requests-oauthlib==1.3.1
101 | requests-unixsocket==0.2.0
102 | rich==12.5.1
103 | rsa==4.8
104 | SecretStorage==2.3.1
105 | sentry-sdk==1.9.5
106 | service-identity==18.1.0
107 | setproctitle==1.3.2
108 | shortuuid==1.0.9
109 | simplejson==3.16.0
110 | six==1.16.0
111 | smmap==5.0.0
112 | sos==4.3
113 | ssh-import-id==5.10
114 | systemd-python==234
115 | tbb==2021.6.0
116 | tensorboard==2.9.1
117 | tensorboard-data-server==0.6.1
118 | tensorboard-plugin-wit==1.8.1
119 | tensorflow==2.10.0
120 | tensorflow-estimator==2.9.0
121 | tensorflow-io-gcs-filesystem==0.26.0
122 | termcolor==1.1.0
123 | torch==1.12.1
124 | torchvision==0.13.1
125 | tqdm==4.64.0
126 | Twisted==18.9.0
127 | typing-extensions==4.2.0
128 | ubuntu-advantage-tools==27.8
129 | ufw==0.36
130 | unattended-upgrades==0.1
131 | uritemplate==3.0.1
132 | urllib3==1.26.9
133 | virtualenv==20.15.1
134 | wadllib==1.3.3
135 | wandb==0.13.2
136 | Werkzeug==2.1.2
137 | wrapt==1.14.1
138 | zipp==1.0.0
139 | zope.interface==4.7.1
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torchvision.utils import save_image
5 |
6 | import wandb
7 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, utils
8 |
9 | torch.cuda.empty_cache()
10 | wandb.login()
11 |
12 | default_hypers = dict(
13 | learning_rate=3e-4,
14 | training_timesteps=1001,
15 | sampling_timesteps=250,
16 | image_size=32,
17 | number_of_samples=25,
18 | batch_size=256,
19 | use_amp=False,
20 | use_fp16=False,
21 | gradient_accumulation_rate=2,
22 | ema_update_rate=10,
23 | ema_decay=0.995,
24 | adam_betas=(0.9, 0.99),
25 | save_and_sample_rate=1000,
26 | do_split_batches=False,
27 | timesteps=4000,
28 | loss_type='L2',
29 | unet_dim=128,
30 | unet_mults=(1, 2, 2, 2),
31 | unet_channels=3,
32 | training_objective='pred_x0'
33 | )
34 |
35 | wandb.init(config=default_hypers, project='bath-thesis', entity='jd202')
36 |
37 | model = Unet(
38 | dim=wandb.config.unet_dim,
39 | dim_mults=wandb.config.unet_mults,
40 | channels=wandb.config.unet_channels
41 | )
42 |
43 | diffusion = GaussianDiffusion(
44 | model,
45 | image_size=wandb.config.image_size,
46 | timesteps=wandb.config.timesteps, # number of steps
47 | sampling_timesteps=wandb.config.sampling_timesteps,
48 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
49 | loss_type=wandb.config.loss_type, # L1 or L2
50 | training_objective=wandb.config.training_objective
51 | )
52 |
53 | trainer = Trainer(
54 | diffusion,
55 | '/Users/jake/Desktop/scp/cifar',
56 | train_batch_size=wandb.config.batch_size,
57 | training_learning_rate=wandb.config.learning_rate,
58 | num_training_steps=wandb.config.training_timesteps, # total training steps
59 | num_samples=wandb.config.number_of_samples,
60 | gradient_accumulate_every=wandb.config.gradient_accumulation_rate, # gradient accumulation steps
61 | ema_update_every=wandb.config.ema_update_rate,
62 | ema_decay=wandb.config.ema_decay, # exponential moving average decay
63 | amp=wandb.config.use_amp, # turn on mixed precision
64 | fp16=wandb.config.use_fp16,
65 | save_and_sample_every=wandb.config.save_and_sample_rate
66 | )
67 |
68 | trainer.load('./results/loadins', '56')
69 | trainer.ema.ema_model.eval()
70 | with torch.no_grad():
71 | milestone = 10 // 1
72 | batches = utils.num_to_groups(wandb.config.number_of_samples, wandb.config.batch_size)
73 | all_images_list = list(map(lambda n: trainer.ema.ema_model.sample(batch_size=n), batches))
74 |
75 | all_images = torch.cat(all_images_list, dim=0)
76 | save_image(all_images, str(f'results/samples/sample-{milestone}.png'),
77 | nrow=int(math.sqrt(wandb.config.number_of_samples)))
78 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='denoising-diffusion-pytorch',
5 | packages=find_packages(),
6 | version='0.27.4',
7 | license='MIT',
8 | description='Denoising Diffusion Probabilistic Models - Pytorch',
9 | author='DevJake',
10 | long_description_content_type='text/markdown',
11 | keywords=[
12 | 'artificial intelligence',
13 | 'generative models'
14 | ],
15 | install_requires=[
16 | 'accelerate',
17 | 'einops',
18 | 'ema-pytorch',
19 | 'pillow',
20 | 'torch',
21 | 'torchvision',
22 | 'tqdm'
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | cd ~/
2 | #python3 -m pip install virtualenv # It is just simpler this way, trust me...
3 | git clone https://github.com/DevJake/EEG-diffusion-pytorch.git diffusion
4 | cd diffusion
5 | #virtualenv venv
6 | #source venv/bin/activate
7 | sudo python3 setup.py install
8 | #pip3 install wandb accelerate einops tqdm ema_pytorch torchvision
9 | #pip3 install -r requirements.txt
10 | #python3 -m wandb login
11 | accelerate config
12 | sudo apt install rclone
13 | mkdir -p ~/.config/rclone
14 | nano ~/.config/rclone/rclone.conf
15 | # Add in your rclone config to connect to the repository storing all EEG and Targets data
16 | mkdir -p datasets/eeg/unsorted datasets/eeg/flower datasets/eeg/penguin datasets/eeg/guitar
17 | mkdir -p datasets/targets/unsorted datasets/targets/flower datasets/targets/penguin datasets/targets/guitar
18 | cd ~/diffusion/datasets/eeg
19 | #rclone copy gc:/bath-thesis-data/data/outputs/preprocessing . -P
20 | rclone copy gc:/bath-thesis-data/data/subjects/preprocessed/combined . -P
21 | find . -name "*.tar.gz" -exec tar -xf {} \; # this will take some time to run...
22 | find . -name "*.tar.gz" -exec rm -v {} \;
23 | #find ./ -type f -exec mv --backup=numbered {} ./ -v \;
24 | cd ~/diffusion/datasets/targets/
25 | rclone copy gc:/bath-thesis-data/data/classes/32x32.tar . -P
26 | tar -xf 32x32.tar
27 | rm 32x32.tar
28 | mv 32x32/flower-32x32/* flower/ & mv 32x32/guitar-32x32/* guitar/ & mv 32x32/penguin-32x32/* penguin/
29 | rm 32x32 -r
30 | cd ../..
31 | accelerate launch model.py
32 |
33 | #curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
34 | #sudo python3 pytorch-xla-env-setup.py --version 1.12
35 |
36 | # Given the enormous size of model save files, and their frequent saving,
37 | # you can use the following command in a tmux session to have them be backed
38 | # up the the Google Cloud bucket, or another provider of choice. Source directories are not deleted,
39 | # so do not worry about the model crashing from not being able to save!
40 | # while sleep 120; do rclone move ~/diffusion/results/ gc:bath-thesis-data/data/trained_models/ -P; done
41 |
--------------------------------------------------------------------------------
/sweep.yaml:
--------------------------------------------------------------------------------
1 | program: /usr/local/bin/accelerate launch model.py
2 | command:
3 | - ${program}
4 | method: bayes
5 | metric:
6 | goal: minimize
7 | name: total_training_loss
8 | parameters:
9 | ema_update_rate:
10 | max: 20
11 | min: 5
12 | distribution: int_uniform
13 | gradient_accumulation_rate:
14 | max: 20
15 | min: 1
16 | distribution: int_uniform
17 | training_timesteps:
18 | max: 2000
19 | min: 500
20 | distribution: int_uniform
21 | training_objective:
22 | values:
23 | - pred_x0
24 | - pred_noise
25 | distribution: categorical
26 | learning_rate:
27 | max: 0.1
28 | min: 0.00001
29 | distribution: uniform
30 | image_size:
31 | max: 64
32 | min: 16
33 | distribution: int_uniform
34 | batch_size:
35 | max: 512
36 | min: 32
37 | distribution: int_uniform
38 | timesteps:
39 | max: 8000
40 | min: 200
41 | distribution: int_uniform
42 | loss_type:
43 | values:
44 | - L2
45 | - L1
46 | distribution: categorical
47 | ema_decay:
48 | max: 1.99
49 | min: 0.4
50 | distribution: uniform
51 | unet_dim:
52 | max: 128
53 | min: 8
54 | distribution: int_uniform
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 |
3 | from denoising_diffusion_pytorch.utils import EEGTargetsDataset, cycle
4 |
5 | dataset = EEGTargetsDataset()
6 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, pin_memory=True, num_workers=0)
7 | dataloader = cycle(dataloader)
8 | data = next(dataloader)
9 |
10 | print(data)
11 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/util/__init__.py
--------------------------------------------------------------------------------
/util/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/util/configs/__init__.py
--------------------------------------------------------------------------------
/util/configs/generate_configs.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import uuid
5 | import util.preprocessing.pipeline as pipe
6 |
7 | if __name__ == '__main__':
8 | samples_dir = './samples'
9 | variations = None
10 |
11 | with open(f'./variations.json', 'r') as f:
12 | variations = json.load(f)
13 |
14 | for w, _ in variations['window_overlap']:
15 | assert len(pipe.get_output_dims_by_factors(w * 1024)) > 0, f'No factor pairs exist for a window width of {w}!'
16 |
17 | combinations = [
18 | [window_overlap, use_channels, high_pass, low_pass]
19 | for window_overlap in variations['window_overlap']
20 | for use_channels in variations['use_channels']
21 | for high_pass in variations['high_pass']
22 | for low_pass in variations['low_pass']
23 | ]
24 |
25 | print(f'Calculated {len(combinations)} total variants...')
26 | print(f'Estimated configuration count is {len(combinations) * len(list(glob.iglob(samples_dir + "/**.json")))}')
27 |
28 | for i, sample in enumerate(glob.iglob(f'{samples_dir}/*.json')):
29 | print('Operating on config #', i)
30 |
31 | os.makedirs(f'{samples_dir}/generated/config_{i}', exist_ok=True)
32 |
33 | conf = None
34 | with open(sample, 'r') as f:
35 | conf = json.load(f)
36 |
37 | for wo, use_channels, high_pass, low_pass in combinations:
38 | window_size, window_overlap = wo
39 |
40 | config = conf.copy()
41 |
42 | # config['META.CONFIG_NAME'] += f'_window_size={window_size}'
43 | # config['META.CONFIG_NAME'] += f'_window_overlap={window_overlap}'
44 | # config['META.CONFIG_NAME'] += f'_use_channels={use_channels}'
45 | # config['META.CONFIG_NAME'] += f'_highpass={high_pass}'
46 | # config['META.CONFIG_NAME'] += f'_lowpass={low_pass}'
47 |
48 | unique_id = str(uuid.uuid4())
49 |
50 | config['META.CONFIG_NAME'] += '--' + unique_id
51 |
52 | config['PREPROCESSING.HIGH_PASS_FILTER.FREQ'] = high_pass
53 | config['PREPROCESSING.LOW_PASS_FILTER.FREQ'] = low_pass
54 | config['RENDER.WINDOW_OVERLAP'] = window_overlap
55 | config['RENDER.WINDOW_WIDTH'] = window_size
56 | config['RENDER.DO_PER_CHANNEL'] = use_channels
57 |
58 | with open(f'{samples_dir}/generated/config_{i}/{config["META.CONFIG_NAME"]}.json', 'w') as f_o:
59 | json.dump(config, f_o, indent=4, sort_keys=True)
60 |
--------------------------------------------------------------------------------
/util/configs/samples/all-but-no-filters-or-dc.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "all-but-no-filters-or-dc",
3 | "META.CONFIG_COMMENT": "All settings and default values, no filtering or DC removal",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false,
6 | "PREPROCESSING.DO_MONTAGE": true,
7 | "PREPROCESSING.DO_REMOVE_DC": false,
8 | "PREPROCESSING.DO_REMOVE_EOG": true,
9 | "PREPROCESSING.DO_USE_ICA": true,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/all-but-no-filters.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "all-but-no-filters",
3 | "META.CONFIG_COMMENT": "All settings and default values, no filtering",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false,
6 | "PREPROCESSING.DO_MONTAGE": true,
7 | "PREPROCESSING.DO_REMOVE_DC": true,
8 | "PREPROCESSING.DO_REMOVE_EOG": true,
9 | "PREPROCESSING.DO_USE_ICA": true,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/default-config-copy.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "default-config-copy",
3 | "META.CONFIG_COMMENT": "A copy of the default config with all settings enabled and default values",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true,
6 | "PREPROCESSING.DO_MONTAGE": true,
7 | "PREPROCESSING.DO_REMOVE_DC": true,
8 | "PREPROCESSING.DO_REMOVE_EOG": true,
9 | "PREPROCESSING.DO_USE_ICA": true,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/ica-only.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "ica-only",
3 | "META.CONFIG_COMMENT": "All settings and default values, no filtering, DC removal or EOG noise removal",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false,
6 | "PREPROCESSING.DO_MONTAGE": true,
7 | "PREPROCESSING.DO_REMOVE_DC": false,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": true,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/raw-high-pass-ica.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "raw-high-pass-ica",
3 | "META.CONFIG_COMMENT": "Apply high pass and ICA only",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false,
6 | "PREPROCESSING.DO_MONTAGE": false,
7 | "PREPROCESSING.DO_REMOVE_DC": false,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": true,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/raw-high-pass-no-ica.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "raw-high-pass-no-ica",
3 | "META.CONFIG_COMMENT": "Apply high pass only",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false,
6 | "PREPROCESSING.DO_MONTAGE": false,
7 | "PREPROCESSING.DO_REMOVE_DC": false,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": false,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/raw-loss-pass-ica.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "raw-low-pass-ica",
3 | "META.CONFIG_COMMENT": "Apply low pass filtering and ICA only",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true,
6 | "PREPROCESSING.DO_MONTAGE": false,
7 | "PREPROCESSING.DO_REMOVE_DC": false,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": true,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/raw-low-pass-no-ica.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "raw-low-pass-no-ica",
3 | "META.CONFIG_COMMENT": "Apply low pass only",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true,
6 | "PREPROCESSING.DO_MONTAGE": false,
7 | "PREPROCESSING.DO_REMOVE_DC": false,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": false,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/raw-only.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "raw-only",
3 | "META.CONFIG_COMMENT": "Apply no preprocessing or ICA",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": false,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": false,
6 | "PREPROCESSING.DO_MONTAGE": false,
7 | "PREPROCESSING.DO_REMOVE_DC": false,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": false,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/remove-all-noise-ica-no-eog.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "remove-all-noise-ica-no-eog",
3 | "META.CONFIG_COMMENT": "Remove all noise, use ICA, do not remove EOG noise",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true,
6 | "PREPROCESSING.DO_MONTAGE": true,
7 | "PREPROCESSING.DO_REMOVE_DC": true,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": true,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/remove-all-noise-no-ica.json:
--------------------------------------------------------------------------------
1 | {
2 | "META.CONFIG_NAME": "remove-all-noise-no-ica",
3 | "META.CONFIG_COMMENT": "Remove all noise, do not use ICA",
4 | "PREPROCESSING.DO_HIGH_PASS_FILTER": true,
5 | "PREPROCESSING.DO_LOW_PASS_FILTER": true,
6 | "PREPROCESSING.DO_MONTAGE": true,
7 | "PREPROCESSING.DO_REMOVE_DC": true,
8 | "PREPROCESSING.DO_REMOVE_EOG": false,
9 | "PREPROCESSING.DO_USE_ICA": false,
10 | "PREPROCESSING.HIGH_PASS_FILTER.FREQ": 50,
11 | "PREPROCESSING.LOW_PASS_FILTER.FREQ": 0.1,
12 | "RENDER.DO_PER_CHANNEL": false,
13 | "RENDER.WINDOW_OVERLAP": 0.5,
14 | "RENDER.WINDOW_WIDTH": 1
15 | }
--------------------------------------------------------------------------------
/util/configs/samples/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.12.0
2 | appdirs==1.4.4
3 | certifi==2022.6.15
4 | charset-normalizer==2.1.1
5 | click==8.1.3
6 | cycler==0.11.0
7 | decorator==5.1.1
8 | docker-pycreds==0.4.0
9 | einops==0.4.1
10 | ema-pytorch==0.0.10
11 | fonttools==4.37.1
12 | gitdb==4.0.9
13 | GitPython==3.1.27
14 | idna==3.3
15 | Jinja2==3.1.2
16 | kiwisolver==1.4.4
17 | MarkupSafe==2.1.1
18 | matplotlib==3.5.3
19 | mne==1.1.1
20 | numpy==1.23.2
21 | packaging==21.3
22 | pathtools==0.1.2
23 | Pillow==9.2.0
24 | pooch==1.6.0
25 | promise==2.3
26 | protobuf==3.20.1
27 | psutil==5.9.1
28 | pyparsing==3.0.9
29 | python-dateutil==2.8.2
30 | PyYAML==6.0
31 | requests==2.28.1
32 | scipy==1.9.0
33 | sentry-sdk==1.9.5
34 | setproctitle==1.3.2
35 | shortuuid==1.0.9
36 | six==1.16.0
37 | smmap==5.0.0
38 | torch==1.12.1
39 | torchvision==0.13.1
40 | tqdm==4.64.0
41 | typing_extensions==4.3.0
42 | urllib3==1.26.12
43 | wandb==0.13.2
44 |
--------------------------------------------------------------------------------
/util/configs/variations.json:
--------------------------------------------------------------------------------
1 | {
2 | "high_pass": [
3 | 30,
4 | 50
5 | ],
6 | "low_pass": [
7 | 1,
8 | 0.5,
9 | 0.1
10 | ],
11 | "use_channels": [
12 | true,
13 | false
14 | ],
15 | "window_overlap": [
16 | [
17 | 1,
18 | 0.5
19 | ],
20 | [
21 | 1.5,
22 | 0.4
23 | ],
24 | [
25 | 0.5,
26 | 0.2
27 | ],
28 | [
29 | 2,
30 | 1
31 | ]
32 | ]
33 | }
--------------------------------------------------------------------------------
/util/move.py:
--------------------------------------------------------------------------------
1 | labels = ['flower', 'penguin', 'guitar']
2 | import os
3 | from pathlib import Path
4 |
5 | import tqdm
6 |
7 | cpt = sum([len(files) for r, d, files in os.walk('./')])
8 |
9 | t = 0
10 | for path in tqdm.tqdm(Path('./').rglob(f'*.jpg'), initial=0, total=cpt):
11 | split = str(path).split('/')
12 | subject, session, name = split[0], split[1], split[-1]
13 |
14 | label = None
15 | label = 'penguin' if 'penguin' in name else label
16 | label = 'guitar' if 'guitar' in name else label
17 | label = 'flower' if 'flower' in name else label
18 |
19 | new_name = f'{label}/{str(t).zfill(8)}_{subject}_{session}_{name}'
20 | os.rename(path, f'./{new_name}')
21 | t += 1
22 |
--------------------------------------------------------------------------------
/util/preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevJake/EEG-diffusion-pytorch/dd806fb6b4bd87e1cee7ae26aa2128e774c92653/util/preprocessing/__init__.py
--------------------------------------------------------------------------------
/util/preprocessing/pipeline.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """EEG to Dataset Pipeline
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/1xR3KqzFsQt5oF1ot8ZzYS49FkIkaZRWV
8 |
9 | This notebook intends to produce the following datasets on a given EEG RAW dataset:
10 | 1. Denoised EEG with no Gaussian noise and sliding window
11 | 2. Denoised EEG with Gaussian noise and sliding window
12 | 3. Noisy EEG with no Gaussian noise and sliding window
13 | 4. Noisy EEG with Gaussian noise and sliding window
14 |
15 | Parameters therefore include:
16 | - Gaussian noise or none
17 | - Amount of Gaussian noise
18 | - Noisy or not
19 | - If denoising, what methods?
20 | - ICA - and number of channels
21 | - Bad channel removal
22 | - ECG, EOG and EEG removal
23 | - Low pass filters
24 | - Sliding window width
25 | - Window overlap amount
26 |
27 |
28 | Output datasets must also be reshaped into consistent image sizes, such as 64x64, 32x32, 64x32, etc. This is dependent on the possible configurations, which is dictated by the sliding window width and overlap amount. This can be calculated by finding the factoring pairs of the flattened EEG signal vector size.
29 |
30 | Each output window should not only be a .jpg image, but should also be either greyscale or black & white. Furthermore, a CSV is needed to contain metadata, such as subject/participant, and what they were imagining/perceiving in the moment.
31 | """
32 |
33 | # Load the given EEG file
34 |
35 | sub_cap_sizes = { # [subject] = capsize
36 | 2: 'N/A',
37 | 3: 'L',
38 | 4: 'L',
39 | 5: 'M',
40 | 6: 'L',
41 | 7: 'L',
42 | 8: 'M',
43 | 9: 'N/A',
44 | 10: 'M',
45 | 11: 'M',
46 | 12: 'L',
47 | 13: 'L',
48 | 14: 'L',
49 | 15: 'N/A',
50 | 16: 'L',
51 | 17: 'L'
52 | }
53 |
54 |
55 | # subject, session = 10, 1
56 |
57 | def load_eeg(subject: int, session: int):
58 | raw_file = f'./data/subjects/Subject {subject}/Session {session}/sub{subject}_sess{session}.set'
59 | eog_channels = ['VEOGL', 'VEOGU', 'HEOGR', 'HEOGL'] # electroculogram electrodes
60 | raw = mne.io.read_raw_eeglab(raw_file, preload=True, eog=eog_channels)
61 |
62 | raw.info['bads'] += 'CCP1h' if sub_cap_sizes[subject] == 'L' else []
63 | try:
64 | raw = raw.pick_types(eog=False, eeg=True, exclude='bads') # Effectively drop all EOG channels, leaving just EEG
65 | except RuntimeError:
66 | print('Exception has occurred when ignoring bad channels, selecting without explicit exclusions...')
67 | raw.info['bads'] = []
68 | raw = raw.pick_types(eeg=True, eog=False, exclude='bads')
69 | raw = raw.interpolate_bads(reset_bads=False)
70 |
71 | return raw
72 |
73 |
74 | from matplotlib import pyplot as plt
75 |
76 |
77 | def generate_sample_image_from_eeg(start_t, end_t, image_height=64, channel=0):
78 | SAMPLE_RATE = 1024
79 | plt.gray()
80 | return plt.imshow(raw.get_data(channel, start_t * SAMPLE_RATE, end_t * SAMPLE_RATE)[0, :].reshape(image_height, -1),
81 | interpolation='nearest')
82 |
83 |
84 | # display(generate_sample_image_from_eeg(500, 510))
85 |
86 | import mne
87 | import json
88 | import numpy as np
89 | import mne.preprocessing as preprocessing
90 |
91 |
92 | def apply_montage(raw, montage_file_path: str):
93 | with open(montage_file_path, 'r') as f:
94 | raw.set_montage(mne.channels.make_dig_montage(json.load(f), coord_frame='head'))
95 | return raw
96 |
97 |
98 | def compute_ICA(raw, channels: int = 20, random_state: int = 0, reject_dict={'mag': 5e-12, 'grad': 4000e-13},
99 | max_iter=800, method: str = 'fastica'):
100 | ica = mne.preprocessing.ICA(n_components=channels, random_state=random_state, max_iter=max_iter, method=method)
101 | return ica.fit(raw, reject=reject_dict)
102 |
103 |
104 | def remove_EOG(raw, ica, eog_channel_names=['Fp1', 'Fp2'], ica_z_threshold=1.96): # Requires ICA
105 | eog_epochs = preprocessing.create_eog_epochs(raw, ch_name=['Fp1', 'Fp2']).average()
106 | eog_epochs.apply_baseline(baseline=(None, -0.2))
107 |
108 | eog_indices, _ = ica.find_bads_eog(eog_epochs, ch_name=eog_channel_names, threshold=ica_z_threshold)
109 | ica.exclude.extend(eog_indices)
110 |
111 | return ica
112 |
113 |
114 | def remove_ECG(raw, ica, ica_z_threshold=1.96): # Requires ICA
115 | ecg_epochs = preprocessing.create_ecg_epochs(raw).average()
116 | ecg_epochs.apply_baseline(baseline=(None, -0.2))
117 |
118 | ecg_indices, _ = ica.find_bads_ecg(ecg_epochs, threshold=ica_z_threshold)
119 | ica.exclude.extend(ecg_indices)
120 |
121 | return ica
122 |
123 |
124 | def remove_DC(raw, notches=np.arange(50, 251, 50)):
125 | # https://mne.tools/stable/generated/mne.io.Raw.html#mne.io.Raw.notch_filter
126 | return raw.notch_filter(notches, picks='eeg', filter_length='auto', phase='zero-double', fir_design='firwin')
127 |
128 |
129 | def apply_filter(raw, low_freq=None, high_freq=None):
130 | return raw.filter(low_freq, high_freq)
131 |
132 |
133 | def apply_ICA_to_RAW(raw, ica, dimensions_to_keep=124):
134 | return ica.apply(raw, n_pca_components=dimensions_to_keep)
135 |
136 |
137 | # TODO comment and document all code
138 | # TODO resolve all TODOs...
139 |
140 | # raw.plot_psd(fmin = 0,fmax=50, n_fft=2048, spatial_colors=True)
141 | # raw.info
142 |
143 | # raw.get_montage().plot(kind='topomap', show_names=True, sphere='auto')
144 |
145 | def generate_events(raw):
146 | events, event_ids = mne.events_from_annotations(raw, verbose=False)
147 | epochs = mne.Epochs(raw=raw, events=events, event_id=event_ids, preload=True, tmin=0, tmax=4, baseline=None,
148 | event_repeated='merge')
149 | # Gather all epochs within the raw data, assigning a corresponding event ID
150 |
151 | events_list = {(a, b, c): [] for a in ['imagined', 'perceived'] for b in ['flower', 'guitar', 'penguin'] for c in
152 | ['text', 'pictorial', 'sound']}
153 | # Generate a list of possible event combinations. There are 18 in total
154 |
155 | for id in event_ids:
156 | idl = id.lower()
157 | a = 'imagined' if 'imagination' in idl else 'perceived'
158 | b = 'guitar' if 'guitar' in idl else 'flower' if 'flower' in idl else 'penguin'
159 | c = 'text' if '_t_' in idl else 'pictorial' if '_image_' in idl else 'sound'
160 | events_list[(a, b, c)].append(id)
161 | # Map the ID of every real event to a corresponding event_id from the events_list variable
162 |
163 | for i, event in enumerate(events_list.items()):
164 | key, value = event
165 | name = '_'.join(key)
166 | try:
167 | mne.epochs.combine_event_ids(epochs, value, {name: 500 + i}, copy=False)
168 | # Merge all events with the same ID into a new ID, effectively grouping them for easy access
169 | except KeyError:
170 | print('KeyError whilst combining event IDs... skipping this ID')
171 | continue
172 |
173 | return events, event_ids, epochs, events_list
174 |
175 |
176 | def select_specific_epochs(epochs, A: list, B: list, C: list):
177 | select_epochs = [[a, b, c] for a in A for b in B for c in C]
178 | select_epochs = ['_'.join(k) for k in select_epochs]
179 | select_epochs = epochs[select_epochs]
180 | return select_epochs
181 |
182 |
183 | def crop_epochs(epochs, cropping_rules=None):
184 | if cropping_rules == None:
185 | cropping_rules = {
186 | 'imagined': {
187 | 'sound': 4,
188 | 'pictorial': 4,
189 | 'text': 4
190 | },
191 | 'perceived': {
192 | 'sound': 4,
193 | 'pictorial': 3,
194 | 'text': 3
195 | }
196 | }
197 |
198 | for i, epoch in enumerate(zip(epochs, list(epochs.event_id.keys()))):
199 | epoch, name = epoch
200 | a, _, c = name.split('_')
201 |
202 | epochs[name].crop(tmin=0, tmax=cropping_rules[a][c])
203 |
204 | return epochs
205 |
206 |
207 | def get_output_dims_by_factors(vector_length: int, orientation='landscape'):
208 | orientation = orientation.lower()
209 |
210 | pairs = np.array(
211 | [[p, vector_length // p] for p in range(2, int(np.sqrt(vector_length) + 1)) if vector_length % p == 0])
212 |
213 | assert pairs.size > 0, f'No factor pairs exist for the given value: {vector_length}'
214 | assert orientation in ['landscape', 'portrait'], 'The orientation must be one of either landscape or portrait. '
215 |
216 | squarest = np.sort(pairs[np.argmin(np.abs(pairs[:, 0] - pairs[:, 1]))])
217 | return squarest if orientation == 'portrait' else squarest[::-1]
218 |
219 |
220 | def calc_total_windows(vector_length: int, window_width: int, overlap: int = 1, sample_rate=1024):
221 | return 1 + int((vector_length - (window_width * sample_rate)) / (overlap * sample_rate))
222 |
223 |
224 | def split_vector_to_window_indices(vector_length: int, window_width=1, window_overlap=0.5,
225 | sample_rate=1024): # return (a, b) pairs from which to split the vector against
226 | window_width *= sample_rate
227 | window_overlap *= sample_rate
228 | # The window_width and window_overlap are products of the sample_rate,
229 | # so if 1 second is 1024Hz, then a 0.9 window_overlap is 0.9*sample_rate
230 |
231 | assert window_width % 1 == 0, f'The given window width does not produce an integer sample count. window_width={window_width / sample_rate}, sample_rate={sample_rate}'
232 | assert window_overlap % 1 == 0, f'The given window overlap does not produce an integer sample count. window_overlap={window_overlap / sample_rate}, sample_rate={sample_rate}'
233 |
234 | window_width, window_overlap = int(window_width), int(window_overlap)
235 |
236 | assert window_width != window_overlap, 'The window width cannot equal the overlap, as this produces no windows!'
237 | assert window_width > 0, 'Window width must be greater than zero.'
238 | assert window_overlap >= 0, 'Window overlap cannot be negative.'
239 |
240 | if window_overlap == 0:
241 | k = np.arange(start=0, stop=vector_length + (sample_rate if vector_length % sample_rate == 0 else 0))
242 | return np.array(list(zip(k[::sample_rate], k[sample_rate::sample_rate])))
243 |
244 | k = np.arange(start=0, stop=vector_length - window_overlap, step=window_width - window_overlap)
245 | k = k[k <= vector_length - window_width]
246 | return np.array([k, k + window_width]).T
247 |
248 |
249 | def generate_eeg_dataset(raw_eeg, channels: list = np.arange(128), per_channel=True, window_width_seconds=1,
250 | window_overlap_seconds=0.5):
251 | assert type(raw_eeg) in [np.ndarray, mne.io.eeglab.eeglab.RawEEGLAB], \
252 | f'The given raw_eeg must be of type Numpy Array or RawEEGLAB. type={type(raw_eeg)}'
253 |
254 | # Trim channel count to match the maximum of our input data
255 | if type(raw_eeg) is mne.io.eeglab.eeglab.RawEEGLAB and len(channels) > len(raw_eeg.ch_names):
256 | channels = np.arange(len(raw_eeg.ch_names))
257 | else:
258 | channels = np.arange(raw_eeg.shape[0])
259 |
260 | window_size = window_width_seconds * 1024 # 1024Hz per second sampling frequency
261 |
262 | output_size = get_output_dims_by_factors(window_size) if per_channel else get_output_dims_by_factors(
263 | window_size * len(channels))
264 | # Size of output image for each EEG sample, optionally per channel
265 |
266 | data = raw_eeg.squeeze() # Remove the outer-most dimension
267 | if type(raw_eeg) is mne.io.eeglab.eeglab.RawEEGLAB:
268 | data = raw_eeg.get_data(channels, start=0, stop=None) # loads all of the data. # (n_channels, n_data_points)
269 |
270 | splits = split_vector_to_window_indices(data.shape[1])
271 |
272 | split_start, split_end = splits[:, 0], splits[:, 1]
273 |
274 | output = None # Prefilling our outputs array with zeros
275 | if per_channel:
276 | output = np.zeros((channels.shape[0], splits.shape[0], output_size[0],
277 | output_size[1])) # n_channels, n_windows, reshaped_window
278 | else:
279 | output = np.zeros((splits.shape[0], output_size[0], output_size[1]))
280 |
281 | for i, s in enumerate(zip(split_start, split_end)):
282 | a, b = s
283 |
284 | if per_channel:
285 | for c in channels:
286 | output[c, i] = data[c, a:b].reshape(-1, output_size[0], output_size[1])
287 | else:
288 | output[i] = data[:, a:b].reshape(output_size[0], output_size[1])
289 | # Populate the outputs array according to if we're operating per_channel or not
290 |
291 | del data
292 | return output
293 |
--------------------------------------------------------------------------------
/util/preprocessing/preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import glob
3 | import os
4 | import random
5 | import uuid
6 |
7 | from PIL import Image
8 | from tqdm import tqdm
9 |
10 | from util.preprocessing import pipeline
11 | import json
12 |
13 |
14 | def preprocess(eeg_data_dir='./data/subjects', output_dir='./data/outputs/preprocessing/',
15 | montage_file_name='ANTNeuro_montage', hypers=None):
16 | assert os.path.exists(f'{eeg_data_dir}/{montage_file_name}.json'), \
17 | 'The specified montage file could not be found! ' \
18 | 'Please check it is in the EEG data directory root.'
19 |
20 | if hypers is None:
21 | hypers = {
22 | 'RENDER.DO_PER_CHANNEL': True,
23 | 'PREPROCESSING.DO_MONTAGE': True,
24 | 'PREPROCESSING.DO_REMOVE_DC': True,
25 | 'PREPROCESSING.DO_LOW_PASS_FILTER': True,
26 | 'PREPROCESSING.DO_HIGH_PASS_FILTER': True,
27 | 'PREPROCESSING.DO_USE_ICA': True,
28 | 'PREPROCESSING.DO_REMOVE_EOG': True,
29 | 'PREPROCESSING.LOW_PASS_FILTER.FREQ': 0.1,
30 | 'PREPROCESSING.HIGH_PASS_FILTER.FREQ': 50,
31 | 'RENDER.WINDOW_WIDTH': 1,
32 | 'RENDER.WINDOW_OVERLAP': 0.5,
33 | 'META.CONFIG_NAME': 'config-1',
34 | 'META.CONFIG_COMMENT': 'The default config for any EEG preprocessing.',
35 | 'META.RAW.DO_SAVE': False
36 | }
37 |
38 | if 'META.RAW.DO_SAVE' not in hypers:
39 | hypers['META.RAW.DO_SAVE'] = False
40 |
41 | sub_sess_pairs = [] # subject, session
42 |
43 | for subject in range(20):
44 | for session in range(3):
45 | path = f'{eeg_data_dir}/Subject {subject}/Session {session}/sub{subject}_sess{session}'
46 | if not os.path.isfile(f'{path}.fdt'):
47 | continue
48 | sub_sess_pairs.append((subject, session))
49 |
50 | print(f'Now executing config \'{hypers["META.CONFIG_NAME"]}\' '
51 | f'against the following subject and session pairs: {sub_sess_pairs}')
52 | pbar_subjects = tqdm(len(sub_sess_pairs), desc='Subjects and Sessions')
53 | k = 0
54 | for subject, session in sub_sess_pairs:
55 | k += 1
56 | print(f'Now preprocessing data for Subject {subject}, session {session}. Progress = {k}/{len(sub_sess_pairs)}')
57 | try:
58 | unique_id = str(uuid.uuid4()) # A unique ID for this run
59 | print('The unique ID for this subject/session pairing is', unique_id)
60 | hypers['META.UUID'] = unique_id
61 |
62 | raw = pipeline.load_eeg(subject, session)
63 | if hypers['PREPROCESSING.DO_MONTAGE']:
64 | raw = pipeline.apply_montage(raw, f'{eeg_data_dir}/{montage_file_name}.json')
65 |
66 | if hypers['PREPROCESSING.DO_REMOVE_DC']:
67 | raw = pipeline.remove_DC(raw)
68 |
69 | if hypers['PREPROCESSING.DO_LOW_PASS_FILTER']:
70 | raw = pipeline.apply_filter(raw, low_freq=hypers['PREPROCESSING.LOW_PASS_FILTER.FREQ'], high_freq=None)
71 |
72 | if hypers['PREPROCESSING.DO_HIGH_PASS_FILTER']:
73 | raw = pipeline.apply_filter(raw, low_freq=None, high_freq=hypers['PREPROCESSING.HIGH_PASS_FILTER.FREQ'])
74 |
75 | ica = None
76 | if hypers['PREPROCESSING.DO_USE_ICA']:
77 | ica = pipeline.compute_ICA(raw)
78 |
79 | if hypers['PREPROCESSING.DO_REMOVE_EOG']:
80 | ica = pipeline.remove_EOG(raw, ica)
81 | # ica = pipeline.remove_ECG(raw, ica) # Sometimes works, sometimes does not, seems to be an issue with MNE
82 | raw = pipeline.apply_ICA_to_RAW(raw, ica)
83 | del ica # It is no longer needed, so we delete it from memory
84 |
85 | _, _, epochs, _ = pipeline.generate_events(raw)
86 | path = f'{eeg_data_dir}/preprocessed/Subject {subject}/Session {session}/{hypers["META.CONFIG_NAME"]}/{unique_id} '
87 | os.makedirs(path, exist_ok=True)
88 |
89 | if hypers['META.RAW.DO_SAVE']:
90 | raw.save(f'{path}/sub_{subject}_sess_{session}_preprocessed_raw.fif.gz')
91 |
92 | with open(f'{path}/sub_{subject}_sess_{session}_hyperparams.json', 'w') as f:
93 | json.dump(hypers, f, sort_keys=True, indent=4)
94 |
95 | del raw
96 |
97 | A, B, C = ['imagined', 'perceived'], ['guitar', 'penguin', 'flower'], ['text', 'sound', 'pictorial']
98 | select_epochs = pipeline.select_specific_epochs(epochs, A, B, C)
99 |
100 | cropped_epochs = pipeline.crop_epochs(select_epochs)
101 | del select_epochs
102 |
103 | print('All preprocessing now complete, saving images!')
104 |
105 | pbar_epochs = tqdm(len(cropped_epochs), desc='Epoch progress')
106 |
107 | for i, p in enumerate(zip(cropped_epochs, cropped_epochs.event_id)):
108 | epoch, name = p
109 | images = pipeline.generate_eeg_dataset(
110 | epoch.squeeze(),
111 | per_channel=hypers['RENDER.DO_PER_CHANNEL'],
112 | window_width_seconds=hypers['RENDER.WINDOW_WIDTH'],
113 | window_overlap_seconds=hypers['RENDER.WINDOW_OVERLAP']
114 | ) # Remove outer dimension as this is just 1, so useless
115 | # pbar_channels = tqdm(images.shape[0], position=1, desc='Channel progress')
116 |
117 | for c, channel in enumerate(images):
118 | dir = f'{output_dir}/subject_{subject}/session_{session}/{hypers["META.CONFIG_NAME"]}/{unique_id}/channel_{c}'
119 | os.makedirs(dir, exist_ok=True)
120 | # pbar_event = tqdm(channel.shape[0], position=2, desc='Event progress')
121 | for e, event in enumerate(channel):
122 | im = Image.fromarray(event, 'L')
123 | im.save(f'{dir}/epoch_{i}_channel_{c}_event_{e}_{name}.jpg')
124 |
125 | # pbar_event.update(1)
126 |
127 | # pbar_channels.update(1)
128 | # pbar_event.close()
129 |
130 | pbar_epochs.update(1)
131 | # pbar_channels.close()
132 |
133 | print(f'Finished saving images for epoch {i}. Name={name}')
134 |
135 | # pbar_epochs.close()
136 |
137 | pbar_subjects.update(1)
138 | print(f'Completed preprocessing for subject {subject}, session {session}')
139 |
140 | del cropped_epochs
141 | except Exception as e:
142 | print(f'Encountered an exception for Subject {subject} session {session}. Stack trace is as follows...')
143 | print(e)
144 | continue
145 |
146 |
147 | def load_and_process_hyperparameters(dir: str, shuffle=True):
148 | print(f'Loading configs from dir: {dir}')
149 | configs = []
150 |
151 | for config_path in glob.iglob(f'{dir}/**/**/*.json'):
152 | with open(config_path, 'r') as f:
153 | configs.append(json.load(f))
154 | # print('Loaded config:', json.load(f))
155 |
156 | print(f'Found {len(configs)} to be processed...')
157 | return random.shuffle(configs) if shuffle else configs
158 |
159 |
160 | if __name__ == '__main__':
161 | print('Here we go...')
162 | num_cpu = '8'
163 | os.environ['OMP_NUM_THREADS'] = num_cpu
164 | ALLOW_DEFAULT_CONFIG = False
165 |
166 | configs = load_and_process_hyperparameters('./data/configurations')
167 | for i, config in enumerate(configs):
168 | if config['META.CONFIG_NAME'] == 'default-config' and not ALLOW_DEFAULT_CONFIG:
169 | continue
170 | print(f'Now processing with the following configuration ({i} of {len(configs)}):')
171 | print(config)
172 | try:
173 | preprocess(hypers=config)
174 | except Exception as e:
175 | print('An exception has been thrown, could not perform preprocessing for the given config...')
176 | print(config)
177 |
178 | print('Every config has now been processed, hurray! Terminating...')
179 |
--------------------------------------------------------------------------------
/util/resize.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | from PIL import Image, ImageOps, ImageFile
5 |
6 | ImageFile.LOAD_TRUNCATED_IMAGES = True
7 |
8 | targets = [
9 | (32, 32),
10 | (48, 48),
11 | (48, 32),
12 | (32, 16),
13 | (64, 32),
14 | (64, 48),
15 | (64, 64),
16 | (96, 96),
17 | (128, 128),
18 | (256, 256),
19 | (512, 512),
20 | (1024, 1024),
21 | (256, 128),
22 | (512, 128),
23 | (1024, 512),
24 | (1024, 256),
25 | (96, 48)
26 | ]
27 | target_dirs = ['flower', 'guitar', 'penguin']
28 |
29 | for w, h in targets:
30 | for d in target_dirs:
31 |
32 | os.makedirs(f'{d}-{w}x{h}', exist_ok=True)
33 | for f in glob.iglob(f'./{d}/*.*'):
34 | print(w, h, f)
35 | name = f.split('/')[-1]
36 | img = Image.open(f)
37 | img = ImageOps.fit(img, (w, h), Image.Resampling.LANCZOS)
38 | img.save(f'{d}-{w}x{h}/{name}')
39 |
--------------------------------------------------------------------------------
/util/update_configs.py:
--------------------------------------------------------------------------------
1 | import wandb
2 |
3 | api = wandb.Api()
4 | run = api.run("jd202/bath-thesis/1fqllayw")
5 |
6 | run.config['learning_rate'] = 3e-4
7 | run.config['training_timesteps'] = 5000
8 | run.config['sampling_timesteps'] = 250
9 | run.config['image_size'] = 32
10 | run.config['number_of_samples'] = 25
11 | run.config['batch_size'] = 512
12 | run.config['use_amp'] = False
13 | run.config['use_fp16'] = True
14 | run.config['gradient_accumulation_rate'] = 2
15 | run.config['ema_update_rate'] = 10
16 | run.config['ema_decay'] = 0.995
17 | run.config['adam_betas'] = (0.9, 0.99)
18 | run.config['save_and_sample_rate'] = 1000
19 | run.config['do_split_batches'] = False
20 | run.config['timesteps'] = 1000
21 | run.config['loss_type'] = 'L2'
22 | run.config['unet_dim'] = 16
23 | run.config['unet_mults'] = (1, 2, 4, 8)
24 | run.config['unet_channels'] = 3
25 | run.config['training_objective'] = 'pred_x0'
26 |
27 | run.update()
28 |
--------------------------------------------------------------------------------
/visualise.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import wandb
4 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
5 |
6 | torch.cuda.empty_cache()
7 | wandb.login()
8 |
9 | default_hypers = dict(
10 | learning_rate=3e-4,
11 | training_timesteps=1001,
12 | sampling_timesteps=250,
13 | image_size=32,
14 | number_of_samples=25,
15 | batch_size=256,
16 | use_amp=False,
17 | use_fp16=False,
18 | gradient_accumulation_rate=2,
19 | ema_update_rate=10,
20 | ema_decay=0.995,
21 | adam_betas=(0.9, 0.99),
22 | save_and_sample_rate=1000,
23 | do_split_batches=False,
24 | timesteps=4000,
25 | loss_type='L2',
26 | unet_dim=128,
27 | unet_mults=(1, 2, 2, 2),
28 | unet_channels=3,
29 | training_objective='pred_x0'
30 | )
31 |
32 | wandb.init(config=default_hypers, project='bath-thesis', entity='jd202')
33 |
34 | model = Unet(
35 | dim=wandb.config.unet_dim,
36 | dim_mults=wandb.config.unet_mults,
37 | channels=wandb.config.unet_channels
38 | )
39 |
40 | diffusion = GaussianDiffusion(
41 | model,
42 | image_size=wandb.config.image_size,
43 | timesteps=wandb.config.timesteps, # number of steps
44 | sampling_timesteps=wandb.config.sampling_timesteps,
45 | # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
46 | loss_type=wandb.config.loss_type, # L1 or L2
47 | training_objective=wandb.config.training_objective
48 | )
49 |
50 | trainer = Trainer(
51 | diffusion,
52 | '/Users/jake/Desktop/scp/cifar',
53 | train_batch_size=wandb.config.batch_size,
54 | training_learning_rate=wandb.config.learning_rate,
55 | num_training_steps=wandb.config.training_timesteps, # total training steps
56 | num_samples=wandb.config.number_of_samples,
57 | gradient_accumulate_every=wandb.config.gradient_accumulation_rate, # gradient accumulation steps
58 | ema_update_every=wandb.config.ema_update_rate,
59 | ema_decay=wandb.config.ema_decay, # exponential moving average decay
60 | amp=wandb.config.use_amp, # turn on mixed precision
61 | fp16=wandb.config.use_fp16,
62 | save_and_sample_every=wandb.config.save_and_sample_rate
63 | )
64 |
65 | trainer.load('./results/loadins', '56')
66 | trainer.ema.ema_model.eval()
67 |
68 | torch.save(model.state_dict(), './')
69 |
--------------------------------------------------------------------------------