├── .gitignore
├── .idea
├── ONNet.iml
├── misc.xml
├── modules.xml
├── other.xml
├── vcs.xml
└── workspace.xml
├── ONNet_wavelet.png
├── README.md
├── case_brain.py
├── case_cifar.py
├── case_covir.py
├── case_dog_cat.py
├── case_face_detect.py
├── case_lung_mask.py
├── case_mnist.py
├── python-package
├── case_fft.py
├── cnn_models
│ └── OpticalNet.py
├── fast_conv.py
└── onnet
│ ├── BinaryDNet.py
│ ├── D2NN_tf.py
│ ├── D2NNet.py
│ ├── DiffractiveLayer.py
│ ├── DropOutLayer.py
│ ├── FFT_layer.py
│ ├── Loss.py
│ ├── NET_config.py
│ ├── Net_Instance.py
│ ├── OpticalFormer.py
│ ├── OpticalFormer_util.py
│ ├── PoolForCls.py
│ ├── RGBO_CNN.py
│ ├── SparseSupport.py
│ ├── ToExcel.py
│ ├── Visualizing.py
│ ├── Z_utils.py
│ ├── __init__.py
│ ├── __version__.py
│ ├── optical_trans.py
│ └── some_utils.py
└── venv
└── pyvenv.cfg
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 |
4 | # User-specific files
5 | *.suo
6 | *.user
7 | *.userosscache
8 | *.sln.docstates
9 | *.pth
10 | # User-specific files (MonoDevelop/Xamarin Studio)
11 | *.userprefs
12 | *.npy
13 |
14 | # Build results
15 | net-source
16 | dump/
17 | runs/
18 | cnn_models/
19 | deap/
20 | checkpoint/
21 | python-package/net_source/
22 | [Dd]ebug/
23 | [Dd]ebugPublic/
24 | [Rr]elease/
25 | [Rr]eleases/
26 | [Xx]64/
27 | [Xx]86/
28 | [Bb]uild/
29 | bld/
30 | [Bb]in/
31 | [Oo]bj/
32 | docs/_build
33 | tests/bin
34 | lib
35 | data
36 | _000
37 | python-package/dist
38 | .pytest_cache/v/cache
39 |
40 | # Visual Studio 2015 cache/options directory
41 | .vs/
42 | # Uncomment if you have tasks that create the project's static files in wwwroot
43 | #wwwroot/
44 |
45 | # MSTest test Results
46 | [Tt]est[Rr]esult*/
47 | [Bb]uild[Ll]og.*
48 |
49 | # NUNIT
50 | *.VisualState.xml
51 | TestResult.xml
52 |
53 | # Build Results of an ATL Project
54 | [Dd]ebugPS/
55 | [Rr]eleasePS/
56 | dlldata.c
57 |
58 | # DNX
59 | project.lock.json
60 | artifacts/
61 |
62 | # Python
63 | *.egg-info
64 | __pycache__
65 | .eggs
66 |
67 | # VS Code
68 | .vscode
69 |
70 | # Prerequisites
71 | *.d
72 |
73 | # Compiled Object files
74 | *.slo
75 | *.lo
76 | *.o
77 | *.obj
78 |
79 | # Precompiled Headers
80 | *.gch
81 |
82 | *_i.c
83 | *_p.c
84 | *_i.h
85 | *.ilk
86 | *.meta
87 | *.obj
88 | *.pch
89 | *.pdb
90 | *.pgc
91 | *.pgd
92 | *.rsp
93 | *.sbr
94 | *.tlb
95 | *.tli
96 | *.tlh
97 | *.tmp
98 | *.tmp_proj
99 | *.log
100 | *.vspscc
101 | *.vssscc
102 | .builds
103 | *.pidb
104 | *.svclog
105 | *.scc
106 | *.rar
107 | *.ym
108 | *.model
109 |
110 |
111 | # Chutzpah Test files
112 | _Chutzpah*
113 |
114 | # Visual C++ cache files
115 | ipch/
116 | *.aps
117 | *.ncb
118 | *.opendb
119 | *.opensdf
120 | *.sdf
121 | *.cachefile
122 | *.VC.db
123 |
124 | # Visual Studio profiler
125 | *.psess
126 | *.vsp
127 | *.vspx
128 | *.sap
129 |
130 | # TFS 2012 Local Workspace
131 | $tf/
132 |
133 | # Guidance Automation Toolkit
134 | *.gpState
135 |
136 | # ReSharper is a .NET coding add-in
137 | _ReSharper*/
138 | *.[Rr]e[Ss]harper
139 | *.DotSettings.user
140 |
141 | # JustCode is a .NET coding add-in
142 | .JustCode
143 |
144 | # TeamCity is a build add-in
145 | _TeamCity*
146 |
147 | # DotCover is a Code Coverage Tool
148 | *.dotCover
149 |
150 | # NCrunch
151 | _NCrunch_*
152 | .*crunch*.local.xml
153 | nCrunchTemp_*
154 |
155 | # MightyMoose
156 | *.mm.*
157 | AutoTest.Net/
158 |
159 | # Web workbench (sass)
160 | .sass-cache/
161 |
162 | # Installshield output folder
163 | [Ee]xpress/
164 |
165 | # DocProject is a documentation generator add-in
166 | DocProject/buildhelp/
167 | DocProject/Help/*.HxT
168 | DocProject/Help/*.HxC
169 | DocProject/Help/*.hhc
170 | DocProject/Help/*.hhk
171 | DocProject/Help/*.hhp
172 | DocProject/Help/Html2
173 | DocProject/Help/html
174 |
175 | # Click-Once directory
176 | publish/
177 |
178 | # Publish Web Output
179 | *.[Pp]ublish.xml
180 | *.azurePubxml
181 |
182 | # TODO: Un-comment the next line if you do not want to checkin
183 | # your web deploy settings because they may include unencrypted
184 | # passwords
185 | #*.pubxml
186 | *.publishproj
187 |
188 | # NuGet Packages
189 | *.nupkg
190 | # The packages folder can be ignored because of Package Restore
191 | **/packages/*
192 | # except build/, which is used as an MSBuild target.
193 | !**/packages/build/
194 | # Uncomment if necessary however generally it will be regenerated when needed
195 | #!**/packages/repositories.config
196 | # NuGet v3's project.json files produces more ignoreable files
197 | *.nuget.props
198 | *.nuget.targets
199 |
200 | # Microsoft Azure Build Output
201 | csx/
202 | *.build.csdef
203 |
204 | # Microsoft Azure Emulator
205 | ecf/
206 | rcf/
207 |
208 | # Windows Store app package directory
209 | AppPackages/
210 | BundleArtifacts/
211 |
212 | # Visual Studio cache files
213 | # files ending in .cache can be ignored
214 | *.[Cc]ache
215 | # but keep track of directories ending in .cache
216 | !*.[Cc]ache/
217 |
218 | # Others
219 | ClientBin/
220 | [Ss]tyle[Cc]op.*
221 | ~$*
222 | *~
223 | *.dbmdl
224 | *.dbproj.schemaview
225 | *.pfx
226 | *.publishsettings
227 | node_modules/
228 | orleans.codegen.cs
229 |
230 | # RIA/Silverlight projects
231 | Generated_Code/
232 |
233 | # Backup & report files from converting an old project file
234 | # to a newer Visual Studio version. Backup files are not needed,
235 | # because we have git ;-)
236 | _UpgradeReport_Files/
237 | Backup*/
238 | UpgradeLog*.XML
239 | UpgradeLog*.htm
240 |
241 | # SQL Server files
242 | *.mdf
243 | *.ldf
244 |
245 | # Business Intelligence projects
246 | *.rdl.data
247 | *.bim.layout
248 | *.bim_*.settings
249 |
250 | # Microsoft Fakes
251 | FakesAssemblies/
252 |
253 | # GhostDoc plugin setting file
254 | *.GhostDoc.xml
255 |
256 | # Node.js Tools for Visual Studio
257 | .ntvs_analysis.dat
258 |
259 | # Visual Studio 6 build log
260 | *.plg
261 |
262 | # Visual Studio 6 workspace options file
263 | *.opt
264 |
265 | # Visual Studio LightSwitch build output
266 | **/*.HTMLClient/GeneratedArtifacts
267 | **/*.DesktopClient/GeneratedArtifacts
268 | **/*.DesktopClient/ModelManifest.xml
269 | **/*.Server/GeneratedArtifacts
270 | **/*.Server/ModelManifest.xml
271 | _Pvt_Extensions
272 |
273 | # LightSwitch generated files
274 | GeneratedArtifacts/
275 | ModelManifest.xml
276 |
277 | # Paket dependency manager
278 | .paket/paket.exe
279 |
280 | # FAKE - F# Make
281 | .fake/
282 | *.lai
283 | *.la
284 | *.a
285 | *.lib
286 | *.zip
287 | *.info
288 | *.dll
289 | *.so
290 | *.dylib
291 | *.mA_bin
292 | *.dat
293 | *.avi
294 | *.ogv
295 | *.asv
296 | *.code
297 | /tests/python_package_test/.pytest_cache/v/cache
298 | /tests/python_package_test/categorical.model
299 | /python-package/geo_test.py
300 | *.csv
301 | /python-package/.pytest_cache/v/cache/lastfailed
302 | /python-package/.pytest_cache/v/cache/nodeids
303 | /python-package/case_qq2019.py
304 | *.txt
305 | *.jpg
306 | /src/learn/discpy.py
307 | /src/learn/sparsipy.py
308 | /doc/Gradient boosting on adpative distrubutions.docx
309 | /python-package/litemort/桌面.lnk
310 | /python-package/LiteMORT_hyppo.py
311 | /python-package/shap_test.py
312 | *.pickle
313 | *.gz
314 | logger.py
315 |
--------------------------------------------------------------------------------
/.idea/ONNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ONNet_wavelet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/closest-git/ONNet/79dacffe164369e564650f65b3e9e857668b63bc/ONNet_wavelet.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ONNet
2 |
3 | **ONNet** is an open-source Python/C++ package for the optical neural networks, which provides many tools for researchers studying optical neural networks. Some new models are as follows:
4 |
5 | - #### Express Wavenet
6 |
7 | 
8 |
9 | Express Wavenet uses random shift wavelet pattern to modulate the phase of optical waves, which only need one percent of the parameters and the accuracy is still high. In the MNIST dataset, it only needs 1229 parameters to get accuracy of 92%, while DDNet needs 125440 parameters. .[2]
10 |
11 | - #### Diffractive deep neural network with multiple frequency-channels
12 |
13 | Each layer have multiple frequency-channels (optical distributions at different frequency). These channels are merged at the output plane with weighting coefficient. [1]
14 |
15 | - #### Diffractive network with multiple binary output plane
16 |
17 |
18 |
19 | Optical neural network(ONN) is a novel machine learning framework on the physical principles of optics, which is still in its infancy and shows great potential. ONN tries to find optimal modulation parameters to change the phase, amplitude or other physical variable of optical wave propagation. So in the final output plane, the optical distribution has special pattern which is the indicator of object’s class or value. ONN opens new doors for the machine learning.
20 |
21 | # BTW:
22 |
23 | I used to think that "ONN opens new doors for the machine learning", but now it seems only few people admit the significance of ONN to machine learning. It's really hard to explain why ONN performs so poorly on widely used data sets(CIFAR...), let alone Imagenet!
24 |
25 | Fortunately, I find the optical diffraction model has subtle connection with some mathematical models, which is worthy of further study.
26 |
27 | ---2/27/2022
28 |
29 | ## Citation
30 |
31 | Please use the following bibtex entry:
32 | ```
33 | [1] Xinyu, Zhang, Jiashuo Shi, and Yingshi Chen. "A Broad-Spectrum Diffractive Network via Ensemble Learning." Opt. Lett 46 (2021): 14.
34 | [2] Chen, Yingshi, et al."An optical diffractive deep neural network with multiple frequency-channels." arXiv preprint arXiv:1912.10730 (2019).
35 | [3] Chen, Yingshi, et al. "Express Wavenet: A lower parameter optical neural network with random shift wavelet pattern." Optics Communications 485 (2021): 126709.
36 | ```
37 |
38 | ## Future work
39 |
40 | - More testing datasets
41 |
42 | Cifar, ImageNet ......
43 |
44 | - More models
45 |
46 | - More papers.
47 |
48 |
49 |
50 | ## License
51 |
52 | The provided implementation is strictly for academic purposes only. If anyone is interested in using our technology for any commercial use, please contact us.
53 |
54 | ## Authors
55 |
56 | Yingshi Chen (gsp.cys@gmail.com)
57 |
58 | QQ group: 1001583663
59 |
--------------------------------------------------------------------------------
/case_brain.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Author: Yingshi Chen
3 |
4 | @Date: 2020-04-08 17:12:34
5 | @
6 | # Description:
7 | '''
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import ToPILImage
11 | import os
12 | import math
13 | import hdf5storage
14 | from enum import Enum
15 | import re
16 | from torchvision.transforms import transforms
17 | import cv2
18 | import numpy as np
19 |
20 | def get_data_if_needed(data_path='./data/', url="https://ndownloader.figshare.com/articles/1512427/versions/5"):
21 | if os.path.isdir(data_path):
22 | #_arrange_brain_tumor_data(data_path)
23 | print("Data directory already exists. ",
24 | "if from some reason the data directory structure is wrong please remove the data dir and rerun this script")
25 | return
26 | filename = "all_data.zip"
27 | download_url(url, data_path, filename)
28 | unzip_all_files(data_path)
29 | _arrange_brain_tumor_data(data_path)
30 |
31 | def convert_landmark_to_bounding_box(landmark):
32 | x_min = x_max = y_min = y_max = None
33 | for x, y in landmark:
34 | if x_min is None:
35 | x_min = x_max = x
36 | y_min = y_max = y
37 | else:
38 | x_min, x_max = min(x, x_min), max(x, x_max)
39 | y_min, y_max = min(y, y_min), max(y, y_max)
40 | return [int(x_min), int(x_max), int(y_min), int(y_max)]
41 |
42 | class ClassesLabels(Enum):
43 | Meningioma = 1
44 | Glioma = 2
45 | Pituitary = 3
46 |
47 | def __len__(self):
48 | return 3
49 |
50 | def normalize(x, mean=470, std=None):
51 | mean_tansor = torch.ones_like(x) * mean
52 | x -= mean_tansor
53 | if std:
54 | x /= std
55 | return x
56 |
57 | # https://github.com/galprz/brain-tumor-segmentation
58 | class BrainTumorDataset(Dataset):
59 | def __init__(self,config, root, train=True, download=True,
60 | classes=(ClassesLabels.Meningioma,
61 | ClassesLabels.Glioma,
62 | ClassesLabels.Pituitary)):
63 | super().__init__()
64 | self.config = config
65 | test_fr = 0.15
66 | if download:
67 | get_data_if_needed(root)
68 | self.root = root
69 | # List all data files
70 | items = []
71 | if ClassesLabels.Meningioma in classes:
72 | items += ['meningioma/' + item for item in os.listdir(root + 'meningioma/')]
73 | if ClassesLabels.Glioma in classes:
74 | items += ['glioma/' + item for item in os.listdir(root + 'glioma/')]
75 | if ClassesLabels.Meningioma in classes:
76 | items += ['pituitary/' + item for item in os.listdir(root + 'pituitary/')]
77 |
78 | if train:
79 | self.items = items[0:math.floor((1-test_fr) * len(items)) + 1]
80 | else:
81 | self.items = items[math.floor((1-test_fr) * len(items)) + 1:]
82 |
83 | def __len__(self):
84 | return len(self.items)
85 |
86 | def __getitem__(self, idx):
87 | if not (0 <= idx < len(self.items)):
88 | raise IndexError("Idx out of bound")
89 | if False:
90 | data = hdf5storage.loadmat(self.root + self.items[idx])['cjdata'][0]
91 | # transform the tumor border to array of (x, y) tuple
92 | xy = data[3]
93 | landmarks = []
94 | for i in range(0, len(xy), 2):
95 | x = xy[i][0]
96 | y = xy[i + 1][0]
97 | landmarks.append((x, y))
98 | mask = data[4]
99 | data[2].dtype = 'uint16'
100 | image = data[2] #ToPILImage()(data[2])
101 | image_with_metadata = {
102 | "label": int(data[0][0]),
103 | "image": image,
104 | "landmarks": landmarks,
105 | "mask": mask,
106 | "bounding_box": convert_landmark_to_bounding_box(landmarks)
107 | }
108 | return image_with_metadata
109 | else:
110 | return load_mat_trans(self.root + self.items[idx],target_size=self.config.IMG_size ) #(128,128)
111 |
112 | def ToUint8(arr):
113 | a_0,a_1 = np.min(arr),np.max(arr)
114 | arr = (arr-a_0)/(a_1-a_0)*255
115 | arr = arr.astype(np.uint8)
116 | a_0,a_1 = np.min(arr),np.max(arr)
117 | return arr
118 |
119 | def load_mat_trans(path,target_size=None):
120 | data_mat = hdf5storage.loadmat(path)
121 | data = data_mat['cjdata'][0]
122 | xy = data[3]
123 | landmarks = []
124 | for i in range(0, len(xy), 2):
125 | x = xy[i][0]
126 | y = xy[i + 1][0]
127 | landmarks.append((x, y))
128 | mask = data[4].astype(np.float32)
129 | m_0,m_1 = np.min(mask),np.max(mask)
130 | #data[2].dtype = 'uint16'
131 | image = data[2].astype(np.float32) #ToPILImage()(data[2])
132 | if target_size is not None:
133 | image = cv2.resize(image,target_size)
134 | #cv2.imshow("",image); cv2.waitKey(0)
135 | mask = cv2.resize(mask,target_size)
136 | #cv2.imshow("",mask*255); cv2.waitKey(0)
137 | image = ToUint8(image)
138 | mask = ToUint8(mask)
139 | image_with_metadata = {
140 | "label": int(data[0][0]),
141 | "image": image,
142 | "landmarks": landmarks,
143 | "mask": mask,
144 | "bounding_box": convert_landmark_to_bounding_box(landmarks)
145 | }
146 | return image_with_metadata
147 |
148 | mask_transformer = transforms.Compose([
149 | transforms.ToTensor(),
150 | ])
151 |
152 | image_transformer_0 = transforms.Compose([
153 | transforms.ToTensor(),
154 | transforms.Lambda(lambda x: normalize(x))
155 | ])
156 | image_transformer = transforms.Compose([
157 | transforms.ToTensor(),
158 | ])
159 |
160 | class BrainTumorDatasetMask(BrainTumorDataset):
161 | def transform(self,image, mask):
162 | img = image_transformer(image).float()
163 | mask = mask_transformer(mask).float()
164 | return img,mask
165 |
166 | def __init__(self,config, root, train=True, transform=None, classes=(ClassesLabels.Meningioma,
167 | ClassesLabels.Glioma,
168 | ClassesLabels.Pituitary)):
169 | super().__init__(config,root, train, classes=classes)
170 | #self.transform = brain_transform
171 |
172 | def __getitem__(self, idx):
173 | item = super().__getitem__(idx)
174 | sample = (item["image"], item["mask"])
175 | #return sample if self.transform is None else self.transform(*sample)
176 | img,mask = self.transform(item["image"], item["mask"])
177 | #i_0,i_1 = torch.min(img),torch.max(img)
178 | #m_0,m_1 = torch.min(mask),torch.max(mask)
179 | return img,mask
180 |
181 | def _arrange_brain_tumor_data(root):
182 | # Remove and split files
183 | items = [item for item in filter(lambda item: re.search("^[0-9]+\.mat$", item), os.listdir(root))]
184 | try:
185 | os.mkdir(root + 'meningioma/')
186 | except:
187 | print("Meningioma directory already exists")
188 | try:
189 | os.mkdir(root + 'glioma/')
190 | except:
191 | print("Glioma directory already exists")
192 | try:
193 | os.mkdir(root + 'pituitary/')
194 | except:
195 | print("Pituitary directory already exists")
196 |
197 | for item in items:
198 | sample = hdf5storage.loadmat(root + item)['cjdata'][0]
199 | if sample[2].shape[0] == 512:
200 | if sample[0] == 1:
201 | os.rename(root + item, root + 'meningioma/' + item)
202 | if sample[0] == 2:
203 | os.rename(root + item, root + 'glioma/' + item)
204 | if sample[0] == 3:
205 | os.rename(root + item, root + 'pituitary/' + item)
206 | else:
207 | os.remove(root + item)
208 |
209 |
--------------------------------------------------------------------------------
/case_cifar.py:
--------------------------------------------------------------------------------
1 | '''
2 | Train CIFAR10 with PyTorch.
3 | https://github.com/kuangliu/pytorch-cifar
4 |
5 | https://medium.com/@wwwbbb8510/lessons-learned-from-reproducing-resnet-and-densenet-on-cifar-10-dataset-6e25b03328da
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | import torch.backends.cudnn as cudnn
12 | import torchvision
13 | import torchvision.transforms as transforms
14 | import os
15 | import sys
16 | import argparse
17 | CNN_MODEL_root = os.path.dirname(os.path.abspath(__file__))+"/python-package"
18 | sys.path.append(CNN_MODEL_root)
19 | from cnn_models import *
20 | ONNET_DIR = os.path.abspath("./python-package/")
21 | sys.path.append(ONNET_DIR) # To find local version of the onnet
22 | from onnet import *
23 | from onnet.OpticalFormer import clip_grad
24 | import sys
25 | import time
26 | import torch.nn as nn
27 | import torch.nn.init as init
28 |
29 |
30 | # The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. The dataset is divided into five training batches and one test batch, each with 10000 images.
31 | IMG_size = (32, 32)
32 | IMG_size = (96, 96)
33 | isDNet = False
34 | isGrayScale = False
35 |
36 | def get_mean_and_std(dataset):
37 | '''Compute the mean and std value of dataset.'''
38 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
39 | mean = torch.zeros(3)
40 | std = torch.zeros(3)
41 | print('==> Computing mean and std..')
42 | for inputs, targets in dataloader:
43 | for i in range(3):
44 | mean[i] += inputs[:,i,:,:].mean()
45 | std[i] += inputs[:,i,:,:].std()
46 | mean.div_(len(dataset))
47 | std.div_(len(dataset))
48 | return mean, std
49 |
50 | def init_params(net):
51 | '''Init layer parameters.'''
52 | for m in net.modules():
53 | if isinstance(m, nn.Conv2d):
54 | init.kaiming_normal(m.weight, mode='fan_out')
55 | if m.bias:
56 | init.constant(m.bias, 0)
57 | elif isinstance(m, nn.BatchNorm2d):
58 | init.constant(m.weight, 1)
59 | init.constant(m.bias, 0)
60 | elif isinstance(m, nn.Linear):
61 | init.normal(m.weight, std=1e-3)
62 | if m.bias:
63 | init.constant(m.bias, 0)
64 |
65 | #_, term_width = os.popen('stty size', 'r').read().split()
66 | term_width = 80
67 | TOTAL_BAR_LENGTH = 25.
68 | last_time = time.time()
69 | begin_time = last_time
70 | def progress_bar(current, total, msg=None):
71 | if current < total-1:
72 | sys.stdout.write('\r')
73 | global last_time, begin_time
74 | if current == 0:
75 | begin_time = time.time() # Reset for new bar.
76 |
77 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
78 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
79 |
80 | sys.stdout.write(' [')
81 | for i in range(cur_len):
82 | sys.stdout.write('=')
83 | sys.stdout.write('>')
84 | for i in range(rest_len):
85 | sys.stdout.write('.')
86 | sys.stdout.write(']')
87 |
88 | cur_time = time.time()
89 | step_time = cur_time - last_time
90 | last_time = cur_time
91 | tot_time = cur_time - begin_time
92 |
93 | L = []
94 | L.append(' Step: %s' % format_time(step_time))
95 | L.append(' | Tot: %s' % format_time(tot_time))
96 | if msg:
97 | L.append(' | ' + msg)
98 |
99 | msg = ''.join(L)
100 | sys.stdout.write(msg)
101 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
102 | sys.stdout.write(' ')
103 |
104 | if False:
105 | # Go back to the center of the bar.
106 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
107 | sys.stdout.write('\b')
108 | sys.stdout.write(' %d/%d ' % (current+1, total))
109 |
110 | if current < total-1:
111 | pass#sys.stdout.write('\r')
112 | else:
113 | sys.stdout.write('\n')
114 | sys.stdout.flush()
115 |
116 | def format_time(seconds):
117 | days = int(seconds / 3600/24)
118 | seconds = seconds - days*3600*24
119 | hours = int(seconds / 3600)
120 | seconds = seconds - hours*3600
121 | minutes = int(seconds / 60)
122 | seconds = seconds - minutes*60
123 | secondsf = int(seconds)
124 | seconds = seconds - secondsf
125 | millis = int(seconds*1000)
126 |
127 | f = ''
128 | i = 1
129 | if days > 0:
130 | f += str(days) + 'D'
131 | i += 1
132 | if hours > 0 and i <= 2:
133 | f += str(hours) + 'h'
134 | i += 1
135 | if minutes > 0 and i <= 2:
136 | f += str(minutes) + 'm'
137 | i += 1
138 | if secondsf > 0 and i <= 2:
139 | f += str(secondsf) + 's'
140 | i += 1
141 | if millis > 0 and i <= 2:
142 | f += str(millis) + 'ms'
143 | i += 1
144 | if f == '':
145 | f = '0ms'
146 | return f
147 |
148 |
149 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
150 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
151 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
152 | # "--gradient_clip=agc",
153 | # "--self_attention=gabor"
154 | args = parser.parse_args()
155 |
156 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
157 | best_acc = 0 # best test accuracy
158 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
159 |
160 | # Data
161 | def Init():
162 | print('==> Preparing data..')
163 | transform_train = transforms.Compose([
164 | transforms.RandomCrop(32, padding=4),
165 | # transforms.Grayscale(),
166 | transforms.RandomHorizontalFlip(),
167 | transforms.Resize(IMG_size),
168 | transforms.ToTensor(),
169 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
170 | # transforms.Normalize(0.48, 0.20),
171 | ])
172 |
173 | transform_test = transforms.Compose([
174 | # transforms.Grayscale(),
175 | transforms.Resize(IMG_size),
176 | transforms.ToTensor(),
177 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
178 | # transforms.Normalize(0.48, 0.20),
179 | ])
180 |
181 | trainset = torchvision.datasets.CIFAR10(root='/home/cys/Downloads/cifar10/', train=True, download=True, transform=transform_train)
182 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
183 |
184 | testset = torchvision.datasets.CIFAR10(root='/home/cys/Downloads/cifar10/', train=False, download=True, transform=transform_test)
185 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
186 |
187 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
188 | # Model
189 | print('==> Building model..')
190 | if isDNet:
191 | #config_0 = RGBO_CNN_config("RGBO_CNN", 'cifar_10', IMG_size, lr_base=args.lr, batch_size=128, nClass=10, nLayer=5)
192 | #env_title, net = RGBO_CNN_instance(config_0)
193 | config_0 = NET_config("DNet",'cifar_10',IMG_size,lr_base=args.lr,batch_size=128, nClass=10, nLayer=10)
194 | env_title, net = DNet_instance(config_0)
195 | config_base = net.config
196 | else:
197 | config_0 = NET_config("OptFormer", 'cifar_10', IMG_size, lr_base=args.lr, batch_size=128, nClass=10)
198 | # net = VGG('VGG19')
199 | #net = ResNet34(); env_title='ResNet34'; net.legend = 'ResNet34'
200 | # net = OpticalNet34(config_0); env_title = 'OpticalNet34'; net.legend = 'OpticalNet34'
201 | env_title, net = DNet_instance(config_0)
202 | # net = PreActResNet18()
203 | # net = GoogLeNet()
204 | # net = DenseNet121()
205 | # net = ResNeXt29_2x64d()
206 | # net = MobileNet()
207 | # net = MobileNetV2()
208 | # net = DPN92(); env_title='DPN92'; net.legend = 'DPN92'
209 | # net = DPN26(); env_title = 'DPN92'; net.legend = 'DPN92'
210 | # net = ShuffleNetG2()
211 | # net = SENet18()
212 | # net = ShuffleNetV2(1)
213 | # net = EfficientNetB0(); env_title='EfficientNetB0'
214 | #visual = Visdom_Visualizer(env_title=env_title)
215 |
216 | print(net)
217 | Net_dump(net)
218 | net = net.to(device)
219 | visual = Visdom_Visualizer(env_title=env_title)
220 | #if hasattr(net, 'DInput'): net.DInput.visual = visual # 看一看
221 |
222 | if device == 'cuda':
223 | pass
224 | #net = torch.nn.DataParallel(net) #https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
225 | #cudnn.benchmark = True #结果会有扰动 https://zhuanlan.zhihu.com/p/73711222
226 |
227 | if args.resume:
228 | # Load checkpoint.
229 | print('==> Resuming from checkpoint..')
230 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
231 | checkpoint = torch.load('./checkpoint/ckpt.pth')
232 | net.load_state_dict(checkpoint['net'])
233 | best_acc = checkpoint['acc']
234 | start_epoch = checkpoint['epoch']
235 |
236 | criterion = nn.CrossEntropyLoss()
237 | #using SGD with scheduled learning rate much better than Adam
238 | optimizer = optim.Adam(net.parameters(), lr=args.lr) # weight_decay=0.0005
239 | #optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
240 |
241 | return net,trainloader,testloader,optimizer,criterion,visual
242 |
243 | # Training
244 | def train(epoch,net,trainloader,optimizer,criterion):
245 | print('\nEpoch: %d' % epoch)
246 | if epoch==0:
247 | #print(f"\n=======dataset={dataset} net={net_type} IMG_size={IMG_size} batch_size={batch_size}")
248 | #print(f"======={net.config}")
249 | print(f"======={optimizer}")
250 | #print(f"======={train_trans}\n")
251 | net.train()
252 | train_loss = 0
253 | correct = 0
254 | total = 0
255 | for batch_idx, (inputs, targets) in enumerate(trainloader):
256 | inputs, targets = inputs.to(device), targets.to(device)
257 | optimizer.zero_grad()
258 | outputs = net(inputs)
259 | loss = criterion(outputs, targets)
260 | loss.backward() #retain_graph=True
261 | if net.clip_grad == "agc":
262 | clip_grad(net)
263 | optimizer.step()
264 |
265 | train_loss += loss.item()
266 | _, predicted = outputs.max(1)
267 | total += targets.size(0)
268 | correct += predicted.eq(targets).sum().item()
269 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
270 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
271 | #break
272 |
273 |
274 | def test(epoch,net,testloader,criterion,visual):
275 | global best_acc
276 | net.eval()
277 | test_loss = 0
278 | correct = 0
279 | total = 0
280 | with torch.no_grad():
281 | for batch_idx, (inputs, targets) in enumerate(testloader):
282 | inputs, targets = inputs.to(device), targets.to(device)
283 | outputs = net(inputs)
284 | loss = criterion(outputs, targets)
285 |
286 | test_loss += loss.item()
287 | _, predicted = outputs.max(1)
288 | total += targets.size(0)
289 | correct += predicted.eq(targets).sum().item()
290 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
291 | % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
292 | #break
293 |
294 | # Save checkpoint.
295 | acc = 100.*correct/total
296 | legend = "resnet"#net.module.legend()
297 | visual.UpdateLoss(title=f"Accuracy on \"cifar_10\"", legend=f"{legend}", loss=acc, yLabel="Accuracy")
298 | if False and acc > best_acc:
299 | print('Saving..')
300 | state = {
301 | 'net': net.state_dict(),
302 | 'acc': acc,
303 | 'epoch': epoch,
304 | }
305 | if not os.path.isdir('checkpoint'):
306 | os.mkdir('checkpoint')
307 | torch.save(state, './checkpoint/ckpt.pth')
308 | best_acc = acc
309 |
310 | if __name__ == '__main__':
311 | seed_everything(42)
312 | net,trainloader,testloader,optimizer,criterion,visual = Init()
313 | #legend = net.module.legend()
314 |
315 | for epoch in range(start_epoch, start_epoch+2000):
316 | train(epoch,net,trainloader,optimizer,criterion)
317 | test(epoch,net,testloader,criterion,visual)
318 |
--------------------------------------------------------------------------------
/case_covir.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Author: Yingshi Chen
3 | https://github.com/lindawangg/COVID-Net/blob/master/create_COVIDx_v2.ipynb
4 | @Date: 2020-04-06 15:50:21
5 | @
6 | # Description:
7 | '''
8 | import numpy as np
9 | import pandas as pd
10 | import os
11 | import random
12 | from shutil import copyfile
13 | import pydicom as dicom
14 | import cv2
15 | from torch.utils.data import Dataset,DataLoader
16 | from torch.optim.lr_scheduler import ReduceLROnPlateau
17 | from torch.nn import CrossEntropyLoss
18 | from PIL import Image
19 | import logging
20 | import sys
21 | import time
22 | ONNET_DIR = os.path.abspath("./python-package/")
23 | sys.path.append(ONNET_DIR) # To find local version of the onnet
24 | #sys.path.append(os.path.abspath("./python-package/cnn_models/"))
25 | from cnn_models.COVIDNext50 import COVIDNext50
26 | from onnet import *
27 | import torch
28 | from torch.optim import Adam
29 | from torchvision import transforms
30 | from sklearn.metrics import f1_score, precision_score, recall_score,accuracy_score,classification_report
31 |
32 | isONN=True
33 | class COVID_set(Dataset):
34 | def __init__(self, config,img_dir, labels_file, transforms):
35 | self.config = config
36 | self.img_pths, self.labels = self._prepare_data(img_dir, labels_file)
37 | self.transforms = transforms
38 |
39 |
40 | def _prepare_data(self, img_dir, labels_file):
41 | with open(labels_file, 'r') as f:
42 | labels_raw = f.readlines()
43 |
44 | labels, img_pths = [], []
45 | for i in range(len(labels_raw)):
46 | data = labels_raw[i].split()
47 | img_pth = data[1]
48 | #img_name = data[1]
49 | #img_pth = os.path.join(img_dir, img_name)
50 | img_pths.append(img_pth)
51 | labels.append(self.config.mapping[data[2]])
52 |
53 | return img_pths, labels
54 |
55 | def __len__(self):
56 | return len(self.labels)
57 |
58 | def __getitem__(self, idx):
59 | img = Image.open(self.img_pths[idx]).convert("RGB")
60 | img_tensor = self.transforms(img)
61 |
62 | label = self.labels[idx]
63 | label_tensor = torch.tensor(label, dtype=torch.long)
64 |
65 | return img_tensor, label_tensor
66 |
67 | def train_test_split():
68 | seed = 0
69 | np.random.seed(seed) # Reset the seed so all runs are the same.
70 | random.seed(seed)
71 | MAXVAL = 255 # Range [0 255]
72 |
73 | # path to covid-19 dataset from https://github.com/ieee8023/covid-chestxray-dataset
74 | imgpath = 'E:/Insegment/covid-chestxray-dataset-master/images'
75 | csvpath = 'E:/Insegment/covid-chestxray-dataset-master/metadata.csv'
76 |
77 | # path to https://www.kaggle.com/c/rsna-pneumonia-detection-challenge
78 | kaggle_datapath = 'F:/Datasets/rsna-pneumonia-detection-challenge/'
79 | kaggle_csvname = 'stage_2_detailed_class_info.csv' # get all the normal from here
80 | kaggle_csvname2 = 'stage_2_train_labels.csv' # get all the 1s from here since 1 indicate pneumonia
81 | kaggle_imgpath = 'stage_2_train_images'
82 |
83 | # parameters for COVIDx dataset
84 | train = []
85 | test = []
86 | test_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}
87 | train_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}
88 |
89 | mapping = dict()
90 | mapping['COVID-19'] = 'COVID-19'
91 | mapping['SARS'] = 'pneumonia'
92 | mapping['MERS'] = 'pneumonia'
93 | mapping['Streptococcus'] = 'pneumonia'
94 | mapping['Normal'] = 'normal'
95 | mapping['Lung Opacity'] = 'pneumonia'
96 | mapping['1'] = 'pneumonia'
97 |
98 | train_file = open("train_split_v2.txt","a")
99 | test_file = open("test_split_v2.txt", "a")
100 | # train/test split
101 | split = 0.1
102 | csv = pd.read_csv(csvpath, nrows=None)
103 | idx_pa = csv["view"] == "PA" # Keep only the PA view
104 | csv = csv[idx_pa]
105 |
106 | pneumonias = ["COVID-19", "SARS", "MERS", "ARDS", "Streptococcus"]
107 | pathologies = ["Pneumonia","Viral Pneumonia", "Bacterial Pneumonia", "No Finding"] + pneumonias
108 | pathologies = sorted(pathologies)
109 |
110 | filename_label = {'normal': [], 'pneumonia': [], 'COVID-19': []}
111 | count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}
112 | for index, row in csv.iterrows():
113 | f = row['finding']
114 | if f in mapping:
115 | count[mapping[f]] += 1
116 | entry = [int(row['patientid']), row['filename'], mapping[f]]
117 | filename_label[mapping[f]].append(entry)
118 |
119 | print('Data distribution from covid-chestxray-dataset:')
120 | print(count)
121 |
122 | for key in filename_label.keys():
123 | arr = np.array(filename_label[key])
124 | if arr.size == 0:
125 | continue
126 | # split by patients
127 | # num_diff_patients = len(np.unique(arr[:,0]))
128 | # num_test = max(1, round(split*num_diff_patients))
129 | # select num_test number of random patients
130 | if key == 'pneumonia':
131 | test_patients = ['8', '31']
132 | elif key == 'COVID-19':
133 | test_patients = ['19', '20', '36', '42', '86'] # random.sample(list(arr[:,0]), num_test)
134 | else:
135 | test_patients = []
136 | print('Key: ', key)
137 | print('Test patients: ', test_patients)
138 | # go through all the patients
139 | for patient in arr:
140 | info = f"{str(patient[0])} {imgpath}\{patient[1]} {patient[2]}\n"
141 | if patient[0] in test_patients:
142 | #copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'test', patient[1]))
143 | test.append(patient); test_count[patient[2]] += 1
144 | train_file.write(info)
145 | else:
146 | #copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1]))
147 | train.append(patient); train_count[patient[2]] += 1
148 | test_file.write(info)
149 |
150 |
151 | csv_normal = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname), nrows=None)
152 | csv_pneu = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname2), nrows=None)
153 | patients = {'normal': [], 'pneumonia': []}
154 |
155 | for index, row in csv_normal.iterrows():
156 | if row['class'] == 'Normal':
157 | patients['normal'].append(row['patientId'])
158 |
159 | for index, row in csv_pneu.iterrows():
160 | if int(row['Target']) == 1:
161 | patients['pneumonia'].append(row['patientId'])
162 |
163 | for key in patients.keys():
164 | arr = np.array(patients[key])
165 | if arr.size == 0:
166 | continue
167 | # split by patients
168 | num_diff_patients = len(np.unique(arr))
169 | num_test = max(1, round(split*num_diff_patients))
170 | #test_patients = np.load('rsna_test_patients_{}.npy'.format(key)) #
171 | test_patients = random.sample(list(arr), num_test) #, download the .npy files from the repo.
172 | np.save('rsna_test_patients_{}.npy'.format(key), np.array(test_patients))
173 | for patient in arr:
174 | ds = dicom.dcmread(os.path.join(kaggle_datapath, kaggle_imgpath, patient + '.dcm'))
175 | pixel_array_numpy = ds.pixel_array
176 | imgname = patient + '.png'
177 | if patient in test_patients:
178 | path = os.path.join(kaggle_datapath, 'test', imgname)
179 | cv2.imwrite(path, pixel_array_numpy)
180 | test.append([patient, imgname, key]); test_count[key] += 1
181 | test_file.write(f"{patient} {path} {key}\n" )
182 | if test_count[key]%50==0:
183 | test_file.flush()
184 | else:
185 | path = os.path.join(kaggle_datapath, 'train', imgname)
186 | cv2.imwrite(path, pixel_array_numpy)
187 | train_file.write(f"{patient} {path} {key}\n")
188 | if train_count[key]%20==0:
189 | train_file.flush()
190 | train.append([patient, imgname, key]); train_count[key] += 1
191 | print(f"\r@{path}",end="")
192 |
193 | print('Final stats')
194 | print('Train count: ', train_count)
195 | print('Test count: ', test_count)
196 | print('Total length of train: ', len(train))
197 | print('Total length of test: ', len(test))
198 |
199 | train_file.close()
200 | test_file.close()
201 |
202 |
203 |
204 | log = logging.getLogger(__name__)
205 | logging.basicConfig(level=logging.INFO)
206 |
207 |
208 | def save_model(model, config):
209 | if isinstance(model, torch.nn.DataParallel):
210 | # Save without the DataParallel module
211 | model_dict = model.module.state_dict()
212 | else:
213 | model_dict = model.state_dict()
214 |
215 | state = {
216 | "state_dict": model_dict,
217 | "global_step": config['global_step'],
218 | "clf_report": config['clf_report']
219 | }
220 | f1_macro = config['clf_report']['macro avg']['f1-score'] * 100
221 | name = "{}_F1_{:.2f}_step_{}.pth".format(config['name'],
222 | f1_macro,
223 | config['global_step'])
224 | model_path = os.path.join(config['save_dir'], name)
225 | torch.save(state, model_path)
226 | log.info("Saved model to {}".format(model_path))
227 |
228 |
229 | def validate(data_loader, model, best_score, global_step, cfg):
230 | model.eval()
231 | gts, predictions = [], []
232 |
233 | log.info("Validation started...")
234 | for data in data_loader:
235 | imgs, labels = data
236 | imgs = to_device(imgs, gpu=cfg.gpu)
237 |
238 | with torch.no_grad():
239 | logits = model(imgs)
240 | if isONN:
241 | preds = net.predict(logits).cpu().numpy()
242 | else:
243 | probs = model.module.probability(logits)
244 | preds = torch.argmax(probs, dim=1).cpu().numpy()
245 |
246 | labels = labels.cpu().detach().numpy()
247 | predictions.extend(preds)
248 | gts.extend(labels)
249 |
250 | predictions = np.array(predictions, dtype=np.int32)
251 | gts = np.array(gts, dtype=np.int32)
252 | acc, f1, prec, rec = clf_metrics(predictions=predictions,targets=gts,average="macro")
253 | report = classification_report(gts, predictions, output_dict=True)
254 | log.info("\n====== VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | Recall {:.4f}".format(acc, f1, prec, rec))
255 |
256 | if f1 > best_score:
257 | save_config = {
258 | 'name': config.name,
259 | 'save_dir': config.ckpts_dir,
260 | 'global_step': global_step,
261 | 'clf_report': report
262 | }
263 | #save_model(model=model, config=save_config)
264 | best_score = f1
265 | #log.info("Validation end")
266 | model.train()
267 | return best_score
268 |
269 | def train_transforms(width, height):
270 | trans_list = [
271 | transforms.Resize((height, width)),
272 | transforms.RandomVerticalFlip(p=0.5),
273 | transforms.RandomHorizontalFlip(p=0.5),
274 | transforms.RandomApply([
275 | transforms.RandomAffine(degrees=20,
276 | translate=(0.15, 0.15),
277 | scale=(0.8, 1.2),
278 | shear=5)], p=0.5),
279 | transforms.RandomApply([
280 | transforms.ColorJitter(brightness=0.3, contrast=0.3)], p=0.5),
281 | transforms.Grayscale(),
282 | transforms.ToTensor()
283 | ]
284 | return transforms.Compose(trans_list)
285 |
286 |
287 | def val_transforms(width, height):
288 | trans_list = [
289 | transforms.Resize((height, width)),
290 | transforms.Grayscale(),
291 | transforms.ToTensor()
292 | ]
293 | return transforms.Compose(trans_list)
294 |
295 | def to_device(tensor, gpu=False):
296 | return tensor.cuda() if gpu else tensor.cpu()
297 |
298 | def clf_metrics(predictions, targets, average='macro'):
299 | f1 = f1_score(targets, predictions, average=average)
300 | precision = precision_score(targets, predictions, average=average)
301 | recall = recall_score(targets, predictions, average=average)
302 | acc = accuracy_score(targets, predictions)
303 |
304 | return acc, f1, precision, recall
305 |
306 | def main(model):
307 | if config.gpu and not torch.cuda.is_available():
308 | raise ValueError("GPU not supported or enabled on this system.")
309 | use_gpu = config.gpu
310 |
311 | log.info("Loading train dataset")
312 | train_dataset = COVID_set(config,config.train_imgs, config.train_labels,train_transforms(config.width,config.height))
313 | train_loader = DataLoader(train_dataset,
314 | batch_size=config.batch_size,shuffle=True,drop_last=True, num_workers=config.n_threads,pin_memory=use_gpu)
315 | log.info("Number of training examples {}".format(len(train_dataset)))
316 |
317 | log.info("Loading val dataset")
318 | val_dataset = COVID_set(config,config.val_imgs, config.val_labels,val_transforms(config.width,config.height))
319 | val_loader = DataLoader(val_dataset,
320 | batch_size=config.batch_size,
321 | shuffle=False,
322 | num_workers=config.n_threads,
323 | pin_memory=use_gpu)
324 | log.info("Number of validation examples {}".format(len(val_dataset)))
325 |
326 | if use_gpu:
327 | model.cuda()
328 | #model = torch.nn.DataParallel(model)
329 | optim_layers = filter(lambda p: p.requires_grad, model.parameters())
330 |
331 | # optimizer and lr scheduler
332 | optimizer = Adam(optim_layers,
333 | lr=config.lr,
334 | weight_decay=config.weight_decay)
335 | scheduler = ReduceLROnPlateau(optimizer=optimizer,
336 | factor=config.lr_reduce_factor,
337 | patience=config.lr_reduce_patience,
338 | mode='max',
339 | min_lr=1e-7)
340 |
341 | # Load the last global_step from the checkpoint if existing
342 | global_step = 0 if state is None else state['global_step'] + 1
343 |
344 | class_weights = to_device(torch.FloatTensor(config.loss_weights),gpu=use_gpu)
345 | loss_fn = CrossEntropyLoss(reduction='mean', weight=class_weights)
346 |
347 | # Reset the best metric score
348 | best_score = -1
349 | t0=time.time()
350 | for epoch in range(config.epochs):
351 | log.info("\nStarted epoch {}/{}".format(epoch + 1,config.epochs))
352 | for data in train_loader:
353 | imgs, labels = data
354 | imgs = to_device(imgs, gpu=use_gpu)
355 | labels = to_device(labels, gpu=use_gpu)
356 |
357 | logits = model(imgs)
358 | loss = loss_fn(logits, labels)
359 | optimizer.zero_grad()
360 | loss.backward()
361 | optimizer.step()
362 |
363 | if global_step % config.log_steps == 0 and global_step > 0:
364 | if isONN:
365 | preds = net.predict(logits).cpu().numpy()
366 | else:
367 | probs = model.module.probability(logits)
368 | preds = torch.argmax(probs, dim=1).detach().cpu().numpy()
369 | labels = labels.cpu().detach().numpy()
370 | acc, f1, _, _ = clf_metrics(preds, labels)
371 | lr = optimizer.param_groups[0]['lr'] #get_learning_rate(optimizer)
372 | print(f"\r{global_step} | batch: Loss={loss.item():.3f} | F1={f1:.3f} | Accuracy={acc:.4f} | LR={lr:.2e}\tT={time.time()-t0:.4f}",end="")
373 |
374 |
375 | if global_step % config.eval_steps == 0 and global_step > 0:
376 | best_score = validate(val_loader, model,best_score=best_score,global_step=global_step,cfg=config)
377 | scheduler.step(best_score)
378 | global_step += 1
379 |
380 | def UpdateConfig(config):
381 | config.name = "COVIDNext50_NewData"
382 | config.gpu = True
383 | config.batch_size = 16
384 | config.n_threads = 4
385 | config.random_seed = 1337
386 | config.weights = "E:/Insegment/COVID-Next-Pytorch-master/COVIDNext50_NewData_F1_92.98_step_10800.pth"
387 | config.lr = 1e-4
388 | config.weight_decay = 1e-3
389 | config.lr_reduce_factor = 0.7
390 | config.lr_reduce_patience = 5
391 | # Data
392 | config.train_imgs = None#"/data/ssd/datasets/covid/COVIDxV2/data/train"
393 | config.train_labels = "E:/ONNet/data/covid_train_split_v2.txt" #"/data/ssd/datasets/covid/COVIDxV2/data/train_COVIDx.txt"
394 | config.val_imgs = None#"/data/ssd/datasets/covid/COVIDxV2/data/test"
395 | config.val_labels = "E:/ONNet/data/covid_test_split_v2.txt" #"/data/ssd/datasets/covid/COVIDxV2/data/test_COVIDx.txt"
396 | # Categories mapping
397 | config.mapping = {
398 | 'normal': 0,
399 | 'pneumonia': 1,
400 | 'COVID-19': 2
401 | }
402 | # Loss weigths order follows the order in the category mapping dict
403 | config.loss_weights = [0.05, 0.05, 1.0]
404 |
405 | config.width = 256
406 | config.height = 256
407 | config.n_classes = len(config.mapping)
408 | # Training
409 | config.epochs = 300
410 | config.log_steps = 5
411 | config.eval_steps = 400
412 | config.ckpts_dir = "./experiments/ckpts"
413 | return config
414 |
415 | IMG_size = (256, 256)
416 | if __name__ == '__main__':
417 | config_0 = NET_config("DNet",'covid',IMG_size,0.01,batch_size=16, nClass=3, nLayer=5)
418 | #config_0 = RGBO_CNN_config("RGBO_CNN",'covid',IMG_size,0.01,batch_size=16, nClass=3, nLayer=5)
419 | if isONN:
420 | env_title, net = DNet_instance(config_0)
421 | #env_title, net = RGBO_CNN_instance(config_0)
422 | config = net.config
423 | config = UpdateConfig(config)
424 | config.batch_size = 64
425 | config.log_steps = 10
426 | config.lr = 0.001
427 | state = None
428 | else:
429 | config = UpdateConfig(config_0)
430 | if config.weights:
431 | state = torch.load(config.weights)
432 | log.info("Loaded model weights from: {}".format(config.weights))
433 | else:
434 | state = None
435 |
436 | state_dict = state["state_dict"] if state else None
437 | net = COVIDNext50(n_classes=config.n_classes)
438 | if state_dict:
439 | net = load_model_weights(model=net, state_dict=state_dict,log=log)
440 | print(net)
441 | Net_dump(net)
442 | seed_everything(config.random_seed)
443 | main(net)
444 |
--------------------------------------------------------------------------------
/case_dog_cat.py:
--------------------------------------------------------------------------------
1 | '''
2 | 1) https://github.com/rdcolema/pytorch-image-classification/blob/master/pytorch_model.ipynb
3 | https://github.com/mukul54/A-Simple-Cat-vs-Dog-Classifier-in-Pytorch/blob/master/catVsDog.py
4 | '''
5 | # https://github.com/mukul54/A-Simple-Cat-vs-Dog-Classifier-in-Pytorch/blob/master/catVsDog.py
6 |
7 | import numpy as np # Matrix Operations (Matlab of Python)
8 | import pandas as pd # Work with Datasources
9 | import matplotlib.pyplot as plt # Drawing Library
10 | from PIL import Image
11 | import torch # Like a numpy but we could work with GPU by pytorch library
12 | import torch.nn as nn # Nural Network Implimented with pytorch
13 | import torchvision # A library for work with pretrained model and datasets
14 | from torchvision import transforms
15 | from torch.utils.data import Dataset
16 | from torch.utils.data import DataLoader
17 | import torch.nn.functional as F
18 | import glob
19 | import os
20 |
21 | image_size = (100, 100)
22 | image_row_size = image_size[0] * image_size[1]
23 |
24 | if False: #https://medium.com/predict/using-pytorch-for-kaggles-famous-dogs-vs-cats-challenge-part-1-preprocessing-and-training-407017e1a10c
25 | import shutil
26 | import re
27 | files = os.listdir(train_dir)
28 | # Move all train cat images to cats folder, dog images to dogs folder
29 | for f in files:
30 | catSearchObj = re.search("cat", f)
31 | dogSearchObj = re.search("dog", f)
32 | if catSearchObj:
33 | shutil.move(f'{train_dir}/{f}', train_cats_dir)
34 | elif dogSearchObj:
35 | shutil.move(f'{train_dir}/{f}', train_dogs_dir)
36 | pass
37 |
38 | class CatDogDataset(Dataset):
39 | def __init__(self, path, transform=None):
40 | self.classes = ["cat","dog"] #os.listdir(path)
41 | self.path = path #[f"{path}/{className}" for className in self.classes]
42 | #self.file_list = [glob.glob(f"{x}/*") for x in self.path]
43 | self.transform = transform
44 |
45 | files = []
46 | for i, className in enumerate(self.classes):
47 | query = f"{self.path}{className}*"
48 | cls_list = glob.glob(query)
49 | print(f"{className}:n={len(cls_list)}")
50 | for fileName in cls_list:
51 | files.append([i, className, fileName])
52 | self.file_list = files
53 | files = None
54 |
55 | def __len__(self):
56 | return len(self.file_list)
57 |
58 | def __getitem__(self, idx):
59 | fileName = self.file_list[idx][2]
60 | classCategory = self.file_list[idx][0]
61 | im = Image.open(fileName)
62 | if self.transform:
63 | im = self.transform(im)
64 | return im.view(-1), classCategory
65 |
66 | #mean = [0.485, 0.456, 0.406]; std = [0.229, 0.224, 0.225]
67 | mean = [0.485]; std = [0.229]
68 | transform = transforms.Compose([
69 | transforms.Resize(image_size),
70 | transforms.Grayscale(),
71 | transforms.ToTensor(),
72 | transforms.Normalize(mean, std)])
73 |
74 | path = '../data/dog_cat/train/'
75 | dataset = CatDogDataset(path, transform=transform)
76 | if True:
77 | def imshow(source):
78 | plt.figure(figsize=(10,10))
79 | imt = (source.view(-1, image_size[0], image_size[0]))
80 | imt = imt.numpy().transpose([1,2,0])
81 | imt = (std * imt + mean).clip(0,1)
82 | plt.subplot(1,2,2)
83 | plt.imshow(imt.squeeze())
84 | imshow(dataset[0][0])
85 | imshow(dataset[2][0])
86 | imshow(dataset[6000][0])
87 | plt.show()
88 |
89 | shuffle = True
90 | batch_size = 64
91 | num_workers = 0
92 | dataloader = DataLoader(dataset=dataset,
93 | shuffle=shuffle,
94 | batch_size=batch_size,
95 | num_workers=num_workers)
96 |
97 | class MyModel(torch.nn.Module):
98 | def __init__(self, in_feature):
99 | super(MyModel, self).__init__()
100 | self.fc1 = torch.nn.Linear(in_features=in_feature, out_features=500)
101 | self.fc2 = torch.nn.Linear(in_features=500, out_features=100)
102 | self.fc3 = torch.nn.Linear(in_features=100, out_features=1)
103 |
104 | def forward(self, x):
105 | x = F.relu( self.fc1(x) )
106 | x = F.relu( self.fc2(x) )
107 | x = F.softmax( self.fc3(x), dim=1)
108 | return x
109 |
110 | model = MyModel(image_row_size)
111 | print(model)
112 |
113 | criterion = torch.nn.CrossEntropyLoss()
114 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.95)
115 |
116 | epochs = 10
117 | for epoch in range(epochs):
118 | for i, (X,Y) in enumerate(dataloader):
119 | # x, y = dataset[i]
120 | yhat = model(X)
121 | loss = criterion(yhat.view(-1), Y)
122 | break
--------------------------------------------------------------------------------
/case_face_detect.py:
--------------------------------------------------------------------------------
1 | '''
2 | https://github.com/jayrodge/Binary-Image-Classifier-PyTorch/blob/master/Binary_face_classifier.ipynb
3 | '''
4 |
5 | import torch
6 | import numpy as np
7 | from torchvision import datasets
8 | import torchvision.transforms as transforms
9 | from torch.utils.data.sampler import SubsetRandomSampler
10 | import matplotlib.pyplot as plt
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | train_on_gpu = torch.cuda.is_available()
15 | # define the CNN architecture
16 | class Net(nn.Module):
17 | def __init__(self):
18 | super(Net, self).__init__()
19 | # convolutional layer
20 | self.conv1 = nn.Conv2d(3, 16, 5)
21 | # max pooling layer
22 | self.pool = nn.MaxPool2d(2, 2)
23 | self.conv2 = nn.Conv2d(16, 32, 5)
24 | self.dropout = nn.Dropout(0.2)
25 | self.fc1 = nn.Linear(32 * 53 * 53, 256)
26 | self.fc2 = nn.Linear(256, 84)
27 | self.fc3 = nn.Linear(84, 2)
28 | self.softmax = nn.LogSoftmax(dim=1)
29 |
30 | def forward(self, x):
31 | # add sequence of convolutional and max pooling layers
32 | x = self.pool(F.relu(self.conv1(x)))
33 | x = self.pool(F.relu(self.conv2(x)))
34 | x = self.dropout(x)
35 | x = x.view(-1, 32 * 53 * 53)
36 | x = F.relu(self.fc1(x))
37 | x = self.dropout(F.relu(self.fc2(x)))
38 | x = self.softmax(self.fc3(x))
39 | return x
40 |
41 | batch_size = 32
42 | # percentage of training set to use as validation
43 | test_size = 0.3
44 | valid_size = 0.1
45 |
46 | def imshow(img):
47 | img = img / 2 + 0.5 # unnormalize
48 | plt.imshow(np.transpose(img, (1, 2, 0)))
49 |
50 | # convert data to a normalized torch.FloatTensor
51 | transform = transforms.Compose([
52 | transforms.RandomHorizontalFlip(),
53 | transforms.RandomRotation(20),
54 | transforms.Resize(size=(224,224)),
55 | transforms.ToTensor(),
56 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
57 | ])
58 |
59 | def load_data():
60 | data = datasets.ImageFolder('../data/Face/',transform=transform)
61 | num_data = len(data)
62 | indices_data = list(range(num_data))
63 | np.random.shuffle(indices_data)
64 | split_tt = int(np.floor(test_size * num_data))
65 | train_idx, test_idx = indices_data[split_tt:], indices_data[:split_tt]
66 |
67 | #For Valid
68 | num_train = len(train_idx)
69 | indices_train = list(range(num_train))
70 | np.random.shuffle(indices_train)
71 | split_tv = int(np.floor(valid_size * num_train))
72 | train_new_idx, valid_idx = indices_train[split_tv:],indices_train[:split_tv]
73 |
74 |
75 | # define samplers for obtaining training and validation batches
76 | train_sampler = SubsetRandomSampler(train_new_idx)
77 | test_sampler = SubsetRandomSampler(test_idx)
78 | valid_sampler = SubsetRandomSampler(valid_idx)
79 |
80 | train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size,
81 | sampler=train_sampler, num_workers=1)
82 | valid_loader = torch.utils.data.DataLoader(data, batch_size=batch_size,
83 | sampler=valid_sampler, num_workers=1)
84 | test_loader = torch.utils.data.DataLoader(data, sampler = test_sampler, batch_size=batch_size,
85 | num_workers=1)
86 | classes = [0,1]
87 |
88 | if False: # display 20 images
89 | dataiter = iter(train_loader)
90 | images, labels = dataiter.next()
91 | images = images.numpy()
92 | fig = plt.figure(figsize=(10, 4))
93 | for idx in np.arange(10):
94 | ax = fig.add_subplot(2, 10 / 2, idx + 1, xticks=[], yticks=[])
95 | imshow(images[idx])
96 | ax.set_title(classes[labels[idx]])
97 | plt.show()
98 | return train_loader,valid_loader,test_loader,classes
99 |
100 | def some_test(test_loader,classes):
101 | # track test loss
102 | test_loss = 0.0
103 | class_correct = list(0. for i in range(2))
104 | class_total = list(0. for i in range(2))
105 |
106 | model.eval()
107 | i = 1
108 | # iterate over test data
109 | len(test_loader)
110 | for data, target in test_loader:
111 | i = i + 1
112 | if len(target) != batch_size:
113 | continue
114 |
115 | # move tensors to GPU if CUDA is available
116 | if train_on_gpu:
117 | data, target = data.cuda(), target.cuda()
118 | # forward pass: compute predicted outputs by passing inputs to the model
119 | output = model(data)
120 | # calculate the batch loss
121 | loss = criterion(output, target)
122 | # update test loss
123 | test_loss += loss.item() * data.size(0)
124 | # convert output probabilities to predicted class
125 | _, pred = torch.max(output, 1)
126 | # compare predictions to true label
127 | correct_tensor = pred.eq(target.data.view_as(pred))
128 | correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())
129 | # calculate test accuracy for each object class
130 | # print(target)
131 |
132 | for i in range(batch_size):
133 | label = target.data[i]
134 | class_correct[label] += correct[i].item()
135 | class_total[label] += 1
136 |
137 | # average test loss
138 | test_loss = test_loss / len(test_loader.dataset)
139 | print('Test Loss: {:.6f}\n'.format(test_loss))
140 |
141 | for i in range(2):
142 | if class_total[i] > 0:
143 | print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
144 | classes[i], 100 * class_correct[i] / class_total[i],
145 | np.sum(class_correct[i]), np.sum(class_total[i])))
146 | else:
147 | print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))
148 |
149 | print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
150 | 100. * np.sum(class_correct) / np.sum(class_total),
151 | np.sum(class_correct), np.sum(class_total)))
152 |
153 | if __name__ == '__main__':
154 |
155 | model = Net()
156 | print(model)
157 |
158 | train_loader,valid_loader,test_loader,classes=load_data()
159 | # move tensors to GPU if CUDA is available
160 | if train_on_gpu:
161 | model.cuda()
162 | criterion = torch.nn.CrossEntropyLoss()
163 | optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
164 | n_epochs = 5 # you may increase this number to train a final model
165 |
166 | valid_loss_min = np.Inf # track change in validation loss
167 |
168 | for epoch in range(1, n_epochs + 1):
169 |
170 | # keep track of training and validation loss
171 | train_loss = 0.0
172 | valid_loss = 0.0
173 |
174 | ###################
175 | # train the model #
176 | ###################
177 | model.train()
178 | for data, target in train_loader:
179 | # move tensors to GPU if CUDA is available
180 | if train_on_gpu:
181 | data, target = data.cuda(), target.cuda()
182 | # clear the gradients of all optimized variables
183 | optimizer.zero_grad()
184 | # forward pass: compute predicted outputs by passing inputs to the model
185 | output = model(data)
186 | # calculate the batch loss
187 | loss = criterion(output, target)
188 | # backward pass: compute gradient of the loss with respect to model parameters
189 | loss.backward()
190 | # perform a single optimization step (parameter update)
191 | optimizer.step()
192 | # update training loss
193 | train_loss += loss.item() * data.size(0)
194 |
195 | ######################
196 | # validate the model #
197 | ######################
198 | model.eval()
199 | for data, target in valid_loader:
200 | # move tensors to GPU if CUDA is available
201 | if train_on_gpu:
202 | data, target = data.cuda(), target.cuda()
203 | # forward pass: compute predicted outputs by passing inputs to the model
204 | output = model(data)
205 | # calculate the batch loss
206 | loss = criterion(output, target)
207 | # update average validation loss
208 | valid_loss += loss.item() * data.size(0)
209 |
210 | # calculate average losses
211 | train_loss = train_loss / len(train_loader.dataset)
212 | valid_loss = valid_loss / len(valid_loader.dataset)
213 |
214 | # print training/validation statistics
215 | print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
216 | epoch, train_loss, valid_loss))
217 |
218 | # save model if validation loss has decreased
219 | if valid_loss <= valid_loss_min:
220 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
221 | valid_loss_min,
222 | valid_loss))
223 | #torch.save(model.state_dict(), 'model_cifar.pt')
224 | valid_loss_min = valid_loss
225 |
226 | some_test(test_loader,classes)
227 |
--------------------------------------------------------------------------------
/case_mnist.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torchvision import datasets, transforms
8 | from torch.optim.lr_scheduler import StepLR
9 | import os
10 | import sys
11 | ONNET_DIR = os.path.abspath("./python-package/")
12 | sys.path.append(ONNET_DIR) # To find local version of the onnet
13 | from onnet import *
14 | import torchvision
15 | import cv2
16 | import math
17 | import matplotlib.pyplot as plt
18 | import numpy as np
19 |
20 | #dataset="emnist"
21 | #dataset="fasion_mnist"
22 | #dataset="cifar"
23 | dataset="mnist"
24 | # IMG_size = (28, 28)
25 | # IMG_size = (56, 56)
26 | IMG_size = (112, 112)
27 | # IMG_size = (14, 14)
28 | batch_size = 128
29 |
30 | #net_type = "OptFormer"
31 | #net_type = "cnn"
32 | net_type = "DNet"
33 | #net_type = "WNet"
34 | #net_type = "MF_WNet"
35 | #net_type = "MF_DNet";
36 | #net_type = "BiDNet"
37 |
38 | class Fasion_Net(nn.Module): #https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
39 | def __init__(self):
40 | super(Net, self).__init__()
41 | self.conv1 = nn.Conv2d(1, 6, 5)
42 | self.pool = nn.MaxPool2d(2, 2)
43 | self.conv2 = nn.Conv2d(6, 16, 5)
44 | self.fc1 = nn.Linear(16 * 4 * 4, 120)
45 | self.fc2 = nn.Linear(120, 84)
46 | self.fc3 = nn.Linear(84, 10)
47 |
48 | def forward(self, x):
49 | x = self.pool(F.relu(self.conv1(x)))
50 | x = self.pool(F.relu(self.conv2(x)))
51 | x = x.view(-1, 16 * 4 * 4)
52 | x = F.relu(self.fc1(x))
53 | x = F.relu(self.fc2(x))
54 | x = self.fc3(x)
55 | return x
56 |
57 | class Mnist_Net(nn.Module):
58 | def __init__(self,config, nCls=10):
59 | super(Mnist_Net, self).__init__()
60 | self.title = "Mnist_Net"
61 | self.config = config
62 | self.config.learning_rate = 0.01
63 | self.conv1 = nn.Conv2d(1, 32, 3, 1)
64 | self.conv2 = nn.Conv2d(32, 64, 3, 1)
65 | self.isDropOut = False
66 | self.nFC=1
67 | if self.isDropOut:
68 | self.dropout1 = nn.Dropout2d(0.25)
69 | self.dropout2 = nn.Dropout2d(0.5)
70 | if IMG_size[0]==56:
71 | nFC1 = 43264
72 | else:
73 | nFC1 = 9216
74 | if self.nFC == 1:
75 | self.fc1 = nn.Linear(nFC1, 10)
76 | else:
77 | self.fc1 = nn.Linear(nFC1, 128)
78 | self.fc2 = nn.Linear(128, 10)
79 | self.loss = F.cross_entropy
80 | self.nClass = nCls
81 |
82 | def forward(self, x):
83 | x = self.conv1(x)
84 | x = F.relu(x)
85 | x = self.conv2(x)
86 | x = F.max_pool2d(x, 2)
87 | if self.isDropOut:
88 | x = self.dropout1(x)
89 | x = torch.flatten(x, 1)
90 | x = self.fc1(x)
91 | x = F.relu(x)
92 | if self.isDropOut:
93 | x = self.dropout2(x)
94 | if self.nFC == 2:
95 | x = self.fc2(x)
96 | #output = F.log_softmax(x, dim=1)
97 | output = x
98 | return output
99 |
100 | def predict(self,output):
101 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
102 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
103 | return pred
104 |
105 | class View(nn.Module):
106 | def __init__(self, *args):
107 | super(View, self).__init__()
108 | self.shape = args
109 |
110 | def forward(self, x):
111 | return x.view(-1,*self.shape)
112 |
113 | train_trans = transforms.Compose([
114 | #transforms.RandomAffine(5,translate=(0,0.1)),
115 | #transforms.RandomRotation(10),
116 | #transforms.Grayscale(),
117 | transforms.Resize(IMG_size),
118 | transforms.ToTensor(),
119 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #Convert a color image to grayscale and normalize the color range to [0,1].
120 | #transforms.Normalize((0.1307,), (0.3081,))
121 | ])
122 | test_trans = transforms.Compose([
123 | #transforms.Grayscale(),
124 | transforms.Resize(IMG_size),
125 | transforms.ToTensor(),
126 | #transforms.Normalize((0.1307,), (0.3081,))
127 | ])
128 |
129 | def train(model, device, train_loader, epoch, optical_trans,visual):
130 | #model.visual = visual
131 | # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9,weight_decay=0.0005)
132 | optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate, weight_decay=0.0005)
133 | if epoch==1:
134 | print(f"\n=======dataset={dataset} net={net_type} IMG_size={IMG_size} batch_size={batch_size}")
135 | print(f"======={model.config}")
136 | print(f"======={optimizer}")
137 | print(f"======={train_trans}\n")
138 |
139 | nClass = model.nClass
140 | model.train()
141 | for batch_idx, (data, target) in enumerate(train_loader):
142 | if batch_idx==0: #check data_range
143 | d0,d1=data.min(),data.max()
144 | assert(d0>=0)
145 | data, target = data.to(device), target.to(device)
146 | optimizer.zero_grad()
147 | output = model(optical_trans(data))
148 | #output = model(data)
149 | loss = model.loss(output, target)
150 | loss.backward()
151 | optimizer.step()
152 | if batch_idx % 50 == 0:
153 | aLoss = loss.item()
154 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
155 | epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader),aLoss ))
156 | #visual.UpdateLoss(title=f"Accuracy on \"{dataset}\"", legend=f"{model.legend()}", loss=aLoss, yLabel="Accuracy")
157 | #break
158 |
159 | def test_one_batch(model,data,target,device):
160 | data, target = data.to(device), target.to(device)
161 | output = model(data)
162 | # output = model(data)
163 | loss = model.loss(output, target, reduction='sum').item() # sum up batch loss
164 | pred = model.predict(output)
165 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
166 | correct = pred.eq(target.view_as(pred)).sum().item()
167 | return loss,correct
168 |
169 | def test(model, device, test_loader, optical_trans,visual):
170 | model.eval()
171 | test_loss = 0
172 | correct = 0
173 | with torch.no_grad():
174 | for data, target in test_loader:
175 | loss, corr = test_one_batch(model, data, target, device)
176 | test_loss += loss
177 | correct += corr
178 | if False:
179 | data, target = data.to(device), target.to(device)
180 | if optical_trans is not None: data = optical_trans(data)
181 | output = model(data)
182 | #output = model(data)
183 | test_loss += model.loss(output, target, reduction='sum').item() # sum up batch loss
184 | pred = model.predict(output)
185 | #pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
186 | correct += pred.eq(target.view_as(pred)).sum().item()
187 |
188 | test_loss /= len(test_loader.dataset)
189 | accu = 100. * correct / len(test_loader.dataset)
190 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(test_loss, correct, len(test_loader.dataset),accu))
191 | if visual is not None:
192 | visual.UpdateLoss(title=f"Accuracy on \"{dataset}\"",legend=f"{model.legend()}", loss=accu,yLabel="Accuracy")
193 | return accu
194 |
195 | def Some_Test():
196 | use_cuda = torch.cuda.is_available()
197 | device = torch.device("cuda" if use_cuda else "cpu")
198 | model_path = "E:/ONNet/checkpoint/DNNet_exp_W_H_Express Wavenet_[17,81.91]_.pth"
199 | PTH = torch.load(model_path)
200 | env_title, model = DNet_instance(PTH['net_type'], PTH['dataset'],
201 | PTH['IMG_size'], PTH['lr_base'], PTH['batch_size'], PTH['nClass'], PTH['nLayer'])
202 | epoch, acc = PTH['epoch'], PTH['acc']
203 | model.load_state_dict(PTH['net'])
204 | model.to(device)
205 | print(f"Load model@{model_path} epoch={epoch},acc={acc}")
206 |
207 | visual = Visdom_Visualizer(env_title,plots=[{"object":"output"}])
208 | visual.img_dir = "./dump/X_images/"
209 | test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('./data', train=False,transform=test_trans),
210 | batch_size=batch_size, shuffle=False)
211 | if True: #only one batch
212 | dataiter = iter(test_loader)
213 | images, target = dataiter.next()
214 | model.visual = visual
215 | loss,correct = test_one_batch(model, images, target, device)
216 | model.visual = None
217 |
218 | if False:
219 | acc_1 = test(model, device, test_loader, None, None)
220 | print(f"Some_Test acc={acc}-{acc_1}")
221 |
222 | def main():
223 | #OnInitInstance()
224 | lr_base = 0.002
225 | parser = argparse.ArgumentParser(description='MNIST optical_trans + hybrid examples')
226 | parser.add_argument('--mode', type=int, default=2,help='optical_trans 1st or 2nd order')
227 | parser.add_argument('--classifier', type=str, default='linear',help='classifier model')
228 | args = parser.parse_args()
229 | assert(args.classifier in ['linear','mlp','cnn'])
230 |
231 | use_cuda = torch.cuda.is_available()
232 | device = torch.device("cuda" if use_cuda else "cpu")
233 | optical_trans = OpticalTrans()
234 |
235 | # DataLoaders
236 | if use_cuda:
237 | num_workers = 4
238 | pin_memory = True
239 | else:
240 | num_workers = None
241 | pin_memory = False
242 |
243 | nLayer = 10
244 | if dataset=="emnist":
245 | train_loader = torch.utils.data.DataLoader(
246 | datasets.EMNIST('./data',split="balanced", train=True, download=True, transform=train_trans),
247 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
248 | test_loader = torch.utils.data.DataLoader(
249 | datasets.EMNIST('./data',split="balanced", train=False, transform=test_trans),
250 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
251 | # balanced=47 byclass=62
252 | nClass = 47
253 | elif dataset=="fasion_mnist":
254 | train_loader = torch.utils.data.DataLoader(
255 | datasets.FashionMNIST('./data',train=True, download=True, transform=train_trans),
256 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
257 | test_loader = torch.utils.data.DataLoader(
258 | datasets.FashionMNIST('./data',train=False, transform=test_trans),
259 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
260 | nClass = 10
261 | elif dataset=="cifar":
262 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./data',train=True, download=True, transform=train_trans),
263 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
264 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./data',train=False, transform=test_trans),
265 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
266 | nClass = 10; lr_base=0.005
267 | else:
268 | nClass = 10
269 | train_loader = torch.utils.data.DataLoader(
270 | datasets.MNIST('./data', train=True, download=True,transform=train_trans),
271 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
272 | test_loader = torch.utils.data.DataLoader(
273 | datasets.MNIST('./data', train=False,transform=test_trans),
274 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
275 |
276 | config_0 = NET_config(net_type,dataset,IMG_size,lr_base,batch_size,nClass,nLayer)
277 | env_title, model = DNet_instance(config_0) #net_type,dataset,IMG_size,lr_base,batch_size,nClass,nLayer
278 | visual = Visdom_Visualizer(env_title=env_title)
279 | # visual = Visualize(env_title=env_title)
280 | model.to(device)
281 | print(model)
282 | # visual.ShowModel(model,train_loader)
283 |
284 | if False: # So strange in initialize
285 | for m in model.modules():
286 | if isinstance(m, nn.Conv2d):
287 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
288 | m.weight.data.normal_(0, 2. / math.sqrt(n))
289 | m.bias.data.zero_()
290 | if isinstance(m, nn.Linear):
291 | m.weight.data.normal_(0, 2. / math.sqrt(m.in_features))
292 | m.bias.data.zero_()
293 |
294 | nzParams = Net_dump(model)
295 | if False:
296 | nzParams=0
297 | for name, param in model.named_parameters():
298 | if param.requires_grad:
299 | nzParams+=param.nelement()
300 | print(f"\t{name}={param.nelement()}")
301 | print(f"========All parameters={nzParams}")
302 |
303 | acc,best_acc = 0,0
304 | accu_=[]
305 | for epoch in range(1, 33):
306 | if False:
307 | assert os.path.isdir('checkpoint')
308 | pth_path = f'./checkpoint/{model.title}_[{epoch},{acc}]_.pth'
309 | torch.save({'net': model.state_dict(), 'acc': acc, 'epoch': epoch,}, pth_path)
310 |
311 | if hasattr(model,'visualize'):
312 | model.visualize(visual, f"E[{epoch-1}")
313 | train( model, device, train_loader, epoch, optical_trans,visual)
314 | acc = test(model, device, test_loader, optical_trans,visual)
315 | accu_.append(acc)
316 | if acc > best_acc:
317 | state = {
318 | 'net_type':net_type,'dataset':dataset,'IMG_size':IMG_size,'lr_base':lr_base,
319 | 'batch_size':batch_size,'nClass':nClass, 'nLayer':nLayer,
320 | 'net': model.state_dict(), 'acc': acc,'epoch': epoch,
321 | }
322 | assert os.path.isdir('checkpoint')
323 | pth_path = f'./checkpoint/{model.title}_[{epoch},{acc}]_.pth'
324 | torch.save(state, pth_path)
325 | best_acc = acc
326 | print(f"\n=======\n=======accu_history={accu_}\n")
327 |
328 | #if args.save_model:
329 | # torch.save(model.state_dict(), "mnist_onn.pt")
330 |
331 | '''
332 | 单衍射层测试算例
333 | 1) PIL加载图片 2)DiffractiveLayer forward 3)plt显示
334 | '''
335 | def layer_test():
336 | from PIL import Image
337 | img = Image.open("E:/ONNet/data/MNIST/test_2.jpg")
338 | img = train_trans(img)
339 |
340 | config=NET_config(net_type,dataset,IMG_size,0.01,32,10,5)
341 | config.modulation = 'phase'
342 | config.init_value = "random"
343 | config.rDrop = 0 #drop out
344 | layer = DiffractiveLayer(IMG_size[0],IMG_size[1],config)
345 |
346 | out = layer.forward(img.cuda())
347 | im_out = layer.z_modulus(out)
348 | im_out = im_out.squeeze().cpu().detach().numpy()
349 |
350 | fig, ax = plt.subplots()
351 | #plt.axis('off')
352 | plt.grid(b=None)
353 | im = ax.imshow(im_out, interpolation='nearest', cmap='coolwarm')
354 | title = f"{layer.__repr__()}"
355 | ax.set_title(title,fontsize=12)
356 | fig.colorbar(im, orientation='horizontal')
357 | plt.show()
358 | plt.close()
359 |
360 | print("!!!Good Luck!!!")
361 |
362 | if __name__ == '__main__':
363 | #Some_Test()
364 | #layer_test()
365 | main()
366 |
--------------------------------------------------------------------------------
/python-package/case_fft.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms
2 | from PIL import Image
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from onnet import *
6 | import torch
7 | from skimage import io, transform
8 | torch.set_printoptions(profile="full")
9 |
10 | size = 28
11 | delta = 0.03
12 | dL = 0.02
13 | c = 3e8
14 | Hz = 0.4e12
15 |
16 | def Init_H(d=delta, N = size, dL = dL, lmb = c/Hz,theta=0.0):
17 | # Parameter
18 | df = 1.0 / dL
19 | k = np.pi * 2.0 / lmb
20 | D = dL * dL / (N * lmb)
21 | # phase
22 | def phase(i, j):
23 | i -= N // 2
24 | j -= N // 2
25 | return ((i * df) * (i * df) + (j * df) * (j * df))
26 |
27 |
28 | ph = np.fromfunction(phase, shape=(N, N), dtype=np.float32)
29 | # H
30 | H = np.exp(1.0j * k * d) * np.exp(-1.0j * lmb * np.pi * d * ph)
31 | H_f = np.fft.fftshift(H)
32 | #print(H_f); print(H)
33 | return H,H_f
34 |
35 | def fft_test(H_f,N = 28):
36 | dL = 0.02
37 | s = dL * dL / (N * N)
38 |
39 | normalize = transforms.Normalize(
40 | mean=[0.485, 0.456, 0.406],
41 | std=[0.229, 0.224, 0.225]
42 | )
43 | preprocess = transforms.Compose([
44 | #transforms.Resize(256),
45 | #transforms.CenterCrop(224),
46 | transforms.ToTensor(),
47 | #normalize
48 | ])
49 | image = io.imread("E:/ONNet/data/MNIST/test_2.jpg").astype(np.float64)
50 | #print(image)
51 | img_tensor = torch.from_numpy(image)
52 | #print(img_tensor)
53 | #img_tensor.unsqueeze_(0)
54 | print(img_tensor.shape, img_tensor.dtype)
55 | u0 = COMPLEX_utils.ToZ(img_tensor)
56 | print(u0.shape, H_f.shape);
57 |
58 | u1 = COMPLEX_utils.fft(u0)
59 | print(u1)
60 | H_z = np.zeros(H_f.shape + (2,))
61 | H_z[..., 0] = H_f.real
62 | H_z[..., 1] = H_f.imag
63 | H_f = torch.from_numpy(H_z)
64 | u1 = COMPLEX_utils.Hadamard(H_f,u1) #H_f * u1
65 | print(u1)
66 | u1 = COMPLEX_utils.fft(u1 ,"C2C",inverse=True)
67 | print(u1)
68 | input(...)
69 |
70 | if __name__ == '__main__':
71 | H, H_f = Init_H()
72 | fft_test(H_f)
--------------------------------------------------------------------------------
/python-package/cnn_models/OpticalNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import sys
5 | import os
6 | #ONNET_DIR = os.path.abspath("../../")
7 | sys.path.append("../") # To find local version of the onnet
8 | from onnet import *
9 | from onnet import DiffractiveLayer
10 |
11 | class OpticalBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self,config, in_planes, planes, stride=1):
15 | super(OpticalBlock, self).__init__()
16 | self.config = config
17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 |
22 | self.shortcut = nn.Sequential()
23 | M,N = self.config.IMG_size[0], self.config.IMG_size[1]
24 | self.diffrac = DiffractiveLayer(M,N,config)
25 | if stride != 1 or in_planes != self.expansion*planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | #assert x.shape[-1]==32 and x.shape[-2]==32
36 | #out += self.diffrac(x)
37 | out = F.relu(out)
38 | return out
39 |
40 |
41 | class OpticalNet(nn.Module):
42 | def __init__(self, config,block, num_blocks):
43 | super(OpticalNet, self).__init__()
44 | num_classes = config.nClass
45 | self.config = config
46 | self.in_planes = 64
47 |
48 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(64)
50 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
51 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
52 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
53 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
54 | self.linear = nn.Linear(512*block.expansion, num_classes)
55 |
56 | def _make_layer(self, block, planes, num_blocks, stride):
57 | strides = [stride] + [1]*(num_blocks-1)
58 | layers = []
59 | for stride in strides:
60 | layers.append(block(self.config,self.in_planes, planes, stride))
61 | self.in_planes = planes * block.expansion
62 | return nn.Sequential(*layers)
63 |
64 | def forward(self, x):
65 | out = F.relu(self.bn1(self.conv1(x)))
66 | out = self.layer1(out)
67 | out = self.layer2(out)
68 | out = self.layer3(out)
69 | out = self.layer4(out)
70 | out = F.avg_pool2d(out, 4)
71 | out = out.view(out.size(0), -1)
72 | out = self.linear(out)
73 | return out
74 |
75 |
76 | def OpticalNet18(config):
77 | return OpticalNet(config,OpticalBlock, [2,2,2,2])
78 |
79 | def OpticalNet34(config):
80 | return OpticalNet(config,OpticalBlock, [3,4,6,3])
81 |
82 | def test():
83 | net = OpticalNet18()
84 | y = net(torch.randn(1,3,32,32))
85 | print(y.size())
86 |
87 | # test()
88 |
--------------------------------------------------------------------------------
/python-package/fast_conv.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Author: Yingshi Chen
3 |
4 | @Date: 2020-03-04 14:50:24
5 | @
6 | # Description:
7 | '''
8 |
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 | import time
12 | import sys
13 | sys.path.append('..')
14 | #from deap.convolve import convDEAP_GIP
15 | from scipy.signal import convolve2d
16 | import matplotlib.pyplot as plt
17 | from deap.helpers import getOutputShape
18 | from deap.mappers import PhotonicConvolverMapper
19 | from deap.mappers import ModulatorArrayMapper
20 | from deap.mappers import PWBArrayMapper
21 |
22 | class MRMTransferFunction:
23 | """
24 | Computes the transfer function of a microring modulator (MRM).
25 | """
26 | def __init__(self, a=0.9, r=0.9):
27 | self.a = a
28 | self.r = r
29 | self._maxThroughput = self.throughput(np.pi)
30 |
31 | def throughput(self, phi):
32 | I_pass = self.a**2 - 2 * self.r * self.a * np.cos(phi) + self.r**2
33 | I_input = 1 - 2 * self.a * self.r * np.cos(phi) + (self.r * self.a)**2
34 | return I_pass / I_input
35 |
36 | def phaseFromThroughput(self, Tn):
37 | Tn = np.asarray(Tn)
38 |
39 | # Create variable to store results
40 | ans = np.empty_like(Tn)
41 |
42 | # For high throuputs, set to pi
43 | moreThanMax = Tn >= self._maxThroughput
44 | maxOrLess = ~moreThanMax
45 | ans[moreThanMax] = np.pi
46 |
47 | # Now solve the remainng
48 | cos_phi = Tn[maxOrLess] * (1 + (self.r * self.a)**2) - self.a**2 - self.r**2 # noqa
49 | ans[maxOrLess] = np.arccos(cos_phi / (-2 * self.r * self.a * (1 - Tn[maxOrLess]))) # noqa
50 | #ans = np.arccos(cos_phi / (-2 * self.r * self.a * (1 - Tn[maxOrLess])))
51 |
52 | return ans
53 |
54 | def convDEAP(image, kernel, stride, bias=0, normval=255):
55 | """
56 | Image is a 3D matrix with index values row, col, depth, index
57 | Kernel is a 4D matrix with index values row, col, depth, index.
58 | The depth of the kernel must be equal to the depth of the input.
59 | """
60 | assert image.shape[2] == kernel.shape[2]
61 |
62 | # Allocate memory for storing result of convolution
63 | outputShape = getOutputShape(image.shape, kernel.shape, stride=stride)
64 | output = np.zeros(outputShape)
65 |
66 | # Build the photonic circuit
67 | weightBanks = []
68 | inputShape = (kernel.shape[0], kernel.shape[1])
69 | for k in range(image.shape[2]):
70 | pc = PhotonicConvolverMapper.build(
71 | imageShape=inputShape,
72 | kernelShape=inputShape,
73 | power=normval)
74 | weightBanks.append(pc)
75 |
76 | for k in range(kernel.shape[3]):
77 | # Load weights
78 | weights = kernel[:, :, :, k]
79 | for c in range(weights.shape[2]):
80 | PWBArrayMapper.updateKernel(
81 | weightBanks[c].pwbArray,
82 | weights[:, :, c])
83 |
84 | for h in range(0, outputShape[0], stride):
85 | for w in range(0, outputShape[1], stride):
86 | # Load inputs
87 | inputs = \
88 | image[h:min(h + kernel.shape[0], image.shape[0]),
89 | w:min(w + kernel.shape[0], image.shape[1]), :]
90 | for c in range(kernel.shape[2]):
91 | ModulatorArrayMapper.updateInputs(
92 | weightBanks[c].modulatorArray,
93 | inputs[:, :, c],
94 | normval=normval)
95 |
96 | # Perform convolution:
97 | for c in range(kernel.shape[2]):
98 | output[h, w, k] += weightBanks[c].step()
99 | output[h, w, k] += bias
100 |
101 | return output
102 |
103 | def convDEAP_GIP(image, kernel, stride, convolverShape=None):
104 | """
105 | Image is a 3D matrix with index values row, col, depth, index
106 | Kernel is a 4D matrix with index values row, col, depth, index.
107 | The depth of the kernel must be equal to the depth of the input.
108 | """
109 | assert image.shape[2] == kernel.shape[2]
110 | assert kernel.shape[2] == 1 and kernel.shape[3] == 1
111 | if convolverShape is None:
112 | convolverShape = image.shape
113 |
114 | # Define convolutional parameters
115 | Hm, Wm = convolverShape[0], convolverShape[1]
116 | H, W = image.shape[0], image.shape[1]
117 | R = kernel.shape[0]
118 |
119 | # Allocate memory for storing result of convolution
120 | outputShape = getOutputShape(image.shape, kernel.shape, stride=stride)
121 | output = np.zeros(outputShape)
122 |
123 | # Load weights
124 | pc = PhotonicConvolverMapper.build(imageShape=convolverShape,kernel=kernel[:, :, 0, 0], power=255)
125 |
126 | input_buffer = np.zeros(convolverShape)
127 | normval=255
128 | _mrm = MRMTransferFunction()
129 | for h in range(0, H - R + 1, Hm - R + 1):
130 | for w in range(0, W - R + 1, Wm - R + 1):
131 | inputs = image[h:min(h + Hm, H), w:min(w + Wm, W), 0]
132 | # Load inputs into a buffer if convolution shape doesn't tile
133 | # nicely.
134 | input_buffer[:inputs.shape[0], :inputs.shape[1]] = inputs
135 | input_buffer[inputs.shape[0]:, inputs.shape[1]:] = 0
136 |
137 | if False:
138 | ModulatorArrayMapper.updateInputs(pc.modulatorArray,input_buffer,normval=255)
139 | else:
140 | #phaseShifts = ModulatorArrayMapper.computePhaseShifts(input_buffer, normval=255)
141 | normalized = input_buffer / normval
142 | assert not np.any(input_buffer < 0)
143 | phaseShifts = _mrm.phaseFromThroughput(normalized)
144 | pc.modulatorArray._update(phaseShifts)
145 |
146 | # Perform the convolution and store to memory
147 | result = pc.step()[:min(h + Hm, H) - h - R + 1,
148 | :min(w + Wm, W) - w - R + 1]
149 | output[h:min(h + Hm, H) - R + 1,
150 | w:min(w + Hm, W) - R + 1,
151 | 0] = result
152 |
153 | return output
154 |
155 | def main():
156 | image = plt.imread("./data/bass.jpg")
157 | greyscale = np.mean(image, axis=2)
158 |
159 | # Define kernel
160 | gaussian_kernel = np.zeros((3, 3, 1, 1))
161 | gaussian_kernel[:, :, 0, 0] = \
162 | np.array([
163 | [1, 2, 1],
164 | [2, 4, 2],
165 | [1, 2, 1]]) * 1/16
166 |
167 |
168 | # Perform convolution
169 | paddedInputs = np.pad(greyscale, (2, 2), 'constant')
170 | paddedInputs = np.expand_dims(paddedInputs, 2)
171 | convolved = convDEAP_GIP(paddedInputs, gaussian_kernel, 1, (12, 12))
172 | t0=time.time()
173 | for i in range(10):
174 | convDEAP_GIP(paddedInputs, gaussian_kernel, 1, (12, 12))
175 | print(f"convDEAP_GIP T_10={time.time()-t0:.3f}")
176 |
177 |
178 | t0=time.time()
179 | for i in range(10):
180 | convolve2d(greyscale, gaussian_kernel[:, :, 0, 0])
181 | print(f"convolve2d T_10={time.time()-t0:.3f}")
182 | conv_scipy = convolve2d(greyscale, gaussian_kernel[:, :, 0, 0])
183 |
184 | err = np.abs(convolved[:, :, 0] - conv_scipy)
185 | mse = np.sum(err**2) / (err.size)
186 | print("MSE distance per pixel", mse)
187 |
188 | if __name__ == '__main__':
189 | main()
--------------------------------------------------------------------------------
/python-package/onnet/BinaryDNet.py:
--------------------------------------------------------------------------------
1 | from .D2NNet import *
2 | import math
3 | import random
4 |
5 | class GatePipe(torch.nn.Module):
6 | def __init__(self,M,N, nHidden,config,pooling="max"):
7 | super(GatePipe, self).__init__()
8 | self.config = config
9 | self.M=M
10 | self.N=N
11 | self.nHidden = nHidden
12 | self.pooling = pooling
13 | self.layers = nn.ModuleList([DiffractiveLayer(self.M, self.N, self.config, HZ=0.3e12) for j in range(self.nHidden)])
14 | if True:
15 | chunk_dim = -1 if random.choice([True, False]) else -2
16 | self.pool = ChunkPool(2, self.config,pooling=self.pooling,chunk_dim=chunk_dim)
17 | else:
18 | self.pt1 = (random.randint(0, self.M-1),random.randint(0,self.N-1))
19 | self.pt2 = (random.randint(0, self.M - 1), random.randint(0, self.N - 1))
20 |
21 | def __repr__(self):
22 | main_str = super(GatePipe, self).__repr__()
23 | main_str = f"GatePipe_[{len(self.layers)}]_pool[{self.pooling}]"
24 | return main_str
25 |
26 | def forward(self, x):
27 | for lay in self.layers:
28 | x = lay(x)
29 | x1 = Z.modulus(x).cuda()
30 | #x1 = Z.phase(x).cuda()
31 | if True:
32 | x1 = self.pool(x1)
33 | else:
34 | x_pt1 = x1[:, 0, self.pt1[0], self.pt1[1]]
35 | x_pt2 = x1[:, 0, self.pt2[0], self.pt2[1]]
36 | x1 = torch.stack([x_pt1,x_pt2], 1)
37 | x2 = F.log_softmax(x1, dim=1)
38 | return x2
39 |
40 | class BinaryDNet(D2NNet):
41 | @staticmethod
42 | def binary_loss(output, target, reduction='mean'):
43 | nGate = len(output)
44 | nSamp = target.shape[0]
45 | loss =0
46 | for i in range(nGate):
47 | target_i = target%2
48 | # loss = F.binary_cross_entropy(output, target, reduction=reduction)
49 | loss_i = F.cross_entropy(output[i], target_i, reduction=reduction)
50 | loss += loss_i
51 | target =(target-target_i)/2
52 |
53 | # loss = F.nll_loss(output, target, reduction=reduction)
54 | return loss
55 |
56 | def predict(self,output):
57 | nGate = len(output)
58 | pred = 0
59 | for i in range(nGate):
60 | pred_i = output[nGate-1-i].max(1, keepdim=True)[1] # get the index of the max log-probability
61 | pred = pred*2+pred_i
62 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
63 | return pred
64 |
65 | def __init__(self, IMG_size,nCls,nInterDifrac,nOutDifac,config):
66 | super(BinaryDNet, self).__init__(IMG_size,nCls,nInterDifrac,config)
67 | self.nGate = (int)(math.ceil(math.log2(self.nClass)))
68 | self.nOutDifac = nOutDifac
69 | self.gates = nn.ModuleList( [GatePipe(self.M,self.N,nOutDifac,config,pooling="mean") for i in range(self.nGate)] )
70 | self.config = config
71 | self.loss = BinaryDNet.binary_loss
72 |
73 | def __repr__(self):
74 | main_str = super(BinaryDNet, self).__repr__()
75 | main_str += f"_nGate={self.nGate}_Difrac=[{self.nDifrac},{self.nOutDifac}]"
76 | return main_str
77 |
78 | def legend(self):
79 | title = f"BinaryDNet"
80 | return title
81 |
82 | def forward(self, x):
83 | x = x.double()
84 | for layD in self.DD:
85 | x = layD(x)
86 |
87 | nSamp = x.shape[0]
88 | output = []
89 | if True:
90 | for gate in self.gates:
91 | x1 = gate(x)
92 | output.append(x1)
93 | else:
94 | for [diffrac,gate] in self.gates:
95 | x1 = diffrac(x)
96 | x1 = self.z_modulus(x1).cuda()
97 | x1 = gate(x1)
98 | x2 = F.log_softmax(x1, dim=1)
99 | output.append(x2)
100 |
101 | return output
102 |
103 |
--------------------------------------------------------------------------------
/python-package/onnet/D2NN_tf.py:
--------------------------------------------------------------------------------
1 | '''
2 | https://github.com/computational-imaging/opticalCNN
3 | https://github.com/Lyn-Wu/Lyn/blob/master/DNN
4 | '''
5 | from tensorflow.examples.tutorials.mnist import input_data
6 | import tensorflow as tf
7 | import numpy as np
8 | from scipy.misc import imresize
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | from skimage import io, transform
12 |
13 | learning_rate = 0.01
14 | #size = 512
15 | size = 28
16 | delta = 0.03
17 | dL = 0.02
18 | batch_size = 64
19 | batch = 10
20 | #mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
21 | mnist = input_data.read_data_sets("E:/ONNet/data/MNIST/raw",one_hot=True)
22 | c = 3e8
23 | Hz = 0.4e12
24 |
25 | def fft_test(N = size):
26 | s = dL * dL / (N * N)
27 | if False:
28 | img_raw = tf.io.read_file("E:/ONNet/data/MNIST/test_2.jpg")
29 | img_raw = tf.image.decode_jpeg(img_raw)
30 | else: #tf.io与skimage.io居然不一样,令人难以理解
31 | img_raw = io.imread("E:/ONNet/data/MNIST/test_2.jpg")
32 | #print(img_raw)
33 | img_tensor = tf.squeeze(img_raw)
34 | with tf.Session() as sess:
35 | img_tensor = img_tensor.eval()
36 | print(img_tensor.shape,img_tensor.dtype)
37 | #print(img_tensor)
38 |
39 | u0 = tf.cast(img_tensor,dtype=tf.complex64)
40 | print(u0.shape,H_f.shape);
41 | u1 = tf.fft2d(u0)
42 | with tf.Session() as sess:
43 | print(u0.eval())
44 | print(u1.eval())
45 | u1 = H_f * u1
46 | u2 = tf.ifft2d(u1 )
47 | with tf.Session() as sess:
48 | print(u1.eval())
49 | print(u2.eval())
50 |
51 | def Init_H(d=delta, N = size, dL = dL, lmb = c/Hz,theta=0.0):
52 | # Parameter
53 | df = 1.0 / dL
54 | k = np.pi * 2.0 / lmb
55 | D = dL * dL / (N * lmb)
56 | # phase
57 | def phase(i, j):
58 | i -= N // 2
59 | j -= N // 2
60 | return ((i * df) * (i * df) + (j * df) * (j * df))
61 |
62 |
63 | ph = np.fromfunction(phase, shape=(N, N), dtype=np.float32)
64 | # H
65 | H = np.exp(1.0j * k * d) * np.exp(-1.0j * lmb * np.pi * d * ph)
66 | H_f = np.fft.fftshift(H)
67 | #print(H_f); print(H)
68 | return H,H_f
69 |
70 | H,H_f=Init_H()
71 | #fft_test(); input(...)
72 |
73 | def _propogation(u0, N = size, dL = dL):
74 | df = 1.0 / dL
75 | return tf.ifft2d(H_f*tf.fft2d(u0)*dL*dL/(N*N))*N*N/dL/dL
76 |
77 | def propogation(u0,d,function=_propogation):
78 | return tf.map_fn(function,u0)
79 |
80 | def make_random(shape):
81 | return np.random.random(size = shape).astype('float32')
82 |
83 |
84 | def add_layer_amp(inputs,amp,phase,size,delta):
85 | return tf.multiply(propogation(inputs,delta),tf.cast(amp,dtype=tf.complex64))
86 | #return propogation(inputs,delta)*tf.cast(amp,dtype=tf.complex64)
87 |
88 | def add_layer_phase_out(inputs,amp,phase,size,delta):
89 | return propogation(inputs,delta,function=_propogation_phase_out)*tf.math.exp(1j*tf.cast(phase,dtype=tf.complex64))
90 |
91 |
92 | def add_layer_phase_in(inputs,amp,phase,size,delta):
93 | return propogation(inputs,delta,function=_propogation_phase_in)*tf.cast(amp,dtype=tf.complex64)
94 |
95 | def _change(input_):
96 | return imresize(input_.reshape(28,28),(size,size),interp="nearest")
97 |
98 | def change(input_):
99 | return np.array(list(map(_change,input_)))
100 |
101 | def rang(arr,shape,size=size,base = 512):
102 | #return arr[shape[0]*size//base:shape[1]*size//base,shape[2]*size//512:shape[3]*size//512]
103 | x0 = shape[0] * size // base
104 | y0 = shape[2] * size // base
105 | delta = (shape[1]-shape[0])* size // base
106 | return arr[x0:x0+delta,y0:y0+delta]
107 |
108 | def reduce_mean(tf_):
109 | return tf.reduce_mean(tf_)
110 |
111 | def _ten_regions(a):
112 | return tf.map_fn(reduce_mean,tf.convert_to_tensor([
113 | rang(a,(120,170,120,170)),
114 | rang(a,(120,170,240,290)),
115 | rang(a,(120,170,360,410)),
116 | rang(a,(220,270,120,170)),
117 | rang(a,(220,270,200,250)),
118 | rang(a,(220,270,280,330)),
119 | rang(a,(220,270,360,410)),
120 | rang(a,(320,370,120,170)),
121 | rang(a,(320,370,240,290)),
122 | rang(a,(320,370,360,410))
123 | ]))
124 |
125 | def ten_regions(logits):
126 | return tf.map_fn(_ten_regions,tf.abs(logits),dtype=tf.float32)
127 |
128 | def download_text(msg,epoch,MIN=1,MAX=7,name=''):
129 | print("Download {}".format(name))
130 | if name == 'Phase':
131 | MIN = 0
132 | MAX = 2
133 | for i in range(MIN,MAX):
134 | print("{} {}:".format(name,i))
135 | np.savetxt("{}_Time_{}_layer_{}.txt".format(name,epoch+1,i),msg[i-1])
136 | print("Done")
137 |
138 | def download_image(msg,epoch,MIN=1,MAX=7,name=''):
139 | print(f"Plot images-[{MIN}:{MAX}]")
140 | if name == 'Phase':
141 | MIN = 0
142 | MAX = 2
143 | for i in range(MIN,MAX):
144 | #print("Image {}:".format(i))
145 | plt.figure(dpi=650.24)
146 | plt.axis('off')
147 | plt.grid('off')
148 | plt.imshow(msg[i-1])
149 | plt.savefig("{}_Time_{}_layer_{}.jpg".format(name,epoch+1,i))
150 | #print("Done")
151 |
152 | def download_acc(acc,epoch):
153 | np.savetxt("Acc{}.txt".format(epoch+1),acc)
154 |
155 |
156 | with tf.device('/cpu:0'):
157 | data_x = tf.placeholder(tf.float32,shape=(batch_size,size,size))
158 | data_y = tf.placeholder(tf.float32,shape=(batch_size,10))
159 |
160 | amp=[
161 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32),
162 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32),
163 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32),
164 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32),
165 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32),
166 | tf.Variable(make_random(shape=(size,size)),dtype=tf.float32)
167 | ]
168 |
169 | phase = [
170 | tf.constant(np.random.random(size=(size,size)),dtype=tf.float32),
171 | tf.constant(np.random.random(size=(size,size)),dtype=tf.float32)
172 | ]
173 |
174 | with tf.variable_scope('FullyConnected'):
175 | layer_1 = add_layer_amp(tf.cast(data_x,dtype=tf.complex64),amp[0],phase[0],size,delta)
176 | layer_2 = add_layer_amp(layer_1,amp[1],phase[1],size,delta)
177 | layer_3 = add_layer_amp(layer_2,amp[2],phase[1],size,delta)
178 | layer_4 = add_layer_amp(layer_3,amp[3],phase[1],size,delta)
179 | layer_5 = add_layer_amp(layer_4,amp[4],phase[1],size,delta)
180 | output_layer = add_layer_amp(layer_5,amp[5],phase[1],size,delta)
181 | output = _propogation(output_layer)
182 |
183 | with tf.variable_scope('Loss'):
184 | logits_abs = tf.square(tf.nn.softmax(ten_regions(tf.abs(output))))
185 | loss = tf.reduce_sum(tf.square(logits_abs-data_y))
186 | train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
187 |
188 | with tf.variable_scope('Accuracy'):
189 | pre_correct = tf.equal(tf.argmax(data_y,1),tf.argmax(logits_abs,1))
190 | accuracy = tf.reduce_mean(tf.cast(pre_correct,tf.float32))
191 |
192 | init = tf.global_variables_initializer()
193 | train_epochs = 20
194 | test_epochs = 5
195 | session = tf.Session()
196 | with tf.device('/gpu:0'):
197 | session.run(init)
198 | total_batch = int(mnist.train.num_examples / batch_size)
199 | #total_batch = 10
200 |
201 | for epoch in tqdm(range(train_epochs)):
202 | for batch in tqdm(range(total_batch)):
203 | batch_x,batch_y = mnist.train.next_batch(batch_size)
204 | session.run(train_op,feed_dict={data_x:change(batch_x),data_y:batch_y})
205 |
206 | loss_,acc = session.run([loss,accuracy],feed_dict={data_x:change(batch_x),data_y:batch_y})
207 | print("epoch :{} loss:{:.4f} acc:{:.4f}".format(epoch+1,loss_,acc))
208 |
209 | with tf.device('/cpu:0'):
210 | msg_amp = np.array(session.run(amp))
211 | download_text(msg_amp,epoch,name='Amp')
212 | #download_image(msg_amp,epoch,name='Amp')
213 | print("Optimizer finished")
--------------------------------------------------------------------------------
/python-package/onnet/D2NNet.py:
--------------------------------------------------------------------------------
1 | # Authors: Yingshi Chen(gsp.cys@gmail.com)
2 |
3 | '''
4 | PyTorch implementation of D2CNN ------ All-optical machine learning using diffractive deep neural networks
5 | '''
6 |
7 | import torch
8 | import torchvision.transforms.functional as F
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from .Z_utils import COMPLEX_utils as Z
12 | from .PoolForCls import *
13 | from .Loss import *
14 | from .SparseSupport import *
15 | from .FFT_layer import *
16 | import numpy as np
17 | from .DiffractiveLayer import *
18 | import cv2
19 | useAttention=False
20 | if useAttention:
21 | import entmax
22 | #from torchscope import scope
23 |
24 | class DNET_config:
25 | def __init__(self,batch,lr_base,modulation="phase",init_value = "random",random_seed=42,
26 | support=SuppLayer.SUPP.exp,isFC=False):
27 | '''
28 |
29 | :param modulation:
30 | :param init_value: ["random","zero","random_reverse","reverse","chunk"]
31 | :param support:
32 | '''
33 | self.custom_legend = "Express Wavenet" #"Express_OFF" "Express Wavenet","Pan_OFF Express_OFF" #for paper and debug
34 | self.seed = random_seed
35 | seed_everything(self.seed)
36 | self.init_value = init_value # "random" "zero"
37 | self.rDrop = 0
38 | self.support = support #None
39 | self.modulation = modulation #["phase","phase_amp"]
40 | self.output_chunk = "2D" #["1D","2D"]
41 | self.output_pooling = "max"
42 | self.batch = batch
43 | self.learning_rate = lr_base
44 | self.isFC = isFC
45 | self.input_scale = 1
46 | self.wavelet = None #dict paramter for wavelet
47 | #if self.isFC == True: self.learning_rate = lr_base/10
48 | self.input_plane = "" #"fourier"
49 |
50 | def env_title(self):
51 | title=f"{self.support.value}"
52 | if self.isFC: title += "[FC]"
53 | if self.custom_legend is not None:
54 | title = title + f"_{self.custom_legend}"
55 | return title
56 |
57 | def __repr__(self):
58 | main_str = f"lr={self.learning_rate}_ mod={self.modulation} input={self.input_scale} detector={self.output_chunk} " \
59 | f"support={self.support}"
60 | if self.isFC: main_str+=" [FC]"
61 | if self.custom_legend is not None:
62 | main_str = main_str + f"_{self.custom_legend}"
63 | return main_str
64 |
65 | class D2NNet(nn.Module):
66 | @staticmethod
67 | def binary_loss(output, target, reduction='mean'):
68 | nSamp = target.shape[0]
69 | nGate = output.shape[1] // 2
70 | loss = 0
71 | for i in range(nGate):
72 | target_i = target % 2
73 | val_2 = torch.stack([output[:,2*i],output[:,2*i+1]],1)
74 |
75 | loss_i = F.cross_entropy(val_2, target_i, reduction=reduction)
76 | loss += loss_i
77 | target = (target - target_i) / 2
78 |
79 | # loss = F.nll_loss(output, target, reduction=reduction)
80 | return loss
81 |
82 | @staticmethod
83 | def logit_loss(output, target, reduction='mean'): #https://stackoverflow.com/questions/53628622/loss-function-its-inputs-for-binary-classification-pytorch
84 | nSamp = target.shape[0]
85 | nGate = output.shape[1]
86 | loss = 0
87 | loss_BCE = nn.BCEWithLogitsLoss()
88 | for i in range(nGate):
89 | target_i = target % 2
90 | out_i = output[:,i]
91 | loss_i = loss_BCE(out_i, target_i.double())
92 | loss += loss_i
93 | target = (target - target_i) / 2
94 | return loss
95 |
96 | def predict(self,output):
97 | if self.config.support == "binary":
98 | nGate = output.shape[1] // 2
99 | #assert nGate == self.n
100 | pred = 0
101 | for i in range(nGate):
102 | no = 2*(nGate - 1 - i)
103 | val_2 = torch.stack([output[:, no], output[:, no + 1]], 1)
104 | pred_i = val_2.max(1, keepdim=True)[1] # get the index of the max log-probability
105 | pred = pred * 2 + pred_i
106 | elif self.config.support == "logit":
107 | nGate = output.shape[1]
108 | # assert nGate == self.n
109 | pred = 0
110 | for i in range(nGate):
111 | no = nGate - 1 - i
112 | val_2 = F.sigmoid(output[:, no])
113 | pred_i = (val_2+0.5).long()
114 | pred = pred * 2 + pred_i
115 | else:
116 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
117 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
118 | return pred
119 |
120 | def GetLayer_(self):
121 | # layer = DiffractiveAMP
122 | if self.config.wavelet is None:
123 | layer = DiffractiveLayer
124 | else:
125 | layer = DiffractiveWavelet
126 | return layer
127 |
128 | def __init__(self,IMG_size,nCls,nDifrac,config):
129 | super(D2NNet, self).__init__()
130 | self.M,self.N=IMG_size
131 | self.z_modulus = Z.modulus
132 | self.nDifrac = nDifrac
133 | #self.isFC = False
134 | self.nClass = nCls
135 | #self.init_value = "random" #"random" "zero"
136 | self.config = config
137 | self.title = f"DNNet"
138 | self.highWay = 1 #1,2,3
139 | if self.config.input_plane == "fourier":
140 | self.highWay = 0
141 |
142 | if hasattr(self.config,'feat_extractor'):
143 | if self.config.feat_extractor!="last_layer":
144 | self.feat_extractor = []
145 |
146 | if self.config.output_chunk == "2D":
147 | assert(self.M*self.N>=self.nClass)
148 | else:
149 | assert (self.M >= self.nClass and self.N >= self.nClass)
150 | print(f"D2NNet nClass={nCls} shape={self.M,self.N}")
151 |
152 |
153 | layer = self.GetLayer_()
154 | #fl = FFT_Layer(self.M, self.N,config,isInv=False)
155 | self.DD = nn.ModuleList([
156 | layer(self.M, self.N,config) for i in range(self.nDifrac)
157 | ])
158 | if self.config.input_plane=="fourier":
159 | self.DD.insert(0,FFT_Layer(self.M, self.N,config,isInv=False))
160 | self.DD.append(FFT_Layer(self.M, self.N,config,isInv=True))
161 | self.nD = len(self.DD)
162 | self.laySupp = None
163 |
164 | if self.highWay>0:
165 | self.wLayer = torch.nn.Parameter(torch.ones(len(self.DD)))
166 | if self.highWay==2:
167 | self.wLayer.data.uniform_(-1, 1)
168 | elif self.highWay==1:
169 | self.wLayer = torch.nn.Parameter(torch.ones(len(self.DD)))
170 |
171 | #self.DD.append(DropOutLayer(self.M, self.N,drop=0.9999))
172 | if self.config.isFC:
173 | self.fc1 = nn.Linear(self.M*self.N, self.nClass)
174 | self.loss = UserLoss.cys_loss
175 | self.title = f"DNNet_FC"
176 | elif self.config.support!=None:
177 | self.laySupp = SuppLayer(config,self.nClass)
178 | self.last_chunk = ChunkPool(self.laySupp.nChunk, config, pooling=config.output_pooling)
179 | self.loss = UserLoss.cys_loss
180 | a = self.config.support.value
181 | self.title = f"DNNet_{self.config.support.value}"
182 | else:
183 | self.last_chunk = ChunkPool(self.nClass,config,pooling=config.output_pooling)
184 | self.loss = UserLoss.cys_loss
185 |
186 | if self.config.wavelet is not None:
187 | self.title = self.title+f"_W"
188 | if self.highWay>0:
189 | self.title = self.title + f"_H"
190 | if self.config.custom_legend is not None:
191 | self.title = self.title + f"_{self.config.custom_legend}"
192 |
193 | '''
194 | BinaryChunk is pool
195 | elif self.config.support=="binary":
196 | self.last_chunk = BinaryChunk(self.nClass, pooling="max")
197 | self.loss = D2NNet.binary_loss
198 | self.title = f"DNNet_binary"
199 | elif self.config.support == "logit":
200 | self.last_chunk = BinaryChunk(self.nClass, isLogit=True, pooling="max")
201 | self.loss = D2NNet.logit_loss
202 | '''
203 |
204 | def visualize(self,visual,suffix):
205 | no = 0
206 | for plot in visual.plots:
207 | images,path = [],""
208 | if plot['object']=='layer pattern':
209 | path = f"{visual.img_dir}/{suffix}.jpg"
210 | for no,layer in enumerate(self.DD):
211 | info = f"{suffix},{no}]"
212 | title = f"layer_{no+1}"
213 | if self.highWay==2:
214 | a = self.wLayer[no]
215 | a = torch.sigmoid(a)
216 | info = info+f"_{a:.2g}"
217 | elif self.highWay==1:
218 | a = self.wLayer[no]
219 | info = info+f"_{a:.2g}"
220 | title = title+f" w={a:.2g}"
221 | image = layer.visualize(visual,info,{'save':False,'title':title})
222 | images.append(image)
223 | no=no+1
224 | if len(images)>0:
225 | image_all = np.concatenate(images, axis=1)
226 | #cv2.imshow("", image_all); cv2.waitKey(0)
227 | cv2.imwrite(path,image_all)
228 |
229 | def legend(self):
230 | if self.config.custom_legend is not None:
231 | leg_ = self.config.custom_legend
232 | else:
233 | leg_ = self.title
234 | return leg_
235 |
236 | def __repr__(self):
237 | main_str = super(D2NNet, self).__repr__()
238 | main_str += f"\n========init={self.config.init_value}"
239 | return main_str
240 |
241 | def input_trans(self,x): # square-rooted and normalized
242 | #x = x.double()*self.config.input_scale
243 | if True:
244 | x = x*self.config.input_scale
245 | x_0,x_1 = torch.min(x).item(),torch.max(x).item()
246 | assert x_0>=0
247 | x = torch.sqrt(x)
248 | else: #为何不行,莫名其妙
249 | x = Z.exp_euler(x*2*math.pi).float()
250 | x_0,x_1 = torch.min(x).item(),torch.max(x).item()
251 | return x
252 |
253 | def do_classify(self,x):
254 | if self.config.isFC:
255 | x = torch.flatten(x, 1)
256 | x = self.fc1(x)
257 | return x
258 |
259 | x = self.last_chunk(x)
260 | if self.laySupp != None:
261 | x = self.laySupp(x)
262 | # output = F.log_softmax(x, dim=1)
263 | return x
264 |
265 | def OnLayerFeats(self):
266 | pass
267 |
268 | def forward(self, x):
269 | if hasattr(self, 'feat_extractor'):
270 | self.feat_extractor.clear()
271 | nSamp,nChannel = x.shape[0],x.shape[1]
272 | assert(nChannel==1)
273 | if nChannel>1:
274 | no = random.randint(0,nChannel-1)
275 | x = x[:,0:1,...]
276 | x = self.input_trans(x)
277 | if hasattr(self,'visual'): self.visual.onX(x.cpu(), f"X@input")
278 | summary = 0
279 | for no,layD in enumerate(self.DD):
280 | info = layD.__repr__()
281 | x = layD(x)
282 | if hasattr(self,'feat_extractor'):
283 | self.feat_extractor.append((self.z_modulus(x),self.wLayer[no]))
284 | if hasattr(self,'visual'): self.visual.onX(x,f"X@{no+1}")
285 | if self.highWay==2:
286 | s = torch.sigmoid(self.wLayer[no])
287 | summary+=x*s
288 | x = x*(1-s)
289 | elif self.highWay==1:
290 | summary += x * self.wLayer[no]
291 | elif self.highWay==3:
292 | summary += self.z_modulus(x) * self.wLayer[no]
293 | if self.highWay==2:
294 | x=x+summary
295 | x = self.z_modulus(x)
296 | elif self.highWay == 1:
297 | x = summary
298 | x = self.z_modulus(x)
299 | elif self.highWay == 3:
300 | x = summary
301 | elif self.highWay == 0:
302 | x = self.z_modulus(x)
303 | if hasattr(self,'visual'): self.visual.onX(x,f"X@output")
304 |
305 |
306 | if hasattr(self,'feat_extractor'):
307 | return
308 | elif hasattr(self.config,'feat_extractor') and self.config.feat_extractor=="last_layer":
309 | return x
310 | else:
311 | output = self.do_classify(x)
312 | return output
313 |
314 | class MultiDNet(D2NNet):
315 | def __init__(self, IMG_size,nCls,nInterDifrac,freq_list,config,shareWeight=True):
316 | super(MultiDNet, self).__init__(IMG_size,nCls,nInterDifrac,config)
317 | self.isShareWeight=shareWeight
318 | self.freq_list = freq_list
319 | nFreq = len(self.freq_list)
320 | del self.DD; self.DD = None
321 | self.wFreq = torch.nn.Parameter(torch.ones(nFreq))
322 | layer = self.GetLayer_()
323 | self.freq_nets=nn.ModuleList([
324 | nn.ModuleList([
325 | layer(self.M, self.N, self.config, HZ=freq) for i in range(self.nDifrac)
326 | ]) for freq in freq_list
327 | ])
328 | if self.isShareWeight:
329 | nSubNet = len(self.freq_nets)
330 | net_0 = self.freq_nets[0]
331 | for i in range(1,nSubNet):
332 | net_1 = self.freq_nets[i]
333 | for j in range(self.nDifrac):
334 | net_1[j].share_weight(net_0[j])
335 |
336 |
337 | def legend(self):
338 | if self.config.custom_legend is not None:
339 | leg_ = self.config.custom_legend
340 | else:
341 | title = f"MF_DNet({len(self.freq_list)} channels)"
342 | return title
343 |
344 | def __repr__(self):
345 | main_str = super(MultiDNet, self).__repr__()
346 | main_str += f"\nfreq_list={self.freq_list}_"
347 | return main_str
348 |
349 | def forward(self, x0):
350 | nSamp = x0.shape[0]
351 | x_sum = 0
352 | for id,fNet in enumerate(self.freq_nets):
353 | x = self.input_trans(x0)
354 | #d0,d1=x0.min(),x0.max()
355 | #x = x0.double()
356 | for layD in fNet:
357 | x = layD(x)
358 | #x_sum = torch.max(x_sum,self.z_modulus(x))).values()
359 | x_sum += self.z_modulus(x)*self.wFreq[id]
360 | x = x_sum
361 |
362 | output = self.do_classify(x)
363 | return output
364 |
365 | def main():
366 | pass
367 |
368 | if __name__ == '__main__':
369 | main()
--------------------------------------------------------------------------------
/python-package/onnet/DiffractiveLayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .Z_utils import COMPLEX_utils as Z
3 | from .some_utils import *
4 | import numpy as np
5 | import random
6 | import torch.nn as nn
7 | import matplotlib
8 | #matplotlib.use('Agg')
9 | import matplotlib.pyplot as plt
10 |
11 |
12 | #https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-custom-nn-modules
13 | class DiffractiveLayer(torch.nn.Module):
14 | def SomeInit(self, M_in, N_in,HZ=0.4e12):
15 | assert (M_in == N_in)
16 | self.M = M_in
17 | self.N = N_in
18 | self.z_modulus = Z.modulus
19 | self.size = M_in
20 | self.delta = 0.03
21 | self.dL = 0.02
22 | self.c = 3e8
23 | self.Hz = HZ#0.4e12
24 |
25 | self.H_z = self.Init_H()
26 |
27 | def __repr__(self):
28 | #main_str = super(DiffractiveLayer, self).__repr__()
29 | main_str = f"DiffractiveLayer_[{(int)(self.Hz/1.0e9)}G]_[{self.M},{self.N}]"
30 | return main_str
31 |
32 | def __init__(self, M_in, N_in,config,HZ=0.4e12):
33 | super(DiffractiveLayer, self).__init__()
34 | self.SomeInit(M_in, N_in,HZ)
35 | assert config is not None
36 | self.config = config
37 | #self.init_value = init_value
38 | #self.rDrop = rDrop
39 | if not hasattr(self.config,'wavelet') or self.config.wavelet is None:
40 | if self.config.modulation=="phase":
41 | self.transmission = torch.nn.Parameter(data=torch.Tensor(self.size, self.size), requires_grad=True)
42 | else:
43 | self.transmission = torch.nn.Parameter(data=torch.Tensor(self.size, self.size, 2), requires_grad=True)
44 |
45 | init_param = self.transmission.data
46 | if self.config.init_value=="reverse": #
47 | half=self.transmission.data.shape[-2]//2
48 | init_param[..., :half, :] = 0
49 | init_param[..., half:, :] = np.pi
50 | elif self.config.init_value=="random":
51 | init_param.uniform_(0, np.pi*2)
52 | elif self.config.init_value == "random_reverse":
53 | init_param = torch.randint_like(init_param,0,2)*np.pi
54 | elif self.config.init_value == "chunk":
55 | sections = split__sections()
56 | for xx in init_param.split(sections, -1):
57 | xx = random.random(0,np.pi*2)
58 |
59 | #self.rDrop = config.rDrop
60 |
61 | #self.bias = torch.nn.Parameter(data=torch.Tensor(1, 1), requires_grad=True)
62 |
63 | def visualize(self,visual,suffix, params):
64 | param = self.transmission.data
65 | name = f"{suffix}_{self.config.modulation}_"
66 | return visual.image(name,param, params)
67 |
68 | def share_weight(self,layer_1):
69 | tp = type(self)
70 | assert(type(layer_1)==tp)
71 | #del self.transmission
72 | #self.transmission = layer_1.transmission
73 |
74 | def Init_H(self):
75 | # Parameter
76 | N = self.size
77 | df = 1.0 / self.dL
78 | d=self.delta
79 | lmb=self.c / self.Hz
80 | k = np.pi * 2.0 / lmb
81 | D = self.dL * self.dL / (N * lmb)
82 | # phase
83 | def phase(i, j):
84 | i -= N // 2
85 | j -= N // 2
86 | return ((i * df) * (i * df) + (j * df) * (j * df))
87 |
88 | ph = np.fromfunction(phase, shape=(N, N), dtype=np.float32)
89 | # H
90 | H = np.exp(1.0j * k * d) * np.exp(-1.0j * lmb * np.pi * d * ph)
91 | H_f = np.fft.fftshift(H)*self.dL*self.dL/(N*N)
92 | # print(H_f); print(H)
93 | H_z = np.zeros(H_f.shape + (2,))
94 | H_z[..., 0] = H_f.real
95 | H_z[..., 1] = H_f.imag
96 | H_z = torch.from_numpy(H_z).cuda()
97 | return H_z
98 |
99 | def Diffractive_(self,u0, theta=0.0):
100 | if Z.isComplex(u0):
101 | z0 = u0
102 | else:
103 | z0 = u0.new_zeros(u0.shape + (2,))
104 | z0[...,0] = u0
105 |
106 | N = self.size
107 | df = 1.0 / self.dL
108 |
109 | z0 = Z.fft(z0)
110 | u1 = Z.Hadamard(z0,self.H_z.float())
111 | u2 = Z.fft(u1,"C2C",inverse=True)
112 | return u2 * N * N * df * df
113 |
114 | def GetTransCoefficient(self):
115 | '''
116 | eps = 1e-5; momentum = 0.1; affine = True
117 |
118 | mean = torch.mean(self.transmission, 1)
119 | vari = torch.var(self.transmission, 1)
120 | amp_bn = torch.batch_norm(self.transmission,mean,vari)
121 | :return:
122 | '''
123 | amp_s = Z.exp_euler(self.transmission)
124 |
125 | return amp_s
126 |
127 | def forward(self, x):
128 | diffrac = self.Diffractive_(x)
129 | amp_s = self.GetTransCoefficient()
130 | x = Z.Hadamard(diffrac,amp_s.float())
131 | if(self.config.rDrop>0):
132 | drop = Z.rDrop2D(1-self.rDrop,(self.M,self.N),isComlex=True)
133 | x = Z.Hadamard(x, drop)
134 | #x = x+self.bias
135 | return x
136 |
137 | class DiffractiveAMP(DiffractiveLayer):
138 | def __init__(self, M_in, N_in,rDrop=0.0):
139 | super(DiffractiveAMP, self).__init__(M_in, N_in,rDrop,params="amp")
140 | #self.amp = torch.nn.Parameter(data=torch.Tensor(self.size, self.size, 2), requires_grad=True)
141 | self.transmission.data.uniform_(0, 1)
142 |
143 | def GetTransCoefficient(self):
144 | # amp_s = Z.sigmoid(self.amp)
145 | # amp_s = torch.clamp(self.amp, 1.0e-6, 1)
146 | amp_s = self.transmission
147 | return amp_s
148 |
149 | class DiffractiveWavelet(DiffractiveLayer):
150 | def __init__(self, M_in, N_in,config,HZ=0.4e12):
151 | super(DiffractiveWavelet, self).__init__(M_in, N_in,config,HZ)
152 | #self.hough = torch.nn.Parameter(data=torch.Tensor(2), requires_grad=True)
153 | self.Init_DisTrans()
154 | #self.GetXita()
155 |
156 | def __repr__(self):
157 | main_str = f"Diffrac_Wavelet_[{(int)(self.Hz/1.0e9)}G]_[{self.M},{self.N}]"
158 | return main_str
159 |
160 | def share_weight(self,layer_1):
161 | tp = type(self)
162 | assert(type(layer_1)==tp)
163 | del self.wavelet
164 | self.wavelet = layer_1.wavelet
165 | del self.dis_map
166 | self.dis_map = layer_1.dis_map
167 | del self.wav_indices
168 | self.wav_indices = layer_1.wav_indices
169 |
170 |
171 | def Init_DisTrans(self):
172 | origin_r, origin_c = (self.M-1) / 2, (self.N-1) / 2
173 | origin_r = random.uniform(0, self.M-1)
174 | origin_c = random.uniform(0, self.N - 1)
175 | self.dis_map={}
176 | #self.dis_trans = torch.zeros((self.size, self.size)).int()
177 | self.wav_indices = torch.LongTensor((self.size*self.size)).cuda()
178 | nz=0
179 | for r in range(self.M):
180 | for c in range(self.N):
181 | off = np.sqrt((r - origin_r) * (r - origin_r) + (c - origin_c) * (c - origin_c))
182 | i_off = (int)(off+0.5)
183 | if i_off not in self.dis_map:
184 | self.dis_map[i_off]=len(self.dis_map)
185 | id = self.dis_map[i_off]
186 | #self.dis_trans[r, c] = id
187 | self.wav_indices[nz] = id; nz=nz+1
188 | #print(f"[{r},{c}]={self.dis_trans[r, c]}")
189 | nD = len(self.dis_map)
190 | if False:
191 | plt.imshow(self.dis_trans.numpy())
192 | plt.show()
193 |
194 | self.wavelet = torch.nn.Parameter(data=torch.Tensor(nD), requires_grad=True)
195 | self.wavelet.data.uniform_(0, np.pi*2)
196 | #self.dis_trans = self.dis_trans.cuda()
197 |
198 | def GetXita(self):
199 | if False:
200 | xita = torch.zeros((self.size, self.size))
201 | for r in range(self.M):
202 | for c in range(self.N):
203 | pos = self.dis_trans[r, c]
204 | xita[r,c] = self.wavelet[pos]
205 | origin_r,origin_c=self.M/2,self.N/2
206 | #xita = self.dis_trans*self.hough[0]+self.hough[1]
207 | else:
208 | xita = torch.index_select(self.wavelet, 0, self.wav_indices)
209 | xita = xita.view(self.size, self.size)
210 |
211 | # print(xita)
212 | return xita
213 |
214 | def GetTransCoefficient(self):
215 | xita = self.GetXita()
216 | amp_s = Z.exp_euler(xita)
217 | return amp_s
218 |
219 | def visualize(self,visual,suffix, params):
220 | xita = self.GetXita()
221 | name = f"{suffix}"
222 | return visual.image(name,torch.sin(xita.detach()), params)
--------------------------------------------------------------------------------
/python-package/onnet/DropOutLayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .Z_utils import COMPLEX_utils as Z
3 |
4 | #Very strange behavior of DROPOUT
5 | class DropOutLayer(torch.nn.Module):
6 | def __init__(self, M_in, N_in,drop=0.5):
7 | super(DropOutLayer, self).__init__()
8 | assert (M_in == N_in)
9 | self.M = M_in
10 | self.N = N_in
11 | self.rDrop = drop
12 |
13 | def forward(self, x):
14 | assert(Z.isComplex(x))
15 | nX = x.numel()//2
16 | d_shape=x.shape[:-1]
17 | drop = np.random.binomial(1, self.rDrop, size=d_shape).astype(np.float)
18 | #print(f"x={x.shape} drop={drop.shape}")
19 | drop = torch.from_numpy(drop).cuda()
20 | x[...,0] *= drop
21 | x[...,1] *= drop
22 | return x
--------------------------------------------------------------------------------
/python-package/onnet/FFT_layer.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Author: Yingshi Chen
3 |
4 | @Date: 2020-04-10 11:22:27
5 | @
6 | # Description:
7 | '''
8 |
9 | import torch
10 | from .Z_utils import COMPLEX_utils as Z
11 | from .some_utils import *
12 | import numpy as np
13 | import random
14 | import torch.nn as nn
15 | import matplotlib
16 | #matplotlib.use('Agg')
17 | import matplotlib.pyplot as plt
18 |
19 | class FFT_Layer(torch.nn.Module):
20 | def SomeInit(self, M_in, N_in,isInv=False):
21 | assert (M_in == N_in)
22 | self.M = M_in
23 | self.N = N_in
24 | self.isInv = isInv
25 |
26 | def __repr__(self):
27 | i_ = "_i" if self.isInv else ""
28 | main_str = f"FFT_Layer{i_}_[{self.M},{self.N}]"
29 | return main_str
30 |
31 | def __init__(self, M_in, N_in,config,isInv=False):
32 | super(FFT_Layer, self).__init__()
33 | self.SomeInit(M_in, N_in,isInv)
34 | assert config is not None
35 | self.config = config
36 | #self.init_value = init_value
37 |
38 | def visualize(self,visual,suffix, params):
39 | param = self.transmission.data
40 | name = f"{suffix}_{self.config.modulation}_"
41 | return visual.image(name,param, params)
42 |
43 |
44 | def Diffractive_(self,u0, theta=0.0):
45 | if Z.isComplex(u0):
46 | z0 = u0
47 | else:
48 | z0 = u0.new_zeros(u0.shape + (2,))
49 | z0[...,0] = u0
50 |
51 | N = self.size
52 | df = 1.0 / self.dL
53 |
54 | z0 = Z.fft(z0)
55 | u1 = Z.Hadamard(z0,self.H_z.float())
56 | u2 = Z.fft(u1,"C2C",inverse=True)
57 | return u2 * N * N * df * df
58 |
59 | def forward(self, x):
60 | #return x
61 | if Z.isComplex(x):
62 | z0 = x
63 | else:
64 | z0 = x.new_zeros(x.shape + (2,))
65 | z0[...,0] = x
66 | if self.isInv:
67 | x = Z.fft(z0,"C2C",inverse=self.isInv)
68 | else:
69 | x = (Z.fft(z0,"C2C",inverse=self.isInv))
70 | x_0,x_1 = torch.min(x),torch.max(x)
71 | return x
72 |
73 | def trans(img):
74 | plt.figure(figsize=(10,8))
75 | plt.subplot(121),plt.imshow(img, cmap = 'gray')
76 | plt.title('Input Image'), plt.xticks([]), plt.yticks([])
77 | f = (abs(np.fft.fftshift(fftn(img))))**0.25*(255)**3 # Amplify
78 | plt.subplot(122),plt.imshow(f, cmap = 'gray')
79 | plt.title('Spectrum'), plt.xticks([]), plt.yticks([])
80 | plt.show()
--------------------------------------------------------------------------------
/python-package/onnet/Loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | class UserLoss(object):
5 |
6 | @staticmethod
7 | def cys_loss(output, target, reduction='mean'):
8 | #loss = F.binary_cross_entropy(output, target, reduction=reduction)
9 | loss = F.cross_entropy(output, target, reduction=reduction)
10 | #loss = F.nll_loss(output, target, reduction=reduction)
11 |
12 | return loss
--------------------------------------------------------------------------------
/python-package/onnet/NET_config.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | '''
5 | parser.add_argument is better than NET_config
6 | '''
7 | class NET_config:
8 | def __init__(self,net_type, data_set, IMG_size, lr_base, batch_size,nClass,nLayer=-1):
9 | #seed_everything(self.seed)
10 | self.net_type = net_type
11 | self.data_set = data_set
12 | self.IMG_size = IMG_size
13 | self.lr_base = lr_base # "random" "zero"
14 | self.batch_size = batch_size
15 | self.nClass = nClass
16 | self.nLayer = nLayer
17 |
--------------------------------------------------------------------------------
/python-package/onnet/Net_Instance.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Author: Yingshi Chen
3 |
4 | @Date: 2020-01-16 15:08:16
5 | @
6 | # Description:
7 | '''
8 | from .D2NNet import *
9 | from .RGBO_CNN import *
10 | from .OpticalFormer import *
11 | import math
12 | from copy import copy, deepcopy
13 |
14 | def dump_model_params(model):
15 | nzParams = 0
16 | for name, param in model.named_parameters():
17 | if param.requires_grad:
18 | nzParams += param.nelement()
19 | print(f"\t{name}={param.nelement()}")
20 | print(f"========All parameters={nzParams}")
21 | return nzParams
22 |
23 | def Net_dump(net):
24 | nzParams=dump_model_params(net)
25 |
26 | #def DNet_instance(net_type,dataset,IMG_size,lr_base,batch_size,nClass,nLayer): 需要重写,只有一个config
27 | def DNet_instance(config):
28 | net_type, dataset, IMG_size, lr_base, batch_size, nClass, nLayer = \
29 | config.net_type,config.data_set, config.IMG_size, config.lr_base, config.batch_size, config.nClass, config.nLayer
30 | if net_type == "BiDNet":
31 | lr_base = 0.01
32 | if dataset == "emnist":
33 | lr_base = 0.01
34 |
35 | config_base = DNET_config(batch=batch_size, lr_base=lr_base)
36 | if hasattr(config,'feat_extractor'):
37 | config_base.feat_extractor = config.feat_extractor
38 | env_title = f"{net_type}_{dataset}_{IMG_size}_{lr_base}_{config_base.env_title()}"
39 | if net_type == "MF_DNet":
40 | freq_list = [0.3e12, 0.35e12, 0.4e12, 0.42e12]
41 | env_title = env_title + f"_C{len(freq_list)}"
42 | if net_type == "BiDNet":
43 | config_base = DNET_config(batch=batch_size, lr_base=lr_base, chunk="binary")
44 |
45 | if net_type == "cnn":
46 | model = Mnist_Net(config=config_base)
47 | return env_title, model
48 |
49 | if net_type == "DNet":
50 | model = D2NNet(IMG_size, nClass, nLayer, config_base)
51 | elif net_type == "WNet":
52 | config_base.wavelet={"nWave":3}
53 | model = D2NNet(IMG_size, nClass, nLayer, config_base)
54 | elif net_type == "MF_DNet":
55 | # model = MultiDNet(IMG_size, nClass, nLayer,[0.3e12,0.35e12,0.4e12,0.42e12,0.5e12,0.6e12], DNET_config())
56 | model = MultiDNet(IMG_size, nClass, nLayer, [0.3e12, 0.35e12, 0.4e12, 0.42e12], config_base)
57 | elif net_type == "MF_WNet":
58 | config_base.wavelet = {"nWave": 3}
59 | model = MultiDNet(IMG_size, nClass, nLayer, [0.3e12, 0.35e12, 0.4e12, 0.42e12], config_base)
60 | elif net_type == "BiDNet":
61 | model = D2NNet(IMG_size, nClass, nLayer, config_base)
62 | elif net_type == "OptFormer":
63 | pass
64 |
65 | #model.double()
66 |
67 | return env_title, model
68 |
69 | def RGBO_CNN_instance(config):
70 | assert config.net_type == "RGBO_CNN"
71 | env_title = f"{config.net_type}_{config.dnet_type}_{config.data_set}_{config.IMG_size}_{config.lr_base}_"
72 | assert hasattr(config,'dnet_type')
73 |
74 | if config.dnet_type!="":
75 | d_conf = deepcopy(config)
76 | if config.dnet_type == "stack_input":
77 | d_conf.net_type = "DNet"
78 | #d_conf.nLayer = 1
79 | #d_conf.feat_extractor = "layers"
80 | else:
81 | d_conf.nLayer = 10
82 | d_conf.net_type = "WNet"
83 | _,DNet = DNet_instance(d_conf)
84 | else:
85 | DNet=None
86 | model = RGBO_CNN(config,DNet)
87 |
88 | return env_title, model
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/python-package/onnet/OpticalFormer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange, repeat
4 | from torch import nn
5 | from .OpticalFormer_util import *
6 | # import lite_bert
7 | MIN_NUM_PATCHES = 16
8 |
9 | class Residual(nn.Module):
10 | def __init__(self, fn):
11 | super().__init__()
12 | self.fn = fn
13 | def forward(self, x, **kwargs):
14 | return self.fn(x, **kwargs) + x
15 |
16 | class PreNorm(nn.Module):
17 | def __init__(self, dim, fn):
18 | super().__init__()
19 | self.norm = nn.LayerNorm(dim)
20 | self.fn = fn
21 | def forward(self, x, **kwargs):
22 | return self.fn(self.norm(x), **kwargs)
23 |
24 | class FeedForward(nn.Module):
25 | def __init__(self, dim, hidden_dim, dropout = 0.):
26 | super().__init__()
27 | self.net = nn.Sequential(
28 | nn.Linear(dim, hidden_dim),
29 | nn.GELU(),
30 | nn.Dropout(dropout),
31 | nn.Linear(hidden_dim, dim),
32 | nn.Dropout(dropout)
33 | )
34 | def forward(self, x):
35 | return self.net(x)
36 |
37 | class Attention(nn.Module):
38 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
39 | super().__init__()
40 | inner_dim = dim_head * heads
41 | self.heads = heads
42 | self.scale = dim ** -0.5
43 |
44 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
45 | self.to_out = nn.Sequential(
46 | nn.Linear(inner_dim, dim),
47 | nn.Dropout(dropout)
48 | )
49 |
50 | def forward(self, x, mask = None):
51 | b, n, _, h = *x.shape, self.heads
52 | qkv = self.to_qkv(x).chunk(3, dim = -1)
53 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
54 |
55 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
56 | mask_value = -torch.finfo(dots.dtype).max
57 |
58 | if mask is not None:
59 | mask = F.pad(mask.flatten(1), (1, 0), value = True)
60 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
61 | mask = mask[:, None, :] * mask[:, :, None]
62 | dots.masked_fill_(~mask, mask_value)
63 | del mask
64 |
65 | attn = dots.softmax(dim=-1)
66 |
67 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
68 | out = rearrange(out, 'b h n d -> b n (h d)')
69 | out = self.to_out(out)
70 | return out
71 |
72 | def unitwise_norm(x,axis=None):
73 | """Compute norms of each output unit separately, also for linear layers."""
74 | if len(torch.squeeze(x).shape) <= 1: # Scalars and vectors
75 | axis = None
76 | keepdims = False
77 | return torch.norm(x)
78 | elif len(x.shape) in [2, 3]: # Linear layers of shape IO or multihead linear
79 | # axis = 0
80 | # axis = 1
81 | keepdims = True
82 | elif len(x.shape) == 4: # Conv kernels of shape HWIO
83 | if axis is None:
84 | axis = [0, 1, 2,]
85 | keepdims = True
86 | else:
87 | raise ValueError(f'Got a parameter with shape not in [1, 2, 4]! {x}')
88 | return torch.sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5
89 |
90 | def clip_grad_rc(grad,W,row_major=False,eps = 1.e-3,clip=0.02):
91 | # adaptive_grad_clip
92 | if len(grad.shape)==2:
93 | nR,nC = grad.shape
94 | axis = 1 if row_major else 0
95 | g_norm = unitwise_norm(grad,axis=axis)
96 | W_norm = unitwise_norm(W,axis=axis)
97 | assert(g_norm.shape==W_norm.shape)
98 | W_norm[W_normrc', grad, s)
106 | return grad
107 |
108 | def clip_grad(model,eps = 1.e-3,clip=0.002,method="agc"):
109 | known_modules = {'Linear'}
110 | for module in model.modules():
111 | classname = module.__class__.__name__
112 | if classname not in known_modules:
113 | continue
114 | if classname == 'Conv2d':
115 | assert(False)
116 | grad = None
117 | elif classname == 'BertLayerNorm':
118 | grad = None
119 | else:
120 | grad = module.weight.grad.data
121 | W = module.weight.data
122 |
123 | # adaptive_grad_clip
124 | assert len(grad.shape)==2
125 | nR,nC = grad.shape
126 | axis = 1 if nR>nC else 0
127 | g_norm = unitwise_norm(grad,axis=axis)
128 | W_norm = unitwise_norm(W,axis=axis)
129 | W_norm[W_normrc', grad, s)
137 | module.weight.grad.data.copy_(grad)
138 |
139 | if module.bias is not None:
140 | v = module.bias.grad.data
141 | axis = 0
142 | b_grad = clip_grad_rc(v,module.bias.data,row_major=axis==1,eps = eps,clip=clip)
143 | module.bias.grad.data.copy_(b_grad)
144 |
145 |
146 | class Transformer(nn.Module):
147 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout,clip_grad=""):
148 | super().__init__()
149 | self.layers = nn.ModuleList([])
150 | self.isV0 = False
151 | for _ in range(depth):
152 | if self.isV0:
153 | self.layers.append(nn.ModuleList([
154 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
155 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
156 | ]))
157 | else:
158 | # self.layers.append(lite_bert.BTransformer(dim, heads, dim * 4, dropout))
159 | self.layers.append(BTransformer(dim, heads, dim * 4, dropout,clip_grad=clip_grad))
160 | def forward(self, x, mask = None):
161 | if self.isV0:
162 | for attn, ff in self.layers:
163 | x = attn(x, mask = mask)
164 | x = ff(x)
165 | else:
166 | for BTrans in self.layers:
167 | x = BTrans(x,mask)
168 | return x
169 |
170 | class OpticalFormer(nn.Module):
171 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ff_hidden, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.,clip_grad=""):
172 | super().__init__()
173 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
174 | num_patches = (image_size // patch_size) ** 2 #64
175 | patch_dim = channels * patch_size ** 2 #48 pixles in each patch
176 | assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
177 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
178 |
179 | self.patch_size = patch_size
180 | self.clip_grad = clip_grad
181 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches , dim))
182 | self.patch_to_embedding = nn.Linear(patch_dim, dim)
183 | # self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
184 | # self.dropout = nn.Dropout(emb_dropout)
185 |
186 | self.transformer = Transformer(dim, depth, heads, dim_head, ff_hidden, dropout,clip_grad=self.clip_grad)
187 |
188 | self.pool = pool
189 | self.to_latent = nn.Identity()
190 |
191 | self.mlp_head = nn.Sequential(
192 | nn.Identity() if self.clip_grad=="agc" else nn.LayerNorm(dim),
193 | nn.Linear(dim, num_classes)
194 | )
195 |
196 | def name_(self):
197 | return "ViT_"
198 |
199 | def forward(self, img, mask = None):
200 | p = self.patch_size
201 |
202 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
203 | # x = rearrange(img, 'b c (h p1) (w p2) -> b (h w c) (p1 p2)', p1 = p, p2 = p)
204 | x = self.patch_to_embedding(x)
205 | b, n, _ = x.shape
206 |
207 | # cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
208 | # x = torch.cat((cls_tokens, x), dim=1)
209 | x += self.pos_embedding[:, :(n )]
210 | # x = self.dropout(x)
211 |
212 | x = self.transformer(x, mask)
213 |
214 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
215 |
216 | x = self.to_latent(x)
217 | return self.mlp_head(x)
218 |
219 | def predict(self,output):
220 | if self.config.support == "binary":
221 | nGate = output.shape[1] // 2
222 | #assert nGate == self.n
223 | pred = 0
224 | for i in range(nGate):
225 | no = 2*(nGate - 1 - i)
226 | val_2 = torch.stack([output[:, no], output[:, no + 1]], 1)
227 | pred_i = val_2.max(1, keepdim=True)[1] # get the index of the max log-probability
228 | pred = pred * 2 + pred_i
229 | elif self.config.support == "logit":
230 | nGate = output.shape[1]
231 | # assert nGate == self.n
232 | pred = 0
233 | for i in range(nGate):
234 | no = nGate - 1 - i
235 | val_2 = F.sigmoid(output[:, no])
236 | pred_i = (val_2+0.5).long()
237 | pred = pred * 2 + pred_i
238 | else:
239 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
240 | #pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
241 | return pred
242 |
243 |
--------------------------------------------------------------------------------
/python-package/onnet/OpticalFormer_util.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | import torch.nn.functional as F
5 | # from .sparse_max import sparsemax, entmax15
6 |
7 | class LayerNorm(nn.Module):
8 | "Construct a layernorm module (See citation for details)."
9 |
10 | def __init__(self, features, eps=1e-6):
11 | super(LayerNorm, self).__init__()
12 | self.a_2 = nn.Parameter(torch.ones(features))
13 | self.b_2 = nn.Parameter(torch.zeros(features))
14 | self.eps = eps
15 |
16 | def forward(self, x):
17 | mean = x.mean(-1, keepdim=True)
18 | std = x.std(-1, keepdim=True)
19 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
20 |
21 | class GELU(nn.Module):
22 | """
23 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
24 | """
25 | def forward(self, x):
26 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
27 |
28 | class QK_Attention(nn.Module):
29 | def forward(self, query, key, value, mask=None, dropout=None):
30 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
31 | #mini batch多句话得长度并不一致,需要按照最大得长度对短句子进行补全,也就是padding零,mask起来,填充一个负无穷(-1e9这样得数值),这样计算就可以为0了,等于把计算遮挡住。
32 | if mask is not None:
33 | scores = scores.masked_fill(mask == 0, -1e9)
34 |
35 | p_attn = F.softmax(scores, dim=-1)
36 | # p_attn = entmax15(scores, dim=-1)
37 |
38 | if dropout is not None:
39 | p_attn = dropout(p_attn)
40 |
41 | return torch.matmul(p_attn, value), p_attn
42 |
43 | class MultiHeadedAttention(nn.Module):
44 | def __init__(self, h, d_model, dropout=0.1):
45 | super().__init__()
46 | assert d_model % h == 0
47 |
48 | # We assume d_v always equals d_k
49 | self.d_k = d_model // h
50 | self.h = h
51 |
52 | self.linear_project = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
53 | self.output_linear = nn.Linear(d_model, d_model)
54 | self.attention = QK_Attention()
55 | self.dropout = nn.Dropout(p=dropout) if dropout>0 else None
56 |
57 | def forward(self, x, mask=None):
58 | batch_size = x.size(0)
59 | if self.attention is None:
60 | x = self.dropout(x) #Very interesting, why self-attention is so useful?
61 | else:
62 | if self.h == 1:
63 | # query, key, value = [l(x) for l, x in zip(self.linear_project, (x, x, x))]
64 | query, key, value = x,x,x
65 | else:
66 | # 1) Do all the linear projections in batch from d_model => h x d_k
67 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
68 | for l, x in zip(self.linear_project, (x, x, x))]
69 | # query, key, value = (x,x,x)
70 |
71 | # 2) Apply attention on all the projected vectors in batch.
72 | x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
73 |
74 | # 3) "Concat" using a view and apply a final linear.
75 | if self.h > 1:
76 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
77 |
78 | return self.output_linear(x)
79 |
80 | class Residual(nn.Module):
81 | def __init__(self, fn):
82 | super().__init__()
83 | self.fn = fn
84 | def forward(self, x, **kwargs):
85 | return self.fn(x, **kwargs) + x
86 |
87 | #keep structure simple ,no norm,no dropout!!!
88 | class PreNorm(nn.Module):
89 | def __init__(self, dim, fn):
90 | super().__init__()
91 | self.norm = nn.LayerNorm(dim) #why this is so good
92 | # self.norm = nn.BatchNorm1d(64) #nearly same as layernorm
93 | # self.norm = nn.Identity()
94 | # self.norm = nn.BatchNorm1d(dim)
95 | self.fn = fn
96 |
97 | def forward(self, x, **kwargs):
98 | if self.fn is None:
99 | x = self.norm(x)
100 | else:
101 | x = self.fn(self.norm(x), **kwargs)
102 | return x
103 |
104 | class PositionwiseFeedForward(nn.Module):
105 | "Implements FFN equation."
106 |
107 | def __init__(self, d_model, d_ff, dropout=0.1):
108 | super(PositionwiseFeedForward, self).__init__()
109 | self.w_1 = nn.Linear(d_model, d_ff)
110 | self.w_2 = nn.Linear(d_ff, d_model)
111 | self.dropout = nn.Dropout(dropout) if dropout > 0 else None
112 | self.activation = GELU()
113 | # self.activation = nn.ReLU() # maybe use ReLU
114 |
115 | def forward(self, x):
116 | if self.dropout is None:
117 | return self.w_2(self.activation(self.w_1(x)))
118 | else:
119 | return self.w_2(self.dropout(self.activation(self.w_1(x))))
120 |
121 | class BTransformer(nn.Module):
122 | """
123 | Bidirectional Encoder = Transformer (self-attention)
124 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
125 | """
126 |
127 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout,clip_grad=""):
128 | """
129 | :param hidden: hidden size of transformer
130 | :param attn_heads: head sizes of multi-head attention
131 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
132 | :param dropout: dropout rate
133 | """
134 |
135 | super().__init__()
136 | print(f"attn_heads={attn_heads}")
137 | self.clip_grad = clip_grad
138 | # self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
139 | # self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
140 | # self.attn = SublayerConnection(size=hidden, dropout=dropout)
141 | # self.ff = SublayerConnection(size=hidden, dropout=dropout)
142 | if self.clip_grad == "agc":
143 | self.attn = Residual( MultiHeadedAttention(h = attn_heads, d_model=hidden, dropout=dropout) )
144 | self.ff = Residual( PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) )
145 | else:
146 | self.attn = Residual(PreNorm(hidden, MultiHeadedAttention(h = attn_heads, d_model=hidden, dropout=dropout)))
147 | self.ff = Residual(PreNorm(hidden, PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)))
148 |
149 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None
150 |
151 | def forward(self, x, mask):
152 | # x = self.attn(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
153 | # x = self.ff(x, self.feed_forward)
154 | x = self.attn(x, mask=mask)
155 | x = self.ff(x)
156 | if self.dropout is not None:
157 | return self.dropout(x)
158 | else:
159 | return x
160 |
161 |
162 | class AttentionQKV(nn.Module):
163 | def __init__(self, hidden, attn_heads, dropout):
164 | super(AttentionQKV, self).__init__()
165 | self.attn = Residual(PreNorm(hidden, MultiHeadedAttention(h = attn_heads, d_model=hidden, dropout=dropout)))
166 |
167 | def forward(self, x, mask=None):
168 | shape = list(x.shape)
169 | if len(shape)==2:
170 | x = x.unsqueeze(1)
171 | x = self.attn(x, mask=mask)
172 | if len(shape)==2:
173 | x = x.squeeze(1)
174 | return x
175 |
--------------------------------------------------------------------------------
/python-package/onnet/PoolForCls.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | from .some_utils import *
5 |
6 | class ChunkPool(torch.nn.Module):
7 | def __init__(self, nCls,config,pooling="max",chunk_dim=-1):
8 | super(ChunkPool, self).__init__()
9 | self.nClass = nCls
10 | self.pooling = pooling
11 | self.chunk_dim=chunk_dim
12 | self.config = config
13 | #self.regions = split_regions_2d(x.shape,self.nClass)
14 |
15 | def __repr__(self):
16 | main_str = super(ChunkPool, self).__repr__()
17 | main_str += f"_cls[{self.nClass}]_pool[{self.pooling}]"
18 | return main_str
19 |
20 | def forward(self, x):
21 | nSamp = x.shape[0]
22 | if False:
23 | x1 = torch.zeros((nSamp, self.nClass)).double()
24 | step = self.M // self.nClass
25 | for samp in range(nSamp):
26 | for i in range(self.nClass):
27 | x1[samp,i] = torch.max(x[samp,:,:,i*step:(i+1)*step])
28 | x_np = x1.detach().cpu().numpy()
29 | x = x1.cuda()
30 | else:
31 | x_max=[]
32 | if self.config.output_chunk=="1D":
33 | sections=split__sections(x.shape[self.chunk_dim],self.nClass)
34 | for xx in x.split(sections, self.chunk_dim):
35 | x2 = xx.contiguous().view(nSamp, -1)
36 | if self.pooling == "max":
37 | x3 = torch.max(x2, 1)
38 | x_max.append(x3.values)
39 | else:
40 | x3 = torch.mean(x2, 1)
41 | x_max.append(x3)
42 | else: #2D
43 | regions = split_regions_2d(x.shape,self.nClass)
44 | for box in regions:
45 | x2 = x[...,box[0]:box[1],box[2]:box[3]]
46 | x2 = x2.contiguous().view(nSamp, -1)
47 | if self.pooling == "max":
48 | x3 = torch.max(x2, 1)
49 | x_max.append(x3.values)
50 | else:
51 | x3 = torch.mean(x2, 1)
52 | x_max.append(x3)
53 | assert len(x_max)==self.nClass
54 | x = torch.stack(x_max,1)
55 | #x_np = x.detach().cpu().numpy()
56 | #print(x_np)
57 | return x
58 |
59 | class BinaryChunk(torch.nn.Module):
60 | def __init__(self, nCls,isLogit=False,pooling="max",chunk_dim=-1):
61 | super(BinaryChunk, self).__init__()
62 | self.nClass = nCls
63 | self.nChunk = (int)(math.ceil(math.log2(self.nClass)))
64 | self.pooling = pooling
65 | self.isLogit = isLogit
66 |
67 | def __repr__(self):
68 | main_str = super(BinaryChunk, self).__repr__()
69 | if self.isLogit:
70 | main_str += "_logit"
71 | main_str += f"_nChunk{self.nChunk}_cls[{self.nClass}]_pool[{self.pooling}]"
72 | return main_str
73 |
74 | def chunk_poll(self,ck,nSamp):
75 | x2 = ck.contiguous().view(nSamp, -1)
76 | if self.pooling == "max":
77 | x3 = torch.max(x2, 1)
78 | return x3.values
79 | else:
80 | x3 = torch.mean(x2, 1)
81 | return x3
82 |
83 | def forward(self, x):
84 | nSamp = x.shape[0]
85 | x_max=[]
86 | for ck in x.chunk(self.nChunk, -1):
87 | if self.isLogit:
88 | x_max.append(self.chunk_poll(ck,nSamp))
89 | else:
90 | for xx in ck.chunk(2, -2):
91 | x2 = xx.contiguous().view(nSamp, -1)
92 | if self.pooling == "max":
93 | x3 = torch.max(x2, 1)
94 | x_max.append(x3.values)
95 | else:
96 | x3 = torch.mean(x2, 1)
97 | x_max.append(x3)
98 | x = torch.stack(x_max,1)
99 |
100 | return x
--------------------------------------------------------------------------------
/python-package/onnet/RGBO_CNN.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch.nn as nn
3 | import os
4 | #from torchvision import models
5 | sys.path.append("../..")
6 | from cnn_models import *
7 | from torchvision import transforms
8 | from torchvision.transforms.functional import to_grayscale
9 | from torch.autograd import Variable
10 | # from resnet import resnet50
11 | from copy import deepcopy
12 | import numpy as np
13 | import pickle
14 | from .NET_config import *
15 | from .D2NNet import *
16 |
17 | class RGBO_CNN_config(NET_config):
18 | def __init__(self, net_type, data_set, IMG_size, lr_base, batch_size, nClass, nLayer):
19 | super(RGBO_CNN_config, self).__init__(net_type, data_set, IMG_size, lr_base, batch_size,nClass,nLayer)
20 | #self.dnet_type = ""
21 | self.dnet_type = "stack_input"
22 | self.dnet_type = "stack_feature"
23 |
24 | def image_transformer():
25 | """
26 | :return: A transformer to convert a PIL image to a tensor image
27 | ready to feed into a neural network
28 | """
29 | return {
30 | 'train': transforms.Compose([
31 | transforms.RandomCrop(224),
32 | transforms.RandomHorizontalFlip(),
33 | transforms.ToTensor(),
34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
35 | ]),
36 | 'val': transforms.Compose([
37 | transforms.Resize(256),
38 | transforms.CenterCrop(224),
39 | transforms.ToTensor(),
40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
41 | ]),
42 | }
43 |
44 | '''
45 | 1 参见cifar_rgbF.jpg,简单的fourier channel没啥效果
46 | '''
47 | class D_input(nn.Module):
48 | def __init__(self, config, DNet):
49 | super(D_input, self).__init__()
50 | self.config = config
51 | self.DNet = DNet
52 | self.inplanes = 64
53 | self.nLayD = DNet.nDifrac#self.DNet.config.nLayer
54 | #self.nLayD = 1
55 | #self.c_input =nn.Conv2d(3+self.nLayD, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
56 | self.c_input = nn.Conv2d(3+self.nLayD, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
57 |
58 | def forward(self, x):
59 | nChan = x.shape[1]
60 | assert nChan==3 or nChan==1
61 | if nChan==3:
62 | gray = x[:, 0:1]*0.3 + 0.59 * x[:, 1:2] + 0.11 * x[:, 2:3] # to_grayscale(x)
63 | else:
64 | gray = x
65 | return self.DNet.forward(gray)
66 | listT = []
67 | for i in range(nChan):
68 | listT.append(x[:, i:i+1])
69 | if self.nLayD>=1:
70 | self.DNet.forward(gray)
71 | assert len(self.DNet.feat_extractor) == self.nLayD
72 | for opti, w in self.DNet.feat_extractor:
73 | listT.append(opti) #*w
74 | elif self.nLayD==0:#
75 | pass
76 | else:
77 | listT.append(gray)
78 |
79 | x = torch.stack(listT,dim=1).squeeze()
80 | if hasattr(self, 'visual'): self.visual.onX(x, f"D_input")
81 | x = self.c_input(x)
82 | return x
83 |
84 | def forward_000(self, x):
85 | if False:
86 | gray = x[:, 0:1] # to_grayscale(x)
87 | self.DNet.forward(gray)
88 | # in_opti = self.DNet.concat_layer_modulus() # self.get_resnet_convs_out(x)
89 | for opti, w in self.DNet.feat_extractor:
90 | opti = torch.stack([opti, opti, opti], 1).squeeze() # opti.repeat(3, 1)
91 | out_opti = self.resNet.forward(opti)
92 | out_sum = out_sum + out_opti * w
93 | pass
94 |
95 | class RGBO_CNN(torch.nn.Module):
96 | '''
97 | resnet https://missinglink.ai/guides/pytorch/pytorch-resnet-building-training-scaling-residual-networks-pytorch/
98 | '''
99 | def pick_models(self):
100 | if False: #from torch vision or cadene models
101 | model_names = sorted(name for name in cnn_models.__dict__
102 | if name.islower() and not name.startswith("__")
103 | and callable(models.__dict__[name]))
104 | print(model_names)
105 | # pretrainedmodels https://data.lip6.fr/cadene/pretrainedmodels/
106 | model_names = ['alexnet', 'bninception', 'cafferesnet101', 'densenet121', 'densenet161', 'densenet169',
107 | 'densenet201',
108 | 'dpn107', 'dpn131', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'fbresnet152',
109 | 'inceptionresnetv2', 'inceptionv3', 'inceptionv4', 'nasnetalarge', 'nasnetamobile',
110 | 'pnasnet5large',
111 | 'polynet',
112 | 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x4d',
113 | 'resnext101_64x4d',
114 | 'se_resnet101', 'se_resnet152', 'se_resnet50', 'se_resnext101_32x4d', 'se_resnext50_32x4d',
115 | 'senet154', 'squeezenet1_0', 'squeezenet1_1',
116 | 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'xception']
117 |
118 | # model_name='cafferesnet101'
119 | # model_name='resnet101'
120 | # model_name='se_resnet50'
121 | # model_name='vgg16_bn'
122 | # model_name='vgg11_bn'
123 | # model_name='dpn68' #learning rate=0.0001 效果较好
124 | self.back_bone = 'resnet18_x'
125 | # model_name='dpn92'
126 | # model_name='senet154'
127 | # model_name='densenet121'
128 | # model_name='alexnet'
129 | # model_name='senet154'
130 | cnn_model = ResNet34() ;#models.resnet18(pretrained=True)
131 | return cnn_model
132 |
133 | def __init__(self, config,DNet):
134 | super(RGBO_CNN, self).__init__()
135 | seed_everything(42)
136 | self.config = config
137 | backbone = self.pick_models()
138 | if self.config.dnet_type == "stack_feature":
139 | self.DInput = D_input(config,DNet)
140 | elif self.config.dnet_type == "stack_input": #False and hasattr(self,'DInput'):
141 | self.CNet = nn.Sequential(*list(backbone.children())[1:])
142 | else:
143 | self.CNet = nn.Sequential(*list(backbone.children()))
144 |
145 | #print(f"=> creating model CNet='{self.CNet}'\nDNet={self.DNet}")
146 | if False: #外层处理
147 | if config.gpu_device is not None:
148 | self.cuda(config.gpu_device)
149 | print(next(self.parameters()).device)
150 | self.thickness_criterion = self.thickness_criterion.cuda()
151 | self.metal_criterion = self.metal_criterion.cuda()
152 | elif config.distributed:
153 | self.cuda()
154 | self = torch.nn.parallel.DistributedDataParallel(self)
155 | else:
156 | self = torch.nn.DataParallel(self).cuda()
157 |
158 | def save_acti(self,x,name):
159 | acti = x.cpu().data.numpy()
160 | self.activations.append({'name':name,'shape':acti.shape,'activation':acti})
161 |
162 | #https://forums.fast.ai/t/pytorch-best-way-to-get-at-intermediate-layers-in-vgg-and-resnet/5707/6
163 |
164 |
165 | def forward_0(self, x):
166 | if hasattr(self, 'DInput'):
167 | x = self.DInput(x)
168 | for no,lay in enumerate(self.CNet):
169 | if isinstance(lay,nn.Linear): #x = self.avgpool(x), x = x.reshape(x.size(0), -1)
170 | x = F.avg_pool2d(x, 4)
171 | x = x.reshape(x.size(0), -1)
172 | x = lay(x)
173 | #print(f"{no}:\t{lay}\nx={x}")
174 | if isinstance(lay,nn.AdaptiveAvgPool2d): #x = self.avgpool(x), x = x.reshape(x.size(0), -1)
175 | x = x.reshape(x.size(0), -1)
176 | out_sum = x
177 | return out_sum
178 |
179 | def forward(self, x):
180 | out_sum = 0
181 | if self.config.dnet_type == "stack_feature":
182 | out_sum= self.DInput(x)
183 | for no,lay in enumerate(self.CNet):
184 | if isinstance(lay,nn.Linear): #x = self.avgpool(x), x = x.reshape(x.size(0), -1)
185 | x = F.avg_pool2d(x, 4)
186 | x = x.reshape(x.size(0), -1)
187 | x = lay(x)
188 | #print(f"{no}:\t{lay}\nx={x}")
189 | if isinstance(lay,nn.AdaptiveAvgPool2d): #x = self.avgpool(x), x = x.reshape(x.size(0), -1)
190 | x = x.reshape(x.size(0), -1)
191 | out_sum += x
192 | return out_sum
193 |
194 | if __name__ == "__main__":
195 | config = DNET_config(None)
196 | a = RGBO_CNN(config,nFilmLayer=10)
197 | print(f"RGBO_CNN={a}")
198 | pass
199 |
200 |
--------------------------------------------------------------------------------
/python-package/onnet/SparseSupport.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .D2NNet import *
3 | from .some_utils import *
4 | import numpy as np
5 | import random
6 | import torch.nn as nn
7 | from enum import Enum
8 |
9 | #https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-custom-nn-modules
10 | class SuppLayer(torch.nn.Module):
11 | class SUPP(Enum):
12 | exp,sparse,expW,diff = 'exp','sparse','expW','differentia'
13 |
14 | def __init__(self,config,nClass, nSupp=10):
15 | super(SuppLayer, self).__init__()
16 | self.nClass = nClass
17 | self.nSupp = nSupp
18 | self.nChunk = self.nClass*2
19 | self.config = config
20 | self.w_11=False
21 | if self.config.support==self.SUPP.sparse: #"supp_sparse":
22 | if self.w_11:
23 | tSupp = torch.ones(self.nClass, self.nSupp)
24 | else:
25 | tSupp = torch.Tensor(self.nClass, self.nSupp).uniform_(-1,1)
26 | self.wSupp = torch.nn.Parameter(tSupp)
27 | self.nChunk = self.nSupp*self.nSupp
28 | self.chunk_map = np.random.randint(self.nChunk, size=(self.nClass, self.nSupp))
29 | #elif self.config.support=="supp_expW":
30 | # self.nSupp = 2
31 | # self.wSupp = torch.nn.Parameter(torch.ones(2))
32 |
33 | def __repr__(self):
34 | w_init="1" if self.w_11 else "random"
35 | main_str = f"SupportLayer supp=({self.nSupp},{w_init}) type=\"{self.config.support}\" nChunk={self.nChunk}"
36 | return main_str
37 |
38 | def sparse_support(self,x):
39 | feats=[]
40 | for i in range(self.nClass):
41 | feat = 0;
42 | for j in range(self.nSupp):
43 | col = (int)(self.chunk_map[i,j])
44 | feat += x[:, col]*self.wSupp[i,j]
45 | feats.append(torch.exp(feat)) #why exp is useful???
46 | #feats.append(feat)
47 | output = torch.stack(feats,1)
48 | return output
49 |
50 | def forward(self, x):
51 | if self.config.support == self.SUPP.sparse: # "supp_sparse":
52 | output = self.sparse_support(x)
53 | return output
54 |
55 | assert x.shape[1] == self.nClass * 2
56 | if self.config.support==self.SUPP.diff: #"supp_differentia":
57 | for i in range(self.nClass):
58 | x[:,i] = (x[:,2*i]-x[:,2*i+1])/(x[:,2*i]+x[:,2*i+1])
59 | output=x[...,0:self.nClass]
60 | elif self.config.support==self.SUPP.exp: #"supp_exp":
61 | for i in range(self.nClass):
62 | x[:, i] = torch.exp(x[:, 2 * i] - x[:, 2 * i + 1])
63 | output = x[..., 0:self.nClass]
64 | elif self.config.support==self.SUPP.expW: #"supp_expW":
65 | output = torch.zeros_like(x)
66 | for i in range(self.nClass):
67 | output[:, i] = torch.exp(x[:, 2 * i]*self.w2[0] - x[:, 2 * i + 1]*self.w2[1])
68 | output = output[..., 0:self.nClass]
69 |
70 | return output
71 |
72 |
--------------------------------------------------------------------------------
/python-package/onnet/ToExcel.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Author: Yingshi Chen
3 |
4 | @Date: 2020-01-14 15:36:32
5 | @
6 | # Description:
7 | '''
8 | import numpy as np
9 | import pandas as pd
10 | import json
11 | import glob
12 | import argparse
13 | from scipy.signal import savgol_filter
14 |
15 | def OnVisdom_json(param,title,smooth=False):
16 | search_str = f"{param['data_root']}{param['select']}"
17 | files = glob.glob(search_str)
18 | datas = []
19 | cols = []
20 | for i, file in enumerate(files):
21 | with open(file, 'r') as f:
22 | meta = json.load(f)
23 | curve = meta['jsons']['loss']['content']['data'][0]
24 | legend = meta['jsons']['loss']['legend']
25 | cols.append(legend[0])
26 | item = curve['y']
27 | datas.append(item)
28 | if smooth:
29 | win = max(9,len(item)//10)
30 | cols.append(f"{legend[0]}_smooth")
31 | item_s = savgol_filter(item, win, 3)
32 | datas.append(item_s)
33 | pass
34 |
35 | df = pd.DataFrame(datas)
36 | df = df.transpose()
37 | for i,col in enumerate(cols):
38 | df = df.rename(columns={i: col})
39 |
40 | path = f"{param['data_root']}{title}_please_rename.xlsx"
41 | df.to_excel(path )
42 |
43 | print(df.head())
44 |
45 |
46 | if __name__ == '__main__':
47 | parser = argparse.ArgumentParser(description='Load json of visdom curves. Save to EXCEL!')
48 | parser.add_argument("keyword", type=str, help="keyword")
49 | parser.add_argument("root", type=str, help="root")
50 |
51 | args = parser.parse_args()
52 |
53 | if hasattr(args,'keyword') and hasattr(args,'root'):
54 | keyword = args.keyword # "WNet_mnist"
55 | data_root = args.root #"F:/arXiv/Diffractive Wavenet - an novel low parameter optical neural network/"
56 | param = {"data_root":data_root,
57 | "select":f"{keyword}*.json"}
58 | OnVisdom_json(param,keyword,smooth=True)
59 | else:
60 | param = {"data_root":"E:\Guided Inverse design of SPP structures\images",
61 | "select":f"3_4*.json"}
62 | OnVisdom_json(param,keyword)
--------------------------------------------------------------------------------
/python-package/onnet/Visualizing.py:
--------------------------------------------------------------------------------
1 | '''
2 | python -m visdom.server
3 | http://localhost:8097
4 | .json file present in your ~/.visdom directory.
5 |
6 | tensorboard --logdir=runs
7 | http://localhost:6006/ 非常奇怪的出错
8 |
9 | ONNX export failed on ATen operator ifft because torch.onnx.symbolic.ifft does not exist
10 | '''
11 | import seaborn as sns; sns.set()
12 | from PIL import Image
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from torch.utils.data import DataLoader
17 | #from torch.utils.tensorboard import SummaryWriter
18 | import visdom
19 | import matplotlib.pyplot as plt
20 | import numpy as np
21 | import torchvision
22 | import cv2
23 | from torchvision import datasets, transforms
24 | from .Z_utils import COMPLEX_utils as Z
25 |
26 | def matplotlib_imshow(img, one_channel=False):
27 | if one_channel:
28 | img = img.mean(dim=0)
29 | img = img / 2 + 0.5 # unnormalize
30 | npimg = img.numpy()
31 | if one_channel:
32 | plt.imshow(npimg, cmap="Greys")
33 | else:
34 | plt.imshow(np.transpose(npimg, (1, 2, 0)))
35 | plt.show()
36 |
37 |
38 | class Visualize:
39 | def __init__(self,env_title="onnet",plots=[], **kwargs):
40 | self.log_dir = f'runs/{env_title}'
41 | self.plots = plots
42 | self.loss_step = 0
43 | self.writer = None #SummaryWriter(self.log_dir)
44 | self.img_dir="./dump/images/"
45 | self.dpi = 100
46 |
47 | #https://stackoverflow.com/questions/9662995/matplotlib-change-title-and-colorbar-text-and-tick-colors
48 | def MatPlot(self,arr, title=""):
49 | fig, ax = plt.subplots()
50 | #plt.axis('off')
51 | plt.grid(b=None)
52 | im = ax.imshow(arr, interpolation='nearest', cmap='coolwarm')
53 | fig.colorbar(im, orientation='horizontal')
54 | plt.savefig(f'{self.img_dir}{title}.jpg')
55 | #plt.show()
56 | plt.close()
57 |
58 | def fig2data(self,fig):
59 | fig.canvas.draw()
60 | if True: # https://stackoverflow.com/questions/42603161/convert-an-image-shown-in-python-into-an-opencv-image
61 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
62 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
63 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
64 | return img
65 | else:
66 | w, h = fig.canvas.get_width_height()
67 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
68 | buf.shape = (w, h, 4)
69 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
70 | buf = np.roll(buf, 3, axis=2)
71 | return buf
72 |
73 | '''
74 | sns.heatmap 很难用,需用自定义,参见https://stackoverflow.com/questions/53248186/custom-ticks-for-seaborn-heatmap
75 | '''
76 | def HeatMap(self, data, file_name, params={},noAxis=True, cbar=True):
77 | title,isSave = file_name,True
78 | if 'save' in params:
79 | isSave = params['save']
80 | if 'title' in params:
81 | title = params['title']
82 | path = '{}{}_.jpg'.format(self.img_dir, file_name)
83 | sns.set(font_scale=3)
84 | s = max(data.shape[1] / self.dpi, data.shape[0] / self.dpi)
85 | # fig.set_size_inches(18.5, 10.5)
86 | cmap = 'coolwarm' # "plasma" #https://matplotlib.org/examples/color/colormaps_reference.html
87 | # cmap = sns.cubehelix_palette(start=1, rot=3, gamma=0.8, as_cmap=True)
88 | if noAxis: # tight samples for training(No text!!!)
89 | figsize = (s, s)
90 | fig, ax = plt.subplots(figsize=figsize, dpi=self.dpi)
91 | ax = sns.heatmap(data, ax=ax, cmap=cmap, cbar=False, xticklabels=False, yticklabels=False)
92 | fig.savefig(path, bbox_inches='tight', pad_inches=0,figsize=(20,10))
93 | if False:
94 | image = cv2.imread(path)
95 | # image = fig2data(ax.get_figure()) #会放大尺寸,难以理解
96 | if (len(title) > 0):
97 | assert (image.shape == self.args.spp_image_shape) # 必须固定一个尺寸
98 | cv2.imshow("",image); cv2.waitKey(0)
99 | plt.close("all")
100 | return path
101 | else: # for paper
102 | ticks = np.linspace(0, 1, 10)
103 | xlabels = [int(i) for i in np.linspace(0, 56, 10)]
104 | ylabels = xlabels
105 | figsize = (s * 10, s * 10)
106 | #fig, ax = plt.subplots(figsize=figsize, dpi=self.dpi) # more concise than plt.figure:
107 | fig, ax = plt.subplots(dpi=self.dpi)
108 | ax.set_title(title)
109 | # cbar_kws={'label': 'Reflex', 'orientation': 'horizontal'}
110 | # sns.set(font_scale=0.2)
111 | # cbar_kws={'label': 'Reflex', 'orientation': 'horizontal'} , center=0.6
112 | # ax = sns.heatmap(data, ax=ax, cmap=cmap,yticklabels=ylabels[::-1],xticklabels=xlabels)
113 | # cbar_kws = dict(ticks=np.linspace(0, 1, 10))
114 | ax = sns.heatmap(data, ax=ax, cmap=cmap,vmin=-1.1, vmax=1.1, cbar=cbar) #
115 | #plt.ylabel('Incident Angle'); plt.xlabel('Wavelength(nm)')
116 | if False:
117 | ax.set_xticklabels(xlabels); ax.set_yticklabels(ylabels[::-1])
118 | y_limit = ax.get_ylim();
119 | x_limit = ax.get_xlim()
120 | ax.set_yticks(ticks * y_limit[0])
121 | ax.set_xticks(ticks * x_limit[1])
122 | else:
123 | plt.axis('off')
124 | if False:
125 | plt.show(block=True)
126 |
127 | image = self.fig2data(ax.get_figure())
128 | plt.close("all")
129 | #image_all = np.concatenate((img_0, img_1, img_diff), axis=1)
130 | #cv2.imshow("", image); cv2.waitKey(0)
131 | if isSave:
132 | cv2.imwrite(path, image)
133 | return path
134 | else:
135 | return image
136 |
137 | plt.close("all")
138 |
139 | def ShowModel(self,model,data_loader):
140 | '''
141 | tensorboar显示效果较差
142 | '''
143 | dataiter = iter(data_loader)
144 | images, labels = dataiter.next()
145 | if images.shape[0]>32:
146 | images=images[0:32,...]
147 | if True:
148 | img_grid = torchvision.utils.make_grid(images)
149 | matplotlib_imshow(img_grid, one_channel=True)
150 | self.writer.add_image('one_batch', img_grid)
151 | self.writer.close()
152 | image_1 = images[0:1,:,:,:]
153 | if False:
154 | images = images.cuda()
155 | self.writer.add_graph(model,images )
156 | self.writer.close()
157 |
158 | def onX(self,X,title,nMostPic=64):
159 | shape = X.shape
160 | if Z.isComplex(X):
161 | #X = torch.cat([X[..., 0],X[..., 1]],0)
162 | X = Z.modulus(X)
163 | X = X.cpu()
164 | if shape[1]!=1:
165 | X = X.contiguous().view(shape[0]*shape[1],1,shape[-2],shape[-1]).cpu()
166 | if X.shape[0]>nMostPic:
167 | X=X[:nMostPic,...]
168 | img_grid = torchvision.utils.make_grid(X).detach().numpy()
169 | plt.axis('off');
170 | plt.grid(b=None)
171 | image_np = np.transpose(img_grid, (1, 2, 0))
172 | min_val,max_val = np.max(image_np),np.min(image_np)
173 | image_np = (image_np - min_val) / (max_val - min_val)
174 | if title is None:
175 | plt.imshow(image_np)
176 | plt.show()
177 | else:
178 | path = '{}{}_.jpg'.format(self.img_dir, title)
179 | plt.imsave(path, image_np)
180 |
181 |
182 | def image(self, file_name, img_, params={}):
183 | #np.random.rand(3, 512, 256),
184 | #self.MatPlot(img_.cpu().numpy(),title=name)
185 |
186 | result = self.HeatMap(img_.cpu().numpy(),file_name,params,noAxis=False)
187 | return result
188 |
189 | def UpdateLoss(self,title,legend,loss,yLabel='LOSS',global_step=None):
190 | tag = legend
191 | step = self.loss_step if global_step==None else global_step
192 | with SummaryWriter(log_dir=self.log_dir) as writer:
193 | writer.add_scalar(tag, loss, global_step=step)
194 | #self.writer.close() # 执行close立即刷新,否则将每120秒自动刷新
195 | self.loss_step = self.loss_step+1
196 |
197 | class Visdom_Visualizer(Visualize):
198 | '''
199 | 封装了visdom的基本操作
200 | '''
201 |
202 | def __init__(self,env_title,plots=[], **kwargs):
203 | super(Visdom_Visualizer, self).__init__(env_title,plots)
204 | try:
205 | self.viz = visdom.Visdom(env=env_title, **kwargs)
206 | assert self.viz.check_connection()
207 | except:
208 | self.viz = None
209 |
210 | def UpdateLoss(self, title,legend, loss, yLabel='LOSS',global_step=None):
211 | self.vis_plot( self.loss_step, loss, title,legend,yLabel)
212 | self.loss_step = self.loss_step + 1
213 |
214 | def vis_plot(self,epoch, loss_, title,legend,yLabel):
215 | if self.viz is None:
216 | return
217 | self.viz.line(X=torch.FloatTensor([epoch]), Y=torch.FloatTensor([loss_]), win='loss',
218 | opts=dict(
219 | legend=[legend], # [config_.use_bn],
220 | fillarea=False,
221 | showlegend=True,
222 | width=1600,
223 | height=800,
224 | xlabel='Epoch',
225 | ylabel=yLabel,
226 | # ytype='log',
227 | title=title,
228 | # marginleft=30,
229 | # marginright=30,
230 | # marginbottom=80,
231 | # margintop=30,
232 | ),
233 | update='append' if epoch > 0 else None)
234 |
235 | def reinit(self, env='default', **kwargs):
236 | self.vis = visdom.Visdom(env=env, **kwargs)
237 | return self
238 |
239 | def plot_many(self, d):
240 | '''
241 | 一次plot多个
242 | @params d: dict (name,value) i.e. ('loss',0.11)
243 | '''
244 | for k, v in d.iteritems():
245 | self.plot(k, v)
246 |
247 | def img_many(self, d):
248 | for k, v in d.iteritems():
249 | self.img(k, v)
250 |
251 | def plot(self, name, y, **kwargs):
252 | '''
253 | self.plot('loss',1.00)
254 | '''
255 | x = self.index.get(name, 0)
256 | self.vis.line(Y=np.array([y]), X=np.array([x]),
257 | win=name,
258 | opts=dict(title=name),
259 | update=None if x == 0 else 'append',
260 | **kwargs
261 | )
262 | self.index[name] = x + 1
263 |
264 | ''' 非常奇怪的出错
265 | def image(self, name, img_, **kwargs):
266 |
267 | assert self.viz.check_connection()
268 | self.vis.image(
269 | np.random.rand(3, 512, 256),
270 | opts=dict(title='Random image as jpg!', caption='How random as jpg.', jpgquality=50),
271 | )
272 | self.vis.image (img_.cpu().numpy(),
273 | #win=(name),
274 | opts=dict(title=name),
275 | **kwargs
276 | )
277 | '''
278 |
279 | def log(self, info, win='log_text'):
280 | '''
281 | self.log({'loss':1,'lr':0.0001})
282 | '''
283 |
284 | self.log_text += ('[{time}] {info}
'.format(
285 | time=time.strftime('%m%d_%H%M%S'), \
286 | info=info))
287 | self.vis.text(self.log_text, win)
288 | print(self.log_text)
289 |
290 | def __getattr__(self, name):
291 | return getattr(self.vis, name)
292 |
293 | def PROJECTOR_test():
294 | """ ==================使用PROJECTOR对高维向量可视化====================
295 | https://blog.csdn.net/wsp_1138886114/article/details/87602112
296 | PROJECTOR的的原理是通过PCA,T-SNE等方法将高维向量投影到三维坐标系(降维度)。
297 | Embedding Projector从模型运行过程中保存的checkpoint文件中读取数据,
298 | 默认使用主成分分析法(PCA)将高维数据投影到3D空间中,也可以通过设置设置选择T-SNE投影方法,
299 | 这里做一个简单的展示。
300 | """
301 | log_dirs = "../../runs/projector/"
302 | BATCH_SIZE = 256
303 | EPOCHS = 2
304 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
305 |
306 | train_loader = DataLoader(datasets.MNIST('../../data', train=True, download=False,
307 | transform=transforms.Compose([
308 | transforms.ToTensor(),
309 | transforms.Normalize((0.1307,), (0.3081,))
310 | ])),
311 | batch_size=BATCH_SIZE, shuffle=True)
312 |
313 | test_loader = torch.utils.data.DataLoader(
314 | datasets.MNIST('../../data', train=False, transform=transforms.Compose([
315 | transforms.ToTensor(),
316 | transforms.Normalize((0.1307,), (0.3081,))
317 | ])),
318 | batch_size=BATCH_SIZE, shuffle=True)
319 |
320 | class ConvNet(nn.Module):
321 | def __init__(self):
322 | super().__init__()
323 | # 1,28x28
324 | self.conv1 = nn.Conv2d(1, 10, 5) # 10, 24x24
325 | self.conv2 = nn.Conv2d(10, 20, 3) # 128, 10x10
326 | self.fc1 = nn.Linear(20 * 10 * 10, 500)
327 | self.fc2 = nn.Linear(500, 10)
328 |
329 | def forward(self, x):
330 | in_size = x.size(0)
331 | out = self.conv1(x) # 24
332 | out = F.relu(out)
333 | out = F.max_pool2d(out, 2, 2) # 12
334 | out = self.conv2(out) # 10
335 | out = F.relu(out)
336 | out = out.view(in_size, -1)
337 | out = self.fc1(out)
338 | out = F.relu(out)
339 | out = self.fc2(out)
340 | out = F.log_softmax(out, dim=1)
341 | return out
342 |
343 | model = ConvNet().to(DEVICE)
344 | optimizer = torch.optim.Adam(model.parameters())
345 |
346 | def train(model, DEVICE, train_loader, optimizer, epoch):
347 | n_iter = 0
348 | model.train()
349 | for batch_idx, (data, target) in enumerate(train_loader):
350 | data, target = data.to(DEVICE), target.to(DEVICE)
351 | optimizer.zero_grad()
352 | output = model(data)
353 | loss = F.nll_loss(output, target)
354 | loss.backward()
355 | optimizer.step()
356 | if (batch_idx + 1) % 30 == 0:
357 | n_iter = n_iter + 1
358 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
359 | epoch, batch_idx * len(data), len(train_loader.dataset),
360 | 100. * batch_idx / len(train_loader), loss.item()))
361 |
362 | # 主要增加了一下内容
363 | out = torch.cat((output.data.cpu(), torch.ones(len(output), 1)), 1) # 因为是投影到3D的空间,所以我们只需要3个维度
364 | with SummaryWriter(log_dir=log_dirs, comment='mnist') as writer:
365 | # 使用add_embedding方法进行可视化展示
366 | writer.add_embedding(
367 | out,
368 | metadata=target.data,
369 | label_img=data.data,
370 | global_step=n_iter)
371 |
372 | def test(model, device, test_loader):
373 | model.eval()
374 | test_loss = 0
375 | correct = 0
376 | with torch.no_grad():
377 | for data, target in test_loader:
378 | data, target = data.to(device), target.to(device)
379 | output = model(data)
380 | test_loss += F.nll_loss(output, target, reduction='sum').item() # 损失相加
381 | pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
382 | correct += pred.eq(target.view_as(pred)).sum().item()
383 |
384 | test_loss /= len(test_loader.dataset)
385 | print('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
386 | .format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
387 |
388 | for epoch in range(1, EPOCHS + 1):
389 | train(model, DEVICE, train_loader, optimizer, epoch)
390 | test(model, DEVICE, test_loader)
391 |
392 | # 保存模型
393 | torch.save(model.state_dict(), './pytorch_tensorboardX_03.pth')
394 |
395 | if __name__ == '__main__':
396 | PROJECTOR_test()
--------------------------------------------------------------------------------
/python-package/onnet/Z_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | 1 晕 Pytorch居然不支持复向量 https://github.com/pytorch/pytorch/issues/755
3 |
4 | '''
5 |
6 | import torch
7 | from torch.nn import ReflectionPad2d
8 | from torch.nn.functional import relu, max_pool2d, dropout, dropout2d
9 | import numpy as np
10 |
11 | class COMPLEX_utils(object):
12 | @staticmethod
13 | def isComplex(input):
14 | return input.size(-1) == 2
15 |
16 | @staticmethod
17 | def isReal(input):
18 | return input.size(-1) == 1
19 |
20 | @staticmethod
21 | def ToZ(u0):
22 | if COMPLEX_utils.isComplex(u0):
23 | return u0
24 | else:
25 | z0 = u0.new_zeros(u0.shape + (2,))
26 | z0[..., 0] = u0
27 | assert(COMPLEX_utils.isComplex(z0))
28 | return z0
29 |
30 | @staticmethod
31 | def relu(input_r,input_i):
32 | return relu(input_r), relu(input_i)
33 |
34 | @staticmethod
35 | def max_pool2d(input_r,input_i,kernel_size, stride=None, padding=0,
36 | dilation=1, ceil_mode=False, return_indices=False):
37 |
38 | return max_pool2d(input_r, kernel_size, stride, padding, dilation,
39 | ceil_mode, return_indices), \
40 | max_pool2d(input_i, kernel_size, stride, padding, dilation,
41 | ceil_mode, return_indices)
42 |
43 | @staticmethod
44 | def rDrop2D(rDrop,d_shape,isComlex=False):
45 | drop = np.random.binomial(1, rDrop, size=d_shape).astype(np.float)
46 | drop[drop == 0] = 1.0e-6
47 | # print(f"x={x.shape} drop={drop.shape}")
48 | drop = torch.from_numpy(drop).cuda()
49 | if isComlex:
50 | drop = COMPLEX_utils.ToZ(drop)
51 | return drop
52 | '''
53 | @staticmethod
54 | def dropout(input_r,input_i, p=0.5, training=True, inplace=False):
55 | return dropout(input_r, p, training, inplace), \
56 | dropout(input_i, p, training, inplace)
57 |
58 | @staticmethod
59 | def dropout2d(input_r,input_i, p=0.5, training=True, inplace=False):
60 | return dropout2d(input_r, p, training, inplace), \
61 | dropout2d(input_i, p, training, inplace)
62 | '''
63 |
64 | #the absolute value or modulus of z https://en.wikipedia.org/wiki/Absolute_value#Complex_numbers
65 | @staticmethod
66 | def modulus(x):
67 | shape = x.size()[:-1]
68 | if False:
69 | norm = torch.zeros(shape)
70 | if x.dtype==torch.float64:
71 | norm = norm.double()
72 | norm = (x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1]).sqrt()
73 | return norm
74 |
75 | @staticmethod
76 | def phase(x):
77 | phase = torch.atan2(x[..., 0],x[..., 1])
78 | return phase
79 |
80 | @staticmethod
81 | def sigmoid(x):
82 | # norm[...,0] = (x[...,0]*x[...,0] + x[...,1]*x[...,1]).sqrt()
83 | s_ = torch.zeros_like(x)
84 | s_[...,0] = torch.sigmoid(x[...,0])
85 | s_[..., 1] = torch.sigmoid(x[..., 1])
86 | return s_
87 |
88 | @staticmethod
89 | def exp_euler(x): #Euler's formula: {\displaystyle e^{ix}=\cos x+i\sin x,}
90 | s_ = torch.zeros(x.shape + (2,)).double().cuda()
91 | s_[..., 0] = torch.cos(x)
92 | s_[..., 1] = torch.sin(x)
93 | return s_
94 |
95 | @staticmethod
96 | def fft(input, direction='C2C', inverse=False):
97 | """
98 | Interface with torch FFT routines for 2D signals.
99 |
100 | Example
101 | -------
102 | x = torch.randn(128, 32, 32, 2)
103 | x_fft = fft(x, inverse=True)
104 |
105 | Parameters
106 | ----------
107 | input : tensor
108 | complex input for the FFT
109 | direction : string
110 | 'C2R' for complex to real, 'C2C' for complex to complex
111 | inverse : bool
112 | True for computing the inverse FFT.
113 | NB : if direction is equal to 'C2R', then the transform
114 | is automatically inverse.
115 | """
116 | if direction == 'C2R':
117 | inverse = True
118 |
119 | if not COMPLEX_utils.isComplex(input):
120 | raise(TypeError('The input should be complex (e.g. last dimension is 2)'))
121 |
122 | if (not input.is_contiguous()):
123 | raise (RuntimeError('Tensors must be contiguous!'))
124 |
125 | if direction == 'C2R':
126 | output = torch.irfft(input, 2, normalized=False, onesided=False)*input.size(-2)*input.size(-3)
127 | elif direction == 'C2C':
128 | if inverse:
129 | #output = torch.ifft(input, 2, normalized=False)*input.size(-2)*input.size(-3)
130 | output = torch.ifft(input, 2, normalized=False)
131 | else:
132 | output = torch.fft(input, 2, normalized=False)
133 |
134 | return output
135 |
136 | @staticmethod
137 | def Hadamard(A, B, inplace=False):
138 | """
139 | Complex pointwise multiplication between (batched) tensor A and tensor B.
140 | Sincr The Hadamard product is commutative, so Hadamard(A, B)=Hadamard(B, A)
141 |
142 | Parameters
143 | ----------
144 | A : tensor
145 | A is a complex tensor of size (B, C, M, N, 2)
146 | B : tensor
147 | B is a complex tensor of size (M, N, 2) or real tensor of (M, N, 1)
148 | inplace : boolean, optional
149 | if set to True, all the operations are performed inplace
150 |
151 | Returns
152 | -------
153 | C : tensor
154 | output tensor of size (B, C, M, N, 2) such that:
155 | C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :]
156 | """
157 | if not COMPLEX_utils.isComplex(A):
158 | raise TypeError('The input must be complex, indicated by a last '
159 | 'dimension of size 2')
160 |
161 | if B.ndimension() != 3:
162 | raise RuntimeError('The filter must be a 3-tensor, with a last '
163 | 'dimension of size 1 or 2 to indicate it is real '
164 | 'or complex, respectively')
165 |
166 | if not COMPLEX_utils.isComplex(B) and not COMPLEX_utils.isReal(B):
167 | raise TypeError('The filter must be complex or real, indicated by a '
168 | 'last dimension of size 2 or 1, respectively')
169 |
170 | if A.size()[-3:-1] != B.size()[-3:-1]:
171 | raise RuntimeError('The filters are not compatible for multiplication!')
172 |
173 | if A.dtype is not B.dtype:
174 | raise RuntimeError('A and B must be of the same dtype')
175 |
176 | if A.device.type != B.device.type:
177 | raise RuntimeError('A and B must be of the same device type')
178 |
179 | if A.device.type == 'cuda':
180 | if A.device.index != B.device.index:
181 | raise RuntimeError('A and B must be on the same GPU!')
182 |
183 | if COMPLEX_utils.isReal(B):
184 | if inplace:
185 | return A.mul_(B)
186 | else:
187 | return A * B
188 | else:
189 | C = A.new(A.size())
190 |
191 | A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3))
192 | A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3))
193 |
194 | B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i)
195 | B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r)
196 |
197 | C[..., 0].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_r - A_i * B_i
198 | C[..., 1].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_i + A_i * B_r
199 |
200 | return C if not inplace else A.copy_(C)
201 |
202 | def IFFT(X1,X2,X3):
203 | f, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, sharex='col', sharey='row',figsize=(10,6))
204 | Z = ifftn(X1)
205 | ax1.imshow(X1, cmap=cm.Reds)
206 | ax4.imshow(np.real(Z), cmap=cm.gray)
207 | Z = ifftn(X2)
208 | ax2.imshow(X2, cmap=cm.Reds)
209 | ax5.imshow(np.real(Z), cmap=cm.gray)
210 | Z = ifftn(X3)
211 | ax3.imshow(X3, cmap=cm.Reds)
212 | ax6.imshow(np.real(Z), cmap=cm.gray)
213 | plt.show()
214 |
215 |
216 | def roll_n(X, axis, n):
217 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
218 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
219 | front = X[f_idx]
220 | back = X[b_idx]
221 | return torch.cat([back, front], axis)
222 |
223 | def batch_fftshift2d(x):
224 | real, imag = torch.unbind(x, -1)
225 | for dim in range(1, len(real.size())):
226 | n_shift = real.size(dim)//2
227 | if real.size(dim) % 2 != 0:
228 | n_shift += 1 # for odd-sized images
229 | real = roll_n(real, axis=dim, n=n_shift)
230 | imag = roll_n(imag, axis=dim, n=n_shift)
231 | return torch.stack((real, imag), -1) # last dim=2 (real&imag)
232 |
233 | def batch_ifftshift2d(x):
234 | real, imag = torch.unbind(x, -1)
235 | for dim in range(len(real.size()) - 1, 0, -1):
236 | real = roll_n(real, axis=dim, n=real.size(dim)//2)
237 | imag = roll_n(imag, axis=dim, n=imag.size(dim)//2)
238 | return torch.stack((real, imag), -1) # last dim=2 (real&imag)
239 |
240 |
--------------------------------------------------------------------------------
/python-package/onnet/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | '''
3 | @Author: Yingshi Chen
4 |
5 | @Date: 2020-01-16 10:38:45
6 | @
7 | # Description:
8 | '''
9 | # coding: utf-8
10 | """LiteMORT, Light Gradient Boosting Machine.
11 |
12 | __author__ = 'Yingshi Chen'
13 | """
14 |
15 | import os
16 |
17 | from .optical_trans import OpticalTrans
18 | from .D2NNet import D2NNet,DNET_config
19 | from .RGBO_CNN import RGBO_CNN,RGBO_CNN_config
20 | from .Z_utils import COMPLEX_utils
21 | from .BinaryDNet import *
22 | from .Net_Instance import *
23 | from .NET_config import *
24 | from .Visualizing import *
25 | from .some_utils import *
26 | from .DiffractiveLayer import *
27 | from .OpticalFormer import clip_grad,OpticalFormer
28 |
29 | '''
30 | try:
31 | except ImportError:
32 | pass
33 | '''
34 |
35 | '''
36 | try:
37 | from .plotting import plot_importance, plot_metric, plot_tree, create_tree_digraph
38 | except ImportError:
39 | pass
40 | '''
41 |
42 | dir_path = os.path.dirname(os.path.realpath(__file__))
43 | #print(f"__init_ dir_path={dir_path}")
44 |
45 | __all__ = ['NET_config',
46 | 'D2NNet','DNET_config','DNet_instance','RGBO_CNN_instance','Net_dump',
47 | 'RGBO_CNN', 'RGBO_CNN_config',
48 | 'OpticalTrans','COMPLEX_utils','MultiDNet','BinaryDNet','Visualize','Visdom_Visualizer',
49 | 'seed_everything','load_model_weights',
50 | 'DiffractiveLayer'
51 | ]
52 |
53 |
54 |
--------------------------------------------------------------------------------
/python-package/onnet/__version__.py:
--------------------------------------------------------------------------------
1 |
2 | VERSION = (0, 0, 1)
3 |
4 | __version__ = '.'.join(map(str, VERSION))
5 |
--------------------------------------------------------------------------------
/python-package/onnet/optical_trans.py:
--------------------------------------------------------------------------------
1 | # Authors: Edouard Oyallon
2 | # Scientific Ancestry: Edouard Oyallon, Laurent Sifre, Joan Bruna
3 |
4 |
5 | __all__ = ['optical_trans']
6 |
7 | import torch
8 |
9 | class OpticalTrans(object):
10 | def forward(self, input):
11 | #input = input.type(torch.complex64)
12 | return input
13 |
14 | def __call__(self, input):
15 | return self.forward(input)
16 |
17 | class Scattering2D(object):
18 | """Main module implementing the scattering transform in 2D.
19 | The scattering transform computes two wavelet transform followed
20 | by modulus non-linearity.
21 | It can be summarized as::
22 |
23 | S_J x = [S_J^0 x, S_J^1 x, S_J^2 x]
24 |
25 | where::
26 |
27 | S_J^0 x = x * phi_J
28 | S_J^1 x = [|x * psi^1_lambda| * phi_J]_lambda
29 | S_J^2 x = [||x * psi^1_lambda| * psi^2_mu| * phi_J]_{lambda, mu}
30 |
31 | where * denotes the convolution (in space), phi_J is a low pass
32 | filter, psi^1_lambda is a family of band pass
33 | filters and psi^2_mu is another family of band pass filters.
34 | Only Morlet filters are used in this implementation.
35 | Convolutions are efficiently performed in the Fourier domain
36 | with this implementation.
37 |
38 | Example
39 | -------
40 | # 1) Define a Scattering object as:
41 | s = Scattering2D(J, shape=(M, N))
42 | # where (M, N) are the image sizes and 2**J the scale of the scattering
43 | # 2) Forward on an input Tensor x of shape B x M x N,
44 | # where B is the batch size.
45 | result_s = s(x)
46 |
47 | Parameters
48 | ----------
49 | J : int
50 | logscale of the scattering
51 | shape : tuple of int
52 | spatial support (M, N) of the input
53 | L : int, optional
54 | number of angles used for the wavelet transform
55 | max_order : int, optional
56 | The maximum order of scattering coefficients to compute. Must be either
57 | `1` or `2`. Defaults to `2`.
58 | pre_pad : boolean, optional
59 | controls the padding: if set to False, a symmetric padding is applied
60 | on the signal. If set to true, the software will assume the signal was
61 | padded externally.
62 |
63 | Attributes
64 | ----------
65 | J : int
66 | logscale of the scattering
67 | shape : tuple of int
68 | spatial support (M, N) of the input
69 | L : int, optional
70 | number of angles used for the wavelet transform
71 | max_order : int, optional
72 | The maximum order of scattering coefficients to compute.
73 | Must be either equal to `1` or `2`. Defaults to `2`.
74 | pre_pad : boolean
75 | controls the padding
76 | Psi : dictionary
77 | containing the wavelets filters at all resolutions. See
78 | filter_bank.filter_bank for an exact description.
79 | Phi : dictionary
80 | containing the low-pass filters at all resolutions. See
81 | filter_bank.filter_bank for an exact description.
82 | M_padded, N_padded : int
83 | spatial support of the padded input
84 |
85 | Notes
86 | -----
87 | The design of the filters is optimized for the value L = 8
88 |
89 | pre_pad is particularly useful when doing crops of a bigger
90 | image because the padding is then extremely accurate. Defaults
91 | to False.
92 |
93 | """
94 | def __init__(self, J, shape, L=8, max_order=2, pre_pad=False):
95 | self.J, self.L = J, L
96 | self.pre_pad = pre_pad
97 | self.max_order = max_order
98 | self.shape = shape
99 | if 2**J>shape[0] or 2**J>shape[1]:
100 | raise (RuntimeError('The smallest dimension should be larger than 2^J'))
101 |
102 | self.build()
103 |
104 | def build(self):
105 | self.M, self.N = self.shape
106 | self.modulus = Modulus()
107 | self.M_padded, self.N_padded = compute_padding(self.M, self.N, self.J)
108 | # pads equally on a given side if the amount of padding to add is an even number of pixels, otherwise it adds an extra pixel
109 | self.pad = Pad([(self.M_padded - self.M) // 2, (self.M_padded - self.M+1) // 2, (self.N_padded - self.N) // 2, (self.N_padded - self.N + 1) // 2], [self.M, self.N], pre_pad=self.pre_pad)
110 | self.subsample_fourier = SubsampleFourier()
111 | # Create the filters
112 | filters = filter_bank(self.M_padded, self.N_padded, self.J, self.L)
113 | self.Psi = convert_filters(filters['psi'])
114 | self.Phi = convert_filters([filters['phi'][j] for j in range(self.J)])
115 |
116 | def _apply(self, fn):
117 | """
118 | Mimics the behavior of the function _apply() of a nn.Module()
119 | """
120 | for key, item in enumerate(self.Psi):
121 | for key2, item2 in self.Psi[key].items():
122 | if torch.is_tensor(item2):
123 | self.Psi[key][key2] = fn(item2)
124 | self.Phi = [fn(v) for v in self.Phi]
125 | self.pad.padding_module._apply(fn)
126 | return self
127 |
128 | def cuda(self, device=None):
129 | """
130 | Mimics the behavior of the function cuda() of a nn.Module()
131 | """
132 | return self._apply(lambda t: t.cuda(device))
133 |
134 | def to(self, *args, **kwargs):
135 | """
136 | Mimics the behavior of the function to() of a nn.Module()
137 | """
138 | device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
139 |
140 | if dtype is not None:
141 | if not dtype.is_floating_point:
142 | raise TypeError('nn.Module.to only accepts floating point '
143 | 'dtypes, but got desired dtype={}'.format(dtype))
144 |
145 | def convert(t):
146 | return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
147 |
148 | return self._apply(convert)
149 |
150 | def cpu(self):
151 | """
152 | Mimics the behavior of the function cpu() of a nn.Module()
153 | """
154 | return self._apply(lambda t: t.cpu())
155 |
156 | def forward(self, input):
157 | """Forward pass of the scattering.
158 |
159 | Parameters
160 | ----------
161 | input : tensor
162 | tensor with 3 dimensions :math:`(B, C, M, N)` where :math:`(B, C)` are arbitrary.
163 | :math:`B` typically is the batch size, whereas :math:`C` is the number of input channels.
164 |
165 | Returns
166 | -------
167 | S : tensor
168 | scattering of the input, a 4D tensor :math:`(B, C, D, Md, Nd)` where :math:`D` corresponds
169 | to a new channel dimension and :math:`(Md, Nd)` are downsampled sizes by a factor :math:`2^J`.
170 |
171 | """
172 | if not torch.is_tensor(input):
173 | raise(TypeError('The input should be a torch.cuda.FloatTensor, a torch.FloatTensor or a torch.DoubleTensor'))
174 |
175 | if len(input.shape) < 2:
176 | raise (RuntimeError('Input tensor must have at least two '
177 | 'dimensions'))
178 |
179 | if (not input.is_contiguous()):
180 | raise (RuntimeError('Tensor must be contiguous!'))
181 |
182 | if((input.size(-1)!=self.N or input.size(-2)!=self.M) and not self.pre_pad):
183 | raise (RuntimeError('Tensor must be of spatial size (%i,%i)!'%(self.M,self.N)))
184 |
185 | if ((input.size(-1) != self.N_padded or input.size(-2) != self.M_padded) and self.pre_pad):
186 | raise (RuntimeError('Padded tensor must be of spatial size (%i,%i)!' % (self.M_padded, self.N_padded)))
187 |
188 | batch_shape = input.shape[:-2]
189 | signal_shape = input.shape[-2:]
190 |
191 | input = input.reshape((-1, 1) + signal_shape)
192 |
193 | J = self.J
194 | phi = self.Phi
195 | psi = self.Psi
196 |
197 | subsample_fourier = self.subsample_fourier
198 | modulus = self.modulus
199 | pad = self.pad
200 | order0_size = 1
201 | order1_size = self.L * J
202 | order2_size = self.L ** 2 * J * (J - 1) // 2
203 | output_size = order0_size + order1_size
204 |
205 | if self.max_order == 2:
206 | output_size += order2_size
207 |
208 | S = input.new(input.size(0),
209 | input.size(1),
210 | output_size,
211 | self.M_padded//(2**J)-2,
212 | self.N_padded//(2**J)-2)
213 | U_r = pad(input)
214 | U_0_c = fft(U_r, 'C2C') # We trick here with U_r and U_2_c
215 |
216 | # First low pass filter
217 | U_1_c = subsample_fourier(cdgmm(U_0_c, phi[0]), k=2**J)
218 |
219 | U_J_r = fft(U_1_c, 'C2R')
220 |
221 | S[..., 0, :, :] = unpad(U_J_r)
222 | n_order1 = 1
223 | n_order2 = 1 + order1_size
224 |
225 | for n1 in range(len(psi)):
226 | j1 = psi[n1]['j']
227 | U_1_c = cdgmm(U_0_c, psi[n1][0])
228 | if(j1 > 0):
229 | U_1_c = subsample_fourier(U_1_c, k=2 ** j1)
230 | U_1_c = fft(U_1_c, 'C2C', inverse=True)
231 | U_1_c = fft(modulus(U_1_c), 'C2C')
232 |
233 | # Second low pass filter
234 | U_2_c = subsample_fourier(cdgmm(U_1_c, phi[j1]), k=2**(J-j1))
235 | U_J_r = fft(U_2_c, 'C2R')
236 | S[..., n_order1, :, :] = unpad(U_J_r)
237 | n_order1 += 1
238 |
239 | if self.max_order == 2:
240 | for n2 in range(len(psi)):
241 | j2 = psi[n2]['j']
242 | if(j1 < j2):
243 | U_2_c = subsample_fourier(cdgmm(U_1_c, psi[n2][j1]), k=2 ** (j2-j1))
244 | U_2_c = fft(U_2_c, 'C2C', inverse=True)
245 | U_2_c = fft(modulus(U_2_c), 'C2C')
246 |
247 | # Third low pass filter
248 | U_2_c = subsample_fourier(cdgmm(U_2_c, phi[j2]), k=2 ** (J-j2))
249 | U_J_r = fft(U_2_c, 'C2R')
250 |
251 | S[..., n_order2, :, :] = unpad(U_J_r)
252 | n_order2 += 1
253 |
254 | scattering_shape = S.shape[-3:]
255 | S = S.reshape(batch_shape + scattering_shape)
256 |
257 | return S
258 |
259 | def __call__(self, input):
260 | return self.forward(input)
261 |
--------------------------------------------------------------------------------
/python-package/onnet/some_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import random
4 | import torch
5 | import sys
6 | import os
7 | import psutil
8 |
9 |
10 |
11 | def split__sections(dim_0,nClass):
12 | split_dim = range(dim_0)
13 | sections=[]
14 | for arr in np.array_split(np.array(split_dim), nClass):
15 | sections.append(len(arr))
16 | assert len(sections) > 0
17 | return sections
18 |
19 | def shrink(x0,x1,max_sz=2):
20 | if x1-x0>max_sz:
21 | center=(x1+x0)//2
22 | #x1 = x0+max_sz
23 | x1 = center + max_sz // 2
24 | x0 = center - max_sz // 2
25 | return x0,x1
26 |
27 | def split_regions_2d(shape,nClass):
28 | dim_1,dim_2=shape[-1],shape[-2]
29 | n1 = (int)(math.sqrt(nClass))
30 | n2 = (int)(math.ceil(nClass/n1))
31 | assert n1*n2>=nClass
32 | section_1 = split__sections(dim_1, n1)
33 | section_2 = split__sections(dim_2, n2)
34 | regions = []
35 | x1,x2=0,0
36 | for sec_1 in section_1:
37 | for sec_2 in section_2:
38 | #box=(x1,x1+sec_1,x2,x2+sec_2)
39 | box = shrink(x1,x1+sec_1)+shrink(x2,x2+sec_2)
40 | regions.append(box)
41 | if len(regions)>=nClass:
42 | break
43 | x2 = x2 + sec_2
44 | x1 = x1 + sec_1; x2=0
45 | return regions
46 |
47 | def seed_everything(seed=0):
48 | print(f"======== seed_everything seed={seed}========")
49 | random.seed(seed)
50 | os.environ['PYTHONHASHSEED'] = str(seed)
51 | np.random.seed(seed)
52 | #https://pytorch.org/docs/stable/notes/randomness.html
53 |
54 | torch.manual_seed(seed)
55 | if torch.cuda.is_available():
56 | torch.cuda.manual_seed(seed)
57 | torch.cuda.manual_seed_all(seed)
58 |
59 | torch.backends.cudnn.deterministic = True
60 | torch.backends.cudnn.benchmark = False
61 | '''
62 | if fix_seed is not None: # fix seed
63 | seed = fix_seed #17 * 19
64 | print("!!! __pyTorch FIX SEED={} use_cuda={}!!!".format(seed,use_cuda) )
65 | random.seed(seed-1)
66 | np.random.seed(seed)
67 | torch.manual_seed(seed+1)
68 | if use_cuda:
69 | torch.cuda.manual_seed(seed+2)
70 | torch.cuda.manual_seed_all(seed+3)
71 | torch.backends.cudnn.deterministic = True
72 | '''
73 |
74 | def cpuStats():
75 | print(sys.version)
76 | print(psutil.cpu_percent())
77 | print(psutil.virtual_memory()) # physical memory usage
78 | pid = os.getpid()
79 | py = psutil.Process(pid)
80 | memoryUse = py.memory_info()[0] / 2. ** 30 # memory use in GB...I think
81 | print('memory use in python(GB):', memoryUse)
82 |
83 | def pytorch_env( ):
84 | print('__Python VERSION:', sys.version)
85 | print('__pyTorch VERSION:', torch.__version__)
86 | print('__CUDA VERSION')
87 | # from subprocess import call
88 | # call(["nvcc", "--version"]) does not work
89 | # ! nvcc --version
90 | print('__CUDNN VERSION:', torch.backends.cudnn.version())
91 | print('__Number CUDA Devices:', torch.cuda.device_count())
92 | print('__Devices')
93 | # call(["nvidia-smi", "--format=csv", "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"])
94 | print('Active CUDA Device: GPU', torch.cuda.current_device())
95 |
96 | print ('Available devices ', torch.cuda.device_count())
97 | print ('Current cuda device ', torch.cuda.current_device())
98 | use_cuda = torch.cuda.is_available()
99 | print("USE CUDA=" + str(use_cuda))
100 |
101 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
102 | FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
103 | LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
104 | Tensor = FloatTensor
105 | cpuStats()
106 | print("===== torch_init device={}".format(device))
107 | return device
108 |
109 | def OnInitInstance(seed=0):
110 | seed_everything(seed)
111 | gpu_device = pytorch_env()
112 | return gpu_device
113 |
114 | def load_model_weights(model, state_dict, log,verbose=True):
115 | """
116 | Loads the model weights from the state dictionary. Function will only load
117 | the weights which have matching key names and dimensions in the state
118 | dictionary.
119 |
120 | :param state_dict: Pytorch model state dictionary
121 | :param verbose: bool, If True, the function will print the
122 | weight keys of parametares that can and cannot be loaded from the
123 | checkpoint state dictionary.
124 | :return: The model with loaded weights
125 | """
126 | new_state_dict = model.state_dict()
127 | non_loadable, loadable = set(), set()
128 |
129 | for k, v in state_dict.items():
130 | if k not in new_state_dict:
131 | non_loadable.add(k)
132 | continue
133 |
134 | if v.shape != new_state_dict[k].shape:
135 | non_loadable.add(k)
136 | continue
137 |
138 | new_state_dict[k] = v
139 | loadable.add(k)
140 |
141 | if verbose:
142 | log.info("### Checkpoint weights that WILL be loaded: ###")
143 | {log.info(k) for k in loadable}
144 |
145 | log.info("### Checkpoint weights that CANNOT be loaded: ###")
146 | {log.info(k) for k in non_loadable}
147 |
148 | model.load_state_dict(new_state_dict)
149 | return model
--------------------------------------------------------------------------------
/venv/pyvenv.cfg:
--------------------------------------------------------------------------------
1 | home = D:\anaconda3
2 | include-system-site-packages = false
3 | version = 3.7.3
4 |
--------------------------------------------------------------------------------