├── .gitignore ├── LICENSE ├── README.md ├── asset ├── CIFAR10_test.png ├── CIRAR10_train.png ├── MNIST_test.png ├── MNIST_train.png ├── cifar.png ├── conv_transform_Fashion.png ├── conv_transform_MNIST.png ├── convolutional_transform_CIFAR10.png ├── mnist.png ├── transform_based_CIFAR10.png └── transfrom_based_MNIST.png ├── common ├── __init__.py ├── builder.py ├── loader.py ├── main.py ├── test.py └── train.py ├── decomposition ├── __init__.py ├── conv_layer.py ├── decompose_all.py ├── fc_layer.py └── tensor_ring.py ├── examples ├── distributed.ipynb ├── ex1.ipynb ├── multi_example.py ├── multi_threading_net.ipynb ├── multi_threading_net.py ├── multiprocess_net.ipynb └── parallel.py ├── nets ├── __init__.py ├── lenet.py └── vgg.py ├── notebooks ├── cp_demo.ipynb ├── eval.ipynb └── main.ipynb └── transform_based_network ├── __init__.py ├── multiprocess.py ├── new_transform.ipynb ├── tNN.py ├── trainer.py ├── transform_layer.py ├── transform_nets.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.i 2 | *.ii 3 | *.gpu 4 | *.ptx 5 | *.cubin 6 | *.fatbin 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xiao-Yang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensor_Layer_for_Deep_Neural_Network_Compression 2 | Apply CP, Tucker, TT/TR, HT to compress neural networks. Train from scratch. 3 | 4 | ## Usage 5 | 6 | First, import required modules by 7 | ``` 8 | from common import * 9 | from decomposition import * 10 | from nets import * 11 | ``` 12 | 13 | Then, specify a neural network model. The user can choose from any model provided by the nets package, or can define a new architecture using pytorch. Examples: 14 | ``` 15 | model0 = LeNet() 16 | model1 = VGG('VGG16') 17 | ``` 18 | 19 | Finally, go through the training and testing process using the `run_all` function. 20 | The function has a few parameters: 21 | * `dataset`: choose a dataset from mnist, cifar10, and cifar100 22 | * `model`: the neural network model defined in the last step 23 | * `decomp`: the method of decomposition; defaulted to be `None` (undecomposed) 24 | * `i`: number of iterations for training; defaulted to be 100 25 | * `rate`: learning rate; defaulted to be 0.05 26 | * `transform_based`: whether to use transform_based network; defaulted to be False. 27 | * To use the transform-based net, it is advised to run a small number of epochs, as the training process is much slower than convolutional nets. 28 | * The decomposition option for transform-based net is currently under construction, so it is temporarily disabled. 29 | 30 | This example below runs the CP-decomposed LeNet on the MNIST dataset for 150 iterations with a learning rate of 0.1: 31 | ``` 32 | run_all('mnist', model0, decomp='cp', i=150, rate=0.1) 33 | ``` 34 | This function creates three subdirectories: 35 | * `data`: stores the dataset 36 | * `models`: stores the trained network 37 | * `cureves`: stores arrays of training and testing accuracy across different iterations in `.npy` format 38 | 39 | 40 | ## Method 41 | I aim to decompose the neural network in both the convolutional portion and the fully connected portion, using popular tensor decomposition algorithms such as CP, Tucker, TT and HT. In doing so, I hope to speedup both the training and the inference process and reduce the number of parameters without signicant sacrifices in terms of accuracy. 42 | 43 | ## CP 44 | CP decomposition works fine with classifying the MNIST dataset, it can compress the network without significant loss in accuracy compared to the original, uncompressed network. However, as noted in paper Lebedev et al., CP cannot has problems dealing with larger networks, and it is often unstable. In my experiments, the decomposition process not only takes an intolerably long time, but it also consumes a lot of RAM. For a convolutional layer with size larger than 512 x 512, the CP decompositon becomes infeasible in terms of memory. Moreover, the CP decomposed network is highly sensitive to the learning rate, and requires the learning rate to be as small as 1e-5 for learning to take place. 45 | 46 | ## Tucker 47 | Tucker decomposition is strictly superior to CP in almost every way. It has more sucess decomposing larger networks, and requires less resources in terms of runtime and memory. It is also more tolerant to larger values of learning rates, allowing the network to learn faster. The network decomposed with Tucker also learns faster, i.e., yields greater accuracy in fewer epochs (see the analysis of performance graphs for details). 48 | 49 | ## Tensor Train (TT) 50 | In my implementation of compression using tensor train, I picked out the two dimensions in the convolutional layer corresponding to the input/output channels, then I matricized the tensor, decomposed the result to matrix product state, and reshaped them back to 4-dimensional tensors. This gives us two decomposed convolutional layer for every convolutional layer in the original network. Experimentally, this method yields better results than Tucker, and has similar rate of compression and speedup as Tucker. In the two papers by Novikov et al., the authors proposed using a transformation to higher-order tensor before applying TT decomposition. 51 | 52 | ## Tensor Ring (TR) UNDER CONSTRUCTION 53 | TR decomposition is highly similar to TT, differing only in an additional non-trivial mode on the first and last tensor core. The way it is applied to neural networks is also similar, although researchers argue that TR has greater expressiveness. 54 | 55 | ## Hierarchical Tucker (HT) 56 | Have not yet developed. 57 | 58 | ## Transform-based networks 59 | This network is trained in the transform domain: the weights and the training data are passed into the network after applying a tensor product between them. The outputs are evaluated against the tubal version of the softmax objective function after an inverse transformation (idct) in the last layer. The backprop process is handled by the pytorch's builtin autograd functions. 60 | 61 | To see the how the fully-connected transform-based network runs on MNIST, run the demo in new_transform.ipynb. 62 | 63 | ## Experiments 64 | I tested the performance of the three compression methods against the uncompressed network on the MNIST and the CIFAR10 datasets. I tried to keep all hyperparameters the same for all tests, including rank, number of epochs, and learning rate. However, as CP is too sensitive to learning rate, I give it a much smaller value for learning rate. 65 | 66 |

67 |

Figure 1. Training accuracy comparision on the MNIST dataset.

68 | 69 |

70 |

Figure 2. Testing accuracy comparision on the MNIST dataset.

71 | 72 | From this performance graph, we can see that even though the CP-decomposed network has higher training accuracy at the end, its testing accuracy is low, likely resulting from overfitting due to a finer learning rate. TT-decomposed network learns faster than Tucker and yields better results. In terms of run time, the four networks do not differ from each other significantly. 73 | 74 |

75 |

Figure 3. Training accuracy comparision on the CIAR10 dataset.

76 | 77 |

78 |

Figure 4. Testing accuracy comparision on the CIAR10 dataset.

79 | 80 |

81 |

Figure 5. Training and testing accuracy of transform-based net on MNIST dataset.

82 | 83 | For the uncompressed network, the average time for each epoch is around 38 seconds, the average time for the Tucker-decomposed network is 26 seconds, and the average time for the TT-decomposed network is 27 seconds. In terms of accuracy, the TT-decomposed network outperforms Tucker in both training and testing, and is almost comparable to the original network before compression. 84 | 85 | ## Profiling 86 | In a typical training process, the profiling output is: 87 | ``` 88 | 510155235 function calls (507136221 primitive calls) in 2053.806 seconds 89 | 90 | Ordered by: internal time 91 | List reduced from 824 to 20 due to restriction <20> 92 | 93 | ncalls tottime percall cumtime percall filename:lineno(function) 94 | 49401 743.334 0.015 743.334 0.015 {method 'item' of 'torch._C._TensorBase' objects} 95 | 19550 222.799 0.011 222.799 0.011 {method 'run_backward' of 'torch._C._EngineBase' objects} 96 | 5747602 107.058 0.000 107.058 0.000 {method 'add_' of 'torch._C._TensorBase' objects} 97 | 1183200 102.777 0.000 102.777 0.000 {built-in method conv2d} 98 | 3010000 59.049 0.000 140.632 0.000 functional.py:192(normalize) 99 | 3010000 45.654 0.000 205.622 0.000 functional.py:43(to_tensor) 100 | 1915802 45.251 0.000 45.251 0.000 {method 'mul_' of 'torch._C._TensorBase' objects} 101 | 394400 41.219 0.000 41.219 0.000 {built-in method batch_norm} 102 | 3010000 40.894 0.000 40.894 0.000 {method 'tobytes' of 'numpy.ndarray' objects} 103 | 1915850 39.603 0.000 39.603 0.000 {method 'zero_' of 'torch._C._TensorBase' objects} 104 | 3010000 32.589 0.000 32.589 0.000 {method 'div' of 'torch._C._TensorBase' objects} 105 | 3010000 25.541 0.000 25.541 0.000 {method 'contiguous' of 'torch._C._TensorBase' objects} 106 | 6020000 25.312 0.000 25.312 0.000 {built-in method as_tensor} 107 | 3010000 24.033 0.000 24.033 0.000 {method 'sub_' of 'torch._C._TensorBase' objects} 108 | 3010000 20.338 0.000 116.898 0.000 Image.py:2644(fromarray) 109 | 6020000 19.078 0.000 19.078 0.000 {method 'transpose' of 'torch._C._TensorBase' objects} 110 | 3034650 18.921 0.000 18.921 0.000 {method 'view' of 'torch._C._TensorBase' objects} 111 | 3010000 18.467 0.000 18.467 0.000 {method 'float' of 'torch._C._TensorBase' objects} 112 | 3010000 15.967 0.000 15.967 0.000 {method 'clone' of 'torch._C._TensorBase' objects} 113 | 19550 15.331 0.001 168.502 0.009 sgd.py:71(step) 114 | ``` 115 | 116 | ## References 117 | ### List of relevent papers: 118 | 119 | * Lebedev, V., Ganin, Y., Rakhuba, M., Oseledets, I. and Lempitsky, V., 2015. Speeding-up convolutional neural networks using fine-tuned CP-decomposition. In 3rd International Conference on Learning Representations, ICLR 2015-Conference Track Proceedings. 120 | * *Notes: applies CP to convlayers.* 121 | 122 | * Kim, Y.D., Park, E., Yoo, S., Choi, T., Yang, L. and Shin, D., 2015. Compression of Deep Convolutional Neural Networks for Fast and Low Power Mobile Applications. arXiv, pp.arXiv-1511. 123 | * *Notes: applies Tucker to convlayers* 124 | 125 | * Garipov, T., Podoprikhin, D., Novikov, A. and Vetrov, D., 2016. Ultimate tensorization: compressing convolutional and FC layers alike. arXiv, pp.arXiv-1611. 126 | * *Notes: applies TT to both conv and FC layers* 127 | 128 | * Novikov, A., Podoprikhin, D., Osokin, A. and Vetrov, D.P., 2015. Tensorizing neural networks. In Advances in neural information processing systems (pp. 442-450). 129 | * *Notes: applies TT to FC layers* 130 | 131 | * Wang, W., Sun, Y., Eriksson, B., Wang, W. and Aggarwal, V., 2018. Wide compression: Tensor ring nets. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 9329-9338). 132 | * *Notes: applies TR to both conv and FC layers* 133 | 134 | * Cohen, N., Sharir, O., Levine, Y., Tamari, R., Yakira, D. and Shashua, A., 2017. Analysis and Design of Convolutional Networks via Hierarchical Tensor Decompositions. arXiv, pp.arXiv-1705. 135 | * *Notes: applies HT to convlayers* 136 | 137 | * Yang, Y., Krompass, D. and Tresp, V., 2017. Tensor-train recurrent neural networks for video classification. arXiv preprint arXiv:1707.01786. 138 | * *Notes: applies TT to sequential models* 139 | 140 | * Yin, M., Liao, S., Liu, X.Y., Wang, X. and Yuan, B., 2020. Compressing Recurrent Neural Networks Using Hierarchical Tucker Tensor Decomposition. arXiv, pp.arXiv-2005. 141 | * *Notes: applies HT to LSTMs* 142 | 143 | * Newman, Elizabeth, et al. "Stable tensor neural networks for rapid deep learning." arXiv preprint arXiv:1811.06569 (2018). 144 | * *Notes: transform-based tensor neural network* 145 | 146 | ### Related Github repos: 147 | 148 | https://github.com/jacobgil/pytorch-tensor-decompositions 149 | 150 | https://github.com/JeanKossaifi/tensorly-notebooks 151 | 152 | https://github.com/vadim-v-lebedev/cp-decomposition 153 | 154 | https://github.com/timgaripov/TensorNet-TF 155 | -------------------------------------------------------------------------------- /asset/CIFAR10_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/CIFAR10_test.png -------------------------------------------------------------------------------- /asset/CIRAR10_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/CIRAR10_train.png -------------------------------------------------------------------------------- /asset/MNIST_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/MNIST_test.png -------------------------------------------------------------------------------- /asset/MNIST_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/MNIST_train.png -------------------------------------------------------------------------------- /asset/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/cifar.png -------------------------------------------------------------------------------- /asset/conv_transform_Fashion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/conv_transform_Fashion.png -------------------------------------------------------------------------------- /asset/conv_transform_MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/conv_transform_MNIST.png -------------------------------------------------------------------------------- /asset/convolutional_transform_CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/convolutional_transform_CIFAR10.png -------------------------------------------------------------------------------- /asset/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/mnist.png -------------------------------------------------------------------------------- /asset/transform_based_CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/transform_based_CIFAR10.png -------------------------------------------------------------------------------- /asset/transfrom_based_MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangletLiu/Tensor_Layer_for_Deep_Neural_Network_Compression/2fe88a989501e4c1f5e17d05873efe6906f45c55/asset/transfrom_based_MNIST.png -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import * 2 | from .test import * 3 | from .loader import * 4 | from .builder import * 5 | from .main import * -------------------------------------------------------------------------------- /common/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorly as tl 3 | import os 4 | import sys 5 | sys.path.append('..') 6 | from decomposition import * 7 | 8 | 9 | def build(model, decomp='cp'): 10 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 11 | print('==> Building model..') 12 | tl.set_backend('pytorch') 13 | full_net = model 14 | full_net = full_net.to(device) 15 | 16 | path = 'models/' 17 | if not os.path.exists(path): 18 | os.mkdir(path) 19 | torch.save(full_net, path + 'model') 20 | if decomp: 21 | decompose_conv(decomp) 22 | decompose_fc(decomp) 23 | if device == 'cuda:0': 24 | net = torch.load(path + "model").cuda() 25 | else: 26 | net = torch.load(path + "model") 27 | print(net) 28 | print('==> Done') 29 | return net -------------------------------------------------------------------------------- /common/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torchvision import models 5 | import os 6 | 7 | 8 | def load_mnist(): 9 | print('==> Loading data..') 10 | transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 11 | transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 12 | 13 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train) 14 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0) 15 | 16 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) 17 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0) 18 | return trainloader, testloader 19 | 20 | def load_fashion_mnist(): 21 | print('==> Loading data..') 22 | transform_train = transforms.Compose([transforms.ToTensor()]) 23 | transform_test = transforms.Compose([transforms.ToTensor()]) 24 | 25 | trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train) 26 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0) 27 | 28 | testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test) 29 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0) 30 | return trainloader, testloader 31 | 32 | def load_cifar10(): 33 | print('==> Loading data..') 34 | transform_train = transforms.Compose([ 35 | transforms.RandomCrop(32, padding=4), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 39 | ]) 40 | 41 | transform_test = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | ]) 45 | 46 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 47 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0) 48 | 49 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 50 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0) 51 | 52 | return trainloader, testloader 53 | 54 | def load_cifar100(): 55 | print('==> Loading data..') 56 | transform_train = transforms.Compose([ 57 | transforms.RandomCrop(32, padding=4), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 61 | ]) 62 | 63 | transform_test = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 66 | ]) 67 | 68 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 69 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0) 70 | 71 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 72 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0) 73 | 74 | return trainloader, testloader 75 | 76 | def load_mnist_multiprocess(override=0): 77 | print('==> Loading data..') 78 | if override: 79 | cpu_count = override 80 | else: 81 | cpu_count = os.cpu_count() 82 | transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 83 | transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 84 | 85 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train) 86 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=cpu_count, shuffle=True, num_workers=0) 87 | 88 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) 89 | testloader = torch.utils.data.DataLoader(testset, batch_size=cpu_count, shuffle=False, num_workers=0) 90 | return trainloader, testloader -------------------------------------------------------------------------------- /common/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torchvision import models 11 | 12 | import tensorly as tl 13 | import tensorly 14 | from itertools import chain 15 | from tensorly.decomposition import parafac, partial_tucker 16 | 17 | import os 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import time 21 | 22 | import sys 23 | sys.path.append('..') 24 | from common import * 25 | from decomposition import * 26 | from nets import * 27 | from transform_based_network import * 28 | 29 | 30 | # main function 31 | def run_all(dataset, model, decomp=None, i=100, rate=0.05, transform_based=False): 32 | 33 | # choose dataset from (MNIST, CIFAR10, ImageNet) 34 | if dataset == 'mnist': 35 | trainloader, testloader = load_mnist() 36 | if dataset == 'cifar10': 37 | trainloader, testloader = load_cifar10() 38 | if dataset == 'cifar100': 39 | trainloader, testloader = load_cifar100() 40 | 41 | # choose decomposition algorithm from (CP, Tucker, TT) 42 | if not transform_based: 43 | net = build(model, decomp) 44 | optimizer = optim.SGD(net.parameters(), lr=rate, momentum=0.9, weight_decay=5e-4) 45 | train_acc, test_acc = train(i, net, trainloader, testloader, optimizer) 46 | _, inf_time = test([], net, testloader) 47 | 48 | if not decomp: 49 | decomp = 'full' 50 | filename = dataset + '_' + decomp 51 | else: 52 | net = Transform_Net() 53 | optimizer = optim.SGD(net.parameters(), lr=rate, momentum=0.9, weight_decay=5e-4) 54 | train_acc, test_acc = train_transform(25, model, trainloader, testloader, optimizer) 55 | 56 | torch.save(net, 'models/' + filename) 57 | path = 'curves/' 58 | if not os.path.exists(path): 59 | os.mkdir(path) 60 | 61 | np.save(path + filename + '_train', train_acc) 62 | np.save(path + filename + '_test', test_acc) 63 | -------------------------------------------------------------------------------- /common/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import numpy as np 9 | import time 10 | 11 | def test(test_acc, model, testloader): 12 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 13 | model.eval() 14 | test_loss = 0 15 | correct = 0 16 | total = 0 17 | criterion = nn.CrossEntropyLoss() 18 | s = time.time() 19 | with torch.no_grad(): 20 | print('|', end='') 21 | for batch_idx, (inputs, targets) in enumerate(testloader): 22 | inputs, targets = inputs.to(device), targets.to(device) 23 | outputs = model(inputs) 24 | loss = criterion(outputs, targets) 25 | test_loss += loss.item() 26 | _, predicted = outputs.max(1) 27 | total += targets.size(0) 28 | correct += predicted.eq(targets).sum().item() 29 | if batch_idx % 10 == 0: 30 | print('=', end='') 31 | e = time.time() 32 | acc = 100. * correct / total 33 | print('|', 'Accuracy:', acc, '% ', correct, '/', total) 34 | print('The inference time is', e - s, 'seconds') 35 | test_acc.append(correct / total) 36 | return test_acc, e - s -------------------------------------------------------------------------------- /common/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import numpy as np 9 | import time 10 | from .test import * 11 | 12 | def train_step(epoch, train_acc, model, trainloader, optimizer): 13 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 14 | print('\nEpoch: ', epoch) 15 | model.train() 16 | criterion = nn.CrossEntropyLoss() 17 | train_loss = 0 18 | correct = 0 19 | total = 0 20 | print('|', end='') 21 | for batch_idx, (inputs, targets) in enumerate(trainloader): 22 | inputs, targets = inputs.to(device), targets.to(device) 23 | optimizer.zero_grad() 24 | outputs = model(inputs) 25 | print(outputs.shape, targets.shape) 26 | loss = criterion(outputs, targets) 27 | loss.backward() 28 | optimizer.step() 29 | train_loss += loss.item() 30 | _, predicted = outputs.max(1) 31 | total += targets.size(0) 32 | correct += predicted.eq(targets).sum().item() 33 | if batch_idx % 10 == 0: 34 | print('=', end='') 35 | print('|', 'Accuracy:', 100. * correct / total,'% ', correct, '/', total) 36 | train_acc.append(correct / total) 37 | return train_acc 38 | 39 | def train(i, model, trainloader, testloader, optimizer): 40 | train_acc = [] 41 | test_acc = [] 42 | scheduler = StepLR(optimizer, step_size=5, gamma=0.9) 43 | for epoch in range(i): 44 | s = time.time() 45 | train_acc = train_step(epoch, train_acc, model, trainloader, optimizer) 46 | test_acc, _ = test(test_acc, model, testloader) 47 | scheduler.step() 48 | e = time.time() 49 | print('This epoch took', e - s, 'seconds to train') 50 | print('Current learning rate: ', scheduler.get_last_lr()[0]) 51 | print('Best training accuracy overall: ', max(test_acc)) 52 | return train_acc, test_acc -------------------------------------------------------------------------------- /decomposition/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_layer import * 2 | from .fc_layer import * 3 | from .decompose_all import * 4 | from .tensor_ring import * -------------------------------------------------------------------------------- /decomposition/conv_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torchvision import models 11 | 12 | import tensorly as tl 13 | import tensorly 14 | from itertools import chain 15 | from tensorly.decomposition import parafac, partial_tucker, matrix_product_state 16 | 17 | import os 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import time 21 | 22 | 23 | def cp_decomposition_conv_layer(layer, rank): 24 | l, f, v, h = parafac(layer.weight.data, rank=rank)[1] 25 | factors = [l, f, v, h] 26 | #print([f.shape for f in factors]) 27 | 28 | pointwise_s_to_r_layer = torch.nn.Conv2d( 29 | in_channels=f.shape[0], 30 | out_channels=f.shape[1], 31 | kernel_size=1, 32 | stride=1, 33 | padding=0, 34 | dilation=layer.dilation, 35 | bias=False) 36 | 37 | depthwise_vertical_layer = torch.nn.Conv2d( 38 | in_channels=v.shape[1], 39 | out_channels=v.shape[1], 40 | kernel_size=(v.shape[0], 1), 41 | stride=1, padding=(layer.padding[0], 0), 42 | dilation=layer.dilation, 43 | groups=v.shape[1], 44 | bias=False) 45 | 46 | depthwise_horizontal_layer = torch.nn.Conv2d( 47 | in_channels=h.shape[1], 48 | out_channels=h.shape[1], 49 | kernel_size=(1, h.shape[0]), 50 | stride=layer.stride, 51 | padding=(0, layer.padding[0]), 52 | dilation=layer.dilation, 53 | groups=h.shape[1], 54 | bias=False) 55 | 56 | pointwise_r_to_t_layer = torch.nn.Conv2d( 57 | in_channels=l.shape[1], 58 | out_channels=l.shape[0], 59 | kernel_size=1, 60 | stride=1, 61 | padding=0, 62 | dilation=layer.dilation, 63 | bias=True) 64 | 65 | pointwise_r_to_t_layer.bias.data = layer.bias.data 66 | depthwise_horizontal_layer.weight.data = torch.transpose(h, 1, 0).unsqueeze(1).unsqueeze(1) 67 | depthwise_vertical_layer.weight.data = torch.transpose(v, 1, 0).unsqueeze(1).unsqueeze(-1) 68 | pointwise_s_to_r_layer.weight.data = torch.transpose(f, 1, 0).unsqueeze(-1).unsqueeze(-1) 69 | pointwise_r_to_t_layer.weight.data = l.unsqueeze(-1).unsqueeze(-1) 70 | 71 | new_layers = [pointwise_s_to_r_layer, depthwise_vertical_layer, 72 | depthwise_horizontal_layer, pointwise_r_to_t_layer] 73 | #for l in new_layers: 74 | # print(l.weight.data.shape) 75 | 76 | return nn.Sequential(*new_layers) 77 | 78 | def tucker_decomposition_conv_layer(layer, ranks): 79 | core, [last, first] = partial_tucker(layer.weight.data, modes=[0, 1], ranks=ranks, init='svd') 80 | #print(core.shape, last.shape, first.shape) 81 | 82 | # A pointwise convolution that reduces the channels from S to R3 83 | first_layer = torch.nn.Conv2d(in_channels=first.shape[0], 84 | out_channels=first.shape[1], kernel_size=1, 85 | stride=1, padding=0, dilation=layer.dilation, bias=False) 86 | 87 | # A regular 2D convolution layer with R3 input channels 88 | # and R3 output channels 89 | core_layer = torch.nn.Conv2d(in_channels=core.shape[1], 90 | out_channels=core.shape[0], kernel_size=layer.kernel_size, 91 | stride=layer.stride, padding=layer.padding, dilation=layer.dilation, bias=False) 92 | 93 | # A pointwise convolution that increases the channels from R4 to T 94 | last_layer = torch.nn.Conv2d(in_channels=last.shape[1], \ 95 | out_channels=last.shape[0], kernel_size=1, stride=1, 96 | padding=0, dilation=layer.dilation, bias=True) 97 | 98 | last_layer.bias.data = layer.bias.data 99 | 100 | first_layer.weight.data = torch.transpose(first, 1, 0).unsqueeze(-1).unsqueeze(-1) 101 | last_layer.weight.data = last.unsqueeze(-1).unsqueeze(-1) 102 | core_layer.weight.data = core 103 | 104 | new_layers = [first_layer, core_layer, last_layer] 105 | #for l in new_layers: 106 | # print(l.weight.data.shape) 107 | return nn.Sequential(*new_layers) 108 | 109 | def tt_decomposition_conv_layer(layer, ranks): 110 | data = layer.weight.data 111 | data2D = tl.base.unfold(data, 0) 112 | 113 | first, last = matrix_product_state(data2D, rank=ranks) 114 | factors = [first, last] 115 | #print([f.shape for f in factors]) 116 | 117 | first = first.reshape(data.shape[0], ranks, 1, 1) 118 | last = last.reshape(ranks, data.shape[1], layer.kernel_size[0], layer.kernel_size[1]) 119 | 120 | pointwise_s_to_r_layer = torch.nn.Conv2d( 121 | in_channels=last.shape[1], 122 | out_channels=last.shape[0], 123 | kernel_size=layer.kernel_size, 124 | stride=layer.stride, 125 | padding=layer.padding, 126 | dilation=layer.dilation, 127 | bias=False) 128 | 129 | pointwise_r_to_t_layer = torch.nn.Conv2d( 130 | in_channels=first.shape[1], 131 | out_channels=first.shape[0], 132 | kernel_size=1, 133 | stride=1, 134 | padding=0, 135 | dilation=layer.dilation, 136 | bias=True) 137 | 138 | pointwise_r_to_t_layer.bias.data = layer.bias.data 139 | pointwise_s_to_r_layer.weight.data = last 140 | pointwise_r_to_t_layer.weight.data = first 141 | 142 | new_layers = [pointwise_s_to_r_layer, pointwise_r_to_t_layer] 143 | #for l in new_layers: 144 | # print(l.weight.data.shape) 145 | 146 | return nn.Sequential(*new_layers) -------------------------------------------------------------------------------- /decomposition/decompose_all.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torchvision import models 5 | import tensorly as tl 6 | import tensorly 7 | from itertools import chain 8 | 9 | from .conv_layer import * 10 | from .fc_layer import * 11 | 12 | # decompose all layers in a network 13 | def decompose_conv(decomp): 14 | model = torch.load("models/model").cuda() 15 | model.eval() 16 | model.cpu() 17 | for i, key in enumerate(model.features._modules.keys()): 18 | if i >= len(model.features._modules.keys()) - 2: 19 | break 20 | conv_layer = model.features._modules[key] 21 | if isinstance(conv_layer, torch.nn.modules.conv.Conv2d): 22 | rank = max(conv_layer.weight.data.numpy().shape) // 3 23 | if decomp == 'cp': 24 | model.features._modules[key] = cp_decomposition_conv_layer(conv_layer, rank) 25 | if decomp == 'tucker': 26 | ranks = [int(np.ceil(conv_layer.weight.data.numpy().shape[0] / 3)), 27 | int(np.ceil(conv_layer.weight.data.numpy().shape[1] / 3))] 28 | model.features._modules[key] = tucker_decomposition_conv_layer(conv_layer, ranks) 29 | if decomp == 'tt': 30 | model.features._modules[key] = tt_decomposition_conv_layer(conv_layer, rank) 31 | torch.save(model, 'models/model') 32 | 33 | def decompose_fc(decomp): 34 | model = torch.load("models/model").cuda() 35 | model.eval() 36 | model.cpu() 37 | for i, key in enumerate(model.classifier._modules.keys()): 38 | linear_layer = model.classifier._modules[key] 39 | if isinstance(linear_layer, torch.nn.modules.linear.Linear): 40 | rank = min(linear_layer.weight.data.numpy().shape) // 2 41 | if decomp == 'tucker': 42 | model.classifier._modules[key] = tucker_decomposition_fc_layer(linear_layer, rank) 43 | else: 44 | model.classifier._modules[key] = decomposition_fc_layer(linear_layer, rank) 45 | torch.save(model, 'models/model') 46 | return model -------------------------------------------------------------------------------- /decomposition/fc_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torchvision import models 11 | 12 | import tensorly as tl 13 | import tensorly 14 | from itertools import chain 15 | from tensorly.decomposition import parafac, tucker, matrix_product_state 16 | 17 | import os 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import time 21 | 22 | 23 | def decomposition_fc_layer(layer, rank): 24 | l, r = matrix_product_state(layer.weight.data, rank=rank) 25 | l, r = l.squeeze(), r.squeeze() 26 | 27 | right_layer = torch.nn.Linear(r.shape[1], r.shape[0]) 28 | left_layer = torch.nn.Linear(l.shape[1], l.shape[0]) 29 | 30 | left_layer.bias.data = layer.bias.data 31 | left_layer.weight.data = l 32 | right_layer.weight.data = r 33 | 34 | new_layers = [right_layer, left_layer] 35 | return nn.Sequential(*new_layers) 36 | 37 | 38 | def tucker_decomposition_fc_layer(layer, rank): 39 | core, [l, r] = tucker(layer.weight.data, rank=rank) 40 | 41 | right_layer = torch.nn.Linear(r.shape[0], r.shape[1]) 42 | core_layer = torch.nn.Linear(core.shape[1], core.shape[0]) 43 | left_layer = torch.nn.Linear(l.shape[1], l.shape[0]) 44 | 45 | left_layer.bias.data = layer.bias.data 46 | left_layer.weight.data = l 47 | right_layer.weight.data = r.T 48 | 49 | new_layers = [right_layer, core_layer, left_layer] 50 | return nn.Sequential(*new_layers) -------------------------------------------------------------------------------- /decomposition/tensor_ring.py: -------------------------------------------------------------------------------- 1 | import tensorly as tl 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def tensor_ring(input_tensor, rank): 7 | 8 | '''Tensor ring (TR) decomposition via recursive SVD 9 | 10 | Decomposes input_tensor into a sequence of order-3 tensors (factors), 11 | with the input rank of the first factor equal to the output rank of the 12 | last factor. This code is modified from MPS decomposition in tensorly 13 | lib's src code 14 | 15 | Parameters 16 | ---------- 17 | input_tensor : tensorly.tensor 18 | rank : {int, int list} 19 | maximum allowable rank of the factors 20 | if int, then this is the same for all the factors 21 | if int list, then rank[k] is the rank of the kth factor 22 | 23 | Returns 24 | ------- 25 | factors : Tensor ring factors 26 | order-3 tensors of the tensor ring decomposition 27 | ''' 28 | 29 | # Check user input for errors 30 | tensor_size = input_tensor.shape 31 | n_dim = len(tensor_size) 32 | 33 | if isinstance(rank, int): 34 | rank = [rank] * n_dim 35 | elif n_dim != len(rank): 36 | message = 'Provided incorrect number of ranks. ' 37 | raise(ValueError(message)) 38 | rank = list(rank) 39 | 40 | # Initialization 41 | unfolding = tl.unfold(input_tensor, 0) 42 | factors = [None] * n_dim 43 | U, S, V = tl.partial_svd(unfolding, rank[0]) 44 | r0 = int(np.sqrt(rank[0])) 45 | while rank[0] % r0: 46 | r0 -= 1; 47 | T0 = tl.reshape(U, (tensor_size[0], r0, rank[0] // r0)) 48 | factors[0] = torch.transpose(torch.tensor(T0), 0, 1) 49 | unfolding = tl.reshape(S, (-1, 1)) * V 50 | rank[1] = rank[0] // r0 51 | rank.append(r0) 52 | 53 | # Getting the MPS factors up to n_dim 54 | for k in range(1, n_dim): 55 | 56 | # Reshape the unfolding matrix of the remaining factors 57 | n_row = int(rank[k]*tensor_size[k]) 58 | unfolding = tl.reshape(unfolding, (n_row, -1)) 59 | 60 | # SVD of unfolding matrix 61 | (n_row, n_column) = unfolding.shape 62 | rank[k+1] = min(n_row, n_column, rank[k+1]) 63 | U, S, V = tl.partial_svd(unfolding, rank[k+1]) 64 | 65 | # Get kth MPS factor 66 | factors[k] = tl.reshape(U, (rank[k], tensor_size[k], rank[k+1])) 67 | 68 | # Get new unfolding matrix for the remaining factors 69 | unfolding = tl.reshape(S, (-1, 1)) * V 70 | 71 | return factors -------------------------------------------------------------------------------- /examples/ex1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('..')\n", 11 | "from common import *\n", 12 | "from decomposition import *\n", 13 | "from nets import *" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "==> Building model..\n", 26 | "LeNet(\n", 27 | " (features): Sequential(\n", 28 | " (0): Sequential(\n", 29 | " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 30 | " (1): Conv2d(1, 11, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", 31 | " (2): Conv2d(11, 32, kernel_size=(1, 1), stride=(1, 1))\n", 32 | " )\n", 33 | " (1): ReLU(inplace=True)\n", 34 | " (2): Sequential(\n", 35 | " (0): Conv2d(32, 11, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 36 | " (1): Conv2d(11, 22, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", 37 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n", 38 | " )\n", 39 | " (3): ReLU(inplace=True)\n", 40 | " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 41 | " )\n", 42 | " (classifier): Sequential(\n", 43 | " (0): Dropout(p=0.25, inplace=False)\n", 44 | " (1): Sequential(\n", 45 | " (0): Linear(in_features=9216, out_features=64, bias=True)\n", 46 | " (1): Linear(in_features=64, out_features=64, bias=True)\n", 47 | " (2): Linear(in_features=64, out_features=128, bias=True)\n", 48 | " )\n", 49 | " (2): ReLU(inplace=True)\n", 50 | " (3): Dropout(p=0.5, inplace=False)\n", 51 | " (4): Sequential(\n", 52 | " (0): Linear(in_features=128, out_features=5, bias=True)\n", 53 | " (1): Linear(in_features=5, out_features=5, bias=True)\n", 54 | " (2): Linear(in_features=5, out_features=10, bias=True)\n", 55 | " )\n", 56 | " )\n", 57 | ")\n", 58 | "==> Done\n" 59 | ] 60 | }, 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorly\\decomposition\\_tucker.py:63: Warning: Given only one int for 'rank' instead of a list of 2 modes. Using this rank for all modes.\n", 66 | " warnings.warn(message, Warning)\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "model = LeNet()\n", 72 | "decomposed = build(model, decomp='tucker')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "==> Loading data..\n", 85 | "==> Building model..\n", 86 | "LeNet(\n", 87 | " (features): Sequential(\n", 88 | " (0): Sequential(\n", 89 | " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 90 | " (1): Conv2d(1, 11, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", 91 | " (2): Conv2d(11, 32, kernel_size=(1, 1), stride=(1, 1))\n", 92 | " )\n", 93 | " (1): ReLU(inplace=True)\n", 94 | " (2): Sequential(\n", 95 | " (0): Conv2d(32, 11, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 96 | " (1): Conv2d(11, 22, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", 97 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n", 98 | " )\n", 99 | " (3): ReLU(inplace=True)\n", 100 | " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 101 | " )\n", 102 | " (classifier): Sequential(\n", 103 | " (0): Dropout(p=0.25, inplace=False)\n", 104 | " (1): Sequential(\n", 105 | " (0): Linear(in_features=9216, out_features=64, bias=True)\n", 106 | " (1): Linear(in_features=64, out_features=64, bias=True)\n", 107 | " (2): Linear(in_features=64, out_features=128, bias=True)\n", 108 | " )\n", 109 | " (2): ReLU(inplace=True)\n", 110 | " (3): Dropout(p=0.5, inplace=False)\n", 111 | " (4): Sequential(\n", 112 | " (0): Linear(in_features=128, out_features=5, bias=True)\n", 113 | " (1): Linear(in_features=5, out_features=5, bias=True)\n", 114 | " (2): Linear(in_features=5, out_features=10, bias=True)\n", 115 | " )\n", 116 | " )\n", 117 | ")\n", 118 | "==> Done\n", 119 | "\n", 120 | "Epoch: 0\n", 121 | "|torch.Size([128, 10]) torch.Size([128])\n", 122 | "=torch.Size([128, 10]) torch.Size([128])\n", 123 | "torch.Size([128, 10]) torch.Size([128])\n", 124 | "torch.Size([128, 10]) torch.Size([128])\n", 125 | "torch.Size([128, 10]) torch.Size([128])\n", 126 | "torch.Size([128, 10]) torch.Size([128])\n", 127 | "torch.Size([128, 10]) torch.Size([128])\n", 128 | "torch.Size([128, 10]) torch.Size([128])\n", 129 | "torch.Size([128, 10]) torch.Size([128])\n", 130 | "torch.Size([128, 10]) torch.Size([128])\n", 131 | "torch.Size([128, 10]) torch.Size([128])\n", 132 | "=torch.Size([128, 10]) torch.Size([128])\n", 133 | "torch.Size([128, 10]) torch.Size([128])\n", 134 | "torch.Size([128, 10]) torch.Size([128])\n", 135 | "torch.Size([128, 10]) torch.Size([128])\n", 136 | "torch.Size([128, 10]) torch.Size([128])\n", 137 | "torch.Size([128, 10]) torch.Size([128])\n", 138 | "torch.Size([128, 10]) torch.Size([128])\n", 139 | "torch.Size([128, 10]) torch.Size([128])\n", 140 | "torch.Size([128, 10]) torch.Size([128])\n", 141 | "torch.Size([128, 10]) torch.Size([128])\n", 142 | "=torch.Size([128, 10]) torch.Size([128])\n", 143 | "torch.Size([128, 10]) torch.Size([128])\n", 144 | "torch.Size([128, 10]) torch.Size([128])\n", 145 | "torch.Size([128, 10]) torch.Size([128])\n", 146 | "torch.Size([128, 10]) torch.Size([128])\n", 147 | "torch.Size([128, 10]) torch.Size([128])\n", 148 | "torch.Size([128, 10]) torch.Size([128])\n", 149 | "torch.Size([128, 10]) torch.Size([128])\n", 150 | "torch.Size([128, 10]) torch.Size([128])\n", 151 | "torch.Size([128, 10]) torch.Size([128])\n", 152 | "=torch.Size([128, 10]) torch.Size([128])\n", 153 | "torch.Size([128, 10]) torch.Size([128])\n", 154 | "torch.Size([128, 10]) torch.Size([128])\n", 155 | "torch.Size([128, 10]) torch.Size([128])\n", 156 | "torch.Size([128, 10]) torch.Size([128])\n", 157 | "torch.Size([128, 10]) torch.Size([128])\n", 158 | "torch.Size([128, 10]) torch.Size([128])\n", 159 | "torch.Size([128, 10]) torch.Size([128])\n", 160 | "torch.Size([128, 10]) torch.Size([128])\n", 161 | "torch.Size([128, 10]) torch.Size([128])\n", 162 | "=torch.Size([128, 10]) torch.Size([128])\n", 163 | "torch.Size([128, 10]) torch.Size([128])\n", 164 | "torch.Size([128, 10]) torch.Size([128])\n", 165 | "torch.Size([128, 10]) torch.Size([128])\n", 166 | "torch.Size([128, 10]) torch.Size([128])\n", 167 | "torch.Size([128, 10]) torch.Size([128])\n", 168 | "torch.Size([128, 10]) torch.Size([128])\n", 169 | "torch.Size([128, 10]) torch.Size([128])\n", 170 | "torch.Size([128, 10]) torch.Size([128])\n", 171 | "torch.Size([128, 10]) torch.Size([128])\n", 172 | "=torch.Size([128, 10]) torch.Size([128])\n", 173 | "torch.Size([128, 10]) torch.Size([128])\n", 174 | "torch.Size([128, 10]) torch.Size([128])\n", 175 | "torch.Size([128, 10]) torch.Size([128])\n", 176 | "torch.Size([128, 10]) torch.Size([128])\n", 177 | "torch.Size([128, 10]) torch.Size([128])\n", 178 | "torch.Size([128, 10]) torch.Size([128])\n", 179 | "torch.Size([128, 10]) torch.Size([128])\n", 180 | "torch.Size([128, 10]) torch.Size([128])\n", 181 | "torch.Size([128, 10]) torch.Size([128])\n", 182 | "=torch.Size([128, 10]) torch.Size([128])\n", 183 | "torch.Size([128, 10]) torch.Size([128])\n", 184 | "torch.Size([128, 10]) torch.Size([128])\n", 185 | "torch.Size([128, 10]) torch.Size([128])\n", 186 | "torch.Size([128, 10]) torch.Size([128])\n", 187 | "torch.Size([128, 10]) torch.Size([128])\n", 188 | "torch.Size([128, 10]) torch.Size([128])\n", 189 | "torch.Size([128, 10]) torch.Size([128])\n", 190 | "torch.Size([128, 10]) torch.Size([128])\n", 191 | "torch.Size([128, 10]) torch.Size([128])\n", 192 | "=torch.Size([128, 10]) torch.Size([128])\n", 193 | "torch.Size([128, 10]) torch.Size([128])\n", 194 | "torch.Size([128, 10]) torch.Size([128])\n", 195 | "torch.Size([128, 10]) torch.Size([128])\n", 196 | "torch.Size([128, 10]) torch.Size([128])\n", 197 | "torch.Size([128, 10]) torch.Size([128])\n", 198 | "torch.Size([128, 10]) torch.Size([128])\n", 199 | "torch.Size([128, 10]) torch.Size([128])\n", 200 | "torch.Size([128, 10]) torch.Size([128])\n", 201 | "torch.Size([128, 10]) torch.Size([128])\n", 202 | "=torch.Size([128, 10]) torch.Size([128])\n", 203 | "torch.Size([128, 10]) torch.Size([128])\n", 204 | "torch.Size([128, 10]) torch.Size([128])\n", 205 | "torch.Size([128, 10]) torch.Size([128])\n", 206 | "torch.Size([128, 10]) torch.Size([128])\n", 207 | "torch.Size([128, 10]) torch.Size([128])\n", 208 | "torch.Size([128, 10]) torch.Size([128])\n", 209 | "torch.Size([128, 10]) torch.Size([128])\n", 210 | "torch.Size([128, 10]) torch.Size([128])\n", 211 | "torch.Size([128, 10]) torch.Size([128])\n", 212 | "=torch.Size([128, 10]) torch.Size([128])\n", 213 | "torch.Size([128, 10]) torch.Size([128])\n", 214 | "torch.Size([128, 10]) torch.Size([128])\n", 215 | "torch.Size([128, 10]) torch.Size([128])\n", 216 | "torch.Size([128, 10]) torch.Size([128])\n", 217 | "torch.Size([128, 10]) torch.Size([128])\n", 218 | "torch.Size([128, 10]) torch.Size([128])\n", 219 | "torch.Size([128, 10]) torch.Size([128])\n", 220 | "torch.Size([128, 10]) torch.Size([128])\n", 221 | "torch.Size([128, 10]) torch.Size([128])\n", 222 | "=torch.Size([128, 10]) torch.Size([128])\n", 223 | "torch.Size([128, 10]) torch.Size([128])\n", 224 | "torch.Size([128, 10]) torch.Size([128])\n", 225 | "torch.Size([128, 10]) torch.Size([128])\n", 226 | "torch.Size([128, 10]) torch.Size([128])\n", 227 | "torch.Size([128, 10]) torch.Size([128])\n", 228 | "torch.Size([128, 10]) torch.Size([128])\n", 229 | "torch.Size([128, 10]) torch.Size([128])\n", 230 | "torch.Size([128, 10]) torch.Size([128])\n", 231 | "torch.Size([128, 10]) torch.Size([128])\n", 232 | "=torch.Size([128, 10]) torch.Size([128])\n", 233 | "torch.Size([128, 10]) torch.Size([128])\n", 234 | "torch.Size([128, 10]) torch.Size([128])\n", 235 | "torch.Size([128, 10]) torch.Size([128])\n", 236 | "torch.Size([128, 10]) torch.Size([128])\n", 237 | "torch.Size([128, 10]) torch.Size([128])\n", 238 | "torch.Size([128, 10]) torch.Size([128])\n", 239 | "torch.Size([128, 10]) torch.Size([128])\n", 240 | "torch.Size([128, 10]) torch.Size([128])\n", 241 | "torch.Size([128, 10]) torch.Size([128])\n", 242 | "=torch.Size([128, 10]) torch.Size([128])\n", 243 | "torch.Size([128, 10]) torch.Size([128])\n", 244 | "torch.Size([128, 10]) torch.Size([128])\n", 245 | "torch.Size([128, 10]) torch.Size([128])\n", 246 | "torch.Size([128, 10]) torch.Size([128])\n", 247 | "torch.Size([128, 10]) torch.Size([128])\n", 248 | "torch.Size([128, 10]) torch.Size([128])\n", 249 | "torch.Size([128, 10]) torch.Size([128])\n", 250 | "torch.Size([128, 10]) torch.Size([128])\n", 251 | "torch.Size([128, 10]) torch.Size([128])\n", 252 | "=torch.Size([128, 10]) torch.Size([128])\n", 253 | "torch.Size([128, 10]) torch.Size([128])\n", 254 | "torch.Size([128, 10]) torch.Size([128])\n", 255 | "torch.Size([128, 10]) torch.Size([128])\n", 256 | "torch.Size([128, 10]) torch.Size([128])\n", 257 | "torch.Size([128, 10]) torch.Size([128])\n", 258 | "torch.Size([128, 10]) torch.Size([128])\n", 259 | "torch.Size([128, 10]) torch.Size([128])\n", 260 | "torch.Size([128, 10]) torch.Size([128])\n", 261 | "torch.Size([128, 10]) torch.Size([128])\n", 262 | "=torch.Size([128, 10]) torch.Size([128])\n", 263 | "torch.Size([128, 10]) torch.Size([128])\n", 264 | "torch.Size([128, 10]) torch.Size([128])\n", 265 | "torch.Size([128, 10]) torch.Size([128])\n", 266 | "torch.Size([128, 10]) torch.Size([128])\n", 267 | "torch.Size([128, 10]) torch.Size([128])\n", 268 | "torch.Size([128, 10]) torch.Size([128])\n", 269 | "torch.Size([128, 10]) torch.Size([128])\n", 270 | "torch.Size([128, 10]) torch.Size([128])\n", 271 | "torch.Size([128, 10]) torch.Size([128])\n", 272 | "=torch.Size([128, 10]) torch.Size([128])\n", 273 | "torch.Size([128, 10]) torch.Size([128])\n", 274 | "torch.Size([128, 10]) torch.Size([128])\n", 275 | "torch.Size([128, 10]) torch.Size([128])\n", 276 | "torch.Size([128, 10]) torch.Size([128])\n", 277 | "torch.Size([128, 10]) torch.Size([128])\n", 278 | "torch.Size([128, 10]) torch.Size([128])\n", 279 | "torch.Size([128, 10]) torch.Size([128])\n", 280 | "torch.Size([128, 10]) torch.Size([128])\n", 281 | "torch.Size([128, 10]) torch.Size([128])\n", 282 | "=torch.Size([128, 10]) torch.Size([128])\n", 283 | "torch.Size([128, 10]) torch.Size([128])\n", 284 | "torch.Size([128, 10]) torch.Size([128])\n", 285 | "torch.Size([128, 10]) torch.Size([128])\n", 286 | "torch.Size([128, 10]) torch.Size([128])\n", 287 | "torch.Size([128, 10]) torch.Size([128])\n", 288 | "torch.Size([128, 10]) torch.Size([128])\n", 289 | "torch.Size([128, 10]) torch.Size([128])\n", 290 | "torch.Size([128, 10]) torch.Size([128])\n", 291 | "torch.Size([128, 10]) torch.Size([128])\n", 292 | "=torch.Size([128, 10]) torch.Size([128])\n", 293 | "torch.Size([128, 10]) torch.Size([128])\n", 294 | "torch.Size([128, 10]) torch.Size([128])\n", 295 | "torch.Size([128, 10]) torch.Size([128])\n", 296 | "torch.Size([128, 10]) torch.Size([128])\n", 297 | "torch.Size([128, 10]) torch.Size([128])\n", 298 | "torch.Size([128, 10]) torch.Size([128])\n" 299 | ] 300 | }, 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "torch.Size([128, 10]) torch.Size([128])\n", 306 | "torch.Size([128, 10]) torch.Size([128])\n", 307 | "torch.Size([128, 10]) torch.Size([128])\n", 308 | "=torch.Size([128, 10]) torch.Size([128])\n", 309 | "torch.Size([128, 10]) torch.Size([128])\n", 310 | "torch.Size([128, 10]) torch.Size([128])\n", 311 | "torch.Size([128, 10]) torch.Size([128])\n", 312 | "torch.Size([128, 10]) torch.Size([128])\n", 313 | "torch.Size([128, 10]) torch.Size([128])\n", 314 | "torch.Size([128, 10]) torch.Size([128])\n", 315 | "torch.Size([128, 10]) torch.Size([128])\n", 316 | "torch.Size([128, 10]) torch.Size([128])\n", 317 | "torch.Size([128, 10]) torch.Size([128])\n", 318 | "=torch.Size([128, 10]) torch.Size([128])\n", 319 | "torch.Size([128, 10]) torch.Size([128])\n", 320 | "torch.Size([128, 10]) torch.Size([128])\n", 321 | "torch.Size([128, 10]) torch.Size([128])\n", 322 | "torch.Size([128, 10]) torch.Size([128])\n", 323 | "torch.Size([128, 10]) torch.Size([128])\n", 324 | "torch.Size([128, 10]) torch.Size([128])\n", 325 | "torch.Size([128, 10]) torch.Size([128])\n", 326 | "torch.Size([128, 10]) torch.Size([128])\n", 327 | "torch.Size([128, 10]) torch.Size([128])\n", 328 | "=torch.Size([128, 10]) torch.Size([128])\n", 329 | "torch.Size([128, 10]) torch.Size([128])\n", 330 | "torch.Size([128, 10]) torch.Size([128])\n", 331 | "torch.Size([128, 10]) torch.Size([128])\n", 332 | "torch.Size([128, 10]) torch.Size([128])\n", 333 | "torch.Size([128, 10]) torch.Size([128])\n", 334 | "torch.Size([128, 10]) torch.Size([128])\n", 335 | "torch.Size([128, 10]) torch.Size([128])\n", 336 | "torch.Size([128, 10]) torch.Size([128])\n", 337 | "torch.Size([128, 10]) torch.Size([128])\n", 338 | "=torch.Size([128, 10]) torch.Size([128])\n", 339 | "torch.Size([128, 10]) torch.Size([128])\n", 340 | "torch.Size([128, 10]) torch.Size([128])\n", 341 | "torch.Size([128, 10]) torch.Size([128])\n", 342 | "torch.Size([128, 10]) torch.Size([128])\n", 343 | "torch.Size([128, 10]) torch.Size([128])\n", 344 | "torch.Size([128, 10]) torch.Size([128])\n", 345 | "torch.Size([128, 10]) torch.Size([128])\n", 346 | "torch.Size([128, 10]) torch.Size([128])\n", 347 | "torch.Size([128, 10]) torch.Size([128])\n", 348 | "=torch.Size([128, 10]) torch.Size([128])\n", 349 | "torch.Size([128, 10]) torch.Size([128])\n", 350 | "torch.Size([128, 10]) torch.Size([128])\n", 351 | "torch.Size([128, 10]) torch.Size([128])\n", 352 | "torch.Size([128, 10]) torch.Size([128])\n", 353 | "torch.Size([128, 10]) torch.Size([128])\n", 354 | "torch.Size([128, 10]) torch.Size([128])\n", 355 | "torch.Size([128, 10]) torch.Size([128])\n", 356 | "torch.Size([128, 10]) torch.Size([128])\n", 357 | "torch.Size([128, 10]) torch.Size([128])\n", 358 | "=torch.Size([128, 10]) torch.Size([128])\n", 359 | "torch.Size([128, 10]) torch.Size([128])\n", 360 | "torch.Size([128, 10]) torch.Size([128])\n", 361 | "torch.Size([128, 10]) torch.Size([128])\n", 362 | "torch.Size([128, 10]) torch.Size([128])\n", 363 | "torch.Size([128, 10]) torch.Size([128])\n", 364 | "torch.Size([128, 10]) torch.Size([128])\n", 365 | "torch.Size([128, 10]) torch.Size([128])\n", 366 | "torch.Size([128, 10]) torch.Size([128])\n", 367 | "torch.Size([128, 10]) torch.Size([128])\n", 368 | "=torch.Size([128, 10]) torch.Size([128])\n", 369 | "torch.Size([128, 10]) torch.Size([128])\n" 370 | ] 371 | }, 372 | { 373 | "ename": "KeyboardInterrupt", 374 | "evalue": "", 375 | "output_type": "error", 376 | "traceback": [ 377 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 378 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 379 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mrun_all\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'mnist'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecomp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'tucker'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.05\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", 380 | "\u001b[1;32m~\\Desktop\\workstation\\TNN\\common\\main.py\u001b[0m in \u001b[0;36mrun_all\u001b[1;34m(dataset, model, decomp, i, rate)\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[0mnet\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbuild\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecomp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mrate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.9\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight_decay\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m5e-4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 43\u001b[1;33m \u001b[0mtrain_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_acc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrainloader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtestloader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 44\u001b[0m \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minf_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtestloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 381 | "\u001b[1;32m~\\Desktop\\workstation\\TNN\\common\\train.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(i, model, trainloader, testloader, optimizer)\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[0ms\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 45\u001b[1;33m \u001b[0mtrain_acc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrainloader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 46\u001b[0m \u001b[0mtest_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_acc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtestloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 47\u001b[0m \u001b[0mscheduler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 382 | "\u001b[1;32m~\\Desktop\\workstation\\TNN\\common\\train.py\u001b[0m in \u001b[0;36mtrain_step\u001b[1;34m(epoch, train_acc, model, trainloader, optimizer)\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[0mtotal\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'|'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 21\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 22\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 23\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 383 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 343\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 344\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 345\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 346\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 347\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 384 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 383\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 384\u001b[0m \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 385\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 386\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 387\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 385 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 386 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 387 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 388 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m 68\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 69\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 70\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 71\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 72\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 389 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, pic)\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mConverted\u001b[0m \u001b[0mimage\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 100\u001b[0m \"\"\"\n\u001b[1;32m--> 101\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 103\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 390 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\functional.py\u001b[0m in \u001b[0;36mto_tensor\u001b[1;34m(pic)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mByteTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdiv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m255\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 391 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 392 | ] 393 | } 394 | ], 395 | "source": [ 396 | "run_all('mnist', model, decomp='tucker', i=10, rate=0.05)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "run_all('cifar10', model, decomp=None, i=10, rate=0.05)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "x = torch.rand(32, 32)\n", 415 | "x_hat = x.reshape(4, 8, 8, 4)\n", 416 | "a = matrix_product_state(x, [1, 2, 1])\n", 417 | "for i in a:\n", 418 | " print(i.shape)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "trainloader, testloader = load_mnist()" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": null, 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "for i in trainloader:\n", 437 | " print(i[0].squeeze().shape, i[1].shape)\n", 438 | " break" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [] 447 | } 448 | ], 449 | "metadata": { 450 | "kernelspec": { 451 | "display_name": "Python 3", 452 | "language": "python", 453 | "name": "python3" 454 | }, 455 | "language_info": { 456 | "codemirror_mode": { 457 | "name": "ipython", 458 | "version": 3 459 | }, 460 | "file_extension": ".py", 461 | "mimetype": "text/x-python", 462 | "name": "python", 463 | "nbconvert_exporter": "python", 464 | "pygments_lexer": "ipython3", 465 | "version": "3.7.6" 466 | } 467 | }, 468 | "nbformat": 4, 469 | "nbformat_minor": 4 470 | } 471 | -------------------------------------------------------------------------------- /examples/multi_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import torch 8 | import torch_dct as dct 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | import torch.backends.cudnn as cudnn 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | import matplotlib.pyplot as plt 16 | plt.style.use(['science','no-latex', 'notebook']) 17 | 18 | import time 19 | import sys 20 | import PIL 21 | sys.path.append('../') 22 | from common import * 23 | from transform_based_network import * 24 | 25 | 26 | # In[4]: 27 | 28 | 29 | trainloader, testloader = load_cifar10() 30 | model = Conv_Transform_Net_CIFAR(100) 31 | print(model) 32 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 33 | train_acc, test_acc = train_transform(1, model, trainloader, testloader, optimizer) 34 | 35 | # In[ ]: 36 | 37 | 38 | plt.figure() 39 | plt.title('Convolutional Transform Net Accuracy on CIFAR10') 40 | plt.xlabel('Epoch') 41 | plt.ylabel('Accuracy') 42 | plt.plot(train_acc, label='Train accuracy') 43 | plt.plot(test_acc, label='Test accuracy') 44 | plt.legend() 45 | plt.show() 46 | 47 | -------------------------------------------------------------------------------- /examples/multi_threading_net.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch_dct as dct\n", 11 | "import torch.nn as nn\n", 12 | "import torch.optim as optim\n", 13 | "import torch.nn.functional as F\n", 14 | "import torch.backends.cudnn as cudnn\n", 15 | "from torch.optim.lr_scheduler import StepLR\n", 16 | "\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "\n", 19 | "from multiprocessing import Pool, Queue, Process, set_start_method\n", 20 | "import multiprocessing as mp_\n", 21 | "\n", 22 | "import time\n", 23 | "import pkbar\n", 24 | "import sys\n", 25 | "sys.path.append('../')\n", 26 | "from common import *\n", 27 | "from transform_based_network import *\n", 28 | "from joblib import Parallel, delayed" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "class T_Layer(nn.Module):\n", 38 | " def __init__(self, dct_w, dct_b):\n", 39 | " super(T_Layer, self).__init__()\n", 40 | " self.weights = nn.Parameter(dct_w)\n", 41 | " self.bias = nn.Parameter(dct_b)\n", 42 | " \n", 43 | " def forward(self, dct_x):\n", 44 | " x = torch.mm(self.weights, dct_x) + self.bias\n", 45 | " return x\n", 46 | "\n", 47 | " \n", 48 | "class Frontal_Slice(nn.Module):\n", 49 | " def __init__(self, dct_w, dct_b):\n", 50 | " super(Frontal_Slice, self).__init__()\n", 51 | " self.device = dct_w.device\n", 52 | " self.dct_linear = nn.Sequential(\n", 53 | " T_Layer(dct_w, dct_b),\n", 54 | " )\n", 55 | " #nn.ReLU(inplace=True),\n", 56 | " #self.linear1 = nn.Linear(28, 28)\n", 57 | " #nn.ReLU(inplace=True),\n", 58 | " #self.classifier = nn.Linear(28, 10)\n", 59 | " \n", 60 | " def forward(self, x):\n", 61 | " #x = torch.transpose(x, 0, 1).to(self.device)\n", 62 | " x = self.dct_linear(x)\n", 63 | " #x = self.linear1(x)\n", 64 | " #x = self.classifier(x)\n", 65 | " #x = torch.transpose(x, 0, 1)\n", 66 | " return x" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "def train_slice(i, model, x_i, y, outputs, optimizer):\n", 76 | " s = time.time()\n", 77 | " criterion = nn.CrossEntropyLoss()\n", 78 | " o = torch.stack(outputs)\n", 79 | " o[i, ...] = outputs_grad[i]\n", 80 | " o = torch_apply(dct.idct, o)\n", 81 | " o = scalar_tubal_func(o)\n", 82 | " o = torch.transpose(o, 0, 1)\n", 83 | " \n", 84 | " optimizer.zero_grad()\n", 85 | " loss = criterion(o, y) \n", 86 | " loss.backward()\n", 87 | " optimizer.step()\n", 88 | " e = time.time()\n", 89 | " # print(e - s)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "device = 'cpu'\n", 99 | "batch_size = 100\n", 100 | "trainloader, testloader = load_mnist_multiprocess(batch_size)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "shape = (28, 28, batch_size)\n", 110 | "models = []\n", 111 | "ops = []\n", 112 | "dct_w, dct_b = make_weights(shape, device=device)\n", 113 | "for i in range(shape[0]):\n", 114 | " w_i = dct_w[i, ...].clone()\n", 115 | " b_i = dct_b[i, ...].clone()\n", 116 | " \n", 117 | " w_i.requires_grad = True\n", 118 | " b_i.requires_grad = True\n", 119 | " \n", 120 | " model = Frontal_Slice(w_i, b_i)\n", 121 | " model.train()\n", 122 | " models.append(model.to(device))\n", 123 | " \n", 124 | " op = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n", 125 | " ops.append(op)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "epochs = 10\n", 135 | "acc_list = []\n", 136 | "loss_list = []\n", 137 | "\n", 138 | "global outputs_grad\n", 139 | "for e in range(epochs):\n", 140 | " correct = 0\n", 141 | " total = 0\n", 142 | " losses = 0\n", 143 | " pbar = pkbar.Pbar(name='Epoch '+str(e), target=60000/batch_size)\n", 144 | " for batch_idx, (x, y) in enumerate(trainloader): \n", 145 | " dct_x = torch_shift(x)\n", 146 | " dct_x = torch_apply(dct.dct, dct_x)\n", 147 | "\n", 148 | " dct_x = dct_x.to(device)\n", 149 | " y = y.to(device) \n", 150 | " \n", 151 | " outputs_grad = []\n", 152 | " outputs = []\n", 153 | " \n", 154 | " for i in range(len(models)):\n", 155 | " out = models[i](dct_x[i, ...])\n", 156 | " outputs_grad.append(out)\n", 157 | " outputs.append(out.detach())\n", 158 | " \n", 159 | " #for i in range(len(models)):\n", 160 | " # train_slice(i, models[i], dct_x[i, ...], y, outputs, ops[i])\n", 161 | " \n", 162 | " Parallel(n_jobs=16, prefer=\"threads\", verbose=0)(\n", 163 | " delayed(train_slice)(i, models[i], dct_x[i, ...], y, outputs, ops[i]) \\\n", 164 | " for i in range(len(models))\n", 165 | " )\n", 166 | "\n", 167 | " res = torch.empty(shape[0], 10, shape[2])\n", 168 | " for i in range(len(models)):\n", 169 | " res[i, ...] = models[i](dct_x[i, ...])\n", 170 | " \n", 171 | " res = torch_apply(dct.idct, res).to(device)\n", 172 | " res = scalar_tubal_func(res)\n", 173 | " res = torch.transpose(res, 0, 1)\n", 174 | " criterion = nn.CrossEntropyLoss()\n", 175 | " total_loss = criterion(res, y)\n", 176 | " \n", 177 | " _, predicted = torch.max(res, 1)\n", 178 | " total += y.size(0)\n", 179 | " correct += predicted.eq(y).sum().item()\n", 180 | " losses += total_loss\n", 181 | " \n", 182 | " pbar.update(batch_idx)\n", 183 | " # print(total_loss)\n", 184 | " # print(predicted.eq(y).sum().item() / y.size(0))\n", 185 | " \n", 186 | " loss_list.append(losses / total)\n", 187 | " 3acc_list.append(correct / total)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | " '''\n", 197 | " tmp = torch_mp.get_context('spawn')\n", 198 | " for model in models:\n", 199 | " model.share_memory()\n", 200 | " processes = []\n", 201 | "\n", 202 | " for i in range(len(models)):\n", 203 | " p = tmp.Process(target=train_slice, \n", 204 | " args=(i, models[i], dct_x[i, ...], y, outputs, ops[i]))\n", 205 | " p.start()\n", 206 | " processes.append(p)\n", 207 | " for p in processes: \n", 208 | " p.join()\n", 209 | " '''" 210 | ] 211 | } 212 | ], 213 | "metadata": { 214 | "kernelspec": { 215 | "display_name": "Python 3", 216 | "language": "python", 217 | "name": "python3" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.8.3" 230 | } 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 4 234 | } 235 | -------------------------------------------------------------------------------- /examples/multi_threading_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_dct as dct 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | from torch.optim.lr_scheduler import StepLR 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | from multiprocessing import Pool, Queue, Process, set_start_method 12 | import multiprocessing as mp_ 13 | from joblib import Parallel, delayed 14 | 15 | import time 16 | import pkbar 17 | import sys 18 | sys.path.append('../') 19 | from common import * 20 | from transform_based_network import * 21 | 22 | 23 | # Layer definition 24 | class T_Layer(nn.Module): 25 | def __init__(self, dct_w, dct_b): 26 | super(T_Layer, self).__init__() 27 | self.weights = nn.Parameter(dct_w) 28 | self.bias = nn.Parameter(dct_b) 29 | 30 | def forward(self, dct_x): 31 | x = torch.mm(self.weights, dct_x) + self.bias 32 | return x 33 | 34 | 35 | # Model definition 36 | # This model is going to be run in parallel 37 | class Frontal_Slice(nn.Module): 38 | def __init__(self, dct_w, dct_b): 39 | super(Frontal_Slice, self).__init__() 40 | self.device = dct_w.device 41 | self.dct_linear = nn.Sequential( 42 | T_Layer(dct_w, dct_b), 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.dct_linear(x) 47 | return x 48 | 49 | 50 | # This function conducts backprop for a FrontalSlice model 51 | # This function is going to be run in parallel 52 | def train_slice(i, model, x_i, y, outputs, optimizer): 53 | criterion = nn.CrossEntropyLoss() 54 | o = torch.stack(outputs) 55 | o[i, ...] = outputs_grad[i] 56 | o = torch_apply(dct.idct, o) 57 | o = scalar_tubal_func(o) 58 | o = torch.transpose(o, 0, 1) 59 | 60 | optimizer.zero_grad() 61 | loss = criterion(o, y) 62 | loss.backward() 63 | optimizer.step() 64 | 65 | 66 | # Load data 67 | device = 'cpu' 68 | batch_size = 100 69 | trainloader, testloader = load_mnist_multiprocess(batch_size) 70 | 71 | 72 | # Create m FrontalSlice models and initialize 73 | shape = (28, 28, batch_size) 74 | models = [] 75 | ops = [] 76 | dct_w, dct_b = make_weights(shape, device=device) 77 | for i in range(shape[0]): 78 | w_i = dct_w[i, ...].clone() 79 | b_i = dct_b[i, ...].clone() 80 | 81 | w_i.requires_grad = True 82 | b_i.requires_grad = True 83 | 84 | model = Frontal_Slice(w_i, b_i) 85 | model.train() 86 | models.append(model.to(device)) 87 | op = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 88 | ops.append(op) 89 | 90 | 91 | epochs = 10 92 | acc_list = [] 93 | loss_list = [] 94 | 95 | global outputs_grad 96 | for e in range(epochs): 97 | correct = 0 98 | total = 0 99 | losses = 0 100 | pbar = pkbar.Pbar(name='Epoch '+str(e), target=60000/batch_size) 101 | for batch_idx, (x, y) in enumerate(trainloader): 102 | dct_x = torch_shift(x) 103 | dct_x = torch_apply(dct.dct, dct_x) 104 | 105 | dct_x = dct_x.to(device) 106 | y = y.to(device) 107 | 108 | outputs_grad = [] 109 | outputs = [] 110 | for i in range(len(models)): 111 | out = models[i](dct_x[i, ...]) 112 | outputs_grad.append(out) 113 | outputs.append(out.detach()) 114 | 115 | # This line makes multiple calls to train_slice function 116 | # Parallelization 117 | Parallel(n_jobs=16, prefer="threads", verbose=0)( 118 | delayed(train_slice)(i, models[i], dct_x[i, ...], y, outputs, ops[i]) \ 119 | for i in range(len(models)) 120 | ) 121 | 122 | res = torch.empty(shape[0], 10, shape[2]) 123 | for i in range(len(models)): 124 | res[i, ...] = models[i](dct_x[i, ...]) 125 | 126 | res = torch_apply(dct.idct, res).to(device) 127 | res = scalar_tubal_func(res) 128 | res = torch.transpose(res, 0, 1) 129 | criterion = nn.CrossEntropyLoss() 130 | total_loss = criterion(res, y) 131 | 132 | _, predicted = torch.max(res, 1) 133 | total += y.size(0) 134 | correct += predicted.eq(y).sum().item() 135 | losses += total_loss 136 | 137 | pbar.update(batch_idx) 138 | 139 | loss_list.append(losses / total) 140 | acc_list.append(correct / total) 141 | 142 | 143 | ''' 144 | tmp = torch_mp.get_context('spawn') 145 | for model in models: 146 | model.share_memory() 147 | processes = [] 148 | 149 | for i in range(len(models)): 150 | p = tmp.Process(target=train_slice, 151 | args=(i, models[i], dct_x[i, ...], y, outputs, ops[i])) 152 | p.start() 153 | processes.append(p) 154 | for p in processes: 155 | p.join() 156 | ''' -------------------------------------------------------------------------------- /examples/multiprocess_net.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch_dct as dct\n", 11 | "import torch.nn as nn\n", 12 | "import torch.optim as optim\n", 13 | "import torch.nn.functional as F\n", 14 | "import torch.backends.cudnn as cudnn\n", 15 | "from torch.optim.lr_scheduler import StepLR\n", 16 | "\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "plt.style.use(['science','no-latex', 'notebook'])\n", 19 | "\n", 20 | "from multiprocessing import Pool, Queue, Process, set_start_method\n", 21 | "import multiprocessing as mp_\n", 22 | "\n", 23 | "import time\n", 24 | "import pkbar\n", 25 | "import sys\n", 26 | "sys.path.append('../')\n", 27 | "from common import *\n", 28 | "from transform_based_network import *" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 45, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "class T_Layer(nn.Module):\n", 38 | " def __init__(self, dct_w, dct_b):\n", 39 | " super(T_Layer, self).__init__()\n", 40 | " w = torch.randn(dct_w.shape)\n", 41 | " b = torch.randn(dct_b.shape)\n", 42 | " self.weights = nn.Parameter(dct_w)\n", 43 | " self.bias = nn.Parameter(dct_b)\n", 44 | " \n", 45 | " def forward(self, dct_x):\n", 46 | " x = torch.mm(self.weights, dct_x)# + self.bias\n", 47 | " return x\n", 48 | "\n", 49 | " \n", 50 | "class Frontal_Slice(nn.Module):\n", 51 | " def __init__(self, dct_w, dct_b):\n", 52 | " super(Frontal_Slice, self).__init__()\n", 53 | " self.device = dct_w.device\n", 54 | " self.dct_linear = nn.Sequential(\n", 55 | " T_Layer(dct_w, dct_b),\n", 56 | " )\n", 57 | " #nn.ReLU(inplace=True),\n", 58 | " #self.linear1 = nn.Linear(28, 28)\n", 59 | " #nn.ReLU(inplace=True),\n", 60 | " #self.linear2 = nn.Linear(28, 28)\n", 61 | " #nn.ReLU(inplace=True),\n", 62 | " #self.classifier = nn.Linear(28, 10)\n", 63 | " \n", 64 | " def forward(self, x):\n", 65 | " #x = torch.transpose(x, 0, 1).to(self.device)\n", 66 | " x = self.dct_linear(x)\n", 67 | " #x = self.linear1(x)\n", 68 | " #x = self.linear2(x)\n", 69 | " #x = self.classifier(x)\n", 70 | " #x = torch.transpose(x, 0, 1)\n", 71 | " return x\n", 72 | " \n", 73 | " \n", 74 | "class Ensemble(nn.Module):\n", 75 | " def __init__(self, shape, device='cpu'):\n", 76 | " super(Ensemble, self).__init__()\n", 77 | " self.device = device \n", 78 | " self.models = nn.ModuleList([])\n", 79 | " dct_w, dct_b = make_weights(shape, device, scale=0.001)\n", 80 | " self.weights = nn.Parameter(dct_w)\n", 81 | " self.bias = nn.Parameter(dct_b)\n", 82 | " for i in range(shape[0]):\n", 83 | " model = Frontal_Slice(self.weights[i, ...], self.bias[i, ...])\n", 84 | " self.models.append(model.to(device))\n", 85 | " \n", 86 | " def forward(self, x):\n", 87 | " self.res = torch.empty(x.shape[0], 10, x.shape[2])\n", 88 | " dct_x = torch_apply(dct.dct, x).to(self.device)\n", 89 | " self.tmp = []\n", 90 | " for i in range(len(self.models)):\n", 91 | " self.tmp.append(self.models[i](dct_x[i, ...]))\n", 92 | " self.res[i, ...] = self.tmp[i]\n", 93 | " self.result = torch_apply(dct.idct, self.res)\n", 94 | " self.softmax = scalar_tubal_func(self.result)\n", 95 | " return torch.transpose(self.softmax, 0, 1)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 46, 101 | "metadata": { 102 | "scrolled": true 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "def train_ensemble(x, y, i=50, device='cuda:0'):\n", 107 | " x = torch_shift(x).to(device)\n", 108 | " y = y.to(device)\n", 109 | " ensemble = Ensemble(x.shape, device).to(device)\n", 110 | " optimizer = optim.SGD(ensemble.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)\n", 111 | " criterion = nn.CrossEntropyLoss()\n", 112 | " pbar = pkbar.Pbar(name='progress', target=i)\n", 113 | " for j in range(i):\n", 114 | " outputs = ensemble(x)\n", 115 | " print(outputs.shape, y.shape)\n", 116 | " optimizer.zero_grad()\n", 117 | " loss = criterion(outputs.to(device), y)\n", 118 | " loss.backward()\n", 119 | " optimizer.step()\n", 120 | " pbar.update(j)\n", 121 | " \n", 122 | " print(loss.item())\n", 123 | " return ensemble\n", 124 | "\n", 125 | "## 16, 10, 10, 100 iterations\n", 126 | "# cpu, for loop: 4.1s\n", 127 | "# gpu, for loop: 5.5s" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 47, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "progress\n", 140 | "torch.Size([16, 10]) torch.Size([16])\n", 141 | "1/2 [==============>...............] - 0.1storch.Size([16, 10]) torch.Size([16])\n", 142 | "2/2 [==============================] - 0.1s\n", 143 | "2.179849624633789\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "x0 = []\n", 149 | "y0 = []\n", 150 | "for i in range(100):\n", 151 | " x0.append(torch.randn(16, 29, 28))\n", 152 | " y0.append(torch.randint(10, (16,)))\n", 153 | "\n", 154 | "for i in range(1):\n", 155 | " train_ensemble(x0[i], y0[i], i=2, device='cpu')" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 48, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "==> Loading data..\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "batch_size = 10\n", 173 | "trainloader, testloader = load_mnist_multiprocess(batch_size)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 49, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "Epoch0\n", 186 | "tensor(2.2876, grad_fn=)\n", 187 | " 1/6000 [..............................] - 0.0stensor(2.2945, grad_fn=)\n", 188 | " 2/6000 [..............................] - 0.1stensor(2.3091, grad_fn=)\n", 189 | " 3/6000 [..............................] - 0.1stensor(2.3509, grad_fn=)\n", 190 | " 4/6000 [..............................] - 0.2stensor(2.3103, grad_fn=)\n", 191 | " 5/6000 [..............................] - 0.3stensor(2.3188, grad_fn=)\n", 192 | " 6/6000 [..............................] - 0.3stensor(2.2883, grad_fn=)\n", 193 | " 7/6000 [..............................] - 0.4stensor(2.3286, grad_fn=)\n", 194 | " 8/6000 [..............................] - 0.4stensor(2.2519, grad_fn=)\n", 195 | " 9/6000 [..............................] - 0.5stensor(2.3762, grad_fn=)\n", 196 | " 10/6000 [..............................] - 0.5stensor(2.2961, grad_fn=)\n", 197 | " 11/6000 [..............................] - 0.6stensor(2.3284, grad_fn=)\n", 198 | " 12/6000 [..............................] - 0.6stensor(2.2897, grad_fn=)\n", 199 | " 13/6000 [..............................] - 0.7stensor(2.3149, grad_fn=)\n", 200 | " 14/6000 [..............................] - 0.7stensor(2.3099, grad_fn=)\n", 201 | " 15/6000 [..............................] - 0.8stensor(2.3955, grad_fn=)\n", 202 | " 16/6000 [..............................] - 0.8stensor(2.3513, grad_fn=)\n", 203 | " 17/6000 [..............................] - 0.9stensor(2.3118, grad_fn=)\n", 204 | " 18/6000 [..............................] - 0.9stensor(2.2473, grad_fn=)\n", 205 | " 19/6000 [..............................] - 1.0stensor(2.2799, grad_fn=)\n", 206 | " 20/6000 [..............................] - 1.0stensor(2.3112, grad_fn=)\n", 207 | " 21/6000 [..............................] - 1.1stensor(2.2879, grad_fn=)\n", 208 | " 22/6000 [..............................] - 1.1stensor(2.3219, grad_fn=)\n", 209 | " 23/6000 [..............................] - 1.2stensor(2.2602, grad_fn=)\n", 210 | " 24/6000 [..............................] - 1.2stensor(2.3284, grad_fn=)\n", 211 | " 25/6000 [..............................] - 1.3stensor(2.3766, grad_fn=)\n", 212 | " 26/6000 [..............................] - 1.3stensor(2.3589, grad_fn=)\n", 213 | " 27/6000 [..............................] - 1.4stensor(2.2433, grad_fn=)\n", 214 | " 28/6000 [..............................] - 1.4stensor(2.2943, grad_fn=)\n", 215 | " 29/6000 [..............................] - 1.5stensor(2.2615, grad_fn=)\n", 216 | " 30/6000 [..............................] - 1.5stensor(2.2647, grad_fn=)\n", 217 | " 31/6000 [..............................] - 1.6stensor(2.2816, grad_fn=)\n", 218 | " 32/6000 [..............................] - 1.6stensor(2.2873, grad_fn=)\n", 219 | " 33/6000 [..............................] - 1.7stensor(2.3596, grad_fn=)\n", 220 | " 34/6000 [..............................] - 1.7stensor(2.3045, grad_fn=)\n", 221 | " 35/6000 [..............................] - 1.8stensor(2.3835, grad_fn=)\n", 222 | " 36/6000 [..............................] - 1.8stensor(2.3288, grad_fn=)\n", 223 | " 37/6000 [..............................] - 1.8stensor(2.2843, grad_fn=)\n", 224 | " 38/6000 [..............................] - 1.9stensor(2.3762, grad_fn=)\n", 225 | " 39/6000 [..............................] - 1.9stensor(2.2595, grad_fn=)\n", 226 | " 40/6000 [..............................] - 2.0stensor(2.3851, grad_fn=)\n", 227 | " 41/6000 [..............................] - 2.0stensor(2.3167, grad_fn=)\n", 228 | " 42/6000 [..............................] - 2.1stensor(2.3009, grad_fn=)\n", 229 | " 43/6000 [..............................] - 2.1stensor(2.2561, grad_fn=)\n", 230 | " 44/6000 [..............................] - 2.2stensor(2.3433, grad_fn=)\n", 231 | " 45/6000 [..............................] - 2.2stensor(2.2998, grad_fn=)\n", 232 | " 46/6000 [..............................] - 2.3stensor(2.3067, grad_fn=)\n", 233 | " 47/6000 [..............................] - 2.3stensor(2.3511, grad_fn=)\n", 234 | " 48/6000 [..............................] - 2.4stensor(2.2981, grad_fn=)\n", 235 | " 49/6000 [..............................] - 2.4stensor(2.3214, grad_fn=)\n", 236 | " 50/6000 [..............................] - 2.5stensor(2.3309, grad_fn=)\n", 237 | " 51/6000 [..............................] - 2.5stensor(2.3063, grad_fn=)\n", 238 | " 52/6000 [..............................] - 2.6stensor(2.3466, grad_fn=)\n", 239 | " 53/6000 [..............................] - 2.6stensor(2.2675, grad_fn=)\n", 240 | " 54/6000 [..............................] - 2.7stensor(2.3141, grad_fn=)\n", 241 | " 55/6000 [..............................] - 2.7stensor(2.3184, grad_fn=)\n", 242 | " 56/6000 [..............................] - 2.8stensor(2.3438, grad_fn=)\n", 243 | " 57/6000 [..............................] - 2.8stensor(2.3158, grad_fn=)\n", 244 | " 58/6000 [..............................] - 2.9stensor(2.5150, grad_fn=)\n", 245 | " 59/6000 [..............................] - 2.9stensor(2.2994, grad_fn=)\n", 246 | " 60/6000 [..............................] - 3.0stensor(2.3718, grad_fn=)\n", 247 | " 61/6000 [..............................] - 3.0stensor(2.3716, grad_fn=)\n", 248 | " 62/6000 [..............................] - 3.1stensor(2.3215, grad_fn=)\n", 249 | " 63/6000 [..............................] - 3.1stensor(2.3169, grad_fn=)\n", 250 | " 64/6000 [..............................] - 3.2stensor(2.3112, grad_fn=)\n", 251 | " 65/6000 [..............................] - 3.2stensor(2.3542, grad_fn=)\n", 252 | " 66/6000 [..............................] - 3.2stensor(2.3217, grad_fn=)\n", 253 | " 67/6000 [..............................] - 3.3stensor(2.2733, grad_fn=)\n", 254 | " 68/6000 [..............................] - 3.3stensor(2.3719, grad_fn=)\n", 255 | " 69/6000 [..............................] - 3.4stensor(2.3289, grad_fn=)\n", 256 | " 70/6000 [..............................] - 3.4stensor(2.2587, grad_fn=)\n", 257 | " 71/6000 [..............................] - 3.5stensor(2.4049, grad_fn=)\n", 258 | " 72/6000 [..............................] - 3.5stensor(2.3286, grad_fn=)\n", 259 | " 73/6000 [..............................] - 3.6stensor(2.4352, grad_fn=)\n", 260 | " 74/6000 [..............................] - 3.6stensor(2.2177, grad_fn=)\n", 261 | " 75/6000 [..............................] - 3.7stensor(2.3477, grad_fn=)\n", 262 | " 76/6000 [..............................] - 3.7stensor(2.3066, grad_fn=)\n", 263 | " 77/6000 [..............................] - 3.8stensor(2.3314, grad_fn=)\n", 264 | " 78/6000 [..............................] - 3.8stensor(2.2981, grad_fn=)\n", 265 | " 79/6000 [..............................] - 3.9stensor(2.3615, grad_fn=)\n", 266 | " 80/6000 [..............................] - 3.9stensor(2.4292, grad_fn=)\n", 267 | " 81/6000 [..............................] - 4.0stensor(2.3730, grad_fn=)\n", 268 | " 82/6000 [..............................] - 4.0stensor(2.3759, grad_fn=)\n", 269 | " 83/6000 [..............................] - 4.1stensor(2.3449, grad_fn=)\n", 270 | " 84/6000 [..............................] - 4.1stensor(2.3138, grad_fn=)\n", 271 | " 85/6000 [..............................] - 4.2stensor(2.2498, grad_fn=)\n", 272 | " 86/6000 [..............................] - 4.2stensor(2.2832, grad_fn=)\n", 273 | " 87/6000 [..............................] - 4.3stensor(2.2201, grad_fn=)\n", 274 | " 88/6000 [..............................] - 4.3stensor(2.3363, grad_fn=)\n" 275 | ] 276 | }, 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | " 89/6000 [..............................] - 4.4stensor(2.2559, grad_fn=)\n", 282 | " 90/6000 [..............................] - 4.4stensor(2.3373, grad_fn=)\n", 283 | " 91/6000 [..............................] - 4.5stensor(2.2490, grad_fn=)\n", 284 | " 92/6000 [..............................] - 4.5stensor(2.2938, grad_fn=)\n", 285 | " 93/6000 [..............................] - 4.6stensor(2.2821, grad_fn=)\n", 286 | " 94/6000 [..............................] - 4.6stensor(2.2545, grad_fn=)\n", 287 | " 95/6000 [..............................] - 4.7stensor(2.3990, grad_fn=)\n", 288 | " 96/6000 [..............................] - 4.7stensor(2.2276, grad_fn=)\n", 289 | " 97/6000 [..............................] - 4.7stensor(2.4132, grad_fn=)\n", 290 | " 98/6000 [..............................] - 4.8s" 291 | ] 292 | }, 293 | { 294 | "ename": "KeyboardInterrupt", 295 | "evalue": "", 296 | "output_type": "error", 297 | "traceback": [ 298 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 299 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 300 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 28\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 301 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m 193\u001b[0m \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 194\u001b[0m \"\"\"\n\u001b[1;32m--> 195\u001b[1;33m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 196\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 197\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 302 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m 97\u001b[0m Variable._execution_engine.run_backward(\n\u001b[0;32m 98\u001b[0m \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 99\u001b[1;33m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[0;32m 100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 303 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 304 | ] 305 | } 306 | ], 307 | "source": [ 308 | "device = 'cpu'\n", 309 | "for epoch in range(10):\n", 310 | " pbar = pkbar.Pbar(name='Epoch'+str(epoch), target=60000/batch_size)\n", 311 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 312 | " '''\n", 313 | " dct_x = torch_apply(dct.dct, x.squeeze())\n", 314 | " y_cat = to_categorical(y, 10) \n", 315 | "\n", 316 | " dct_y_cat = torch.randn(y_cat.shape[0], dct_x.shape[1], 10)\n", 317 | " for i in range(10):\n", 318 | " dct_y_cat[:, i, :] = y_cat\n", 319 | " dct_y_cat = torch_apply(dct.dct, dct_y_cat)\n", 320 | " dct_x.to(device)\n", 321 | " dct_y_cat.to(device)\n", 322 | " '''\n", 323 | " correct = 0\n", 324 | " train_loss = 0\n", 325 | " total = 0\n", 326 | " inputs = torch_shift(inputs).to(device)\n", 327 | " ensemble = Ensemble(inputs.shape, device).to(device)\n", 328 | " optimizer = optim.SGD(ensemble.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n", 329 | " criterion = nn.CrossEntropyLoss()\n", 330 | "\n", 331 | " outputs = ensemble(inputs) \n", 332 | " optimizer.zero_grad()\n", 333 | " loss = criterion(outputs.to(device), targets.to(device))\n", 334 | " loss.backward()\n", 335 | " optimizer.step()\n", 336 | " \n", 337 | " _, predicted = torch.max(outputs, 1)\n", 338 | " correct += predicted.eq(targets).sum().item()\n", 339 | " train_loss += loss.item()\n", 340 | " total += batch_size\n", 341 | " print(loss)\n", 342 | " \n", 343 | " pbar.update(batch_idx)\n", 344 | " print(correct/total, train_loss/total)\n", 345 | " \n", 346 | "\n", 347 | "'''\n", 348 | " models = []\n", 349 | " for i in range(16):\n", 350 | " dct_w, dct_b = make_weights(dct_x.shape, device=device)\n", 351 | " model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...])\n", 352 | " models.append(model.to(device))\n", 353 | "\n", 354 | " for i in range(len(models)):\n", 355 | " train_slice(models[i], dct_x[i, ...], dct_y_cat[i, ...])\n", 356 | " print()\n", 357 | " pbar.update(batch_idx)\n", 358 | " \n", 359 | " tmp = torch_mp.get_context('spawn')\n", 360 | " for model in models:\n", 361 | " model.share_memory()\n", 362 | " processes = []\n", 363 | "\n", 364 | " for i in range(len(models)):\n", 365 | " p = tmp.Process(target=train_slice, \n", 366 | " args=(models[i], dct_x[i, ...], dct_y_cat[i, ...]))\n", 367 | " p.start()\n", 368 | " processes.append(p)\n", 369 | " for p in processes: \n", 370 | " p.join()\n", 371 | " '''" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 25, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "def train_slice(model, x_i, y_i):\n", 381 | " criterion = nn.MSELoss()\n", 382 | " optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=5e-4)\n", 383 | " outputs = model(x_i)\n", 384 | " # print(outputs.shape, y_i.shape)\n", 385 | " optimizer.zero_grad()\n", 386 | " loss = criterion(outputs, y_i)\n", 387 | " loss.backward()\n", 388 | " optimizer.step()" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 56, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "for batch_idx, (x, y) in enumerate(trainloader): \n", 398 | " device = 'cpu'\n", 399 | " x = torch_shift(x)\n", 400 | " dct_x = torch_apply(dct.dct, x.squeeze())\n", 401 | " y_cat = to_categorical(y, 10) \n", 402 | "\n", 403 | " dct_y_cat = torch.randn(28, dct_x.shape[2], 10) #y_cat.shape[0]\n", 404 | " for i in range(28):\n", 405 | " dct_y_cat[i, :, :] = y_cat\n", 406 | " dct_y_cat = torch_apply(dct.dct, dct_y_cat)\n", 407 | " dct_x.to(device)\n", 408 | " dct_y_cat.to(device)\n", 409 | " \n", 410 | " models = []\n", 411 | " dct_w, dct_b = make_weights(dct_x.shape, device=device)\n", 412 | " for i in range(28):\n", 413 | " model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...])\n", 414 | " models.append(model.to(device))\n", 415 | "\n", 416 | " for i in range(len(models)):\n", 417 | " train_slice(models[i], dct_x[i, ...], dct_y_cat[i, ...])" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 61, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "y = torch.eye(10)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 74, 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "dct_yy = torch.empty(28, 10, 10)\n", 443 | "for i in range(28):\n", 444 | " dct_yy[i, ...] = y * 1\n", 445 | "dct_yy = torch_apply(dct.dct, dct_yy)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 75, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "text/plain": [ 456 | "tensor([[12.6239, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n", 457 | " 1.7085, 1.7085],\n", 458 | " [ 1.7085, 12.6239, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n", 459 | " 1.7085, 1.7085],\n", 460 | " [ 1.7085, 1.7085, 12.6239, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n", 461 | " 1.7085, 1.7085],\n", 462 | " [ 1.7085, 1.7085, 1.7085, 12.6239, 1.7085, 1.7085, 1.7085, 1.7085,\n", 463 | " 1.7085, 1.7085],\n", 464 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 12.6239, 1.7085, 1.7085, 1.7085,\n", 465 | " 1.7085, 1.7085],\n", 466 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 12.6239, 1.7085, 1.7085,\n", 467 | " 1.7085, 1.7085],\n", 468 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 12.6239, 1.7085,\n", 469 | " 1.7085, 1.7085],\n", 470 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 12.6239,\n", 471 | " 1.7085, 1.7085],\n", 472 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n", 473 | " 12.6239, 1.7085],\n", 474 | " [ 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085, 1.7085,\n", 475 | " 1.7085, 12.6239]])" 476 | ] 477 | }, 478 | "execution_count": 75, 479 | "metadata": {}, 480 | "output_type": "execute_result" 481 | } 482 | ], 483 | "source": [ 484 | "result = torch_apply(dct.idct, dct_yy)\n", 485 | "softmax = scalar_tubal_func(result)\n", 486 | "torch.transpose(softmax, 0, 1)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [] 495 | } 496 | ], 497 | "metadata": { 498 | "kernelspec": { 499 | "display_name": "Python 3", 500 | "language": "python", 501 | "name": "python3" 502 | }, 503 | "language_info": { 504 | "codemirror_mode": { 505 | "name": "ipython", 506 | "version": 3 507 | }, 508 | "file_extension": ".py", 509 | "mimetype": "text/x-python", 510 | "name": "python", 511 | "nbconvert_exporter": "python", 512 | "pygments_lexer": "ipython3", 513 | "version": "3.7.6" 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 4 518 | } 519 | -------------------------------------------------------------------------------- /examples/parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_dct as dct 5 | 6 | import sys 7 | sys.path.append('../') 8 | from common import * 9 | from transform_based_network import * 10 | 11 | from torch.multiprocessing import Pool, Queue, Process, set_start_method 12 | import torch.multiprocessing as torch_mp 13 | import multiprocessing as mp 14 | 15 | if __name__ == '__main__': 16 | tmp = torch_mp.get_context('spawn') 17 | 18 | trainloader, testloader = load_mnist() 19 | model = Transform_Net(100) 20 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 21 | 22 | model.share_memory() 23 | 24 | processes = [] 25 | num_cores = torch_mp.cpu_count() 26 | for i in range(num_cores): 27 | # q = Queue() 28 | p = tmp.Process(target=train_transform, args=(1, model, trainloader, testloader, optimizer)) 29 | p.start() 30 | # print(q.get()) 31 | processes.append(p) 32 | 33 | for p in processes: 34 | p.join() -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .lenet import * 2 | from .vgg import * -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torchvision import models 11 | 12 | class LeNet(nn.Module): 13 | def __init__(self): 14 | super(LeNet, self).__init__() 15 | self.features = nn.Sequential( 16 | nn.Conv2d(1, 32, 3, 1), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(32, 64, 3, 1), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=2, stride=2), 21 | ) 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(0.25), 24 | nn.Linear(9216, 128), 25 | nn.ReLU(inplace=True), 26 | nn.Dropout(0.5), 27 | nn.Linear(128, 10), 28 | ) 29 | 30 | def forward(self, x): 31 | x = self.features(x) 32 | x = torch.flatten(x, 1) 33 | x = self.classifier(x) 34 | output = F.log_softmax(x, dim=1) 35 | return output -------------------------------------------------------------------------------- /nets/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torchvision import models 11 | 12 | cfg = { 13 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 14 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 15 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 16 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 17 | } 18 | 19 | class VGG(nn.Module): 20 | def __init__(self, vgg_name): 21 | super(VGG, self).__init__() 22 | self.features = self._make_layers(cfg[vgg_name]) 23 | self.classifier = nn.Sequential( 24 | nn.Dropout(0.25), 25 | nn.Linear(512, 128), 26 | nn.ReLU(inplace=True), 27 | nn.Dropout(0.25), 28 | nn.Linear(128, 10), 29 | ) 30 | 31 | def forward(self, x): 32 | out = self.features(x) 33 | out = out.view(out.size(0), -1) 34 | out = self.classifier(out) 35 | return out 36 | 37 | def _make_layers(self, cfg): 38 | layers = [] 39 | in_channels = 3 40 | for x in cfg: 41 | if x == 'M': 42 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 43 | else: 44 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(x), 46 | nn.ReLU(inplace=True)] 47 | in_channels = x 48 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 49 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /notebooks/eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import cProfile\n", 10 | "import time\n", 11 | "import pstats\n", 12 | "from main import *" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 3, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "==> Loading data..\n", 25 | "Files already downloaded and verified\n", 26 | "Files already downloaded and verified\n", 27 | "==> Building model..\n", 28 | "VGG(\n", 29 | " (features): Sequential(\n", 30 | " (0): Sequential(\n", 31 | " (0): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 32 | " (1): Conv2d(1, 22, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 33 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n", 34 | " )\n", 35 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 36 | " (2): ReLU(inplace=True)\n", 37 | " (3): Sequential(\n", 38 | " (0): Conv2d(64, 22, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 39 | " (1): Conv2d(22, 22, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 40 | " (2): Conv2d(22, 64, kernel_size=(1, 1), stride=(1, 1))\n", 41 | " )\n", 42 | " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 43 | " (5): ReLU(inplace=True)\n", 44 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 45 | " (7): Sequential(\n", 46 | " (0): Conv2d(64, 22, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 47 | " (1): Conv2d(22, 43, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 48 | " (2): Conv2d(43, 128, kernel_size=(1, 1), stride=(1, 1))\n", 49 | " )\n", 50 | " (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 51 | " (9): ReLU(inplace=True)\n", 52 | " (10): Sequential(\n", 53 | " (0): Conv2d(128, 43, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 54 | " (1): Conv2d(43, 43, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 55 | " (2): Conv2d(43, 128, kernel_size=(1, 1), stride=(1, 1))\n", 56 | " )\n", 57 | " (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 58 | " (12): ReLU(inplace=True)\n", 59 | " (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 60 | " (14): Sequential(\n", 61 | " (0): Conv2d(128, 43, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 62 | " (1): Conv2d(43, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 63 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n", 64 | " )\n", 65 | " (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 66 | " (16): ReLU(inplace=True)\n", 67 | " (17): Sequential(\n", 68 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 69 | " (1): Conv2d(86, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 70 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n", 71 | " )\n", 72 | " (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 73 | " (19): ReLU(inplace=True)\n", 74 | " (20): Sequential(\n", 75 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 76 | " (1): Conv2d(86, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 77 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n", 78 | " )\n", 79 | " (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 80 | " (22): ReLU(inplace=True)\n", 81 | " (23): Sequential(\n", 82 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 83 | " (1): Conv2d(86, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 84 | " (2): Conv2d(86, 256, kernel_size=(1, 1), stride=(1, 1))\n", 85 | " )\n", 86 | " (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 87 | " (25): ReLU(inplace=True)\n", 88 | " (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 89 | " (27): Sequential(\n", 90 | " (0): Conv2d(256, 86, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 91 | " (1): Conv2d(86, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 92 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 93 | " )\n", 94 | " (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 95 | " (29): ReLU(inplace=True)\n", 96 | " (30): Sequential(\n", 97 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 98 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 99 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 100 | " )\n", 101 | " (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 102 | " (32): ReLU(inplace=True)\n", 103 | " (33): Sequential(\n", 104 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 105 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 106 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 107 | " )\n", 108 | " (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 109 | " (35): ReLU(inplace=True)\n", 110 | " (36): Sequential(\n", 111 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 112 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 113 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 114 | " )\n", 115 | " (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 116 | " (38): ReLU(inplace=True)\n", 117 | " (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 118 | " (40): Sequential(\n", 119 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 120 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 121 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 122 | " )\n", 123 | " (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 124 | " (42): ReLU(inplace=True)\n", 125 | " (43): Sequential(\n", 126 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 127 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 128 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 129 | " )\n", 130 | " (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 131 | " (45): ReLU(inplace=True)\n", 132 | " (46): Sequential(\n", 133 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 134 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 135 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 136 | " )\n", 137 | " (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 138 | " (48): ReLU(inplace=True)\n", 139 | " (49): Sequential(\n", 140 | " (0): Conv2d(512, 171, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 141 | " (1): Conv2d(171, 171, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 142 | " (2): Conv2d(171, 512, kernel_size=(1, 1), stride=(1, 1))\n", 143 | " )\n", 144 | " (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 145 | " (51): ReLU(inplace=True)\n", 146 | " (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 147 | " (53): AvgPool2d(kernel_size=1, stride=1, padding=0)\n", 148 | " )\n", 149 | " (classifier): Linear(in_features=512, out_features=10, bias=True)\n", 150 | ")\n", 151 | "==> Done\n", 152 | "\n", 153 | "Epoch: 0\n", 154 | "|========================================| Accuracy: 10.832 % 5416 / 50000\n", 155 | "|==========| Accuracy: 12.12 % 1212 / 10000\n", 156 | "This epoch took 34.11177968978882 seconds\n", 157 | "Current learning rate: 0.05\n", 158 | "\n", 159 | "Epoch: 1\n", 160 | "|========================================| Accuracy: 11.602 % 5801 / 50000\n", 161 | "|==========| Accuracy: 11.53 % 1153 / 10000\n", 162 | "This epoch took 32.89801287651062 seconds\n", 163 | "Current learning rate: 0.05\n", 164 | "\n", 165 | "Epoch: 2\n", 166 | "|========================================| Accuracy: 11.694 % 5847 / 50000\n", 167 | "|==========| Accuracy: 11.77 % 1177 / 10000\n", 168 | "This epoch took 32.95981407165527 seconds\n", 169 | "Current learning rate: 0.05\n", 170 | "\n", 171 | "Epoch: 3\n", 172 | "|========================================| Accuracy: 12.656 % 6328 / 50000\n", 173 | "|==========| Accuracy: 13.45 % 1345 / 10000\n", 174 | "This epoch took 33.00275278091431 seconds\n", 175 | "Current learning rate: 0.05\n", 176 | "\n", 177 | "Epoch: 4\n", 178 | "|========================================| Accuracy: 17.582 % 8791 / 50000\n", 179 | "|==========| Accuracy: 17.5 % 1750 / 10000\n", 180 | "This epoch took 32.84714198112488 seconds\n", 181 | "Current learning rate: 0.04050000000000001\n", 182 | "\n", 183 | "Epoch: 5\n", 184 | "|========================================| Accuracy: 20.526 % 10263 / 50000\n", 185 | "|==========| Accuracy: 21.94 % 2194 / 10000\n", 186 | "This epoch took 32.7773003578186 seconds\n", 187 | "Current learning rate: 0.045000000000000005\n", 188 | "\n", 189 | "Epoch: 6\n", 190 | "|========================================| Accuracy: 23.64 % 11820 / 50000\n", 191 | "|==========| Accuracy: 27.2 % 2720 / 10000\n", 192 | "This epoch took 33.12140941619873 seconds\n", 193 | "Current learning rate: 0.045000000000000005\n", 194 | "\n", 195 | "Epoch: 7\n" 196 | ] 197 | }, 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "|========================================| Accuracy: 25.932 % 12966 / 50000\n", 203 | "|==========| Accuracy: 26.37 % 2637 / 10000\n", 204 | "This epoch took 33.11143517494202 seconds\n", 205 | "Current learning rate: 0.045000000000000005\n", 206 | "\n", 207 | "Epoch: 8\n", 208 | "|========================================| Accuracy: 28.86 % 14430 / 50000\n", 209 | "|==========| Accuracy: 24.97 % 2497 / 10000\n", 210 | "This epoch took 35.06288194656372 seconds\n", 211 | "Current learning rate: 0.045000000000000005\n", 212 | "\n", 213 | "Epoch: 9\n", 214 | "|========================================| Accuracy: 31.124 % 15562 / 50000\n", 215 | "|==========| Accuracy: 30.32 % 3032 / 10000\n", 216 | "This epoch took 35.821218967437744 seconds\n", 217 | "Current learning rate: 0.03645000000000001\n", 218 | "\n", 219 | "Epoch: 10\n", 220 | "|========================================| Accuracy: 32.398 % 16199 / 50000\n", 221 | "|==========| Accuracy: 15.31 % 1531 / 10000\n", 222 | "This epoch took 37.601341009140015 seconds\n", 223 | "Current learning rate: 0.04050000000000001\n", 224 | "\n", 225 | "Epoch: 11\n", 226 | "|========================================| Accuracy: 32.892 % 16446 / 50000\n", 227 | "|==========| Accuracy: 28.75 % 2875 / 10000\n", 228 | "This epoch took 43.43238830566406 seconds\n", 229 | "Current learning rate: 0.04050000000000001\n", 230 | "\n", 231 | "Epoch: 12\n", 232 | "|========================================| Accuracy: 33.482 % 16741 / 50000\n", 233 | "|==========| Accuracy: 20.39 % 2039 / 10000\n", 234 | "This epoch took 42.75646662712097 seconds\n", 235 | "Current learning rate: 0.04050000000000001\n", 236 | "\n", 237 | "Epoch: 13\n", 238 | "|========================================| Accuracy: 33.716 % 16858 / 50000\n", 239 | "|==========| Accuracy: 19.55 % 1955 / 10000\n", 240 | "This epoch took 40.91100358963013 seconds\n", 241 | "Current learning rate: 0.04050000000000001\n", 242 | "\n", 243 | "Epoch: 14\n", 244 | "|========================================| Accuracy: 33.624 % 16812 / 50000\n", 245 | "|==========| Accuracy: 30.54 % 3054 / 10000\n", 246 | "This epoch took 40.79039931297302 seconds\n", 247 | "Current learning rate: 0.03280500000000001\n", 248 | "\n", 249 | "Epoch: 15\n", 250 | "|========================================| Accuracy: 33.68 % 16840 / 50000\n", 251 | "|==========| Accuracy: 31.88 % 3188 / 10000\n", 252 | "This epoch took 39.57789707183838 seconds\n", 253 | "Current learning rate: 0.03645000000000001\n", 254 | "\n", 255 | "Epoch: 16\n", 256 | "|========================================| Accuracy: 33.614 % 16807 / 50000\n", 257 | "|==========| Accuracy: 29.32 % 2932 / 10000\n", 258 | "This epoch took 41.83706545829773 seconds\n", 259 | "Current learning rate: 0.03645000000000001\n", 260 | "\n", 261 | "Epoch: 17\n", 262 | "|========================================| Accuracy: 36.988 % 18494 / 50000\n", 263 | "|==========| Accuracy: 39.15 % 3915 / 10000\n", 264 | "This epoch took 40.11329388618469 seconds\n", 265 | "Current learning rate: 0.03645000000000001\n", 266 | "\n", 267 | "Epoch: 18\n", 268 | "|========================================| Accuracy: 42.32 % 21160 / 50000\n", 269 | "|==========| Accuracy: 42.6 % 4260 / 10000\n", 270 | "This epoch took 43.18025541305542 seconds\n", 271 | "Current learning rate: 0.03645000000000001\n", 272 | "\n", 273 | "Epoch: 19\n", 274 | "|========================================| Accuracy: 44.63 % 22315 / 50000\n", 275 | "|==========| Accuracy: 36.87 % 3687 / 10000\n", 276 | "This epoch took 41.14051365852356 seconds\n", 277 | "Current learning rate: 0.02952450000000001\n", 278 | "\n", 279 | "Epoch: 20\n", 280 | "|========================================| Accuracy: 47.01 % 23505 / 50000\n", 281 | "|==========| Accuracy: 39.65 % 3965 / 10000\n", 282 | "This epoch took 42.76271724700928 seconds\n", 283 | "Current learning rate: 0.03280500000000001\n", 284 | "\n", 285 | "Epoch: 21\n", 286 | "|========================================| Accuracy: 47.842 % 23921 / 50000\n", 287 | "|==========| Accuracy: 42.14 % 4214 / 10000\n", 288 | "This epoch took 43.29025745391846 seconds\n", 289 | "Current learning rate: 0.03280500000000001\n", 290 | "\n", 291 | "Epoch: 22\n", 292 | "|========================================| Accuracy: 49.414 % 24707 / 50000\n", 293 | "|==========| Accuracy: 43.76 % 4376 / 10000\n", 294 | "This epoch took 45.87770104408264 seconds\n", 295 | "Current learning rate: 0.03280500000000001\n", 296 | "\n", 297 | "Epoch: 23\n", 298 | "|========================================| Accuracy: 50.198 % 25099 / 50000\n", 299 | "|==========| Accuracy: 50.9 % 5090 / 10000\n", 300 | "This epoch took 43.95064091682434 seconds\n", 301 | "Current learning rate: 0.03280500000000001\n", 302 | "\n", 303 | "Epoch: 24\n", 304 | "|========================================| Accuracy: 50.882 % 25441 / 50000\n", 305 | "|==========| Accuracy: 39.98 % 3998 / 10000\n", 306 | "This epoch took 44.10254693031311 seconds\n", 307 | "Current learning rate: 0.02657205000000001\n", 308 | "\n", 309 | "Epoch: 25\n", 310 | "|========================================| Accuracy: 52.896 % 26448 / 50000\n", 311 | "|==========| Accuracy: 40.97 % 4097 / 10000\n", 312 | "This epoch took 44.637322187423706 seconds\n", 313 | "Current learning rate: 0.02952450000000001\n", 314 | "\n", 315 | "Epoch: 26\n", 316 | "|========================================| Accuracy: 53.682 % 26841 / 50000\n", 317 | "|==========| Accuracy: 48.24 % 4824 / 10000\n", 318 | "This epoch took 43.39285159111023 seconds\n", 319 | "Current learning rate: 0.02952450000000001\n", 320 | "\n", 321 | "Epoch: 27\n", 322 | "|========================================| Accuracy: 54.718 % 27359 / 50000\n", 323 | "|==========| Accuracy: 53.49 % 5349 / 10000\n", 324 | "This epoch took 41.88004779815674 seconds\n", 325 | "Current learning rate: 0.02952450000000001\n", 326 | "\n", 327 | "Epoch: 28\n", 328 | "|========================================| Accuracy: 54.854 % 27427 / 50000\n", 329 | "|==========| Accuracy: 49.18 % 4918 / 10000\n", 330 | "This epoch took 45.3488028049469 seconds\n", 331 | "Current learning rate: 0.02952450000000001\n", 332 | "\n", 333 | "Epoch: 29\n", 334 | "|========================================| Accuracy: 55.854 % 27927 / 50000\n", 335 | "|==========| Accuracy: 45.03 % 4503 / 10000\n", 336 | "This epoch took 42.54808688163757 seconds\n", 337 | "Current learning rate: 0.02391484500000001\n", 338 | "\n", 339 | "Epoch: 30\n", 340 | "|========================================| Accuracy: 58.202 % 29101 / 50000\n", 341 | "|==========| Accuracy: 50.08 % 5008 / 10000\n", 342 | "This epoch took 46.533724308013916 seconds\n", 343 | "Current learning rate: 0.02657205000000001\n", 344 | "\n", 345 | "Epoch: 31\n", 346 | "|========================================| Accuracy: 59.262 % 29631 / 50000\n", 347 | "|==========| Accuracy: 50.52 % 5052 / 10000\n", 348 | "This epoch took 46.0679144859314 seconds\n", 349 | "Current learning rate: 0.02657205000000001\n", 350 | "\n", 351 | "Epoch: 32\n", 352 | "|========================================| Accuracy: 60.114 % 30057 / 50000\n", 353 | "|==========| Accuracy: 52.74 % 5274 / 10000\n", 354 | "This epoch took 44.88120102882385 seconds\n", 355 | "Current learning rate: 0.02657205000000001\n", 356 | "\n", 357 | "Epoch: 33\n", 358 | "|========================================| Accuracy: 60.54 % 30270 / 50000\n", 359 | "|==========| Accuracy: 55.82 % 5582 / 10000\n", 360 | "This epoch took 38.706753730773926 seconds\n", 361 | "Current learning rate: 0.02657205000000001\n", 362 | "\n", 363 | "Epoch: 34\n", 364 | "|========================================| Accuracy: 61.202 % 30601 / 50000\n", 365 | "|==========| Accuracy: 53.62 % 5362 / 10000\n", 366 | "This epoch took 40.96643614768982 seconds\n", 367 | "Current learning rate: 0.021523360500000012\n", 368 | "\n", 369 | "Epoch: 35\n", 370 | "|========================================| Accuracy: 63.02 % 31510 / 50000\n", 371 | "|==========| Accuracy: 51.62 % 5162 / 10000\n", 372 | "This epoch took 42.569278717041016 seconds\n", 373 | "Current learning rate: 0.02391484500000001\n", 374 | "\n", 375 | "Epoch: 36\n", 376 | "|========================================| Accuracy: 63.206 % 31603 / 50000\n", 377 | "|==========| Accuracy: 63.72 % 6372 / 10000\n", 378 | "This epoch took 42.79570960998535 seconds\n", 379 | "Current learning rate: 0.02391484500000001\n", 380 | "\n", 381 | "Epoch: 37\n", 382 | "|========================================| Accuracy: 63.456 % 31728 / 50000\n", 383 | "|==========| Accuracy: 52.97 % 5297 / 10000\n", 384 | "This epoch took 39.450865030288696 seconds\n", 385 | "Current learning rate: 0.02391484500000001\n", 386 | "\n", 387 | "Epoch: 38\n", 388 | "|========================================| Accuracy: 64.164 % 32082 / 50000\n", 389 | "|==========| Accuracy: 63.85 % 6385 / 10000\n", 390 | "This epoch took 40.91485786437988 seconds\n", 391 | "Current learning rate: 0.02391484500000001\n", 392 | "\n", 393 | "Epoch: 39\n", 394 | "|========================================| Accuracy: 64.876 % 32438 / 50000\n", 395 | "|==========| Accuracy: 62.66 % 6266 / 10000\n", 396 | "This epoch took 40.16771578788757 seconds\n", 397 | "Current learning rate: 0.01937102445000001\n", 398 | "\n", 399 | "Epoch: 40\n", 400 | "|========================================| Accuracy: 66.234 % 33117 / 50000\n", 401 | "|==========| Accuracy: 63.21 % 6321 / 10000\n", 402 | "This epoch took 40.66132354736328 seconds\n", 403 | "Current learning rate: 0.021523360500000012\n", 404 | "\n", 405 | "Epoch: 41\n", 406 | "|========================================| Accuracy: 66.352 % 33176 / 50000\n", 407 | "|==========| Accuracy: 62.34 % 6234 / 10000\n", 408 | "This epoch took 46.38554286956787 seconds\n", 409 | "Current learning rate: 0.021523360500000012\n", 410 | "\n", 411 | "Epoch: 42\n", 412 | "|========================================| Accuracy: 66.85 % 33425 / 50000\n", 413 | "|==========| Accuracy: 60.7 % 6070 / 10000\n", 414 | "This epoch took 42.69945788383484 seconds\n", 415 | "Current learning rate: 0.021523360500000012\n", 416 | "\n", 417 | "Epoch: 43\n", 418 | "|========================================| Accuracy: 67.504 % 33752 / 50000\n", 419 | "|==========| Accuracy: 59.91 % 5991 / 10000\n", 420 | "This epoch took 43.23206090927124 seconds\n", 421 | "Current learning rate: 0.021523360500000012\n", 422 | "\n", 423 | "Epoch: 44\n", 424 | "|========================================| Accuracy: 67.722 % 33861 / 50000\n" 425 | ] 426 | }, 427 | { 428 | "name": "stdout", 429 | "output_type": "stream", 430 | "text": [ 431 | "|==========| Accuracy: 58.57 % 5857 / 10000\n", 432 | "This epoch took 41.02703857421875 seconds\n", 433 | "Current learning rate: 0.01743392200500001\n", 434 | "\n", 435 | "Epoch: 45\n", 436 | "|========================================| Accuracy: 68.858 % 34429 / 50000\n", 437 | "|==========| Accuracy: 61.41 % 6141 / 10000\n", 438 | "This epoch took 41.44154691696167 seconds\n", 439 | "Current learning rate: 0.01937102445000001\n", 440 | "\n", 441 | "Epoch: 46\n", 442 | "|========================================| Accuracy: 69.02 % 34510 / 50000\n", 443 | "|==========| Accuracy: 63.16 % 6316 / 10000\n", 444 | "This epoch took 43.859628200531006 seconds\n", 445 | "Current learning rate: 0.01937102445000001\n", 446 | "\n", 447 | "Epoch: 47\n", 448 | "|========================================| Accuracy: 69.556 % 34778 / 50000\n", 449 | "|==========| Accuracy: 63.7 % 6370 / 10000\n", 450 | "This epoch took 41.11183762550354 seconds\n", 451 | "Current learning rate: 0.01937102445000001\n", 452 | "\n", 453 | "Epoch: 48\n", 454 | "|========================================| Accuracy: 69.896 % 34948 / 50000\n", 455 | "|==========| Accuracy: 60.98 % 6098 / 10000\n", 456 | "This epoch took 42.747894048690796 seconds\n", 457 | "Current learning rate: 0.01937102445000001\n", 458 | "\n", 459 | "Epoch: 49\n", 460 | "|========================================| Accuracy: 70.256 % 35128 / 50000\n", 461 | "|==========| Accuracy: 63.98 % 6398 / 10000\n", 462 | "This epoch took 41.37175130844116 seconds\n", 463 | "Current learning rate: 0.015690529804500006\n", 464 | "Best training accuracy overall: 63.98\n", 465 | "|==========| Accuracy: 63.98 % 6398 / 10000\n", 466 | "Testing accuracy: 63.98\n", 467 | "Mon Jun 8 21:33:18 2020 stats\n", 468 | "\n", 469 | " 510155235 function calls (507136221 primitive calls) in 2053.806 seconds\n", 470 | "\n", 471 | " Ordered by: internal time\n", 472 | " List reduced from 824 to 20 due to restriction <20>\n", 473 | "\n", 474 | " ncalls tottime percall cumtime percall filename:lineno(function)\n", 475 | " 49401 743.334 0.015 743.334 0.015 {method 'item' of 'torch._C._TensorBase' objects}\n", 476 | " 19550 222.799 0.011 222.799 0.011 {method 'run_backward' of 'torch._C._EngineBase' objects}\n", 477 | " 5747602 107.058 0.000 107.058 0.000 {method 'add_' of 'torch._C._TensorBase' objects}\n", 478 | " 1183200 102.777 0.000 102.777 0.000 {built-in method conv2d}\n", 479 | " 3010000 59.049 0.000 140.632 0.000 functional.py:192(normalize)\n", 480 | " 3010000 45.654 0.000 205.622 0.000 functional.py:43(to_tensor)\n", 481 | " 1915802 45.251 0.000 45.251 0.000 {method 'mul_' of 'torch._C._TensorBase' objects}\n", 482 | " 394400 41.219 0.000 41.219 0.000 {built-in method batch_norm}\n", 483 | " 3010000 40.894 0.000 40.894 0.000 {method 'tobytes' of 'numpy.ndarray' objects}\n", 484 | " 1915850 39.603 0.000 39.603 0.000 {method 'zero_' of 'torch._C._TensorBase' objects}\n", 485 | " 3010000 32.589 0.000 32.589 0.000 {method 'div' of 'torch._C._TensorBase' objects}\n", 486 | " 3010000 25.541 0.000 25.541 0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}\n", 487 | " 6020000 25.312 0.000 25.312 0.000 {built-in method as_tensor}\n", 488 | " 3010000 24.033 0.000 24.033 0.000 {method 'sub_' of 'torch._C._TensorBase' objects}\n", 489 | " 3010000 20.338 0.000 116.898 0.000 Image.py:2644(fromarray)\n", 490 | " 6020000 19.078 0.000 19.078 0.000 {method 'transpose' of 'torch._C._TensorBase' objects}\n", 491 | " 3034650 18.921 0.000 18.921 0.000 {method 'view' of 'torch._C._TensorBase' objects}\n", 492 | " 3010000 18.467 0.000 18.467 0.000 {method 'float' of 'torch._C._TensorBase' objects}\n", 493 | " 3010000 15.967 0.000 15.967 0.000 {method 'clone' of 'torch._C._TensorBase' objects}\n", 494 | " 19550 15.331 0.001 168.502 0.009 sgd.py:71(step)\n", 495 | "\n", 496 | "\n", 497 | "Wall time: 34min 13s\n" 498 | ] 499 | } 500 | ], 501 | "source": [ 502 | "%%time \n", 503 | "\n", 504 | "cProfile.run('eval()', 'stats')\n", 505 | "p = pstats.Stats('stats')\n", 506 | "p.strip_dirs().sort_stats(1).print_stats(20)\n", 507 | "p.dump_stats('prof\\\\main.prof')" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [] 516 | } 517 | ], 518 | "metadata": { 519 | "kernelspec": { 520 | "display_name": "Python 3", 521 | "language": "python", 522 | "name": "python3" 523 | }, 524 | "language_info": { 525 | "codemirror_mode": { 526 | "name": "ipython", 527 | "version": 3 528 | }, 529 | "file_extension": ".py", 530 | "mimetype": "text/x-python", 531 | "name": "python", 532 | "nbconvert_exporter": "python", 533 | "pygments_lexer": "ipython3", 534 | "version": "3.7.6" 535 | } 536 | }, 537 | "nbformat": 4, 538 | "nbformat_minor": 4 539 | } 540 | -------------------------------------------------------------------------------- /notebooks/main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# main function for decomposition\n", 8 | "### Author: Yiming Fang" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import torch\n", 18 | "import torch.nn as nn\n", 19 | "import torch.optim as optim\n", 20 | "import torch.nn.functional as F\n", 21 | "import torch.backends.cudnn as cudnn\n", 22 | "from torch.optim.lr_scheduler import StepLR\n", 23 | "\n", 24 | "import torchvision\n", 25 | "import torchvision.transforms as transforms\n", 26 | "from torchvision import models\n", 27 | "\n", 28 | "import tensorly as tl\n", 29 | "import tensorly\n", 30 | "from itertools import chain\n", 31 | "from tensorly.decomposition import parafac, partial_tucker\n", 32 | "\n", 33 | "import os\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "import numpy as np\n", 36 | "import time\n", 37 | "\n", 38 | "from nets import *\n", 39 | "from decomp import *" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# load data\n", 49 | "def load_mnist():\n", 50 | " print('==> Loading data..')\n", 51 | " transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", 52 | " transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", 53 | "\n", 54 | " trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)\n", 55 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n", 56 | "\n", 57 | " testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)\n", 58 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n", 59 | " return trainloader, testloader\n", 60 | "\n", 61 | "def load_cifar10():\n", 62 | " print('==> Loading data..')\n", 63 | " transform_train = transforms.Compose([\n", 64 | " transforms.RandomCrop(32, padding=4),\n", 65 | " transforms.RandomHorizontalFlip(),\n", 66 | " transforms.ToTensor(),\n", 67 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 68 | " ])\n", 69 | "\n", 70 | " transform_test = transforms.Compose([\n", 71 | " transforms.ToTensor(),\n", 72 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 73 | " ])\n", 74 | "\n", 75 | " trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n", 76 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n", 77 | "\n", 78 | " testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n", 79 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n", 80 | " \n", 81 | " return trainloader, testloader\n", 82 | "\n", 83 | "# ImageNet is no longer publically available\n", 84 | "def load_imagenet():\n", 85 | " print('==> Loading data..')\n", 86 | " transform_train = transforms.Compose([\n", 87 | " transforms.RandomResizedCrop(224),\n", 88 | " transforms.RandomHorizontalFlip(),\n", 89 | " transforms.ToTensor(),\n", 90 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", 91 | " ])\n", 92 | "\n", 93 | " transform_test = transforms.Compose([\n", 94 | " transforms.Resize(256),\n", 95 | " transforms.CenterCrop(224),\n", 96 | " transforms.ToTensor(),\n", 97 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", 98 | " ])\n", 99 | "\n", 100 | " trainset = torchvision.datasets.ImageNet(root='./data', train=True, download=True, transform=transform_train)\n", 101 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n", 102 | "\n", 103 | " testset = torchvision.datasets.ImageNet(root='./data', train=False, download=True, transform=transform_test)\n", 104 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n", 105 | " \n", 106 | " return trainloader, testloader\n", 107 | "\n", 108 | "def load_cifar100():\n", 109 | " print('==> Loading data..')\n", 110 | " transform_train = transforms.Compose([\n", 111 | " transforms.RandomCrop(32, padding=4),\n", 112 | " transforms.RandomHorizontalFlip(),\n", 113 | " transforms.ToTensor(),\n", 114 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 115 | " ])\n", 116 | "\n", 117 | " transform_test = transforms.Compose([\n", 118 | " transforms.ToTensor(),\n", 119 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 120 | " ])\n", 121 | "\n", 122 | " trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)\n", 123 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)\n", 124 | "\n", 125 | " testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)\n", 126 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)\n", 127 | " \n", 128 | " return trainloader, testloader" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 3, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "# build model\n", 138 | "def build(model, decomp='cp'):\n", 139 | " print('==> Building model..')\n", 140 | " tl.set_backend('pytorch')\n", 141 | " full_net = model\n", 142 | " full_net = full_net.to(device)\n", 143 | " torch.save(full_net, 'models/model')\n", 144 | " if decomp:\n", 145 | " decompose(decomp)\n", 146 | " net = torch.load(\"models/model\").cuda()\n", 147 | " print(net)\n", 148 | " print('==> Done')\n", 149 | " return net\n", 150 | " \n", 151 | "# training\n", 152 | "def train(epoch, train_acc, model):\n", 153 | " print('\\nEpoch: ', epoch)\n", 154 | " model.train()\n", 155 | " criterion = nn.CrossEntropyLoss()\n", 156 | " train_loss = 0\n", 157 | " correct = 0\n", 158 | " total = 0\n", 159 | " print('|', end='')\n", 160 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 161 | " inputs, targets = inputs.to(device), targets.to(device)\n", 162 | " optimizer.zero_grad()\n", 163 | " outputs = model(inputs)\n", 164 | " loss = criterion(outputs, targets)\n", 165 | " loss.backward()\n", 166 | " optimizer.step()\n", 167 | " train_loss += loss.item()\n", 168 | " _, predicted = outputs.max(1)\n", 169 | " total += targets.size(0)\n", 170 | " correct += predicted.eq(targets).sum().item()\n", 171 | " if batch_idx % 10 == 0:\n", 172 | " print('=', end='')\n", 173 | " print('|', 'Accuracy:', 100. * correct / total,'% ', correct, '/', total)\n", 174 | " train_acc.append(correct / total)\n", 175 | " return train_acc\n", 176 | "\n", 177 | "# testing\n", 178 | "def test(test_acc, model):\n", 179 | " model.eval()\n", 180 | " test_loss = 0\n", 181 | " correct = 0\n", 182 | " total = 0\n", 183 | " criterion = nn.CrossEntropyLoss()\n", 184 | " with torch.no_grad():\n", 185 | " print('|', end='')\n", 186 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 187 | " inputs, targets = inputs.to(device), targets.to(device)\n", 188 | " outputs = model(inputs)\n", 189 | " loss = criterion(outputs, targets)\n", 190 | " test_loss += loss.item()\n", 191 | " _, predicted = outputs.max(1)\n", 192 | " total += targets.size(0)\n", 193 | " correct += predicted.eq(targets).sum().item()\n", 194 | " if batch_idx % 10 == 0:\n", 195 | " print('=', end='')\n", 196 | " acc = 100. * correct / total\n", 197 | " print('|', 'Accuracy:', acc, '% ', correct, '/', total)\n", 198 | " test_acc.append(correct / total) \n", 199 | " return test_acc\n", 200 | "\n", 201 | "# decompose\n", 202 | "def decompose(decomp):\n", 203 | " model = torch.load(\"models/model\").cuda()\n", 204 | " model.eval()\n", 205 | " model.cpu()\n", 206 | " for i, key in enumerate(model.features._modules.keys()):\n", 207 | " if i >= len(model.features._modules.keys()) - 2:\n", 208 | " break\n", 209 | " conv_layer = model.features._modules[key]\n", 210 | " if isinstance(conv_layer, torch.nn.modules.conv.Conv2d):\n", 211 | " rank = max(conv_layer.weight.data.numpy().shape) // 10\n", 212 | " if decomp == 'cp':\n", 213 | " model.features._modules[key] = cp_decomposition_conv_layer(conv_layer, rank)\n", 214 | " if decomp == 'tucker': \n", 215 | " ranks = [int(np.ceil(conv_layer.weight.data.numpy().shape[0] / 3)), \n", 216 | " int(np.ceil(conv_layer.weight.data.numpy().shape[1] / 3))]\n", 217 | " model.features._modules[key] = tucker_decomposition_conv_layer(conv_layer, ranks)\n", 218 | " if decomp == 'tt':\n", 219 | " model.features._modules[key] = tt_decomposition_conv_layer(conv_layer, rank)\n", 220 | " torch.save(model, 'models/model')\n", 221 | " return model\n", 222 | "\n", 223 | "# Run functions\n", 224 | "def run_train(i, model):\n", 225 | " train_acc = []\n", 226 | " test_acc = []\n", 227 | " for epoch in range(i):\n", 228 | " s = time.time()\n", 229 | " train_acc = train(epoch, train_acc, model)\n", 230 | " test_acc = test(test_acc, model)\n", 231 | " scheduler.step()\n", 232 | " e = time.time()\n", 233 | " print('This epoch took', e - s, 'seconds')\n", 234 | " print('Current learning rate: ', scheduler.get_lr()[0])\n", 235 | " print('Best training accuracy overall: ', max(test_acc))\n", 236 | " return train_acc, test_acc" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 4, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "# main function\n", 246 | "def run_all(dataset, decomp=None, iterations=100, rate=0.05): \n", 247 | " global trainloader, testloader, device, optimizer, scheduler\n", 248 | " \n", 249 | " # choose an appropriate learning rate\n", 250 | " rate = rate\n", 251 | " \n", 252 | " # choose dataset from (MNIST, CIFAR10, ImageNet)\n", 253 | " if dataset == 'mnist':\n", 254 | " trainloader, testloader = load_mnist()\n", 255 | " model = Net()\n", 256 | " if dataset == 'cifar10':\n", 257 | " trainloader, testloader = load_cifar10()\n", 258 | " model = VGG('VGG19')\n", 259 | " if dataset == 'cifar100':\n", 260 | " trainloader, testloader = load_cifar100()\n", 261 | " model = VGG('VGG19')\n", 262 | " \n", 263 | " # check GPU availability\n", 264 | " device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", 265 | " \n", 266 | " # choose decomposition algorithm from (CP, Tucker, TT)\n", 267 | " net = build(model, decomp)\n", 268 | " optimizer = optim.SGD(net.parameters(), lr=rate, momentum=0.9, weight_decay=5e-4)\n", 269 | " scheduler = StepLR(optimizer, step_size=5, gamma=0.9)\n", 270 | " train_acc, test_acc = run_train(iterations, net)\n", 271 | " \n", 272 | " if not decomp:\n", 273 | " decomp = 'full'\n", 274 | " \n", 275 | " filename = dataset + '_' + decomp\n", 276 | " torch.save(net, 'models/' + filename)\n", 277 | " np.save('curves/' + filename + '_train', train_acc)\n", 278 | " np.save('curves/' + filename + '_test', test_acc)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "==> Loading data..\n", 291 | "==> Building model..\n", 292 | "Net(\n", 293 | " (features): Sequential(\n", 294 | " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", 295 | " (1): ReLU(inplace=True)\n", 296 | " (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", 297 | " (3): ReLU(inplace=True)\n", 298 | " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 299 | " )\n", 300 | " (classifier): Sequential(\n", 301 | " (0): Dropout(p=0.25, inplace=False)\n", 302 | " (1): Linear(in_features=9216, out_features=128, bias=True)\n", 303 | " (2): ReLU(inplace=True)\n", 304 | " (3): Dropout(p=0.5, inplace=False)\n", 305 | " (4): Linear(in_features=128, out_features=10, bias=True)\n", 306 | " )\n", 307 | ")\n", 308 | "==> Done\n", 309 | "\n", 310 | "Epoch: 0\n", 311 | "|===============================================| Accuracy: 91.945 % 55167 / 60000\n", 312 | "|==========| Accuracy: 98.28 % 9828 / 10000\n", 313 | "This epoch took 13.191731214523315 seconds\n", 314 | "Current learning rate: 0.05\n", 315 | "\n", 316 | "Epoch: 1\n", 317 | "|===============================================| Accuracy: 97.02166666666666 % 58213 / 60000\n", 318 | "|==========| Accuracy: 98.61 % 9861 / 10000\n", 319 | "This epoch took 12.04975700378418 seconds\n", 320 | "Current learning rate: 0.05\n", 321 | "\n", 322 | "Epoch: 2\n", 323 | "|===============================================| Accuracy: 97.565 % 58539 / 60000\n", 324 | "|==========| Accuracy: 98.8 % 9880 / 10000\n", 325 | "This epoch took 12.05876111984253 seconds\n", 326 | "Current learning rate: 0.05\n", 327 | "\n", 328 | "Epoch: 3\n", 329 | "|===============================================| Accuracy: 97.915 % 58749 / 60000\n", 330 | "|==========| Accuracy: 98.9 % 9890 / 10000\n", 331 | "This epoch took 12.11062240600586 seconds\n", 332 | "Current learning rate: 0.05\n", 333 | "\n", 334 | "Epoch: 4\n", 335 | "|===============================================| Accuracy: 97.955 % 58773 / 60000\n", 336 | "|==========| Accuracy: 98.95 % 9895 / 10000\n", 337 | "This epoch took 12.165474891662598 seconds\n", 338 | "Current learning rate: 0.04050000000000001\n", 339 | "\n", 340 | "Epoch: 5\n", 341 | "|===============================================| Accuracy: 98.34833333333333 % 59009 / 60000\n", 342 | "|==========| Accuracy: 98.95 % 9895 / 10000\n", 343 | "This epoch took 12.157521963119507 seconds\n", 344 | "Current learning rate: 0.045000000000000005\n", 345 | "\n", 346 | "Epoch: 6\n", 347 | "|===============================================| Accuracy: 98.545 % 59127 / 60000\n", 348 | "|==========| Accuracy: 99.0 % 9900 / 10000\n", 349 | "This epoch took 12.109601020812988 seconds\n", 350 | "Current learning rate: 0.045000000000000005\n", 351 | "\n", 352 | "Epoch: 7\n", 353 | "|===============================================| Accuracy: 98.56 % 59136 / 60000\n", 354 | "|==========| Accuracy: 98.68 % 9868 / 10000\n", 355 | "This epoch took 12.11560845375061 seconds\n", 356 | "Current learning rate: 0.045000000000000005\n", 357 | "\n", 358 | "Epoch: 8\n", 359 | "|===============================================| Accuracy: 98.64333333333333 % 59186 / 60000\n", 360 | "|==========| Accuracy: 99.15 % 9915 / 10000\n", 361 | "This epoch took 12.045796155929565 seconds\n", 362 | "Current learning rate: 0.045000000000000005\n", 363 | "\n", 364 | "Epoch: 9\n", 365 | "|===============================================| Accuracy: 98.69166666666666 % 59215 / 60000\n", 366 | "|==========| Accuracy: 99.06 % 9906 / 10000\n", 367 | "This epoch took 11.952045917510986 seconds\n", 368 | "Current learning rate: 0.03645000000000001\n", 369 | "\n", 370 | "Epoch: 10\n", 371 | "|===============================================| Accuracy: 98.92666666666666 % 59356 / 60000\n", 372 | "|==========| Accuracy: 99.22 % 9922 / 10000\n", 373 | "This epoch took 12.087683916091919 seconds\n", 374 | "Current learning rate: 0.04050000000000001\n", 375 | "\n", 376 | "Epoch: 11\n", 377 | "|===============================================| Accuracy: 98.82333333333334 % 59294 / 60000\n", 378 | "|==========| Accuracy: 99.15 % 9915 / 10000\n", 379 | "This epoch took 11.951063871383667 seconds\n", 380 | "Current learning rate: 0.04050000000000001\n", 381 | "\n", 382 | "Epoch: 12\n", 383 | "|===============================================| Accuracy: 98.93333333333334 % 59360 / 60000\n", 384 | "|==========| Accuracy: 99.11 % 9911 / 10000\n", 385 | "This epoch took 11.971977710723877 seconds\n", 386 | "Current learning rate: 0.04050000000000001\n", 387 | "\n", 388 | "Epoch: 13\n", 389 | "|===============================================| Accuracy: 98.83833333333334 % 59303 / 60000\n", 390 | "|==========| Accuracy: 99.02 % 9902 / 10000\n", 391 | "This epoch took 11.958030462265015 seconds\n", 392 | "Current learning rate: 0.04050000000000001\n", 393 | "\n", 394 | "Epoch: 14\n", 395 | "|===============================================| Accuracy: 98.97 % 59382 / 60000\n", 396 | "|==========| Accuracy: 99.11 % 9911 / 10000\n", 397 | "This epoch took 11.924121141433716 seconds\n", 398 | "Current learning rate: 0.03280500000000001\n", 399 | "\n", 400 | "Epoch: 15\n", 401 | "|===============================================| Accuracy: 99.04 % 59424 / 60000\n", 402 | "|==========| Accuracy: 99.17 % 9917 / 10000\n", 403 | "This epoch took 11.97598123550415 seconds\n", 404 | "Current learning rate: 0.03645000000000001\n", 405 | "\n", 406 | "Epoch: 16\n", 407 | "|===============================================| Accuracy: 99.12333333333333 % 59474 / 60000\n", 408 | "|==========| Accuracy: 99.3 % 9930 / 10000\n", 409 | "This epoch took 11.95703387260437 seconds\n", 410 | "Current learning rate: 0.03645000000000001\n", 411 | "\n", 412 | "Epoch: 17\n", 413 | "|===============================================| Accuracy: 99.04333333333334 % 59426 / 60000\n", 414 | "|==========| Accuracy: 99.12 % 9912 / 10000\n", 415 | "This epoch took 11.95304274559021 seconds\n", 416 | "Current learning rate: 0.03645000000000001\n", 417 | "\n", 418 | "Epoch: 18\n", 419 | "|===============================================| Accuracy: 99.06166666666667 % 59437 / 60000\n", 420 | "|==========| Accuracy: 99.26 % 9926 / 10000\n", 421 | "This epoch took 12.056767225265503 seconds\n", 422 | "Current learning rate: 0.03645000000000001\n", 423 | "\n", 424 | "Epoch: 19\n", 425 | "|===============================================| Accuracy: 99.13833333333334 % 59483 / 60000\n", 426 | "|==========| Accuracy: 99.18 % 9918 / 10000\n", 427 | "This epoch took 12.055768728256226 seconds\n", 428 | "Current learning rate: 0.02952450000000001\n", 429 | "\n", 430 | "Epoch: 20\n", 431 | "|===============================================| Accuracy: 99.13166666666666 % 59479 / 60000\n", 432 | "|==========| Accuracy: 99.13 % 9913 / 10000\n", 433 | "This epoch took 12.040835857391357 seconds\n", 434 | "Current learning rate: 0.03280500000000001\n", 435 | "\n", 436 | "Epoch: 21\n", 437 | "|===============================================| Accuracy: 99.195 % 59517 / 60000\n", 438 | "|==========| Accuracy: 99.24 % 9924 / 10000\n", 439 | "This epoch took 12.066712379455566 seconds\n", 440 | "Current learning rate: 0.03280500000000001\n", 441 | "\n", 442 | "Epoch: 22\n", 443 | "|===============================================| Accuracy: 99.25666666666666 % 59554 / 60000\n", 444 | "|==========| Accuracy: 99.2 % 9920 / 10000\n", 445 | "This epoch took 12.057764053344727 seconds\n", 446 | "Current learning rate: 0.03280500000000001\n", 447 | "\n", 448 | "Epoch: 23\n", 449 | "|===============================================| Accuracy: 99.2 % 59520 / 60000\n", 450 | "|==========| Accuracy: 99.21 % 9921 / 10000\n", 451 | "This epoch took 12.059757709503174 seconds\n", 452 | "Current learning rate: 0.03280500000000001\n", 453 | "\n", 454 | "Epoch: 24\n", 455 | "|===============================================| Accuracy: 99.25666666666666 % 59554 / 60000\n", 456 | "|==========| Accuracy: 99.16 % 9916 / 10000\n", 457 | "This epoch took 12.063747882843018 seconds\n", 458 | "Current learning rate: 0.02657205000000001\n", 459 | "\n", 460 | "Epoch: 25\n", 461 | "|===============================================| Accuracy: 99.28166666666667 % 59569 / 60000\n", 462 | "|==========| Accuracy: 99.22 % 9922 / 10000\n", 463 | "This epoch took 12.0886812210083 seconds\n", 464 | "Current learning rate: 0.02952450000000001\n", 465 | "\n", 466 | "Epoch: 26\n", 467 | "|===============================================| Accuracy: 99.34333333333333 % 59606 / 60000\n", 468 | "|==========| Accuracy: 99.27 % 9927 / 10000\n", 469 | "This epoch took 12.026846408843994 seconds\n", 470 | "Current learning rate: 0.02952450000000001\n", 471 | "\n", 472 | "Epoch: 27\n", 473 | "|===============================================| Accuracy: 99.27166666666666 % 59563 / 60000\n", 474 | "|==========| Accuracy: 99.25 % 9925 / 10000\n", 475 | "This epoch took 12.094665050506592 seconds\n", 476 | "Current learning rate: 0.02952450000000001\n", 477 | "\n", 478 | "Epoch: 28\n", 479 | "|===============================================| Accuracy: 99.31666666666666 % 59590 / 60000\n", 480 | "|==========| Accuracy: 99.15 % 9915 / 10000\n", 481 | "This epoch took 12.090675592422485 seconds\n", 482 | "Current learning rate: 0.02952450000000001\n", 483 | "\n", 484 | "Epoch: 29\n", 485 | "|===============================================| Accuracy: 99.32166666666667 % 59593 / 60000\n", 486 | "|==========| Accuracy: 99.06 % 9906 / 10000\n", 487 | "This epoch took 12.038814544677734 seconds\n", 488 | "Current learning rate: 0.02391484500000001\n", 489 | "\n", 490 | "Epoch: 30\n", 491 | "|===============================================| Accuracy: 99.32833333333333 % 59597 / 60000\n", 492 | "|==========| Accuracy: 99.15 % 9915 / 10000\n", 493 | "This epoch took 12.056766271591187 seconds\n", 494 | "Current learning rate: 0.02657205000000001\n", 495 | "\n", 496 | "Epoch: 31\n", 497 | "|===============================================| Accuracy: 99.36166666666666 % 59617 / 60000\n", 498 | "|==========| Accuracy: 99.19 % 9919 / 10000\n", 499 | "This epoch took 12.067737340927124 seconds\n", 500 | "Current learning rate: 0.02657205000000001\n", 501 | "\n", 502 | "Epoch: 32\n", 503 | "|===============================================| Accuracy: 99.41 % 59646 / 60000\n", 504 | "|==========| Accuracy: 99.21 % 9921 / 10000\n", 505 | "This epoch took 12.074690818786621 seconds\n", 506 | "Current learning rate: 0.02657205000000001\n", 507 | "\n", 508 | "Epoch: 33\n" 509 | ] 510 | }, 511 | { 512 | "name": "stdout", 513 | "output_type": "stream", 514 | "text": [ 515 | "|=====================" 516 | ] 517 | } 518 | ], 519 | "source": [ 520 | "%%time\n", 521 | "run_all('mnist')" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "%%time\n", 531 | "run_all('mnist', 'cp', rate=0.01)" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": {}, 538 | "outputs": [], 539 | "source": [ 540 | "%%time\n", 541 | "run_all('mnist', 'tucker')" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [ 550 | "%%time\n", 551 | "run_all('mnist', 'tt')" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": null, 557 | "metadata": {}, 558 | "outputs": [], 559 | "source": [ 560 | "%%time\n", 561 | "run_all('cifar10', iterations=200)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "%%time\n", 571 | "run_all('cifar10', 'tucker', iterations=200)" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "%%time\n", 581 | "run_all('cifar100', 'tt', iterations=200)" 582 | ] 583 | } 584 | ], 585 | "metadata": { 586 | "kernelspec": { 587 | "display_name": "Python 3", 588 | "language": "python", 589 | "name": "python3" 590 | }, 591 | "language_info": { 592 | "codemirror_mode": { 593 | "name": "ipython", 594 | "version": 3 595 | }, 596 | "file_extension": ".py", 597 | "mimetype": "text/x-python", 598 | "name": "python", 599 | "nbconvert_exporter": "python", 600 | "pygments_lexer": "ipython3", 601 | "version": "3.7.6" 602 | } 603 | }, 604 | "nbformat": 4, 605 | "nbformat_minor": 4 606 | } 607 | -------------------------------------------------------------------------------- /transform_based_network/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .trainer import * 3 | from .transform_layer import * 4 | from .transform_nets import * 5 | from .multiprocess import * -------------------------------------------------------------------------------- /transform_based_network/multiprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_dct as dct 5 | 6 | import time 7 | from torch.multiprocessing import Pool, Queue, Process, set_start_method 8 | import torch.multiprocessing as torch_mp 9 | import multiprocessing as mp_ 10 | 11 | import sys 12 | sys.path.append('../') 13 | from common import * 14 | from transform_based_network import * 15 | 16 | 17 | def t_product_slice(A, B, C, i): 18 | C[i, ...] = torch.mm(A[i, ...], B[i, ...]) 19 | 20 | def t_product_multiprocess(A, B): 21 | tmp = torch_mp.get_context('spawn') 22 | 23 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]) 24 | dct_A = torch.transpose(dct.dct(torch.transpose(A, 0, 2)), 0, 2) 25 | dct_B = torch.transpose(dct.dct(torch.transpose(B, 0, 2)), 0, 2) 26 | dct_C = torch.zeros(A.shape[0], A.shape[1], B.shape[2]) 27 | 28 | #dct_A.share_memory_() 29 | #dct_B.share_memory_() 30 | #dct_C.share_memory_() 31 | 32 | processes = [] 33 | # num_cores = torch_mp.cpu_count() 34 | for i in range(dct_C.shape[0]): 35 | p = tmp.Process(target=t_product_slice, args=(dct_A, dct_B, dct_C, i)) 36 | p.start() 37 | processes.append(p) 38 | for p in processes: 39 | p.join() 40 | 41 | C = torch.transpose(dct.idct(torch.transpose(dct_C, 0, 2)), 0, 2) 42 | return C 43 | 44 | ''' 45 | if __name__ == '__main__': 46 | A = torch.ones(17, 1000, 1000) 47 | B = torch.ones(17, 1000, 1000) 48 | t_product_multiprocess(A, B) 49 | ''' -------------------------------------------------------------------------------- /transform_based_network/new_transform.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch_dct as dct\n", 11 | "import torch.nn as nn\n", 12 | "import torch.optim as optim\n", 13 | "import torch.nn.functional as F\n", 14 | "import torch.backends.cudnn as cudnn\n", 15 | "import torch.nn.init as init\n", 16 | "from torch.optim.lr_scheduler import StepLR\n", 17 | "\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "import time\n", 20 | "import pkbar\n", 21 | "import math\n", 22 | "\n", 23 | "import sys\n", 24 | "sys.path.append('../')\n", 25 | "from common import *\n", 26 | "from transform_based_network import *" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 14, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "def t_product_in_network(A, B):\n", 36 | " device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", 37 | " assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])\n", 38 | " dct_C = torch.zeros(A.shape[0], A.shape[1], B.shape[2])\n", 39 | " dct_A = torch_apply(dct.dct, A)\n", 40 | " for k in range(A.shape[0]):\n", 41 | " dct_C[k, ...] = torch.mm(dct_A[k, ...], B[k, ...])\n", 42 | " return dct_C #.to(device)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 15, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "class tNN(nn.Module):\n", 52 | " def __init__(self):\n", 53 | " super(tNN, self).__init__()\n", 54 | " W, B = [], []\n", 55 | " self.num_layers = 10\n", 56 | " for i in range(self.num_layers):\n", 57 | " W.append(nn.Parameter(torch.Tensor(28, 28, 28)))\n", 58 | " B.append(nn.Parameter(torch.Tensor(28, 28, 1)))\n", 59 | " self.W = nn.ParameterList(W)\n", 60 | " self.B = nn.ParameterList(B)\n", 61 | " self.reset_parameters()\n", 62 | "\n", 63 | " def forward(self, x):\n", 64 | " for i in range(self.num_layers):\n", 65 | " x = torch.add(t_product(self.W[i], x), self.B[i])\n", 66 | " x = F.relu(x)\n", 67 | " return x\n", 68 | "\n", 69 | " def reset_parameters(self):\n", 70 | " for i in range(self.num_layers):\n", 71 | " init.kaiming_uniform_(self.W[i], a=math.sqrt(5))\n", 72 | " fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i])\n", 73 | " bound = 1 / math.sqrt(fan_in)\n", 74 | " init.uniform_(self.B[i], -bound, bound)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 22, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "class new_tNN(nn.Module):\n", 84 | " def __init__(self):\n", 85 | " super(new_tNN, self).__init__()\n", 86 | " W, B = [], []\n", 87 | " self.num_layers = 4\n", 88 | " for i in range(self.num_layers):\n", 89 | " W.append(nn.Parameter(torch.Tensor(28, 28, 28)))\n", 90 | " B.append(nn.Parameter(torch.Tensor(28, 28, 1)))\n", 91 | " self.W = nn.ParameterList(W)\n", 92 | " self.B = nn.ParameterList(B)\n", 93 | " self.reset_parameters()\n", 94 | "\n", 95 | " def forward(self, x):\n", 96 | " x = torch_apply(dct.dct, x)\n", 97 | " for i in range(self.num_layers):\n", 98 | " x = torch.add(t_product_in_network(self.W[i], x), self.B[i])\n", 99 | " x = F.relu(x)\n", 100 | " x = torch_apply(dct.idct, x)\n", 101 | " return x\n", 102 | "\n", 103 | " def reset_parameters(self):\n", 104 | " for i in range(self.num_layers):\n", 105 | " init.kaiming_uniform_(self.W[i], a=math.sqrt(5))\n", 106 | " fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i])\n", 107 | " bound = 1 / math.sqrt(fan_in)\n", 108 | " init.uniform_(self.B[i], -bound, bound)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 28, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "==> Loading data..\n", 121 | "Epoch 1 training:\n", 122 | " 23/600 [>.............................] - 8.4s" 123 | ] 124 | }, 125 | { 126 | "ename": "KeyboardInterrupt", 127 | "evalue": "", 128 | "output_type": "error", 129 | "traceback": [ 130 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 131 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 132 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[0mpbar_train\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpkbar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mPbar\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Epoch '\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;34m' training:'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m60000\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 28\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mraw_img\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m28\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 133 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 361\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 362\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 363\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 364\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 365\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 134 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 401\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 402\u001b[0m \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 403\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 404\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 405\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 135 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 136 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 137 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 138 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m 59\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 60\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 61\u001b[1;33m \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 62\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 63\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 139 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, tensor)\u001b[0m\n\u001b[0;32m 210\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mNormalized\u001b[0m \u001b[0mTensor\u001b[0m \u001b[0mimage\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 211\u001b[0m \"\"\"\n\u001b[1;32m--> 212\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormalize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 213\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 214\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 140 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\functional.py\u001b[0m in \u001b[0;36mnormalize\u001b[1;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[0;32m 290\u001b[0m \u001b[0mmean\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 291\u001b[0m \u001b[0mstd\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 292\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mstd\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0many\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 293\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'std evaluated to zero after conversion to {}, leading to division by zero.'\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 294\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 141 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mwrapped\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 22\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 23\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 142 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "lr_rate = 0.001\n", 148 | "epochs_num = 20\n", 149 | "device = 'cpu' # 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", 150 | "batch_size = 100\n", 151 | "train_loader, test_loader = load_mnist_multiprocess(batch_size)\n", 152 | "\n", 153 | "module = new_tNN()\n", 154 | "module = module.to(device)\n", 155 | "\n", 156 | "Loss_function = nn.CrossEntropyLoss()\n", 157 | "optimizer = torch.optim.SGD(module.parameters(), lr=lr_rate)\n", 158 | "\n", 159 | "test_loss_epoch = []\n", 160 | "test_acc_epoch = []\n", 161 | "train_loss_epoch = []\n", 162 | "train_acc_epoch = []\n", 163 | "time_list = []\n", 164 | "\n", 165 | "# begain train\n", 166 | "for epoch in range(epochs_num):\n", 167 | " since = time.time()\n", 168 | " running_loss = 0.0\n", 169 | " running_acc = 0.0\n", 170 | " module.train()\n", 171 | "\n", 172 | " pbar_train = pkbar.Pbar(name='Epoch '+str(epoch+1)+' training:', target=60000/batch_size)\n", 173 | " for i, data in enumerate(train_loader):\n", 174 | " img, label = data\n", 175 | " img = raw_img(img, batch_size, n=28)\n", 176 | " img = img.to(device)\n", 177 | " label = label.to(device)\n", 178 | "\n", 179 | " # forward\n", 180 | " out = module(img)\n", 181 | "\n", 182 | " # softmax function\n", 183 | " out = torch.transpose(scalar_tubal_func(out), 0, 1)\n", 184 | " loss = Loss_function(out, label)\n", 185 | " running_loss += loss.item()\n", 186 | " _, pred = torch.max(out, 1)\n", 187 | " running_acc += (pred == label).float().mean()\n", 188 | "\n", 189 | " # backward\n", 190 | " optimizer.zero_grad()\n", 191 | " loss.backward()\n", 192 | " optimizer.step()\n", 193 | " \n", 194 | " pbar_train.update(i)\n", 195 | "\n", 196 | " print('[{Epoch}/{Epochs_num}] Loss:{Running_loss} Acc:{Running_acc}'\n", 197 | " .format(Epoch=epoch + 1, Epochs_num=epochs_num, Running_loss=(running_loss / i),\n", 198 | " Running_acc=running_acc / i))\n", 199 | " train_loss_epoch.append(running_loss / i)\n", 200 | " train_acc_epoch.append((running_acc / i) * 100)\n", 201 | "\n", 202 | " module.eval()\n", 203 | " eval_loss = 0.0\n", 204 | " eval_acc = 0.0\n", 205 | "\n", 206 | " pbar_test = pkbar.Pbar(name='Epoch '+str(epoch+1)+' test', target=10000/batch_size)\n", 207 | " for i, data in enumerate(test_loader):\n", 208 | " img, label = data\n", 209 | " img = cifar_img_process(img)\n", 210 | " img = img.to(device)\n", 211 | " label = label.to(device)\n", 212 | "\n", 213 | " with torch.no_grad():\n", 214 | " out = module(img)\n", 215 | " out = torch.transpose(scalar_tubal_func(out), 0, 1)\n", 216 | " loss = Loss_function(out, label)\n", 217 | " eval_loss += loss.item()\n", 218 | " _, pred = torch.max(out, 1)\n", 219 | " eval_acc += (pred == label).float().mean()\n", 220 | "\n", 221 | " pbar_test.update(i)\n", 222 | "\n", 223 | " print('Test Loss: {Eval_loss}, Acc: {Eval_acc}'\n", 224 | " .format(Eval_loss=eval_loss / len(test_loader), \n", 225 | " Eval_acc=eval_acc / len(test_loader)))\n", 226 | " test_loss_epoch.append(eval_loss / len(test_loader))\n", 227 | " test_acc_epoch.append((eval_acc / len(test_loader)) * 100)\n", 228 | " time_list.append(time.time() - since)\n", 229 | "\n", 230 | " if np.isnan(eval_loss):\n", 231 | " print('invalid loss')\n", 232 | " break" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [] 248 | } 249 | ], 250 | "metadata": { 251 | "kernelspec": { 252 | "display_name": "Python 3", 253 | "language": "python", 254 | "name": "python3" 255 | } 256 | }, 257 | "nbformat": 4, 258 | "nbformat_minor": 4 259 | } 260 | -------------------------------------------------------------------------------- /transform_based_network/tNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_dct as dct 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn.init as init 8 | from torch.optim.lr_scheduler import StepLR 9 | 10 | import matplotlib.pyplot as plt 11 | import time 12 | import pkbar 13 | import math 14 | 15 | import sys 16 | sys.path.append('../') 17 | from common import * 18 | from transform_based_network import * 19 | 20 | 21 | class tNN(nn.Module): 22 | def __init__(self): 23 | super(tNN, self).__init__() 24 | W, B = [], [] 25 | self.num_layers = 4 26 | for i in range(self.num_layers): 27 | W.append(nn.Parameter(torch.Tensor(28, 28, 28))) 28 | B.append(nn.Parameter(torch.Tensor(28, 28, 1))) 29 | self.W = nn.ParameterList(W) 30 | self.B = nn.ParameterList(B) 31 | self.reset_parameters() 32 | 33 | def forward(self, x): 34 | for i in range(self.num_layers): 35 | x = torch.add(t_product(self.W[i], x), self.B[i]) 36 | x = F.relu(x) 37 | return x 38 | 39 | def reset_parameters(self): 40 | for i in range(self.num_layers): 41 | init.kaiming_uniform_(self.W[i], a=math.sqrt(5)) 42 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i]) 43 | bound = 1 / math.sqrt(fan_in) 44 | init.uniform_(self.B[i], -bound, bound) -------------------------------------------------------------------------------- /transform_based_network/trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from common import * 4 | from transform_based_network import * 5 | 6 | import torch 7 | import torch_dct as dct 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | import torch.backends.cudnn as cudnn 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | 15 | def train_step_transform(epoch, train_acc, model, trainloader, optimizer): 16 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 17 | model = model.to(device) 18 | model.train() 19 | train_loss = 0.0 20 | correct = 0 21 | total = 0 22 | criterion = nn.CrossEntropyLoss() 23 | 24 | print('\nEpoch: ', epoch) 25 | print('|', end='') 26 | for batch_idx, (inputs, labels) in enumerate(trainloader): 27 | inputs = inputs.to(device) 28 | if not inputs.shape[0] == 100: 29 | break 30 | 31 | labels = labels.to(device) 32 | outputs = model(inputs) 33 | outputs = torch.transpose(scalar_tubal_func(outputs), 0, 1) 34 | 35 | optimizer.zero_grad() 36 | loss = criterion(outputs, labels) 37 | # print(loss) 38 | if np.isnan(loss.item()): 39 | print('Training terminated due to instability') 40 | break 41 | loss.backward() 42 | optimizer.step() 43 | train_loss += loss.item() 44 | _, predicted = torch.max(outputs, 1) 45 | total += labels.size(0) 46 | correct += predicted.eq(labels).sum().item() 47 | if batch_idx % 10 == 0: 48 | print('=', end='') 49 | print('|', 'Accuracy:', correct / total, 'Loss:', train_loss / total) 50 | train_acc.append(correct / total) 51 | return train_acc 52 | 53 | def test(test_acc, model, testloader): 54 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 55 | model = model.to(device) 56 | model.eval() 57 | test_loss = 0 58 | correct = 0 59 | total = 0 60 | criterion = nn.CrossEntropyLoss() 61 | s = time.time() 62 | with torch.no_grad(): 63 | print('|', end='') 64 | for batch_idx, (inputs, targets) in enumerate(testloader): 65 | inputs = inputs.to(device) 66 | if not inputs.shape[0] == 100: 67 | break 68 | targets = targets.to(device) 69 | outputs = model(inputs) 70 | outputs = torch.transpose(scalar_tubal_func(outputs), 0, 1) 71 | loss = criterion(outputs, targets) 72 | test_loss += loss.item() 73 | _, predicted = outputs.max(1) 74 | total += targets.size(0) 75 | correct += predicted.eq(targets).sum().item() 76 | if batch_idx % 10 == 0: 77 | print('=', end='') 78 | e = time.time() 79 | print('|', ' Test accuracy:', correct / total, 'Test loss:', test_loss / total) 80 | print('The inference time is', e - s, 'seconds') 81 | test_acc.append(correct / total) 82 | return test_acc, e - s 83 | 84 | def train_transform(i, model, trainloader, testloader, optimizer): 85 | train_acc, test_acc = [], [] 86 | scheduler = StepLR(optimizer, step_size=1, gamma=0.95) 87 | 88 | for epoch in range(i): 89 | s = time.time() 90 | train_acc = train_step_transform(epoch, train_acc, model, trainloader, optimizer) 91 | e = time.time() 92 | test_acc, _ = test(test_acc, model, testloader) 93 | scheduler.step() 94 | 95 | print('This epoch took', e - s, 'seconds to train') 96 | print('Current learning rate:', scheduler.get_last_lr()[0]) 97 | print('Best training accuracy overall:', max(test_acc)) 98 | return train_acc, test_acc -------------------------------------------------------------------------------- /transform_based_network/transform_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_dct as dct 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | from torch.optim.lr_scheduler import StepLR 8 | import sys 9 | sys.path.append('../') 10 | from common import * 11 | from transform_based_network import * 12 | 13 | 14 | class Transform_Layer(nn.Module): 15 | def __init__(self, n, size_in, m, size_out): 16 | super().__init__() 17 | self.size_in = size_in 18 | self.size_out = size_out 19 | weights = torch.randn(n, size_out, size_in) * 0.01 20 | bias = torch.randn(1, size_out, m) 21 | self.weights = nn.Parameter(weights, requires_grad=True) 22 | self.bias = nn.Parameter(bias, requires_grad=True) 23 | 24 | def forward(self, x): 25 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 26 | return t_product_v2(self.weights, x).to(device) #/ 1e2 27 | #torch.add(t_product_v2(self.weights, x).to(device), self.bias) 28 | 29 | class T_Layer(nn.Module): 30 | def __init__(self, dct_w, dct_b): 31 | super(T_Layer, self).__init__() 32 | self.weights = nn.Parameter(dct_w, requires_grad=True) 33 | self.bias = nn.Parameter(dct_b, requires_grad=True) 34 | 35 | def forward(self, dct_x): 36 | x = torch.mm(self.weights, dct_x) + self.bias 37 | return x -------------------------------------------------------------------------------- /transform_based_network/transform_nets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from common import * 4 | from transform_based_network import * 5 | 6 | import torch 7 | import torch_dct as dct 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | import torch.backends.cudnn as cudnn 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | 15 | class Transform_Net(nn.Module): 16 | def __init__(self, batch_size): 17 | super(Transform_Net, self).__init__() 18 | self.features = nn.Sequential( 19 | #Transform_Layer(28, 28, batch_size, 28), 20 | #nn.ReLU(inplace=True), 21 | #Transform_Layer(28, 28, batch_size, 28), 22 | #nn.ReLU(inplace=True), 23 | Transform_Layer(28, 28, batch_size, 10), 24 | ) 25 | 26 | def forward(self, x): 27 | self.x = torch_shift(x) 28 | self.x = self.features(self.x) 29 | return self.x / 1e2#/ 1e2 30 | 31 | 32 | class Conv_Transform_Net(nn.Module): 33 | def __init__(self, batch_size): 34 | super(Conv_Transform_Net, self).__init__() 35 | self.first = nn.Sequential( 36 | Transform_Layer(28, 28, batch_size, 28), 37 | nn.ReLU(inplace=True), 38 | ) 39 | self.intermediate = nn.Sequential( 40 | nn.Conv2d(28, 28, kernel_size=3, padding=1), 41 | nn.Conv2d(28, 28, kernel_size=3, padding=1), 42 | nn.Conv2d(28, 28, kernel_size=1, padding=0), 43 | ) 44 | self.last = Transform_Layer(28, 28, batch_size, 10) 45 | 46 | def forward(self, x): 47 | x = torch_shift(x) 48 | x = self.first(x) 49 | 50 | x = torch.transpose(x, 0, 2) 51 | x = torch.transpose(x, 1, 2) 52 | x = x.reshape(100, 28, 4, 7) 53 | 54 | x = self.intermediate(x) 55 | 56 | x = x.reshape(100, 28, 28) 57 | x = torch_shift(x) 58 | x = self.last(x) 59 | 60 | return x / 5e2 61 | 62 | 63 | class Conv_Transform_Net_CIFAR(nn.Module): 64 | def __init__(self, batch_size): 65 | super(Conv_Transform_Net_CIFAR, self).__init__() 66 | self.batch_size = batch_size 67 | self.first = nn.Sequential( 68 | Transform_Layer(96, 32, batch_size, 32), 69 | nn.ReLU(inplace=True), 70 | Transform_Layer(96, 32, batch_size, 32), 71 | nn.ReLU(inplace=True), 72 | ) 73 | 74 | self.intermediate = nn.Sequential( 75 | nn.Conv2d(96, 96, kernel_size=3, padding=1), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(96, 96, kernel_size=3, padding=1), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(96, 96, kernel_size=3, padding=1), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(96, 96, kernel_size=3, padding=1), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(96, 96, kernel_size=3, padding=1), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(96, 96, kernel_size=1, padding=0), 86 | nn.ReLU(inplace=True), 87 | ) 88 | self.last = Transform_Layer(96, 32, batch_size, 10) 89 | 90 | def forward(self, x): 91 | x = torch.reshape(x, (100, 96, 32)) 92 | x = torch_shift(x) 93 | x = self.first(x) 94 | 95 | x = torch.transpose(x, 0, 2) 96 | x = torch.transpose(x, 1, 2) 97 | x = x.reshape(self.batch_size, 96, 4, 8) 98 | x = self.intermediate(x) 99 | 100 | x = x.reshape(self.batch_size, 96, 32) 101 | x = torch_shift(x) 102 | x = self.last(x) 103 | 104 | return x 105 | 106 | class Frontal_Slice(nn.Module): 107 | def __init__(self, dct_w, dct_b): 108 | super(Frontal_Slice, self).__init__() 109 | self.device = dct_w.device 110 | self.dct_linear = T_Layer(dct_w, dct_b) 111 | 112 | def forward(self, dct_x): 113 | return self.dct_linear(dct_x.to(self.device)) 114 | 115 | 116 | class Ensemble(nn.Module): 117 | def __init__(self, shape, device='cpu'): 118 | super(Ensemble, self).__init__() 119 | self.device = device 120 | self.models = [] 121 | for i in range(shape[0]): 122 | dct_w, dct_b = make_weights(shape, device) 123 | model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...]) 124 | self.models.append(model.to(device)) 125 | 126 | def forward(self, x): 127 | s = self.models[0].dct_linear.weights.shape 128 | result = torch.empty(x.shape[0], s[0], x.shape[2]) 129 | dct_x = torch_apply(dct.dct, x).to(self.device) 130 | 131 | for i in range(len(self.models)): 132 | result[i, ...] = self.models[i](dct_x[i, ...]) 133 | 134 | result = torch_apply(dct.idct, result) 135 | result = torch_shift(result) 136 | softmax = scalar_tubal_func(result) 137 | return torch.transpose(softmax, 0, 1) -------------------------------------------------------------------------------- /transform_based_network/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from common import * 4 | from transform_based_network import * 5 | 6 | import torch 7 | import torch_dct as dct 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | import torch.backends.cudnn as cudnn 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | 15 | def bcirc(A): 16 | l, m, n = A.shape 17 | bcirc_A = [] 18 | for i in range(l): 19 | bcirc_A.append(torch.roll(A, shifts=i, dims=0)) 20 | return torch.cat(bcirc_A, dim=2).reshape(l*m, l*n) 21 | 22 | def hankel(A): 23 | l, m, n = A.shape 24 | circ = torch.zeros(2 * l + 1, m, n) 25 | circ[l, ...] = torch.zeros(m, n) 26 | for i in range(l): 27 | k = circ.shape[0] - i - 1 28 | circ[i, ...] = A[i, ...] 29 | circ[k, ...] = A[i, ...] 30 | hankel_A = [] 31 | for i in range(1, l + 1): 32 | hankel_A.append(circ[i : i + l, ...]) 33 | return torch.cat(hankel_A, dim=2).reshape(l*m, l*n) 34 | 35 | def toeplitz(A): 36 | l, m, n = A.shape 37 | circ = torch.zeros(2 * l - 1, m, n) 38 | for i in range(l): 39 | circ[i + l - 1, ...] = A[i, ...] 40 | circ[l - i - 1, ...] = A[i, ...] 41 | toeplitz_A = [] 42 | for i in range(l): 43 | toeplitz_A.append(circ[l - i - 1: 2 * l - i - 1, ...]) 44 | return torch.cat(toeplitz_A, dim=2).reshape(l*m, l*n) 45 | 46 | def tph(A): 47 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 48 | return torch.Tensor(toeplitz(A) + hankel(A)).to(device) 49 | 50 | def shift(X): 51 | A = X.clone() 52 | for i in range(1, A.shape[0]): 53 | k = A.shape[0] - i - 1 54 | A[k, ...] -= A[k + 1, ...] 55 | return A 56 | 57 | def t_product(A, B): 58 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]) 59 | prod = torch.mm(tph(shift(A)), bcirc(B)[:, 0:B.shape[2]]) 60 | return prod.reshape(A.shape[0], A.shape[1], B.shape[2]) 61 | 62 | def dct_t_product(A, B): 63 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]) 64 | dct_C = torch.zeros(A.shape[0], A.shape[1], B.shape[2]) 65 | dct_A = torch_apply(dct.dct, A) 66 | dct_B = torch_apply(dct.dct, B) 67 | for k in range(A.shape[0]): 68 | dct_C[k, ...] = torch.mm(dct_A[k, ...], dct_B[k, ...]) 69 | return torch_apply(dct.idct, dct_C) 70 | 71 | def t_product_fft(A, B): 72 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]) 73 | prod = torch.mm(bcirc(A), bcirc(B)[:, 0:B.shape[2]]) 74 | return prod.reshape(A.shape[0], A.shape[1], B.shape[2]) 75 | 76 | def t_product_fft_v2(A, B): 77 | assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]) 78 | dct_C = np.zeros((A.shape[0], A.shape[1], B.shape[2]), dtype=complex) 79 | for k in range(A.shape[0]): 80 | dct_C[k, ...] = np.fft.fft(A, axis=0)[k, ...] @ np.fft.fft(B, axis=0)[k, ...] 81 | return np.real(np.fft.ifft(dct_C, axis=0)) 82 | 83 | def scalar_tubal_func(output_tensor): 84 | l, m, n = output_tensor.shape 85 | lateral_slices = [output_tensor[:, :, i].reshape(l, m, 1) for i in range(n)] 86 | h_slice = [] 87 | for s in lateral_slices: 88 | h_slice.append(h_func_dct(s)) 89 | pro_matrix = torch.stack(h_slice, dim=2) 90 | return pro_matrix.reshape(m, n) 91 | 92 | def h_func_dct(lateral_slice): 93 | l, m, n = lateral_slice.shape 94 | dct_slice = dct.dct(lateral_slice) 95 | tubes = [dct_slice[i, :, 0] for i in range(l)] 96 | h_tubes = [] 97 | for tube in tubes: 98 | h_tubes.append(torch.exp(tube) / torch.sum(torch.exp(tube))) 99 | res_slice = torch.stack(h_tubes, dim=0).reshape(l, m, n) 100 | idct_a = dct.idct(res_slice) 101 | return torch.sum(idct_a, dim=0) 102 | 103 | def torch_apply(func, x): 104 | x = func(torch.transpose(x, 0, 2)) 105 | return torch.transpose(x, 0, 2) 106 | 107 | def make_weights(shape, device='cpu', scale=0.01): 108 | w = torch.randn(shape[0], 10, shape[1]) * scale 109 | b = torch.randn(shape[0], 10, shape[2]) * scale 110 | dct_w = torch_apply(dct.dct, w).to(device) 111 | dct_b = torch_apply(dct.dct, b).to(device) 112 | return dct_w, dct_b 113 | 114 | def to_categorical(y, num_classes): 115 | categorical = torch.empty(len(y), num_classes) 116 | for i in range(len(y)): 117 | categorical[i, :] = torch.eye(num_classes, num_classes)[y[i]] 118 | return categorical 119 | 120 | def torch_shift(A): 121 | x = A.squeeze() 122 | x = torch.transpose(x, 0, 2) 123 | x = torch.transpose(x, 0, 1) 124 | return x 125 | 126 | def cifar_img_process(raw_img): 127 | k, l, m, n = raw_img.shape 128 | img_list = torch.split(raw_img, split_size_or_sections=1, dim=0) 129 | list = [] 130 | for img in img_list: 131 | img = img.reshape(l, m, n) 132 | frontal = torch.cat([img[i, :, :] for i in range(l)], dim=0) 133 | single_img = torch.transpose(frontal.reshape(1, l * m, n), 0, 2) 134 | list.append(single_img) 135 | ultra_img = torch.cat(list, dim=2) 136 | return ultra_img 137 | 138 | def raw_img(img, batch_size, n): 139 | img_raw = img.reshape(batch_size, n * n) 140 | single_img = torch.split(img_raw, split_size_or_sections=1, dim=0) 141 | single_img_T = [torch.transpose(i.reshape(n, n, 1), 0, 1) for i in single_img] 142 | ultra_img = torch.cat(single_img_T, dim=2) 143 | return ultra_img 144 | 145 | 146 | 147 | 148 | --------------------------------------------------------------------------------