├── .gitignore
├── LICENSE
├── README.md
├── data
├── toy-regression-features.csv
└── toy-regression-labels.csv
├── example.ipynb
├── requirements.txt
├── results
└── reduced_dataset.csv
├── src
├── __init__.py
├── data.py
└── ml.py
└── tests
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 apalladi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # feature-selection-adding-noise
2 |
3 | ## Introduction
4 | The purpose of this small library is to apply feature selection to your high-dimensional data. In order to do that, we apply the following steps:
5 | 1) the input features are automatically standardized
6 | 2) a column containing gaussian noise (average = 0, standard deviation = 1) is added
7 | 3) a model is trained
8 | 4) the feature importance is evaluated
9 | 5) all the features that are less important than the random one, are excluded
10 | The previous steps are repeated iteratively, until the algorithm converges.
11 |
12 | ## Initialize the repository
13 | Let us start by cloning the repository, by using the following command:
14 | ```
15 | git@github.com:apalladi/feature-selection-adding-noise.git
16 | ```
17 | Then you need to install the dependencies. I suggest to create a virtual environment, as follows:
18 | ```
19 | python3 -m venv .env
20 | source .env/bin/activate
21 | pip install -r requirements.txt
22 | ```
23 |
24 | To check if everything works, you can run the unit tests:
25 | ```
26 | python -m pytest tests/test.py
27 | ```
28 |
29 | You are now ready to use the repository!
30 |
31 | ## Getting started
32 | The example you need is contained in [this notebook](example.ipynb).
33 | A toy dataset, to build a regression model, is imported.
34 | Then we import the function `get_relevant_features`
35 | ```
36 | from src.ml import get_relevant_features
37 | ```
38 | This function takes as arguments:
39 | - `features`
40 | - `labels`
41 | - `model`, a scikit-learn model
42 | - `epochs`, the number of epochs (i.e. for how many cycles you want to apply recursively the feature selection)
43 | - `patience`, number of epochs without any improvement of the features selection, before stopping the process (the idea is similar to the early stopping of Tensorflow/Keras)
44 | - `splitting_type`, it can be equal to `simple` (for simple train/test split) or `kfold` (for 5-fold splitting). If you choose `kfold`, the feature importance will be computed as the average feature importance for each train/test subset.
45 | - `noise_type`, it can be equal to `gaussian` for gaussian noise or `random` for flat random noise
46 | - `filename_output`, a string to indicate where to save the file. You can also choose `None` if you do not want to save it
47 | - `random_state`, set the random seed that it is used by the k-fold splitting
48 |
49 | The function `get_relevant_features` returns a DataFrame with a reduced dataset, i.e. a dataset that contains only the most important features.
50 |
51 |
--------------------------------------------------------------------------------
/data/toy-regression-labels.csv:
--------------------------------------------------------------------------------
1 | labels
2 | 219.66421110537377
3 | 40.99812573852436
4 | 109.42385821856618
5 | -280.66750675702076
6 | 599.1785378315125
7 | 210.84346059574347
8 | 283.00445948218146
9 | -30.12438553925481
10 | -59.81206007271604
11 | -420.65737435194205
12 | -197.5390398130957
13 | 45.09427291240982
14 | -473.5937200217304
15 | 99.05970861007242
16 | -306.0419398927675
17 | -79.5975539814209
18 | -162.12424764307752
19 | 78.67863470667349
20 | 408.7254953097355
21 | 123.84731009652106
22 | -102.12843400269146
23 | -17.312519366997122
24 | 287.8622440339944
25 | 134.2825735118363
26 | -288.2538933803337
27 | -294.31645884698355
28 | 93.5402459089359
29 | -165.75915298963588
30 | 187.3625488233377
31 | 477.24364434704654
32 | -160.02519426852024
33 | -369.07047587223315
34 | 124.76128030649855
35 | -168.9598361723375
36 | 183.84951935491583
37 | 346.93821420244996
38 | -146.9363708507223
39 | -475.7711780150155
40 | 164.3657272962622
41 | -265.0171836664488
42 | -593.7346139201999
43 | -209.7491908447602
44 | -342.35365549545133
45 | -68.77410945324183
46 | 48.28451867992422
47 | -258.396355994631
48 | 39.1283861478949
49 | -420.94219859643
50 | 143.72124261476043
51 | 275.85123174279823
52 | 67.1920188813163
53 | 226.3593163855234
54 | -99.55328243673284
55 | -408.20911939582174
56 | -99.08091797152437
57 | -481.2436653525199
58 | -2.9292434308502493
59 | -50.57606534287761
60 | 175.3610172710227
61 | 220.4375232785938
62 | 202.3942278357318
63 | -113.57111937095235
64 | 232.3565587371829
65 | -162.54704358780953
66 | -57.130109705320024
67 | 441.03614555634294
68 | -67.02651224239445
69 | -509.34070453705914
70 | -27.870055104524937
71 | -229.53662143572663
72 | -250.14024664163387
73 | -366.4372342171117
74 | 46.518804179761176
75 | 116.92043101176571
76 | 18.820544960998525
77 | 210.30034026958214
78 | 478.5477878660953
79 | 417.2844739466522
80 | 102.39941584413236
81 | -978.0120573994108
82 | -308.8451209548697
83 | -213.5210097425842
84 | 261.4645116037161
85 | 36.27066302849197
86 | -6.669884127949885
87 | 358.608687862422
88 | -226.98071759699877
89 | 74.74585810340724
90 | 8.740405469897468
91 | -57.65180347772909
92 | 345.68698072536876
93 | 323.44218805538713
94 | -238.43075811367788
95 | 344.57810996390094
96 | -353.15252408200615
97 | 311.8043731963006
98 | 301.9183965858471
99 | 352.5979167661178
100 | 110.00390089019643
101 | 246.01921125336307
102 | 677.599200798106
103 | 711.4843824206637
104 | 191.7277911019544
105 | 57.99259068049659
106 | -383.0990013737419
107 | 334.4620955849534
108 | -288.13816332058644
109 | 161.98906727450884
110 | -256.52604740772733
111 | 33.02783772330004
112 | -369.31457330854863
113 | 282.8980069994969
114 | -108.07987335065204
115 | 181.88867152135373
116 | -463.355389084535
117 | -57.16858021828898
118 | 302.345137705743
119 | -18.689778293238987
120 | -325.1782919210175
121 | -130.03126703705072
122 | -45.75588839018792
123 | 95.40621723970422
124 | -479.6015968105236
125 | 227.22243128108795
126 | -30.991358685988317
127 | -142.9779163301373
128 | -893.830818650887
129 | 131.59725557684942
130 | -208.08364950440514
131 | -50.11918151891862
132 | -129.82418090113322
133 | -204.7782679241685
134 | -365.2291784981545
135 | 958.0195550658315
136 | 91.03634732414339
137 | 668.3293179922772
138 | 137.33931696932729
139 | 156.45819823550076
140 | 46.78787104054108
141 | -73.0611991287546
142 | -16.079112537893373
143 | -332.9179458827927
144 | 12.608764116640717
145 | 265.29304542408283
146 | -139.21545955687984
147 | 110.2828069067367
148 | -44.607068712910745
149 | 16.492497282475313
150 | 317.5265009381611
151 | -401.0451689065083
152 | 12.067463582747962
153 | -83.09755944022235
154 | -11.42361671446308
155 | 355.4587444092366
156 | -278.8705245321509
157 | 259.0809422276131
158 | 131.9797904814068
159 | 458.8901234460045
160 | -201.37984822699846
161 | 77.89296354228058
162 | -371.78620712581846
163 | 303.19151387677914
164 | 295.70743716021695
165 | 244.792621765663
166 | -151.51871438114455
167 | 375.0709213146323
168 | 406.5079040014588
169 | 286.7800957483075
170 | -282.5277096305248
171 | -420.2630572029418
172 | -413.618662390845
173 | 216.3436814539506
174 | -683.6027897802443
175 | 92.93560643486715
176 | 44.70344795422818
177 | 343.9726303950285
178 | 66.1754662656133
179 | -29.590210811428733
180 | -3.368565792046752
181 | -315.1923255488523
182 | 412.37831319704844
183 | 414.65900678313886
184 | 30.541122590840843
185 | -238.63012467202824
186 | -96.84336054466398
187 | 300.1892887346487
188 | -334.54217813498695
189 | 27.99608296508022
190 | -61.77129204337399
191 | 14.942188777365573
192 | -525.7326455722985
193 | 112.7322750639299
194 | 364.84216999578814
195 | -295.65563357735243
196 | 120.74118061033646
197 | -313.03977941675026
198 | -145.45345902728027
199 | 11.797154900881168
200 | -61.810344123332534
201 | 74.26605490724862
202 | -309.00266961667126
203 | -530.0828723263916
204 | 418.57140590307944
205 | -41.74482714115291
206 | 121.41684815273871
207 | 405.2362104672736
208 | -171.9970231255059
209 | 88.2797601652037
210 | 512.6155407924663
211 | 269.54406471209035
212 | -124.36335410913836
213 | -324.12015696729463
214 | 246.2544716090739
215 | 434.7036044660102
216 | -223.69334812177925
217 | 463.5430305770866
218 | 359.6186097903139
219 | 396.58142510788787
220 | 738.2674707542574
221 | 180.06500962643116
222 | -288.23528134988373
223 | -266.16857094666966
224 | -577.184257189553
225 | -480.8624043453125
226 | 112.89469666560275
227 | -60.85487516021421
228 | 442.58742035798167
229 | 248.77945846601668
230 | -174.53520141553088
231 | 562.0106297094923
232 | 396.1942647182949
233 | 199.06894198875804
234 | 252.71478393513448
235 | -146.93929247269597
236 | 191.7323273679387
237 | -45.2958727095224
238 | -427.21687468532275
239 | 12.461349601654678
240 | 422.30717448682867
241 | 233.6751650938761
242 | 563.0275372062894
243 | 433.30058799142154
244 | 108.30182641189415
245 | -218.02224784442905
246 | -183.4040116108224
247 | -182.92255119924062
248 | 392.30028867140544
249 | 3.7131965707191625
250 | 497.99579980378456
251 | 212.01834201356186
252 | 715.4649248810988
253 | 519.4609797102954
254 | -26.989752202347162
255 | 169.2888596594962
256 | 18.92806454001243
257 | -537.2461912164638
258 | -326.86335003591523
259 | -286.22485161024184
260 | 39.01717391310153
261 | -299.96850412900557
262 | -174.72281041255383
263 | -329.74061064784985
264 | 87.28903183163771
265 | 337.6511516443617
266 | 62.49194262923407
267 | 143.18118441638228
268 | -73.79382938901495
269 | 535.2243122098414
270 | 175.81395911870072
271 | -193.80398326330058
272 | 146.98895958839728
273 | 670.1157889736027
274 | -335.4876878063802
275 | -198.371912719899
276 | 314.2386412193764
277 | 168.8034824012617
278 | 306.6286193162233
279 | -367.7065829404182
280 | 498.65778038228956
281 | 56.676911072786595
282 | 298.9089820684999
283 | 375.2550323657371
284 | 158.16757721904568
285 | 200.4595688879491
286 | 509.27615081515336
287 | 141.71944355481935
288 | -358.73494550303235
289 | -141.46508942485553
290 | 200.38965431527947
291 | 639.9028651997666
292 | 93.07027597756083
293 | -146.52030808194593
294 | 144.75082163123142
295 | -111.83320853763635
296 | -50.37950192580972
297 | 501.19660261519607
298 | -67.44465712983656
299 | 534.1826610568827
300 | -160.7425747770929
301 | 85.73377964948206
302 | -96.47360982691012
303 | -274.9768131633041
304 | -412.5949823402696
305 | 477.4320851906831
306 | 476.16694166567004
307 | 53.43186613335136
308 | 187.87089203687583
309 | -205.79927102995535
310 | -237.13510435758218
311 | 36.3208198702175
312 | 207.94503006838846
313 | -88.9914384510391
314 | 186.6920918619557
315 | 86.53458018538245
316 | 199.05376088951718
317 | 50.04826687442488
318 | -205.8351854633912
319 | 289.4862395862873
320 | 78.71216376430846
321 | -462.4158964137397
322 | -76.2541187285045
323 | -402.35782297263756
324 | -877.2417176434378
325 | 57.554397302157
326 | -373.5605273471517
327 | 35.245880980242866
328 | -296.7834630772414
329 | -178.4838476097142
330 | 111.24982505285713
331 | -96.79190759717235
332 | -477.4648890058636
333 | 445.67187652053417
334 | -79.15399783119454
335 | 116.88235533395297
336 | 140.65941231630532
337 | 318.7863278451861
338 | -142.36798735269005
339 | -452.34506451028284
340 | 14.033099980482014
341 | -478.32528351098404
342 | 62.32621657690383
343 | -85.67661762917075
344 | 23.88662281891239
345 | 83.5806484692816
346 | -119.7489839281694
347 | -13.711441029105597
348 | -142.0344353209897
349 | 195.61770088578427
350 | 83.3791706398313
351 | -207.13874933383445
352 | -405.67021956846685
353 | -232.55675706729517
354 | 369.80403435973915
355 | 36.919250572105284
356 | 671.7929214211165
357 | -118.61369183668698
358 | 3.750548887927671
359 | -37.75545984466757
360 | -161.4348670843344
361 | -17.961521195049784
362 | 434.77540556488526
363 | -215.0279279255545
364 | 279.14136475142726
365 | -335.9299339463031
366 | 275.1685178819571
367 | -155.30757260332908
368 | -21.696019192629535
369 | -15.924744454492455
370 | 43.906253087008835
371 | 337.7686374470297
372 | -211.05341310481703
373 | -12.746175500056381
374 | 311.6254368113914
375 | 260.4555125568561
376 | 29.82725104023927
377 | 812.7399411215245
378 | 21.450446688103256
379 | 39.67367688152194
380 | -423.7973802453141
381 | -547.631683256067
382 | -102.93782119795702
383 | 74.76674346877166
384 | 107.52179751282722
385 | 280.4643452488732
386 | 174.61066521970474
387 | 101.23715037872829
388 | 396.05923966670605
389 | -377.947729330531
390 | -164.04101178001133
391 | 61.173919708742005
392 | 192.76588108646376
393 | -318.8179933745456
394 | -521.3846419750324
395 | 258.9820100030168
396 | 419.3231880074004
397 | -149.44226140107605
398 | 98.63114384731146
399 | -283.00958979274805
400 | 206.80574392772616
401 | -471.448371385884
402 | 130.07083889411336
403 | -501.04526864600626
404 | 89.24531288037647
405 | 228.05957571684934
406 | 318.30367800303617
407 | -434.4176084063522
408 | 774.5826919616825
409 | -0.9240439714436945
410 | -226.9520935877121
411 | 386.33513606410327
412 | 370.3836005652059
413 | 548.3799787687266
414 | 108.16433254033271
415 | -276.89949378114954
416 | -173.49339913684315
417 | 62.7674938216976
418 | -524.4276129904287
419 | 169.90198539054362
420 | 58.567795121879044
421 | -290.3747761985046
422 | 438.5833753750337
423 | -321.3578626667747
424 | -80.22699823707424
425 | -244.867673307717
426 | 209.4160679033282
427 | 242.13090604931872
428 | -307.18553169522727
429 | 288.7689437895989
430 | 236.37393252502866
431 | 313.063444212058
432 | 78.54134351510749
433 | -138.6433548546298
434 | 90.01870742580228
435 | 320.52346086529633
436 | -218.48563832111998
437 | 470.88640661939144
438 | 250.13352504170456
439 | -774.7651878072118
440 | -143.62815334681963
441 | 304.7855992694973
442 | -29.218330087607676
443 | -429.8289041752295
444 | -106.37838578950459
445 | -151.85466042085102
446 | -136.38359123956027
447 | 498.8875532079957
448 | -310.20813667277804
449 | -52.68642490176664
450 | 75.2253541889834
451 | -596.8809617691197
452 | -135.78130261309687
453 | -2.973639249572358
454 | -105.76612422435545
455 | 27.68036765794641
456 | -109.46412956248764
457 | 21.70842618529365
458 | -62.75111953580587
459 | 54.85449410719809
460 | 646.883158144461
461 | -132.5690334866119
462 | 275.5876887887282
463 | 470.40560918955055
464 | -73.56141628702832
465 | -495.2509411971815
466 | 437.79739193352054
467 | 313.7208983943566
468 | 89.79983035183005
469 | 69.22106362777522
470 | -149.181549095875
471 | 121.28585336420143
472 | -257.4156559512919
473 | 147.40779256195276
474 | -827.4059274712373
475 | -126.83611731992451
476 | -155.1572846515104
477 | 397.312269072818
478 | 341.8377369066052
479 | -573.1004574202724
480 | -379.46651936485296
481 | 232.37312209668877
482 | 64.35427719644893
483 | -403.03132612200744
484 | -695.0170848613226
485 | 112.47782185996614
486 | -345.074080739759
487 | -17.48180554945972
488 | 138.58459553286718
489 | 86.0430565982474
490 | -86.70623349810324
491 | 121.13398379191366
492 | -649.0922728040513
493 | 139.6156164280343
494 | -534.6862563642604
495 | 281.87317484690914
496 | 612.3306504710739
497 | -405.43560777977365
498 | 442.58039978248786
499 | -254.94346887899212
500 | -88.76336873598447
501 | -633.3370404134627
502 | -288.1141069995499
503 | 366.78973999043546
504 | -224.8536442421646
505 | 49.26671000659827
506 | 165.83040781552805
507 | 230.49469903002304
508 | 1057.9904144036054
509 | 473.39259752821386
510 | -49.86707813435416
511 | -200.6655577952709
512 | -18.614254099505132
513 | -118.9877427327001
514 | -101.85548024116953
515 | -3.3007141186634996
516 | 227.1415666490299
517 | -130.88124656998087
518 | 301.7640536750737
519 | 221.79947258234054
520 | 55.85295941144099
521 | -475.15121157051595
522 | -290.21029063523025
523 | 607.3300594336403
524 | -6.972351318487355
525 | -213.26044982185797
526 | 271.41930318809136
527 | -87.87163521051338
528 | 82.85724429895416
529 | 60.13044183532068
530 | 238.71488930311492
531 | -205.30782778551102
532 | -167.97249164001585
533 | -626.6680881535467
534 | -210.71794044174294
535 | 338.4268516743314
536 | 21.327754022949932
537 | 131.11375629474173
538 | 99.63642744845114
539 | -431.4696571353565
540 | 268.53069314040255
541 | 183.7155684889983
542 | 206.1136567255354
543 | 938.485055074153
544 | -370.3122696032017
545 | 302.2829100726916
546 | -345.980519117454
547 | 51.90634308481428
548 | 100.59413868158724
549 | 258.95425833506727
550 | -354.49223246498894
551 | -188.29842177527783
552 | 334.018648080925
553 | 82.57817449668167
554 | 105.40827565012256
555 | -642.6763674904915
556 | 80.88811418501643
557 | -537.7847275519619
558 | -44.994214473374534
559 | -668.5934888064476
560 | 225.93756565254031
561 | 727.7205122969831
562 | 69.73521987107483
563 | -0.6889414780357015
564 | 60.25016581711141
565 | 361.763437322291
566 | -193.73993874298833
567 | -75.06634974409894
568 | 87.67349755652162
569 | -70.7317911894568
570 | -158.9408099969209
571 | -53.949959598525496
572 | 289.97169234846314
573 | 513.8419618454099
574 | 456.7974490024387
575 | -387.32989714442715
576 | 11.541136573054459
577 | -516.2398183426416
578 | 22.20570426318264
579 | 44.93674703405627
580 | 111.74542701010068
581 | -166.65462533937958
582 | -423.0594820821815
583 | -67.79888144006907
584 | 95.99435634966315
585 | -153.93204277170054
586 | -441.61233639970544
587 | -377.97986068534146
588 | 43.88456551556874
589 | -159.09913081414723
590 | -601.3987380926017
591 | 159.0792015242133
592 | -64.48693133155123
593 | -376.56215882758215
594 | 702.7653676049299
595 | -92.69929665954832
596 | 204.1021997084409
597 | -8.097467943855179
598 | 48.60906336352542
599 | 294.90707401500975
600 | -242.9887235405698
601 | -57.499842323521875
602 | -456.36415593043733
603 | -7.932043052156564
604 | -26.227365280070046
605 | 368.6816217166653
606 | -256.03451275320384
607 | -94.58741293806388
608 | -30.03610893222566
609 | 178.4686951578877
610 | -271.2166766746184
611 | 370.0690174105201
612 | 72.05821771129402
613 | -171.4086748354379
614 | 119.67317757590877
615 | 47.86138741039744
616 | -265.26096397236466
617 | 425.8969952679584
618 | -123.18707188929788
619 | -167.2729715423001
620 | 253.92919370621652
621 | 6.371348453025277
622 | -106.02817326063514
623 | 741.5831023881881
624 | 238.74053090627243
625 | 366.96040429489227
626 | 393.27889904186316
627 | -162.37593584076217
628 | -194.5705991304323
629 | 453.1045664785996
630 | 92.57524932430822
631 | -115.34039171552371
632 | 263.6981736022322
633 | -411.37409886920096
634 | -381.9315239427792
635 | -87.82999213046611
636 | -40.75840233833431
637 | -574.9642829681824
638 | -203.27207694459548
639 | -122.741903041147
640 | -45.81988575674249
641 | 190.31576548103914
642 | -83.47111790485431
643 | -594.0219764302144
644 | -331.555854994611
645 | -551.2374852333826
646 | 148.7805869062618
647 | -215.18544048959663
648 | -296.58525416818475
649 | 383.39717648406804
650 | -319.99519952221743
651 | -322.55548938317884
652 | 688.4763052174203
653 | -42.797139515328055
654 | -128.11759461291382
655 | -374.23022915977714
656 | 43.74615623829504
657 | -207.0158737196524
658 | -350.7611332215587
659 | -284.1165610418166
660 | -90.74801244701638
661 | 175.22957014566754
662 | -298.4087178186445
663 | -358.3700225301559
664 | -15.945301151668161
665 | -423.47767401950034
666 | 51.633991919049144
667 | -172.36802631710177
668 | -115.72692061473215
669 | -154.41682560554972
670 | -252.19844458800276
671 | 301.13990716116524
672 | -247.8283487837687
673 | 213.34261253129188
674 | 283.4762578581023
675 | 220.15047098936952
676 | -63.64095810380125
677 | -271.4068829605406
678 | 172.2692175343363
679 | -556.5743113894846
680 | -101.92502182653703
681 | -385.27886175543796
682 | -116.67586605766246
683 | -220.27319939629527
684 | -447.6788463490998
685 | 86.42500713688688
686 | 326.30401506756243
687 | 247.6125458429959
688 | -125.25381166234219
689 | 135.92522190244517
690 | 812.9778135202583
691 | 240.70471517707932
692 | -32.89302845321356
693 | -25.069754727092985
694 | -124.8494311761978
695 | 78.1042829869844
696 | -192.97338775699205
697 | 486.3149327304531
698 | -50.88187970179066
699 | -239.1838866235216
700 | -348.31392934990356
701 | 79.13059626041633
702 | -397.6528365157774
703 | 9.260427333251801
704 | -303.56384357784856
705 | -143.0683011881344
706 | 154.97431087417888
707 | 167.8703940779693
708 | -250.37861808592208
709 | -36.73229949348913
710 | 165.60650395211113
711 | 436.6006340377123
712 | 144.98908182197334
713 | 138.68962210567366
714 | 439.6913244022819
715 | -105.56865101487321
716 | -144.6139907739925
717 | 41.46050366051495
718 | -100.38103185724285
719 | -115.9344709778836
720 | -122.55243220564434
721 | 135.74693752080393
722 | 318.02752995801427
723 | 659.0212948798771
724 | -178.9965439883482
725 | 153.70731456575817
726 | 12.561994790850235
727 | 538.2284866710995
728 | -149.14274589228876
729 | 56.63563619964414
730 | 106.55202867811536
731 | -83.95084515175216
732 | 176.37471060505206
733 | -12.396560499889546
734 | 29.014189286129238
735 | -322.8971488544742
736 | 512.7493263315325
737 | 188.19605378069392
738 | -26.532341521799662
739 | -340.8873819200601
740 | 42.375938360829224
741 | -432.2287980731304
742 | -306.8147227365162
743 | 80.56253887254563
744 | 531.8022962281085
745 | 122.22210015068346
746 | -28.711876698403643
747 | -736.6318129813221
748 | -129.23225956774758
749 | -4.192889360752608
750 | 84.53862831722492
751 | 136.43955202170602
752 | 103.04309754489225
753 | 900.5480460628117
754 | 488.18126422960006
755 | 194.6835555204844
756 | -50.652877604316025
757 | 209.9990779052415
758 | -113.52168835069503
759 | -215.41697904140096
760 | 48.07543592712512
761 | 283.94707766190487
762 | -430.01140706045237
763 | -43.82034902582119
764 | 43.09281783061093
765 | -299.3197271738013
766 | 175.75587560089198
767 | 77.95424864352577
768 | 196.45854533770043
769 | 212.90041115696633
770 | -209.5943712051902
771 | 93.15746999857956
772 | 61.03128186054663
773 | -77.99932355249459
774 | 132.37960560861092
775 | 47.58263934272867
776 | -230.0332142697132
777 | -134.74106595396276
778 | 79.37675734884638
779 | 278.36548634963424
780 | 221.96462495616566
781 | -822.880410741632
782 | 304.24573270881484
783 | -157.30162263219063
784 | 114.76518810323809
785 | 683.7096888263625
786 | 332.10317504702385
787 | -150.9332671102602
788 | 5.780666939958898
789 | 62.28137295155727
790 | -21.88742304950337
791 | -192.86481759405876
792 | 104.78803076331286
793 | -2.4886452846441074
794 | 95.87387009329376
795 | 106.9098263477729
796 | -326.0344916581623
797 | -459.3990198880376
798 | -320.9807797173338
799 | 41.44749510332275
800 | 299.9946902476721
801 | 568.0657424201406
802 | -52.80744052754126
803 | 166.30254539587904
804 | 513.3007141852776
805 | 490.06698675010614
806 | 339.26341073253946
807 | 121.70581006253593
808 | 129.56754891837846
809 | 476.9660515715946
810 | -764.1239805671025
811 | -234.2538479919876
812 | 272.68222245244743
813 | -65.60175064220076
814 | -307.2895651112007
815 | -263.1630745143601
816 | 426.2143861529066
817 | -4.864004048687123
818 | 465.2616236349868
819 | 239.09667158297088
820 | 359.5919087000877
821 | -62.051583979771976
822 | 300.6095963012117
823 | 493.78166093247654
824 | 276.9596787772388
825 | -51.958690491403075
826 | 11.622513648479057
827 | -372.2848578850104
828 | -60.25501178999095
829 | -270.2801583332541
830 | -124.03732005035349
831 | 207.46249274383024
832 | -77.92651595271144
833 | 199.05196933469293
834 | 23.68396605251496
835 | 406.93646409719247
836 | 333.43564416626447
837 | 379.9091676365338
838 | -291.45600463045366
839 | 214.31898278938803
840 | 108.54504251233661
841 | 392.9890931378893
842 | 389.37490872620447
843 | -47.13948884817606
844 | 236.2736616453898
845 | 536.2291186335767
846 | 313.35672188644133
847 | 157.9410042327246
848 | 415.0832675205393
849 | 237.29531787190953
850 | -70.10915410006021
851 | -29.286773927063404
852 | -345.1713423345644
853 | -235.1522562252115
854 | 15.712787711489767
855 | 271.34011205379045
856 | 126.09839785627477
857 | 76.53085613644575
858 | 330.60732285513814
859 | 34.440358762550986
860 | 660.739973216025
861 | -303.05501843825823
862 | 133.75158283314812
863 | -724.343066664628
864 | 183.09812144722144
865 | -3.302592053092212
866 | -214.9021562043016
867 | -277.4859874448581
868 | 457.8643085395454
869 | -127.90289527090574
870 | 124.26463199984457
871 | -14.066153487964044
872 | 606.1254619119906
873 | 752.9105249583428
874 | 20.998367218762038
875 | -288.3982756395808
876 | 129.1460922441576
877 | -295.13280290128586
878 | -219.95543227476344
879 | 456.8799114608556
880 | 621.2093520109033
881 | -83.93690799290073
882 | 206.05472525814582
883 | 96.93661335941695
884 | 102.34345019423195
885 | 568.3362243819324
886 | -472.8872467791076
887 | 89.17749470969224
888 | -23.907146376147715
889 | -474.1992606768101
890 | -42.71255474529988
891 | -60.75882464083398
892 | -131.2080420503905
893 | 488.0404769587037
894 | -54.580161168673776
895 | -100.79516212291948
896 | -256.70422868482774
897 | 363.95303473481823
898 | 251.05877228577782
899 | -427.60908165701653
900 | 271.985574552793
901 | 32.54576913176726
902 | 242.48784889913935
903 | 196.22713157388603
904 | 496.90183801658486
905 | -674.7436315228723
906 | -21.246938994175366
907 | 750.9383308111346
908 | 184.03873828418403
909 | -87.20430335899057
910 | -512.7163323894375
911 | 213.92223372220948
912 | -214.58027422152009
913 | 602.179427584113
914 | -140.43796059460806
915 | 119.13023473128969
916 | -173.56948532405613
917 | 267.5881448430044
918 | 503.2388043031035
919 | -23.11613398276637
920 | 67.1934141382712
921 | -398.0822267265129
922 | -86.00128294019206
923 | -323.2790498924765
924 | -455.19634577646275
925 | -345.47731457905314
926 | 180.32857789932658
927 | 439.18198130609005
928 | -58.349789948524915
929 | -51.37114972775309
930 | 358.43440398144514
931 | 34.41218090391774
932 | -310.46791180264745
933 | 30.432302966235966
934 | -289.7008373298527
935 | 121.69478729348083
936 | -80.06079925455828
937 | 341.92481591386377
938 | -256.23416416202184
939 | -130.00793340771452
940 | -510.569989749574
941 | -73.93557152974061
942 | 629.7574088812376
943 | -297.84547292519795
944 | -341.32309331391923
945 | 703.0937906274942
946 | -348.1396269485223
947 | 148.6247895162148
948 | -534.7323875010395
949 | 348.1935402945218
950 | 416.12154662501456
951 | 185.55488326379268
952 | 532.7910188918678
953 | 600.4218386644078
954 | -581.5320404118182
955 | -146.25664261076827
956 | -79.41256718592786
957 | 36.02759064730649
958 | -166.02555364748002
959 | 217.1552919966481
960 | -290.73080080056593
961 | -143.63862526707655
962 | -210.1774268490937
963 | -476.94008056005026
964 | -488.6480697893357
965 | -98.15687158225435
966 | -382.8604131254918
967 | 850.4747444559168
968 | 338.46030491126146
969 | -59.79339507552459
970 | 157.35918250239692
971 | 200.40195612609668
972 | -735.6164027108637
973 | -34.50250712620171
974 | 406.77013960209877
975 | 57.725587096903666
976 | -187.4647698018989
977 | 76.94863072652515
978 | 40.71787537182026
979 | 201.7848818691717
980 | -110.50546079155595
981 | -341.0723169408731
982 | -274.0465840801938
983 | 659.8104426557172
984 | -126.5209140985389
985 | -455.3151630042686
986 | -652.1053535876599
987 | -207.10595716904226
988 | -263.2606998263593
989 | -79.68629052372415
990 | -165.27483731283775
991 | -23.465321928129832
992 | 401.1325288416542
993 | -226.63495432426132
994 | 199.6686463904946
995 | -328.02133654428997
996 | -53.60775725091852
997 | 192.5280891686172
998 | -121.51712662816664
999 | -198.5471746260622
1000 | -62.324907025231084
1001 | 396.3846847390514
1002 |
--------------------------------------------------------------------------------
/example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "61895906-6628-4029-b022-24dd1dbef8e5",
6 | "metadata": {},
7 | "source": [
8 | "# Introduction\n",
9 | "\n",
10 | "In this Notebook we show how to use this library, that performs the feature selection by introducing a column with gaussian noise (average = 0, standard deviation = 1).\n",
11 | "\n"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "5f56d2c7-5638-4830-a855-be239fe600f5",
17 | "metadata": {},
18 | "source": [
19 | "# Import toy regression data\n",
20 | "\n",
21 | "We import a toy regression dataset.\n",
22 | "The features consists of 1000 samples and 300 features, while the output consists of a single variable."
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "id": "0c43ee6a-358c-4b29-8a71-f657dc7ef8dd",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import numpy as np\n",
33 | "import pandas as pd\n",
34 | "\n",
35 | "X = pd.read_csv('data/toy-regression-features.csv')\n",
36 | "y = pd.read_csv('data/toy-regression-labels.csv')"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 2,
42 | "id": "a0784882-b8f1-45fc-9e06-7353f7e38e98",
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "name": "stdout",
47 | "output_type": "stream",
48 | "text": [
49 | "(1000, 300)\n"
50 | ]
51 | },
52 | {
53 | "data": {
54 | "text/html": [
55 | "
\n",
56 | "\n",
69 | "
\n",
70 | " \n",
71 | " \n",
72 | " | \n",
73 | " col_0 | \n",
74 | " col_1 | \n",
75 | " col_2 | \n",
76 | " col_3 | \n",
77 | " col_4 | \n",
78 | " col_5 | \n",
79 | " col_6 | \n",
80 | " col_7 | \n",
81 | " col_8 | \n",
82 | " col_9 | \n",
83 | " ... | \n",
84 | " col_290 | \n",
85 | " col_291 | \n",
86 | " col_292 | \n",
87 | " col_293 | \n",
88 | " col_294 | \n",
89 | " col_295 | \n",
90 | " col_296 | \n",
91 | " col_297 | \n",
92 | " col_298 | \n",
93 | " col_299 | \n",
94 | "
\n",
95 | " \n",
96 | " \n",
97 | " \n",
98 | " 0 | \n",
99 | " 0.537070 | \n",
100 | " -1.102140 | \n",
101 | " 1.614065 | \n",
102 | " 0.446704 | \n",
103 | " 2.150072 | \n",
104 | " 0.022795 | \n",
105 | " -0.062325 | \n",
106 | " 0.842601 | \n",
107 | " 1.337646 | \n",
108 | " 0.423004 | \n",
109 | " ... | \n",
110 | " -0.923334 | \n",
111 | " -0.481815 | \n",
112 | " 1.203645 | \n",
113 | " -1.458585 | \n",
114 | " -0.382055 | \n",
115 | " 0.778446 | \n",
116 | " 1.281670 | \n",
117 | " 0.083526 | \n",
118 | " 0.690047 | \n",
119 | " 0.117204 | \n",
120 | "
\n",
121 | " \n",
122 | " 1 | \n",
123 | " -0.291104 | \n",
124 | " 0.870662 | \n",
125 | " 0.989858 | \n",
126 | " 0.340181 | \n",
127 | " 0.462467 | \n",
128 | " -0.582148 | \n",
129 | " 1.888760 | \n",
130 | " 1.326881 | \n",
131 | " -1.654321 | \n",
132 | " -0.130696 | \n",
133 | " ... | \n",
134 | " 1.429273 | \n",
135 | " 0.845691 | \n",
136 | " -1.089257 | \n",
137 | " -0.918355 | \n",
138 | " 0.394018 | \n",
139 | " 2.608926 | \n",
140 | " -1.485463 | \n",
141 | " -0.907812 | \n",
142 | " -0.173660 | \n",
143 | " 0.920506 | \n",
144 | "
\n",
145 | " \n",
146 | " 2 | \n",
147 | " -0.623914 | \n",
148 | " 0.645679 | \n",
149 | " -0.603598 | \n",
150 | " -0.382241 | \n",
151 | " -1.038855 | \n",
152 | " 1.036846 | \n",
153 | " -0.411746 | \n",
154 | " 0.309138 | \n",
155 | " 0.377860 | \n",
156 | " 1.115033 | \n",
157 | " ... | \n",
158 | " -0.853996 | \n",
159 | " -1.977211 | \n",
160 | " -0.360070 | \n",
161 | " 0.457125 | \n",
162 | " -1.372804 | \n",
163 | " 0.320784 | \n",
164 | " -0.961563 | \n",
165 | " -0.203412 | \n",
166 | " 0.920264 | \n",
167 | " 0.799161 | \n",
168 | "
\n",
169 | " \n",
170 | " 3 | \n",
171 | " -0.007280 | \n",
172 | " -1.159284 | \n",
173 | " 1.205723 | \n",
174 | " -0.869215 | \n",
175 | " -0.571466 | \n",
176 | " 0.540196 | \n",
177 | " 0.656639 | \n",
178 | " 0.041661 | \n",
179 | " 0.244310 | \n",
180 | " -0.860549 | \n",
181 | " ... | \n",
182 | " 0.646255 | \n",
183 | " -0.762227 | \n",
184 | " -0.940969 | \n",
185 | " -0.889827 | \n",
186 | " -0.534136 | \n",
187 | " -0.649951 | \n",
188 | " -0.387092 | \n",
189 | " -1.089814 | \n",
190 | " 0.054935 | \n",
191 | " 0.955872 | \n",
192 | "
\n",
193 | " \n",
194 | " 4 | \n",
195 | " 0.578238 | \n",
196 | " -0.756635 | \n",
197 | " -0.768636 | \n",
198 | " 1.339886 | \n",
199 | " 0.612525 | \n",
200 | " -0.431343 | \n",
201 | " -0.058266 | \n",
202 | " 0.975151 | \n",
203 | " -1.992118 | \n",
204 | " 0.179272 | \n",
205 | " ... | \n",
206 | " 0.263658 | \n",
207 | " 0.837735 | \n",
208 | " 0.724682 | \n",
209 | " -2.493489 | \n",
210 | " -2.108600 | \n",
211 | " -1.646070 | \n",
212 | " -0.674911 | \n",
213 | " -0.344457 | \n",
214 | " -0.771689 | \n",
215 | " -0.691474 | \n",
216 | "
\n",
217 | " \n",
218 | "
\n",
219 | "
5 rows × 300 columns
\n",
220 | "
"
221 | ],
222 | "text/plain": [
223 | " col_0 col_1 col_2 col_3 col_4 col_5 col_6 \\\n",
224 | "0 0.537070 -1.102140 1.614065 0.446704 2.150072 0.022795 -0.062325 \n",
225 | "1 -0.291104 0.870662 0.989858 0.340181 0.462467 -0.582148 1.888760 \n",
226 | "2 -0.623914 0.645679 -0.603598 -0.382241 -1.038855 1.036846 -0.411746 \n",
227 | "3 -0.007280 -1.159284 1.205723 -0.869215 -0.571466 0.540196 0.656639 \n",
228 | "4 0.578238 -0.756635 -0.768636 1.339886 0.612525 -0.431343 -0.058266 \n",
229 | "\n",
230 | " col_7 col_8 col_9 ... col_290 col_291 col_292 col_293 \\\n",
231 | "0 0.842601 1.337646 0.423004 ... -0.923334 -0.481815 1.203645 -1.458585 \n",
232 | "1 1.326881 -1.654321 -0.130696 ... 1.429273 0.845691 -1.089257 -0.918355 \n",
233 | "2 0.309138 0.377860 1.115033 ... -0.853996 -1.977211 -0.360070 0.457125 \n",
234 | "3 0.041661 0.244310 -0.860549 ... 0.646255 -0.762227 -0.940969 -0.889827 \n",
235 | "4 0.975151 -1.992118 0.179272 ... 0.263658 0.837735 0.724682 -2.493489 \n",
236 | "\n",
237 | " col_294 col_295 col_296 col_297 col_298 col_299 \n",
238 | "0 -0.382055 0.778446 1.281670 0.083526 0.690047 0.117204 \n",
239 | "1 0.394018 2.608926 -1.485463 -0.907812 -0.173660 0.920506 \n",
240 | "2 -1.372804 0.320784 -0.961563 -0.203412 0.920264 0.799161 \n",
241 | "3 -0.534136 -0.649951 -0.387092 -1.089814 0.054935 0.955872 \n",
242 | "4 -2.108600 -1.646070 -0.674911 -0.344457 -0.771689 -0.691474 \n",
243 | "\n",
244 | "[5 rows x 300 columns]"
245 | ]
246 | },
247 | "execution_count": 2,
248 | "metadata": {},
249 | "output_type": "execute_result"
250 | }
251 | ],
252 | "source": [
253 | "print(X.shape)\n",
254 | "X.head()"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": 3,
260 | "id": "23443e5f-c1ec-422c-b894-b6da48b79c1c",
261 | "metadata": {},
262 | "outputs": [
263 | {
264 | "data": {
265 | "text/html": [
266 | "\n",
267 | "\n",
280 | "
\n",
281 | " \n",
282 | " \n",
283 | " | \n",
284 | " labels | \n",
285 | "
\n",
286 | " \n",
287 | " \n",
288 | " \n",
289 | " 0 | \n",
290 | " 219.664211 | \n",
291 | "
\n",
292 | " \n",
293 | " 1 | \n",
294 | " 40.998126 | \n",
295 | "
\n",
296 | " \n",
297 | " 2 | \n",
298 | " 109.423858 | \n",
299 | "
\n",
300 | " \n",
301 | " 3 | \n",
302 | " -280.667507 | \n",
303 | "
\n",
304 | " \n",
305 | " 4 | \n",
306 | " 599.178538 | \n",
307 | "
\n",
308 | " \n",
309 | "
\n",
310 | "
"
311 | ],
312 | "text/plain": [
313 | " labels\n",
314 | "0 219.664211\n",
315 | "1 40.998126\n",
316 | "2 109.423858\n",
317 | "3 -280.667507\n",
318 | "4 599.178538"
319 | ]
320 | },
321 | "execution_count": 3,
322 | "metadata": {},
323 | "output_type": "execute_result"
324 | }
325 | ],
326 | "source": [
327 | "y.head()"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "id": "584213bc-0455-443c-873f-f6ea190a4f83",
333 | "metadata": {},
334 | "source": [
335 | "# Use the library to select only the relevant features\n",
336 | "\n",
337 | "We use the library to select only the relevant features.\n",
338 | "A column containing gaussian noise (mean = 0, std. dev = 1) is created at each epoch. Then the feature importance is computed and all the features that are less important that the random one, are excluded.\n",
339 | "\n",
340 | "The process is repeated for the number of selected `epochs`. It is also possible to put an early stopping, by assigning to the parameter `patience` a value that is smaller than the number of epochs. If the number of selected features remains the same for a number of epochs equal to patience, the process stops. "
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 4,
346 | "id": "5381a35e-7993-49b5-b8f8-23faed331359",
347 | "metadata": {},
348 | "outputs": [
349 | {
350 | "name": "stdout",
351 | "output_type": "stream",
352 | "text": [
353 | "=====================EPOCH 1 =====================\n",
354 | "Fitting the model with 301 features\n",
355 | "Train score 0.9252\n",
356 | "Test score 0.8687\n",
357 | "Fitting the model with 301 features\n",
358 | "Train score 0.9302\n",
359 | "Test score 0.8272\n",
360 | "Fitting the model with 301 features\n",
361 | "Train score 0.926\n",
362 | "Test score 0.864\n",
363 | "Fitting the model with 301 features\n",
364 | "Train score 0.9237\n",
365 | "Test score 0.8608\n",
366 | "Fitting the model with 301 features\n",
367 | "Train score 0.9238\n",
368 | "Test score 0.8639\n",
369 | "Selected 286 features out of 300\n",
370 | "=====================EPOCH 2 =====================\n",
371 | "Fitting the model with 287 features\n",
372 | "Train score 0.9267\n",
373 | "Test score 0.8527\n",
374 | "Fitting the model with 287 features\n",
375 | "Train score 0.9271\n",
376 | "Test score 0.8579\n",
377 | "Fitting the model with 287 features\n",
378 | "Train score 0.9294\n",
379 | "Test score 0.8502\n",
380 | "Fitting the model with 287 features\n",
381 | "Train score 0.9228\n",
382 | "Test score 0.8675\n",
383 | "Fitting the model with 287 features\n",
384 | "Train score 0.9245\n",
385 | "Test score 0.8594\n",
386 | "Selected 110 features out of 286\n",
387 | "=====================EPOCH 3 =====================\n",
388 | "Fitting the model with 111 features\n",
389 | "Train score 0.9145\n",
390 | "Test score 0.8962\n",
391 | "Fitting the model with 111 features\n",
392 | "Train score 0.9209\n",
393 | "Test score 0.8703\n",
394 | "Fitting the model with 111 features\n",
395 | "Train score 0.9161\n",
396 | "Test score 0.8956\n",
397 | "Fitting the model with 111 features\n",
398 | "Train score 0.9188\n",
399 | "Test score 0.8793\n",
400 | "Fitting the model with 111 features\n",
401 | "Train score 0.9135\n",
402 | "Test score 0.907\n",
403 | "Selected 110 features out of 110\n",
404 | "The feature selection did not improve in the last 1 epochs\n",
405 | "=====================EPOCH 4 =====================\n",
406 | "Fitting the model with 111 features\n",
407 | "Train score 0.9186\n",
408 | "Test score 0.8801\n",
409 | "Fitting the model with 111 features\n",
410 | "Train score 0.9191\n",
411 | "Test score 0.879\n",
412 | "Fitting the model with 111 features\n",
413 | "Train score 0.9212\n",
414 | "Test score 0.8573\n",
415 | "Fitting the model with 111 features\n",
416 | "Train score 0.9139\n",
417 | "Test score 0.9048\n",
418 | "Fitting the model with 111 features\n",
419 | "Train score 0.9133\n",
420 | "Test score 0.9052\n",
421 | "Selected 110 features out of 110\n",
422 | "The feature selection did not improve in the last 2 epochs\n",
423 | "=====================EPOCH 5 =====================\n",
424 | "Fitting the model with 111 features\n",
425 | "Train score 0.9208\n",
426 | "Test score 0.8722\n",
427 | "Fitting the model with 111 features\n",
428 | "Train score 0.9189\n",
429 | "Test score 0.8778\n",
430 | "Fitting the model with 111 features\n",
431 | "Train score 0.9169\n",
432 | "Test score 0.8905\n",
433 | "Fitting the model with 111 features\n",
434 | "Train score 0.914\n",
435 | "Test score 0.9026\n",
436 | "Fitting the model with 111 features\n",
437 | "Train score 0.9144\n",
438 | "Test score 0.9004\n",
439 | "Selected 109 features out of 110\n",
440 | "=====================EPOCH 6 =====================\n",
441 | "Fitting the model with 110 features\n",
442 | "Train score 0.9156\n",
443 | "Test score 0.8955\n",
444 | "Fitting the model with 110 features\n",
445 | "Train score 0.9162\n",
446 | "Test score 0.8835\n",
447 | "Fitting the model with 110 features\n",
448 | "Train score 0.9187\n",
449 | "Test score 0.8874\n",
450 | "Fitting the model with 110 features\n",
451 | "Train score 0.9165\n",
452 | "Test score 0.891\n",
453 | "Fitting the model with 110 features\n",
454 | "Train score 0.919\n",
455 | "Test score 0.8763\n",
456 | "Selected 109 features out of 109\n",
457 | "The feature selection did not improve in the last 1 epochs\n",
458 | "=====================EPOCH 7 =====================\n",
459 | "Fitting the model with 110 features\n",
460 | "Train score 0.9167\n",
461 | "Test score 0.8909\n",
462 | "Fitting the model with 110 features\n",
463 | "Train score 0.9115\n",
464 | "Test score 0.9116\n",
465 | "Fitting the model with 110 features\n",
466 | "Train score 0.9187\n",
467 | "Test score 0.8824\n",
468 | "Fitting the model with 110 features\n",
469 | "Train score 0.9188\n",
470 | "Test score 0.8862\n",
471 | "Fitting the model with 110 features\n",
472 | "Train score 0.9173\n",
473 | "Test score 0.8895\n",
474 | "Selected 109 features out of 109\n",
475 | "The feature selection did not improve in the last 2 epochs\n",
476 | "=====================EPOCH 8 =====================\n",
477 | "Fitting the model with 110 features\n",
478 | "Train score 0.9147\n",
479 | "Test score 0.9004\n",
480 | "Fitting the model with 110 features\n",
481 | "Train score 0.9195\n",
482 | "Test score 0.8733\n",
483 | "Fitting the model with 110 features\n",
484 | "Train score 0.9159\n",
485 | "Test score 0.893\n",
486 | "Fitting the model with 110 features\n",
487 | "Train score 0.9187\n",
488 | "Test score 0.8793\n",
489 | "Fitting the model with 110 features\n",
490 | "Train score 0.9161\n",
491 | "Test score 0.8979\n",
492 | "Selected 109 features out of 109\n",
493 | "The feature selection did not improve in the last 3 epochs\n",
494 | "=====================EPOCH 9 =====================\n",
495 | "Fitting the model with 110 features\n",
496 | "Train score 0.9153\n",
497 | "Test score 0.8915\n",
498 | "Fitting the model with 110 features\n",
499 | "Train score 0.9122\n",
500 | "Test score 0.9122\n",
501 | "Fitting the model with 110 features\n",
502 | "Train score 0.9176\n",
503 | "Test score 0.8872\n",
504 | "Fitting the model with 110 features\n",
505 | "Train score 0.9196\n",
506 | "Test score 0.8751\n",
507 | "Fitting the model with 110 features\n",
508 | "Train score 0.9193\n",
509 | "Test score 0.8746\n",
510 | "Selected 109 features out of 109\n",
511 | "The feature selection did not improve in the last 4 epochs\n",
512 | "=====================EPOCH 10 =====================\n",
513 | "Fitting the model with 110 features\n",
514 | "Train score 0.9192\n",
515 | "Test score 0.8761\n",
516 | "Fitting the model with 110 features\n",
517 | "Train score 0.9195\n",
518 | "Test score 0.8733\n",
519 | "Fitting the model with 110 features\n",
520 | "Train score 0.9166\n",
521 | "Test score 0.8916\n",
522 | "Fitting the model with 110 features\n",
523 | "Train score 0.9141\n",
524 | "Test score 0.9037\n",
525 | "Fitting the model with 110 features\n",
526 | "Train score 0.915\n",
527 | "Test score 0.899\n",
528 | "Selected 109 features out of 109\n",
529 | "The feature selection did not improve in the last 5 epochs\n",
530 | "=====================EPOCH 11 =====================\n",
531 | "Fitting the model with 110 features\n",
532 | "Train score 0.9203\n",
533 | "Test score 0.8697\n",
534 | "Fitting the model with 110 features\n",
535 | "Train score 0.912\n",
536 | "Test score 0.9121\n",
537 | "Fitting the model with 110 features\n",
538 | "Train score 0.9114\n",
539 | "Test score 0.9132\n",
540 | "Fitting the model with 110 features\n",
541 | "Train score 0.9217\n",
542 | "Test score 0.8598\n",
543 | "Fitting the model with 110 features\n",
544 | "Train score 0.9184\n",
545 | "Test score 0.8842\n",
546 | "Selected 109 features out of 109\n",
547 | "The feature selection did not improve in the last 6 epochs\n",
548 | "=====================EPOCH 12 =====================\n",
549 | "Fitting the model with 110 features\n",
550 | "Train score 0.9154\n",
551 | "Test score 0.8939\n",
552 | "Fitting the model with 110 features\n",
553 | "Train score 0.9175\n",
554 | "Test score 0.878\n",
555 | "Fitting the model with 110 features\n",
556 | "Train score 0.9168\n",
557 | "Test score 0.8906\n",
558 | "Fitting the model with 110 features\n",
559 | "Train score 0.9191\n",
560 | "Test score 0.8852\n",
561 | "Fitting the model with 110 features\n",
562 | "Train score 0.9165\n",
563 | "Test score 0.892\n",
564 | "Selected 107 features out of 109\n",
565 | "=====================EPOCH 13 =====================\n",
566 | "Fitting the model with 108 features\n",
567 | "Train score 0.9183\n",
568 | "Test score 0.887\n",
569 | "Fitting the model with 108 features\n",
570 | "Train score 0.9184\n",
571 | "Test score 0.8826\n",
572 | "Fitting the model with 108 features\n",
573 | "Train score 0.9109\n",
574 | "Test score 0.9111\n",
575 | "Fitting the model with 108 features\n",
576 | "Train score 0.9199\n",
577 | "Test score 0.8789\n",
578 | "Fitting the model with 108 features\n",
579 | "Train score 0.9147\n",
580 | "Test score 0.8987\n",
581 | "Selected 106 features out of 107\n",
582 | "=====================EPOCH 14 =====================\n",
583 | "Fitting the model with 107 features\n",
584 | "Train score 0.9168\n",
585 | "Test score 0.8898\n",
586 | "Fitting the model with 107 features\n",
587 | "Train score 0.9141\n",
588 | "Test score 0.898\n",
589 | "Fitting the model with 107 features\n",
590 | "Train score 0.9157\n",
591 | "Test score 0.8899\n",
592 | "Fitting the model with 107 features\n",
593 | "Train score 0.9177\n",
594 | "Test score 0.8791\n",
595 | "Fitting the model with 107 features\n",
596 | "Train score 0.9192\n",
597 | "Test score 0.8762\n",
598 | "Selected 106 features out of 106\n",
599 | "The feature selection did not improve in the last 1 epochs\n",
600 | "=====================EPOCH 15 =====================\n",
601 | "Fitting the model with 107 features\n",
602 | "Train score 0.9176\n",
603 | "Test score 0.8826\n",
604 | "Fitting the model with 107 features\n",
605 | "Train score 0.9157\n",
606 | "Test score 0.8898\n",
607 | "Fitting the model with 107 features\n",
608 | "Train score 0.9167\n",
609 | "Test score 0.8908\n",
610 | "Fitting the model with 107 features\n",
611 | "Train score 0.9187\n",
612 | "Test score 0.8814\n",
613 | "Fitting the model with 107 features\n",
614 | "Train score 0.9162\n",
615 | "Test score 0.8945\n",
616 | "Selected 95 features out of 106\n",
617 | "=====================EPOCH 16 =====================\n",
618 | "Fitting the model with 96 features\n",
619 | "Train score 0.9148\n",
620 | "Test score 0.8947\n",
621 | "Fitting the model with 96 features\n",
622 | "Train score 0.9152\n",
623 | "Test score 0.892\n",
624 | "Fitting the model with 96 features\n",
625 | "Train score 0.9117\n",
626 | "Test score 0.9098\n",
627 | "Fitting the model with 96 features\n",
628 | "Train score 0.9177\n",
629 | "Test score 0.8805\n",
630 | "Fitting the model with 96 features\n",
631 | "Train score 0.918\n",
632 | "Test score 0.8736\n",
633 | "Selected 95 features out of 95\n",
634 | "The feature selection did not improve in the last 1 epochs\n",
635 | "=====================EPOCH 17 =====================\n",
636 | "Fitting the model with 96 features\n",
637 | "Train score 0.9177\n",
638 | "Test score 0.8767\n",
639 | "Fitting the model with 96 features\n",
640 | "Train score 0.9158\n",
641 | "Test score 0.889\n",
642 | "Fitting the model with 96 features\n",
643 | "Train score 0.9178\n",
644 | "Test score 0.8789\n",
645 | "Fitting the model with 96 features\n",
646 | "Train score 0.9118\n",
647 | "Test score 0.9098\n",
648 | "Fitting the model with 96 features\n",
649 | "Train score 0.9154\n",
650 | "Test score 0.891\n",
651 | "Selected 95 features out of 95\n",
652 | "The feature selection did not improve in the last 2 epochs\n",
653 | "=====================EPOCH 18 =====================\n",
654 | "Fitting the model with 96 features\n",
655 | "Train score 0.913\n",
656 | "Test score 0.8988\n",
657 | "Fitting the model with 96 features\n",
658 | "Train score 0.9172\n",
659 | "Test score 0.8844\n",
660 | "Fitting the model with 96 features\n",
661 | "Train score 0.9147\n",
662 | "Test score 0.8959\n",
663 | "Fitting the model with 96 features\n",
664 | "Train score 0.9171\n",
665 | "Test score 0.8779\n",
666 | "Fitting the model with 96 features\n",
667 | "Train score 0.9152\n",
668 | "Test score 0.8911\n",
669 | "Selected 95 features out of 95\n",
670 | "The feature selection did not improve in the last 3 epochs\n",
671 | "=====================EPOCH 19 =====================\n",
672 | "Fitting the model with 96 features\n",
673 | "Train score 0.9194\n",
674 | "Test score 0.8763\n",
675 | "Fitting the model with 96 features\n",
676 | "Train score 0.9142\n",
677 | "Test score 0.8969\n",
678 | "Fitting the model with 96 features\n",
679 | "Train score 0.9127\n",
680 | "Test score 0.9045\n",
681 | "Fitting the model with 96 features\n",
682 | "Train score 0.9122\n",
683 | "Test score 0.899\n",
684 | "Fitting the model with 96 features\n",
685 | "Train score 0.9171\n",
686 | "Test score 0.8837\n",
687 | "Selected 95 features out of 95\n",
688 | "The feature selection did not improve in the last 4 epochs\n",
689 | "=====================EPOCH 20 =====================\n",
690 | "Fitting the model with 96 features\n",
691 | "Train score 0.9164\n",
692 | "Test score 0.8848\n",
693 | "Fitting the model with 96 features\n",
694 | "Train score 0.9159\n",
695 | "Test score 0.8899\n",
696 | "Fitting the model with 96 features\n",
697 | "Train score 0.9171\n",
698 | "Test score 0.8875\n",
699 | "Fitting the model with 96 features\n",
700 | "Train score 0.9187\n",
701 | "Test score 0.8805\n",
702 | "Fitting the model with 96 features\n",
703 | "Train score 0.9116\n",
704 | "Test score 0.9067\n",
705 | "Selected 25 features out of 95\n",
706 | "=====================EPOCH 21 =====================\n",
707 | "Fitting the model with 26 features\n",
708 | "Train score 0.9049\n",
709 | "Test score 0.8632\n",
710 | "Fitting the model with 26 features\n",
711 | "Train score 0.8987\n",
712 | "Test score 0.8921\n",
713 | "Fitting the model with 26 features\n",
714 | "Train score 0.9031\n",
715 | "Test score 0.8715\n",
716 | "Fitting the model with 26 features\n",
717 | "Train score 0.8997\n",
718 | "Test score 0.8882\n",
719 | "Fitting the model with 26 features\n",
720 | "Train score 0.8887\n",
721 | "Test score 0.9273\n",
722 | "Selected 25 features out of 25\n",
723 | "The feature selection did not improve in the last 1 epochs\n",
724 | "=====================EPOCH 22 =====================\n",
725 | "Fitting the model with 26 features\n",
726 | "Train score 0.9028\n",
727 | "Test score 0.8746\n",
728 | "Fitting the model with 26 features\n",
729 | "Train score 0.8923\n",
730 | "Test score 0.9145\n",
731 | "Fitting the model with 26 features\n",
732 | "Train score 0.8998\n",
733 | "Test score 0.8894\n",
734 | "Fitting the model with 26 features\n",
735 | "Train score 0.8998\n",
736 | "Test score 0.8897\n",
737 | "Fitting the model with 26 features\n",
738 | "Train score 0.8999\n",
739 | "Test score 0.8887\n",
740 | "Selected 25 features out of 25\n",
741 | "The feature selection did not improve in the last 2 epochs\n",
742 | "=====================EPOCH 23 =====================\n",
743 | "Fitting the model with 26 features\n",
744 | "Train score 0.895\n",
745 | "Test score 0.91\n",
746 | "Fitting the model with 26 features\n",
747 | "Train score 0.8962\n",
748 | "Test score 0.9039\n",
749 | "Fitting the model with 26 features\n",
750 | "Train score 0.8997\n",
751 | "Test score 0.8881\n",
752 | "Fitting the model with 26 features\n",
753 | "Train score 0.9022\n",
754 | "Test score 0.88\n",
755 | "Fitting the model with 26 features\n",
756 | "Train score 0.9029\n",
757 | "Test score 0.8729\n",
758 | "Selected 25 features out of 25\n",
759 | "The feature selection did not improve in the last 3 epochs\n",
760 | "=====================EPOCH 24 =====================\n",
761 | "Fitting the model with 26 features\n",
762 | "Train score 0.9008\n",
763 | "Test score 0.8872\n",
764 | "Fitting the model with 26 features\n",
765 | "Train score 0.8986\n",
766 | "Test score 0.8936\n",
767 | "Fitting the model with 26 features\n",
768 | "Train score 0.9037\n",
769 | "Test score 0.8698\n",
770 | "Fitting the model with 26 features\n",
771 | "Train score 0.892\n",
772 | "Test score 0.9178\n",
773 | "Fitting the model with 26 features\n",
774 | "Train score 0.8999\n",
775 | "Test score 0.8872\n",
776 | "Selected 25 features out of 25\n",
777 | "The feature selection did not improve in the last 4 epochs\n",
778 | "=====================EPOCH 25 =====================\n",
779 | "Fitting the model with 26 features\n",
780 | "Train score 0.8958\n",
781 | "Test score 0.9047\n",
782 | "Fitting the model with 26 features\n",
783 | "Train score 0.9008\n",
784 | "Test score 0.8845\n",
785 | "Fitting the model with 26 features\n",
786 | "Train score 0.8935\n",
787 | "Test score 0.9164\n",
788 | "Fitting the model with 26 features\n",
789 | "Train score 0.9012\n",
790 | "Test score 0.8853\n",
791 | "Fitting the model with 26 features\n",
792 | "Train score 0.904\n",
793 | "Test score 0.8682\n",
794 | "Selected 25 features out of 25\n",
795 | "The feature selection did not improve in the last 5 epochs\n",
796 | "=====================EPOCH 26 =====================\n",
797 | "Fitting the model with 26 features\n",
798 | "Train score 0.9\n",
799 | "Test score 0.8883\n",
800 | "Fitting the model with 26 features\n",
801 | "Train score 0.9015\n",
802 | "Test score 0.8834\n",
803 | "Fitting the model with 26 features\n",
804 | "Train score 0.8981\n",
805 | "Test score 0.8957\n",
806 | "Fitting the model with 26 features\n",
807 | "Train score 0.8987\n",
808 | "Test score 0.8939\n",
809 | "Fitting the model with 26 features\n",
810 | "Train score 0.8986\n",
811 | "Test score 0.8947\n",
812 | "Selected 25 features out of 25\n",
813 | "The feature selection did not improve in the last 6 epochs\n",
814 | "=====================EPOCH 27 =====================\n",
815 | "Fitting the model with 26 features\n",
816 | "Train score 0.9007\n",
817 | "Test score 0.8812\n",
818 | "Fitting the model with 26 features\n",
819 | "Train score 0.8963\n",
820 | "Test score 0.9039\n",
821 | "Fitting the model with 26 features\n",
822 | "Train score 0.9007\n",
823 | "Test score 0.8857\n",
824 | "Fitting the model with 26 features\n",
825 | "Train score 0.8971\n",
826 | "Test score 0.8999\n",
827 | "Fitting the model with 26 features\n",
828 | "Train score 0.9013\n",
829 | "Test score 0.8816\n",
830 | "Selected 25 features out of 25\n",
831 | "The feature selection did not improve in the last 7 epochs\n",
832 | "=====================EPOCH 28 =====================\n",
833 | "Fitting the model with 26 features\n",
834 | "Train score 0.9005\n",
835 | "Test score 0.8873\n",
836 | "Fitting the model with 26 features\n",
837 | "Train score 0.8973\n",
838 | "Test score 0.8975\n",
839 | "Fitting the model with 26 features\n",
840 | "Train score 0.9051\n",
841 | "Test score 0.8591\n",
842 | "Fitting the model with 26 features\n",
843 | "Train score 0.901\n",
844 | "Test score 0.8853\n",
845 | "Fitting the model with 26 features\n",
846 | "Train score 0.8913\n",
847 | "Test score 0.9177\n",
848 | "Selected 25 features out of 25\n",
849 | "The feature selection did not improve in the last 8 epochs\n",
850 | "=====================EPOCH 29 =====================\n",
851 | "Fitting the model with 26 features\n",
852 | "Train score 0.9\n",
853 | "Test score 0.8884\n",
854 | "Fitting the model with 26 features\n",
855 | "Train score 0.8975\n",
856 | "Test score 0.9002\n",
857 | "Fitting the model with 26 features\n",
858 | "Train score 0.8985\n",
859 | "Test score 0.8945\n",
860 | "Fitting the model with 26 features\n",
861 | "Train score 0.8974\n",
862 | "Test score 0.899\n",
863 | "Fitting the model with 26 features\n",
864 | "Train score 0.9037\n",
865 | "Test score 0.8706\n",
866 | "Selected 25 features out of 25\n",
867 | "The feature selection did not improve in the last 9 epochs\n",
868 | "=====================EPOCH 30 =====================\n",
869 | "Fitting the model with 26 features\n",
870 | "Train score 0.8977\n",
871 | "Test score 0.9002\n",
872 | "Fitting the model with 26 features\n",
873 | "Train score 0.8962\n",
874 | "Test score 0.9008\n",
875 | "Fitting the model with 26 features\n",
876 | "Train score 0.9006\n",
877 | "Test score 0.8849\n",
878 | "Fitting the model with 26 features\n",
879 | "Train score 0.9029\n",
880 | "Test score 0.8715\n",
881 | "Fitting the model with 26 features\n",
882 | "Train score 0.8989\n",
883 | "Test score 0.8896\n",
884 | "Selected 25 features out of 25\n",
885 | "The feature selection did not improve in the last 10 epochs\n"
886 | ]
887 | }
888 | ],
889 | "source": [
890 | "from src.ml import get_relevant_features\n",
891 | "from sklearn.linear_model import Lasso\n",
892 | "\n",
893 | "lasso_model = Lasso(alpha=1)\n",
894 | "\n",
895 | "X_reduced = get_relevant_features(X, y, \n",
896 | " model=lasso_model, \n",
897 | " epochs=100, \n",
898 | " patience=10, \n",
899 | " splitting_type='kfold',\n",
900 | " noise_type='gaussian',\n",
901 | " filename_output='results/reduced_dataset.csv',\n",
902 | " random_state=42)"
903 | ]
904 | },
905 | {
906 | "cell_type": "markdown",
907 | "id": "4dd5b308-86bf-4e0f-a38b-0eb4cba2612f",
908 | "metadata": {},
909 | "source": [
910 | "# Inspect the new (reduced) dataset\n",
911 | "\n",
912 | "The new reduced dataset contains only a subset of features, namely the most relevant ones "
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "execution_count": 5,
918 | "id": "e7f4e2f4-eaab-4700-ab75-6d1356162917",
919 | "metadata": {},
920 | "outputs": [
921 | {
922 | "data": {
923 | "text/html": [
924 | "\n",
925 | "\n",
938 | "
\n",
939 | " \n",
940 | " \n",
941 | " | \n",
942 | " col_168 | \n",
943 | " col_277 | \n",
944 | " col_88 | \n",
945 | " col_283 | \n",
946 | " col_8 | \n",
947 | " col_270 | \n",
948 | " col_258 | \n",
949 | " col_187 | \n",
950 | " col_76 | \n",
951 | " col_171 | \n",
952 | " ... | \n",
953 | " col_119 | \n",
954 | " col_174 | \n",
955 | " col_250 | \n",
956 | " col_24 | \n",
957 | " col_25 | \n",
958 | " col_55 | \n",
959 | " col_274 | \n",
960 | " col_47 | \n",
961 | " col_45 | \n",
962 | " col_154 | \n",
963 | "
\n",
964 | " \n",
965 | " \n",
966 | " \n",
967 | " 0 | \n",
968 | " -0.003067 | \n",
969 | " -0.003618 | \n",
970 | " -1.019952 | \n",
971 | " -2.092326 | \n",
972 | " 1.337646 | \n",
973 | " 0.220574 | \n",
974 | " 1.673497 | \n",
975 | " 0.236082 | \n",
976 | " 0.656732 | \n",
977 | " -0.878309 | \n",
978 | " ... | \n",
979 | " -0.645682 | \n",
980 | " 0.577020 | \n",
981 | " 1.053236 | \n",
982 | " -0.309741 | \n",
983 | " 1.178854 | \n",
984 | " 0.800531 | \n",
985 | " 0.692207 | \n",
986 | " 1.696463 | \n",
987 | " 1.498433 | \n",
988 | " -1.750620 | \n",
989 | "
\n",
990 | " \n",
991 | " 1 | \n",
992 | " 0.030676 | \n",
993 | " 0.259951 | \n",
994 | " 0.920548 | \n",
995 | " 0.685426 | \n",
996 | " -1.654321 | \n",
997 | " -0.041250 | \n",
998 | " -0.051170 | \n",
999 | " 0.853362 | \n",
1000 | " -0.367285 | \n",
1001 | " -0.343505 | \n",
1002 | " ... | \n",
1003 | " 2.052382 | \n",
1004 | " 0.653817 | \n",
1005 | " -1.545429 | \n",
1006 | " 0.830601 | \n",
1007 | " -0.570561 | \n",
1008 | " -0.017992 | \n",
1009 | " -0.808267 | \n",
1010 | " 0.063894 | \n",
1011 | " -1.335906 | \n",
1012 | " -1.292565 | \n",
1013 | "
\n",
1014 | " \n",
1015 | " 2 | \n",
1016 | " 0.459295 | \n",
1017 | " -1.351804 | \n",
1018 | " -3.277806 | \n",
1019 | " 0.933398 | \n",
1020 | " 0.377860 | \n",
1021 | " 0.758731 | \n",
1022 | " 0.971993 | \n",
1023 | " -0.739010 | \n",
1024 | " 0.200066 | \n",
1025 | " -0.076279 | \n",
1026 | " ... | \n",
1027 | " -0.745177 | \n",
1028 | " -0.260010 | \n",
1029 | " 2.176639 | \n",
1030 | " -1.145208 | \n",
1031 | " 1.957113 | \n",
1032 | " 0.742541 | \n",
1033 | " -1.399284 | \n",
1034 | " -1.426165 | \n",
1035 | " -0.852144 | \n",
1036 | " 1.935704 | \n",
1037 | "
\n",
1038 | " \n",
1039 | " 3 | \n",
1040 | " 0.002783 | \n",
1041 | " -0.413398 | \n",
1042 | " -0.936796 | \n",
1043 | " 1.153329 | \n",
1044 | " 0.244310 | \n",
1045 | " -0.781265 | \n",
1046 | " 0.744198 | \n",
1047 | " 1.123666 | \n",
1048 | " 0.328173 | \n",
1049 | " -0.379941 | \n",
1050 | " ... | \n",
1051 | " -1.638474 | \n",
1052 | " -0.077514 | \n",
1053 | " -2.033497 | \n",
1054 | " -1.016165 | \n",
1055 | " 0.245085 | \n",
1056 | " -0.630077 | \n",
1057 | " 0.221333 | \n",
1058 | " 0.356726 | \n",
1059 | " 0.389595 | \n",
1060 | " 0.210160 | \n",
1061 | "
\n",
1062 | " \n",
1063 | " 4 | \n",
1064 | " 0.054107 | \n",
1065 | " 3.081069 | \n",
1066 | " 0.842412 | \n",
1067 | " -0.148948 | \n",
1068 | " -1.992118 | \n",
1069 | " 0.513912 | \n",
1070 | " 0.093666 | \n",
1071 | " 2.831712 | \n",
1072 | " 1.261611 | \n",
1073 | " 0.741283 | \n",
1074 | " ... | \n",
1075 | " -0.557619 | \n",
1076 | " -0.676896 | \n",
1077 | " 0.351711 | \n",
1078 | " 2.785978 | \n",
1079 | " -0.719393 | \n",
1080 | " -1.068996 | \n",
1081 | " 0.234328 | \n",
1082 | " -0.817592 | \n",
1083 | " -0.178059 | \n",
1084 | " -0.086358 | \n",
1085 | "
\n",
1086 | " \n",
1087 | "
\n",
1088 | "
5 rows × 25 columns
\n",
1089 | "
"
1090 | ],
1091 | "text/plain": [
1092 | " col_168 col_277 col_88 col_283 col_8 col_270 col_258 \\\n",
1093 | "0 -0.003067 -0.003618 -1.019952 -2.092326 1.337646 0.220574 1.673497 \n",
1094 | "1 0.030676 0.259951 0.920548 0.685426 -1.654321 -0.041250 -0.051170 \n",
1095 | "2 0.459295 -1.351804 -3.277806 0.933398 0.377860 0.758731 0.971993 \n",
1096 | "3 0.002783 -0.413398 -0.936796 1.153329 0.244310 -0.781265 0.744198 \n",
1097 | "4 0.054107 3.081069 0.842412 -0.148948 -1.992118 0.513912 0.093666 \n",
1098 | "\n",
1099 | " col_187 col_76 col_171 ... col_119 col_174 col_250 col_24 \\\n",
1100 | "0 0.236082 0.656732 -0.878309 ... -0.645682 0.577020 1.053236 -0.309741 \n",
1101 | "1 0.853362 -0.367285 -0.343505 ... 2.052382 0.653817 -1.545429 0.830601 \n",
1102 | "2 -0.739010 0.200066 -0.076279 ... -0.745177 -0.260010 2.176639 -1.145208 \n",
1103 | "3 1.123666 0.328173 -0.379941 ... -1.638474 -0.077514 -2.033497 -1.016165 \n",
1104 | "4 2.831712 1.261611 0.741283 ... -0.557619 -0.676896 0.351711 2.785978 \n",
1105 | "\n",
1106 | " col_25 col_55 col_274 col_47 col_45 col_154 \n",
1107 | "0 1.178854 0.800531 0.692207 1.696463 1.498433 -1.750620 \n",
1108 | "1 -0.570561 -0.017992 -0.808267 0.063894 -1.335906 -1.292565 \n",
1109 | "2 1.957113 0.742541 -1.399284 -1.426165 -0.852144 1.935704 \n",
1110 | "3 0.245085 -0.630077 0.221333 0.356726 0.389595 0.210160 \n",
1111 | "4 -0.719393 -1.068996 0.234328 -0.817592 -0.178059 -0.086358 \n",
1112 | "\n",
1113 | "[5 rows x 25 columns]"
1114 | ]
1115 | },
1116 | "execution_count": 5,
1117 | "metadata": {},
1118 | "output_type": "execute_result"
1119 | }
1120 | ],
1121 | "source": [
1122 | "X_reduced.head()"
1123 | ]
1124 | },
1125 | {
1126 | "cell_type": "code",
1127 | "execution_count": null,
1128 | "id": "debafd0f-0ced-426b-99e2-2b4c85beaeb8",
1129 | "metadata": {},
1130 | "outputs": [],
1131 | "source": []
1132 | }
1133 | ],
1134 | "metadata": {
1135 | "kernelspec": {
1136 | "display_name": "Python 3 (ipykernel)",
1137 | "language": "python",
1138 | "name": "python3"
1139 | },
1140 | "language_info": {
1141 | "codemirror_mode": {
1142 | "name": "ipython",
1143 | "version": 3
1144 | },
1145 | "file_extension": ".py",
1146 | "mimetype": "text/x-python",
1147 | "name": "python",
1148 | "nbconvert_exporter": "python",
1149 | "pygments_lexer": "ipython3",
1150 | "version": "3.9.7"
1151 | }
1152 | },
1153 | "nbformat": 4,
1154 | "nbformat_minor": 5
1155 | }
1156 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.20.3
2 | pandas==1.3.4
3 | scipy==1.7.1
4 | scikit-learn==0.24.2
5 | jupyterlab==3.2.1
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apalladi/feature-selection-adding-noise/88e1c418f9e456986914d01a90534074b90e98f8/src/__init__.py
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | """This module produces toy dataset to be used for regression problems"""
2 |
3 | import pandas as pd
4 | from sklearn.datasets import make_regression
5 |
6 |
7 | def create_regression_data():
8 | """This function produces a toy dataset, to be used in a regression problem.
9 | It takes no input and it gives as output features and labels."""
10 |
11 | features, labels = make_regression(
12 | n_samples=1000, n_features=300, n_informative=20, n_targets=1, noise=100
13 | )
14 |
15 | col_names = ["col_" + str(i) for i in range(features.shape[1])]
16 | features = pd.DataFrame(features, columns=col_names)
17 | labels = pd.DataFrame(labels, columns=["labels"])
18 | return features, labels
19 |
--------------------------------------------------------------------------------
/src/ml.py:
--------------------------------------------------------------------------------
1 | """This module contains the function to perform the feature selection,
2 | by adding random noise"""
3 |
4 | from typing import Tuple, List, Optional
5 | import numpy as np
6 | import pandas as pd
7 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
8 | from sklearn.model_selection import KFold, train_test_split
9 | from sklearn.base import BaseEstimator
10 |
11 |
12 | def train_evaluate_model(
13 | x_train: pd.DataFrame,
14 | y_train: pd.DataFrame,
15 | x_test: pd.DataFrame,
16 | y_test: pd.DataFrame,
17 | model: BaseEstimator,
18 | scaler_type: BaseEstimator,
19 | verbose: bool,
20 | ) -> BaseEstimator:
21 | """It trains and evaluate the machine learning model.
22 |
23 | Parameters:
24 | - x_train: training features
25 | - y_train: training labels
26 | - x_test: test features
27 | - y_test: test labels
28 | - model: a scikit-learn machine learning (untrained) model
29 | - scaler_type: choose between StandardScaler or MinMaxScaler
30 | - verbose: True of False to tune the level of verbosity
31 |
32 | Return:
33 | - the trained model
34 | """
35 |
36 | # scale data
37 | if scaler_type == "StandardScaler":
38 | scaler = StandardScaler()
39 | elif scaler_type == "MinMaxScaler":
40 | scaler = MinMaxScaler()
41 | else:
42 | raise ValueError(
43 | "Allowed values for scaler_type are StandardScaler and MinMaxScaler"
44 | )
45 |
46 | x_train = scaler.fit_transform(x_train)
47 | x_test = scaler.transform(x_test)
48 |
49 | # fit model
50 | if verbose:
51 | print("Fitting the model with", x_train.shape[1], "features")
52 | model.fit(x_train, y_train)
53 | train_score = round(model.score(x_train, y_train), 4)
54 | test_score = round(model.score(x_test, y_test), 4)
55 |
56 | if verbose:
57 | print("Train score", train_score)
58 | print("Test score", test_score)
59 |
60 | return model
61 |
62 |
63 | def get_feature_importances(
64 | trained_model: BaseEstimator, column_names: List[str]
65 | ) -> pd.DataFrame:
66 | """It computes the features importance, given a trained model.
67 |
68 | Parameters:
69 | - trained_model: a scikit-learn ML trained model
70 | - column_names: the name of the columns associated to the features
71 |
72 | Return:
73 | - a DataFrame containing the feature importance (not sorted) as column and
74 | the name of the features as index
75 | """
76 |
77 | # inspect coefficients
78 | if hasattr(trained_model, "coef_"):
79 | model_coefficients = trained_model.coef_
80 | elif hasattr(trained_model, "feature_importances_"):
81 | model_coefficients = trained_model.feature_importances_
82 | else:
83 | raise ValueError("Could not retrieve the feature importance")
84 |
85 | df_coef = pd.DataFrame(model_coefficients, index=column_names)
86 |
87 | return df_coef
88 |
89 |
90 | def compute_mean_coefficients(df_coefs: pd.DataFrame) -> pd.DataFrame:
91 | """It computes the average coefficients, given a DataFrame with multiple columns.
92 |
93 | Parameters:
94 | - a DataFrame with coefficients obtained in multiple trainings
95 |
96 | Return:
97 | - a DataFrame with one column, containing the absolute values of the average coefficients
98 | """
99 |
100 | if df_coefs.shape[1] > 1:
101 | df_coef = pd.DataFrame(df_coefs.mean(axis=1), columns=["Feature importance"])
102 | else:
103 | print("Using this one")
104 | df_coef = pd.DataFrame(df_coefs.iloc[:, 0], index=df_coefs.index)
105 | df_coef.columns = ["Feature importance"]
106 |
107 | df_coef["Feature importance"] = np.abs(df_coef["Feature importance"])
108 | df_coef["Feature name"] = df_coef.index
109 | df_coef = df_coef.sort_values("Feature importance", ascending=False)
110 | df_coef.reset_index(inplace=True, drop=True)
111 |
112 | return df_coef
113 |
114 |
115 | def select_relevant_features(
116 | df_coef: pd.DataFrame, features: pd.DataFrame, verbose: bool
117 | ) -> pd.DataFrame:
118 | """It computes the relevant features, given the DataFrame with feature importance
119 | and the original features.
120 | This is obtained by adding a feature with random noise.
121 |
122 | Parameters:
123 | - df_coef: the DataFrame with the the feature importance
124 | - features: the original features
125 | - verbose: True or False to tune the level of verbosity
126 |
127 | Return:
128 | - the simplified dataset, with the relevant features
129 | """
130 |
131 | # select relevant features
132 | index_threshold = np.array(
133 | df_coef[df_coef["Feature name"] == "random_feature"].index
134 | )[0]
135 | relevant_features = df_coef.iloc[0:index_threshold]
136 | relevant_features = relevant_features["Feature name"]
137 |
138 | if verbose:
139 | print(
140 | "Selected", len(relevant_features), "features out of", features.shape[1] - 1
141 | )
142 |
143 | # return simplified dataset, containing only relevant features
144 | simplified_dataset = features.loc[:, relevant_features]
145 |
146 | return simplified_dataset
147 |
148 |
149 | def generate_kfold_data(
150 | features: pd.DataFrame, labels: pd.DataFrame, random_state: int
151 | ) -> Tuple[List, List, List, List]:
152 | """It splits the data into training and validation,
153 | by using the KFold splitting method.
154 |
155 | Parameters:
156 | - features: the matrix with features, commonly called X
157 | - labels: the vector with labels, commonly called y
158 |
159 | Return:
160 | - train and test data
161 | """
162 |
163 | x_trains = []
164 | y_trains = []
165 | x_tests = []
166 | y_tests = []
167 |
168 | k_fold = KFold(n_splits=5, random_state=random_state, shuffle=True)
169 | k_fold.get_n_splits(features)
170 | for _, (train_index, test_index) in enumerate(k_fold.split(features)):
171 | # train data
172 | x_trains.append(features.iloc[train_index, :])
173 | y_trains.append(labels.iloc[train_index])
174 | # test data
175 | x_tests.append(features.iloc[test_index, :])
176 | y_tests.append(labels.iloc[test_index])
177 |
178 | return x_trains, y_trains, x_tests, y_tests
179 |
180 |
181 | def train_with_kfold_splitting(
182 | features: pd.DataFrame,
183 | labels: pd.DataFrame,
184 | model: BaseEstimator,
185 | scaler_type: BaseEstimator,
186 | verbose: bool,
187 | random_state: int,
188 | ) -> pd.DataFrame:
189 | """It trains the model using the kfold splitting and returns
190 | a DataFrame with the feature importance.
191 |
192 | Parameters:
193 | - features: the matrix with features, commonly called X
194 | - labels: the vector with labels, commonly called y
195 | - model: an untrained scikit-learn model
196 | - scaler_type: choose between StandardScaler or MinMaxScaler
197 | - verbose: True or False to tune the level of verbosity
198 | - random_state: select the random state of the train/test splitting process
199 |
200 | Return:
201 | - a DataFrame with one column, containing the features importance (or the coefficients)
202 | """
203 |
204 | # create train-test data
205 | x_trains, y_trains, x_tests, y_tests = generate_kfold_data(
206 | features, labels, random_state
207 | )
208 |
209 | for i in range(len(x_trains)):
210 | trained_model = train_evaluate_model(
211 | x_trains[i],
212 | y_trains[i],
213 | x_tests[i],
214 | y_tests[i],
215 | model,
216 | scaler_type,
217 | verbose,
218 | )
219 | if i == 0:
220 | df_coefs = get_feature_importances(trained_model, x_trains[i].columns)
221 | df_coefs.columns = ["cycle_" + str(i + 1)]
222 | else:
223 | df_coefs["cycle_" + str(i + 1)] = get_feature_importances(
224 | trained_model, x_trains[i].columns
225 | )
226 |
227 | df_coef = compute_mean_coefficients(df_coefs)
228 | return df_coef
229 |
230 |
231 | def train_with_simple_splitting(
232 | features: pd.DataFrame,
233 | labels: pd.DataFrame,
234 | model: BaseEstimator,
235 | scaler_type: BaseEstimator,
236 | verbose: bool,
237 | random_state: int,
238 | ) -> pd.DataFrame:
239 | """It trains the model using the train/test splitting and returns
240 | a DataFrame with the feature importance.
241 |
242 | Parameters:
243 | - features: the matrix with features, commonly called X
244 | - labels: the vector with labels, commonly called y
245 | - model: an untrained scikit-learn model
246 | - scaler_type: choose between StandardScaler or MinMaxScaler
247 | - verbose: True or False to tune the level of verbosity
248 | - random_state: select the random state of the train/test splitting process
249 |
250 | Return:
251 | - a DataFrame with one column, containing the features importance (or the coefficients)
252 | """
253 |
254 | # create train-test data
255 | x_train, x_test, y_train, y_test = train_test_split(
256 | features, labels, test_size=0.2, random_state=random_state
257 | )
258 |
259 | trained_model = train_evaluate_model(
260 | x_train, y_train, x_test, y_test, model, scaler_type, verbose
261 | )
262 | df_coefs = get_feature_importances(trained_model, x_train.columns)
263 |
264 | df_coef = compute_mean_coefficients(df_coefs)
265 |
266 | return df_coef
267 |
268 |
269 | def scan_features_pipeline(
270 | features: pd.DataFrame,
271 | labels: pd.DataFrame,
272 | model: BaseEstimator,
273 | splitting_type: str,
274 | verbose: bool,
275 | random_state: int,
276 | noise_type: str,
277 | ) -> pd.DataFrame:
278 | """This pipeline performs various operations:
279 | - train and evaluate the model
280 | - generates the DataFrame with the feature importance
281 | - computes the simplified dataset, containing only the relevant features
282 |
283 | Parameters:
284 | - features: the matrix with features, commonly called X
285 | - labels: the vector with labels, commonly called y
286 | - model: an untrained scikit-learn model
287 | - splitting_type: choose between "simple" (80% train, 20% test)
288 | or "kfold" (5-fold splitting)
289 | - verbose: True or False to tune the level of verbosity
290 | - random_state: select the random state of the train/test splitting process
291 | - noise_type: choose between "gaussian" noise or "random" (flat) noise
292 |
293 | Return:
294 | - the simplified dataset, containing only the most relevant features
295 | """
296 |
297 | # add noise
298 | x_new = features.copy(deep=True)
299 |
300 | if noise_type == "gaussian":
301 | x_new["random_feature"] = np.random.normal(0, 1, size=len(x_new))
302 | scaler_type = "StandardScaler"
303 | elif noise_type == "random":
304 | x_new["random_feature"] = np.random.rand(len(x_new))
305 | scaler_type = "MinMaxScaler"
306 | else:
307 | raise ValueError("Allowed values for noise_type are gaussian and random")
308 |
309 | if splitting_type == "kfold":
310 | df_coef = train_with_kfold_splitting(
311 | x_new, labels, model, scaler_type, verbose, random_state
312 | )
313 | elif splitting_type == "simple":
314 | df_coef = train_with_simple_splitting(
315 | x_new, labels, model, scaler_type, verbose, random_state
316 | )
317 | else:
318 | raise ValueError("Choice not recognized. Possible choices are kfold or simple")
319 |
320 | simplified_dataset = select_relevant_features(df_coef, x_new, verbose)
321 |
322 | return simplified_dataset
323 |
324 |
325 | def get_relevant_features(
326 | features: pd.DataFrame,
327 | labels: pd.DataFrame,
328 | model: BaseEstimator,
329 | splitting_type: str,
330 | epochs: int,
331 | patience: int,
332 | noise_type: str = "gaussian",
333 | verbose: bool = True,
334 | filename_output: Optional[str] = None,
335 | random_state: int = 42,
336 | ) -> pd.DataFrame:
337 | """This functions performs multiple cycles to reduce the dimension of the dataset.
338 |
339 | Parameters:
340 | - features: the matrix with features, commonly called X
341 | - labels: the vector with labels, commonly called y
342 | - model: an untrained scikit-learn model
343 | - splitting_type: choose between "simple" (80% train, 20% test)
344 | or "kfold" (5-fold splitting)
345 | - epochs: the number of epochs (or cycles)
346 | - patience: the number of cycles of non-improvement to wait before stopping
347 | the execution of the code
348 | - noise_type: choose between "gaussian" noise or "random" (flat) noise
349 | - verbose: True or False, to tune the level of verbosity
350 | - filename_output: name of the simplified dataset if you want to export it, default is None
351 | - random_state: select the random seed
352 |
353 | Return:
354 | - the dataset simplified after multiple epochs of feature selection
355 | """
356 |
357 | x_new = features.copy(deep=True)
358 | counter_patience = 0
359 | epoch = 0
360 |
361 | np.random.seed(random_state)
362 | random_states = np.random.randint(1, int(10 * epochs), size=epochs)
363 |
364 | while (counter_patience < patience) and (epoch < epochs):
365 | n_features_before = x_new.shape[1]
366 | print("=====================EPOCH", epoch + 1, "=====================")
367 | x_new = scan_features_pipeline(
368 | x_new,
369 | labels,
370 | model,
371 | splitting_type,
372 | verbose,
373 | random_states[epoch],
374 | noise_type,
375 | )
376 | n_features_after = x_new.shape[1]
377 |
378 | if n_features_before == n_features_after:
379 | counter_patience += 1
380 | print(
381 | "The feature selection did not improve in the last",
382 | counter_patience,
383 | "epochs",
384 | )
385 | else:
386 | counter_patience = 0
387 |
388 | epoch += 1
389 |
390 | if filename_output is not None:
391 | x_new.to_csv(filename_output, index=False)
392 |
393 | return x_new
394 |
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sklearn.linear_model import Lasso
4 | from sklearn.ensemble import RandomForestRegressor
5 | from src.data import create_regression_data
6 | from src.ml import (
7 | train_evaluate_model,
8 | get_feature_importances,
9 | compute_mean_coefficients,
10 | select_relevant_features,
11 | generate_kfold_data,
12 | train_with_kfold_splitting,
13 | train_with_simple_splitting,
14 | scan_features_pipeline,
15 | get_relevant_features,
16 | )
17 |
18 |
19 | def test_data_creation():
20 | features, labels = create_regression_data()
21 | assert features.shape == (1000, 300), "Shape of features is wrong"
22 | assert labels.shape == (1000, 1), "Shape of labels is wrong"
23 |
24 |
25 | def test_train_evaluate():
26 | x_train = np.random.rand(100, 10)
27 | y_train = np.random.rand(100, 1)
28 | x_test = np.random.rand(20, 10)
29 | y_test = np.random.rand(20, 1)
30 | model = Lasso()
31 | trained_model = train_evaluate_model(
32 | x_train,
33 | y_train,
34 | x_test,
35 | y_test,
36 | model,
37 | scaler_type="StandardScaler",
38 | verbose=False,
39 | )
40 | assert (
41 | len(trained_model.coef_) == x_train.shape[1]
42 | ), "The model is not trained properly"
43 |
44 |
45 | def test_get_features_importance():
46 | x_train = np.random.rand(100, 10)
47 | y_train = np.random.rand(100)
48 | column_names = np.arange(0, x_train.shape[1])
49 |
50 | lasso = Lasso()
51 | lasso.fit(x_train, y_train)
52 | df_coef = get_feature_importances(lasso, column_names)
53 | assert (
54 | type(df_coef) == pd.DataFrame
55 | ), "The table with feature importance must be a DataFrame"
56 | assert (
57 | len(df_coef) == x_train.shape[1]
58 | ), "The number of coefficients does not match the shape of the training data"
59 |
60 | rf = RandomForestRegressor()
61 | rf.fit(x_train, y_train)
62 | df_coef = get_feature_importances(rf, column_names)
63 | assert (
64 | type(df_coef) == pd.DataFrame
65 | ), "The table with feature importance must be a DataFrame"
66 | assert (
67 | len(df_coef) == x_train.shape[1]
68 | ), "The number of coefficients does not match the shape of the training data"
69 |
70 |
71 | def test_mean_coefficients_single_column():
72 | feature_importance = np.random.randint(-100, 100, size=20)
73 | df = pd.DataFrame(feature_importance, index=np.arange(0, len(feature_importance)))
74 | vec = np.sort(np.abs(feature_importance))[::-1]
75 | df_sorted = compute_mean_coefficients(df)
76 | assert all(
77 | np.array(df_sorted["Feature importance"]) == vec
78 | ), "Feature importances are not sorted properly"
79 |
80 |
81 | def test_mean_coefficients_multiple_columns():
82 | feature_importance = 2 * np.random.rand(100, 5) - 1
83 | df = pd.DataFrame(feature_importance, index=np.arange(0, len(feature_importance)))
84 | vec = np.sort(np.abs(df.mean(axis=1)))[::-1]
85 | df_sorted = compute_mean_coefficients(df)
86 | assert all(
87 | np.array(df_sorted["Feature importance"]) == vec
88 | ), "Feature importances are not sorted properly"
89 |
90 |
91 | def test_select_relevant_features():
92 | df_coef = pd.DataFrame([5, 4, 3, 2, 1], columns=["Feature importance"])
93 | df_coef["Feature name"] = ["col1", "col2", "random_feature", "col3", "col4"]
94 | features = pd.DataFrame(
95 | np.random.rand(10, 5),
96 | columns=["col1", "col2", "random_feature", "col3", "col4"],
97 | )
98 | feature_selected = select_relevant_features(df_coef, features, verbose=True)
99 | assert all(feature_selected.columns == ["col1", "col2"]), "Wrong columns selected"
100 |
101 |
102 | def test_kfold_splitting():
103 | features = pd.DataFrame(np.random.rand(100, 10))
104 | labels = pd.DataFrame(np.random.rand(100))
105 | x_trains, y_trains, x_tests, y_tests = generate_kfold_data(
106 | features, labels, random_state=42
107 | )
108 | assert len(x_trains) == 5, "Length of train features is wrong"
109 | assert len(x_tests) == 5, "Length of test features is wrong"
110 | assert len(y_trains) == 5, "Length of train labels is wrong"
111 | assert len(y_tests) == 5, "Length of test labels is wrong"
112 |
113 |
114 | def test_train_kfold_splitting():
115 | features = pd.DataFrame(np.random.rand(100, 10))
116 | labels = pd.DataFrame(np.random.rand(100))
117 | model = Lasso()
118 | df_coef = train_with_kfold_splitting(
119 | features,
120 | labels,
121 | model,
122 | scaler_type="StandardScaler",
123 | verbose=True,
124 | random_state=42,
125 | )
126 | assert type(df_coef) == pd.DataFrame, "df_coef must be a Pandas DataFrame"
127 | assert (
128 | len(df_coef) == features.shape[1]
129 | ), "The length of df_coef must match the number of features"
130 |
131 |
132 | def test_train_simple_splitting():
133 | features = pd.DataFrame(np.random.rand(100, 10))
134 | labels = pd.DataFrame(np.random.rand(100))
135 | model = Lasso()
136 | df_coef = train_with_simple_splitting(
137 | features,
138 | labels,
139 | model,
140 | scaler_type="MinMaxScaler",
141 | verbose=True,
142 | random_state=42,
143 | )
144 | assert type(df_coef) == pd.DataFrame, "df_coef must be a Pandas DataFrame"
145 | assert (
146 | len(df_coef) == features.shape[1]
147 | ), "The length of df_coef must match the number of features"
148 |
149 |
150 | def test_scan_feature_pipeline():
151 | features, labels = create_regression_data()
152 | model = Lasso()
153 | reduced_features = scan_features_pipeline(
154 | features,
155 | labels,
156 | model,
157 | splitting_type="simple",
158 | verbose=False,
159 | random_state=43,
160 | noise_type="gaussian",
161 | )
162 | assert (
163 | reduced_features.shape[1] < features.shape[1]
164 | ), "The pipeline did not reduce the number of features"
165 |
166 | reduced_features = scan_features_pipeline(
167 | features,
168 | labels,
169 | model,
170 | splitting_type="kfold",
171 | verbose=False,
172 | random_state=43,
173 | noise_type="random",
174 | )
175 | assert (
176 | reduced_features.shape[1] < features.shape[1]
177 | ), "The pipeline did not reduce the number of features"
178 |
179 |
180 | def test_get_relevant_features():
181 | features, labels = create_regression_data()
182 | model = Lasso()
183 |
184 | x_new = get_relevant_features(
185 | features,
186 | labels,
187 | model,
188 | splitting_type="simple",
189 | epochs=10,
190 | patience=5,
191 | random_state=41,
192 | )
193 |
194 | assert (
195 | x_new.shape[1] < features.shape[1]
196 | ), "The pipeline did not reduce the number of features"
197 |
198 | x_new = get_relevant_features(
199 | features,
200 | labels,
201 | model,
202 | splitting_type="kfold",
203 | epochs=10,
204 | patience=5,
205 | random_state=41,
206 | )
207 |
208 | assert (
209 | x_new.shape[1] < features.shape[1]
210 | ), "The pipeline did not reduce the number of features"
211 |
--------------------------------------------------------------------------------