├── .gitignore
├── LICENSE
├── README.md
├── asset
├── CIFAR10_test.png
├── CIRAR10_train.png
├── MNIST_test.png
├── MNIST_train.png
├── cifar.png
├── conv_transform_Fashion.png
├── conv_transform_MNIST.png
├── convolutional_transform_CIFAR10.png
├── mnist.png
├── transform_based_CIFAR10.png
└── transfrom_based_MNIST.png
├── common
├── __init__.py
├── builder.py
├── loader.py
├── main.py
├── test.py
└── train.py
├── decomposition
├── __init__.py
├── conv_layer.py
├── decompose_all.py
├── fc_layer.py
└── tensor_ring.py
├── examples
├── distributed.ipynb
├── ex1.ipynb
├── multi_example.py
├── multi_threading_net.ipynb
├── multi_threading_net.py
├── multiprocess_net.ipynb
└── parallel.py
├── nets
├── __init__.py
├── lenet.py
└── vgg.py
├── notebooks
├── cp_demo.ipynb
├── eval.ipynb
└── main.ipynb
└── transform_based_network
├── __init__.py
├── multiprocess.py
├── new_transform.ipynb
├── tNN.py
├── trainer.py
├── transform_layer.py
├── transform_nets.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.i
2 | *.ii
3 | *.gpu
4 | *.ptx
5 | *.cubin
6 | *.fatbin
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Xiao-Yang Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensor_Layer_for_Deep_Neural_Network_Compression
2 | Apply CP, Tucker, TT/TR, HT to compress neural networks. Train from scratch.
3 |
4 | ## Usage
5 |
6 | First, import required modules by
7 | ```
8 | from common import *
9 | from decomposition import *
10 | from nets import *
11 | ```
12 |
13 | Then, specify a neural network model. The user can choose from any model provided by the nets package, or can define a new architecture using pytorch. Examples:
14 | ```
15 | model0 = LeNet()
16 | model1 = VGG('VGG16')
17 | ```
18 |
19 | Finally, go through the training and testing process using the `run_all` function.
20 | The function has a few parameters:
21 | * `dataset`: choose a dataset from mnist, cifar10, and cifar100
22 | * `model`: the neural network model defined in the last step
23 | * `decomp`: the method of decomposition; defaulted to be `None` (undecomposed)
24 | * `i`: number of iterations for training; defaulted to be 100
25 | * `rate`: learning rate; defaulted to be 0.05
26 | * `transform_based`: whether to use transform_based network; defaulted to be False.
27 | * To use the transform-based net, it is advised to run a small number of epochs, as the training process is much slower than convolutional nets.
28 | * The decomposition option for transform-based net is currently under construction, so it is temporarily disabled.
29 |
30 | This example below runs the CP-decomposed LeNet on the MNIST dataset for 150 iterations with a learning rate of 0.1:
31 | ```
32 | run_all('mnist', model0, decomp='cp', i=150, rate=0.1)
33 | ```
34 | This function creates three subdirectories:
35 | * `data`: stores the dataset
36 | * `models`: stores the trained network
37 | * `cureves`: stores arrays of training and testing accuracy across different iterations in `.npy` format
38 |
39 |
40 | ## Method
41 | I aim to decompose the neural network in both the convolutional portion and the fully connected portion, using popular tensor decomposition algorithms such as CP, Tucker, TT and HT. In doing so, I hope to speedup both the training and the inference process and reduce the number of parameters without signicant sacrifices in terms of accuracy.
42 |
43 | ## CP
44 | CP decomposition works fine with classifying the MNIST dataset, it can compress the network without significant loss in accuracy compared to the original, uncompressed network. However, as noted in paper Lebedev et al., CP cannot has problems dealing with larger networks, and it is often unstable. In my experiments, the decomposition process not only takes an intolerably long time, but it also consumes a lot of RAM. For a convolutional layer with size larger than 512 x 512, the CP decompositon becomes infeasible in terms of memory. Moreover, the CP decomposed network is highly sensitive to the learning rate, and requires the learning rate to be as small as 1e-5 for learning to take place.
45 |
46 | ## Tucker
47 | Tucker decomposition is strictly superior to CP in almost every way. It has more sucess decomposing larger networks, and requires less resources in terms of runtime and memory. It is also more tolerant to larger values of learning rates, allowing the network to learn faster. The network decomposed with Tucker also learns faster, i.e., yields greater accuracy in fewer epochs (see the analysis of performance graphs for details).
48 |
49 | ## Tensor Train (TT)
50 | In my implementation of compression using tensor train, I picked out the two dimensions in the convolutional layer corresponding to the input/output channels, then I matricized the tensor, decomposed the result to matrix product state, and reshaped them back to 4-dimensional tensors. This gives us two decomposed convolutional layer for every convolutional layer in the original network. Experimentally, this method yields better results than Tucker, and has similar rate of compression and speedup as Tucker. In the two papers by Novikov et al., the authors proposed using a transformation to higher-order tensor before applying TT decomposition.
51 |
52 | ## Tensor Ring (TR) UNDER CONSTRUCTION
53 | TR decomposition is highly similar to TT, differing only in an additional non-trivial mode on the first and last tensor core. The way it is applied to neural networks is also similar, although researchers argue that TR has greater expressiveness.
54 |
55 | ## Hierarchical Tucker (HT)
56 | Have not yet developed.
57 |
58 | ## Transform-based networks
59 | This network is trained in the transform domain: the weights and the training data are passed into the network after applying a tensor product between them. The outputs are evaluated against the tubal version of the softmax objective function after an inverse transformation (idct) in the last layer. The backprop process is handled by the pytorch's builtin autograd functions.
60 |
61 | To see the how the fully-connected transform-based network runs on MNIST, run the demo in new_transform.ipynb.
62 |
63 | ## Experiments
64 | I tested the performance of the three compression methods against the uncompressed network on the MNIST and the CIFAR10 datasets. I tried to keep all hyperparameters the same for all tests, including rank, number of epochs, and learning rate. However, as CP is too sensitive to learning rate, I give it a much smaller value for learning rate.
65 |
66 |

67 | Figure 1. Training accuracy comparision on the MNIST dataset.
68 |
69 | 
70 | Figure 2. Testing accuracy comparision on the MNIST dataset.
71 |
72 | From this performance graph, we can see that even though the CP-decomposed network has higher training accuracy at the end, its testing accuracy is low, likely resulting from overfitting due to a finer learning rate. TT-decomposed network learns faster than Tucker and yields better results. In terms of run time, the four networks do not differ from each other significantly.
73 |
74 | 
75 | Figure 3. Training accuracy comparision on the CIAR10 dataset.
76 |
77 | 
78 | Figure 4. Testing accuracy comparision on the CIAR10 dataset.
79 |
80 | 
81 | Figure 5. Training and testing accuracy of transform-based net on MNIST dataset.
82 |
83 | For the uncompressed network, the average time for each epoch is around 38 seconds, the average time for the Tucker-decomposed network is 26 seconds, and the average time for the TT-decomposed network is 27 seconds. In terms of accuracy, the TT-decomposed network outperforms Tucker in both training and testing, and is almost comparable to the original network before compression.
84 |
85 | ## Profiling
86 | In a typical training process, the profiling output is:
87 | ```
88 | 510155235 function calls (507136221 primitive calls) in 2053.806 seconds
89 |
90 | Ordered by: internal time
91 | List reduced from 824 to 20 due to restriction <20>
92 |
93 | ncalls tottime percall cumtime percall filename:lineno(function)
94 | 49401 743.334 0.015 743.334 0.015 {method 'item' of 'torch._C._TensorBase' objects}
95 | 19550 222.799 0.011 222.799 0.011 {method 'run_backward' of 'torch._C._EngineBase' objects}
96 | 5747602 107.058 0.000 107.058 0.000 {method 'add_' of 'torch._C._TensorBase' objects}
97 | 1183200 102.777 0.000 102.777 0.000 {built-in method conv2d}
98 | 3010000 59.049 0.000 140.632 0.000 functional.py:192(normalize)
99 | 3010000 45.654 0.000 205.622 0.000 functional.py:43(to_tensor)
100 | 1915802 45.251 0.000 45.251 0.000 {method 'mul_' of 'torch._C._TensorBase' objects}
101 | 394400 41.219 0.000 41.219 0.000 {built-in method batch_norm}
102 | 3010000 40.894 0.000 40.894 0.000 {method 'tobytes' of 'numpy.ndarray' objects}
103 | 1915850 39.603 0.000 39.603 0.000 {method 'zero_' of 'torch._C._TensorBase' objects}
104 | 3010000 32.589 0.000 32.589 0.000 {method 'div' of 'torch._C._TensorBase' objects}
105 | 3010000 25.541 0.000 25.541 0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
106 | 6020000 25.312 0.000 25.312 0.000 {built-in method as_tensor}
107 | 3010000 24.033 0.000 24.033 0.000 {method 'sub_' of 'torch._C._TensorBase' objects}
108 | 3010000 20.338 0.000 116.898 0.000 Image.py:2644(fromarray)
109 | 6020000 19.078 0.000 19.078 0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
110 | 3034650 18.921 0.000 18.921 0.000 {method 'view' of 'torch._C._TensorBase' objects}
111 | 3010000 18.467 0.000 18.467 0.000 {method 'float' of 'torch._C._TensorBase' objects}
112 | 3010000 15.967 0.000 15.967 0.000 {method 'clone' of 'torch._C._TensorBase' objects}
113 | 19550 15.331 0.001 168.502 0.009 sgd.py:71(step)
114 | ```
115 |
116 | ## References
117 | ### List of relevent papers:
118 |
119 | * Lebedev, V., Ganin, Y., Rakhuba, M., Oseledets, I. and Lempitsky, V., 2015. Speeding-up convolutional neural networks using fine-tuned CP-decomposition. In 3rd International Conference on Learning Representations, ICLR 2015-Conference Track Proceedings.
120 | * *Notes: applies CP to convlayers.*
121 |
122 | * Kim, Y.D., Park, E., Yoo, S., Choi, T., Yang, L. and Shin, D., 2015. Compression of Deep Convolutional Neural Networks for Fast and Low Power Mobile Applications. arXiv, pp.arXiv-1511.
123 | * *Notes: applies Tucker to convlayers*
124 |
125 | * Garipov, T., Podoprikhin, D., Novikov, A. and Vetrov, D., 2016. Ultimate tensorization: compressing convolutional and FC layers alike. arXiv, pp.arXiv-1611.
126 | * *Notes: applies TT to both conv and FC layers*
127 |
128 | * Novikov, A., Podoprikhin, D., Osokin, A. and Vetrov, D.P., 2015. Tensorizing neural networks. In Advances in neural information processing systems (pp. 442-450).
129 | * *Notes: applies TT to FC layers*
130 |
131 | * Wang, W., Sun, Y., Eriksson, B., Wang, W. and Aggarwal, V., 2018. Wide compression: Tensor ring nets. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 9329-9338).
132 | * *Notes: applies TR to both conv and FC layers*
133 |
134 | * Cohen, N., Sharir, O., Levine, Y., Tamari, R., Yakira, D. and Shashua, A., 2017. Analysis and Design of Convolutional Networks via Hierarchical Tensor Decompositions. arXiv, pp.arXiv-1705.
135 | * *Notes: applies HT to convlayers*
136 |
137 | * Yang, Y., Krompass, D. and Tresp, V., 2017. Tensor-train recurrent neural networks for video classification. arXiv preprint arXiv:1707.01786.
138 | * *Notes: applies TT to sequential models*
139 |
140 | * Yin, M., Liao, S., Liu, X.Y., Wang, X. and Yuan, B., 2020. Compressing Recurrent Neural Networks Using Hierarchical Tucker Tensor Decomposition. arXiv, pp.arXiv-2005.
141 | * *Notes: applies HT to LSTMs*
142 |
143 | * Newman, Elizabeth, et al. "Stable tensor neural networks for rapid deep learning." arXiv preprint arXiv:1811.06569 (2018).
144 | * *Notes: transform-based tensor neural network*
145 |
146 | ### Related Github repos:
147 |
148 | https://github.com/jacobgil/pytorch-tensor-decompositions
149 |
150 | https://github.com/JeanKossaifi/tensorly-notebooks
151 |
152 | https://github.com/vadim-v-lebedev/cp-decomposition
153 |
154 | https://github.com/timgaripov/TensorNet-TF
155 |
--------------------------------------------------------------------------------
/asset/CIFAR10_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/CIFAR10_test.png
--------------------------------------------------------------------------------
/asset/CIRAR10_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/CIRAR10_train.png
--------------------------------------------------------------------------------
/asset/MNIST_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/MNIST_test.png
--------------------------------------------------------------------------------
/asset/MNIST_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/MNIST_train.png
--------------------------------------------------------------------------------
/asset/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/cifar.png
--------------------------------------------------------------------------------
/asset/conv_transform_Fashion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/conv_transform_Fashion.png
--------------------------------------------------------------------------------
/asset/conv_transform_MNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/conv_transform_MNIST.png
--------------------------------------------------------------------------------
/asset/convolutional_transform_CIFAR10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/convolutional_transform_CIFAR10.png
--------------------------------------------------------------------------------
/asset/mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/mnist.png
--------------------------------------------------------------------------------
/asset/transform_based_CIFAR10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/transform_based_CIFAR10.png
--------------------------------------------------------------------------------
/asset/transfrom_based_MNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/transfrom_based_MNIST.png
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .train import *
2 | from .test import *
3 | from .loader import *
4 | from .builder import *
5 | from .main import *
--------------------------------------------------------------------------------
/common/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorly as tl
3 | import os
4 | import sys
5 | sys.path.append('..')
6 | from decomposition import *
7 |
8 |
9 | def build(model, decomp='cp'):
10 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
11 | print('==> Building model..')
12 | tl.set_backend('pytorch')
13 | full_net = model
14 | full_net = full_net.to(device)
15 |
16 | path = 'models/'
17 | if not os.path.exists(path):
18 | os.mkdir(path)
19 | torch.save(full_net, path + 'model')
20 | if decomp:
21 | decompose_conv(decomp)
22 | decompose_fc(decomp)
23 | if device == 'cuda:0':
24 | net = torch.load(path + "model").cuda()
25 | else:
26 | net = torch.load(path + "model")
27 | print(net)
28 | print('==> Done')
29 | return net
--------------------------------------------------------------------------------
/common/loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | from torchvision import models
5 | import os
6 |
7 |
8 | def load_mnist():
9 | print('==> Loading data..')
10 | transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
11 | transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
12 |
13 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
14 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)
15 |
16 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
17 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
18 | return trainloader, testloader
19 |
20 | def load_fashion_mnist():
21 | print('==> Loading data..')
22 | transform_train = transforms.Compose([transforms.ToTensor()])
23 | transform_test = transforms.Compose([transforms.ToTensor()])
24 |
25 | trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
26 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)
27 |
28 | testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
29 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
30 | return trainloader, testloader
31 |
32 | def load_cifar10():
33 | print('==> Loading data..')
34 | transform_train = transforms.Compose([
35 | transforms.RandomCrop(32, padding=4),
36 | transforms.RandomHorizontalFlip(),
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
39 | ])
40 |
41 | transform_test = transforms.Compose([
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
44 | ])
45 |
46 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
47 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)
48 |
49 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
50 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
51 |
52 | return trainloader, testloader
53 |
54 | def load_cifar100():
55 | print('==> Loading data..')
56 | transform_train = transforms.Compose([
57 | transforms.RandomCrop(32, padding=4),
58 | transforms.RandomHorizontalFlip(),
59 | transforms.ToTensor(),
60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
61 | ])
62 |
63 | transform_test = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
66 | ])
67 |
68 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
69 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)
70 |
71 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
72 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
73 |
74 | return trainloader, testloader
75 |
76 | def load_mnist_multiprocess(override=0):
77 | print('==> Loading data..')
78 | if override:
79 | cpu_count = override
80 | else:
81 | cpu_count = os.cpu_count()
82 | transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
83 | transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
84 |
85 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
86 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=cpu_count, shuffle=True, num_workers=0)
87 |
88 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
89 | testloader = torch.utils.data.DataLoader(testset, batch_size=cpu_count, shuffle=False, num_workers=0)
90 | return trainloader, testloader
--------------------------------------------------------------------------------
/common/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | from torchvision import models
11 |
12 | import tensorly as tl
13 | import tensorly
14 | from itertools import chain
15 | from tensorly.decomposition import parafac, partial_tucker
16 |
17 | import os
18 | import matplotlib.pyplot as plt
19 | import numpy as np
20 | import time
21 |
22 | import sys
23 | sys.path.append('..')
24 | from common import *
25 | from decomposition import *
26 | from nets import *
27 | from transform_based_network import *
28 |
29 |
30 | # main function
31 | def run_all(dataset, model, decomp=None, i=100, rate=0.05, transform_based=False):
32 |
33 | # choose dataset from (MNIST, CIFAR10, ImageNet)
34 | if dataset == 'mnist':
35 | trainloader, testloader = load_mnist()
36 | if dataset == 'cifar10':
37 | trainloader, testloader = load_cifar10()
38 | if dataset == 'cifar100':
39 | trainloader, testloader = load_cifar100()
40 |
41 | # choose decomposition algorithm from (CP, Tucker, TT)
42 | if not transform_based:
43 | net = build(model, decomp)
44 | optimizer = optim.SGD(net.parameters(), lr=rate, momentum=0.9, weight_decay=5e-4)
45 | train_acc, test_acc = train(i, net, trainloader, testloader, optimizer)
46 | _, inf_time = test([], net, testloader)
47 |
48 | if not decomp:
49 | decomp = 'full'
50 | filename = dataset + '_' + decomp
51 | else:
52 | net = Transform_Net()
53 | optimizer = optim.SGD(net.parameters(), lr=rate, momentum=0.9, weight_decay=5e-4)
54 | train_acc, test_acc = train_transform(25, model, trainloader, testloader, optimizer)
55 |
56 | torch.save(net, 'models/' + filename)
57 | path = 'curves/'
58 | if not os.path.exists(path):
59 | os.mkdir(path)
60 |
61 | np.save(path + filename + '_train', train_acc)
62 | np.save(path + filename + '_test', test_acc)
63 |
--------------------------------------------------------------------------------
/common/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | import numpy as np
9 | import time
10 |
11 | def test(test_acc, model, testloader):
12 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
13 | model.eval()
14 | test_loss = 0
15 | correct = 0
16 | total = 0
17 | criterion = nn.CrossEntropyLoss()
18 | s = time.time()
19 | with torch.no_grad():
20 | print('|', end='')
21 | for batch_idx, (inputs, targets) in enumerate(testloader):
22 | inputs, targets = inputs.to(device), targets.to(device)
23 | outputs = model(inputs)
24 | loss = criterion(outputs, targets)
25 | test_loss += loss.item()
26 | _, predicted = outputs.max(1)
27 | total += targets.size(0)
28 | correct += predicted.eq(targets).sum().item()
29 | if batch_idx % 10 == 0:
30 | print('=', end='')
31 | e = time.time()
32 | acc = 100. * correct / total
33 | print('|', 'Accuracy:', acc, '% ', correct, '/', total)
34 | print('The inference time is', e - s, 'seconds')
35 | test_acc.append(correct / total)
36 | return test_acc, e - s
--------------------------------------------------------------------------------
/common/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | import numpy as np
9 | import time
10 | from .test import *
11 |
12 | def train_step(epoch, train_acc, model, trainloader, optimizer):
13 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
14 | print('\nEpoch: ', epoch)
15 | model.train()
16 | criterion = nn.CrossEntropyLoss()
17 | train_loss = 0
18 | correct = 0
19 | total = 0
20 | print('|', end='')
21 | for batch_idx, (inputs, targets) in enumerate(trainloader):
22 | inputs, targets = inputs.to(device), targets.to(device)
23 | optimizer.zero_grad()
24 | outputs = model(inputs)
25 | print(outputs.shape, targets.shape)
26 | loss = criterion(outputs, targets)
27 | loss.backward()
28 | optimizer.step()
29 | train_loss += loss.item()
30 | _, predicted = outputs.max(1)
31 | total += targets.size(0)
32 | correct += predicted.eq(targets).sum().item()
33 | if batch_idx % 10 == 0:
34 | print('=', end='')
35 | print('|', 'Accuracy:', 100. * correct / total,'% ', correct, '/', total)
36 | train_acc.append(correct / total)
37 | return train_acc
38 |
39 | def train(i, model, trainloader, testloader, optimizer):
40 | train_acc = []
41 | test_acc = []
42 | scheduler = StepLR(optimizer, step_size=5, gamma=0.9)
43 | for epoch in range(i):
44 | s = time.time()
45 | train_acc = train_step(epoch, train_acc, model, trainloader, optimizer)
46 | test_acc, _ = test(test_acc, model, testloader)
47 | scheduler.step()
48 | e = time.time()
49 | print('This epoch took', e - s, 'seconds to train')
50 | print('Current learning rate: ', scheduler.get_last_lr()[0])
51 | print('Best training accuracy overall: ', max(test_acc))
52 | return train_acc, test_acc
--------------------------------------------------------------------------------
/decomposition/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv_layer import *
2 | from .fc_layer import *
3 | from .decompose_all import *
4 | from .tensor_ring import *
--------------------------------------------------------------------------------
/decomposition/conv_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | from torchvision import models
11 |
12 | import tensorly as tl
13 | import tensorly
14 | from itertools import chain
15 | from tensorly.decomposition import parafac, partial_tucker, matrix_product_state
16 |
17 | import os
18 | import matplotlib.pyplot as plt
19 | import numpy as np
20 | import time
21 |
22 |
23 | def cp_decomposition_conv_layer(layer, rank):
24 | l, f, v, h = parafac(layer.weight.data, rank=rank)[1]
25 | factors = [l, f, v, h]
26 | #print([f.shape for f in factors])
27 |
28 | pointwise_s_to_r_layer = torch.nn.Conv2d(
29 | in_channels=f.shape[0],
30 | out_channels=f.shape[1],
31 | kernel_size=1,
32 | stride=1,
33 | padding=0,
34 | dilation=layer.dilation,
35 | bias=False)
36 |
37 | depthwise_vertical_layer = torch.nn.Conv2d(
38 | in_channels=v.shape[1],
39 | out_channels=v.shape[1],
40 | kernel_size=(v.shape[0], 1),
41 | stride=1, padding=(layer.padding[0], 0),
42 | dilation=layer.dilation,
43 | groups=v.shape[1],
44 | bias=False)
45 |
46 | depthwise_horizontal_layer = torch.nn.Conv2d(
47 | in_channels=h.shape[1],
48 | out_channels=h.shape[1],
49 | kernel_size=(1, h.shape[0]),
50 | stride=layer.stride,
51 | padding=(0, layer.padding[0]),
52 | dilation=layer.dilation,
53 | groups=h.shape[1],
54 | bias=False)
55 |
56 | pointwise_r_to_t_layer = torch.nn.Conv2d(
57 | in_channels=l.shape[1],
58 | out_channels=l.shape[0],
59 | kernel_size=1,
60 | stride=1,
61 | padding=0,
62 | dilation=layer.dilation,
63 | bias=True)
64 |
65 | pointwise_r_to_t_layer.bias.data = layer.bias.data
66 | depthwise_horizontal_layer.weight.data = torch.transpose(h, 1, 0).unsqueeze(1).unsqueeze(1)
67 | depthwise_vertical_layer.weight.data = torch.transpose(v, 1, 0).unsqueeze(1).unsqueeze(-1)
68 | pointwise_s_to_r_layer.weight.data = torch.transpose(f, 1, 0).unsqueeze(-1).unsqueeze(-1)
69 | pointwise_r_to_t_layer.weight.data = l.unsqueeze(-1).unsqueeze(-1)
70 |
71 | new_layers = [pointwise_s_to_r_layer, depthwise_vertical_layer,
72 | depthwise_horizontal_layer, pointwise_r_to_t_layer]
73 | #for l in new_layers:
74 | # print(l.weight.data.shape)
75 |
76 | return nn.Sequential(*new_layers)
77 |
78 | def tucker_decomposition_conv_layer(layer, ranks):
79 | core, [last, first] = partial_tucker(layer.weight.data, modes=[0, 1], ranks=ranks, init='svd')
80 | #print(core.shape, last.shape, first.shape)
81 |
82 | # A pointwise convolution that reduces the channels from S to R3
83 | first_layer = torch.nn.Conv2d(in_channels=first.shape[0],
84 | out_channels=first.shape[1], kernel_size=1,
85 | stride=1, padding=0, dilation=layer.dilation, bias=False)
86 |
87 | # A regular 2D convolution layer with R3 input channels
88 | # and R3 output channels
89 | core_layer = torch.nn.Conv2d(in_channels=core.shape[1],
90 | out_channels=core.shape[0], kernel_size=layer.kernel_size,
91 | stride=layer.stride, padding=layer.padding, dilation=layer.dilation, bias=False)
92 |
93 | # A pointwise convolution that increases the channels from R4 to T
94 | last_layer = torch.nn.Conv2d(in_channels=last.shape[1], \
95 | out_channels=last.shape[0], kernel_size=1, stride=1,
96 | padding=0, dilation=layer.dilation, bias=True)
97 |
98 | last_layer.bias.data = layer.bias.data
99 |
100 | first_layer.weight.data = torch.transpose(first, 1, 0).unsqueeze(-1).unsqueeze(-1)
101 | last_layer.weight.data = last.unsqueeze(-1).unsqueeze(-1)
102 | core_layer.weight.data = core
103 |
104 | new_layers = [first_layer, core_layer, last_layer]
105 | #for l in new_layers:
106 | # print(l.weight.data.shape)
107 | return nn.Sequential(*new_layers)
108 |
109 | def tt_decomposition_conv_layer(layer, ranks):
110 | data = layer.weight.data
111 | data2D = tl.base.unfold(data, 0)
112 |
113 | first, last = matrix_product_state(data2D, rank=ranks)
114 | factors = [first, last]
115 | #print([f.shape for f in factors])
116 |
117 | first = first.reshape(data.shape[0], ranks, 1, 1)
118 | last = last.reshape(ranks, data.shape[1], layer.kernel_size[0], layer.kernel_size[1])
119 |
120 | pointwise_s_to_r_layer = torch.nn.Conv2d(
121 | in_channels=last.shape[1],
122 | out_channels=last.shape[0],
123 | kernel_size=layer.kernel_size,
124 | stride=layer.stride,
125 | padding=layer.padding,
126 | dilation=layer.dilation,
127 | bias=False)
128 |
129 | pointwise_r_to_t_layer = torch.nn.Conv2d(
130 | in_channels=first.shape[1],
131 | out_channels=first.shape[0],
132 | kernel_size=1,
133 | stride=1,
134 | padding=0,
135 | dilation=layer.dilation,
136 | bias=True)
137 |
138 | pointwise_r_to_t_layer.bias.data = layer.bias.data
139 | pointwise_s_to_r_layer.weight.data = last
140 | pointwise_r_to_t_layer.weight.data = first
141 |
142 | new_layers = [pointwise_s_to_r_layer, pointwise_r_to_t_layer]
143 | #for l in new_layers:
144 | # print(l.weight.data.shape)
145 |
146 | return nn.Sequential(*new_layers)
--------------------------------------------------------------------------------
/decomposition/decompose_all.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | from torchvision import models
5 | import tensorly as tl
6 | import tensorly
7 | from itertools import chain
8 |
9 | from .conv_layer import *
10 | from .fc_layer import *
11 |
12 | # decompose all layers in a network
13 | def decompose_conv(decomp):
14 | model = torch.load("models/model").cuda()
15 | model.eval()
16 | model.cpu()
17 | for i, key in enumerate(model.features._modules.keys()):
18 | if i >= len(model.features._modules.keys()) - 2:
19 | break
20 | conv_layer = model.features._modules[key]
21 | if isinstance(conv_layer, torch.nn.modules.conv.Conv2d):
22 | rank = max(conv_layer.weight.data.numpy().shape) // 3
23 | if decomp == 'cp':
24 | model.features._modules[key] = cp_decomposition_conv_layer(conv_layer, rank)
25 | if decomp == 'tucker':
26 | ranks = [int(np.ceil(conv_layer.weight.data.numpy().shape[0] / 3)),
27 | int(np.ceil(conv_layer.weight.data.numpy().shape[1] / 3))]
28 | model.features._modules[key] = tucker_decomposition_conv_layer(conv_layer, ranks)
29 | if decomp == 'tt':
30 | model.features._modules[key] = tt_decomposition_conv_layer(conv_layer, rank)
31 | torch.save(model, 'models/model')
32 |
33 | def decompose_fc(decomp):
34 | model = torch.load("models/model").cuda()
35 | model.eval()
36 | model.cpu()
37 | for i, key in enumerate(model.classifier._modules.keys()):
38 | linear_layer = model.classifier._modules[key]
39 | if isinstance(linear_layer, torch.nn.modules.linear.Linear):
40 | rank = min(linear_layer.weight.data.numpy().shape) // 2
41 | if decomp == 'tucker':
42 | model.classifier._modules[key] = tucker_decomposition_fc_layer(linear_layer, rank)
43 | else:
44 | model.classifier._modules[key] = decomposition_fc_layer(linear_layer, rank)
45 | torch.save(model, 'models/model')
46 | return model
--------------------------------------------------------------------------------
/decomposition/fc_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | from torchvision import models
11 |
12 | import tensorly as tl
13 | import tensorly
14 | from itertools import chain
15 | from tensorly.decomposition import parafac, tucker, matrix_product_state
16 |
17 | import os
18 | import matplotlib.pyplot as plt
19 | import numpy as np
20 | import time
21 |
22 |
23 | def decomposition_fc_layer(layer, rank):
24 | l, r = matrix_product_state(layer.weight.data, rank=rank)
25 | l, r = l.squeeze(), r.squeeze()
26 |
27 | right_layer = torch.nn.Linear(r.shape[1], r.shape[0])
28 | left_layer = torch.nn.Linear(l.shape[1], l.shape[0])
29 |
30 | left_layer.bias.data = layer.bias.data
31 | left_layer.weight.data = l
32 | right_layer.weight.data = r
33 |
34 | new_layers = [right_layer, left_layer]
35 | return nn.Sequential(*new_layers)
36 |
37 |
38 | def tucker_decomposition_fc_layer(layer, rank):
39 | core, [l, r] = tucker(layer.weight.data, rank=rank)
40 |
41 | right_layer = torch.nn.Linear(r.shape[0], r.shape[1])
42 | core_layer = torch.nn.Linear(core.shape[1], core.shape[0])
43 | left_layer = torch.nn.Linear(l.shape[1], l.shape[0])
44 |
45 | left_layer.bias.data = layer.bias.data
46 | left_layer.weight.data = l
47 | right_layer.weight.data = r.T
48 |
49 | new_layers = [right_layer, core_layer, left_layer]
50 | return nn.Sequential(*new_layers)
--------------------------------------------------------------------------------
/decomposition/tensor_ring.py:
--------------------------------------------------------------------------------
1 | import tensorly as tl
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def tensor_ring(input_tensor, rank):
7 |
8 | '''Tensor ring (TR) decomposition via recursive SVD
9 |
10 | Decomposes input_tensor into a sequence of order-3 tensors (factors),
11 | with the input rank of the first factor equal to the output rank of the
12 | last factor. This code is modified from MPS decomposition in tensorly
13 | lib's src code
14 |
15 | Parameters
16 | ----------
17 | input_tensor : tensorly.tensor
18 | rank : {int, int list}
19 | maximum allowable rank of the factors
20 | if int, then this is the same for all the factors
21 | if int list, then rank[k] is the rank of the kth factor
22 |
23 | Returns
24 | -------
25 | factors : Tensor ring factors
26 | order-3 tensors of the tensor ring decomposition
27 | '''
28 |
29 | # Check user input for errors
30 | tensor_size = input_tensor.shape
31 | n_dim = len(tensor_size)
32 |
33 | if isinstance(rank, int):
34 | rank = [rank] * n_dim
35 | elif n_dim != len(rank):
36 | message = 'Provided incorrect number of ranks. '
37 | raise(ValueError(message))
38 | rank = list(rank)
39 |
40 | # Initialization
41 | unfolding = tl.unfold(input_tensor, 0)
42 | factors = [None] * n_dim
43 | U, S, V = tl.partial_svd(unfolding, rank[0])
44 | r0 = int(np.sqrt(rank[0]))
45 | while rank[0] % r0:
46 | r0 -= 1;
47 | T0 = tl.reshape(U, (tensor_size[0], r0, rank[0] // r0))
48 | factors[0] = torch.transpose(torch.tensor(T0), 0, 1)
49 | unfolding = tl.reshape(S, (-1, 1)) * V
50 | rank[1] = rank[0] // r0
51 | rank.append(r0)
52 |
53 | # Getting the MPS factors up to n_dim
54 | for k in range(1, n_dim):
55 |
56 | # Reshape the unfolding matrix of the remaining factors
57 | n_row = int(rank[k]*tensor_size[k])
58 | unfolding = tl.reshape(unfolding, (n_row, -1))
59 |
60 | # SVD of unfolding matrix
61 | (n_row, n_column) = unfolding.shape
62 | rank[k+1] = min(n_row, n_column, rank[k+1])
63 | U, S, V = tl.partial_svd(unfolding, rank[k+1])
64 |
65 | # Get kth MPS factor
66 | factors[k] = tl.reshape(U, (rank[k], tensor_size[k], rank[k+1]))
67 |
68 | # Get new unfolding matrix for the remaining factors
69 | unfolding = tl.reshape(S, (-1, 1)) * V
70 |
71 | return factors
--------------------------------------------------------------------------------
/examples/ex1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "sys.path.append('..')\n",
11 | "from common import *\n",
12 | "from decomposition import *\n",
13 | "from nets import *"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "name": "stdout",
23 | "output_type": "stream",
24 | "text": [
25 | "==> Building model..\n",
26 | "LeNet(\n",
27 | " (features): Sequential(\n",
28 | " (0): Sequential(\n",
29 | " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
30 | " (1): Conv2d(1, 11, kernel_size=(3, 3), stride=(1, 1), bias=False)\n",
31 | " (2): Conv2d(11, 32, kernel_size=(1, 1), stride=(1, 1))\n",
32 | " )\n",
33 | " (1): ReLU(inplace=True)\n",
34 | " (2): Sequential(\n",
35 | " (0): Conv2d(32, 11, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
36 | " (1): Conv2d(11, 22, kernel_size=(3, 3), stride=(1, 1), bias=False)\n",
37 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n",
38 | " )\n",
39 | " (3): ReLU(inplace=True)\n",
40 | " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
41 | " )\n",
42 | " (classifier): Sequential(\n",
43 | " (0): Dropout(p=0.25, inplace=False)\n",
44 | " (1): Sequential(\n",
45 | " (0): Linear(in_features=9216, out_features=64, bias=True)\n",
46 | " (1): Linear(in_features=64, out_features=64, bias=True)\n",
47 | " (2): Linear(in_features=64, out_features=128, bias=True)\n",
48 | " )\n",
49 | " (2): ReLU(inplace=True)\n",
50 | " (3): Dropout(p=0.5, inplace=False)\n",
51 | " (4): Sequential(\n",
52 | " (0): Linear(in_features=128, out_features=5, bias=True)\n",
53 | " (1): Linear(in_features=5, out_features=5, bias=True)\n",
54 | " (2): Linear(in_features=5, out_features=10, bias=True)\n",
55 | " )\n",
56 | " )\n",
57 | ")\n",
58 | "==> Done\n"
59 | ]
60 | },
61 | {
62 | "name": "stderr",
63 | "output_type": "stream",
64 | "text": [
65 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorly\\decomposition\\_tucker.py:63: Warning: Given only one int for 'rank' instead of a list of 2 modes. Using this rank for all modes.\n",
66 | " warnings.warn(message, Warning)\n"
67 | ]
68 | }
69 | ],
70 | "source": [
71 | "model = LeNet()\n",
72 | "decomposed = build(model, decomp='tucker')"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 3,
78 | "metadata": {},
79 | "outputs": [
80 | {
81 | "name": "stdout",
82 | "output_type": "stream",
83 | "text": [
84 | "==> Loading data..\n",
85 | "==> Building model..\n",
86 | "LeNet(\n",
87 | " (features): Sequential(\n",
88 | " (0): Sequential(\n",
89 | " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
90 | " (1): Conv2d(1, 11, kernel_size=(3, 3), stride=(1, 1), bias=False)\n",
91 | " (2): Conv2d(11, 32, kernel_size=(1, 1), stride=(1, 1))\n",
92 | " )\n",
93 | " (1): ReLU(inplace=True)\n",
94 | " (2): Sequential(\n",
95 | " (0): Conv2d(32, 11, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
96 | " (1): Conv2d(11, 22, kernel_size=(3, 3), stride=(1, 1), bias=False)\n",
97 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n",
98 | " )\n",
99 | " (3): ReLU(inplace=True)\n",
100 | " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
101 | " )\n",
102 | " (classifier): Sequential(\n",
103 | " (0): Dropout(p=0.25, inplace=False)\n",
104 | " (1): Sequential(\n",
105 | " (0): Linear(in_features=9216, out_features=64, bias=True)\n",
106 | " (1): Linear(in_features=64, out_features=64, bias=True)\n",
107 | " (2): Linear(in_features=64, out_features=128, bias=True)\n",
108 | " )\n",
109 | " (2): ReLU(inplace=True)\n",
110 | " (3): Dropout(p=0.5, inplace=False)\n",
111 | " (4): Sequential(\n",
112 | " (0): Linear(in_features=128, out_features=5, bias=True)\n",
113 | " (1): Linear(in_features=5, out_features=5, bias=True)\n",
114 | " (2): Linear(in_features=5, out_features=10, bias=True)\n",
115 | " )\n",
116 | " )\n",
117 | ")\n",
118 | "==> Done\n",
119 | "\n",
120 | "Epoch: 0\n",
121 | "|torch.Size([128, 10]) torch.Size([128])\n",
122 | "=torch.Size([128, 10]) torch.Size([128])\n",
123 | "torch.Size([128, 10]) torch.Size([128])\n",
124 | "torch.Size([128, 10]) torch.Size([128])\n",
125 | "torch.Size([128, 10]) torch.Size([128])\n",
126 | "torch.Size([128, 10]) torch.Size([128])\n",
127 | "torch.Size([128, 10]) torch.Size([128])\n",
128 | "torch.Size([128, 10]) torch.Size([128])\n",
129 | "torch.Size([128, 10]) torch.Size([128])\n",
130 | "torch.Size([128, 10]) torch.Size([128])\n",
131 | "torch.Size([128, 10]) torch.Size([128])\n",
132 | "=torch.Size([128, 10]) torch.Size([128])\n",
133 | "torch.Size([128, 10]) torch.Size([128])\n",
134 | "torch.Size([128, 10]) torch.Size([128])\n",
135 | "torch.Size([128, 10]) torch.Size([128])\n",
136 | "torch.Size([128, 10]) torch.Size([128])\n",
137 | "torch.Size([128, 10]) torch.Size([128])\n",
138 | "torch.Size([128, 10]) torch.Size([128])\n",
139 | "torch.Size([128, 10]) torch.Size([128])\n",
140 | "torch.Size([128, 10]) torch.Size([128])\n",
141 | "torch.Size([128, 10]) torch.Size([128])\n",
142 | "=torch.Size([128, 10]) torch.Size([128])\n",
143 | "torch.Size([128, 10]) torch.Size([128])\n",
144 | "torch.Size([128, 10]) torch.Size([128])\n",
145 | "torch.Size([128, 10]) torch.Size([128])\n",
146 | "torch.Size([128, 10]) torch.Size([128])\n",
147 | "torch.Size([128, 10]) torch.Size([128])\n",
148 | "torch.Size([128, 10]) torch.Size([128])\n",
149 | "torch.Size([128, 10]) torch.Size([128])\n",
150 | "torch.Size([128, 10]) torch.Size([128])\n",
151 | "torch.Size([128, 10]) torch.Size([128])\n",
152 | "=torch.Size([128, 10]) torch.Size([128])\n",
153 | "torch.Size([128, 10]) torch.Size([128])\n",
154 | "torch.Size([128, 10]) torch.Size([128])\n",
155 | "torch.Size([128, 10]) torch.Size([128])\n",
156 | "torch.Size([128, 10]) torch.Size([128])\n",
157 | "torch.Size([128, 10]) torch.Size([128])\n",
158 | "torch.Size([128, 10]) torch.Size([128])\n",
159 | "torch.Size([128, 10]) torch.Size([128])\n",
160 | "torch.Size([128, 10]) torch.Size([128])\n",
161 | "torch.Size([128, 10]) torch.Size([128])\n",
162 | "=torch.Size([128, 10]) torch.Size([128])\n",
163 | "torch.Size([128, 10]) torch.Size([128])\n",
164 | "torch.Size([128, 10]) torch.Size([128])\n",
165 | "torch.Size([128, 10]) torch.Size([128])\n",
166 | "torch.Size([128, 10]) torch.Size([128])\n",
167 | "torch.Size([128, 10]) torch.Size([128])\n",
168 | "torch.Size([128, 10]) torch.Size([128])\n",
169 | "torch.Size([128, 10]) torch.Size([128])\n",
170 | "torch.Size([128, 10]) torch.Size([128])\n",
171 | "torch.Size([128, 10]) torch.Size([128])\n",
172 | "=torch.Size([128, 10]) torch.Size([128])\n",
173 | "torch.Size([128, 10]) torch.Size([128])\n",
174 | "torch.Size([128, 10]) torch.Size([128])\n",
175 | "torch.Size([128, 10]) torch.Size([128])\n",
176 | "torch.Size([128, 10]) torch.Size([128])\n",
177 | "torch.Size([128, 10]) torch.Size([128])\n",
178 | "torch.Size([128, 10]) torch.Size([128])\n",
179 | "torch.Size([128, 10]) torch.Size([128])\n",
180 | "torch.Size([128, 10]) torch.Size([128])\n",
181 | "torch.Size([128, 10]) torch.Size([128])\n",
182 | "=torch.Size([128, 10]) torch.Size([128])\n",
183 | "torch.Size([128, 10]) torch.Size([128])\n",
184 | "torch.Size([128, 10]) torch.Size([128])\n",
185 | "torch.Size([128, 10]) torch.Size([128])\n",
186 | "torch.Size([128, 10]) torch.Size([128])\n",
187 | "torch.Size([128, 10]) torch.Size([128])\n",
188 | "torch.Size([128, 10]) torch.Size([128])\n",
189 | "torch.Size([128, 10]) torch.Size([128])\n",
190 | "torch.Size([128, 10]) torch.Size([128])\n",
191 | "torch.Size([128, 10]) torch.Size([128])\n",
192 | "=torch.Size([128, 10]) torch.Size([128])\n",
193 | "torch.Size([128, 10]) torch.Size([128])\n",
194 | "torch.Size([128, 10]) torch.Size([128])\n",
195 | "torch.Size([128, 10]) torch.Size([128])\n",
196 | "torch.Size([128, 10]) torch.Size([128])\n",
197 | "torch.Size([128, 10]) torch.Size([128])\n",
198 | "torch.Size([128, 10]) torch.Size([128])\n",
199 | "torch.Size([128, 10]) torch.Size([128])\n",
200 | "torch.Size([128, 10]) torch.Size([128])\n",
201 | "torch.Size([128, 10]) torch.Size([128])\n",
202 | "=torch.Size([128, 10]) torch.Size([128])\n",
203 | "torch.Size([128, 10]) torch.Size([128])\n",
204 | "torch.Size([128, 10]) torch.Size([128])\n",
205 | "torch.Size([128, 10]) torch.Size([128])\n",
206 | "torch.Size([128, 10]) torch.Size([128])\n",
207 | "torch.Size([128, 10]) torch.Size([128])\n",
208 | "torch.Size([128, 10]) torch.Size([128])\n",
209 | "torch.Size([128, 10]) torch.Size([128])\n",
210 | "torch.Size([128, 10]) torch.Size([128])\n",
211 | "torch.Size([128, 10]) torch.Size([128])\n",
212 | "=torch.Size([128, 10]) torch.Size([128])\n",
213 | "torch.Size([128, 10]) torch.Size([128])\n",
214 | "torch.Size([128, 10]) torch.Size([128])\n",
215 | "torch.Size([128, 10]) torch.Size([128])\n",
216 | "torch.Size([128, 10]) torch.Size([128])\n",
217 | "torch.Size([128, 10]) torch.Size([128])\n",
218 | "torch.Size([128, 10]) torch.Size([128])\n",
219 | "torch.Size([128, 10]) torch.Size([128])\n",
220 | "torch.Size([128, 10]) torch.Size([128])\n",
221 | "torch.Size([128, 10]) torch.Size([128])\n",
222 | "=torch.Size([128, 10]) torch.Size([128])\n",
223 | "torch.Size([128, 10]) torch.Size([128])\n",
224 | "torch.Size([128, 10]) torch.Size([128])\n",
225 | "torch.Size([128, 10]) torch.Size([128])\n",
226 | "torch.Size([128, 10]) torch.Size([128])\n",
227 | "torch.Size([128, 10]) torch.Size([128])\n",
228 | "torch.Size([128, 10]) torch.Size([128])\n",
229 | "torch.Size([128, 10]) torch.Size([128])\n",
230 | "torch.Size([128, 10]) torch.Size([128])\n",
231 | "torch.Size([128, 10]) torch.Size([128])\n",
232 | "=torch.Size([128, 10]) torch.Size([128])\n",
233 | "torch.Size([128, 10]) torch.Size([128])\n",
234 | "torch.Size([128, 10]) torch.Size([128])\n",
235 | "torch.Size([128, 10]) torch.Size([128])\n",
236 | "torch.Size([128, 10]) torch.Size([128])\n",
237 | "torch.Size([128, 10]) torch.Size([128])\n",
238 | "torch.Size([128, 10]) torch.Size([128])\n",
239 | "torch.Size([128, 10]) torch.Size([128])\n",
240 | "torch.Size([128, 10]) torch.Size([128])\n",
241 | "torch.Size([128, 10]) torch.Size([128])\n",
242 | "=torch.Size([128, 10]) torch.Size([128])\n",
243 | "torch.Size([128, 10]) torch.Size([128])\n",
244 | "torch.Size([128, 10]) torch.Size([128])\n",
245 | "torch.Size([128, 10]) torch.Size([128])\n",
246 | "torch.Size([128, 10]) torch.Size([128])\n",
247 | "torch.Size([128, 10]) torch.Size([128])\n",
248 | "torch.Size([128, 10]) torch.Size([128])\n",
249 | "torch.Size([128, 10]) torch.Size([128])\n",
250 | "torch.Size([128, 10]) torch.Size([128])\n",
251 | "torch.Size([128, 10]) torch.Size([128])\n",
252 | "=torch.Size([128, 10]) torch.Size([128])\n",
253 | "torch.Size([128, 10]) torch.Size([128])\n",
254 | "torch.Size([128, 10]) torch.Size([128])\n",
255 | "torch.Size([128, 10]) torch.Size([128])\n",
256 | "torch.Size([128, 10]) torch.Size([128])\n",
257 | "torch.Size([128, 10]) torch.Size([128])\n",
258 | "torch.Size([128, 10]) torch.Size([128])\n",
259 | "torch.Size([128, 10]) torch.Size([128])\n",
260 | "torch.Size([128, 10]) torch.Size([128])\n",
261 | "torch.Size([128, 10]) torch.Size([128])\n",
262 | "=torch.Size([128, 10]) torch.Size([128])\n",
263 | "torch.Size([128, 10]) torch.Size([128])\n",
264 | "torch.Size([128, 10]) torch.Size([128])\n",
265 | "torch.Size([128, 10]) torch.Size([128])\n",
266 | "torch.Size([128, 10]) torch.Size([128])\n",
267 | "torch.Size([128, 10]) torch.Size([128])\n",
268 | "torch.Size([128, 10]) torch.Size([128])\n",
269 | "torch.Size([128, 10]) torch.Size([128])\n",
270 | "torch.Size([128, 10]) torch.Size([128])\n",
271 | "torch.Size([128, 10]) torch.Size([128])\n",
272 | "=torch.Size([128, 10]) torch.Size([128])\n",
273 | "torch.Size([128, 10]) torch.Size([128])\n",
274 | "torch.Size([128, 10]) torch.Size([128])\n",
275 | "torch.Size([128, 10]) torch.Size([128])\n",
276 | "torch.Size([128, 10]) torch.Size([128])\n",
277 | "torch.Size([128, 10]) torch.Size([128])\n",
278 | "torch.Size([128, 10]) torch.Size([128])\n",
279 | "torch.Size([128, 10]) torch.Size([128])\n",
280 | "torch.Size([128, 10]) torch.Size([128])\n",
281 | "torch.Size([128, 10]) torch.Size([128])\n",
282 | "=torch.Size([128, 10]) torch.Size([128])\n",
283 | "torch.Size([128, 10]) torch.Size([128])\n",
284 | "torch.Size([128, 10]) torch.Size([128])\n",
285 | "torch.Size([128, 10]) torch.Size([128])\n",
286 | "torch.Size([128, 10]) torch.Size([128])\n",
287 | "torch.Size([128, 10]) torch.Size([128])\n",
288 | "torch.Size([128, 10]) torch.Size([128])\n",
289 | "torch.Size([128, 10]) torch.Size([128])\n",
290 | "torch.Size([128, 10]) torch.Size([128])\n",
291 | "torch.Size([128, 10]) torch.Size([128])\n",
292 | "=torch.Size([128, 10]) torch.Size([128])\n",
293 | "torch.Size([128, 10]) torch.Size([128])\n",
294 | "torch.Size([128, 10]) torch.Size([128])\n",
295 | "torch.Size([128, 10]) torch.Size([128])\n",
296 | "torch.Size([128, 10]) torch.Size([128])\n",
297 | "torch.Size([128, 10]) torch.Size([128])\n",
298 | "torch.Size([128, 10]) torch.Size([128])\n"
299 | ]
300 | },
301 | {
302 | "name": "stdout",
303 | "output_type": "stream",
304 | "text": [
305 | "torch.Size([128, 10]) torch.Size([128])\n",
306 | "torch.Size([128, 10]) torch.Size([128])\n",
307 | "torch.Size([128, 10]) torch.Size([128])\n",
308 | "=torch.Size([128, 10]) torch.Size([128])\n",
309 | "torch.Size([128, 10]) torch.Size([128])\n",
310 | "torch.Size([128, 10]) torch.Size([128])\n",
311 | "torch.Size([128, 10]) torch.Size([128])\n",
312 | "torch.Size([128, 10]) torch.Size([128])\n",
313 | "torch.Size([128, 10]) torch.Size([128])\n",
314 | "torch.Size([128, 10]) torch.Size([128])\n",
315 | "torch.Size([128, 10]) torch.Size([128])\n",
316 | "torch.Size([128, 10]) torch.Size([128])\n",
317 | "torch.Size([128, 10]) torch.Size([128])\n",
318 | "=torch.Size([128, 10]) torch.Size([128])\n",
319 | "torch.Size([128, 10]) torch.Size([128])\n",
320 | "torch.Size([128, 10]) torch.Size([128])\n",
321 | "torch.Size([128, 10]) torch.Size([128])\n",
322 | "torch.Size([128, 10]) torch.Size([128])\n",
323 | "torch.Size([128, 10]) torch.Size([128])\n",
324 | "torch.Size([128, 10]) torch.Size([128])\n",
325 | "torch.Size([128, 10]) torch.Size([128])\n",
326 | "torch.Size([128, 10]) torch.Size([128])\n",
327 | "torch.Size([128, 10]) torch.Size([128])\n",
328 | "=torch.Size([128, 10]) torch.Size([128])\n",
329 | "torch.Size([128, 10]) torch.Size([128])\n",
330 | "torch.Size([128, 10]) torch.Size([128])\n",
331 | "torch.Size([128, 10]) torch.Size([128])\n",
332 | "torch.Size([128, 10]) torch.Size([128])\n",
333 | "torch.Size([128, 10]) torch.Size([128])\n",
334 | "torch.Size([128, 10]) torch.Size([128])\n",
335 | "torch.Size([128, 10]) torch.Size([128])\n",
336 | "torch.Size([128, 10]) torch.Size([128])\n",
337 | "torch.Size([128, 10]) torch.Size([128])\n",
338 | "=torch.Size([128, 10]) torch.Size([128])\n",
339 | "torch.Size([128, 10]) torch.Size([128])\n",
340 | "torch.Size([128, 10]) torch.Size([128])\n",
341 | "torch.Size([128, 10]) torch.Size([128])\n",
342 | "torch.Size([128, 10]) torch.Size([128])\n",
343 | "torch.Size([128, 10]) torch.Size([128])\n",
344 | "torch.Size([128, 10]) torch.Size([128])\n",
345 | "torch.Size([128, 10]) torch.Size([128])\n",
346 | "torch.Size([128, 10]) torch.Size([128])\n",
347 | "torch.Size([128, 10]) torch.Size([128])\n",
348 | "=torch.Size([128, 10]) torch.Size([128])\n",
349 | "torch.Size([128, 10]) torch.Size([128])\n",
350 | "torch.Size([128, 10]) torch.Size([128])\n",
351 | "torch.Size([128, 10]) torch.Size([128])\n",
352 | "torch.Size([128, 10]) torch.Size([128])\n",
353 | "torch.Size([128, 10]) torch.Size([128])\n",
354 | "torch.Size([128, 10]) torch.Size([128])\n",
355 | "torch.Size([128, 10]) torch.Size([128])\n",
356 | "torch.Size([128, 10]) torch.Size([128])\n",
357 | "torch.Size([128, 10]) torch.Size([128])\n",
358 | "=torch.Size([128, 10]) torch.Size([128])\n",
359 | "torch.Size([128, 10]) torch.Size([128])\n",
360 | "torch.Size([128, 10]) torch.Size([128])\n",
361 | "torch.Size([128, 10]) torch.Size([128])\n",
362 | "torch.Size([128, 10]) torch.Size([128])\n",
363 | "torch.Size([128, 10]) torch.Size([128])\n",
364 | "torch.Size([128, 10]) torch.Size([128])\n",
365 | "torch.Size([128, 10]) torch.Size([128])\n",
366 | "torch.Size([128, 10]) torch.Size([128])\n",
367 | "torch.Size([128, 10]) torch.Size([128])\n",
368 | "=torch.Size([128, 10]) torch.Size([128])\n",
369 | "torch.Size([128, 10]) torch.Size([128])\n"
370 | ]
371 | },
372 | {
373 | "ename": "KeyboardInterrupt",
374 | "evalue": "",
375 | "output_type": "error",
376 | "traceback": [
377 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
378 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
379 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mrun_all\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'mnist'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecomp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'tucker'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.05\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
380 | "\u001b[1;32m~\\Desktop\\workstation\\TNN\\common\\main.py\u001b[0m in \u001b[0;36mrun_all\u001b[1;34m(dataset, model, decomp, i, rate)\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[0mnet\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbuild\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecomp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mrate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.9\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight_decay\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m5e-4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 43\u001b[1;33m \u001b[0mtrain_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_acc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrainloader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtestloader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 44\u001b[0m \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minf_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtestloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
381 | "\u001b[1;32m~\\Desktop\\workstation\\TNN\\common\\train.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(i, model, trainloader, testloader, optimizer)\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[0ms\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 45\u001b[1;33m \u001b[0mtrain_acc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrainloader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 46\u001b[0m \u001b[0mtest_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtestloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 47\u001b[0m \u001b[0mscheduler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
382 | "\u001b[1;32m~\\Desktop\\workstation\\TNN\\common\\train.py\u001b[0m in \u001b[0;36mtrain_step\u001b[1;34m(epoch, train_acc, model, trainloader, optimizer)\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[0mtotal\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'|'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 21\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 22\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 23\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
383 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 343\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 344\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 345\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 346\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 347\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
384 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 383\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 384\u001b[0m \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 385\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 386\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 387\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
385 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
386 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
387 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
388 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m 68\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 69\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 70\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 71\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 72\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
389 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, pic)\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mConverted\u001b[0m \u001b[0mimage\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 100\u001b[0m \"\"\"\n\u001b[1;32m--> 101\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 103\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
390 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\functional.py\u001b[0m in \u001b[0;36mto_tensor\u001b[1;34m(pic)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mByteTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdiv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m255\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
391 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
392 | ]
393 | }
394 | ],
395 | "source": [
396 | "run_all('mnist', model, decomp='tucker', i=10, rate=0.05)"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": null,
402 | "metadata": {},
403 | "outputs": [],
404 | "source": [
405 | "run_all('cifar10', model, decomp=None, i=10, rate=0.05)"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": null,
411 | "metadata": {},
412 | "outputs": [],
413 | "source": [
414 | "x = torch.rand(32, 32)\n",
415 | "x_hat = x.reshape(4, 8, 8, 4)\n",
416 | "a = matrix_product_state(x, [1, 2, 1])\n",
417 | "for i in a:\n",
418 | " print(i.shape)"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": null,
424 | "metadata": {},
425 | "outputs": [],
426 | "source": [
427 | "trainloader, testloader = load_mnist()"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": null,
433 | "metadata": {},
434 | "outputs": [],
435 | "source": [
436 | "for i in trainloader:\n",
437 | " print(i[0].squeeze().shape, i[1].shape)\n",
438 | " break"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": null,
444 | "metadata": {},
445 | "outputs": [],
446 | "source": []
447 | }
448 | ],
449 | "metadata": {
450 | "kernelspec": {
451 | "display_name": "Python 3",
452 | "language": "python",
453 | "name": "python3"
454 | },
455 | "language_info": {
456 | "codemirror_mode": {
457 | "name": "ipython",
458 | "version": 3
459 | },
460 | "file_extension": ".py",
461 | "mimetype": "text/x-python",
462 | "name": "python",
463 | "nbconvert_exporter": "python",
464 | "pygments_lexer": "ipython3",
465 | "version": "3.7.6"
466 | }
467 | },
468 | "nbformat": 4,
469 | "nbformat_minor": 4
470 | }
471 |
--------------------------------------------------------------------------------
/examples/multi_example.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[1]:
5 |
6 |
7 | import torch
8 | import torch_dct as dct
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 | import torch.backends.cudnn as cudnn
13 | from torch.optim.lr_scheduler import StepLR
14 |
15 | import matplotlib.pyplot as plt
16 | plt.style.use(['science','no-latex', 'notebook'])
17 |
18 | import time
19 | import sys
20 | import PIL
21 | sys.path.append('../')
22 | from common import *
23 | from transform_based_network import *
24 |
25 |
26 | # In[4]:
27 |
28 |
29 | trainloader, testloader = load_cifar10()
30 | model = Conv_Transform_Net_CIFAR(100)
31 | print(model)
32 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
33 | train_acc, test_acc = train_transform(1, model, trainloader, testloader, optimizer)
34 |
35 | # In[ ]:
36 |
37 |
38 | plt.figure()
39 | plt.title('Convolutional Transform Net Accuracy on CIFAR10')
40 | plt.xlabel('Epoch')
41 | plt.ylabel('Accuracy')
42 | plt.plot(train_acc, label='Train accuracy')
43 | plt.plot(test_acc, label='Test accuracy')
44 | plt.legend()
45 | plt.show()
46 |
47 |
--------------------------------------------------------------------------------
/examples/multi_threading_net.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch_dct as dct\n",
11 | "import torch.nn as nn\n",
12 | "import torch.optim as optim\n",
13 | "import torch.nn.functional as F\n",
14 | "import torch.backends.cudnn as cudnn\n",
15 | "from torch.optim.lr_scheduler import StepLR\n",
16 | "\n",
17 | "import matplotlib.pyplot as plt\n",
18 | "\n",
19 | "from multiprocessing import Pool, Queue, Process, set_start_method\n",
20 | "import multiprocessing as mp_\n",
21 | "\n",
22 | "import time\n",
23 | "import pkbar\n",
24 | "import sys\n",
25 | "sys.path.append('../')\n",
26 | "from common import *\n",
27 | "from transform_based_network import *\n",
28 | "from joblib import Parallel, delayed"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "class T_Layer(nn.Module):\n",
38 | " def __init__(self, dct_w, dct_b):\n",
39 | " super(T_Layer, self).__init__()\n",
40 | " self.weights = nn.Parameter(dct_w)\n",
41 | " self.bias = nn.Parameter(dct_b)\n",
42 | " \n",
43 | " def forward(self, dct_x):\n",
44 | " x = torch.mm(self.weights, dct_x) + self.bias\n",
45 | " return x\n",
46 | "\n",
47 | " \n",
48 | "class Frontal_Slice(nn.Module):\n",
49 | " def __init__(self, dct_w, dct_b):\n",
50 | " super(Frontal_Slice, self).__init__()\n",
51 | " self.device = dct_w.device\n",
52 | " self.dct_linear = nn.Sequential(\n",
53 | " T_Layer(dct_w, dct_b),\n",
54 | " )\n",
55 | " #nn.ReLU(inplace=True),\n",
56 | " #self.linear1 = nn.Linear(28, 28)\n",
57 | " #nn.ReLU(inplace=True),\n",
58 | " #self.classifier = nn.Linear(28, 10)\n",
59 | " \n",
60 | " def forward(self, x):\n",
61 | " #x = torch.transpose(x, 0, 1).to(self.device)\n",
62 | " x = self.dct_linear(x)\n",
63 | " #x = self.linear1(x)\n",
64 | " #x = self.classifier(x)\n",
65 | " #x = torch.transpose(x, 0, 1)\n",
66 | " return x"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "def train_slice(i, model, x_i, y, outputs, optimizer):\n",
76 | " s = time.time()\n",
77 | " criterion = nn.CrossEntropyLoss()\n",
78 | " o = torch.stack(outputs)\n",
79 | " o[i, ...] = outputs_grad[i]\n",
80 | " o = torch_apply(dct.idct, o)\n",
81 | " o = scalar_tubal_func(o)\n",
82 | " o = torch.transpose(o, 0, 1)\n",
83 | " \n",
84 | " optimizer.zero_grad()\n",
85 | " loss = criterion(o, y) \n",
86 | " loss.backward()\n",
87 | " optimizer.step()\n",
88 | " e = time.time()\n",
89 | " # print(e - s)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "device = 'cpu'\n",
99 | "batch_size = 100\n",
100 | "trainloader, testloader = load_mnist_multiprocess(batch_size)"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "shape = (28, 28, batch_size)\n",
110 | "models = []\n",
111 | "ops = []\n",
112 | "dct_w, dct_b = make_weights(shape, device=device)\n",
113 | "for i in range(shape[0]):\n",
114 | " w_i = dct_w[i, ...].clone()\n",
115 | " b_i = dct_b[i, ...].clone()\n",
116 | " \n",
117 | " w_i.requires_grad = True\n",
118 | " b_i.requires_grad = True\n",
119 | " \n",
120 | " model = Frontal_Slice(w_i, b_i)\n",
121 | " model.train()\n",
122 | " models.append(model.to(device))\n",
123 | " \n",
124 | " op = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
125 | " ops.append(op)"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "epochs = 10\n",
135 | "acc_list = []\n",
136 | "loss_list = []\n",
137 | "\n",
138 | "global outputs_grad\n",
139 | "for e in range(epochs):\n",
140 | " correct = 0\n",
141 | " total = 0\n",
142 | " losses = 0\n",
143 | " pbar = pkbar.Pbar(name='Epoch '+str(e), target=60000/batch_size)\n",
144 | " for batch_idx, (x, y) in enumerate(trainloader): \n",
145 | " dct_x = torch_shift(x)\n",
146 | " dct_x = torch_apply(dct.dct, dct_x)\n",
147 | "\n",
148 | " dct_x = dct_x.to(device)\n",
149 | " y = y.to(device) \n",
150 | " \n",
151 | " outputs_grad = []\n",
152 | " outputs = []\n",
153 | " \n",
154 | " for i in range(len(models)):\n",
155 | " out = models[i](dct_x[i, ...])\n",
156 | " outputs_grad.append(out)\n",
157 | " outputs.append(out.detach())\n",
158 | " \n",
159 | " #for i in range(len(models)):\n",
160 | " # train_slice(i, models[i], dct_x[i, ...], y, outputs, ops[i])\n",
161 | " \n",
162 | " Parallel(n_jobs=16, prefer=\"threads\", verbose=0)(\n",
163 | " delayed(train_slice)(i, models[i], dct_x[i, ...], y, outputs, ops[i]) \\\n",
164 | " for i in range(len(models))\n",
165 | " )\n",
166 | "\n",
167 | " res = torch.empty(shape[0], 10, shape[2])\n",
168 | " for i in range(len(models)):\n",
169 | " res[i, ...] = models[i](dct_x[i, ...])\n",
170 | " \n",
171 | " res = torch_apply(dct.idct, res).to(device)\n",
172 | " res = scalar_tubal_func(res)\n",
173 | " res = torch.transpose(res, 0, 1)\n",
174 | " criterion = nn.CrossEntropyLoss()\n",
175 | " total_loss = criterion(res, y)\n",
176 | " \n",
177 | " _, predicted = torch.max(res, 1)\n",
178 | " total += y.size(0)\n",
179 | " correct += predicted.eq(y).sum().item()\n",
180 | " losses += total_loss\n",
181 | " \n",
182 | " pbar.update(batch_idx)\n",
183 | " # print(total_loss)\n",
184 | " # print(predicted.eq(y).sum().item() / y.size(0))\n",
185 | " \n",
186 | " loss_list.append(losses / total)\n",
187 | " 3acc_list.append(correct / total)"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | " '''\n",
197 | " tmp = torch_mp.get_context('spawn')\n",
198 | " for model in models:\n",
199 | " model.share_memory()\n",
200 | " processes = []\n",
201 | "\n",
202 | " for i in range(len(models)):\n",
203 | " p = tmp.Process(target=train_slice, \n",
204 | " args=(i, models[i], dct_x[i, ...], y, outputs, ops[i]))\n",
205 | " p.start()\n",
206 | " processes.append(p)\n",
207 | " for p in processes: \n",
208 | " p.join()\n",
209 | " '''"
210 | ]
211 | }
212 | ],
213 | "metadata": {
214 | "kernelspec": {
215 | "display_name": "Python 3",
216 | "language": "python",
217 | "name": "python3"
218 | },
219 | "language_info": {
220 | "codemirror_mode": {
221 | "name": "ipython",
222 | "version": 3
223 | },
224 | "file_extension": ".py",
225 | "mimetype": "text/x-python",
226 | "name": "python",
227 | "nbconvert_exporter": "python",
228 | "pygments_lexer": "ipython3",
229 | "version": "3.8.3"
230 | }
231 | },
232 | "nbformat": 4,
233 | "nbformat_minor": 4
234 | }
235 |
--------------------------------------------------------------------------------
/examples/multi_threading_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_dct as dct
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import torch.backends.cudnn as cudnn
7 | from torch.optim.lr_scheduler import StepLR
8 |
9 | import matplotlib.pyplot as plt
10 |
11 | from multiprocessing import Pool, Queue, Process, set_start_method
12 | import multiprocessing as mp_
13 | from joblib import Parallel, delayed
14 |
15 | import time
16 | import pkbar
17 | import sys
18 | sys.path.append('../')
19 | from common import *
20 | from transform_based_network import *
21 |
22 |
23 | # Layer definition
24 | class T_Layer(nn.Module):
25 | def __init__(self, dct_w, dct_b):
26 | super(T_Layer, self).__init__()
27 | self.weights = nn.Parameter(dct_w)
28 | self.bias = nn.Parameter(dct_b)
29 |
30 | def forward(self, dct_x):
31 | x = torch.mm(self.weights, dct_x) + self.bias
32 | return x
33 |
34 |
35 | # Model definition
36 | # This model is going to be run in parallel
37 | class Frontal_Slice(nn.Module):
38 | def __init__(self, dct_w, dct_b):
39 | super(Frontal_Slice, self).__init__()
40 | self.device = dct_w.device
41 | self.dct_linear = nn.Sequential(
42 | T_Layer(dct_w, dct_b),
43 | )
44 |
45 | def forward(self, x):
46 | x = self.dct_linear(x)
47 | return x
48 |
49 |
50 | # This function conducts backprop for a FrontalSlice model
51 | # This function is going to be run in parallel
52 | def train_slice(i, model, x_i, y, outputs, optimizer):
53 | criterion = nn.CrossEntropyLoss()
54 | o = torch.stack(outputs)
55 | o[i, ...] = outputs_grad[i]
56 | o = torch_apply(dct.idct, o)
57 | o = scalar_tubal_func(o)
58 | o = torch.transpose(o, 0, 1)
59 |
60 | optimizer.zero_grad()
61 | loss = criterion(o, y)
62 | loss.backward()
63 | optimizer.step()
64 |
65 |
66 | # Load data
67 | device = 'cpu'
68 | batch_size = 100
69 | trainloader, testloader = load_mnist_multiprocess(batch_size)
70 |
71 |
72 | # Create m FrontalSlice models and initialize
73 | shape = (28, 28, batch_size)
74 | models = []
75 | ops = []
76 | dct_w, dct_b = make_weights(shape, device=device)
77 | for i in range(shape[0]):
78 | w_i = dct_w[i, ...].clone()
79 | b_i = dct_b[i, ...].clone()
80 |
81 | w_i.requires_grad = True
82 | b_i.requires_grad = True
83 |
84 | model = Frontal_Slice(w_i, b_i)
85 | model.train()
86 | models.append(model.to(device))
87 | op = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
88 | ops.append(op)
89 |
90 |
91 | epochs = 10
92 | acc_list = []
93 | loss_list = []
94 |
95 | global outputs_grad
96 | for e in range(epochs):
97 | correct = 0
98 | total = 0
99 | losses = 0
100 | pbar = pkbar.Pbar(name='Epoch '+str(e), target=60000/batch_size)
101 | for batch_idx, (x, y) in enumerate(trainloader):
102 | dct_x = torch_shift(x)
103 | dct_x = torch_apply(dct.dct, dct_x)
104 |
105 | dct_x = dct_x.to(device)
106 | y = y.to(device)
107 |
108 | outputs_grad = []
109 | outputs = []
110 | for i in range(len(models)):
111 | out = models[i](dct_x[i, ...])
112 | outputs_grad.append(out)
113 | outputs.append(out.detach())
114 |
115 | # This line makes multiple calls to train_slice function
116 | # Parallelization
117 | Parallel(n_jobs=16, prefer="threads", verbose=0)(
118 | delayed(train_slice)(i, models[i], dct_x[i, ...], y, outputs, ops[i]) \
119 | for i in range(len(models))
120 | )
121 |
122 | res = torch.empty(shape[0], 10, shape[2])
123 | for i in range(len(models)):
124 | res[i, ...] = models[i](dct_x[i, ...])
125 |
126 | res = torch_apply(dct.idct, res).to(device)
127 | res = scalar_tubal_func(res)
128 | res = torch.transpose(res, 0, 1)
129 | criterion = nn.CrossEntropyLoss()
130 | total_loss = criterion(res, y)
131 |
132 | _, predicted = torch.max(res, 1)
133 | total += y.size(0)
134 | correct += predicted.eq(y).sum().item()
135 | losses += total_loss
136 |
137 | pbar.update(batch_idx)
138 |
139 | loss_list.append(losses / total)
140 | acc_list.append(correct / total)
141 |
142 |
143 | '''
144 | tmp = torch_mp.get_context('spawn')
145 | for model in models:
146 | model.share_memory()
147 | processes = []
148 |
149 | for i in range(len(models)):
150 | p = tmp.Process(target=train_slice,
151 | args=(i, models[i], dct_x[i, ...], y, outputs, ops[i]))
152 | p.start()
153 | processes.append(p)
154 | for p in processes:
155 | p.join()
156 | '''
--------------------------------------------------------------------------------
/examples/multiprocess_net.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch_dct as dct\n",
11 | "import torch.nn as nn\n",
12 | "import torch.optim as optim\n",
13 | "import torch.nn.functional as F\n",
14 | "import torch.backends.cudnn as cudnn\n",
15 | "from torch.optim.lr_scheduler import StepLR\n",
16 | "\n",
17 | "import matplotlib.pyplot as plt\n",
18 | "plt.style.use(['science','no-latex', 'notebook'])\n",
19 | "\n",
20 | "from multiprocessing import Pool, Queue, Process, set_start_method\n",
21 | "import multiprocessing as mp_\n",
22 | "\n",
23 | "import time\n",
24 | "import pkbar\n",
25 | "import sys\n",
26 | "sys.path.append('../')\n",
27 | "from common import *\n",
28 | "from transform_based_network import *"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 45,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "class T_Layer(nn.Module):\n",
38 | " def __init__(self, dct_w, dct_b):\n",
39 | " super(T_Layer, self).__init__()\n",
40 | " w = torch.randn(dct_w.shape)\n",
41 | " b = torch.randn(dct_b.shape)\n",
42 | " self.weights = nn.Parameter(dct_w)\n",
43 | " self.bias = nn.Parameter(dct_b)\n",
44 | " \n",
45 | " def forward(self, dct_x):\n",
46 | " x = torch.mm(self.weights, dct_x)# + self.bias\n",
47 | " return x\n",
48 | "\n",
49 | " \n",
50 | "class Frontal_Slice(nn.Module):\n",
51 | " def __init__(self, dct_w, dct_b):\n",
52 | " super(Frontal_Slice, self).__init__()\n",
53 | " self.device = dct_w.device\n",
54 | " self.dct_linear = nn.Sequential(\n",
55 | " T_Layer(dct_w, dct_b),\n",
56 | " )\n",
57 | " #nn.ReLU(inplace=True),\n",
58 | " #self.linear1 = nn.Linear(28, 28)\n",
59 | " #nn.ReLU(inplace=True),\n",
60 | " #self.linear2 = nn.Linear(28, 28)\n",
61 | " #nn.ReLU(inplace=True),\n",
62 | " #self.classifier = nn.Linear(28, 10)\n",
63 | " \n",
64 | " def forward(self, x):\n",
65 | " #x = torch.transpose(x, 0, 1).to(self.device)\n",
66 | " x = self.dct_linear(x)\n",
67 | " #x = self.linear1(x)\n",
68 | " #x = self.linear2(x)\n",
69 | " #x = self.classifier(x)\n",
70 | " #x = torch.transpose(x, 0, 1)\n",
71 | " return x\n",
72 | " \n",
73 | " \n",
74 | "class Ensemble(nn.Module):\n",
75 | " def __init__(self, shape, device='cpu'):\n",
76 | " super(Ensemble, self).__init__()\n",
77 | " self.device = device \n",
78 | " self.models = nn.ModuleList([])\n",
79 | " dct_w, dct_b = make_weights(shape, device, scale=0.001)\n",
80 | " self.weights = nn.Parameter(dct_w)\n",
81 | " self.bias = nn.Parameter(dct_b)\n",
82 | " for i in range(shape[0]):\n",
83 | " model = Frontal_Slice(self.weights[i, ...], self.bias[i, ...])\n",
84 | " self.models.append(model.to(device))\n",
85 | " \n",
86 | " def forward(self, x):\n",
87 | " self.res = torch.empty(x.shape[0], 10, x.shape[2])\n",
88 | " dct_x = torch_apply(dct.dct, x).to(self.device)\n",
89 | " self.tmp = []\n",
90 | " for i in range(len(self.models)):\n",
91 | " self.tmp.append(self.models[i](dct_x[i, ...]))\n",
92 | " self.res[i, ...] = self.tmp[i]\n",
93 | " self.result = torch_apply(dct.idct, self.res)\n",
94 | " self.softmax = scalar_tubal_func(self.result)\n",
95 | " return torch.transpose(self.softmax, 0, 1)"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 46,
101 | "metadata": {
102 | "scrolled": true
103 | },
104 | "outputs": [],
105 | "source": [
106 | "def train_ensemble(x, y, i=50, device='cuda:0'):\n",
107 | " x = torch_shift(x).to(device)\n",
108 | " y = y.to(device)\n",
109 | " ensemble = Ensemble(x.shape, device).to(device)\n",
110 | " optimizer = optim.SGD(ensemble.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)\n",
111 | " criterion = nn.CrossEntropyLoss()\n",
112 | " pbar = pkbar.Pbar(name='progress', target=i)\n",
113 | " for j in range(i):\n",
114 | " outputs = ensemble(x)\n",
115 | " print(outputs.shape, y.shape)\n",
116 | " optimizer.zero_grad()\n",
117 | " loss = criterion(outputs.to(device), y)\n",
118 | " loss.backward()\n",
119 | " optimizer.step()\n",
120 | " pbar.update(j)\n",
121 | " \n",
122 | " print(loss.item())\n",
123 | " return ensemble\n",
124 | "\n",
125 | "## 16, 10, 10, 100 iterations\n",
126 | "# cpu, for loop: 4.1s\n",
127 | "# gpu, for loop: 5.5s"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 47,
133 | "metadata": {},
134 | "outputs": [
135 | {
136 | "name": "stdout",
137 | "output_type": "stream",
138 | "text": [
139 | "progress\n",
140 | "torch.Size([16, 10]) torch.Size([16])\n",
141 | "1/2 [==============>...............] - 0.1storch.Size([16, 10]) torch.Size([16])\n",
142 | "2/2 [==============================] - 0.1s\n",
143 | "2.179849624633789\n"
144 | ]
145 | }
146 | ],
147 | "source": [
148 | "x0 = []\n",
149 | "y0 = []\n",
150 | "for i in range(100):\n",
151 | " x0.append(torch.randn(16, 29, 28))\n",
152 | " y0.append(torch.randint(10, (16,)))\n",
153 | "\n",
154 | "for i in range(1):\n",
155 | " train_ensemble(x0[i], y0[i], i=2, device='cpu')"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 48,
161 | "metadata": {},
162 | "outputs": [
163 | {
164 | "name": "stdout",
165 | "output_type": "stream",
166 | "text": [
167 | "==> Loading data..\n"
168 | ]
169 | }
170 | ],
171 | "source": [
172 | "batch_size = 10\n",
173 | "trainloader, testloader = load_mnist_multiprocess(batch_size)"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 49,
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "name": "stdout",
183 | "output_type": "stream",
184 | "text": [
185 | "Epoch0\n",
186 | "tensor(2.2876, grad_fn=)\n",
187 | " 1/6000 [..............................] - 0.0stensor(2.2945, grad_fn=)\n",
188 | " 2/6000 [..............................] - 0.1stensor(2.3091, grad_fn=)\n",
189 | " 3/6000 [..............................] - 0.1stensor(2.3509, grad_fn=)\n",
190 | " 4/6000 [..............................] - 0.2stensor(2.3103, grad_fn=)\n",
191 | " 5/6000 [..............................] - 0.3stensor(2.3188, grad_fn=)\n",
192 | " 6/6000 [..............................] - 0.3stensor(2.2883, grad_fn=)\n",
193 | " 7/6000 [..............................] - 0.4stensor(2.3286, grad_fn=)\n",
194 | " 8/6000 [..............................] - 0.4stensor(2.2519, grad_fn=)\n",
195 | " 9/6000 [..............................] - 0.5stensor(2.3762, grad_fn=)\n",
196 | " 10/6000 [..............................] - 0.5stensor(2.2961, grad_fn=)\n",
197 | " 11/6000 [..............................] - 0.6stensor(2.3284, grad_fn=)\n",
198 | " 12/6000 [..............................] - 0.6stensor(2.2897, grad_fn=)\n",
199 | " 13/6000 [..............................] - 0.7stensor(2.3149, grad_fn=)\n",
200 | " 14/6000 [..............................] - 0.7stensor(2.3099, grad_fn=)\n",
201 | " 15/6000 [..............................] - 0.8stensor(2.3955, grad_fn=)\n",
202 | " 16/6000 [..............................] - 0.8stensor(2.3513, grad_fn=)\n",
203 | " 17/6000 [..............................] - 0.9stensor(2.3118, grad_fn=)\n",
204 | " 18/6000 [..............................] - 0.9stensor(2.2473, grad_fn=)\n",
205 | " 19/6000 [..............................] - 1.0stensor(2.2799, grad_fn=)\n",
206 | " 20/6000 [..............................] - 1.0stensor(2.3112, grad_fn=)\n",
207 | " 21/6000 [..............................] - 1.1stensor(2.2879, grad_fn=)\n",
208 | " 22/6000 [..............................] - 1.1stensor(2.3219, grad_fn=)\n",
209 | " 23/6000 [..............................] - 1.2stensor(2.2602, grad_fn=)\n",
210 | " 24/6000 [..............................] - 1.2stensor(2.3284, grad_fn=)\n",
211 | " 25/6000 [..............................] - 1.3stensor(2.3766, grad_fn=)\n",
212 | " 26/6000 [..............................] - 1.3stensor(2.3589, grad_fn=)\n",
213 | " 27/6000 [..............................] - 1.4stensor(2.2433, grad_fn=)\n",
214 | " 28/6000 [..............................] - 1.4stensor(2.2943, grad_fn=)\n",
215 | " 29/6000 [..............................] - 1.5stensor(2.2615, grad_fn=)\n",
216 | " 30/6000 [..............................] - 1.5stensor(2.2647, grad_fn=)\n",
217 | " 31/6000 [..............................] - 1.6stensor(2.2816, grad_fn=)\n",
218 | " 32/6000 [..............................] - 1.6stensor(2.2873, grad_fn=)\n",
219 | " 33/6000 [..............................] - 1.7stensor(2.3596, grad_fn=)\n",
220 | " 34/6000 [..............................] - 1.7stensor(2.3045, grad_fn=)\n",
221 | " 35/6000 [..............................] - 1.8stensor(2.3835, grad_fn=)\n",
222 | " 36/6000 [..............................] - 1.8stensor(2.3288, grad_fn=)\n",
223 | " 37/6000 [..............................] - 1.8stensor(2.2843, grad_fn=)\n",
224 | " 38/6000 [..............................] - 1.9stensor(2.3762, grad_fn=)\n",
225 | " 39/6000 [..............................] - 1.9stensor(2.2595, grad_fn=)\n",
226 | " 40/6000 [..............................] - 2.0stensor(2.3851, grad_fn=)\n",
227 | " 41/6000 [..............................] - 2.0stensor(2.3167, grad_fn=)\n",
228 | " 42/6000 [..............................] - 2.1stensor(2.3009, grad_fn=)\n",
229 | " 43/6000 [..............................] - 2.1stensor(2.2561, grad_fn=)\n",
230 | " 44/6000 [..............................] - 2.2stensor(2.3433, grad_fn=)\n",
231 | " 45/6000 [..............................] - 2.2stensor(2.2998, grad_fn=)\n",
232 | " 46/6000 [..............................] - 2.3stensor(2.3067, grad_fn=)\n",
233 | " 47/6000 [..............................] - 2.3stensor(2.3511, grad_fn=)\n",
234 | " 48/6000 [..............................] - 2.4stensor(2.2981, grad_fn=)\n",
235 | " 49/6000 [..............................] - 2.4stensor(2.3214, grad_fn=)\n",
236 | " 50/6000 [..............................] - 2.5stensor(2.3309, grad_fn=)\n",
237 | " 51/6000 [..............................] - 2.5stensor(2.3063, grad_fn=)\n",
238 | " 52/6000 [..............................] - 2.6stensor(2.3466, grad_fn=)\n",
239 | " 53/6000 [..............................] - 2.6stensor(2.2675, grad_fn=)\n",
240 | " 54/6000 [..............................] - 2.7stensor(2.3141, grad_fn=)\n",
241 | " 55/6000 [..............................] - 2.7stensor(2.3184, grad_fn=)\n",
242 | " 56/6000 [..............................] - 2.8stensor(2.3438, grad_fn=)\n",
243 | " 57/6000 [..............................] - 2.8stensor(2.3158, grad_fn=)\n",
244 | " 58/6000 [..............................] - 2.9stensor(2.5150, grad_fn=)\n",
245 | " 59/6000 [..............................] - 2.9stensor(2.2994, grad_fn=)\n",
246 | " 60/6000 [..............................] - 3.0stensor(2.3718, grad_fn=)\n",
247 | " 61/6000 [..............................] - 3.0stensor(2.3716, grad_fn=)\n",
248 | " 62/6000 [..............................] - 3.1stensor(2.3215, grad_fn=)\n",
249 | " 63/6000 [..............................] - 3.1stensor(2.3169, grad_fn=)\n",
250 | " 64/6000 [..............................] - 3.2stensor(2.3112, grad_fn=)\n",
251 | " 65/6000 [..............................] - 3.2stensor(2.3542, grad_fn=)\n",
252 | " 66/6000 [..............................] - 3.2stensor(2.3217, grad_fn=)\n",
253 | " 67/6000 [..............................] - 3.3stensor(2.2733, grad_fn=)\n",
254 | " 68/6000 [..............................] - 3.3stensor(2.3719, grad_fn=)\n",
255 | " 69/6000 [..............................] - 3.4stensor(2.3289, grad_fn=)\n",
256 | " 70/6000 [..............................] - 3.4stensor(2.2587, grad_fn=)\n",
257 | " 71/6000 [..............................] - 3.5stensor(2.4049, grad_fn=)\n",
258 | " 72/6000 [..............................] - 3.5stensor(2.3286, grad_fn=)\n",
259 | " 73/6000 [..............................] - 3.6stensor(2.4352, grad_fn=)\n",
260 | " 74/6000 [..............................] - 3.6stensor(2.2177, grad_fn=)\n",
261 | " 75/6000 [..............................] - 3.7stensor(2.3477, grad_fn=)\n",
262 | " 76/6000 [..............................] - 3.7stensor(2.3066, grad_fn=)\n",
263 | " 77/6000 [..............................] - 3.8stensor(2.3314, grad_fn=)\n",
264 | " 78/6000 [..............................] - 3.8stensor(2.2981, grad_fn=)\n",
265 | " 79/6000 [..............................] - 3.9stensor(2.3615, grad_fn=)\n",
266 | " 80/6000 [..............................] - 3.9stensor(2.4292, grad_fn=)\n",
267 | " 81/6000 [..............................] - 4.0stensor(2.3730, grad_fn=)\n",
268 | " 82/6000 [..............................] - 4.0stensor(2.3759, grad_fn=)\n",
269 | " 83/6000 [..............................] - 4.1stensor(2.3449, grad_fn=)\n",
270 | " 84/6000 [..............................] - 4.1stensor(2.3138, grad_fn=)\n",
271 | " 85/6000 [..............................] - 4.2stensor(2.2498, grad_fn=)\n",
272 | " 86/6000 [..............................] - 4.2stensor(2.2832, grad_fn=)\n",
273 | " 87/6000 [..............................] - 4.3stensor(2.2201, grad_fn=)\n",
274 | " 88/6000 [..............................] - 4.3stensor(2.3363, grad_fn=)\n"
275 | ]
276 | },
277 | {
278 | "name": "stdout",
279 | "output_type": "stream",
280 | "text": [
281 | " 89/6000 [..............................] - 4.4stensor(2.2559, grad_fn=)\n",
282 | " 90/6000 [..............................] - 4.4stensor(2.3373, grad_fn=)\n",
283 | " 91/6000 [..............................] - 4.5stensor(2.2490, grad_fn=)\n",
284 | " 92/6000 [..............................] - 4.5stensor(2.2938, grad_fn=)\n",
285 | " 93/6000 [..............................] - 4.6stensor(2.2821, grad_fn=)\n",
286 | " 94/6000 [..............................] - 4.6stensor(2.2545, grad_fn=)\n",
287 | " 95/6000 [..............................] - 4.7stensor(2.3990, grad_fn=)\n",
288 | " 96/6000 [..............................] - 4.7stensor(2.2276, grad_fn=)\n",
289 | " 97/6000 [..............................] - 4.7stensor(2.4132, grad_fn=)\n",
290 | " 98/6000 [..............................] - 4.8s"
291 | ]
292 | },
293 | {
294 | "ename": "KeyboardInterrupt",
295 | "evalue": "",
296 | "output_type": "error",
297 | "traceback": [
298 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
299 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
300 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 28\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
301 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m 193\u001b[0m \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 194\u001b[0m \"\"\"\n\u001b[1;32m--> 195\u001b[1;33m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 196\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 197\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
302 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m 97\u001b[0m Variable._execution_engine.run_backward(\n\u001b[0;32m 98\u001b[0m \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 99\u001b[1;33m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[0;32m 100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
303 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
304 | ]
305 | }
306 | ],
307 | "source": [
308 | "device = 'cpu'\n",
309 | "for epoch in range(10):\n",
310 | " pbar = pkbar.Pbar(name='Epoch'+str(epoch), target=60000/batch_size)\n",
311 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
312 | " '''\n",
313 | " dct_x = torch_apply(dct.dct, x.squeeze())\n",
314 | " y_cat = to_categorical(y, 10) \n",
315 | "\n",
316 | " dct_y_cat = torch.randn(y_cat.shape[0], dct_x.shape[1], 10)\n",
317 | " for i in range(10):\n",
318 | " dct_y_cat[:, i, :] = y_cat\n",
319 | " dct_y_cat = torch_apply(dct.dct, dct_y_cat)\n",
320 | " dct_x.to(device)\n",
321 | " dct_y_cat.to(device)\n",
322 | " '''\n",
323 | " correct = 0\n",
324 | " train_loss = 0\n",
325 | " total = 0\n",
326 | " inputs = torch_shift(inputs).to(device)\n",
327 | " ensemble = Ensemble(inputs.shape, device).to(device)\n",
328 | " optimizer = optim.SGD(ensemble.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
329 | " criterion = nn.CrossEntropyLoss()\n",
330 | "\n",
331 | " outputs = ensemble(inputs) \n",
332 | " optimizer.zero_grad()\n",
333 | " loss = criterion(outputs.to(device), targets.to(device))\n",
334 | " loss.backward()\n",
335 | " optimizer.step()\n",
336 | " \n",
337 | " _, predicted = torch.max(outputs, 1)\n",
338 | " correct += predicted.eq(targets).sum().item()\n",
339 | " train_loss += loss.item()\n",
340 | " total += batch_size\n",
341 | " print(loss)\n",
342 | " \n",
343 | " pbar.update(batch_idx)\n",
344 | " print(correct/total, train_loss/total)\n",
345 | " \n",
346 | "\n",
347 | "'''\n",
348 | " models = []\n",
349 | " for i in range(16):\n",
350 | " dct_w, dct_b = make_weights(dct_x.shape, device=device)\n",
351 | " model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...])\n",
352 | " models.append(model.to(device))\n",
353 | "\n",
354 | " for i in range(len(models)):\n",
355 | " train_slice(models[i], dct_x[i, ...], dct_y_cat[i, ...])\n",
356 | " print()\n",
357 | " pbar.update(batch_idx)\n",
358 | " \n",
359 | " tmp = torch_mp.get_context('spawn')\n",
360 | " for model in models:\n",
361 | " model.share_memory()\n",
362 | " processes = []\n",
363 | "\n",
364 | " for i in range(len(models)):\n",
365 | " p = tmp.Process(target=train_slice, \n",
366 | " args=(models[i], dct_x[i, ...], dct_y_cat[i, ...]))\n",
367 | " p.start()\n",
368 | " processes.append(p)\n",
369 | " for p in processes: \n",
370 | " p.join()\n",
371 | " '''"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 25,
377 | "metadata": {},
378 | "outputs": [],
379 | "source": [
380 | "def train_slice(model, x_i, y_i):\n",
381 | " criterion = nn.MSELoss()\n",
382 | " optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=5e-4)\n",
383 | " outputs = model(x_i)\n",
384 | " # print(outputs.shape, y_i.shape)\n",
385 | " optimizer.zero_grad()\n",
386 | " loss = criterion(outputs, y_i)\n",
387 | " loss.backward()\n",
388 | " optimizer.step()"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": 56,
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "for batch_idx, (x, y) in enumerate(trainloader): \n",
398 | " device = 'cpu'\n",
399 | " x = torch_shift(x)\n",
400 | " dct_x = torch_apply(dct.dct, x.squeeze())\n",
401 | " y_cat = to_categorical(y, 10) \n",
402 | "\n",
403 | " dct_y_cat = torch.randn(28, dct_x.shape[2], 10) #y_cat.shape[0]\n",
404 | " for i in range(28):\n",
405 | " dct_y_cat[i, :, :] = y_cat\n",
406 | " dct_y_cat = torch_apply(dct.dct, dct_y_cat)\n",
407 | " dct_x.to(device)\n",
408 | " dct_y_cat.to(device)\n",
409 | " \n",
410 | " models = []\n",
411 | " dct_w, dct_b = make_weights(dct_x.shape, device=device)\n",
412 | " for i in range(28):\n",
413 | " model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...])\n",
414 | " models.append(model.to(device))\n",
415 | "\n",
416 | " for i in range(len(models)):\n",
417 | " train_slice(models[i], dct_x[i, ...], dct_y_cat[i, ...])"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": []
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 61,
430 | "metadata": {},
431 | "outputs": [],
432 | "source": [
433 | "y = torch.eye(10)"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": 74,
439 | "metadata": {},
440 | "outputs": [],
441 | "source": [
442 | "dct_yy = torch.empty(28, 10, 10)\n",
443 | "for i in range(28):\n",
444 | " dct_yy[i, ...] = y * 1\n",
445 | "dct_yy = torch_apply(dct.dct, dct_yy)"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": 75,
451 | "metadata": {},
452 | "outputs": [
453 | {
454 | "data": {
455 | "text/plain": [
456 | "tensor([[12.6239, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n",
457 | " 1.7085, 1.7085],\n",
458 | " [ 1.7085, 12.6239, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n",
459 | " 1.7085, 1.7085],\n",
460 | " [ 1.7085, 1.7085, 12.6239, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n",
461 | " 1.7085, 1.7085],\n",
462 | " [ 1.7085, 1.7085, 1.7085, 12.6239, 1.7085, 1.7085, 1.7085, 1.7085,\n",
463 | " 1.7085, 1.7085],\n",
464 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 12.6239, 1.7085, 1.7085, 1.7085,\n",
465 | " 1.7085, 1.7085],\n",
466 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 12.6239, 1.7085, 1.7085,\n",
467 | " 1.7085, 1.7085],\n",
468 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 12.6239, 1.7085,\n",
469 | " 1.7085, 1.7085],\n",
470 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 12.6239,\n",
471 | " 1.7085, 1.7085],\n",
472 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n",
473 | " 12.6239, 1.7085],\n",
474 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n",
475 | " 1.7085, 12.6239]])"
476 | ]
477 | },
478 | "execution_count": 75,
479 | "metadata": {},
480 | "output_type": "execute_result"
481 | }
482 | ],
483 | "source": [
484 | "result = torch_apply(dct.idct, dct_yy)\n",
485 | "softmax = scalar_tubal_func(result)\n",
486 | "torch.transpose(softmax, 0, 1)"
487 | ]
488 | },
489 | {
490 | "cell_type": "code",
491 | "execution_count": null,
492 | "metadata": {},
493 | "outputs": [],
494 | "source": []
495 | }
496 | ],
497 | "metadata": {
498 | "kernelspec": {
499 | "display_name": "Python 3",
500 | "language": "python",
501 | "name": "python3"
502 | },
503 | "language_info": {
504 | "codemirror_mode": {
505 | "name": "ipython",
506 | "version": 3
507 | },
508 | "file_extension": ".py",
509 | "mimetype": "text/x-python",
510 | "name": "python",
511 | "nbconvert_exporter": "python",
512 | "pygments_lexer": "ipython3",
513 | "version": "3.7.6"
514 | }
515 | },
516 | "nbformat": 4,
517 | "nbformat_minor": 4
518 | }
519 |
--------------------------------------------------------------------------------
/examples/parallel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch_dct as dct
5 |
6 | import sys
7 | sys.path.append('../')
8 | from common import *
9 | from transform_based_network import *
10 |
11 | from torch.multiprocessing import Pool, Queue, Process, set_start_method
12 | import torch.multiprocessing as torch_mp
13 | import multiprocessing as mp
14 |
15 | if __name__ == '__main__':
16 | tmp = torch_mp.get_context('spawn')
17 |
18 | trainloader, testloader = load_mnist()
19 | model = Transform_Net(100)
20 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
21 |
22 | model.share_memory()
23 |
24 | processes = []
25 | num_cores = torch_mp.cpu_count()
26 | for i in range(num_cores):
27 | # q = Queue()
28 | p = tmp.Process(target=train_transform, args=(1, model, trainloader, testloader, optimizer))
29 | p.start()
30 | # print(q.get())
31 | processes.append(p)
32 |
33 | for p in processes:
34 | p.join()
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .lenet import *
2 | from .vgg import *
--------------------------------------------------------------------------------
/nets/lenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | from torchvision import models
11 |
12 | class LeNet(nn.Module):
13 | def __init__(self):
14 | super(LeNet, self).__init__()
15 | self.features = nn.Sequential(
16 | nn.Conv2d(1, 32, 3, 1),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(32, 64, 3, 1),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=2, stride=2),
21 | )
22 | self.classifier = nn.Sequential(
23 | nn.Dropout(0.25),
24 | nn.Linear(9216, 128),
25 | nn.ReLU(inplace=True),
26 | nn.Dropout(0.5),
27 | nn.Linear(128, 10),
28 | )
29 |
30 | def forward(self, x):
31 | x = self.features(x)
32 | x = torch.flatten(x, 1)
33 | x = self.classifier(x)
34 | output = F.log_softmax(x, dim=1)
35 | return output
--------------------------------------------------------------------------------
/nets/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | from torchvision import models
11 |
12 | cfg = {
13 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
14 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
15 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
16 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
17 | }
18 |
19 | class VGG(nn.Module):
20 | def __init__(self, vgg_name):
21 | super(VGG, self).__init__()
22 | self.features = self._make_layers(cfg[vgg_name])
23 | self.classifier = nn.Sequential(
24 | nn.Dropout(0.25),
25 | nn.Linear(512, 128),
26 | nn.ReLU(inplace=True),
27 | nn.Dropout(0.25),
28 | nn.Linear(128, 10),
29 | )
30 |
31 | def forward(self, x):
32 | out = self.features(x)
33 | out = out.view(out.size(0), -1)
34 | out = self.classifier(out)
35 | return out
36 |
37 | def _make_layers(self, cfg):
38 | layers = []
39 | in_channels = 3
40 | for x in cfg:
41 | if x == 'M':
42 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
43 | else:
44 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
45 | nn.BatchNorm2d(x),
46 | nn.ReLU(inplace=True)]
47 | in_channels = x
48 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
49 | return nn.Sequential(*layers)
--------------------------------------------------------------------------------
/notebooks/eval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import cProfile\n",
10 | "import time\n",
11 | "import pstats\n",
12 | "from main import *"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 3,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stdout",
22 | "output_type": "stream",
23 | "text": [
24 | "==> Loading data..\n",
25 | "Files already downloaded and verified\n",
26 | "Files already downloaded and verified\n",
27 | "==> Building model..\n",
28 | "VGG(\n",
29 | " (features): Sequential(\n",
30 | " (0): Sequential(\n",
31 | " (0): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
32 | " (1): Conv2d(1, 22, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
33 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n",
34 | " )\n",
35 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
36 | " (2): ReLU(inplace=True)\n",
37 | " (3): Sequential(\n",
38 | " (0): Conv2d(64, 22, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
39 | " (1): Conv2d(22, 22, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
40 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n",
41 | " )\n",
42 | " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
43 | " (5): ReLU(inplace=True)\n",
44 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
45 | " (7): Sequential(\n",
46 | " (0): Conv2d(64, 22, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
47 | " (1): Conv2d(22, 43, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
48 | " (2): Conv2d(43, 128, kernel_size=(1, 1), stride=(1, 1))\n",
49 | " )\n",
50 | " (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
51 | " (9): ReLU(inplace=True)\n",
52 | " (10): Sequential(\n",
53 | " (0): Conv2d(128, 43, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
54 | " (1): Conv2d(43, 43, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
55 | " (2): Conv2d(43, 128, kernel_size=(1, 1), stride=(1, 1))\n",
56 | " )\n",
57 | " (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
58 | " (12): ReLU(inplace=True)\n",
59 | " (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
60 | " (14): Sequential(\n",
61 | " (0): Conv2d(128, 43, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
62 | " (1): Conv2d(43, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
63 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n",
64 | " )\n",
65 | " (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
66 | " (16): ReLU(inplace=True)\n",
67 | " (17): Sequential(\n",
68 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
69 | " (1): Conv2d(86, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
70 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n",
71 | " )\n",
72 | " (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
73 | " (19): ReLU(inplace=True)\n",
74 | " (20): Sequential(\n",
75 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
76 | " (1): Conv2d(86, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
77 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n",
78 | " )\n",
79 | " (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
80 | " (22): ReLU(inplace=True)\n",
81 | " (23): Sequential(\n",
82 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
83 | " (1): Conv2d(86, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
84 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n",
85 | " )\n",
86 | " (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
87 | " (25): ReLU(inplace=True)\n",
88 | " (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
89 | " (27): Sequential(\n",
90 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
91 | " (1): Conv2d(86, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
92 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
93 | " )\n",
94 | " (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
95 | " (29): ReLU(inplace=True)\n",
96 | " (30): Sequential(\n",
97 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
98 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
99 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
100 | " )\n",
101 | " (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
102 | " (32): ReLU(inplace=True)\n",
103 | " (33): Sequential(\n",
104 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
105 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
106 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
107 | " )\n",
108 | " (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
109 | " (35): ReLU(inplace=True)\n",
110 | " (36): Sequential(\n",
111 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
112 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
113 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
114 | " )\n",
115 | " (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
116 | " (38): ReLU(inplace=True)\n",
117 | " (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
118 | " (40): Sequential(\n",
119 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
120 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
121 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
122 | " )\n",
123 | " (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
124 | " (42): ReLU(inplace=True)\n",
125 | " (43): Sequential(\n",
126 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
127 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
128 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
129 | " )\n",
130 | " (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
131 | " (45): ReLU(inplace=True)\n",
132 | " (46): Sequential(\n",
133 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
134 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
135 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
136 | " )\n",
137 | " (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
138 | " (48): ReLU(inplace=True)\n",
139 | " (49): Sequential(\n",
140 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
141 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
142 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n",
143 | " )\n",
144 | " (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
145 | " (51): ReLU(inplace=True)\n",
146 | " (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
147 | " (53): AvgPool2d(kernel_size=1, stride=1, padding=0)\n",
148 | " )\n",
149 | " (classifier): Linear(in_features=512, out_features=10, bias=True)\n",
150 | ")\n",
151 | "==> Done\n",
152 | "\n",
153 | "Epoch: 0\n",
154 | "|========================================| Accuracy: 10.832 % 5416 / 50000\n",
155 | "|==========| Accuracy: 12.12 % 1212 / 10000\n",
156 | "This epoch took 34.11177968978882 seconds\n",
157 | "Current learning rate: 0.05\n",
158 | "\n",
159 | "Epoch: 1\n",
160 | "|========================================| Accuracy: 11.602 % 5801 / 50000\n",
161 | "|==========| Accuracy: 11.53 % 1153 / 10000\n",
162 | "This epoch took 32.89801287651062 seconds\n",
163 | "Current learning rate: 0.05\n",
164 | "\n",
165 | "Epoch: 2\n",
166 | "|========================================| Accuracy: 11.694 % 5847 / 50000\n",
167 | "|==========| Accuracy: 11.77 % 1177 / 10000\n",
168 | "This epoch took 32.95981407165527 seconds\n",
169 | "Current learning rate: 0.05\n",
170 | "\n",
171 | "Epoch: 3\n",
172 | "|========================================| Accuracy: 12.656 % 6328 / 50000\n",
173 | "|==========| Accuracy: 13.45 % 1345 / 10000\n",
174 | "This epoch took 33.00275278091431 seconds\n",
175 | "Current learning rate: 0.05\n",
176 | "\n",
177 | "Epoch: 4\n",
178 | "|========================================| Accuracy: 17.582 % 8791 / 50000\n",
179 | "|==========| Accuracy: 17.5 % 1750 / 10000\n",
180 | "This epoch took 32.84714198112488 seconds\n",
181 | "Current learning rate: 0.04050000000000001\n",
182 | "\n",
183 | "Epoch: 5\n",
184 | "|========================================| Accuracy: 20.526 % 10263 / 50000\n",
185 | "|==========| Accuracy: 21.94 % 2194 / 10000\n",
186 | "This epoch took 32.7773003578186 seconds\n",
187 | "Current learning rate: 0.045000000000000005\n",
188 | "\n",
189 | "Epoch: 6\n",
190 | "|========================================| Accuracy: 23.64 % 11820 / 50000\n",
191 | "|==========| Accuracy: 27.2 % 2720 / 10000\n",
192 | "This epoch took 33.12140941619873 seconds\n",
193 | "Current learning rate: 0.045000000000000005\n",
194 | "\n",
195 | "Epoch: 7\n"
196 | ]
197 | },
198 | {
199 | "name": "stdout",
200 | "output_type": "stream",
201 | "text": [
202 | "|========================================| Accuracy: 25.932 % 12966 / 50000\n",
203 | "|==========| Accuracy: 26.37 % 2637 / 10000\n",
204 | "This epoch took 33.11143517494202 seconds\n",
205 | "Current learning rate: 0.045000000000000005\n",
206 | "\n",
207 | "Epoch: 8\n",
208 | "|========================================| Accuracy: 28.86 % 14430 / 50000\n",
209 | "|==========| Accuracy: 24.97 % 2497 / 10000\n",
210 | "This epoch took 35.06288194656372 seconds\n",
211 | "Current learning rate: 0.045000000000000005\n",
212 | "\n",
213 | "Epoch: 9\n",
214 | "|========================================| Accuracy: 31.124 % 15562 / 50000\n",
215 | "|==========| Accuracy: 30.32 % 3032 / 10000\n",
216 | "This epoch took 35.821218967437744 seconds\n",
217 | "Current learning rate: 0.03645000000000001\n",
218 | "\n",
219 | "Epoch: 10\n",
220 | "|========================================| Accuracy: 32.398 % 16199 / 50000\n",
221 | "|==========| Accuracy: 15.31 % 1531 / 10000\n",
222 | "This epoch took 37.601341009140015 seconds\n",
223 | "Current learning rate: 0.04050000000000001\n",
224 | "\n",
225 | "Epoch: 11\n",
226 | "|========================================| Accuracy: 32.892 % 16446 / 50000\n",
227 | "|==========| Accuracy: 28.75 % 2875 / 10000\n",
228 | "This epoch took 43.43238830566406 seconds\n",
229 | "Current learning rate: 0.04050000000000001\n",
230 | "\n",
231 | "Epoch: 12\n",
232 | "|========================================| Accuracy: 33.482 % 16741 / 50000\n",
233 | "|==========| Accuracy: 20.39 % 2039 / 10000\n",
234 | "This epoch took 42.75646662712097 seconds\n",
235 | "Current learning rate: 0.04050000000000001\n",
236 | "\n",
237 | "Epoch: 13\n",
238 | "|========================================| Accuracy: 33.716 % 16858 / 50000\n",
239 | "|==========| Accuracy: 19.55 % 1955 / 10000\n",
240 | "This epoch took 40.91100358963013 seconds\n",
241 | "Current learning rate: 0.04050000000000001\n",
242 | "\n",
243 | "Epoch: 14\n",
244 | "|========================================| Accuracy: 33.624 % 16812 / 50000\n",
245 | "|==========| Accuracy: 30.54 % 3054 / 10000\n",
246 | "This epoch took 40.79039931297302 seconds\n",
247 | "Current learning rate: 0.03280500000000001\n",
248 | "\n",
249 | "Epoch: 15\n",
250 | "|========================================| Accuracy: 33.68 % 16840 / 50000\n",
251 | "|==========| Accuracy: 31.88 % 3188 / 10000\n",
252 | "This epoch took 39.57789707183838 seconds\n",
253 | "Current learning rate: 0.03645000000000001\n",
254 | "\n",
255 | "Epoch: 16\n",
256 | "|========================================| Accuracy: 33.614 % 16807 / 50000\n",
257 | "|==========| Accuracy: 29.32 % 2932 / 10000\n",
258 | "This epoch took 41.83706545829773 seconds\n",
259 | "Current learning rate: 0.03645000000000001\n",
260 | "\n",
261 | "Epoch: 17\n",
262 | "|========================================| Accuracy: 36.988 % 18494 / 50000\n",
263 | "|==========| Accuracy: 39.15 % 3915 / 10000\n",
264 | "This epoch took 40.11329388618469 seconds\n",
265 | "Current learning rate: 0.03645000000000001\n",
266 | "\n",
267 | "Epoch: 18\n",
268 | "|========================================| Accuracy: 42.32 % 21160 / 50000\n",
269 | "|==========| Accuracy: 42.6 % 4260 / 10000\n",
270 | "This epoch took 43.18025541305542 seconds\n",
271 | "Current learning rate: 0.03645000000000001\n",
272 | "\n",
273 | "Epoch: 19\n",
274 | "|========================================| Accuracy: 44.63 % 22315 / 50000\n",
275 | "|==========| Accuracy: 36.87 % 3687 / 10000\n",
276 | "This epoch took 41.14051365852356 seconds\n",
277 | "Current learning rate: 0.02952450000000001\n",
278 | "\n",
279 | "Epoch: 20\n",
280 | "|========================================| Accuracy: 47.01 % 23505 / 50000\n",
281 | "|==========| Accuracy: 39.65 % 3965 / 10000\n",
282 | "This epoch took 42.76271724700928 seconds\n",
283 | "Current learning rate: 0.03280500000000001\n",
284 | "\n",
285 | "Epoch: 21\n",
286 | "|========================================| Accuracy: 47.842 % 23921 / 50000\n",
287 | "|==========| Accuracy: 42.14 % 4214 / 10000\n",
288 | "This epoch took 43.29025745391846 seconds\n",
289 | "Current learning rate: 0.03280500000000001\n",
290 | "\n",
291 | "Epoch: 22\n",
292 | "|========================================| Accuracy: 49.414 % 24707 / 50000\n",
293 | "|==========| Accuracy: 43.76 % 4376 / 10000\n",
294 | "This epoch took 45.87770104408264 seconds\n",
295 | "Current learning rate: 0.03280500000000001\n",
296 | "\n",
297 | "Epoch: 23\n",
298 | "|========================================| Accuracy: 50.198 % 25099 / 50000\n",
299 | "|==========| Accuracy: 50.9 % 5090 / 10000\n",
300 | "This epoch took 43.95064091682434 seconds\n",
301 | "Current learning rate: 0.03280500000000001\n",
302 | "\n",
303 | "Epoch: 24\n",
304 | "|========================================| Accuracy: 50.882 % 25441 / 50000\n",
305 | "|==========| Accuracy: 39.98 % 3998 / 10000\n",
306 | "This epoch took 44.10254693031311 seconds\n",
307 | "Current learning rate: 0.02657205000000001\n",
308 | "\n",
309 | "Epoch: 25\n",
310 | "|========================================| Accuracy: 52.896 % 26448 / 50000\n",
311 | "|==========| Accuracy: 40.97 % 4097 / 10000\n",
312 | "This epoch took 44.637322187423706 seconds\n",
313 | "Current learning rate: 0.02952450000000001\n",
314 | "\n",
315 | "Epoch: 26\n",
316 | "|========================================| Accuracy: 53.682 % 26841 / 50000\n",
317 | "|==========| Accuracy: 48.24 % 4824 / 10000\n",
318 | "This epoch took 43.39285159111023 seconds\n",
319 | "Current learning rate: 0.02952450000000001\n",
320 | "\n",
321 | "Epoch: 27\n",
322 | "|========================================| Accuracy: 54.718 % 27359 / 50000\n",
323 | "|==========| Accuracy: 53.49 % 5349 / 10000\n",
324 | "This epoch took 41.88004779815674 seconds\n",
325 | "Current learning rate: 0.02952450000000001\n",
326 | "\n",
327 | "Epoch: 28\n",
328 | "|========================================| Accuracy: 54.854 % 27427 / 50000\n",
329 | "|==========| Accuracy: 49.18 % 4918 / 10000\n",
330 | "This epoch took 45.3488028049469 seconds\n",
331 | "Current learning rate: 0.02952450000000001\n",
332 | "\n",
333 | "Epoch: 29\n",
334 | "|========================================| Accuracy: 55.854 % 27927 / 50000\n",
335 | "|==========| Accuracy: 45.03 % 4503 / 10000\n",
336 | "This epoch took 42.54808688163757 seconds\n",
337 | "Current learning rate: 0.02391484500000001\n",
338 | "\n",
339 | "Epoch: 30\n",
340 | "|========================================| Accuracy: 58.202 % 29101 / 50000\n",
341 | "|==========| Accuracy: 50.08 % 5008 / 10000\n",
342 | "This epoch took 46.533724308013916 seconds\n",
343 | "Current learning rate: 0.02657205000000001\n",
344 | "\n",
345 | "Epoch: 31\n",
346 | "|========================================| Accuracy: 59.262 % 29631 / 50000\n",
347 | "|==========| Accuracy: 50.52 % 5052 / 10000\n",
348 | "This epoch took 46.0679144859314 seconds\n",
349 | "Current learning rate: 0.02657205000000001\n",
350 | "\n",
351 | "Epoch: 32\n",
352 | "|========================================| Accuracy: 60.114 % 30057 / 50000\n",
353 | "|==========| Accuracy: 52.74 % 5274 / 10000\n",
354 | "This epoch took 44.88120102882385 seconds\n",
355 | "Current learning rate: 0.02657205000000001\n",
356 | "\n",
357 | "Epoch: 33\n",
358 | "|========================================| Accuracy: 60.54 % 30270 / 50000\n",
359 | "|==========| Accuracy: 55.82 % 5582 / 10000\n",
360 | "This epoch took 38.706753730773926 seconds\n",
361 | "Current learning rate: 0.02657205000000001\n",
362 | "\n",
363 | "Epoch: 34\n",
364 | "|========================================| Accuracy: 61.202 % 30601 / 50000\n",
365 | "|==========| Accuracy: 53.62 % 5362 / 10000\n",
366 | "This epoch took 40.96643614768982 seconds\n",
367 | "Current learning rate: 0.021523360500000012\n",
368 | "\n",
369 | "Epoch: 35\n",
370 | "|========================================| Accuracy: 63.02 % 31510 / 50000\n",
371 | "|==========| Accuracy: 51.62 % 5162 / 10000\n",
372 | "This epoch took 42.569278717041016 seconds\n",
373 | "Current learning rate: 0.02391484500000001\n",
374 | "\n",
375 | "Epoch: 36\n",
376 | "|========================================| Accuracy: 63.206 % 31603 / 50000\n",
377 | "|==========| Accuracy: 63.72 % 6372 / 10000\n",
378 | "This epoch took 42.79570960998535 seconds\n",
379 | "Current learning rate: 0.02391484500000001\n",
380 | "\n",
381 | "Epoch: 37\n",
382 | "|========================================| Accuracy: 63.456 % 31728 / 50000\n",
383 | "|==========| Accuracy: 52.97 % 5297 / 10000\n",
384 | "This epoch took 39.450865030288696 seconds\n",
385 | "Current learning rate: 0.02391484500000001\n",
386 | "\n",
387 | "Epoch: 38\n",
388 | "|========================================| Accuracy: 64.164 % 32082 / 50000\n",
389 | "|==========| Accuracy: 63.85 % 6385 / 10000\n",
390 | "This epoch took 40.91485786437988 seconds\n",
391 | "Current learning rate: 0.02391484500000001\n",
392 | "\n",
393 | "Epoch: 39\n",
394 | "|========================================| Accuracy: 64.876 % 32438 / 50000\n",
395 | "|==========| Accuracy: 62.66 % 6266 / 10000\n",
396 | "This epoch took 40.16771578788757 seconds\n",
397 | "Current learning rate: 0.01937102445000001\n",
398 | "\n",
399 | "Epoch: 40\n",
400 | "|========================================| Accuracy: 66.234 % 33117 / 50000\n",
401 | "|==========| Accuracy: 63.21 % 6321 / 10000\n",
402 | "This epoch took 40.66132354736328 seconds\n",
403 | "Current learning rate: 0.021523360500000012\n",
404 | "\n",
405 | "Epoch: 41\n",
406 | "|========================================| Accuracy: 66.352 % 33176 / 50000\n",
407 | "|==========| Accuracy: 62.34 % 6234 / 10000\n",
408 | "This epoch took 46.38554286956787 seconds\n",
409 | "Current learning rate: 0.021523360500000012\n",
410 | "\n",
411 | "Epoch: 42\n",
412 | "|========================================| Accuracy: 66.85 % 33425 / 50000\n",
413 | "|==========| Accuracy: 60.7 % 6070 / 10000\n",
414 | "This epoch took 42.69945788383484 seconds\n",
415 | "Current learning rate: 0.021523360500000012\n",
416 | "\n",
417 | "Epoch: 43\n",
418 | "|========================================| Accuracy: 67.504 % 33752 / 50000\n",
419 | "|==========| Accuracy: 59.91 % 5991 / 10000\n",
420 | "This epoch took 43.23206090927124 seconds\n",
421 | "Current learning rate: 0.021523360500000012\n",
422 | "\n",
423 | "Epoch: 44\n",
424 | "|========================================| Accuracy: 67.722 % 33861 / 50000\n"
425 | ]
426 | },
427 | {
428 | "name": "stdout",
429 | "output_type": "stream",
430 | "text": [
431 | "|==========| Accuracy: 58.57 % 5857 / 10000\n",
432 | "This epoch took 41.02703857421875 seconds\n",
433 | "Current learning rate: 0.01743392200500001\n",
434 | "\n",
435 | "Epoch: 45\n",
436 | "|========================================| Accuracy: 68.858 % 34429 / 50000\n",
437 | "|==========| Accuracy: 61.41 % 6141 / 10000\n",
438 | "This epoch took 41.44154691696167 seconds\n",
439 | "Current learning rate: 0.01937102445000001\n",
440 | "\n",
441 | "Epoch: 46\n",
442 | "|========================================| Accuracy: 69.02 % 34510 / 50000\n",
443 | "|==========| Accuracy: 63.16 % 6316 / 10000\n",
444 | "This epoch took 43.859628200531006 seconds\n",
445 | "Current learning rate: 0.01937102445000001\n",
446 | "\n",
447 | "Epoch: 47\n",
448 | "|========================================| Accuracy: 69.556 % 34778 / 50000\n",
449 | "|==========| Accuracy: 63.7 % 6370 / 10000\n",
450 | "This epoch took 41.11183762550354 seconds\n",
451 | "Current learning rate: 0.01937102445000001\n",
452 | "\n",
453 | "Epoch: 48\n",
454 | "|========================================| Accuracy: 69.896 % 34948 / 50000\n",
455 | "|==========| Accuracy: 60.98 % 6098 / 10000\n",
456 | "This epoch took 42.747894048690796 seconds\n",
457 | "Current learning rate: 0.01937102445000001\n",
458 | "\n",
459 | "Epoch: 49\n",
460 | "|========================================| Accuracy: 70.256 % 35128 / 50000\n",
461 | "|==========| Accuracy: 63.98 % 6398 / 10000\n",
462 | "This epoch took 41.37175130844116 seconds\n",
463 | "Current learning rate: 0.015690529804500006\n",
464 | "Best training accuracy overall: 63.98\n",
465 | "|==========| Accuracy: 63.98 % 6398 / 10000\n",
466 | "Testing accuracy: 63.98\n",
467 | "Mon Jun 8 21:33:18 2020 stats\n",
468 | "\n",
469 | " 510155235 function calls (507136221 primitive calls) in 2053.806 seconds\n",
470 | "\n",
471 | " Ordered by: internal time\n",
472 | " List reduced from 824 to 20 due to restriction <20>\n",
473 | "\n",
474 | " ncalls tottime percall cumtime percall filename:lineno(function)\n",
475 | " 49401 743.334 0.015 743.334 0.015 {method 'item' of 'torch._C._TensorBase' objects}\n",
476 | " 19550 222.799 0.011 222.799 0.011 {method 'run_backward' of 'torch._C._EngineBase' objects}\n",
477 | " 5747602 107.058 0.000 107.058 0.000 {method 'add_' of 'torch._C._TensorBase' objects}\n",
478 | " 1183200 102.777 0.000 102.777 0.000 {built-in method conv2d}\n",
479 | " 3010000 59.049 0.000 140.632 0.000 functional.py:192(normalize)\n",
480 | " 3010000 45.654 0.000 205.622 0.000 functional.py:43(to_tensor)\n",
481 | " 1915802 45.251 0.000 45.251 0.000 {method 'mul_' of 'torch._C._TensorBase' objects}\n",
482 | " 394400 41.219 0.000 41.219 0.000 {built-in method batch_norm}\n",
483 | " 3010000 40.894 0.000 40.894 0.000 {method 'tobytes' of 'numpy.ndarray' objects}\n",
484 | " 1915850 39.603 0.000 39.603 0.000 {method 'zero_' of 'torch._C._TensorBase' objects}\n",
485 | " 3010000 32.589 0.000 32.589 0.000 {method 'div' of 'torch._C._TensorBase' objects}\n",
486 | " 3010000 25.541 0.000 25.541 0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}\n",
487 | " 6020000 25.312 0.000 25.312 0.000 {built-in method as_tensor}\n",
488 | " 3010000 24.033 0.000 24.033 0.000 {method 'sub_' of 'torch._C._TensorBase' objects}\n",
489 | " 3010000 20.338 0.000 116.898 0.000 Image.py:2644(fromarray)\n",
490 | " 6020000 19.078 0.000 19.078 0.000 {method 'transpose' of 'torch._C._TensorBase' objects}\n",
491 | " 3034650 18.921 0.000 18.921 0.000 {method 'view' of 'torch._C._TensorBase' objects}\n",
492 | " 3010000 18.467 0.000 18.467 0.000 {method 'float' of 'torch._C._TensorBase' objects}\n",
493 | " 3010000 15.967 0.000 15.967 0.000 {method 'clone' of 'torch._C._TensorBase' objects}\n",
494 | " 19550 15.331 0.001 168.502 0.009 sgd.py:71(step)\n",
495 | "\n",
496 | "\n",
497 | "Wall time: 34min 13s\n"
498 | ]
499 | }
500 | ],
501 | "source": [
502 | "%%time \n",
503 | "\n",
504 | "cProfile.run('eval()', 'stats')\n",
505 | "p = pstats.Stats('stats')\n",
506 | "p.strip_dirs().sort_stats(1).print_stats(20)\n",
507 | "p.dump_stats('prof\\\\main.prof')"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": null,
513 | "metadata": {},
514 | "outputs": [],
515 | "source": []
516 | }
517 | ],
518 | "metadata": {
519 | "kernelspec": {
520 | "display_name": "Python 3",
521 | "language": "python",
522 | "name": "python3"
523 | },
524 | "language_info": {
525 | "codemirror_mode": {
526 | "name": "ipython",
527 | "version": 3
528 | },
529 | "file_extension": ".py",
530 | "mimetype": "text/x-python",
531 | "name": "python",
532 | "nbconvert_exporter": "python",
533 | "pygments_lexer": "ipython3",
534 | "version": "3.7.6"
535 | }
536 | },
537 | "nbformat": 4,
538 | "nbformat_minor": 4
539 | }
540 |
--------------------------------------------------------------------------------
/notebooks/main.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# main function for decomposition\n",
8 | "### Author: Yiming Fang"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import torch\n",
18 | "import torch.nn as nn\n",
19 | "import torch.optim as optim\n",
20 | "import torch.nn.functional as F\n",
21 | "import torch.backends.cudnn as cudnn\n",
22 | "from torch.optim.lr_scheduler import StepLR\n",
23 | "\n",
24 | "import torchvision\n",
25 | "import torchvision.transforms as transforms\n",
26 | "from torchvision import models\n",
27 | "\n",
28 | "import tensorly as tl\n",
29 | "import tensorly\n",
30 | "from itertools import chain\n",
31 | "from tensorly.decomposition import parafac, partial_tucker\n",
32 | "\n",
33 | "import os\n",
34 | "import matplotlib.pyplot as plt\n",
35 | "import numpy as np\n",
36 | "import time\n",
37 | "\n",
38 | "from nets import *\n",
39 | "from decomp import *"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 2,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "# load data\n",
49 | "def load_mnist():\n",
50 | " print('==> Loading data..')\n",
51 | " transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n",
52 | " transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n",
53 | "\n",
54 | " trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)\n",
55 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n",
56 | "\n",
57 | " testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)\n",
58 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n",
59 | " return trainloader, testloader\n",
60 | "\n",
61 | "def load_cifar10():\n",
62 | " print('==> Loading data..')\n",
63 | " transform_train = transforms.Compose([\n",
64 | " transforms.RandomCrop(32, padding=4),\n",
65 | " transforms.RandomHorizontalFlip(),\n",
66 | " transforms.ToTensor(),\n",
67 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
68 | " ])\n",
69 | "\n",
70 | " transform_test = transforms.Compose([\n",
71 | " transforms.ToTensor(),\n",
72 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
73 | " ])\n",
74 | "\n",
75 | " trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
76 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n",
77 | "\n",
78 | " testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
79 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n",
80 | " \n",
81 | " return trainloader, testloader\n",
82 | "\n",
83 | "# ImageNet is no longer publically available\n",
84 | "def load_imagenet():\n",
85 | " print('==> Loading data..')\n",
86 | " transform_train = transforms.Compose([\n",
87 | " transforms.RandomResizedCrop(224),\n",
88 | " transforms.RandomHorizontalFlip(),\n",
89 | " transforms.ToTensor(),\n",
90 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
91 | " ])\n",
92 | "\n",
93 | " transform_test = transforms.Compose([\n",
94 | " transforms.Resize(256),\n",
95 | " transforms.CenterCrop(224),\n",
96 | " transforms.ToTensor(),\n",
97 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
98 | " ])\n",
99 | "\n",
100 | " trainset = torchvision.datasets.ImageNet(root='./data', train=True, download=True, transform=transform_train)\n",
101 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n",
102 | "\n",
103 | " testset = torchvision.datasets.ImageNet(root='./data', train=False, download=True, transform=transform_test)\n",
104 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n",
105 | " \n",
106 | " return trainloader, testloader\n",
107 | "\n",
108 | "def load_cifar100():\n",
109 | " print('==> Loading data..')\n",
110 | " transform_train = transforms.Compose([\n",
111 | " transforms.RandomCrop(32, padding=4),\n",
112 | " transforms.RandomHorizontalFlip(),\n",
113 | " transforms.ToTensor(),\n",
114 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
115 | " ])\n",
116 | "\n",
117 | " transform_test = transforms.Compose([\n",
118 | " transforms.ToTensor(),\n",
119 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
120 | " ])\n",
121 | "\n",
122 | " trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)\n",
123 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n",
124 | "\n",
125 | " testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)\n",
126 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n",
127 | " \n",
128 | " return trainloader, testloader"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 3,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "# build model\n",
138 | "def build(model, decomp='cp'):\n",
139 | " print('==> Building model..')\n",
140 | " tl.set_backend('pytorch')\n",
141 | " full_net = model\n",
142 | " full_net = full_net.to(device)\n",
143 | " torch.save(full_net, 'models/model')\n",
144 | " if decomp:\n",
145 | " decompose(decomp)\n",
146 | " net = torch.load(\"models/model\").cuda()\n",
147 | " print(net)\n",
148 | " print('==> Done')\n",
149 | " return net\n",
150 | " \n",
151 | "# training\n",
152 | "def train(epoch, train_acc, model):\n",
153 | " print('\\nEpoch: ', epoch)\n",
154 | " model.train()\n",
155 | " criterion = nn.CrossEntropyLoss()\n",
156 | " train_loss = 0\n",
157 | " correct = 0\n",
158 | " total = 0\n",
159 | " print('|', end='')\n",
160 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
161 | " inputs, targets = inputs.to(device), targets.to(device)\n",
162 | " optimizer.zero_grad()\n",
163 | " outputs = model(inputs)\n",
164 | " loss = criterion(outputs, targets)\n",
165 | " loss.backward()\n",
166 | " optimizer.step()\n",
167 | " train_loss += loss.item()\n",
168 | " _, predicted = outputs.max(1)\n",
169 | " total += targets.size(0)\n",
170 | " correct += predicted.eq(targets).sum().item()\n",
171 | " if batch_idx % 10 == 0:\n",
172 | " print('=', end='')\n",
173 | " print('|', 'Accuracy:', 100. * correct / total,'% ', correct, '/', total)\n",
174 | " train_acc.append(correct / total)\n",
175 | " return train_acc\n",
176 | "\n",
177 | "# testing\n",
178 | "def test(test_acc, model):\n",
179 | " model.eval()\n",
180 | " test_loss = 0\n",
181 | " correct = 0\n",
182 | " total = 0\n",
183 | " criterion = nn.CrossEntropyLoss()\n",
184 | " with torch.no_grad():\n",
185 | " print('|', end='')\n",
186 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n",
187 | " inputs, targets = inputs.to(device), targets.to(device)\n",
188 | " outputs = model(inputs)\n",
189 | " loss = criterion(outputs, targets)\n",
190 | " test_loss += loss.item()\n",
191 | " _, predicted = outputs.max(1)\n",
192 | " total += targets.size(0)\n",
193 | " correct += predicted.eq(targets).sum().item()\n",
194 | " if batch_idx % 10 == 0:\n",
195 | " print('=', end='')\n",
196 | " acc = 100. * correct / total\n",
197 | " print('|', 'Accuracy:', acc, '% ', correct, '/', total)\n",
198 | " test_acc.append(correct / total) \n",
199 | " return test_acc\n",
200 | "\n",
201 | "# decompose\n",
202 | "def decompose(decomp):\n",
203 | " model = torch.load(\"models/model\").cuda()\n",
204 | " model.eval()\n",
205 | " model.cpu()\n",
206 | " for i, key in enumerate(model.features._modules.keys()):\n",
207 | " if i >= len(model.features._modules.keys()) - 2:\n",
208 | " break\n",
209 | " conv_layer = model.features._modules[key]\n",
210 | " if isinstance(conv_layer, torch.nn.modules.conv.Conv2d):\n",
211 | " rank = max(conv_layer.weight.data.numpy().shape) // 10\n",
212 | " if decomp == 'cp':\n",
213 | " model.features._modules[key] = cp_decomposition_conv_layer(conv_layer, rank)\n",
214 | " if decomp == 'tucker': \n",
215 | " ranks = [int(np.ceil(conv_layer.weight.data.numpy().shape[0] / 3)), \n",
216 | " int(np.ceil(conv_layer.weight.data.numpy().shape[1] / 3))]\n",
217 | " model.features._modules[key] = tucker_decomposition_conv_layer(conv_layer, ranks)\n",
218 | " if decomp == 'tt':\n",
219 | " model.features._modules[key] = tt_decomposition_conv_layer(conv_layer, rank)\n",
220 | " torch.save(model, 'models/model')\n",
221 | " return model\n",
222 | "\n",
223 | "# Run functions\n",
224 | "def run_train(i, model):\n",
225 | " train_acc = []\n",
226 | " test_acc = []\n",
227 | " for epoch in range(i):\n",
228 | " s = time.time()\n",
229 | " train_acc = train(epoch, train_acc, model)\n",
230 | " test_acc = test(test_acc, model)\n",
231 | " scheduler.step()\n",
232 | " e = time.time()\n",
233 | " print('This epoch took', e - s, 'seconds')\n",
234 | " print('Current learning rate: ', scheduler.get_lr()[0])\n",
235 | " print('Best training accuracy overall: ', max(test_acc))\n",
236 | " return train_acc, test_acc"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 4,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "# main function\n",
246 | "def run_all(dataset, decomp=None, iterations=100, rate=0.05): \n",
247 | " global trainloader, testloader, device, optimizer, scheduler\n",
248 | " \n",
249 | " # choose an appropriate learning rate\n",
250 | " rate = rate\n",
251 | " \n",
252 | " # choose dataset from (MNIST, CIFAR10, ImageNet)\n",
253 | " if dataset == 'mnist':\n",
254 | " trainloader, testloader = load_mnist()\n",
255 | " model = Net()\n",
256 | " if dataset == 'cifar10':\n",
257 | " trainloader, testloader = load_cifar10()\n",
258 | " model = VGG('VGG19')\n",
259 | " if dataset == 'cifar100':\n",
260 | " trainloader, testloader = load_cifar100()\n",
261 | " model = VGG('VGG19')\n",
262 | " \n",
263 | " # check GPU availability\n",
264 | " device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
265 | " \n",
266 | " # choose decomposition algorithm from (CP, Tucker, TT)\n",
267 | " net = build(model, decomp)\n",
268 | " optimizer = optim.SGD(net.parameters(), lr=rate, momentum=0.9, weight_decay=5e-4)\n",
269 | " scheduler = StepLR(optimizer, step_size=5, gamma=0.9)\n",
270 | " train_acc, test_acc = run_train(iterations, net)\n",
271 | " \n",
272 | " if not decomp:\n",
273 | " decomp = 'full'\n",
274 | " \n",
275 | " filename = dataset + '_' + decomp\n",
276 | " torch.save(net, 'models/' + filename)\n",
277 | " np.save('curves/' + filename + '_train', train_acc)\n",
278 | " np.save('curves/' + filename + '_test', test_acc)"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": null,
284 | "metadata": {},
285 | "outputs": [
286 | {
287 | "name": "stdout",
288 | "output_type": "stream",
289 | "text": [
290 | "==> Loading data..\n",
291 | "==> Building model..\n",
292 | "Net(\n",
293 | " (features): Sequential(\n",
294 | " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
295 | " (1): ReLU(inplace=True)\n",
296 | " (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
297 | " (3): ReLU(inplace=True)\n",
298 | " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
299 | " )\n",
300 | " (classifier): Sequential(\n",
301 | " (0): Dropout(p=0.25, inplace=False)\n",
302 | " (1): Linear(in_features=9216, out_features=128, bias=True)\n",
303 | " (2): ReLU(inplace=True)\n",
304 | " (3): Dropout(p=0.5, inplace=False)\n",
305 | " (4): Linear(in_features=128, out_features=10, bias=True)\n",
306 | " )\n",
307 | ")\n",
308 | "==> Done\n",
309 | "\n",
310 | "Epoch: 0\n",
311 | "|===============================================| Accuracy: 91.945 % 55167 / 60000\n",
312 | "|==========| Accuracy: 98.28 % 9828 / 10000\n",
313 | "This epoch took 13.191731214523315 seconds\n",
314 | "Current learning rate: 0.05\n",
315 | "\n",
316 | "Epoch: 1\n",
317 | "|===============================================| Accuracy: 97.02166666666666 % 58213 / 60000\n",
318 | "|==========| Accuracy: 98.61 % 9861 / 10000\n",
319 | "This epoch took 12.04975700378418 seconds\n",
320 | "Current learning rate: 0.05\n",
321 | "\n",
322 | "Epoch: 2\n",
323 | "|===============================================| Accuracy: 97.565 % 58539 / 60000\n",
324 | "|==========| Accuracy: 98.8 % 9880 / 10000\n",
325 | "This epoch took 12.05876111984253 seconds\n",
326 | "Current learning rate: 0.05\n",
327 | "\n",
328 | "Epoch: 3\n",
329 | "|===============================================| Accuracy: 97.915 % 58749 / 60000\n",
330 | "|==========| Accuracy: 98.9 % 9890 / 10000\n",
331 | "This epoch took 12.11062240600586 seconds\n",
332 | "Current learning rate: 0.05\n",
333 | "\n",
334 | "Epoch: 4\n",
335 | "|===============================================| Accuracy: 97.955 % 58773 / 60000\n",
336 | "|==========| Accuracy: 98.95 % 9895 / 10000\n",
337 | "This epoch took 12.165474891662598 seconds\n",
338 | "Current learning rate: 0.04050000000000001\n",
339 | "\n",
340 | "Epoch: 5\n",
341 | "|===============================================| Accuracy: 98.34833333333333 % 59009 / 60000\n",
342 | "|==========| Accuracy: 98.95 % 9895 / 10000\n",
343 | "This epoch took 12.157521963119507 seconds\n",
344 | "Current learning rate: 0.045000000000000005\n",
345 | "\n",
346 | "Epoch: 6\n",
347 | "|===============================================| Accuracy: 98.545 % 59127 / 60000\n",
348 | "|==========| Accuracy: 99.0 % 9900 / 10000\n",
349 | "This epoch took 12.109601020812988 seconds\n",
350 | "Current learning rate: 0.045000000000000005\n",
351 | "\n",
352 | "Epoch: 7\n",
353 | "|===============================================| Accuracy: 98.56 % 59136 / 60000\n",
354 | "|==========| Accuracy: 98.68 % 9868 / 10000\n",
355 | "This epoch took 12.11560845375061 seconds\n",
356 | "Current learning rate: 0.045000000000000005\n",
357 | "\n",
358 | "Epoch: 8\n",
359 | "|===============================================| Accuracy: 98.64333333333333 % 59186 / 60000\n",
360 | "|==========| Accuracy: 99.15 % 9915 / 10000\n",
361 | "This epoch took 12.045796155929565 seconds\n",
362 | "Current learning rate: 0.045000000000000005\n",
363 | "\n",
364 | "Epoch: 9\n",
365 | "|===============================================| Accuracy: 98.69166666666666 % 59215 / 60000\n",
366 | "|==========| Accuracy: 99.06 % 9906 / 10000\n",
367 | "This epoch took 11.952045917510986 seconds\n",
368 | "Current learning rate: 0.03645000000000001\n",
369 | "\n",
370 | "Epoch: 10\n",
371 | "|===============================================| Accuracy: 98.92666666666666 % 59356 / 60000\n",
372 | "|==========| Accuracy: 99.22 % 9922 / 10000\n",
373 | "This epoch took 12.087683916091919 seconds\n",
374 | "Current learning rate: 0.04050000000000001\n",
375 | "\n",
376 | "Epoch: 11\n",
377 | "|===============================================| Accuracy: 98.82333333333334 % 59294 / 60000\n",
378 | "|==========| Accuracy: 99.15 % 9915 / 10000\n",
379 | "This epoch took 11.951063871383667 seconds\n",
380 | "Current learning rate: 0.04050000000000001\n",
381 | "\n",
382 | "Epoch: 12\n",
383 | "|===============================================| Accuracy: 98.93333333333334 % 59360 / 60000\n",
384 | "|==========| Accuracy: 99.11 % 9911 / 10000\n",
385 | "This epoch took 11.971977710723877 seconds\n",
386 | "Current learning rate: 0.04050000000000001\n",
387 | "\n",
388 | "Epoch: 13\n",
389 | "|===============================================| Accuracy: 98.83833333333334 % 59303 / 60000\n",
390 | "|==========| Accuracy: 99.02 % 9902 / 10000\n",
391 | "This epoch took 11.958030462265015 seconds\n",
392 | "Current learning rate: 0.04050000000000001\n",
393 | "\n",
394 | "Epoch: 14\n",
395 | "|===============================================| Accuracy: 98.97 % 59382 / 60000\n",
396 | "|==========| Accuracy: 99.11 % 9911 / 10000\n",
397 | "This epoch took 11.924121141433716 seconds\n",
398 | "Current learning rate: 0.03280500000000001\n",
399 | "\n",
400 | "Epoch: 15\n",
401 | "|===============================================| Accuracy: 99.04 % 59424 / 60000\n",
402 | "|==========| Accuracy: 99.17 % 9917 / 10000\n",
403 | "This epoch took 11.97598123550415 seconds\n",
404 | "Current learning rate: 0.03645000000000001\n",
405 | "\n",
406 | "Epoch: 16\n",
407 | "|===============================================| Accuracy: 99.12333333333333 % 59474 / 60000\n",
408 | "|==========| Accuracy: 99.3 % 9930 / 10000\n",
409 | "This epoch took 11.95703387260437 seconds\n",
410 | "Current learning rate: 0.03645000000000001\n",
411 | "\n",
412 | "Epoch: 17\n",
413 | "|===============================================| Accuracy: 99.04333333333334 % 59426 / 60000\n",
414 | "|==========| Accuracy: 99.12 % 9912 / 10000\n",
415 | "This epoch took 11.95304274559021 seconds\n",
416 | "Current learning rate: 0.03645000000000001\n",
417 | "\n",
418 | "Epoch: 18\n",
419 | "|===============================================| Accuracy: 99.06166666666667 % 59437 / 60000\n",
420 | "|==========| Accuracy: 99.26 % 9926 / 10000\n",
421 | "This epoch took 12.056767225265503 seconds\n",
422 | "Current learning rate: 0.03645000000000001\n",
423 | "\n",
424 | "Epoch: 19\n",
425 | "|===============================================| Accuracy: 99.13833333333334 % 59483 / 60000\n",
426 | "|==========| Accuracy: 99.18 % 9918 / 10000\n",
427 | "This epoch took 12.055768728256226 seconds\n",
428 | "Current learning rate: 0.02952450000000001\n",
429 | "\n",
430 | "Epoch: 20\n",
431 | "|===============================================| Accuracy: 99.13166666666666 % 59479 / 60000\n",
432 | "|==========| Accuracy: 99.13 % 9913 / 10000\n",
433 | "This epoch took 12.040835857391357 seconds\n",
434 | "Current learning rate: 0.03280500000000001\n",
435 | "\n",
436 | "Epoch: 21\n",
437 | "|===============================================| Accuracy: 99.195 % 59517 / 60000\n",
438 | "|==========| Accuracy: 99.24 % 9924 / 10000\n",
439 | "This epoch took 12.066712379455566 seconds\n",
440 | "Current learning rate: 0.03280500000000001\n",
441 | "\n",
442 | "Epoch: 22\n",
443 | "|===============================================| Accuracy: 99.25666666666666 % 59554 / 60000\n",
444 | "|==========| Accuracy: 99.2 % 9920 / 10000\n",
445 | "This epoch took 12.057764053344727 seconds\n",
446 | "Current learning rate: 0.03280500000000001\n",
447 | "\n",
448 | "Epoch: 23\n",
449 | "|===============================================| Accuracy: 99.2 % 59520 / 60000\n",
450 | "|==========| Accuracy: 99.21 % 9921 / 10000\n",
451 | "This epoch took 12.059757709503174 seconds\n",
452 | "Current learning rate: 0.03280500000000001\n",
453 | "\n",
454 | "Epoch: 24\n",
455 | "|===============================================| Accuracy: 99.25666666666666 % 59554 / 60000\n",
456 | "|==========| Accuracy: 99.16 % 9916 / 10000\n",
457 | "This epoch took 12.063747882843018 seconds\n",
458 | "Current learning rate: 0.02657205000000001\n",
459 | "\n",
460 | "Epoch: 25\n",
461 | "|===============================================| Accuracy: 99.28166666666667 % 59569 / 60000\n",
462 | "|==========| Accuracy: 99.22 % 9922 / 10000\n",
463 | "This epoch took 12.0886812210083 seconds\n",
464 | "Current learning rate: 0.02952450000000001\n",
465 | "\n",
466 | "Epoch: 26\n",
467 | "|===============================================| Accuracy: 99.34333333333333 % 59606 / 60000\n",
468 | "|==========| Accuracy: 99.27 % 9927 / 10000\n",
469 | "This epoch took 12.026846408843994 seconds\n",
470 | "Current learning rate: 0.02952450000000001\n",
471 | "\n",
472 | "Epoch: 27\n",
473 | "|===============================================| Accuracy: 99.27166666666666 % 59563 / 60000\n",
474 | "|==========| Accuracy: 99.25 % 9925 / 10000\n",
475 | "This epoch took 12.094665050506592 seconds\n",
476 | "Current learning rate: 0.02952450000000001\n",
477 | "\n",
478 | "Epoch: 28\n",
479 | "|===============================================| Accuracy: 99.31666666666666 % 59590 / 60000\n",
480 | "|==========| Accuracy: 99.15 % 9915 / 10000\n",
481 | "This epoch took 12.090675592422485 seconds\n",
482 | "Current learning rate: 0.02952450000000001\n",
483 | "\n",
484 | "Epoch: 29\n",
485 | "|===============================================| Accuracy: 99.32166666666667 % 59593 / 60000\n",
486 | "|==========| Accuracy: 99.06 % 9906 / 10000\n",
487 | "This epoch took 12.038814544677734 seconds\n",
488 | "Current learning rate: 0.02391484500000001\n",
489 | "\n",
490 | "Epoch: 30\n",
491 | "|===============================================| Accuracy: 99.32833333333333 % 59597 / 60000\n",
492 | "|==========| Accuracy: 99.15 % 9915 / 10000\n",
493 | "This epoch took 12.056766271591187 seconds\n",
494 | "Current learning rate: 0.02657205000000001\n",
495 | "\n",
496 | "Epoch: 31\n",
497 | "|===============================================| Accuracy: 99.36166666666666 % 59617 / 60000\n",
498 | "|==========| Accuracy: 99.19 % 9919 / 10000\n",
499 | "This epoch took 12.067737340927124 seconds\n",
500 | "Current learning rate: 0.02657205000000001\n",
501 | "\n",
502 | "Epoch: 32\n",
503 | "|===============================================| Accuracy: 99.41 % 59646 / 60000\n",
504 | "|==========| Accuracy: 99.21 % 9921 / 10000\n",
505 | "This epoch took 12.074690818786621 seconds\n",
506 | "Current learning rate: 0.02657205000000001\n",
507 | "\n",
508 | "Epoch: 33\n"
509 | ]
510 | },
511 | {
512 | "name": "stdout",
513 | "output_type": "stream",
514 | "text": [
515 | "|====================="
516 | ]
517 | }
518 | ],
519 | "source": [
520 | "%%time\n",
521 | "run_all('mnist')"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "metadata": {},
528 | "outputs": [],
529 | "source": [
530 | "%%time\n",
531 | "run_all('mnist', 'cp', rate=0.01)"
532 | ]
533 | },
534 | {
535 | "cell_type": "code",
536 | "execution_count": null,
537 | "metadata": {},
538 | "outputs": [],
539 | "source": [
540 | "%%time\n",
541 | "run_all('mnist', 'tucker')"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": null,
547 | "metadata": {},
548 | "outputs": [],
549 | "source": [
550 | "%%time\n",
551 | "run_all('mnist', 'tt')"
552 | ]
553 | },
554 | {
555 | "cell_type": "code",
556 | "execution_count": null,
557 | "metadata": {},
558 | "outputs": [],
559 | "source": [
560 | "%%time\n",
561 | "run_all('cifar10', iterations=200)"
562 | ]
563 | },
564 | {
565 | "cell_type": "code",
566 | "execution_count": null,
567 | "metadata": {},
568 | "outputs": [],
569 | "source": [
570 | "%%time\n",
571 | "run_all('cifar10', 'tucker', iterations=200)"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": null,
577 | "metadata": {},
578 | "outputs": [],
579 | "source": [
580 | "%%time\n",
581 | "run_all('cifar100', 'tt', iterations=200)"
582 | ]
583 | }
584 | ],
585 | "metadata": {
586 | "kernelspec": {
587 | "display_name": "Python 3",
588 | "language": "python",
589 | "name": "python3"
590 | },
591 | "language_info": {
592 | "codemirror_mode": {
593 | "name": "ipython",
594 | "version": 3
595 | },
596 | "file_extension": ".py",
597 | "mimetype": "text/x-python",
598 | "name": "python",
599 | "nbconvert_exporter": "python",
600 | "pygments_lexer": "ipython3",
601 | "version": "3.7.6"
602 | }
603 | },
604 | "nbformat": 4,
605 | "nbformat_minor": 4
606 | }
607 |
--------------------------------------------------------------------------------
/transform_based_network/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import *
2 | from .trainer import *
3 | from .transform_layer import *
4 | from .transform_nets import *
5 | from .multiprocess import *
--------------------------------------------------------------------------------
/transform_based_network/multiprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch_dct as dct
5 |
6 | import time
7 | from torch.multiprocessing import Pool, Queue, Process, set_start_method
8 | import torch.multiprocessing as torch_mp
9 | import multiprocessing as mp_
10 |
11 | import sys
12 | sys.path.append('../')
13 | from common import *
14 | from transform_based_network import *
15 |
16 |
17 | def t_product_slice(A, B, C, i):
18 | C[i, ...] = torch.mm(A[i, ...], B[i, ...])
19 |
20 | def t_product_multiprocess(A, B):
21 | tmp = torch_mp.get_context('spawn')
22 |
23 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])
24 | dct_A = torch.transpose(dct.dct(torch.transpose(A, 0, 2)), 0, 2)
25 | dct_B = torch.transpose(dct.dct(torch.transpose(B, 0, 2)), 0, 2)
26 | dct_C = torch.zeros(A.shape[0], A.shape[1], B.shape[2])
27 |
28 | #dct_A.share_memory_()
29 | #dct_B.share_memory_()
30 | #dct_C.share_memory_()
31 |
32 | processes = []
33 | # num_cores = torch_mp.cpu_count()
34 | for i in range(dct_C.shape[0]):
35 | p = tmp.Process(target=t_product_slice, args=(dct_A, dct_B, dct_C, i))
36 | p.start()
37 | processes.append(p)
38 | for p in processes:
39 | p.join()
40 |
41 | C = torch.transpose(dct.idct(torch.transpose(dct_C, 0, 2)), 0, 2)
42 | return C
43 |
44 | '''
45 | if __name__ == '__main__':
46 | A = torch.ones(17, 1000, 1000)
47 | B = torch.ones(17, 1000, 1000)
48 | t_product_multiprocess(A, B)
49 | '''
--------------------------------------------------------------------------------
/transform_based_network/new_transform.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch_dct as dct\n",
11 | "import torch.nn as nn\n",
12 | "import torch.optim as optim\n",
13 | "import torch.nn.functional as F\n",
14 | "import torch.backends.cudnn as cudnn\n",
15 | "import torch.nn.init as init\n",
16 | "from torch.optim.lr_scheduler import StepLR\n",
17 | "\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "import time\n",
20 | "import pkbar\n",
21 | "import math\n",
22 | "\n",
23 | "import sys\n",
24 | "sys.path.append('../')\n",
25 | "from common import *\n",
26 | "from transform_based_network import *"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 14,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "def t_product_in_network(A, B):\n",
36 | " device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
37 | " assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])\n",
38 | " dct_C = torch.zeros(A.shape[0], A.shape[1], B.shape[2])\n",
39 | " dct_A = torch_apply(dct.dct, A)\n",
40 | " for k in range(A.shape[0]):\n",
41 | " dct_C[k, ...] = torch.mm(dct_A[k, ...], B[k, ...])\n",
42 | " return dct_C #.to(device)"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 15,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "class tNN(nn.Module):\n",
52 | " def __init__(self):\n",
53 | " super(tNN, self).__init__()\n",
54 | " W, B = [], []\n",
55 | " self.num_layers = 10\n",
56 | " for i in range(self.num_layers):\n",
57 | " W.append(nn.Parameter(torch.Tensor(28, 28, 28)))\n",
58 | " B.append(nn.Parameter(torch.Tensor(28, 28, 1)))\n",
59 | " self.W = nn.ParameterList(W)\n",
60 | " self.B = nn.ParameterList(B)\n",
61 | " self.reset_parameters()\n",
62 | "\n",
63 | " def forward(self, x):\n",
64 | " for i in range(self.num_layers):\n",
65 | " x = torch.add(t_product(self.W[i], x), self.B[i])\n",
66 | " x = F.relu(x)\n",
67 | " return x\n",
68 | "\n",
69 | " def reset_parameters(self):\n",
70 | " for i in range(self.num_layers):\n",
71 | " init.kaiming_uniform_(self.W[i], a=math.sqrt(5))\n",
72 | " fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i])\n",
73 | " bound = 1 / math.sqrt(fan_in)\n",
74 | " init.uniform_(self.B[i], -bound, bound)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 22,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "class new_tNN(nn.Module):\n",
84 | " def __init__(self):\n",
85 | " super(new_tNN, self).__init__()\n",
86 | " W, B = [], []\n",
87 | " self.num_layers = 4\n",
88 | " for i in range(self.num_layers):\n",
89 | " W.append(nn.Parameter(torch.Tensor(28, 28, 28)))\n",
90 | " B.append(nn.Parameter(torch.Tensor(28, 28, 1)))\n",
91 | " self.W = nn.ParameterList(W)\n",
92 | " self.B = nn.ParameterList(B)\n",
93 | " self.reset_parameters()\n",
94 | "\n",
95 | " def forward(self, x):\n",
96 | " x = torch_apply(dct.dct, x)\n",
97 | " for i in range(self.num_layers):\n",
98 | " x = torch.add(t_product_in_network(self.W[i], x), self.B[i])\n",
99 | " x = F.relu(x)\n",
100 | " x = torch_apply(dct.idct, x)\n",
101 | " return x\n",
102 | "\n",
103 | " def reset_parameters(self):\n",
104 | " for i in range(self.num_layers):\n",
105 | " init.kaiming_uniform_(self.W[i], a=math.sqrt(5))\n",
106 | " fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i])\n",
107 | " bound = 1 / math.sqrt(fan_in)\n",
108 | " init.uniform_(self.B[i], -bound, bound)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 28,
114 | "metadata": {},
115 | "outputs": [
116 | {
117 | "name": "stdout",
118 | "output_type": "stream",
119 | "text": [
120 | "==> Loading data..\n",
121 | "Epoch 1 training:\n",
122 | " 23/600 [>.............................] - 8.4s"
123 | ]
124 | },
125 | {
126 | "ename": "KeyboardInterrupt",
127 | "evalue": "",
128 | "output_type": "error",
129 | "traceback": [
130 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
131 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
132 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[0mpbar_train\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpkbar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mPbar\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Epoch '\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;34m' training:'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m60000\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 28\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mraw_img\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m28\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
133 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 361\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 362\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 363\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 364\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 365\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
134 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 401\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 402\u001b[0m \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 403\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 404\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 405\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
135 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
136 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
137 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
138 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m 59\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 60\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 61\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 62\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 63\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
139 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, tensor)\u001b[0m\n\u001b[0;32m 210\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mNormalized\u001b[0m \u001b[0mTensor\u001b[0m \u001b[0mimage\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 211\u001b[0m \"\"\"\n\u001b[1;32m--> 212\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormalize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 213\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 214\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
140 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\functional.py\u001b[0m in \u001b[0;36mnormalize\u001b[1;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[0;32m 290\u001b[0m \u001b[0mmean\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 291\u001b[0m \u001b[0mstd\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 292\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mstd\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0many\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 293\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'std evaluated to zero after conversion to {}, leading to division by zero.'\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 294\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
141 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mwrapped\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 22\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 23\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
142 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
143 | ]
144 | }
145 | ],
146 | "source": [
147 | "lr_rate = 0.001\n",
148 | "epochs_num = 20\n",
149 | "device = 'cpu' # 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
150 | "batch_size = 100\n",
151 | "train_loader, test_loader = load_mnist_multiprocess(batch_size)\n",
152 | "\n",
153 | "module = new_tNN()\n",
154 | "module = module.to(device)\n",
155 | "\n",
156 | "Loss_function = nn.CrossEntropyLoss()\n",
157 | "optimizer = torch.optim.SGD(module.parameters(), lr=lr_rate)\n",
158 | "\n",
159 | "test_loss_epoch = []\n",
160 | "test_acc_epoch = []\n",
161 | "train_loss_epoch = []\n",
162 | "train_acc_epoch = []\n",
163 | "time_list = []\n",
164 | "\n",
165 | "# begain train\n",
166 | "for epoch in range(epochs_num):\n",
167 | " since = time.time()\n",
168 | " running_loss = 0.0\n",
169 | " running_acc = 0.0\n",
170 | " module.train()\n",
171 | "\n",
172 | " pbar_train = pkbar.Pbar(name='Epoch '+str(epoch+1)+' training:', target=60000/batch_size)\n",
173 | " for i, data in enumerate(train_loader):\n",
174 | " img, label = data\n",
175 | " img = raw_img(img, batch_size, n=28)\n",
176 | " img = img.to(device)\n",
177 | " label = label.to(device)\n",
178 | "\n",
179 | " # forward\n",
180 | " out = module(img)\n",
181 | "\n",
182 | " # softmax function\n",
183 | " out = torch.transpose(scalar_tubal_func(out), 0, 1)\n",
184 | " loss = Loss_function(out, label)\n",
185 | " running_loss += loss.item()\n",
186 | " _, pred = torch.max(out, 1)\n",
187 | " running_acc += (pred == label).float().mean()\n",
188 | "\n",
189 | " # backward\n",
190 | " optimizer.zero_grad()\n",
191 | " loss.backward()\n",
192 | " optimizer.step()\n",
193 | " \n",
194 | " pbar_train.update(i)\n",
195 | "\n",
196 | " print('[{Epoch}/{Epochs_num}] Loss:{Running_loss} Acc:{Running_acc}'\n",
197 | " .format(Epoch=epoch + 1, Epochs_num=epochs_num, Running_loss=(running_loss / i),\n",
198 | " Running_acc=running_acc / i))\n",
199 | " train_loss_epoch.append(running_loss / i)\n",
200 | " train_acc_epoch.append((running_acc / i) * 100)\n",
201 | "\n",
202 | " module.eval()\n",
203 | " eval_loss = 0.0\n",
204 | " eval_acc = 0.0\n",
205 | "\n",
206 | " pbar_test = pkbar.Pbar(name='Epoch '+str(epoch+1)+' test', target=10000/batch_size)\n",
207 | " for i, data in enumerate(test_loader):\n",
208 | " img, label = data\n",
209 | " img = cifar_img_process(img)\n",
210 | " img = img.to(device)\n",
211 | " label = label.to(device)\n",
212 | "\n",
213 | " with torch.no_grad():\n",
214 | " out = module(img)\n",
215 | " out = torch.transpose(scalar_tubal_func(out), 0, 1)\n",
216 | " loss = Loss_function(out, label)\n",
217 | " eval_loss += loss.item()\n",
218 | " _, pred = torch.max(out, 1)\n",
219 | " eval_acc += (pred == label).float().mean()\n",
220 | "\n",
221 | " pbar_test.update(i)\n",
222 | "\n",
223 | " print('Test Loss: {Eval_loss}, Acc: {Eval_acc}'\n",
224 | " .format(Eval_loss=eval_loss / len(test_loader), \n",
225 | " Eval_acc=eval_acc / len(test_loader)))\n",
226 | " test_loss_epoch.append(eval_loss / len(test_loader))\n",
227 | " test_acc_epoch.append((eval_acc / len(test_loader)) * 100)\n",
228 | " time_list.append(time.time() - since)\n",
229 | "\n",
230 | " if np.isnan(eval_loss):\n",
231 | " print('invalid loss')\n",
232 | " break"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {},
239 | "outputs": [],
240 | "source": []
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": null,
245 | "metadata": {},
246 | "outputs": [],
247 | "source": []
248 | }
249 | ],
250 | "metadata": {
251 | "kernelspec": {
252 | "display_name": "Python 3",
253 | "language": "python",
254 | "name": "python3"
255 | }
256 | },
257 | "nbformat": 4,
258 | "nbformat_minor": 4
259 | }
260 |
--------------------------------------------------------------------------------
/transform_based_network/tNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_dct as dct
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import torch.backends.cudnn as cudnn
7 | import torch.nn.init as init
8 | from torch.optim.lr_scheduler import StepLR
9 |
10 | import matplotlib.pyplot as plt
11 | import time
12 | import pkbar
13 | import math
14 |
15 | import sys
16 | sys.path.append('../')
17 | from common import *
18 | from transform_based_network import *
19 |
20 |
21 | class tNN(nn.Module):
22 | def __init__(self):
23 | super(tNN, self).__init__()
24 | W, B = [], []
25 | self.num_layers = 4
26 | for i in range(self.num_layers):
27 | W.append(nn.Parameter(torch.Tensor(28, 28, 28)))
28 | B.append(nn.Parameter(torch.Tensor(28, 28, 1)))
29 | self.W = nn.ParameterList(W)
30 | self.B = nn.ParameterList(B)
31 | self.reset_parameters()
32 |
33 | def forward(self, x):
34 | for i in range(self.num_layers):
35 | x = torch.add(t_product(self.W[i], x), self.B[i])
36 | x = F.relu(x)
37 | return x
38 |
39 | def reset_parameters(self):
40 | for i in range(self.num_layers):
41 | init.kaiming_uniform_(self.W[i], a=math.sqrt(5))
42 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i])
43 | bound = 1 / math.sqrt(fan_in)
44 | init.uniform_(self.B[i], -bound, bound)
--------------------------------------------------------------------------------
/transform_based_network/trainer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | from common import *
4 | from transform_based_network import *
5 |
6 | import torch
7 | import torch_dct as dct
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | import torch.backends.cudnn as cudnn
12 | from torch.optim.lr_scheduler import StepLR
13 |
14 |
15 | def train_step_transform(epoch, train_acc, model, trainloader, optimizer):
16 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
17 | model = model.to(device)
18 | model.train()
19 | train_loss = 0.0
20 | correct = 0
21 | total = 0
22 | criterion = nn.CrossEntropyLoss()
23 |
24 | print('\nEpoch: ', epoch)
25 | print('|', end='')
26 | for batch_idx, (inputs, labels) in enumerate(trainloader):
27 | inputs = inputs.to(device)
28 | if not inputs.shape[0] == 100:
29 | break
30 |
31 | labels = labels.to(device)
32 | outputs = model(inputs)
33 | outputs = torch.transpose(scalar_tubal_func(outputs), 0, 1)
34 |
35 | optimizer.zero_grad()
36 | loss = criterion(outputs, labels)
37 | # print(loss)
38 | if np.isnan(loss.item()):
39 | print('Training terminated due to instability')
40 | break
41 | loss.backward()
42 | optimizer.step()
43 | train_loss += loss.item()
44 | _, predicted = torch.max(outputs, 1)
45 | total += labels.size(0)
46 | correct += predicted.eq(labels).sum().item()
47 | if batch_idx % 10 == 0:
48 | print('=', end='')
49 | print('|', 'Accuracy:', correct / total, 'Loss:', train_loss / total)
50 | train_acc.append(correct / total)
51 | return train_acc
52 |
53 | def test(test_acc, model, testloader):
54 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
55 | model = model.to(device)
56 | model.eval()
57 | test_loss = 0
58 | correct = 0
59 | total = 0
60 | criterion = nn.CrossEntropyLoss()
61 | s = time.time()
62 | with torch.no_grad():
63 | print('|', end='')
64 | for batch_idx, (inputs, targets) in enumerate(testloader):
65 | inputs = inputs.to(device)
66 | if not inputs.shape[0] == 100:
67 | break
68 | targets = targets.to(device)
69 | outputs = model(inputs)
70 | outputs = torch.transpose(scalar_tubal_func(outputs), 0, 1)
71 | loss = criterion(outputs, targets)
72 | test_loss += loss.item()
73 | _, predicted = outputs.max(1)
74 | total += targets.size(0)
75 | correct += predicted.eq(targets).sum().item()
76 | if batch_idx % 10 == 0:
77 | print('=', end='')
78 | e = time.time()
79 | print('|', ' Test accuracy:', correct / total, 'Test loss:', test_loss / total)
80 | print('The inference time is', e - s, 'seconds')
81 | test_acc.append(correct / total)
82 | return test_acc, e - s
83 |
84 | def train_transform(i, model, trainloader, testloader, optimizer):
85 | train_acc, test_acc = [], []
86 | scheduler = StepLR(optimizer, step_size=1, gamma=0.95)
87 |
88 | for epoch in range(i):
89 | s = time.time()
90 | train_acc = train_step_transform(epoch, train_acc, model, trainloader, optimizer)
91 | e = time.time()
92 | test_acc, _ = test(test_acc, model, testloader)
93 | scheduler.step()
94 |
95 | print('This epoch took', e - s, 'seconds to train')
96 | print('Current learning rate:', scheduler.get_last_lr()[0])
97 | print('Best training accuracy overall:', max(test_acc))
98 | return train_acc, test_acc
--------------------------------------------------------------------------------
/transform_based_network/transform_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_dct as dct
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import torch.backends.cudnn as cudnn
7 | from torch.optim.lr_scheduler import StepLR
8 | import sys
9 | sys.path.append('../')
10 | from common import *
11 | from transform_based_network import *
12 |
13 |
14 | class Transform_Layer(nn.Module):
15 | def __init__(self, n, size_in, m, size_out):
16 | super().__init__()
17 | self.size_in = size_in
18 | self.size_out = size_out
19 | weights = torch.randn(n, size_out, size_in) * 0.01
20 | bias = torch.randn(1, size_out, m)
21 | self.weights = nn.Parameter(weights, requires_grad=True)
22 | self.bias = nn.Parameter(bias, requires_grad=True)
23 |
24 | def forward(self, x):
25 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
26 | return t_product_v2(self.weights, x).to(device) #/ 1e2
27 | #torch.add(t_product_v2(self.weights, x).to(device), self.bias)
28 |
29 | class T_Layer(nn.Module):
30 | def __init__(self, dct_w, dct_b):
31 | super(T_Layer, self).__init__()
32 | self.weights = nn.Parameter(dct_w, requires_grad=True)
33 | self.bias = nn.Parameter(dct_b, requires_grad=True)
34 |
35 | def forward(self, dct_x):
36 | x = torch.mm(self.weights, dct_x) + self.bias
37 | return x
--------------------------------------------------------------------------------
/transform_based_network/transform_nets.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | from common import *
4 | from transform_based_network import *
5 |
6 | import torch
7 | import torch_dct as dct
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | import torch.backends.cudnn as cudnn
12 | from torch.optim.lr_scheduler import StepLR
13 |
14 |
15 | class Transform_Net(nn.Module):
16 | def __init__(self, batch_size):
17 | super(Transform_Net, self).__init__()
18 | self.features = nn.Sequential(
19 | #Transform_Layer(28, 28, batch_size, 28),
20 | #nn.ReLU(inplace=True),
21 | #Transform_Layer(28, 28, batch_size, 28),
22 | #nn.ReLU(inplace=True),
23 | Transform_Layer(28, 28, batch_size, 10),
24 | )
25 |
26 | def forward(self, x):
27 | self.x = torch_shift(x)
28 | self.x = self.features(self.x)
29 | return self.x / 1e2#/ 1e2
30 |
31 |
32 | class Conv_Transform_Net(nn.Module):
33 | def __init__(self, batch_size):
34 | super(Conv_Transform_Net, self).__init__()
35 | self.first = nn.Sequential(
36 | Transform_Layer(28, 28, batch_size, 28),
37 | nn.ReLU(inplace=True),
38 | )
39 | self.intermediate = nn.Sequential(
40 | nn.Conv2d(28, 28, kernel_size=3, padding=1),
41 | nn.Conv2d(28, 28, kernel_size=3, padding=1),
42 | nn.Conv2d(28, 28, kernel_size=1, padding=0),
43 | )
44 | self.last = Transform_Layer(28, 28, batch_size, 10)
45 |
46 | def forward(self, x):
47 | x = torch_shift(x)
48 | x = self.first(x)
49 |
50 | x = torch.transpose(x, 0, 2)
51 | x = torch.transpose(x, 1, 2)
52 | x = x.reshape(100, 28, 4, 7)
53 |
54 | x = self.intermediate(x)
55 |
56 | x = x.reshape(100, 28, 28)
57 | x = torch_shift(x)
58 | x = self.last(x)
59 |
60 | return x / 5e2
61 |
62 |
63 | class Conv_Transform_Net_CIFAR(nn.Module):
64 | def __init__(self, batch_size):
65 | super(Conv_Transform_Net_CIFAR, self).__init__()
66 | self.batch_size = batch_size
67 | self.first = nn.Sequential(
68 | Transform_Layer(96, 32, batch_size, 32),
69 | nn.ReLU(inplace=True),
70 | Transform_Layer(96, 32, batch_size, 32),
71 | nn.ReLU(inplace=True),
72 | )
73 |
74 | self.intermediate = nn.Sequential(
75 | nn.Conv2d(96, 96, kernel_size=3, padding=1),
76 | nn.ReLU(inplace=True),
77 | nn.Conv2d(96, 96, kernel_size=3, padding=1),
78 | nn.ReLU(inplace=True),
79 | nn.Conv2d(96, 96, kernel_size=3, padding=1),
80 | nn.ReLU(inplace=True),
81 | nn.Conv2d(96, 96, kernel_size=3, padding=1),
82 | nn.ReLU(inplace=True),
83 | nn.Conv2d(96, 96, kernel_size=3, padding=1),
84 | nn.ReLU(inplace=True),
85 | nn.Conv2d(96, 96, kernel_size=1, padding=0),
86 | nn.ReLU(inplace=True),
87 | )
88 | self.last = Transform_Layer(96, 32, batch_size, 10)
89 |
90 | def forward(self, x):
91 | x = torch.reshape(x, (100, 96, 32))
92 | x = torch_shift(x)
93 | x = self.first(x)
94 |
95 | x = torch.transpose(x, 0, 2)
96 | x = torch.transpose(x, 1, 2)
97 | x = x.reshape(self.batch_size, 96, 4, 8)
98 | x = self.intermediate(x)
99 |
100 | x = x.reshape(self.batch_size, 96, 32)
101 | x = torch_shift(x)
102 | x = self.last(x)
103 |
104 | return x
105 |
106 | class Frontal_Slice(nn.Module):
107 | def __init__(self, dct_w, dct_b):
108 | super(Frontal_Slice, self).__init__()
109 | self.device = dct_w.device
110 | self.dct_linear = T_Layer(dct_w, dct_b)
111 |
112 | def forward(self, dct_x):
113 | return self.dct_linear(dct_x.to(self.device))
114 |
115 |
116 | class Ensemble(nn.Module):
117 | def __init__(self, shape, device='cpu'):
118 | super(Ensemble, self).__init__()
119 | self.device = device
120 | self.models = []
121 | for i in range(shape[0]):
122 | dct_w, dct_b = make_weights(shape, device)
123 | model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...])
124 | self.models.append(model.to(device))
125 |
126 | def forward(self, x):
127 | s = self.models[0].dct_linear.weights.shape
128 | result = torch.empty(x.shape[0], s[0], x.shape[2])
129 | dct_x = torch_apply(dct.dct, x).to(self.device)
130 |
131 | for i in range(len(self.models)):
132 | result[i, ...] = self.models[i](dct_x[i, ...])
133 |
134 | result = torch_apply(dct.idct, result)
135 | result = torch_shift(result)
136 | softmax = scalar_tubal_func(result)
137 | return torch.transpose(softmax, 0, 1)
--------------------------------------------------------------------------------
/transform_based_network/util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | from common import *
4 | from transform_based_network import *
5 |
6 | import torch
7 | import torch_dct as dct
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | import torch.backends.cudnn as cudnn
12 | from torch.optim.lr_scheduler import StepLR
13 |
14 |
15 | def bcirc(A):
16 | l, m, n = A.shape
17 | bcirc_A = []
18 | for i in range(l):
19 | bcirc_A.append(torch.roll(A, shifts=i, dims=0))
20 | return torch.cat(bcirc_A, dim=2).reshape(l*m, l*n)
21 |
22 | def hankel(A):
23 | l, m, n = A.shape
24 | circ = torch.zeros(2 * l + 1, m, n)
25 | circ[l, ...] = torch.zeros(m, n)
26 | for i in range(l):
27 | k = circ.shape[0] - i - 1
28 | circ[i, ...] = A[i, ...]
29 | circ[k, ...] = A[i, ...]
30 | hankel_A = []
31 | for i in range(1, l + 1):
32 | hankel_A.append(circ[i : i + l, ...])
33 | return torch.cat(hankel_A, dim=2).reshape(l*m, l*n)
34 |
35 | def toeplitz(A):
36 | l, m, n = A.shape
37 | circ = torch.zeros(2 * l - 1, m, n)
38 | for i in range(l):
39 | circ[i + l - 1, ...] = A[i, ...]
40 | circ[l - i - 1, ...] = A[i, ...]
41 | toeplitz_A = []
42 | for i in range(l):
43 | toeplitz_A.append(circ[l - i - 1: 2 * l - i - 1, ...])
44 | return torch.cat(toeplitz_A, dim=2).reshape(l*m, l*n)
45 |
46 | def tph(A):
47 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
48 | return torch.Tensor(toeplitz(A) + hankel(A)).to(device)
49 |
50 | def shift(X):
51 | A = X.clone()
52 | for i in range(1, A.shape[0]):
53 | k = A.shape[0] - i - 1
54 | A[k, ...] -= A[k + 1, ...]
55 | return A
56 |
57 | def t_product(A, B):
58 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])
59 | prod = torch.mm(tph(shift(A)), bcirc(B)[:, 0:B.shape[2]])
60 | return prod.reshape(A.shape[0], A.shape[1], B.shape[2])
61 |
62 | def dct_t_product(A, B):
63 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])
64 | dct_C = torch.zeros(A.shape[0], A.shape[1], B.shape[2])
65 | dct_A = torch_apply(dct.dct, A)
66 | dct_B = torch_apply(dct.dct, B)
67 | for k in range(A.shape[0]):
68 | dct_C[k, ...] = torch.mm(dct_A[k, ...], dct_B[k, ...])
69 | return torch_apply(dct.idct, dct_C)
70 |
71 | def t_product_fft(A, B):
72 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])
73 | prod = torch.mm(bcirc(A), bcirc(B)[:, 0:B.shape[2]])
74 | return prod.reshape(A.shape[0], A.shape[1], B.shape[2])
75 |
76 | def t_product_fft_v2(A, B):
77 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])
78 | dct_C = np.zeros((A.shape[0], A.shape[1], B.shape[2]), dtype=complex)
79 | for k in range(A.shape[0]):
80 | dct_C[k, ...] = np.fft.fft(A, axis=0)[k, ...] @ np.fft.fft(B, axis=0)[k, ...]
81 | return np.real(np.fft.ifft(dct_C, axis=0))
82 |
83 | def scalar_tubal_func(output_tensor):
84 | l, m, n = output_tensor.shape
85 | lateral_slices = [output_tensor[:, :, i].reshape(l, m, 1) for i in range(n)]
86 | h_slice = []
87 | for s in lateral_slices:
88 | h_slice.append(h_func_dct(s))
89 | pro_matrix = torch.stack(h_slice, dim=2)
90 | return pro_matrix.reshape(m, n)
91 |
92 | def h_func_dct(lateral_slice):
93 | l, m, n = lateral_slice.shape
94 | dct_slice = dct.dct(lateral_slice)
95 | tubes = [dct_slice[i, :, 0] for i in range(l)]
96 | h_tubes = []
97 | for tube in tubes:
98 | h_tubes.append(torch.exp(tube) / torch.sum(torch.exp(tube)))
99 | res_slice = torch.stack(h_tubes, dim=0).reshape(l, m, n)
100 | idct_a = dct.idct(res_slice)
101 | return torch.sum(idct_a, dim=0)
102 |
103 | def torch_apply(func, x):
104 | x = func(torch.transpose(x, 0, 2))
105 | return torch.transpose(x, 0, 2)
106 |
107 | def make_weights(shape, device='cpu', scale=0.01):
108 | w = torch.randn(shape[0], 10, shape[1]) * scale
109 | b = torch.randn(shape[0], 10, shape[2]) * scale
110 | dct_w = torch_apply(dct.dct, w).to(device)
111 | dct_b = torch_apply(dct.dct, b).to(device)
112 | return dct_w, dct_b
113 |
114 | def to_categorical(y, num_classes):
115 | categorical = torch.empty(len(y), num_classes)
116 | for i in range(len(y)):
117 | categorical[i, :] = torch.eye(num_classes, num_classes)[y[i]]
118 | return categorical
119 |
120 | def torch_shift(A):
121 | x = A.squeeze()
122 | x = torch.transpose(x, 0, 2)
123 | x = torch.transpose(x, 0, 1)
124 | return x
125 |
126 | def cifar_img_process(raw_img):
127 | k, l, m, n = raw_img.shape
128 | img_list = torch.split(raw_img, split_size_or_sections=1, dim=0)
129 | list = []
130 | for img in img_list:
131 | img = img.reshape(l, m, n)
132 | frontal = torch.cat([img[i, :, :] for i in range(l)], dim=0)
133 | single_img = torch.transpose(frontal.reshape(1, l * m, n), 0, 2)
134 | list.append(single_img)
135 | ultra_img = torch.cat(list, dim=2)
136 | return ultra_img
137 |
138 | def raw_img(img, batch_size, n):
139 | img_raw = img.reshape(batch_size, n * n)
140 | single_img = torch.split(img_raw, split_size_or_sections=1, dim=0)
141 | single_img_T = [torch.transpose(i.reshape(n, n, 1), 0, 1) for i in single_img]
142 | ultra_img = torch.cat(single_img_T, dim=2)
143 | return ultra_img
144 |
145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------