├── .editorconfig
├── .gitignore
├── DCGAN_pytorch
├── README.md
├── recourses
│ ├── 1_epoch.png
│ └── paper_figure.png
└── src
│ ├── model.py
│ └── train.py
├── LICENSE
├── README.md
├── cnn_architectures
├── README.md
├── resources
│ └── vgg_architectures.png
└── vgg16.py
├── neural_style_transfer
├── README.md
├── content_images
│ ├── alisha.jpg
│ ├── los_angeles.jpg
│ ├── los_angeles_night.jpg
│ └── snow.jpg
├── generated_images
│ ├── alisha_1.png
│ ├── alisha_2.png
│ ├── gen_1000.png
│ ├── generated_1000.png
│ ├── generated_1500.png
│ ├── generated_200.png
│ ├── generated_2000.png
│ ├── generated_600.png
│ └── snow_1.png
├── nst.py
├── readme_results
│ ├── result_1.png
│ └── result_2.png
└── style_images
│ ├── dog_abstract_style.jpg
│ └── starry_night_full.jpg
├── object_detection
├── README.md
├── object_localization
│ ├── README.md
│ ├── pet_localization.ipynb
│ └── readme_resources
│ │ ├── result_1.png
│ │ ├── result_2.png
│ │ ├── result_3.png
│ │ └── result_4.png
└── yolo
│ └── model.py
├── poetry.lock
├── pyproject.toml
└── transfer_learning
├── README.md
└── transfer_learn.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = tab
5 | indent_size = 4
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | # data
133 | transfer_learning/data
134 | DCGAN_pytorch/src/data
135 |
136 | # DS_Store
137 | .DS_Store
138 | transfer_learning/.DS_Store
--------------------------------------------------------------------------------
/DCGAN_pytorch/README.md:
--------------------------------------------------------------------------------
1 | # DCGAN_pytorch
2 | DCGAN implementation of the original [paper](https://arxiv.org/abs/1511.06434) in PyTorch. I tried to stay as close to the [paper](https://arxiv.org/abs/1511.06434) as possible.
3 |
4 |
5 | Most of the values used in the implementation are straight from the paper.
6 | The implemented structure of the network is the one depicted in the paper as well:
7 |
8 | 
9 |
10 | *Figure 1 in the original [paper](https://arxiv.org/abs/1511.06434)*
11 |
12 |
13 |
14 | ### Results on MNIST
15 | **After one Epoch:**
16 |
17 | 
18 |
--------------------------------------------------------------------------------
/DCGAN_pytorch/recourses/1_epoch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/DCGAN_pytorch/recourses/1_epoch.png
--------------------------------------------------------------------------------
/DCGAN_pytorch/recourses/paper_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/DCGAN_pytorch/recourses/paper_figure.png
--------------------------------------------------------------------------------
/DCGAN_pytorch/src/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | # Discriminator (-> paper)
7 | class Discriminator(nn.Module):
8 | def __init__(self, img_channels, features_d):
9 | super(Discriminator, self).__init__()
10 | self.dis = nn.Sequential(
11 |
12 | nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, padding=1),
13 | nn.LeakyReLU(.2),
14 |
15 | nn.Conv2d(features_d, features_d*2, kernel_size=4, stride=2, padding=1, bias=False),
16 | nn.BatchNorm2d(features_d*2),
17 | nn.LeakyReLU(.2),
18 |
19 | nn.Conv2d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1, bias=False),
20 | nn.BatchNorm2d(features_d*4),
21 | nn.LeakyReLU(.2),
22 |
23 | nn.Conv2d(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1, bias=False),
24 | nn.BatchNorm2d(features_d*8),
25 | nn.LeakyReLU(.2),
26 |
27 | nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
28 | nn.Sigmoid(),
29 | )
30 |
31 | def forward(self, x):
32 | return self.dis(x)
33 |
34 |
35 |
36 | # Generator (-> paper)
37 | class Generator(nn.Module):
38 | def __init__(self, z_dim, img_channels, features_g):
39 | super(Generator, self).__init__()
40 | self.gen = nn.Sequential(
41 |
42 | nn.ConvTranspose2d(z_dim, features_g*16, kernel_size=4, stride=1, padding=0),
43 | nn.BatchNorm2d(features_g*16),
44 | nn.ReLU(),
45 |
46 | nn.ConvTranspose2d(features_g*16, features_g*8, kernel_size=4, stride=2, padding=1),
47 | nn.BatchNorm2d(features_g*8),
48 | nn.ReLU(),
49 |
50 | nn.ConvTranspose2d(features_g*8, features_g*4, kernel_size=4, stride=2, padding=1),
51 | nn.BatchNorm2d(features_g*4),
52 | nn.ReLU(),
53 |
54 | nn.ConvTranspose2d(features_g*4, features_g*2, kernel_size=4, stride=2, padding=1),
55 | nn.BatchNorm2d(features_g*2),
56 | nn.ReLU(),
57 |
58 | nn.ConvTranspose2d(features_g*2, img_channels, kernel_size=4, stride=2, padding=1),
59 | nn.Tanh(),
60 | )
61 |
62 | def forward(self, x):
63 | return self.gen(x)
64 |
65 |
66 |
67 | def init_weights(model):
68 | for i in model.modules():
69 | if isinstance(i, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
70 | nn.init.normal_(i.weight.data, .0, .02)
71 |
72 |
--------------------------------------------------------------------------------
/DCGAN_pytorch/src/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | from torch.utils.data import DataLoader
7 |
8 | import matplotlib.pyplot as plt
9 |
10 | from model import Discriminator, Generator, init_weights
11 |
12 |
13 | # cuda if GPU is available, otherwise train on cpu
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | print(f"Using {device} for training!")
16 |
17 | # hyperparameters (-> paper)
18 | img_size = 64
19 | img_channels = 1 # 3 for rgb
20 | z_dim = 100 # noise dim
21 | feature_dis = 64
22 | feature_gen = 64
23 |
24 | EPOCHS = 1
25 | batch_size = 64
26 | learning_rate = 2e-4
27 |
28 |
29 | # Transforms
30 | transforms = transforms.Compose([
31 | transforms.Resize(img_size),
32 | transforms.ToTensor(),
33 | transforms.Normalize([.5], [.5]), # needs to be changes when using more then 1 channel!!
34 | ])
35 |
36 | # download and load dataset - in this case I'll be using MNIST
37 | dataset = torchvision.datasets.MNIST(root='data/', train=True, download=True, transform=transforms)
38 | # create datalaoder
39 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
40 |
41 | # init Generator
42 | gen = Generator(z_dim, img_channels, feature_gen).to(device)
43 | dis = Discriminator(img_channels, feature_dis).to(device)
44 |
45 | # init weights
46 | init_weights(gen)
47 | init_weights(dis)
48 |
49 | # optimizers (-> paper)
50 | gen_optim = optim.Adam(gen.parameters(), lr=learning_rate, betas=(.5, .999))
51 | dis_optim = optim.Adam(dis.parameters(), lr=learning_rate, betas=(.5, .999))
52 |
53 | # loss function
54 | loss = nn.BCELoss()
55 |
56 |
57 | # training loop
58 | print('starting training loop...')
59 | for epoch in range(EPOCHS):
60 | for batch, (real_img, _) in enumerate(data_loader):
61 | real_img = real_img.to(device)
62 | z = torch.randn((batch_size, z_dim, 1, 1)).to(device) # noise
63 | fake_img = gen(z)
64 |
65 | # Discriminator training
66 | dis_real_img = dis(real_img).reshape(-1)
67 | loss_dis_real_img = loss(dis_real_img, torch.ones_like(dis_real_img))
68 |
69 | dis_fake_img = dis(fake_img).reshape(-1)
70 | loss_dis_fake_img = loss(dis_fake_img, torch.zeros_like(dis_fake_img))
71 |
72 | loss_dis = (loss_dis_real_img + loss_dis_fake_img)/2
73 | dis.zero_grad()
74 | loss_dis.backward(retain_graph = True) # retain_graph because we are using it in the training for the generator
75 | dis_optim.step()
76 |
77 |
78 | # Generator training
79 | out = dis(fake_img).reshape(-1)
80 | loss_gen = loss(out, torch.ones_like(out))
81 | gen.zero_grad()
82 | loss_gen.backward()
83 | gen_optim.step()
84 |
85 | print(f"{epoch}/{EPOCHS}")
86 |
87 | # print a grid after every epoch - tensorboard would be better but that is a pain in colab so...
88 | with torch.no_grad():
89 | fake_img = gen(fixed_z)
90 | img_grid_fake = torchvision.utils.make_grid(fake_img[:32], normalize=True).cpu()
91 | plt.imshow(img_grid_fake.permute(1, 2, 0))
92 | plt.show()
93 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Wilhelm Berghammer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **Machine Learning**
2 | ### *Currently on hold due to other, more practical projects*
3 | ---
4 |
5 | This is where I learn machine learning🤷 This means that this repo covers no specific topic of machine learning or a project - I work in here when I want to learn/try something
6 |
7 | I'll try to keep it orgamised but I don't want it to be too nested because that becomes a pain to navigate, instead I'll try to keet the table of content up-to-date.
8 |
9 | **If I use a jupyter notebook I'm experimenting with something and the code will probably be a mess ... so don't take those too serious.**
10 |
11 | ## Table of Content
12 | * [Transfer Learning for Computer vision](https://github.com/wilhelmberghammer/MachineLearning/tree/main/Transfer_Learning_CV)
13 |
14 | ### Paper implementations
15 | * [DCGAN paper implementation in PyTorch](https://github.com/wilhelmberghammer/MachineLearning/tree/main/DCGAN_pytorch)
16 | * [Neural Style Transfer - implementation in PyTorch](https://github.com/wilhelmberghammer/MachineLearning/tree/main/neural_style_transfer)
17 |
18 | ### Model/Architecture implementations
19 | * [VGG16](https://github.com/wilhelmberghammer/MachineLearning/tree/main/cnn_architectures/vgg16.py)
20 |
21 | ### Working on / On Hold
22 | * [Object detection](https://github.com/wilhelmberghammer/MachineLearning/tree/main/object_detection)
23 |
24 |
25 | ### Coming up (things I want to learn in the near future)
26 | * [Weights & Biases](https://wandb.ai/site)
27 |
28 |
--------------------------------------------------------------------------------
/cnn_architectures/README.md:
--------------------------------------------------------------------------------
1 | # CNN architecture implementations in PyTorch
2 |
3 | **This is just for fun** and to get a better feeling for VGG models when using the pretrained ones ... I didn't train this model I just imprelemted the architectue (I also added the hyperparameters of the training from the paper)
4 | I didn't implement the training-loop either because I didn't intented to train the thing, plus the training-loop is not the main part of the paper.
5 |
6 |
7 | ## VGG
8 | Simple PyTorch implementation of the paper [VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION](https://arxiv.org/pdf/1409.1556.pdf)
9 |
10 | I made VGG16 (D) because it's the most popular and the architecute can be modified.
11 |
12 | 
13 |
14 | *👆 figure from the paper*
15 |
--------------------------------------------------------------------------------
/cnn_architectures/resources/vgg_architectures.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/cnn_architectures/resources/vgg_architectures.png
--------------------------------------------------------------------------------
/cnn_architectures/vgg16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | class VGG16(nn.Module):
7 | '''
8 | VGG16 (D) architecture acording to paper
9 | '''
10 | def __init__(self, input_channels, n_classes):
11 | super(VGG16, self).__init__()
12 | self.conv_layers = nn.Sequential(
13 | # There are way better ways to to this (iterate over a list for example) ... this is so I have a better overview
14 | nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=1, padding=1),
15 | nn.ReLU(),
16 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
17 | nn.ReLU(),
18 |
19 | nn.MaxPool2d(kernel_size=2, stride=2),
20 |
21 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
22 | nn.ReLU(),
23 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
24 | nn.ReLU(),
25 |
26 | nn.MaxPool2d(kernel_size=2, stride=2),
27 |
28 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
29 | nn.ReLU(),
30 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
31 | nn.ReLU(),
32 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
33 | nn.ReLU(),
34 |
35 | nn.MaxPool2d(kernel_size=2, stride=2),
36 |
37 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
38 | nn.ReLU(),
39 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
40 | nn.ReLU(),
41 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
42 | nn.ReLU(),
43 |
44 | nn.MaxPool2d(kernel_size=2, stride=2),
45 |
46 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
47 | nn.ReLU(),
48 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
49 | nn.ReLU(),
50 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
51 | nn.ReLU(),
52 |
53 | nn.MaxPool2d(kernel_size=2, stride=2),
54 | )
55 |
56 | self.fc_layers = nn.Sequential(
57 | nn.Linear(512*7*7, 4096), # calculation in paper
58 | nn.ReLU(),
59 | nn.Dropout(.5), # paper: dropout regularisation for the first two fully-connected layers (dropout ratio set to 0.5)
60 | nn.Linear(4096, 4096),
61 | nn.ReLU(),
62 | nn.Dropout(.5),
63 | nn.Linear(4096, n_classes),
64 | )
65 |
66 |
67 | def forward(self, x):
68 | x = self.conv_layers(x)
69 | x = x.reshape(x.shape[0], -1)
70 | x = self.fc_layers(x)
71 | return x
72 |
73 |
74 | # the following might differ slightly from the paper
75 | # I didn't do the weight initialisation like in the paper - they trained the shallower model and initialized the larger ones with those weights
76 | model = VGG16(input_channels=3, n_classes=1000)
77 |
78 | batch_size = 256
79 | loss = torch.nn.CrossEntropyLoss()
80 | optim = torch.optim.SGD(model.parameters, lr=1e-3, momentum=0.9, weight_decay=5e-4)
81 |
82 |
83 | # training-loop ...
--------------------------------------------------------------------------------
/neural_style_transfer/README.md:
--------------------------------------------------------------------------------
1 | # Neural Style Transfer
2 |
3 | I tried to follow the paper [A Neural Algorithm of Artistic Style](https://arxiv.org/pdf/1508.06576v2.pdf) but I did some things differently.
4 |
5 | For example I didn't use their, more complicated loss function, I simply used MSE.
6 |
7 | I also didn't use noise as input, I used the content.
8 |
9 | This might be the reason for some differences. I didn't notice any differences when changing the alpha and beta values. The biggest change happens when changing the learning rate. A higher learning rate leads to more significant changens to the original content.
10 |
11 |
12 |
13 | ### Setup
14 | You will need a virtualenv with `python>=3.8`.
15 |
16 | ```bash
17 | pip install poetry
18 | poetry install
19 | ```
20 |
21 | ### To run style transfer
22 | ```bash
23 | poetry run python neural_style_transfer/nst.py
24 | ```
25 |
26 | ## Result Example
27 | **Example 1:**
28 |
29 | 
30 |
31 | **Example 2:**
32 |
33 | 
34 |
--------------------------------------------------------------------------------
/neural_style_transfer/content_images/alisha.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/alisha.jpg
--------------------------------------------------------------------------------
/neural_style_transfer/content_images/los_angeles.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/los_angeles.jpg
--------------------------------------------------------------------------------
/neural_style_transfer/content_images/los_angeles_night.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/los_angeles_night.jpg
--------------------------------------------------------------------------------
/neural_style_transfer/content_images/snow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/content_images/snow.jpg
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/alisha_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/alisha_1.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/alisha_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/alisha_2.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/gen_1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/gen_1000.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/generated_1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_1000.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/generated_1500.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_1500.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/generated_200.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_200.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/generated_2000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_2000.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/generated_600.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/generated_600.png
--------------------------------------------------------------------------------
/neural_style_transfer/generated_images/snow_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/generated_images/snow_1.png
--------------------------------------------------------------------------------
/neural_style_transfer/nst.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torchvision.transforms as transforms
6 | import torchvision.models as models
7 | import torchvision.utils as utils
8 |
9 | import matplotlib.pyplot as plt
10 |
11 |
12 | class VGG_19(nn.Module):
13 | def __init__(self):
14 | super(VGG_19, self).__init__()
15 | # model used: VGG19 (like in the paper)
16 | # everything after the 28th layer is technically not needed
17 | self.model = models.vgg19(pretrained=True).features[:30]
18 |
19 | # better results when changing the MaxPool layers to AvgPool (-> paper)
20 | for i, _ in enumerate(self.model):
21 | # Indicies of the MaxPool layers -> replaced by AvgPool with same parameters
22 | if i in [4, 9, 18, 27]:
23 | self.model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
24 |
25 | def forward(self, x):
26 | features = []
27 |
28 | for i, layer in enumerate(self.model):
29 | x = layer(x)
30 | # indicies of the conv layers after the now AvgPool layers
31 | if i in [0, 5, 10, 19, 28]:
32 | features.append(x)
33 | return features
34 |
35 |
36 | def load_img(path_to_image, img_size):
37 | transform = transforms.Compose([
38 | transforms.Resize((img_size, img_size)),
39 | transforms.ToTensor(),
40 | ])
41 | img = Image.open(path_to_image)
42 | img = transform(img).unsqueeze(0)
43 | return img
44 |
45 |
46 | def transfer_style(iterations, optimizer, alpha, beta, generated_image, content_image, style_image, show_images=False):
47 | for iter in range(iterations+1):
48 | generated_features = model(generated_image)
49 | content_features = model(content_image)
50 | style_features = model(style_image)
51 |
52 | content_loss = 0
53 | style_loss = 0
54 |
55 | for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
56 | batch_size, n_feature_maps, height, width = generated_feature.size()
57 |
58 | # in paper it is 1/2*((g - c)**2) ... but it is easies this way because I don't have to worry about dimensions ... and it workes as well
59 | content_loss += (torch.mean((generated_feature - content_feature) ** 2))
60 |
61 | # batch_size is one ... so it isn't needed. I still inclueded it for better understanding.
62 | G = torch.mm((generated_feature.view(batch_size*n_feature_maps, height*width)), (generated_feature.view(batch_size*n_feature_maps, height*width)).t())
63 | A = torch.mm((style_feature.view(batch_size*n_feature_maps, height*width)), (style_feature.view(batch_size*n_feature_maps, height*width)).t())
64 |
65 | # different in paper!!
66 | E_l = ((G - A)**2)
67 | # w_l ... one divided by the number of active layers with a non-zero loss-weight -> directly from the paper (technically isn't needed)
68 | w_l = 1/5
69 | style_loss += torch.mean(w_l*E_l)
70 |
71 | # I found little difference when changing the alpha and beta values ... still kept it in for better understanding of paper
72 | total_loss = alpha * content_loss + beta * style_loss
73 | optimizer.zero_grad()
74 | total_loss.backward()
75 | optimizer.step()
76 |
77 |
78 | if iter % 100 == 0:
79 | print('-'*15)
80 | print(f'\n{iter} \nTotal Loss: {total_loss.item()} \n Content Loss: {content_loss} \t Style Loss: {style_loss}')
81 | print('-'*15)
82 |
83 | # show image
84 | if show_images == True:
85 | plt.figure(figsize=(10, 10))
86 | plt.imshow(generated_image.permute(0, 2, 3, 1)[0].cpu().detach().numpy())
87 | plt.show()
88 |
89 | return generated_image
90 |
91 | #if iter % 500 == 0:
92 | #utils.save_image(generated, f'./gen_{iter}.png')
93 |
94 |
95 |
96 | if __name__ == '__main__':
97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
98 |
99 | content_img = load_img('./content_images/alisha.jpg', 512).to(device)
100 | style_img = load_img('./style_images/dog_abstract_style.jpg', 512).to(device)
101 |
102 | model = VGG_19().to(device)
103 | # freeze parameters
104 | for param in model.parameters():
105 | param.requires_grad = False
106 |
107 | # generated image (init) is the content image ... could also be noise
108 | # requires_grad because the network itself is frozen ... the thing we are changine is this
109 | generated_init = content_img.clone().requires_grad_(True)
110 |
111 |
112 | iterations = 200
113 | # the real difference is visibale whe changing the learning rate ... 1e-2 is rather high -> heavy changes to content image
114 | lr = 1e-2
115 | # I found no real difference when changing these values...this is why I keep them at 1
116 | alpha = 1
117 | beta = 1
118 |
119 | optimizer = optim.Adam([generated_init], lr=lr)
120 |
121 | generated_image = transfer_style(iterations=iterations,
122 | optimizer=optimizer,
123 | alpha=alpha,
124 | beta=beta,
125 | generated_image=generated_init,
126 | content_image=content_img,
127 | style_image=style_img,
128 | show_images=False # only true in jupyter notebook
129 | )
130 |
131 | utils.save_image(generated_image, f'./gen.png')
--------------------------------------------------------------------------------
/neural_style_transfer/readme_results/result_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/readme_results/result_1.png
--------------------------------------------------------------------------------
/neural_style_transfer/readme_results/result_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/readme_results/result_2.png
--------------------------------------------------------------------------------
/neural_style_transfer/style_images/dog_abstract_style.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/style_images/dog_abstract_style.jpg
--------------------------------------------------------------------------------
/neural_style_transfer/style_images/starry_night_full.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/neural_style_transfer/style_images/starry_night_full.jpg
--------------------------------------------------------------------------------
/object_detection/README.md:
--------------------------------------------------------------------------------
1 | # Object Detection
2 |
3 | After studying and playing around with CNNs I'll continue with computer vision and start with object localization and object detection.
4 |
5 | ### Roadmap
6 | * 1st step - learn about Object Localization (I will probably delete this after implementing it better for object detection)
7 | * learn core concepts of object localization
8 | * **just to experiment - the code is a mess**
9 | * test on [The Oxford-IIIT Pet Dataset](https://www.kaggle.com/devdgohil/the-oxfordiiit-pet-dataset)
10 |
11 |
12 | * **2nd step - Object Detection** *On Hold*
13 | * **this is the main part**
14 | * study [OverFeat: Integrated Recognition, Localization and Detection using Conv Nets Paper](https://arxiv.org/abs/1312.6229)
15 | * [YOLO paper](https://arxiv.org/abs/1506.02640) implementation
16 |
17 |
18 | ### Resources
19 | * [Deeplearning.ai YouTube Playlist](https://www.youtube.com/watch?v=GSwYGkTfOKk&list=PLkDaE6sCZn6Gl29AoE31iwdVwSG-KnDzF&index=23)
20 | * [OverFeat: Integrated Recognition, Localization and Detection using Conv Nets Paper](https://arxiv.org/abs/1312.6229)
21 | * [origtnal YOLO paper](https://arxiv.org/abs/1506.02640)
22 |
--------------------------------------------------------------------------------
/object_detection/object_localization/README.md:
--------------------------------------------------------------------------------
1 | # Object Localization
2 | **This is just to experiment with the concept of object localization!**
3 |
4 | Don't take the code too serious
5 |
6 | ## Results:
7 | 
8 |
9 | 
10 |
11 | 
12 |
13 | 
14 |
--------------------------------------------------------------------------------
/object_detection/object_localization/pet_localization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "pot_location.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "machine_shape": "hm"
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "source": [
20 | "# Per Localization\n",
21 | "Datset used: [The Oxford IIIT Pet Dataset](https://www.kaggle.com/devdgohil/the-oxfordiiit-pet-dataset)\n",
22 | "\n",
23 | "### Disclaimer:\n",
24 | "**This is just to implement the core concept of basic object localization and is not to be taken serious - this is just experimental and to learn it before going to object detection**\n"
25 | ],
26 | "cell_type": "markdown",
27 | "metadata": {}
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "9samIqgWbqHj"
33 | },
34 | "source": [
35 | "import zipfile\n",
36 | "import os\n",
37 | "import numpy as np\n",
38 | "import csv\n",
39 | "import cv2\n",
40 | "import glob\n",
41 | "import xml.etree.ElementTree as ET\n",
42 | "import matplotlib.pyplot as plt\n",
43 | "import random\n",
44 | "import tqdm\n",
45 | "\n",
46 | "import torch\n",
47 | "import torch.nn as nn\n",
48 | "import torch.optim as optim"
49 | ],
50 | "execution_count": 1,
51 | "outputs": []
52 | },
53 | {
54 | "cell_type": "code",
55 | "metadata": {
56 | "id": "rt1cIZKG3SNX"
57 | },
58 | "source": [
59 | "os.mkdir('data') \n",
60 | "\n",
61 | "with zipfile.ZipFile(\"drive/MyDrive/AI/Data/ImageData/oxford_iiit.zip\",\"r\") as zip_ref:\n",
62 | " zip_ref.extractall(\"./data\")"
63 | ],
64 | "execution_count": null,
65 | "outputs": []
66 | },
67 | {
68 | "cell_type": "code",
69 | "metadata": {
70 | "id": "IUxXMMkyX5cb"
71 | },
72 | "source": [
73 | "def pad_image(img, IMG_SIZE):\n",
74 | " image = cv2.copyMakeBorder(img, 0, IMG_SIZE-img.shape[0], 0, IMG_SIZE-img.shape[1], cv2.BORDER_CONSTANT)\n",
75 | " return image"
76 | ],
77 | "execution_count": 2,
78 | "outputs": []
79 | },
80 | {
81 | "cell_type": "code",
82 | "metadata": {
83 | "colab": {
84 | "base_uri": "https://localhost:8080/"
85 | },
86 | "id": "LuL3ITTV4BH8",
87 | "outputId": "01c64691-f1a3-416a-8333-1829d2609213"
88 | },
89 | "source": [
90 | "IMG_SIZE = 550\n",
91 | "XMLS = \"./data/annotations/annotations/xmls\"\n",
92 | "\n",
93 | "training_data = []\n",
94 | "xml_files = glob.glob(\"{}/*xml\".format(XMLS))\n",
95 | "for i, xml_file in enumerate(xml_files):\n",
96 | " tree = ET.parse(xml_file)\n",
97 | "\n",
98 | " path = os.path.join('./data/images/images', tree.findtext(\"./filename\"))\n",
99 | " img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)\n",
100 | "\n",
101 | " if img.shape[0] < IMG_SIZE and img.shape[1] < IMG_SIZE:\n",
102 | " xmin = int(tree.findtext(\"./object/bndbox/xmin\"))\n",
103 | " ymin = int(tree.findtext(\"./object/bndbox/ymin\"))\n",
104 | " xmax = int(tree.findtext(\"./object/bndbox/xmax\"))\n",
105 | " ymax = int(tree.findtext(\"./object/bndbox/ymax\"))\n",
106 | "\n",
107 | " image = pad_image(img, IMG_SIZE)\n",
108 | " training_data.append([np.array(image), [xmin/IMG_SIZE, ymin/IMG_SIZE, xmax/IMG_SIZE, ymax/IMG_SIZE]])\n",
109 | "\n",
110 | "print('training_data length: ', len(training_data))"
111 | ],
112 | "execution_count": 3,
113 | "outputs": [
114 | {
115 | "output_type": "stream",
116 | "text": [
117 | "training_data length: 3634\n"
118 | ],
119 | "name": "stdout"
120 | }
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "metadata": {
126 | "id": "_EAUPCZWuglh"
127 | },
128 | "source": [
129 | "X = torch.Tensor([i[0] for i in training_data]).view(-1, IMG_SIZE, IMG_SIZE)\n",
130 | "X = X/255.0\n",
131 | "y = torch.Tensor([i[1] for i in training_data])"
132 | ],
133 | "execution_count": 4,
134 | "outputs": []
135 | },
136 | {
137 | "cell_type": "code",
138 | "metadata": {
139 | "colab": {
140 | "base_uri": "https://localhost:8080/"
141 | },
142 | "id": "erULjdDlneMp",
143 | "outputId": "63b19af7-6015-45d2-b44e-8b3f49adf296"
144 | },
145 | "source": [
146 | "class Net(nn.Module):\n",
147 | " def __init__(self, in_channels, n_classes):\n",
148 | " super(Net, self).__init__()\n",
149 | " self.conv = nn.Sequential(\n",
150 | " nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=5),\n",
151 | " nn.ReLU(),\n",
152 | " nn.BatchNorm2d(8),\n",
153 | "\n",
154 | " nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5),\n",
155 | " nn.ReLU(),\n",
156 | " nn.BatchNorm2d(8),\n",
157 | "\n",
158 | " nn.MaxPool2d(2),\n",
159 | "\n",
160 | " nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5),\n",
161 | " nn.ReLU(),\n",
162 | " nn.BatchNorm2d(16),\n",
163 | "\n",
164 | " nn.MaxPool2d(2),\n",
165 | "\n",
166 | " nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5),\n",
167 | " nn.ReLU(),\n",
168 | " nn.BatchNorm2d(16),\n",
169 | "\n",
170 | " nn.MaxPool2d(2),\n",
171 | "\n",
172 | " )\n",
173 | " self.fc = nn.Sequential(\n",
174 | " nn.Linear(16*64*64, 4096),\n",
175 | " nn.ReLU(),\n",
176 | " nn.Dropout(.5),\n",
177 | " nn.Linear(4096, 1024),\n",
178 | " nn.ReLU(),\n",
179 | " nn.Dropout(.2),\n",
180 | " nn.Linear(1024, n_classes),\n",
181 | " nn.Sigmoid()\n",
182 | " )\n",
183 | " \n",
184 | " def forward(self, x):\n",
185 | " x = self.conv(x)\n",
186 | " x = x.view(-1, 16*64*64)\n",
187 | " x = self.fc(x)\n",
188 | " return x\n",
189 | "\n",
190 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
191 | "model = Net(1, 4).to(device)\n",
192 | "print(model)"
193 | ],
194 | "execution_count": 5,
195 | "outputs": [
196 | {
197 | "output_type": "stream",
198 | "text": [
199 | "Net(\n",
200 | " (conv): Sequential(\n",
201 | " (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))\n",
202 | " (1): ReLU()\n",
203 | " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204 | " (3): Conv2d(8, 8, kernel_size=(5, 5), stride=(1, 1))\n",
205 | " (4): ReLU()\n",
206 | " (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
207 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
208 | " (7): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))\n",
209 | " (8): ReLU()\n",
210 | " (9): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
211 | " (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
212 | " (11): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))\n",
213 | " (12): ReLU()\n",
214 | " (13): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
215 | " (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
216 | " )\n",
217 | " (fc): Sequential(\n",
218 | " (0): Linear(in_features=65536, out_features=4096, bias=True)\n",
219 | " (1): ReLU()\n",
220 | " (2): Dropout(p=0.5, inplace=False)\n",
221 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n",
222 | " (4): ReLU()\n",
223 | " (5): Dropout(p=0.2, inplace=False)\n",
224 | " (6): Linear(in_features=1024, out_features=4, bias=True)\n",
225 | " (7): Sigmoid()\n",
226 | " )\n",
227 | ")\n"
228 | ],
229 | "name": "stdout"
230 | }
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "metadata": {
236 | "colab": {
237 | "base_uri": "https://localhost:8080/"
238 | },
239 | "id": "EuWFNtGGv9Dg",
240 | "outputId": "d18547ff-ecb9-423b-9ac1-8b70c05972f4"
241 | },
242 | "source": [
243 | "# to double check dimensions\n",
244 | "model.conv(X[:1].view(-1, 1, 550, 550).to(device)).size()"
245 | ],
246 | "execution_count": 15,
247 | "outputs": [
248 | {
249 | "output_type": "execute_result",
250 | "data": {
251 | "text/plain": [
252 | "torch.Size([1, 16, 64, 64])"
253 | ]
254 | },
255 | "metadata": {
256 | "tags": []
257 | },
258 | "execution_count": 15
259 | }
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "metadata": {
265 | "colab": {
266 | "base_uri": "https://localhost:8080/"
267 | },
268 | "id": "qg9ILtQ_1Lcd",
269 | "outputId": "d37d12a2-9e12-4005-c1c6-544d3a1f2615"
270 | },
271 | "source": [
272 | "BATCH_SIZE = 4\n",
273 | "EPOCHS = 10\n",
274 | "model.train()\n",
275 | "optimizer = optim.SGD(model.parameters(), lr=0.01)\n",
276 | "loss_function = nn.MSELoss()\n",
277 | "\n",
278 | "def train(model):\n",
279 | " for epoch in range(EPOCHS):\n",
280 | " for i in tqdm.tqdm(range(0, len(X), BATCH_SIZE)):\n",
281 | " batch_X = X[i:i+BATCH_SIZE].view(-1, 1, 550, 550).to(device)\n",
282 | " batch_y = y[i:i+BATCH_SIZE].to(device)\n",
283 | "\n",
284 | " model.zero_grad()\n",
285 | "\n",
286 | " outputs = model(batch_X)\n",
287 | " loss = loss_function(outputs, batch_y)\n",
288 | " loss.backward()\n",
289 | " optimizer.step() \n",
290 | "\n",
291 | " print(f\"\\nEpoch: {epoch}. Loss: {loss}\")\n",
292 | "\n",
293 | "\n",
294 | "train(model)"
295 | ],
296 | "execution_count": 7,
297 | "outputs": [
298 | {
299 | "output_type": "stream",
300 | "text": [
301 | "100%|██████████| 909/909 [00:30<00:00, 29.67it/s]\n",
302 | " 0%| | 4/909 [00:00<00:24, 37.05it/s]"
303 | ],
304 | "name": "stderr"
305 | },
306 | {
307 | "output_type": "stream",
308 | "text": [
309 | "\n",
310 | "Epoch: 0. Loss: 0.005115551874041557\n"
311 | ],
312 | "name": "stdout"
313 | },
314 | {
315 | "output_type": "stream",
316 | "text": [
317 | "100%|██████████| 909/909 [00:30<00:00, 29.71it/s]\n",
318 | " 0%| | 4/909 [00:00<00:24, 37.57it/s]"
319 | ],
320 | "name": "stderr"
321 | },
322 | {
323 | "output_type": "stream",
324 | "text": [
325 | "\n",
326 | "Epoch: 1. Loss: 0.00565384142100811\n"
327 | ],
328 | "name": "stdout"
329 | },
330 | {
331 | "output_type": "stream",
332 | "text": [
333 | "100%|██████████| 909/909 [00:30<00:00, 29.68it/s]\n",
334 | " 0%| | 4/909 [00:00<00:24, 37.51it/s]"
335 | ],
336 | "name": "stderr"
337 | },
338 | {
339 | "output_type": "stream",
340 | "text": [
341 | "\n",
342 | "Epoch: 2. Loss: 0.0017478576628491282\n"
343 | ],
344 | "name": "stdout"
345 | },
346 | {
347 | "output_type": "stream",
348 | "text": [
349 | "100%|██████████| 909/909 [00:30<00:00, 29.74it/s]\n",
350 | " 0%| | 4/909 [00:00<00:24, 37.59it/s]"
351 | ],
352 | "name": "stderr"
353 | },
354 | {
355 | "output_type": "stream",
356 | "text": [
357 | "\n",
358 | "Epoch: 3. Loss: 0.00245093647390604\n"
359 | ],
360 | "name": "stdout"
361 | },
362 | {
363 | "output_type": "stream",
364 | "text": [
365 | "100%|██████████| 909/909 [00:30<00:00, 29.72it/s]\n",
366 | " 0%| | 4/909 [00:00<00:24, 37.33it/s]"
367 | ],
368 | "name": "stderr"
369 | },
370 | {
371 | "output_type": "stream",
372 | "text": [
373 | "\n",
374 | "Epoch: 4. Loss: 0.001998061779886484\n"
375 | ],
376 | "name": "stdout"
377 | },
378 | {
379 | "output_type": "stream",
380 | "text": [
381 | "100%|██████████| 909/909 [00:30<00:00, 29.71it/s]\n",
382 | " 0%| | 4/909 [00:00<00:24, 37.64it/s]"
383 | ],
384 | "name": "stderr"
385 | },
386 | {
387 | "output_type": "stream",
388 | "text": [
389 | "\n",
390 | "Epoch: 5. Loss: 0.0021302467212080956\n"
391 | ],
392 | "name": "stdout"
393 | },
394 | {
395 | "output_type": "stream",
396 | "text": [
397 | "100%|██████████| 909/909 [00:30<00:00, 29.72it/s]\n",
398 | " 0%| | 4/909 [00:00<00:24, 37.52it/s]"
399 | ],
400 | "name": "stderr"
401 | },
402 | {
403 | "output_type": "stream",
404 | "text": [
405 | "\n",
406 | "Epoch: 6. Loss: 0.0025986104737967253\n"
407 | ],
408 | "name": "stdout"
409 | },
410 | {
411 | "output_type": "stream",
412 | "text": [
413 | "100%|██████████| 909/909 [00:30<00:00, 29.71it/s]\n",
414 | " 0%| | 4/909 [00:00<00:24, 37.28it/s]"
415 | ],
416 | "name": "stderr"
417 | },
418 | {
419 | "output_type": "stream",
420 | "text": [
421 | "\n",
422 | "Epoch: 7. Loss: 0.0031612785533070564\n"
423 | ],
424 | "name": "stdout"
425 | },
426 | {
427 | "output_type": "stream",
428 | "text": [
429 | "100%|██████████| 909/909 [00:30<00:00, 29.70it/s]\n",
430 | " 0%| | 4/909 [00:00<00:24, 37.63it/s]"
431 | ],
432 | "name": "stderr"
433 | },
434 | {
435 | "output_type": "stream",
436 | "text": [
437 | "\n",
438 | "Epoch: 8. Loss: 0.0011636980343610048\n"
439 | ],
440 | "name": "stdout"
441 | },
442 | {
443 | "output_type": "stream",
444 | "text": [
445 | "100%|██████████| 909/909 [00:30<00:00, 29.67it/s]"
446 | ],
447 | "name": "stderr"
448 | },
449 | {
450 | "output_type": "stream",
451 | "text": [
452 | "\n",
453 | "Epoch: 9. Loss: 0.0014444717671722174\n"
454 | ],
455 | "name": "stdout"
456 | },
457 | {
458 | "output_type": "stream",
459 | "text": [
460 | "\n"
461 | ],
462 | "name": "stderr"
463 | }
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "metadata": {
469 | "id": "LTgBvy4CDLmE"
470 | },
471 | "source": [
472 | "e = 1\n",
473 | "\n",
474 | "label = model(X[e].view(-1, 1, 550, 550).to(device))\n",
475 | "\n",
476 | "xmin = label[0][0].item()\n",
477 | "ymin = label[0][1].item()\n",
478 | "xmax = label[0][2].item()\n",
479 | "ymax = label[0][3].item()\n",
480 | "\n",
481 | "img = training_data[e][0]\n",
482 | "bnd_img = cv2.rectangle(img, (int(xmin*IMG_SIZE), int(ymin*IMG_SIZE)), (int(xmax*IMG_SIZE), int(ymax*IMG_SIZE)), (255, 0, 0), 2)\n",
483 | "plt.imshow(bnd_img, cmap='gray')"
484 | ],
485 | "execution_count": null,
486 | "outputs": []
487 | }
488 | ]
489 | }
--------------------------------------------------------------------------------
/object_detection/object_localization/readme_resources/result_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_1.png
--------------------------------------------------------------------------------
/object_detection/object_localization/readme_resources/result_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_2.png
--------------------------------------------------------------------------------
/object_detection/object_localization/readme_resources/result_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_3.png
--------------------------------------------------------------------------------
/object_detection/object_localization/readme_resources/result_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wilhelmberghammer/MachineLearning/87bb5bd3ecd03104c2b38b80b5837379f7af2926/object_detection/object_localization/readme_resources/result_4.png
--------------------------------------------------------------------------------
/object_detection/yolo/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # this is just an implementation of the CNN model in the paper (-> Figure 3: The Architecture)
5 | class YOLO(nn.Module):
6 | def __init__(self, in_channels, n_anchorboxes, n_classes):
7 | super(YOLO, self).__init__()
8 | self.conv_layers = nn.Sequential(
9 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False),
10 | nn.LeakyReLU(.1),
11 | nn.MaxPool2d(kernel_size=2, stride=2),
12 |
13 | nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, stride=1, padding=1, bias=False),
14 | nn.LeakyReLU(.1),
15 | nn.MaxPool2d(kernel_size=2, stride=2),
16 |
17 | nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False),
18 | nn.LeakyReLU(.1),
19 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
20 | nn.LeakyReLU(.1),
21 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False),
22 | nn.LeakyReLU(.1),
23 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),
24 | nn.LeakyReLU(.1),
25 | nn.MaxPool2d(kernel_size=2, stride=2),
26 |
27 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False),
28 | nn.LeakyReLU(.1),
29 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),
30 | nn.LeakyReLU(.1),
31 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False),
32 | nn.LeakyReLU(.1),
33 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),
34 | nn.LeakyReLU(.1),
35 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False),
36 | nn.LeakyReLU(.1),
37 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),
38 | nn.LeakyReLU(.1),
39 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False),
40 | nn.LeakyReLU(.1),
41 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),
42 | nn.LeakyReLU(.1),
43 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False),
44 | nn.LeakyReLU(.1),
45 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False),
46 | nn.LeakyReLU(.1),
47 | nn.MaxPool2d(kernel_size=2, stride=2),
48 |
49 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False),
50 | nn.LeakyReLU(.1),
51 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False),
52 | nn.LeakyReLU(.1),
53 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False),
54 | nn.LeakyReLU(.1),
55 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False),
56 | nn.LeakyReLU(.1),
57 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False),
58 | nn.LeakyReLU(.1),
59 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1, bias=False),
60 | nn.LeakyReLU(.1),
61 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False),
62 | nn.LeakyReLU(.1),
63 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False),
64 | nn.LeakyReLU(.1),
65 | )
66 |
67 | self.fc = nn.Sequential(
68 | nn.Flatten(),
69 | nn.Linear(1024*7*7, 4096),
70 | nn.Dropout(.5),
71 | nn.LeakyReLU(.1),
72 | nn.Linear(4096, 7*7*(n_anchorboxes*5 + n_classes))
73 | )
74 |
75 | def forward(self, x):
76 | x = self.conv_layers(x)
77 | x = self.fc(x)
78 | return x
--------------------------------------------------------------------------------
/poetry.lock:
--------------------------------------------------------------------------------
1 | [[package]]
2 | name = "cycler"
3 | version = "0.10.0"
4 | description = "Composable style cycles"
5 | category = "main"
6 | optional = false
7 | python-versions = "*"
8 |
9 | [package.dependencies]
10 | six = "*"
11 |
12 | [[package]]
13 | name = "kiwisolver"
14 | version = "1.3.1"
15 | description = "A fast implementation of the Cassowary constraint solver"
16 | category = "main"
17 | optional = false
18 | python-versions = ">=3.6"
19 |
20 | [[package]]
21 | name = "matplotlib"
22 | version = "3.3.4"
23 | description = "Python plotting package"
24 | category = "main"
25 | optional = false
26 | python-versions = ">=3.6"
27 |
28 | [package.dependencies]
29 | cycler = ">=0.10"
30 | kiwisolver = ">=1.0.1"
31 | numpy = ">=1.15"
32 | pillow = ">=6.2.0"
33 | pyparsing = ">=2.0.3,<2.0.4 || >2.0.4,<2.1.2 || >2.1.2,<2.1.6 || >2.1.6"
34 | python-dateutil = ">=2.1"
35 |
36 | [[package]]
37 | name = "numpy"
38 | version = "1.20.1"
39 | description = "NumPy is the fundamental package for array computing with Python."
40 | category = "main"
41 | optional = false
42 | python-versions = ">=3.7"
43 |
44 | [[package]]
45 | name = "pillow"
46 | version = "8.1.0"
47 | description = "Python Imaging Library (Fork)"
48 | category = "main"
49 | optional = false
50 | python-versions = ">=3.6"
51 |
52 | [[package]]
53 | name = "pyparsing"
54 | version = "2.4.7"
55 | description = "Python parsing module"
56 | category = "main"
57 | optional = false
58 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
59 |
60 | [[package]]
61 | name = "python-dateutil"
62 | version = "2.8.1"
63 | description = "Extensions to the standard Python datetime module"
64 | category = "main"
65 | optional = false
66 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
67 |
68 | [package.dependencies]
69 | six = ">=1.5"
70 |
71 | [[package]]
72 | name = "six"
73 | version = "1.15.0"
74 | description = "Python 2 and 3 compatibility utilities"
75 | category = "main"
76 | optional = false
77 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
78 |
79 | [[package]]
80 | name = "torch"
81 | version = "1.7.1"
82 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
83 | category = "main"
84 | optional = false
85 | python-versions = ">=3.6.2"
86 |
87 | [package.dependencies]
88 | numpy = "*"
89 | typing-extensions = "*"
90 |
91 | [[package]]
92 | name = "torchvision"
93 | version = "0.8.2"
94 | description = "image and video datasets and models for torch deep learning"
95 | category = "main"
96 | optional = false
97 | python-versions = "*"
98 |
99 | [package.dependencies]
100 | numpy = "*"
101 | pillow = ">=4.1.1"
102 | torch = "1.7.1"
103 |
104 | [package.extras]
105 | scipy = ["scipy"]
106 |
107 | [[package]]
108 | name = "typing-extensions"
109 | version = "3.7.4.3"
110 | description = "Backported and Experimental Type Hints for Python 3.5+"
111 | category = "main"
112 | optional = false
113 | python-versions = "*"
114 |
115 | [metadata]
116 | lock-version = "1.1"
117 | python-versions = "^3.8"
118 | content-hash = "7f5f14d8e96047e0a4c039c8cdf67b2023797c8e2a39124797e9d1d08f1118c7"
119 |
120 | [metadata.files]
121 | cycler = [
122 | {file = "cycler-0.10.0-py2.py3-none-any.whl", hash = "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d"},
123 | {file = "cycler-0.10.0.tar.gz", hash = "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8"},
124 | ]
125 | kiwisolver = [
126 | {file = "kiwisolver-1.3.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:fd34fbbfbc40628200730bc1febe30631347103fc8d3d4fa012c21ab9c11eca9"},
127 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:d3155d828dec1d43283bd24d3d3e0d9c7c350cdfcc0bd06c0ad1209c1bbc36d0"},
128 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:5a7a7dbff17e66fac9142ae2ecafb719393aaee6a3768c9de2fd425c63b53e21"},
129 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:f8d6f8db88049a699817fd9178782867bf22283e3813064302ac59f61d95be05"},
130 | {file = "kiwisolver-1.3.1-cp36-cp36m-manylinux2014_ppc64le.whl", hash = "sha256:5f6ccd3dd0b9739edcf407514016108e2280769c73a85b9e59aa390046dbf08b"},
131 | {file = "kiwisolver-1.3.1-cp36-cp36m-win32.whl", hash = "sha256:225e2e18f271e0ed8157d7f4518ffbf99b9450fca398d561eb5c4a87d0986dd9"},
132 | {file = "kiwisolver-1.3.1-cp36-cp36m-win_amd64.whl", hash = "sha256:cf8b574c7b9aa060c62116d4181f3a1a4e821b2ec5cbfe3775809474113748d4"},
133 | {file = "kiwisolver-1.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:232c9e11fd7ac3a470d65cd67e4359eee155ec57e822e5220322d7b2ac84fbf0"},
134 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:b38694dcdac990a743aa654037ff1188c7a9801ac3ccc548d3341014bc5ca278"},
135 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ca3820eb7f7faf7f0aa88de0e54681bddcb46e485beb844fcecbcd1c8bd01689"},
136 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:c8fd0f1ae9d92b42854b2979024d7597685ce4ada367172ed7c09edf2cef9cb8"},
137 | {file = "kiwisolver-1.3.1-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:1e1bc12fb773a7b2ffdeb8380609f4f8064777877b2225dec3da711b421fda31"},
138 | {file = "kiwisolver-1.3.1-cp37-cp37m-win32.whl", hash = "sha256:72c99e39d005b793fb7d3d4e660aed6b6281b502e8c1eaf8ee8346023c8e03bc"},
139 | {file = "kiwisolver-1.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:8be8d84b7d4f2ba4ffff3665bcd0211318aa632395a1a41553250484a871d454"},
140 | {file = "kiwisolver-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:31dfd2ac56edc0ff9ac295193eeaea1c0c923c0355bf948fbd99ed6018010b72"},
141 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:563c649cfdef27d081c84e72a03b48ea9408c16657500c312575ae9d9f7bc1c3"},
142 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:78751b33595f7f9511952e7e60ce858c6d64db2e062afb325985ddbd34b5c131"},
143 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a357fd4f15ee49b4a98b44ec23a34a95f1e00292a139d6015c11f55774ef10de"},
144 | {file = "kiwisolver-1.3.1-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:5989db3b3b34b76c09253deeaf7fbc2707616f130e166996606c284395da3f18"},
145 | {file = "kiwisolver-1.3.1-cp38-cp38-win32.whl", hash = "sha256:c08e95114951dc2090c4a630c2385bef681cacf12636fb0241accdc6b303fd81"},
146 | {file = "kiwisolver-1.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:44a62e24d9b01ba94ae7a4a6c3fb215dc4af1dde817e7498d901e229aaf50e4e"},
147 | {file = "kiwisolver-1.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:50af681a36b2a1dee1d3c169ade9fdc59207d3c31e522519181e12f1b3ba7000"},
148 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:a53d27d0c2a0ebd07e395e56a1fbdf75ffedc4a05943daf472af163413ce9598"},
149 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:834ee27348c4aefc20b479335fd422a2c69db55f7d9ab61721ac8cd83eb78882"},
150 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5c3e6455341008a054cccee8c5d24481bcfe1acdbc9add30aa95798e95c65621"},
151 | {file = "kiwisolver-1.3.1-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:acef3d59d47dd85ecf909c359d0fd2c81ed33bdff70216d3956b463e12c38a54"},
152 | {file = "kiwisolver-1.3.1-cp39-cp39-win32.whl", hash = "sha256:c5518d51a0735b1e6cee1fdce66359f8d2b59c3ca85dc2b0813a8aa86818a030"},
153 | {file = "kiwisolver-1.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:b9edd0110a77fc321ab090aaa1cfcaba1d8499850a12848b81be2222eab648f6"},
154 | {file = "kiwisolver-1.3.1-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0cd53f403202159b44528498de18f9285b04482bab2a6fc3f5dd8dbb9352e30d"},
155 | {file = "kiwisolver-1.3.1-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:33449715e0101e4d34f64990352bce4095c8bf13bed1b390773fc0a7295967b3"},
156 | {file = "kiwisolver-1.3.1-pp36-pypy36_pp73-win32.whl", hash = "sha256:401a2e9afa8588589775fe34fc22d918ae839aaaf0c0e96441c0fdbce6d8ebe6"},
157 | {file = "kiwisolver-1.3.1.tar.gz", hash = "sha256:950a199911a8d94683a6b10321f9345d5a3a8433ec58b217ace979e18f16e248"},
158 | ]
159 | matplotlib = [
160 | {file = "matplotlib-3.3.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:672960dd114e342b7c610bf32fb99d14227f29919894388b41553217457ba7ef"},
161 | {file = "matplotlib-3.3.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:7c155437ae4fd366e2700e2716564d1787700687443de46bcb895fe0f84b761d"},
162 | {file = "matplotlib-3.3.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:a17f0a10604fac7627ec82820439e7db611722e80c408a726cd00d8c974c2fb3"},
163 | {file = "matplotlib-3.3.4-cp36-cp36m-win32.whl", hash = "sha256:215e2a30a2090221a9481db58b770ce56b8ef46f13224ae33afe221b14b24dc1"},
164 | {file = "matplotlib-3.3.4-cp36-cp36m-win_amd64.whl", hash = "sha256:348e6032f666ffd151b323342f9278b16b95d4a75dfacae84a11d2829a7816ae"},
165 | {file = "matplotlib-3.3.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:94bdd1d55c20e764d8aea9d471d2ae7a7b2c84445e0fa463f02e20f9730783e1"},
166 | {file = "matplotlib-3.3.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a1acb72f095f1d58ecc2538ed1b8bca0b57df313b13db36ed34b8cdf1868e674"},
167 | {file = "matplotlib-3.3.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:46b1a60a04e6d884f0250d5cc8dc7bd21a9a96c584a7acdaab44698a44710bab"},
168 | {file = "matplotlib-3.3.4-cp37-cp37m-win32.whl", hash = "sha256:ed4a9e6dcacba56b17a0a9ac22ae2c72a35b7f0ef0693aa68574f0b2df607a89"},
169 | {file = "matplotlib-3.3.4-cp37-cp37m-win_amd64.whl", hash = "sha256:c24c05f645aef776e8b8931cb81e0f1632d229b42b6d216e30836e2e145a2b40"},
170 | {file = "matplotlib-3.3.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7310e353a4a35477c7f032409966920197d7df3e757c7624fd842f3eeb307d3d"},
171 | {file = "matplotlib-3.3.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:451cc89cb33d6652c509fc6b588dc51c41d7246afdcc29b8624e256b7663ed1f"},
172 | {file = "matplotlib-3.3.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:3d2eb9c1cc254d0ffa90bc96fde4b6005d09c2228f99dfd493a4219c1af99644"},
173 | {file = "matplotlib-3.3.4-cp38-cp38-win32.whl", hash = "sha256:e15fa23d844d54e7b3b7243afd53b7567ee71c721f592deb0727ee85e668f96a"},
174 | {file = "matplotlib-3.3.4-cp38-cp38-win_amd64.whl", hash = "sha256:1de0bb6cbfe460725f0e97b88daa8643bcf9571c18ba90bb8e41432aaeca91d6"},
175 | {file = "matplotlib-3.3.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f44149a0ef5b4991aaef12a93b8e8d66d6412e762745fea1faa61d98524e0ba9"},
176 | {file = "matplotlib-3.3.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:746a1df55749629e26af7f977ea426817ca9370ad1569436608dc48d1069b87c"},
177 | {file = "matplotlib-3.3.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:5f571b92a536206f7958f7cb2d367ff6c9a1fa8229dc35020006e4cdd1ca0acd"},
178 | {file = "matplotlib-3.3.4-cp39-cp39-win32.whl", hash = "sha256:9265ae0fb35e29f9b8cc86c2ab0a2e3dcddc4dd9de4b85bf26c0f63fe5c1c2ca"},
179 | {file = "matplotlib-3.3.4-cp39-cp39-win_amd64.whl", hash = "sha256:9a79e5dd7bb797aa611048f5b70588b23c5be05b63eefd8a0d152ac77c4243db"},
180 | {file = "matplotlib-3.3.4-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1e850163579a8936eede29fad41e202b25923a0a8d5ffd08ce50fc0a97dcdc93"},
181 | {file = "matplotlib-3.3.4-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:d738acfdfb65da34c91acbdb56abed46803db39af259b7f194dc96920360dbe4"},
182 | {file = "matplotlib-3.3.4-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa49571d8030ad0b9ac39708ee77bd2a22f87815e12bdee52ecaffece9313ed8"},
183 | {file = "matplotlib-3.3.4-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:cf3a7e54eff792f0815dbbe9b85df2f13d739289c93d346925554f71d484be78"},
184 | {file = "matplotlib-3.3.4.tar.gz", hash = "sha256:3e477db76c22929e4c6876c44f88d790aacdf3c3f8f3a90cb1975c0bf37825b0"},
185 | ]
186 | numpy = [
187 | {file = "numpy-1.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ae61f02b84a0211abb56462a3b6cd1e7ec39d466d3160eb4e1da8bf6717cdbeb"},
188 | {file = "numpy-1.20.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:65410c7f4398a0047eea5cca9b74009ea61178efd78d1be9847fac1d6716ec1e"},
189 | {file = "numpy-1.20.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:2d7e27442599104ee08f4faed56bb87c55f8b10a5494ac2ead5c98a4b289e61f"},
190 | {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:4ed8e96dc146e12c1c5cdd6fb9fd0757f2ba66048bf94c5126b7efebd12d0090"},
191 | {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:ecb5b74c702358cdc21268ff4c37f7466357871f53a30e6f84c686952bef16a9"},
192 | {file = "numpy-1.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b9410c0b6fed4a22554f072a86c361e417f0258838957b78bd063bde2c7f841f"},
193 | {file = "numpy-1.20.1-cp37-cp37m-win32.whl", hash = "sha256:3d3087e24e354c18fb35c454026af3ed8997cfd4997765266897c68d724e4845"},
194 | {file = "numpy-1.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:89f937b13b8dd17b0099c7c2e22066883c86ca1575a975f754babc8fbf8d69a9"},
195 | {file = "numpy-1.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a1d7995d1023335e67fb070b2fae6f5968f5be3802b15ad6d79d81ecaa014fe0"},
196 | {file = "numpy-1.20.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:60759ab15c94dd0e1ed88241fd4fa3312db4e91d2c8f5a2d4cf3863fad83d65b"},
197 | {file = "numpy-1.20.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:125a0e10ddd99a874fd357bfa1b636cd58deb78ba4a30b5ddb09f645c3512e04"},
198 | {file = "numpy-1.20.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:c26287dfc888cf1e65181f39ea75e11f42ffc4f4529e5bd19add57ad458996e2"},
199 | {file = "numpy-1.20.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:7199109fa46277be503393be9250b983f325880766f847885607d9b13848f257"},
200 | {file = "numpy-1.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:72251e43ac426ff98ea802a931922c79b8d7596480300eb9f1b1e45e0543571e"},
201 | {file = "numpy-1.20.1-cp38-cp38-win32.whl", hash = "sha256:c91ec9569facd4757ade0888371eced2ecf49e7982ce5634cc2cf4e7331a4b14"},
202 | {file = "numpy-1.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:13adf545732bb23a796914fe5f891a12bd74cf3d2986eed7b7eba2941eea1590"},
203 | {file = "numpy-1.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:104f5e90b143dbf298361a99ac1af4cf59131218a045ebf4ee5990b83cff5fab"},
204 | {file = "numpy-1.20.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:89e5336f2bec0c726ac7e7cdae181b325a9c0ee24e604704ed830d241c5e47ff"},
205 | {file = "numpy-1.20.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:032be656d89bbf786d743fee11d01ef318b0781281241997558fa7950028dd29"},
206 | {file = "numpy-1.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:66b467adfcf628f66ea4ac6430ded0614f5cc06ba530d09571ea404789064adc"},
207 | {file = "numpy-1.20.1-cp39-cp39-win32.whl", hash = "sha256:12e4ba5c6420917571f1a5becc9338abbde71dd811ce40b37ba62dec7b39af6d"},
208 | {file = "numpy-1.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:9c94cab5054bad82a70b2e77741271790304651d584e2cdfe2041488e753863b"},
209 | {file = "numpy-1.20.1-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:9eb551d122fadca7774b97db8a112b77231dcccda8e91a5bc99e79890797175e"},
210 | {file = "numpy-1.20.1.zip", hash = "sha256:3bc63486a870294683980d76ec1e3efc786295ae00128f9ea38e2c6e74d5a60a"},
211 | ]
212 | pillow = [
213 | {file = "Pillow-8.1.0-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:d355502dce85ade85a2511b40b4c61a128902f246504f7de29bbeec1ae27933a"},
214 | {file = "Pillow-8.1.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:93a473b53cc6e0b3ce6bf51b1b95b7b1e7e6084be3a07e40f79b42e83503fbf2"},
215 | {file = "Pillow-8.1.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:2353834b2c49b95e1313fb34edf18fca4d57446675d05298bb694bca4b194174"},
216 | {file = "Pillow-8.1.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:1d208e670abfeb41b6143537a681299ef86e92d2a3dac299d3cd6830d5c7bded"},
217 | {file = "Pillow-8.1.0-cp36-cp36m-win32.whl", hash = "sha256:dd9eef866c70d2cbbea1ae58134eaffda0d4bfea403025f4db6859724b18ab3d"},
218 | {file = "Pillow-8.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:b09e10ec453de97f9a23a5aa5e30b334195e8d2ddd1ce76cc32e52ba63c8b31d"},
219 | {file = "Pillow-8.1.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:b02a0b9f332086657852b1f7cb380f6a42403a6d9c42a4c34a561aa4530d5234"},
220 | {file = "Pillow-8.1.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:ca20739e303254287138234485579b28cb0d524401f83d5129b5ff9d606cb0a8"},
221 | {file = "Pillow-8.1.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:604815c55fd92e735f9738f65dabf4edc3e79f88541c221d292faec1904a4b17"},
222 | {file = "Pillow-8.1.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cf6e33d92b1526190a1de904df21663c46a456758c0424e4f947ae9aa6088bf7"},
223 | {file = "Pillow-8.1.0-cp37-cp37m-win32.whl", hash = "sha256:47c0d93ee9c8b181f353dbead6530b26980fe4f5485aa18be8f1fd3c3cbc685e"},
224 | {file = "Pillow-8.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:96d4dc103d1a0fa6d47c6c55a47de5f5dafd5ef0114fa10c85a1fd8e0216284b"},
225 | {file = "Pillow-8.1.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:7916cbc94f1c6b1301ac04510d0881b9e9feb20ae34094d3615a8a7c3db0dcc0"},
226 | {file = "Pillow-8.1.0-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3de6b2ee4f78c6b3d89d184ade5d8fa68af0848f9b6b6da2b9ab7943ec46971a"},
227 | {file = "Pillow-8.1.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cdbbe7dff4a677fb555a54f9bc0450f2a21a93c5ba2b44e09e54fcb72d2bd13d"},
228 | {file = "Pillow-8.1.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:f50e7a98b0453f39000619d845be8b06e611e56ee6e8186f7f60c3b1e2f0feae"},
229 | {file = "Pillow-8.1.0-cp38-cp38-win32.whl", hash = "sha256:cb192176b477d49b0a327b2a5a4979552b7a58cd42037034316b8018ac3ebb59"},
230 | {file = "Pillow-8.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:6c5275bd82711cd3dcd0af8ce0bb99113ae8911fc2952805f1d012de7d600a4c"},
231 | {file = "Pillow-8.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:165c88bc9d8dba670110c689e3cc5c71dbe4bfb984ffa7cbebf1fac9554071d6"},
232 | {file = "Pillow-8.1.0-cp39-cp39-manylinux1_i686.whl", hash = "sha256:5e2fe3bb2363b862671eba632537cd3a823847db4d98be95690b7e382f3d6378"},
233 | {file = "Pillow-8.1.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7612520e5e1a371d77e1d1ca3a3ee6227eef00d0a9cddb4ef7ecb0b7396eddf7"},
234 | {file = "Pillow-8.1.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d673c4990acd016229a5c1c4ee8a9e6d8f481b27ade5fc3d95938697fa443ce0"},
235 | {file = "Pillow-8.1.0-cp39-cp39-win32.whl", hash = "sha256:dc577f4cfdda354db3ae37a572428a90ffdbe4e51eda7849bf442fb803f09c9b"},
236 | {file = "Pillow-8.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:22d070ca2e60c99929ef274cfced04294d2368193e935c5d6febfd8b601bf865"},
237 | {file = "Pillow-8.1.0-pp36-pypy36_pp73-macosx_10_10_x86_64.whl", hash = "sha256:a3d3e086474ef12ef13d42e5f9b7bbf09d39cf6bd4940f982263d6954b13f6a9"},
238 | {file = "Pillow-8.1.0-pp36-pypy36_pp73-manylinux2010_i686.whl", hash = "sha256:731ca5aabe9085160cf68b2dbef95fc1991015bc0a3a6ea46a371ab88f3d0913"},
239 | {file = "Pillow-8.1.0-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:bba80df38cfc17f490ec651c73bb37cd896bc2400cfba27d078c2135223c1206"},
240 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-macosx_10_10_x86_64.whl", hash = "sha256:c3d911614b008e8a576b8e5303e3db29224b455d3d66d1b2848ba6ca83f9ece9"},
241 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-manylinux2010_i686.whl", hash = "sha256:39725acf2d2e9c17356e6835dccebe7a697db55f25a09207e38b835d5e1bc032"},
242 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:81c3fa9a75d9f1afafdb916d5995633f319db09bd773cb56b8e39f1e98d90820"},
243 | {file = "Pillow-8.1.0-pp37-pypy37_pp73-win32.whl", hash = "sha256:b6f00ad5ebe846cc91763b1d0c6d30a8042e02b2316e27b05de04fa6ec831ec5"},
244 | {file = "Pillow-8.1.0.tar.gz", hash = "sha256:887668e792b7edbfb1d3c9d8b5d8c859269a0f0eba4dda562adb95500f60dbba"},
245 | ]
246 | pyparsing = [
247 | {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"},
248 | {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"},
249 | ]
250 | python-dateutil = [
251 | {file = "python-dateutil-2.8.1.tar.gz", hash = "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c"},
252 | {file = "python_dateutil-2.8.1-py2.py3-none-any.whl", hash = "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a"},
253 | ]
254 | six = [
255 | {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"},
256 | {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"},
257 | ]
258 | torch = [
259 | {file = "torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:422e64e98d0e100c360993819d0307e5d56e9517b26135808ad68984d577d75a"},
260 | {file = "torch-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f0aaf657145533824b15f2fd8fde8f8c67fe6c6281088ef588091f03fad90243"},
261 | {file = "torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:af464a6f4314a875035e0c4c2b07517599704b214634f4ed3ad2e748c5ef291f"},
262 | {file = "torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d76c255a41484c1d41a9ff570b9c9f36cb85df9428aa15a58ae16ac7cfc2ea6"},
263 | {file = "torch-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d241c3f1c4d563e4ba86f84769c23e12606db167ee6f674eedff6d02901462e3"},
264 | {file = "torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:de84b4166e3f7335eb868b51d3bbd909ec33828af27290b4171bce832a55be3c"},
265 | {file = "torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:dd2fc6880c95e836960d86efbbc7f63d3287f2e1893c51d31f96dbfe02f0d73e"},
266 | {file = "torch-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:e000b94be3aa58ad7f61e7d07cf379ea9366cf6c6874e68bd58ad0bdc537b3a7"},
267 | {file = "torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2e49cac969976be63117004ee00d0a3e3dd4ea662ad77383f671b8992825de1a"},
268 | {file = "torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a3793dcceb12b1e2281290cca1277c5ce86ddfd5bf044f654285a4d69057aea7"},
269 | {file = "torch-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:6652a767a0572ae0feb74ad128758e507afd3b8396b6e7f147e438ba8d4c6f63"},
270 | {file = "torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:38d67f4fb189a92a977b2c0a38e4f6dd413e0bf55aa6d40004696df7e40a71ff"},
271 | ]
272 | torchvision = [
273 | {file = "torchvision-0.8.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86fae370d222f76ad57c57c3bee03f78b8db727743bfb4c1559a3d395159cea8"},
274 | {file = "torchvision-0.8.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:951239b5fcb911dbf78c1385d677f5f48c7a1b12859e3d3ec287562821b17cf2"},
275 | {file = "torchvision-0.8.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24db8f4c3d812a032273f68563ad5dbd724f5bfbed523d0c6dce8cede26bb153"},
276 | {file = "torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b068f6bcbe91bdd34dda0a39e8a26392add45a3be82543f6dd523b76484fb56f"},
277 | {file = "torchvision-0.8.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb76a66b9b0693f758a881a2bf333ed97e3c0c3f15a413c4f49d8dd8bd21307"},
278 | {file = "torchvision-0.8.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd8817e9197fc60ebae37162a445db90bbf35591314a5767ad3d1490b5d65b0f"},
279 | {file = "torchvision-0.8.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bd58acc3366ec02266aae56a7a752d43ef07de4a6ba420c4f907d0c9168bb8c"},
280 | {file = "torchvision-0.8.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:976750a49db2e23dc5a1ed0b5c31f7af51ed2702eee410ee09ef985c3a3e48cf"},
281 | ]
282 | typing-extensions = [
283 | {file = "typing_extensions-3.7.4.3-py2-none-any.whl", hash = "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"},
284 | {file = "typing_extensions-3.7.4.3-py3-none-any.whl", hash = "sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918"},
285 | {file = "typing_extensions-3.7.4.3.tar.gz", hash = "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c"},
286 | ]
287 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "machinelearning"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["wilhelmberghammer"]
6 |
7 | [tool.poetry.dependencies]
8 | python = "^3.8"
9 | Pillow = "^8.1.0"
10 | torch = "^1.7.1"
11 | torchvision = "^0.8.2"
12 | matplotlib = "^3.3.4"
13 |
14 | [tool.poetry.dev-dependencies]
15 |
16 | [build-system]
17 | requires = ["poetry-core>=1.0.0"]
18 | build-backend = "poetry.core.masonry.api"
19 |
--------------------------------------------------------------------------------
/transfer_learning/README.md:
--------------------------------------------------------------------------------
1 | # Transfer Learning for computer vision
2 |
3 | I'll be using the ants and bees dataset like in the pytorch [tutorial on transfer learning](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
4 |
5 |
6 | ## Results
7 | ```python
8 | batch_size = 8
9 | EPOCHS = 10
10 |
11 | pretrained_model = models.resnet18(pretrained=True)
12 |
13 | for param in pretrained_model.parameters():
14 | param.requires_grad = False
15 |
16 | pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, 2)
17 | pretrained_model = pretrained_model.to(device)
18 |
19 | criterion = nn.CrossEntropyLoss()
20 | optimizer = optim.Adam(pretrained_model.fc.parameters(), lr=0.001)
21 |
22 | model = train(pretrained_model, EPOCHS, criterion, optimizer, data_loaders)
23 | ```
24 | ```
25 | cpu
26 |
27 | __________
28 | EPOCH 0/9
29 | ----------
30 | train Loss: 0.6054 Acc: 0.6721
31 | val Loss: 0.4359 Acc: 0.8039
32 | __________
33 | EPOCH 1/9
34 | ----------
35 | train Loss: 0.3995 Acc: 0.8648
36 | val Loss: 0.3478 Acc: 0.8431
37 | __________
38 | EPOCH 2/9
39 | ----------
40 | train Loss: 0.3758 Acc: 0.8279
41 | val Loss: 0.2956 Acc: 0.9085
42 | __________
43 | EPOCH 3/9
44 | ----------
45 | train Loss: 0.2948 Acc: 0.8770
46 | val Loss: 0.2717 Acc: 0.9216
47 | __________
48 | EPOCH 4/9
49 | ----------
50 | train Loss: 0.3318 Acc: 0.8607
51 | val Loss: 0.2493 Acc: 0.8954
52 | __________
53 | EPOCH 5/9
54 | ----------
55 | train Loss: 0.2259 Acc: 0.8975
56 | val Loss: 0.2445 Acc: 0.8954
57 | __________
58 | EPOCH 6/9
59 | ----------
60 | train Loss: 0.2454 Acc: 0.9057
61 | val Loss: 0.2424 Acc: 0.8954
62 | __________
63 | EPOCH 7/9
64 | ----------
65 | train Loss: 0.2141 Acc: 0.9303
66 | val Loss: 0.2759 Acc: 0.8627
67 | __________
68 | EPOCH 8/9
69 | ----------
70 | train Loss: 0.2456 Acc: 0.8934
71 | val Loss: 0.2233 Acc: 0.9216
72 | __________
73 | EPOCH 9/9
74 | ----------
75 | train Loss: 0.1821 Acc: 0.9385
76 | val Loss: 0.2317 Acc: 0.9085
77 | ```
78 |
--------------------------------------------------------------------------------
/transfer_learning/transfer_learn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torchvision.transforms as transforms
6 | import torchvision.models as models
7 | import torchvision.datasets as datasets
8 |
9 |
10 | def get_data(data_dir, batch_size, num_workers=2):
11 | data_transforms = {
12 | 'train': transforms.Compose([
13 | transforms.Resize((224, 224)),
14 | transforms.RandomHorizontalFlip(),
15 | transforms.ToTensor(),
16 | # because of the resnet model
17 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
18 | ]),
19 | 'val': transforms.Compose([
20 | transforms.Resize((224, 224)),
21 | transforms.ToTensor(),
22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23 | ]),
24 | }
25 |
26 | img_datasets = {
27 | phase: datasets.ImageFolder(os.path.join(data_dir, phase), data_transforms[phase]) for phase in ['train', 'val']
28 | }
29 |
30 | data_loaders = {
31 | phase: torch.utils.data.DataLoader(img_datasets[phase], batch_size=batch_size, shuffle=True, num_workers=num_workers) for phase in ['train', 'val']
32 | }
33 |
34 | class_names = img_datasets['train'].classes
35 |
36 | return data_loaders
37 |
38 |
39 | def train(model, EPOCHS, criterion, optimizer, data_loaders):
40 | for epoch in range(EPOCHS):
41 | print('_' * 10)
42 | print(f'EPOCH {epoch}/{EPOCHS - 1}')
43 | print('-' * 10)
44 |
45 | # each epoch has a training and a val phase ... never did that before (from the pytorch tutorial) but makes sense
46 | for phase in ['train', 'val']:
47 | if phase == 'train':
48 | model.train()
49 | else:
50 | model.eval()
51 |
52 | running_loss = .0
53 | running_correct = 0
54 |
55 | for x, labels in data_loaders[phase]:
56 | x = x.to(device)
57 | labels = labels.to(device)
58 |
59 | optimizer.zero_grad()
60 |
61 | with torch.set_grad_enabled(phase == 'train'):
62 | outputs = model(x)
63 | loss = criterion(outputs, labels)
64 |
65 | # value is not important (_); index is important (preds)
66 | _, preds = torch.max(outputs, 1)
67 |
68 | if phase == 'train':
69 | loss.backward()
70 | optimizer.step()
71 |
72 | running_loss += loss.item()*x.size(0)
73 | running_correct += torch.sum(preds==labels.data)
74 |
75 | epoch_loss = running_loss / len(data_loaders[phase].dataset)
76 | epoch_acc = running_correct.double() / len(data_loaders[phase].dataset)
77 |
78 | print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
79 |
80 | return model
81 |
82 |
83 |
84 | if __name__ == '__main__':
85 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86 | print(device, '\n')
87 |
88 | # hyperparameters
89 | batch_size = 8
90 | EPOCHS = 10
91 |
92 | data_loaders = get_data(data_dir='./data/', batch_size=batch_size)
93 |
94 | pretrained_model = models.resnet18(pretrained=True)
95 | for param in pretrained_model.parameters():
96 | param.requires_grad = False
97 |
98 | pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, 2)
99 | pretrained_model = pretrained_model.to(device)
100 |
101 | criterion = nn.CrossEntropyLoss()
102 |
103 | optimizer = optim.Adam(pretrained_model.fc.parameters(), lr=0.001)
104 |
105 | model = train(pretrained_model, EPOCHS, criterion, optimizer, data_loaders)
106 |
107 |
--------------------------------------------------------------------------------