├── .vscode
└── settings.json
├── Pipfile
├── Pipfile.lock
├── README.md
├── data
└── wines.csv
├── img
└── skorch2.jpg
├── main.py
└── src
├── __init__.py
├── data_loader.py
└── model.py
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "/Users/fernandolopez/.local/share/virtualenvs/skorch-zxWYGwU5/bin/python"
3 | }
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | name = "pypi"
3 | url = "https://pypi.org/simple"
4 | verify_ssl = true
5 |
6 | [dev-packages]
7 |
8 | [packages]
9 | scikit-learn = "*"
10 | skorch = "*"
11 | torch = "*"
12 | pandas = "*"
13 |
14 | [requires]
15 | python_version = "3.8"
16 |
--------------------------------------------------------------------------------
/Pipfile.lock:
--------------------------------------------------------------------------------
1 | {
2 | "_meta": {
3 | "hash": {
4 | "sha256": "caba5c50b162d0d09265df9f323da40e4399d222f88c1294f47b43f97165d8f6"
5 | },
6 | "pipfile-spec": 6,
7 | "requires": {
8 | "python_version": "3.8"
9 | },
10 | "sources": [
11 | {
12 | "name": "pypi",
13 | "url": "https://pypi.org/simple",
14 | "verify_ssl": true
15 | }
16 | ]
17 | },
18 | "default": {
19 | "dataclasses": {
20 | "hashes": [
21 | "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f",
22 | "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"
23 | ],
24 | "version": "==0.6"
25 | },
26 | "future": {
27 | "hashes": [
28 | "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"
29 | ],
30 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
31 | "version": "==0.18.2"
32 | },
33 | "joblib": {
34 | "hashes": [
35 | "sha256:698c311779f347cf6b7e6b8a39bb682277b8ee4aba8cf9507bc0cf4cd4737b72",
36 | "sha256:9e284edd6be6b71883a63c9b7f124738a3c16195513ad940eae7e3438de885d5"
37 | ],
38 | "markers": "python_version >= '3.6'",
39 | "version": "==0.17.0"
40 | },
41 | "numpy": {
42 | "hashes": [
43 | "sha256:08308c38e44cc926bdfce99498b21eec1f848d24c302519e64203a8da99a97db",
44 | "sha256:09c12096d843b90eafd01ea1b3307e78ddd47a55855ad402b157b6c4862197ce",
45 | "sha256:13d166f77d6dc02c0a73c1101dd87fdf01339febec1030bd810dcd53fff3b0f1",
46 | "sha256:141ec3a3300ab89c7f2b0775289954d193cc8edb621ea05f99db9cb181530512",
47 | "sha256:16c1b388cc31a9baa06d91a19366fb99ddbe1c7b205293ed072211ee5bac1ed2",
48 | "sha256:18bed2bcb39e3f758296584337966e68d2d5ba6aab7e038688ad53c8f889f757",
49 | "sha256:1aeef46a13e51931c0b1cf8ae1168b4a55ecd282e6688fdb0a948cc5a1d5afb9",
50 | "sha256:27d3f3b9e3406579a8af3a9f262f5339005dd25e0ecf3cf1559ff8a49ed5cbf2",
51 | "sha256:2a2740aa9733d2e5b2dfb33639d98a64c3b0f24765fed86b0fd2aec07f6a0a08",
52 | "sha256:4377e10b874e653fe96985c05feed2225c912e328c8a26541f7fc600fb9c637b",
53 | "sha256:448ebb1b3bf64c0267d6b09a7cba26b5ae61b6d2dbabff7c91b660c7eccf2bdb",
54 | "sha256:50e86c076611212ca62e5a59f518edafe0c0730f7d9195fec718da1a5c2bb1fc",
55 | "sha256:5734bdc0342aba9dfc6f04920988140fb41234db42381cf7ccba64169f9fe7ac",
56 | "sha256:64324f64f90a9e4ef732be0928be853eee378fd6a01be21a0a8469c4f2682c83",
57 | "sha256:6ae6c680f3ebf1cf7ad1d7748868b39d9f900836df774c453c11c5440bc15b36",
58 | "sha256:6d7593a705d662be5bfe24111af14763016765f43cb6923ed86223f965f52387",
59 | "sha256:8cac8790a6b1ddf88640a9267ee67b1aee7a57dfa2d2dd33999d080bc8ee3a0f",
60 | "sha256:8ece138c3a16db8c1ad38f52eb32be6086cc72f403150a79336eb2045723a1ad",
61 | "sha256:9eeb7d1d04b117ac0d38719915ae169aa6b61fca227b0b7d198d43728f0c879c",
62 | "sha256:a09f98011236a419ee3f49cedc9ef27d7a1651df07810ae430a6b06576e0b414",
63 | "sha256:a5d897c14513590a85774180be713f692df6fa8ecf6483e561a6d47309566f37",
64 | "sha256:ad6f2ff5b1989a4899bf89800a671d71b1612e5ff40866d1f4d8bcf48d4e5764",
65 | "sha256:c42c4b73121caf0ed6cd795512c9c09c52a7287b04d105d112068c1736d7c753",
66 | "sha256:cb1017eec5257e9ac6209ac172058c430e834d5d2bc21961dceeb79d111e5909",
67 | "sha256:d6c7bb82883680e168b55b49c70af29b84b84abb161cbac2800e8fcb6f2109b6",
68 | "sha256:e452dc66e08a4ce642a961f134814258a082832c78c90351b75c41ad16f79f63",
69 | "sha256:e5b6ed0f0b42317050c88022349d994fe72bfe35f5908617512cd8c8ef9da2a9",
70 | "sha256:e9b30d4bd69498fc0c3fe9db5f62fffbb06b8eb9321f92cc970f2969be5e3949",
71 | "sha256:ec149b90019852266fec2341ce1db513b843e496d5a8e8cdb5ced1923a92faab",
72 | "sha256:edb01671b3caae1ca00881686003d16c2209e07b7ef8b7639f1867852b948f7c",
73 | "sha256:f0d3929fe88ee1c155129ecd82f981b8856c5d97bcb0d5f23e9b4242e79d1de3",
74 | "sha256:f29454410db6ef8126c83bd3c968d143304633d45dc57b51252afbd79d700893",
75 | "sha256:fe45becb4c2f72a0907c1d0246ea6449fe7a9e2293bb0e11c4e9a32bb0930a15",
76 | "sha256:fedbd128668ead37f33917820b704784aff695e0019309ad446a6d0b065b57e4"
77 | ],
78 | "markers": "python_version >= '3.6'",
79 | "version": "==1.19.4"
80 | },
81 | "pandas": {
82 | "hashes": [
83 | "sha256:09e0503758ad61afe81c9069505f8cb8c1e36ea8cc1e6826a95823ef5b327daf",
84 | "sha256:0a11a6290ef3667575cbd4785a1b62d658c25a2fd70a5adedba32e156a8f1773",
85 | "sha256:0d9a38a59242a2f6298fff45d09768b78b6eb0c52af5919ea9e45965d7ba56d9",
86 | "sha256:112c5ba0f9ea0f60b2cc38c25f87ca1d5ca10f71efbee8e0f1bee9cf584ed5d5",
87 | "sha256:185cf8c8f38b169dbf7001e1a88c511f653fbb9dfa3e048f5e19c38049e991dc",
88 | "sha256:3aa8e10768c730cc1b610aca688f588831fa70b65a26cb549fbb9f35049a05e0",
89 | "sha256:41746d520f2b50409dffdba29a15c42caa7babae15616bcf80800d8cfcae3d3e",
90 | "sha256:43cea38cbcadb900829858884f49745eb1f42f92609d368cabcc674b03e90efc",
91 | "sha256:5378f58172bd63d8c16dd5d008d7dcdd55bf803fcdbe7da2dcb65dbbf322f05b",
92 | "sha256:54404abb1cd3f89d01f1fb5350607815326790efb4789be60508f458cdd5ccbf",
93 | "sha256:5dac3aeaac5feb1016e94bde851eb2012d1733a222b8afa788202b836c97dad5",
94 | "sha256:5fdb2a61e477ce58d3f1fdf2470ee142d9f0dde4969032edaf0b8f1a9dafeaa2",
95 | "sha256:6613c7815ee0b20222178ad32ec144061cb07e6a746970c9160af1ebe3ad43b4",
96 | "sha256:6d2b5b58e7df46b2c010ec78d7fb9ab20abf1d306d0614d3432e7478993fbdb0",
97 | "sha256:8a5d7e57b9df2c0a9a202840b2881bb1f7a648eba12dd2d919ac07a33a36a97f",
98 | "sha256:8b4c2055ebd6e497e5ecc06efa5b8aa76f59d15233356eb10dad22a03b757805",
99 | "sha256:a15653480e5b92ee376f8458197a58cca89a6e95d12cccb4c2d933df5cecc63f",
100 | "sha256:a7d2547b601ecc9a53fd41561de49a43d2231728ad65c7713d6b616cd02ddbed",
101 | "sha256:a979d0404b135c63954dea79e6246c45dd45371a88631cdbb4877d844e6de3b6",
102 | "sha256:b1f8111635700de7ac350b639e7e452b06fc541a328cf6193cf8fc638804bab8",
103 | "sha256:c5a3597880a7a29a31ebd39b73b2c824316ae63a05c3c8a5ce2aea3fc68afe35",
104 | "sha256:c681e8fcc47a767bf868341d8f0d76923733cbdcabd6ec3a3560695c69f14a1e",
105 | "sha256:cf135a08f306ebbcfea6da8bf775217613917be23e5074c69215b91e180caab4",
106 | "sha256:e2b8557fe6d0a18db4d61c028c6af61bfed44ef90e419ed6fadbdc079eba141e"
107 | ],
108 | "index": "pypi",
109 | "version": "==1.1.4"
110 | },
111 | "python-dateutil": {
112 | "hashes": [
113 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c",
114 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a"
115 | ],
116 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
117 | "version": "==2.8.1"
118 | },
119 | "pytz": {
120 | "hashes": [
121 | "sha256:3e6b7dd2d1e0a59084bcee14a17af60c5c562cdc16d828e8eba2e683d3a7e268",
122 | "sha256:5c55e189b682d420be27c6995ba6edce0c0a77dd67bfbe2ae6607134d5851ffd"
123 | ],
124 | "version": "==2020.4"
125 | },
126 | "scikit-learn": {
127 | "hashes": [
128 | "sha256:0a127cc70990d4c15b1019680bfedc7fec6c23d14d3719fdf9b64b22d37cdeca",
129 | "sha256:0d39748e7c9669ba648acf40fb3ce96b8a07b240db6888563a7cb76e05e0d9cc",
130 | "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea",
131 | "sha256:20766f515e6cd6f954554387dfae705d93c7b544ec0e6c6a5d8e006f6f7ef480",
132 | "sha256:2aa95c2f17d2f80534156215c87bee72b6aa314a7f8b8fe92a2d71f47280570d",
133 | "sha256:5ce7a8021c9defc2b75620571b350acc4a7d9763c25b7593621ef50f3bd019a2",
134 | "sha256:6c28a1d00aae7c3c9568f61aafeaad813f0f01c729bee4fd9479e2132b215c1d",
135 | "sha256:7671bbeddd7f4f9a6968f3b5442dac5f22bf1ba06709ef888cc9132ad354a9ab",
136 | "sha256:914ac2b45a058d3f1338d7736200f7f3b094857758895f8667be8a81ff443b5b",
137 | "sha256:98508723f44c61896a4e15894b2016762a55555fbf09365a0bb1870ecbd442de",
138 | "sha256:a64817b050efd50f9abcfd311870073e500ae11b299683a519fbb52d85e08d25",
139 | "sha256:cb3e76380312e1f86abd20340ab1d5b3cc46a26f6593d3c33c9ea3e4c7134028",
140 | "sha256:d0dcaa54263307075cb93d0bee3ceb02821093b1b3d25f66021987d305d01dce",
141 | "sha256:d9a1ce5f099f29c7c33181cc4386660e0ba891b21a60dc036bf369e3a3ee3aec",
142 | "sha256:da8e7c302003dd765d92a5616678e591f347460ac7b53e53d667be7dfe6d1b10",
143 | "sha256:daf276c465c38ef736a79bd79fc80a249f746bcbcae50c40945428f7ece074f8"
144 | ],
145 | "index": "pypi",
146 | "version": "==0.23.2"
147 | },
148 | "scipy": {
149 | "hashes": [
150 | "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce",
151 | "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42",
152 | "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554",
153 | "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e",
154 | "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce",
155 | "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9",
156 | "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4",
157 | "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06",
158 | "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b",
159 | "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47",
160 | "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443",
161 | "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389",
162 | "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e",
163 | "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57",
164 | "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62",
165 | "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d",
166 | "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437",
167 | "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2",
168 | "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54",
169 | "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474",
170 | "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938",
171 | "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340",
172 | "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660",
173 | "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012",
174 | "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"
175 | ],
176 | "markers": "python_version >= '3.6'",
177 | "version": "==1.5.4"
178 | },
179 | "six": {
180 | "hashes": [
181 | "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259",
182 | "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"
183 | ],
184 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
185 | "version": "==1.15.0"
186 | },
187 | "skorch": {
188 | "hashes": [
189 | "sha256:12bb80276719cdbd114bc5042f4d0b395ce1ebe5dbc29aba5d4ea2f1792f9705",
190 | "sha256:26317da14837f372fdeb8fb4eee9199c2cc0b0db1056fc4ab69696402e17e135",
191 | "sha256:bdce9370153fd80c5c4ec499a639f55eef0620e45d4b15fbf7d7ff2a225a3d40"
192 | ],
193 | "index": "pypi",
194 | "version": "==0.9.0"
195 | },
196 | "tabulate": {
197 | "hashes": [
198 | "sha256:ac64cb76d53b1231d364babcd72abbb16855adac7de6665122f97b593f1eb2ba",
199 | "sha256:db2723a20d04bcda8522165c73eea7c300eda74e0ce852d9022e0159d7895007"
200 | ],
201 | "version": "==0.8.7"
202 | },
203 | "threadpoolctl": {
204 | "hashes": [
205 | "sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725",
206 | "sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b"
207 | ],
208 | "markers": "python_version >= '3.5'",
209 | "version": "==2.1.0"
210 | },
211 | "torch": {
212 | "hashes": [
213 | "sha256:11054f26eee5c3114d217201dba5b3a35f1745d11133c123c077c5981bc95997",
214 | "sha256:1520c48430dea38e5845b7b3defc9054edad45f1f245808aa268ade840bb2c2a",
215 | "sha256:6b0c9b56cb56afe3ecbac79351d21c6f7172dffc7b7daa8c365f660541baf1a5",
216 | "sha256:89cb8774243750bd3fd2b3b3d09bab6e3be68b1785ad48b8411f1eb4fc7acdba",
217 | "sha256:b8000e39600e101b2f19dbbab75de663a3b78e3979c3e1720b7136aae1c35ce2",
218 | "sha256:e8cc3b2c3937b7ae036a3b447a189af049bfc006bca054fc1d8ae78766ca3105"
219 | ],
220 | "index": "pypi",
221 | "version": "==1.7.0"
222 | },
223 | "tqdm": {
224 | "hashes": [
225 | "sha256:9ad44aaf0fc3697c06f6e05c7cf025dd66bc7bcb7613c66d85f4464c47ac8fad",
226 | "sha256:ef54779f1c09f346b2b5a8e5c61f96fbcb639929e640e59f8cf810794f406432"
227 | ],
228 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
229 | "version": "==4.51.0"
230 | },
231 | "typing-extensions": {
232 | "hashes": [
233 | "sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918",
234 | "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c",
235 | "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"
236 | ],
237 | "version": "==3.7.4.3"
238 | }
239 | },
240 | "develop": {}
241 | }
242 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
9 | [![Medium][medium-shield]][medium-url]
10 | [![Twitter][twitter-shield]][twitter-url]
11 | [![Linkedin][linkedin-shield]][linkedin-url]
12 |
13 | # SKORCH: PyTorch Models Trained with a Scikit-Learn Wrapper
14 | This repository shows an example of the usability of SKORCH to train a PyTorch model making use of different capabilities of the scikit-learn framework.
15 |
16 | If you want to understand the details about how this model was created, take a look at this very clear and detailed explanation: SKORCH: PyTorch Models Trained with a Scikit-Learn Wrapper
17 |
18 |
19 | ## Table of Contents
20 |
21 | * [The model](#the-model)
22 | * [Files](#files)
23 | * [How to use](#how-to-use)
24 | * [Contributing](#contributing)
25 | * [Contact](#contact)
26 | * [License](#license)
27 |
28 |
29 | ## 1. The model
30 | The idea of this repository is to show how to use some of the SKorch functionalities to train a PyTorch model. In this case, a neural network was created to classify the wines dataset. In order to understand better what SKorch is, take a look at the following image:
31 |
32 |
33 |
34 |
35 |
36 |
37 | ## 2. Files
38 | * **data**: Here you will find the wines dataset
39 | * **src**: It contains two files ``data_loader.py`` and ``model.py``. The file ``data_loader.py``contains the functions to load an preprocess the wines dataset. The file ``model.py``contains the PyTorch model.
40 | * **main.py**: This file trigger the different cases explained in the Medium article.
41 |
42 |
43 |
44 | ## 3. How to use
45 | You just need to type
46 |
47 | ```SH
48 | python main.py
49 | ```
50 | however, I recommend you to work with a virtual environment, in this case I am using pipenv. So in order to install the dependencies located in the ``Pipfile`` you just need to type:
51 |
52 | ```SH
53 | pipenv install
54 | ```
55 | and then
56 |
57 | ```SH
58 | pipenv shell
59 | ```
60 |
61 |
62 | ## 4. Contributing
63 | Feel free to fork the model and add your own suggestiongs.
64 |
65 | 1. Fork the Project
66 | 2. Create your Feature Branch (`git checkout -b feature/YourGreatFeature`)
67 | 3. Commit your Changes (`git commit -m 'Add some YourGreatFeature'`)
68 | 4. Push to the Branch (`git push origin feature/YourGreatFeature`)
69 | 5. Open a Pull Request
70 |
71 |
72 | ## 5. Contact
73 | If you have any question, feel free to reach me out at:
74 | * Twitter
75 | * Medium
76 | * Linkedin
77 | * Email: fer.neutron@gmail.com
78 |
79 |
80 | ## 6. License
81 | Distributed under the MIT License. See ``LICENSE.md`` for more information.
82 |
83 |
84 |
85 |
86 | [medium-shield]: https://img.shields.io/badge/medium-%2312100E.svg?&style=for-the-badge&logo=medium&logoColor=white
87 | [medium-url]: https://medium.com/@fer.neutron
88 | [twitter-shield]: https://img.shields.io/badge/twitter-%231DA1F2.svg?&style=for-the-badge&logo=twitter&logoColor=white
89 | [twitter-url]: https://twitter.com/Fernando_LpzV
90 | [linkedin-shield]: https://img.shields.io/badge/linkedin-%230077B5.svg?&style=for-the-badge&logo=linkedin&logoColor=white
91 | [linkedin-url]: https://www.linkedin.com/in/fernando-lopezvelasco/
--------------------------------------------------------------------------------
/data/wines.csv:
--------------------------------------------------------------------------------
1 | class,feat1,feat2,feat3,feat4,feat5,feat6,feat7,feat8,feat9,feat10,feat11,feat12,feat13
2 | 1,14.23,1.71,2.43,15.6,127,2.8,3.06,.28,2.29,5.64,1.04,3.92,1065
3 | 1,13.2,1.78,2.14,11.2,100,2.65,2.76,.26,1.28,4.38,1.05,3.4,1050
4 | 1,13.16,2.36,2.67,18.6,101,2.8,3.24,.3,2.81,5.68,1.03,3.17,1185
5 | 1,14.37,1.95,2.5,16.8,113,3.85,3.49,.24,2.18,7.8,.86,3.45,1480
6 | 1,13.24,2.59,2.87,21,118,2.8,2.69,.39,1.82,4.32,1.04,2.93,735
7 | 1,14.2,1.76,2.45,15.2,112,3.27,3.39,.34,1.97,6.75,1.05,2.85,1450
8 | 1,14.39,1.87,2.45,14.6,96,2.5,2.52,.3,1.98,5.25,1.02,3.58,1290
9 | 1,14.06,2.15,2.61,17.6,121,2.6,2.51,.31,1.25,5.05,1.06,3.58,1295
10 | 1,14.83,1.64,2.17,14,97,2.8,2.98,.29,1.98,5.2,1.08,2.85,1045
11 | 1,13.86,1.35,2.27,16,98,2.98,3.15,.22,1.85,7.22,1.01,3.55,1045
12 | 1,14.1,2.16,2.3,18,105,2.95,3.32,.22,2.38,5.75,1.25,3.17,1510
13 | 1,14.12,1.48,2.32,16.8,95,2.2,2.43,.26,1.57,5,1.17,2.82,1280
14 | 1,13.75,1.73,2.41,16,89,2.6,2.76,.29,1.81,5.6,1.15,2.9,1320
15 | 1,14.75,1.73,2.39,11.4,91,3.1,3.69,.43,2.81,5.4,1.25,2.73,1150
16 | 1,14.38,1.87,2.38,12,102,3.3,3.64,.29,2.96,7.5,1.2,3,1547
17 | 1,13.63,1.81,2.7,17.2,112,2.85,2.91,.3,1.46,7.3,1.28,2.88,1310
18 | 1,14.3,1.92,2.72,20,120,2.8,3.14,.33,1.97,6.2,1.07,2.65,1280
19 | 1,13.83,1.57,2.62,20,115,2.95,3.4,.4,1.72,6.6,1.13,2.57,1130
20 | 1,14.19,1.59,2.48,16.5,108,3.3,3.93,.32,1.86,8.7,1.23,2.82,1680
21 | 1,13.64,3.1,2.56,15.2,116,2.7,3.03,.17,1.66,5.1,.96,3.36,845
22 | 1,14.06,1.63,2.28,16,126,3,3.17,.24,2.1,5.65,1.09,3.71,780
23 | 1,12.93,3.8,2.65,18.6,102,2.41,2.41,.25,1.98,4.5,1.03,3.52,770
24 | 1,13.71,1.86,2.36,16.6,101,2.61,2.88,.27,1.69,3.8,1.11,4,1035
25 | 1,12.85,1.6,2.52,17.8,95,2.48,2.37,.26,1.46,3.93,1.09,3.63,1015
26 | 1,13.5,1.81,2.61,20,96,2.53,2.61,.28,1.66,3.52,1.12,3.82,845
27 | 1,13.05,2.05,3.22,25,124,2.63,2.68,.47,1.92,3.58,1.13,3.2,830
28 | 1,13.39,1.77,2.62,16.1,93,2.85,2.94,.34,1.45,4.8,.92,3.22,1195
29 | 1,13.3,1.72,2.14,17,94,2.4,2.19,.27,1.35,3.95,1.02,2.77,1285
30 | 1,13.87,1.9,2.8,19.4,107,2.95,2.97,.37,1.76,4.5,1.25,3.4,915
31 | 1,14.02,1.68,2.21,16,96,2.65,2.33,.26,1.98,4.7,1.04,3.59,1035
32 | 1,13.73,1.5,2.7,22.5,101,3,3.25,.29,2.38,5.7,1.19,2.71,1285
33 | 1,13.58,1.66,2.36,19.1,106,2.86,3.19,.22,1.95,6.9,1.09,2.88,1515
34 | 1,13.68,1.83,2.36,17.2,104,2.42,2.69,.42,1.97,3.84,1.23,2.87,990
35 | 1,13.76,1.53,2.7,19.5,132,2.95,2.74,.5,1.35,5.4,1.25,3,1235
36 | 1,13.51,1.8,2.65,19,110,2.35,2.53,.29,1.54,4.2,1.1,2.87,1095
37 | 1,13.48,1.81,2.41,20.5,100,2.7,2.98,.26,1.86,5.1,1.04,3.47,920
38 | 1,13.28,1.64,2.84,15.5,110,2.6,2.68,.34,1.36,4.6,1.09,2.78,880
39 | 1,13.05,1.65,2.55,18,98,2.45,2.43,.29,1.44,4.25,1.12,2.51,1105
40 | 1,13.07,1.5,2.1,15.5,98,2.4,2.64,.28,1.37,3.7,1.18,2.69,1020
41 | 1,14.22,3.99,2.51,13.2,128,3,3.04,.2,2.08,5.1,.89,3.53,760
42 | 1,13.56,1.71,2.31,16.2,117,3.15,3.29,.34,2.34,6.13,.95,3.38,795
43 | 1,13.41,3.84,2.12,18.8,90,2.45,2.68,.27,1.48,4.28,.91,3,1035
44 | 1,13.88,1.89,2.59,15,101,3.25,3.56,.17,1.7,5.43,.88,3.56,1095
45 | 1,13.24,3.98,2.29,17.5,103,2.64,2.63,.32,1.66,4.36,.82,3,680
46 | 1,13.05,1.77,2.1,17,107,3,3,.28,2.03,5.04,.88,3.35,885
47 | 1,14.21,4.04,2.44,18.9,111,2.85,2.65,.3,1.25,5.24,.87,3.33,1080
48 | 1,14.38,3.59,2.28,16,102,3.25,3.17,.27,2.19,4.9,1.04,3.44,1065
49 | 1,13.9,1.68,2.12,16,101,3.1,3.39,.21,2.14,6.1,.91,3.33,985
50 | 1,14.1,2.02,2.4,18.8,103,2.75,2.92,.32,2.38,6.2,1.07,2.75,1060
51 | 1,13.94,1.73,2.27,17.4,108,2.88,3.54,.32,2.08,8.90,1.12,3.1,1260
52 | 1,13.05,1.73,2.04,12.4,92,2.72,3.27,.17,2.91,7.2,1.12,2.91,1150
53 | 1,13.83,1.65,2.6,17.2,94,2.45,2.99,.22,2.29,5.6,1.24,3.37,1265
54 | 1,13.82,1.75,2.42,14,111,3.88,3.74,.32,1.87,7.05,1.01,3.26,1190
55 | 1,13.77,1.9,2.68,17.1,115,3,2.79,.39,1.68,6.3,1.13,2.93,1375
56 | 1,13.74,1.67,2.25,16.4,118,2.6,2.9,.21,1.62,5.85,.92,3.2,1060
57 | 1,13.56,1.73,2.46,20.5,116,2.96,2.78,.2,2.45,6.25,.98,3.03,1120
58 | 1,14.22,1.7,2.3,16.3,118,3.2,3,.26,2.03,6.38,.94,3.31,970
59 | 1,13.29,1.97,2.68,16.8,102,3,3.23,.31,1.66,6,1.07,2.84,1270
60 | 1,13.72,1.43,2.5,16.7,108,3.4,3.67,.19,2.04,6.8,.89,2.87,1285
61 | 2,12.37,.94,1.36,10.6,88,1.98,.57,.28,.42,1.95,1.05,1.82,520
62 | 2,12.33,1.1,2.28,16,101,2.05,1.09,.63,.41,3.27,1.25,1.67,680
63 | 2,12.64,1.36,2.02,16.8,100,2.02,1.41,.53,.62,5.75,.98,1.59,450
64 | 2,13.67,1.25,1.92,18,94,2.1,1.79,.32,.73,3.8,1.23,2.46,630
65 | 2,12.37,1.13,2.16,19,87,3.5,3.1,.19,1.87,4.45,1.22,2.87,420
66 | 2,12.17,1.45,2.53,19,104,1.89,1.75,.45,1.03,2.95,1.45,2.23,355
67 | 2,12.37,1.21,2.56,18.1,98,2.42,2.65,.37,2.08,4.6,1.19,2.3,678
68 | 2,13.11,1.01,1.7,15,78,2.98,3.18,.26,2.28,5.3,1.12,3.18,502
69 | 2,12.37,1.17,1.92,19.6,78,2.11,2,.27,1.04,4.68,1.12,3.48,510
70 | 2,13.34,.94,2.36,17,110,2.53,1.3,.55,.42,3.17,1.02,1.93,750
71 | 2,12.21,1.19,1.75,16.8,151,1.85,1.28,.14,2.5,2.85,1.28,3.07,718
72 | 2,12.29,1.61,2.21,20.4,103,1.1,1.02,.37,1.46,3.05,.906,1.82,870
73 | 2,13.86,1.51,2.67,25,86,2.95,2.86,.21,1.87,3.38,1.36,3.16,410
74 | 2,13.49,1.66,2.24,24,87,1.88,1.84,.27,1.03,3.74,.98,2.78,472
75 | 2,12.99,1.67,2.6,30,139,3.3,2.89,.21,1.96,3.35,1.31,3.5,985
76 | 2,11.96,1.09,2.3,21,101,3.38,2.14,.13,1.65,3.21,.99,3.13,886
77 | 2,11.66,1.88,1.92,16,97,1.61,1.57,.34,1.15,3.8,1.23,2.14,428
78 | 2,13.03,.9,1.71,16,86,1.95,2.03,.24,1.46,4.6,1.19,2.48,392
79 | 2,11.84,2.89,2.23,18,112,1.72,1.32,.43,.95,2.65,.96,2.52,500
80 | 2,12.33,.99,1.95,14.8,136,1.9,1.85,.35,2.76,3.4,1.06,2.31,750
81 | 2,12.7,3.87,2.4,23,101,2.83,2.55,.43,1.95,2.57,1.19,3.13,463
82 | 2,12,.92,2,19,86,2.42,2.26,.3,1.43,2.5,1.38,3.12,278
83 | 2,12.72,1.81,2.2,18.8,86,2.2,2.53,.26,1.77,3.9,1.16,3.14,714
84 | 2,12.08,1.13,2.51,24,78,2,1.58,.4,1.4,2.2,1.31,2.72,630
85 | 2,13.05,3.86,2.32,22.5,85,1.65,1.59,.61,1.62,4.8,.84,2.01,515
86 | 2,11.84,.89,2.58,18,94,2.2,2.21,.22,2.35,3.05,.79,3.08,520
87 | 2,12.67,.98,2.24,18,99,2.2,1.94,.3,1.46,2.62,1.23,3.16,450
88 | 2,12.16,1.61,2.31,22.8,90,1.78,1.69,.43,1.56,2.45,1.33,2.26,495
89 | 2,11.65,1.67,2.62,26,88,1.92,1.61,.4,1.34,2.6,1.36,3.21,562
90 | 2,11.64,2.06,2.46,21.6,84,1.95,1.69,.48,1.35,2.8,1,2.75,680
91 | 2,12.08,1.33,2.3,23.6,70,2.2,1.59,.42,1.38,1.74,1.07,3.21,625
92 | 2,12.08,1.83,2.32,18.5,81,1.6,1.5,.52,1.64,2.4,1.08,2.27,480
93 | 2,12,1.51,2.42,22,86,1.45,1.25,.5,1.63,3.6,1.05,2.65,450
94 | 2,12.69,1.53,2.26,20.7,80,1.38,1.46,.58,1.62,3.05,.96,2.06,495
95 | 2,12.29,2.83,2.22,18,88,2.45,2.25,.25,1.99,2.15,1.15,3.3,290
96 | 2,11.62,1.99,2.28,18,98,3.02,2.26,.17,1.35,3.25,1.16,2.96,345
97 | 2,12.47,1.52,2.2,19,162,2.5,2.27,.32,3.28,2.6,1.16,2.63,937
98 | 2,11.81,2.12,2.74,21.5,134,1.6,.99,.14,1.56,2.5,.95,2.26,625
99 | 2,12.29,1.41,1.98,16,85,2.55,2.5,.29,1.77,2.9,1.23,2.74,428
100 | 2,12.37,1.07,2.1,18.5,88,3.52,3.75,.24,1.95,4.5,1.04,2.77,660
101 | 2,12.29,3.17,2.21,18,88,2.85,2.99,.45,2.81,2.3,1.42,2.83,406
102 | 2,12.08,2.08,1.7,17.5,97,2.23,2.17,.26,1.4,3.3,1.27,2.96,710
103 | 2,12.6,1.34,1.9,18.5,88,1.45,1.36,.29,1.35,2.45,1.04,2.77,562
104 | 2,12.34,2.45,2.46,21,98,2.56,2.11,.34,1.31,2.8,.8,3.38,438
105 | 2,11.82,1.72,1.88,19.5,86,2.5,1.64,.37,1.42,2.06,.94,2.44,415
106 | 2,12.51,1.73,1.98,20.5,85,2.2,1.92,.32,1.48,2.94,1.04,3.57,672
107 | 2,12.42,2.55,2.27,22,90,1.68,1.84,.66,1.42,2.7,.86,3.3,315
108 | 2,12.25,1.73,2.12,19,80,1.65,2.03,.37,1.63,3.4,1,3.17,510
109 | 2,12.72,1.75,2.28,22.5,84,1.38,1.76,.48,1.63,3.3,.88,2.42,488
110 | 2,12.22,1.29,1.94,19,92,2.36,2.04,.39,2.08,2.7,.86,3.02,312
111 | 2,11.61,1.35,2.7,20,94,2.74,2.92,.29,2.49,2.65,.96,3.26,680
112 | 2,11.46,3.74,1.82,19.5,107,3.18,2.58,.24,3.58,2.9,.75,2.81,562
113 | 2,12.52,2.43,2.17,21,88,2.55,2.27,.26,1.22,2,.9,2.78,325
114 | 2,11.76,2.68,2.92,20,103,1.75,2.03,.6,1.05,3.8,1.23,2.5,607
115 | 2,11.41,.74,2.5,21,88,2.48,2.01,.42,1.44,3.08,1.1,2.31,434
116 | 2,12.08,1.39,2.5,22.5,84,2.56,2.29,.43,1.04,2.9,.93,3.19,385
117 | 2,11.03,1.51,2.2,21.5,85,2.46,2.17,.52,2.01,1.9,1.71,2.87,407
118 | 2,11.82,1.47,1.99,20.8,86,1.98,1.6,.3,1.53,1.95,.95,3.33,495
119 | 2,12.42,1.61,2.19,22.5,108,2,2.09,.34,1.61,2.06,1.06,2.96,345
120 | 2,12.77,3.43,1.98,16,80,1.63,1.25,.43,.83,3.4,.7,2.12,372
121 | 2,12,3.43,2,19,87,2,1.64,.37,1.87,1.28,.93,3.05,564
122 | 2,11.45,2.4,2.42,20,96,2.9,2.79,.32,1.83,3.25,.8,3.39,625
123 | 2,11.56,2.05,3.23,28.5,119,3.18,5.08,.47,1.87,6,.93,3.69,465
124 | 2,12.42,4.43,2.73,26.5,102,2.2,2.13,.43,1.71,2.08,.92,3.12,365
125 | 2,13.05,5.8,2.13,21.5,86,2.62,2.65,.3,2.01,2.6,.73,3.1,380
126 | 2,11.87,4.31,2.39,21,82,2.86,3.03,.21,2.91,2.8,.75,3.64,380
127 | 2,12.07,2.16,2.17,21,85,2.6,2.65,.37,1.35,2.76,.86,3.28,378
128 | 2,12.43,1.53,2.29,21.5,86,2.74,3.15,.39,1.77,3.94,.69,2.84,352
129 | 2,11.79,2.13,2.78,28.5,92,2.13,2.24,.58,1.76,3,.97,2.44,466
130 | 2,12.37,1.63,2.3,24.5,88,2.22,2.45,.4,1.9,2.12,.89,2.78,342
131 | 2,12.04,4.3,2.38,22,80,2.1,1.75,.42,1.35,2.6,.79,2.57,580
132 | 3,12.86,1.35,2.32,18,122,1.51,1.25,.21,.94,4.1,.76,1.29,630
133 | 3,12.88,2.99,2.4,20,104,1.3,1.22,.24,.83,5.4,.74,1.42,530
134 | 3,12.81,2.31,2.4,24,98,1.15,1.09,.27,.83,5.7,.66,1.36,560
135 | 3,12.7,3.55,2.36,21.5,106,1.7,1.2,.17,.84,5,.78,1.29,600
136 | 3,12.51,1.24,2.25,17.5,85,2,.58,.6,1.25,5.45,.75,1.51,650
137 | 3,12.6,2.46,2.2,18.5,94,1.62,.66,.63,.94,7.1,.73,1.58,695
138 | 3,12.25,4.72,2.54,21,89,1.38,.47,.53,.8,3.85,.75,1.27,720
139 | 3,12.53,5.51,2.64,25,96,1.79,.6,.63,1.1,5,.82,1.69,515
140 | 3,13.49,3.59,2.19,19.5,88,1.62,.48,.58,.88,5.7,.81,1.82,580
141 | 3,12.84,2.96,2.61,24,101,2.32,.6,.53,.81,4.92,.89,2.15,590
142 | 3,12.93,2.81,2.7,21,96,1.54,.5,.53,.75,4.6,.77,2.31,600
143 | 3,13.36,2.56,2.35,20,89,1.4,.5,.37,.64,5.6,.7,2.47,780
144 | 3,13.52,3.17,2.72,23.5,97,1.55,.52,.5,.55,4.35,.89,2.06,520
145 | 3,13.62,4.95,2.35,20,92,2,.8,.47,1.02,4.4,.91,2.05,550
146 | 3,12.25,3.88,2.2,18.5,112,1.38,.78,.29,1.14,8.21,.65,2,855
147 | 3,13.16,3.57,2.15,21,102,1.5,.55,.43,1.3,4,.6,1.68,830
148 | 3,13.88,5.04,2.23,20,80,.98,.34,.4,.68,4.9,.58,1.33,415
149 | 3,12.87,4.61,2.48,21.5,86,1.7,.65,.47,.86,7.65,.54,1.86,625
150 | 3,13.32,3.24,2.38,21.5,92,1.93,.76,.45,1.25,8.42,.55,1.62,650
151 | 3,13.08,3.9,2.36,21.5,113,1.41,1.39,.34,1.14,9.40,.57,1.33,550
152 | 3,13.5,3.12,2.62,24,123,1.4,1.57,.22,1.25,8.60,.59,1.3,500
153 | 3,12.79,2.67,2.48,22,112,1.48,1.36,.24,1.26,10.8,.48,1.47,480
154 | 3,13.11,1.9,2.75,25.5,116,2.2,1.28,.26,1.56,7.1,.61,1.33,425
155 | 3,13.23,3.3,2.28,18.5,98,1.8,.83,.61,1.87,10.52,.56,1.51,675
156 | 3,12.58,1.29,2.1,20,103,1.48,.58,.53,1.4,7.6,.58,1.55,640
157 | 3,13.17,5.19,2.32,22,93,1.74,.63,.61,1.55,7.9,.6,1.48,725
158 | 3,13.84,4.12,2.38,19.5,89,1.8,.83,.48,1.56,9.01,.57,1.64,480
159 | 3,12.45,3.03,2.64,27,97,1.9,.58,.63,1.14,7.5,.67,1.73,880
160 | 3,14.34,1.68,2.7,25,98,2.8,1.31,.53,2.7,13,.57,1.96,660
161 | 3,13.48,1.67,2.64,22.5,89,2.6,1.1,.52,2.29,11.75,.57,1.78,620
162 | 3,12.36,3.83,2.38,21,88,2.3,.92,.5,1.04,7.65,.56,1.58,520
163 | 3,13.69,3.26,2.54,20,107,1.83,.56,.5,.8,5.88,.96,1.82,680
164 | 3,12.85,3.27,2.58,22,106,1.65,.6,.6,.96,5.58,.87,2.11,570
165 | 3,12.96,3.45,2.35,18.5,106,1.39,.7,.4,.94,5.28,.68,1.75,675
166 | 3,13.78,2.76,2.3,22,90,1.35,.68,.41,1.03,9.58,.7,1.68,615
167 | 3,13.73,4.36,2.26,22.5,88,1.28,.47,.52,1.15,6.62,.78,1.75,520
168 | 3,13.45,3.7,2.6,23,111,1.7,.92,.43,1.46,10.68,.85,1.56,695
169 | 3,12.82,3.37,2.3,19.5,88,1.48,.66,.4,.97,10.26,.72,1.75,685
170 | 3,13.58,2.58,2.69,24.5,105,1.55,.84,.39,1.54,8.66,.74,1.8,750
171 | 3,13.4,4.6,2.86,25,112,1.98,.96,.27,1.11,8.5,.67,1.92,630
172 | 3,12.2,3.03,2.32,19,96,1.25,.49,.4,.73,5.5,.66,1.83,510
173 | 3,12.77,2.39,2.28,19.5,86,1.39,.51,.48,.64,9.899999,.57,1.63,470
174 | 3,14.16,2.51,2.48,20,91,1.68,.7,.44,1.24,9.7,.62,1.71,660
175 | 3,13.71,5.65,2.45,20.5,95,1.68,.61,.52,1.06,7.7,.64,1.74,740
176 | 3,13.4,3.91,2.48,23,102,1.8,.75,.43,1.41,7.3,.7,1.56,750
177 | 3,13.27,4.28,2.26,20,120,1.59,.69,.43,1.35,10.2,.59,1.56,835
178 | 3,13.17,2.59,2.37,20,120,1.65,.68,.53,1.46,9.3,.6,1.62,840
179 | 3,14.13,4.1,2.74,24.5,96,2.05,.76,.56,1.35,9.2,.61,1.6,560
180 |
--------------------------------------------------------------------------------
/img/skorch2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FernandoLpz/SKORCH-PyTorch-Wrapper/8e0aac8d347e567347eb94104b0ced5abd40a9f7/img/skorch2.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 |
3 | from sklearn.pipeline import Pipeline
4 | from sklearn.preprocessing import StandardScaler
5 | from sklearn.model_selection import GridSearchCV
6 |
7 | from skorch.callbacks import EpochScoring
8 | from skorch import NeuralNetClassifier
9 |
10 | from src import load_data
11 | from src import NeuralNet
12 |
13 | class Run:
14 | def __init__(self, x, y):
15 | self.x = x
16 | self.y = y
17 |
18 | def simple_training(self):
19 | # Trains the Neural Network with fixed hyperparameters
20 |
21 | # The Neural Net is initialized with fixed hyperparameters
22 | nn = NeuralNetClassifier(NeuralNet, max_epochs=10, lr=0.01, batch_size=12, optimizer=optim.RMSprop)
23 | # Training
24 | nn.fit(self.x, self.y)
25 | pass
26 |
27 | def simple_pipeline_training(self):
28 | # Trains the Neural Network within a scikit-learn pipeline
29 | # The pipeline is composed by scaling features and NN training
30 | # The hyperparameters are fixed values
31 |
32 | # The Neural Net is initialized with fixed hyperparameters
33 | nn = NeuralNetClassifier(NeuralNet, max_epochs=10, lr=0.01, batch_size=12, optimizer=optim.RMSprop)
34 | # The pipeline instatiated, it wraps scaling and training phase
35 | pipeline = Pipeline([('scale', StandardScaler()), ('nn', nn)])
36 | # Pipeline execution
37 | pipeline.fit(self.x, self.y)
38 |
39 | pass
40 |
41 | def simple_pipeline_training_with_callbacks(self):
42 | # Trains the Neural Network within a scikit-learn pipeline
43 | # The pipeline is composed by scaling features and NN training
44 | # A callback is added in order to calculate the "balanced accuracy" and "accuracy" in the training phase
45 |
46 | # The EpochScoring from callbacks is initialized
47 | balanced_accuracy = EpochScoring(scoring='balanced_accuracy', lower_is_better=False)
48 | accuracy = EpochScoring(scoring='accuracy', lower_is_better=False)
49 |
50 | # The Neural Net is initialized with fixed hyperparameters
51 | nn = NeuralNetClassifier(NeuralNet, max_epochs=10, lr=0.01, batch_size=12, optimizer=optim.RMSprop, callbacks=[balanced_accuracy, accuracy])
52 | # The pipeline instatiated, it wraps scaling and training phase
53 | pipeline = Pipeline([('scale', StandardScaler()), ('nn', nn)])
54 | # Pipeline execution
55 | pipeline.fit(self.x, self.y)
56 |
57 | pass
58 |
59 | def grid_search_pipeline_training(self):
60 | # Through a grid search, the optimal hyperparameters are found
61 | # A pipeline is used in order to scale and train the neural net
62 | # The grid search module from scikit-learn wraps the pipeline
63 |
64 | # The Neural Net is instantiated, none hyperparameter is provided
65 | nn = NeuralNetClassifier(NeuralNet, verbose=0, train_split=False)
66 | # The pipeline is instantiated, it wraps scaling and training phase
67 | pipeline = Pipeline([('scale', StandardScaler()), ('nn', nn)])
68 |
69 | # The parameters for the grid search are defined
70 | # It must be used the prefix "nn__" when setting hyperparamters for the training phase
71 | # It must be used the prefix "nn__module__" when setting hyperparameters for the Neural Net
72 | params = {
73 | 'nn__max_epochs':[10, 20],
74 | 'nn__lr': [0.1, 0.01],
75 | 'nn__module__num_units': [5, 10],
76 | 'nn__module__dropout': [0.1, 0.5],
77 | 'nn__optimizer': [optim.Adam, optim.SGD, optim.RMSprop]}
78 |
79 | # The grid search module is instantiated
80 | gs = GridSearchCV(pipeline, params, refit=False, cv=3, scoring='balanced_accuracy', verbose=1)
81 | # Initialize grid search
82 | gs.fit(self.x, self.y)
83 | pass
84 |
85 | if __name__ == "__main__":
86 | x, y = load_data()
87 |
88 | run = Run(x, y)
89 |
90 | # run.simple_training()
91 | # run.simple_pipeline_training()
92 | # run.simple_pipeline_training_with_callbacks()
93 | run.grid_search_pipeline_training()
94 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_loader import load_data
2 | from .model import NeuralNet
--------------------------------------------------------------------------------
/src/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sklearn.utils import shuffle
4 |
5 | def load_data():
6 | # Load csv dataset
7 | data = pd.read_csv('data/wines.csv')
8 |
9 | # Shuffling data
10 | data = shuffle(data)
11 |
12 | # Fix class labels
13 | # Original class labels are [1, 2, 3], the ones must be changed as [0, 1, 2]
14 | data['class'] = data['class'].replace([1, 2, 3],[0, 1, 2])
15 |
16 |
17 | # Split x and y vectors
18 | x = data[[feature for feature in data.columns if feature != 'class']].values
19 | y = np.squeeze(data[['class']].values)
20 |
21 | # Fix datatypes
22 | x = x.astype(np.float32)
23 | y = y.astype(np.int64)
24 |
25 | return x, y
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class NeuralNet(nn.Module):
6 | def __init__(self, num_units=10, dropout=0.1):
7 | super(NeuralNet, self).__init__()
8 | self.num_units = num_units
9 | self.linear_1 = nn.Linear(13, num_units)
10 | self.dropout = nn.Dropout(dropout)
11 | self.linear_2 = nn.Linear(num_units, 10)
12 | self.linear_3 = nn.Linear(10, 3)
13 |
14 | def forward(self, x):
15 |
16 | x = self.linear_1(x)
17 | x = F.relu(x)
18 | x = self.linear_2(x)
19 | x = F.relu(x)
20 | x = self.linear_3(x)
21 | x = F.softmax(x, dim=-1)
22 |
23 | return x
--------------------------------------------------------------------------------