├── .editorconfig ├── .gitignore ├── DCGAN_pytorch ├── README.md ├── recourses │ ├── 1_epoch.png │ └── paper_figure.png └── src │ ├── model.py │ └── train.py ├── LICENSE ├── README.md ├── cnn_architectures ├── README.md ├── resources │ └── vgg_architectures.png └── vgg16.py ├── neural_style_transfer ├── README.md ├── content_images │ ├── alisha.jpg │ ├── los_angeles.jpg │ ├── los_angeles_night.jpg │ └── snow.jpg ├── generated_images │ ├── alisha_1.png │ ├── alisha_2.png │ ├── gen_1000.png │ ├── generated_1000.png │ ├── generated_1500.png │ ├── generated_200.png │ ├── generated_2000.png │ ├── generated_600.png │ └── snow_1.png ├── nst.py ├── readme_results │ ├── result_1.png │ └── result_2.png └── style_images │ ├── dog_abstract_style.jpg │ └── starry_night_full.jpg ├── object_detection ├── README.md ├── object_localization │ ├── README.md │ ├── pet_localization.ipynb │ └── readme_resources │ │ ├── result_1.png │ │ ├── result_2.png │ │ ├── result_3.png │ │ └── result_4.png └── yolo │ └── model.py ├── poetry.lock ├── pyproject.toml └── transfer_learning ├── README.md └── transfer_learn.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = tab 5 | indent_size = 4 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # data 133 | transfer_learning/data 134 | DCGAN_pytorch/src/data 135 | 136 | # DS_Store 137 | .DS_Store 138 | transfer_learning/.DS_Store -------------------------------------------------------------------------------- /DCGAN_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # DCGAN_pytorch 2 | DCGAN implementation of the original [paper](https://arxiv.org/abs/1511.06434) in PyTorch. I tried to stay as close to the [paper](https://arxiv.org/abs/1511.06434) as possible. 3 | 4 | 5 | Most of the values used in the implementation are straight from the paper. 6 | The implemented structure of the network is the one depicted in the paper as well: 7 |
8 | ![structure](https://github.com/wilhelmberghammer/MachineLearning/blob/main/DCGAN_pytorch/recourses/paper_figure.png) 9 | 10 | *Figure 1 in the original [paper](https://arxiv.org/abs/1511.06434)* 11 | 12 | 13 | 14 | ### Results on MNIST 15 | **After one Epoch:** 16 |
17 | ![after one epoch](https://github.com/wilhelmberghammer/MachineLearning/blob/main/DCGAN_pytorch/recourses/1_epoch.png) 18 | -------------------------------------------------------------------------------- /DCGAN_pytorch/recourses/1_epoch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/DCGAN_pytorch/recourses/1_epoch.png -------------------------------------------------------------------------------- /DCGAN_pytorch/recourses/paper_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/DCGAN_pytorch/recourses/paper_figure.png -------------------------------------------------------------------------------- /DCGAN_pytorch/src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | # Discriminator (-> paper) 7 | class Discriminator(nn.Module): 8 | def __init__(self, img_channels, features_d): 9 | super(Discriminator, self).__init__() 10 | self.dis = nn.Sequential( 11 | 12 | nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, padding=1), 13 | nn.LeakyReLU(.2), 14 | 15 | nn.Conv2d(features_d, features_d*2, kernel_size=4, stride=2, padding=1, bias=False), 16 | nn.BatchNorm2d(features_d*2), 17 | nn.LeakyReLU(.2), 18 | 19 | nn.Conv2d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1, bias=False), 20 | nn.BatchNorm2d(features_d*4), 21 | nn.LeakyReLU(.2), 22 | 23 | nn.Conv2d(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1, bias=False), 24 | nn.BatchNorm2d(features_d*8), 25 | nn.LeakyReLU(.2), 26 | 27 | nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), 28 | nn.Sigmoid(), 29 | ) 30 | 31 | def forward(self, x): 32 | return self.dis(x) 33 | 34 | 35 | 36 | # Generator (-> paper) 37 | class Generator(nn.Module): 38 | def __init__(self, z_dim, img_channels, features_g): 39 | super(Generator, self).__init__() 40 | self.gen = nn.Sequential( 41 | 42 | nn.ConvTranspose2d(z_dim, features_g*16, kernel_size=4, stride=1, padding=0), 43 | nn.BatchNorm2d(features_g*16), 44 | nn.ReLU(), 45 | 46 | nn.ConvTranspose2d(features_g*16, features_g*8, kernel_size=4, stride=2, padding=1), 47 | nn.BatchNorm2d(features_g*8), 48 | nn.ReLU(), 49 | 50 | nn.ConvTranspose2d(features_g*8, features_g*4, kernel_size=4, stride=2, padding=1), 51 | nn.BatchNorm2d(features_g*4), 52 | nn.ReLU(), 53 | 54 | nn.ConvTranspose2d(features_g*4, features_g*2, kernel_size=4, stride=2, padding=1), 55 | nn.BatchNorm2d(features_g*2), 56 | nn.ReLU(), 57 | 58 | nn.ConvTranspose2d(features_g*2, img_channels, kernel_size=4, stride=2, padding=1), 59 | nn.Tanh(), 60 | ) 61 | 62 | def forward(self, x): 63 | return self.gen(x) 64 | 65 | 66 | 67 | def init_weights(model): 68 | for i in model.modules(): 69 | if isinstance(i, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)): 70 | nn.init.normal_(i.weight.data, .0, .02) 71 | 72 | -------------------------------------------------------------------------------- /DCGAN_pytorch/src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import DataLoader 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | from model import Discriminator, Generator, init_weights 11 | 12 | 13 | # cuda if GPU is available, otherwise train on cpu 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print(f"Using {device} for training!") 16 | 17 | # hyperparameters (-> paper) 18 | img_size = 64 19 | img_channels = 1 # 3 for rgb 20 | z_dim = 100 # noise dim 21 | feature_dis = 64 22 | feature_gen = 64 23 | 24 | EPOCHS = 1 25 | batch_size = 64 26 | learning_rate = 2e-4 27 | 28 | 29 | # Transforms 30 | transforms = transforms.Compose([ 31 | transforms.Resize(img_size), 32 | transforms.ToTensor(), 33 | transforms.Normalize([.5], [.5]), # needs to be changes when using more then 1 channel!! 34 | ]) 35 | 36 | # download and load dataset - in this case I'll be using MNIST 37 | dataset = torchvision.datasets.MNIST(root='data/', train=True, download=True, transform=transforms) 38 | # create datalaoder 39 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 40 | 41 | # init Generator 42 | gen = Generator(z_dim, img_channels, feature_gen).to(device) 43 | dis = Discriminator(img_channels, feature_dis).to(device) 44 | 45 | # init weights 46 | init_weights(gen) 47 | init_weights(dis) 48 | 49 | # optimizers (-> paper) 50 | gen_optim = optim.Adam(gen.parameters(), lr=learning_rate, betas=(.5, .999)) 51 | dis_optim = optim.Adam(dis.parameters(), lr=learning_rate, betas=(.5, .999)) 52 | 53 | # loss function 54 | loss = nn.BCELoss() 55 | 56 | 57 | # training loop 58 | print('starting training loop...') 59 | for epoch in range(EPOCHS): 60 | for batch, (real_img, _) in enumerate(data_loader): 61 | real_img = real_img.to(device) 62 | z = torch.randn((batch_size, z_dim, 1, 1)).to(device) # noise 63 | fake_img = gen(z) 64 | 65 | # Discriminator training 66 | dis_real_img = dis(real_img).reshape(-1) 67 | loss_dis_real_img = loss(dis_real_img, torch.ones_like(dis_real_img)) 68 | 69 | dis_fake_img = dis(fake_img).reshape(-1) 70 | loss_dis_fake_img = loss(dis_fake_img, torch.zeros_like(dis_fake_img)) 71 | 72 | loss_dis = (loss_dis_real_img + loss_dis_fake_img)/2 73 | dis.zero_grad() 74 | loss_dis.backward(retain_graph = True) # retain_graph because we are using it in the training for the generator 75 | dis_optim.step() 76 | 77 | 78 | # Generator training 79 | out = dis(fake_img).reshape(-1) 80 | loss_gen = loss(out, torch.ones_like(out)) 81 | gen.zero_grad() 82 | loss_gen.backward() 83 | gen_optim.step() 84 | 85 | print(f"{epoch}/{EPOCHS}") 86 | 87 | # print a grid after every epoch - tensorboard would be better but that is a pain in colab so... 88 | with torch.no_grad(): 89 | fake_img = gen(fixed_z) 90 | img_grid_fake = torchvision.utils.make_grid(fake_img[:32], normalize=True).cpu() 91 | plt.imshow(img_grid_fake.permute(1, 2, 0)) 92 | plt.show() 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Wilhelm Berghammer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **Machine Learning** 2 | ### *Currently on hold due to other, more practical projects* 3 | --- 4 | 5 | This is where I learn machine learning🤷‍ This means that this repo covers no specific topic of machine learning or a project - I work in here when I want to learn/try something 6 | 7 | I'll try to keep it orgamised but I don't want it to be too nested because that becomes a pain to navigate, instead I'll try to keet the table of content up-to-date. 8 | 9 | **If I use a jupyter notebook I'm experimenting with something and the code will probably be a mess ... so don't take those too serious.** 10 | 11 | ## Table of Content 12 | * [Transfer Learning for Computer vision](https://github.com/wilhelmberghammer/MachineLearning/tree/main/Transfer_Learning_CV) 13 | 14 | ### Paper implementations 15 | * [DCGAN paper implementation in PyTorch](https://github.com/wilhelmberghammer/MachineLearning/tree/main/DCGAN_pytorch) 16 | * [Neural Style Transfer - implementation in PyTorch](https://github.com/wilhelmberghammer/MachineLearning/tree/main/neural_style_transfer) 17 | 18 | ### Model/Architecture implementations 19 | * [VGG16](https://github.com/wilhelmberghammer/MachineLearning/tree/main/cnn_architectures/vgg16.py) 20 | 21 | ### Working on / On Hold 22 | * [Object detection](https://github.com/wilhelmberghammer/MachineLearning/tree/main/object_detection) 23 | 24 | 25 | ### Coming up (things I want to learn in the near future) 26 | * [Weights & Biases](https://wandb.ai/site) 27 | 28 | -------------------------------------------------------------------------------- /cnn_architectures/README.md: -------------------------------------------------------------------------------- 1 | # CNN architecture implementations in PyTorch 2 | 3 | **This is just for fun** and to get a better feeling for VGG models when using the pretrained ones ... I didn't train this model I just imprelemted the architectue (I also added the hyperparameters of the training from the paper) 4 | I didn't implement the training-loop either because I didn't intented to train the thing, plus the training-loop is not the main part of the paper. 5 | 6 | 7 | ## VGG 8 | Simple PyTorch implementation of the paper [VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION](https://arxiv.org/pdf/1409.1556.pdf) 9 | 10 | I made VGG16 (D) because it's the most popular and the architecute can be modified. 11 | 12 | ![architechtures](https://github.com/wilhelmberghammer/MachineLearning/blob/main/cnn_architectures/resources/vgg_architectures.png?raw=true) 13 | 14 | *👆 figure from the paper* 15 | -------------------------------------------------------------------------------- /cnn_architectures/resources/vgg_architectures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/cnn_architectures/resources/vgg_architectures.png -------------------------------------------------------------------------------- /cnn_architectures/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class VGG16(nn.Module): 7 | ''' 8 | VGG16 (D) architecture acording to paper 9 | ''' 10 | def __init__(self, input_channels, n_classes): 11 | super(VGG16, self).__init__() 12 | self.conv_layers = nn.Sequential( 13 | # There are way better ways to to this (iterate over a list for example) ... this is so I have a better overview 14 | nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=1, padding=1), 15 | nn.ReLU(), 16 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 17 | nn.ReLU(), 18 | 19 | nn.MaxPool2d(kernel_size=2, stride=2), 20 | 21 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 22 | nn.ReLU(), 23 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 24 | nn.ReLU(), 25 | 26 | nn.MaxPool2d(kernel_size=2, stride=2), 27 | 28 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), 29 | nn.ReLU(), 30 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 31 | nn.ReLU(), 32 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 33 | nn.ReLU(), 34 | 35 | nn.MaxPool2d(kernel_size=2, stride=2), 36 | 37 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), 38 | nn.ReLU(), 39 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 40 | nn.ReLU(), 41 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(), 43 | 44 | nn.MaxPool2d(kernel_size=2, stride=2), 45 | 46 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 47 | nn.ReLU(), 48 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 49 | nn.ReLU(), 50 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 51 | nn.ReLU(), 52 | 53 | nn.MaxPool2d(kernel_size=2, stride=2), 54 | ) 55 | 56 | self.fc_layers = nn.Sequential( 57 | nn.Linear(512*7*7, 4096), # calculation in paper 58 | nn.ReLU(), 59 | nn.Dropout(.5), # paper: dropout regularisation for the first two fully-connected layers (dropout ratio set to 0.5) 60 | nn.Linear(4096, 4096), 61 | nn.ReLU(), 62 | nn.Dropout(.5), 63 | nn.Linear(4096, n_classes), 64 | ) 65 | 66 | 67 | def forward(self, x): 68 | x = self.conv_layers(x) 69 | x = x.reshape(x.shape[0], -1) 70 | x = self.fc_layers(x) 71 | return x 72 | 73 | 74 | # the following might differ slightly from the paper 75 | # I didn't do the weight initialisation like in the paper - they trained the shallower model and initialized the larger ones with those weights 76 | model = VGG16(input_channels=3, n_classes=1000) 77 | 78 | batch_size = 256 79 | loss = torch.nn.CrossEntropyLoss() 80 | optim = torch.optim.SGD(model.parameters, lr=1e-3, momentum=0.9, weight_decay=5e-4) 81 | 82 | 83 | # training-loop ... -------------------------------------------------------------------------------- /neural_style_transfer/README.md: -------------------------------------------------------------------------------- 1 | # Neural Style Transfer 2 | 3 | I tried to follow the paper [A Neural Algorithm of Artistic Style](https://arxiv.org/pdf/1508.06576v2.pdf) but I did some things differently. 4 | 5 | For example I didn't use their, more complicated loss function, I simply used MSE. 6 | 7 | I also didn't use noise as input, I used the content. 8 | 9 | This might be the reason for some differences. I didn't notice any differences when changing the alpha and beta values. The biggest change happens when changing the learning rate. A higher learning rate leads to more significant changens to the original content. 10 | 11 |
12 | 13 | ### Setup 14 | You will need a virtualenv with `python>=3.8`. 15 | 16 | ```bash 17 | pip install poetry 18 | poetry install 19 | ``` 20 | 21 | ### To run style transfer 22 | ```bash 23 | poetry run python neural_style_transfer/nst.py 24 | ``` 25 | 26 | ## Result Example 27 | **Example 1:** 28 | 29 | ![content alisha](https://github.com/wilhelmberghammer/MachineLearning/blob/main/neural_style_transfer/readme_results/result_1.png) 30 | 31 | **Example 2:** 32 | 33 | ![style dog](https://github.com/wilhelmberghammer/MachineLearning/blob/main/neural_style_transfer/readme_results/result_2.png) 34 | -------------------------------------------------------------------------------- /neural_style_transfer/content_images/alisha.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/alisha.jpg -------------------------------------------------------------------------------- /neural_style_transfer/content_images/los_angeles.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/los_angeles.jpg -------------------------------------------------------------------------------- /neural_style_transfer/content_images/los_angeles_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/los_angeles_night.jpg -------------------------------------------------------------------------------- /neural_style_transfer/content_images/snow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/snow.jpg -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/alisha_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/alisha_1.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/alisha_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/alisha_2.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/gen_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/gen_1000.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/generated_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_1000.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/generated_1500.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_1500.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/generated_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_200.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/generated_2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_2000.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/generated_600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_600.png -------------------------------------------------------------------------------- /neural_style_transfer/generated_images/snow_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/snow_1.png -------------------------------------------------------------------------------- /neural_style_transfer/nst.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision.transforms as transforms 6 | import torchvision.models as models 7 | import torchvision.utils as utils 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | class VGG_19(nn.Module): 13 | def __init__(self): 14 | super(VGG_19, self).__init__() 15 | # model used: VGG19 (like in the paper) 16 | # everything after the 28th layer is technically not needed 17 | self.model = models.vgg19(pretrained=True).features[:30] 18 | 19 | # better results when changing the MaxPool layers to AvgPool (-> paper) 20 | for i, _ in enumerate(self.model): 21 | # Indicies of the MaxPool layers -> replaced by AvgPool with same parameters 22 | if i in [4, 9, 18, 27]: 23 | self.model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 24 | 25 | def forward(self, x): 26 | features = [] 27 | 28 | for i, layer in enumerate(self.model): 29 | x = layer(x) 30 | # indicies of the conv layers after the now AvgPool layers 31 | if i in [0, 5, 10, 19, 28]: 32 | features.append(x) 33 | return features 34 | 35 | 36 | def load_img(path_to_image, img_size): 37 | transform = transforms.Compose([ 38 | transforms.Resize((img_size, img_size)), 39 | transforms.ToTensor(), 40 | ]) 41 | img = Image.open(path_to_image) 42 | img = transform(img).unsqueeze(0) 43 | return img 44 | 45 | 46 | def transfer_style(iterations, optimizer, alpha, beta, generated_image, content_image, style_image, show_images=False): 47 | for iter in range(iterations+1): 48 | generated_features = model(generated_image) 49 | content_features = model(content_image) 50 | style_features = model(style_image) 51 | 52 | content_loss = 0 53 | style_loss = 0 54 | 55 | for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features): 56 | batch_size, n_feature_maps, height, width = generated_feature.size() 57 | 58 | # in paper it is 1/2*((g - c)**2) ... but it is easies this way because I don't have to worry about dimensions ... and it workes as well 59 | content_loss += (torch.mean((generated_feature - content_feature) ** 2)) 60 | 61 | # batch_size is one ... so it isn't needed. I still inclueded it for better understanding. 62 | G = torch.mm((generated_feature.view(batch_size*n_feature_maps, height*width)), (generated_feature.view(batch_size*n_feature_maps, height*width)).t()) 63 | A = torch.mm((style_feature.view(batch_size*n_feature_maps, height*width)), (style_feature.view(batch_size*n_feature_maps, height*width)).t()) 64 | 65 | # different in paper!! 66 | E_l = ((G - A)**2) 67 | # w_l ... one divided by the number of active layers with a non-zero loss-weight -> directly from the paper (technically isn't needed) 68 | w_l = 1/5 69 | style_loss += torch.mean(w_l*E_l) 70 | 71 | # I found little difference when changing the alpha and beta values ... still kept it in for better understanding of paper 72 | total_loss = alpha * content_loss + beta * style_loss 73 | optimizer.zero_grad() 74 | total_loss.backward() 75 | optimizer.step() 76 | 77 | 78 | if iter % 100 == 0: 79 | print('-'*15) 80 | print(f'\n{iter} \nTotal Loss: {total_loss.item()} \n Content Loss: {content_loss} \t Style Loss: {style_loss}') 81 | print('-'*15) 82 | 83 | # show image 84 | if show_images == True: 85 | plt.figure(figsize=(10, 10)) 86 | plt.imshow(generated_image.permute(0, 2, 3, 1)[0].cpu().detach().numpy()) 87 | plt.show() 88 | 89 | return generated_image 90 | 91 | #if iter % 500 == 0: 92 | #utils.save_image(generated, f'./gen_{iter}.png') 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 98 | 99 | content_img = load_img('./content_images/alisha.jpg', 512).to(device) 100 | style_img = load_img('./style_images/dog_abstract_style.jpg', 512).to(device) 101 | 102 | model = VGG_19().to(device) 103 | # freeze parameters 104 | for param in model.parameters(): 105 | param.requires_grad = False 106 | 107 | # generated image (init) is the content image ... could also be noise 108 | # requires_grad because the network itself is frozen ... the thing we are changine is this 109 | generated_init = content_img.clone().requires_grad_(True) 110 | 111 | 112 | iterations = 200 113 | # the real difference is visibale whe changing the learning rate ... 1e-2 is rather high -> heavy changes to content image 114 | lr = 1e-2 115 | # I found no real difference when changing these values...this is why I keep them at 1 116 | alpha = 1 117 | beta = 1 118 | 119 | optimizer = optim.Adam([generated_init], lr=lr) 120 | 121 | generated_image = transfer_style(iterations=iterations, 122 | optimizer=optimizer, 123 | alpha=alpha, 124 | beta=beta, 125 | generated_image=generated_init, 126 | content_image=content_img, 127 | style_image=style_img, 128 | show_images=False # only true in jupyter notebook 129 | ) 130 | 131 | utils.save_image(generated_image, f'./gen.png') -------------------------------------------------------------------------------- /neural_style_transfer/readme_results/result_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/readme_results/result_1.png -------------------------------------------------------------------------------- /neural_style_transfer/readme_results/result_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/readme_results/result_2.png -------------------------------------------------------------------------------- /neural_style_transfer/style_images/dog_abstract_style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/style_images/dog_abstract_style.jpg -------------------------------------------------------------------------------- /neural_style_transfer/style_images/starry_night_full.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/style_images/starry_night_full.jpg -------------------------------------------------------------------------------- /object_detection/README.md: -------------------------------------------------------------------------------- 1 | # Object Detection 2 | 3 | After studying and playing around with CNNs I'll continue with computer vision and start with object localization and object detection. 4 | 5 | ### Roadmap 6 | * 1st step - learn about Object Localization (I will probably delete this after implementing it better for object detection) 7 | * learn core concepts of object localization 8 | * **just to experiment - the code is a mess** 9 | * test on [The Oxford-IIIT Pet Dataset](https://www.kaggle.com/devdgohil/the-oxfordiiit-pet-dataset) 10 | 11 | 12 | * **2nd step - Object Detection** *On Hold* 13 | * **this is the main part** 14 | * study [OverFeat: Integrated Recognition, Localization and Detection using Conv Nets Paper](https://arxiv.org/abs/1312.6229) 15 | * [YOLO paper](https://arxiv.org/abs/1506.02640) implementation 16 | 17 | 18 | ### Resources 19 | * [Deeplearning.ai YouTube Playlist](https://www.youtube.com/watch?v=GSwYGkTfOKk&list=PLkDaE6sCZn6Gl29AoE31iwdVwSG-KnDzF&index=23) 20 | * [OverFeat: Integrated Recognition, Localization and Detection using Conv Nets Paper](https://arxiv.org/abs/1312.6229) 21 | * [origtnal YOLO paper](https://arxiv.org/abs/1506.02640) 22 | -------------------------------------------------------------------------------- /object_detection/object_localization/README.md: -------------------------------------------------------------------------------- 1 | # Object Localization 2 | **This is just to experiment with the concept of object localization!** 3 | 4 | Don't take the code too serious 5 | 6 | ## Results: 7 | ![reslut 1](https://github.com/wilhelmberghammer/MachineLearning/blob/main/object_detection/object_localization/readme_resources/result_1.png) 8 | 9 | ![reslut 2](https://github.com/wilhelmberghammer/MachineLearning/blob/main/object_detection/object_localization/readme_resources/result_2.png) 10 | 11 | ![reslut 3](https://github.com/wilhelmberghammer/MachineLearning/blob/main/object_detection/object_localization/readme_resources/result_3.png) 12 | 13 | ![reslut 4](https://github.com/wilhelmberghammer/MachineLearning/blob/main/object_detection/object_localization/readme_resources/result_4.png) 14 | -------------------------------------------------------------------------------- /object_detection/object_localization/pet_localization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pot_location.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm" 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "source": [ 20 | "# Per Localization\n", 21 | "Datset used: [The Oxford IIIT Pet Dataset](https://www.kaggle.com/devdgohil/the-oxfordiiit-pet-dataset)\n", 22 | "\n", 23 | "### Disclaimer:\n", 24 | "**This is just to implement the core concept of basic object localization and is not to be taken serious - this is just experimental and to learn it before going to object detection**\n" 25 | ], 26 | "cell_type": "markdown", 27 | "metadata": {} 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "9samIqgWbqHj" 33 | }, 34 | "source": [ 35 | "import zipfile\n", 36 | "import os\n", 37 | "import numpy as np\n", 38 | "import csv\n", 39 | "import cv2\n", 40 | "import glob\n", 41 | "import xml.etree.ElementTree as ET\n", 42 | "import matplotlib.pyplot as plt\n", 43 | "import random\n", 44 | "import tqdm\n", 45 | "\n", 46 | "import torch\n", 47 | "import torch.nn as nn\n", 48 | "import torch.optim as optim" 49 | ], 50 | "execution_count": 1, 51 | "outputs": [] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "metadata": { 56 | "id": "rt1cIZKG3SNX" 57 | }, 58 | "source": [ 59 | "os.mkdir('data') \n", 60 | "\n", 61 | "with zipfile.ZipFile(\"drive/MyDrive/AI/Data/ImageData/oxford_iiit.zip\",\"r\") as zip_ref:\n", 62 | " zip_ref.extractall(\"./data\")" 63 | ], 64 | "execution_count": null, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "id": "IUxXMMkyX5cb" 71 | }, 72 | "source": [ 73 | "def pad_image(img, IMG_SIZE):\n", 74 | " image = cv2.copyMakeBorder(img, 0, IMG_SIZE-img.shape[0], 0, IMG_SIZE-img.shape[1], cv2.BORDER_CONSTANT)\n", 75 | " return image" 76 | ], 77 | "execution_count": 2, 78 | "outputs": [] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "colab": { 84 | "base_uri": "https://localhost:8080/" 85 | }, 86 | "id": "LuL3ITTV4BH8", 87 | "outputId": "01c64691-f1a3-416a-8333-1829d2609213" 88 | }, 89 | "source": [ 90 | "IMG_SIZE = 550\n", 91 | "XMLS = \"./data/annotations/annotations/xmls\"\n", 92 | "\n", 93 | "training_data = []\n", 94 | "xml_files = glob.glob(\"{}/*xml\".format(XMLS))\n", 95 | "for i, xml_file in enumerate(xml_files):\n", 96 | " tree = ET.parse(xml_file)\n", 97 | "\n", 98 | " path = os.path.join('./data/images/images', tree.findtext(\"./filename\"))\n", 99 | " img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)\n", 100 | "\n", 101 | " if img.shape[0] < IMG_SIZE and img.shape[1] < IMG_SIZE:\n", 102 | " xmin = int(tree.findtext(\"./object/bndbox/xmin\"))\n", 103 | " ymin = int(tree.findtext(\"./object/bndbox/ymin\"))\n", 104 | " xmax = int(tree.findtext(\"./object/bndbox/xmax\"))\n", 105 | " ymax = int(tree.findtext(\"./object/bndbox/ymax\"))\n", 106 | "\n", 107 | " image = pad_image(img, IMG_SIZE)\n", 108 | " training_data.append([np.array(image), [xmin/IMG_SIZE, ymin/IMG_SIZE, xmax/IMG_SIZE, ymax/IMG_SIZE]])\n", 109 | "\n", 110 | "print('training_data length: ', len(training_data))" 111 | ], 112 | "execution_count": 3, 113 | "outputs": [ 114 | { 115 | "output_type": "stream", 116 | "text": [ 117 | "training_data length: 3634\n" 118 | ], 119 | "name": "stdout" 120 | } 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "metadata": { 126 | "id": "_EAUPCZWuglh" 127 | }, 128 | "source": [ 129 | "X = torch.Tensor([i[0] for i in training_data]).view(-1, IMG_SIZE, IMG_SIZE)\n", 130 | "X = X/255.0\n", 131 | "y = torch.Tensor([i[1] for i in training_data])" 132 | ], 133 | "execution_count": 4, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "metadata": { 139 | "colab": { 140 | "base_uri": "https://localhost:8080/" 141 | }, 142 | "id": "erULjdDlneMp", 143 | "outputId": "63b19af7-6015-45d2-b44e-8b3f49adf296" 144 | }, 145 | "source": [ 146 | "class Net(nn.Module):\n", 147 | " def __init__(self, in_channels, n_classes):\n", 148 | " super(Net, self).__init__()\n", 149 | " self.conv = nn.Sequential(\n", 150 | " nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=5),\n", 151 | " nn.ReLU(),\n", 152 | " nn.BatchNorm2d(8),\n", 153 | "\n", 154 | " nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5),\n", 155 | " nn.ReLU(),\n", 156 | " nn.BatchNorm2d(8),\n", 157 | "\n", 158 | " nn.MaxPool2d(2),\n", 159 | "\n", 160 | " nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5),\n", 161 | " nn.ReLU(),\n", 162 | " nn.BatchNorm2d(16),\n", 163 | "\n", 164 | " nn.MaxPool2d(2),\n", 165 | "\n", 166 | " nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5),\n", 167 | " nn.ReLU(),\n", 168 | " nn.BatchNorm2d(16),\n", 169 | "\n", 170 | " nn.MaxPool2d(2),\n", 171 | "\n", 172 | " )\n", 173 | " self.fc = nn.Sequential(\n", 174 | " nn.Linear(16*64*64, 4096),\n", 175 | " nn.ReLU(),\n", 176 | " nn.Dropout(.5),\n", 177 | " nn.Linear(4096, 1024),\n", 178 | " nn.ReLU(),\n", 179 | " nn.Dropout(.2),\n", 180 | " nn.Linear(1024, n_classes),\n", 181 | " nn.Sigmoid()\n", 182 | " )\n", 183 | " \n", 184 | " def forward(self, x):\n", 185 | " x = self.conv(x)\n", 186 | " x = x.view(-1, 16*64*64)\n", 187 | " x = self.fc(x)\n", 188 | " return x\n", 189 | "\n", 190 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 191 | "model = Net(1, 4).to(device)\n", 192 | "print(model)" 193 | ], 194 | "execution_count": 5, 195 | "outputs": [ 196 | { 197 | "output_type": "stream", 198 | "text": [ 199 | "Net(\n", 200 | " (conv): Sequential(\n", 201 | " (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))\n", 202 | " (1): ReLU()\n", 203 | " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 204 | " (3): Conv2d(8, 8, kernel_size=(5, 5), stride=(1, 1))\n", 205 | " (4): ReLU()\n", 206 | " (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 207 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 208 | " (7): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))\n", 209 | " (8): ReLU()\n", 210 | " (9): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 211 | " (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 212 | " (11): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))\n", 213 | " (12): ReLU()\n", 214 | " (13): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 215 | " (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 216 | " )\n", 217 | " (fc): Sequential(\n", 218 | " (0): Linear(in_features=65536, out_features=4096, bias=True)\n", 219 | " (1): ReLU()\n", 220 | " (2): Dropout(p=0.5, inplace=False)\n", 221 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 222 | " (4): ReLU()\n", 223 | " (5): Dropout(p=0.2, inplace=False)\n", 224 | " (6): Linear(in_features=1024, out_features=4, bias=True)\n", 225 | " (7): Sigmoid()\n", 226 | " )\n", 227 | ")\n" 228 | ], 229 | "name": "stdout" 230 | } 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "metadata": { 236 | "colab": { 237 | "base_uri": "https://localhost:8080/" 238 | }, 239 | "id": "EuWFNtGGv9Dg", 240 | "outputId": "d18547ff-ecb9-423b-9ac1-8b70c05972f4" 241 | }, 242 | "source": [ 243 | "# to double check dimensions\n", 244 | "model.conv(X[:1].view(-1, 1, 550, 550).to(device)).size()" 245 | ], 246 | "execution_count": 15, 247 | "outputs": [ 248 | { 249 | "output_type": "execute_result", 250 | "data": { 251 | "text/plain": [ 252 | "torch.Size([1, 16, 64, 64])" 253 | ] 254 | }, 255 | "metadata": { 256 | "tags": [] 257 | }, 258 | "execution_count": 15 259 | } 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "metadata": { 265 | "colab": { 266 | "base_uri": "https://localhost:8080/" 267 | }, 268 | "id": "qg9ILtQ_1Lcd", 269 | "outputId": "d37d12a2-9e12-4005-c1c6-544d3a1f2615" 270 | }, 271 | "source": [ 272 | "BATCH_SIZE = 4\n", 273 | "EPOCHS = 10\n", 274 | "model.train()\n", 275 | "optimizer = optim.SGD(model.parameters(), lr=0.01)\n", 276 | "loss_function = nn.MSELoss()\n", 277 | "\n", 278 | "def train(model):\n", 279 | " for epoch in range(EPOCHS):\n", 280 | " for i in tqdm.tqdm(range(0, len(X), BATCH_SIZE)):\n", 281 | " batch_X = X[i:i+BATCH_SIZE].view(-1, 1, 550, 550).to(device)\n", 282 | " batch_y = y[i:i+BATCH_SIZE].to(device)\n", 283 | "\n", 284 | " model.zero_grad()\n", 285 | "\n", 286 | " outputs = model(batch_X)\n", 287 | " loss = loss_function(outputs, batch_y)\n", 288 | " loss.backward()\n", 289 | " optimizer.step() \n", 290 | "\n", 291 | " print(f\"\\nEpoch: {epoch}. Loss: {loss}\")\n", 292 | "\n", 293 | "\n", 294 | "train(model)" 295 | ], 296 | "execution_count": 7, 297 | "outputs": [ 298 | { 299 | "output_type": "stream", 300 | "text": [ 301 | "100%|██████████| 909/909 [00:30<00:00, 29.67it/s]\n", 302 | " 0%| | 4/909 [00:00<00:24, 37.05it/s]" 303 | ], 304 | "name": "stderr" 305 | }, 306 | { 307 | "output_type": "stream", 308 | "text": [ 309 | "\n", 310 | "Epoch: 0. Loss: 0.005115551874041557\n" 311 | ], 312 | "name": "stdout" 313 | }, 314 | { 315 | "output_type": "stream", 316 | "text": [ 317 | "100%|██████████| 909/909 [00:30<00:00, 29.71it/s]\n", 318 | " 0%| | 4/909 [00:00<00:24, 37.57it/s]" 319 | ], 320 | "name": "stderr" 321 | }, 322 | { 323 | "output_type": "stream", 324 | "text": [ 325 | "\n", 326 | "Epoch: 1. Loss: 0.00565384142100811\n" 327 | ], 328 | "name": "stdout" 329 | }, 330 | { 331 | "output_type": "stream", 332 | "text": [ 333 | "100%|██████████| 909/909 [00:30<00:00, 29.68it/s]\n", 334 | " 0%| | 4/909 [00:00<00:24, 37.51it/s]" 335 | ], 336 | "name": "stderr" 337 | }, 338 | { 339 | "output_type": "stream", 340 | "text": [ 341 | "\n", 342 | "Epoch: 2. Loss: 0.0017478576628491282\n" 343 | ], 344 | "name": "stdout" 345 | }, 346 | { 347 | "output_type": "stream", 348 | "text": [ 349 | "100%|██████████| 909/909 [00:30<00:00, 29.74it/s]\n", 350 | " 0%| | 4/909 [00:00<00:24, 37.59it/s]" 351 | ], 352 | "name": "stderr" 353 | }, 354 | { 355 | "output_type": "stream", 356 | "text": [ 357 | "\n", 358 | "Epoch: 3. Loss: 0.00245093647390604\n" 359 | ], 360 | "name": "stdout" 361 | }, 362 | { 363 | "output_type": "stream", 364 | "text": [ 365 | "100%|██████████| 909/909 [00:30<00:00, 29.72it/s]\n", 366 | " 0%| | 4/909 [00:00<00:24, 37.33it/s]" 367 | ], 368 | "name": "stderr" 369 | }, 370 | { 371 | "output_type": "stream", 372 | "text": [ 373 | "\n", 374 | "Epoch: 4. Loss: 0.001998061779886484\n" 375 | ], 376 | "name": "stdout" 377 | }, 378 | { 379 | "output_type": "stream", 380 | "text": [ 381 | "100%|██████████| 909/909 [00:30<00:00, 29.71it/s]\n", 382 | " 0%| | 4/909 [00:00<00:24, 37.64it/s]" 383 | ], 384 | "name": "stderr" 385 | }, 386 | { 387 | "output_type": "stream", 388 | "text": [ 389 | "\n", 390 | "Epoch: 5. Loss: 0.0021302467212080956\n" 391 | ], 392 | "name": "stdout" 393 | }, 394 | { 395 | "output_type": "stream", 396 | "text": [ 397 | "100%|██████████| 909/909 [00:30<00:00, 29.72it/s]\n", 398 | " 0%| | 4/909 [00:00<00:24, 37.52it/s]" 399 | ], 400 | "name": "stderr" 401 | }, 402 | { 403 | "output_type": "stream", 404 | "text": [ 405 | "\n", 406 | "Epoch: 6. Loss: 0.0025986104737967253\n" 407 | ], 408 | "name": "stdout" 409 | }, 410 | { 411 | "output_type": "stream", 412 | "text": [ 413 | "100%|██████████| 909/909 [00:30<00:00, 29.71it/s]\n", 414 | " 0%| | 4/909 [00:00<00:24, 37.28it/s]" 415 | ], 416 | "name": "stderr" 417 | }, 418 | { 419 | "output_type": "stream", 420 | "text": [ 421 | "\n", 422 | "Epoch: 7. Loss: 0.0031612785533070564\n" 423 | ], 424 | "name": "stdout" 425 | }, 426 | { 427 | "output_type": "stream", 428 | "text": [ 429 | "100%|██████████| 909/909 [00:30<00:00, 29.70it/s]\n", 430 | " 0%| | 4/909 [00:00<00:24, 37.63it/s]" 431 | ], 432 | "name": "stderr" 433 | }, 434 | { 435 | "output_type": "stream", 436 | "text": [ 437 | "\n", 438 | "Epoch: 8. Loss: 0.0011636980343610048\n" 439 | ], 440 | "name": "stdout" 441 | }, 442 | { 443 | "output_type": "stream", 444 | "text": [ 445 | "100%|██████████| 909/909 [00:30<00:00, 29.67it/s]" 446 | ], 447 | "name": "stderr" 448 | }, 449 | { 450 | "output_type": "stream", 451 | "text": [ 452 | "\n", 453 | "Epoch: 9. Loss: 0.0014444717671722174\n" 454 | ], 455 | "name": "stdout" 456 | }, 457 | { 458 | "output_type": "stream", 459 | "text": [ 460 | "\n" 461 | ], 462 | "name": "stderr" 463 | } 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "metadata": { 469 | "id": "LTgBvy4CDLmE" 470 | }, 471 | "source": [ 472 | "e = 1\n", 473 | "\n", 474 | "label = model(X[e].view(-1, 1, 550, 550).to(device))\n", 475 | "\n", 476 | "xmin = label[0][0].item()\n", 477 | "ymin = label[0][1].item()\n", 478 | "xmax = label[0][2].item()\n", 479 | "ymax = label[0][3].item()\n", 480 | "\n", 481 | "img = training_data[e][0]\n", 482 | "bnd_img = cv2.rectangle(img, (int(xmin*IMG_SIZE), int(ymin*IMG_SIZE)), (int(xmax*IMG_SIZE), int(ymax*IMG_SIZE)), (255, 0, 0), 2)\n", 483 | "plt.imshow(bnd_img, cmap='gray')" 484 | ], 485 | "execution_count": null, 486 | "outputs": [] 487 | } 488 | ] 489 | } -------------------------------------------------------------------------------- /object_detection/object_localization/readme_resources/result_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_1.png -------------------------------------------------------------------------------- /object_detection/object_localization/readme_resources/result_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_2.png -------------------------------------------------------------------------------- /object_detection/object_localization/readme_resources/result_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_3.png -------------------------------------------------------------------------------- /object_detection/object_localization/readme_resources/result_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_4.png -------------------------------------------------------------------------------- /object_detection/yolo/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # this is just an implementation of the CNN model in the paper (-> Figure 3: The Architecture) 5 | class YOLO(nn.Module): 6 | def __init__(self, in_channels, n_anchorboxes, n_classes): 7 | super(YOLO, self).__init__() 8 | self.conv_layers = nn.Sequential( 9 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False), 10 | nn.LeakyReLU(.1), 11 | nn.MaxPool2d(kernel_size=2, stride=2), 12 | 13 | nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, stride=1, padding=1, bias=False), 14 | nn.LeakyReLU(.1), 15 | nn.MaxPool2d(kernel_size=2, stride=2), 16 | 17 | nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False), 18 | nn.LeakyReLU(.1), 19 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 20 | nn.LeakyReLU(.1), 21 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False), 22 | nn.LeakyReLU(.1), 23 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False), 24 | nn.LeakyReLU(.1), 25 | nn.MaxPool2d(kernel_size=2, stride=2), 26 | 27 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False), 28 | nn.LeakyReLU(.1), 29 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False), 30 | nn.LeakyReLU(.1), 31 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False), 32 | nn.LeakyReLU(.1), 33 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False), 34 | nn.LeakyReLU(.1), 35 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False), 36 | nn.LeakyReLU(.1), 37 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False), 38 | nn.LeakyReLU(.1), 39 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False), 40 | nn.LeakyReLU(.1), 41 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False), 42 | nn.LeakyReLU(.1), 43 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False), 44 | nn.LeakyReLU(.1), 45 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False), 46 | nn.LeakyReLU(.1), 47 | nn.MaxPool2d(kernel_size=2, stride=2), 48 | 49 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False), 50 | nn.LeakyReLU(.1), 51 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False), 52 | nn.LeakyReLU(.1), 53 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False), 54 | nn.LeakyReLU(.1), 55 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False), 56 | nn.LeakyReLU(.1), 57 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False), 58 | nn.LeakyReLU(.1), 59 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1, bias=False), 60 | nn.LeakyReLU(.1), 61 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False), 62 | nn.LeakyReLU(.1), 63 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False), 64 | nn.LeakyReLU(.1), 65 | ) 66 | 67 | self.fc = nn.Sequential( 68 | nn.Flatten(), 69 | nn.Linear(1024*7*7, 4096), 70 | nn.Dropout(.5), 71 | nn.LeakyReLU(.1), 72 | nn.Linear(4096, 7*7*(n_anchorboxes*5 + n_classes)) 73 | ) 74 | 75 | def forward(self, x): 76 | x = self.conv_layers(x) 77 | x = self.fc(x) 78 | return x -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "cycler" 3 | version = "0.10.0" 4 | description = "Composable style cycles" 5 | category = "main" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [package.dependencies] 10 | six = "*" 11 | 12 | [[package]] 13 | name = "kiwisolver" 14 | version = "1.3.1" 15 | description = "A fast implementation of the Cassowary constraint solver" 16 | category = "main" 17 | optional = false 18 | python-versions = ">=3.6" 19 | 20 | [[package]] 21 | name = "matplotlib" 22 | version = "3.3.4" 23 | description = "Python plotting package" 24 | category = "main" 25 | optional = false 26 | python-versions = ">=3.6" 27 | 28 | [package.dependencies] 29 | cycler = ">=0.10" 30 | kiwisolver = ">=1.0.1" 31 | numpy = ">=1.15" 32 | pillow = ">=6.2.0" 33 | pyparsing = ">=2.0.3,<2.0.4 || >2.0.4,<2.1.2 || >2.1.2,<2.1.6 || >2.1.6" 34 | python-dateutil = ">=2.1" 35 | 36 | [[package]] 37 | name = "numpy" 38 | version = "1.20.1" 39 | description = "NumPy is the fundamental package for array computing with Python." 40 | category = "main" 41 | optional = false 42 | python-versions = ">=3.7" 43 | 44 | [[package]] 45 | name = "pillow" 46 | version = "8.1.0" 47 | description = "Python Imaging Library (Fork)" 48 | category = "main" 49 | optional = false 50 | python-versions = ">=3.6" 51 | 52 | [[package]] 53 | name = "pyparsing" 54 | version = "2.4.7" 55 | description = "Python parsing module" 56 | category = "main" 57 | optional = false 58 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 59 | 60 | [[package]] 61 | name = "python-dateutil" 62 | version = "2.8.1" 63 | description = "Extensions to the standard Python datetime module" 64 | category = "main" 65 | optional = false 66 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" 67 | 68 | [package.dependencies] 69 | six = ">=1.5" 70 | 71 | [[package]] 72 | name = "six" 73 | version = "1.15.0" 74 | description = "Python 2 and 3 compatibility utilities" 75 | category = "main" 76 | optional = false 77 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 78 | 79 | [[package]] 80 | name = "torch" 81 | version = "1.7.1" 82 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 83 | category = "main" 84 | optional = false 85 | python-versions = ">=3.6.2" 86 | 87 | [package.dependencies] 88 | numpy = "*" 89 | typing-extensions = "*" 90 | 91 | [[package]] 92 | name = "torchvision" 93 | version = "0.8.2" 94 | description = "image and video datasets and models for torch deep learning" 95 | category = "main" 96 | optional = false 97 | python-versions = "*" 98 | 99 | [package.dependencies] 100 | numpy = "*" 101 | pillow = ">=4.1.1" 102 | torch = "1.7.1" 103 | 104 | [package.extras] 105 | scipy = ["scipy"] 106 | 107 | [[package]] 108 | name = "typing-extensions" 109 | version = "3.7.4.3" 110 | description = "Backported and Experimental Type Hints for Python 3.5+" 111 | category = "main" 112 | optional = false 113 | python-versions = "*" 114 | 115 | [metadata] 116 | lock-version = "1.1" 117 | python-versions = "^3.8" 118 | content-hash = "7f5f14d8e96047e0a4c039c8cdf67b2023797c8e2a39124797e9d1d08f1118c7" 119 | 120 | [metadata.files] 121 | cycler = [ 122 | {file = "cycler-0.10.0-py2.py3-none-any.whl", hash = "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d"}, 123 | {file = "cycler-0.10.0.tar.gz", hash = "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8"}, 124 | ] 125 | kiwisolver = [ 126 | {file = "kiwisolver-1.3.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:fd34fbbfbc40628200730bc1febe30631347103fc8d3d4fa012c21ab9c11eca9"}, 127 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:d3155d828dec1d43283bd24d3d3e0d9c7c350cdfcc0bd06c0ad1209c1bbc36d0"}, 128 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:5a7a7dbff17e66fac9142ae2ecafb719393aaee6a3768c9de2fd425c63b53e21"}, 129 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:f8d6f8db88049a699817fd9178782867bf22283e3813064302ac59f61d95be05"}, 130 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux2014_ppc64le.whl", hash = "sha256:5f6ccd3dd0b9739edcf407514016108e2280769c73a85b9e59aa390046dbf08b"}, 131 | {file = "kiwisolver-1.3.1-cp36-cp36m-win32.whl", hash = "sha256:225e2e18f271e0ed8157d7f4518ffbf99b9450fca398d561eb5c4a87d0986dd9"}, 132 | {file = "kiwisolver-1.3.1-cp36-cp36m-win_amd64.whl", hash = "sha256:cf8b574c7b9aa060c62116d4181f3a1a4e821b2ec5cbfe3775809474113748d4"}, 133 | {file = "kiwisolver-1.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:232c9e11fd7ac3a470d65cd67e4359eee155ec57e822e5220322d7b2ac84fbf0"}, 134 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:b38694dcdac990a743aa654037ff1188c7a9801ac3ccc548d3341014bc5ca278"}, 135 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ca3820eb7f7faf7f0aa88de0e54681bddcb46e485beb844fcecbcd1c8bd01689"}, 136 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:c8fd0f1ae9d92b42854b2979024d7597685ce4ada367172ed7c09edf2cef9cb8"}, 137 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:1e1bc12fb773a7b2ffdeb8380609f4f8064777877b2225dec3da711b421fda31"}, 138 | {file = "kiwisolver-1.3.1-cp37-cp37m-win32.whl", hash = "sha256:72c99e39d005b793fb7d3d4e660aed6b6281b502e8c1eaf8ee8346023c8e03bc"}, 139 | {file = "kiwisolver-1.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:8be8d84b7d4f2ba4ffff3665bcd0211318aa632395a1a41553250484a871d454"}, 140 | {file = "kiwisolver-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:31dfd2ac56edc0ff9ac295193eeaea1c0c923c0355bf948fbd99ed6018010b72"}, 141 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:563c649cfdef27d081c84e72a03b48ea9408c16657500c312575ae9d9f7bc1c3"}, 142 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:78751b33595f7f9511952e7e60ce858c6d64db2e062afb325985ddbd34b5c131"}, 143 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a357fd4f15ee49b4a98b44ec23a34a95f1e00292a139d6015c11f55774ef10de"}, 144 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:5989db3b3b34b76c09253deeaf7fbc2707616f130e166996606c284395da3f18"}, 145 | {file = "kiwisolver-1.3.1-cp38-cp38-win32.whl", hash = "sha256:c08e95114951dc2090c4a630c2385bef681cacf12636fb0241accdc6b303fd81"}, 146 | {file = "kiwisolver-1.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:44a62e24d9b01ba94ae7a4a6c3fb215dc4af1dde817e7498d901e229aaf50e4e"}, 147 | {file = "kiwisolver-1.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:50af681a36b2a1dee1d3c169ade9fdc59207d3c31e522519181e12f1b3ba7000"}, 148 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:a53d27d0c2a0ebd07e395e56a1fbdf75ffedc4a05943daf472af163413ce9598"}, 149 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:834ee27348c4aefc20b479335fd422a2c69db55f7d9ab61721ac8cd83eb78882"}, 150 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5c3e6455341008a054cccee8c5d24481bcfe1acdbc9add30aa95798e95c65621"}, 151 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:acef3d59d47dd85ecf909c359d0fd2c81ed33bdff70216d3956b463e12c38a54"}, 152 | {file = "kiwisolver-1.3.1-cp39-cp39-win32.whl", hash = "sha256:c5518d51a0735b1e6cee1fdce66359f8d2b59c3ca85dc2b0813a8aa86818a030"}, 153 | {file = "kiwisolver-1.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:b9edd0110a77fc321ab090aaa1cfcaba1d8499850a12848b81be2222eab648f6"}, 154 | {file = "kiwisolver-1.3.1-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0cd53f403202159b44528498de18f9285b04482bab2a6fc3f5dd8dbb9352e30d"}, 155 | {file = "kiwisolver-1.3.1-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:33449715e0101e4d34f64990352bce4095c8bf13bed1b390773fc0a7295967b3"}, 156 | {file = "kiwisolver-1.3.1-pp36-pypy36_pp73-win32.whl", hash = "sha256:401a2e9afa8588589775fe34fc22d918ae839aaaf0c0e96441c0fdbce6d8ebe6"}, 157 | {file = "kiwisolver-1.3.1.tar.gz", hash = "sha256:950a199911a8d94683a6b10321f9345d5a3a8433ec58b217ace979e18f16e248"}, 158 | ] 159 | matplotlib = [ 160 | {file = "matplotlib-3.3.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:672960dd114e342b7c610bf32fb99d14227f29919894388b41553217457ba7ef"}, 161 | {file = "matplotlib-3.3.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:7c155437ae4fd366e2700e2716564d1787700687443de46bcb895fe0f84b761d"}, 162 | {file = "matplotlib-3.3.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:a17f0a10604fac7627ec82820439e7db611722e80c408a726cd00d8c974c2fb3"}, 163 | {file = "matplotlib-3.3.4-cp36-cp36m-win32.whl", hash = "sha256:215e2a30a2090221a9481db58b770ce56b8ef46f13224ae33afe221b14b24dc1"}, 164 | {file = "matplotlib-3.3.4-cp36-cp36m-win_amd64.whl", hash = "sha256:348e6032f666ffd151b323342f9278b16b95d4a75dfacae84a11d2829a7816ae"}, 165 | {file = "matplotlib-3.3.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:94bdd1d55c20e764d8aea9d471d2ae7a7b2c84445e0fa463f02e20f9730783e1"}, 166 | {file = "matplotlib-3.3.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a1acb72f095f1d58ecc2538ed1b8bca0b57df313b13db36ed34b8cdf1868e674"}, 167 | {file = "matplotlib-3.3.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:46b1a60a04e6d884f0250d5cc8dc7bd21a9a96c584a7acdaab44698a44710bab"}, 168 | {file = "matplotlib-3.3.4-cp37-cp37m-win32.whl", hash = "sha256:ed4a9e6dcacba56b17a0a9ac22ae2c72a35b7f0ef0693aa68574f0b2df607a89"}, 169 | {file = "matplotlib-3.3.4-cp37-cp37m-win_amd64.whl", hash = "sha256:c24c05f645aef776e8b8931cb81e0f1632d229b42b6d216e30836e2e145a2b40"}, 170 | {file = "matplotlib-3.3.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7310e353a4a35477c7f032409966920197d7df3e757c7624fd842f3eeb307d3d"}, 171 | {file = "matplotlib-3.3.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:451cc89cb33d6652c509fc6b588dc51c41d7246afdcc29b8624e256b7663ed1f"}, 172 | {file = "matplotlib-3.3.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:3d2eb9c1cc254d0ffa90bc96fde4b6005d09c2228f99dfd493a4219c1af99644"}, 173 | {file = "matplotlib-3.3.4-cp38-cp38-win32.whl", hash = "sha256:e15fa23d844d54e7b3b7243afd53b7567ee71c721f592deb0727ee85e668f96a"}, 174 | {file = "matplotlib-3.3.4-cp38-cp38-win_amd64.whl", hash = "sha256:1de0bb6cbfe460725f0e97b88daa8643bcf9571c18ba90bb8e41432aaeca91d6"}, 175 | {file = "matplotlib-3.3.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f44149a0ef5b4991aaef12a93b8e8d66d6412e762745fea1faa61d98524e0ba9"}, 176 | {file = "matplotlib-3.3.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:746a1df55749629e26af7f977ea426817ca9370ad1569436608dc48d1069b87c"}, 177 | {file = "matplotlib-3.3.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:5f571b92a536206f7958f7cb2d367ff6c9a1fa8229dc35020006e4cdd1ca0acd"}, 178 | {file = "matplotlib-3.3.4-cp39-cp39-win32.whl", hash = "sha256:9265ae0fb35e29f9b8cc86c2ab0a2e3dcddc4dd9de4b85bf26c0f63fe5c1c2ca"}, 179 | {file = "matplotlib-3.3.4-cp39-cp39-win_amd64.whl", hash = "sha256:9a79e5dd7bb797aa611048f5b70588b23c5be05b63eefd8a0d152ac77c4243db"}, 180 | {file = "matplotlib-3.3.4-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1e850163579a8936eede29fad41e202b25923a0a8d5ffd08ce50fc0a97dcdc93"}, 181 | {file = "matplotlib-3.3.4-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:d738acfdfb65da34c91acbdb56abed46803db39af259b7f194dc96920360dbe4"}, 182 | {file = "matplotlib-3.3.4-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa49571d8030ad0b9ac39708ee77bd2a22f87815e12bdee52ecaffece9313ed8"}, 183 | {file = "matplotlib-3.3.4-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:cf3a7e54eff792f0815dbbe9b85df2f13d739289c93d346925554f71d484be78"}, 184 | {file = "matplotlib-3.3.4.tar.gz", hash = "sha256:3e477db76c22929e4c6876c44f88d790aacdf3c3f8f3a90cb1975c0bf37825b0"}, 185 | ] 186 | numpy = [ 187 | {file = "numpy-1.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ae61f02b84a0211abb56462a3b6cd1e7ec39d466d3160eb4e1da8bf6717cdbeb"}, 188 | {file = "numpy-1.20.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:65410c7f4398a0047eea5cca9b74009ea61178efd78d1be9847fac1d6716ec1e"}, 189 | {file = "numpy-1.20.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:2d7e27442599104ee08f4faed56bb87c55f8b10a5494ac2ead5c98a4b289e61f"}, 190 | {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:4ed8e96dc146e12c1c5cdd6fb9fd0757f2ba66048bf94c5126b7efebd12d0090"}, 191 | {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:ecb5b74c702358cdc21268ff4c37f7466357871f53a30e6f84c686952bef16a9"}, 192 | {file = "numpy-1.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b9410c0b6fed4a22554f072a86c361e417f0258838957b78bd063bde2c7f841f"}, 193 | {file = "numpy-1.20.1-cp37-cp37m-win32.whl", hash = "sha256:3d3087e24e354c18fb35c454026af3ed8997cfd4997765266897c68d724e4845"}, 194 | {file = "numpy-1.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:89f937b13b8dd17b0099c7c2e22066883c86ca1575a975f754babc8fbf8d69a9"}, 195 | {file = "numpy-1.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a1d7995d1023335e67fb070b2fae6f5968f5be3802b15ad6d79d81ecaa014fe0"}, 196 | {file = "numpy-1.20.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:60759ab15c94dd0e1ed88241fd4fa3312db4e91d2c8f5a2d4cf3863fad83d65b"}, 197 | {file = "numpy-1.20.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:125a0e10ddd99a874fd357bfa1b636cd58deb78ba4a30b5ddb09f645c3512e04"}, 198 | {file = "numpy-1.20.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:c26287dfc888cf1e65181f39ea75e11f42ffc4f4529e5bd19add57ad458996e2"}, 199 | {file = "numpy-1.20.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:7199109fa46277be503393be9250b983f325880766f847885607d9b13848f257"}, 200 | {file = "numpy-1.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:72251e43ac426ff98ea802a931922c79b8d7596480300eb9f1b1e45e0543571e"}, 201 | {file = "numpy-1.20.1-cp38-cp38-win32.whl", hash = "sha256:c91ec9569facd4757ade0888371eced2ecf49e7982ce5634cc2cf4e7331a4b14"}, 202 | {file = "numpy-1.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:13adf545732bb23a796914fe5f891a12bd74cf3d2986eed7b7eba2941eea1590"}, 203 | {file = "numpy-1.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:104f5e90b143dbf298361a99ac1af4cf59131218a045ebf4ee5990b83cff5fab"}, 204 | {file = "numpy-1.20.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:89e5336f2bec0c726ac7e7cdae181b325a9c0ee24e604704ed830d241c5e47ff"}, 205 | {file = "numpy-1.20.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:032be656d89bbf786d743fee11d01ef318b0781281241997558fa7950028dd29"}, 206 | {file = "numpy-1.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:66b467adfcf628f66ea4ac6430ded0614f5cc06ba530d09571ea404789064adc"}, 207 | {file = "numpy-1.20.1-cp39-cp39-win32.whl", hash = "sha256:12e4ba5c6420917571f1a5becc9338abbde71dd811ce40b37ba62dec7b39af6d"}, 208 | {file = "numpy-1.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:9c94cab5054bad82a70b2e77741271790304651d584e2cdfe2041488e753863b"}, 209 | {file = "numpy-1.20.1-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:9eb551d122fadca7774b97db8a112b77231dcccda8e91a5bc99e79890797175e"}, 210 | {file = "numpy-1.20.1.zip", hash = "sha256:3bc63486a870294683980d76ec1e3efc786295ae00128f9ea38e2c6e74d5a60a"}, 211 | ] 212 | pillow = [ 213 | {file = "Pillow-8.1.0-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:d355502dce85ade85a2511b40b4c61a128902f246504f7de29bbeec1ae27933a"}, 214 | {file = "Pillow-8.1.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:93a473b53cc6e0b3ce6bf51b1b95b7b1e7e6084be3a07e40f79b42e83503fbf2"}, 215 | {file = "Pillow-8.1.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:2353834b2c49b95e1313fb34edf18fca4d57446675d05298bb694bca4b194174"}, 216 | {file = "Pillow-8.1.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:1d208e670abfeb41b6143537a681299ef86e92d2a3dac299d3cd6830d5c7bded"}, 217 | {file = "Pillow-8.1.0-cp36-cp36m-win32.whl", hash = "sha256:dd9eef866c70d2cbbea1ae58134eaffda0d4bfea403025f4db6859724b18ab3d"}, 218 | {file = "Pillow-8.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:b09e10ec453de97f9a23a5aa5e30b334195e8d2ddd1ce76cc32e52ba63c8b31d"}, 219 | {file = "Pillow-8.1.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:b02a0b9f332086657852b1f7cb380f6a42403a6d9c42a4c34a561aa4530d5234"}, 220 | {file = "Pillow-8.1.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:ca20739e303254287138234485579b28cb0d524401f83d5129b5ff9d606cb0a8"}, 221 | {file = "Pillow-8.1.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:604815c55fd92e735f9738f65dabf4edc3e79f88541c221d292faec1904a4b17"}, 222 | {file = "Pillow-8.1.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cf6e33d92b1526190a1de904df21663c46a456758c0424e4f947ae9aa6088bf7"}, 223 | {file = "Pillow-8.1.0-cp37-cp37m-win32.whl", hash = "sha256:47c0d93ee9c8b181f353dbead6530b26980fe4f5485aa18be8f1fd3c3cbc685e"}, 224 | {file = "Pillow-8.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:96d4dc103d1a0fa6d47c6c55a47de5f5dafd5ef0114fa10c85a1fd8e0216284b"}, 225 | {file = "Pillow-8.1.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:7916cbc94f1c6b1301ac04510d0881b9e9feb20ae34094d3615a8a7c3db0dcc0"}, 226 | {file = "Pillow-8.1.0-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3de6b2ee4f78c6b3d89d184ade5d8fa68af0848f9b6b6da2b9ab7943ec46971a"}, 227 | {file = "Pillow-8.1.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cdbbe7dff4a677fb555a54f9bc0450f2a21a93c5ba2b44e09e54fcb72d2bd13d"}, 228 | {file = "Pillow-8.1.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:f50e7a98b0453f39000619d845be8b06e611e56ee6e8186f7f60c3b1e2f0feae"}, 229 | {file = "Pillow-8.1.0-cp38-cp38-win32.whl", hash = "sha256:cb192176b477d49b0a327b2a5a4979552b7a58cd42037034316b8018ac3ebb59"}, 230 | {file = "Pillow-8.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:6c5275bd82711cd3dcd0af8ce0bb99113ae8911fc2952805f1d012de7d600a4c"}, 231 | {file = "Pillow-8.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:165c88bc9d8dba670110c689e3cc5c71dbe4bfb984ffa7cbebf1fac9554071d6"}, 232 | {file = "Pillow-8.1.0-cp39-cp39-manylinux1_i686.whl", hash = "sha256:5e2fe3bb2363b862671eba632537cd3a823847db4d98be95690b7e382f3d6378"}, 233 | {file = "Pillow-8.1.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7612520e5e1a371d77e1d1ca3a3ee6227eef00d0a9cddb4ef7ecb0b7396eddf7"}, 234 | {file = "Pillow-8.1.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d673c4990acd016229a5c1c4ee8a9e6d8f481b27ade5fc3d95938697fa443ce0"}, 235 | {file = "Pillow-8.1.0-cp39-cp39-win32.whl", hash = "sha256:dc577f4cfdda354db3ae37a572428a90ffdbe4e51eda7849bf442fb803f09c9b"}, 236 | {file = "Pillow-8.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:22d070ca2e60c99929ef274cfced04294d2368193e935c5d6febfd8b601bf865"}, 237 | {file = "Pillow-8.1.0-pp36-pypy36_pp73-macosx_10_10_x86_64.whl", hash = "sha256:a3d3e086474ef12ef13d42e5f9b7bbf09d39cf6bd4940f982263d6954b13f6a9"}, 238 | {file = "Pillow-8.1.0-pp36-pypy36_pp73-manylinux2010_i686.whl", hash = "sha256:731ca5aabe9085160cf68b2dbef95fc1991015bc0a3a6ea46a371ab88f3d0913"}, 239 | {file = "Pillow-8.1.0-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:bba80df38cfc17f490ec651c73bb37cd896bc2400cfba27d078c2135223c1206"}, 240 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-macosx_10_10_x86_64.whl", hash = "sha256:c3d911614b008e8a576b8e5303e3db29224b455d3d66d1b2848ba6ca83f9ece9"}, 241 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-manylinux2010_i686.whl", hash = "sha256:39725acf2d2e9c17356e6835dccebe7a697db55f25a09207e38b835d5e1bc032"}, 242 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:81c3fa9a75d9f1afafdb916d5995633f319db09bd773cb56b8e39f1e98d90820"}, 243 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-win32.whl", hash = "sha256:b6f00ad5ebe846cc91763b1d0c6d30a8042e02b2316e27b05de04fa6ec831ec5"}, 244 | {file = "Pillow-8.1.0.tar.gz", hash = "sha256:887668e792b7edbfb1d3c9d8b5d8c859269a0f0eba4dda562adb95500f60dbba"}, 245 | ] 246 | pyparsing = [ 247 | {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, 248 | {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, 249 | ] 250 | python-dateutil = [ 251 | {file = "python-dateutil-2.8.1.tar.gz", hash = "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c"}, 252 | {file = "python_dateutil-2.8.1-py2.py3-none-any.whl", hash = "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a"}, 253 | ] 254 | six = [ 255 | {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"}, 256 | {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"}, 257 | ] 258 | torch = [ 259 | {file = "torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:422e64e98d0e100c360993819d0307e5d56e9517b26135808ad68984d577d75a"}, 260 | {file = "torch-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f0aaf657145533824b15f2fd8fde8f8c67fe6c6281088ef588091f03fad90243"}, 261 | {file = "torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:af464a6f4314a875035e0c4c2b07517599704b214634f4ed3ad2e748c5ef291f"}, 262 | {file = "torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d76c255a41484c1d41a9ff570b9c9f36cb85df9428aa15a58ae16ac7cfc2ea6"}, 263 | {file = "torch-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d241c3f1c4d563e4ba86f84769c23e12606db167ee6f674eedff6d02901462e3"}, 264 | {file = "torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:de84b4166e3f7335eb868b51d3bbd909ec33828af27290b4171bce832a55be3c"}, 265 | {file = "torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:dd2fc6880c95e836960d86efbbc7f63d3287f2e1893c51d31f96dbfe02f0d73e"}, 266 | {file = "torch-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:e000b94be3aa58ad7f61e7d07cf379ea9366cf6c6874e68bd58ad0bdc537b3a7"}, 267 | {file = "torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2e49cac969976be63117004ee00d0a3e3dd4ea662ad77383f671b8992825de1a"}, 268 | {file = "torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a3793dcceb12b1e2281290cca1277c5ce86ddfd5bf044f654285a4d69057aea7"}, 269 | {file = "torch-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:6652a767a0572ae0feb74ad128758e507afd3b8396b6e7f147e438ba8d4c6f63"}, 270 | {file = "torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:38d67f4fb189a92a977b2c0a38e4f6dd413e0bf55aa6d40004696df7e40a71ff"}, 271 | ] 272 | torchvision = [ 273 | {file = "torchvision-0.8.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86fae370d222f76ad57c57c3bee03f78b8db727743bfb4c1559a3d395159cea8"}, 274 | {file = "torchvision-0.8.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:951239b5fcb911dbf78c1385d677f5f48c7a1b12859e3d3ec287562821b17cf2"}, 275 | {file = "torchvision-0.8.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24db8f4c3d812a032273f68563ad5dbd724f5bfbed523d0c6dce8cede26bb153"}, 276 | {file = "torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b068f6bcbe91bdd34dda0a39e8a26392add45a3be82543f6dd523b76484fb56f"}, 277 | {file = "torchvision-0.8.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb76a66b9b0693f758a881a2bf333ed97e3c0c3f15a413c4f49d8dd8bd21307"}, 278 | {file = "torchvision-0.8.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd8817e9197fc60ebae37162a445db90bbf35591314a5767ad3d1490b5d65b0f"}, 279 | {file = "torchvision-0.8.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bd58acc3366ec02266aae56a7a752d43ef07de4a6ba420c4f907d0c9168bb8c"}, 280 | {file = "torchvision-0.8.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:976750a49db2e23dc5a1ed0b5c31f7af51ed2702eee410ee09ef985c3a3e48cf"}, 281 | ] 282 | typing-extensions = [ 283 | {file = "typing_extensions-3.7.4.3-py2-none-any.whl", hash = "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"}, 284 | {file = "typing_extensions-3.7.4.3-py3-none-any.whl", hash = "sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918"}, 285 | {file = "typing_extensions-3.7.4.3.tar.gz", hash = "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c"}, 286 | ] 287 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "machinelearning" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["wilhelmberghammer"] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.8" 9 | Pillow = "^8.1.0" 10 | torch = "^1.7.1" 11 | torchvision = "^0.8.2" 12 | matplotlib = "^3.3.4" 13 | 14 | [tool.poetry.dev-dependencies] 15 | 16 | [build-system] 17 | requires = ["poetry-core>=1.0.0"] 18 | build-backend = "poetry.core.masonry.api" 19 | -------------------------------------------------------------------------------- /transfer_learning/README.md: -------------------------------------------------------------------------------- 1 | # Transfer Learning for computer vision 2 | 3 | I'll be using the ants and bees dataset like in the pytorch [tutorial on transfer learning](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) 4 | 5 | 6 | ## Results 7 | ```python 8 | batch_size = 8 9 | EPOCHS = 10 10 | 11 | pretrained_model = models.resnet18(pretrained=True) 12 | 13 | for param in pretrained_model.parameters(): 14 | param.requires_grad = False 15 | 16 | pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, 2) 17 | pretrained_model = pretrained_model.to(device) 18 | 19 | criterion = nn.CrossEntropyLoss() 20 | optimizer = optim.Adam(pretrained_model.fc.parameters(), lr=0.001) 21 | 22 | model = train(pretrained_model, EPOCHS, criterion, optimizer, data_loaders) 23 | ``` 24 | ``` 25 | cpu 26 | 27 | __________ 28 | EPOCH 0/9 29 | ---------- 30 | train Loss: 0.6054 Acc: 0.6721 31 | val Loss: 0.4359 Acc: 0.8039 32 | __________ 33 | EPOCH 1/9 34 | ---------- 35 | train Loss: 0.3995 Acc: 0.8648 36 | val Loss: 0.3478 Acc: 0.8431 37 | __________ 38 | EPOCH 2/9 39 | ---------- 40 | train Loss: 0.3758 Acc: 0.8279 41 | val Loss: 0.2956 Acc: 0.9085 42 | __________ 43 | EPOCH 3/9 44 | ---------- 45 | train Loss: 0.2948 Acc: 0.8770 46 | val Loss: 0.2717 Acc: 0.9216 47 | __________ 48 | EPOCH 4/9 49 | ---------- 50 | train Loss: 0.3318 Acc: 0.8607 51 | val Loss: 0.2493 Acc: 0.8954 52 | __________ 53 | EPOCH 5/9 54 | ---------- 55 | train Loss: 0.2259 Acc: 0.8975 56 | val Loss: 0.2445 Acc: 0.8954 57 | __________ 58 | EPOCH 6/9 59 | ---------- 60 | train Loss: 0.2454 Acc: 0.9057 61 | val Loss: 0.2424 Acc: 0.8954 62 | __________ 63 | EPOCH 7/9 64 | ---------- 65 | train Loss: 0.2141 Acc: 0.9303 66 | val Loss: 0.2759 Acc: 0.8627 67 | __________ 68 | EPOCH 8/9 69 | ---------- 70 | train Loss: 0.2456 Acc: 0.8934 71 | val Loss: 0.2233 Acc: 0.9216 72 | __________ 73 | EPOCH 9/9 74 | ---------- 75 | train Loss: 0.1821 Acc: 0.9385 76 | val Loss: 0.2317 Acc: 0.9085 77 | ``` 78 | -------------------------------------------------------------------------------- /transfer_learning/transfer_learn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision.transforms as transforms 6 | import torchvision.models as models 7 | import torchvision.datasets as datasets 8 | 9 | 10 | def get_data(data_dir, batch_size, num_workers=2): 11 | data_transforms = { 12 | 'train': transforms.Compose([ 13 | transforms.Resize((224, 224)), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ToTensor(), 16 | # because of the resnet model 17 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 18 | ]), 19 | 'val': transforms.Compose([ 20 | transforms.Resize((224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 23 | ]), 24 | } 25 | 26 | img_datasets = { 27 | phase: datasets.ImageFolder(os.path.join(data_dir, phase), data_transforms[phase]) for phase in ['train', 'val'] 28 | } 29 | 30 | data_loaders = { 31 | phase: torch.utils.data.DataLoader(img_datasets[phase], batch_size=batch_size, shuffle=True, num_workers=num_workers) for phase in ['train', 'val'] 32 | } 33 | 34 | class_names = img_datasets['train'].classes 35 | 36 | return data_loaders 37 | 38 | 39 | def train(model, EPOCHS, criterion, optimizer, data_loaders): 40 | for epoch in range(EPOCHS): 41 | print('_' * 10) 42 | print(f'EPOCH {epoch}/{EPOCHS - 1}') 43 | print('-' * 10) 44 | 45 | # each epoch has a training and a val phase ... never did that before (from the pytorch tutorial) but makes sense 46 | for phase in ['train', 'val']: 47 | if phase == 'train': 48 | model.train() 49 | else: 50 | model.eval() 51 | 52 | running_loss = .0 53 | running_correct = 0 54 | 55 | for x, labels in data_loaders[phase]: 56 | x = x.to(device) 57 | labels = labels.to(device) 58 | 59 | optimizer.zero_grad() 60 | 61 | with torch.set_grad_enabled(phase == 'train'): 62 | outputs = model(x) 63 | loss = criterion(outputs, labels) 64 | 65 | # value is not important (_); index is important (preds) 66 | _, preds = torch.max(outputs, 1) 67 | 68 | if phase == 'train': 69 | loss.backward() 70 | optimizer.step() 71 | 72 | running_loss += loss.item()*x.size(0) 73 | running_correct += torch.sum(preds==labels.data) 74 | 75 | epoch_loss = running_loss / len(data_loaders[phase].dataset) 76 | epoch_acc = running_correct.double() / len(data_loaders[phase].dataset) 77 | 78 | print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') 79 | 80 | return model 81 | 82 | 83 | 84 | if __name__ == '__main__': 85 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 86 | print(device, '\n') 87 | 88 | # hyperparameters 89 | batch_size = 8 90 | EPOCHS = 10 91 | 92 | data_loaders = get_data(data_dir='./data/', batch_size=batch_size) 93 | 94 | pretrained_model = models.resnet18(pretrained=True) 95 | for param in pretrained_model.parameters(): 96 | param.requires_grad = False 97 | 98 | pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, 2) 99 | pretrained_model = pretrained_model.to(device) 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | 103 | optimizer = optim.Adam(pretrained_model.fc.parameters(), lr=0.001) 104 | 105 | model = train(pretrained_model, EPOCHS, criterion, optimizer, data_loaders) 106 | 107 | --------------------------------------------------------------------------------