├── .gitignore ├── LICENSE ├── README.md ├── data ├── toy-regression-features.csv └── toy-regression-labels.csv ├── example.ipynb ├── requirements.txt ├── results └── reduced_dataset.csv ├── src ├── __init__.py ├── data.py └── ml.py └── tests └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 apalladi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # feature-selection-adding-noise 2 | 3 | ## Introduction 4 | The purpose of this small library is to apply feature selection to your high-dimensional data. In order to do that, we apply the following steps: 5 | 1) the input features are automatically standardized 6 | 2) a column containing gaussian noise (average = 0, standard deviation = 1) is added 7 | 3) a model is trained 8 | 4) the feature importance is evaluated 9 | 5) all the features that are less important than the random one, are excluded 10 | The previous steps are repeated iteratively, until the algorithm converges. 11 | 12 | ## Initialize the repository 13 | Let us start by cloning the repository, by using the following command: 14 | ``` 15 | git@github.com:apalladi/feature-selection-adding-noise.git 16 | ``` 17 | Then you need to install the dependencies. I suggest to create a virtual environment, as follows: 18 | ``` 19 | python3 -m venv .env 20 | source .env/bin/activate 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | To check if everything works, you can run the unit tests: 25 | ``` 26 | python -m pytest tests/test.py 27 | ``` 28 | 29 | You are now ready to use the repository! 30 | 31 | ## Getting started 32 | The example you need is contained in [this notebook](example.ipynb). 33 | A toy dataset, to build a regression model, is imported. 34 | Then we import the function `get_relevant_features` 35 | ``` 36 | from src.ml import get_relevant_features 37 | ``` 38 | This function takes as arguments: 39 | - `features` 40 | - `labels` 41 | - `model`, a scikit-learn model 42 | - `epochs`, the number of epochs (i.e. for how many cycles you want to apply recursively the feature selection) 43 | - `patience`, number of epochs without any improvement of the features selection, before stopping the process (the idea is similar to the early stopping of Tensorflow/Keras) 44 | - `splitting_type`, it can be equal to `simple` (for simple train/test split) or `kfold` (for 5-fold splitting). If you choose `kfold`, the feature importance will be computed as the average feature importance for each train/test subset. 45 | - `noise_type`, it can be equal to `gaussian` for gaussian noise or `random` for flat random noise 46 | - `filename_output`, a string to indicate where to save the file. You can also choose `None` if you do not want to save it 47 | - `random_state`, set the random seed that it is used by the k-fold splitting 48 | 49 | The function `get_relevant_features` returns a DataFrame with a reduced dataset, i.e. a dataset that contains only the most important features. 50 | 51 | -------------------------------------------------------------------------------- /data/toy-regression-labels.csv: -------------------------------------------------------------------------------- 1 | labels 2 | 219.66421110537377 3 | 40.99812573852436 4 | 109.42385821856618 5 | -280.66750675702076 6 | 599.1785378315125 7 | 210.84346059574347 8 | 283.00445948218146 9 | -30.12438553925481 10 | -59.81206007271604 11 | -420.65737435194205 12 | -197.5390398130957 13 | 45.09427291240982 14 | -473.5937200217304 15 | 99.05970861007242 16 | -306.0419398927675 17 | -79.5975539814209 18 | -162.12424764307752 19 | 78.67863470667349 20 | 408.7254953097355 21 | 123.84731009652106 22 | -102.12843400269146 23 | -17.312519366997122 24 | 287.8622440339944 25 | 134.2825735118363 26 | -288.2538933803337 27 | -294.31645884698355 28 | 93.5402459089359 29 | -165.75915298963588 30 | 187.3625488233377 31 | 477.24364434704654 32 | -160.02519426852024 33 | -369.07047587223315 34 | 124.76128030649855 35 | -168.9598361723375 36 | 183.84951935491583 37 | 346.93821420244996 38 | -146.9363708507223 39 | -475.7711780150155 40 | 164.3657272962622 41 | -265.0171836664488 42 | -593.7346139201999 43 | -209.7491908447602 44 | -342.35365549545133 45 | -68.77410945324183 46 | 48.28451867992422 47 | -258.396355994631 48 | 39.1283861478949 49 | -420.94219859643 50 | 143.72124261476043 51 | 275.85123174279823 52 | 67.1920188813163 53 | 226.3593163855234 54 | -99.55328243673284 55 | -408.20911939582174 56 | -99.08091797152437 57 | -481.2436653525199 58 | -2.9292434308502493 59 | -50.57606534287761 60 | 175.3610172710227 61 | 220.4375232785938 62 | 202.3942278357318 63 | -113.57111937095235 64 | 232.3565587371829 65 | -162.54704358780953 66 | -57.130109705320024 67 | 441.03614555634294 68 | -67.02651224239445 69 | -509.34070453705914 70 | -27.870055104524937 71 | -229.53662143572663 72 | -250.14024664163387 73 | -366.4372342171117 74 | 46.518804179761176 75 | 116.92043101176571 76 | 18.820544960998525 77 | 210.30034026958214 78 | 478.5477878660953 79 | 417.2844739466522 80 | 102.39941584413236 81 | -978.0120573994108 82 | -308.8451209548697 83 | -213.5210097425842 84 | 261.4645116037161 85 | 36.27066302849197 86 | -6.669884127949885 87 | 358.608687862422 88 | -226.98071759699877 89 | 74.74585810340724 90 | 8.740405469897468 91 | -57.65180347772909 92 | 345.68698072536876 93 | 323.44218805538713 94 | -238.43075811367788 95 | 344.57810996390094 96 | -353.15252408200615 97 | 311.8043731963006 98 | 301.9183965858471 99 | 352.5979167661178 100 | 110.00390089019643 101 | 246.01921125336307 102 | 677.599200798106 103 | 711.4843824206637 104 | 191.7277911019544 105 | 57.99259068049659 106 | -383.0990013737419 107 | 334.4620955849534 108 | -288.13816332058644 109 | 161.98906727450884 110 | -256.52604740772733 111 | 33.02783772330004 112 | -369.31457330854863 113 | 282.8980069994969 114 | -108.07987335065204 115 | 181.88867152135373 116 | -463.355389084535 117 | -57.16858021828898 118 | 302.345137705743 119 | -18.689778293238987 120 | -325.1782919210175 121 | -130.03126703705072 122 | -45.75588839018792 123 | 95.40621723970422 124 | -479.6015968105236 125 | 227.22243128108795 126 | -30.991358685988317 127 | -142.9779163301373 128 | -893.830818650887 129 | 131.59725557684942 130 | -208.08364950440514 131 | -50.11918151891862 132 | -129.82418090113322 133 | -204.7782679241685 134 | -365.2291784981545 135 | 958.0195550658315 136 | 91.03634732414339 137 | 668.3293179922772 138 | 137.33931696932729 139 | 156.45819823550076 140 | 46.78787104054108 141 | -73.0611991287546 142 | -16.079112537893373 143 | -332.9179458827927 144 | 12.608764116640717 145 | 265.29304542408283 146 | -139.21545955687984 147 | 110.2828069067367 148 | -44.607068712910745 149 | 16.492497282475313 150 | 317.5265009381611 151 | -401.0451689065083 152 | 12.067463582747962 153 | -83.09755944022235 154 | -11.42361671446308 155 | 355.4587444092366 156 | -278.8705245321509 157 | 259.0809422276131 158 | 131.9797904814068 159 | 458.8901234460045 160 | -201.37984822699846 161 | 77.89296354228058 162 | -371.78620712581846 163 | 303.19151387677914 164 | 295.70743716021695 165 | 244.792621765663 166 | -151.51871438114455 167 | 375.0709213146323 168 | 406.5079040014588 169 | 286.7800957483075 170 | -282.5277096305248 171 | -420.2630572029418 172 | -413.618662390845 173 | 216.3436814539506 174 | -683.6027897802443 175 | 92.93560643486715 176 | 44.70344795422818 177 | 343.9726303950285 178 | 66.1754662656133 179 | -29.590210811428733 180 | -3.368565792046752 181 | -315.1923255488523 182 | 412.37831319704844 183 | 414.65900678313886 184 | 30.541122590840843 185 | -238.63012467202824 186 | -96.84336054466398 187 | 300.1892887346487 188 | -334.54217813498695 189 | 27.99608296508022 190 | -61.77129204337399 191 | 14.942188777365573 192 | -525.7326455722985 193 | 112.7322750639299 194 | 364.84216999578814 195 | -295.65563357735243 196 | 120.74118061033646 197 | -313.03977941675026 198 | -145.45345902728027 199 | 11.797154900881168 200 | -61.810344123332534 201 | 74.26605490724862 202 | -309.00266961667126 203 | -530.0828723263916 204 | 418.57140590307944 205 | -41.74482714115291 206 | 121.41684815273871 207 | 405.2362104672736 208 | -171.9970231255059 209 | 88.2797601652037 210 | 512.6155407924663 211 | 269.54406471209035 212 | -124.36335410913836 213 | -324.12015696729463 214 | 246.2544716090739 215 | 434.7036044660102 216 | -223.69334812177925 217 | 463.5430305770866 218 | 359.6186097903139 219 | 396.58142510788787 220 | 738.2674707542574 221 | 180.06500962643116 222 | -288.23528134988373 223 | -266.16857094666966 224 | -577.184257189553 225 | -480.8624043453125 226 | 112.89469666560275 227 | -60.85487516021421 228 | 442.58742035798167 229 | 248.77945846601668 230 | -174.53520141553088 231 | 562.0106297094923 232 | 396.1942647182949 233 | 199.06894198875804 234 | 252.71478393513448 235 | -146.93929247269597 236 | 191.7323273679387 237 | -45.2958727095224 238 | -427.21687468532275 239 | 12.461349601654678 240 | 422.30717448682867 241 | 233.6751650938761 242 | 563.0275372062894 243 | 433.30058799142154 244 | 108.30182641189415 245 | -218.02224784442905 246 | -183.4040116108224 247 | -182.92255119924062 248 | 392.30028867140544 249 | 3.7131965707191625 250 | 497.99579980378456 251 | 212.01834201356186 252 | 715.4649248810988 253 | 519.4609797102954 254 | -26.989752202347162 255 | 169.2888596594962 256 | 18.92806454001243 257 | -537.2461912164638 258 | -326.86335003591523 259 | -286.22485161024184 260 | 39.01717391310153 261 | -299.96850412900557 262 | -174.72281041255383 263 | -329.74061064784985 264 | 87.28903183163771 265 | 337.6511516443617 266 | 62.49194262923407 267 | 143.18118441638228 268 | -73.79382938901495 269 | 535.2243122098414 270 | 175.81395911870072 271 | -193.80398326330058 272 | 146.98895958839728 273 | 670.1157889736027 274 | -335.4876878063802 275 | -198.371912719899 276 | 314.2386412193764 277 | 168.8034824012617 278 | 306.6286193162233 279 | -367.7065829404182 280 | 498.65778038228956 281 | 56.676911072786595 282 | 298.9089820684999 283 | 375.2550323657371 284 | 158.16757721904568 285 | 200.4595688879491 286 | 509.27615081515336 287 | 141.71944355481935 288 | -358.73494550303235 289 | -141.46508942485553 290 | 200.38965431527947 291 | 639.9028651997666 292 | 93.07027597756083 293 | -146.52030808194593 294 | 144.75082163123142 295 | -111.83320853763635 296 | -50.37950192580972 297 | 501.19660261519607 298 | -67.44465712983656 299 | 534.1826610568827 300 | -160.7425747770929 301 | 85.73377964948206 302 | -96.47360982691012 303 | -274.9768131633041 304 | -412.5949823402696 305 | 477.4320851906831 306 | 476.16694166567004 307 | 53.43186613335136 308 | 187.87089203687583 309 | -205.79927102995535 310 | -237.13510435758218 311 | 36.3208198702175 312 | 207.94503006838846 313 | -88.9914384510391 314 | 186.6920918619557 315 | 86.53458018538245 316 | 199.05376088951718 317 | 50.04826687442488 318 | -205.8351854633912 319 | 289.4862395862873 320 | 78.71216376430846 321 | -462.4158964137397 322 | -76.2541187285045 323 | -402.35782297263756 324 | -877.2417176434378 325 | 57.554397302157 326 | -373.5605273471517 327 | 35.245880980242866 328 | -296.7834630772414 329 | -178.4838476097142 330 | 111.24982505285713 331 | -96.79190759717235 332 | -477.4648890058636 333 | 445.67187652053417 334 | -79.15399783119454 335 | 116.88235533395297 336 | 140.65941231630532 337 | 318.7863278451861 338 | -142.36798735269005 339 | -452.34506451028284 340 | 14.033099980482014 341 | -478.32528351098404 342 | 62.32621657690383 343 | -85.67661762917075 344 | 23.88662281891239 345 | 83.5806484692816 346 | -119.7489839281694 347 | -13.711441029105597 348 | -142.0344353209897 349 | 195.61770088578427 350 | 83.3791706398313 351 | -207.13874933383445 352 | -405.67021956846685 353 | -232.55675706729517 354 | 369.80403435973915 355 | 36.919250572105284 356 | 671.7929214211165 357 | -118.61369183668698 358 | 3.750548887927671 359 | -37.75545984466757 360 | -161.4348670843344 361 | -17.961521195049784 362 | 434.77540556488526 363 | -215.0279279255545 364 | 279.14136475142726 365 | -335.9299339463031 366 | 275.1685178819571 367 | -155.30757260332908 368 | -21.696019192629535 369 | -15.924744454492455 370 | 43.906253087008835 371 | 337.7686374470297 372 | -211.05341310481703 373 | -12.746175500056381 374 | 311.6254368113914 375 | 260.4555125568561 376 | 29.82725104023927 377 | 812.7399411215245 378 | 21.450446688103256 379 | 39.67367688152194 380 | -423.7973802453141 381 | -547.631683256067 382 | -102.93782119795702 383 | 74.76674346877166 384 | 107.52179751282722 385 | 280.4643452488732 386 | 174.61066521970474 387 | 101.23715037872829 388 | 396.05923966670605 389 | -377.947729330531 390 | -164.04101178001133 391 | 61.173919708742005 392 | 192.76588108646376 393 | -318.8179933745456 394 | -521.3846419750324 395 | 258.9820100030168 396 | 419.3231880074004 397 | -149.44226140107605 398 | 98.63114384731146 399 | -283.00958979274805 400 | 206.80574392772616 401 | -471.448371385884 402 | 130.07083889411336 403 | -501.04526864600626 404 | 89.24531288037647 405 | 228.05957571684934 406 | 318.30367800303617 407 | -434.4176084063522 408 | 774.5826919616825 409 | -0.9240439714436945 410 | -226.9520935877121 411 | 386.33513606410327 412 | 370.3836005652059 413 | 548.3799787687266 414 | 108.16433254033271 415 | -276.89949378114954 416 | -173.49339913684315 417 | 62.7674938216976 418 | -524.4276129904287 419 | 169.90198539054362 420 | 58.567795121879044 421 | -290.3747761985046 422 | 438.5833753750337 423 | -321.3578626667747 424 | -80.22699823707424 425 | -244.867673307717 426 | 209.4160679033282 427 | 242.13090604931872 428 | -307.18553169522727 429 | 288.7689437895989 430 | 236.37393252502866 431 | 313.063444212058 432 | 78.54134351510749 433 | -138.6433548546298 434 | 90.01870742580228 435 | 320.52346086529633 436 | -218.48563832111998 437 | 470.88640661939144 438 | 250.13352504170456 439 | -774.7651878072118 440 | -143.62815334681963 441 | 304.7855992694973 442 | -29.218330087607676 443 | -429.8289041752295 444 | -106.37838578950459 445 | -151.85466042085102 446 | -136.38359123956027 447 | 498.8875532079957 448 | -310.20813667277804 449 | -52.68642490176664 450 | 75.2253541889834 451 | -596.8809617691197 452 | -135.78130261309687 453 | -2.973639249572358 454 | -105.76612422435545 455 | 27.68036765794641 456 | -109.46412956248764 457 | 21.70842618529365 458 | -62.75111953580587 459 | 54.85449410719809 460 | 646.883158144461 461 | -132.5690334866119 462 | 275.5876887887282 463 | 470.40560918955055 464 | -73.56141628702832 465 | -495.2509411971815 466 | 437.79739193352054 467 | 313.7208983943566 468 | 89.79983035183005 469 | 69.22106362777522 470 | -149.181549095875 471 | 121.28585336420143 472 | -257.4156559512919 473 | 147.40779256195276 474 | -827.4059274712373 475 | -126.83611731992451 476 | -155.1572846515104 477 | 397.312269072818 478 | 341.8377369066052 479 | -573.1004574202724 480 | -379.46651936485296 481 | 232.37312209668877 482 | 64.35427719644893 483 | -403.03132612200744 484 | -695.0170848613226 485 | 112.47782185996614 486 | -345.074080739759 487 | -17.48180554945972 488 | 138.58459553286718 489 | 86.0430565982474 490 | -86.70623349810324 491 | 121.13398379191366 492 | -649.0922728040513 493 | 139.6156164280343 494 | -534.6862563642604 495 | 281.87317484690914 496 | 612.3306504710739 497 | -405.43560777977365 498 | 442.58039978248786 499 | -254.94346887899212 500 | -88.76336873598447 501 | -633.3370404134627 502 | -288.1141069995499 503 | 366.78973999043546 504 | -224.8536442421646 505 | 49.26671000659827 506 | 165.83040781552805 507 | 230.49469903002304 508 | 1057.9904144036054 509 | 473.39259752821386 510 | -49.86707813435416 511 | -200.6655577952709 512 | -18.614254099505132 513 | -118.9877427327001 514 | -101.85548024116953 515 | -3.3007141186634996 516 | 227.1415666490299 517 | -130.88124656998087 518 | 301.7640536750737 519 | 221.79947258234054 520 | 55.85295941144099 521 | -475.15121157051595 522 | -290.21029063523025 523 | 607.3300594336403 524 | -6.972351318487355 525 | -213.26044982185797 526 | 271.41930318809136 527 | -87.87163521051338 528 | 82.85724429895416 529 | 60.13044183532068 530 | 238.71488930311492 531 | -205.30782778551102 532 | -167.97249164001585 533 | -626.6680881535467 534 | -210.71794044174294 535 | 338.4268516743314 536 | 21.327754022949932 537 | 131.11375629474173 538 | 99.63642744845114 539 | -431.4696571353565 540 | 268.53069314040255 541 | 183.7155684889983 542 | 206.1136567255354 543 | 938.485055074153 544 | -370.3122696032017 545 | 302.2829100726916 546 | -345.980519117454 547 | 51.90634308481428 548 | 100.59413868158724 549 | 258.95425833506727 550 | -354.49223246498894 551 | -188.29842177527783 552 | 334.018648080925 553 | 82.57817449668167 554 | 105.40827565012256 555 | -642.6763674904915 556 | 80.88811418501643 557 | -537.7847275519619 558 | -44.994214473374534 559 | -668.5934888064476 560 | 225.93756565254031 561 | 727.7205122969831 562 | 69.73521987107483 563 | -0.6889414780357015 564 | 60.25016581711141 565 | 361.763437322291 566 | -193.73993874298833 567 | -75.06634974409894 568 | 87.67349755652162 569 | -70.7317911894568 570 | -158.9408099969209 571 | -53.949959598525496 572 | 289.97169234846314 573 | 513.8419618454099 574 | 456.7974490024387 575 | -387.32989714442715 576 | 11.541136573054459 577 | -516.2398183426416 578 | 22.20570426318264 579 | 44.93674703405627 580 | 111.74542701010068 581 | -166.65462533937958 582 | -423.0594820821815 583 | -67.79888144006907 584 | 95.99435634966315 585 | -153.93204277170054 586 | -441.61233639970544 587 | -377.97986068534146 588 | 43.88456551556874 589 | -159.09913081414723 590 | -601.3987380926017 591 | 159.0792015242133 592 | -64.48693133155123 593 | -376.56215882758215 594 | 702.7653676049299 595 | -92.69929665954832 596 | 204.1021997084409 597 | -8.097467943855179 598 | 48.60906336352542 599 | 294.90707401500975 600 | -242.9887235405698 601 | -57.499842323521875 602 | -456.36415593043733 603 | -7.932043052156564 604 | -26.227365280070046 605 | 368.6816217166653 606 | -256.03451275320384 607 | -94.58741293806388 608 | -30.03610893222566 609 | 178.4686951578877 610 | -271.2166766746184 611 | 370.0690174105201 612 | 72.05821771129402 613 | -171.4086748354379 614 | 119.67317757590877 615 | 47.86138741039744 616 | -265.26096397236466 617 | 425.8969952679584 618 | -123.18707188929788 619 | -167.2729715423001 620 | 253.92919370621652 621 | 6.371348453025277 622 | -106.02817326063514 623 | 741.5831023881881 624 | 238.74053090627243 625 | 366.96040429489227 626 | 393.27889904186316 627 | -162.37593584076217 628 | -194.5705991304323 629 | 453.1045664785996 630 | 92.57524932430822 631 | -115.34039171552371 632 | 263.6981736022322 633 | -411.37409886920096 634 | -381.9315239427792 635 | -87.82999213046611 636 | -40.75840233833431 637 | -574.9642829681824 638 | -203.27207694459548 639 | -122.741903041147 640 | -45.81988575674249 641 | 190.31576548103914 642 | -83.47111790485431 643 | -594.0219764302144 644 | -331.555854994611 645 | -551.2374852333826 646 | 148.7805869062618 647 | -215.18544048959663 648 | -296.58525416818475 649 | 383.39717648406804 650 | -319.99519952221743 651 | -322.55548938317884 652 | 688.4763052174203 653 | -42.797139515328055 654 | -128.11759461291382 655 | -374.23022915977714 656 | 43.74615623829504 657 | -207.0158737196524 658 | -350.7611332215587 659 | -284.1165610418166 660 | -90.74801244701638 661 | 175.22957014566754 662 | -298.4087178186445 663 | -358.3700225301559 664 | -15.945301151668161 665 | -423.47767401950034 666 | 51.633991919049144 667 | -172.36802631710177 668 | -115.72692061473215 669 | -154.41682560554972 670 | -252.19844458800276 671 | 301.13990716116524 672 | -247.8283487837687 673 | 213.34261253129188 674 | 283.4762578581023 675 | 220.15047098936952 676 | -63.64095810380125 677 | -271.4068829605406 678 | 172.2692175343363 679 | -556.5743113894846 680 | -101.92502182653703 681 | -385.27886175543796 682 | -116.67586605766246 683 | -220.27319939629527 684 | -447.6788463490998 685 | 86.42500713688688 686 | 326.30401506756243 687 | 247.6125458429959 688 | -125.25381166234219 689 | 135.92522190244517 690 | 812.9778135202583 691 | 240.70471517707932 692 | -32.89302845321356 693 | -25.069754727092985 694 | -124.8494311761978 695 | 78.1042829869844 696 | -192.97338775699205 697 | 486.3149327304531 698 | -50.88187970179066 699 | -239.1838866235216 700 | -348.31392934990356 701 | 79.13059626041633 702 | -397.6528365157774 703 | 9.260427333251801 704 | -303.56384357784856 705 | -143.0683011881344 706 | 154.97431087417888 707 | 167.8703940779693 708 | -250.37861808592208 709 | -36.73229949348913 710 | 165.60650395211113 711 | 436.6006340377123 712 | 144.98908182197334 713 | 138.68962210567366 714 | 439.6913244022819 715 | -105.56865101487321 716 | -144.6139907739925 717 | 41.46050366051495 718 | -100.38103185724285 719 | -115.9344709778836 720 | -122.55243220564434 721 | 135.74693752080393 722 | 318.02752995801427 723 | 659.0212948798771 724 | -178.9965439883482 725 | 153.70731456575817 726 | 12.561994790850235 727 | 538.2284866710995 728 | -149.14274589228876 729 | 56.63563619964414 730 | 106.55202867811536 731 | -83.95084515175216 732 | 176.37471060505206 733 | -12.396560499889546 734 | 29.014189286129238 735 | -322.8971488544742 736 | 512.7493263315325 737 | 188.19605378069392 738 | -26.532341521799662 739 | -340.8873819200601 740 | 42.375938360829224 741 | -432.2287980731304 742 | -306.8147227365162 743 | 80.56253887254563 744 | 531.8022962281085 745 | 122.22210015068346 746 | -28.711876698403643 747 | -736.6318129813221 748 | -129.23225956774758 749 | -4.192889360752608 750 | 84.53862831722492 751 | 136.43955202170602 752 | 103.04309754489225 753 | 900.5480460628117 754 | 488.18126422960006 755 | 194.6835555204844 756 | -50.652877604316025 757 | 209.9990779052415 758 | -113.52168835069503 759 | -215.41697904140096 760 | 48.07543592712512 761 | 283.94707766190487 762 | -430.01140706045237 763 | -43.82034902582119 764 | 43.09281783061093 765 | -299.3197271738013 766 | 175.75587560089198 767 | 77.95424864352577 768 | 196.45854533770043 769 | 212.90041115696633 770 | -209.5943712051902 771 | 93.15746999857956 772 | 61.03128186054663 773 | -77.99932355249459 774 | 132.37960560861092 775 | 47.58263934272867 776 | -230.0332142697132 777 | -134.74106595396276 778 | 79.37675734884638 779 | 278.36548634963424 780 | 221.96462495616566 781 | -822.880410741632 782 | 304.24573270881484 783 | -157.30162263219063 784 | 114.76518810323809 785 | 683.7096888263625 786 | 332.10317504702385 787 | -150.9332671102602 788 | 5.780666939958898 789 | 62.28137295155727 790 | -21.88742304950337 791 | -192.86481759405876 792 | 104.78803076331286 793 | -2.4886452846441074 794 | 95.87387009329376 795 | 106.9098263477729 796 | -326.0344916581623 797 | -459.3990198880376 798 | -320.9807797173338 799 | 41.44749510332275 800 | 299.9946902476721 801 | 568.0657424201406 802 | -52.80744052754126 803 | 166.30254539587904 804 | 513.3007141852776 805 | 490.06698675010614 806 | 339.26341073253946 807 | 121.70581006253593 808 | 129.56754891837846 809 | 476.9660515715946 810 | -764.1239805671025 811 | -234.2538479919876 812 | 272.68222245244743 813 | -65.60175064220076 814 | -307.2895651112007 815 | -263.1630745143601 816 | 426.2143861529066 817 | -4.864004048687123 818 | 465.2616236349868 819 | 239.09667158297088 820 | 359.5919087000877 821 | -62.051583979771976 822 | 300.6095963012117 823 | 493.78166093247654 824 | 276.9596787772388 825 | -51.958690491403075 826 | 11.622513648479057 827 | -372.2848578850104 828 | -60.25501178999095 829 | -270.2801583332541 830 | -124.03732005035349 831 | 207.46249274383024 832 | -77.92651595271144 833 | 199.05196933469293 834 | 23.68396605251496 835 | 406.93646409719247 836 | 333.43564416626447 837 | 379.9091676365338 838 | -291.45600463045366 839 | 214.31898278938803 840 | 108.54504251233661 841 | 392.9890931378893 842 | 389.37490872620447 843 | -47.13948884817606 844 | 236.2736616453898 845 | 536.2291186335767 846 | 313.35672188644133 847 | 157.9410042327246 848 | 415.0832675205393 849 | 237.29531787190953 850 | -70.10915410006021 851 | -29.286773927063404 852 | -345.1713423345644 853 | -235.1522562252115 854 | 15.712787711489767 855 | 271.34011205379045 856 | 126.09839785627477 857 | 76.53085613644575 858 | 330.60732285513814 859 | 34.440358762550986 860 | 660.739973216025 861 | -303.05501843825823 862 | 133.75158283314812 863 | -724.343066664628 864 | 183.09812144722144 865 | -3.302592053092212 866 | -214.9021562043016 867 | -277.4859874448581 868 | 457.8643085395454 869 | -127.90289527090574 870 | 124.26463199984457 871 | -14.066153487964044 872 | 606.1254619119906 873 | 752.9105249583428 874 | 20.998367218762038 875 | -288.3982756395808 876 | 129.1460922441576 877 | -295.13280290128586 878 | -219.95543227476344 879 | 456.8799114608556 880 | 621.2093520109033 881 | -83.93690799290073 882 | 206.05472525814582 883 | 96.93661335941695 884 | 102.34345019423195 885 | 568.3362243819324 886 | -472.8872467791076 887 | 89.17749470969224 888 | -23.907146376147715 889 | -474.1992606768101 890 | -42.71255474529988 891 | -60.75882464083398 892 | -131.2080420503905 893 | 488.0404769587037 894 | -54.580161168673776 895 | -100.79516212291948 896 | -256.70422868482774 897 | 363.95303473481823 898 | 251.05877228577782 899 | -427.60908165701653 900 | 271.985574552793 901 | 32.54576913176726 902 | 242.48784889913935 903 | 196.22713157388603 904 | 496.90183801658486 905 | -674.7436315228723 906 | -21.246938994175366 907 | 750.9383308111346 908 | 184.03873828418403 909 | -87.20430335899057 910 | -512.7163323894375 911 | 213.92223372220948 912 | -214.58027422152009 913 | 602.179427584113 914 | -140.43796059460806 915 | 119.13023473128969 916 | -173.56948532405613 917 | 267.5881448430044 918 | 503.2388043031035 919 | -23.11613398276637 920 | 67.1934141382712 921 | -398.0822267265129 922 | -86.00128294019206 923 | -323.2790498924765 924 | -455.19634577646275 925 | -345.47731457905314 926 | 180.32857789932658 927 | 439.18198130609005 928 | -58.349789948524915 929 | -51.37114972775309 930 | 358.43440398144514 931 | 34.41218090391774 932 | -310.46791180264745 933 | 30.432302966235966 934 | -289.7008373298527 935 | 121.69478729348083 936 | -80.06079925455828 937 | 341.92481591386377 938 | -256.23416416202184 939 | -130.00793340771452 940 | -510.569989749574 941 | -73.93557152974061 942 | 629.7574088812376 943 | -297.84547292519795 944 | -341.32309331391923 945 | 703.0937906274942 946 | -348.1396269485223 947 | 148.6247895162148 948 | -534.7323875010395 949 | 348.1935402945218 950 | 416.12154662501456 951 | 185.55488326379268 952 | 532.7910188918678 953 | 600.4218386644078 954 | -581.5320404118182 955 | -146.25664261076827 956 | -79.41256718592786 957 | 36.02759064730649 958 | -166.02555364748002 959 | 217.1552919966481 960 | -290.73080080056593 961 | -143.63862526707655 962 | -210.1774268490937 963 | -476.94008056005026 964 | -488.6480697893357 965 | -98.15687158225435 966 | -382.8604131254918 967 | 850.4747444559168 968 | 338.46030491126146 969 | -59.79339507552459 970 | 157.35918250239692 971 | 200.40195612609668 972 | -735.6164027108637 973 | -34.50250712620171 974 | 406.77013960209877 975 | 57.725587096903666 976 | -187.4647698018989 977 | 76.94863072652515 978 | 40.71787537182026 979 | 201.7848818691717 980 | -110.50546079155595 981 | -341.0723169408731 982 | -274.0465840801938 983 | 659.8104426557172 984 | -126.5209140985389 985 | -455.3151630042686 986 | -652.1053535876599 987 | -207.10595716904226 988 | -263.2606998263593 989 | -79.68629052372415 990 | -165.27483731283775 991 | -23.465321928129832 992 | 401.1325288416542 993 | -226.63495432426132 994 | 199.6686463904946 995 | -328.02133654428997 996 | -53.60775725091852 997 | 192.5280891686172 998 | -121.51712662816664 999 | -198.5471746260622 1000 | -62.324907025231084 1001 | 396.3846847390514 1002 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "61895906-6628-4029-b022-24dd1dbef8e5", 6 | "metadata": {}, 7 | "source": [ 8 | "# Introduction\n", 9 | "\n", 10 | "In this Notebook we show how to use this library, that performs the feature selection by introducing a column with gaussian noise (average = 0, standard deviation = 1).\n", 11 | "\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "5f56d2c7-5638-4830-a855-be239fe600f5", 17 | "metadata": {}, 18 | "source": [ 19 | "# Import toy regression data\n", 20 | "\n", 21 | "We import a toy regression dataset.\n", 22 | "The features consists of 1000 samples and 300 features, while the output consists of a single variable." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "id": "0c43ee6a-358c-4b29-8a71-f657dc7ef8dd", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import pandas as pd\n", 34 | "\n", 35 | "X = pd.read_csv('data/toy-regression-features.csv')\n", 36 | "y = pd.read_csv('data/toy-regression-labels.csv')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "a0784882-b8f1-45fc-9e06-7353f7e38e98", 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "(1000, 300)\n" 50 | ] 51 | }, 52 | { 53 | "data": { 54 | "text/html": [ 55 | "
\n", 56 | "\n", 69 | "\n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | "
col_0col_1col_2col_3col_4col_5col_6col_7col_8col_9...col_290col_291col_292col_293col_294col_295col_296col_297col_298col_299
00.537070-1.1021401.6140650.4467042.1500720.022795-0.0623250.8426011.3376460.423004...-0.923334-0.4818151.203645-1.458585-0.3820550.7784461.2816700.0835260.6900470.117204
1-0.2911040.8706620.9898580.3401810.462467-0.5821481.8887601.326881-1.654321-0.130696...1.4292730.845691-1.089257-0.9183550.3940182.608926-1.485463-0.907812-0.1736600.920506
2-0.6239140.645679-0.603598-0.382241-1.0388551.036846-0.4117460.3091380.3778601.115033...-0.853996-1.977211-0.3600700.457125-1.3728040.320784-0.961563-0.2034120.9202640.799161
3-0.007280-1.1592841.205723-0.869215-0.5714660.5401960.6566390.0416610.244310-0.860549...0.646255-0.762227-0.940969-0.889827-0.534136-0.649951-0.387092-1.0898140.0549350.955872
40.578238-0.756635-0.7686361.3398860.612525-0.431343-0.0582660.975151-1.9921180.179272...0.2636580.8377350.724682-2.493489-2.108600-1.646070-0.674911-0.344457-0.771689-0.691474
\n", 219 | "

5 rows × 300 columns

\n", 220 | "
" 221 | ], 222 | "text/plain": [ 223 | " col_0 col_1 col_2 col_3 col_4 col_5 col_6 \\\n", 224 | "0 0.537070 -1.102140 1.614065 0.446704 2.150072 0.022795 -0.062325 \n", 225 | "1 -0.291104 0.870662 0.989858 0.340181 0.462467 -0.582148 1.888760 \n", 226 | "2 -0.623914 0.645679 -0.603598 -0.382241 -1.038855 1.036846 -0.411746 \n", 227 | "3 -0.007280 -1.159284 1.205723 -0.869215 -0.571466 0.540196 0.656639 \n", 228 | "4 0.578238 -0.756635 -0.768636 1.339886 0.612525 -0.431343 -0.058266 \n", 229 | "\n", 230 | " col_7 col_8 col_9 ... col_290 col_291 col_292 col_293 \\\n", 231 | "0 0.842601 1.337646 0.423004 ... -0.923334 -0.481815 1.203645 -1.458585 \n", 232 | "1 1.326881 -1.654321 -0.130696 ... 1.429273 0.845691 -1.089257 -0.918355 \n", 233 | "2 0.309138 0.377860 1.115033 ... -0.853996 -1.977211 -0.360070 0.457125 \n", 234 | "3 0.041661 0.244310 -0.860549 ... 0.646255 -0.762227 -0.940969 -0.889827 \n", 235 | "4 0.975151 -1.992118 0.179272 ... 0.263658 0.837735 0.724682 -2.493489 \n", 236 | "\n", 237 | " col_294 col_295 col_296 col_297 col_298 col_299 \n", 238 | "0 -0.382055 0.778446 1.281670 0.083526 0.690047 0.117204 \n", 239 | "1 0.394018 2.608926 -1.485463 -0.907812 -0.173660 0.920506 \n", 240 | "2 -1.372804 0.320784 -0.961563 -0.203412 0.920264 0.799161 \n", 241 | "3 -0.534136 -0.649951 -0.387092 -1.089814 0.054935 0.955872 \n", 242 | "4 -2.108600 -1.646070 -0.674911 -0.344457 -0.771689 -0.691474 \n", 243 | "\n", 244 | "[5 rows x 300 columns]" 245 | ] 246 | }, 247 | "execution_count": 2, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "print(X.shape)\n", 254 | "X.head()" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 3, 260 | "id": "23443e5f-c1ec-422c-b894-b6da48b79c1c", 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "text/html": [ 266 | "
\n", 267 | "\n", 280 | "\n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | "
labels
0219.664211
140.998126
2109.423858
3-280.667507
4599.178538
\n", 310 | "
" 311 | ], 312 | "text/plain": [ 313 | " labels\n", 314 | "0 219.664211\n", 315 | "1 40.998126\n", 316 | "2 109.423858\n", 317 | "3 -280.667507\n", 318 | "4 599.178538" 319 | ] 320 | }, 321 | "execution_count": 3, 322 | "metadata": {}, 323 | "output_type": "execute_result" 324 | } 325 | ], 326 | "source": [ 327 | "y.head()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "id": "584213bc-0455-443c-873f-f6ea190a4f83", 333 | "metadata": {}, 334 | "source": [ 335 | "# Use the library to select only the relevant features\n", 336 | "\n", 337 | "We use the library to select only the relevant features.\n", 338 | "A column containing gaussian noise (mean = 0, std. dev = 1) is created at each epoch. Then the feature importance is computed and all the features that are less important that the random one, are excluded.\n", 339 | "\n", 340 | "The process is repeated for the number of selected `epochs`. It is also possible to put an early stopping, by assigning to the parameter `patience` a value that is smaller than the number of epochs. If the number of selected features remains the same for a number of epochs equal to patience, the process stops. " 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 4, 346 | "id": "5381a35e-7993-49b5-b8f8-23faed331359", 347 | "metadata": {}, 348 | "outputs": [ 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "=====================EPOCH 1 =====================\n", 354 | "Fitting the model with 301 features\n", 355 | "Train score 0.9252\n", 356 | "Test score 0.8687\n", 357 | "Fitting the model with 301 features\n", 358 | "Train score 0.9302\n", 359 | "Test score 0.8272\n", 360 | "Fitting the model with 301 features\n", 361 | "Train score 0.926\n", 362 | "Test score 0.864\n", 363 | "Fitting the model with 301 features\n", 364 | "Train score 0.9237\n", 365 | "Test score 0.8608\n", 366 | "Fitting the model with 301 features\n", 367 | "Train score 0.9238\n", 368 | "Test score 0.8639\n", 369 | "Selected 286 features out of 300\n", 370 | "=====================EPOCH 2 =====================\n", 371 | "Fitting the model with 287 features\n", 372 | "Train score 0.9267\n", 373 | "Test score 0.8527\n", 374 | "Fitting the model with 287 features\n", 375 | "Train score 0.9271\n", 376 | "Test score 0.8579\n", 377 | "Fitting the model with 287 features\n", 378 | "Train score 0.9294\n", 379 | "Test score 0.8502\n", 380 | "Fitting the model with 287 features\n", 381 | "Train score 0.9228\n", 382 | "Test score 0.8675\n", 383 | "Fitting the model with 287 features\n", 384 | "Train score 0.9245\n", 385 | "Test score 0.8594\n", 386 | "Selected 110 features out of 286\n", 387 | "=====================EPOCH 3 =====================\n", 388 | "Fitting the model with 111 features\n", 389 | "Train score 0.9145\n", 390 | "Test score 0.8962\n", 391 | "Fitting the model with 111 features\n", 392 | "Train score 0.9209\n", 393 | "Test score 0.8703\n", 394 | "Fitting the model with 111 features\n", 395 | "Train score 0.9161\n", 396 | "Test score 0.8956\n", 397 | "Fitting the model with 111 features\n", 398 | "Train score 0.9188\n", 399 | "Test score 0.8793\n", 400 | "Fitting the model with 111 features\n", 401 | "Train score 0.9135\n", 402 | "Test score 0.907\n", 403 | "Selected 110 features out of 110\n", 404 | "The feature selection did not improve in the last 1 epochs\n", 405 | "=====================EPOCH 4 =====================\n", 406 | "Fitting the model with 111 features\n", 407 | "Train score 0.9186\n", 408 | "Test score 0.8801\n", 409 | "Fitting the model with 111 features\n", 410 | "Train score 0.9191\n", 411 | "Test score 0.879\n", 412 | "Fitting the model with 111 features\n", 413 | "Train score 0.9212\n", 414 | "Test score 0.8573\n", 415 | "Fitting the model with 111 features\n", 416 | "Train score 0.9139\n", 417 | "Test score 0.9048\n", 418 | "Fitting the model with 111 features\n", 419 | "Train score 0.9133\n", 420 | "Test score 0.9052\n", 421 | "Selected 110 features out of 110\n", 422 | "The feature selection did not improve in the last 2 epochs\n", 423 | "=====================EPOCH 5 =====================\n", 424 | "Fitting the model with 111 features\n", 425 | "Train score 0.9208\n", 426 | "Test score 0.8722\n", 427 | "Fitting the model with 111 features\n", 428 | "Train score 0.9189\n", 429 | "Test score 0.8778\n", 430 | "Fitting the model with 111 features\n", 431 | "Train score 0.9169\n", 432 | "Test score 0.8905\n", 433 | "Fitting the model with 111 features\n", 434 | "Train score 0.914\n", 435 | "Test score 0.9026\n", 436 | "Fitting the model with 111 features\n", 437 | "Train score 0.9144\n", 438 | "Test score 0.9004\n", 439 | "Selected 109 features out of 110\n", 440 | "=====================EPOCH 6 =====================\n", 441 | "Fitting the model with 110 features\n", 442 | "Train score 0.9156\n", 443 | "Test score 0.8955\n", 444 | "Fitting the model with 110 features\n", 445 | "Train score 0.9162\n", 446 | "Test score 0.8835\n", 447 | "Fitting the model with 110 features\n", 448 | "Train score 0.9187\n", 449 | "Test score 0.8874\n", 450 | "Fitting the model with 110 features\n", 451 | "Train score 0.9165\n", 452 | "Test score 0.891\n", 453 | "Fitting the model with 110 features\n", 454 | "Train score 0.919\n", 455 | "Test score 0.8763\n", 456 | "Selected 109 features out of 109\n", 457 | "The feature selection did not improve in the last 1 epochs\n", 458 | "=====================EPOCH 7 =====================\n", 459 | "Fitting the model with 110 features\n", 460 | "Train score 0.9167\n", 461 | "Test score 0.8909\n", 462 | "Fitting the model with 110 features\n", 463 | "Train score 0.9115\n", 464 | "Test score 0.9116\n", 465 | "Fitting the model with 110 features\n", 466 | "Train score 0.9187\n", 467 | "Test score 0.8824\n", 468 | "Fitting the model with 110 features\n", 469 | "Train score 0.9188\n", 470 | "Test score 0.8862\n", 471 | "Fitting the model with 110 features\n", 472 | "Train score 0.9173\n", 473 | "Test score 0.8895\n", 474 | "Selected 109 features out of 109\n", 475 | "The feature selection did not improve in the last 2 epochs\n", 476 | "=====================EPOCH 8 =====================\n", 477 | "Fitting the model with 110 features\n", 478 | "Train score 0.9147\n", 479 | "Test score 0.9004\n", 480 | "Fitting the model with 110 features\n", 481 | "Train score 0.9195\n", 482 | "Test score 0.8733\n", 483 | "Fitting the model with 110 features\n", 484 | "Train score 0.9159\n", 485 | "Test score 0.893\n", 486 | "Fitting the model with 110 features\n", 487 | "Train score 0.9187\n", 488 | "Test score 0.8793\n", 489 | "Fitting the model with 110 features\n", 490 | "Train score 0.9161\n", 491 | "Test score 0.8979\n", 492 | "Selected 109 features out of 109\n", 493 | "The feature selection did not improve in the last 3 epochs\n", 494 | "=====================EPOCH 9 =====================\n", 495 | "Fitting the model with 110 features\n", 496 | "Train score 0.9153\n", 497 | "Test score 0.8915\n", 498 | "Fitting the model with 110 features\n", 499 | "Train score 0.9122\n", 500 | "Test score 0.9122\n", 501 | "Fitting the model with 110 features\n", 502 | "Train score 0.9176\n", 503 | "Test score 0.8872\n", 504 | "Fitting the model with 110 features\n", 505 | "Train score 0.9196\n", 506 | "Test score 0.8751\n", 507 | "Fitting the model with 110 features\n", 508 | "Train score 0.9193\n", 509 | "Test score 0.8746\n", 510 | "Selected 109 features out of 109\n", 511 | "The feature selection did not improve in the last 4 epochs\n", 512 | "=====================EPOCH 10 =====================\n", 513 | "Fitting the model with 110 features\n", 514 | "Train score 0.9192\n", 515 | "Test score 0.8761\n", 516 | "Fitting the model with 110 features\n", 517 | "Train score 0.9195\n", 518 | "Test score 0.8733\n", 519 | "Fitting the model with 110 features\n", 520 | "Train score 0.9166\n", 521 | "Test score 0.8916\n", 522 | "Fitting the model with 110 features\n", 523 | "Train score 0.9141\n", 524 | "Test score 0.9037\n", 525 | "Fitting the model with 110 features\n", 526 | "Train score 0.915\n", 527 | "Test score 0.899\n", 528 | "Selected 109 features out of 109\n", 529 | "The feature selection did not improve in the last 5 epochs\n", 530 | "=====================EPOCH 11 =====================\n", 531 | "Fitting the model with 110 features\n", 532 | "Train score 0.9203\n", 533 | "Test score 0.8697\n", 534 | "Fitting the model with 110 features\n", 535 | "Train score 0.912\n", 536 | "Test score 0.9121\n", 537 | "Fitting the model with 110 features\n", 538 | "Train score 0.9114\n", 539 | "Test score 0.9132\n", 540 | "Fitting the model with 110 features\n", 541 | "Train score 0.9217\n", 542 | "Test score 0.8598\n", 543 | "Fitting the model with 110 features\n", 544 | "Train score 0.9184\n", 545 | "Test score 0.8842\n", 546 | "Selected 109 features out of 109\n", 547 | "The feature selection did not improve in the last 6 epochs\n", 548 | "=====================EPOCH 12 =====================\n", 549 | "Fitting the model with 110 features\n", 550 | "Train score 0.9154\n", 551 | "Test score 0.8939\n", 552 | "Fitting the model with 110 features\n", 553 | "Train score 0.9175\n", 554 | "Test score 0.878\n", 555 | "Fitting the model with 110 features\n", 556 | "Train score 0.9168\n", 557 | "Test score 0.8906\n", 558 | "Fitting the model with 110 features\n", 559 | "Train score 0.9191\n", 560 | "Test score 0.8852\n", 561 | "Fitting the model with 110 features\n", 562 | "Train score 0.9165\n", 563 | "Test score 0.892\n", 564 | "Selected 107 features out of 109\n", 565 | "=====================EPOCH 13 =====================\n", 566 | "Fitting the model with 108 features\n", 567 | "Train score 0.9183\n", 568 | "Test score 0.887\n", 569 | "Fitting the model with 108 features\n", 570 | "Train score 0.9184\n", 571 | "Test score 0.8826\n", 572 | "Fitting the model with 108 features\n", 573 | "Train score 0.9109\n", 574 | "Test score 0.9111\n", 575 | "Fitting the model with 108 features\n", 576 | "Train score 0.9199\n", 577 | "Test score 0.8789\n", 578 | "Fitting the model with 108 features\n", 579 | "Train score 0.9147\n", 580 | "Test score 0.8987\n", 581 | "Selected 106 features out of 107\n", 582 | "=====================EPOCH 14 =====================\n", 583 | "Fitting the model with 107 features\n", 584 | "Train score 0.9168\n", 585 | "Test score 0.8898\n", 586 | "Fitting the model with 107 features\n", 587 | "Train score 0.9141\n", 588 | "Test score 0.898\n", 589 | "Fitting the model with 107 features\n", 590 | "Train score 0.9157\n", 591 | "Test score 0.8899\n", 592 | "Fitting the model with 107 features\n", 593 | "Train score 0.9177\n", 594 | "Test score 0.8791\n", 595 | "Fitting the model with 107 features\n", 596 | "Train score 0.9192\n", 597 | "Test score 0.8762\n", 598 | "Selected 106 features out of 106\n", 599 | "The feature selection did not improve in the last 1 epochs\n", 600 | "=====================EPOCH 15 =====================\n", 601 | "Fitting the model with 107 features\n", 602 | "Train score 0.9176\n", 603 | "Test score 0.8826\n", 604 | "Fitting the model with 107 features\n", 605 | "Train score 0.9157\n", 606 | "Test score 0.8898\n", 607 | "Fitting the model with 107 features\n", 608 | "Train score 0.9167\n", 609 | "Test score 0.8908\n", 610 | "Fitting the model with 107 features\n", 611 | "Train score 0.9187\n", 612 | "Test score 0.8814\n", 613 | "Fitting the model with 107 features\n", 614 | "Train score 0.9162\n", 615 | "Test score 0.8945\n", 616 | "Selected 95 features out of 106\n", 617 | "=====================EPOCH 16 =====================\n", 618 | "Fitting the model with 96 features\n", 619 | "Train score 0.9148\n", 620 | "Test score 0.8947\n", 621 | "Fitting the model with 96 features\n", 622 | "Train score 0.9152\n", 623 | "Test score 0.892\n", 624 | "Fitting the model with 96 features\n", 625 | "Train score 0.9117\n", 626 | "Test score 0.9098\n", 627 | "Fitting the model with 96 features\n", 628 | "Train score 0.9177\n", 629 | "Test score 0.8805\n", 630 | "Fitting the model with 96 features\n", 631 | "Train score 0.918\n", 632 | "Test score 0.8736\n", 633 | "Selected 95 features out of 95\n", 634 | "The feature selection did not improve in the last 1 epochs\n", 635 | "=====================EPOCH 17 =====================\n", 636 | "Fitting the model with 96 features\n", 637 | "Train score 0.9177\n", 638 | "Test score 0.8767\n", 639 | "Fitting the model with 96 features\n", 640 | "Train score 0.9158\n", 641 | "Test score 0.889\n", 642 | "Fitting the model with 96 features\n", 643 | "Train score 0.9178\n", 644 | "Test score 0.8789\n", 645 | "Fitting the model with 96 features\n", 646 | "Train score 0.9118\n", 647 | "Test score 0.9098\n", 648 | "Fitting the model with 96 features\n", 649 | "Train score 0.9154\n", 650 | "Test score 0.891\n", 651 | "Selected 95 features out of 95\n", 652 | "The feature selection did not improve in the last 2 epochs\n", 653 | "=====================EPOCH 18 =====================\n", 654 | "Fitting the model with 96 features\n", 655 | "Train score 0.913\n", 656 | "Test score 0.8988\n", 657 | "Fitting the model with 96 features\n", 658 | "Train score 0.9172\n", 659 | "Test score 0.8844\n", 660 | "Fitting the model with 96 features\n", 661 | "Train score 0.9147\n", 662 | "Test score 0.8959\n", 663 | "Fitting the model with 96 features\n", 664 | "Train score 0.9171\n", 665 | "Test score 0.8779\n", 666 | "Fitting the model with 96 features\n", 667 | "Train score 0.9152\n", 668 | "Test score 0.8911\n", 669 | "Selected 95 features out of 95\n", 670 | "The feature selection did not improve in the last 3 epochs\n", 671 | "=====================EPOCH 19 =====================\n", 672 | "Fitting the model with 96 features\n", 673 | "Train score 0.9194\n", 674 | "Test score 0.8763\n", 675 | "Fitting the model with 96 features\n", 676 | "Train score 0.9142\n", 677 | "Test score 0.8969\n", 678 | "Fitting the model with 96 features\n", 679 | "Train score 0.9127\n", 680 | "Test score 0.9045\n", 681 | "Fitting the model with 96 features\n", 682 | "Train score 0.9122\n", 683 | "Test score 0.899\n", 684 | "Fitting the model with 96 features\n", 685 | "Train score 0.9171\n", 686 | "Test score 0.8837\n", 687 | "Selected 95 features out of 95\n", 688 | "The feature selection did not improve in the last 4 epochs\n", 689 | "=====================EPOCH 20 =====================\n", 690 | "Fitting the model with 96 features\n", 691 | "Train score 0.9164\n", 692 | "Test score 0.8848\n", 693 | "Fitting the model with 96 features\n", 694 | "Train score 0.9159\n", 695 | "Test score 0.8899\n", 696 | "Fitting the model with 96 features\n", 697 | "Train score 0.9171\n", 698 | "Test score 0.8875\n", 699 | "Fitting the model with 96 features\n", 700 | "Train score 0.9187\n", 701 | "Test score 0.8805\n", 702 | "Fitting the model with 96 features\n", 703 | "Train score 0.9116\n", 704 | "Test score 0.9067\n", 705 | "Selected 25 features out of 95\n", 706 | "=====================EPOCH 21 =====================\n", 707 | "Fitting the model with 26 features\n", 708 | "Train score 0.9049\n", 709 | "Test score 0.8632\n", 710 | "Fitting the model with 26 features\n", 711 | "Train score 0.8987\n", 712 | "Test score 0.8921\n", 713 | "Fitting the model with 26 features\n", 714 | "Train score 0.9031\n", 715 | "Test score 0.8715\n", 716 | "Fitting the model with 26 features\n", 717 | "Train score 0.8997\n", 718 | "Test score 0.8882\n", 719 | "Fitting the model with 26 features\n", 720 | "Train score 0.8887\n", 721 | "Test score 0.9273\n", 722 | "Selected 25 features out of 25\n", 723 | "The feature selection did not improve in the last 1 epochs\n", 724 | "=====================EPOCH 22 =====================\n", 725 | "Fitting the model with 26 features\n", 726 | "Train score 0.9028\n", 727 | "Test score 0.8746\n", 728 | "Fitting the model with 26 features\n", 729 | "Train score 0.8923\n", 730 | "Test score 0.9145\n", 731 | "Fitting the model with 26 features\n", 732 | "Train score 0.8998\n", 733 | "Test score 0.8894\n", 734 | "Fitting the model with 26 features\n", 735 | "Train score 0.8998\n", 736 | "Test score 0.8897\n", 737 | "Fitting the model with 26 features\n", 738 | "Train score 0.8999\n", 739 | "Test score 0.8887\n", 740 | "Selected 25 features out of 25\n", 741 | "The feature selection did not improve in the last 2 epochs\n", 742 | "=====================EPOCH 23 =====================\n", 743 | "Fitting the model with 26 features\n", 744 | "Train score 0.895\n", 745 | "Test score 0.91\n", 746 | "Fitting the model with 26 features\n", 747 | "Train score 0.8962\n", 748 | "Test score 0.9039\n", 749 | "Fitting the model with 26 features\n", 750 | "Train score 0.8997\n", 751 | "Test score 0.8881\n", 752 | "Fitting the model with 26 features\n", 753 | "Train score 0.9022\n", 754 | "Test score 0.88\n", 755 | "Fitting the model with 26 features\n", 756 | "Train score 0.9029\n", 757 | "Test score 0.8729\n", 758 | "Selected 25 features out of 25\n", 759 | "The feature selection did not improve in the last 3 epochs\n", 760 | "=====================EPOCH 24 =====================\n", 761 | "Fitting the model with 26 features\n", 762 | "Train score 0.9008\n", 763 | "Test score 0.8872\n", 764 | "Fitting the model with 26 features\n", 765 | "Train score 0.8986\n", 766 | "Test score 0.8936\n", 767 | "Fitting the model with 26 features\n", 768 | "Train score 0.9037\n", 769 | "Test score 0.8698\n", 770 | "Fitting the model with 26 features\n", 771 | "Train score 0.892\n", 772 | "Test score 0.9178\n", 773 | "Fitting the model with 26 features\n", 774 | "Train score 0.8999\n", 775 | "Test score 0.8872\n", 776 | "Selected 25 features out of 25\n", 777 | "The feature selection did not improve in the last 4 epochs\n", 778 | "=====================EPOCH 25 =====================\n", 779 | "Fitting the model with 26 features\n", 780 | "Train score 0.8958\n", 781 | "Test score 0.9047\n", 782 | "Fitting the model with 26 features\n", 783 | "Train score 0.9008\n", 784 | "Test score 0.8845\n", 785 | "Fitting the model with 26 features\n", 786 | "Train score 0.8935\n", 787 | "Test score 0.9164\n", 788 | "Fitting the model with 26 features\n", 789 | "Train score 0.9012\n", 790 | "Test score 0.8853\n", 791 | "Fitting the model with 26 features\n", 792 | "Train score 0.904\n", 793 | "Test score 0.8682\n", 794 | "Selected 25 features out of 25\n", 795 | "The feature selection did not improve in the last 5 epochs\n", 796 | "=====================EPOCH 26 =====================\n", 797 | "Fitting the model with 26 features\n", 798 | "Train score 0.9\n", 799 | "Test score 0.8883\n", 800 | "Fitting the model with 26 features\n", 801 | "Train score 0.9015\n", 802 | "Test score 0.8834\n", 803 | "Fitting the model with 26 features\n", 804 | "Train score 0.8981\n", 805 | "Test score 0.8957\n", 806 | "Fitting the model with 26 features\n", 807 | "Train score 0.8987\n", 808 | "Test score 0.8939\n", 809 | "Fitting the model with 26 features\n", 810 | "Train score 0.8986\n", 811 | "Test score 0.8947\n", 812 | "Selected 25 features out of 25\n", 813 | "The feature selection did not improve in the last 6 epochs\n", 814 | "=====================EPOCH 27 =====================\n", 815 | "Fitting the model with 26 features\n", 816 | "Train score 0.9007\n", 817 | "Test score 0.8812\n", 818 | "Fitting the model with 26 features\n", 819 | "Train score 0.8963\n", 820 | "Test score 0.9039\n", 821 | "Fitting the model with 26 features\n", 822 | "Train score 0.9007\n", 823 | "Test score 0.8857\n", 824 | "Fitting the model with 26 features\n", 825 | "Train score 0.8971\n", 826 | "Test score 0.8999\n", 827 | "Fitting the model with 26 features\n", 828 | "Train score 0.9013\n", 829 | "Test score 0.8816\n", 830 | "Selected 25 features out of 25\n", 831 | "The feature selection did not improve in the last 7 epochs\n", 832 | "=====================EPOCH 28 =====================\n", 833 | "Fitting the model with 26 features\n", 834 | "Train score 0.9005\n", 835 | "Test score 0.8873\n", 836 | "Fitting the model with 26 features\n", 837 | "Train score 0.8973\n", 838 | "Test score 0.8975\n", 839 | "Fitting the model with 26 features\n", 840 | "Train score 0.9051\n", 841 | "Test score 0.8591\n", 842 | "Fitting the model with 26 features\n", 843 | "Train score 0.901\n", 844 | "Test score 0.8853\n", 845 | "Fitting the model with 26 features\n", 846 | "Train score 0.8913\n", 847 | "Test score 0.9177\n", 848 | "Selected 25 features out of 25\n", 849 | "The feature selection did not improve in the last 8 epochs\n", 850 | "=====================EPOCH 29 =====================\n", 851 | "Fitting the model with 26 features\n", 852 | "Train score 0.9\n", 853 | "Test score 0.8884\n", 854 | "Fitting the model with 26 features\n", 855 | "Train score 0.8975\n", 856 | "Test score 0.9002\n", 857 | "Fitting the model with 26 features\n", 858 | "Train score 0.8985\n", 859 | "Test score 0.8945\n", 860 | "Fitting the model with 26 features\n", 861 | "Train score 0.8974\n", 862 | "Test score 0.899\n", 863 | "Fitting the model with 26 features\n", 864 | "Train score 0.9037\n", 865 | "Test score 0.8706\n", 866 | "Selected 25 features out of 25\n", 867 | "The feature selection did not improve in the last 9 epochs\n", 868 | "=====================EPOCH 30 =====================\n", 869 | "Fitting the model with 26 features\n", 870 | "Train score 0.8977\n", 871 | "Test score 0.9002\n", 872 | "Fitting the model with 26 features\n", 873 | "Train score 0.8962\n", 874 | "Test score 0.9008\n", 875 | "Fitting the model with 26 features\n", 876 | "Train score 0.9006\n", 877 | "Test score 0.8849\n", 878 | "Fitting the model with 26 features\n", 879 | "Train score 0.9029\n", 880 | "Test score 0.8715\n", 881 | "Fitting the model with 26 features\n", 882 | "Train score 0.8989\n", 883 | "Test score 0.8896\n", 884 | "Selected 25 features out of 25\n", 885 | "The feature selection did not improve in the last 10 epochs\n" 886 | ] 887 | } 888 | ], 889 | "source": [ 890 | "from src.ml import get_relevant_features\n", 891 | "from sklearn.linear_model import Lasso\n", 892 | "\n", 893 | "lasso_model = Lasso(alpha=1)\n", 894 | "\n", 895 | "X_reduced = get_relevant_features(X, y, \n", 896 | " model=lasso_model, \n", 897 | " epochs=100, \n", 898 | " patience=10, \n", 899 | " splitting_type='kfold',\n", 900 | " noise_type='gaussian',\n", 901 | " filename_output='results/reduced_dataset.csv',\n", 902 | " random_state=42)" 903 | ] 904 | }, 905 | { 906 | "cell_type": "markdown", 907 | "id": "4dd5b308-86bf-4e0f-a38b-0eb4cba2612f", 908 | "metadata": {}, 909 | "source": [ 910 | "# Inspect the new (reduced) dataset\n", 911 | "\n", 912 | "The new reduced dataset contains only a subset of features, namely the most relevant ones " 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": 5, 918 | "id": "e7f4e2f4-eaab-4700-ab75-6d1356162917", 919 | "metadata": {}, 920 | "outputs": [ 921 | { 922 | "data": { 923 | "text/html": [ 924 | "
\n", 925 | "\n", 938 | "\n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | "
col_168col_277col_88col_283col_8col_270col_258col_187col_76col_171...col_119col_174col_250col_24col_25col_55col_274col_47col_45col_154
0-0.003067-0.003618-1.019952-2.0923261.3376460.2205741.6734970.2360820.656732-0.878309...-0.6456820.5770201.053236-0.3097411.1788540.8005310.6922071.6964631.498433-1.750620
10.0306760.2599510.9205480.685426-1.654321-0.041250-0.0511700.853362-0.367285-0.343505...2.0523820.653817-1.5454290.830601-0.570561-0.017992-0.8082670.063894-1.335906-1.292565
20.459295-1.351804-3.2778060.9333980.3778600.7587310.971993-0.7390100.200066-0.076279...-0.745177-0.2600102.176639-1.1452081.9571130.742541-1.399284-1.426165-0.8521441.935704
30.002783-0.413398-0.9367961.1533290.244310-0.7812650.7441981.1236660.328173-0.379941...-1.638474-0.077514-2.033497-1.0161650.245085-0.6300770.2213330.3567260.3895950.210160
40.0541073.0810690.842412-0.148948-1.9921180.5139120.0936662.8317121.2616110.741283...-0.557619-0.6768960.3517112.785978-0.719393-1.0689960.234328-0.817592-0.178059-0.086358
\n", 1088 | "

5 rows × 25 columns

\n", 1089 | "
" 1090 | ], 1091 | "text/plain": [ 1092 | " col_168 col_277 col_88 col_283 col_8 col_270 col_258 \\\n", 1093 | "0 -0.003067 -0.003618 -1.019952 -2.092326 1.337646 0.220574 1.673497 \n", 1094 | "1 0.030676 0.259951 0.920548 0.685426 -1.654321 -0.041250 -0.051170 \n", 1095 | "2 0.459295 -1.351804 -3.277806 0.933398 0.377860 0.758731 0.971993 \n", 1096 | "3 0.002783 -0.413398 -0.936796 1.153329 0.244310 -0.781265 0.744198 \n", 1097 | "4 0.054107 3.081069 0.842412 -0.148948 -1.992118 0.513912 0.093666 \n", 1098 | "\n", 1099 | " col_187 col_76 col_171 ... col_119 col_174 col_250 col_24 \\\n", 1100 | "0 0.236082 0.656732 -0.878309 ... -0.645682 0.577020 1.053236 -0.309741 \n", 1101 | "1 0.853362 -0.367285 -0.343505 ... 2.052382 0.653817 -1.545429 0.830601 \n", 1102 | "2 -0.739010 0.200066 -0.076279 ... -0.745177 -0.260010 2.176639 -1.145208 \n", 1103 | "3 1.123666 0.328173 -0.379941 ... -1.638474 -0.077514 -2.033497 -1.016165 \n", 1104 | "4 2.831712 1.261611 0.741283 ... -0.557619 -0.676896 0.351711 2.785978 \n", 1105 | "\n", 1106 | " col_25 col_55 col_274 col_47 col_45 col_154 \n", 1107 | "0 1.178854 0.800531 0.692207 1.696463 1.498433 -1.750620 \n", 1108 | "1 -0.570561 -0.017992 -0.808267 0.063894 -1.335906 -1.292565 \n", 1109 | "2 1.957113 0.742541 -1.399284 -1.426165 -0.852144 1.935704 \n", 1110 | "3 0.245085 -0.630077 0.221333 0.356726 0.389595 0.210160 \n", 1111 | "4 -0.719393 -1.068996 0.234328 -0.817592 -0.178059 -0.086358 \n", 1112 | "\n", 1113 | "[5 rows x 25 columns]" 1114 | ] 1115 | }, 1116 | "execution_count": 5, 1117 | "metadata": {}, 1118 | "output_type": "execute_result" 1119 | } 1120 | ], 1121 | "source": [ 1122 | "X_reduced.head()" 1123 | ] 1124 | }, 1125 | { 1126 | "cell_type": "code", 1127 | "execution_count": null, 1128 | "id": "debafd0f-0ced-426b-99e2-2b4c85beaeb8", 1129 | "metadata": {}, 1130 | "outputs": [], 1131 | "source": [] 1132 | } 1133 | ], 1134 | "metadata": { 1135 | "kernelspec": { 1136 | "display_name": "Python 3 (ipykernel)", 1137 | "language": "python", 1138 | "name": "python3" 1139 | }, 1140 | "language_info": { 1141 | "codemirror_mode": { 1142 | "name": "ipython", 1143 | "version": 3 1144 | }, 1145 | "file_extension": ".py", 1146 | "mimetype": "text/x-python", 1147 | "name": "python", 1148 | "nbconvert_exporter": "python", 1149 | "pygments_lexer": "ipython3", 1150 | "version": "3.9.7" 1151 | } 1152 | }, 1153 | "nbformat": 4, 1154 | "nbformat_minor": 5 1155 | } 1156 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.3 2 | pandas==1.3.4 3 | scipy==1.7.1 4 | scikit-learn==0.24.2 5 | jupyterlab==3.2.1 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apalladi/feature-selection-adding-noise/88e1c418f9e456986914d01a90534074b90e98f8/src/__init__.py -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | """This module produces toy dataset to be used for regression problems""" 2 | 3 | import pandas as pd 4 | from sklearn.datasets import make_regression 5 | 6 | 7 | def create_regression_data(): 8 | """This function produces a toy dataset, to be used in a regression problem. 9 | It takes no input and it gives as output features and labels.""" 10 | 11 | features, labels = make_regression( 12 | n_samples=1000, n_features=300, n_informative=20, n_targets=1, noise=100 13 | ) 14 | 15 | col_names = ["col_" + str(i) for i in range(features.shape[1])] 16 | features = pd.DataFrame(features, columns=col_names) 17 | labels = pd.DataFrame(labels, columns=["labels"]) 18 | return features, labels 19 | -------------------------------------------------------------------------------- /src/ml.py: -------------------------------------------------------------------------------- 1 | """This module contains the function to perform the feature selection, 2 | by adding random noise""" 3 | 4 | from typing import Tuple, List, Optional 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 8 | from sklearn.model_selection import KFold, train_test_split 9 | from sklearn.base import BaseEstimator 10 | 11 | 12 | def train_evaluate_model( 13 | x_train: pd.DataFrame, 14 | y_train: pd.DataFrame, 15 | x_test: pd.DataFrame, 16 | y_test: pd.DataFrame, 17 | model: BaseEstimator, 18 | scaler_type: BaseEstimator, 19 | verbose: bool, 20 | ) -> BaseEstimator: 21 | """It trains and evaluate the machine learning model. 22 | 23 | Parameters: 24 | - x_train: training features 25 | - y_train: training labels 26 | - x_test: test features 27 | - y_test: test labels 28 | - model: a scikit-learn machine learning (untrained) model 29 | - scaler_type: choose between StandardScaler or MinMaxScaler 30 | - verbose: True of False to tune the level of verbosity 31 | 32 | Return: 33 | - the trained model 34 | """ 35 | 36 | # scale data 37 | if scaler_type == "StandardScaler": 38 | scaler = StandardScaler() 39 | elif scaler_type == "MinMaxScaler": 40 | scaler = MinMaxScaler() 41 | else: 42 | raise ValueError( 43 | "Allowed values for scaler_type are StandardScaler and MinMaxScaler" 44 | ) 45 | 46 | x_train = scaler.fit_transform(x_train) 47 | x_test = scaler.transform(x_test) 48 | 49 | # fit model 50 | if verbose: 51 | print("Fitting the model with", x_train.shape[1], "features") 52 | model.fit(x_train, y_train) 53 | train_score = round(model.score(x_train, y_train), 4) 54 | test_score = round(model.score(x_test, y_test), 4) 55 | 56 | if verbose: 57 | print("Train score", train_score) 58 | print("Test score", test_score) 59 | 60 | return model 61 | 62 | 63 | def get_feature_importances( 64 | trained_model: BaseEstimator, column_names: List[str] 65 | ) -> pd.DataFrame: 66 | """It computes the features importance, given a trained model. 67 | 68 | Parameters: 69 | - trained_model: a scikit-learn ML trained model 70 | - column_names: the name of the columns associated to the features 71 | 72 | Return: 73 | - a DataFrame containing the feature importance (not sorted) as column and 74 | the name of the features as index 75 | """ 76 | 77 | # inspect coefficients 78 | if hasattr(trained_model, "coef_"): 79 | model_coefficients = trained_model.coef_ 80 | elif hasattr(trained_model, "feature_importances_"): 81 | model_coefficients = trained_model.feature_importances_ 82 | else: 83 | raise ValueError("Could not retrieve the feature importance") 84 | 85 | df_coef = pd.DataFrame(model_coefficients, index=column_names) 86 | 87 | return df_coef 88 | 89 | 90 | def compute_mean_coefficients(df_coefs: pd.DataFrame) -> pd.DataFrame: 91 | """It computes the average coefficients, given a DataFrame with multiple columns. 92 | 93 | Parameters: 94 | - a DataFrame with coefficients obtained in multiple trainings 95 | 96 | Return: 97 | - a DataFrame with one column, containing the absolute values of the average coefficients 98 | """ 99 | 100 | if df_coefs.shape[1] > 1: 101 | df_coef = pd.DataFrame(df_coefs.mean(axis=1), columns=["Feature importance"]) 102 | else: 103 | print("Using this one") 104 | df_coef = pd.DataFrame(df_coefs.iloc[:, 0], index=df_coefs.index) 105 | df_coef.columns = ["Feature importance"] 106 | 107 | df_coef["Feature importance"] = np.abs(df_coef["Feature importance"]) 108 | df_coef["Feature name"] = df_coef.index 109 | df_coef = df_coef.sort_values("Feature importance", ascending=False) 110 | df_coef.reset_index(inplace=True, drop=True) 111 | 112 | return df_coef 113 | 114 | 115 | def select_relevant_features( 116 | df_coef: pd.DataFrame, features: pd.DataFrame, verbose: bool 117 | ) -> pd.DataFrame: 118 | """It computes the relevant features, given the DataFrame with feature importance 119 | and the original features. 120 | This is obtained by adding a feature with random noise. 121 | 122 | Parameters: 123 | - df_coef: the DataFrame with the the feature importance 124 | - features: the original features 125 | - verbose: True or False to tune the level of verbosity 126 | 127 | Return: 128 | - the simplified dataset, with the relevant features 129 | """ 130 | 131 | # select relevant features 132 | index_threshold = np.array( 133 | df_coef[df_coef["Feature name"] == "random_feature"].index 134 | )[0] 135 | relevant_features = df_coef.iloc[0:index_threshold] 136 | relevant_features = relevant_features["Feature name"] 137 | 138 | if verbose: 139 | print( 140 | "Selected", len(relevant_features), "features out of", features.shape[1] - 1 141 | ) 142 | 143 | # return simplified dataset, containing only relevant features 144 | simplified_dataset = features.loc[:, relevant_features] 145 | 146 | return simplified_dataset 147 | 148 | 149 | def generate_kfold_data( 150 | features: pd.DataFrame, labels: pd.DataFrame, random_state: int 151 | ) -> Tuple[List, List, List, List]: 152 | """It splits the data into training and validation, 153 | by using the KFold splitting method. 154 | 155 | Parameters: 156 | - features: the matrix with features, commonly called X 157 | - labels: the vector with labels, commonly called y 158 | 159 | Return: 160 | - train and test data 161 | """ 162 | 163 | x_trains = [] 164 | y_trains = [] 165 | x_tests = [] 166 | y_tests = [] 167 | 168 | k_fold = KFold(n_splits=5, random_state=random_state, shuffle=True) 169 | k_fold.get_n_splits(features) 170 | for _, (train_index, test_index) in enumerate(k_fold.split(features)): 171 | # train data 172 | x_trains.append(features.iloc[train_index, :]) 173 | y_trains.append(labels.iloc[train_index]) 174 | # test data 175 | x_tests.append(features.iloc[test_index, :]) 176 | y_tests.append(labels.iloc[test_index]) 177 | 178 | return x_trains, y_trains, x_tests, y_tests 179 | 180 | 181 | def train_with_kfold_splitting( 182 | features: pd.DataFrame, 183 | labels: pd.DataFrame, 184 | model: BaseEstimator, 185 | scaler_type: BaseEstimator, 186 | verbose: bool, 187 | random_state: int, 188 | ) -> pd.DataFrame: 189 | """It trains the model using the kfold splitting and returns 190 | a DataFrame with the feature importance. 191 | 192 | Parameters: 193 | - features: the matrix with features, commonly called X 194 | - labels: the vector with labels, commonly called y 195 | - model: an untrained scikit-learn model 196 | - scaler_type: choose between StandardScaler or MinMaxScaler 197 | - verbose: True or False to tune the level of verbosity 198 | - random_state: select the random state of the train/test splitting process 199 | 200 | Return: 201 | - a DataFrame with one column, containing the features importance (or the coefficients) 202 | """ 203 | 204 | # create train-test data 205 | x_trains, y_trains, x_tests, y_tests = generate_kfold_data( 206 | features, labels, random_state 207 | ) 208 | 209 | for i in range(len(x_trains)): 210 | trained_model = train_evaluate_model( 211 | x_trains[i], 212 | y_trains[i], 213 | x_tests[i], 214 | y_tests[i], 215 | model, 216 | scaler_type, 217 | verbose, 218 | ) 219 | if i == 0: 220 | df_coefs = get_feature_importances(trained_model, x_trains[i].columns) 221 | df_coefs.columns = ["cycle_" + str(i + 1)] 222 | else: 223 | df_coefs["cycle_" + str(i + 1)] = get_feature_importances( 224 | trained_model, x_trains[i].columns 225 | ) 226 | 227 | df_coef = compute_mean_coefficients(df_coefs) 228 | return df_coef 229 | 230 | 231 | def train_with_simple_splitting( 232 | features: pd.DataFrame, 233 | labels: pd.DataFrame, 234 | model: BaseEstimator, 235 | scaler_type: BaseEstimator, 236 | verbose: bool, 237 | random_state: int, 238 | ) -> pd.DataFrame: 239 | """It trains the model using the train/test splitting and returns 240 | a DataFrame with the feature importance. 241 | 242 | Parameters: 243 | - features: the matrix with features, commonly called X 244 | - labels: the vector with labels, commonly called y 245 | - model: an untrained scikit-learn model 246 | - scaler_type: choose between StandardScaler or MinMaxScaler 247 | - verbose: True or False to tune the level of verbosity 248 | - random_state: select the random state of the train/test splitting process 249 | 250 | Return: 251 | - a DataFrame with one column, containing the features importance (or the coefficients) 252 | """ 253 | 254 | # create train-test data 255 | x_train, x_test, y_train, y_test = train_test_split( 256 | features, labels, test_size=0.2, random_state=random_state 257 | ) 258 | 259 | trained_model = train_evaluate_model( 260 | x_train, y_train, x_test, y_test, model, scaler_type, verbose 261 | ) 262 | df_coefs = get_feature_importances(trained_model, x_train.columns) 263 | 264 | df_coef = compute_mean_coefficients(df_coefs) 265 | 266 | return df_coef 267 | 268 | 269 | def scan_features_pipeline( 270 | features: pd.DataFrame, 271 | labels: pd.DataFrame, 272 | model: BaseEstimator, 273 | splitting_type: str, 274 | verbose: bool, 275 | random_state: int, 276 | noise_type: str, 277 | ) -> pd.DataFrame: 278 | """This pipeline performs various operations: 279 | - train and evaluate the model 280 | - generates the DataFrame with the feature importance 281 | - computes the simplified dataset, containing only the relevant features 282 | 283 | Parameters: 284 | - features: the matrix with features, commonly called X 285 | - labels: the vector with labels, commonly called y 286 | - model: an untrained scikit-learn model 287 | - splitting_type: choose between "simple" (80% train, 20% test) 288 | or "kfold" (5-fold splitting) 289 | - verbose: True or False to tune the level of verbosity 290 | - random_state: select the random state of the train/test splitting process 291 | - noise_type: choose between "gaussian" noise or "random" (flat) noise 292 | 293 | Return: 294 | - the simplified dataset, containing only the most relevant features 295 | """ 296 | 297 | #  add noise 298 | x_new = features.copy(deep=True) 299 | 300 | if noise_type == "gaussian": 301 | x_new["random_feature"] = np.random.normal(0, 1, size=len(x_new)) 302 | scaler_type = "StandardScaler" 303 | elif noise_type == "random": 304 | x_new["random_feature"] = np.random.rand(len(x_new)) 305 | scaler_type = "MinMaxScaler" 306 | else: 307 | raise ValueError("Allowed values for noise_type are gaussian and random") 308 | 309 | if splitting_type == "kfold": 310 | df_coef = train_with_kfold_splitting( 311 | x_new, labels, model, scaler_type, verbose, random_state 312 | ) 313 | elif splitting_type == "simple": 314 | df_coef = train_with_simple_splitting( 315 | x_new, labels, model, scaler_type, verbose, random_state 316 | ) 317 | else: 318 | raise ValueError("Choice not recognized. Possible choices are kfold or simple") 319 | 320 | simplified_dataset = select_relevant_features(df_coef, x_new, verbose) 321 | 322 | return simplified_dataset 323 | 324 | 325 | def get_relevant_features( 326 | features: pd.DataFrame, 327 | labels: pd.DataFrame, 328 | model: BaseEstimator, 329 | splitting_type: str, 330 | epochs: int, 331 | patience: int, 332 | noise_type: str = "gaussian", 333 | verbose: bool = True, 334 | filename_output: Optional[str] = None, 335 | random_state: int = 42, 336 | ) -> pd.DataFrame: 337 | """This functions performs multiple cycles to reduce the dimension of the dataset. 338 | 339 | Parameters: 340 | - features: the matrix with features, commonly called X 341 | - labels: the vector with labels, commonly called y 342 | - model: an untrained scikit-learn model 343 | - splitting_type: choose between "simple" (80% train, 20% test) 344 | or "kfold" (5-fold splitting) 345 | - epochs: the number of epochs (or cycles) 346 | - patience: the number of cycles of non-improvement to wait before stopping 347 | the execution of the code 348 | - noise_type: choose between "gaussian" noise or "random" (flat) noise 349 | - verbose: True or False, to tune the level of verbosity 350 | - filename_output: name of the simplified dataset if you want to export it, default is None 351 | - random_state: select the random seed 352 | 353 | Return: 354 | - the dataset simplified after multiple epochs of feature selection 355 | """ 356 | 357 | x_new = features.copy(deep=True) 358 | counter_patience = 0 359 | epoch = 0 360 | 361 | np.random.seed(random_state) 362 | random_states = np.random.randint(1, int(10 * epochs), size=epochs) 363 | 364 | while (counter_patience < patience) and (epoch < epochs): 365 | n_features_before = x_new.shape[1] 366 | print("=====================EPOCH", epoch + 1, "=====================") 367 | x_new = scan_features_pipeline( 368 | x_new, 369 | labels, 370 | model, 371 | splitting_type, 372 | verbose, 373 | random_states[epoch], 374 | noise_type, 375 | ) 376 | n_features_after = x_new.shape[1] 377 | 378 | if n_features_before == n_features_after: 379 | counter_patience += 1 380 | print( 381 | "The feature selection did not improve in the last", 382 | counter_patience, 383 | "epochs", 384 | ) 385 | else: 386 | counter_patience = 0 387 | 388 | epoch += 1 389 | 390 | if filename_output is not None: 391 | x_new.to_csv(filename_output, index=False) 392 | 393 | return x_new 394 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.linear_model import Lasso 4 | from sklearn.ensemble import RandomForestRegressor 5 | from src.data import create_regression_data 6 | from src.ml import ( 7 | train_evaluate_model, 8 | get_feature_importances, 9 | compute_mean_coefficients, 10 | select_relevant_features, 11 | generate_kfold_data, 12 | train_with_kfold_splitting, 13 | train_with_simple_splitting, 14 | scan_features_pipeline, 15 | get_relevant_features, 16 | ) 17 | 18 | 19 | def test_data_creation(): 20 | features, labels = create_regression_data() 21 | assert features.shape == (1000, 300), "Shape of features is wrong" 22 | assert labels.shape == (1000, 1), "Shape of labels is wrong" 23 | 24 | 25 | def test_train_evaluate(): 26 | x_train = np.random.rand(100, 10) 27 | y_train = np.random.rand(100, 1) 28 | x_test = np.random.rand(20, 10) 29 | y_test = np.random.rand(20, 1) 30 | model = Lasso() 31 | trained_model = train_evaluate_model( 32 | x_train, 33 | y_train, 34 | x_test, 35 | y_test, 36 | model, 37 | scaler_type="StandardScaler", 38 | verbose=False, 39 | ) 40 | assert ( 41 | len(trained_model.coef_) == x_train.shape[1] 42 | ), "The model is not trained properly" 43 | 44 | 45 | def test_get_features_importance(): 46 | x_train = np.random.rand(100, 10) 47 | y_train = np.random.rand(100) 48 | column_names = np.arange(0, x_train.shape[1]) 49 | 50 | lasso = Lasso() 51 | lasso.fit(x_train, y_train) 52 | df_coef = get_feature_importances(lasso, column_names) 53 | assert ( 54 | type(df_coef) == pd.DataFrame 55 | ), "The table with feature importance must be a DataFrame" 56 | assert ( 57 | len(df_coef) == x_train.shape[1] 58 | ), "The number of coefficients does not match the shape of the training data" 59 | 60 | rf = RandomForestRegressor() 61 | rf.fit(x_train, y_train) 62 | df_coef = get_feature_importances(rf, column_names) 63 | assert ( 64 | type(df_coef) == pd.DataFrame 65 | ), "The table with feature importance must be a DataFrame" 66 | assert ( 67 | len(df_coef) == x_train.shape[1] 68 | ), "The number of coefficients does not match the shape of the training data" 69 | 70 | 71 | def test_mean_coefficients_single_column(): 72 | feature_importance = np.random.randint(-100, 100, size=20) 73 | df = pd.DataFrame(feature_importance, index=np.arange(0, len(feature_importance))) 74 | vec = np.sort(np.abs(feature_importance))[::-1] 75 | df_sorted = compute_mean_coefficients(df) 76 | assert all( 77 | np.array(df_sorted["Feature importance"]) == vec 78 | ), "Feature importances are not sorted properly" 79 | 80 | 81 | def test_mean_coefficients_multiple_columns(): 82 | feature_importance = 2 * np.random.rand(100, 5) - 1 83 | df = pd.DataFrame(feature_importance, index=np.arange(0, len(feature_importance))) 84 | vec = np.sort(np.abs(df.mean(axis=1)))[::-1] 85 | df_sorted = compute_mean_coefficients(df) 86 | assert all( 87 | np.array(df_sorted["Feature importance"]) == vec 88 | ), "Feature importances are not sorted properly" 89 | 90 | 91 | def test_select_relevant_features(): 92 | df_coef = pd.DataFrame([5, 4, 3, 2, 1], columns=["Feature importance"]) 93 | df_coef["Feature name"] = ["col1", "col2", "random_feature", "col3", "col4"] 94 | features = pd.DataFrame( 95 | np.random.rand(10, 5), 96 | columns=["col1", "col2", "random_feature", "col3", "col4"], 97 | ) 98 | feature_selected = select_relevant_features(df_coef, features, verbose=True) 99 | assert all(feature_selected.columns == ["col1", "col2"]), "Wrong columns selected" 100 | 101 | 102 | def test_kfold_splitting(): 103 | features = pd.DataFrame(np.random.rand(100, 10)) 104 | labels = pd.DataFrame(np.random.rand(100)) 105 | x_trains, y_trains, x_tests, y_tests = generate_kfold_data( 106 | features, labels, random_state=42 107 | ) 108 | assert len(x_trains) == 5, "Length of train features is wrong" 109 | assert len(x_tests) == 5, "Length of test features is wrong" 110 | assert len(y_trains) == 5, "Length of train labels is wrong" 111 | assert len(y_tests) == 5, "Length of test labels is wrong" 112 | 113 | 114 | def test_train_kfold_splitting(): 115 | features = pd.DataFrame(np.random.rand(100, 10)) 116 | labels = pd.DataFrame(np.random.rand(100)) 117 | model = Lasso() 118 | df_coef = train_with_kfold_splitting( 119 | features, 120 | labels, 121 | model, 122 | scaler_type="StandardScaler", 123 | verbose=True, 124 | random_state=42, 125 | ) 126 | assert type(df_coef) == pd.DataFrame, "df_coef must be a Pandas DataFrame" 127 | assert ( 128 | len(df_coef) == features.shape[1] 129 | ), "The length of df_coef must match the number of features" 130 | 131 | 132 | def test_train_simple_splitting(): 133 | features = pd.DataFrame(np.random.rand(100, 10)) 134 | labels = pd.DataFrame(np.random.rand(100)) 135 | model = Lasso() 136 | df_coef = train_with_simple_splitting( 137 | features, 138 | labels, 139 | model, 140 | scaler_type="MinMaxScaler", 141 | verbose=True, 142 | random_state=42, 143 | ) 144 | assert type(df_coef) == pd.DataFrame, "df_coef must be a Pandas DataFrame" 145 | assert ( 146 | len(df_coef) == features.shape[1] 147 | ), "The length of df_coef must match the number of features" 148 | 149 | 150 | def test_scan_feature_pipeline(): 151 | features, labels = create_regression_data() 152 | model = Lasso() 153 | reduced_features = scan_features_pipeline( 154 | features, 155 | labels, 156 | model, 157 | splitting_type="simple", 158 | verbose=False, 159 | random_state=43, 160 | noise_type="gaussian", 161 | ) 162 | assert ( 163 | reduced_features.shape[1] < features.shape[1] 164 | ), "The pipeline did not reduce the number of features" 165 | 166 | reduced_features = scan_features_pipeline( 167 | features, 168 | labels, 169 | model, 170 | splitting_type="kfold", 171 | verbose=False, 172 | random_state=43, 173 | noise_type="random", 174 | ) 175 | assert ( 176 | reduced_features.shape[1] < features.shape[1] 177 | ), "The pipeline did not reduce the number of features" 178 | 179 | 180 | def test_get_relevant_features(): 181 | features, labels = create_regression_data() 182 | model = Lasso() 183 | 184 | x_new = get_relevant_features( 185 | features, 186 | labels, 187 | model, 188 | splitting_type="simple", 189 | epochs=10, 190 | patience=5, 191 | random_state=41, 192 | ) 193 | 194 | assert ( 195 | x_new.shape[1] < features.shape[1] 196 | ), "The pipeline did not reduce the number of features" 197 | 198 | x_new = get_relevant_features( 199 | features, 200 | labels, 201 | model, 202 | splitting_type="kfold", 203 | epochs=10, 204 | patience=5, 205 | random_state=41, 206 | ) 207 | 208 | assert ( 209 | x_new.shape[1] < features.shape[1] 210 | ), "The pipeline did not reduce the number of features" 211 | --------------------------------------------------------------------------------