├── .gitignore ├── README.md ├── classifier.py ├── doe.py ├── propagate.py ├── requirements.txt ├── schematic.png ├── spatial_coherence.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | data/ 163 | .DS_Store 164 | lightning_logs/ 165 | *.mp4 166 | *.pt 167 | *.npy 168 | .vscode/ 169 | 170 | # Exclude build directories often generated by documentation tools like Sphinx 171 | _build/ 172 | _build/html 173 | _build/doctrees 174 | 175 | # Exclude files created by sphinx-apidoc 176 | modules.rst 177 | 178 | # Exclude any .pyc files that might be created 179 | *.pyc 180 | 181 | # Exclude any virtual environment directories 182 | venv/ 183 | .env/ 184 | 185 | # Exclude other files/directories that are specific to your project or documentation workflow 186 | *.log 187 | *.tmp 188 | runs/ 189 | *logs/ 190 | *venv* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffractive Optical Neural Networks with Coherence 2 | 3 | > PyTorch implementation of diffractive optical neural networks under arbitrary spatial coherence. 4 | 5 |

6 | 7 |

8 | 9 | Supplementary code from our [paper](https://arxiv.org/abs/2310.03679). 10 | 11 | ## Usage 12 | Run the following to train a model on the MNIST dataset: 13 | ```bash 14 | python train.py --coherence-degree=1 --wavelength=700e-9 --pixel-size=10e-6 15 | ``` 16 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | 5 | class Classifier(Module): 6 | def __init__(self, shape, region_size): 7 | super().__init__() 8 | if shape < 4 * region_size: 9 | raise ValueError("shape must be at least 4*region_size") 10 | 11 | weight = torch.zeros(10, shape, shape, dtype=torch.double) 12 | row_offset = (shape - 4 * region_size) // 2 13 | col_offset = (shape - 3 * region_size) // 2 14 | 15 | # Function to set a region to 1 16 | def set_region(digit, row, col): 17 | start_row = row * (region_size) + row_offset 18 | start_col = col * (region_size) + col_offset 19 | weight[ 20 | digit, 21 | start_row : start_row + region_size, 22 | start_col : start_col + region_size, 23 | ] = 1 24 | 25 | # Add the bottom row representing "zero" (special case) 26 | set_region(0, 3, 1) 27 | 28 | # Add the top three rows from left to right 29 | for digit in range(1, 10): 30 | row, col = (digit - 1) // 3, (digit - 1) % 3 31 | set_region(digit, row, col) 32 | 33 | self.register_buffer("weight", weight, persistent=False) 34 | 35 | def forward(self, x): 36 | return torch.einsum("nxy,bxy->bn", self.weight, x) 37 | -------------------------------------------------------------------------------- /doe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Parameter 3 | 4 | 5 | class DOE(Module): 6 | def __init__(self, shape: int): 7 | super().__init__() 8 | self.phase_params = Parameter(2 * torch.pi * torch.rand(shape, shape)) 9 | 10 | def forward(self, x): 11 | return torch.exp(1j * self.phase_params) * x 12 | -------------------------------------------------------------------------------- /propagate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fft2, ifft2 3 | from torch.nn import Module 4 | 5 | 6 | class Propagate(Module): 7 | def __init__( 8 | self, 9 | preceding_shape: int, 10 | succeeding_shape: int, 11 | propagation_distance: float, 12 | wavelength: float, 13 | pixel_size: float, 14 | ): 15 | super().__init__() 16 | grid_extent = (preceding_shape + succeeding_shape) / 2 17 | coords = torch.arange(-grid_extent + 1, grid_extent, dtype=torch.double) 18 | x, y = torch.meshgrid(coords * pixel_size, coords * pixel_size, indexing="ij") 19 | 20 | r_squared = x**2 + y**2 + propagation_distance**2 21 | r = torch.sqrt(r_squared) 22 | impulse_response = ( 23 | (propagation_distance / r_squared * (1 / (2 * torch.pi * r) - 1.0j / wavelength)) 24 | * torch.exp(2j * torch.pi * r / wavelength) 25 | * pixel_size**2 26 | ) 27 | self.register_buffer("impulse_response_ft", fft2(impulse_response), persistent=False) 28 | 29 | def forward(self, field: torch.Tensor) -> torch.Tensor: 30 | return conv2d_fft(self.impulse_response_ft, field) 31 | 32 | 33 | def conv2d_fft(H_fr: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 34 | """Performs a 2D convolution using Fast Fourier Transforms (FFT). 35 | 36 | Args: 37 | H_fr (torch.Tensor): Fourier-transformed transfer function. 38 | x (torch.Tensor): Input complex field. 39 | 40 | Returns: 41 | torch.Tensor: Output field after convolution. 42 | """ 43 | output_size = (H_fr.size(-2) - x.size(-2) + 1, H_fr.size(-1) - x.size(-1) + 1) 44 | x_fr = fft2(x.flip(-1, -2).conj(), s=(H_fr.size(-2), H_fr.size(-1))) 45 | output_fr = H_fr * x_fr.conj() 46 | output = ifft2(output_fr)[..., : output_size[0], : output_size[1]].clone() 47 | return output 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==2.0.2 2 | torch==2.1.2 3 | torchmetrics==0.11.4 4 | torchvision==0.15.2 5 | -------------------------------------------------------------------------------- /schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatthewFilipovich/diffractive-optical-neural-networks-with-coherence/a0f61449bd432d926c9326d62510164367ee0789/schematic.png -------------------------------------------------------------------------------- /spatial_coherence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_exponentially_decaying_spatial_coherence(field, coherence_degree): 5 | if coherence_degree < 0 or coherence_degree > 1: 6 | raise ValueError("Coherence degree must be between 0 and 1.") 7 | xv, yv = torch.meshgrid( 8 | torch.arange(field.shape[-1], device=field.device, dtype=torch.double), 9 | torch.arange(field.shape[-1], device=field.device, dtype=torch.double), 10 | indexing="ij", 11 | ) 12 | new_xv = xv - xv[..., None, None] 13 | new_yv = yv - yv[..., None, None] 14 | r = torch.sqrt(new_xv**2 + new_yv**2) 15 | return (field[..., None, None, :, :] * field.conj()[..., None, None]) * coherence_degree**r 16 | 17 | 18 | def get_source_modes(shape, image_pixel_size): # shape would be 28 for MNIST 19 | source_modes = torch.zeros( 20 | shape**2, # Number of source modes i.e., total input pixels 21 | shape * image_pixel_size, # Nx 22 | shape * image_pixel_size, # Ny 23 | dtype=torch.cdouble, 24 | ) 25 | for i in range(shape): 26 | for j in range(shape): 27 | source_modes[ 28 | i * shape + j, 29 | i * image_pixel_size : (i + 1) * image_pixel_size, 30 | j * image_pixel_size : (j + 1) * image_pixel_size, 31 | ] = 1 32 | return source_modes 33 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | from torch.nn import ModuleList 7 | from torch.nn.functional import cross_entropy 8 | from torch.optim.lr_scheduler import StepLR 9 | from torch.utils.data import DataLoader, random_split 10 | from torchmetrics import Accuracy 11 | from torchvision import datasets, transforms 12 | 13 | from classifier import Classifier 14 | from doe import DOE 15 | from propagate import Propagate 16 | from spatial_coherence import get_exponentially_decaying_spatial_coherence, get_source_modes 17 | 18 | 19 | class DiffractiveSystem(pl.LightningModule): 20 | def __init__(self, learning_rate, gamma, coherence_degree, wavelength, pixel_size): 21 | super().__init__() 22 | self.save_hyperparameters() 23 | self.doe_list = ModuleList([DOE(shape=100) for _ in range(5)]) 24 | self.initial_propagate = Propagate( 25 | preceding_shape=28 * 4, 26 | succeeding_shape=100, 27 | propagation_distance=0.05, 28 | wavelength=wavelength, 29 | pixel_size=pixel_size, 30 | ) 31 | self.intralayer_propagate = Propagate( 32 | preceding_shape=100, 33 | succeeding_shape=100, 34 | propagation_distance=0.05, 35 | wavelength=wavelength, 36 | pixel_size=pixel_size, 37 | ) 38 | self.classifier = Classifier(shape=100, region_size=25) 39 | self.source_modes = get_source_modes(shape=28, image_pixel_size=4) 40 | self.accuracy = Accuracy("multiclass", num_classes=10) 41 | 42 | def forward(self, x): 43 | coherence_tensor = get_exponentially_decaying_spatial_coherence( 44 | torch.squeeze(x, -3).to(torch.cdouble), self.hparams.coherence_degree 45 | ) 46 | 47 | modes = self.source_modes 48 | modes = self.initial_propagate(modes) 49 | for doe in self.doe_list: 50 | modes = doe(modes) 51 | modes = self.intralayer_propagate(modes) 52 | 53 | batch_size = coherence_tensor.shape[0] 54 | total_input_pixels = coherence_tensor.shape[-2] * coherence_tensor.shape[-1] 55 | total_output_pixels = modes.shape[-2] * modes.shape[-1] 56 | output_intensity = ( 57 | torch.einsum( # Reduce precision to cfloat for performance 58 | "bij, io, jo-> bo", 59 | coherence_tensor.view(batch_size, total_input_pixels, total_input_pixels).to(torch.cfloat), 60 | modes.view(total_input_pixels, total_output_pixels).conj().to(torch.cfloat), 61 | modes.view(total_input_pixels, total_output_pixels).to(torch.cfloat), 62 | ) 63 | .real.view(batch_size, *modes.shape[-2:]) 64 | .to(torch.double) 65 | ) 66 | return self.classifier(output_intensity) 67 | 68 | def training_step(self, batch, batch_idx): 69 | data, target = batch 70 | output = self(data) 71 | loss = cross_entropy(output, target) 72 | acc = self.accuracy(output, target) 73 | self.log("train_loss", loss) 74 | self.log("train_acc", acc) 75 | self.log("learning_rate", self.trainer.optimizers[0].param_groups[0]["lr"]) 76 | return loss 77 | 78 | def validation_step(self, batch, batch_idx): 79 | data, target = batch 80 | output = self(data) 81 | loss = cross_entropy(output, target) 82 | acc = self.accuracy(output, target) 83 | self.log("val_loss", loss) 84 | self.log("val_acc", acc) 85 | 86 | def test_step(self, batch, batch_idx): 87 | data, target = batch 88 | output = self(data) 89 | loss = cross_entropy(output, target) 90 | acc = self.accuracy(output, target) 91 | self.log("test_loss", loss) 92 | self.log("test_acc", acc) 93 | 94 | def configure_optimizers(self): 95 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 96 | scheduler = StepLR(optimizer, step_size=1, gamma=self.hparams.gamma) 97 | return [optimizer], [scheduler] 98 | 99 | 100 | def main(args): 101 | torch.manual_seed(args.seed) 102 | transform_list = [ 103 | transforms.ToTensor(), 104 | transforms.Normalize((0.1307,), (0.3081,)), 105 | ] 106 | 107 | transform = transforms.Compose(transform_list) 108 | dataset = datasets.MNIST("../data", train=True, download=True, transform=transform) 109 | train_size = int(0.8 * len(dataset)) 110 | validation_size = len(dataset) - train_size 111 | train_dataset, val_dataset = random_split(dataset, [train_size, validation_size]) 112 | 113 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 114 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 115 | test_loader = DataLoader( 116 | datasets.MNIST("../data", train=False, transform=transform), 117 | batch_size=args.batch_size, 118 | num_workers=args.num_workers, 119 | ) 120 | 121 | model = DiffractiveSystem(args.lr, args.gamma, args.coherence_degree, args.wavelength, args.pixel_size) 122 | checkpoint_callback = ModelCheckpoint( 123 | monitor="val_acc", 124 | mode="max", 125 | save_top_k=1, 126 | verbose=True, 127 | ) 128 | accelerator = "cuda" if torch.cuda.is_available() else "cpu" 129 | trainer = pl.Trainer(max_epochs=args.epochs, callbacks=[checkpoint_callback], accelerator=accelerator) 130 | trainer.fit(model, train_loader, val_loader) 131 | trainer.test(dataloaders=test_loader, ckpt_path="best") 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description="PyTorch MNIST Example") 136 | parser.add_argument("--coherence-degree", type=float, required=True, help="coherence degree") 137 | parser.add_argument("--wavelength", type=float, default=700e-9, help="field wavelength (default: 700 nm)") 138 | parser.add_argument("--pixel-size", type=float, default=10e-6, help="field pixel size (default: 10 um)") 139 | parser.add_argument( 140 | "--batch-size", type=int, default=32, help="input batch size for training (default: 32)" 141 | ) 142 | parser.add_argument("--epochs", type=int, default=50, help="number of epochs to train (default: 50)") 143 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate (default: 1e-2)") 144 | parser.add_argument("--num-workers", type=int, default=1, help="number of workers (default: 1)") 145 | parser.add_argument("--gamma", type=float, default=0.95, help="Learning rate step gamma (default: 0.95)") 146 | parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)") 147 | args = parser.parse_args() 148 | main(args) 149 | --------------------------------------------------------------------------------