├── .gitignore ├── .idea ├── ONNet.iml ├── misc.xml ├── modules.xml ├── other.xml ├── vcs.xml └── workspace.xml ├── ONNet_wavelet.png ├── README.md ├── case_brain.py ├── case_cifar.py ├── case_covir.py ├── case_dog_cat.py ├── case_face_detect.py ├── case_lung_mask.py ├── case_mnist.py ├── python-package ├── case_fft.py ├── cnn_models │ └── OpticalNet.py ├── fast_conv.py └── onnet │ ├── BinaryDNet.py │ ├── D2NN_tf.py │ ├── D2NNet.py │ ├── DiffractiveLayer.py │ ├── DropOutLayer.py │ ├── FFT_layer.py │ ├── Loss.py │ ├── NET_config.py │ ├── Net_Instance.py │ ├── OpticalFormer.py │ ├── OpticalFormer_util.py │ ├── PoolForCls.py │ ├── RGBO_CNN.py │ ├── SparseSupport.py │ ├── ToExcel.py │ ├── Visualizing.py │ ├── Z_utils.py │ ├── __init__.py │ ├── __version__.py │ ├── optical_trans.py │ └── some_utils.py └── venv └── pyvenv.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | 4 | # User-specific files 5 | *.suo 6 | *.user 7 | *.userosscache 8 | *.sln.docstates 9 | *.pth 10 | # User-specific files (MonoDevelop/Xamarin Studio) 11 | *.userprefs 12 | *.npy 13 | 14 | # Build results 15 | net-source 16 | dump/ 17 | runs/ 18 | cnn_models/ 19 | deap/ 20 | checkpoint/ 21 | python-package/net_source/ 22 | [Dd]ebug/ 23 | [Dd]ebugPublic/ 24 | [Rr]elease/ 25 | [Rr]eleases/ 26 | [Xx]64/ 27 | [Xx]86/ 28 | [Bb]uild/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | docs/_build 33 | tests/bin 34 | lib 35 | data 36 | _000 37 | python-package/dist 38 | .pytest_cache/v/cache 39 | 40 | # Visual Studio 2015 cache/options directory 41 | .vs/ 42 | # Uncomment if you have tasks that create the project's static files in wwwroot 43 | #wwwroot/ 44 | 45 | # MSTest test Results 46 | [Tt]est[Rr]esult*/ 47 | [Bb]uild[Ll]og.* 48 | 49 | # NUNIT 50 | *.VisualState.xml 51 | TestResult.xml 52 | 53 | # Build Results of an ATL Project 54 | [Dd]ebugPS/ 55 | [Rr]eleasePS/ 56 | dlldata.c 57 | 58 | # DNX 59 | project.lock.json 60 | artifacts/ 61 | 62 | # Python 63 | *.egg-info 64 | __pycache__ 65 | .eggs 66 | 67 | # VS Code 68 | .vscode 69 | 70 | # Prerequisites 71 | *.d 72 | 73 | # Compiled Object files 74 | *.slo 75 | *.lo 76 | *.o 77 | *.obj 78 | 79 | # Precompiled Headers 80 | *.gch 81 | 82 | *_i.c 83 | *_p.c 84 | *_i.h 85 | *.ilk 86 | *.meta 87 | *.obj 88 | *.pch 89 | *.pdb 90 | *.pgc 91 | *.pgd 92 | *.rsp 93 | *.sbr 94 | *.tlb 95 | *.tli 96 | *.tlh 97 | *.tmp 98 | *.tmp_proj 99 | *.log 100 | *.vspscc 101 | *.vssscc 102 | .builds 103 | *.pidb 104 | *.svclog 105 | *.scc 106 | *.rar 107 | *.ym 108 | *.model 109 | 110 | 111 | # Chutzpah Test files 112 | _Chutzpah* 113 | 114 | # Visual C++ cache files 115 | ipch/ 116 | *.aps 117 | *.ncb 118 | *.opendb 119 | *.opensdf 120 | *.sdf 121 | *.cachefile 122 | *.VC.db 123 | 124 | # Visual Studio profiler 125 | *.psess 126 | *.vsp 127 | *.vspx 128 | *.sap 129 | 130 | # TFS 2012 Local Workspace 131 | $tf/ 132 | 133 | # Guidance Automation Toolkit 134 | *.gpState 135 | 136 | # ReSharper is a .NET coding add-in 137 | _ReSharper*/ 138 | *.[Rr]e[Ss]harper 139 | *.DotSettings.user 140 | 141 | # JustCode is a .NET coding add-in 142 | .JustCode 143 | 144 | # TeamCity is a build add-in 145 | _TeamCity* 146 | 147 | # DotCover is a Code Coverage Tool 148 | *.dotCover 149 | 150 | # NCrunch 151 | _NCrunch_* 152 | .*crunch*.local.xml 153 | nCrunchTemp_* 154 | 155 | # MightyMoose 156 | *.mm.* 157 | AutoTest.Net/ 158 | 159 | # Web workbench (sass) 160 | .sass-cache/ 161 | 162 | # Installshield output folder 163 | [Ee]xpress/ 164 | 165 | # DocProject is a documentation generator add-in 166 | DocProject/buildhelp/ 167 | DocProject/Help/*.HxT 168 | DocProject/Help/*.HxC 169 | DocProject/Help/*.hhc 170 | DocProject/Help/*.hhk 171 | DocProject/Help/*.hhp 172 | DocProject/Help/Html2 173 | DocProject/Help/html 174 | 175 | # Click-Once directory 176 | publish/ 177 | 178 | # Publish Web Output 179 | *.[Pp]ublish.xml 180 | *.azurePubxml 181 | 182 | # TODO: Un-comment the next line if you do not want to checkin 183 | # your web deploy settings because they may include unencrypted 184 | # passwords 185 | #*.pubxml 186 | *.publishproj 187 | 188 | # NuGet Packages 189 | *.nupkg 190 | # The packages folder can be ignored because of Package Restore 191 | **/packages/* 192 | # except build/, which is used as an MSBuild target. 193 | !**/packages/build/ 194 | # Uncomment if necessary however generally it will be regenerated when needed 195 | #!**/packages/repositories.config 196 | # NuGet v3's project.json files produces more ignoreable files 197 | *.nuget.props 198 | *.nuget.targets 199 | 200 | # Microsoft Azure Build Output 201 | csx/ 202 | *.build.csdef 203 | 204 | # Microsoft Azure Emulator 205 | ecf/ 206 | rcf/ 207 | 208 | # Windows Store app package directory 209 | AppPackages/ 210 | BundleArtifacts/ 211 | 212 | # Visual Studio cache files 213 | # files ending in .cache can be ignored 214 | *.[Cc]ache 215 | # but keep track of directories ending in .cache 216 | !*.[Cc]ache/ 217 | 218 | # Others 219 | ClientBin/ 220 | [Ss]tyle[Cc]op.* 221 | ~$* 222 | *~ 223 | *.dbmdl 224 | *.dbproj.schemaview 225 | *.pfx 226 | *.publishsettings 227 | node_modules/ 228 | orleans.codegen.cs 229 | 230 | # RIA/Silverlight projects 231 | Generated_Code/ 232 | 233 | # Backup & report files from converting an old project file 234 | # to a newer Visual Studio version. Backup files are not needed, 235 | # because we have git ;-) 236 | _UpgradeReport_Files/ 237 | Backup*/ 238 | UpgradeLog*.XML 239 | UpgradeLog*.htm 240 | 241 | # SQL Server files 242 | *.mdf 243 | *.ldf 244 | 245 | # Business Intelligence projects 246 | *.rdl.data 247 | *.bim.layout 248 | *.bim_*.settings 249 | 250 | # Microsoft Fakes 251 | FakesAssemblies/ 252 | 253 | # GhostDoc plugin setting file 254 | *.GhostDoc.xml 255 | 256 | # Node.js Tools for Visual Studio 257 | .ntvs_analysis.dat 258 | 259 | # Visual Studio 6 build log 260 | *.plg 261 | 262 | # Visual Studio 6 workspace options file 263 | *.opt 264 | 265 | # Visual Studio LightSwitch build output 266 | **/*.HTMLClient/GeneratedArtifacts 267 | **/*.DesktopClient/GeneratedArtifacts 268 | **/*.DesktopClient/ModelManifest.xml 269 | **/*.Server/GeneratedArtifacts 270 | **/*.Server/ModelManifest.xml 271 | _Pvt_Extensions 272 | 273 | # LightSwitch generated files 274 | GeneratedArtifacts/ 275 | ModelManifest.xml 276 | 277 | # Paket dependency manager 278 | .paket/paket.exe 279 | 280 | # FAKE - F# Make 281 | .fake/ 282 | *.lai 283 | *.la 284 | *.a 285 | *.lib 286 | *.zip 287 | *.info 288 | *.dll 289 | *.so 290 | *.dylib 291 | *.mA_bin 292 | *.dat 293 | *.avi 294 | *.ogv 295 | *.asv 296 | *.code 297 | /tests/python_package_test/.pytest_cache/v/cache 298 | /tests/python_package_test/categorical.model 299 | /python-package/geo_test.py 300 | *.csv 301 | /python-package/.pytest_cache/v/cache/lastfailed 302 | /python-package/.pytest_cache/v/cache/nodeids 303 | /python-package/case_qq2019.py 304 | *.txt 305 | *.jpg 306 | /src/learn/discpy.py 307 | /src/learn/sparsipy.py 308 | /doc/Gradient boosting on adpative distrubutions.docx 309 | /python-package/litemort/桌面.lnk 310 | /python-package/LiteMORT_hyppo.py 311 | /python-package/shap_test.py 312 | *.pickle 313 | *.gz 314 | logger.py 315 | -------------------------------------------------------------------------------- /.idea/ONNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /ONNet_wavelet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closest-git/ONNet/79dacffe164369e564650f65b3e9e857668b63bc/ONNet_wavelet.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNet 2 | 3 | **ONNet** is an open-source Python/C++ package for the optical neural networks, which provides many tools for researchers studying optical neural networks. Some new models are as follows: 4 | 5 | - #### Express Wavenet 6 | 7 | ![](./ONNet_wavelet.png) 8 | 9 | Express Wavenet uses random shift wavelet pattern to modulate the phase of optical waves, which only need one percent of the parameters and the accuracy is still high. In the MNIST dataset, it only needs 1229 parameters to get accuracy of 92%, while DDNet needs 125440 parameters. .[2] 10 | 11 | - #### Diffractive deep neural network with multiple frequency-channels 12 | 13 | Each layer have multiple frequency-channels (optical distributions at different frequency). These channels are merged at the output plane with weighting coefficient. [1] 14 | 15 | - #### Diffractive network with multiple binary output plane 16 | 17 | 18 | 19 | Optical neural network(ONN) is a novel machine learning framework on the physical principles of optics, which is still in its infancy and shows great potential. ONN tries to find optimal modulation parameters to change the phase, amplitude or other physical variable of optical wave propagation. So in the final output plane, the optical distribution has special pattern which is the indicator of object’s class or value. ONN opens new doors for the machine learning. 20 | 21 | # BTW: 22 | 23 | I used to think that "ONN opens new doors for the machine learning", but now it seems only few people admit the significance of ONN to machine learning. It's really hard to explain why ONN performs so poorly on widely used data sets(CIFAR...), let alone Imagenet! 24 | 25 | Fortunately, I find the optical diffraction model has subtle connection with some mathematical models, which is worthy of further study. 26 | 27 | ---2/27/2022 28 | 29 | ## Citation 30 | 31 | Please use the following bibtex entry: 32 | ``` 33 | [1] Xinyu, Zhang, Jiashuo Shi, and Yingshi Chen. "A Broad-Spectrum Diffractive Network via Ensemble Learning." Opt. Lett 46 (2021): 14. 34 | [2] Chen, Yingshi, et al."An optical diffractive deep neural network with multiple frequency-channels." arXiv preprint arXiv:1912.10730 (2019). 35 | [3] Chen, Yingshi, et al. "Express Wavenet: A lower parameter optical neural network with random shift wavelet pattern." Optics Communications 485 (2021): 126709. 36 | ``` 37 | 38 | ## Future work 39 | 40 | - More testing datasets 41 | 42 | ​ Cifar, ImageNet ...... 43 | 44 | - More models 45 | 46 | - More papers. 47 | 48 | 49 | 50 | ## License 51 | 52 | The provided implementation is strictly for academic purposes only. If anyone is interested in using our technology for any commercial use, please contact us. 53 | 54 | ## Authors 55 | 56 | Yingshi Chen (gsp.cys@gmail.com) 57 | 58 | QQ group: 1001583663 59 | -------------------------------------------------------------------------------- /case_brain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Yingshi Chen 3 | 4 | @Date: 2020-04-08 17:12:34 5 | @ 6 | # Description: 7 | ''' 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import ToPILImage 11 | import os 12 | import math 13 | import hdf5storage 14 | from enum import Enum 15 | import re 16 | from torchvision.transforms import transforms 17 | import cv2 18 | import numpy as np 19 | 20 | def get_data_if_needed(data_path='./data/', url="https://ndownloader.figshare.com/articles/1512427/versions/5"): 21 | if os.path.isdir(data_path): 22 | #_arrange_brain_tumor_data(data_path) 23 | print("Data directory already exists. ", 24 | "if from some reason the data directory structure is wrong please remove the data dir and rerun this script") 25 | return 26 | filename = "all_data.zip" 27 | download_url(url, data_path, filename) 28 | unzip_all_files(data_path) 29 | _arrange_brain_tumor_data(data_path) 30 | 31 | def convert_landmark_to_bounding_box(landmark): 32 | x_min = x_max = y_min = y_max = None 33 | for x, y in landmark: 34 | if x_min is None: 35 | x_min = x_max = x 36 | y_min = y_max = y 37 | else: 38 | x_min, x_max = min(x, x_min), max(x, x_max) 39 | y_min, y_max = min(y, y_min), max(y, y_max) 40 | return [int(x_min), int(x_max), int(y_min), int(y_max)] 41 | 42 | class ClassesLabels(Enum): 43 | Meningioma = 1 44 | Glioma = 2 45 | Pituitary = 3 46 | 47 | def __len__(self): 48 | return 3 49 | 50 | def normalize(x, mean=470, std=None): 51 | mean_tansor = torch.ones_like(x) * mean 52 | x -= mean_tansor 53 | if std: 54 | x /= std 55 | return x 56 | 57 | # https://github.com/galprz/brain-tumor-segmentation 58 | class BrainTumorDataset(Dataset): 59 | def __init__(self,config, root, train=True, download=True, 60 | classes=(ClassesLabels.Meningioma, 61 | ClassesLabels.Glioma, 62 | ClassesLabels.Pituitary)): 63 | super().__init__() 64 | self.config = config 65 | test_fr = 0.15 66 | if download: 67 | get_data_if_needed(root) 68 | self.root = root 69 | # List all data files 70 | items = [] 71 | if ClassesLabels.Meningioma in classes: 72 | items += ['meningioma/' + item for item in os.listdir(root + 'meningioma/')] 73 | if ClassesLabels.Glioma in classes: 74 | items += ['glioma/' + item for item in os.listdir(root + 'glioma/')] 75 | if ClassesLabels.Meningioma in classes: 76 | items += ['pituitary/' + item for item in os.listdir(root + 'pituitary/')] 77 | 78 | if train: 79 | self.items = items[0:math.floor((1-test_fr) * len(items)) + 1] 80 | else: 81 | self.items = items[math.floor((1-test_fr) * len(items)) + 1:] 82 | 83 | def __len__(self): 84 | return len(self.items) 85 | 86 | def __getitem__(self, idx): 87 | if not (0 <= idx < len(self.items)): 88 | raise IndexError("Idx out of bound") 89 | if False: 90 | data = hdf5storage.loadmat(self.root + self.items[idx])['cjdata'][0] 91 | # transform the tumor border to array of (x, y) tuple 92 | xy = data[3] 93 | landmarks = [] 94 | for i in range(0, len(xy), 2): 95 | x = xy[i][0] 96 | y = xy[i + 1][0] 97 | landmarks.append((x, y)) 98 | mask = data[4] 99 | data[2].dtype = 'uint16' 100 | image = data[2] #ToPILImage()(data[2]) 101 | image_with_metadata = { 102 | "label": int(data[0][0]), 103 | "image": image, 104 | "landmarks": landmarks, 105 | "mask": mask, 106 | "bounding_box": convert_landmark_to_bounding_box(landmarks) 107 | } 108 | return image_with_metadata 109 | else: 110 | return load_mat_trans(self.root + self.items[idx],target_size=self.config.IMG_size ) #(128,128) 111 | 112 | def ToUint8(arr): 113 | a_0,a_1 = np.min(arr),np.max(arr) 114 | arr = (arr-a_0)/(a_1-a_0)*255 115 | arr = arr.astype(np.uint8) 116 | a_0,a_1 = np.min(arr),np.max(arr) 117 | return arr 118 | 119 | def load_mat_trans(path,target_size=None): 120 | data_mat = hdf5storage.loadmat(path) 121 | data = data_mat['cjdata'][0] 122 | xy = data[3] 123 | landmarks = [] 124 | for i in range(0, len(xy), 2): 125 | x = xy[i][0] 126 | y = xy[i + 1][0] 127 | landmarks.append((x, y)) 128 | mask = data[4].astype(np.float32) 129 | m_0,m_1 = np.min(mask),np.max(mask) 130 | #data[2].dtype = 'uint16' 131 | image = data[2].astype(np.float32) #ToPILImage()(data[2]) 132 | if target_size is not None: 133 | image = cv2.resize(image,target_size) 134 | #cv2.imshow("",image); cv2.waitKey(0) 135 | mask = cv2.resize(mask,target_size) 136 | #cv2.imshow("",mask*255); cv2.waitKey(0) 137 | image = ToUint8(image) 138 | mask = ToUint8(mask) 139 | image_with_metadata = { 140 | "label": int(data[0][0]), 141 | "image": image, 142 | "landmarks": landmarks, 143 | "mask": mask, 144 | "bounding_box": convert_landmark_to_bounding_box(landmarks) 145 | } 146 | return image_with_metadata 147 | 148 | mask_transformer = transforms.Compose([ 149 | transforms.ToTensor(), 150 | ]) 151 | 152 | image_transformer_0 = transforms.Compose([ 153 | transforms.ToTensor(), 154 | transforms.Lambda(lambda x: normalize(x)) 155 | ]) 156 | image_transformer = transforms.Compose([ 157 | transforms.ToTensor(), 158 | ]) 159 | 160 | class BrainTumorDatasetMask(BrainTumorDataset): 161 | def transform(self,image, mask): 162 | img = image_transformer(image).float() 163 | mask = mask_transformer(mask).float() 164 | return img,mask 165 | 166 | def __init__(self,config, root, train=True, transform=None, classes=(ClassesLabels.Meningioma, 167 | ClassesLabels.Glioma, 168 | ClassesLabels.Pituitary)): 169 | super().__init__(config,root, train, classes=classes) 170 | #self.transform = brain_transform 171 | 172 | def __getitem__(self, idx): 173 | item = super().__getitem__(idx) 174 | sample = (item["image"], item["mask"]) 175 | #return sample if self.transform is None else self.transform(*sample) 176 | img,mask = self.transform(item["image"], item["mask"]) 177 | #i_0,i_1 = torch.min(img),torch.max(img) 178 | #m_0,m_1 = torch.min(mask),torch.max(mask) 179 | return img,mask 180 | 181 | def _arrange_brain_tumor_data(root): 182 | # Remove and split files 183 | items = [item for item in filter(lambda item: re.search("^[0-9]+\.mat$", item), os.listdir(root))] 184 | try: 185 | os.mkdir(root + 'meningioma/') 186 | except: 187 | print("Meningioma directory already exists") 188 | try: 189 | os.mkdir(root + 'glioma/') 190 | except: 191 | print("Glioma directory already exists") 192 | try: 193 | os.mkdir(root + 'pituitary/') 194 | except: 195 | print("Pituitary directory already exists") 196 | 197 | for item in items: 198 | sample = hdf5storage.loadmat(root + item)['cjdata'][0] 199 | if sample[2].shape[0] == 512: 200 | if sample[0] == 1: 201 | os.rename(root + item, root + 'meningioma/' + item) 202 | if sample[0] == 2: 203 | os.rename(root + item, root + 'glioma/' + item) 204 | if sample[0] == 3: 205 | os.rename(root + item, root + 'pituitary/' + item) 206 | else: 207 | os.remove(root + item) 208 | 209 | -------------------------------------------------------------------------------- /case_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Train CIFAR10 with PyTorch. 3 | https://github.com/kuangliu/pytorch-cifar 4 | 5 | https://medium.com/@wwwbbb8510/lessons-learned-from-reproducing-resnet-and-densenet-on-cifar-10-dataset-6e25b03328da 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | import torch.backends.cudnn as cudnn 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | import os 15 | import sys 16 | import argparse 17 | CNN_MODEL_root = os.path.dirname(os.path.abspath(__file__))+"/python-package" 18 | sys.path.append(CNN_MODEL_root) 19 | from cnn_models import * 20 | ONNET_DIR = os.path.abspath("./python-package/") 21 | sys.path.append(ONNET_DIR) # To find local version of the onnet 22 | from onnet import * 23 | from onnet.OpticalFormer import clip_grad 24 | import sys 25 | import time 26 | import torch.nn as nn 27 | import torch.nn.init as init 28 | 29 | 30 | # The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. The dataset is divided into five training batches and one test batch, each with 10000 images. 31 | IMG_size = (32, 32) 32 | IMG_size = (96, 96) 33 | isDNet = False 34 | isGrayScale = False 35 | 36 | def get_mean_and_std(dataset): 37 | '''Compute the mean and std value of dataset.''' 38 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 39 | mean = torch.zeros(3) 40 | std = torch.zeros(3) 41 | print('==> Computing mean and std..') 42 | for inputs, targets in dataloader: 43 | for i in range(3): 44 | mean[i] += inputs[:,i,:,:].mean() 45 | std[i] += inputs[:,i,:,:].std() 46 | mean.div_(len(dataset)) 47 | std.div_(len(dataset)) 48 | return mean, std 49 | 50 | def init_params(net): 51 | '''Init layer parameters.''' 52 | for m in net.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | init.kaiming_normal(m.weight, mode='fan_out') 55 | if m.bias: 56 | init.constant(m.bias, 0) 57 | elif isinstance(m, nn.BatchNorm2d): 58 | init.constant(m.weight, 1) 59 | init.constant(m.bias, 0) 60 | elif isinstance(m, nn.Linear): 61 | init.normal(m.weight, std=1e-3) 62 | if m.bias: 63 | init.constant(m.bias, 0) 64 | 65 | #_, term_width = os.popen('stty size', 'r').read().split() 66 | term_width = 80 67 | TOTAL_BAR_LENGTH = 25. 68 | last_time = time.time() 69 | begin_time = last_time 70 | def progress_bar(current, total, msg=None): 71 | if current < total-1: 72 | sys.stdout.write('\r') 73 | global last_time, begin_time 74 | if current == 0: 75 | begin_time = time.time() # Reset for new bar. 76 | 77 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 78 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 79 | 80 | sys.stdout.write(' [') 81 | for i in range(cur_len): 82 | sys.stdout.write('=') 83 | sys.stdout.write('>') 84 | for i in range(rest_len): 85 | sys.stdout.write('.') 86 | sys.stdout.write(']') 87 | 88 | cur_time = time.time() 89 | step_time = cur_time - last_time 90 | last_time = cur_time 91 | tot_time = cur_time - begin_time 92 | 93 | L = [] 94 | L.append(' Step: %s' % format_time(step_time)) 95 | L.append(' | Tot: %s' % format_time(tot_time)) 96 | if msg: 97 | L.append(' | ' + msg) 98 | 99 | msg = ''.join(L) 100 | sys.stdout.write(msg) 101 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 102 | sys.stdout.write(' ') 103 | 104 | if False: 105 | # Go back to the center of the bar. 106 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 107 | sys.stdout.write('\b') 108 | sys.stdout.write(' %d/%d ' % (current+1, total)) 109 | 110 | if current < total-1: 111 | pass#sys.stdout.write('\r') 112 | else: 113 | sys.stdout.write('\n') 114 | sys.stdout.flush() 115 | 116 | def format_time(seconds): 117 | days = int(seconds / 3600/24) 118 | seconds = seconds - days*3600*24 119 | hours = int(seconds / 3600) 120 | seconds = seconds - hours*3600 121 | minutes = int(seconds / 60) 122 | seconds = seconds - minutes*60 123 | secondsf = int(seconds) 124 | seconds = seconds - secondsf 125 | millis = int(seconds*1000) 126 | 127 | f = '' 128 | i = 1 129 | if days > 0: 130 | f += str(days) + 'D' 131 | i += 1 132 | if hours > 0 and i <= 2: 133 | f += str(hours) + 'h' 134 | i += 1 135 | if minutes > 0 and i <= 2: 136 | f += str(minutes) + 'm' 137 | i += 1 138 | if secondsf > 0 and i <= 2: 139 | f += str(secondsf) + 's' 140 | i += 1 141 | if millis > 0 and i <= 2: 142 | f += str(millis) + 'ms' 143 | i += 1 144 | if f == '': 145 | f = '0ms' 146 | return f 147 | 148 | 149 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 150 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 151 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 152 | # "--gradient_clip=agc", 153 | # "--self_attention=gabor" 154 | args = parser.parse_args() 155 | 156 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 157 | best_acc = 0 # best test accuracy 158 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 159 | 160 | # Data 161 | def Init(): 162 | print('==> Preparing data..') 163 | transform_train = transforms.Compose([ 164 | transforms.RandomCrop(32, padding=4), 165 | # transforms.Grayscale(), 166 | transforms.RandomHorizontalFlip(), 167 | transforms.Resize(IMG_size), 168 | transforms.ToTensor(), 169 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 170 | # transforms.Normalize(0.48, 0.20), 171 | ]) 172 | 173 | transform_test = transforms.Compose([ 174 | # transforms.Grayscale(), 175 | transforms.Resize(IMG_size), 176 | transforms.ToTensor(), 177 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 178 | # transforms.Normalize(0.48, 0.20), 179 | ]) 180 | 181 | trainset = torchvision.datasets.CIFAR10(root='/home/cys/Downloads/cifar10/', train=True, download=True, transform=transform_train) 182 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 183 | 184 | testset = torchvision.datasets.CIFAR10(root='/home/cys/Downloads/cifar10/', train=False, download=True, transform=transform_test) 185 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 186 | 187 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 188 | # Model 189 | print('==> Building model..') 190 | if isDNet: 191 | #config_0 = RGBO_CNN_config("RGBO_CNN", 'cifar_10', IMG_size, lr_base=args.lr, batch_size=128, nClass=10, nLayer=5) 192 | #env_title, net = RGBO_CNN_instance(config_0) 193 | config_0 = NET_config("DNet",'cifar_10',IMG_size,lr_base=args.lr,batch_size=128, nClass=10, nLayer=10) 194 | env_title, net = DNet_instance(config_0) 195 | config_base = net.config 196 | else: 197 | config_0 = NET_config("OptFormer", 'cifar_10', IMG_size, lr_base=args.lr, batch_size=128, nClass=10) 198 | # net = VGG('VGG19') 199 | #net = ResNet34(); env_title='ResNet34'; net.legend = 'ResNet34' 200 | # net = OpticalNet34(config_0); env_title = 'OpticalNet34'; net.legend = 'OpticalNet34' 201 | env_title, net = DNet_instance(config_0) 202 | # net = PreActResNet18() 203 | # net = GoogLeNet() 204 | # net = DenseNet121() 205 | # net = ResNeXt29_2x64d() 206 | # net = MobileNet() 207 | # net = MobileNetV2() 208 | # net = DPN92(); env_title='DPN92'; net.legend = 'DPN92' 209 | # net = DPN26(); env_title = 'DPN92'; net.legend = 'DPN92' 210 | # net = ShuffleNetG2() 211 | # net = SENet18() 212 | # net = ShuffleNetV2(1) 213 | # net = EfficientNetB0(); env_title='EfficientNetB0' 214 | #visual = Visdom_Visualizer(env_title=env_title) 215 | 216 | print(net) 217 | Net_dump(net) 218 | net = net.to(device) 219 | visual = Visdom_Visualizer(env_title=env_title) 220 | #if hasattr(net, 'DInput'): net.DInput.visual = visual # 看一看 221 | 222 | if device == 'cuda': 223 | pass 224 | #net = torch.nn.DataParallel(net) #https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html 225 | #cudnn.benchmark = True #结果会有扰动 https://zhuanlan.zhihu.com/p/73711222 226 | 227 | if args.resume: 228 | # Load checkpoint. 229 | print('==> Resuming from checkpoint..') 230 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 231 | checkpoint = torch.load('./checkpoint/ckpt.pth') 232 | net.load_state_dict(checkpoint['net']) 233 | best_acc = checkpoint['acc'] 234 | start_epoch = checkpoint['epoch'] 235 | 236 | criterion = nn.CrossEntropyLoss() 237 | #using SGD with scheduled learning rate much better than Adam 238 | optimizer = optim.Adam(net.parameters(), lr=args.lr) # weight_decay=0.0005 239 | #optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 240 | 241 | return net,trainloader,testloader,optimizer,criterion,visual 242 | 243 | # Training 244 | def train(epoch,net,trainloader,optimizer,criterion): 245 | print('\nEpoch: %d' % epoch) 246 | if epoch==0: 247 | #print(f"\n=======dataset={dataset} net={net_type} IMG_size={IMG_size} batch_size={batch_size}") 248 | #print(f"======={net.config}") 249 | print(f"======={optimizer}") 250 | #print(f"======={train_trans}\n") 251 | net.train() 252 | train_loss = 0 253 | correct = 0 254 | total = 0 255 | for batch_idx, (inputs, targets) in enumerate(trainloader): 256 | inputs, targets = inputs.to(device), targets.to(device) 257 | optimizer.zero_grad() 258 | outputs = net(inputs) 259 | loss = criterion(outputs, targets) 260 | loss.backward() #retain_graph=True 261 | if net.clip_grad == "agc": 262 | clip_grad(net) 263 | optimizer.step() 264 | 265 | train_loss += loss.item() 266 | _, predicted = outputs.max(1) 267 | total += targets.size(0) 268 | correct += predicted.eq(targets).sum().item() 269 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 270 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 271 | #break 272 | 273 | 274 | def test(epoch,net,testloader,criterion,visual): 275 | global best_acc 276 | net.eval() 277 | test_loss = 0 278 | correct = 0 279 | total = 0 280 | with torch.no_grad(): 281 | for batch_idx, (inputs, targets) in enumerate(testloader): 282 | inputs, targets = inputs.to(device), targets.to(device) 283 | outputs = net(inputs) 284 | loss = criterion(outputs, targets) 285 | 286 | test_loss += loss.item() 287 | _, predicted = outputs.max(1) 288 | total += targets.size(0) 289 | correct += predicted.eq(targets).sum().item() 290 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 291 | % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 292 | #break 293 | 294 | # Save checkpoint. 295 | acc = 100.*correct/total 296 | legend = "resnet"#net.module.legend() 297 | visual.UpdateLoss(title=f"Accuracy on \"cifar_10\"", legend=f"{legend}", loss=acc, yLabel="Accuracy") 298 | if False and acc > best_acc: 299 | print('Saving..') 300 | state = { 301 | 'net': net.state_dict(), 302 | 'acc': acc, 303 | 'epoch': epoch, 304 | } 305 | if not os.path.isdir('checkpoint'): 306 | os.mkdir('checkpoint') 307 | torch.save(state, './checkpoint/ckpt.pth') 308 | best_acc = acc 309 | 310 | if __name__ == '__main__': 311 | seed_everything(42) 312 | net,trainloader,testloader,optimizer,criterion,visual = Init() 313 | #legend = net.module.legend() 314 | 315 | for epoch in range(start_epoch, start_epoch+2000): 316 | train(epoch,net,trainloader,optimizer,criterion) 317 | test(epoch,net,testloader,criterion,visual) 318 | -------------------------------------------------------------------------------- /case_covir.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Yingshi Chen 3 | https://github.com/lindawangg/COVID-Net/blob/master/create_COVIDx_v2.ipynb 4 | @Date: 2020-04-06 15:50:21 5 | @ 6 | # Description: 7 | ''' 8 | import numpy as np 9 | import pandas as pd 10 | import os 11 | import random 12 | from shutil import copyfile 13 | import pydicom as dicom 14 | import cv2 15 | from torch.utils.data import Dataset,DataLoader 16 | from torch.optim.lr_scheduler import ReduceLROnPlateau 17 | from torch.nn import CrossEntropyLoss 18 | from PIL import Image 19 | import logging 20 | import sys 21 | import time 22 | ONNET_DIR = os.path.abspath("./python-package/") 23 | sys.path.append(ONNET_DIR) # To find local version of the onnet 24 | #sys.path.append(os.path.abspath("./python-package/cnn_models/")) 25 | from cnn_models.COVIDNext50 import COVIDNext50 26 | from onnet import * 27 | import torch 28 | from torch.optim import Adam 29 | from torchvision import transforms 30 | from sklearn.metrics import f1_score, precision_score, recall_score,accuracy_score,classification_report 31 | 32 | isONN=True 33 | class COVID_set(Dataset): 34 | def __init__(self, config,img_dir, labels_file, transforms): 35 | self.config = config 36 | self.img_pths, self.labels = self._prepare_data(img_dir, labels_file) 37 | self.transforms = transforms 38 | 39 | 40 | def _prepare_data(self, img_dir, labels_file): 41 | with open(labels_file, 'r') as f: 42 | labels_raw = f.readlines() 43 | 44 | labels, img_pths = [], [] 45 | for i in range(len(labels_raw)): 46 | data = labels_raw[i].split() 47 | img_pth = data[1] 48 | #img_name = data[1] 49 | #img_pth = os.path.join(img_dir, img_name) 50 | img_pths.append(img_pth) 51 | labels.append(self.config.mapping[data[2]]) 52 | 53 | return img_pths, labels 54 | 55 | def __len__(self): 56 | return len(self.labels) 57 | 58 | def __getitem__(self, idx): 59 | img = Image.open(self.img_pths[idx]).convert("RGB") 60 | img_tensor = self.transforms(img) 61 | 62 | label = self.labels[idx] 63 | label_tensor = torch.tensor(label, dtype=torch.long) 64 | 65 | return img_tensor, label_tensor 66 | 67 | def train_test_split(): 68 | seed = 0 69 | np.random.seed(seed) # Reset the seed so all runs are the same. 70 | random.seed(seed) 71 | MAXVAL = 255 # Range [0 255] 72 | 73 | # path to covid-19 dataset from https://github.com/ieee8023/covid-chestxray-dataset 74 | imgpath = 'E:/Insegment/covid-chestxray-dataset-master/images' 75 | csvpath = 'E:/Insegment/covid-chestxray-dataset-master/metadata.csv' 76 | 77 | # path to https://www.kaggle.com/c/rsna-pneumonia-detection-challenge 78 | kaggle_datapath = 'F:/Datasets/rsna-pneumonia-detection-challenge/' 79 | kaggle_csvname = 'stage_2_detailed_class_info.csv' # get all the normal from here 80 | kaggle_csvname2 = 'stage_2_train_labels.csv' # get all the 1s from here since 1 indicate pneumonia 81 | kaggle_imgpath = 'stage_2_train_images' 82 | 83 | # parameters for COVIDx dataset 84 | train = [] 85 | test = [] 86 | test_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0} 87 | train_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0} 88 | 89 | mapping = dict() 90 | mapping['COVID-19'] = 'COVID-19' 91 | mapping['SARS'] = 'pneumonia' 92 | mapping['MERS'] = 'pneumonia' 93 | mapping['Streptococcus'] = 'pneumonia' 94 | mapping['Normal'] = 'normal' 95 | mapping['Lung Opacity'] = 'pneumonia' 96 | mapping['1'] = 'pneumonia' 97 | 98 | train_file = open("train_split_v2.txt","a") 99 | test_file = open("test_split_v2.txt", "a") 100 | # train/test split 101 | split = 0.1 102 | csv = pd.read_csv(csvpath, nrows=None) 103 | idx_pa = csv["view"] == "PA" # Keep only the PA view 104 | csv = csv[idx_pa] 105 | 106 | pneumonias = ["COVID-19", "SARS", "MERS", "ARDS", "Streptococcus"] 107 | pathologies = ["Pneumonia","Viral Pneumonia", "Bacterial Pneumonia", "No Finding"] + pneumonias 108 | pathologies = sorted(pathologies) 109 | 110 | filename_label = {'normal': [], 'pneumonia': [], 'COVID-19': []} 111 | count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0} 112 | for index, row in csv.iterrows(): 113 | f = row['finding'] 114 | if f in mapping: 115 | count[mapping[f]] += 1 116 | entry = [int(row['patientid']), row['filename'], mapping[f]] 117 | filename_label[mapping[f]].append(entry) 118 | 119 | print('Data distribution from covid-chestxray-dataset:') 120 | print(count) 121 | 122 | for key in filename_label.keys(): 123 | arr = np.array(filename_label[key]) 124 | if arr.size == 0: 125 | continue 126 | # split by patients 127 | # num_diff_patients = len(np.unique(arr[:,0])) 128 | # num_test = max(1, round(split*num_diff_patients)) 129 | # select num_test number of random patients 130 | if key == 'pneumonia': 131 | test_patients = ['8', '31'] 132 | elif key == 'COVID-19': 133 | test_patients = ['19', '20', '36', '42', '86'] # random.sample(list(arr[:,0]), num_test) 134 | else: 135 | test_patients = [] 136 | print('Key: ', key) 137 | print('Test patients: ', test_patients) 138 | # go through all the patients 139 | for patient in arr: 140 | info = f"{str(patient[0])} {imgpath}\{patient[1]} {patient[2]}\n" 141 | if patient[0] in test_patients: 142 | #copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'test', patient[1])) 143 | test.append(patient); test_count[patient[2]] += 1 144 | train_file.write(info) 145 | else: 146 | #copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1])) 147 | train.append(patient); train_count[patient[2]] += 1 148 | test_file.write(info) 149 | 150 | 151 | csv_normal = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname), nrows=None) 152 | csv_pneu = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname2), nrows=None) 153 | patients = {'normal': [], 'pneumonia': []} 154 | 155 | for index, row in csv_normal.iterrows(): 156 | if row['class'] == 'Normal': 157 | patients['normal'].append(row['patientId']) 158 | 159 | for index, row in csv_pneu.iterrows(): 160 | if int(row['Target']) == 1: 161 | patients['pneumonia'].append(row['patientId']) 162 | 163 | for key in patients.keys(): 164 | arr = np.array(patients[key]) 165 | if arr.size == 0: 166 | continue 167 | # split by patients 168 | num_diff_patients = len(np.unique(arr)) 169 | num_test = max(1, round(split*num_diff_patients)) 170 | #test_patients = np.load('rsna_test_patients_{}.npy'.format(key)) # 171 | test_patients = random.sample(list(arr), num_test) #, download the .npy files from the repo. 172 | np.save('rsna_test_patients_{}.npy'.format(key), np.array(test_patients)) 173 | for patient in arr: 174 | ds = dicom.dcmread(os.path.join(kaggle_datapath, kaggle_imgpath, patient + '.dcm')) 175 | pixel_array_numpy = ds.pixel_array 176 | imgname = patient + '.png' 177 | if patient in test_patients: 178 | path = os.path.join(kaggle_datapath, 'test', imgname) 179 | cv2.imwrite(path, pixel_array_numpy) 180 | test.append([patient, imgname, key]); test_count[key] += 1 181 | test_file.write(f"{patient} {path} {key}\n" ) 182 | if test_count[key]%50==0: 183 | test_file.flush() 184 | else: 185 | path = os.path.join(kaggle_datapath, 'train', imgname) 186 | cv2.imwrite(path, pixel_array_numpy) 187 | train_file.write(f"{patient} {path} {key}\n") 188 | if train_count[key]%20==0: 189 | train_file.flush() 190 | train.append([patient, imgname, key]); train_count[key] += 1 191 | print(f"\r@{path}",end="") 192 | 193 | print('Final stats') 194 | print('Train count: ', train_count) 195 | print('Test count: ', test_count) 196 | print('Total length of train: ', len(train)) 197 | print('Total length of test: ', len(test)) 198 | 199 | train_file.close() 200 | test_file.close() 201 | 202 | 203 | 204 | log = logging.getLogger(__name__) 205 | logging.basicConfig(level=logging.INFO) 206 | 207 | 208 | def save_model(model, config): 209 | if isinstance(model, torch.nn.DataParallel): 210 | # Save without the DataParallel module 211 | model_dict = model.module.state_dict() 212 | else: 213 | model_dict = model.state_dict() 214 | 215 | state = { 216 | "state_dict": model_dict, 217 | "global_step": config['global_step'], 218 | "clf_report": config['clf_report'] 219 | } 220 | f1_macro = config['clf_report']['macro avg']['f1-score'] * 100 221 | name = "{}_F1_{:.2f}_step_{}.pth".format(config['name'], 222 | f1_macro, 223 | config['global_step']) 224 | model_path = os.path.join(config['save_dir'], name) 225 | torch.save(state, model_path) 226 | log.info("Saved model to {}".format(model_path)) 227 | 228 | 229 | def validate(data_loader, model, best_score, global_step, cfg): 230 | model.eval() 231 | gts, predictions = [], [] 232 | 233 | log.info("Validation started...") 234 | for data in data_loader: 235 | imgs, labels = data 236 | imgs = to_device(imgs, gpu=cfg.gpu) 237 | 238 | with torch.no_grad(): 239 | logits = model(imgs) 240 | if isONN: 241 | preds = net.predict(logits).cpu().numpy() 242 | else: 243 | probs = model.module.probability(logits) 244 | preds = torch.argmax(probs, dim=1).cpu().numpy() 245 | 246 | labels = labels.cpu().detach().numpy() 247 | predictions.extend(preds) 248 | gts.extend(labels) 249 | 250 | predictions = np.array(predictions, dtype=np.int32) 251 | gts = np.array(gts, dtype=np.int32) 252 | acc, f1, prec, rec = clf_metrics(predictions=predictions,targets=gts,average="macro") 253 | report = classification_report(gts, predictions, output_dict=True) 254 | log.info("\n====== VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | Recall {:.4f}".format(acc, f1, prec, rec)) 255 | 256 | if f1 > best_score: 257 | save_config = { 258 | 'name': config.name, 259 | 'save_dir': config.ckpts_dir, 260 | 'global_step': global_step, 261 | 'clf_report': report 262 | } 263 | #save_model(model=model, config=save_config) 264 | best_score = f1 265 | #log.info("Validation end") 266 | model.train() 267 | return best_score 268 | 269 | def train_transforms(width, height): 270 | trans_list = [ 271 | transforms.Resize((height, width)), 272 | transforms.RandomVerticalFlip(p=0.5), 273 | transforms.RandomHorizontalFlip(p=0.5), 274 | transforms.RandomApply([ 275 | transforms.RandomAffine(degrees=20, 276 | translate=(0.15, 0.15), 277 | scale=(0.8, 1.2), 278 | shear=5)], p=0.5), 279 | transforms.RandomApply([ 280 | transforms.ColorJitter(brightness=0.3, contrast=0.3)], p=0.5), 281 | transforms.Grayscale(), 282 | transforms.ToTensor() 283 | ] 284 | return transforms.Compose(trans_list) 285 | 286 | 287 | def val_transforms(width, height): 288 | trans_list = [ 289 | transforms.Resize((height, width)), 290 | transforms.Grayscale(), 291 | transforms.ToTensor() 292 | ] 293 | return transforms.Compose(trans_list) 294 | 295 | def to_device(tensor, gpu=False): 296 | return tensor.cuda() if gpu else tensor.cpu() 297 | 298 | def clf_metrics(predictions, targets, average='macro'): 299 | f1 = f1_score(targets, predictions, average=average) 300 | precision = precision_score(targets, predictions, average=average) 301 | recall = recall_score(targets, predictions, average=average) 302 | acc = accuracy_score(targets, predictions) 303 | 304 | return acc, f1, precision, recall 305 | 306 | def main(model): 307 | if config.gpu and not torch.cuda.is_available(): 308 | raise ValueError("GPU not supported or enabled on this system.") 309 | use_gpu = config.gpu 310 | 311 | log.info("Loading train dataset") 312 | train_dataset = COVID_set(config,config.train_imgs, config.train_labels,train_transforms(config.width,config.height)) 313 | train_loader = DataLoader(train_dataset, 314 | batch_size=config.batch_size,shuffle=True,drop_last=True, num_workers=config.n_threads,pin_memory=use_gpu) 315 | log.info("Number of training examples {}".format(len(train_dataset))) 316 | 317 | log.info("Loading val dataset") 318 | val_dataset = COVID_set(config,config.val_imgs, config.val_labels,val_transforms(config.width,config.height)) 319 | val_loader = DataLoader(val_dataset, 320 | batch_size=config.batch_size, 321 | shuffle=False, 322 | num_workers=config.n_threads, 323 | pin_memory=use_gpu) 324 | log.info("Number of validation examples {}".format(len(val_dataset))) 325 | 326 | if use_gpu: 327 | model.cuda() 328 | #model = torch.nn.DataParallel(model) 329 | optim_layers = filter(lambda p: p.requires_grad, model.parameters()) 330 | 331 | # optimizer and lr scheduler 332 | optimizer = Adam(optim_layers, 333 | lr=config.lr, 334 | weight_decay=config.weight_decay) 335 | scheduler = ReduceLROnPlateau(optimizer=optimizer, 336 | factor=config.lr_reduce_factor, 337 | patience=config.lr_reduce_patience, 338 | mode='max', 339 | min_lr=1e-7) 340 | 341 | # Load the last global_step from the checkpoint if existing 342 | global_step = 0 if state is None else state['global_step'] + 1 343 | 344 | class_weights = to_device(torch.FloatTensor(config.loss_weights),gpu=use_gpu) 345 | loss_fn = CrossEntropyLoss(reduction='mean', weight=class_weights) 346 | 347 | # Reset the best metric score 348 | best_score = -1 349 | t0=time.time() 350 | for epoch in range(config.epochs): 351 | log.info("\nStarted epoch {}/{}".format(epoch + 1,config.epochs)) 352 | for data in train_loader: 353 | imgs, labels = data 354 | imgs = to_device(imgs, gpu=use_gpu) 355 | labels = to_device(labels, gpu=use_gpu) 356 | 357 | logits = model(imgs) 358 | loss = loss_fn(logits, labels) 359 | optimizer.zero_grad() 360 | loss.backward() 361 | optimizer.step() 362 | 363 | if global_step % config.log_steps == 0 and global_step > 0: 364 | if isONN: 365 | preds = net.predict(logits).cpu().numpy() 366 | else: 367 | probs = model.module.probability(logits) 368 | preds = torch.argmax(probs, dim=1).detach().cpu().numpy() 369 | labels = labels.cpu().detach().numpy() 370 | acc, f1, _, _ = clf_metrics(preds, labels) 371 | lr = optimizer.param_groups[0]['lr'] #get_learning_rate(optimizer) 372 | print(f"\r{global_step} | batch: Loss={loss.item():.3f} | F1={f1:.3f} | Accuracy={acc:.4f} | LR={lr:.2e}\tT={time.time()-t0:.4f}",end="") 373 | 374 | 375 | if global_step % config.eval_steps == 0 and global_step > 0: 376 | best_score = validate(val_loader, model,best_score=best_score,global_step=global_step,cfg=config) 377 | scheduler.step(best_score) 378 | global_step += 1 379 | 380 | def UpdateConfig(config): 381 | config.name = "COVIDNext50_NewData" 382 | config.gpu = True 383 | config.batch_size = 16 384 | config.n_threads = 4 385 | config.random_seed = 1337 386 | config.weights = "E:/Insegment/COVID-Next-Pytorch-master/COVIDNext50_NewData_F1_92.98_step_10800.pth" 387 | config.lr = 1e-4 388 | config.weight_decay = 1e-3 389 | config.lr_reduce_factor = 0.7 390 | config.lr_reduce_patience = 5 391 | # Data 392 | config.train_imgs = None#"/data/ssd/datasets/covid/COVIDxV2/data/train" 393 | config.train_labels = "E:/ONNet/data/covid_train_split_v2.txt" #"/data/ssd/datasets/covid/COVIDxV2/data/train_COVIDx.txt" 394 | config.val_imgs = None#"/data/ssd/datasets/covid/COVIDxV2/data/test" 395 | config.val_labels = "E:/ONNet/data/covid_test_split_v2.txt" #"/data/ssd/datasets/covid/COVIDxV2/data/test_COVIDx.txt" 396 | # Categories mapping 397 | config.mapping = { 398 | 'normal': 0, 399 | 'pneumonia': 1, 400 | 'COVID-19': 2 401 | } 402 | # Loss weigths order follows the order in the category mapping dict 403 | config.loss_weights = [0.05, 0.05, 1.0] 404 | 405 | config.width = 256 406 | config.height = 256 407 | config.n_classes = len(config.mapping) 408 | # Training 409 | config.epochs = 300 410 | config.log_steps = 5 411 | config.eval_steps = 400 412 | config.ckpts_dir = "./experiments/ckpts" 413 | return config 414 | 415 | IMG_size = (256, 256) 416 | if __name__ == '__main__': 417 | config_0 = NET_config("DNet",'covid',IMG_size,0.01,batch_size=16, nClass=3, nLayer=5) 418 | #config_0 = RGBO_CNN_config("RGBO_CNN",'covid',IMG_size,0.01,batch_size=16, nClass=3, nLayer=5) 419 | if isONN: 420 | env_title, net = DNet_instance(config_0) 421 | #env_title, net = RGBO_CNN_instance(config_0) 422 | config = net.config 423 | config = UpdateConfig(config) 424 | config.batch_size = 64 425 | config.log_steps = 10 426 | config.lr = 0.001 427 | state = None 428 | else: 429 | config = UpdateConfig(config_0) 430 | if config.weights: 431 | state = torch.load(config.weights) 432 | log.info("Loaded model weights from: {}".format(config.weights)) 433 | else: 434 | state = None 435 | 436 | state_dict = state["state_dict"] if state else None 437 | net = COVIDNext50(n_classes=config.n_classes) 438 | if state_dict: 439 | net = load_model_weights(model=net, state_dict=state_dict,log=log) 440 | print(net) 441 | Net_dump(net) 442 | seed_everything(config.random_seed) 443 | main(net) 444 | -------------------------------------------------------------------------------- /case_dog_cat.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 1) https://github.com/rdcolema/pytorch-image-classification/blob/master/pytorch_model.ipynb 3 | https://github.com/mukul54/A-Simple-Cat-vs-Dog-Classifier-in-Pytorch/blob/master/catVsDog.py 4 | ''' 5 | # https://github.com/mukul54/A-Simple-Cat-vs-Dog-Classifier-in-Pytorch/blob/master/catVsDog.py 6 | 7 | import numpy as np # Matrix Operations (Matlab of Python) 8 | import pandas as pd # Work with Datasources 9 | import matplotlib.pyplot as plt # Drawing Library 10 | from PIL import Image 11 | import torch # Like a numpy but we could work with GPU by pytorch library 12 | import torch.nn as nn # Nural Network Implimented with pytorch 13 | import torchvision # A library for work with pretrained model and datasets 14 | from torchvision import transforms 15 | from torch.utils.data import Dataset 16 | from torch.utils.data import DataLoader 17 | import torch.nn.functional as F 18 | import glob 19 | import os 20 | 21 | image_size = (100, 100) 22 | image_row_size = image_size[0] * image_size[1] 23 | 24 | if False: #https://medium.com/predict/using-pytorch-for-kaggles-famous-dogs-vs-cats-challenge-part-1-preprocessing-and-training-407017e1a10c 25 | import shutil 26 | import re 27 | files = os.listdir(train_dir) 28 | # Move all train cat images to cats folder, dog images to dogs folder 29 | for f in files: 30 | catSearchObj = re.search("cat", f) 31 | dogSearchObj = re.search("dog", f) 32 | if catSearchObj: 33 | shutil.move(f'{train_dir}/{f}', train_cats_dir) 34 | elif dogSearchObj: 35 | shutil.move(f'{train_dir}/{f}', train_dogs_dir) 36 | pass 37 | 38 | class CatDogDataset(Dataset): 39 | def __init__(self, path, transform=None): 40 | self.classes = ["cat","dog"] #os.listdir(path) 41 | self.path = path #[f"{path}/{className}" for className in self.classes] 42 | #self.file_list = [glob.glob(f"{x}/*") for x in self.path] 43 | self.transform = transform 44 | 45 | files = [] 46 | for i, className in enumerate(self.classes): 47 | query = f"{self.path}{className}*" 48 | cls_list = glob.glob(query) 49 | print(f"{className}:n={len(cls_list)}") 50 | for fileName in cls_list: 51 | files.append([i, className, fileName]) 52 | self.file_list = files 53 | files = None 54 | 55 | def __len__(self): 56 | return len(self.file_list) 57 | 58 | def __getitem__(self, idx): 59 | fileName = self.file_list[idx][2] 60 | classCategory = self.file_list[idx][0] 61 | im = Image.open(fileName) 62 | if self.transform: 63 | im = self.transform(im) 64 | return im.view(-1), classCategory 65 | 66 | #mean = [0.485, 0.456, 0.406]; std = [0.229, 0.224, 0.225] 67 | mean = [0.485]; std = [0.229] 68 | transform = transforms.Compose([ 69 | transforms.Resize(image_size), 70 | transforms.Grayscale(), 71 | transforms.ToTensor(), 72 | transforms.Normalize(mean, std)]) 73 | 74 | path = '../data/dog_cat/train/' 75 | dataset = CatDogDataset(path, transform=transform) 76 | if True: 77 | def imshow(source): 78 | plt.figure(figsize=(10,10)) 79 | imt = (source.view(-1, image_size[0], image_size[0])) 80 | imt = imt.numpy().transpose([1,2,0]) 81 | imt = (std * imt + mean).clip(0,1) 82 | plt.subplot(1,2,2) 83 | plt.imshow(imt.squeeze()) 84 | imshow(dataset[0][0]) 85 | imshow(dataset[2][0]) 86 | imshow(dataset[6000][0]) 87 | plt.show() 88 | 89 | shuffle = True 90 | batch_size = 64 91 | num_workers = 0 92 | dataloader = DataLoader(dataset=dataset, 93 | shuffle=shuffle, 94 | batch_size=batch_size, 95 | num_workers=num_workers) 96 | 97 | class MyModel(torch.nn.Module): 98 | def __init__(self, in_feature): 99 | super(MyModel, self).__init__() 100 | self.fc1 = torch.nn.Linear(in_features=in_feature, out_features=500) 101 | self.fc2 = torch.nn.Linear(in_features=500, out_features=100) 102 | self.fc3 = torch.nn.Linear(in_features=100, out_features=1) 103 | 104 | def forward(self, x): 105 | x = F.relu( self.fc1(x) ) 106 | x = F.relu( self.fc2(x) ) 107 | x = F.softmax( self.fc3(x), dim=1) 108 | return x 109 | 110 | model = MyModel(image_row_size) 111 | print(model) 112 | 113 | criterion = torch.nn.CrossEntropyLoss() 114 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.95) 115 | 116 | epochs = 10 117 | for epoch in range(epochs): 118 | for i, (X,Y) in enumerate(dataloader): 119 | # x, y = dataset[i] 120 | yhat = model(X) 121 | loss = criterion(yhat.view(-1), Y) 122 | break -------------------------------------------------------------------------------- /case_face_detect.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/jayrodge/Binary-Image-Classifier-PyTorch/blob/master/Binary_face_classifier.ipynb 3 | ''' 4 | 5 | import torch 6 | import numpy as np 7 | from torchvision import datasets 8 | import torchvision.transforms as transforms 9 | from torch.utils.data.sampler import SubsetRandomSampler 10 | import matplotlib.pyplot as plt 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | train_on_gpu = torch.cuda.is_available() 15 | # define the CNN architecture 16 | class Net(nn.Module): 17 | def __init__(self): 18 | super(Net, self).__init__() 19 | # convolutional layer 20 | self.conv1 = nn.Conv2d(3, 16, 5) 21 | # max pooling layer 22 | self.pool = nn.MaxPool2d(2, 2) 23 | self.conv2 = nn.Conv2d(16, 32, 5) 24 | self.dropout = nn.Dropout(0.2) 25 | self.fc1 = nn.Linear(32 * 53 * 53, 256) 26 | self.fc2 = nn.Linear(256, 84) 27 | self.fc3 = nn.Linear(84, 2) 28 | self.softmax = nn.LogSoftmax(dim=1) 29 | 30 | def forward(self, x): 31 | # add sequence of convolutional and max pooling layers 32 | x = self.pool(F.relu(self.conv1(x))) 33 | x = self.pool(F.relu(self.conv2(x))) 34 | x = self.dropout(x) 35 | x = x.view(-1, 32 * 53 * 53) 36 | x = F.relu(self.fc1(x)) 37 | x = self.dropout(F.relu(self.fc2(x))) 38 | x = self.softmax(self.fc3(x)) 39 | return x 40 | 41 | batch_size = 32 42 | # percentage of training set to use as validation 43 | test_size = 0.3 44 | valid_size = 0.1 45 | 46 | def imshow(img): 47 | img = img / 2 + 0.5 # unnormalize 48 | plt.imshow(np.transpose(img, (1, 2, 0))) 49 | 50 | # convert data to a normalized torch.FloatTensor 51 | transform = transforms.Compose([ 52 | transforms.RandomHorizontalFlip(), 53 | transforms.RandomRotation(20), 54 | transforms.Resize(size=(224,224)), 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 57 | ]) 58 | 59 | def load_data(): 60 | data = datasets.ImageFolder('../data/Face/',transform=transform) 61 | num_data = len(data) 62 | indices_data = list(range(num_data)) 63 | np.random.shuffle(indices_data) 64 | split_tt = int(np.floor(test_size * num_data)) 65 | train_idx, test_idx = indices_data[split_tt:], indices_data[:split_tt] 66 | 67 | #For Valid 68 | num_train = len(train_idx) 69 | indices_train = list(range(num_train)) 70 | np.random.shuffle(indices_train) 71 | split_tv = int(np.floor(valid_size * num_train)) 72 | train_new_idx, valid_idx = indices_train[split_tv:],indices_train[:split_tv] 73 | 74 | 75 | # define samplers for obtaining training and validation batches 76 | train_sampler = SubsetRandomSampler(train_new_idx) 77 | test_sampler = SubsetRandomSampler(test_idx) 78 | valid_sampler = SubsetRandomSampler(valid_idx) 79 | 80 | train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, 81 | sampler=train_sampler, num_workers=1) 82 | valid_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, 83 | sampler=valid_sampler, num_workers=1) 84 | test_loader = torch.utils.data.DataLoader(data, sampler = test_sampler, batch_size=batch_size, 85 | num_workers=1) 86 | classes = [0,1] 87 | 88 | if False: # display 20 images 89 | dataiter = iter(train_loader) 90 | images, labels = dataiter.next() 91 | images = images.numpy() 92 | fig = plt.figure(figsize=(10, 4)) 93 | for idx in np.arange(10): 94 | ax = fig.add_subplot(2, 10 / 2, idx + 1, xticks=[], yticks=[]) 95 | imshow(images[idx]) 96 | ax.set_title(classes[labels[idx]]) 97 | plt.show() 98 | return train_loader,valid_loader,test_loader,classes 99 | 100 | def some_test(test_loader,classes): 101 | # track test loss 102 | test_loss = 0.0 103 | class_correct = list(0. for i in range(2)) 104 | class_total = list(0. for i in range(2)) 105 | 106 | model.eval() 107 | i = 1 108 | # iterate over test data 109 | len(test_loader) 110 | for data, target in test_loader: 111 | i = i + 1 112 | if len(target) != batch_size: 113 | continue 114 | 115 | # move tensors to GPU if CUDA is available 116 | if train_on_gpu: 117 | data, target = data.cuda(), target.cuda() 118 | # forward pass: compute predicted outputs by passing inputs to the model 119 | output = model(data) 120 | # calculate the batch loss 121 | loss = criterion(output, target) 122 | # update test loss 123 | test_loss += loss.item() * data.size(0) 124 | # convert output probabilities to predicted class 125 | _, pred = torch.max(output, 1) 126 | # compare predictions to true label 127 | correct_tensor = pred.eq(target.data.view_as(pred)) 128 | correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy()) 129 | # calculate test accuracy for each object class 130 | # print(target) 131 | 132 | for i in range(batch_size): 133 | label = target.data[i] 134 | class_correct[label] += correct[i].item() 135 | class_total[label] += 1 136 | 137 | # average test loss 138 | test_loss = test_loss / len(test_loader.dataset) 139 | print('Test Loss: {:.6f}\n'.format(test_loss)) 140 | 141 | for i in range(2): 142 | if class_total[i] > 0: 143 | print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % ( 144 | classes[i], 100 * class_correct[i] / class_total[i], 145 | np.sum(class_correct[i]), np.sum(class_total[i]))) 146 | else: 147 | print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i])) 148 | 149 | print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % ( 150 | 100. * np.sum(class_correct) / np.sum(class_total), 151 | np.sum(class_correct), np.sum(class_total))) 152 | 153 | if __name__ == '__main__': 154 | 155 | model = Net() 156 | print(model) 157 | 158 | train_loader,valid_loader,test_loader,classes=load_data() 159 | # move tensors to GPU if CUDA is available 160 | if train_on_gpu: 161 | model.cuda() 162 | criterion = torch.nn.CrossEntropyLoss() 163 | optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) 164 | n_epochs = 5 # you may increase this number to train a final model 165 | 166 | valid_loss_min = np.Inf # track change in validation loss 167 | 168 | for epoch in range(1, n_epochs + 1): 169 | 170 | # keep track of training and validation loss 171 | train_loss = 0.0 172 | valid_loss = 0.0 173 | 174 | ################### 175 | # train the model # 176 | ################### 177 | model.train() 178 | for data, target in train_loader: 179 | # move tensors to GPU if CUDA is available 180 | if train_on_gpu: 181 | data, target = data.cuda(), target.cuda() 182 | # clear the gradients of all optimized variables 183 | optimizer.zero_grad() 184 | # forward pass: compute predicted outputs by passing inputs to the model 185 | output = model(data) 186 | # calculate the batch loss 187 | loss = criterion(output, target) 188 | # backward pass: compute gradient of the loss with respect to model parameters 189 | loss.backward() 190 | # perform a single optimization step (parameter update) 191 | optimizer.step() 192 | # update training loss 193 | train_loss += loss.item() * data.size(0) 194 | 195 | ###################### 196 | # validate the model # 197 | ###################### 198 | model.eval() 199 | for data, target in valid_loader: 200 | # move tensors to GPU if CUDA is available 201 | if train_on_gpu: 202 | data, target = data.cuda(), target.cuda() 203 | # forward pass: compute predicted outputs by passing inputs to the model 204 | output = model(data) 205 | # calculate the batch loss 206 | loss = criterion(output, target) 207 | # update average validation loss 208 | valid_loss += loss.item() * data.size(0) 209 | 210 | # calculate average losses 211 | train_loss = train_loss / len(train_loader.dataset) 212 | valid_loss = valid_loss / len(valid_loader.dataset) 213 | 214 | # print training/validation statistics 215 | print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format( 216 | epoch, train_loss, valid_loss)) 217 | 218 | # save model if validation loss has decreased 219 | if valid_loss <= valid_loss_min: 220 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format( 221 | valid_loss_min, 222 | valid_loss)) 223 | #torch.save(model.state_dict(), 'model_cifar.pt') 224 | valid_loss_min = valid_loss 225 | 226 | some_test(test_loader,classes) 227 | -------------------------------------------------------------------------------- /case_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.optim.lr_scheduler import StepLR 9 | import os 10 | import sys 11 | ONNET_DIR = os.path.abspath("./python-package/") 12 | sys.path.append(ONNET_DIR) # To find local version of the onnet 13 | from onnet import * 14 | import torchvision 15 | import cv2 16 | import math 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | 20 | #dataset="emnist" 21 | #dataset="fasion_mnist" 22 | #dataset="cifar" 23 | dataset="mnist" 24 | # IMG_size = (28, 28) 25 | # IMG_size = (56, 56) 26 | IMG_size = (112, 112) 27 | # IMG_size = (14, 14) 28 | batch_size = 128 29 | 30 | #net_type = "OptFormer" 31 | #net_type = "cnn" 32 | net_type = "DNet" 33 | #net_type = "WNet" 34 | #net_type = "MF_WNet" 35 | #net_type = "MF_DNet"; 36 | #net_type = "BiDNet" 37 | 38 | class Fasion_Net(nn.Module): #https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html 39 | def __init__(self): 40 | super(Net, self).__init__() 41 | self.conv1 = nn.Conv2d(1, 6, 5) 42 | self.pool = nn.MaxPool2d(2, 2) 43 | self.conv2 = nn.Conv2d(6, 16, 5) 44 | self.fc1 = nn.Linear(16 * 4 * 4, 120) 45 | self.fc2 = nn.Linear(120, 84) 46 | self.fc3 = nn.Linear(84, 10) 47 | 48 | def forward(self, x): 49 | x = self.pool(F.relu(self.conv1(x))) 50 | x = self.pool(F.relu(self.conv2(x))) 51 | x = x.view(-1, 16 * 4 * 4) 52 | x = F.relu(self.fc1(x)) 53 | x = F.relu(self.fc2(x)) 54 | x = self.fc3(x) 55 | return x 56 | 57 | class Mnist_Net(nn.Module): 58 | def __init__(self,config, nCls=10): 59 | super(Mnist_Net, self).__init__() 60 | self.title = "Mnist_Net" 61 | self.config = config 62 | self.config.learning_rate = 0.01 63 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 64 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 65 | self.isDropOut = False 66 | self.nFC=1 67 | if self.isDropOut: 68 | self.dropout1 = nn.Dropout2d(0.25) 69 | self.dropout2 = nn.Dropout2d(0.5) 70 | if IMG_size[0]==56: 71 | nFC1 = 43264 72 | else: 73 | nFC1 = 9216 74 | if self.nFC == 1: 75 | self.fc1 = nn.Linear(nFC1, 10) 76 | else: 77 | self.fc1 = nn.Linear(nFC1, 128) 78 | self.fc2 = nn.Linear(128, 10) 79 | self.loss = F.cross_entropy 80 | self.nClass = nCls 81 | 82 | def forward(self, x): 83 | x = self.conv1(x) 84 | x = F.relu(x) 85 | x = self.conv2(x) 86 | x = F.max_pool2d(x, 2) 87 | if self.isDropOut: 88 | x = self.dropout1(x) 89 | x = torch.flatten(x, 1) 90 | x = self.fc1(x) 91 | x = F.relu(x) 92 | if self.isDropOut: 93 | x = self.dropout2(x) 94 | if self.nFC == 2: 95 | x = self.fc2(x) 96 | #output = F.log_softmax(x, dim=1) 97 | output = x 98 | return output 99 | 100 | def predict(self,output): 101 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 102 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 103 | return pred 104 | 105 | class View(nn.Module): 106 | def __init__(self, *args): 107 | super(View, self).__init__() 108 | self.shape = args 109 | 110 | def forward(self, x): 111 | return x.view(-1,*self.shape) 112 | 113 | train_trans = transforms.Compose([ 114 | #transforms.RandomAffine(5,translate=(0,0.1)), 115 | #transforms.RandomRotation(10), 116 | #transforms.Grayscale(), 117 | transforms.Resize(IMG_size), 118 | transforms.ToTensor(), 119 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #Convert a color image to grayscale and normalize the color range to [0,1]. 120 | #transforms.Normalize((0.1307,), (0.3081,)) 121 | ]) 122 | test_trans = transforms.Compose([ 123 | #transforms.Grayscale(), 124 | transforms.Resize(IMG_size), 125 | transforms.ToTensor(), 126 | #transforms.Normalize((0.1307,), (0.3081,)) 127 | ]) 128 | 129 | def train(model, device, train_loader, epoch, optical_trans,visual): 130 | #model.visual = visual 131 | # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9,weight_decay=0.0005) 132 | optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate, weight_decay=0.0005) 133 | if epoch==1: 134 | print(f"\n=======dataset={dataset} net={net_type} IMG_size={IMG_size} batch_size={batch_size}") 135 | print(f"======={model.config}") 136 | print(f"======={optimizer}") 137 | print(f"======={train_trans}\n") 138 | 139 | nClass = model.nClass 140 | model.train() 141 | for batch_idx, (data, target) in enumerate(train_loader): 142 | if batch_idx==0: #check data_range 143 | d0,d1=data.min(),data.max() 144 | assert(d0>=0) 145 | data, target = data.to(device), target.to(device) 146 | optimizer.zero_grad() 147 | output = model(optical_trans(data)) 148 | #output = model(data) 149 | loss = model.loss(output, target) 150 | loss.backward() 151 | optimizer.step() 152 | if batch_idx % 50 == 0: 153 | aLoss = loss.item() 154 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 155 | epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader),aLoss )) 156 | #visual.UpdateLoss(title=f"Accuracy on \"{dataset}\"", legend=f"{model.legend()}", loss=aLoss, yLabel="Accuracy") 157 | #break 158 | 159 | def test_one_batch(model,data,target,device): 160 | data, target = data.to(device), target.to(device) 161 | output = model(data) 162 | # output = model(data) 163 | loss = model.loss(output, target, reduction='sum').item() # sum up batch loss 164 | pred = model.predict(output) 165 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 166 | correct = pred.eq(target.view_as(pred)).sum().item() 167 | return loss,correct 168 | 169 | def test(model, device, test_loader, optical_trans,visual): 170 | model.eval() 171 | test_loss = 0 172 | correct = 0 173 | with torch.no_grad(): 174 | for data, target in test_loader: 175 | loss, corr = test_one_batch(model, data, target, device) 176 | test_loss += loss 177 | correct += corr 178 | if False: 179 | data, target = data.to(device), target.to(device) 180 | if optical_trans is not None: data = optical_trans(data) 181 | output = model(data) 182 | #output = model(data) 183 | test_loss += model.loss(output, target, reduction='sum').item() # sum up batch loss 184 | pred = model.predict(output) 185 | #pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 186 | correct += pred.eq(target.view_as(pred)).sum().item() 187 | 188 | test_loss /= len(test_loader.dataset) 189 | accu = 100. * correct / len(test_loader.dataset) 190 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(test_loss, correct, len(test_loader.dataset),accu)) 191 | if visual is not None: 192 | visual.UpdateLoss(title=f"Accuracy on \"{dataset}\"",legend=f"{model.legend()}", loss=accu,yLabel="Accuracy") 193 | return accu 194 | 195 | def Some_Test(): 196 | use_cuda = torch.cuda.is_available() 197 | device = torch.device("cuda" if use_cuda else "cpu") 198 | model_path = "E:/ONNet/checkpoint/DNNet_exp_W_H_Express Wavenet_[17,81.91]_.pth" 199 | PTH = torch.load(model_path) 200 | env_title, model = DNet_instance(PTH['net_type'], PTH['dataset'], 201 | PTH['IMG_size'], PTH['lr_base'], PTH['batch_size'], PTH['nClass'], PTH['nLayer']) 202 | epoch, acc = PTH['epoch'], PTH['acc'] 203 | model.load_state_dict(PTH['net']) 204 | model.to(device) 205 | print(f"Load model@{model_path} epoch={epoch},acc={acc}") 206 | 207 | visual = Visdom_Visualizer(env_title,plots=[{"object":"output"}]) 208 | visual.img_dir = "./dump/X_images/" 209 | test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('./data', train=False,transform=test_trans), 210 | batch_size=batch_size, shuffle=False) 211 | if True: #only one batch 212 | dataiter = iter(test_loader) 213 | images, target = dataiter.next() 214 | model.visual = visual 215 | loss,correct = test_one_batch(model, images, target, device) 216 | model.visual = None 217 | 218 | if False: 219 | acc_1 = test(model, device, test_loader, None, None) 220 | print(f"Some_Test acc={acc}-{acc_1}") 221 | 222 | def main(): 223 | #OnInitInstance() 224 | lr_base = 0.002 225 | parser = argparse.ArgumentParser(description='MNIST optical_trans + hybrid examples') 226 | parser.add_argument('--mode', type=int, default=2,help='optical_trans 1st or 2nd order') 227 | parser.add_argument('--classifier', type=str, default='linear',help='classifier model') 228 | args = parser.parse_args() 229 | assert(args.classifier in ['linear','mlp','cnn']) 230 | 231 | use_cuda = torch.cuda.is_available() 232 | device = torch.device("cuda" if use_cuda else "cpu") 233 | optical_trans = OpticalTrans() 234 | 235 | # DataLoaders 236 | if use_cuda: 237 | num_workers = 4 238 | pin_memory = True 239 | else: 240 | num_workers = None 241 | pin_memory = False 242 | 243 | nLayer = 10 244 | if dataset=="emnist": 245 | train_loader = torch.utils.data.DataLoader( 246 | datasets.EMNIST('./data',split="balanced", train=True, download=True, transform=train_trans), 247 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) 248 | test_loader = torch.utils.data.DataLoader( 249 | datasets.EMNIST('./data',split="balanced", train=False, transform=test_trans), 250 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) 251 | # balanced=47 byclass=62 252 | nClass = 47 253 | elif dataset=="fasion_mnist": 254 | train_loader = torch.utils.data.DataLoader( 255 | datasets.FashionMNIST('./data',train=True, download=True, transform=train_trans), 256 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) 257 | test_loader = torch.utils.data.DataLoader( 258 | datasets.FashionMNIST('./data',train=False, transform=test_trans), 259 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) 260 | nClass = 10 261 | elif dataset=="cifar": 262 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./data',train=True, download=True, transform=train_trans), 263 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) 264 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./data',train=False, transform=test_trans), 265 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) 266 | nClass = 10; lr_base=0.005 267 | else: 268 | nClass = 10 269 | train_loader = torch.utils.data.DataLoader( 270 | datasets.MNIST('./data', train=True, download=True,transform=train_trans), 271 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) 272 | test_loader = torch.utils.data.DataLoader( 273 | datasets.MNIST('./data', train=False,transform=test_trans), 274 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) 275 | 276 | config_0 = NET_config(net_type,dataset,IMG_size,lr_base,batch_size,nClass,nLayer) 277 | env_title, model = DNet_instance(config_0) #net_type,dataset,IMG_size,lr_base,batch_size,nClass,nLayer 278 | visual = Visdom_Visualizer(env_title=env_title) 279 | # visual = Visualize(env_title=env_title) 280 | model.to(device) 281 | print(model) 282 | # visual.ShowModel(model,train_loader) 283 | 284 | if False: # So strange in initialize 285 | for m in model.modules(): 286 | if isinstance(m, nn.Conv2d): 287 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 288 | m.weight.data.normal_(0, 2. / math.sqrt(n)) 289 | m.bias.data.zero_() 290 | if isinstance(m, nn.Linear): 291 | m.weight.data.normal_(0, 2. / math.sqrt(m.in_features)) 292 | m.bias.data.zero_() 293 | 294 | nzParams = Net_dump(model) 295 | if False: 296 | nzParams=0 297 | for name, param in model.named_parameters(): 298 | if param.requires_grad: 299 | nzParams+=param.nelement() 300 | print(f"\t{name}={param.nelement()}") 301 | print(f"========All parameters={nzParams}") 302 | 303 | acc,best_acc = 0,0 304 | accu_=[] 305 | for epoch in range(1, 33): 306 | if False: 307 | assert os.path.isdir('checkpoint') 308 | pth_path = f'./checkpoint/{model.title}_[{epoch},{acc}]_.pth' 309 | torch.save({'net': model.state_dict(), 'acc': acc, 'epoch': epoch,}, pth_path) 310 | 311 | if hasattr(model,'visualize'): 312 | model.visualize(visual, f"E[{epoch-1}") 313 | train( model, device, train_loader, epoch, optical_trans,visual) 314 | acc = test(model, device, test_loader, optical_trans,visual) 315 | accu_.append(acc) 316 | if acc > best_acc: 317 | state = { 318 | 'net_type':net_type,'dataset':dataset,'IMG_size':IMG_size,'lr_base':lr_base, 319 | 'batch_size':batch_size,'nClass':nClass, 'nLayer':nLayer, 320 | 'net': model.state_dict(), 'acc': acc,'epoch': epoch, 321 | } 322 | assert os.path.isdir('checkpoint') 323 | pth_path = f'./checkpoint/{model.title}_[{epoch},{acc}]_.pth' 324 | torch.save(state, pth_path) 325 | best_acc = acc 326 | print(f"\n=======\n=======accu_history={accu_}\n") 327 | 328 | #if args.save_model: 329 | # torch.save(model.state_dict(), "mnist_onn.pt") 330 | 331 | ''' 332 | 单衍射层测试算例 333 | 1) PIL加载图片 2)DiffractiveLayer forward 3)plt显示 334 | ''' 335 | def layer_test(): 336 | from PIL import Image 337 | img = Image.open("E:/ONNet/data/MNIST/test_2.jpg") 338 | img = train_trans(img) 339 | 340 | config=NET_config(net_type,dataset,IMG_size,0.01,32,10,5) 341 | config.modulation = 'phase' 342 | config.init_value = "random" 343 | config.rDrop = 0 #drop out 344 | layer = DiffractiveLayer(IMG_size[0],IMG_size[1],config) 345 | 346 | out = layer.forward(img.cuda()) 347 | im_out = layer.z_modulus(out) 348 | im_out = im_out.squeeze().cpu().detach().numpy() 349 | 350 | fig, ax = plt.subplots() 351 | #plt.axis('off') 352 | plt.grid(b=None) 353 | im = ax.imshow(im_out, interpolation='nearest', cmap='coolwarm') 354 | title = f"{layer.__repr__()}" 355 | ax.set_title(title,fontsize=12) 356 | fig.colorbar(im, orientation='horizontal') 357 | plt.show() 358 | plt.close() 359 | 360 | print("!!!Good Luck!!!") 361 | 362 | if __name__ == '__main__': 363 | #Some_Test() 364 | #layer_test() 365 | main() 366 | -------------------------------------------------------------------------------- /python-package/case_fft.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from PIL import Image 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from onnet import * 6 | import torch 7 | from skimage import io, transform 8 | torch.set_printoptions(profile="full") 9 | 10 | size = 28 11 | delta = 0.03 12 | dL = 0.02 13 | c = 3e8 14 | Hz = 0.4e12 15 | 16 | def Init_H(d=delta, N = size, dL = dL, lmb = c/Hz,theta=0.0): 17 | # Parameter 18 | df = 1.0 / dL 19 | k = np.pi * 2.0 / lmb 20 | D = dL * dL / (N * lmb) 21 | # phase 22 | def phase(i, j): 23 | i -= N // 2 24 | j -= N // 2 25 | return ((i * df) * (i * df) + (j * df) * (j * df)) 26 | 27 | 28 | ph = np.fromfunction(phase, shape=(N, N), dtype=np.float32) 29 | # H 30 | H = np.exp(1.0j * k * d) * np.exp(-1.0j * lmb * np.pi * d * ph) 31 | H_f = np.fft.fftshift(H) 32 | #print(H_f); print(H) 33 | return H,H_f 34 | 35 | def fft_test(H_f,N = 28): 36 | dL = 0.02 37 | s = dL * dL / (N * N) 38 | 39 | normalize = transforms.Normalize( 40 | mean=[0.485, 0.456, 0.406], 41 | std=[0.229, 0.224, 0.225] 42 | ) 43 | preprocess = transforms.Compose([ 44 | #transforms.Resize(256), 45 | #transforms.CenterCrop(224), 46 | transforms.ToTensor(), 47 | #normalize 48 | ]) 49 | image = io.imread("E:/ONNet/data/MNIST/test_2.jpg").astype(np.float64) 50 | #print(image) 51 | img_tensor = torch.from_numpy(image) 52 | #print(img_tensor) 53 | #img_tensor.unsqueeze_(0) 54 | print(img_tensor.shape, img_tensor.dtype) 55 | u0 = COMPLEX_utils.ToZ(img_tensor) 56 | print(u0.shape, H_f.shape); 57 | 58 | u1 = COMPLEX_utils.fft(u0) 59 | print(u1) 60 | H_z = np.zeros(H_f.shape + (2,)) 61 | H_z[..., 0] = H_f.real 62 | H_z[..., 1] = H_f.imag 63 | H_f = torch.from_numpy(H_z) 64 | u1 = COMPLEX_utils.Hadamard(H_f,u1) #H_f * u1 65 | print(u1) 66 | u1 = COMPLEX_utils.fft(u1 ,"C2C",inverse=True) 67 | print(u1) 68 | input(...) 69 | 70 | if __name__ == '__main__': 71 | H, H_f = Init_H() 72 | fft_test(H_f) -------------------------------------------------------------------------------- /python-package/cnn_models/OpticalNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | import os 6 | #ONNET_DIR = os.path.abspath("../../") 7 | sys.path.append("../") # To find local version of the onnet 8 | from onnet import * 9 | from onnet import DiffractiveLayer 10 | 11 | class OpticalBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self,config, in_planes, planes, stride=1): 15 | super(OpticalBlock, self).__init__() 16 | self.config = config 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | M,N = self.config.IMG_size[0], self.config.IMG_size[1] 24 | self.diffrac = DiffractiveLayer(M,N,config) 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | #assert x.shape[-1]==32 and x.shape[-2]==32 36 | #out += self.diffrac(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class OpticalNet(nn.Module): 42 | def __init__(self, config,block, num_blocks): 43 | super(OpticalNet, self).__init__() 44 | num_classes = config.nClass 45 | self.config = config 46 | self.in_planes = 64 47 | 48 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(64) 50 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 51 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 52 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 53 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 54 | self.linear = nn.Linear(512*block.expansion, num_classes) 55 | 56 | def _make_layer(self, block, planes, num_blocks, stride): 57 | strides = [stride] + [1]*(num_blocks-1) 58 | layers = [] 59 | for stride in strides: 60 | layers.append(block(self.config,self.in_planes, planes, stride)) 61 | self.in_planes = planes * block.expansion 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = self.layer1(out) 67 | out = self.layer2(out) 68 | out = self.layer3(out) 69 | out = self.layer4(out) 70 | out = F.avg_pool2d(out, 4) 71 | out = out.view(out.size(0), -1) 72 | out = self.linear(out) 73 | return out 74 | 75 | 76 | def OpticalNet18(config): 77 | return OpticalNet(config,OpticalBlock, [2,2,2,2]) 78 | 79 | def OpticalNet34(config): 80 | return OpticalNet(config,OpticalBlock, [3,4,6,3]) 81 | 82 | def test(): 83 | net = OpticalNet18() 84 | y = net(torch.randn(1,3,32,32)) 85 | print(y.size()) 86 | 87 | # test() 88 | -------------------------------------------------------------------------------- /python-package/fast_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Yingshi Chen 3 | 4 | @Date: 2020-03-04 14:50:24 5 | @ 6 | # Description: 7 | ''' 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import time 12 | import sys 13 | sys.path.append('..') 14 | #from deap.convolve import convDEAP_GIP 15 | from scipy.signal import convolve2d 16 | import matplotlib.pyplot as plt 17 | from deap.helpers import getOutputShape 18 | from deap.mappers import PhotonicConvolverMapper 19 | from deap.mappers import ModulatorArrayMapper 20 | from deap.mappers import PWBArrayMapper 21 | 22 | class MRMTransferFunction: 23 | """ 24 | Computes the transfer function of a microring modulator (MRM). 25 | """ 26 | def __init__(self, a=0.9, r=0.9): 27 | self.a = a 28 | self.r = r 29 | self._maxThroughput = self.throughput(np.pi) 30 | 31 | def throughput(self, phi): 32 | I_pass = self.a**2 - 2 * self.r * self.a * np.cos(phi) + self.r**2 33 | I_input = 1 - 2 * self.a * self.r * np.cos(phi) + (self.r * self.a)**2 34 | return I_pass / I_input 35 | 36 | def phaseFromThroughput(self, Tn): 37 | Tn = np.asarray(Tn) 38 | 39 | # Create variable to store results 40 | ans = np.empty_like(Tn) 41 | 42 | # For high throuputs, set to pi 43 | moreThanMax = Tn >= self._maxThroughput 44 | maxOrLess = ~moreThanMax 45 | ans[moreThanMax] = np.pi 46 | 47 | # Now solve the remainng 48 | cos_phi = Tn[maxOrLess] * (1 + (self.r * self.a)**2) - self.a**2 - self.r**2 # noqa 49 | ans[maxOrLess] = np.arccos(cos_phi / (-2 * self.r * self.a * (1 - Tn[maxOrLess]))) # noqa 50 | #ans = np.arccos(cos_phi / (-2 * self.r * self.a * (1 - Tn[maxOrLess]))) 51 | 52 | return ans 53 | 54 | def convDEAP(image, kernel, stride, bias=0, normval=255): 55 | """ 56 | Image is a 3D matrix with index values row, col, depth, index 57 | Kernel is a 4D matrix with index values row, col, depth, index. 58 | The depth of the kernel must be equal to the depth of the input. 59 | """ 60 | assert image.shape[2] == kernel.shape[2] 61 | 62 | # Allocate memory for storing result of convolution 63 | outputShape = getOutputShape(image.shape, kernel.shape, stride=stride) 64 | output = np.zeros(outputShape) 65 | 66 | # Build the photonic circuit 67 | weightBanks = [] 68 | inputShape = (kernel.shape[0], kernel.shape[1]) 69 | for k in range(image.shape[2]): 70 | pc = PhotonicConvolverMapper.build( 71 | imageShape=inputShape, 72 | kernelShape=inputShape, 73 | power=normval) 74 | weightBanks.append(pc) 75 | 76 | for k in range(kernel.shape[3]): 77 | # Load weights 78 | weights = kernel[:, :, :, k] 79 | for c in range(weights.shape[2]): 80 | PWBArrayMapper.updateKernel( 81 | weightBanks[c].pwbArray, 82 | weights[:, :, c]) 83 | 84 | for h in range(0, outputShape[0], stride): 85 | for w in range(0, outputShape[1], stride): 86 | # Load inputs 87 | inputs = \ 88 | image[h:min(h + kernel.shape[0], image.shape[0]), 89 | w:min(w + kernel.shape[0], image.shape[1]), :] 90 | for c in range(kernel.shape[2]): 91 | ModulatorArrayMapper.updateInputs( 92 | weightBanks[c].modulatorArray, 93 | inputs[:, :, c], 94 | normval=normval) 95 | 96 | # Perform convolution: 97 | for c in range(kernel.shape[2]): 98 | output[h, w, k] += weightBanks[c].step() 99 | output[h, w, k] += bias 100 | 101 | return output 102 | 103 | def convDEAP_GIP(image, kernel, stride, convolverShape=None): 104 | """ 105 | Image is a 3D matrix with index values row, col, depth, index 106 | Kernel is a 4D matrix with index values row, col, depth, index. 107 | The depth of the kernel must be equal to the depth of the input. 108 | """ 109 | assert image.shape[2] == kernel.shape[2] 110 | assert kernel.shape[2] == 1 and kernel.shape[3] == 1 111 | if convolverShape is None: 112 | convolverShape = image.shape 113 | 114 | # Define convolutional parameters 115 | Hm, Wm = convolverShape[0], convolverShape[1] 116 | H, W = image.shape[0], image.shape[1] 117 | R = kernel.shape[0] 118 | 119 | # Allocate memory for storing result of convolution 120 | outputShape = getOutputShape(image.shape, kernel.shape, stride=stride) 121 | output = np.zeros(outputShape) 122 | 123 | # Load weights 124 | pc = PhotonicConvolverMapper.build(imageShape=convolverShape,kernel=kernel[:, :, 0, 0], power=255) 125 | 126 | input_buffer = np.zeros(convolverShape) 127 | normval=255 128 | _mrm = MRMTransferFunction() 129 | for h in range(0, H - R + 1, Hm - R + 1): 130 | for w in range(0, W - R + 1, Wm - R + 1): 131 | inputs = image[h:min(h + Hm, H), w:min(w + Wm, W), 0] 132 | # Load inputs into a buffer if convolution shape doesn't tile 133 | # nicely. 134 | input_buffer[:inputs.shape[0], :inputs.shape[1]] = inputs 135 | input_buffer[inputs.shape[0]:, inputs.shape[1]:] = 0 136 | 137 | if False: 138 | ModulatorArrayMapper.updateInputs(pc.modulatorArray,input_buffer,normval=255) 139 | else: 140 | #phaseShifts = ModulatorArrayMapper.computePhaseShifts(input_buffer, normval=255) 141 | normalized = input_buffer / normval 142 | assert not np.any(input_buffer < 0) 143 | phaseShifts = _mrm.phaseFromThroughput(normalized) 144 | pc.modulatorArray._update(phaseShifts) 145 | 146 | # Perform the convolution and store to memory 147 | result = pc.step()[:min(h + Hm, H) - h - R + 1, 148 | :min(w + Wm, W) - w - R + 1] 149 | output[h:min(h + Hm, H) - R + 1, 150 | w:min(w + Hm, W) - R + 1, 151 | 0] = result 152 | 153 | return output 154 | 155 | def main(): 156 | image = plt.imread("./data/bass.jpg") 157 | greyscale = np.mean(image, axis=2) 158 | 159 | # Define kernel 160 | gaussian_kernel = np.zeros((3, 3, 1, 1)) 161 | gaussian_kernel[:, :, 0, 0] = \ 162 | np.array([ 163 | [1, 2, 1], 164 | [2, 4, 2], 165 | [1, 2, 1]]) * 1/16 166 | 167 | 168 | # Perform convolution 169 | paddedInputs = np.pad(greyscale, (2, 2), 'constant') 170 | paddedInputs = np.expand_dims(paddedInputs, 2) 171 | convolved = convDEAP_GIP(paddedInputs, gaussian_kernel, 1, (12, 12)) 172 | t0=time.time() 173 | for i in range(10): 174 | convDEAP_GIP(paddedInputs, gaussian_kernel, 1, (12, 12)) 175 | print(f"convDEAP_GIP T_10={time.time()-t0:.3f}") 176 | 177 | 178 | t0=time.time() 179 | for i in range(10): 180 | convolve2d(greyscale, gaussian_kernel[:, :, 0, 0]) 181 | print(f"convolve2d T_10={time.time()-t0:.3f}") 182 | conv_scipy = convolve2d(greyscale, gaussian_kernel[:, :, 0, 0]) 183 | 184 | err = np.abs(convolved[:, :, 0] - conv_scipy) 185 | mse = np.sum(err**2) / (err.size) 186 | print("MSE distance per pixel", mse) 187 | 188 | if __name__ == '__main__': 189 | main() -------------------------------------------------------------------------------- /python-package/onnet/BinaryDNet.py: -------------------------------------------------------------------------------- 1 | from .D2NNet import * 2 | import math 3 | import random 4 | 5 | class GatePipe(torch.nn.Module): 6 | def __init__(self,M,N, nHidden,config,pooling="max"): 7 | super(GatePipe, self).__init__() 8 | self.config = config 9 | self.M=M 10 | self.N=N 11 | self.nHidden = nHidden 12 | self.pooling = pooling 13 | self.layers = nn.ModuleList([DiffractiveLayer(self.M, self.N, self.config, HZ=0.3e12) for j in range(self.nHidden)]) 14 | if True: 15 | chunk_dim = -1 if random.choice([True, False]) else -2 16 | self.pool = ChunkPool(2, self.config,pooling=self.pooling,chunk_dim=chunk_dim) 17 | else: 18 | self.pt1 = (random.randint(0, self.M-1),random.randint(0,self.N-1)) 19 | self.pt2 = (random.randint(0, self.M - 1), random.randint(0, self.N - 1)) 20 | 21 | def __repr__(self): 22 | main_str = super(GatePipe, self).__repr__() 23 | main_str = f"GatePipe_[{len(self.layers)}]_pool[{self.pooling}]" 24 | return main_str 25 | 26 | def forward(self, x): 27 | for lay in self.layers: 28 | x = lay(x) 29 | x1 = Z.modulus(x).cuda() 30 | #x1 = Z.phase(x).cuda() 31 | if True: 32 | x1 = self.pool(x1) 33 | else: 34 | x_pt1 = x1[:, 0, self.pt1[0], self.pt1[1]] 35 | x_pt2 = x1[:, 0, self.pt2[0], self.pt2[1]] 36 | x1 = torch.stack([x_pt1,x_pt2], 1) 37 | x2 = F.log_softmax(x1, dim=1) 38 | return x2 39 | 40 | class BinaryDNet(D2NNet): 41 | @staticmethod 42 | def binary_loss(output, target, reduction='mean'): 43 | nGate = len(output) 44 | nSamp = target.shape[0] 45 | loss =0 46 | for i in range(nGate): 47 | target_i = target%2 48 | # loss = F.binary_cross_entropy(output, target, reduction=reduction) 49 | loss_i = F.cross_entropy(output[i], target_i, reduction=reduction) 50 | loss += loss_i 51 | target =(target-target_i)/2 52 | 53 | # loss = F.nll_loss(output, target, reduction=reduction) 54 | return loss 55 | 56 | def predict(self,output): 57 | nGate = len(output) 58 | pred = 0 59 | for i in range(nGate): 60 | pred_i = output[nGate-1-i].max(1, keepdim=True)[1] # get the index of the max log-probability 61 | pred = pred*2+pred_i 62 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 63 | return pred 64 | 65 | def __init__(self, IMG_size,nCls,nInterDifrac,nOutDifac,config): 66 | super(BinaryDNet, self).__init__(IMG_size,nCls,nInterDifrac,config) 67 | self.nGate = (int)(math.ceil(math.log2(self.nClass))) 68 | self.nOutDifac = nOutDifac 69 | self.gates = nn.ModuleList( [GatePipe(self.M,self.N,nOutDifac,config,pooling="mean") for i in range(self.nGate)] ) 70 | self.config = config 71 | self.loss = BinaryDNet.binary_loss 72 | 73 | def __repr__(self): 74 | main_str = super(BinaryDNet, self).__repr__() 75 | main_str += f"_nGate={self.nGate}_Difrac=[{self.nDifrac},{self.nOutDifac}]" 76 | return main_str 77 | 78 | def legend(self): 79 | title = f"BinaryDNet" 80 | return title 81 | 82 | def forward(self, x): 83 | x = x.double() 84 | for layD in self.DD: 85 | x = layD(x) 86 | 87 | nSamp = x.shape[0] 88 | output = [] 89 | if True: 90 | for gate in self.gates: 91 | x1 = gate(x) 92 | output.append(x1) 93 | else: 94 | for [diffrac,gate] in self.gates: 95 | x1 = diffrac(x) 96 | x1 = self.z_modulus(x1).cuda() 97 | x1 = gate(x1) 98 | x2 = F.log_softmax(x1, dim=1) 99 | output.append(x2) 100 | 101 | return output 102 | 103 | -------------------------------------------------------------------------------- /python-package/onnet/D2NN_tf.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/computational-imaging/opticalCNN 3 | https://github.com/Lyn-Wu/Lyn/blob/master/DNN 4 | ''' 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | import tensorflow as tf 7 | import numpy as np 8 | from scipy.misc import imresize 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | from skimage import io, transform 12 | 13 | learning_rate = 0.01 14 | #size = 512 15 | size = 28 16 | delta = 0.03 17 | dL = 0.02 18 | batch_size = 64 19 | batch = 10 20 | #mnist = input_data.read_data_sets("MNIST_data",one_hot=True) 21 | mnist = input_data.read_data_sets("E:/ONNet/data/MNIST/raw",one_hot=True) 22 | c = 3e8 23 | Hz = 0.4e12 24 | 25 | def fft_test(N = size): 26 | s = dL * dL / (N * N) 27 | if False: 28 | img_raw = tf.io.read_file("E:/ONNet/data/MNIST/test_2.jpg") 29 | img_raw = tf.image.decode_jpeg(img_raw) 30 | else: #tf.io与skimage.io居然不一样,令人难以理解 31 | img_raw = io.imread("E:/ONNet/data/MNIST/test_2.jpg") 32 | #print(img_raw) 33 | img_tensor = tf.squeeze(img_raw) 34 | with tf.Session() as sess: 35 | img_tensor = img_tensor.eval() 36 | print(img_tensor.shape,img_tensor.dtype) 37 | #print(img_tensor) 38 | 39 | u0 = tf.cast(img_tensor,dtype=tf.complex64) 40 | print(u0.shape,H_f.shape); 41 | u1 = tf.fft2d(u0) 42 | with tf.Session() as sess: 43 | print(u0.eval()) 44 | print(u1.eval()) 45 | u1 = H_f * u1 46 | u2 = tf.ifft2d(u1 ) 47 | with tf.Session() as sess: 48 | print(u1.eval()) 49 | print(u2.eval()) 50 | 51 | def Init_H(d=delta, N = size, dL = dL, lmb = c/Hz,theta=0.0): 52 | # Parameter 53 | df = 1.0 / dL 54 | k = np.pi * 2.0 / lmb 55 | D = dL * dL / (N * lmb) 56 | # phase 57 | def phase(i, j): 58 | i -= N // 2 59 | j -= N // 2 60 | return ((i * df) * (i * df) + (j * df) * (j * df)) 61 | 62 | 63 | ph = np.fromfunction(phase, shape=(N, N), dtype=np.float32) 64 | # H 65 | H = np.exp(1.0j * k * d) * np.exp(-1.0j * lmb * np.pi * d * ph) 66 | H_f = np.fft.fftshift(H) 67 | #print(H_f); print(H) 68 | return H,H_f 69 | 70 | H,H_f=Init_H() 71 | #fft_test(); input(...) 72 | 73 | def _propogation(u0, N = size, dL = dL): 74 | df = 1.0 / dL 75 | return tf.ifft2d(H_f*tf.fft2d(u0)*dL*dL/(N*N))*N*N/dL/dL 76 | 77 | def propogation(u0,d,function=_propogation): 78 | return tf.map_fn(function,u0) 79 | 80 | def make_random(shape): 81 | return np.random.random(size = shape).astype('float32') 82 | 83 | 84 | def add_layer_amp(inputs,amp,phase,size,delta): 85 | return tf.multiply(propogation(inputs,delta),tf.cast(amp,dtype=tf.complex64)) 86 | #return propogation(inputs,delta)*tf.cast(amp,dtype=tf.complex64) 87 | 88 | def add_layer_phase_out(inputs,amp,phase,size,delta): 89 | return propogation(inputs,delta,function=_propogation_phase_out)*tf.math.exp(1j*tf.cast(phase,dtype=tf.complex64)) 90 | 91 | 92 | def add_layer_phase_in(inputs,amp,phase,size,delta): 93 | return propogation(inputs,delta,function=_propogation_phase_in)*tf.cast(amp,dtype=tf.complex64) 94 | 95 | def _change(input_): 96 | return imresize(input_.reshape(28,28),(size,size),interp="nearest") 97 | 98 | def change(input_): 99 | return np.array(list(map(_change,input_))) 100 | 101 | def rang(arr,shape,size=size,base = 512): 102 | #return arr[shape[0]*size//base:shape[1]*size//base,shape[2]*size//512:shape[3]*size//512] 103 | x0 = shape[0] * size // base 104 | y0 = shape[2] * size // base 105 | delta = (shape[1]-shape[0])* size // base 106 | return arr[x0:x0+delta,y0:y0+delta] 107 | 108 | def reduce_mean(tf_): 109 | return tf.reduce_mean(tf_) 110 | 111 | def _ten_regions(a): 112 | return tf.map_fn(reduce_mean,tf.convert_to_tensor([ 113 | rang(a,(120,170,120,170)), 114 | rang(a,(120,170,240,290)), 115 | rang(a,(120,170,360,410)), 116 | rang(a,(220,270,120,170)), 117 | rang(a,(220,270,200,250)), 118 | rang(a,(220,270,280,330)), 119 | rang(a,(220,270,360,410)), 120 | rang(a,(320,370,120,170)), 121 | rang(a,(320,370,240,290)), 122 | rang(a,(320,370,360,410)) 123 | ])) 124 | 125 | def ten_regions(logits): 126 | return tf.map_fn(_ten_regions,tf.abs(logits),dtype=tf.float32) 127 | 128 | def download_text(msg,epoch,MIN=1,MAX=7,name=''): 129 | print("Download {}".format(name)) 130 | if name == 'Phase': 131 | MIN = 0 132 | MAX = 2 133 | for i in range(MIN,MAX): 134 | print("{} {}:".format(name,i)) 135 | np.savetxt("{}_Time_{}_layer_{}.txt".format(name,epoch+1,i),msg[i-1]) 136 | print("Done") 137 | 138 | def download_image(msg,epoch,MIN=1,MAX=7,name=''): 139 | print(f"Plot images-[{MIN}:{MAX}]") 140 | if name == 'Phase': 141 | MIN = 0 142 | MAX = 2 143 | for i in range(MIN,MAX): 144 | #print("Image {}:".format(i)) 145 | plt.figure(dpi=650.24) 146 | plt.axis('off') 147 | plt.grid('off') 148 | plt.imshow(msg[i-1]) 149 | plt.savefig("{}_Time_{}_layer_{}.jpg".format(name,epoch+1,i)) 150 | #print("Done") 151 | 152 | def download_acc(acc,epoch): 153 | np.savetxt("Acc{}.txt".format(epoch+1),acc) 154 | 155 | 156 | with tf.device('/cpu:0'): 157 | data_x = tf.placeholder(tf.float32,shape=(batch_size,size,size)) 158 | data_y = tf.placeholder(tf.float32,shape=(batch_size,10)) 159 | 160 | amp=[ 161 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32), 162 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32), 163 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32), 164 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32), 165 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32), 166 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32) 167 | ] 168 | 169 | phase = [ 170 | tf.constant(np.random.random(size=(size,size)),dtype=tf.float32), 171 | tf.constant(np.random.random(size=(size,size)),dtype=tf.float32) 172 | ] 173 | 174 | with tf.variable_scope('FullyConnected'): 175 | layer_1 = add_layer_amp(tf.cast(data_x,dtype=tf.complex64),amp[0],phase[0],size,delta) 176 | layer_2 = add_layer_amp(layer_1,amp[1],phase[1],size,delta) 177 | layer_3 = add_layer_amp(layer_2,amp[2],phase[1],size,delta) 178 | layer_4 = add_layer_amp(layer_3,amp[3],phase[1],size,delta) 179 | layer_5 = add_layer_amp(layer_4,amp[4],phase[1],size,delta) 180 | output_layer = add_layer_amp(layer_5,amp[5],phase[1],size,delta) 181 | output = _propogation(output_layer) 182 | 183 | with tf.variable_scope('Loss'): 184 | logits_abs = tf.square(tf.nn.softmax(ten_regions(tf.abs(output)))) 185 | loss = tf.reduce_sum(tf.square(logits_abs-data_y)) 186 | train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) 187 | 188 | with tf.variable_scope('Accuracy'): 189 | pre_correct = tf.equal(tf.argmax(data_y,1),tf.argmax(logits_abs,1)) 190 | accuracy = tf.reduce_mean(tf.cast(pre_correct,tf.float32)) 191 | 192 | init = tf.global_variables_initializer() 193 | train_epochs = 20 194 | test_epochs = 5 195 | session = tf.Session() 196 | with tf.device('/gpu:0'): 197 | session.run(init) 198 | total_batch = int(mnist.train.num_examples / batch_size) 199 | #total_batch = 10 200 | 201 | for epoch in tqdm(range(train_epochs)): 202 | for batch in tqdm(range(total_batch)): 203 | batch_x,batch_y = mnist.train.next_batch(batch_size) 204 | session.run(train_op,feed_dict={data_x:change(batch_x),data_y:batch_y}) 205 | 206 | loss_,acc = session.run([loss,accuracy],feed_dict={data_x:change(batch_x),data_y:batch_y}) 207 | print("epoch :{} loss:{:.4f} acc:{:.4f}".format(epoch+1,loss_,acc)) 208 | 209 | with tf.device('/cpu:0'): 210 | msg_amp = np.array(session.run(amp)) 211 | download_text(msg_amp,epoch,name='Amp') 212 | #download_image(msg_amp,epoch,name='Amp') 213 | print("Optimizer finished") -------------------------------------------------------------------------------- /python-package/onnet/D2NNet.py: -------------------------------------------------------------------------------- 1 | # Authors: Yingshi Chen(gsp.cys@gmail.com) 2 | 3 | ''' 4 | PyTorch implementation of D2CNN ------ All-optical machine learning using diffractive deep neural networks 5 | ''' 6 | 7 | import torch 8 | import torchvision.transforms.functional as F 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from .Z_utils import COMPLEX_utils as Z 12 | from .PoolForCls import * 13 | from .Loss import * 14 | from .SparseSupport import * 15 | from .FFT_layer import * 16 | import numpy as np 17 | from .DiffractiveLayer import * 18 | import cv2 19 | useAttention=False 20 | if useAttention: 21 | import entmax 22 | #from torchscope import scope 23 | 24 | class DNET_config: 25 | def __init__(self,batch,lr_base,modulation="phase",init_value = "random",random_seed=42, 26 | support=SuppLayer.SUPP.exp,isFC=False): 27 | ''' 28 | 29 | :param modulation: 30 | :param init_value: ["random","zero","random_reverse","reverse","chunk"] 31 | :param support: 32 | ''' 33 | self.custom_legend = "Express Wavenet" #"Express_OFF" "Express Wavenet","Pan_OFF Express_OFF" #for paper and debug 34 | self.seed = random_seed 35 | seed_everything(self.seed) 36 | self.init_value = init_value # "random" "zero" 37 | self.rDrop = 0 38 | self.support = support #None 39 | self.modulation = modulation #["phase","phase_amp"] 40 | self.output_chunk = "2D" #["1D","2D"] 41 | self.output_pooling = "max" 42 | self.batch = batch 43 | self.learning_rate = lr_base 44 | self.isFC = isFC 45 | self.input_scale = 1 46 | self.wavelet = None #dict paramter for wavelet 47 | #if self.isFC == True: self.learning_rate = lr_base/10 48 | self.input_plane = "" #"fourier" 49 | 50 | def env_title(self): 51 | title=f"{self.support.value}" 52 | if self.isFC: title += "[FC]" 53 | if self.custom_legend is not None: 54 | title = title + f"_{self.custom_legend}" 55 | return title 56 | 57 | def __repr__(self): 58 | main_str = f"lr={self.learning_rate}_ mod={self.modulation} input={self.input_scale} detector={self.output_chunk} " \ 59 | f"support={self.support}" 60 | if self.isFC: main_str+=" [FC]" 61 | if self.custom_legend is not None: 62 | main_str = main_str + f"_{self.custom_legend}" 63 | return main_str 64 | 65 | class D2NNet(nn.Module): 66 | @staticmethod 67 | def binary_loss(output, target, reduction='mean'): 68 | nSamp = target.shape[0] 69 | nGate = output.shape[1] // 2 70 | loss = 0 71 | for i in range(nGate): 72 | target_i = target % 2 73 | val_2 = torch.stack([output[:,2*i],output[:,2*i+1]],1) 74 | 75 | loss_i = F.cross_entropy(val_2, target_i, reduction=reduction) 76 | loss += loss_i 77 | target = (target - target_i) / 2 78 | 79 | # loss = F.nll_loss(output, target, reduction=reduction) 80 | return loss 81 | 82 | @staticmethod 83 | def logit_loss(output, target, reduction='mean'): #https://stackoverflow.com/questions/53628622/loss-function-its-inputs-for-binary-classification-pytorch 84 | nSamp = target.shape[0] 85 | nGate = output.shape[1] 86 | loss = 0 87 | loss_BCE = nn.BCEWithLogitsLoss() 88 | for i in range(nGate): 89 | target_i = target % 2 90 | out_i = output[:,i] 91 | loss_i = loss_BCE(out_i, target_i.double()) 92 | loss += loss_i 93 | target = (target - target_i) / 2 94 | return loss 95 | 96 | def predict(self,output): 97 | if self.config.support == "binary": 98 | nGate = output.shape[1] // 2 99 | #assert nGate == self.n 100 | pred = 0 101 | for i in range(nGate): 102 | no = 2*(nGate - 1 - i) 103 | val_2 = torch.stack([output[:, no], output[:, no + 1]], 1) 104 | pred_i = val_2.max(1, keepdim=True)[1] # get the index of the max log-probability 105 | pred = pred * 2 + pred_i 106 | elif self.config.support == "logit": 107 | nGate = output.shape[1] 108 | # assert nGate == self.n 109 | pred = 0 110 | for i in range(nGate): 111 | no = nGate - 1 - i 112 | val_2 = F.sigmoid(output[:, no]) 113 | pred_i = (val_2+0.5).long() 114 | pred = pred * 2 + pred_i 115 | else: 116 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 117 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 118 | return pred 119 | 120 | def GetLayer_(self): 121 | # layer = DiffractiveAMP 122 | if self.config.wavelet is None: 123 | layer = DiffractiveLayer 124 | else: 125 | layer = DiffractiveWavelet 126 | return layer 127 | 128 | def __init__(self,IMG_size,nCls,nDifrac,config): 129 | super(D2NNet, self).__init__() 130 | self.M,self.N=IMG_size 131 | self.z_modulus = Z.modulus 132 | self.nDifrac = nDifrac 133 | #self.isFC = False 134 | self.nClass = nCls 135 | #self.init_value = "random" #"random" "zero" 136 | self.config = config 137 | self.title = f"DNNet" 138 | self.highWay = 1 #1,2,3 139 | if self.config.input_plane == "fourier": 140 | self.highWay = 0 141 | 142 | if hasattr(self.config,'feat_extractor'): 143 | if self.config.feat_extractor!="last_layer": 144 | self.feat_extractor = [] 145 | 146 | if self.config.output_chunk == "2D": 147 | assert(self.M*self.N>=self.nClass) 148 | else: 149 | assert (self.M >= self.nClass and self.N >= self.nClass) 150 | print(f"D2NNet nClass={nCls} shape={self.M,self.N}") 151 | 152 | 153 | layer = self.GetLayer_() 154 | #fl = FFT_Layer(self.M, self.N,config,isInv=False) 155 | self.DD = nn.ModuleList([ 156 | layer(self.M, self.N,config) for i in range(self.nDifrac) 157 | ]) 158 | if self.config.input_plane=="fourier": 159 | self.DD.insert(0,FFT_Layer(self.M, self.N,config,isInv=False)) 160 | self.DD.append(FFT_Layer(self.M, self.N,config,isInv=True)) 161 | self.nD = len(self.DD) 162 | self.laySupp = None 163 | 164 | if self.highWay>0: 165 | self.wLayer = torch.nn.Parameter(torch.ones(len(self.DD))) 166 | if self.highWay==2: 167 | self.wLayer.data.uniform_(-1, 1) 168 | elif self.highWay==1: 169 | self.wLayer = torch.nn.Parameter(torch.ones(len(self.DD))) 170 | 171 | #self.DD.append(DropOutLayer(self.M, self.N,drop=0.9999)) 172 | if self.config.isFC: 173 | self.fc1 = nn.Linear(self.M*self.N, self.nClass) 174 | self.loss = UserLoss.cys_loss 175 | self.title = f"DNNet_FC" 176 | elif self.config.support!=None: 177 | self.laySupp = SuppLayer(config,self.nClass) 178 | self.last_chunk = ChunkPool(self.laySupp.nChunk, config, pooling=config.output_pooling) 179 | self.loss = UserLoss.cys_loss 180 | a = self.config.support.value 181 | self.title = f"DNNet_{self.config.support.value}" 182 | else: 183 | self.last_chunk = ChunkPool(self.nClass,config,pooling=config.output_pooling) 184 | self.loss = UserLoss.cys_loss 185 | 186 | if self.config.wavelet is not None: 187 | self.title = self.title+f"_W" 188 | if self.highWay>0: 189 | self.title = self.title + f"_H" 190 | if self.config.custom_legend is not None: 191 | self.title = self.title + f"_{self.config.custom_legend}" 192 | 193 | ''' 194 | BinaryChunk is pool 195 | elif self.config.support=="binary": 196 | self.last_chunk = BinaryChunk(self.nClass, pooling="max") 197 | self.loss = D2NNet.binary_loss 198 | self.title = f"DNNet_binary" 199 | elif self.config.support == "logit": 200 | self.last_chunk = BinaryChunk(self.nClass, isLogit=True, pooling="max") 201 | self.loss = D2NNet.logit_loss 202 | ''' 203 | 204 | def visualize(self,visual,suffix): 205 | no = 0 206 | for plot in visual.plots: 207 | images,path = [],"" 208 | if plot['object']=='layer pattern': 209 | path = f"{visual.img_dir}/{suffix}.jpg" 210 | for no,layer in enumerate(self.DD): 211 | info = f"{suffix},{no}]" 212 | title = f"layer_{no+1}" 213 | if self.highWay==2: 214 | a = self.wLayer[no] 215 | a = torch.sigmoid(a) 216 | info = info+f"_{a:.2g}" 217 | elif self.highWay==1: 218 | a = self.wLayer[no] 219 | info = info+f"_{a:.2g}" 220 | title = title+f" w={a:.2g}" 221 | image = layer.visualize(visual,info,{'save':False,'title':title}) 222 | images.append(image) 223 | no=no+1 224 | if len(images)>0: 225 | image_all = np.concatenate(images, axis=1) 226 | #cv2.imshow("", image_all); cv2.waitKey(0) 227 | cv2.imwrite(path,image_all) 228 | 229 | def legend(self): 230 | if self.config.custom_legend is not None: 231 | leg_ = self.config.custom_legend 232 | else: 233 | leg_ = self.title 234 | return leg_ 235 | 236 | def __repr__(self): 237 | main_str = super(D2NNet, self).__repr__() 238 | main_str += f"\n========init={self.config.init_value}" 239 | return main_str 240 | 241 | def input_trans(self,x): # square-rooted and normalized 242 | #x = x.double()*self.config.input_scale 243 | if True: 244 | x = x*self.config.input_scale 245 | x_0,x_1 = torch.min(x).item(),torch.max(x).item() 246 | assert x_0>=0 247 | x = torch.sqrt(x) 248 | else: #为何不行,莫名其妙 249 | x = Z.exp_euler(x*2*math.pi).float() 250 | x_0,x_1 = torch.min(x).item(),torch.max(x).item() 251 | return x 252 | 253 | def do_classify(self,x): 254 | if self.config.isFC: 255 | x = torch.flatten(x, 1) 256 | x = self.fc1(x) 257 | return x 258 | 259 | x = self.last_chunk(x) 260 | if self.laySupp != None: 261 | x = self.laySupp(x) 262 | # output = F.log_softmax(x, dim=1) 263 | return x 264 | 265 | def OnLayerFeats(self): 266 | pass 267 | 268 | def forward(self, x): 269 | if hasattr(self, 'feat_extractor'): 270 | self.feat_extractor.clear() 271 | nSamp,nChannel = x.shape[0],x.shape[1] 272 | assert(nChannel==1) 273 | if nChannel>1: 274 | no = random.randint(0,nChannel-1) 275 | x = x[:,0:1,...] 276 | x = self.input_trans(x) 277 | if hasattr(self,'visual'): self.visual.onX(x.cpu(), f"X@input") 278 | summary = 0 279 | for no,layD in enumerate(self.DD): 280 | info = layD.__repr__() 281 | x = layD(x) 282 | if hasattr(self,'feat_extractor'): 283 | self.feat_extractor.append((self.z_modulus(x),self.wLayer[no])) 284 | if hasattr(self,'visual'): self.visual.onX(x,f"X@{no+1}") 285 | if self.highWay==2: 286 | s = torch.sigmoid(self.wLayer[no]) 287 | summary+=x*s 288 | x = x*(1-s) 289 | elif self.highWay==1: 290 | summary += x * self.wLayer[no] 291 | elif self.highWay==3: 292 | summary += self.z_modulus(x) * self.wLayer[no] 293 | if self.highWay==2: 294 | x=x+summary 295 | x = self.z_modulus(x) 296 | elif self.highWay == 1: 297 | x = summary 298 | x = self.z_modulus(x) 299 | elif self.highWay == 3: 300 | x = summary 301 | elif self.highWay == 0: 302 | x = self.z_modulus(x) 303 | if hasattr(self,'visual'): self.visual.onX(x,f"X@output") 304 | 305 | 306 | if hasattr(self,'feat_extractor'): 307 | return 308 | elif hasattr(self.config,'feat_extractor') and self.config.feat_extractor=="last_layer": 309 | return x 310 | else: 311 | output = self.do_classify(x) 312 | return output 313 | 314 | class MultiDNet(D2NNet): 315 | def __init__(self, IMG_size,nCls,nInterDifrac,freq_list,config,shareWeight=True): 316 | super(MultiDNet, self).__init__(IMG_size,nCls,nInterDifrac,config) 317 | self.isShareWeight=shareWeight 318 | self.freq_list = freq_list 319 | nFreq = len(self.freq_list) 320 | del self.DD; self.DD = None 321 | self.wFreq = torch.nn.Parameter(torch.ones(nFreq)) 322 | layer = self.GetLayer_() 323 | self.freq_nets=nn.ModuleList([ 324 | nn.ModuleList([ 325 | layer(self.M, self.N, self.config, HZ=freq) for i in range(self.nDifrac) 326 | ]) for freq in freq_list 327 | ]) 328 | if self.isShareWeight: 329 | nSubNet = len(self.freq_nets) 330 | net_0 = self.freq_nets[0] 331 | for i in range(1,nSubNet): 332 | net_1 = self.freq_nets[i] 333 | for j in range(self.nDifrac): 334 | net_1[j].share_weight(net_0[j]) 335 | 336 | 337 | def legend(self): 338 | if self.config.custom_legend is not None: 339 | leg_ = self.config.custom_legend 340 | else: 341 | title = f"MF_DNet({len(self.freq_list)} channels)" 342 | return title 343 | 344 | def __repr__(self): 345 | main_str = super(MultiDNet, self).__repr__() 346 | main_str += f"\nfreq_list={self.freq_list}_" 347 | return main_str 348 | 349 | def forward(self, x0): 350 | nSamp = x0.shape[0] 351 | x_sum = 0 352 | for id,fNet in enumerate(self.freq_nets): 353 | x = self.input_trans(x0) 354 | #d0,d1=x0.min(),x0.max() 355 | #x = x0.double() 356 | for layD in fNet: 357 | x = layD(x) 358 | #x_sum = torch.max(x_sum,self.z_modulus(x))).values() 359 | x_sum += self.z_modulus(x)*self.wFreq[id] 360 | x = x_sum 361 | 362 | output = self.do_classify(x) 363 | return output 364 | 365 | def main(): 366 | pass 367 | 368 | if __name__ == '__main__': 369 | main() -------------------------------------------------------------------------------- /python-package/onnet/DiffractiveLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .Z_utils import COMPLEX_utils as Z 3 | from .some_utils import * 4 | import numpy as np 5 | import random 6 | import torch.nn as nn 7 | import matplotlib 8 | #matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | #https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-custom-nn-modules 13 | class DiffractiveLayer(torch.nn.Module): 14 | def SomeInit(self, M_in, N_in,HZ=0.4e12): 15 | assert (M_in == N_in) 16 | self.M = M_in 17 | self.N = N_in 18 | self.z_modulus = Z.modulus 19 | self.size = M_in 20 | self.delta = 0.03 21 | self.dL = 0.02 22 | self.c = 3e8 23 | self.Hz = HZ#0.4e12 24 | 25 | self.H_z = self.Init_H() 26 | 27 | def __repr__(self): 28 | #main_str = super(DiffractiveLayer, self).__repr__() 29 | main_str = f"DiffractiveLayer_[{(int)(self.Hz/1.0e9)}G]_[{self.M},{self.N}]" 30 | return main_str 31 | 32 | def __init__(self, M_in, N_in,config,HZ=0.4e12): 33 | super(DiffractiveLayer, self).__init__() 34 | self.SomeInit(M_in, N_in,HZ) 35 | assert config is not None 36 | self.config = config 37 | #self.init_value = init_value 38 | #self.rDrop = rDrop 39 | if not hasattr(self.config,'wavelet') or self.config.wavelet is None: 40 | if self.config.modulation=="phase": 41 | self.transmission = torch.nn.Parameter(data=torch.Tensor(self.size, self.size), requires_grad=True) 42 | else: 43 | self.transmission = torch.nn.Parameter(data=torch.Tensor(self.size, self.size, 2), requires_grad=True) 44 | 45 | init_param = self.transmission.data 46 | if self.config.init_value=="reverse": # 47 | half=self.transmission.data.shape[-2]//2 48 | init_param[..., :half, :] = 0 49 | init_param[..., half:, :] = np.pi 50 | elif self.config.init_value=="random": 51 | init_param.uniform_(0, np.pi*2) 52 | elif self.config.init_value == "random_reverse": 53 | init_param = torch.randint_like(init_param,0,2)*np.pi 54 | elif self.config.init_value == "chunk": 55 | sections = split__sections() 56 | for xx in init_param.split(sections, -1): 57 | xx = random.random(0,np.pi*2) 58 | 59 | #self.rDrop = config.rDrop 60 | 61 | #self.bias = torch.nn.Parameter(data=torch.Tensor(1, 1), requires_grad=True) 62 | 63 | def visualize(self,visual,suffix, params): 64 | param = self.transmission.data 65 | name = f"{suffix}_{self.config.modulation}_" 66 | return visual.image(name,param, params) 67 | 68 | def share_weight(self,layer_1): 69 | tp = type(self) 70 | assert(type(layer_1)==tp) 71 | #del self.transmission 72 | #self.transmission = layer_1.transmission 73 | 74 | def Init_H(self): 75 | # Parameter 76 | N = self.size 77 | df = 1.0 / self.dL 78 | d=self.delta 79 | lmb=self.c / self.Hz 80 | k = np.pi * 2.0 / lmb 81 | D = self.dL * self.dL / (N * lmb) 82 | # phase 83 | def phase(i, j): 84 | i -= N // 2 85 | j -= N // 2 86 | return ((i * df) * (i * df) + (j * df) * (j * df)) 87 | 88 | ph = np.fromfunction(phase, shape=(N, N), dtype=np.float32) 89 | # H 90 | H = np.exp(1.0j * k * d) * np.exp(-1.0j * lmb * np.pi * d * ph) 91 | H_f = np.fft.fftshift(H)*self.dL*self.dL/(N*N) 92 | # print(H_f); print(H) 93 | H_z = np.zeros(H_f.shape + (2,)) 94 | H_z[..., 0] = H_f.real 95 | H_z[..., 1] = H_f.imag 96 | H_z = torch.from_numpy(H_z).cuda() 97 | return H_z 98 | 99 | def Diffractive_(self,u0, theta=0.0): 100 | if Z.isComplex(u0): 101 | z0 = u0 102 | else: 103 | z0 = u0.new_zeros(u0.shape + (2,)) 104 | z0[...,0] = u0 105 | 106 | N = self.size 107 | df = 1.0 / self.dL 108 | 109 | z0 = Z.fft(z0) 110 | u1 = Z.Hadamard(z0,self.H_z.float()) 111 | u2 = Z.fft(u1,"C2C",inverse=True) 112 | return u2 * N * N * df * df 113 | 114 | def GetTransCoefficient(self): 115 | ''' 116 | eps = 1e-5; momentum = 0.1; affine = True 117 | 118 | mean = torch.mean(self.transmission, 1) 119 | vari = torch.var(self.transmission, 1) 120 | amp_bn = torch.batch_norm(self.transmission,mean,vari) 121 | :return: 122 | ''' 123 | amp_s = Z.exp_euler(self.transmission) 124 | 125 | return amp_s 126 | 127 | def forward(self, x): 128 | diffrac = self.Diffractive_(x) 129 | amp_s = self.GetTransCoefficient() 130 | x = Z.Hadamard(diffrac,amp_s.float()) 131 | if(self.config.rDrop>0): 132 | drop = Z.rDrop2D(1-self.rDrop,(self.M,self.N),isComlex=True) 133 | x = Z.Hadamard(x, drop) 134 | #x = x+self.bias 135 | return x 136 | 137 | class DiffractiveAMP(DiffractiveLayer): 138 | def __init__(self, M_in, N_in,rDrop=0.0): 139 | super(DiffractiveAMP, self).__init__(M_in, N_in,rDrop,params="amp") 140 | #self.amp = torch.nn.Parameter(data=torch.Tensor(self.size, self.size, 2), requires_grad=True) 141 | self.transmission.data.uniform_(0, 1) 142 | 143 | def GetTransCoefficient(self): 144 | # amp_s = Z.sigmoid(self.amp) 145 | # amp_s = torch.clamp(self.amp, 1.0e-6, 1) 146 | amp_s = self.transmission 147 | return amp_s 148 | 149 | class DiffractiveWavelet(DiffractiveLayer): 150 | def __init__(self, M_in, N_in,config,HZ=0.4e12): 151 | super(DiffractiveWavelet, self).__init__(M_in, N_in,config,HZ) 152 | #self.hough = torch.nn.Parameter(data=torch.Tensor(2), requires_grad=True) 153 | self.Init_DisTrans() 154 | #self.GetXita() 155 | 156 | def __repr__(self): 157 | main_str = f"Diffrac_Wavelet_[{(int)(self.Hz/1.0e9)}G]_[{self.M},{self.N}]" 158 | return main_str 159 | 160 | def share_weight(self,layer_1): 161 | tp = type(self) 162 | assert(type(layer_1)==tp) 163 | del self.wavelet 164 | self.wavelet = layer_1.wavelet 165 | del self.dis_map 166 | self.dis_map = layer_1.dis_map 167 | del self.wav_indices 168 | self.wav_indices = layer_1.wav_indices 169 | 170 | 171 | def Init_DisTrans(self): 172 | origin_r, origin_c = (self.M-1) / 2, (self.N-1) / 2 173 | origin_r = random.uniform(0, self.M-1) 174 | origin_c = random.uniform(0, self.N - 1) 175 | self.dis_map={} 176 | #self.dis_trans = torch.zeros((self.size, self.size)).int() 177 | self.wav_indices = torch.LongTensor((self.size*self.size)).cuda() 178 | nz=0 179 | for r in range(self.M): 180 | for c in range(self.N): 181 | off = np.sqrt((r - origin_r) * (r - origin_r) + (c - origin_c) * (c - origin_c)) 182 | i_off = (int)(off+0.5) 183 | if i_off not in self.dis_map: 184 | self.dis_map[i_off]=len(self.dis_map) 185 | id = self.dis_map[i_off] 186 | #self.dis_trans[r, c] = id 187 | self.wav_indices[nz] = id; nz=nz+1 188 | #print(f"[{r},{c}]={self.dis_trans[r, c]}") 189 | nD = len(self.dis_map) 190 | if False: 191 | plt.imshow(self.dis_trans.numpy()) 192 | plt.show() 193 | 194 | self.wavelet = torch.nn.Parameter(data=torch.Tensor(nD), requires_grad=True) 195 | self.wavelet.data.uniform_(0, np.pi*2) 196 | #self.dis_trans = self.dis_trans.cuda() 197 | 198 | def GetXita(self): 199 | if False: 200 | xita = torch.zeros((self.size, self.size)) 201 | for r in range(self.M): 202 | for c in range(self.N): 203 | pos = self.dis_trans[r, c] 204 | xita[r,c] = self.wavelet[pos] 205 | origin_r,origin_c=self.M/2,self.N/2 206 | #xita = self.dis_trans*self.hough[0]+self.hough[1] 207 | else: 208 | xita = torch.index_select(self.wavelet, 0, self.wav_indices) 209 | xita = xita.view(self.size, self.size) 210 | 211 | # print(xita) 212 | return xita 213 | 214 | def GetTransCoefficient(self): 215 | xita = self.GetXita() 216 | amp_s = Z.exp_euler(xita) 217 | return amp_s 218 | 219 | def visualize(self,visual,suffix, params): 220 | xita = self.GetXita() 221 | name = f"{suffix}" 222 | return visual.image(name,torch.sin(xita.detach()), params) -------------------------------------------------------------------------------- /python-package/onnet/DropOutLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .Z_utils import COMPLEX_utils as Z 3 | 4 | #Very strange behavior of DROPOUT 5 | class DropOutLayer(torch.nn.Module): 6 | def __init__(self, M_in, N_in,drop=0.5): 7 | super(DropOutLayer, self).__init__() 8 | assert (M_in == N_in) 9 | self.M = M_in 10 | self.N = N_in 11 | self.rDrop = drop 12 | 13 | def forward(self, x): 14 | assert(Z.isComplex(x)) 15 | nX = x.numel()//2 16 | d_shape=x.shape[:-1] 17 | drop = np.random.binomial(1, self.rDrop, size=d_shape).astype(np.float) 18 | #print(f"x={x.shape} drop={drop.shape}") 19 | drop = torch.from_numpy(drop).cuda() 20 | x[...,0] *= drop 21 | x[...,1] *= drop 22 | return x -------------------------------------------------------------------------------- /python-package/onnet/FFT_layer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Yingshi Chen 3 | 4 | @Date: 2020-04-10 11:22:27 5 | @ 6 | # Description: 7 | ''' 8 | 9 | import torch 10 | from .Z_utils import COMPLEX_utils as Z 11 | from .some_utils import * 12 | import numpy as np 13 | import random 14 | import torch.nn as nn 15 | import matplotlib 16 | #matplotlib.use('Agg') 17 | import matplotlib.pyplot as plt 18 | 19 | class FFT_Layer(torch.nn.Module): 20 | def SomeInit(self, M_in, N_in,isInv=False): 21 | assert (M_in == N_in) 22 | self.M = M_in 23 | self.N = N_in 24 | self.isInv = isInv 25 | 26 | def __repr__(self): 27 | i_ = "_i" if self.isInv else "" 28 | main_str = f"FFT_Layer{i_}_[{self.M},{self.N}]" 29 | return main_str 30 | 31 | def __init__(self, M_in, N_in,config,isInv=False): 32 | super(FFT_Layer, self).__init__() 33 | self.SomeInit(M_in, N_in,isInv) 34 | assert config is not None 35 | self.config = config 36 | #self.init_value = init_value 37 | 38 | def visualize(self,visual,suffix, params): 39 | param = self.transmission.data 40 | name = f"{suffix}_{self.config.modulation}_" 41 | return visual.image(name,param, params) 42 | 43 | 44 | def Diffractive_(self,u0, theta=0.0): 45 | if Z.isComplex(u0): 46 | z0 = u0 47 | else: 48 | z0 = u0.new_zeros(u0.shape + (2,)) 49 | z0[...,0] = u0 50 | 51 | N = self.size 52 | df = 1.0 / self.dL 53 | 54 | z0 = Z.fft(z0) 55 | u1 = Z.Hadamard(z0,self.H_z.float()) 56 | u2 = Z.fft(u1,"C2C",inverse=True) 57 | return u2 * N * N * df * df 58 | 59 | def forward(self, x): 60 | #return x 61 | if Z.isComplex(x): 62 | z0 = x 63 | else: 64 | z0 = x.new_zeros(x.shape + (2,)) 65 | z0[...,0] = x 66 | if self.isInv: 67 | x = Z.fft(z0,"C2C",inverse=self.isInv) 68 | else: 69 | x = (Z.fft(z0,"C2C",inverse=self.isInv)) 70 | x_0,x_1 = torch.min(x),torch.max(x) 71 | return x 72 | 73 | def trans(img): 74 | plt.figure(figsize=(10,8)) 75 | plt.subplot(121),plt.imshow(img, cmap = 'gray') 76 | plt.title('Input Image'), plt.xticks([]), plt.yticks([]) 77 | f = (abs(np.fft.fftshift(fftn(img))))**0.25*(255)**3 # Amplify 78 | plt.subplot(122),plt.imshow(f, cmap = 'gray') 79 | plt.title('Spectrum'), plt.xticks([]), plt.yticks([]) 80 | plt.show() -------------------------------------------------------------------------------- /python-package/onnet/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class UserLoss(object): 5 | 6 | @staticmethod 7 | def cys_loss(output, target, reduction='mean'): 8 | #loss = F.binary_cross_entropy(output, target, reduction=reduction) 9 | loss = F.cross_entropy(output, target, reduction=reduction) 10 | #loss = F.nll_loss(output, target, reduction=reduction) 11 | 12 | return loss -------------------------------------------------------------------------------- /python-package/onnet/NET_config.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | ''' 5 | parser.add_argument is better than NET_config 6 | ''' 7 | class NET_config: 8 | def __init__(self,net_type, data_set, IMG_size, lr_base, batch_size,nClass,nLayer=-1): 9 | #seed_everything(self.seed) 10 | self.net_type = net_type 11 | self.data_set = data_set 12 | self.IMG_size = IMG_size 13 | self.lr_base = lr_base # "random" "zero" 14 | self.batch_size = batch_size 15 | self.nClass = nClass 16 | self.nLayer = nLayer 17 | -------------------------------------------------------------------------------- /python-package/onnet/Net_Instance.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Yingshi Chen 3 | 4 | @Date: 2020-01-16 15:08:16 5 | @ 6 | # Description: 7 | ''' 8 | from .D2NNet import * 9 | from .RGBO_CNN import * 10 | from .OpticalFormer import * 11 | import math 12 | from copy import copy, deepcopy 13 | 14 | def dump_model_params(model): 15 | nzParams = 0 16 | for name, param in model.named_parameters(): 17 | if param.requires_grad: 18 | nzParams += param.nelement() 19 | print(f"\t{name}={param.nelement()}") 20 | print(f"========All parameters={nzParams}") 21 | return nzParams 22 | 23 | def Net_dump(net): 24 | nzParams=dump_model_params(net) 25 | 26 | #def DNet_instance(net_type,dataset,IMG_size,lr_base,batch_size,nClass,nLayer): 需要重写,只有一个config 27 | def DNet_instance(config): 28 | net_type, dataset, IMG_size, lr_base, batch_size, nClass, nLayer = \ 29 | config.net_type,config.data_set, config.IMG_size, config.lr_base, config.batch_size, config.nClass, config.nLayer 30 | if net_type == "BiDNet": 31 | lr_base = 0.01 32 | if dataset == "emnist": 33 | lr_base = 0.01 34 | 35 | config_base = DNET_config(batch=batch_size, lr_base=lr_base) 36 | if hasattr(config,'feat_extractor'): 37 | config_base.feat_extractor = config.feat_extractor 38 | env_title = f"{net_type}_{dataset}_{IMG_size}_{lr_base}_{config_base.env_title()}" 39 | if net_type == "MF_DNet": 40 | freq_list = [0.3e12, 0.35e12, 0.4e12, 0.42e12] 41 | env_title = env_title + f"_C{len(freq_list)}" 42 | if net_type == "BiDNet": 43 | config_base = DNET_config(batch=batch_size, lr_base=lr_base, chunk="binary") 44 | 45 | if net_type == "cnn": 46 | model = Mnist_Net(config=config_base) 47 | return env_title, model 48 | 49 | if net_type == "DNet": 50 | model = D2NNet(IMG_size, nClass, nLayer, config_base) 51 | elif net_type == "WNet": 52 | config_base.wavelet={"nWave":3} 53 | model = D2NNet(IMG_size, nClass, nLayer, config_base) 54 | elif net_type == "MF_DNet": 55 | # model = MultiDNet(IMG_size, nClass, nLayer,[0.3e12,0.35e12,0.4e12,0.42e12,0.5e12,0.6e12], DNET_config()) 56 | model = MultiDNet(IMG_size, nClass, nLayer, [0.3e12, 0.35e12, 0.4e12, 0.42e12], config_base) 57 | elif net_type == "MF_WNet": 58 | config_base.wavelet = {"nWave": 3} 59 | model = MultiDNet(IMG_size, nClass, nLayer, [0.3e12, 0.35e12, 0.4e12, 0.42e12], config_base) 60 | elif net_type == "BiDNet": 61 | model = D2NNet(IMG_size, nClass, nLayer, config_base) 62 | elif net_type == "OptFormer": 63 | pass 64 | 65 | #model.double() 66 | 67 | return env_title, model 68 | 69 | def RGBO_CNN_instance(config): 70 | assert config.net_type == "RGBO_CNN" 71 | env_title = f"{config.net_type}_{config.dnet_type}_{config.data_set}_{config.IMG_size}_{config.lr_base}_" 72 | assert hasattr(config,'dnet_type') 73 | 74 | if config.dnet_type!="": 75 | d_conf = deepcopy(config) 76 | if config.dnet_type == "stack_input": 77 | d_conf.net_type = "DNet" 78 | #d_conf.nLayer = 1 79 | #d_conf.feat_extractor = "layers" 80 | else: 81 | d_conf.nLayer = 10 82 | d_conf.net_type = "WNet" 83 | _,DNet = DNet_instance(d_conf) 84 | else: 85 | DNet=None 86 | model = RGBO_CNN(config,DNet) 87 | 88 | return env_title, model 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /python-package/onnet/OpticalFormer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | from torch import nn 5 | from .OpticalFormer_util import * 6 | # import lite_bert 7 | MIN_NUM_PATCHES = 16 8 | 9 | class Residual(nn.Module): 10 | def __init__(self, fn): 11 | super().__init__() 12 | self.fn = fn 13 | def forward(self, x, **kwargs): 14 | return self.fn(x, **kwargs) + x 15 | 16 | class PreNorm(nn.Module): 17 | def __init__(self, dim, fn): 18 | super().__init__() 19 | self.norm = nn.LayerNorm(dim) 20 | self.fn = fn 21 | def forward(self, x, **kwargs): 22 | return self.fn(self.norm(x), **kwargs) 23 | 24 | class FeedForward(nn.Module): 25 | def __init__(self, dim, hidden_dim, dropout = 0.): 26 | super().__init__() 27 | self.net = nn.Sequential( 28 | nn.Linear(dim, hidden_dim), 29 | nn.GELU(), 30 | nn.Dropout(dropout), 31 | nn.Linear(hidden_dim, dim), 32 | nn.Dropout(dropout) 33 | ) 34 | def forward(self, x): 35 | return self.net(x) 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 39 | super().__init__() 40 | inner_dim = dim_head * heads 41 | self.heads = heads 42 | self.scale = dim ** -0.5 43 | 44 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 45 | self.to_out = nn.Sequential( 46 | nn.Linear(inner_dim, dim), 47 | nn.Dropout(dropout) 48 | ) 49 | 50 | def forward(self, x, mask = None): 51 | b, n, _, h = *x.shape, self.heads 52 | qkv = self.to_qkv(x).chunk(3, dim = -1) 53 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 54 | 55 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 56 | mask_value = -torch.finfo(dots.dtype).max 57 | 58 | if mask is not None: 59 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 60 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 61 | mask = mask[:, None, :] * mask[:, :, None] 62 | dots.masked_fill_(~mask, mask_value) 63 | del mask 64 | 65 | attn = dots.softmax(dim=-1) 66 | 67 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 68 | out = rearrange(out, 'b h n d -> b n (h d)') 69 | out = self.to_out(out) 70 | return out 71 | 72 | def unitwise_norm(x,axis=None): 73 | """Compute norms of each output unit separately, also for linear layers.""" 74 | if len(torch.squeeze(x).shape) <= 1: # Scalars and vectors 75 | axis = None 76 | keepdims = False 77 | return torch.norm(x) 78 | elif len(x.shape) in [2, 3]: # Linear layers of shape IO or multihead linear 79 | # axis = 0 80 | # axis = 1 81 | keepdims = True 82 | elif len(x.shape) == 4: # Conv kernels of shape HWIO 83 | if axis is None: 84 | axis = [0, 1, 2,] 85 | keepdims = True 86 | else: 87 | raise ValueError(f'Got a parameter with shape not in [1, 2, 4]! {x}') 88 | return torch.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5 89 | 90 | def clip_grad_rc(grad,W,row_major=False,eps = 1.e-3,clip=0.02): 91 | # adaptive_grad_clip 92 | if len(grad.shape)==2: 93 | nR,nC = grad.shape 94 | axis = 1 if row_major else 0 95 | g_norm = unitwise_norm(grad,axis=axis) 96 | W_norm = unitwise_norm(W,axis=axis) 97 | assert(g_norm.shape==W_norm.shape) 98 | W_norm[W_normrc', grad, s) 106 | return grad 107 | 108 | def clip_grad(model,eps = 1.e-3,clip=0.002,method="agc"): 109 | known_modules = {'Linear'} 110 | for module in model.modules(): 111 | classname = module.__class__.__name__ 112 | if classname not in known_modules: 113 | continue 114 | if classname == 'Conv2d': 115 | assert(False) 116 | grad = None 117 | elif classname == 'BertLayerNorm': 118 | grad = None 119 | else: 120 | grad = module.weight.grad.data 121 | W = module.weight.data 122 | 123 | # adaptive_grad_clip 124 | assert len(grad.shape)==2 125 | nR,nC = grad.shape 126 | axis = 1 if nR>nC else 0 127 | g_norm = unitwise_norm(grad,axis=axis) 128 | W_norm = unitwise_norm(W,axis=axis) 129 | W_norm[W_normrc', grad, s) 137 | module.weight.grad.data.copy_(grad) 138 | 139 | if module.bias is not None: 140 | v = module.bias.grad.data 141 | axis = 0 142 | b_grad = clip_grad_rc(v,module.bias.data,row_major=axis==1,eps = eps,clip=clip) 143 | module.bias.grad.data.copy_(b_grad) 144 | 145 | 146 | class Transformer(nn.Module): 147 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout,clip_grad=""): 148 | super().__init__() 149 | self.layers = nn.ModuleList([]) 150 | self.isV0 = False 151 | for _ in range(depth): 152 | if self.isV0: 153 | self.layers.append(nn.ModuleList([ 154 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 155 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 156 | ])) 157 | else: 158 | # self.layers.append(lite_bert.BTransformer(dim, heads, dim * 4, dropout)) 159 | self.layers.append(BTransformer(dim, heads, dim * 4, dropout,clip_grad=clip_grad)) 160 | def forward(self, x, mask = None): 161 | if self.isV0: 162 | for attn, ff in self.layers: 163 | x = attn(x, mask = mask) 164 | x = ff(x) 165 | else: 166 | for BTrans in self.layers: 167 | x = BTrans(x,mask) 168 | return x 169 | 170 | class OpticalFormer(nn.Module): 171 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ff_hidden, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.,clip_grad=""): 172 | super().__init__() 173 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 174 | num_patches = (image_size // patch_size) ** 2 #64 175 | patch_dim = channels * patch_size ** 2 #48 pixles in each patch 176 | assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' 177 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 178 | 179 | self.patch_size = patch_size 180 | self.clip_grad = clip_grad 181 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches , dim)) 182 | self.patch_to_embedding = nn.Linear(patch_dim, dim) 183 | # self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 184 | # self.dropout = nn.Dropout(emb_dropout) 185 | 186 | self.transformer = Transformer(dim, depth, heads, dim_head, ff_hidden, dropout,clip_grad=self.clip_grad) 187 | 188 | self.pool = pool 189 | self.to_latent = nn.Identity() 190 | 191 | self.mlp_head = nn.Sequential( 192 | nn.Identity() if self.clip_grad=="agc" else nn.LayerNorm(dim), 193 | nn.Linear(dim, num_classes) 194 | ) 195 | 196 | def name_(self): 197 | return "ViT_" 198 | 199 | def forward(self, img, mask = None): 200 | p = self.patch_size 201 | 202 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 203 | # x = rearrange(img, 'b c (h p1) (w p2) -> b (h w c) (p1 p2)', p1 = p, p2 = p) 204 | x = self.patch_to_embedding(x) 205 | b, n, _ = x.shape 206 | 207 | # cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 208 | # x = torch.cat((cls_tokens, x), dim=1) 209 | x += self.pos_embedding[:, :(n )] 210 | # x = self.dropout(x) 211 | 212 | x = self.transformer(x, mask) 213 | 214 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 215 | 216 | x = self.to_latent(x) 217 | return self.mlp_head(x) 218 | 219 | def predict(self,output): 220 | if self.config.support == "binary": 221 | nGate = output.shape[1] // 2 222 | #assert nGate == self.n 223 | pred = 0 224 | for i in range(nGate): 225 | no = 2*(nGate - 1 - i) 226 | val_2 = torch.stack([output[:, no], output[:, no + 1]], 1) 227 | pred_i = val_2.max(1, keepdim=True)[1] # get the index of the max log-probability 228 | pred = pred * 2 + pred_i 229 | elif self.config.support == "logit": 230 | nGate = output.shape[1] 231 | # assert nGate == self.n 232 | pred = 0 233 | for i in range(nGate): 234 | no = nGate - 1 - i 235 | val_2 = F.sigmoid(output[:, no]) 236 | pred_i = (val_2+0.5).long() 237 | pred = pred * 2 + pred_i 238 | else: 239 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 240 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 241 | return pred 242 | 243 | -------------------------------------------------------------------------------- /python-package/onnet/OpticalFormer_util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torch.nn.functional as F 5 | # from .sparse_max import sparsemax, entmax15 6 | 7 | class LayerNorm(nn.Module): 8 | "Construct a layernorm module (See citation for details)." 9 | 10 | def __init__(self, features, eps=1e-6): 11 | super(LayerNorm, self).__init__() 12 | self.a_2 = nn.Parameter(torch.ones(features)) 13 | self.b_2 = nn.Parameter(torch.zeros(features)) 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | mean = x.mean(-1, keepdim=True) 18 | std = x.std(-1, keepdim=True) 19 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 20 | 21 | class GELU(nn.Module): 22 | """ 23 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 24 | """ 25 | def forward(self, x): 26 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 27 | 28 | class QK_Attention(nn.Module): 29 | def forward(self, query, key, value, mask=None, dropout=None): 30 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) 31 | #mini batch多句话得长度并不一致,需要按照最大得长度对短句子进行补全,也就是padding零,mask起来,填充一个负无穷(-1e9这样得数值),这样计算就可以为0了,等于把计算遮挡住。 32 | if mask is not None: 33 | scores = scores.masked_fill(mask == 0, -1e9) 34 | 35 | p_attn = F.softmax(scores, dim=-1) 36 | # p_attn = entmax15(scores, dim=-1) 37 | 38 | if dropout is not None: 39 | p_attn = dropout(p_attn) 40 | 41 | return torch.matmul(p_attn, value), p_attn 42 | 43 | class MultiHeadedAttention(nn.Module): 44 | def __init__(self, h, d_model, dropout=0.1): 45 | super().__init__() 46 | assert d_model % h == 0 47 | 48 | # We assume d_v always equals d_k 49 | self.d_k = d_model // h 50 | self.h = h 51 | 52 | self.linear_project = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) 53 | self.output_linear = nn.Linear(d_model, d_model) 54 | self.attention = QK_Attention() 55 | self.dropout = nn.Dropout(p=dropout) if dropout>0 else None 56 | 57 | def forward(self, x, mask=None): 58 | batch_size = x.size(0) 59 | if self.attention is None: 60 | x = self.dropout(x) #Very interesting, why self-attention is so useful? 61 | else: 62 | if self.h == 1: 63 | # query, key, value = [l(x) for l, x in zip(self.linear_project, (x, x, x))] 64 | query, key, value = x,x,x 65 | else: 66 | # 1) Do all the linear projections in batch from d_model => h x d_k 67 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 68 | for l, x in zip(self.linear_project, (x, x, x))] 69 | # query, key, value = (x,x,x) 70 | 71 | # 2) Apply attention on all the projected vectors in batch. 72 | x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout) 73 | 74 | # 3) "Concat" using a view and apply a final linear. 75 | if self.h > 1: 76 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 77 | 78 | return self.output_linear(x) 79 | 80 | class Residual(nn.Module): 81 | def __init__(self, fn): 82 | super().__init__() 83 | self.fn = fn 84 | def forward(self, x, **kwargs): 85 | return self.fn(x, **kwargs) + x 86 | 87 | #keep structure simple ,no norm,no dropout!!! 88 | class PreNorm(nn.Module): 89 | def __init__(self, dim, fn): 90 | super().__init__() 91 | self.norm = nn.LayerNorm(dim) #why this is so good 92 | # self.norm = nn.BatchNorm1d(64) #nearly same as layernorm 93 | # self.norm = nn.Identity() 94 | # self.norm = nn.BatchNorm1d(dim) 95 | self.fn = fn 96 | 97 | def forward(self, x, **kwargs): 98 | if self.fn is None: 99 | x = self.norm(x) 100 | else: 101 | x = self.fn(self.norm(x), **kwargs) 102 | return x 103 | 104 | class PositionwiseFeedForward(nn.Module): 105 | "Implements FFN equation." 106 | 107 | def __init__(self, d_model, d_ff, dropout=0.1): 108 | super(PositionwiseFeedForward, self).__init__() 109 | self.w_1 = nn.Linear(d_model, d_ff) 110 | self.w_2 = nn.Linear(d_ff, d_model) 111 | self.dropout = nn.Dropout(dropout) if dropout > 0 else None 112 | self.activation = GELU() 113 | # self.activation = nn.ReLU() # maybe use ReLU 114 | 115 | def forward(self, x): 116 | if self.dropout is None: 117 | return self.w_2(self.activation(self.w_1(x))) 118 | else: 119 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 120 | 121 | class BTransformer(nn.Module): 122 | """ 123 | Bidirectional Encoder = Transformer (self-attention) 124 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 125 | """ 126 | 127 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout,clip_grad=""): 128 | """ 129 | :param hidden: hidden size of transformer 130 | :param attn_heads: head sizes of multi-head attention 131 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 132 | :param dropout: dropout rate 133 | """ 134 | 135 | super().__init__() 136 | print(f"attn_heads={attn_heads}") 137 | self.clip_grad = clip_grad 138 | # self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) 139 | # self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) 140 | # self.attn = SublayerConnection(size=hidden, dropout=dropout) 141 | # self.ff = SublayerConnection(size=hidden, dropout=dropout) 142 | if self.clip_grad == "agc": 143 | self.attn = Residual( MultiHeadedAttention(h = attn_heads, d_model=hidden, dropout=dropout) ) 144 | self.ff = Residual( PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) ) 145 | else: 146 | self.attn = Residual(PreNorm(hidden, MultiHeadedAttention(h = attn_heads, d_model=hidden, dropout=dropout))) 147 | self.ff = Residual(PreNorm(hidden, PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout))) 148 | 149 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None 150 | 151 | def forward(self, x, mask): 152 | # x = self.attn(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) 153 | # x = self.ff(x, self.feed_forward) 154 | x = self.attn(x, mask=mask) 155 | x = self.ff(x) 156 | if self.dropout is not None: 157 | return self.dropout(x) 158 | else: 159 | return x 160 | 161 | 162 | class AttentionQKV(nn.Module): 163 | def __init__(self, hidden, attn_heads, dropout): 164 | super(AttentionQKV, self).__init__() 165 | self.attn = Residual(PreNorm(hidden, MultiHeadedAttention(h = attn_heads, d_model=hidden, dropout=dropout))) 166 | 167 | def forward(self, x, mask=None): 168 | shape = list(x.shape) 169 | if len(shape)==2: 170 | x = x.unsqueeze(1) 171 | x = self.attn(x, mask=mask) 172 | if len(shape)==2: 173 | x = x.squeeze(1) 174 | return x 175 | -------------------------------------------------------------------------------- /python-package/onnet/PoolForCls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from .some_utils import * 5 | 6 | class ChunkPool(torch.nn.Module): 7 | def __init__(self, nCls,config,pooling="max",chunk_dim=-1): 8 | super(ChunkPool, self).__init__() 9 | self.nClass = nCls 10 | self.pooling = pooling 11 | self.chunk_dim=chunk_dim 12 | self.config = config 13 | #self.regions = split_regions_2d(x.shape,self.nClass) 14 | 15 | def __repr__(self): 16 | main_str = super(ChunkPool, self).__repr__() 17 | main_str += f"_cls[{self.nClass}]_pool[{self.pooling}]" 18 | return main_str 19 | 20 | def forward(self, x): 21 | nSamp = x.shape[0] 22 | if False: 23 | x1 = torch.zeros((nSamp, self.nClass)).double() 24 | step = self.M // self.nClass 25 | for samp in range(nSamp): 26 | for i in range(self.nClass): 27 | x1[samp,i] = torch.max(x[samp,:,:,i*step:(i+1)*step]) 28 | x_np = x1.detach().cpu().numpy() 29 | x = x1.cuda() 30 | else: 31 | x_max=[] 32 | if self.config.output_chunk=="1D": 33 | sections=split__sections(x.shape[self.chunk_dim],self.nClass) 34 | for xx in x.split(sections, self.chunk_dim): 35 | x2 = xx.contiguous().view(nSamp, -1) 36 | if self.pooling == "max": 37 | x3 = torch.max(x2, 1) 38 | x_max.append(x3.values) 39 | else: 40 | x3 = torch.mean(x2, 1) 41 | x_max.append(x3) 42 | else: #2D 43 | regions = split_regions_2d(x.shape,self.nClass) 44 | for box in regions: 45 | x2 = x[...,box[0]:box[1],box[2]:box[3]] 46 | x2 = x2.contiguous().view(nSamp, -1) 47 | if self.pooling == "max": 48 | x3 = torch.max(x2, 1) 49 | x_max.append(x3.values) 50 | else: 51 | x3 = torch.mean(x2, 1) 52 | x_max.append(x3) 53 | assert len(x_max)==self.nClass 54 | x = torch.stack(x_max,1) 55 | #x_np = x.detach().cpu().numpy() 56 | #print(x_np) 57 | return x 58 | 59 | class BinaryChunk(torch.nn.Module): 60 | def __init__(self, nCls,isLogit=False,pooling="max",chunk_dim=-1): 61 | super(BinaryChunk, self).__init__() 62 | self.nClass = nCls 63 | self.nChunk = (int)(math.ceil(math.log2(self.nClass))) 64 | self.pooling = pooling 65 | self.isLogit = isLogit 66 | 67 | def __repr__(self): 68 | main_str = super(BinaryChunk, self).__repr__() 69 | if self.isLogit: 70 | main_str += "_logit" 71 | main_str += f"_nChunk{self.nChunk}_cls[{self.nClass}]_pool[{self.pooling}]" 72 | return main_str 73 | 74 | def chunk_poll(self,ck,nSamp): 75 | x2 = ck.contiguous().view(nSamp, -1) 76 | if self.pooling == "max": 77 | x3 = torch.max(x2, 1) 78 | return x3.values 79 | else: 80 | x3 = torch.mean(x2, 1) 81 | return x3 82 | 83 | def forward(self, x): 84 | nSamp = x.shape[0] 85 | x_max=[] 86 | for ck in x.chunk(self.nChunk, -1): 87 | if self.isLogit: 88 | x_max.append(self.chunk_poll(ck,nSamp)) 89 | else: 90 | for xx in ck.chunk(2, -2): 91 | x2 = xx.contiguous().view(nSamp, -1) 92 | if self.pooling == "max": 93 | x3 = torch.max(x2, 1) 94 | x_max.append(x3.values) 95 | else: 96 | x3 = torch.mean(x2, 1) 97 | x_max.append(x3) 98 | x = torch.stack(x_max,1) 99 | 100 | return x -------------------------------------------------------------------------------- /python-package/onnet/RGBO_CNN.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch.nn as nn 3 | import os 4 | #from torchvision import models 5 | sys.path.append("../..") 6 | from cnn_models import * 7 | from torchvision import transforms 8 | from torchvision.transforms.functional import to_grayscale 9 | from torch.autograd import Variable 10 | # from resnet import resnet50 11 | from copy import deepcopy 12 | import numpy as np 13 | import pickle 14 | from .NET_config import * 15 | from .D2NNet import * 16 | 17 | class RGBO_CNN_config(NET_config): 18 | def __init__(self, net_type, data_set, IMG_size, lr_base, batch_size, nClass, nLayer): 19 | super(RGBO_CNN_config, self).__init__(net_type, data_set, IMG_size, lr_base, batch_size,nClass,nLayer) 20 | #self.dnet_type = "" 21 | self.dnet_type = "stack_input" 22 | self.dnet_type = "stack_feature" 23 | 24 | def image_transformer(): 25 | """ 26 | :return: A transformer to convert a PIL image to a tensor image 27 | ready to feed into a neural network 28 | """ 29 | return { 30 | 'train': transforms.Compose([ 31 | transforms.RandomCrop(224), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]), 36 | 'val': transforms.Compose([ 37 | transforms.Resize(256), 38 | transforms.CenterCrop(224), 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 41 | ]), 42 | } 43 | 44 | ''' 45 | 1 参见cifar_rgbF.jpg,简单的fourier channel没啥效果 46 | ''' 47 | class D_input(nn.Module): 48 | def __init__(self, config, DNet): 49 | super(D_input, self).__init__() 50 | self.config = config 51 | self.DNet = DNet 52 | self.inplanes = 64 53 | self.nLayD = DNet.nDifrac#self.DNet.config.nLayer 54 | #self.nLayD = 1 55 | #self.c_input =nn.Conv2d(3+self.nLayD, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 56 | self.c_input = nn.Conv2d(3+self.nLayD, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 57 | 58 | def forward(self, x): 59 | nChan = x.shape[1] 60 | assert nChan==3 or nChan==1 61 | if nChan==3: 62 | gray = x[:, 0:1]*0.3 + 0.59 * x[:, 1:2] + 0.11 * x[:, 2:3] # to_grayscale(x) 63 | else: 64 | gray = x 65 | return self.DNet.forward(gray) 66 | listT = [] 67 | for i in range(nChan): 68 | listT.append(x[:, i:i+1]) 69 | if self.nLayD>=1: 70 | self.DNet.forward(gray) 71 | assert len(self.DNet.feat_extractor) == self.nLayD 72 | for opti, w in self.DNet.feat_extractor: 73 | listT.append(opti) #*w 74 | elif self.nLayD==0:# 75 | pass 76 | else: 77 | listT.append(gray) 78 | 79 | x = torch.stack(listT,dim=1).squeeze() 80 | if hasattr(self, 'visual'): self.visual.onX(x, f"D_input") 81 | x = self.c_input(x) 82 | return x 83 | 84 | def forward_000(self, x): 85 | if False: 86 | gray = x[:, 0:1] # to_grayscale(x) 87 | self.DNet.forward(gray) 88 | # in_opti = self.DNet.concat_layer_modulus() # self.get_resnet_convs_out(x) 89 | for opti, w in self.DNet.feat_extractor: 90 | opti = torch.stack([opti, opti, opti], 1).squeeze() # opti.repeat(3, 1) 91 | out_opti = self.resNet.forward(opti) 92 | out_sum = out_sum + out_opti * w 93 | pass 94 | 95 | class RGBO_CNN(torch.nn.Module): 96 | ''' 97 | resnet https://missinglink.ai/guides/pytorch/pytorch-resnet-building-training-scaling-residual-networks-pytorch/ 98 | ''' 99 | def pick_models(self): 100 | if False: #from torch vision or cadene models 101 | model_names = sorted(name for name in cnn_models.__dict__ 102 | if name.islower() and not name.startswith("__") 103 | and callable(models.__dict__[name])) 104 | print(model_names) 105 | # pretrainedmodels https://data.lip6.fr/cadene/pretrainedmodels/ 106 | model_names = ['alexnet', 'bninception', 'cafferesnet101', 'densenet121', 'densenet161', 'densenet169', 107 | 'densenet201', 108 | 'dpn107', 'dpn131', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'fbresnet152', 109 | 'inceptionresnetv2', 'inceptionv3', 'inceptionv4', 'nasnetalarge', 'nasnetamobile', 110 | 'pnasnet5large', 111 | 'polynet', 112 | 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x4d', 113 | 'resnext101_64x4d', 114 | 'se_resnet101', 'se_resnet152', 'se_resnet50', 'se_resnext101_32x4d', 'se_resnext50_32x4d', 115 | 'senet154', 'squeezenet1_0', 'squeezenet1_1', 116 | 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'xception'] 117 | 118 | # model_name='cafferesnet101' 119 | # model_name='resnet101' 120 | # model_name='se_resnet50' 121 | # model_name='vgg16_bn' 122 | # model_name='vgg11_bn' 123 | # model_name='dpn68' #learning rate=0.0001 效果较好 124 | self.back_bone = 'resnet18_x' 125 | # model_name='dpn92' 126 | # model_name='senet154' 127 | # model_name='densenet121' 128 | # model_name='alexnet' 129 | # model_name='senet154' 130 | cnn_model = ResNet34() ;#models.resnet18(pretrained=True) 131 | return cnn_model 132 | 133 | def __init__(self, config,DNet): 134 | super(RGBO_CNN, self).__init__() 135 | seed_everything(42) 136 | self.config = config 137 | backbone = self.pick_models() 138 | if self.config.dnet_type == "stack_feature": 139 | self.DInput = D_input(config,DNet) 140 | elif self.config.dnet_type == "stack_input": #False and hasattr(self,'DInput'): 141 | self.CNet = nn.Sequential(*list(backbone.children())[1:]) 142 | else: 143 | self.CNet = nn.Sequential(*list(backbone.children())) 144 | 145 | #print(f"=> creating model CNet='{self.CNet}'\nDNet={self.DNet}") 146 | if False: #外层处理 147 | if config.gpu_device is not None: 148 | self.cuda(config.gpu_device) 149 | print(next(self.parameters()).device) 150 | self.thickness_criterion = self.thickness_criterion.cuda() 151 | self.metal_criterion = self.metal_criterion.cuda() 152 | elif config.distributed: 153 | self.cuda() 154 | self = torch.nn.parallel.DistributedDataParallel(self) 155 | else: 156 | self = torch.nn.DataParallel(self).cuda() 157 | 158 | def save_acti(self,x,name): 159 | acti = x.cpu().data.numpy() 160 | self.activations.append({'name':name,'shape':acti.shape,'activation':acti}) 161 | 162 | #https://forums.fast.ai/t/pytorch-best-way-to-get-at-intermediate-layers-in-vgg-and-resnet/5707/6 163 | 164 | 165 | def forward_0(self, x): 166 | if hasattr(self, 'DInput'): 167 | x = self.DInput(x) 168 | for no,lay in enumerate(self.CNet): 169 | if isinstance(lay,nn.Linear): #x = self.avgpool(x), x = x.reshape(x.size(0), -1) 170 | x = F.avg_pool2d(x, 4) 171 | x = x.reshape(x.size(0), -1) 172 | x = lay(x) 173 | #print(f"{no}:\t{lay}\nx={x}") 174 | if isinstance(lay,nn.AdaptiveAvgPool2d): #x = self.avgpool(x), x = x.reshape(x.size(0), -1) 175 | x = x.reshape(x.size(0), -1) 176 | out_sum = x 177 | return out_sum 178 | 179 | def forward(self, x): 180 | out_sum = 0 181 | if self.config.dnet_type == "stack_feature": 182 | out_sum= self.DInput(x) 183 | for no,lay in enumerate(self.CNet): 184 | if isinstance(lay,nn.Linear): #x = self.avgpool(x), x = x.reshape(x.size(0), -1) 185 | x = F.avg_pool2d(x, 4) 186 | x = x.reshape(x.size(0), -1) 187 | x = lay(x) 188 | #print(f"{no}:\t{lay}\nx={x}") 189 | if isinstance(lay,nn.AdaptiveAvgPool2d): #x = self.avgpool(x), x = x.reshape(x.size(0), -1) 190 | x = x.reshape(x.size(0), -1) 191 | out_sum += x 192 | return out_sum 193 | 194 | if __name__ == "__main__": 195 | config = DNET_config(None) 196 | a = RGBO_CNN(config,nFilmLayer=10) 197 | print(f"RGBO_CNN={a}") 198 | pass 199 | 200 | -------------------------------------------------------------------------------- /python-package/onnet/SparseSupport.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .D2NNet import * 3 | from .some_utils import * 4 | import numpy as np 5 | import random 6 | import torch.nn as nn 7 | from enum import Enum 8 | 9 | #https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-custom-nn-modules 10 | class SuppLayer(torch.nn.Module): 11 | class SUPP(Enum): 12 | exp,sparse,expW,diff = 'exp','sparse','expW','differentia' 13 | 14 | def __init__(self,config,nClass, nSupp=10): 15 | super(SuppLayer, self).__init__() 16 | self.nClass = nClass 17 | self.nSupp = nSupp 18 | self.nChunk = self.nClass*2 19 | self.config = config 20 | self.w_11=False 21 | if self.config.support==self.SUPP.sparse: #"supp_sparse": 22 | if self.w_11: 23 | tSupp = torch.ones(self.nClass, self.nSupp) 24 | else: 25 | tSupp = torch.Tensor(self.nClass, self.nSupp).uniform_(-1,1) 26 | self.wSupp = torch.nn.Parameter(tSupp) 27 | self.nChunk = self.nSupp*self.nSupp 28 | self.chunk_map = np.random.randint(self.nChunk, size=(self.nClass, self.nSupp)) 29 | #elif self.config.support=="supp_expW": 30 | # self.nSupp = 2 31 | # self.wSupp = torch.nn.Parameter(torch.ones(2)) 32 | 33 | def __repr__(self): 34 | w_init="1" if self.w_11 else "random" 35 | main_str = f"SupportLayer supp=({self.nSupp},{w_init}) type=\"{self.config.support}\" nChunk={self.nChunk}" 36 | return main_str 37 | 38 | def sparse_support(self,x): 39 | feats=[] 40 | for i in range(self.nClass): 41 | feat = 0; 42 | for j in range(self.nSupp): 43 | col = (int)(self.chunk_map[i,j]) 44 | feat += x[:, col]*self.wSupp[i,j] 45 | feats.append(torch.exp(feat)) #why exp is useful??? 46 | #feats.append(feat) 47 | output = torch.stack(feats,1) 48 | return output 49 | 50 | def forward(self, x): 51 | if self.config.support == self.SUPP.sparse: # "supp_sparse": 52 | output = self.sparse_support(x) 53 | return output 54 | 55 | assert x.shape[1] == self.nClass * 2 56 | if self.config.support==self.SUPP.diff: #"supp_differentia": 57 | for i in range(self.nClass): 58 | x[:,i] = (x[:,2*i]-x[:,2*i+1])/(x[:,2*i]+x[:,2*i+1]) 59 | output=x[...,0:self.nClass] 60 | elif self.config.support==self.SUPP.exp: #"supp_exp": 61 | for i in range(self.nClass): 62 | x[:, i] = torch.exp(x[:, 2 * i] - x[:, 2 * i + 1]) 63 | output = x[..., 0:self.nClass] 64 | elif self.config.support==self.SUPP.expW: #"supp_expW": 65 | output = torch.zeros_like(x) 66 | for i in range(self.nClass): 67 | output[:, i] = torch.exp(x[:, 2 * i]*self.w2[0] - x[:, 2 * i + 1]*self.w2[1]) 68 | output = output[..., 0:self.nClass] 69 | 70 | return output 71 | 72 | -------------------------------------------------------------------------------- /python-package/onnet/ToExcel.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Yingshi Chen 3 | 4 | @Date: 2020-01-14 15:36:32 5 | @ 6 | # Description: 7 | ''' 8 | import numpy as np 9 | import pandas as pd 10 | import json 11 | import glob 12 | import argparse 13 | from scipy.signal import savgol_filter 14 | 15 | def OnVisdom_json(param,title,smooth=False): 16 | search_str = f"{param['data_root']}{param['select']}" 17 | files = glob.glob(search_str) 18 | datas = [] 19 | cols = [] 20 | for i, file in enumerate(files): 21 | with open(file, 'r') as f: 22 | meta = json.load(f) 23 | curve = meta['jsons']['loss']['content']['data'][0] 24 | legend = meta['jsons']['loss']['legend'] 25 | cols.append(legend[0]) 26 | item = curve['y'] 27 | datas.append(item) 28 | if smooth: 29 | win = max(9,len(item)//10) 30 | cols.append(f"{legend[0]}_smooth") 31 | item_s = savgol_filter(item, win, 3) 32 | datas.append(item_s) 33 | pass 34 | 35 | df = pd.DataFrame(datas) 36 | df = df.transpose() 37 | for i,col in enumerate(cols): 38 | df = df.rename(columns={i: col}) 39 | 40 | path = f"{param['data_root']}{title}_please_rename.xlsx" 41 | df.to_excel(path ) 42 | 43 | print(df.head()) 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser(description='Load json of visdom curves. Save to EXCEL!') 48 | parser.add_argument("keyword", type=str, help="keyword") 49 | parser.add_argument("root", type=str, help="root") 50 | 51 | args = parser.parse_args() 52 | 53 | if hasattr(args,'keyword') and hasattr(args,'root'): 54 | keyword = args.keyword # "WNet_mnist" 55 | data_root = args.root #"F:/arXiv/Diffractive Wavenet - an novel low parameter optical neural network/" 56 | param = {"data_root":data_root, 57 | "select":f"{keyword}*.json"} 58 | OnVisdom_json(param,keyword,smooth=True) 59 | else: 60 | param = {"data_root":"E:\Guided Inverse design of SPP structures\images", 61 | "select":f"3_4*.json"} 62 | OnVisdom_json(param,keyword) -------------------------------------------------------------------------------- /python-package/onnet/Visualizing.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python -m visdom.server 3 | http://localhost:8097 4 | .json file present in your ~/.visdom directory. 5 | 6 | tensorboard --logdir=runs 7 | http://localhost:6006/ 非常奇怪的出错 8 | 9 | ONNX export failed on ATen operator ifft because torch.onnx.symbolic.ifft does not exist 10 | ''' 11 | import seaborn as sns; sns.set() 12 | from PIL import Image 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | #from torch.utils.tensorboard import SummaryWriter 18 | import visdom 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import torchvision 22 | import cv2 23 | from torchvision import datasets, transforms 24 | from .Z_utils import COMPLEX_utils as Z 25 | 26 | def matplotlib_imshow(img, one_channel=False): 27 | if one_channel: 28 | img = img.mean(dim=0) 29 | img = img / 2 + 0.5 # unnormalize 30 | npimg = img.numpy() 31 | if one_channel: 32 | plt.imshow(npimg, cmap="Greys") 33 | else: 34 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 35 | plt.show() 36 | 37 | 38 | class Visualize: 39 | def __init__(self,env_title="onnet",plots=[], **kwargs): 40 | self.log_dir = f'runs/{env_title}' 41 | self.plots = plots 42 | self.loss_step = 0 43 | self.writer = None #SummaryWriter(self.log_dir) 44 | self.img_dir="./dump/images/" 45 | self.dpi = 100 46 | 47 | #https://stackoverflow.com/questions/9662995/matplotlib-change-title-and-colorbar-text-and-tick-colors 48 | def MatPlot(self,arr, title=""): 49 | fig, ax = plt.subplots() 50 | #plt.axis('off') 51 | plt.grid(b=None) 52 | im = ax.imshow(arr, interpolation='nearest', cmap='coolwarm') 53 | fig.colorbar(im, orientation='horizontal') 54 | plt.savefig(f'{self.img_dir}{title}.jpg') 55 | #plt.show() 56 | plt.close() 57 | 58 | def fig2data(self,fig): 59 | fig.canvas.draw() 60 | if True: # https://stackoverflow.com/questions/42603161/convert-an-image-shown-in-python-into-an-opencv-image 61 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 62 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 63 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 64 | return img 65 | else: 66 | w, h = fig.canvas.get_width_height() 67 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 68 | buf.shape = (w, h, 4) 69 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 70 | buf = np.roll(buf, 3, axis=2) 71 | return buf 72 | 73 | ''' 74 | sns.heatmap 很难用,需用自定义,参见https://stackoverflow.com/questions/53248186/custom-ticks-for-seaborn-heatmap 75 | ''' 76 | def HeatMap(self, data, file_name, params={},noAxis=True, cbar=True): 77 | title,isSave = file_name,True 78 | if 'save' in params: 79 | isSave = params['save'] 80 | if 'title' in params: 81 | title = params['title'] 82 | path = '{}{}_.jpg'.format(self.img_dir, file_name) 83 | sns.set(font_scale=3) 84 | s = max(data.shape[1] / self.dpi, data.shape[0] / self.dpi) 85 | # fig.set_size_inches(18.5, 10.5) 86 | cmap = 'coolwarm' # "plasma" #https://matplotlib.org/examples/color/colormaps_reference.html 87 | # cmap = sns.cubehelix_palette(start=1, rot=3, gamma=0.8, as_cmap=True) 88 | if noAxis: # tight samples for training(No text!!!) 89 | figsize = (s, s) 90 | fig, ax = plt.subplots(figsize=figsize, dpi=self.dpi) 91 | ax = sns.heatmap(data, ax=ax, cmap=cmap, cbar=False, xticklabels=False, yticklabels=False) 92 | fig.savefig(path, bbox_inches='tight', pad_inches=0,figsize=(20,10)) 93 | if False: 94 | image = cv2.imread(path) 95 | # image = fig2data(ax.get_figure()) #会放大尺寸,难以理解 96 | if (len(title) > 0): 97 | assert (image.shape == self.args.spp_image_shape) # 必须固定一个尺寸 98 | cv2.imshow("",image); cv2.waitKey(0) 99 | plt.close("all") 100 | return path 101 | else: # for paper 102 | ticks = np.linspace(0, 1, 10) 103 | xlabels = [int(i) for i in np.linspace(0, 56, 10)] 104 | ylabels = xlabels 105 | figsize = (s * 10, s * 10) 106 | #fig, ax = plt.subplots(figsize=figsize, dpi=self.dpi) # more concise than plt.figure: 107 | fig, ax = plt.subplots(dpi=self.dpi) 108 | ax.set_title(title) 109 | # cbar_kws={'label': 'Reflex', 'orientation': 'horizontal'} 110 | # sns.set(font_scale=0.2) 111 | # cbar_kws={'label': 'Reflex', 'orientation': 'horizontal'} , center=0.6 112 | # ax = sns.heatmap(data, ax=ax, cmap=cmap,yticklabels=ylabels[::-1],xticklabels=xlabels) 113 | # cbar_kws = dict(ticks=np.linspace(0, 1, 10)) 114 | ax = sns.heatmap(data, ax=ax, cmap=cmap,vmin=-1.1, vmax=1.1, cbar=cbar) # 115 | #plt.ylabel('Incident Angle'); plt.xlabel('Wavelength(nm)') 116 | if False: 117 | ax.set_xticklabels(xlabels); ax.set_yticklabels(ylabels[::-1]) 118 | y_limit = ax.get_ylim(); 119 | x_limit = ax.get_xlim() 120 | ax.set_yticks(ticks * y_limit[0]) 121 | ax.set_xticks(ticks * x_limit[1]) 122 | else: 123 | plt.axis('off') 124 | if False: 125 | plt.show(block=True) 126 | 127 | image = self.fig2data(ax.get_figure()) 128 | plt.close("all") 129 | #image_all = np.concatenate((img_0, img_1, img_diff), axis=1) 130 | #cv2.imshow("", image); cv2.waitKey(0) 131 | if isSave: 132 | cv2.imwrite(path, image) 133 | return path 134 | else: 135 | return image 136 | 137 | plt.close("all") 138 | 139 | def ShowModel(self,model,data_loader): 140 | ''' 141 | tensorboar显示效果较差 142 | ''' 143 | dataiter = iter(data_loader) 144 | images, labels = dataiter.next() 145 | if images.shape[0]>32: 146 | images=images[0:32,...] 147 | if True: 148 | img_grid = torchvision.utils.make_grid(images) 149 | matplotlib_imshow(img_grid, one_channel=True) 150 | self.writer.add_image('one_batch', img_grid) 151 | self.writer.close() 152 | image_1 = images[0:1,:,:,:] 153 | if False: 154 | images = images.cuda() 155 | self.writer.add_graph(model,images ) 156 | self.writer.close() 157 | 158 | def onX(self,X,title,nMostPic=64): 159 | shape = X.shape 160 | if Z.isComplex(X): 161 | #X = torch.cat([X[..., 0],X[..., 1]],0) 162 | X = Z.modulus(X) 163 | X = X.cpu() 164 | if shape[1]!=1: 165 | X = X.contiguous().view(shape[0]*shape[1],1,shape[-2],shape[-1]).cpu() 166 | if X.shape[0]>nMostPic: 167 | X=X[:nMostPic,...] 168 | img_grid = torchvision.utils.make_grid(X).detach().numpy() 169 | plt.axis('off'); 170 | plt.grid(b=None) 171 | image_np = np.transpose(img_grid, (1, 2, 0)) 172 | min_val,max_val = np.max(image_np),np.min(image_np) 173 | image_np = (image_np - min_val) / (max_val - min_val) 174 | if title is None: 175 | plt.imshow(image_np) 176 | plt.show() 177 | else: 178 | path = '{}{}_.jpg'.format(self.img_dir, title) 179 | plt.imsave(path, image_np) 180 | 181 | 182 | def image(self, file_name, img_, params={}): 183 | #np.random.rand(3, 512, 256), 184 | #self.MatPlot(img_.cpu().numpy(),title=name) 185 | 186 | result = self.HeatMap(img_.cpu().numpy(),file_name,params,noAxis=False) 187 | return result 188 | 189 | def UpdateLoss(self,title,legend,loss,yLabel='LOSS',global_step=None): 190 | tag = legend 191 | step = self.loss_step if global_step==None else global_step 192 | with SummaryWriter(log_dir=self.log_dir) as writer: 193 | writer.add_scalar(tag, loss, global_step=step) 194 | #self.writer.close() # 执行close立即刷新,否则将每120秒自动刷新 195 | self.loss_step = self.loss_step+1 196 | 197 | class Visdom_Visualizer(Visualize): 198 | ''' 199 | 封装了visdom的基本操作 200 | ''' 201 | 202 | def __init__(self,env_title,plots=[], **kwargs): 203 | super(Visdom_Visualizer, self).__init__(env_title,plots) 204 | try: 205 | self.viz = visdom.Visdom(env=env_title, **kwargs) 206 | assert self.viz.check_connection() 207 | except: 208 | self.viz = None 209 | 210 | def UpdateLoss(self, title,legend, loss, yLabel='LOSS',global_step=None): 211 | self.vis_plot( self.loss_step, loss, title,legend,yLabel) 212 | self.loss_step = self.loss_step + 1 213 | 214 | def vis_plot(self,epoch, loss_, title,legend,yLabel): 215 | if self.viz is None: 216 | return 217 | self.viz.line(X=torch.FloatTensor([epoch]), Y=torch.FloatTensor([loss_]), win='loss', 218 | opts=dict( 219 | legend=[legend], # [config_.use_bn], 220 | fillarea=False, 221 | showlegend=True, 222 | width=1600, 223 | height=800, 224 | xlabel='Epoch', 225 | ylabel=yLabel, 226 | # ytype='log', 227 | title=title, 228 | # marginleft=30, 229 | # marginright=30, 230 | # marginbottom=80, 231 | # margintop=30, 232 | ), 233 | update='append' if epoch > 0 else None) 234 | 235 | def reinit(self, env='default', **kwargs): 236 | self.vis = visdom.Visdom(env=env, **kwargs) 237 | return self 238 | 239 | def plot_many(self, d): 240 | ''' 241 | 一次plot多个 242 | @params d: dict (name,value) i.e. ('loss',0.11) 243 | ''' 244 | for k, v in d.iteritems(): 245 | self.plot(k, v) 246 | 247 | def img_many(self, d): 248 | for k, v in d.iteritems(): 249 | self.img(k, v) 250 | 251 | def plot(self, name, y, **kwargs): 252 | ''' 253 | self.plot('loss',1.00) 254 | ''' 255 | x = self.index.get(name, 0) 256 | self.vis.line(Y=np.array([y]), X=np.array([x]), 257 | win=name, 258 | opts=dict(title=name), 259 | update=None if x == 0 else 'append', 260 | **kwargs 261 | ) 262 | self.index[name] = x + 1 263 | 264 | ''' 非常奇怪的出错 265 | def image(self, name, img_, **kwargs): 266 | 267 | assert self.viz.check_connection() 268 | self.vis.image( 269 | np.random.rand(3, 512, 256), 270 | opts=dict(title='Random image as jpg!', caption='How random as jpg.', jpgquality=50), 271 | ) 272 | self.vis.image (img_.cpu().numpy(), 273 | #win=(name), 274 | opts=dict(title=name), 275 | **kwargs 276 | ) 277 | ''' 278 | 279 | def log(self, info, win='log_text'): 280 | ''' 281 | self.log({'loss':1,'lr':0.0001}) 282 | ''' 283 | 284 | self.log_text += ('[{time}] {info}
'.format( 285 | time=time.strftime('%m%d_%H%M%S'), \ 286 | info=info)) 287 | self.vis.text(self.log_text, win) 288 | print(self.log_text) 289 | 290 | def __getattr__(self, name): 291 | return getattr(self.vis, name) 292 | 293 | def PROJECTOR_test(): 294 | """ ==================使用PROJECTOR对高维向量可视化==================== 295 | https://blog.csdn.net/wsp_1138886114/article/details/87602112 296 | PROJECTOR的的原理是通过PCA,T-SNE等方法将高维向量投影到三维坐标系(降维度)。 297 | Embedding Projector从模型运行过程中保存的checkpoint文件中读取数据, 298 | 默认使用主成分分析法(PCA)将高维数据投影到3D空间中,也可以通过设置设置选择T-SNE投影方法, 299 | 这里做一个简单的展示。 300 | """ 301 | log_dirs = "../../runs/projector/" 302 | BATCH_SIZE = 256 303 | EPOCHS = 2 304 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 305 | 306 | train_loader = DataLoader(datasets.MNIST('../../data', train=True, download=False, 307 | transform=transforms.Compose([ 308 | transforms.ToTensor(), 309 | transforms.Normalize((0.1307,), (0.3081,)) 310 | ])), 311 | batch_size=BATCH_SIZE, shuffle=True) 312 | 313 | test_loader = torch.utils.data.DataLoader( 314 | datasets.MNIST('../../data', train=False, transform=transforms.Compose([ 315 | transforms.ToTensor(), 316 | transforms.Normalize((0.1307,), (0.3081,)) 317 | ])), 318 | batch_size=BATCH_SIZE, shuffle=True) 319 | 320 | class ConvNet(nn.Module): 321 | def __init__(self): 322 | super().__init__() 323 | # 1,28x28 324 | self.conv1 = nn.Conv2d(1, 10, 5) # 10, 24x24 325 | self.conv2 = nn.Conv2d(10, 20, 3) # 128, 10x10 326 | self.fc1 = nn.Linear(20 * 10 * 10, 500) 327 | self.fc2 = nn.Linear(500, 10) 328 | 329 | def forward(self, x): 330 | in_size = x.size(0) 331 | out = self.conv1(x) # 24 332 | out = F.relu(out) 333 | out = F.max_pool2d(out, 2, 2) # 12 334 | out = self.conv2(out) # 10 335 | out = F.relu(out) 336 | out = out.view(in_size, -1) 337 | out = self.fc1(out) 338 | out = F.relu(out) 339 | out = self.fc2(out) 340 | out = F.log_softmax(out, dim=1) 341 | return out 342 | 343 | model = ConvNet().to(DEVICE) 344 | optimizer = torch.optim.Adam(model.parameters()) 345 | 346 | def train(model, DEVICE, train_loader, optimizer, epoch): 347 | n_iter = 0 348 | model.train() 349 | for batch_idx, (data, target) in enumerate(train_loader): 350 | data, target = data.to(DEVICE), target.to(DEVICE) 351 | optimizer.zero_grad() 352 | output = model(data) 353 | loss = F.nll_loss(output, target) 354 | loss.backward() 355 | optimizer.step() 356 | if (batch_idx + 1) % 30 == 0: 357 | n_iter = n_iter + 1 358 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format( 359 | epoch, batch_idx * len(data), len(train_loader.dataset), 360 | 100. * batch_idx / len(train_loader), loss.item())) 361 | 362 | # 主要增加了一下内容 363 | out = torch.cat((output.data.cpu(), torch.ones(len(output), 1)), 1) # 因为是投影到3D的空间,所以我们只需要3个维度 364 | with SummaryWriter(log_dir=log_dirs, comment='mnist') as writer: 365 | # 使用add_embedding方法进行可视化展示 366 | writer.add_embedding( 367 | out, 368 | metadata=target.data, 369 | label_img=data.data, 370 | global_step=n_iter) 371 | 372 | def test(model, device, test_loader): 373 | model.eval() 374 | test_loss = 0 375 | correct = 0 376 | with torch.no_grad(): 377 | for data, target in test_loader: 378 | data, target = data.to(device), target.to(device) 379 | output = model(data) 380 | test_loss += F.nll_loss(output, target, reduction='sum').item() # 损失相加 381 | pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标 382 | correct += pred.eq(target.view_as(pred)).sum().item() 383 | 384 | test_loss /= len(test_loader.dataset) 385 | print('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' 386 | .format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset))) 387 | 388 | for epoch in range(1, EPOCHS + 1): 389 | train(model, DEVICE, train_loader, optimizer, epoch) 390 | test(model, DEVICE, test_loader) 391 | 392 | # 保存模型 393 | torch.save(model.state_dict(), './pytorch_tensorboardX_03.pth') 394 | 395 | if __name__ == '__main__': 396 | PROJECTOR_test() -------------------------------------------------------------------------------- /python-package/onnet/Z_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 1 晕 Pytorch居然不支持复向量 https://github.com/pytorch/pytorch/issues/755 3 | 4 | ''' 5 | 6 | import torch 7 | from torch.nn import ReflectionPad2d 8 | from torch.nn.functional import relu, max_pool2d, dropout, dropout2d 9 | import numpy as np 10 | 11 | class COMPLEX_utils(object): 12 | @staticmethod 13 | def isComplex(input): 14 | return input.size(-1) == 2 15 | 16 | @staticmethod 17 | def isReal(input): 18 | return input.size(-1) == 1 19 | 20 | @staticmethod 21 | def ToZ(u0): 22 | if COMPLEX_utils.isComplex(u0): 23 | return u0 24 | else: 25 | z0 = u0.new_zeros(u0.shape + (2,)) 26 | z0[..., 0] = u0 27 | assert(COMPLEX_utils.isComplex(z0)) 28 | return z0 29 | 30 | @staticmethod 31 | def relu(input_r,input_i): 32 | return relu(input_r), relu(input_i) 33 | 34 | @staticmethod 35 | def max_pool2d(input_r,input_i,kernel_size, stride=None, padding=0, 36 | dilation=1, ceil_mode=False, return_indices=False): 37 | 38 | return max_pool2d(input_r, kernel_size, stride, padding, dilation, 39 | ceil_mode, return_indices), \ 40 | max_pool2d(input_i, kernel_size, stride, padding, dilation, 41 | ceil_mode, return_indices) 42 | 43 | @staticmethod 44 | def rDrop2D(rDrop,d_shape,isComlex=False): 45 | drop = np.random.binomial(1, rDrop, size=d_shape).astype(np.float) 46 | drop[drop == 0] = 1.0e-6 47 | # print(f"x={x.shape} drop={drop.shape}") 48 | drop = torch.from_numpy(drop).cuda() 49 | if isComlex: 50 | drop = COMPLEX_utils.ToZ(drop) 51 | return drop 52 | ''' 53 | @staticmethod 54 | def dropout(input_r,input_i, p=0.5, training=True, inplace=False): 55 | return dropout(input_r, p, training, inplace), \ 56 | dropout(input_i, p, training, inplace) 57 | 58 | @staticmethod 59 | def dropout2d(input_r,input_i, p=0.5, training=True, inplace=False): 60 | return dropout2d(input_r, p, training, inplace), \ 61 | dropout2d(input_i, p, training, inplace) 62 | ''' 63 | 64 | #the absolute value or modulus of z https://en.wikipedia.org/wiki/Absolute_value#Complex_numbers 65 | @staticmethod 66 | def modulus(x): 67 | shape = x.size()[:-1] 68 | if False: 69 | norm = torch.zeros(shape) 70 | if x.dtype==torch.float64: 71 | norm = norm.double() 72 | norm = (x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1]).sqrt() 73 | return norm 74 | 75 | @staticmethod 76 | def phase(x): 77 | phase = torch.atan2(x[..., 0],x[..., 1]) 78 | return phase 79 | 80 | @staticmethod 81 | def sigmoid(x): 82 | # norm[...,0] = (x[...,0]*x[...,0] + x[...,1]*x[...,1]).sqrt() 83 | s_ = torch.zeros_like(x) 84 | s_[...,0] = torch.sigmoid(x[...,0]) 85 | s_[..., 1] = torch.sigmoid(x[..., 1]) 86 | return s_ 87 | 88 | @staticmethod 89 | def exp_euler(x): #Euler's formula: {\displaystyle e^{ix}=\cos x+i\sin x,} 90 | s_ = torch.zeros(x.shape + (2,)).double().cuda() 91 | s_[..., 0] = torch.cos(x) 92 | s_[..., 1] = torch.sin(x) 93 | return s_ 94 | 95 | @staticmethod 96 | def fft(input, direction='C2C', inverse=False): 97 | """ 98 | Interface with torch FFT routines for 2D signals. 99 | 100 | Example 101 | ------- 102 | x = torch.randn(128, 32, 32, 2) 103 | x_fft = fft(x, inverse=True) 104 | 105 | Parameters 106 | ---------- 107 | input : tensor 108 | complex input for the FFT 109 | direction : string 110 | 'C2R' for complex to real, 'C2C' for complex to complex 111 | inverse : bool 112 | True for computing the inverse FFT. 113 | NB : if direction is equal to 'C2R', then the transform 114 | is automatically inverse. 115 | """ 116 | if direction == 'C2R': 117 | inverse = True 118 | 119 | if not COMPLEX_utils.isComplex(input): 120 | raise(TypeError('The input should be complex (e.g. last dimension is 2)')) 121 | 122 | if (not input.is_contiguous()): 123 | raise (RuntimeError('Tensors must be contiguous!')) 124 | 125 | if direction == 'C2R': 126 | output = torch.irfft(input, 2, normalized=False, onesided=False)*input.size(-2)*input.size(-3) 127 | elif direction == 'C2C': 128 | if inverse: 129 | #output = torch.ifft(input, 2, normalized=False)*input.size(-2)*input.size(-3) 130 | output = torch.ifft(input, 2, normalized=False) 131 | else: 132 | output = torch.fft(input, 2, normalized=False) 133 | 134 | return output 135 | 136 | @staticmethod 137 | def Hadamard(A, B, inplace=False): 138 | """ 139 | Complex pointwise multiplication between (batched) tensor A and tensor B. 140 | Sincr The Hadamard product is commutative, so Hadamard(A, B)=Hadamard(B, A) 141 | 142 | Parameters 143 | ---------- 144 | A : tensor 145 | A is a complex tensor of size (B, C, M, N, 2) 146 | B : tensor 147 | B is a complex tensor of size (M, N, 2) or real tensor of (M, N, 1) 148 | inplace : boolean, optional 149 | if set to True, all the operations are performed inplace 150 | 151 | Returns 152 | ------- 153 | C : tensor 154 | output tensor of size (B, C, M, N, 2) such that: 155 | C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :] 156 | """ 157 | if not COMPLEX_utils.isComplex(A): 158 | raise TypeError('The input must be complex, indicated by a last ' 159 | 'dimension of size 2') 160 | 161 | if B.ndimension() != 3: 162 | raise RuntimeError('The filter must be a 3-tensor, with a last ' 163 | 'dimension of size 1 or 2 to indicate it is real ' 164 | 'or complex, respectively') 165 | 166 | if not COMPLEX_utils.isComplex(B) and not COMPLEX_utils.isReal(B): 167 | raise TypeError('The filter must be complex or real, indicated by a ' 168 | 'last dimension of size 2 or 1, respectively') 169 | 170 | if A.size()[-3:-1] != B.size()[-3:-1]: 171 | raise RuntimeError('The filters are not compatible for multiplication!') 172 | 173 | if A.dtype is not B.dtype: 174 | raise RuntimeError('A and B must be of the same dtype') 175 | 176 | if A.device.type != B.device.type: 177 | raise RuntimeError('A and B must be of the same device type') 178 | 179 | if A.device.type == 'cuda': 180 | if A.device.index != B.device.index: 181 | raise RuntimeError('A and B must be on the same GPU!') 182 | 183 | if COMPLEX_utils.isReal(B): 184 | if inplace: 185 | return A.mul_(B) 186 | else: 187 | return A * B 188 | else: 189 | C = A.new(A.size()) 190 | 191 | A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3)) 192 | A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3)) 193 | 194 | B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i) 195 | B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r) 196 | 197 | C[..., 0].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_r - A_i * B_i 198 | C[..., 1].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_i + A_i * B_r 199 | 200 | return C if not inplace else A.copy_(C) 201 | 202 | def IFFT(X1,X2,X3): 203 | f, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, sharex='col', sharey='row',figsize=(10,6)) 204 | Z = ifftn(X1) 205 | ax1.imshow(X1, cmap=cm.Reds) 206 | ax4.imshow(np.real(Z), cmap=cm.gray) 207 | Z = ifftn(X2) 208 | ax2.imshow(X2, cmap=cm.Reds) 209 | ax5.imshow(np.real(Z), cmap=cm.gray) 210 | Z = ifftn(X3) 211 | ax3.imshow(X3, cmap=cm.Reds) 212 | ax6.imshow(np.real(Z), cmap=cm.gray) 213 | plt.show() 214 | 215 | 216 | def roll_n(X, axis, n): 217 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) 218 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) 219 | front = X[f_idx] 220 | back = X[b_idx] 221 | return torch.cat([back, front], axis) 222 | 223 | def batch_fftshift2d(x): 224 | real, imag = torch.unbind(x, -1) 225 | for dim in range(1, len(real.size())): 226 | n_shift = real.size(dim)//2 227 | if real.size(dim) % 2 != 0: 228 | n_shift += 1 # for odd-sized images 229 | real = roll_n(real, axis=dim, n=n_shift) 230 | imag = roll_n(imag, axis=dim, n=n_shift) 231 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 232 | 233 | def batch_ifftshift2d(x): 234 | real, imag = torch.unbind(x, -1) 235 | for dim in range(len(real.size()) - 1, 0, -1): 236 | real = roll_n(real, axis=dim, n=real.size(dim)//2) 237 | imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) 238 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 239 | 240 | -------------------------------------------------------------------------------- /python-package/onnet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | ''' 3 | @Author: Yingshi Chen 4 | 5 | @Date: 2020-01-16 10:38:45 6 | @ 7 | # Description: 8 | ''' 9 | # coding: utf-8 10 | """LiteMORT, Light Gradient Boosting Machine. 11 | 12 | __author__ = 'Yingshi Chen' 13 | """ 14 | 15 | import os 16 | 17 | from .optical_trans import OpticalTrans 18 | from .D2NNet import D2NNet,DNET_config 19 | from .RGBO_CNN import RGBO_CNN,RGBO_CNN_config 20 | from .Z_utils import COMPLEX_utils 21 | from .BinaryDNet import * 22 | from .Net_Instance import * 23 | from .NET_config import * 24 | from .Visualizing import * 25 | from .some_utils import * 26 | from .DiffractiveLayer import * 27 | from .OpticalFormer import clip_grad,OpticalFormer 28 | 29 | ''' 30 | try: 31 | except ImportError: 32 | pass 33 | ''' 34 | 35 | ''' 36 | try: 37 | from .plotting import plot_importance, plot_metric, plot_tree, create_tree_digraph 38 | except ImportError: 39 | pass 40 | ''' 41 | 42 | dir_path = os.path.dirname(os.path.realpath(__file__)) 43 | #print(f"__init_ dir_path={dir_path}") 44 | 45 | __all__ = ['NET_config', 46 | 'D2NNet','DNET_config','DNet_instance','RGBO_CNN_instance','Net_dump', 47 | 'RGBO_CNN', 'RGBO_CNN_config', 48 | 'OpticalTrans','COMPLEX_utils','MultiDNet','BinaryDNet','Visualize','Visdom_Visualizer', 49 | 'seed_everything','load_model_weights', 50 | 'DiffractiveLayer' 51 | ] 52 | 53 | 54 | -------------------------------------------------------------------------------- /python-package/onnet/__version__.py: -------------------------------------------------------------------------------- 1 | 2 | VERSION = (0, 0, 1) 3 | 4 | __version__ = '.'.join(map(str, VERSION)) 5 | -------------------------------------------------------------------------------- /python-package/onnet/optical_trans.py: -------------------------------------------------------------------------------- 1 | # Authors: Edouard Oyallon 2 | # Scientific Ancestry: Edouard Oyallon, Laurent Sifre, Joan Bruna 3 | 4 | 5 | __all__ = ['optical_trans'] 6 | 7 | import torch 8 | 9 | class OpticalTrans(object): 10 | def forward(self, input): 11 | #input = input.type(torch.complex64) 12 | return input 13 | 14 | def __call__(self, input): 15 | return self.forward(input) 16 | 17 | class Scattering2D(object): 18 | """Main module implementing the scattering transform in 2D. 19 | The scattering transform computes two wavelet transform followed 20 | by modulus non-linearity. 21 | It can be summarized as:: 22 | 23 | S_J x = [S_J^0 x, S_J^1 x, S_J^2 x] 24 | 25 | where:: 26 | 27 | S_J^0 x = x * phi_J 28 | S_J^1 x = [|x * psi^1_lambda| * phi_J]_lambda 29 | S_J^2 x = [||x * psi^1_lambda| * psi^2_mu| * phi_J]_{lambda, mu} 30 | 31 | where * denotes the convolution (in space), phi_J is a low pass 32 | filter, psi^1_lambda is a family of band pass 33 | filters and psi^2_mu is another family of band pass filters. 34 | Only Morlet filters are used in this implementation. 35 | Convolutions are efficiently performed in the Fourier domain 36 | with this implementation. 37 | 38 | Example 39 | ------- 40 | # 1) Define a Scattering object as: 41 | s = Scattering2D(J, shape=(M, N)) 42 | # where (M, N) are the image sizes and 2**J the scale of the scattering 43 | # 2) Forward on an input Tensor x of shape B x M x N, 44 | # where B is the batch size. 45 | result_s = s(x) 46 | 47 | Parameters 48 | ---------- 49 | J : int 50 | logscale of the scattering 51 | shape : tuple of int 52 | spatial support (M, N) of the input 53 | L : int, optional 54 | number of angles used for the wavelet transform 55 | max_order : int, optional 56 | The maximum order of scattering coefficients to compute. Must be either 57 | `1` or `2`. Defaults to `2`. 58 | pre_pad : boolean, optional 59 | controls the padding: if set to False, a symmetric padding is applied 60 | on the signal. If set to true, the software will assume the signal was 61 | padded externally. 62 | 63 | Attributes 64 | ---------- 65 | J : int 66 | logscale of the scattering 67 | shape : tuple of int 68 | spatial support (M, N) of the input 69 | L : int, optional 70 | number of angles used for the wavelet transform 71 | max_order : int, optional 72 | The maximum order of scattering coefficients to compute. 73 | Must be either equal to `1` or `2`. Defaults to `2`. 74 | pre_pad : boolean 75 | controls the padding 76 | Psi : dictionary 77 | containing the wavelets filters at all resolutions. See 78 | filter_bank.filter_bank for an exact description. 79 | Phi : dictionary 80 | containing the low-pass filters at all resolutions. See 81 | filter_bank.filter_bank for an exact description. 82 | M_padded, N_padded : int 83 | spatial support of the padded input 84 | 85 | Notes 86 | ----- 87 | The design of the filters is optimized for the value L = 8 88 | 89 | pre_pad is particularly useful when doing crops of a bigger 90 | image because the padding is then extremely accurate. Defaults 91 | to False. 92 | 93 | """ 94 | def __init__(self, J, shape, L=8, max_order=2, pre_pad=False): 95 | self.J, self.L = J, L 96 | self.pre_pad = pre_pad 97 | self.max_order = max_order 98 | self.shape = shape 99 | if 2**J>shape[0] or 2**J>shape[1]: 100 | raise (RuntimeError('The smallest dimension should be larger than 2^J')) 101 | 102 | self.build() 103 | 104 | def build(self): 105 | self.M, self.N = self.shape 106 | self.modulus = Modulus() 107 | self.M_padded, self.N_padded = compute_padding(self.M, self.N, self.J) 108 | # pads equally on a given side if the amount of padding to add is an even number of pixels, otherwise it adds an extra pixel 109 | self.pad = Pad([(self.M_padded - self.M) // 2, (self.M_padded - self.M+1) // 2, (self.N_padded - self.N) // 2, (self.N_padded - self.N + 1) // 2], [self.M, self.N], pre_pad=self.pre_pad) 110 | self.subsample_fourier = SubsampleFourier() 111 | # Create the filters 112 | filters = filter_bank(self.M_padded, self.N_padded, self.J, self.L) 113 | self.Psi = convert_filters(filters['psi']) 114 | self.Phi = convert_filters([filters['phi'][j] for j in range(self.J)]) 115 | 116 | def _apply(self, fn): 117 | """ 118 | Mimics the behavior of the function _apply() of a nn.Module() 119 | """ 120 | for key, item in enumerate(self.Psi): 121 | for key2, item2 in self.Psi[key].items(): 122 | if torch.is_tensor(item2): 123 | self.Psi[key][key2] = fn(item2) 124 | self.Phi = [fn(v) for v in self.Phi] 125 | self.pad.padding_module._apply(fn) 126 | return self 127 | 128 | def cuda(self, device=None): 129 | """ 130 | Mimics the behavior of the function cuda() of a nn.Module() 131 | """ 132 | return self._apply(lambda t: t.cuda(device)) 133 | 134 | def to(self, *args, **kwargs): 135 | """ 136 | Mimics the behavior of the function to() of a nn.Module() 137 | """ 138 | device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) 139 | 140 | if dtype is not None: 141 | if not dtype.is_floating_point: 142 | raise TypeError('nn.Module.to only accepts floating point ' 143 | 'dtypes, but got desired dtype={}'.format(dtype)) 144 | 145 | def convert(t): 146 | return t.to(device, dtype if t.is_floating_point() else None, non_blocking) 147 | 148 | return self._apply(convert) 149 | 150 | def cpu(self): 151 | """ 152 | Mimics the behavior of the function cpu() of a nn.Module() 153 | """ 154 | return self._apply(lambda t: t.cpu()) 155 | 156 | def forward(self, input): 157 | """Forward pass of the scattering. 158 | 159 | Parameters 160 | ---------- 161 | input : tensor 162 | tensor with 3 dimensions :math:`(B, C, M, N)` where :math:`(B, C)` are arbitrary. 163 | :math:`B` typically is the batch size, whereas :math:`C` is the number of input channels. 164 | 165 | Returns 166 | ------- 167 | S : tensor 168 | scattering of the input, a 4D tensor :math:`(B, C, D, Md, Nd)` where :math:`D` corresponds 169 | to a new channel dimension and :math:`(Md, Nd)` are downsampled sizes by a factor :math:`2^J`. 170 | 171 | """ 172 | if not torch.is_tensor(input): 173 | raise(TypeError('The input should be a torch.cuda.FloatTensor, a torch.FloatTensor or a torch.DoubleTensor')) 174 | 175 | if len(input.shape) < 2: 176 | raise (RuntimeError('Input tensor must have at least two ' 177 | 'dimensions')) 178 | 179 | if (not input.is_contiguous()): 180 | raise (RuntimeError('Tensor must be contiguous!')) 181 | 182 | if((input.size(-1)!=self.N or input.size(-2)!=self.M) and not self.pre_pad): 183 | raise (RuntimeError('Tensor must be of spatial size (%i,%i)!'%(self.M,self.N))) 184 | 185 | if ((input.size(-1) != self.N_padded or input.size(-2) != self.M_padded) and self.pre_pad): 186 | raise (RuntimeError('Padded tensor must be of spatial size (%i,%i)!' % (self.M_padded, self.N_padded))) 187 | 188 | batch_shape = input.shape[:-2] 189 | signal_shape = input.shape[-2:] 190 | 191 | input = input.reshape((-1, 1) + signal_shape) 192 | 193 | J = self.J 194 | phi = self.Phi 195 | psi = self.Psi 196 | 197 | subsample_fourier = self.subsample_fourier 198 | modulus = self.modulus 199 | pad = self.pad 200 | order0_size = 1 201 | order1_size = self.L * J 202 | order2_size = self.L ** 2 * J * (J - 1) // 2 203 | output_size = order0_size + order1_size 204 | 205 | if self.max_order == 2: 206 | output_size += order2_size 207 | 208 | S = input.new(input.size(0), 209 | input.size(1), 210 | output_size, 211 | self.M_padded//(2**J)-2, 212 | self.N_padded//(2**J)-2) 213 | U_r = pad(input) 214 | U_0_c = fft(U_r, 'C2C') # We trick here with U_r and U_2_c 215 | 216 | # First low pass filter 217 | U_1_c = subsample_fourier(cdgmm(U_0_c, phi[0]), k=2**J) 218 | 219 | U_J_r = fft(U_1_c, 'C2R') 220 | 221 | S[..., 0, :, :] = unpad(U_J_r) 222 | n_order1 = 1 223 | n_order2 = 1 + order1_size 224 | 225 | for n1 in range(len(psi)): 226 | j1 = psi[n1]['j'] 227 | U_1_c = cdgmm(U_0_c, psi[n1][0]) 228 | if(j1 > 0): 229 | U_1_c = subsample_fourier(U_1_c, k=2 ** j1) 230 | U_1_c = fft(U_1_c, 'C2C', inverse=True) 231 | U_1_c = fft(modulus(U_1_c), 'C2C') 232 | 233 | # Second low pass filter 234 | U_2_c = subsample_fourier(cdgmm(U_1_c, phi[j1]), k=2**(J-j1)) 235 | U_J_r = fft(U_2_c, 'C2R') 236 | S[..., n_order1, :, :] = unpad(U_J_r) 237 | n_order1 += 1 238 | 239 | if self.max_order == 2: 240 | for n2 in range(len(psi)): 241 | j2 = psi[n2]['j'] 242 | if(j1 < j2): 243 | U_2_c = subsample_fourier(cdgmm(U_1_c, psi[n2][j1]), k=2 ** (j2-j1)) 244 | U_2_c = fft(U_2_c, 'C2C', inverse=True) 245 | U_2_c = fft(modulus(U_2_c), 'C2C') 246 | 247 | # Third low pass filter 248 | U_2_c = subsample_fourier(cdgmm(U_2_c, phi[j2]), k=2 ** (J-j2)) 249 | U_J_r = fft(U_2_c, 'C2R') 250 | 251 | S[..., n_order2, :, :] = unpad(U_J_r) 252 | n_order2 += 1 253 | 254 | scattering_shape = S.shape[-3:] 255 | S = S.reshape(batch_shape + scattering_shape) 256 | 257 | return S 258 | 259 | def __call__(self, input): 260 | return self.forward(input) 261 | -------------------------------------------------------------------------------- /python-package/onnet/some_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | import torch 5 | import sys 6 | import os 7 | import psutil 8 | 9 | 10 | 11 | def split__sections(dim_0,nClass): 12 | split_dim = range(dim_0) 13 | sections=[] 14 | for arr in np.array_split(np.array(split_dim), nClass): 15 | sections.append(len(arr)) 16 | assert len(sections) > 0 17 | return sections 18 | 19 | def shrink(x0,x1,max_sz=2): 20 | if x1-x0>max_sz: 21 | center=(x1+x0)//2 22 | #x1 = x0+max_sz 23 | x1 = center + max_sz // 2 24 | x0 = center - max_sz // 2 25 | return x0,x1 26 | 27 | def split_regions_2d(shape,nClass): 28 | dim_1,dim_2=shape[-1],shape[-2] 29 | n1 = (int)(math.sqrt(nClass)) 30 | n2 = (int)(math.ceil(nClass/n1)) 31 | assert n1*n2>=nClass 32 | section_1 = split__sections(dim_1, n1) 33 | section_2 = split__sections(dim_2, n2) 34 | regions = [] 35 | x1,x2=0,0 36 | for sec_1 in section_1: 37 | for sec_2 in section_2: 38 | #box=(x1,x1+sec_1,x2,x2+sec_2) 39 | box = shrink(x1,x1+sec_1)+shrink(x2,x2+sec_2) 40 | regions.append(box) 41 | if len(regions)>=nClass: 42 | break 43 | x2 = x2 + sec_2 44 | x1 = x1 + sec_1; x2=0 45 | return regions 46 | 47 | def seed_everything(seed=0): 48 | print(f"======== seed_everything seed={seed}========") 49 | random.seed(seed) 50 | os.environ['PYTHONHASHSEED'] = str(seed) 51 | np.random.seed(seed) 52 | #https://pytorch.org/docs/stable/notes/randomness.html 53 | 54 | torch.manual_seed(seed) 55 | if torch.cuda.is_available(): 56 | torch.cuda.manual_seed(seed) 57 | torch.cuda.manual_seed_all(seed) 58 | 59 | torch.backends.cudnn.deterministic = True 60 | torch.backends.cudnn.benchmark = False 61 | ''' 62 | if fix_seed is not None: # fix seed 63 | seed = fix_seed #17 * 19 64 | print("!!! __pyTorch FIX SEED={} use_cuda={}!!!".format(seed,use_cuda) ) 65 | random.seed(seed-1) 66 | np.random.seed(seed) 67 | torch.manual_seed(seed+1) 68 | if use_cuda: 69 | torch.cuda.manual_seed(seed+2) 70 | torch.cuda.manual_seed_all(seed+3) 71 | torch.backends.cudnn.deterministic = True 72 | ''' 73 | 74 | def cpuStats(): 75 | print(sys.version) 76 | print(psutil.cpu_percent()) 77 | print(psutil.virtual_memory()) # physical memory usage 78 | pid = os.getpid() 79 | py = psutil.Process(pid) 80 | memoryUse = py.memory_info()[0] / 2. ** 30 # memory use in GB...I think 81 | print('memory use in python(GB):', memoryUse) 82 | 83 | def pytorch_env( ): 84 | print('__Python VERSION:', sys.version) 85 | print('__pyTorch VERSION:', torch.__version__) 86 | print('__CUDA VERSION') 87 | # from subprocess import call 88 | # call(["nvcc", "--version"]) does not work 89 | # ! nvcc --version 90 | print('__CUDNN VERSION:', torch.backends.cudnn.version()) 91 | print('__Number CUDA Devices:', torch.cuda.device_count()) 92 | print('__Devices') 93 | # call(["nvidia-smi", "--format=csv", "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"]) 94 | print('Active CUDA Device: GPU', torch.cuda.current_device()) 95 | 96 | print ('Available devices ', torch.cuda.device_count()) 97 | print ('Current cuda device ', torch.cuda.current_device()) 98 | use_cuda = torch.cuda.is_available() 99 | print("USE CUDA=" + str(use_cuda)) 100 | 101 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 102 | FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor 103 | LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor 104 | Tensor = FloatTensor 105 | cpuStats() 106 | print("===== torch_init device={}".format(device)) 107 | return device 108 | 109 | def OnInitInstance(seed=0): 110 | seed_everything(seed) 111 | gpu_device = pytorch_env() 112 | return gpu_device 113 | 114 | def load_model_weights(model, state_dict, log,verbose=True): 115 | """ 116 | Loads the model weights from the state dictionary. Function will only load 117 | the weights which have matching key names and dimensions in the state 118 | dictionary. 119 | 120 | :param state_dict: Pytorch model state dictionary 121 | :param verbose: bool, If True, the function will print the 122 | weight keys of parametares that can and cannot be loaded from the 123 | checkpoint state dictionary. 124 | :return: The model with loaded weights 125 | """ 126 | new_state_dict = model.state_dict() 127 | non_loadable, loadable = set(), set() 128 | 129 | for k, v in state_dict.items(): 130 | if k not in new_state_dict: 131 | non_loadable.add(k) 132 | continue 133 | 134 | if v.shape != new_state_dict[k].shape: 135 | non_loadable.add(k) 136 | continue 137 | 138 | new_state_dict[k] = v 139 | loadable.add(k) 140 | 141 | if verbose: 142 | log.info("### Checkpoint weights that WILL be loaded: ###") 143 | {log.info(k) for k in loadable} 144 | 145 | log.info("### Checkpoint weights that CANNOT be loaded: ###") 146 | {log.info(k) for k in non_loadable} 147 | 148 | model.load_state_dict(new_state_dict) 149 | return model -------------------------------------------------------------------------------- /venv/pyvenv.cfg: -------------------------------------------------------------------------------- 1 | home = D:\anaconda3 2 | include-system-site-packages = false 3 | version = 3.7.3 4 | --------------------------------------------------------------------------------