├── .gitignore
├── README.md
├── classifier.py
├── doe.py
├── propagate.py
├── requirements.txt
├── schematic.png
├── spatial_coherence.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | data/
163 | .DS_Store
164 | lightning_logs/
165 | *.mp4
166 | *.pt
167 | *.npy
168 | .vscode/
169 |
170 | # Exclude build directories often generated by documentation tools like Sphinx
171 | _build/
172 | _build/html
173 | _build/doctrees
174 |
175 | # Exclude files created by sphinx-apidoc
176 | modules.rst
177 |
178 | # Exclude any .pyc files that might be created
179 | *.pyc
180 |
181 | # Exclude any virtual environment directories
182 | venv/
183 | .env/
184 |
185 | # Exclude other files/directories that are specific to your project or documentation workflow
186 | *.log
187 | *.tmp
188 | runs/
189 | *logs/
190 | *venv*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Diffractive Optical Neural Networks with Coherence
2 |
3 | > PyTorch implementation of diffractive optical neural networks under arbitrary spatial coherence.
4 |
5 |
6 |
7 |
8 |
9 | Supplementary code from our [paper](https://arxiv.org/abs/2310.03679).
10 |
11 | ## Usage
12 | Run the following to train a model on the MNIST dataset:
13 | ```bash
14 | python train.py --coherence-degree=1 --wavelength=700e-9 --pixel-size=10e-6
15 | ```
16 |
--------------------------------------------------------------------------------
/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module
3 |
4 |
5 | class Classifier(Module):
6 | def __init__(self, shape, region_size):
7 | super().__init__()
8 | if shape < 4 * region_size:
9 | raise ValueError("shape must be at least 4*region_size")
10 |
11 | weight = torch.zeros(10, shape, shape, dtype=torch.double)
12 | row_offset = (shape - 4 * region_size) // 2
13 | col_offset = (shape - 3 * region_size) // 2
14 |
15 | # Function to set a region to 1
16 | def set_region(digit, row, col):
17 | start_row = row * (region_size) + row_offset
18 | start_col = col * (region_size) + col_offset
19 | weight[
20 | digit,
21 | start_row : start_row + region_size,
22 | start_col : start_col + region_size,
23 | ] = 1
24 |
25 | # Add the bottom row representing "zero" (special case)
26 | set_region(0, 3, 1)
27 |
28 | # Add the top three rows from left to right
29 | for digit in range(1, 10):
30 | row, col = (digit - 1) // 3, (digit - 1) % 3
31 | set_region(digit, row, col)
32 |
33 | self.register_buffer("weight", weight, persistent=False)
34 |
35 | def forward(self, x):
36 | return torch.einsum("nxy,bxy->bn", self.weight, x)
37 |
--------------------------------------------------------------------------------
/doe.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module, Parameter
3 |
4 |
5 | class DOE(Module):
6 | def __init__(self, shape: int):
7 | super().__init__()
8 | self.phase_params = Parameter(2 * torch.pi * torch.rand(shape, shape))
9 |
10 | def forward(self, x):
11 | return torch.exp(1j * self.phase_params) * x
12 |
--------------------------------------------------------------------------------
/propagate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.fft import fft2, ifft2
3 | from torch.nn import Module
4 |
5 |
6 | class Propagate(Module):
7 | def __init__(
8 | self,
9 | preceding_shape: int,
10 | succeeding_shape: int,
11 | propagation_distance: float,
12 | wavelength: float,
13 | pixel_size: float,
14 | ):
15 | super().__init__()
16 | grid_extent = (preceding_shape + succeeding_shape) / 2
17 | coords = torch.arange(-grid_extent + 1, grid_extent, dtype=torch.double)
18 | x, y = torch.meshgrid(coords * pixel_size, coords * pixel_size, indexing="ij")
19 |
20 | r_squared = x**2 + y**2 + propagation_distance**2
21 | r = torch.sqrt(r_squared)
22 | impulse_response = (
23 | (propagation_distance / r_squared * (1 / (2 * torch.pi * r) - 1.0j / wavelength))
24 | * torch.exp(2j * torch.pi * r / wavelength)
25 | * pixel_size**2
26 | )
27 | self.register_buffer("impulse_response_ft", fft2(impulse_response), persistent=False)
28 |
29 | def forward(self, field: torch.Tensor) -> torch.Tensor:
30 | return conv2d_fft(self.impulse_response_ft, field)
31 |
32 |
33 | def conv2d_fft(H_fr: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
34 | """Performs a 2D convolution using Fast Fourier Transforms (FFT).
35 |
36 | Args:
37 | H_fr (torch.Tensor): Fourier-transformed transfer function.
38 | x (torch.Tensor): Input complex field.
39 |
40 | Returns:
41 | torch.Tensor: Output field after convolution.
42 | """
43 | output_size = (H_fr.size(-2) - x.size(-2) + 1, H_fr.size(-1) - x.size(-1) + 1)
44 | x_fr = fft2(x.flip(-1, -2).conj(), s=(H_fr.size(-2), H_fr.size(-1)))
45 | output_fr = H_fr * x_fr.conj()
46 | output = ifft2(output_fr)[..., : output_size[0], : output_size[1]].clone()
47 | return output
48 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch_lightning==2.0.2
2 | torch==2.1.2
3 | torchmetrics==0.11.4
4 | torchvision==0.15.2
5 |
--------------------------------------------------------------------------------
/schematic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MatthewFilipovich/diffractive-optical-neural-networks-with-coherence/a0f61449bd432d926c9326d62510164367ee0789/schematic.png
--------------------------------------------------------------------------------
/spatial_coherence.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_exponentially_decaying_spatial_coherence(field, coherence_degree):
5 | if coherence_degree < 0 or coherence_degree > 1:
6 | raise ValueError("Coherence degree must be between 0 and 1.")
7 | xv, yv = torch.meshgrid(
8 | torch.arange(field.shape[-1], device=field.device, dtype=torch.double),
9 | torch.arange(field.shape[-1], device=field.device, dtype=torch.double),
10 | indexing="ij",
11 | )
12 | new_xv = xv - xv[..., None, None]
13 | new_yv = yv - yv[..., None, None]
14 | r = torch.sqrt(new_xv**2 + new_yv**2)
15 | return (field[..., None, None, :, :] * field.conj()[..., None, None]) * coherence_degree**r
16 |
17 |
18 | def get_source_modes(shape, image_pixel_size): # shape would be 28 for MNIST
19 | source_modes = torch.zeros(
20 | shape**2, # Number of source modes i.e., total input pixels
21 | shape * image_pixel_size, # Nx
22 | shape * image_pixel_size, # Ny
23 | dtype=torch.cdouble,
24 | )
25 | for i in range(shape):
26 | for j in range(shape):
27 | source_modes[
28 | i * shape + j,
29 | i * image_pixel_size : (i + 1) * image_pixel_size,
30 | j * image_pixel_size : (j + 1) * image_pixel_size,
31 | ] = 1
32 | return source_modes
33 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from pytorch_lightning.callbacks import ModelCheckpoint
6 | from torch.nn import ModuleList
7 | from torch.nn.functional import cross_entropy
8 | from torch.optim.lr_scheduler import StepLR
9 | from torch.utils.data import DataLoader, random_split
10 | from torchmetrics import Accuracy
11 | from torchvision import datasets, transforms
12 |
13 | from classifier import Classifier
14 | from doe import DOE
15 | from propagate import Propagate
16 | from spatial_coherence import get_exponentially_decaying_spatial_coherence, get_source_modes
17 |
18 |
19 | class DiffractiveSystem(pl.LightningModule):
20 | def __init__(self, learning_rate, gamma, coherence_degree, wavelength, pixel_size):
21 | super().__init__()
22 | self.save_hyperparameters()
23 | self.doe_list = ModuleList([DOE(shape=100) for _ in range(5)])
24 | self.initial_propagate = Propagate(
25 | preceding_shape=28 * 4,
26 | succeeding_shape=100,
27 | propagation_distance=0.05,
28 | wavelength=wavelength,
29 | pixel_size=pixel_size,
30 | )
31 | self.intralayer_propagate = Propagate(
32 | preceding_shape=100,
33 | succeeding_shape=100,
34 | propagation_distance=0.05,
35 | wavelength=wavelength,
36 | pixel_size=pixel_size,
37 | )
38 | self.classifier = Classifier(shape=100, region_size=25)
39 | self.source_modes = get_source_modes(shape=28, image_pixel_size=4)
40 | self.accuracy = Accuracy("multiclass", num_classes=10)
41 |
42 | def forward(self, x):
43 | coherence_tensor = get_exponentially_decaying_spatial_coherence(
44 | torch.squeeze(x, -3).to(torch.cdouble), self.hparams.coherence_degree
45 | )
46 |
47 | modes = self.source_modes
48 | modes = self.initial_propagate(modes)
49 | for doe in self.doe_list:
50 | modes = doe(modes)
51 | modes = self.intralayer_propagate(modes)
52 |
53 | batch_size = coherence_tensor.shape[0]
54 | total_input_pixels = coherence_tensor.shape[-2] * coherence_tensor.shape[-1]
55 | total_output_pixels = modes.shape[-2] * modes.shape[-1]
56 | output_intensity = (
57 | torch.einsum( # Reduce precision to cfloat for performance
58 | "bij, io, jo-> bo",
59 | coherence_tensor.view(batch_size, total_input_pixels, total_input_pixels).to(torch.cfloat),
60 | modes.view(total_input_pixels, total_output_pixels).conj().to(torch.cfloat),
61 | modes.view(total_input_pixels, total_output_pixels).to(torch.cfloat),
62 | )
63 | .real.view(batch_size, *modes.shape[-2:])
64 | .to(torch.double)
65 | )
66 | return self.classifier(output_intensity)
67 |
68 | def training_step(self, batch, batch_idx):
69 | data, target = batch
70 | output = self(data)
71 | loss = cross_entropy(output, target)
72 | acc = self.accuracy(output, target)
73 | self.log("train_loss", loss)
74 | self.log("train_acc", acc)
75 | self.log("learning_rate", self.trainer.optimizers[0].param_groups[0]["lr"])
76 | return loss
77 |
78 | def validation_step(self, batch, batch_idx):
79 | data, target = batch
80 | output = self(data)
81 | loss = cross_entropy(output, target)
82 | acc = self.accuracy(output, target)
83 | self.log("val_loss", loss)
84 | self.log("val_acc", acc)
85 |
86 | def test_step(self, batch, batch_idx):
87 | data, target = batch
88 | output = self(data)
89 | loss = cross_entropy(output, target)
90 | acc = self.accuracy(output, target)
91 | self.log("test_loss", loss)
92 | self.log("test_acc", acc)
93 |
94 | def configure_optimizers(self):
95 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
96 | scheduler = StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)
97 | return [optimizer], [scheduler]
98 |
99 |
100 | def main(args):
101 | torch.manual_seed(args.seed)
102 | transform_list = [
103 | transforms.ToTensor(),
104 | transforms.Normalize((0.1307,), (0.3081,)),
105 | ]
106 |
107 | transform = transforms.Compose(transform_list)
108 | dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
109 | train_size = int(0.8 * len(dataset))
110 | validation_size = len(dataset) - train_size
111 | train_dataset, val_dataset = random_split(dataset, [train_size, validation_size])
112 |
113 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
114 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
115 | test_loader = DataLoader(
116 | datasets.MNIST("../data", train=False, transform=transform),
117 | batch_size=args.batch_size,
118 | num_workers=args.num_workers,
119 | )
120 |
121 | model = DiffractiveSystem(args.lr, args.gamma, args.coherence_degree, args.wavelength, args.pixel_size)
122 | checkpoint_callback = ModelCheckpoint(
123 | monitor="val_acc",
124 | mode="max",
125 | save_top_k=1,
126 | verbose=True,
127 | )
128 | accelerator = "cuda" if torch.cuda.is_available() else "cpu"
129 | trainer = pl.Trainer(max_epochs=args.epochs, callbacks=[checkpoint_callback], accelerator=accelerator)
130 | trainer.fit(model, train_loader, val_loader)
131 | trainer.test(dataloaders=test_loader, ckpt_path="best")
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
136 | parser.add_argument("--coherence-degree", type=float, required=True, help="coherence degree")
137 | parser.add_argument("--wavelength", type=float, default=700e-9, help="field wavelength (default: 700 nm)")
138 | parser.add_argument("--pixel-size", type=float, default=10e-6, help="field pixel size (default: 10 um)")
139 | parser.add_argument(
140 | "--batch-size", type=int, default=32, help="input batch size for training (default: 32)"
141 | )
142 | parser.add_argument("--epochs", type=int, default=50, help="number of epochs to train (default: 50)")
143 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate (default: 1e-2)")
144 | parser.add_argument("--num-workers", type=int, default=1, help="number of workers (default: 1)")
145 | parser.add_argument("--gamma", type=float, default=0.95, help="Learning rate step gamma (default: 0.95)")
146 | parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
147 | args = parser.parse_args()
148 | main(args)
149 |
--------------------------------------------------------------------------------