├── .gitattributes ├── .gitignore ├── README.md ├── config ├── hybrid-fiq.yaml ├── hybrid-piq.yaml └── hybrid-viq.yaml ├── data └── ABIDE │ └── Rest │ ├── data.h5 │ ├── test.split │ └── train.split ├── environment.yml ├── main.py ├── scripts ├── download_and_preprocess_ABIDE.py └── subject_ID.txt └── src ├── __init__.py ├── cpm.py ├── dataset.py ├── hybrid ├── __init__.py └── model.py └── model.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | outputs 2 | data/ABCD* 3 | 4 | *.DS_Store 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | *.png 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | .vscode 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔥 [ICML 2024] Learning High-Order Relationships of Brain Regions 2 | 3 | This is the official implementation of ICML 2024 paper [Learning High-Order Relationships of Brain Regions](https://arxiv.org/abs/2312.02203). 4 | 5 | 6 | ## Installation 7 | Make sure you have [git-lfs](https://git-lfs.com/) installed in order to [clone](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) the preprocessed dataset. Altenatively, you can download from [Google Drive](https://drive.google.com/drive/folders/1SvhOlPAIHVX4AYy-hU9Ik7-lKX7u1Ti2?usp=sharing). 8 | After cloning the repo, please check the provided `environment.yml` to install the [conda environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-from-an-environment-yml-file). 9 | ``` 10 | conda env create -f environment.yml 11 | ``` 12 | 13 | ### Dataset 14 | Besides directly downloading the data from the host, `scripts/download_and_preprocess_ABIDE.py` is used to download and preprocess the ABIDE dataset. 15 | ``` 16 | python scripts/download_and_preprocess_ABIDE.py /path/to/your/output.h5 17 | ``` 18 | 19 | ## Usage 20 | ### If you want to integrate HyBRiD into your own project 21 | Copy the folder `src/hybrid` to your local storage, and use it by `from hybrid import HyBRiD`. Check the docstring and type hints in the file for more details. 22 | 23 | **Note:** 24 | The repo is implemented in `python3.10` and I use the new typing convention (e.g. `list[int]` instead of `List[int]`) so it is not backward compatible. However, adapting it to a lower version is always straightforward. 25 | 26 | ### If you want to run our experiment 27 | Make sure you follow the guidance in the **Installation** section and run the following command 28 | ```shell 29 | python main.py -c config/hybrid-piq.yaml 30 | ``` 31 | This will train the model and report the metrics on the ABIDE PIQ task. 32 | -------------------------------------------------------------------------------- /config/hybrid-fiq.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 43 2 | trainer: 3 | logger: 4 | class_path: lightning.pytorch.loggers.WandbLogger 5 | init_args: 6 | project: "hybrid" 7 | save_dir: "outputs" 8 | max_epochs: 250 9 | check_val_every_n_epoch: 10 10 | accelerator: "auto" 11 | gradient_clip_val: 0.5 12 | deterministic: false 13 | 14 | callbacks: 15 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 16 | init_args: 17 | save_last: true 18 | data: 19 | class_path: src.dataset.BrainDataModule 20 | init_args: 21 | dataset_keys: 22 | - ABIDE-Rest 23 | y_key: "fiq" 24 | batch_size: 4 25 | num_workers: 8 26 | model: 27 | class_path: src.model.RegressionModule 28 | init_args: 29 | learning_rate: 0.0001 30 | weight_decay: 0.01 31 | beta: 0.2 32 | model: 33 | class_path: src.hybrid.model.HyBRiD 34 | init_args: 35 | n_nodes: 200 36 | n_hypers: 32 37 | hidden_size: 256 38 | dropout: 0.1 39 | -------------------------------------------------------------------------------- /config/hybrid-piq.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 43 2 | trainer: 3 | logger: 4 | class_path: lightning.pytorch.loggers.WandbLogger 5 | init_args: 6 | project: "hybrid" 7 | save_dir: "outputs" 8 | max_epochs: 250 9 | check_val_every_n_epoch: 10 10 | accelerator: "auto" 11 | gradient_clip_val: 0.5 12 | deterministic: false 13 | 14 | callbacks: 15 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 16 | init_args: 17 | save_last: true 18 | data: 19 | class_path: src.dataset.BrainDataModule 20 | init_args: 21 | dataset_keys: 22 | - ABIDE-Rest 23 | y_key: "piq" 24 | batch_size: 8 25 | num_workers: 8 26 | model: 27 | class_path: src.model.RegressionModule 28 | init_args: 29 | learning_rate: 0.0001 30 | weight_decay: 0.01 31 | beta: 0.2 32 | model: 33 | class_path: src.hybrid.model.HyBRiD 34 | init_args: 35 | n_nodes: 200 36 | n_hypers: 32 37 | hidden_size: 256 38 | dropout: 0.1 39 | -------------------------------------------------------------------------------- /config/hybrid-viq.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 43 2 | trainer: 3 | logger: 4 | class_path: lightning.pytorch.loggers.WandbLogger 5 | init_args: 6 | project: "hybrid" 7 | save_dir: "outputs" 8 | max_epochs: 100 9 | check_val_every_n_epoch: 10 10 | accelerator: "auto" 11 | gradient_clip_val: 0.5 12 | deterministic: false 13 | 14 | callbacks: 15 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 16 | init_args: 17 | save_last: true 18 | data: 19 | class_path: src.dataset.BrainDataModule 20 | init_args: 21 | dataset_keys: 22 | - ABIDE-Rest 23 | y_key: "piq" 24 | batch_size: 4 25 | num_workers: 8 26 | model: 27 | class_path: src.model.RegressionModule 28 | init_args: 29 | learning_rate: 0.0001 30 | weight_decay: 0.01 31 | beta: 0.2 32 | model: 33 | class_path: src.hybrid.model.HyBRiD 34 | init_args: 35 | n_nodes: 200 36 | n_hypers: 32 37 | hidden_size: 256 38 | dropout: 0.1 39 | -------------------------------------------------------------------------------- /data/ABIDE/Rest/data.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:49b15d9938988fe9dc9b10934a2ade04f771c3a4114624a45431ad23143dc00f 3 | size 655944960 4 | -------------------------------------------------------------------------------- /data/ABIDE/Rest/test.split: -------------------------------------------------------------------------------- 1 | 50003 2 | 50310 3 | 50324 4 | 51111 5 | 51074 6 | 50556 7 | 50786 8 | 51151 9 | 50726 10 | 51140 11 | 51480 12 | 50608 13 | 50331 14 | 50040 15 | 50160 16 | 51478 17 | 51257 18 | 50797 19 | 50164 20 | 51215 21 | 50023 22 | 51153 23 | 50659 24 | 50118 25 | 50755 26 | 50322 27 | 50340 28 | 50667 29 | 51155 30 | 50124 31 | 50131 32 | 51191 33 | 50663 34 | 51562 35 | 50058 36 | 50213 37 | 50998 38 | 51044 39 | 51216 40 | 50802 41 | 50283 42 | 50781 43 | 51129 44 | 51210 45 | 50553 46 | 51607 47 | 50687 48 | 51128 49 | 50710 50 | 50109 51 | 50204 52 | 50274 53 | 51066 54 | 50496 55 | 50610 56 | 50785 57 | 50426 58 | 50007 59 | 50728 60 | 51234 61 | 51055 62 | 50368 63 | 50262 64 | 50954 65 | 50964 66 | 50725 67 | 51199 68 | 51231 69 | 50168 70 | 51010 71 | 50555 72 | 51295 73 | 51212 74 | 51103 75 | 51296 76 | 51564 77 | 51326 78 | 50211 79 | 50789 80 | 50144 81 | 50374 82 | 50406 83 | 50390 84 | 51134 85 | 50339 86 | 50440 87 | 51320 88 | 51006 89 | 50353 90 | 50405 91 | 50235 92 | 50993 93 | 50753 94 | 50285 95 | 51152 96 | 50488 97 | 50757 98 | 50252 99 | 50698 100 | 50327 101 | 51576 102 | 50561 103 | 51058 104 | 50500 105 | 50319 106 | 50494 107 | 50959 108 | 50605 109 | 50527 110 | 50609 111 | 50338 112 | 51175 113 | 51075 114 | 51359 115 | 50198 116 | 51123 117 | 51293 118 | 50260 119 | 50038 120 | 51263 121 | 50523 122 | 51581 123 | 51164 124 | 50563 125 | 51307 126 | 51266 127 | 51116 128 | 51018 129 | 50271 130 | 50385 131 | 50560 132 | 51457 133 | 50438 134 | 50059 135 | 51052 136 | 51223 137 | 50102 138 | 50955 139 | 50362 140 | 51170 141 | 50650 142 | 51211 143 | 50746 144 | 51094 145 | 51330 146 | 51493 147 | 50519 148 | 50705 149 | 50106 150 | 51173 151 | 51201 152 | 51090 153 | 50266 154 | 50149 155 | 50288 156 | 50214 157 | 50497 158 | 51107 159 | 51061 160 | 51323 161 | 50995 162 | 50615 163 | 51189 164 | 51477 165 | 51146 166 | 50779 167 | 51357 168 | 50125 169 | 50056 170 | 50643 171 | 50706 172 | 50035 173 | 51203 174 | 50822 175 | 51325 176 | 50814 177 | 50237 178 | 50655 179 | 51183 180 | 50206 181 | 50276 182 | 51345 183 | 51492 184 | 50509 185 | 50359 186 | 51112 187 | 51091 188 | 50012 189 | 50194 190 | 50294 191 | 51573 192 | 50036 193 | 50415 194 | 50402 195 | 50961 196 | 51316 197 | 50532 198 | 50807 199 | 50242 200 | 51559 201 | 51458 202 | 51355 203 | 51214 204 | 50685 205 | 50477 206 | 50515 207 | 50349 208 | -------------------------------------------------------------------------------- /data/ABIDE/Rest/train.split: -------------------------------------------------------------------------------- 1 | 50699 2 | 50694 3 | 51221 4 | 50268 5 | 50006 6 | 50041 7 | 50815 8 | 51073 9 | 50295 10 | 50653 11 | 51482 12 | 51139 13 | 50145 14 | 50727 15 | 51029 16 | 51098 17 | 50163 18 | 50453 19 | 50216 20 | 50651 21 | 51258 22 | 50159 23 | 50013 24 | 50234 25 | 51100 26 | 51163 27 | 50756 28 | 50660 29 | 50162 30 | 50642 31 | 50364 32 | 51275 33 | 50485 34 | 50199 35 | 50604 36 | 50489 37 | 51197 38 | 50360 39 | 50263 40 | 51011 41 | 51101 42 | 50169 43 | 51265 44 | 50156 45 | 50046 46 | 50212 47 | 50019 48 | 50264 49 | 50689 50 | 51039 51 | 51346 52 | 50284 53 | 51461 54 | 51034 55 | 51237 56 | 51340 57 | 50250 58 | 50308 59 | 51076 60 | 50825 61 | 51057 62 | 51568 63 | 50030 64 | 51305 65 | 51067 66 | 50414 67 | 50774 68 | 50315 69 | 50350 70 | 51229 71 | 50386 72 | 50104 73 | 51124 74 | 50791 75 | 51370 76 | 51069 77 | 50411 78 | 50480 79 | 51491 80 | 51138 81 | 50487 82 | 51565 83 | 51065 84 | 50623 85 | 51121 86 | 50306 87 | 50166 88 | 50387 89 | 51142 90 | 51161 91 | 50290 92 | 50463 93 | 50750 94 | 50818 95 | 51117 96 | 50383 97 | 50265 98 | 50342 99 | 51225 100 | 50372 101 | 50992 102 | 51277 103 | 50251 104 | 50428 105 | 51114 106 | 51150 107 | 50745 108 | 50622 109 | 51027 110 | 50481 111 | 50708 112 | 51042 113 | 51020 114 | 50688 115 | 50334 116 | 51238 117 | 50325 118 | 51557 119 | 51240 120 | 51070 121 | 50376 122 | 50965 123 | 50015 124 | 51474 125 | 50612 126 | 50191 127 | 51304 128 | 50188 129 | 51135 130 | 50991 131 | 51300 132 | 50055 133 | 50570 134 | 51560 135 | 51198 136 | 51030 137 | 50744 138 | 51228 139 | 51016 140 | 50960 141 | 50620 142 | 50987 143 | 51299 144 | 50049 145 | 50377 146 | 51224 147 | 50686 148 | 51051 149 | 51082 150 | 50133 151 | 50272 152 | 50996 153 | 50289 154 | 50120 155 | 51485 156 | 51019 157 | 50812 158 | 51308 159 | 50298 160 | 51089 161 | 50412 162 | 51279 163 | 50418 164 | 51466 165 | 50182 166 | 51171 167 | 50148 168 | 51472 169 | 50983 170 | 50749 171 | 50135 172 | 50603 173 | 51239 174 | 50370 175 | 51241 176 | 50397 177 | 50754 178 | 51205 179 | 50490 180 | 50114 181 | 51301 182 | 50425 183 | 51071 184 | 51566 185 | 50103 186 | 50314 187 | 51193 188 | 51319 189 | 50217 190 | 50788 191 | 50355 192 | 51167 193 | 51118 194 | 50690 195 | 50524 196 | 50646 197 | 50736 198 | 51001 199 | 51180 200 | 50737 201 | 50281 202 | 51322 203 | 51579 204 | 50367 205 | 51317 206 | 51577 207 | 50656 208 | 50042 209 | 50571 210 | 50011 211 | 50443 212 | 50738 213 | 50060 214 | 51343 215 | 50486 216 | 50029 217 | 51113 218 | 51147 219 | 50033 220 | 50126 221 | 50518 222 | 51486 223 | 51021 224 | 50245 225 | 50379 226 | 50800 227 | 50793 228 | 50320 229 | 50107 230 | 50130 231 | 51102 232 | 50772 233 | 50128 234 | 50115 235 | 50968 236 | 50783 237 | 51318 238 | 51219 239 | 50291 240 | 51047 241 | 51174 242 | 51169 243 | 50183 244 | 50236 245 | 51165 246 | 51032 247 | 51264 248 | 50207 249 | 50665 250 | 51462 251 | 51080 252 | 50530 253 | 51190 254 | 51186 255 | 51096 256 | 50780 257 | 51053 258 | 50491 259 | 51578 260 | 51278 261 | 51012 262 | 50492 263 | 50657 264 | 51083 265 | 50205 266 | 51254 267 | 50439 268 | 50601 269 | 51563 270 | 51364 271 | 51084 272 | 51110 273 | 51351 274 | 50644 275 | 51267 276 | 51353 277 | 51213 278 | 50127 279 | 50711 280 | 51227 281 | 50408 282 | 50171 283 | 50146 284 | 51217 285 | 50318 286 | 50239 287 | 50821 288 | 51063 289 | 51255 290 | 51054 291 | 50196 292 | 50280 293 | 50348 294 | 50986 295 | 50967 296 | 50647 297 | 50695 298 | 50241 299 | 50399 300 | 50652 301 | 50747 302 | 51025 303 | 50170 304 | 51314 305 | 50658 306 | 50142 307 | 50202 308 | 50111 309 | 51148 310 | 50192 311 | 50005 312 | 50748 313 | 50696 314 | 51104 315 | 51556 316 | 50782 317 | 50529 318 | 50361 319 | 50014 320 | 50053 321 | 51347 322 | 51294 323 | 50503 324 | 50564 325 | 50044 326 | 50483 327 | 50121 328 | 50034 329 | 51465 330 | 50693 331 | 51460 332 | 50215 333 | 51206 334 | 50388 335 | 50313 336 | 51208 337 | 51481 338 | 51002 339 | 50979 340 | 50978 341 | 50470 342 | 50321 343 | 50105 344 | 50707 345 | 51166 346 | 51003 347 | 50261 348 | 50554 349 | 51574 350 | 50253 351 | 50803 352 | 51184 353 | 51109 354 | 50790 355 | 50273 356 | 50363 357 | 50299 358 | 50150 359 | 50819 360 | 51207 361 | 50022 362 | 51298 363 | 51078 364 | 50625 365 | 51160 366 | 50255 367 | 50025 368 | 50407 369 | 51327 370 | 50110 371 | 51188 372 | 50436 373 | 51040 374 | 51230 375 | 51575 376 | 51280 377 | 50468 378 | 51226 379 | 51580 380 | 51324 381 | 51483 382 | 50433 383 | 50823 384 | 50152 385 | 50624 386 | 50446 387 | 50043 388 | 50502 389 | 50010 390 | 51085 391 | 51149 392 | 50982 393 | 51008 394 | 51187 395 | 51250 396 | 50445 397 | 51572 398 | 50573 399 | 50032 400 | 50333 401 | 51023 402 | 50568 403 | 50956 404 | 51038 405 | 50972 406 | 51268 407 | 51360 408 | 51204 409 | 51172 410 | 50045 411 | 50382 412 | 50193 413 | 50371 414 | 50441 415 | 50050 416 | 50004 417 | 50365 418 | 51253 419 | 50017 420 | 50435 421 | 50243 422 | 51262 423 | 50990 424 | 50292 425 | 51363 426 | 51122 427 | 51273 428 | 51136 429 | 50989 430 | 51050 431 | 50621 432 | 51356 433 | 50505 434 | 51479 435 | 51028 436 | 50702 437 | 50559 438 | 51584 439 | 50048 440 | 51024 441 | 51292 442 | 50380 443 | 51046 444 | 50358 445 | 51373 446 | 51281 447 | 50419 448 | 51291 449 | 50422 450 | 50158 451 | 50973 452 | 50332 453 | 50981 454 | 50787 455 | 51009 456 | 50421 457 | 51282 458 | 51086 459 | 51469 460 | 51007 461 | 50525 462 | 50499 463 | 51488 464 | 50526 465 | 50552 466 | 51261 467 | 50562 468 | 51272 469 | 50627 470 | 50197 471 | 50247 472 | 50962 473 | 50574 474 | 50203 475 | 50008 476 | 50300 477 | 50799 478 | 51274 479 | 51195 480 | 50316 481 | 50161 482 | 50469 483 | 50801 484 | 50326 485 | 51178 486 | 50997 487 | 51154 488 | 50775 489 | 51309 490 | 50257 491 | 50208 492 | 50335 493 | 50329 494 | 50668 495 | 50692 496 | 50984 497 | 51459 498 | 50112 499 | 50700 500 | 50566 501 | 50336 502 | 51334 503 | 50703 504 | 50352 505 | 51192 506 | 50732 507 | 50970 508 | 50575 509 | 50275 510 | 51202 511 | 50613 512 | 51570 513 | 50776 514 | 51095 515 | 51236 516 | 51276 517 | 50778 518 | 50507 519 | 50184 520 | 51349 521 | 51220 522 | 51196 523 | 50132 524 | 51141 525 | 51312 526 | 50391 527 | 50201 528 | 50369 529 | 50467 530 | 51159 531 | 50511 532 | 50031 533 | 50617 534 | 50691 535 | 51072 536 | 51015 537 | 51045 538 | 50577 539 | 50248 540 | 50493 541 | 50427 542 | 51335 543 | 50424 544 | 50293 545 | 50504 546 | 50444 547 | 50434 548 | 50286 549 | 50051 550 | 50210 551 | 51087 552 | 51348 553 | 51017 554 | 51487 555 | 50287 556 | 50448 557 | 51344 558 | 51473 559 | 50666 560 | 50664 561 | 50498 562 | 50551 563 | 51260 564 | 51338 565 | 50356 566 | 50682 567 | 50410 568 | 50189 569 | 51463 570 | 51168 571 | 51097 572 | 51060 573 | 51064 574 | 50123 575 | 51088 576 | 51068 577 | 51252 578 | 50269 579 | 50565 580 | 51014 581 | 50730 582 | 50735 583 | 50773 584 | 51365 585 | 51099 586 | 50134 587 | 50259 588 | 51311 589 | 50824 590 | 51013 591 | 51126 592 | 50516 593 | 50246 594 | 50437 595 | 50722 596 | 51181 597 | 50301 598 | 50232 599 | 50249 600 | 50796 601 | 50449 602 | 51248 603 | 50020 604 | 51606 605 | 50826 606 | 51333 607 | 50167 608 | 51468 609 | 50282 610 | 50375 611 | 51341 612 | 50611 613 | 50195 614 | 51177 615 | 50741 616 | 51036 617 | 50619 618 | 50794 619 | 50704 620 | 50337 621 | 50297 622 | 50521 623 | 51313 624 | 50404 625 | 50628 626 | 51131 627 | 51059 628 | 50816 629 | 50143 630 | 50209 631 | 50416 632 | 51476 633 | 50136 634 | 50558 635 | 51350 636 | 50302 637 | 50366 638 | 50792 639 | 50009 640 | 51361 641 | 50190 642 | 50777 643 | 51000 644 | 50966 645 | 50724 646 | 50743 647 | 50233 648 | 50957 649 | 51329 650 | 50482 651 | 51271 652 | 51470 653 | 51471 654 | 51569 655 | 51331 656 | 51048 657 | 51194 658 | 51558 659 | 51041 660 | 50305 661 | 50654 662 | 51035 663 | 51582 664 | 51079 665 | 50442 666 | 50147 667 | 51269 668 | 50119 669 | 50277 670 | 51162 671 | 50157 672 | 50304 673 | 51339 674 | 50117 675 | 50795 676 | 50648 677 | 51182 678 | 51328 679 | 50413 680 | 50296 681 | 50804 682 | 50977 683 | 51369 684 | 51336 685 | 51137 686 | 51342 687 | 51456 688 | 50312 689 | 50417 690 | 50602 691 | 50343 692 | 51132 693 | 50113 694 | 51567 695 | 51179 696 | 50501 697 | 50016 698 | 50798 699 | 51302 700 | 51256 701 | 51033 702 | 50618 703 | 50279 704 | 50731 705 | 51106 706 | 51583 707 | 51081 708 | 50733 709 | 51130 710 | 51585 711 | 50381 712 | 50026 713 | 51218 714 | 51484 715 | 51156 716 | 50557 717 | 50346 718 | 50531 719 | 50024 720 | 51315 721 | 50052 722 | 50976 723 | 50567 724 | 50200 725 | 51467 726 | 50661 727 | 50027 728 | 51306 729 | 50697 730 | 50739 731 | 50742 732 | 50528 733 | 50122 734 | 50185 735 | 50344 736 | 50514 737 | 50614 738 | 51056 739 | 50606 740 | 51362 741 | 50952 742 | 50057 743 | 50116 744 | 51062 745 | 50129 746 | 50607 747 | 50403 748 | 50751 749 | 50569 750 | 50466 751 | 51464 752 | 50303 753 | 50817 754 | 50994 755 | 51026 756 | 50447 757 | 50317 758 | 50683 759 | 50510 760 | 50752 761 | 50330 762 | 50153 763 | 50649 764 | 50520 765 | 50958 766 | 50341 767 | 50616 768 | 50576 769 | 50373 770 | 50969 771 | 50985 772 | 51354 773 | 50740 774 | 50307 775 | 50267 776 | 50270 777 | 51297 778 | 51127 779 | 50988 780 | 50572 781 | 50028 782 | 50709 783 | 50187 784 | 51222 785 | 51321 786 | 51561 787 | 50999 788 | 51185 789 | 51077 790 | 50354 791 | 50278 792 | 51490 793 | 51133 794 | 50186 795 | 51251 796 | 51332 797 | 50054 798 | 51249 799 | 50311 800 | 51093 801 | 50669 802 | 50701 803 | 50626 804 | 50254 805 | 51049 806 | 51352 807 | 50347 808 | 51358 809 | 51209 810 | 51489 811 | 51571 812 | 50345 813 | 50357 814 | 50578 815 | 50240 816 | 50351 817 | 50974 818 | 51105 819 | 50037 820 | 50820 821 | 50645 822 | 50455 823 | 51303 824 | 50723 825 | 51235 826 | 50047 827 | 50784 828 | 50039 829 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hybrid 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - bioconda 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_gnu 11 | - aom=3.9.1=hac33072_0 12 | - blas=1.0=mkl 13 | - brotli-python=1.1.0=py310hc6cd4ac_1 14 | - bzip2=1.0.8=hd590300_5 15 | - ca-certificates=2024.7.4=hbcca054_0 16 | - cairo=1.18.0=hbb29018_2 17 | - certifi=2024.7.4=pyhd8ed1ab_0 18 | - cffi=1.16.0=py310h2fee648_0 19 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 20 | - cuda-cudart=12.1.105=0 21 | - cuda-cupti=12.1.105=0 22 | - cuda-libraries=12.1.0=0 23 | - cuda-nvrtc=12.1.105=0 24 | - cuda-nvtx=12.1.105=0 25 | - cuda-opencl=12.5.39=0 26 | - cuda-runtime=12.1.0=0 27 | - cuda-version=12.5=3 28 | - dav1d=1.2.1=hd590300_0 29 | - expat=2.6.2=h59595ed_0 30 | - ffmpeg=7.0.1=gpl_h9be9148_104 31 | - filelock=3.15.4=pyhd8ed1ab_0 32 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 33 | - font-ttf-inconsolata=3.000=h77eed37_0 34 | - font-ttf-source-code-pro=2.038=h77eed37_0 35 | - font-ttf-ubuntu=0.83=h77eed37_2 36 | - fontconfig=2.14.2=h14ed4e7_0 37 | - fonts-conda-ecosystem=1=0 38 | - fonts-conda-forge=1=0 39 | - freetype=2.12.1=h267a509_2 40 | - fribidi=1.0.10=h36c2ea0_0 41 | - gettext=0.22.5=h59595ed_2 42 | - gettext-tools=0.22.5=h59595ed_2 43 | - gmp=6.3.0=hac33072_2 44 | - gmpy2=2.1.5=py310hc7909c9_1 45 | - gnutls=3.7.9=hb077bed_0 46 | - graphite2=1.3.13=h59595ed_1003 47 | - h2=4.1.0=pyhd8ed1ab_0 48 | - harfbuzz=9.0.0=hfac3d4d_0 49 | - hpack=4.0.0=pyh9f0ad1d_0 50 | - hyperframe=6.0.1=pyhd8ed1ab_0 51 | - icu=73.2=h59595ed_0 52 | - idna=3.7=pyhd8ed1ab_0 53 | - intel-openmp=2022.0.1=h06a4308_3633 54 | - jinja2=3.1.4=pyhd8ed1ab_0 55 | - lame=3.100=h166bdaf_1003 56 | - lcms2=2.16=hb7c19ff_0 57 | - ld_impl_linux-64=2.40=hf3520f5_7 58 | - lerc=4.0.0=h27087fc_0 59 | - libabseil=20240116.2=cxx17_h59595ed_0 60 | - libasprintf=0.22.5=h661eb56_2 61 | - libasprintf-devel=0.22.5=h661eb56_2 62 | - libass=0.17.1=h39113c1_2 63 | - libblas=3.9.0=16_linux64_mkl 64 | - libcblas=3.9.0=16_linux64_mkl 65 | - libcublas=12.1.0.26=0 66 | - libcufft=11.0.2.4=0 67 | - libcufile=1.10.1.7=0 68 | - libcurand=10.3.6.82=0 69 | - libcusolver=11.4.4.55=0 70 | - libcusparse=12.0.2.55=0 71 | - libdeflate=1.20=hd590300_0 72 | - libdrm=2.4.122=h4ab18f5_0 73 | - libexpat=2.6.2=h59595ed_0 74 | - libffi=3.4.2=h7f98852_5 75 | - libgcc-ng=14.1.0=h77fa898_0 76 | - libgettextpo=0.22.5=h59595ed_2 77 | - libgettextpo-devel=0.22.5=h59595ed_2 78 | - libglib=2.80.3=h8a4344b_1 79 | - libgomp=14.1.0=h77fa898_0 80 | - libhwloc=2.11.1=default_hecaa2ac_1000 81 | - libiconv=1.17=hd590300_2 82 | - libidn2=2.3.7=hd590300_0 83 | - libjpeg-turbo=3.0.0=hd590300_1 84 | - liblapack=3.9.0=16_linux64_mkl 85 | - libnpp=12.0.2.50=0 86 | - libnsl=2.0.1=hd590300_0 87 | - libnvjitlink=12.1.105=0 88 | - libnvjpeg=12.1.1.14=0 89 | - libopenvino=2024.2.0=h2da1b83_1 90 | - libopenvino-auto-batch-plugin=2024.2.0=hb045406_1 91 | - libopenvino-auto-plugin=2024.2.0=hb045406_1 92 | - libopenvino-hetero-plugin=2024.2.0=h5c03a75_1 93 | - libopenvino-intel-cpu-plugin=2024.2.0=h2da1b83_1 94 | - libopenvino-intel-gpu-plugin=2024.2.0=h2da1b83_1 95 | - libopenvino-intel-npu-plugin=2024.2.0=he02047a_1 96 | - libopenvino-ir-frontend=2024.2.0=h5c03a75_1 97 | - libopenvino-onnx-frontend=2024.2.0=h07e8aee_1 98 | - libopenvino-paddle-frontend=2024.2.0=h07e8aee_1 99 | - libopenvino-pytorch-frontend=2024.2.0=he02047a_1 100 | - libopenvino-tensorflow-frontend=2024.2.0=h39126c6_1 101 | - libopenvino-tensorflow-lite-frontend=2024.2.0=he02047a_1 102 | - libopus=1.3.1=h7f98852_1 103 | - libpciaccess=0.18=hd590300_0 104 | - libpng=1.6.43=h2797004_0 105 | - libprotobuf=4.25.3=h08a7969_0 106 | - libsqlite=3.46.0=hde9e2c9_0 107 | - libstdcxx-ng=14.1.0=hc0a3c3a_0 108 | - libtasn1=4.19.0=h166bdaf_0 109 | - libtiff=4.6.0=h1dd3fc0_3 110 | - libunistring=0.9.10=h7f98852_0 111 | - libuuid=2.38.1=h0b41bf4_0 112 | - libva=2.22.0=hb711507_0 113 | - libvpx=1.14.1=hac33072_0 114 | - libwebp-base=1.4.0=hd590300_0 115 | - libxcb=1.16=hd590300_0 116 | - libxcrypt=4.4.36=hd590300_1 117 | - libxml2=2.12.7=h4c95cb1_3 118 | - libzlib=1.3.1=h4ab18f5_1 119 | - llvm-openmp=15.0.7=h0cdce71_0 120 | - markupsafe=2.1.5=py310h2372a71_0 121 | - mkl=2022.1.0=hc2b9512_224 122 | - mpc=1.3.1=hfe3b2da_0 123 | - mpfr=4.2.1=h9458935_1 124 | - mpmath=1.3.0=pyhd8ed1ab_0 125 | - ncurses=6.5=h59595ed_0 126 | - nettle=3.9.1=h7ab15ed_0 127 | - networkx=3.3=pyhd8ed1ab_1 128 | - ocl-icd=2.3.2=hd590300_1 129 | - openh264=2.4.1=h59595ed_0 130 | - openjpeg=2.5.2=h488ebb8_0 131 | - openssl=3.3.1=h4ab18f5_1 132 | - p11-kit=0.24.1=hc5aa10d_0 133 | - pcre2=10.44=h0f59acf_0 134 | - pillow=10.4.0=py310hebfe307_0 135 | - pip=24.0=pyhd8ed1ab_0 136 | - pixman=0.43.2=h59595ed_0 137 | - pthread-stubs=0.4=h36c2ea0_1001 138 | - pugixml=1.14=h59595ed_0 139 | - pycparser=2.22=pyhd8ed1ab_0 140 | - pysocks=1.7.1=pyha2e5f31_6 141 | - python=3.10.14=hd12c33a_0_cpython 142 | - python_abi=3.10=4_cp310 143 | - pytorch=2.1.2=py3.10_cuda12.1_cudnn8.9.2_0 144 | - pytorch-cuda=12.1=ha16c6d3_5 145 | - pytorch-mutex=1.0=cuda 146 | - pyyaml=6.0.1=py310h2372a71_1 147 | - readline=8.2=h8228510_1 148 | - requests=2.32.3=pyhd8ed1ab_0 149 | - setuptools=70.3.0=pyhd8ed1ab_0 150 | - snappy=1.2.1=ha2e4443_0 151 | - svt-av1=2.1.2=hac33072_0 152 | - sympy=1.12.1=pypyh2585a3b_103 153 | - tbb=2021.12.0=h434a139_3 154 | - tk=8.6.13=noxft_h4845f30_101 155 | - torchaudio=2.1.2=py310_cu121 156 | - torchtriton=2.1.0=py310 157 | - torchvision=0.16.2=py310_cu121 158 | - typing_extensions=4.12.2=pyha770c72_0 159 | - urllib3=2.2.2=pyhd8ed1ab_1 160 | - wayland=1.23.0=h5291e77_0 161 | - wayland-protocols=1.36=hd8ed1ab_0 162 | - wheel=0.43.0=pyhd8ed1ab_1 163 | - x264=1!164.3095=h166bdaf_2 164 | - x265=3.5=h924138e_3 165 | - xorg-fixesproto=5.0=h7f98852_1002 166 | - xorg-kbproto=1.0.7=h7f98852_1002 167 | - xorg-libice=1.1.1=hd590300_0 168 | - xorg-libsm=1.2.4=h7391055_0 169 | - xorg-libx11=1.8.9=hb711507_1 170 | - xorg-libxau=1.0.11=hd590300_0 171 | - xorg-libxdmcp=1.1.3=h7f98852_0 172 | - xorg-libxext=1.3.4=h0b41bf4_2 173 | - xorg-libxfixes=5.0.3=h7f98852_1004 174 | - xorg-libxrender=0.9.11=hd590300_0 175 | - xorg-renderproto=0.11.1=h7f98852_1002 176 | - xorg-xextproto=7.3.0=h0b41bf4_1003 177 | - xorg-xproto=7.0.31=h7f98852_1007 178 | - xz=5.2.6=h166bdaf_0 179 | - yaml=0.2.5=h7f98852_2 180 | - zlib=1.3.1=h4ab18f5_1 181 | - zstandard=0.22.0=py310hab88d88_1 182 | - zstd=1.5.6=ha6fb4c9_0 183 | - pip: 184 | - absl-py==2.1.0 185 | - aiohttp==3.9.5 186 | - aiosignal==1.3.1 187 | - asttokens==2.4.1 188 | - async-timeout==4.0.3 189 | - attrs==23.2.0 190 | - autoflake==2.3.1 191 | - black==24.4.2 192 | - click==8.1.7 193 | - contourpy==1.2.1 194 | - cycler==0.12.1 195 | - decorator==5.1.1 196 | - docker-pycreds==0.4.0 197 | - docstring-parser==0.16 198 | - exceptiongroup==1.2.2 199 | - executing==2.0.1 200 | - fonttools==4.53.1 201 | - frozenlist==1.4.1 202 | - fsspec==2024.6.1 203 | - gitdb==4.0.11 204 | - gitpython==3.1.43 205 | - grpcio==1.64.1 206 | - h5py==3.11.0 207 | - importlib-resources==6.4.0 208 | - ipdb==0.13.13 209 | - ipython==8.26.0 210 | - isort==5.13.2 211 | - jedi==0.19.1 212 | - joblib==1.4.2 213 | - jsonargparse==4.31.0 214 | - kiwisolver==1.4.5 215 | - lightning==2.3.3 216 | - lightning-bolts==0.7.0 217 | - lightning-utilities==0.11.3.post0 218 | - markdown==3.6 219 | - matplotlib==3.9.1 220 | - matplotlib-inline==0.1.7 221 | - multidict==6.0.5 222 | - mypy-extensions==1.0.0 223 | - nibabel==5.2.1 224 | - numpy==1.25.2 225 | - packaging==24.1 226 | - pandas==2.2.2 227 | - parso==0.8.4 228 | - pathspec==0.12.1 229 | - pexpect==4.9.0 230 | - platformdirs==4.2.2 231 | - prompt-toolkit==3.0.47 232 | - protobuf==4.25.3 233 | - psutil==6.0.0 234 | - ptyprocess==0.7.0 235 | - pure-eval==0.2.2 236 | - pycodestyle==2.12.0 237 | - pyflakes==3.2.0 238 | - pygments==2.18.0 239 | - pyparsing==3.1.2 240 | - python-dateutil==2.9.0.post0 241 | - pytz==2024.1 242 | - scikit-learn==1.5.1 243 | - scipy==1.9.3 244 | - sentry-sdk==2.9.0 245 | - setproctitle==1.3.3 246 | - six==1.16.0 247 | - smmap==5.0.1 248 | - stack-data==0.6.3 249 | - tensorboard==2.17.0 250 | - tensorboard-data-server==0.7.2 251 | - threadpoolctl==3.5.0 252 | - tomli==2.0.1 253 | - torchmetrics==1.4.0.post0 254 | - tqdm==4.66.4 255 | - traitlets==5.14.3 256 | - typeshed-client==2.5.1 257 | - tzdata==2024.1 258 | - wandb==0.17.4 259 | - wcwidth==0.2.13 260 | - werkzeug==3.0.3 261 | - yarl==1.9.4 262 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lightning.fabric.utilities.cloud_io import get_filesystem 4 | from lightning.pytorch.cli import LightningCLI, SaveConfigCallback 5 | from lightning.pytorch.loggers import Logger 6 | 7 | 8 | class SaveConfigCallback(SaveConfigCallback): 9 | def setup(self, trainer, pl_module, stage): 10 | if self.already_saved: 11 | return 12 | 13 | if self.save_to_log_dir: 14 | assert trainer.log_dir is not None 15 | log_dir = os.path.join( 16 | trainer.log_dir, trainer.logger.name, trainer.logger.version 17 | ) # this broadcasts the directory 18 | if trainer.is_global_zero and not os.path.exists(log_dir): 19 | os.makedirs(log_dir) 20 | config_path = os.path.join(log_dir, self.config_filename) 21 | fs = get_filesystem(log_dir) 22 | 23 | if not self.overwrite: 24 | # check if the file exists on rank 0 25 | file_exists = ( 26 | fs.isfile(config_path) if trainer.is_global_zero else False 27 | ) 28 | # broadcast whether to fail to all ranks 29 | file_exists = trainer.strategy.broadcast(file_exists) 30 | if file_exists: 31 | raise RuntimeError( 32 | f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" 33 | " results of a previous run. You can delete the previous config file," 34 | " set `LightningCLI(save_config_callback=None)` to disable config saving," 35 | ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' 36 | ) 37 | 38 | if trainer.is_global_zero: 39 | fs.makedirs(log_dir, exist_ok=True) 40 | self.parser.save( 41 | self.config, 42 | config_path, 43 | skip_none=False, 44 | overwrite=self.overwrite, 45 | multifile=self.multifile, 46 | ) 47 | 48 | if trainer.is_global_zero: 49 | self.save_config(trainer, pl_module, stage) 50 | self.already_saved = True 51 | 52 | # broadcast so that all ranks are in sync on future calls to .setup() 53 | self.already_saved = trainer.strategy.broadcast(self.already_saved) 54 | 55 | def save_config(self, trainer, pl_module, stage: str) -> None: 56 | if isinstance(trainer.logger, Logger): 57 | config = self.parser.dump( 58 | self.config, skip_none=False 59 | ) # Required for proper reproducibility 60 | trainer.logger.log_hyperparams({"config": config}) 61 | 62 | 63 | class Main(LightningCLI): 64 | def add_arguments_to_parser(self, parser): 65 | parser.link_arguments( 66 | "data.entropy", "model.init_args.node_entropy", apply_on="instantiate" 67 | ) 68 | 69 | 70 | def main(): 71 | cli = Main(save_config_callback=SaveConfigCallback, run=False) 72 | cli.trainer.fit(cli.model, datamodule=cli.datamodule) 73 | 74 | cli.trainer.test(cli.model, datamodule=cli.datamodule) 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /scripts/download_and_preprocess_ABIDE.py: -------------------------------------------------------------------------------- 1 | # This script is modified based on https://github.com/xxlya/BrainGNN_Pytorch 2 | 3 | import os 4 | import h5py 5 | import argparse 6 | from nilearn import datasets 7 | import shutil 8 | import glob 9 | import csv 10 | import numpy as np 11 | import scipy.io as sio 12 | from nilearn import connectome 13 | from sklearn.compose import ColumnTransformer 14 | from sklearn.preprocessing import OrdinalEncoder 15 | 16 | # Input data variables 17 | code_folder = os.getcwd() 18 | root_folder = "./data/" 19 | data_folder = os.path.join(root_folder, "ABIDE_pcp/cpac/filt_noglobal/") 20 | if not os.path.exists(data_folder): 21 | os.makedirs(data_folder) 22 | phenotype = os.path.join(root_folder, "ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv") 23 | shutil.copyfile( 24 | os.path.join(os.path.dirname(__file__), "subject_ID.txt"), 25 | os.path.join(data_folder, "subject_IDs.txt"), 26 | ) 27 | 28 | pipeline = "cpac" 29 | atlas = "cc200" 30 | 31 | 32 | def download(): 33 | # Files to fetch 34 | files = ["rois_" + atlas] 35 | filemapping = {"func_preproc": "func_preproc.nii.gz", files[0]: files[0] + ".1D"} 36 | 37 | # Download database files 38 | datasets.fetch_abide_pcp( 39 | data_dir=root_folder, 40 | pipeline=pipeline, 41 | band_pass_filtering=True, 42 | global_signal_regression=False, 43 | derivatives=files, 44 | quality_checked=False, 45 | ) 46 | 47 | subject_IDs = get_ids() # changed path to data path 48 | subject_IDs = subject_IDs.tolist() 49 | 50 | # Create a folder for each subject 51 | for s, fname in zip(subject_IDs, fetch_filenames(subject_IDs, files[0], atlas)): 52 | subject_folder = os.path.join(data_folder, s) 53 | if not os.path.exists(subject_folder): 54 | os.mkdir(subject_folder) 55 | 56 | # Get the base filename for each subject 57 | base = fname.split(files[0])[0] 58 | 59 | # Move each subject file to the subject folder 60 | for fl in files: 61 | if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])): 62 | shutil.move(base + filemapping[fl], subject_folder) 63 | 64 | time_series = get_timeseries(subject_IDs, atlas) 65 | 66 | # Compute and save connectivity matrices 67 | subject_connectivity(time_series, subject_IDs, atlas, "correlation") 68 | subject_connectivity(time_series, subject_IDs, atlas, "partial correlation") 69 | 70 | 71 | def preprocess(args): 72 | # Get subject IDs and IQ scores 73 | subject_IDs = get_ids() 74 | fiqscores = get_subject_score(subject_IDs, score="FIQ") 75 | viqscores = get_subject_score(subject_IDs, score="VIQ") 76 | piqscores = get_subject_score(subject_IDs, score="PIQ") 77 | # Compute feature vectors (vectorised connectivity networks) 78 | fea_corr = get_networks( 79 | subject_IDs, iter_no="", kind="correlation", atlas_name=atlas 80 | ) # (1035, 200, 200) 81 | # get the time series of the subjects 82 | time_series = get_timeseries( 83 | subject_IDs, atlas_name=atlas, silence=True 84 | ) # (1035,x, 200) 85 | # we need to flip the time series one by one to match the shape of the correlation matrix #(1035, 200, x) 86 | 87 | # prepare the data for h5 file 88 | if not os.path.exists(os.path.dirname(args.output_path)): 89 | os.makedirs(os.path.dirname(args.output_path)) 90 | file = h5py.File(args.output_path, "w") 91 | for i, subject in enumerate(subject_IDs): 92 | # create a group for each subject 93 | group = file.create_group(subject) 94 | # add the feature vector 95 | group.create_dataset("pearson", data=fea_corr[i]) 96 | # add the time series 97 | group.create_dataset("x", data=time_series[i].T) 98 | # add the IQ scores 99 | group.create_dataset("fiq", data=fiqscores[subject]) 100 | group.create_dataset("viq", data=viqscores[subject]) 101 | group.create_dataset("piq", data=piqscores[subject]) 102 | 103 | file.close() 104 | 105 | 106 | def fetch_filenames(subject_IDs, file_type, atlas): 107 | """ 108 | subject_list : list of short subject IDs in string format 109 | file_type : must be one of the available file types 110 | filemapping : resulting file name format 111 | returns: 112 | filenames : list of filetypes (same length as subject_list) 113 | """ 114 | 115 | filemapping = { 116 | "func_preproc": "_func_preproc.nii.gz", 117 | "rois_" + atlas: "_rois_" + atlas + ".1D", 118 | } 119 | # The list to be filled 120 | filenames = [] 121 | 122 | # Fill list with requested file paths 123 | for i in range(len(subject_IDs)): 124 | path = os.path.join(data_folder, "*" + subject_IDs[i] + filemapping[file_type]) 125 | try: 126 | filenames.append(glob.glob(path)[0]) 127 | except IndexError: 128 | filenames.append("N/A") 129 | return filenames 130 | 131 | 132 | # Get timeseries arrays for list of subjects 133 | def get_timeseries(subject_list, atlas_name, silence=False): 134 | """ 135 | subject_list : list of short subject IDs in string format 136 | atlas_name : the atlas based on which the timeseries are generated e.g. aal, cc200 137 | returns: 138 | time_series : list of timeseries arrays, each of shape (timepoints x regions) 139 | """ 140 | 141 | timeseries = [] 142 | for i in range(len(subject_list)): 143 | subject_folder = os.path.join(data_folder, subject_list[i]) 144 | ro_file = [ 145 | f 146 | for f in os.listdir(subject_folder) 147 | if f.endswith("_rois_" + atlas_name + ".1D") 148 | ] 149 | fl = os.path.join(subject_folder, ro_file[0]) 150 | if silence != True: 151 | print("Reading timeseries file %s" % fl) 152 | timeseries.append(np.loadtxt(fl, skiprows=0)) 153 | 154 | return timeseries 155 | 156 | 157 | # compute connectivity matrices 158 | def subject_connectivity( 159 | timeseries, 160 | subjects, 161 | atlas_name, 162 | kind, 163 | iter_no="", 164 | seed=1234, 165 | n_subjects="", 166 | save=True, 167 | save_path=data_folder, 168 | ): 169 | """ 170 | timeseries : timeseries table for subject (timepoints x regions) 171 | subjects : subject IDs 172 | atlas_name : name of the parcellation atlas used 173 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 174 | iter_no : tangent connectivity iteration number for cross validation evaluation 175 | save : save the connectivity matrix to a file 176 | save_path : specify path to save the matrix if different from subject folder 177 | returns: 178 | connectivity : connectivity matrix (regions x regions) 179 | """ 180 | 181 | if kind in ["TPE", "TE", "correlation", "partial correlation"]: 182 | if kind not in ["TPE", "TE"]: 183 | conn_measure = connectome.ConnectivityMeasure(kind=kind) 184 | connectivity = conn_measure.fit_transform(timeseries) 185 | else: 186 | if kind == "TPE": 187 | conn_measure = connectome.ConnectivityMeasure(kind="correlation") 188 | conn_mat = conn_measure.fit_transform(timeseries) 189 | conn_measure = connectome.ConnectivityMeasure(kind="tangent") 190 | connectivity_fit = conn_measure.fit(conn_mat) 191 | connectivity = connectivity_fit.transform(conn_mat) 192 | else: 193 | conn_measure = connectome.ConnectivityMeasure(kind="tangent") 194 | connectivity_fit = conn_measure.fit(timeseries) 195 | connectivity = connectivity_fit.transform(timeseries) 196 | 197 | if save: 198 | if kind not in ["TPE", "TE"]: 199 | for i, subj_id in enumerate(subjects): 200 | subject_file = os.path.join( 201 | save_path, 202 | subj_id, 203 | subj_id + "_" + atlas_name + "_" + kind.replace(" ", "_") + ".mat", 204 | ) 205 | sio.savemat(subject_file, {"connectivity": connectivity[i]}) 206 | return connectivity 207 | else: 208 | for i, subj_id in enumerate(subjects): 209 | subject_file = os.path.join( 210 | save_path, 211 | subj_id, 212 | subj_id 213 | + "_" 214 | + atlas_name 215 | + "_" 216 | + kind.replace(" ", "_") 217 | + "_" 218 | + str(iter_no) 219 | + "_" 220 | + str(seed) 221 | + "_" 222 | + validation_ext 223 | + str(n_subjects) 224 | + ".mat", 225 | ) 226 | sio.savemat(subject_file, {"connectivity": connectivity[i]}) 227 | return connectivity_fit 228 | 229 | 230 | # Get the list of subject IDs 231 | 232 | 233 | def get_ids(num_subjects=None): 234 | """ 235 | return: 236 | subject_IDs : list of all subject IDs 237 | """ 238 | 239 | subject_IDs = np.genfromtxt(os.path.join(data_folder, "subject_IDs.txt"), dtype=str) 240 | 241 | if num_subjects is not None: 242 | subject_IDs = subject_IDs[:num_subjects] 243 | 244 | return subject_IDs 245 | 246 | 247 | # Get phenotype values for a list of subjects 248 | def get_subject_score(subject_list, score): 249 | scores_dict = {} 250 | 251 | with open(phenotype) as csv_file: 252 | reader = csv.DictReader(csv_file) 253 | for row in reader: 254 | if row["SUB_ID"] in subject_list: 255 | if score == "HANDEDNESS_CATEGORY": 256 | if (row[score].strip() == "-9999") or (row[score].strip() == ""): 257 | scores_dict[row["SUB_ID"]] = "R" 258 | elif row[score] == "Mixed": 259 | scores_dict[row["SUB_ID"]] = "Ambi" 260 | elif row[score] == "L->R": 261 | scores_dict[row["SUB_ID"]] = "Ambi" 262 | else: 263 | scores_dict[row["SUB_ID"]] = row[score] 264 | elif score == "FIQ" or score == "PIQ" or score == "VIQ": 265 | if (row[score].strip() == "-9999") or (row[score].strip() == ""): 266 | scores_dict[row["SUB_ID"]] = 100 267 | else: 268 | scores_dict[row["SUB_ID"]] = float(row[score]) 269 | 270 | else: 271 | scores_dict[row["SUB_ID"]] = row[score] 272 | 273 | return scores_dict 274 | 275 | 276 | # preprocess phenotypes. Categorical -> ordinal representation 277 | def preprocess_phenotypes(pheno_ft, params): 278 | if params["model"] == "MIDA": 279 | ct = ColumnTransformer( 280 | [("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder="passthrough" 281 | ) 282 | else: 283 | ct = ColumnTransformer( 284 | [("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder="passthrough" 285 | ) 286 | 287 | pheno_ft = ct.fit_transform(pheno_ft) 288 | pheno_ft = pheno_ft.astype("float32") 289 | 290 | return pheno_ft 291 | 292 | 293 | # create phenotype feature vector to concatenate with fmri feature vectors 294 | def phenotype_ft_vector(pheno_ft, num_subjects, params): 295 | gender = pheno_ft[:, 0] 296 | if params["model"] == "MIDA": 297 | eye = pheno_ft[:, 0] 298 | hand = pheno_ft[:, 2] 299 | age = pheno_ft[:, 3] 300 | fiq = pheno_ft[:, 4] 301 | else: 302 | eye = pheno_ft[:, 2] 303 | hand = pheno_ft[:, 3] 304 | age = pheno_ft[:, 4] 305 | fiq = pheno_ft[:, 5] 306 | 307 | phenotype_ft = np.zeros((num_subjects, 4)) 308 | phenotype_ft_eye = np.zeros((num_subjects, 2)) 309 | phenotype_ft_hand = np.zeros((num_subjects, 3)) 310 | 311 | for i in range(num_subjects): 312 | phenotype_ft[i, int(gender[i])] = 1 313 | phenotype_ft[i, -2] = age[i] 314 | phenotype_ft[i, -1] = fiq[i] 315 | phenotype_ft_eye[i, int(eye[i])] = 1 316 | phenotype_ft_hand[i, int(hand[i])] = 1 317 | 318 | if params["model"] == "MIDA": 319 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1) 320 | else: 321 | phenotype_ft = np.concatenate( 322 | [phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1 323 | ) 324 | 325 | return phenotype_ft 326 | 327 | 328 | # Load precomputed fMRI connectivity networks 329 | def get_networks( 330 | subject_list, 331 | kind, 332 | iter_no="", 333 | seed=1234, 334 | n_subjects="", 335 | atlas_name="aal", 336 | variable="connectivity", 337 | ): 338 | """ 339 | subject_list : list of subject IDs 340 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 341 | atlas_name : name of the parcellation atlas used 342 | variable : variable name in the .mat file that has been used to save the precomputed networks 343 | return: 344 | matrix : feature matrix of connectivity networks (num_subjects x network_size) 345 | """ 346 | 347 | all_networks = [] 348 | for subject in subject_list: 349 | if len(kind.split()) == 2: 350 | kind = "_".join(kind.split()) 351 | fl = os.path.join( 352 | data_folder, 353 | subject, 354 | subject + "_" + atlas_name + "_" + kind.replace(" ", "_") + ".mat", 355 | ) 356 | 357 | matrix = sio.loadmat(fl)[variable] 358 | all_networks.append(matrix) 359 | 360 | if kind in ["TE", "TPE"]: 361 | norm_networks = [mat for mat in all_networks] 362 | else: 363 | norm_networks = [np.arctanh(mat) for mat in all_networks] 364 | 365 | networks = np.stack(norm_networks) 366 | 367 | return networks 368 | 369 | 370 | if __name__ == "__main__": 371 | parser = argparse.ArgumentParser() 372 | parser.add_argument( 373 | "output_path", 374 | type=str, 375 | help="Path of the h5 file to save the preprocessed data", 376 | ) 377 | args = parser.parse_args() 378 | print("Downloading ABIDE raw data") 379 | download() 380 | print("Preprocessing ABIDE data") 381 | preprocess(args) 382 | -------------------------------------------------------------------------------- /scripts/subject_ID.txt: -------------------------------------------------------------------------------- 1 | 50128 2 | 51203 3 | 50325 4 | 50117 5 | 50573 6 | 50741 7 | 50779 8 | 51009 9 | 50746 10 | 50574 11 | 50110 12 | 50322 13 | 51036 14 | 51204 15 | 50119 16 | 50126 17 | 50314 18 | 51490 19 | 50784 20 | 51464 21 | 51000 22 | 51038 23 | 50748 24 | 51235 25 | 51007 26 | 51463 27 | 50783 28 | 50777 29 | 50313 30 | 50121 31 | 51053 32 | 51261 33 | 50723 34 | 50511 35 | 51295 36 | 50347 37 | 50982 38 | 50976 39 | 51098 40 | 51292 41 | 50340 42 | 50516 43 | 50724 44 | 51266 45 | 51054 46 | 50186 47 | 50529 48 | 50985 49 | 50520 50 | 50376 51 | 50978 52 | 50144 53 | 51096 54 | 50382 55 | 51250 56 | 51062 57 | 50349 58 | 51065 59 | 50385 60 | 51257 61 | 50143 62 | 51091 63 | 50371 64 | 50527 65 | 51268 66 | 50188 67 | 50518 68 | 50749 69 | 51039 70 | 50776 71 | 50120 72 | 50312 73 | 51006 74 | 51234 75 | 50782 76 | 51462 77 | 50118 78 | 51465 79 | 50785 80 | 51001 81 | 50315 82 | 50127 83 | 51491 84 | 51008 85 | 50778 86 | 51205 87 | 50575 88 | 50747 89 | 50111 90 | 50129 91 | 50116 92 | 50324 93 | 50740 94 | 50572 95 | 51030 96 | 51202 97 | 50370 98 | 50142 99 | 51090 100 | 50526 101 | 51256 102 | 51064 103 | 50519 104 | 50189 105 | 51269 106 | 51063 107 | 50383 108 | 51251 109 | 50521 110 | 50145 111 | 51097 112 | 50979 113 | 50377 114 | 50348 115 | 51055 116 | 50187 117 | 51267 118 | 51293 119 | 50341 120 | 50725 121 | 51258 122 | 50984 123 | 50528 124 | 50970 125 | 50510 126 | 50722 127 | 51294 128 | 50346 129 | 51260 130 | 51052 131 | 51099 132 | 50977 133 | 50379 134 | 50983 135 | 50039 136 | 50496 137 | 51312 138 | 50234 139 | 50006 140 | 50650 141 | 50802 142 | 50668 143 | 51118 144 | 50657 145 | 50233 146 | 51127 147 | 51315 148 | 50491 149 | 50008 150 | 50498 151 | 50037 152 | 50205 153 | 50661 154 | 51581 155 | 50453 156 | 50695 157 | 51575 158 | 51111 159 | 51323 160 | 51129 161 | 50659 162 | 51324 163 | 51116 164 | 51572 165 | 50692 166 | 50666 167 | 50202 168 | 50030 169 | 51142 170 | 51370 171 | 50269 172 | 51189 173 | 50251 174 | 50407 175 | 50438 176 | 51348 177 | 50603 178 | 50267 179 | 51187 180 | 50055 181 | 51341 182 | 50293 183 | 51173 184 | 51174 185 | 51346 186 | 50294 187 | 51180 188 | 50052 189 | 50260 190 | 50604 191 | 50436 192 | 50658 193 | 51128 194 | 50667 195 | 50455 196 | 50031 197 | 50203 198 | 51117 199 | 51325 200 | 50693 201 | 51573 202 | 50499 203 | 50009 204 | 51574 205 | 50694 206 | 51322 207 | 51110 208 | 50204 209 | 50036 210 | 51580 211 | 50660 212 | 50803 213 | 50669 214 | 51314 215 | 51126 216 | 50490 217 | 50656 218 | 50232 219 | 50038 220 | 50804 221 | 50007 222 | 50235 223 | 50651 224 | 50463 225 | 50497 226 | 51121 227 | 51313 228 | 50261 229 | 51181 230 | 50053 231 | 50437 232 | 50605 233 | 51347 234 | 50295 235 | 51175 236 | 50408 237 | 51172 238 | 51340 239 | 50292 240 | 50602 241 | 51186 242 | 50054 243 | 50266 244 | 50259 245 | 50250 246 | 50406 247 | 51349 248 | 50439 249 | 50257 250 | 51188 251 | 50268 252 | 51195 253 | 50047 254 | 50275 255 | 50611 256 | 51161 257 | 51353 258 | 50281 259 | 51159 260 | 51354 261 | 50286 262 | 51166 263 | 50424 264 | 50616 265 | 50272 266 | 51192 267 | 50040 268 | 50049 269 | 51362 270 | 51150 271 | 50412 272 | 50620 273 | 50618 274 | 50288 275 | 51168 276 | 50627 277 | 50415 278 | 50243 279 | 51365 280 | 50441 281 | 50217 282 | 50025 283 | 50819 284 | 51331 285 | 51103 286 | 51567 287 | 50687 288 | 50826 289 | 51558 290 | 51560 291 | 51104 292 | 51336 293 | 50022 294 | 50210 295 | 50446 296 | 51309 297 | 50821 298 | 51132 299 | 51300 300 | 51556 301 | 50642 302 | 50470 303 | 50014 304 | 51569 305 | 50689 306 | 50817 307 | 50013 308 | 50477 309 | 50645 310 | 50483 311 | 51307 312 | 51135 313 | 50448 314 | 51338 315 | 51169 316 | 50289 317 | 50619 318 | 51364 319 | 51156 320 | 50414 321 | 50626 322 | 50242 323 | 50048 324 | 50245 325 | 50621 326 | 50413 327 | 51151 328 | 51363 329 | 50628 330 | 50617 331 | 50425 332 | 51193 333 | 50041 334 | 50273 335 | 51167 336 | 51355 337 | 50287 338 | 51352 339 | 50280 340 | 51160 341 | 50274 342 | 51194 343 | 50046 344 | 50422 345 | 50610 346 | 50482 347 | 51134 348 | 51306 349 | 50012 350 | 50644 351 | 51339 352 | 50449 353 | 50643 354 | 50015 355 | 51301 356 | 51133 357 | 50485 358 | 51557 359 | 50816 360 | 50688 361 | 51568 362 | 50211 363 | 50023 364 | 50447 365 | 51561 366 | 51105 367 | 50820 368 | 51308 369 | 51102 370 | 51330 371 | 50686 372 | 51566 373 | 50440 374 | 50818 375 | 50024 376 | 50216 377 | 51559 378 | 50169 379 | 50955 380 | 50156 381 | 51084 382 | 50364 383 | 50700 384 | 50532 385 | 51070 386 | 50390 387 | 51048 388 | 50952 389 | 50738 390 | 50397 391 | 51077 392 | 50999 393 | 50707 394 | 50363 395 | 51083 396 | 50990 397 | 50158 398 | 50964 399 | 51273 400 | 51041 401 | 50193 402 | 50355 403 | 50167 404 | 50503 405 | 50731 406 | 50709 407 | 50399 408 | 51079 409 | 50997 410 | 50736 411 | 50504 412 | 50160 413 | 51280 414 | 50352 415 | 51046 416 | 50194 417 | 51274 418 | 51482 419 | 50306 420 | 50134 421 | 51220 422 | 51012 423 | 51476 424 | 50796 425 | 50339 426 | 50791 427 | 51471 428 | 51015 429 | 51227 430 | 50133 431 | 50301 432 | 50557 433 | 51485 434 | 51218 435 | 50568 436 | 51023 437 | 51211 438 | 50753 439 | 50561 440 | 50105 441 | 50337 442 | 51478 443 | 50798 444 | 50308 445 | 50330 446 | 50102 447 | 50566 448 | 50754 449 | 51216 450 | 51024 451 | 50559 452 | 51229 453 | 50996 454 | 51078 455 | 50962 456 | 50708 457 | 51275 458 | 51047 459 | 50195 460 | 50505 461 | 50737 462 | 51281 463 | 50353 464 | 50161 465 | 50965 466 | 50159 467 | 50991 468 | 50166 469 | 50354 470 | 50730 471 | 50502 472 | 51040 473 | 50192 474 | 51272 475 | 50739 476 | 51049 477 | 50706 478 | 50150 479 | 51082 480 | 50362 481 | 50998 482 | 51076 483 | 50954 484 | 50168 485 | 50391 486 | 51071 487 | 50365 488 | 50157 489 | 51085 490 | 50701 491 | 51025 492 | 51217 493 | 50103 494 | 50331 495 | 50755 496 | 50567 497 | 51228 498 | 50558 499 | 50560 500 | 50752 501 | 50336 502 | 50104 503 | 51210 504 | 50799 505 | 51479 506 | 50300 507 | 50132 508 | 50556 509 | 51484 510 | 51470 511 | 50790 512 | 51226 513 | 51014 514 | 50569 515 | 51219 516 | 51013 517 | 51221 518 | 50797 519 | 51477 520 | 50551 521 | 51483 522 | 50135 523 | 50307 524 | 50338 525 | 50171 526 | 50343 527 | 51291 528 | 50727 529 | 50515 530 | 50185 531 | 51057 532 | 51265 533 | 50972 534 | 50388 535 | 50986 536 | 51068 537 | 51262 538 | 50182 539 | 51050 540 | 51606 541 | 50344 542 | 51296 543 | 50981 544 | 50149 545 | 51254 546 | 50386 547 | 50988 548 | 51066 549 | 50372 550 | 50524 551 | 51059 552 | 50711 553 | 50523 554 | 51095 555 | 50147 556 | 50375 557 | 51061 558 | 51253 559 | 50381 560 | 51298 561 | 51238 562 | 50577 563 | 50745 564 | 50321 565 | 50113 566 | 51207 567 | 51035 568 | 51469 569 | 50789 570 | 50319 571 | 51456 572 | 51032 573 | 50114 574 | 50326 575 | 50742 576 | 50570 577 | 51209 578 | 51236 579 | 50780 580 | 51460 581 | 50774 582 | 50122 583 | 50310 584 | 51458 585 | 50317 586 | 50125 587 | 51493 588 | 50773 589 | 51467 590 | 50787 591 | 51231 592 | 51003 593 | 51252 594 | 50380 595 | 51060 596 | 50710 597 | 50374 598 | 51094 599 | 50146 600 | 51299 601 | 51093 602 | 50373 603 | 50525 604 | 51067 605 | 50989 606 | 51255 607 | 50387 608 | 50728 609 | 51058 610 | 50345 611 | 51297 612 | 50183 613 | 51051 614 | 51263 615 | 51607 616 | 50148 617 | 50974 618 | 51264 619 | 50184 620 | 51056 621 | 50342 622 | 50170 623 | 50514 624 | 50726 625 | 51069 626 | 50987 627 | 50973 628 | 51459 629 | 50329 630 | 50786 631 | 51466 632 | 51002 633 | 51230 634 | 50124 635 | 50316 636 | 50772 637 | 51492 638 | 50578 639 | 51208 640 | 50775 641 | 50311 642 | 50123 643 | 51237 644 | 51461 645 | 50781 646 | 50318 647 | 50788 648 | 51468 649 | 50327 650 | 50115 651 | 50571 652 | 50743 653 | 51457 654 | 51201 655 | 51033 656 | 51239 657 | 51034 658 | 51206 659 | 50744 660 | 50576 661 | 50112 662 | 50320 663 | 50060 664 | 50252 665 | 50404 666 | 51146 667 | 50609 668 | 50299 669 | 51179 670 | 51373 671 | 51141 672 | 50403 673 | 50255 674 | 50058 675 | 50297 676 | 51345 677 | 51177 678 | 50263 679 | 50051 680 | 51183 681 | 50435 682 | 50607 683 | 51148 684 | 50056 685 | 51184 686 | 50264 687 | 51170 688 | 50290 689 | 51342 690 | 50801 691 | 51329 692 | 50466 693 | 50654 694 | 51316 695 | 51124 696 | 50492 697 | 51578 698 | 50698 699 | 50208 700 | 51123 701 | 51311 702 | 50005 703 | 50237 704 | 50653 705 | 51318 706 | 50468 707 | 51327 708 | 50691 709 | 51571 710 | 50665 711 | 51585 712 | 50033 713 | 50201 714 | 50239 715 | 50206 716 | 50034 717 | 51582 718 | 51576 719 | 50696 720 | 51320 721 | 51112 722 | 50291 723 | 51343 724 | 51171 725 | 50433 726 | 50601 727 | 50265 728 | 50057 729 | 51185 730 | 50050 731 | 51182 732 | 50262 733 | 50606 734 | 50434 735 | 50296 736 | 51344 737 | 51149 738 | 50402 739 | 50254 740 | 51140 741 | 50059 742 | 51147 743 | 50253 744 | 50405 745 | 51178 746 | 50298 747 | 50608 748 | 50697 749 | 51577 750 | 51113 751 | 51321 752 | 50035 753 | 50207 754 | 50663 755 | 51583 756 | 50469 757 | 51319 758 | 51584 759 | 50664 760 | 50200 761 | 50032 762 | 51326 763 | 51114 764 | 51570 765 | 50690 766 | 50807 767 | 50209 768 | 50699 769 | 51579 770 | 50236 771 | 50004 772 | 50652 773 | 50494 774 | 51122 775 | 51328 776 | 50800 777 | 51317 778 | 50493 779 | 50655 780 | 50467 781 | 50003 782 | 51563 783 | 50683 784 | 51335 785 | 51107 786 | 50213 787 | 50445 788 | 51138 789 | 50648 790 | 50822 791 | 50442 792 | 50026 793 | 50214 794 | 51100 795 | 51332 796 | 51564 797 | 50019 798 | 50825 799 | 50489 800 | 50010 801 | 50646 802 | 50480 803 | 51136 804 | 51304 805 | 51109 806 | 51303 807 | 51131 808 | 50487 809 | 50017 810 | 50028 811 | 50814 812 | 50418 813 | 51165 814 | 50285 815 | 51357 816 | 50615 817 | 50427 818 | 50043 819 | 51191 820 | 50271 821 | 50249 822 | 50276 823 | 50044 824 | 51196 825 | 50612 826 | 50282 827 | 51350 828 | 51162 829 | 51359 830 | 50416 831 | 50624 832 | 50240 833 | 51154 834 | 50278 835 | 51198 836 | 51153 837 | 51361 838 | 50247 839 | 50623 840 | 50411 841 | 50016 842 | 51130 843 | 51302 844 | 50486 845 | 50815 846 | 50029 847 | 50481 848 | 51305 849 | 51137 850 | 50011 851 | 50647 852 | 50812 853 | 51333 854 | 51101 855 | 51565 856 | 50685 857 | 50443 858 | 50215 859 | 50027 860 | 50488 861 | 50824 862 | 50020 863 | 50212 864 | 50444 865 | 50682 866 | 51562 867 | 51106 868 | 51334 869 | 50649 870 | 50823 871 | 51139 872 | 51199 873 | 50279 874 | 50246 875 | 50410 876 | 50622 877 | 51360 878 | 51152 879 | 51358 880 | 50428 881 | 51155 882 | 50625 883 | 50417 884 | 50241 885 | 50248 886 | 51163 887 | 50283 888 | 51351 889 | 50045 890 | 51197 891 | 50277 892 | 50613 893 | 50421 894 | 50419 895 | 51369 896 | 50426 897 | 50614 898 | 50270 899 | 50042 900 | 51190 901 | 50284 902 | 51356 903 | 51164 904 | 51472 905 | 50792 906 | 51224 907 | 51016 908 | 50302 909 | 50130 910 | 51486 911 | 50554 912 | 51029 913 | 51481 914 | 50553 915 | 50305 916 | 51011 917 | 51223 918 | 50795 919 | 50333 920 | 50757 921 | 50565 922 | 51027 923 | 51215 924 | 51488 925 | 51018 926 | 51212 927 | 51020 928 | 50562 929 | 50750 930 | 50334 931 | 50106 932 | 51279 933 | 50199 934 | 50509 935 | 51074 936 | 50704 937 | 51080 938 | 50152 939 | 50360 940 | 50956 941 | 50358 942 | 50367 943 | 51087 944 | 50969 945 | 50531 946 | 50703 947 | 51241 948 | 51073 949 | 50960 950 | 50994 951 | 51248 952 | 50507 953 | 50735 954 | 50351 955 | 50163 956 | 51277 957 | 50197 958 | 51045 959 | 50993 960 | 50369 961 | 51089 962 | 50967 963 | 50190 964 | 51042 965 | 50164 966 | 50958 967 | 50356 968 | 50732 969 | 50500 970 | 50751 971 | 50563 972 | 50107 973 | 50335 974 | 51021 975 | 51213 976 | 51214 977 | 51026 978 | 50332 979 | 50564 980 | 50756 981 | 51019 982 | 51489 983 | 51222 984 | 51010 985 | 51474 986 | 50794 987 | 51480 988 | 50552 989 | 50304 990 | 50136 991 | 50109 992 | 50131 993 | 50303 994 | 51487 995 | 50555 996 | 50793 997 | 51473 998 | 51017 999 | 51225 1000 | 51028 1001 | 50966 1002 | 51088 1003 | 50368 1004 | 50992 1005 | 50357 1006 | 50959 1007 | 50501 1008 | 50733 1009 | 51271 1010 | 50191 1011 | 51249 1012 | 50995 1013 | 50961 1014 | 50196 1015 | 51044 1016 | 51276 1017 | 50162 1018 | 50350 1019 | 51282 1020 | 50359 1021 | 50957 1022 | 51072 1023 | 51240 1024 | 50968 1025 | 51086 1026 | 50366 1027 | 50702 1028 | 50530 1029 | 50198 1030 | 51278 1031 | 50705 1032 | 50361 1033 | 51081 1034 | 50153 1035 | 51075 1036 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/HyBRiD/973d937a38449409264218e380ec6593ee30568c/src/__init__.py -------------------------------------------------------------------------------- /src/cpm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from the original Matlab's implementation of CPM. 3 | https://github.com/YaleMRRC/CPM 4 | """ 5 | 6 | import numpy as np 7 | import scipy.io 8 | from scipy.stats import pearsonr 9 | 10 | 11 | def corr(X, y): 12 | """ 13 | Inupts: 14 | X - [feature_dim x n_subjects] 15 | y - [n_subjects x 1] 16 | Outputs: 17 | r_mat - [feature_dim] 18 | p_mat - [feature_dim] 19 | """ 20 | r_mat = np.zeros(X.shape[0]) 21 | p_mat = np.zeros(X.shape[0]) 22 | for i, x in enumerate(X): 23 | r_val, p_val = scipy.stats.pearsonr(x, y) 24 | r_mat[i] = r_val 25 | p_mat[i] = p_val 26 | return r_mat, p_mat 27 | 28 | 29 | def cpm(edge_weights, labels, traintestid): 30 | """ 31 | This function evaluates the edge weigths and returns (also prints) the R value by CPM 32 | """ 33 | train_id = traintestid["train_id"] 34 | test_id = traintestid["test_id"] 35 | all_mats = edge_weights 36 | all_behav = labels 37 | 38 | final_test_mats = all_mats[:, test_id] 39 | final_test_behave = all_behav[test_id] 40 | 41 | all_mats = all_mats[:, train_id] 42 | all_behav = all_behav[train_id] 43 | 44 | # threshold for feature selection 45 | thresh = 0.05 46 | 47 | # correlate all edges with behavior 48 | r_mat, p_mat = corr(all_mats, all_behav) 49 | 50 | final_mask = np.zeros(all_mats.shape[0]) 51 | final_edges = np.where((r_mat > 0) & (p_mat < thresh))[0] 52 | final_mask[final_edges] = 1 53 | 54 | # first get the train linear model 55 | train_sum = np.sum(all_mats * final_mask[:, np.newaxis], axis=0) 56 | 57 | fit = np.polyfit(train_sum, all_behav, 1) 58 | 59 | # evaluate on test 60 | test_sum = np.sum(final_test_mats * final_mask[:, np.newaxis], axis=0) 61 | test_behav = fit[0] * test_sum + fit[1] 62 | 63 | test_R, test_P = pearsonr(test_behav, final_test_behave) 64 | 65 | print("R: ", test_R) 66 | return test_R, test_P 67 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from functools import cached_property 4 | 5 | import h5py 6 | import numpy as np 7 | import torch 8 | from lightning import LightningDataModule 9 | from scipy.stats import entropy as stats_entropy 10 | from sklearn.model_selection import train_test_split 11 | from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset 12 | from tqdm import tqdm 13 | 14 | 15 | class MultiSourceDataset(Dataset): 16 | def __init__(self, datasets: dict[str, Dataset]): 17 | self.datasets: dict = datasets 18 | 19 | def __len__(self): 20 | return sum([len(d) for d in self.datasets.values()]) 21 | 22 | def __getitem__(self, idx: tuple[str, int]): 23 | task, idx = idx 24 | dataset = self.datasets[task] 25 | return dataset[idx] 26 | 27 | 28 | class MultiSourceBatchSampler(torch.utils.data.Sampler): 29 | def __init__( 30 | self, 31 | data_source: Dataset, 32 | batch_size: int, 33 | shuffle: bool = False, 34 | drop_last: bool = False, 35 | ): 36 | self.data_source = data_source 37 | self.batch_size = batch_size 38 | self.shuffle = shuffle 39 | self.drop_last = drop_last 40 | 41 | def __iter__(self): 42 | def chunks(lst, n): 43 | """Yield successive n-sized chunks from lst.""" 44 | for i in range(0, len(lst), n): 45 | yield lst[i : i + n] 46 | 47 | all_indices: list[list[tuple[str, int]]] = [] 48 | for task, dataset in self.data_source.datasets.items(): 49 | indices: list[int] = list(range(len(dataset))) 50 | if self.shuffle: 51 | np.random.shuffle(indices) 52 | for chunk in chunks(indices, self.batch_size): 53 | all_indices.append([(task, c) for c in chunk]) 54 | if self.shuffle: 55 | np.random.shuffle(all_indices) 56 | 57 | yield from all_indices 58 | 59 | def __len__(self): 60 | if self.drop_last: 61 | return len(self.data_source) // self.batch_size # type: ignore[arg-type] 62 | else: 63 | return (len(self.data_source) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] 64 | 65 | 66 | class BrainDataset(Dataset): 67 | def __init__( 68 | self, 69 | file: h5py.File, 70 | dataset_name: str, 71 | task_name: str, 72 | x_key: str = "x", 73 | y_key: str = "y", 74 | ): 75 | self.file = file 76 | self.dataset_name = dataset_name 77 | self.task_name = task_name 78 | self.in_memory = True 79 | 80 | self.x_key = x_key 81 | self.y_key = y_key 82 | 83 | assert hasattr(self, "dataset_name") and hasattr(self, "task_name") 84 | 85 | self.dataset = self._read_data() 86 | 87 | def __len__(self): 88 | return len(self.dataset) 89 | 90 | def process(self, data): 91 | x = data[self.x_key][()] 92 | x = torch.from_numpy(x).float() 93 | y = torch.tensor(data[self.y_key][()]).float() 94 | return x, (y - 60) / (121 - 60) 95 | 96 | def __getitem__(self, idx): 97 | if self.in_memory: 98 | x, y = self.dataset[idx] 99 | else: 100 | idx = self.idx_to_filename[idx] 101 | x, y = self.process(self.dataset[idx]) 102 | 103 | x = torch.corrcoef(x) 104 | x = torch.nan_to_num(x, 0) 105 | 106 | meta = { 107 | "subject": idx, 108 | "dataset_name": self.dataset_name, 109 | "task_name": self.task_name, 110 | } 111 | return_data = {"x": x, "y": y, "meta": meta} 112 | return return_data 113 | 114 | def _read_data(self): 115 | data = self.file 116 | 117 | file_names = data.keys() 118 | self.idx_to_filename = file_names 119 | if self.in_memory: 120 | data = [self.process(self.file[fn]) for fn in tqdm(file_names)] 121 | 122 | print(f"# instances of {self.dataset_name} {self.task_name}: ", len(file_names)) 123 | 124 | return data 125 | 126 | 127 | class BrainDataModule(LightningDataModule): 128 | def __init__( 129 | self, dataset_keys: list[str], y_key: str, batch_size: int, num_workers: int 130 | ) -> None: 131 | super().__init__() 132 | self.y_key = y_key 133 | 134 | if os.environ.get("WANDB_MODE", "") == "disabled": 135 | num_workers = 0 136 | 137 | self.batch_size = batch_size 138 | self.num_workers = num_workers 139 | 140 | datasets = [] 141 | for dataset_key in dataset_keys: 142 | dataset_name, task_name = dataset_key.split("-") 143 | file_path = os.path.join("data", dataset_name, task_name, "data.h5") 144 | file = h5py.File(file_path, "r") 145 | datasets.append( 146 | BrainDataset( 147 | file=file, 148 | y_key=self.y_key, 149 | dataset_name=dataset_name, 150 | task_name=task_name, 151 | ) 152 | ) 153 | 154 | self.datasets = datasets 155 | 156 | def setup(self, stage: str): 157 | if stage != "fit": 158 | return 159 | 160 | datasets = {"train": defaultdict(list), "val": defaultdict(list), "test": defaultdict(list)} 161 | for dataset in self.datasets: 162 | dataset_name, task_name = dataset.dataset_name, dataset.task_name 163 | train_ids, test_ids = self.get_train_test_split(dataset_name, task_name) 164 | 165 | values = [] 166 | for idx in train_ids: 167 | value = dataset[idx]["y"] 168 | values.append(value) 169 | values = np.array(values) 170 | max_ = values.max() 171 | min_ = values.min() 172 | values = np.digitize( 173 | values, bins=np.linspace(min_, max_, 6)[1:], right=True 174 | ) 175 | if False: 176 | """ 177 | NOTE: The performance is really unstable under different random seeds due to the size of the dataset. Therefore, we select the optimal hyperparameters on the validation dataset, and re-train the model using both of the training dataset and the validation dataset. 178 | """ 179 | train_ids, val_ids, _, _ = train_test_split( 180 | train_ids, values, test_size=1 / 8, stratify=values 181 | ) 182 | datasets["train"][task_name].append(Subset(dataset, train_ids)) 183 | 184 | datasets = dict(datasets) 185 | for split in datasets.keys(): 186 | for task, datasets_of_task in datasets[split].items(): 187 | datasets[split][task] = ConcatDataset(datasets_of_task) 188 | self.train = MultiSourceDataset(datasets["train"]) 189 | 190 | def train_dataloader(self): 191 | batch_sampler = MultiSourceBatchSampler( 192 | self.train, batch_size=self.batch_size, shuffle=True 193 | ) 194 | return DataLoader( 195 | self.train, batch_sampler=batch_sampler, num_workers=self.num_workers 196 | ) 197 | 198 | def test_dataloader(self): 199 | return [ 200 | DataLoader( 201 | dataset, 202 | batch_size=self.batch_size, 203 | shuffle=False, 204 | num_workers=self.num_workers, 205 | ) 206 | for dataset in self.datasets 207 | ] 208 | 209 | def get_train_test_split(self, dataset_name, task_name): 210 | base_path = os.path.join("data", dataset_name, task_name) 211 | subjects = list(h5py.File(os.path.join(base_path, "data.h5")).keys()) 212 | 213 | def get_indices_of_subjects(path): 214 | subject_ids = ( 215 | open(os.path.join(base_path, path), "r").read().rstrip("\n").split("\n") 216 | ) 217 | indices = [subjects.index(sid) for sid in subject_ids] 218 | return indices 219 | 220 | train_ids, test_ids = map( 221 | get_indices_of_subjects, ("train.split", "test.split") 222 | ) 223 | return train_ids, test_ids 224 | 225 | @cached_property 226 | def entropy(self) -> dict[str, np.ndarray]: 227 | # collect all x in datasets 228 | all_data = defaultdict(list) 229 | for dataset in self.datasets: 230 | for sample in dataset: 231 | task = sample["meta"]["task_name"] 232 | x = sample["x"] 233 | all_data[task].extend(list(x)) 234 | 235 | # compute entropy for each fMRI task 236 | entropy_all_task = dict() 237 | for task, x in all_data.items(): 238 | x = torch.stack(x, dim=0) 239 | x = x.transpose(0, 1) 240 | x = list(x.numpy()) 241 | 242 | bins = np.linspace(-1.0, 1.0, 100) 243 | entropies = [] 244 | # compute entropy for each node 245 | for x_node in x: 246 | hist = np.histogram(x_node, bins=bins, density=True)[0] 247 | entropy = stats_entropy(hist) 248 | entropies.append(entropy) 249 | 250 | entropy_all_task[task] = np.asarray(entropies) 251 | 252 | return entropy_all_task 253 | -------------------------------------------------------------------------------- /src/hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import HyBRiD 2 | -------------------------------------------------------------------------------- /src/hybrid/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor, nn 4 | 5 | 6 | class Masker(nn.Module): 7 | def __init__( 8 | self, 9 | n_heads: int, 10 | n_nodes: int, 11 | ) -> None: 12 | super().__init__() 13 | 14 | self.mask = nn.Parameter(torch.Tensor(n_heads, n_nodes, 2)) 15 | nn.init.xavier_normal_(self.mask) 16 | 17 | def forward(self) -> tuple[Tensor, Tensor]: 18 | """ 19 | Outputs: 20 | mask - [n_heads (n_hypers), n_nodes], node selection 21 | mask_logits - [n_heads (n_hypers), n_nodes], logits of node selection probability 22 | """ 23 | mask_logits = self.mask 24 | mask_logits = torch.log(mask_logits.softmax(-1)) 25 | 26 | mask = F.gumbel_softmax(mask_logits, tau=1, hard=True)[..., 1] 27 | 28 | return mask, mask_logits[..., 1] 29 | 30 | 31 | class HyBRiDConstructor(nn.Module): 32 | def __init__(self, n_hypers: int, n_nodes: int, dropout: float = 0.1) -> None: 33 | super().__init__() 34 | 35 | self.dropout = nn.Dropout(dropout) 36 | self.mask = Masker(n_hypers, n_nodes) 37 | 38 | def forward(self, x: Tensor) -> tuple[Tensor, tuple[Tensor, Tensor]]: 39 | """ 40 | Inputs: 41 | x - [batch_size, n_nodes, feature_dim], note that feature_dim = n_nodes 42 | Outputs: 43 | h - [batch_size, n_hypers, feature_dim] 44 | mask - [n_hypers, n_nodes], node selection 45 | mask_logits - [n_hypers, n_nodes], logits of node selection probability 46 | """ 47 | bs, n_nodes, dim = x.size() 48 | mask, mask_logits = self.mask() 49 | 50 | x = x[:, None, :, :] * mask[None, :, :, None] 51 | h = x.sum(-2) / (1e-7 + mask.sum(-1)[None, :, None]) 52 | 53 | h = self.dropout(h) 54 | 55 | return h, (mask, mask_logits) 56 | 57 | 58 | class HyBRiDWeighter(nn.Module): 59 | def __init__( 60 | self, d_model: int, hidden_size: int, n_hypers: int, layer_norm_eps=1e-5 61 | ) -> None: 62 | super().__init__() 63 | self.dim_reduction = nn.Sequential( 64 | nn.LayerNorm(d_model, layer_norm_eps), 65 | nn.Linear(d_model, hidden_size), 66 | nn.GELU(), 67 | nn.Linear(hidden_size, d_model), 68 | nn.LayerNorm(d_model, layer_norm_eps), 69 | nn.Linear(d_model, 1), 70 | ) 71 | self.last = nn.Linear(n_hypers, 1) 72 | 73 | def forward(self, h: Tensor) -> tuple[Tensor, Tensor]: 74 | """ 75 | Inputs: 76 | h - [batch_size, n_hypers, feature_dim] 77 | Outputs: 78 | preds - [batch_size], predicted targets (e.g. IQ) 79 | last - [batch_size, n_hypers], used as weights of hyperedges 80 | """ 81 | bs = h.size(0) 82 | h = self.dim_reduction(h) 83 | last = h.reshape((bs, -1)) 84 | preds = self.last(last).squeeze() 85 | 86 | return preds, last 87 | 88 | 89 | class HyBRiD(nn.Module): 90 | def __init__(self, n_hypers: int, hidden_size: int, n_nodes: int, dropout: float) -> None: 91 | super().__init__() 92 | self.constructor = HyBRiDConstructor(n_hypers=n_hypers, n_nodes=n_nodes, dropout=dropout) 93 | self.weighter = HyBRiDWeighter( 94 | d_model=n_nodes, hidden_size=hidden_size, n_hypers=n_hypers 95 | ) 96 | 97 | def forward(self, x: Tensor) -> dict[str, Tensor]: 98 | """ 99 | Inputs: 100 | x - [batch_size, n_nodes, feature_dim], note that feature_dim = n_nodes 101 | Outputs: 102 | preds - [batch_size], predicted targets (e.g. IQ) 103 | last - [batch_size, n_hypers], used as weights of hyperedges 104 | mask - [n_hypers, n_nodes], node selection 105 | mask_logits - [n_hypers, n_nodes], logits of node selection probability 106 | """ 107 | 108 | h, (mask, mask_logits) = self.constructor(x) 109 | preds, last = self.weighter(h) 110 | 111 | return { 112 | "preds": preds, 113 | "last": last.detach(), 114 | "mask": mask, 115 | "mask_logits": mask_logits, 116 | } 117 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import torch 5 | from lightning import LightningModule 6 | from torch import nn 7 | from torchmetrics import MeanSquaredError, Metric 8 | 9 | from .cpm import cpm 10 | 11 | 12 | class CorrMetric(Metric): 13 | def __init__(self): 14 | super().__init__() 15 | self.add_state("preds", default=[]) 16 | self.add_state("target", default=[]) 17 | 18 | def update(self, preds: torch.Tensor, target: torch.Tensor): 19 | self.preds.append(preds.detach()) 20 | self.target.append(target.detach()) 21 | 22 | def compute(self): 23 | if len(self.preds) == 0: 24 | return torch.tensor(0) 25 | preds = torch.cat(self.preds) 26 | target = torch.cat(self.target) 27 | assert preds.shape == target.shape 28 | 29 | vx = preds - torch.mean(preds) 30 | vy = target - torch.mean(target) 31 | corr = torch.sum(vx * vy) / (1e-7 + torch.norm(vx) * torch.norm(vy)) 32 | 33 | return corr 34 | 35 | 36 | class RegressionModule(LightningModule): 37 | def __init__( 38 | self, 39 | model: nn.Module, 40 | learning_rate: float, 41 | weight_decay: float = 0.01, 42 | beta: float = 0.0, 43 | node_entropy: dict[str, np.ndarray] = None, 44 | ) -> None: 45 | super().__init__() 46 | self.save_hyperparameters() 47 | 48 | self.model = model 49 | self.learning_rate = learning_rate 50 | self.weight_decay = weight_decay 51 | self.beta = beta 52 | 53 | for key, value in node_entropy.items(): 54 | self.register_buffer("entropy_" + key, torch.from_numpy(value)) 55 | 56 | self.criterion = nn.MSELoss() 57 | self.train_metric = CorrMetric() 58 | self.val_metric = CorrMetric() 59 | self.mse = MeanSquaredError() 60 | 61 | def configure_optimizers(self): 62 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 63 | 64 | def lr_scaler(epoch): 65 | warmup_epoch = 100 66 | if epoch < warmup_epoch: 67 | # warm up lr 68 | lr_scale = epoch / warmup_epoch 69 | else: 70 | lr_scale = 1.0 71 | 72 | return lr_scale 73 | 74 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_scaler) 75 | 76 | return [optimizer], [scheduler] 77 | 78 | def forward(self, batch): 79 | x = batch["x"] 80 | meta = self.get_meta_from_batch(batch) 81 | task_key = meta["task_name"] 82 | 83 | outputs = self.model(x) 84 | 85 | y_hat = outputs["preds"] 86 | last = outputs["last"] 87 | mask = outputs.get("mask", None) 88 | mask_logits = outputs.get("mask_logits", None) 89 | assert y_hat.dim() == 1 90 | 91 | y = batch["y"] 92 | assert y.dim() == 1 93 | 94 | max_term = -self.criterion(y_hat, y) 95 | 96 | density = mask.mean() 97 | node_entropy = getattr(self, "entropy_" + task_key) 98 | min_term = (torch.sigmoid(mask_logits.squeeze()) * node_entropy).mean() 99 | 100 | loss = -max_term + self.beta * min_term 101 | 102 | return (loss, max_term, min_term), (y_hat, y, last), density 103 | 104 | def training_step(self, batch, batch_idx): 105 | (loss, max_term, min_term), (y_hat, y, _), density = self.forward(batch) 106 | self.train_metric(y_hat, y) 107 | self.log("train/loss", loss) 108 | self.log("train/density", density) 109 | self.log("train/max_term", max_term) 110 | self.log("train/min_term", min_term) 111 | 112 | return loss 113 | 114 | def on_train_epoch_end(self): 115 | self.log("train/corr", self.train_metric) 116 | 117 | def validation_step(self, batch, batch_idx): 118 | (loss, max_term, min_term), (y_hat, y, _), density = self.forward(batch) 119 | 120 | meta = self.get_meta_from_batch(batch) 121 | self.val_metric(y_hat, y) 122 | self.log("val/loss", loss) 123 | self.log("val/density", density) 124 | self.log("val/max_term", max_term) 125 | self.log("val/min_term", min_term) 126 | 127 | self.mse(y_hat, y) 128 | self.log("val/mse", self.mse) 129 | 130 | def on_validation_epoch_end(self): 131 | self.log("val/corr", self.val_metric) 132 | 133 | def on_test_start(self): 134 | self.weights = defaultdict(list) 135 | self.labels = defaultdict(list) 136 | 137 | def get_meta_from_batch(self, batch): 138 | meta = batch["meta"] 139 | if isinstance(meta, list): 140 | meta = meta[0] 141 | elif isinstance(meta, dict): 142 | meta = {k: v[0] for k, v in meta.items()} 143 | return meta 144 | 145 | def test_step(self, batch, batch_idx): 146 | x = batch["x"] 147 | bs = x.size(0) 148 | seq_len = x.size(-1) 149 | 150 | _, (y_hat, y, weights), _ = self(batch) 151 | 152 | meta = self.get_meta_from_batch(batch) 153 | task: str = meta["task_name"] 154 | dataset_name: str = meta["dataset_name"] 155 | self.weights[dataset_name + "-" + task].append(weights.cpu().numpy()) 156 | self.labels[dataset_name + "-" + task].append(batch["y"].cpu().numpy()) 157 | 158 | def on_test_end(self) -> None: 159 | for key in self.weights.keys(): 160 | weights = self.weights[key] 161 | labels = self.labels[key] 162 | weights = np.concatenate(weights) 163 | weights = (weights - weights.mean()) / (1e-7 + weights.std()) 164 | labels = np.concatenate(labels) 165 | weights = np.transpose(weights) 166 | X = np.concatenate([weights, -weights], axis=0) 167 | Y = labels[..., None] 168 | dataset, task = key.split("-") 169 | 170 | train_ids, test_ids = self.trainer.datamodule.get_train_test_split( 171 | dataset, task 172 | ) 173 | train_test_split = { 174 | "train_id": np.asarray(train_ids), 175 | "test_id": np.asarray(test_ids), 176 | } 177 | R, P = cpm(X, Y, train_test_split) 178 | print("CPM_R:", R.item()) 179 | --------------------------------------------------------------------------------