├── .vscode └── settings.json ├── Pipfile ├── Pipfile.lock ├── README.md ├── data └── wines.csv ├── img └── skorch2.jpg ├── main.py └── src ├── __init__.py ├── data_loader.py └── model.py /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/Users/fernandolopez/.local/share/virtualenvs/skorch-zxWYGwU5/bin/python" 3 | } -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | scikit-learn = "*" 10 | skorch = "*" 11 | torch = "*" 12 | pandas = "*" 13 | 14 | [requires] 15 | python_version = "3.8" 16 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "caba5c50b162d0d09265df9f323da40e4399d222f88c1294f47b43f97165d8f6" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.8" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "dataclasses": { 20 | "hashes": [ 21 | "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f", 22 | "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84" 23 | ], 24 | "version": "==0.6" 25 | }, 26 | "future": { 27 | "hashes": [ 28 | "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d" 29 | ], 30 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", 31 | "version": "==0.18.2" 32 | }, 33 | "joblib": { 34 | "hashes": [ 35 | "sha256:698c311779f347cf6b7e6b8a39bb682277b8ee4aba8cf9507bc0cf4cd4737b72", 36 | "sha256:9e284edd6be6b71883a63c9b7f124738a3c16195513ad940eae7e3438de885d5" 37 | ], 38 | "markers": "python_version >= '3.6'", 39 | "version": "==0.17.0" 40 | }, 41 | "numpy": { 42 | "hashes": [ 43 | "sha256:08308c38e44cc926bdfce99498b21eec1f848d24c302519e64203a8da99a97db", 44 | "sha256:09c12096d843b90eafd01ea1b3307e78ddd47a55855ad402b157b6c4862197ce", 45 | "sha256:13d166f77d6dc02c0a73c1101dd87fdf01339febec1030bd810dcd53fff3b0f1", 46 | "sha256:141ec3a3300ab89c7f2b0775289954d193cc8edb621ea05f99db9cb181530512", 47 | "sha256:16c1b388cc31a9baa06d91a19366fb99ddbe1c7b205293ed072211ee5bac1ed2", 48 | "sha256:18bed2bcb39e3f758296584337966e68d2d5ba6aab7e038688ad53c8f889f757", 49 | "sha256:1aeef46a13e51931c0b1cf8ae1168b4a55ecd282e6688fdb0a948cc5a1d5afb9", 50 | "sha256:27d3f3b9e3406579a8af3a9f262f5339005dd25e0ecf3cf1559ff8a49ed5cbf2", 51 | "sha256:2a2740aa9733d2e5b2dfb33639d98a64c3b0f24765fed86b0fd2aec07f6a0a08", 52 | "sha256:4377e10b874e653fe96985c05feed2225c912e328c8a26541f7fc600fb9c637b", 53 | "sha256:448ebb1b3bf64c0267d6b09a7cba26b5ae61b6d2dbabff7c91b660c7eccf2bdb", 54 | "sha256:50e86c076611212ca62e5a59f518edafe0c0730f7d9195fec718da1a5c2bb1fc", 55 | "sha256:5734bdc0342aba9dfc6f04920988140fb41234db42381cf7ccba64169f9fe7ac", 56 | "sha256:64324f64f90a9e4ef732be0928be853eee378fd6a01be21a0a8469c4f2682c83", 57 | "sha256:6ae6c680f3ebf1cf7ad1d7748868b39d9f900836df774c453c11c5440bc15b36", 58 | "sha256:6d7593a705d662be5bfe24111af14763016765f43cb6923ed86223f965f52387", 59 | "sha256:8cac8790a6b1ddf88640a9267ee67b1aee7a57dfa2d2dd33999d080bc8ee3a0f", 60 | "sha256:8ece138c3a16db8c1ad38f52eb32be6086cc72f403150a79336eb2045723a1ad", 61 | "sha256:9eeb7d1d04b117ac0d38719915ae169aa6b61fca227b0b7d198d43728f0c879c", 62 | "sha256:a09f98011236a419ee3f49cedc9ef27d7a1651df07810ae430a6b06576e0b414", 63 | "sha256:a5d897c14513590a85774180be713f692df6fa8ecf6483e561a6d47309566f37", 64 | "sha256:ad6f2ff5b1989a4899bf89800a671d71b1612e5ff40866d1f4d8bcf48d4e5764", 65 | "sha256:c42c4b73121caf0ed6cd795512c9c09c52a7287b04d105d112068c1736d7c753", 66 | "sha256:cb1017eec5257e9ac6209ac172058c430e834d5d2bc21961dceeb79d111e5909", 67 | "sha256:d6c7bb82883680e168b55b49c70af29b84b84abb161cbac2800e8fcb6f2109b6", 68 | "sha256:e452dc66e08a4ce642a961f134814258a082832c78c90351b75c41ad16f79f63", 69 | "sha256:e5b6ed0f0b42317050c88022349d994fe72bfe35f5908617512cd8c8ef9da2a9", 70 | "sha256:e9b30d4bd69498fc0c3fe9db5f62fffbb06b8eb9321f92cc970f2969be5e3949", 71 | "sha256:ec149b90019852266fec2341ce1db513b843e496d5a8e8cdb5ced1923a92faab", 72 | "sha256:edb01671b3caae1ca00881686003d16c2209e07b7ef8b7639f1867852b948f7c", 73 | "sha256:f0d3929fe88ee1c155129ecd82f981b8856c5d97bcb0d5f23e9b4242e79d1de3", 74 | "sha256:f29454410db6ef8126c83bd3c968d143304633d45dc57b51252afbd79d700893", 75 | "sha256:fe45becb4c2f72a0907c1d0246ea6449fe7a9e2293bb0e11c4e9a32bb0930a15", 76 | "sha256:fedbd128668ead37f33917820b704784aff695e0019309ad446a6d0b065b57e4" 77 | ], 78 | "markers": "python_version >= '3.6'", 79 | "version": "==1.19.4" 80 | }, 81 | "pandas": { 82 | "hashes": [ 83 | "sha256:09e0503758ad61afe81c9069505f8cb8c1e36ea8cc1e6826a95823ef5b327daf", 84 | "sha256:0a11a6290ef3667575cbd4785a1b62d658c25a2fd70a5adedba32e156a8f1773", 85 | "sha256:0d9a38a59242a2f6298fff45d09768b78b6eb0c52af5919ea9e45965d7ba56d9", 86 | "sha256:112c5ba0f9ea0f60b2cc38c25f87ca1d5ca10f71efbee8e0f1bee9cf584ed5d5", 87 | "sha256:185cf8c8f38b169dbf7001e1a88c511f653fbb9dfa3e048f5e19c38049e991dc", 88 | "sha256:3aa8e10768c730cc1b610aca688f588831fa70b65a26cb549fbb9f35049a05e0", 89 | "sha256:41746d520f2b50409dffdba29a15c42caa7babae15616bcf80800d8cfcae3d3e", 90 | "sha256:43cea38cbcadb900829858884f49745eb1f42f92609d368cabcc674b03e90efc", 91 | "sha256:5378f58172bd63d8c16dd5d008d7dcdd55bf803fcdbe7da2dcb65dbbf322f05b", 92 | "sha256:54404abb1cd3f89d01f1fb5350607815326790efb4789be60508f458cdd5ccbf", 93 | "sha256:5dac3aeaac5feb1016e94bde851eb2012d1733a222b8afa788202b836c97dad5", 94 | "sha256:5fdb2a61e477ce58d3f1fdf2470ee142d9f0dde4969032edaf0b8f1a9dafeaa2", 95 | "sha256:6613c7815ee0b20222178ad32ec144061cb07e6a746970c9160af1ebe3ad43b4", 96 | "sha256:6d2b5b58e7df46b2c010ec78d7fb9ab20abf1d306d0614d3432e7478993fbdb0", 97 | "sha256:8a5d7e57b9df2c0a9a202840b2881bb1f7a648eba12dd2d919ac07a33a36a97f", 98 | "sha256:8b4c2055ebd6e497e5ecc06efa5b8aa76f59d15233356eb10dad22a03b757805", 99 | "sha256:a15653480e5b92ee376f8458197a58cca89a6e95d12cccb4c2d933df5cecc63f", 100 | "sha256:a7d2547b601ecc9a53fd41561de49a43d2231728ad65c7713d6b616cd02ddbed", 101 | "sha256:a979d0404b135c63954dea79e6246c45dd45371a88631cdbb4877d844e6de3b6", 102 | "sha256:b1f8111635700de7ac350b639e7e452b06fc541a328cf6193cf8fc638804bab8", 103 | "sha256:c5a3597880a7a29a31ebd39b73b2c824316ae63a05c3c8a5ce2aea3fc68afe35", 104 | "sha256:c681e8fcc47a767bf868341d8f0d76923733cbdcabd6ec3a3560695c69f14a1e", 105 | "sha256:cf135a08f306ebbcfea6da8bf775217613917be23e5074c69215b91e180caab4", 106 | "sha256:e2b8557fe6d0a18db4d61c028c6af61bfed44ef90e419ed6fadbdc079eba141e" 107 | ], 108 | "index": "pypi", 109 | "version": "==1.1.4" 110 | }, 111 | "python-dateutil": { 112 | "hashes": [ 113 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", 114 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" 115 | ], 116 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 117 | "version": "==2.8.1" 118 | }, 119 | "pytz": { 120 | "hashes": [ 121 | "sha256:3e6b7dd2d1e0a59084bcee14a17af60c5c562cdc16d828e8eba2e683d3a7e268", 122 | "sha256:5c55e189b682d420be27c6995ba6edce0c0a77dd67bfbe2ae6607134d5851ffd" 123 | ], 124 | "version": "==2020.4" 125 | }, 126 | "scikit-learn": { 127 | "hashes": [ 128 | "sha256:0a127cc70990d4c15b1019680bfedc7fec6c23d14d3719fdf9b64b22d37cdeca", 129 | "sha256:0d39748e7c9669ba648acf40fb3ce96b8a07b240db6888563a7cb76e05e0d9cc", 130 | "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea", 131 | "sha256:20766f515e6cd6f954554387dfae705d93c7b544ec0e6c6a5d8e006f6f7ef480", 132 | "sha256:2aa95c2f17d2f80534156215c87bee72b6aa314a7f8b8fe92a2d71f47280570d", 133 | "sha256:5ce7a8021c9defc2b75620571b350acc4a7d9763c25b7593621ef50f3bd019a2", 134 | "sha256:6c28a1d00aae7c3c9568f61aafeaad813f0f01c729bee4fd9479e2132b215c1d", 135 | "sha256:7671bbeddd7f4f9a6968f3b5442dac5f22bf1ba06709ef888cc9132ad354a9ab", 136 | "sha256:914ac2b45a058d3f1338d7736200f7f3b094857758895f8667be8a81ff443b5b", 137 | "sha256:98508723f44c61896a4e15894b2016762a55555fbf09365a0bb1870ecbd442de", 138 | "sha256:a64817b050efd50f9abcfd311870073e500ae11b299683a519fbb52d85e08d25", 139 | "sha256:cb3e76380312e1f86abd20340ab1d5b3cc46a26f6593d3c33c9ea3e4c7134028", 140 | "sha256:d0dcaa54263307075cb93d0bee3ceb02821093b1b3d25f66021987d305d01dce", 141 | "sha256:d9a1ce5f099f29c7c33181cc4386660e0ba891b21a60dc036bf369e3a3ee3aec", 142 | "sha256:da8e7c302003dd765d92a5616678e591f347460ac7b53e53d667be7dfe6d1b10", 143 | "sha256:daf276c465c38ef736a79bd79fc80a249f746bcbcae50c40945428f7ece074f8" 144 | ], 145 | "index": "pypi", 146 | "version": "==0.23.2" 147 | }, 148 | "scipy": { 149 | "hashes": [ 150 | "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce", 151 | "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42", 152 | "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554", 153 | "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e", 154 | "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce", 155 | "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9", 156 | "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4", 157 | "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06", 158 | "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b", 159 | "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47", 160 | "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443", 161 | "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389", 162 | "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e", 163 | "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57", 164 | "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62", 165 | "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d", 166 | "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437", 167 | "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2", 168 | "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54", 169 | "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474", 170 | "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938", 171 | "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340", 172 | "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660", 173 | "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012", 174 | "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c" 175 | ], 176 | "markers": "python_version >= '3.6'", 177 | "version": "==1.5.4" 178 | }, 179 | "six": { 180 | "hashes": [ 181 | "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259", 182 | "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced" 183 | ], 184 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 185 | "version": "==1.15.0" 186 | }, 187 | "skorch": { 188 | "hashes": [ 189 | "sha256:12bb80276719cdbd114bc5042f4d0b395ce1ebe5dbc29aba5d4ea2f1792f9705", 190 | "sha256:26317da14837f372fdeb8fb4eee9199c2cc0b0db1056fc4ab69696402e17e135", 191 | "sha256:bdce9370153fd80c5c4ec499a639f55eef0620e45d4b15fbf7d7ff2a225a3d40" 192 | ], 193 | "index": "pypi", 194 | "version": "==0.9.0" 195 | }, 196 | "tabulate": { 197 | "hashes": [ 198 | "sha256:ac64cb76d53b1231d364babcd72abbb16855adac7de6665122f97b593f1eb2ba", 199 | "sha256:db2723a20d04bcda8522165c73eea7c300eda74e0ce852d9022e0159d7895007" 200 | ], 201 | "version": "==0.8.7" 202 | }, 203 | "threadpoolctl": { 204 | "hashes": [ 205 | "sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725", 206 | "sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b" 207 | ], 208 | "markers": "python_version >= '3.5'", 209 | "version": "==2.1.0" 210 | }, 211 | "torch": { 212 | "hashes": [ 213 | "sha256:11054f26eee5c3114d217201dba5b3a35f1745d11133c123c077c5981bc95997", 214 | "sha256:1520c48430dea38e5845b7b3defc9054edad45f1f245808aa268ade840bb2c2a", 215 | "sha256:6b0c9b56cb56afe3ecbac79351d21c6f7172dffc7b7daa8c365f660541baf1a5", 216 | "sha256:89cb8774243750bd3fd2b3b3d09bab6e3be68b1785ad48b8411f1eb4fc7acdba", 217 | "sha256:b8000e39600e101b2f19dbbab75de663a3b78e3979c3e1720b7136aae1c35ce2", 218 | "sha256:e8cc3b2c3937b7ae036a3b447a189af049bfc006bca054fc1d8ae78766ca3105" 219 | ], 220 | "index": "pypi", 221 | "version": "==1.7.0" 222 | }, 223 | "tqdm": { 224 | "hashes": [ 225 | "sha256:9ad44aaf0fc3697c06f6e05c7cf025dd66bc7bcb7613c66d85f4464c47ac8fad", 226 | "sha256:ef54779f1c09f346b2b5a8e5c61f96fbcb639929e640e59f8cf810794f406432" 227 | ], 228 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", 229 | "version": "==4.51.0" 230 | }, 231 | "typing-extensions": { 232 | "hashes": [ 233 | "sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918", 234 | "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c", 235 | "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f" 236 | ], 237 | "version": "==3.7.4.3" 238 | } 239 | }, 240 | "develop": {} 241 | } 242 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 9 | [![Medium][medium-shield]][medium-url] 10 | [![Twitter][twitter-shield]][twitter-url] 11 | [![Linkedin][linkedin-shield]][linkedin-url] 12 | 13 | # SKORCH: PyTorch Models Trained with a Scikit-Learn Wrapper 14 | This repository shows an example of the usability of SKORCH to train a PyTorch model making use of different capabilities of the scikit-learn framework. 15 | 16 | If you want to understand the details about how this model was created, take a look at this very clear and detailed explanation: SKORCH: PyTorch Models Trained with a Scikit-Learn Wrapper 17 | 18 | 19 | ## Table of Contents 20 | 21 | * [The model](#the-model) 22 | * [Files](#files) 23 | * [How to use](#how-to-use) 24 | * [Contributing](#contributing) 25 | * [Contact](#contact) 26 | * [License](#license) 27 | 28 | 29 | ## 1. The model 30 | The idea of this repository is to show how to use some of the SKorch functionalities to train a PyTorch model. In this case, a neural network was created to classify the wines dataset. In order to understand better what SKorch is, take a look at the following image: 31 | 32 |

33 | 34 |

35 | 36 | 37 | ## 2. Files 38 | * **data**: Here you will find the wines dataset 39 | * **src**: It contains two files ``data_loader.py`` and ``model.py``. The file ``data_loader.py``contains the functions to load an preprocess the wines dataset. The file ``model.py``contains the PyTorch model. 40 | * **main.py**: This file trigger the different cases explained in the Medium article. 41 | 42 | 43 | 44 | ## 3. How to use 45 | You just need to type 46 | 47 | ```SH 48 | python main.py 49 | ``` 50 | however, I recommend you to work with a virtual environment, in this case I am using pipenv. So in order to install the dependencies located in the ``Pipfile`` you just need to type: 51 | 52 | ```SH 53 | pipenv install 54 | ``` 55 | and then 56 | 57 | ```SH 58 | pipenv shell 59 | ``` 60 | 61 | 62 | ## 4. Contributing 63 | Feel free to fork the model and add your own suggestiongs. 64 | 65 | 1. Fork the Project 66 | 2. Create your Feature Branch (`git checkout -b feature/YourGreatFeature`) 67 | 3. Commit your Changes (`git commit -m 'Add some YourGreatFeature'`) 68 | 4. Push to the Branch (`git push origin feature/YourGreatFeature`) 69 | 5. Open a Pull Request 70 | 71 | 72 | ## 5. Contact 73 | If you have any question, feel free to reach me out at: 74 | * Twitter 75 | * Medium 76 | * Linkedin 77 | * Email: fer.neutron@gmail.com 78 | 79 | 80 | ## 6. License 81 | Distributed under the MIT License. See ``LICENSE.md`` for more information. 82 | 83 | 84 | 85 | 86 | [medium-shield]: https://img.shields.io/badge/medium-%2312100E.svg?&style=for-the-badge&logo=medium&logoColor=white 87 | [medium-url]: https://medium.com/@fer.neutron 88 | [twitter-shield]: https://img.shields.io/badge/twitter-%231DA1F2.svg?&style=for-the-badge&logo=twitter&logoColor=white 89 | [twitter-url]: https://twitter.com/Fernando_LpzV 90 | [linkedin-shield]: https://img.shields.io/badge/linkedin-%230077B5.svg?&style=for-the-badge&logo=linkedin&logoColor=white 91 | [linkedin-url]: https://www.linkedin.com/in/fernando-lopezvelasco/ -------------------------------------------------------------------------------- /data/wines.csv: -------------------------------------------------------------------------------- 1 | class,feat1,feat2,feat3,feat4,feat5,feat6,feat7,feat8,feat9,feat10,feat11,feat12,feat13 2 | 1,14.23,1.71,2.43,15.6,127,2.8,3.06,.28,2.29,5.64,1.04,3.92,1065 3 | 1,13.2,1.78,2.14,11.2,100,2.65,2.76,.26,1.28,4.38,1.05,3.4,1050 4 | 1,13.16,2.36,2.67,18.6,101,2.8,3.24,.3,2.81,5.68,1.03,3.17,1185 5 | 1,14.37,1.95,2.5,16.8,113,3.85,3.49,.24,2.18,7.8,.86,3.45,1480 6 | 1,13.24,2.59,2.87,21,118,2.8,2.69,.39,1.82,4.32,1.04,2.93,735 7 | 1,14.2,1.76,2.45,15.2,112,3.27,3.39,.34,1.97,6.75,1.05,2.85,1450 8 | 1,14.39,1.87,2.45,14.6,96,2.5,2.52,.3,1.98,5.25,1.02,3.58,1290 9 | 1,14.06,2.15,2.61,17.6,121,2.6,2.51,.31,1.25,5.05,1.06,3.58,1295 10 | 1,14.83,1.64,2.17,14,97,2.8,2.98,.29,1.98,5.2,1.08,2.85,1045 11 | 1,13.86,1.35,2.27,16,98,2.98,3.15,.22,1.85,7.22,1.01,3.55,1045 12 | 1,14.1,2.16,2.3,18,105,2.95,3.32,.22,2.38,5.75,1.25,3.17,1510 13 | 1,14.12,1.48,2.32,16.8,95,2.2,2.43,.26,1.57,5,1.17,2.82,1280 14 | 1,13.75,1.73,2.41,16,89,2.6,2.76,.29,1.81,5.6,1.15,2.9,1320 15 | 1,14.75,1.73,2.39,11.4,91,3.1,3.69,.43,2.81,5.4,1.25,2.73,1150 16 | 1,14.38,1.87,2.38,12,102,3.3,3.64,.29,2.96,7.5,1.2,3,1547 17 | 1,13.63,1.81,2.7,17.2,112,2.85,2.91,.3,1.46,7.3,1.28,2.88,1310 18 | 1,14.3,1.92,2.72,20,120,2.8,3.14,.33,1.97,6.2,1.07,2.65,1280 19 | 1,13.83,1.57,2.62,20,115,2.95,3.4,.4,1.72,6.6,1.13,2.57,1130 20 | 1,14.19,1.59,2.48,16.5,108,3.3,3.93,.32,1.86,8.7,1.23,2.82,1680 21 | 1,13.64,3.1,2.56,15.2,116,2.7,3.03,.17,1.66,5.1,.96,3.36,845 22 | 1,14.06,1.63,2.28,16,126,3,3.17,.24,2.1,5.65,1.09,3.71,780 23 | 1,12.93,3.8,2.65,18.6,102,2.41,2.41,.25,1.98,4.5,1.03,3.52,770 24 | 1,13.71,1.86,2.36,16.6,101,2.61,2.88,.27,1.69,3.8,1.11,4,1035 25 | 1,12.85,1.6,2.52,17.8,95,2.48,2.37,.26,1.46,3.93,1.09,3.63,1015 26 | 1,13.5,1.81,2.61,20,96,2.53,2.61,.28,1.66,3.52,1.12,3.82,845 27 | 1,13.05,2.05,3.22,25,124,2.63,2.68,.47,1.92,3.58,1.13,3.2,830 28 | 1,13.39,1.77,2.62,16.1,93,2.85,2.94,.34,1.45,4.8,.92,3.22,1195 29 | 1,13.3,1.72,2.14,17,94,2.4,2.19,.27,1.35,3.95,1.02,2.77,1285 30 | 1,13.87,1.9,2.8,19.4,107,2.95,2.97,.37,1.76,4.5,1.25,3.4,915 31 | 1,14.02,1.68,2.21,16,96,2.65,2.33,.26,1.98,4.7,1.04,3.59,1035 32 | 1,13.73,1.5,2.7,22.5,101,3,3.25,.29,2.38,5.7,1.19,2.71,1285 33 | 1,13.58,1.66,2.36,19.1,106,2.86,3.19,.22,1.95,6.9,1.09,2.88,1515 34 | 1,13.68,1.83,2.36,17.2,104,2.42,2.69,.42,1.97,3.84,1.23,2.87,990 35 | 1,13.76,1.53,2.7,19.5,132,2.95,2.74,.5,1.35,5.4,1.25,3,1235 36 | 1,13.51,1.8,2.65,19,110,2.35,2.53,.29,1.54,4.2,1.1,2.87,1095 37 | 1,13.48,1.81,2.41,20.5,100,2.7,2.98,.26,1.86,5.1,1.04,3.47,920 38 | 1,13.28,1.64,2.84,15.5,110,2.6,2.68,.34,1.36,4.6,1.09,2.78,880 39 | 1,13.05,1.65,2.55,18,98,2.45,2.43,.29,1.44,4.25,1.12,2.51,1105 40 | 1,13.07,1.5,2.1,15.5,98,2.4,2.64,.28,1.37,3.7,1.18,2.69,1020 41 | 1,14.22,3.99,2.51,13.2,128,3,3.04,.2,2.08,5.1,.89,3.53,760 42 | 1,13.56,1.71,2.31,16.2,117,3.15,3.29,.34,2.34,6.13,.95,3.38,795 43 | 1,13.41,3.84,2.12,18.8,90,2.45,2.68,.27,1.48,4.28,.91,3,1035 44 | 1,13.88,1.89,2.59,15,101,3.25,3.56,.17,1.7,5.43,.88,3.56,1095 45 | 1,13.24,3.98,2.29,17.5,103,2.64,2.63,.32,1.66,4.36,.82,3,680 46 | 1,13.05,1.77,2.1,17,107,3,3,.28,2.03,5.04,.88,3.35,885 47 | 1,14.21,4.04,2.44,18.9,111,2.85,2.65,.3,1.25,5.24,.87,3.33,1080 48 | 1,14.38,3.59,2.28,16,102,3.25,3.17,.27,2.19,4.9,1.04,3.44,1065 49 | 1,13.9,1.68,2.12,16,101,3.1,3.39,.21,2.14,6.1,.91,3.33,985 50 | 1,14.1,2.02,2.4,18.8,103,2.75,2.92,.32,2.38,6.2,1.07,2.75,1060 51 | 1,13.94,1.73,2.27,17.4,108,2.88,3.54,.32,2.08,8.90,1.12,3.1,1260 52 | 1,13.05,1.73,2.04,12.4,92,2.72,3.27,.17,2.91,7.2,1.12,2.91,1150 53 | 1,13.83,1.65,2.6,17.2,94,2.45,2.99,.22,2.29,5.6,1.24,3.37,1265 54 | 1,13.82,1.75,2.42,14,111,3.88,3.74,.32,1.87,7.05,1.01,3.26,1190 55 | 1,13.77,1.9,2.68,17.1,115,3,2.79,.39,1.68,6.3,1.13,2.93,1375 56 | 1,13.74,1.67,2.25,16.4,118,2.6,2.9,.21,1.62,5.85,.92,3.2,1060 57 | 1,13.56,1.73,2.46,20.5,116,2.96,2.78,.2,2.45,6.25,.98,3.03,1120 58 | 1,14.22,1.7,2.3,16.3,118,3.2,3,.26,2.03,6.38,.94,3.31,970 59 | 1,13.29,1.97,2.68,16.8,102,3,3.23,.31,1.66,6,1.07,2.84,1270 60 | 1,13.72,1.43,2.5,16.7,108,3.4,3.67,.19,2.04,6.8,.89,2.87,1285 61 | 2,12.37,.94,1.36,10.6,88,1.98,.57,.28,.42,1.95,1.05,1.82,520 62 | 2,12.33,1.1,2.28,16,101,2.05,1.09,.63,.41,3.27,1.25,1.67,680 63 | 2,12.64,1.36,2.02,16.8,100,2.02,1.41,.53,.62,5.75,.98,1.59,450 64 | 2,13.67,1.25,1.92,18,94,2.1,1.79,.32,.73,3.8,1.23,2.46,630 65 | 2,12.37,1.13,2.16,19,87,3.5,3.1,.19,1.87,4.45,1.22,2.87,420 66 | 2,12.17,1.45,2.53,19,104,1.89,1.75,.45,1.03,2.95,1.45,2.23,355 67 | 2,12.37,1.21,2.56,18.1,98,2.42,2.65,.37,2.08,4.6,1.19,2.3,678 68 | 2,13.11,1.01,1.7,15,78,2.98,3.18,.26,2.28,5.3,1.12,3.18,502 69 | 2,12.37,1.17,1.92,19.6,78,2.11,2,.27,1.04,4.68,1.12,3.48,510 70 | 2,13.34,.94,2.36,17,110,2.53,1.3,.55,.42,3.17,1.02,1.93,750 71 | 2,12.21,1.19,1.75,16.8,151,1.85,1.28,.14,2.5,2.85,1.28,3.07,718 72 | 2,12.29,1.61,2.21,20.4,103,1.1,1.02,.37,1.46,3.05,.906,1.82,870 73 | 2,13.86,1.51,2.67,25,86,2.95,2.86,.21,1.87,3.38,1.36,3.16,410 74 | 2,13.49,1.66,2.24,24,87,1.88,1.84,.27,1.03,3.74,.98,2.78,472 75 | 2,12.99,1.67,2.6,30,139,3.3,2.89,.21,1.96,3.35,1.31,3.5,985 76 | 2,11.96,1.09,2.3,21,101,3.38,2.14,.13,1.65,3.21,.99,3.13,886 77 | 2,11.66,1.88,1.92,16,97,1.61,1.57,.34,1.15,3.8,1.23,2.14,428 78 | 2,13.03,.9,1.71,16,86,1.95,2.03,.24,1.46,4.6,1.19,2.48,392 79 | 2,11.84,2.89,2.23,18,112,1.72,1.32,.43,.95,2.65,.96,2.52,500 80 | 2,12.33,.99,1.95,14.8,136,1.9,1.85,.35,2.76,3.4,1.06,2.31,750 81 | 2,12.7,3.87,2.4,23,101,2.83,2.55,.43,1.95,2.57,1.19,3.13,463 82 | 2,12,.92,2,19,86,2.42,2.26,.3,1.43,2.5,1.38,3.12,278 83 | 2,12.72,1.81,2.2,18.8,86,2.2,2.53,.26,1.77,3.9,1.16,3.14,714 84 | 2,12.08,1.13,2.51,24,78,2,1.58,.4,1.4,2.2,1.31,2.72,630 85 | 2,13.05,3.86,2.32,22.5,85,1.65,1.59,.61,1.62,4.8,.84,2.01,515 86 | 2,11.84,.89,2.58,18,94,2.2,2.21,.22,2.35,3.05,.79,3.08,520 87 | 2,12.67,.98,2.24,18,99,2.2,1.94,.3,1.46,2.62,1.23,3.16,450 88 | 2,12.16,1.61,2.31,22.8,90,1.78,1.69,.43,1.56,2.45,1.33,2.26,495 89 | 2,11.65,1.67,2.62,26,88,1.92,1.61,.4,1.34,2.6,1.36,3.21,562 90 | 2,11.64,2.06,2.46,21.6,84,1.95,1.69,.48,1.35,2.8,1,2.75,680 91 | 2,12.08,1.33,2.3,23.6,70,2.2,1.59,.42,1.38,1.74,1.07,3.21,625 92 | 2,12.08,1.83,2.32,18.5,81,1.6,1.5,.52,1.64,2.4,1.08,2.27,480 93 | 2,12,1.51,2.42,22,86,1.45,1.25,.5,1.63,3.6,1.05,2.65,450 94 | 2,12.69,1.53,2.26,20.7,80,1.38,1.46,.58,1.62,3.05,.96,2.06,495 95 | 2,12.29,2.83,2.22,18,88,2.45,2.25,.25,1.99,2.15,1.15,3.3,290 96 | 2,11.62,1.99,2.28,18,98,3.02,2.26,.17,1.35,3.25,1.16,2.96,345 97 | 2,12.47,1.52,2.2,19,162,2.5,2.27,.32,3.28,2.6,1.16,2.63,937 98 | 2,11.81,2.12,2.74,21.5,134,1.6,.99,.14,1.56,2.5,.95,2.26,625 99 | 2,12.29,1.41,1.98,16,85,2.55,2.5,.29,1.77,2.9,1.23,2.74,428 100 | 2,12.37,1.07,2.1,18.5,88,3.52,3.75,.24,1.95,4.5,1.04,2.77,660 101 | 2,12.29,3.17,2.21,18,88,2.85,2.99,.45,2.81,2.3,1.42,2.83,406 102 | 2,12.08,2.08,1.7,17.5,97,2.23,2.17,.26,1.4,3.3,1.27,2.96,710 103 | 2,12.6,1.34,1.9,18.5,88,1.45,1.36,.29,1.35,2.45,1.04,2.77,562 104 | 2,12.34,2.45,2.46,21,98,2.56,2.11,.34,1.31,2.8,.8,3.38,438 105 | 2,11.82,1.72,1.88,19.5,86,2.5,1.64,.37,1.42,2.06,.94,2.44,415 106 | 2,12.51,1.73,1.98,20.5,85,2.2,1.92,.32,1.48,2.94,1.04,3.57,672 107 | 2,12.42,2.55,2.27,22,90,1.68,1.84,.66,1.42,2.7,.86,3.3,315 108 | 2,12.25,1.73,2.12,19,80,1.65,2.03,.37,1.63,3.4,1,3.17,510 109 | 2,12.72,1.75,2.28,22.5,84,1.38,1.76,.48,1.63,3.3,.88,2.42,488 110 | 2,12.22,1.29,1.94,19,92,2.36,2.04,.39,2.08,2.7,.86,3.02,312 111 | 2,11.61,1.35,2.7,20,94,2.74,2.92,.29,2.49,2.65,.96,3.26,680 112 | 2,11.46,3.74,1.82,19.5,107,3.18,2.58,.24,3.58,2.9,.75,2.81,562 113 | 2,12.52,2.43,2.17,21,88,2.55,2.27,.26,1.22,2,.9,2.78,325 114 | 2,11.76,2.68,2.92,20,103,1.75,2.03,.6,1.05,3.8,1.23,2.5,607 115 | 2,11.41,.74,2.5,21,88,2.48,2.01,.42,1.44,3.08,1.1,2.31,434 116 | 2,12.08,1.39,2.5,22.5,84,2.56,2.29,.43,1.04,2.9,.93,3.19,385 117 | 2,11.03,1.51,2.2,21.5,85,2.46,2.17,.52,2.01,1.9,1.71,2.87,407 118 | 2,11.82,1.47,1.99,20.8,86,1.98,1.6,.3,1.53,1.95,.95,3.33,495 119 | 2,12.42,1.61,2.19,22.5,108,2,2.09,.34,1.61,2.06,1.06,2.96,345 120 | 2,12.77,3.43,1.98,16,80,1.63,1.25,.43,.83,3.4,.7,2.12,372 121 | 2,12,3.43,2,19,87,2,1.64,.37,1.87,1.28,.93,3.05,564 122 | 2,11.45,2.4,2.42,20,96,2.9,2.79,.32,1.83,3.25,.8,3.39,625 123 | 2,11.56,2.05,3.23,28.5,119,3.18,5.08,.47,1.87,6,.93,3.69,465 124 | 2,12.42,4.43,2.73,26.5,102,2.2,2.13,.43,1.71,2.08,.92,3.12,365 125 | 2,13.05,5.8,2.13,21.5,86,2.62,2.65,.3,2.01,2.6,.73,3.1,380 126 | 2,11.87,4.31,2.39,21,82,2.86,3.03,.21,2.91,2.8,.75,3.64,380 127 | 2,12.07,2.16,2.17,21,85,2.6,2.65,.37,1.35,2.76,.86,3.28,378 128 | 2,12.43,1.53,2.29,21.5,86,2.74,3.15,.39,1.77,3.94,.69,2.84,352 129 | 2,11.79,2.13,2.78,28.5,92,2.13,2.24,.58,1.76,3,.97,2.44,466 130 | 2,12.37,1.63,2.3,24.5,88,2.22,2.45,.4,1.9,2.12,.89,2.78,342 131 | 2,12.04,4.3,2.38,22,80,2.1,1.75,.42,1.35,2.6,.79,2.57,580 132 | 3,12.86,1.35,2.32,18,122,1.51,1.25,.21,.94,4.1,.76,1.29,630 133 | 3,12.88,2.99,2.4,20,104,1.3,1.22,.24,.83,5.4,.74,1.42,530 134 | 3,12.81,2.31,2.4,24,98,1.15,1.09,.27,.83,5.7,.66,1.36,560 135 | 3,12.7,3.55,2.36,21.5,106,1.7,1.2,.17,.84,5,.78,1.29,600 136 | 3,12.51,1.24,2.25,17.5,85,2,.58,.6,1.25,5.45,.75,1.51,650 137 | 3,12.6,2.46,2.2,18.5,94,1.62,.66,.63,.94,7.1,.73,1.58,695 138 | 3,12.25,4.72,2.54,21,89,1.38,.47,.53,.8,3.85,.75,1.27,720 139 | 3,12.53,5.51,2.64,25,96,1.79,.6,.63,1.1,5,.82,1.69,515 140 | 3,13.49,3.59,2.19,19.5,88,1.62,.48,.58,.88,5.7,.81,1.82,580 141 | 3,12.84,2.96,2.61,24,101,2.32,.6,.53,.81,4.92,.89,2.15,590 142 | 3,12.93,2.81,2.7,21,96,1.54,.5,.53,.75,4.6,.77,2.31,600 143 | 3,13.36,2.56,2.35,20,89,1.4,.5,.37,.64,5.6,.7,2.47,780 144 | 3,13.52,3.17,2.72,23.5,97,1.55,.52,.5,.55,4.35,.89,2.06,520 145 | 3,13.62,4.95,2.35,20,92,2,.8,.47,1.02,4.4,.91,2.05,550 146 | 3,12.25,3.88,2.2,18.5,112,1.38,.78,.29,1.14,8.21,.65,2,855 147 | 3,13.16,3.57,2.15,21,102,1.5,.55,.43,1.3,4,.6,1.68,830 148 | 3,13.88,5.04,2.23,20,80,.98,.34,.4,.68,4.9,.58,1.33,415 149 | 3,12.87,4.61,2.48,21.5,86,1.7,.65,.47,.86,7.65,.54,1.86,625 150 | 3,13.32,3.24,2.38,21.5,92,1.93,.76,.45,1.25,8.42,.55,1.62,650 151 | 3,13.08,3.9,2.36,21.5,113,1.41,1.39,.34,1.14,9.40,.57,1.33,550 152 | 3,13.5,3.12,2.62,24,123,1.4,1.57,.22,1.25,8.60,.59,1.3,500 153 | 3,12.79,2.67,2.48,22,112,1.48,1.36,.24,1.26,10.8,.48,1.47,480 154 | 3,13.11,1.9,2.75,25.5,116,2.2,1.28,.26,1.56,7.1,.61,1.33,425 155 | 3,13.23,3.3,2.28,18.5,98,1.8,.83,.61,1.87,10.52,.56,1.51,675 156 | 3,12.58,1.29,2.1,20,103,1.48,.58,.53,1.4,7.6,.58,1.55,640 157 | 3,13.17,5.19,2.32,22,93,1.74,.63,.61,1.55,7.9,.6,1.48,725 158 | 3,13.84,4.12,2.38,19.5,89,1.8,.83,.48,1.56,9.01,.57,1.64,480 159 | 3,12.45,3.03,2.64,27,97,1.9,.58,.63,1.14,7.5,.67,1.73,880 160 | 3,14.34,1.68,2.7,25,98,2.8,1.31,.53,2.7,13,.57,1.96,660 161 | 3,13.48,1.67,2.64,22.5,89,2.6,1.1,.52,2.29,11.75,.57,1.78,620 162 | 3,12.36,3.83,2.38,21,88,2.3,.92,.5,1.04,7.65,.56,1.58,520 163 | 3,13.69,3.26,2.54,20,107,1.83,.56,.5,.8,5.88,.96,1.82,680 164 | 3,12.85,3.27,2.58,22,106,1.65,.6,.6,.96,5.58,.87,2.11,570 165 | 3,12.96,3.45,2.35,18.5,106,1.39,.7,.4,.94,5.28,.68,1.75,675 166 | 3,13.78,2.76,2.3,22,90,1.35,.68,.41,1.03,9.58,.7,1.68,615 167 | 3,13.73,4.36,2.26,22.5,88,1.28,.47,.52,1.15,6.62,.78,1.75,520 168 | 3,13.45,3.7,2.6,23,111,1.7,.92,.43,1.46,10.68,.85,1.56,695 169 | 3,12.82,3.37,2.3,19.5,88,1.48,.66,.4,.97,10.26,.72,1.75,685 170 | 3,13.58,2.58,2.69,24.5,105,1.55,.84,.39,1.54,8.66,.74,1.8,750 171 | 3,13.4,4.6,2.86,25,112,1.98,.96,.27,1.11,8.5,.67,1.92,630 172 | 3,12.2,3.03,2.32,19,96,1.25,.49,.4,.73,5.5,.66,1.83,510 173 | 3,12.77,2.39,2.28,19.5,86,1.39,.51,.48,.64,9.899999,.57,1.63,470 174 | 3,14.16,2.51,2.48,20,91,1.68,.7,.44,1.24,9.7,.62,1.71,660 175 | 3,13.71,5.65,2.45,20.5,95,1.68,.61,.52,1.06,7.7,.64,1.74,740 176 | 3,13.4,3.91,2.48,23,102,1.8,.75,.43,1.41,7.3,.7,1.56,750 177 | 3,13.27,4.28,2.26,20,120,1.59,.69,.43,1.35,10.2,.59,1.56,835 178 | 3,13.17,2.59,2.37,20,120,1.65,.68,.53,1.46,9.3,.6,1.62,840 179 | 3,14.13,4.1,2.74,24.5,96,2.05,.76,.56,1.35,9.2,.61,1.6,560 180 | -------------------------------------------------------------------------------- /img/skorch2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FernandoLpz/SKORCH-PyTorch-Wrapper/8e0aac8d347e567347eb94104b0ced5abd40a9f7/img/skorch2.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | from sklearn.pipeline import Pipeline 4 | from sklearn.preprocessing import StandardScaler 5 | from sklearn.model_selection import GridSearchCV 6 | 7 | from skorch.callbacks import EpochScoring 8 | from skorch import NeuralNetClassifier 9 | 10 | from src import load_data 11 | from src import NeuralNet 12 | 13 | class Run: 14 | def __init__(self, x, y): 15 | self.x = x 16 | self.y = y 17 | 18 | def simple_training(self): 19 | # Trains the Neural Network with fixed hyperparameters 20 | 21 | # The Neural Net is initialized with fixed hyperparameters 22 | nn = NeuralNetClassifier(NeuralNet, max_epochs=10, lr=0.01, batch_size=12, optimizer=optim.RMSprop) 23 | # Training 24 | nn.fit(self.x, self.y) 25 | pass 26 | 27 | def simple_pipeline_training(self): 28 | # Trains the Neural Network within a scikit-learn pipeline 29 | # The pipeline is composed by scaling features and NN training 30 | # The hyperparameters are fixed values 31 | 32 | # The Neural Net is initialized with fixed hyperparameters 33 | nn = NeuralNetClassifier(NeuralNet, max_epochs=10, lr=0.01, batch_size=12, optimizer=optim.RMSprop) 34 | # The pipeline instatiated, it wraps scaling and training phase 35 | pipeline = Pipeline([('scale', StandardScaler()), ('nn', nn)]) 36 | # Pipeline execution 37 | pipeline.fit(self.x, self.y) 38 | 39 | pass 40 | 41 | def simple_pipeline_training_with_callbacks(self): 42 | # Trains the Neural Network within a scikit-learn pipeline 43 | # The pipeline is composed by scaling features and NN training 44 | # A callback is added in order to calculate the "balanced accuracy" and "accuracy" in the training phase 45 | 46 | # The EpochScoring from callbacks is initialized 47 | balanced_accuracy = EpochScoring(scoring='balanced_accuracy', lower_is_better=False) 48 | accuracy = EpochScoring(scoring='accuracy', lower_is_better=False) 49 | 50 | # The Neural Net is initialized with fixed hyperparameters 51 | nn = NeuralNetClassifier(NeuralNet, max_epochs=10, lr=0.01, batch_size=12, optimizer=optim.RMSprop, callbacks=[balanced_accuracy, accuracy]) 52 | # The pipeline instatiated, it wraps scaling and training phase 53 | pipeline = Pipeline([('scale', StandardScaler()), ('nn', nn)]) 54 | # Pipeline execution 55 | pipeline.fit(self.x, self.y) 56 | 57 | pass 58 | 59 | def grid_search_pipeline_training(self): 60 | # Through a grid search, the optimal hyperparameters are found 61 | # A pipeline is used in order to scale and train the neural net 62 | # The grid search module from scikit-learn wraps the pipeline 63 | 64 | # The Neural Net is instantiated, none hyperparameter is provided 65 | nn = NeuralNetClassifier(NeuralNet, verbose=0, train_split=False) 66 | # The pipeline is instantiated, it wraps scaling and training phase 67 | pipeline = Pipeline([('scale', StandardScaler()), ('nn', nn)]) 68 | 69 | # The parameters for the grid search are defined 70 | # It must be used the prefix "nn__" when setting hyperparamters for the training phase 71 | # It must be used the prefix "nn__module__" when setting hyperparameters for the Neural Net 72 | params = { 73 | 'nn__max_epochs':[10, 20], 74 | 'nn__lr': [0.1, 0.01], 75 | 'nn__module__num_units': [5, 10], 76 | 'nn__module__dropout': [0.1, 0.5], 77 | 'nn__optimizer': [optim.Adam, optim.SGD, optim.RMSprop]} 78 | 79 | # The grid search module is instantiated 80 | gs = GridSearchCV(pipeline, params, refit=False, cv=3, scoring='balanced_accuracy', verbose=1) 81 | # Initialize grid search 82 | gs.fit(self.x, self.y) 83 | pass 84 | 85 | if __name__ == "__main__": 86 | x, y = load_data() 87 | 88 | run = Run(x, y) 89 | 90 | # run.simple_training() 91 | # run.simple_pipeline_training() 92 | # run.simple_pipeline_training_with_callbacks() 93 | run.grid_search_pipeline_training() 94 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import load_data 2 | from .model import NeuralNet -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.utils import shuffle 4 | 5 | def load_data(): 6 | # Load csv dataset 7 | data = pd.read_csv('data/wines.csv') 8 | 9 | # Shuffling data 10 | data = shuffle(data) 11 | 12 | # Fix class labels 13 | # Original class labels are [1, 2, 3], the ones must be changed as [0, 1, 2] 14 | data['class'] = data['class'].replace([1, 2, 3],[0, 1, 2]) 15 | 16 | 17 | # Split x and y vectors 18 | x = data[[feature for feature in data.columns if feature != 'class']].values 19 | y = np.squeeze(data[['class']].values) 20 | 21 | # Fix datatypes 22 | x = x.astype(np.float32) 23 | y = y.astype(np.int64) 24 | 25 | return x, y -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class NeuralNet(nn.Module): 6 | def __init__(self, num_units=10, dropout=0.1): 7 | super(NeuralNet, self).__init__() 8 | self.num_units = num_units 9 | self.linear_1 = nn.Linear(13, num_units) 10 | self.dropout = nn.Dropout(dropout) 11 | self.linear_2 = nn.Linear(num_units, 10) 12 | self.linear_3 = nn.Linear(10, 3) 13 | 14 | def forward(self, x): 15 | 16 | x = self.linear_1(x) 17 | x = F.relu(x) 18 | x = self.linear_2(x) 19 | x = F.relu(x) 20 | x = self.linear_3(x) 21 | x = F.softmax(x, dim=-1) 22 | 23 | return x --------------------------------------------------------------------------------