├── .readthedocs.yaml ├── LICENSE ├── README.md ├── data ├── butterfly │ ├── w1.csv │ └── w2.csv └── omics │ ├── clusters.txt │ ├── omics1.txt │ ├── omics2.txt │ └── omics3.txt ├── docs ├── Makefile ├── make.bat ├── requirements-docs.txt └── source │ ├── conf.py │ ├── faq.rst │ ├── index.rst │ ├── installation.rst │ ├── integrao.rst │ ├── references.rst │ ├── tutorial_butterfly.nblink │ ├── tutorial_cancer.nblink │ └── tutorial_classify.nblink ├── figures └── integrAO_overview.png ├── integrao ├── IntegrAO_supervised.py ├── IntegrAO_unsupervised.py ├── __init__.py ├── dataset.py ├── integrater.py ├── main.py ├── supervised_train.py ├── unsupervised_train.py └── util.py ├── pyproject.toml ├── requirement.txt ├── tests └── __init__.py └── tutorials ├── cancer_omics_classification.ipynb ├── simulated_butterfly.ipynb ├── simulated_cancer_omics.ipynb ├── supervised_integration_feature_importance.ipynb └── unsupervised_integration_feature_importance.ipynb /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version, and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.10" 12 | 13 | # Build documentation in the "docs/" directory with Sphinx 14 | sphinx: 15 | configuration: docs/source/conf.py 16 | 17 | # Optionally, but recommended, 18 | # declare the Python requirements required to build your documentation 19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: docs/requirements-docs.txt 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 WangLab @ U of T 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IntegrAO: Integrate Any Omics 2 | This is the official codebase for **Integrate Any Omics: Towards genome-wide data integration for patient stratification**. 3 | 4 | [![Preprint](https://img.shields.io/badge/preprint-available-brightgreen)](https://arxiv.org/abs/2401.07937)   5 | [![Documentation](https://img.shields.io/badge/docs-available-brightgreen)](https://integrao.readthedocs.io/en/latest/)   6 | [![PyPI version](https://badge.fury.io/py/integrao.svg)](https://pypi.org/project/integrao/)   7 | [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/bowang-lab/IntegrAO/blob/main/LICENSE) 8 | 9 | **Updates**: 10 | 11 | **[2025.03.02]** 🔥🔥🔥 We added the functionalities of extracting **feature importance** for the unsupervised and supervised IntegrAO models! Feel free to check it out here: [Unsupervised integration feature importance](https://github.com/bowang-lab/IntegrAO/blob/main/tutorials/unsupervised_integration_feature_importance.ipynb) and [Supervised integration feature importance](https://github.com/bowang-lab/IntegrAO/blob/main/tutorials/supervised_integration_feature_importance.ipynb). Welcome for suggestions! 12 | 13 | **[2025.01.23]** 🥳 IntegrAO is published on [Nature Machine Intelligence](https://www.nature.com/articles/s42256-024-00942-3)! 14 | 15 | **[2024.01.15]** 🥳 IntegrAO [Preprint](https://arxiv.org/abs/2401.07937) available! 16 | 17 | ## 🔨 Hardware requirements 18 | `IntegrAO` package requires only a standard computer with enough RAM to support the in-memory operations. 19 | 20 | 21 | ## 🔨 Installation 22 | IntegrAO works with Python >= 3.7. Please make sure you have the correct version of Python pre-installation. 23 | 24 | 1. Create a virtual environment: `conda create -n integrAO python=3.10 -y` and `conda activate integrAO` 25 | 2. Install [Pytorch 2.1.0](https://pytorch.org/get-started/locally/) 26 | 3. IntegrAO is available on PyPI. To install IntegrAO, run the following command: `pip install integrao` 27 | 28 | For developing, clone this repo with following commands: 29 | 30 | ```bash 31 | $ git clone this-repo-url 32 | $ cd IntegrAO 33 | $ pip install -r requirement.txt 34 | ``` 35 | 36 | 37 | ## 🧬 Introduction 38 | High-throughput omics profiling advancements have greatly enhanced cancer patient stratification. However, incomplete data in multi-omics integration presents a significant challenge, as traditional methods like sample exclusion or imputation often compromise biological diversity and dependencies. Furthermore, the critical task of accurately classifying new patients with partial omics data into existing subtypes is commonly overlooked. We introduce IntegrAO, an unsupervised framework integrating incomplete multi-omics and classifying new biological samples. IntegrAO first combines partially overlapping patient graphs from diverse omics sources and utilizes graph neural networks to produce unified patient embeddings. 39 | 40 | An overview of IntegrAO can be seen below. 41 | 42 | ![integrAO](https://github.com/bowang-lab/IntegrAO/blob/main/figures/integrAO_overview.png) 43 | 44 | ## 📖 Tutorial 45 | 46 | We offer the following tutorials for demonstration: 47 | 48 | * **NEW**: [Unsupervised integration feature importance](https://github.com/bowang-lab/IntegrAO/blob/main/tutorials/unsupervised_integration_feature_importance.ipynb) 49 | * **NEW**: [Supervised integration feature importance](https://github.com/bowang-lab/IntegrAO/blob/main/tutorials/supervised_integration_feature_importance.ipynb) 50 | * [Integrate simulated butterfly datasets](https://github.com/bowang-lab/IntegrAO/blob/main/tutorials/simulated_butterfly.ipynb) 51 | * [Integrate simulated cancer omics datasets](https://github.com/bowang-lab/IntegrAO/blob/main/tutorials/simulated_cancer_omics.ipynb) 52 | * [Classify new samples with incomplete omics datasets](https://github.com/bowang-lab/IntegrAO/blob/main/tutorials/cancer_omics_classification.ipynb) 53 | 54 | ## Citing IntegrAO 55 | ```bash 56 | @article{ma2025moving, 57 | title={Moving towards genome-wide data integration for patient stratification with Integrate Any Omics}, 58 | author={Ma, Shihao and Zeng, Andy GX and Haibe-Kains, Benjamin and Goldenberg, Anna and Dick, John E and Wang, Bo}, 59 | journal={Nature Machine Intelligence}, 60 | volume={7}, 61 | number={1}, 62 | pages={29--42}, 63 | year={2025}, 64 | publisher={Nature Publishing Group} 65 | } 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /data/omics/clusters.txt: -------------------------------------------------------------------------------- 1 | "subjects" "cluster.id" 2 | "1" "subject1" 6 3 | "2" "subject2" 7 4 | "3" "subject3" 9 5 | "4" "subject4" 6 6 | "5" "subject5" 4 7 | "6" "subject6" 11 8 | "7" "subject7" 12 9 | "8" "subject8" 12 10 | "9" "subject9" 12 11 | "10" "subject10" 2 12 | "11" "subject11" 5 13 | "12" "subject12" 6 14 | "13" "subject13" 2 15 | "14" "subject14" 5 16 | "15" "subject15" 11 17 | "16" "subject16" 11 18 | "17" "subject17" 4 19 | "18" "subject18" 1 20 | "19" "subject19" 9 21 | "20" "subject20" 3 22 | "21" "subject21" 8 23 | "22" "subject22" 8 24 | "23" "subject23" 12 25 | "24" "subject24" 1 26 | "25" "subject25" 11 27 | "26" "subject26" 11 28 | "27" "subject27" 2 29 | "28" "subject28" 12 30 | "29" "subject29" 12 31 | "30" "subject30" 2 32 | "31" "subject31" 9 33 | "32" "subject32" 11 34 | "33" "subject33" 5 35 | "34" "subject34" 10 36 | "35" "subject35" 7 37 | "36" "subject36" 5 38 | "37" "subject37" 1 39 | "38" "subject38" 12 40 | "39" "subject39" 12 41 | "40" "subject40" 11 42 | "41" "subject41" 12 43 | "42" "subject42" 12 44 | "43" "subject43" 5 45 | "44" "subject44" 7 46 | "45" "subject45" 14 47 | "46" "subject46" 12 48 | "47" "subject47" 11 49 | "48" "subject48" 11 50 | "49" "subject49" 7 51 | "50" "subject50" 12 52 | "51" "subject51" 13 53 | "52" "subject52" 1 54 | "53" "subject53" 15 55 | "54" "subject54" 12 56 | "55" "subject55" 13 57 | "56" "subject56" 1 58 | "57" "subject57" 5 59 | "58" "subject58" 3 60 | "59" "subject59" 1 61 | "60" "subject60" 4 62 | "61" "subject61" 12 63 | "62" "subject62" 5 64 | "63" "subject63" 15 65 | "64" "subject64" 9 66 | "65" "subject65" 12 67 | "66" "subject66" 7 68 | "67" "subject67" 1 69 | "68" "subject68" 14 70 | "69" "subject69" 11 71 | "70" "subject70" 2 72 | "71" "subject71" 12 73 | "72" "subject72" 15 74 | "73" "subject73" 14 75 | "74" "subject74" 6 76 | "75" "subject75" 2 77 | "76" "subject76" 12 78 | "77" "subject77" 2 79 | "78" "subject78" 7 80 | "79" "subject79" 1 81 | "80" "subject80" 12 82 | "81" "subject81" 5 83 | "82" "subject82" 4 84 | "83" "subject83" 11 85 | "84" "subject84" 13 86 | "85" "subject85" 4 87 | "86" "subject86" 13 88 | "87" "subject87" 5 89 | "88" "subject88" 7 90 | "89" "subject89" 11 91 | "90" "subject90" 7 92 | "91" "subject91" 7 93 | "92" "subject92" 1 94 | "93" "subject93" 12 95 | "94" "subject94" 1 96 | "95" "subject95" 2 97 | "96" "subject96" 9 98 | "97" "subject97" 15 99 | "98" "subject98" 7 100 | "99" "subject99" 15 101 | "100" "subject100" 12 102 | "101" "subject101" 12 103 | "102" "subject102" 7 104 | "103" "subject103" 13 105 | "104" "subject104" 3 106 | "105" "subject105" 7 107 | "106" "subject106" 7 108 | "107" "subject107" 5 109 | "108" "subject108" 8 110 | "109" "subject109" 5 111 | "110" "subject110" 12 112 | "111" "subject111" 15 113 | "112" "subject112" 1 114 | "113" "subject113" 9 115 | "114" "subject114" 12 116 | "115" "subject115" 9 117 | "116" "subject116" 2 118 | "117" "subject117" 12 119 | "118" "subject118" 1 120 | "119" "subject119" 12 121 | "120" "subject120" 12 122 | "121" "subject121" 6 123 | "122" "subject122" 9 124 | "123" "subject123" 8 125 | "124" "subject124" 5 126 | "125" "subject125" 1 127 | "126" "subject126" 11 128 | "127" "subject127" 11 129 | "128" "subject128" 5 130 | "129" "subject129" 9 131 | "130" "subject130" 11 132 | "131" "subject131" 11 133 | "132" "subject132" 6 134 | "133" "subject133" 1 135 | "134" "subject134" 12 136 | "135" "subject135" 8 137 | "136" "subject136" 1 138 | "137" "subject137" 1 139 | "138" "subject138" 2 140 | "139" "subject139" 1 141 | "140" "subject140" 7 142 | "141" "subject141" 3 143 | "142" "subject142" 7 144 | "143" "subject143" 7 145 | "144" "subject144" 12 146 | "145" "subject145" 3 147 | "146" "subject146" 11 148 | "147" "subject147" 13 149 | "148" "subject148" 1 150 | "149" "subject149" 12 151 | "150" "subject150" 12 152 | "151" "subject151" 11 153 | "152" "subject152" 11 154 | "153" "subject153" 2 155 | "154" "subject154" 12 156 | "155" "subject155" 9 157 | "156" "subject156" 12 158 | "157" "subject157" 1 159 | "158" "subject158" 5 160 | "159" "subject159" 7 161 | "160" "subject160" 7 162 | "161" "subject161" 15 163 | "162" "subject162" 13 164 | "163" "subject163" 11 165 | "164" "subject164" 12 166 | "165" "subject165" 1 167 | "166" "subject166" 1 168 | "167" "subject167" 6 169 | "168" "subject168" 5 170 | "169" "subject169" 10 171 | "170" "subject170" 4 172 | "171" "subject171" 1 173 | "172" "subject172" 1 174 | "173" "subject173" 12 175 | "174" "subject174" 14 176 | "175" "subject175" 12 177 | "176" "subject176" 11 178 | "177" "subject177" 9 179 | "178" "subject178" 12 180 | "179" "subject179" 11 181 | "180" "subject180" 15 182 | "181" "subject181" 5 183 | "182" "subject182" 11 184 | "183" "subject183" 4 185 | "184" "subject184" 3 186 | "185" "subject185" 12 187 | "186" "subject186" 4 188 | "187" "subject187" 15 189 | "188" "subject188" 11 190 | "189" "subject189" 3 191 | "190" "subject190" 14 192 | "191" "subject191" 11 193 | "192" "subject192" 12 194 | "193" "subject193" 10 195 | "194" "subject194" 10 196 | "195" "subject195" 7 197 | "196" "subject196" 8 198 | "197" "subject197" 7 199 | "198" "subject198" 3 200 | "199" "subject199" 3 201 | "200" "subject200" 12 202 | "201" "subject201" 3 203 | "202" "subject202" 8 204 | "203" "subject203" 2 205 | "204" "subject204" 12 206 | "205" "subject205" 2 207 | "206" "subject206" 12 208 | "207" "subject207" 1 209 | "208" "subject208" 3 210 | "209" "subject209" 11 211 | "210" "subject210" 7 212 | "211" "subject211" 5 213 | "212" "subject212" 11 214 | "213" "subject213" 11 215 | "214" "subject214" 12 216 | "215" "subject215" 11 217 | "216" "subject216" 12 218 | "217" "subject217" 1 219 | "218" "subject218" 7 220 | "219" "subject219" 2 221 | "220" "subject220" 6 222 | "221" "subject221" 12 223 | "222" "subject222" 9 224 | "223" "subject223" 11 225 | "224" "subject224" 13 226 | "225" "subject225" 7 227 | "226" "subject226" 7 228 | "227" "subject227" 4 229 | "228" "subject228" 7 230 | "229" "subject229" 7 231 | "230" "subject230" 2 232 | "231" "subject231" 15 233 | "232" "subject232" 7 234 | "233" "subject233" 7 235 | "234" "subject234" 12 236 | "235" "subject235" 7 237 | "236" "subject236" 7 238 | "237" "subject237" 12 239 | "238" "subject238" 4 240 | "239" "subject239" 5 241 | "240" "subject240" 13 242 | "241" "subject241" 11 243 | "242" "subject242" 12 244 | "243" "subject243" 5 245 | "244" "subject244" 9 246 | "245" "subject245" 8 247 | "246" "subject246" 13 248 | "247" "subject247" 15 249 | "248" "subject248" 2 250 | "249" "subject249" 7 251 | "250" "subject250" 12 252 | "251" "subject251" 15 253 | "252" "subject252" 5 254 | "253" "subject253" 4 255 | "254" "subject254" 1 256 | "255" "subject255" 12 257 | "256" "subject256" 11 258 | "257" "subject257" 5 259 | "258" "subject258" 1 260 | "259" "subject259" 2 261 | "260" "subject260" 12 262 | "261" "subject261" 8 263 | "262" "subject262" 5 264 | "263" "subject263" 2 265 | "264" "subject264" 11 266 | "265" "subject265" 5 267 | "266" "subject266" 11 268 | "267" "subject267" 12 269 | "268" "subject268" 9 270 | "269" "subject269" 1 271 | "270" "subject270" 3 272 | "271" "subject271" 15 273 | "272" "subject272" 5 274 | "273" "subject273" 7 275 | "274" "subject274" 9 276 | "275" "subject275" 13 277 | "276" "subject276" 15 278 | "277" "subject277" 9 279 | "278" "subject278" 12 280 | "279" "subject279" 15 281 | "280" "subject280" 12 282 | "281" "subject281" 9 283 | "282" "subject282" 13 284 | "283" "subject283" 12 285 | "284" "subject284" 2 286 | "285" "subject285" 1 287 | "286" "subject286" 5 288 | "287" "subject287" 2 289 | "288" "subject288" 9 290 | "289" "subject289" 10 291 | "290" "subject290" 5 292 | "291" "subject291" 7 293 | "292" "subject292" 5 294 | "293" "subject293" 5 295 | "294" "subject294" 15 296 | "295" "subject295" 4 297 | "296" "subject296" 12 298 | "297" "subject297" 9 299 | "298" "subject298" 2 300 | "299" "subject299" 7 301 | "300" "subject300" 8 302 | "301" "subject301" 3 303 | "302" "subject302" 5 304 | "303" "subject303" 12 305 | "304" "subject304" 15 306 | "305" "subject305" 5 307 | "306" "subject306" 5 308 | "307" "subject307" 11 309 | "308" "subject308" 1 310 | "309" "subject309" 8 311 | "310" "subject310" 12 312 | "311" "subject311" 7 313 | "312" "subject312" 2 314 | "313" "subject313" 2 315 | "314" "subject314" 2 316 | "315" "subject315" 11 317 | "316" "subject316" 11 318 | "317" "subject317" 12 319 | "318" "subject318" 12 320 | "319" "subject319" 12 321 | "320" "subject320" 5 322 | "321" "subject321" 7 323 | "322" "subject322" 12 324 | "323" "subject323" 11 325 | "324" "subject324" 12 326 | "325" "subject325" 12 327 | "326" "subject326" 6 328 | "327" "subject327" 2 329 | "328" "subject328" 1 330 | "329" "subject329" 7 331 | "330" "subject330" 15 332 | "331" "subject331" 1 333 | "332" "subject332" 5 334 | "333" "subject333" 5 335 | "334" "subject334" 7 336 | "335" "subject335" 13 337 | "336" "subject336" 15 338 | "337" "subject337" 2 339 | "338" "subject338" 5 340 | "339" "subject339" 14 341 | "340" "subject340" 12 342 | "341" "subject341" 6 343 | "342" "subject342" 15 344 | "343" "subject343" 11 345 | "344" "subject344" 1 346 | "345" "subject345" 5 347 | "346" "subject346" 5 348 | "347" "subject347" 5 349 | "348" "subject348" 1 350 | "349" "subject349" 5 351 | "350" "subject350" 3 352 | "351" "subject351" 15 353 | "352" "subject352" 5 354 | "353" "subject353" 5 355 | "354" "subject354" 1 356 | "355" "subject355" 7 357 | "356" "subject356" 7 358 | "357" "subject357" 3 359 | "358" "subject358" 1 360 | "359" "subject359" 5 361 | "360" "subject360" 5 362 | "361" "subject361" 13 363 | "362" "subject362" 12 364 | "363" "subject363" 12 365 | "364" "subject364" 2 366 | "365" "subject365" 10 367 | "366" "subject366" 12 368 | "367" "subject367" 1 369 | "368" "subject368" 12 370 | "369" "subject369" 8 371 | "370" "subject370" 12 372 | "371" "subject371" 9 373 | "372" "subject372" 2 374 | "373" "subject373" 7 375 | "374" "subject374" 11 376 | "375" "subject375" 12 377 | "376" "subject376" 3 378 | "377" "subject377" 1 379 | "378" "subject378" 15 380 | "379" "subject379" 11 381 | "380" "subject380" 12 382 | "381" "subject381" 8 383 | "382" "subject382" 12 384 | "383" "subject383" 12 385 | "384" "subject384" 12 386 | "385" "subject385" 5 387 | "386" "subject386" 12 388 | "387" "subject387" 6 389 | "388" "subject388" 14 390 | "389" "subject389" 5 391 | "390" "subject390" 8 392 | "391" "subject391" 2 393 | "392" "subject392" 4 394 | "393" "subject393" 12 395 | "394" "subject394" 1 396 | "395" "subject395" 15 397 | "396" "subject396" 12 398 | "397" "subject397" 2 399 | "398" "subject398" 12 400 | "399" "subject399" 1 401 | "400" "subject400" 5 402 | "401" "subject401" 12 403 | "402" "subject402" 11 404 | "403" "subject403" 1 405 | "404" "subject404" 2 406 | "405" "subject405" 9 407 | "406" "subject406" 4 408 | "407" "subject407" 12 409 | "408" "subject408" 1 410 | "409" "subject409" 12 411 | "410" "subject410" 9 412 | "411" "subject411" 12 413 | "412" "subject412" 3 414 | "413" "subject413" 6 415 | "414" "subject414" 9 416 | "415" "subject415" 11 417 | "416" "subject416" 15 418 | "417" "subject417" 12 419 | "418" "subject418" 12 420 | "419" "subject419" 11 421 | "420" "subject420" 7 422 | "421" "subject421" 2 423 | "422" "subject422" 5 424 | "423" "subject423" 3 425 | "424" "subject424" 5 426 | "425" "subject425" 2 427 | "426" "subject426" 2 428 | "427" "subject427" 2 429 | "428" "subject428" 11 430 | "429" "subject429" 7 431 | "430" "subject430" 1 432 | "431" "subject431" 12 433 | "432" "subject432" 5 434 | "433" "subject433" 11 435 | "434" "subject434" 14 436 | "435" "subject435" 1 437 | "436" "subject436" 12 438 | "437" "subject437" 12 439 | "438" "subject438" 7 440 | "439" "subject439" 12 441 | "440" "subject440" 2 442 | "441" "subject441" 5 443 | "442" "subject442" 5 444 | "443" "subject443" 7 445 | "444" "subject444" 5 446 | "445" "subject445" 8 447 | "446" "subject446" 5 448 | "447" "subject447" 3 449 | "448" "subject448" 12 450 | "449" "subject449" 2 451 | "450" "subject450" 2 452 | "451" "subject451" 11 453 | "452" "subject452" 1 454 | "453" "subject453" 12 455 | "454" "subject454" 14 456 | "455" "subject455" 1 457 | "456" "subject456" 1 458 | "457" "subject457" 12 459 | "458" "subject458" 3 460 | "459" "subject459" 12 461 | "460" "subject460" 7 462 | "461" "subject461" 12 463 | "462" "subject462" 5 464 | "463" "subject463" 10 465 | "464" "subject464" 7 466 | "465" "subject465" 10 467 | "466" "subject466" 12 468 | "467" "subject467" 2 469 | "468" "subject468" 11 470 | "469" "subject469" 5 471 | "470" "subject470" 11 472 | "471" "subject471" 12 473 | "472" "subject472" 6 474 | "473" "subject473" 15 475 | "474" "subject474" 1 476 | "475" "subject475" 11 477 | "476" "subject476" 5 478 | "477" "subject477" 7 479 | "478" "subject478" 12 480 | "479" "subject479" 10 481 | "480" "subject480" 10 482 | "481" "subject481" 5 483 | "482" "subject482" 5 484 | "483" "subject483" 7 485 | "484" "subject484" 13 486 | "485" "subject485" 12 487 | "486" "subject486" 12 488 | "487" "subject487" 2 489 | "488" "subject488" 5 490 | "489" "subject489" 12 491 | "490" "subject490" 6 492 | "491" "subject491" 7 493 | "492" "subject492" 6 494 | "493" "subject493" 12 495 | "494" "subject494" 5 496 | "495" "subject495" 9 497 | "496" "subject496" 1 498 | "497" "subject497" 14 499 | "498" "subject498" 4 500 | "499" "subject499" 1 501 | "500" "subject500" 9 502 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements-docs.txt: -------------------------------------------------------------------------------- 1 | 2 | # The following packages are for .readthedocs 3 | sphinx 4 | sphinx_rtd_theme 5 | nbsphinx 6 | nbsphinx-link 7 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # -- Project information 4 | 5 | project = 'IntegrAO' 6 | copyright = '2024, Wanglab, Shihao Ma' 7 | author = 'Shihao Ma' 8 | 9 | release = '0.1' 10 | version = '0.1.0' 11 | 12 | # -- General configuration 13 | 14 | extensions = [ 15 | 'sphinx.ext.duration', 16 | 'sphinx.ext.doctest', 17 | 'sphinx.ext.autodoc', 18 | 'sphinx.ext.autosummary', 19 | 'sphinx.ext.intersphinx', 20 | 'sphinx_rtd_theme', 21 | 'nbsphinx', 22 | 'nbsphinx_link', 23 | ] 24 | 25 | intersphinx_mapping = { 26 | 'python': ('https://docs.python.org/3/', None), 27 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 28 | } 29 | intersphinx_disabled_domains = ['std'] 30 | 31 | templates_path = ['_templates'] 32 | 33 | # -- Options for HTML output 34 | 35 | html_theme = 'sphinx_rtd_theme' #To use this theme, make sure 'sphinx' and 'sphinx_rtd_theme' are in requirements.txt AND sphinx_rtd_theme is listed in extensions above (https://sphinx-rtd-theme.readthedocs.io/en/stable/installing.html) 36 | 37 | # -- Options for EPUB output 38 | epub_show_urls = 'footnote' 39 | -------------------------------------------------------------------------------- /docs/source/faq.rst: -------------------------------------------------------------------------------- 1 | FAQ 2 | ------------ 3 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. SHAP documentation master file, 2 | You can adapt this file completely to your liking, but it should at least 3 | contain the root `toctree` directive. 4 | 5 | Welcome to IntegrAO's documentation 6 | =================================== 7 | 8 | 9 | **IntegrAO** (**Integr**\ate **A**\ny **O**\mics) is an unsupervised, GNN-based framework for integrating incomplete multi-omics data. (see 10 | `paper `_ for details and citations). 11 | 12 | .. image:: https://img.shields.io/badge/preprint-available-brightgreen.svg?style=flat 13 | :target: https://arxiv.org/abs/2401.07937 14 | :alt: Preprint link 15 | 16 | .. image:: https://badge.fury.io/py/integrao.svg 17 | :target: https://badge.fury.io/py/integrao 18 | :alt: PyPI version 19 | 20 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg 21 | :target: https://github.com/bowang-lab/IntegrAO/blob/main/LICENSE 22 | :alt: License 23 | 24 | 25 | Introduction 26 | ---------- 27 | High-throughput omics profiling advancements have greatly enhanced cancer patient stratification. However, incomplete data in multi-omics integration presents a significant challenge, as traditional methods like sample exclusion or imputation often compromise biological diversity and dependencies. Furthermore, the critical task of accurately classifying new patients with partial omics data into existing subtypes is commonly overlooked. We introduce IntegrAO, an unsupervised framework integrating incomplete multi-omics and classifying new biological samples. IntegrAO first combines partially overlapping patient graphs from diverse omics sources and utilizes graph neural networks to produce unified patient embeddings. 28 | 29 | Overview 30 | -------- 31 | .. image:: https://github.com/bowang-lab/IntegrAO/raw/main/figures/integrAO_overview.png 32 | 33 | 34 | .. toctree:: 35 | :maxdepth: 2 36 | :caption: Getting Started 37 | 38 | Installation 39 | 40 | .. toctree:: 41 | :maxdepth: 1 42 | :caption: Tutorials 43 | 44 | tutorial_butterfly 45 | tutorial_cancer 46 | tutorial_classify 47 | 48 | .. toctree:: 49 | :maxdepth: 2 50 | :caption: API 51 | 52 | integrao 53 | 54 | .. toctree:: 55 | :maxdepth: 2 56 | :caption: References: 57 | 58 | faq 59 | references 60 | 61 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ------------ 3 | IntegrAO works with Python >= 3.7. Please make sure you have the correct version of Python pre-installation. 4 | 5 | 1. Create a virtual environment 6 | :: 7 | conda create -n integrAO python=3.7 -y 8 | conda activate integrAO 9 | 10 | 2. Install `Pytorch `_ 2.1.0 11 | :: 12 | pip install torch torchvision torchaudio 13 | 14 | 3. IntegrAO is available on PyPI. To install IntegrAO, run the following command 15 | :: 16 | pip install integrao 17 | 18 | For developing, clone this repo with following commands:: 19 | 20 | git clone https://github.com/bowang-lab/IntegrAO.git 21 | cd IntegrAO 22 | pip install -r requirement.txt 23 | -------------------------------------------------------------------------------- /docs/source/integrao.rst: -------------------------------------------------------------------------------- 1 | IntegrAO package 2 | =============== 3 | 4 | integrao.IntegrAO_supervised 5 | --------------------------- 6 | 7 | .. automodule:: integrao.IntegrAO\_supervised 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | integrao.IntegrAO\_unsupervised 13 | -------------------------- 14 | 15 | .. automodule:: integrao.IntegrAO_unsupervised 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | integrao.dataset 21 | ----------------- 22 | 23 | .. automodule:: integrao.dataset 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | integrao.integrater 29 | ----------------------- 30 | 31 | .. automodule:: integrao.integrater 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | integrao.main 37 | -------------------- 38 | 39 | .. automodule:: integrao.main 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | integrao.supervised\_train 45 | -------------------- 46 | 47 | .. automodule:: integrao.supervised\_train 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | integrao.unsupervised\_train 53 | -------------------- 54 | 55 | .. automodule:: integrao.unsupervised\_train 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | integrao.util 61 | -------------------- 62 | 63 | .. automodule:: integrao.util 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ------------ 3 | 4 | -------------------------------------------------------------------------------- /docs/source/tutorial_butterfly.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../tutorials/simulated_butterfly.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/tutorial_cancer.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../tutorials/simulated_cancer_omics.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/tutorial_classify.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../tutorials/cancer_omics_classification.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /figures/integrAO_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/IntegrAO/5be3f4a46a1a4739ddf737008e1164783f48c37a/figures/integrAO_overview.png -------------------------------------------------------------------------------- /integrao/IntegrAO_supervised.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import torch 6 | 7 | from torch_geometric.nn import GraphSAGE 8 | 9 | 10 | class IntegrAO(nn.Module): 11 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, pred_n_layer=2, pred_act="softplus", num_classes=None): 12 | super(IntegrAO, self).__init__() 13 | self.in_channels = in_channels # this is an array since we have multiple domains 14 | self.hidden_channels = hidden_channels 15 | self.output_dim = out_channels 16 | self.num_layers = num_layers 17 | self.num_classes = num_classes 18 | self.pred_n_layer=2, 19 | self.pred_act="softplus", 20 | 21 | num = len(in_channels) 22 | feature = [] 23 | for i in range(num): 24 | model_sage = GraphSAGE( 25 | in_channels=self.in_channels[i], 26 | hidden_channels=self.hidden_channels, 27 | num_layers=self.num_layers, 28 | out_channels=self.output_dim, 29 | project=False,) 30 | 31 | feature.append(model_sage) 32 | 33 | self.feature = nn.ModuleList(feature) 34 | 35 | self.feature_show = nn.Sequential( 36 | nn.Linear(self.output_dim, self.output_dim), 37 | nn.BatchNorm1d(self.output_dim), 38 | nn.LeakyReLU(0.1, True), 39 | nn.Linear(self.output_dim, self.output_dim), 40 | ) 41 | 42 | self.pred_head = nn.Sequential( 43 | nn.Linear(self.output_dim, self.output_dim // 2 ), 44 | nn.BatchNorm1d(self.output_dim // 2), 45 | nn.LeakyReLU(0.1, True), 46 | # nn.Softplus(), 47 | nn.Linear(self.output_dim // 2, self.num_classes), 48 | ) 49 | 50 | print(self) 51 | 52 | 53 | def get_sample_ids_for_domain(self, domain): 54 | return self.sample_ids[domain] 55 | 56 | 57 | def forward(self, x_dict, edge_index_dict, domain_sample_ids): 58 | z_all = {} 59 | z_sample_dict = {} 60 | for domain in x_dict.keys(): 61 | z = self.feature[domain](x_dict[domain], edge_index_dict[domain]) 62 | z = self.feature_show(z) 63 | z_all[domain] = z 64 | 65 | # Let's assume that your samples have unique identifiers and you 66 | # can extract these identifiers for each domain 67 | sample_ids = domain_sample_ids[domain] 68 | 69 | # Go through each sample and its corresponding vector 70 | for sample_id, vector in zip(sample_ids, z): 71 | # If the sample's vectors haven't been recorded, create a new list 72 | if sample_id not in z_sample_dict: 73 | z_sample_dict[sample_id] = [] 74 | 75 | # Append the new vector to the list of vectors 76 | z_sample_dict[sample_id].append(vector) 77 | 78 | # Now, average the vectors for each sample 79 | z_avg = {} 80 | for sample_id, vectors in z_sample_dict.items(): 81 | # Stack all vectors along a new dimension and calculate the mean 82 | z_avg[sample_id] = torch.stack(vectors).mean(dim=0) 83 | 84 | sorted_list = z_avg.items() # sorted(z_avg.items()) 85 | z_avg_list = [z_avg for _, z_avg in sorted_list] 86 | z_id_list = [z_id for z_id, _ in sorted_list] 87 | z_avg_tensor = torch.stack(z_avg_list) 88 | 89 | output = self.pred_head(z_avg_tensor) 90 | return z_all, z_avg_tensor, output, z_id_list 91 | 92 | def load_my_state_dict(self, state_dict): 93 | own_state = self.state_dict() 94 | for name, param in state_dict.items(): 95 | if name not in own_state: 96 | continue 97 | if isinstance(param, nn.parameter.Parameter): 98 | # backwards compatibility for serialized parameters 99 | param = param.data 100 | own_state[name].copy_(param) 101 | -------------------------------------------------------------------------------- /integrao/IntegrAO_unsupervised.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | from torch_geometric.nn import GraphSAGE 7 | 8 | 9 | class IntegrAO(nn.Module): 10 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): 11 | super(IntegrAO, self).__init__() 12 | self.in_channels = in_channels # this is an array since we have multiple domains 13 | self.hidden_channels = hidden_channels 14 | self.output_dim = out_channels 15 | self.num_layers = num_layers 16 | 17 | num = len(in_channels) 18 | feature = [] 19 | 20 | for i in range(num): 21 | model_sage = GraphSAGE( 22 | in_channels=self.in_channels[i], 23 | hidden_channels=self.hidden_channels, 24 | num_layers=self.num_layers, 25 | out_channels=self.output_dim, 26 | project=False, 27 | ) 28 | 29 | feature.append(model_sage) 30 | 31 | self.feature = nn.ModuleList(feature) 32 | 33 | self.feature_show = nn.Sequential( 34 | nn.Linear(self.output_dim, self.output_dim), 35 | nn.BatchNorm1d(self.output_dim), 36 | nn.LeakyReLU(0.1, True), 37 | nn.Linear(self.output_dim, self.output_dim), 38 | ) 39 | 40 | def forward(self, x_dict, edge_index_dict): 41 | z_all = {} 42 | for domain in x_dict.keys(): 43 | z = self.feature[domain](x_dict[domain], edge_index_dict[domain]) 44 | z = self.feature_show(z) 45 | z_all[domain] = z 46 | 47 | return z_all 48 | -------------------------------------------------------------------------------- /integrao/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/IntegrAO/5be3f4a46a1a4739ddf737008e1164783f48c37a/integrao/__init__.py -------------------------------------------------------------------------------- /integrao/dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import InMemoryDataset, Data 2 | from snf.compute import _find_dominate_set 3 | from sklearn.utils.validation import ( 4 | check_array, 5 | check_symmetric, 6 | check_consistent_length, 7 | 8 | ) 9 | import networkx as nx 10 | import numpy as np 11 | import torch 12 | 13 | # custom dataset 14 | class GraphDataset(InMemoryDataset): 15 | 16 | def __init__(self, neighbor_size, feature, network, transform=None): 17 | super(GraphDataset, self).__init__('.', transform, None, None) 18 | 19 | neighbor_size = min(int(neighbor_size), network.shape[0]) 20 | 21 | # preprocess the input into a pyg graph 22 | network = _find_dominate_set(network, K=neighbor_size) 23 | network = check_symmetric(network, raise_warning=False) 24 | network[network > 0.0] = 1.0 25 | G = nx.from_numpy_array(network) 26 | 27 | # create edge index from 28 | adj = nx.to_scipy_sparse_array(G).tocoo() 29 | row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long) 30 | col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long) 31 | edge_index = torch.stack([row, col], dim=0) 32 | 33 | data = Data(edge_index=edge_index) 34 | data.num_nodes = G.number_of_nodes() 35 | 36 | # embedding 37 | data.x = torch.from_numpy(feature).type(torch.float32) 38 | 39 | self.data, self.slices = self.collate([data]) 40 | 41 | def _download(self): 42 | return 43 | 44 | def _process(self): 45 | return 46 | 47 | def __repr__(self): 48 | return '{}()'.format(self.__class__.__name__) 49 | 50 | 51 | 52 | # custom dataset 53 | class GraphDataset_weight(InMemoryDataset): 54 | 55 | def __init__(self, neighbor_size, feature, network, transform=None): 56 | super(GraphDataset_weight, self).__init__('.', transform, None, None) 57 | 58 | # preprocess the input into a pyg graph 59 | network = _find_dominate_set(network, K=neighbor_size) 60 | network = check_symmetric(network, raise_warning=False) 61 | 62 | # Create a binary mask to extract non-zero values for edge weights 63 | mask = (network > 0.0).astype(float) 64 | G = nx.from_numpy_array(mask) 65 | 66 | # create edge index from 67 | adj = nx.to_scipy_sparse_array(G).tocoo() 68 | row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long) 69 | col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long) 70 | edge_index = torch.stack([row, col], dim=0) 71 | 72 | # Extracting edge weights from the original network using the mask 73 | edge_weights = network[adj.row, adj.col] 74 | 75 | data = Data(edge_index=edge_index, edge_attr=torch.from_numpy(edge_weights).type(torch.float32)) 76 | data.num_nodes = G.number_of_nodes() 77 | 78 | # embedding 79 | data.x = torch.from_numpy(feature).type(torch.float32) 80 | 81 | self.data, self.slices = self.collate([data]) 82 | 83 | def _download(self): 84 | return 85 | 86 | def _process(self): 87 | return 88 | 89 | def __repr__(self): 90 | return '{}()'.format(self.__class__.__name__) 91 | -------------------------------------------------------------------------------- /integrao/integrater.py: -------------------------------------------------------------------------------- 1 | from integrao.unsupervised_train import tsne_p_deep 2 | from integrao.supervised_train import tsne_p_deep_classification 3 | 4 | from integrao.main import dist2, integrao_fuse, _stable_normalized 5 | from integrao.util import data_indexing 6 | 7 | import snf 8 | import pandas as pd 9 | import numpy as np 10 | import os 11 | 12 | import torch 13 | import torch_geometric.transforms as T 14 | import torch.nn.functional as F 15 | from integrao.dataset import GraphDataset 16 | 17 | 18 | class integrao_integrater(object): 19 | def __init__( 20 | self, 21 | datasets, 22 | dataset_name=None, 23 | modalities_name_list=None, 24 | neighbor_size=None, 25 | embedding_dims=50, 26 | fusing_iteration=20, 27 | normalization_factor=1.0, 28 | alighment_epochs=1000, 29 | beta=1.0, 30 | mu=0.5, 31 | random_state=42, 32 | ): 33 | self.datasets = datasets 34 | self.dataset_name = dataset_name 35 | self.modalities_name_list = modalities_name_list 36 | self.embedding_dims = embedding_dims 37 | self.fusing_iteration = fusing_iteration 38 | self.normalization_factor = normalization_factor 39 | self.alighment_epochs = alighment_epochs 40 | self.beta = beta 41 | self.mu = mu 42 | self.random_state=random_state 43 | 44 | # data indexing 45 | ( 46 | self.dicts_common, 47 | self.dicts_commonIndex, 48 | self.dict_sampleToIndexs, 49 | self.dicts_unique, 50 | self.original_order, 51 | self.dict_original_order, 52 | ) = data_indexing(self.datasets) 53 | 54 | # set neighbor size 55 | if neighbor_size == None: 56 | self.neighbor_size = int(datasets[0].shape[0] / 6) 57 | else: 58 | self.neighbor_size = neighbor_size 59 | print("Neighbor size:", self.neighbor_size) 60 | 61 | def network_diffusion(self): 62 | S_dfs = [] 63 | for i in range(0, len(self.datasets)): 64 | view = self.datasets[i] 65 | dist_mat = dist2(view.values, view.values) 66 | S_mat = snf.compute.affinity_matrix( 67 | dist_mat, K=self.neighbor_size, mu=self.mu 68 | ) 69 | 70 | S_df = pd.DataFrame( 71 | data=S_mat, index=self.original_order[i], columns=self.original_order[i] 72 | ) 73 | 74 | S_dfs.append(S_df) 75 | 76 | self.fused_networks = integrao_fuse( 77 | S_dfs.copy(), 78 | dicts_common=self.dicts_common, 79 | dicts_unique=self.dicts_unique, 80 | original_order=self.original_order, 81 | neighbor_size=self.neighbor_size, 82 | fusing_iteration=self.fusing_iteration, 83 | normalization_factor=self.normalization_factor, 84 | ) 85 | return self.fused_networks 86 | 87 | def unsupervised_alignment(self): 88 | # turn pandas dataframe into np array 89 | datasets_val = [x.values for x in self.datasets] 90 | fused_networks_val = [x.values for x in self.fused_networks] 91 | 92 | S_final, self.models = tsne_p_deep( 93 | self.dicts_commonIndex, 94 | self.dict_sampleToIndexs, 95 | datasets_val, 96 | P=fused_networks_val, 97 | neighbor_size=self.neighbor_size, 98 | embedding_dims=self.embedding_dims, 99 | alighment_epochs=self.alighment_epochs, 100 | ) 101 | 102 | self.final_embeds = pd.DataFrame( 103 | data=S_final, index=self.dict_sampleToIndexs.keys() 104 | ) 105 | self.final_embeds.sort_index(inplace=True) 106 | 107 | # calculate the final similarity graph 108 | dist_final = dist2(self.final_embeds.values, self.final_embeds.values) 109 | Wall_final = snf.compute.affinity_matrix( 110 | dist_final, K=self.neighbor_size, mu=self.mu 111 | ) 112 | 113 | Wall_final = _stable_normalized(Wall_final) 114 | 115 | return self.final_embeds, Wall_final, self.models 116 | 117 | def classification_finetuning(self, clf_labels, model_path, finetune_epochs=1000): 118 | # turn pandas dataframe into np array 119 | datasets_val = [x.values for x in self.datasets] 120 | fused_networks_val = [x.values for x in self.fused_networks] 121 | 122 | # reorder of clf_labels to make it the same with self.dict_sampleToIndexs.keys() 123 | clf_labels = clf_labels.loc[self.dict_sampleToIndexs.keys()] 124 | 125 | S_final, self.models, preds = tsne_p_deep_classification( 126 | self.dicts_commonIndex, 127 | self.dict_sampleToIndexs, 128 | self.dict_original_order, 129 | datasets_val, 130 | clf_labels, 131 | P=fused_networks_val, 132 | model_path=model_path, 133 | neighbor_size=self.neighbor_size, 134 | embedding_dims=self.embedding_dims, 135 | alighment_epochs=finetune_epochs, 136 | num_classes=len(np.unique(clf_labels)), 137 | ) 138 | 139 | self.final_embeds = pd.DataFrame( 140 | data=S_final, index=self.dict_sampleToIndexs.keys() 141 | ) 142 | self.final_embeds.sort_index(inplace=True) 143 | 144 | # calculate the final similarity graph 145 | dist_final = dist2(self.final_embeds.values, self.final_embeds.values) 146 | Wall_final = snf.compute.affinity_matrix( 147 | dist_final, K=self.neighbor_size, mu=self.mu 148 | ) 149 | 150 | Wall_final = _stable_normalized(Wall_final) 151 | 152 | return self.final_embeds, Wall_final, self.models, preds 153 | 154 | 155 | class integrao_predictor(object): 156 | def __init__( 157 | self, 158 | datasets, 159 | dataset_name=None, 160 | modalities_name_list=None, 161 | neighbor_size=None, 162 | embedding_dims=50, 163 | hidden_channels=128, 164 | fusing_iteration=20, 165 | normalization_factor=1.0, 166 | alighment_epochs=1000, 167 | beta=1.0, 168 | mu=0.5, 169 | num_classes=None, 170 | ): 171 | self.datasets = datasets 172 | self.dataset_name = dataset_name 173 | self.modalities_name_list = modalities_name_list 174 | self.embedding_dims = embedding_dims 175 | self.hidden_channels = hidden_channels 176 | self.fusing_iteration = fusing_iteration 177 | self.normalization_factor = normalization_factor 178 | self.alighment_epochs = alighment_epochs 179 | self.beta = beta 180 | self.mu = mu 181 | self.num_classes = num_classes 182 | 183 | # data indexing 184 | ( 185 | self.dicts_common, 186 | self.dicts_commonIndex, 187 | self.dict_sampleToIndexs, 188 | self.dicts_unique, 189 | self.original_order, 190 | self.dict_original_order, 191 | ) = data_indexing(self.datasets) 192 | 193 | # set neighbor size 194 | if neighbor_size == None: 195 | self.neighbor_size = int(datasets[0].shape[0] / 6) 196 | else: 197 | self.neighbor_size = neighbor_size 198 | print("Neighbor size:", self.neighbor_size) 199 | 200 | self.feature_dims = [] 201 | for i in range(len(self.datasets)): 202 | self.feature_dims.append(np.shape(self.datasets[i])[1]) 203 | 204 | if num_classes is not None: 205 | self.num_classes = num_classes 206 | 207 | 208 | def network_diffusion(self): 209 | S_dfs = [] 210 | for i in range(0, len(self.datasets)): 211 | view = self.datasets[i] 212 | dist_mat = dist2(view.values, view.values) 213 | S_mat = snf.compute.affinity_matrix( 214 | dist_mat, K=self.neighbor_size, mu=self.mu 215 | ) 216 | 217 | S_df = pd.DataFrame( 218 | data=S_mat, index=self.original_order[i], columns=self.original_order[i] 219 | ) 220 | 221 | S_dfs.append(S_df) 222 | 223 | self.fused_networks = integrao_fuse( 224 | S_dfs.copy(), 225 | dicts_common=self.dicts_common, 226 | dicts_unique=self.dicts_unique, 227 | original_order=self.original_order, 228 | neighbor_size=self.neighbor_size, 229 | fusing_iteration=self.fusing_iteration, 230 | normalization_factor=self.normalization_factor, 231 | ) 232 | return self.fused_networks 233 | 234 | 235 | def _load_pre_trained_weights(self, model, model_path, device): 236 | try: 237 | state_dict = torch.load(model_path, map_location=device) 238 | model.load_state_dict(state_dict) 239 | print("Loaded pre-trained model with success.") 240 | except FileNotFoundError: 241 | print("Pre-trained weights not found. Training from scratch.") 242 | 243 | return model 244 | 245 | def inference_unsupervised(self, model_path, new_datasets, modalities_names): 246 | # loop through the new_dataset and create Graphdatase 247 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 248 | 249 | from integrao.IntegrAO_unsupervised import IntegrAO 250 | model = IntegrAO(self.feature_dims, self.hidden_channels, self.embedding_dims).to(device) 251 | model = self._load_pre_trained_weights(model, model_path, device) 252 | 253 | x_dict = {} 254 | edge_index_dict = {} 255 | for i, modal in enumerate(new_datasets): 256 | # find the index of the modal in the self.modalities_name_list 257 | model_name = modalities_names[i] 258 | modal_index = self.modalities_name_list.index(model_name) 259 | 260 | dataset = GraphDataset( 261 | self.neighbor_size, 262 | modal.values, 263 | self.fused_networks[modal_index].values, 264 | transform=T.ToDevice(device), 265 | ) 266 | modal_dg = dataset[0] 267 | 268 | x_dict[modal_index] = modal_dg.x 269 | edge_index_dict[modal_index] = modal_dg.edge_index 270 | 271 | # Now to do the inference 272 | # --------------------------------------------------------- 273 | embeddings= model(x_dict, edge_index_dict) 274 | for i in range(len(new_datasets)): 275 | embeddings[i] = embeddings[i].detach().cpu().numpy() 276 | 277 | final_embedding = np.array([]).reshape(0, self.embedding_dims) 278 | for key in self.dict_sampleToIndexs: 279 | sample_embedding = np.zeros((1, self.embedding_dims)) 280 | 281 | for (dataset, index) in self.dict_sampleToIndexs[key]: 282 | sample_embedding += embeddings[dataset][index] 283 | sample_embedding /= len(self.dict_sampleToIndexs[key]) 284 | 285 | final_embedding = np.concatenate((final_embedding, sample_embedding), axis=0) 286 | 287 | # Now format the final embeddings 288 | # --------------------------------------------------------- 289 | final_embedding_df = pd.DataFrame( 290 | data=final_embedding, index=self.dict_sampleToIndexs.keys() 291 | ) 292 | final_embedding_df.sort_index(inplace=True) 293 | 294 | # calculate the final similarity graph 295 | dist_final = dist2(final_embedding_df.values, final_embedding_df.values) 296 | Wall_final = snf.compute.affinity_matrix( 297 | dist_final, K=self.neighbor_size, mu=self.mu 298 | ) 299 | 300 | Wall_final = _stable_normalized(Wall_final) 301 | 302 | return final_embedding_df, Wall_final 303 | 304 | def interpret_unsupervised(self, model_path, result_dir, new_datasets, modalities_names): 305 | # loop through the new_dataset and create Graphdatase 306 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 307 | 308 | from integrao.IntegrAO_unsupervised import IntegrAO 309 | model = IntegrAO(self.feature_dims, self.hidden_channels, self.embedding_dims).to(device) 310 | model = self._load_pre_trained_weights(model, model_path, device) 311 | 312 | 313 | # explain the model 314 | from captum.attr import IntegratedGradients 315 | 316 | # It takes as input the variable node features for one domain, 317 | # while the remaining features and edge indices remain fixed. 318 | def custom_forward(x, static_x_dict, edge_index_dict, domain): 319 | x_dict = static_x_dict.copy() 320 | x_dict[domain] = x 321 | 322 | out_dict = model(x_dict, edge_index_dict) 323 | 324 | return out_dict[domain].sum(dim=1) # iG requires scalar output; so we sum the output of the embeddings 325 | 326 | # prepare the data 327 | x_dict = {} 328 | edge_index_dict = {} 329 | for i, modal in enumerate(new_datasets): 330 | model_name = modalities_names[i] 331 | modal_index = self.modalities_name_list.index(model_name) 332 | 333 | dataset = GraphDataset( 334 | self.neighbor_size, 335 | modal.values, 336 | self.fused_networks[modal_index].values, 337 | transform=T.ToDevice(device), 338 | ) 339 | modal_dg = dataset[0] 340 | 341 | x_dict[modal_index] = modal_dg.x 342 | edge_index_dict[modal_index] = modal_dg.edge_index 343 | 344 | # Loop over each domain (modality) 345 | # --------------------------------------------------------- 346 | feat_importances = {} 347 | for domain in x_dict: 348 | x_input = x_dict[domain] # The variable input for the current domain. 349 | static_x = {k: x_dict[k] for k in x_dict} 350 | 351 | ig = IntegratedGradients(custom_forward) 352 | 353 | attributions, delta = ig.attribute( 354 | inputs=x_input, 355 | additional_forward_args=(static_x, edge_index_dict, domain), 356 | return_convergence_delta=True 357 | ) 358 | 359 | if domain not in feat_importances: 360 | feat_importances[domain] = [] 361 | feat_importances[domain].append(attributions.detach().cpu().numpy()) 362 | 363 | 364 | df_list = [] 365 | for domain in feat_importances: 366 | 367 | # Concatenate along the first axis (nodes). 368 | feat_importances[domain] = np.concatenate(feat_importances[domain], axis=0) 369 | num_feats = feat_importances[domain].shape[1] 370 | # Create a DataFrame; here columns are named feat_0, feat_1, etc. 371 | df = pd.DataFrame(feat_importances[domain], columns=[f'feat_{i}' for i in range(num_feats)]) 372 | df_list.append(df) 373 | 374 | # save the feature importance 375 | csv_path = os.path.join(result_dir, f'{modalities_names[domain]}_feat_importance.csv') 376 | df.to_csv(csv_path, index=False) 377 | print(df.shape) 378 | 379 | print(f"Saved feature importances for domain {modalities_names[domain]} to {csv_path}") 380 | 381 | return df_list 382 | 383 | 384 | def inference_supervised(self, model_path, new_datasets, modalities_names): 385 | # loop through the new_dataset and create Graphdatase 386 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 387 | 388 | from integrao.IntegrAO_supervised import IntegrAO 389 | model = IntegrAO(self.feature_dims, self.hidden_channels, self.embedding_dims, num_classes=self.num_classes).to(device) 390 | model = self._load_pre_trained_weights(model, model_path, device) 391 | 392 | x_dict = {} 393 | edge_index_dict = {} 394 | for i, modal in enumerate(new_datasets): 395 | # find the index of the modal in the self.modalities_name_list 396 | model_name = modalities_names[i] 397 | modal_index = self.modalities_name_list.index(model_name) 398 | 399 | dataset = GraphDataset( 400 | self.neighbor_size, 401 | modal.values, 402 | self.fused_networks[modal_index].values, 403 | transform=T.ToDevice(device), 404 | ) 405 | modal_dg = dataset[0] 406 | 407 | x_dict[modal_index] = modal_dg.x 408 | edge_index_dict[modal_index] = modal_dg.edge_index 409 | 410 | # Now to do the inference 411 | final_embeddings, _, preds, id_list = model( 412 | x_dict, edge_index_dict, self.dict_original_order 413 | ) 414 | 415 | preds = F.softmax(preds, dim=1) 416 | preds = preds.detach().cpu().numpy() 417 | preds = np.argmax(preds, axis=1) 418 | 419 | return preds 420 | 421 | 422 | def interpret_supervised(self, model_path, result_dir, new_datasets, modalities_names): 423 | # loop through the new_dataset and create Graphdatase 424 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 425 | 426 | from integrao.IntegrAO_supervised import IntegrAO 427 | model = IntegrAO(self.feature_dims, self.hidden_channels, self.embedding_dims, num_classes=self.num_classes).to(device) 428 | model = self._load_pre_trained_weights(model, model_path, device) 429 | 430 | # explain the model 431 | from captum.attr import IntegratedGradients 432 | 433 | # It takes variable node features (x) for a given domain, 434 | # while keeping the rest of the inputs (static_x_dict, edge_index_dict, and domain_sample_ids) fixed. 435 | def custom_forward(x, static_x_dict, edge_index_dict, domain, domain_sample_ids): 436 | x_dict = static_x_dict.copy() 437 | x_dict[domain] = x 438 | 439 | _, _, output, _ = model(x_dict, edge_index_dict, domain_sample_ids) 440 | 441 | # Aggregate output per sample to a scalar. 442 | # Here we sum over the class dimension (dim=1); adjust if you need a different reduction; for example just a single class. 443 | return output.sum(dim=1) 444 | 445 | 446 | # Prepare the data dictionaries for node features and edge indices. 447 | x_dict = {} 448 | edge_index_dict = {} 449 | for i, modal in enumerate(new_datasets): 450 | # find the index of the modal in the self.modalities_name_list 451 | model_name = modalities_names[i] 452 | modal_index = self.modalities_name_list.index(model_name) 453 | 454 | dataset = GraphDataset( 455 | self.neighbor_size, 456 | modal.values, 457 | self.fused_networks[modal_index].values, 458 | transform=T.ToDevice(device), 459 | ) 460 | modal_dg = dataset[0] 461 | 462 | x_dict[modal_index] = modal_dg.x 463 | edge_index_dict[modal_index] = modal_dg.edge_index 464 | 465 | # Compute feature importances using IntegratedGradients. 466 | feat_importances = {} 467 | for domain in x_dict: 468 | x_input = x_dict[domain] 469 | static_x = {k: x_dict[k] for k in x_dict} 470 | 471 | ig = IntegratedGradients(custom_forward) 472 | 473 | attributions, delta = ig.attribute( 474 | inputs=x_input, 475 | additional_forward_args=(static_x, edge_index_dict, domain, self.dict_original_order), 476 | return_convergence_delta=True 477 | ) 478 | 479 | if domain not in feat_importances: 480 | feat_importances[domain] = [] 481 | feat_importances[domain].append(attributions.detach().cpu().numpy()) 482 | 483 | 484 | df_list = [] 485 | for domain in feat_importances: 486 | 487 | # Concatenate along the first axis (nodes). 488 | feat_importances[domain] = np.concatenate(feat_importances[domain], axis=0) 489 | num_feats = feat_importances[domain].shape[1] 490 | # Create a DataFrame; here columns are named feat_0, feat_1, etc. 491 | df = pd.DataFrame(feat_importances[domain], columns=[f'feat_{i}' for i in range(num_feats)]) 492 | df_list.append(df) 493 | 494 | # save the feature importance 495 | csv_path = os.path.join(result_dir, f'{modalities_names[domain]}_feat_importance.csv') 496 | df.to_csv(csv_path, index=False) 497 | print(df.shape) 498 | 499 | print(f"Saved feature importances for domain {modalities_names[domain]} to {csv_path}") 500 | 501 | return df_list -------------------------------------------------------------------------------- /integrao/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import time 4 | from snf.compute import _flatten, _B0_normalized, _find_dominate_set 5 | from sklearn.utils.validation import ( 6 | check_array, 7 | check_symmetric, 8 | check_consistent_length, 9 | ) 10 | 11 | def dist2(X, C): 12 | """ 13 | Description: Computes the Euclidean distances between all pairs of data point given 14 | 15 | Usage: dist2(X, C) 16 | X: A data matrix where each row is a different data point 17 | C: A data matrix where each row is a different data point. If this matrix is the same as X, 18 | pairwise distances for all data points in X are computed. 19 | 20 | Return: Returns an N x M matrix where N is the number of rows in X and M is the number of rows in C. 21 | 22 | Author: Dr. Anna Goldenberg, Bo Wang, Aziz Mezlini, Feyyaz Demir 23 | Python Version Rewrite: Rex Ma 24 | 25 | Examples: 26 | # Data1 is of size n x d_1, where n is the number of patients, d_1 is the number of genes, 27 | # Data2 is of size n x d_2, where n is the number of patients, d_2 is the number of methylation 28 | Dist1 = dist2(Data1, Data1) 29 | Dist2 = dist2(Data2, Data2) 30 | """ 31 | 32 | ndata = X.shape[0] 33 | ncentres = C.shape[0] 34 | 35 | sumsqX = np.sum(X * X, axis=1) 36 | sumsqC = np.sum(C * C, axis=1) 37 | 38 | XC = 2 * (np.matmul(X, np.transpose(C))) 39 | 40 | res = ( 41 | np.transpose(np.reshape(np.tile(sumsqX, ncentres), (ncentres, ndata))) 42 | + np.reshape(np.tile(sumsqC, ndata), (ndata, ncentres)) 43 | - XC 44 | ) 45 | 46 | return res 47 | 48 | 49 | def _find_dominate_set_relative(W, K=20): 50 | """ 51 | Retains `K` strongest edges for each sample in `W` 52 | Parameters 53 | ---------- 54 | W : (N, N) array_like 55 | Input data 56 | K : (0, N) int, optional 57 | Number of neighbors to retain. Default: 20 58 | Returns 59 | ------- 60 | Wk : (N, N) np.ndarray 61 | Thresholded version of `W` 62 | """ 63 | 64 | # let's not modify W in place 65 | Wk = W.copy() 66 | 67 | # determine percentile cutoff that will keep only `K` edges for each sample 68 | # remove everything below this cutoff 69 | cutoff = 100 - (100 * (K / len(W))) 70 | Wk[Wk < np.percentile(Wk, cutoff, axis=1, keepdims=True)] = 0 71 | 72 | # normalize by strength of remaining edges 73 | Wk = Wk / np.nansum(Wk, axis=1, keepdims=True) 74 | 75 | Ws = Wk + np.transpose(Wk) 76 | 77 | return Ws 78 | 79 | 80 | 81 | def _stable_normalized(W): 82 | """ 83 | Adds `alpha` to the diagonal of `W` 84 | 85 | Parameters 86 | ---------- 87 | W : (N, N) array_like 88 | Similarity array from SNF 89 | 90 | Returns 91 | ------- 92 | W : (N, N) np.ndarray 93 | Stable-normalized similiarity array 94 | """ 95 | 96 | # add `alpha` to the diagonal and symmetrize `W` 97 | rowSum = np.sum(W, 1) - np.diag(W) 98 | rowSum[rowSum == 0] = 1 99 | 100 | W = W / (2 * rowSum) 101 | np.fill_diagonal(W, 0.5) 102 | W = check_symmetric(W, raise_warning=False) 103 | 104 | return W 105 | 106 | 107 | def _stable_normalized_pd(W): 108 | """ 109 | Adds `alpha` to the diagonal of pandas dataframe `W` 110 | 111 | Parameters 112 | ---------- 113 | W : (N, N) array_like 114 | Similarity array from SNF 115 | 116 | Returns 117 | ------- 118 | W : (N, N) np.ndarray 119 | Stable-normalized similiarity array 120 | """ 121 | 122 | # add `alpha` to the diagonal and symmetrize `W` 123 | rowSum = np.sum(W, 1) - np.diag(W) 124 | rowSum[rowSum == 0] = 1 125 | 126 | W = W / (2 * rowSum) 127 | 128 | W_np = W.values 129 | np.fill_diagonal(W_np, 0.5) 130 | W = pd.DataFrame(W_np, index=W.index, columns=W.columns) 131 | 132 | W = check_symmetric(W, raise_warning=False) 133 | 134 | return W 135 | 136 | 137 | def _scaling_normalized_pd(W, ratio): 138 | """ 139 | Adds `alpha` to the diagonal of pandas dataframe `W` 140 | 141 | Parameters 142 | ---------- 143 | W : (N, N) array_like 144 | Similarity array from SNF 145 | 146 | Returns 147 | ------- 148 | W : (N, N) np.ndarray 149 | Stable-normalized similiarity array 150 | """ 151 | 152 | # add `alpha` to the diagonal and symmetrize `W` 153 | rowSum = np.sum(W, 1) - np.diag(W) 154 | rowSum[rowSum == 0] = 1 155 | 156 | W = (W / rowSum) * 0.5 * ratio 157 | 158 | W_np = W.values 159 | np.fill_diagonal(W_np, 1-0.5*ratio) 160 | W = pd.DataFrame(W_np, index=W.index, columns=W.columns) 161 | 162 | W = check_symmetric(W, raise_warning=False) 163 | 164 | return W 165 | 166 | 167 | def handle_zeros_in_scale(scale, copy=True): 168 | """Makes sure that whenever scale is zero, we handle it correctly. 169 | This happens in most scalers when we have constant features. 170 | Adapted from sklearn.preprocessing.data""" 171 | 172 | # if we are fitting on 1D arrays, scale might be a scalar 173 | if np.isscalar(scale): 174 | if scale == 0.0: 175 | scale = 1.0 176 | return scale 177 | elif isinstance(scale, np.ndarray): 178 | if copy: 179 | # New array to avoid side-effects 180 | scale = scale.copy() 181 | scale[scale == 0.0] = 1.0 182 | return scale 183 | 184 | 185 | def integrao_fuse(aff, dicts_common, dicts_unique, original_order, neighbor_size=20, fusing_iteration=20, normalization_factor=1.0): 186 | """ 187 | Performs Patient Graph Fusion on `aff` matrices 188 | 189 | Parameters 190 | ---------- 191 | aff : (N, N) pandas dataframe 192 | Input similarity arrays; all arrays should be square but no need to be equal size. 193 | 194 | dicts_common: dictionaries, required 195 | Dictionaries for getting common samples from different views 196 | Example: dicts_common[(0, 1)] == dicts_common[(1, 0)], meaning the common patients between view 1&2 197 | 198 | dicts_unique: dictionaries, required 199 | Dictionaries for getting unique samples for different views 200 | Example: dicts_unique[(0, 1)], meaning the unique samples from view 1 that are not in view 2 201 | dicts_unique[(1, 0)], meaning the unique samples from view 2 that are not in view 1 202 | 203 | original_order: lists, required 204 | The original order of each view 205 | 206 | K : (0, N) int, optional 207 | Hyperparameter normalization factor for scaling. Default: 20 208 | 209 | t : int, optional 210 | Number of iterations to perform information swapping. Default: 20 211 | 212 | alpha : (0, 1) float, optional 213 | Hyperparameter normalization factor for scaling. Default: 1.0 214 | 215 | Returns 216 | ------- 217 | W: (N, N) Ouputs similarity arrays 218 | Fused similarity networks of input arrays 219 | """ 220 | 221 | print("Start applying diffusion!") 222 | 223 | start_time = time.time() 224 | 225 | newW = [0] * len(aff) 226 | 227 | # First, normalize different networks to avoid scale problems, it is compatible with pandas dataframe 228 | for n, mat in enumerate(aff): 229 | 230 | # normalize affinity matrix based on strength of edges 231 | # mat = mat / np.nansum(mat, axis=1, keepdims=True) 232 | aff[n] = _stable_normalized_pd(mat) 233 | # aff[n] = check_symmetric(mat, raise_warning=False) 234 | 235 | # apply KNN threshold to normalized affinity matrix 236 | # We need to crop the intersecting samples from newW matrices 237 | neighbor_size = min(int(neighbor_size), mat.shape[0]) 238 | newW[n] = _find_dominate_set(aff[n], neighbor_size) 239 | 240 | # If there is only one view, return it 241 | if len(aff) == 1: 242 | print("Only one view, return it directly") 243 | return newW 244 | 245 | for iteration in range(fusing_iteration): 246 | 247 | # Make a copy of the aff matrix for this iteration 248 | # goal is to update aff[n], but it is the average of all the defused matrices 249 | # Make a copy of add[n], and set it to 0 250 | aff_next = [] 251 | for k in range(len(aff)): 252 | aff_temp = aff[k].copy() 253 | for col in aff_temp.columns: 254 | aff_temp[col].values[:] = 0 255 | aff_next.append(aff_temp) 256 | 257 | for n, mat in enumerate(aff): 258 | # temporarily convert nans to 0 to avoid propagation errors 259 | nzW = newW[n] # TODO: not sure this is a deep copy or not 260 | 261 | for j, mat_tofuse in enumerate(aff): 262 | if n == j: 263 | continue 264 | 265 | # reorder mat_tofuse to have the common samples 266 | mat_tofuse = mat_tofuse.reindex( 267 | (sorted(dicts_common[(j, n)]) + sorted(dicts_unique[(j, n)])), 268 | axis=1, 269 | ) 270 | mat_tofuse = mat_tofuse.reindex( 271 | (sorted(dicts_common[(j, n)]) + sorted(dicts_unique[(j, n)])), 272 | axis=0, 273 | ) 274 | 275 | # Next, let's crop mat_tofuse 276 | num_common = len(dicts_common[(n, j)]) 277 | to_drop_mat = mat_tofuse.columns[ 278 | num_common : mat_tofuse.shape[1] 279 | ].values.tolist() 280 | mat_tofuse_crop = mat_tofuse.drop(to_drop_mat, axis=1) 281 | mat_tofuse_crop = mat_tofuse_crop.drop(to_drop_mat, axis=0) 282 | 283 | # Next, add the similarity from the view to fused to the current view identity matrix 284 | nzW_identity = pd.DataFrame( 285 | data=np.identity(nzW.shape[0]), 286 | index=original_order[n], 287 | columns=original_order[n], 288 | ) 289 | 290 | 291 | mat_tofuse_union = nzW_identity + mat_tofuse_crop 292 | mat_tofuse_union.fillna(0.0, inplace=True) 293 | mat_tofuse_union = _scaling_normalized_pd(mat_tofuse_union, ratio=mat_tofuse_crop.shape[0]/nzW_identity.shape[0]) 294 | mat_tofuse_union = check_symmetric(mat_tofuse_union, raise_warning=False) 295 | mat_tofuse_union = mat_tofuse_union.reindex(original_order[n], axis=1) 296 | mat_tofuse_union = mat_tofuse_union.reindex(original_order[n], axis=0) 297 | 298 | # Now we are ready to do the diffusion 299 | nzW_T = np.transpose(nzW) 300 | aff0_temp = nzW.dot( 301 | mat_tofuse_union.dot(nzW_T) 302 | ) # Matmul is not working, but .dot() is good 303 | 304 | 305 | ################################################# 306 | # Experimentally introduce a weighting machanisim, use the exponential weight; Already proved it's not a good idea 307 | # num_com = mat_tofuse_crop.shape[0] / aff[n].shape[0] 308 | # alpha = pow(2, num_com) - 1 309 | # aff0_temp = alpha * aff0_temp + (1-alpha) * aff[n] 310 | 311 | #aff0_temp = _B0_normalized(aff0_temp, alpha=normalization_factor) 312 | aff0_temp = _stable_normalized_pd(aff0_temp) 313 | # aff0_temp = check_symmetric(aff0_temp, raise_warning=False) 314 | 315 | aff_next[n] = np.add(aff0_temp, aff_next[n]) 316 | 317 | aff_next[n] = np.divide(aff_next[n], len(aff) - 1) 318 | # aff_next[n] = _stable_normalized_pd(aff_next[n]) 319 | 320 | # put the value in aff_next back to aff 321 | for k in range(len(aff)): 322 | aff[k] = aff_next[k] 323 | 324 | for n, mat in enumerate(aff): 325 | aff[n] = _stable_normalized_pd(mat) 326 | # aff[n] = check_symmetric(mat, raise_warning=False) 327 | 328 | end_time = time.time() 329 | print("Diffusion ends! Times: {}s".format(end_time - start_time)) 330 | return aff -------------------------------------------------------------------------------- /integrao/supervised_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | 8 | cudnn.benchmark = True 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import networkx as nx 13 | import time 14 | import os 15 | from snf.compute import _find_dominate_set 16 | 17 | from integrao.IntegrAO_supervised import IntegrAO 18 | from integrao.dataset import GraphDataset 19 | import torch_geometric.transforms as T 20 | 21 | def tsne_loss(P, activations): 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | n = activations.size(0) 24 | alpha = 1 25 | eps = 1e-12 26 | sum_act = torch.sum(torch.pow(activations, 2), 1) 27 | Q = ( 28 | sum_act 29 | + sum_act.view([-1, 1]) 30 | - 2 * torch.matmul(activations, torch.transpose(activations, 0, 1)) 31 | ) 32 | Q = Q / alpha 33 | Q = torch.pow(1 + Q, -(alpha + 1) / 2) 34 | Q = Q * autograd.Variable(1 - torch.eye(n), requires_grad=False).to(device) 35 | Q = Q / torch.sum(Q) 36 | Q = torch.clamp(Q, min=eps) 37 | C = torch.log((P + eps) / (Q + eps)) 38 | C = torch.sum(P * C) 39 | return C 40 | 41 | 42 | def adjust_learning_rate(optimizer, epoch): 43 | """Sets the learning rate to the initial LR decayed by 0.1 every 100 epochs""" 44 | lr = 0.1 * (0.1 ** (epoch // 100)) 45 | lr = max(lr, 1e-3) 46 | for param_group in optimizer.param_groups: 47 | param_group["lr"] = lr 48 | 49 | 50 | def init_model(net, device, restore): 51 | if restore is not None and os.path.exits(restore): 52 | net.load_state_dict(torch.load(restore)) 53 | net.restored = True 54 | print("Restore model from: {}".format(os.path.abspath(restore))) 55 | 56 | else: 57 | pass 58 | 59 | if torch.cuda.is_available(): 60 | cudnn.benchmark = True 61 | net.to(device) 62 | 63 | return net 64 | 65 | 66 | def P_preprocess(P): 67 | # Make sure P-values are set properly 68 | np.fill_diagonal(P, 0) # set diagonal to zero 69 | P = P + np.transpose(P) # symmetrize P-values 70 | P = P / np.sum(P) # make sure P-values sum to one 71 | # P = P * 4.0 # early exaggeration 72 | P = np.maximum(P, 1e-12) 73 | return P 74 | 75 | def _load_pre_trained_weights(model, model_path, device): 76 | try: 77 | state_dict = torch.load( 78 | os.path.join(model_path, "model.pth"), map_location=device 79 | ) 80 | # model.load_state_dict(state_dict) 81 | model.load_my_state_dict(state_dict) 82 | print("Loaded pre-trained model with success.") 83 | except FileNotFoundError: 84 | print("Pre-trained weights not found. Training from scratch.") 85 | 86 | return model 87 | 88 | def tsne_p_deep_classification(dicts_commonIndex, dict_sampleToIndexs, dict_original_order, data, clf_labels, model_path=None, P=np.array([]), neighbor_size=20, embedding_dims=50, alighment_epochs=1000, num_classes=2): 89 | """ 90 | Runs t-SNE on the dataset in the NxN matrix P to extract embedding vectors 91 | to no_dims dimensions. 92 | """ 93 | 94 | # Check inputs 95 | if isinstance(embedding_dims, float): 96 | print("Error: array P should have type float.") 97 | return -1 98 | if round(embedding_dims) != embedding_dims: 99 | print("Error: number of dimensions should be an integer.") 100 | return -1 101 | 102 | print("Starting supervised fineting!") 103 | start_time = time.time() 104 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 105 | 106 | hidden_channels = 128 # TODO: change to using ymal file 107 | dataset_num = len(P) 108 | feature_dims = [] 109 | transform = T.Compose([ 110 | T.ToDevice(device), 111 | ]) 112 | 113 | # clf_labels is a dataframe 114 | labels = torch.from_numpy(clf_labels.values.flatten()).long().to(device) 115 | 116 | x_dict = {} 117 | edge_index_dict = {} 118 | for i in range(dataset_num): 119 | # preprocess the inputs into PyG graph format 120 | dataset = GraphDataset(neighbor_size, data[i], P[i], transform=transform) 121 | x_dict[i] = dataset[0].x 122 | edge_index_dict[i] = dataset[0].edge_index 123 | 124 | feature_dims.append(np.shape(data[i])[1]) 125 | print("Dataset {}:".format(i), np.shape(data[i])) 126 | 127 | # preprocess similarity matrix for t-sne kl loss 128 | P[i] = P_preprocess(P[i]) 129 | P[i] = torch.from_numpy(P[i]).float().to(device) 130 | 131 | 132 | net = IntegrAO(feature_dims, hidden_channels, embedding_dims, num_classes=num_classes).to(device) # should load pre-trained model 133 | 134 | if model_path is not None: 135 | Project_GNN = _load_pre_trained_weights(net, model_path, device) 136 | else: 137 | Project_GNN = init_model(net, device, restore=None) 138 | Project_GNN.train() 139 | 140 | optimizer = torch.optim.Adam(Project_GNN.parameters(), lr=1e-1) 141 | c_mse = nn.MSELoss() 142 | c_cn = nn.CrossEntropyLoss() 143 | 144 | for epoch in range(alighment_epochs): 145 | adjust_learning_rate(optimizer, epoch) 146 | 147 | loss = 0 148 | embeddings = [] 149 | 150 | kl_loss = np.array(0) 151 | kl_loss = torch.from_numpy(kl_loss).to(device).float() 152 | 153 | # KL loss for each network 154 | embeddings, _, pred, _ = Project_GNN(x_dict, edge_index_dict, dict_original_order) 155 | embeddings = list(embeddings.values()) 156 | 157 | for i, X_embedding in enumerate(embeddings): 158 | kl_loss += tsne_loss(P[i], X_embedding) 159 | 160 | # pairwise alignment loss between each pair of networks 161 | alignment_loss = np.array(0) 162 | alignment_loss = torch.from_numpy(alignment_loss).to(device).float() 163 | 164 | for i in range(dataset_num - 1): 165 | for j in range(i + 1, dataset_num): 166 | low_dim_set1 = embeddings[i][dicts_commonIndex[(i, j)]] 167 | low_dim_set2 = embeddings[j][dicts_commonIndex[(j, i)]] 168 | alignment_loss += c_mse(low_dim_set1, low_dim_set2) 169 | 170 | loss += kl_loss + alignment_loss 171 | 172 | # if classification task, take the average of all the embeddings and calculate the classification loss 173 | clf_loss = c_cn(pred, labels) 174 | loss += clf_loss 175 | 176 | optimizer.zero_grad() 177 | loss.backward() 178 | optimizer.step() 179 | 180 | if (epoch) % 100 == 0: 181 | print( 182 | "epoch {}: loss {}, kl_loss:{:4f}, align_loss:{:4f}, clf_loss:{:4f}".format( 183 | epoch, loss.data.item(), kl_loss.data.item(), alignment_loss.data.item(), clf_loss.data.item() 184 | ) 185 | ) 186 | # if epoch == 100: 187 | # for i in range(dataset_num): 188 | # P[i] = P[i] / 4.0 189 | 190 | # get the final embeddings for all samples 191 | embeddings, X_embedding_avg, preds, _ = Project_GNN(x_dict, edge_index_dict, dict_original_order) 192 | pred = pred.detach().cpu().numpy() 193 | 194 | # Now I need to put X_embedding_avg in order 195 | final_embeddings = X_embedding_avg.detach().cpu().numpy() 196 | 197 | end_time = time.time() 198 | print("Manifold alignment ends! Times: {}s".format(end_time - start_time)) 199 | 200 | return final_embeddings, Project_GNN, preds 201 | -------------------------------------------------------------------------------- /integrao/unsupervised_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | 8 | cudnn.benchmark = True 9 | 10 | from snf.compute import _find_dominate_set 11 | import numpy as np 12 | import networkx as nx 13 | import time 14 | 15 | 16 | from integrao.IntegrAO_unsupervised import IntegrAO 17 | from integrao.dataset import GraphDataset 18 | import torch_geometric.transforms as T 19 | 20 | 21 | def tsne_loss(P, activations): 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | n = activations.size(0) 24 | alpha = 1 25 | eps = 1e-12 26 | sum_act = torch.sum(torch.pow(activations, 2), 1) 27 | Q = ( 28 | sum_act 29 | + sum_act.view([-1, 1]) 30 | - 2 * torch.matmul(activations, torch.transpose(activations, 0, 1)) 31 | ) 32 | Q = Q / alpha 33 | Q = torch.pow(1 + Q, -(alpha + 1) / 2) 34 | Q = Q * autograd.Variable(1 - torch.eye(n), requires_grad=False).to(device) 35 | Q = Q / torch.sum(Q) 36 | Q = torch.clamp(Q, min=eps) 37 | C = torch.log((P + eps) / (Q + eps)) 38 | C = torch.sum(P * C) 39 | return C 40 | 41 | 42 | def adjust_learning_rate(optimizer, epoch): 43 | """Sets the learning rate to the initial LR decayed by 0.1 every 100 epochs""" 44 | lr = 0.1 * (0.1 ** (epoch // 100)) 45 | lr = max(lr, 1e-3) 46 | for param_group in optimizer.param_groups: 47 | param_group["lr"] = lr 48 | 49 | 50 | def init_model(net, device, restore): 51 | if restore is not None and os.path.exits(restore): 52 | net.load_state_dict(torch.load(restore)) 53 | net.restored = True 54 | print("Restore model from: {}".format(os.path.abspath(restore))) 55 | 56 | else: 57 | pass 58 | 59 | if torch.cuda.is_available(): 60 | cudnn.benchmark = True 61 | net.to(device) 62 | 63 | return net 64 | 65 | 66 | def P_preprocess(P): 67 | # Make sure P-values are set properly 68 | np.fill_diagonal(P, 0) # set diagonal to zero 69 | P = P + np.transpose(P) # symmetrize P-values 70 | P = P / np.sum(P) # make sure P-values sum to one 71 | P = P * 4.0 # early exaggeration 72 | P = np.maximum(P, 1e-12) 73 | return P 74 | 75 | 76 | def tsne_p_deep(dicts_commonIndex, dict_sampleToIndexs, data, P=np.array([]), neighbor_size=20, embedding_dims=50, alighment_epochs=1000): 77 | """ 78 | Runs t-SNE on the dataset in the NxN matrix P to extract embedding vectors 79 | to no_dims dimensions. 80 | """ 81 | 82 | # Check inputs 83 | if isinstance(embedding_dims, float): 84 | print("Error: array P should have type float.") 85 | return -1 86 | if round(embedding_dims) != embedding_dims: 87 | print("Error: number of dimensions should be an integer.") 88 | return -1 89 | 90 | print("Starting unsupervised exmbedding extraction!") 91 | start_time = time.time() 92 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 93 | 94 | hidden_channels = 128 # TODO: change to using ymal file 95 | dataset_num = len(P) 96 | feature_dims = [] 97 | transform = T.Compose([ 98 | T.ToDevice(device), 99 | ]) 100 | 101 | x_dict = {} 102 | edge_index_dict = {} 103 | for i in range(dataset_num): 104 | # preprocess the inputs into PyG graph format 105 | dataset = GraphDataset(neighbor_size, data[i], P[i], transform=transform) 106 | x_dict[i] = dataset[0].x 107 | edge_index_dict[i] = dataset[0].edge_index 108 | 109 | feature_dims.append(np.shape(data[i])[1]) 110 | print("Dataset {}:".format(i), np.shape(data[i])) 111 | 112 | # preprocess similarity matrix for t-sne loss 113 | P[i] = P_preprocess(P[i]) 114 | P[i] = torch.from_numpy(P[i]).float().to(device) 115 | 116 | net = IntegrAO(feature_dims, hidden_channels, embedding_dims) 117 | Project_GNN = init_model(net, device, restore=None) 118 | Project_GNN.train() 119 | 120 | optimizer = torch.optim.Adam(Project_GNN.parameters(), lr=1e-1) 121 | c_mse = nn.MSELoss() 122 | 123 | 124 | for epoch in range(alighment_epochs): 125 | adjust_learning_rate(optimizer, epoch) 126 | 127 | loss = 0 128 | embeddings = [] 129 | 130 | kl_loss = np.array(0) 131 | kl_loss = torch.from_numpy(kl_loss).to(device).float() 132 | 133 | embeddings = Project_GNN(x_dict, edge_index_dict) 134 | embeddings = list(embeddings.values()) 135 | for i, X_embedding in enumerate(embeddings): 136 | kl_loss += tsne_loss(P[i], X_embedding) 137 | 138 | # pairwise alignment loss between each pair of networks 139 | alignment_loss = np.array(0) 140 | alignment_loss = torch.from_numpy(alignment_loss).to(device).float() 141 | 142 | for i in range(dataset_num - 1): 143 | for j in range(i + 1, dataset_num): 144 | low_dim_set1 = embeddings[i][dicts_commonIndex[(i, j)]] 145 | low_dim_set2 = embeddings[j][dicts_commonIndex[(j, i)]] 146 | alignment_loss += c_mse(low_dim_set1, low_dim_set2) 147 | 148 | loss += kl_loss + alignment_loss 149 | 150 | optimizer.zero_grad() 151 | loss.backward() 152 | optimizer.step() 153 | 154 | if (epoch) % 100 == 0: 155 | print( 156 | "epoch {}: loss {}, align_loss:{:4f}".format( 157 | epoch, loss.data.item(), alignment_loss.data.item() 158 | ) 159 | ) 160 | if epoch == 100: 161 | for i in range(dataset_num): 162 | P[i] = P[i] / 4.0 163 | 164 | # get the final embeddings for all samples 165 | embeddings = Project_GNN(x_dict, edge_index_dict) 166 | for i in range(dataset_num): 167 | embeddings[i] = embeddings[i].detach().cpu().numpy() 168 | 169 | # compute the average embedding for each sample 170 | final_embedding = np.array([]).reshape(0, embedding_dims) 171 | for key in dict_sampleToIndexs: 172 | sample_embedding = np.zeros((1, embedding_dims)) 173 | 174 | for (dataset, index) in dict_sampleToIndexs[key]: 175 | sample_embedding += embeddings[dataset][index] 176 | sample_embedding /= len(dict_sampleToIndexs[key]) 177 | 178 | final_embedding = np.concatenate((final_embedding, sample_embedding), axis=0) 179 | 180 | end_time = time.time() 181 | print("Manifold alignment ends! Times: {}s".format(end_time - start_time)) 182 | 183 | return final_embedding, Project_GNN 184 | -------------------------------------------------------------------------------- /integrao/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | The functions in this file is used to get the common and unique samples between any 2 views. 3 | 4 | Author: Shihao Ma at WangLab, Dec. 14, 2020 5 | """ 6 | 7 | import numpy as np 8 | from collections import defaultdict 9 | 10 | def data_indexing(matrices): 11 | """ 12 | Performs data indexing on input expression matrices 13 | 14 | Parameters 15 | ---------- 16 | matrices : (M, N) array_like 17 | Input expression matrices, with gene/feature in columns and sample in row. 18 | 19 | Returns 20 | ------- 21 | matrices_pure: Expression matrices without the first column and first row 22 | dict_commonSample: dictionaries that give you the common samples between 2 views 23 | dict_uniqueSample: dictionaries that give you the unique samples between 2 views 24 | original_order: the original order of samples for each view 25 | """ 26 | 27 | if len(matrices) < 1: 28 | print("Input nothing, return nothing") 29 | return None 30 | 31 | print("Start indexing input expression matrices!") 32 | 33 | original_order = [0] * (len(matrices)) 34 | dict_original_order = {} 35 | dict_commonSample = {} 36 | dict_uniqueSample = {} 37 | dict_commonSampleIndex = {} 38 | dict_sampleToIndexs = defaultdict(list) 39 | 40 | for i in range(0, len(matrices)): 41 | original_order[i] = list(matrices[i].index) 42 | dict_original_order[i] = original_order[i] 43 | for sample in original_order[i]: 44 | dict_sampleToIndexs[sample].append( 45 | (i, np.argwhere(matrices[i].index == sample).squeeze().tolist()) 46 | ) 47 | 48 | for i in range(0, len(original_order)): 49 | for j in range(i + 1, len(original_order)): 50 | commonList = list(set(original_order[i]).intersection(original_order[j])) 51 | print("Common sample between view{} and view{}: {}".format(i, j, len(commonList))) 52 | dict_commonSample.update( 53 | dict_commonSample.fromkeys([(i, j), (j, i)], commonList) 54 | ) 55 | dict_commonSampleIndex[(i, j)] = [ 56 | np.argwhere(matrices[i].index == x).squeeze().tolist() 57 | for x in commonList 58 | ] 59 | dict_commonSampleIndex[(j, i)] = [ 60 | np.argwhere(matrices[j].index == x).squeeze().tolist() 61 | for x in commonList 62 | ] 63 | 64 | dict_uniqueSample[(i, j)] = list( 65 | set(original_order[i]).symmetric_difference(commonList) 66 | ) 67 | dict_uniqueSample[(j, i)] = list( 68 | set(original_order[j]).symmetric_difference(commonList) 69 | ) 70 | 71 | return ( 72 | dict_commonSample, 73 | dict_commonSampleIndex, 74 | dict_sampleToIndexs, 75 | dict_uniqueSample, 76 | original_order, 77 | dict_original_order, 78 | ) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "integrao" 7 | version = "0.1.3" 8 | dependencies = [ 9 | "snfpy", 10 | "networkx==3.0", 11 | "numpy>=1.24.1", 12 | "pandas>=1.3.5", 13 | "scikit-learn==1.3.2", 14 | "scipy==1.11.4", 15 | "umap-learn>=0.5.5", 16 | "torch-geometric>=2.3.0" 17 | ] 18 | description = "The Python implementation of the IntegrAO." 19 | authors = [ 20 | { name="Shihao Ma", email="rex.ma@mail.utoronto.ca" }, 21 | ] 22 | readme = "README.md" 23 | requires-python = ">=3.9" 24 | classifiers = [ 25 | "Programming Language :: Python :: 3", 26 | "License :: OSI Approved :: MIT License", 27 | "Operating System :: OS Independent", 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://github.com/bowang-lab/IntegrAO" 32 | Issues = "https://github.com/bowang-lab/IntegrAO/issues" 33 | 34 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | snfpy 2 | matplotlib==3.5.3 3 | networkx==3.0 4 | numpy==1.24.1 5 | pandas==1.3.5 6 | scikit-learn==1.3.2 7 | scipy==1.11.4 8 | seaborn==0.12.2 9 | skunk==1.2.0 10 | typing_extensions==4.4.0 11 | umap-learn==0.5.5 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/IntegrAO/5be3f4a46a1a4739ddf737008e1164783f48c37a/tests/__init__.py -------------------------------------------------------------------------------- /tutorials/cancer_omics_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# New patient classification with incomplete omics profiles" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Import packages and IntegrAO code" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stderr", 24 | "output_type": "stream", 25 | "text": [ 26 | "/home/jma/anaconda3/envs/integrAO/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 27 | " from .autonotebook import tqdm as notebook_tqdm\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import pandas as pd\n", 34 | "import snf\n", 35 | "from sklearn.cluster import spectral_clustering\n", 36 | "from sklearn.metrics import v_measure_score\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "\n", 39 | "import sys\n", 40 | "import os\n", 41 | "import argparse\n", 42 | "import torch\n", 43 | "\n", 44 | "import umap\n", 45 | "from sklearn.model_selection import train_test_split" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Add the parent directory of \"integrao\" to the Python path\n", 55 | "module_path = os.path.abspath(os.path.join('../'))\n", 56 | "if module_path not in sys.path:\n", 57 | " sys.path.append(module_path)\n", 58 | " \n", 59 | "from integrao.dataset import GraphDataset\n", 60 | "from integrao.main import dist2\n", 61 | "from integrao.integrater import integrao_integrater, integrao_predictor" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Set hyperparameters" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "True" 80 | ] 81 | }, 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "torch.cuda.is_available()" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# Hyperparameters\n", 98 | "neighbor_size = 20\n", 99 | "embedding_dims = 64\n", 100 | "fusing_iteration = 30\n", 101 | "normalization_factor = 1.0\n", 102 | "alighment_epochs = 1000\n", 103 | "beta = 1.0\n", 104 | "mu = 0.5\n", 105 | "\n", 106 | "\n", 107 | "dataset_name = 'cancer_omics_prediction'\n", 108 | "cluster_number = 15" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# create result dir\n", 118 | "result_dir = os.path.join(\n", 119 | " module_path, \"results/{}\".format(dataset_name)\n", 120 | ")\n", 121 | "if not os.path.exists(result_dir):\n", 122 | " os.makedirs(result_dir)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Read data" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "(500, 367)\n", 142 | "(500, 131)\n", 143 | "(500, 160)\n", 144 | "(500, 2)\n", 145 | "finish loading data!\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "testdata_dir = os.path.join(module_path, \"data/omics/\")\n", 151 | "\n", 152 | "methyl_ = os.path.join(testdata_dir, \"omics1.txt\")\n", 153 | "expr_ = os.path.join(testdata_dir, \"omics2.txt\")\n", 154 | "protein_ = os.path.join(testdata_dir, \"omics3.txt\")\n", 155 | "truelabel = os.path.join(testdata_dir, \"clusters.txt\")\n", 156 | "\n", 157 | "\n", 158 | "methyl = pd.read_csv(methyl_, index_col=0, delimiter=\"\\t\")\n", 159 | "expr = pd.read_csv(expr_, index_col=0, delimiter=\"\\t\")\n", 160 | "protein = pd.read_csv(protein_, index_col=0, delimiter=\"\\t\")\n", 161 | "truelabel = pd.read_csv(truelabel, index_col=0, delimiter=\"\\t\")\n", 162 | "\n", 163 | "methyl = np.transpose(methyl)\n", 164 | "expr = np.transpose(expr)\n", 165 | "protein = np.transpose(protein)\n", 166 | "print(methyl.shape)\n", 167 | "print(expr.shape)\n", 168 | "print(protein.shape)\n", 169 | "print(truelabel.shape)\n", 170 | "print(\"finish loading data!\")" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Random stratified-subsample 80%-20% samples to simulate the senario of incomplete omics dataset\n" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 7, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/html": [ 188 | "
\n", 189 | "\n", 202 | "\n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | "
subjectscluster.id
1subject16
2subject27
3subject39
4subject46
5subject54
.........
496subject4961
497subject49714
498subject4984
499subject4991
500subject5009
\n", 268 | "

500 rows × 2 columns

\n", 269 | "
" 270 | ], 271 | "text/plain": [ 272 | " subjects cluster.id\n", 273 | "1 subject1 6\n", 274 | "2 subject2 7\n", 275 | "3 subject3 9\n", 276 | "4 subject4 6\n", 277 | "5 subject5 4\n", 278 | ".. ... ...\n", 279 | "496 subject496 1\n", 280 | "497 subject497 14\n", 281 | "498 subject498 4\n", 282 | "499 subject499 1\n", 283 | "500 subject500 9\n", 284 | "\n", 285 | "[500 rows x 2 columns]" 286 | ] 287 | }, 288 | "execution_count": 7, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "truelabel" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 8, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "common_patient = methyl.index\n", 304 | "y = truelabel['cluster.id'].tolist()\n", 305 | "\n", 306 | "X_train, X_test, y_train, y_test = train_test_split(common_patient, y, stratify=y, test_size=0.2)\n", 307 | "\n", 308 | "# get the reference and query data\n", 309 | "methyl_ref = methyl.loc[X_train]\n", 310 | "expr_ref = expr.loc[X_train]\n", 311 | "protein_ref = protein.loc[X_train]\n", 312 | "\n", 313 | "methyl_query = methyl.loc[X_test]\n", 314 | "expr_query = expr.loc[X_test]\n", 315 | "protein_query = protein.loc[X_test]" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "## Now let's intergrate the reference data " 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 9, 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "Start indexing input expression matrices!\n", 335 | "Common sample between view0 and view1: 400\n", 336 | "Common sample between view0 and view2: 400\n", 337 | "Common sample between view1 and view2: 400\n", 338 | "Neighbor size: 20\n", 339 | "Start applying diffusion!\n", 340 | "Diffusion ends! Times: 4.388705015182495s\n", 341 | "Starting unsupervised exmbedding extraction!\n", 342 | "Dataset 0: (400, 367)\n", 343 | "Dataset 1: (400, 131)\n", 344 | "Dataset 2: (400, 160)\n", 345 | "epoch 0: loss 27.789127349853516, align_loss:0.744084\n", 346 | "epoch 100: loss 19.32496452331543, align_loss:0.101542\n", 347 | "epoch 200: loss 0.7247133255004883, align_loss:0.061813\n", 348 | "epoch 300: loss 0.7239326238632202, align_loss:0.061128\n", 349 | "epoch 400: loss 0.7230430245399475, align_loss:0.060395\n", 350 | "epoch 500: loss 0.7220683693885803, align_loss:0.059626\n", 351 | "epoch 600: loss 0.7210268974304199, align_loss:0.058864\n", 352 | "epoch 700: loss 0.7199146747589111, align_loss:0.058144\n", 353 | "epoch 800: loss 0.718734085559845, align_loss:0.057389\n", 354 | "epoch 900: loss 0.7174885272979736, align_loss:0.056649\n", 355 | "Manifold alignment ends! Times: 7.765907526016235s\n" 356 | ] 357 | } 358 | ], 359 | "source": [ 360 | "# Initialize integrater\n", 361 | "integrater = integrao_integrater(\n", 362 | " [methyl_ref, expr_ref, protein_ref],\n", 363 | " dataset_name,\n", 364 | " modalities_name_list=[\"methyl\", \"expr\", \"protein\"], # used for naming the incomplete modalities during new sample inference\n", 365 | " neighbor_size=neighbor_size,\n", 366 | " embedding_dims=embedding_dims,\n", 367 | " fusing_iteration=fusing_iteration,\n", 368 | " normalization_factor=normalization_factor,\n", 369 | " alighment_epochs=alighment_epochs,\n", 370 | " beta=beta,\n", 371 | " mu=mu,\n", 372 | ")\n", 373 | "# data indexing\n", 374 | "fused_networks = integrater.network_diffusion()\n", 375 | "embeds_final, S_final, model = integrater.unsupervised_alignment()\n", 376 | "\n", 377 | "# save the model for fine-tuning\n", 378 | "torch.save(model.state_dict(), os.path.join(result_dir, \"model.pth\"))" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 10, 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "IntegrAO for clustering reference 400 samples NMI score: 1.0\n" 391 | ] 392 | } 393 | ], 394 | "source": [ 395 | "labels = spectral_clustering(S_final, n_clusters=cluster_number)\n", 396 | "\n", 397 | "# select from truelabel based on the 'subjects' column in embeds_final\n", 398 | "truelabel_filtered = truelabel[truelabel['subjects'].isin(embeds_final.index)]\n", 399 | "truelabel_filtered = truelabel_filtered.sort_values('subjects')['cluster.id'].tolist()\n", 400 | "\n", 401 | "score_all = v_measure_score(truelabel_filtered, labels)\n", 402 | "print(\"IntegrAO for clustering reference 400 samples NMI score: \", score_all)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "## Now to perform fine-tuning using on the ground true labels" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 11, 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "data": { 419 | "text/html": [ 420 | "
\n", 421 | "\n", 434 | "\n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | "
cluster.id
subjects
subject15
subject26
subject38
subject45
subject53
......
subject4960
subject49713
subject4983
subject4990
subject5008
\n", 492 | "

400 rows × 1 columns

\n", 493 | "
" 494 | ], 495 | "text/plain": [ 496 | " cluster.id\n", 497 | "subjects \n", 498 | "subject1 5\n", 499 | "subject2 6\n", 500 | "subject3 8\n", 501 | "subject4 5\n", 502 | "subject5 3\n", 503 | "... ...\n", 504 | "subject496 0\n", 505 | "subject497 13\n", 506 | "subject498 3\n", 507 | "subject499 0\n", 508 | "subject500 8\n", 509 | "\n", 510 | "[400 rows x 1 columns]" 511 | ] 512 | }, 513 | "execution_count": 11, 514 | "metadata": {}, 515 | "output_type": "execute_result" 516 | } 517 | ], 518 | "source": [ 519 | "truelabel_sub = truelabel[truelabel['subjects'].isin(embeds_final.index)]\n", 520 | "truelabel_sub = truelabel_sub.set_index('subjects')\n", 521 | "\n", 522 | "# minus 1 for the cluster id to avoid CUDA error\n", 523 | "truelabel_sub['cluster.id'] = truelabel_sub['cluster.id'] - 1\n", 524 | "truelabel_sub" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 13, 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "name": "stdout", 534 | "output_type": "stream", 535 | "text": [ 536 | "Starting supervised fineting!\n", 537 | "Dataset 0: (400, 367)\n", 538 | "Dataset 1: (400, 131)\n", 539 | "Dataset 2: (400, 160)\n", 540 | "IntegrAO(\n", 541 | " (feature): ModuleList(\n", 542 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 543 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 544 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 545 | " )\n", 546 | " (feature_show): Sequential(\n", 547 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 548 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 549 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 550 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 551 | " )\n", 552 | " (pred_head): Sequential(\n", 553 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 554 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 555 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 556 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 557 | " )\n", 558 | ")\n", 559 | "Loaded pre-trained model with success.\n", 560 | "epoch 0: loss 3.224135637283325, kl_loss:0.669090, align_loss:0.052747, clf_loss:2.502299\n", 561 | "epoch 100: loss 0.6574736833572388, kl_loss:0.615961, align_loss:0.041423, clf_loss:0.000090\n", 562 | "epoch 200: loss 0.6475169658660889, kl_loss:0.607740, align_loss:0.039688, clf_loss:0.000089\n", 563 | "epoch 300: loss 0.6462971568107605, kl_loss:0.606705, align_loss:0.039503, clf_loss:0.000089\n", 564 | "epoch 400: loss 0.6448504328727722, kl_loss:0.605449, align_loss:0.039312, clf_loss:0.000089\n", 565 | "epoch 500: loss 0.6431947350502014, kl_loss:0.604019, align_loss:0.039087, clf_loss:0.000089\n", 566 | "epoch 600: loss 0.6413335204124451, kl_loss:0.602399, align_loss:0.038845, clf_loss:0.000090\n", 567 | "epoch 700: loss 0.639265775680542, kl_loss:0.600594, align_loss:0.038581, clf_loss:0.000091\n", 568 | "Manifold alignment ends! Times: 21.10444712638855s\n" 569 | ] 570 | } 571 | ], 572 | "source": [ 573 | "embeds_final, S_final, model, preds = integrater.classification_finetuning(truelabel_sub, result_dir, finetune_epochs=800)" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 14, 579 | "metadata": {}, 580 | "outputs": [], 581 | "source": [ 582 | "torch.save(model.state_dict(), os.path.join(result_dir, \"model_integrao_supervised.pth\"))" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": {}, 588 | "source": [ 589 | "## Now to perform inference on query data" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 15, 595 | "metadata": {}, 596 | "outputs": [ 597 | { 598 | "name": "stdout", 599 | "output_type": "stream", 600 | "text": [ 601 | "Start indexing input expression matrices!\n", 602 | "Common sample between view0 and view1: 500\n", 603 | "Common sample between view0 and view2: 500\n", 604 | "Common sample between view1 and view2: 500\n", 605 | "Neighbor size: 20\n", 606 | "Start applying diffusion!\n", 607 | "Diffusion ends! Times: 5.840329647064209s\n" 608 | ] 609 | } 610 | ], 611 | "source": [ 612 | "# Network fusion for the whole graph\n", 613 | "predictor = integrao_predictor(\n", 614 | " [methyl, expr, protein],\n", 615 | " dataset_name,\n", 616 | " modalities_name_list=[\"methyl\", \"expr\", \"protein\"], \n", 617 | " neighbor_size=neighbor_size,\n", 618 | " embedding_dims=embedding_dims,\n", 619 | " fusing_iteration=fusing_iteration,\n", 620 | " normalization_factor=normalization_factor,\n", 621 | " alighment_epochs=alighment_epochs,\n", 622 | " beta=beta,\n", 623 | " mu=mu,\n", 624 | " num_classes=cluster_number,\n", 625 | ")\n", 626 | "# data indexing\n", 627 | "fused_networks = predictor.network_diffusion()" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 16, 633 | "metadata": {}, 634 | "outputs": [], 635 | "source": [ 636 | "from sklearn.metrics import accuracy_score, f1_score\n", 637 | "\n", 638 | "# helper function to get the metrics on test set\n", 639 | "def get_metrics(preds, preds_index, X_test, y_test):\n", 640 | "\n", 641 | " pred_df = pd.DataFrame(data=preds, index=preds_index)\n", 642 | " pred_df_test = pred_df.loc[X_test]\n", 643 | "\n", 644 | " # add 1 back to the cluster id\n", 645 | " pred_df_test = pred_df_test + 1\n", 646 | "\n", 647 | " f1_micro = f1_score(y_test, pred_df_test, average='micro')\n", 648 | " f1_weighted = f1_score(y_test, pred_df_test, average='weighted')\n", 649 | " acc = accuracy_score(y_test, pred_df_test)\n", 650 | "\n", 651 | " return f1_micro, f1_weighted, acc\n" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": {}, 657 | "source": [ 658 | "## Using one modalities" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 17, 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "model_path = os.path.join(result_dir, \"model_integrao_supervised.pth\")" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 18, 673 | "metadata": {}, 674 | "outputs": [ 675 | { 676 | "name": "stdout", 677 | "output_type": "stream", 678 | "text": [ 679 | "IntegrAO(\n", 680 | " (feature): ModuleList(\n", 681 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 682 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 683 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 684 | " )\n", 685 | " (feature_show): Sequential(\n", 686 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 687 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 688 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 689 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 690 | " )\n", 691 | " (pred_head): Sequential(\n", 692 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 693 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 694 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 695 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 696 | " )\n", 697 | ")\n", 698 | "Loaded pre-trained model with success.\n", 699 | "methyl f1_micro: 1.0\n", 700 | "methyl f1_weight: 1.0\n", 701 | "methyl acc: 1.0\n" 702 | ] 703 | } 704 | ], 705 | "source": [ 706 | "# for methyl\n", 707 | "preds = predictor.inference_supervised(model_path, new_datasets=[methyl], modalities_names=[\"methyl\"])\n", 708 | "\n", 709 | "f1_micro, f1_weight, acc = get_metrics(preds, methyl.index, X_test, y_test)\n", 710 | "\n", 711 | "print(\"methyl f1_micro: \", f1_micro)\n", 712 | "print(\"methyl f1_weight: \", f1_weight)\n", 713 | "print(\"methyl acc: \", acc)" 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": 20, 719 | "metadata": {}, 720 | "outputs": [ 721 | { 722 | "name": "stdout", 723 | "output_type": "stream", 724 | "text": [ 725 | "IntegrAO(\n", 726 | " (feature): ModuleList(\n", 727 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 728 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 729 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 730 | " )\n", 731 | " (feature_show): Sequential(\n", 732 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 733 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 734 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 735 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 736 | " )\n", 737 | " (pred_head): Sequential(\n", 738 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 739 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 740 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 741 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 742 | " )\n", 743 | ")\n", 744 | "Loaded pre-trained model with success.\n", 745 | "expr f1_micro: 0.99\n", 746 | "expr f1_weight: 0.989047619047619\n", 747 | "expr acc: 0.99\n" 748 | ] 749 | } 750 | ], 751 | "source": [ 752 | "# for expr\n", 753 | "preds = predictor.inference_supervised(model_path, new_datasets=[expr], modalities_names=[\"expr\"])\n", 754 | "\n", 755 | "f1_micro, f1_weight, acc = get_metrics(preds, expr.index, X_test, y_test)\n", 756 | "\n", 757 | "print(\"expr f1_micro: \", f1_micro)\n", 758 | "print(\"expr f1_weight: \", f1_weight)\n", 759 | "print(\"expr acc: \", acc)" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 21, 765 | "metadata": {}, 766 | "outputs": [ 767 | { 768 | "name": "stdout", 769 | "output_type": "stream", 770 | "text": [ 771 | "IntegrAO(\n", 772 | " (feature): ModuleList(\n", 773 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 774 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 775 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 776 | " )\n", 777 | " (feature_show): Sequential(\n", 778 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 779 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 780 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 781 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 782 | " )\n", 783 | " (pred_head): Sequential(\n", 784 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 785 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 786 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 787 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 788 | " )\n", 789 | ")\n", 790 | "Loaded pre-trained model with success.\n", 791 | "protein f1_micro: 1.0\n", 792 | "protein f1_weight: 1.0\n", 793 | "protein acc: 1.0\n" 794 | ] 795 | } 796 | ], 797 | "source": [ 798 | "# for protein\n", 799 | "preds = predictor.inference_supervised(model_path, new_datasets=[protein], modalities_names=[\"protein\"])\n", 800 | "\n", 801 | "f1_micro, f1_weight, acc = get_metrics(preds, protein.index, X_test, y_test)\n", 802 | "\n", 803 | "\n", 804 | "print(\"protein f1_micro: \", f1_micro)\n", 805 | "print(\"protein f1_weight: \", f1_weight)\n", 806 | "print(\"protein acc: \", acc)" 807 | ] 808 | }, 809 | { 810 | "cell_type": "markdown", 811 | "metadata": {}, 812 | "source": [ 813 | "## Two modalities" 814 | ] 815 | }, 816 | { 817 | "cell_type": "code", 818 | "execution_count": 22, 819 | "metadata": {}, 820 | "outputs": [ 821 | { 822 | "name": "stdout", 823 | "output_type": "stream", 824 | "text": [ 825 | "IntegrAO(\n", 826 | " (feature): ModuleList(\n", 827 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 828 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 829 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 830 | " )\n", 831 | " (feature_show): Sequential(\n", 832 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 833 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 834 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 835 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 836 | " )\n", 837 | " (pred_head): Sequential(\n", 838 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 839 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 840 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 841 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 842 | " )\n", 843 | ")\n", 844 | "Loaded pre-trained model with success.\n", 845 | "methyl+expr f1_micro: 1.0\n", 846 | "methyl+expr f1_weight: 1.0\n", 847 | "methyl+expr acc: 1.0\n" 848 | ] 849 | } 850 | ], 851 | "source": [ 852 | "# methyl and expr\n", 853 | "preds = predictor.inference_supervised(model_path, new_datasets=[methyl, expr], modalities_names=[\"methyl\", \"expr\"])\n", 854 | "\n", 855 | "f1_micro, f1_weight, acc = get_metrics(preds, methyl.index, X_test, y_test)\n", 856 | "\n", 857 | "print(\"methyl+expr f1_micro: \", f1_micro)\n", 858 | "print(\"methyl+expr f1_weight: \", f1_weight)\n", 859 | "print(\"methyl+expr acc: \", acc)" 860 | ] 861 | }, 862 | { 863 | "cell_type": "code", 864 | "execution_count": 23, 865 | "metadata": {}, 866 | "outputs": [ 867 | { 868 | "name": "stdout", 869 | "output_type": "stream", 870 | "text": [ 871 | "IntegrAO(\n", 872 | " (feature): ModuleList(\n", 873 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 874 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 875 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 876 | " )\n", 877 | " (feature_show): Sequential(\n", 878 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 879 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 880 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 881 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 882 | " )\n", 883 | " (pred_head): Sequential(\n", 884 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 885 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 886 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 887 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 888 | " )\n", 889 | ")\n", 890 | "Loaded pre-trained model with success.\n", 891 | "methyl+protein f1_micro: 1.0\n", 892 | "methyl+protein f1_weight: 1.0\n", 893 | "methyl+protein acc: 1.0\n" 894 | ] 895 | } 896 | ], 897 | "source": [ 898 | "# methyl and protein\n", 899 | "preds = predictor.inference_supervised(model_path, new_datasets=[methyl, protein], modalities_names=[\"methyl\", \"protein\"])\n", 900 | "\n", 901 | "f1_micro, f1_weight, acc = get_metrics(preds, methyl.index, X_test, y_test)\n", 902 | "\n", 903 | "print(\"methyl+protein f1_micro: \", f1_micro)\n", 904 | "print(\"methyl+protein f1_weight: \", f1_weight)\n", 905 | "print(\"methyl+protein acc: \", acc)" 906 | ] 907 | }, 908 | { 909 | "cell_type": "code", 910 | "execution_count": 24, 911 | "metadata": {}, 912 | "outputs": [ 913 | { 914 | "name": "stdout", 915 | "output_type": "stream", 916 | "text": [ 917 | "IntegrAO(\n", 918 | " (feature): ModuleList(\n", 919 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 920 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 921 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 922 | " )\n", 923 | " (feature_show): Sequential(\n", 924 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 925 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 926 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 927 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 928 | " )\n", 929 | " (pred_head): Sequential(\n", 930 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 931 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 932 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 933 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 934 | " )\n", 935 | ")\n", 936 | "Loaded pre-trained model with success.\n", 937 | "expr+protein f1_micro: 1.0\n", 938 | "expr+protein f1_weight: 1.0\n", 939 | "expr+protein acc: 1.0\n" 940 | ] 941 | } 942 | ], 943 | "source": [ 944 | "# expr and protein\n", 945 | "preds = predictor.inference_supervised(model_path, new_datasets=[expr, protein], modalities_names=[\"expr\", \"protein\"])\n", 946 | "\n", 947 | "f1_micro, f1_weight, acc = get_metrics(preds, expr.index, X_test, y_test)\n", 948 | "\n", 949 | "print(\"expr+protein f1_micro: \", f1_micro)\n", 950 | "print(\"expr+protein f1_weight: \", f1_weight)\n", 951 | "print(\"expr+protein acc: \", acc)" 952 | ] 953 | }, 954 | { 955 | "cell_type": "markdown", 956 | "metadata": {}, 957 | "source": [ 958 | "## Three modalities" 959 | ] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": 25, 964 | "metadata": {}, 965 | "outputs": [ 966 | { 967 | "name": "stdout", 968 | "output_type": "stream", 969 | "text": [ 970 | "IntegrAO(\n", 971 | " (feature): ModuleList(\n", 972 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 973 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 974 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 975 | " )\n", 976 | " (feature_show): Sequential(\n", 977 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 978 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 979 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 980 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 981 | " )\n", 982 | " (pred_head): Sequential(\n", 983 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 984 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 985 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 986 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 987 | " )\n", 988 | ")\n", 989 | "Loaded pre-trained model with success.\n", 990 | "methyl+expr+protein f1_micro: 1.0\n", 991 | "methyl+expr+protein f1_weight: 1.0\n", 992 | "methyl+expr+protein acc: 1.0\n" 993 | ] 994 | } 995 | ], 996 | "source": [ 997 | "# methyl, expr and protein\n", 998 | "preds = predictor.inference_supervised(model_path, new_datasets=[methyl, expr, protein], modalities_names=[\"methyl\", \"expr\", \"protein\"])\n", 999 | "\n", 1000 | "f1_micro, f1_weight, acc = get_metrics(preds, methyl.index, X_test, y_test)\n", 1001 | "\n", 1002 | "print(\"methyl+expr+protein f1_micro: \", f1_micro)\n", 1003 | "print(\"methyl+expr+protein f1_weight: \", f1_weight)\n", 1004 | "print(\"methyl+expr+protein acc: \", acc)" 1005 | ] 1006 | } 1007 | ], 1008 | "metadata": { 1009 | "kernelspec": { 1010 | "display_name": "integrAO", 1011 | "language": "python", 1012 | "name": "python3" 1013 | }, 1014 | "language_info": { 1015 | "codemirror_mode": { 1016 | "name": "ipython", 1017 | "version": 3 1018 | }, 1019 | "file_extension": ".py", 1020 | "mimetype": "text/x-python", 1021 | "name": "python", 1022 | "nbconvert_exporter": "python", 1023 | "pygments_lexer": "ipython3", 1024 | "version": "3.10.16" 1025 | } 1026 | }, 1027 | "nbformat": 4, 1028 | "nbformat_minor": 2 1029 | } 1030 | -------------------------------------------------------------------------------- /tutorials/supervised_integration_feature_importance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## New patient classification with incomplete omics profiles; load the saved model and perform feature importance extraction" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Import packages and IntegrAO code" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stderr", 24 | "output_type": "stream", 25 | "text": [ 26 | "/home/jma/anaconda3/envs/integrAO/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 27 | " from .autonotebook import tqdm as notebook_tqdm\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import pandas as pd\n", 34 | "import snf\n", 35 | "from sklearn.cluster import spectral_clustering\n", 36 | "from sklearn.metrics import v_measure_score\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "\n", 39 | "import sys\n", 40 | "import os\n", 41 | "import argparse\n", 42 | "import torch\n", 43 | "\n", 44 | "import umap\n", 45 | "from sklearn.model_selection import train_test_split" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Add the parent directory of \"integrao\" to the Python path\n", 55 | "module_path = os.path.abspath(os.path.join('../'))\n", 56 | "if module_path not in sys.path:\n", 57 | " sys.path.append(module_path)\n", 58 | " \n", 59 | "from integrao.dataset import GraphDataset\n", 60 | "from integrao.main import dist2\n", 61 | "from integrao.integrater import integrao_integrater, integrao_predictor" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Set hyperparameters" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "True" 80 | ] 81 | }, 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "torch.cuda.is_available()" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# Hyperparameters\n", 98 | "neighbor_size = 20\n", 99 | "embedding_dims = 64\n", 100 | "fusing_iteration = 30\n", 101 | "normalization_factor = 1.0\n", 102 | "alighment_epochs = 1000\n", 103 | "beta = 1.0\n", 104 | "mu = 0.5\n", 105 | "\n", 106 | "\n", 107 | "dataset_name = 'supervised_integration_feature_importance'\n", 108 | "cluster_number = 15" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# create result dir\n", 118 | "result_dir = os.path.join(\n", 119 | " module_path, \"results/{}\".format(dataset_name)\n", 120 | ")\n", 121 | "if not os.path.exists(result_dir):\n", 122 | " os.makedirs(result_dir)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Read data" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "(500, 367)\n", 142 | "(500, 131)\n", 143 | "(500, 160)\n", 144 | "(500, 2)\n", 145 | "finish loading data!\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "testdata_dir = os.path.join(module_path, \"data/omics/\")\n", 151 | "\n", 152 | "methyl_ = os.path.join(testdata_dir, \"omics1.txt\")\n", 153 | "expr_ = os.path.join(testdata_dir, \"omics2.txt\")\n", 154 | "protein_ = os.path.join(testdata_dir, \"omics3.txt\")\n", 155 | "truelabel = os.path.join(testdata_dir, \"clusters.txt\")\n", 156 | "\n", 157 | "\n", 158 | "methyl = pd.read_csv(methyl_, index_col=0, delimiter=\"\\t\")\n", 159 | "expr = pd.read_csv(expr_, index_col=0, delimiter=\"\\t\")\n", 160 | "protein = pd.read_csv(protein_, index_col=0, delimiter=\"\\t\")\n", 161 | "truelabel = pd.read_csv(truelabel, index_col=0, delimiter=\"\\t\")\n", 162 | "\n", 163 | "methyl = np.transpose(methyl)\n", 164 | "expr = np.transpose(expr)\n", 165 | "protein = np.transpose(protein)\n", 166 | "print(methyl.shape)\n", 167 | "print(expr.shape)\n", 168 | "print(protein.shape)\n", 169 | "print(truelabel.shape)\n", 170 | "print(\"finish loading data!\")" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Random stratified-subsample 80%-20% samples to simulate the senario of incomplete omics dataset\n" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 7, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/html": [ 188 | "
\n", 189 | "\n", 202 | "\n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | "
subjectscluster.id
1subject16
2subject27
3subject39
4subject46
5subject54
.........
496subject4961
497subject49714
498subject4984
499subject4991
500subject5009
\n", 268 | "

500 rows × 2 columns

\n", 269 | "
" 270 | ], 271 | "text/plain": [ 272 | " subjects cluster.id\n", 273 | "1 subject1 6\n", 274 | "2 subject2 7\n", 275 | "3 subject3 9\n", 276 | "4 subject4 6\n", 277 | "5 subject5 4\n", 278 | ".. ... ...\n", 279 | "496 subject496 1\n", 280 | "497 subject497 14\n", 281 | "498 subject498 4\n", 282 | "499 subject499 1\n", 283 | "500 subject500 9\n", 284 | "\n", 285 | "[500 rows x 2 columns]" 286 | ] 287 | }, 288 | "execution_count": 7, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "truelabel" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 8, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "common_patient = methyl.index\n", 304 | "y = truelabel['cluster.id'].tolist()\n", 305 | "\n", 306 | "X_train, X_test, y_train, y_test = train_test_split(common_patient, y, stratify=y, test_size=0.2)\n", 307 | "\n", 308 | "# get the reference and query data\n", 309 | "methyl_ref = methyl.loc[X_train]\n", 310 | "expr_ref = expr.loc[X_train]\n", 311 | "protein_ref = protein.loc[X_train]\n", 312 | "\n", 313 | "methyl_query = methyl.loc[X_test]\n", 314 | "expr_query = expr.loc[X_test]\n", 315 | "protein_query = protein.loc[X_test]" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "## Now let's intergrate the reference data " 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 9, 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "Start indexing input expression matrices!\n", 335 | "Common sample between view0 and view1: 400\n", 336 | "Common sample between view0 and view2: 400\n", 337 | "Common sample between view1 and view2: 400\n", 338 | "Neighbor size: 20\n", 339 | "Start applying diffusion!\n", 340 | "Diffusion ends! Times: 4.6185383796691895s\n", 341 | "Starting unsupervised exmbedding extraction!\n", 342 | "Dataset 0: (400, 367)\n", 343 | "Dataset 1: (400, 131)\n", 344 | "Dataset 2: (400, 160)\n", 345 | "epoch 0: loss 27.530149459838867, align_loss:0.731598\n", 346 | "epoch 100: loss 19.294527053833008, align_loss:0.099107\n", 347 | "epoch 200: loss 0.7154172658920288, align_loss:0.059697\n", 348 | "epoch 300: loss 0.7146754264831543, align_loss:0.059069\n", 349 | "epoch 400: loss 0.7138354182243347, align_loss:0.058444\n", 350 | "epoch 500: loss 0.7129197120666504, align_loss:0.057739\n", 351 | "epoch 600: loss 0.7119321823120117, align_loss:0.057074\n", 352 | "epoch 700: loss 0.7108953595161438, align_loss:0.056396\n", 353 | "epoch 800: loss 0.7097985744476318, align_loss:0.055732\n", 354 | "epoch 900: loss 0.7086648344993591, align_loss:0.055095\n", 355 | "Manifold alignment ends! Times: 7.758870363235474s\n" 356 | ] 357 | } 358 | ], 359 | "source": [ 360 | "# Initialize integrater\n", 361 | "integrater = integrao_integrater(\n", 362 | " [methyl_ref, expr_ref, protein_ref],\n", 363 | " dataset_name,\n", 364 | " modalities_name_list=[\"methyl\", \"expr\", \"protein\"], # used for naming the incomplete modalities during new sample inference\n", 365 | " neighbor_size=neighbor_size,\n", 366 | " embedding_dims=embedding_dims,\n", 367 | " fusing_iteration=fusing_iteration,\n", 368 | " normalization_factor=normalization_factor,\n", 369 | " alighment_epochs=alighment_epochs,\n", 370 | " beta=beta,\n", 371 | " mu=mu,\n", 372 | ")\n", 373 | "# data indexing\n", 374 | "fused_networks = integrater.network_diffusion()\n", 375 | "embeds_final, S_final, model = integrater.unsupervised_alignment()\n", 376 | "\n", 377 | "# save the model for fine-tuning\n", 378 | "torch.save(model.state_dict(), os.path.join(result_dir, \"model.pth\"))" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 10, 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "IntegrAO for clustering reference 400 samples NMI score: 1.0\n" 391 | ] 392 | } 393 | ], 394 | "source": [ 395 | "labels = spectral_clustering(S_final, n_clusters=cluster_number)\n", 396 | "\n", 397 | "# select from truelabel based on the 'subjects' column in embeds_final\n", 398 | "truelabel_filtered = truelabel[truelabel['subjects'].isin(embeds_final.index)]\n", 399 | "truelabel_filtered = truelabel_filtered.sort_values('subjects')['cluster.id'].tolist()\n", 400 | "\n", 401 | "score_all = v_measure_score(truelabel_filtered, labels)\n", 402 | "print(\"IntegrAO for clustering reference 400 samples NMI score: \", score_all)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "## Now to perform fine-tuning using on the ground true labels" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 11, 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "data": { 419 | "text/html": [ 420 | "
\n", 421 | "\n", 434 | "\n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | "
cluster.id
subjects
subject15
subject26
subject45
subject53
subject610
......
subject4958
subject4960
subject4983
subject4990
subject5008
\n", 492 | "

400 rows × 1 columns

\n", 493 | "
" 494 | ], 495 | "text/plain": [ 496 | " cluster.id\n", 497 | "subjects \n", 498 | "subject1 5\n", 499 | "subject2 6\n", 500 | "subject4 5\n", 501 | "subject5 3\n", 502 | "subject6 10\n", 503 | "... ...\n", 504 | "subject495 8\n", 505 | "subject496 0\n", 506 | "subject498 3\n", 507 | "subject499 0\n", 508 | "subject500 8\n", 509 | "\n", 510 | "[400 rows x 1 columns]" 511 | ] 512 | }, 513 | "execution_count": 11, 514 | "metadata": {}, 515 | "output_type": "execute_result" 516 | } 517 | ], 518 | "source": [ 519 | "truelabel_sub = truelabel[truelabel['subjects'].isin(embeds_final.index)]\n", 520 | "truelabel_sub = truelabel_sub.set_index('subjects')\n", 521 | "\n", 522 | "# minus 1 for the cluster id to avoid CUDA error\n", 523 | "truelabel_sub['cluster.id'] = truelabel_sub['cluster.id'] - 1\n", 524 | "truelabel_sub" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 12, 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "name": "stdout", 534 | "output_type": "stream", 535 | "text": [ 536 | "Starting supervised fineting!\n", 537 | "Dataset 0: (400, 367)\n", 538 | "Dataset 1: (400, 131)\n", 539 | "Dataset 2: (400, 160)\n", 540 | "IntegrAO(\n", 541 | " (feature): ModuleList(\n", 542 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 543 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 544 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 545 | " )\n", 546 | " (feature_show): Sequential(\n", 547 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 548 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 549 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 550 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 551 | " )\n", 552 | " (pred_head): Sequential(\n", 553 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 554 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 555 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 556 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 557 | " )\n", 558 | ")\n", 559 | "Loaded pre-trained model with success.\n", 560 | "epoch 0: loss 3.6358721256256104, kl_loss:0.660988, align_loss:0.051447, clf_loss:2.923437\n", 561 | "epoch 100: loss 0.6453948020935059, kl_loss:0.604947, align_loss:0.040350, clf_loss:0.000097\n", 562 | "epoch 200: loss 0.6355220675468445, kl_loss:0.595312, align_loss:0.040107, clf_loss:0.000103\n", 563 | "epoch 300: loss 0.6342864036560059, kl_loss:0.594162, align_loss:0.040022, clf_loss:0.000103\n", 564 | "epoch 400: loss 0.632818341255188, kl_loss:0.592784, align_loss:0.039932, clf_loss:0.000103\n", 565 | "epoch 500: loss 0.6311599016189575, kl_loss:0.591240, align_loss:0.039818, clf_loss:0.000102\n", 566 | "epoch 600: loss 0.6293436288833618, kl_loss:0.589557, align_loss:0.039685, clf_loss:0.000102\n", 567 | "epoch 700: loss 0.6273597478866577, kl_loss:0.587750, align_loss:0.039508, clf_loss:0.000102\n", 568 | "Manifold alignment ends! Times: 20.068203926086426s\n" 569 | ] 570 | } 571 | ], 572 | "source": [ 573 | "embeds_final, S_final, model, preds = integrater.classification_finetuning(truelabel_sub, result_dir, finetune_epochs=800)" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 13, 579 | "metadata": {}, 580 | "outputs": [], 581 | "source": [ 582 | "torch.save(model.state_dict(), os.path.join(result_dir, \"model_integrao_supervised.pth\"))" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": {}, 588 | "source": [ 589 | "## Now to perform inference on query data" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 14, 595 | "metadata": {}, 596 | "outputs": [ 597 | { 598 | "name": "stdout", 599 | "output_type": "stream", 600 | "text": [ 601 | "Start indexing input expression matrices!\n", 602 | "Common sample between view0 and view1: 500\n", 603 | "Common sample between view0 and view2: 500\n", 604 | "Common sample between view1 and view2: 500\n", 605 | "Neighbor size: 20\n", 606 | "Start applying diffusion!\n", 607 | "Diffusion ends! Times: 5.997296571731567s\n" 608 | ] 609 | } 610 | ], 611 | "source": [ 612 | "# Network fusion for the whole graph\n", 613 | "predictor = integrao_predictor(\n", 614 | " [methyl, expr, protein],\n", 615 | " dataset_name,\n", 616 | " modalities_name_list=[\"methyl\", \"expr\", \"protein\"], \n", 617 | " neighbor_size=neighbor_size,\n", 618 | " embedding_dims=embedding_dims,\n", 619 | " fusing_iteration=fusing_iteration,\n", 620 | " normalization_factor=normalization_factor,\n", 621 | " alighment_epochs=alighment_epochs,\n", 622 | " beta=beta,\n", 623 | " mu=mu,\n", 624 | " num_classes=cluster_number,\n", 625 | ")\n", 626 | "# data indexing\n", 627 | "fused_networks = predictor.network_diffusion()" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 15, 633 | "metadata": {}, 634 | "outputs": [], 635 | "source": [ 636 | "from sklearn.metrics import accuracy_score, f1_score\n", 637 | "\n", 638 | "# helper function to get the metrics on test set\n", 639 | "def get_metrics(preds, preds_index, X_test, y_test):\n", 640 | "\n", 641 | " pred_df = pd.DataFrame(data=preds, index=preds_index)\n", 642 | " pred_df_test = pred_df.loc[X_test]\n", 643 | "\n", 644 | " # add 1 back to the cluster id\n", 645 | " pred_df_test = pred_df_test + 1\n", 646 | "\n", 647 | " f1_micro = f1_score(y_test, pred_df_test, average='micro')\n", 648 | " f1_weighted = f1_score(y_test, pred_df_test, average='weighted')\n", 649 | " acc = accuracy_score(y_test, pred_df_test)\n", 650 | "\n", 651 | " return f1_micro, f1_weighted, acc\n" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": {}, 657 | "source": [ 658 | "## Classification prediction using three modalities" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 16, 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "model_path = os.path.join(result_dir, \"model_integrao_supervised.pth\")" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 17, 673 | "metadata": {}, 674 | "outputs": [ 675 | { 676 | "name": "stdout", 677 | "output_type": "stream", 678 | "text": [ 679 | "IntegrAO(\n", 680 | " (feature): ModuleList(\n", 681 | " (0): GraphSAGE(367, 64, num_layers=2)\n", 682 | " (1): GraphSAGE(131, 64, num_layers=2)\n", 683 | " (2): GraphSAGE(160, 64, num_layers=2)\n", 684 | " )\n", 685 | " (feature_show): Sequential(\n", 686 | " (0): Linear(in_features=64, out_features=64, bias=True)\n", 687 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 688 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 689 | " (3): Linear(in_features=64, out_features=64, bias=True)\n", 690 | " )\n", 691 | " (pred_head): Sequential(\n", 692 | " (0): Linear(in_features=64, out_features=32, bias=True)\n", 693 | " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 694 | " (2): LeakyReLU(negative_slope=0.1, inplace=True)\n", 695 | " (3): Linear(in_features=32, out_features=15, bias=True)\n", 696 | " )\n", 697 | ")\n", 698 | "Loaded pre-trained model with success.\n", 699 | "methyl+expr+protein f1_micro: 1.0\n", 700 | "methyl+expr+protein f1_weight: 1.0\n", 701 | "methyl+expr+protein acc: 1.0\n" 702 | ] 703 | } 704 | ], 705 | "source": [ 706 | "# methyl, expr and protein\n", 707 | "preds = predictor.inference_supervised(model_path, new_datasets=[methyl, expr, protein], modalities_names=[\"methyl\", \"expr\", \"protein\"])\n", 708 | "\n", 709 | "f1_micro, f1_weight, acc = get_metrics(preds, methyl.index, X_test, y_test)\n", 710 | "\n", 711 | "print(\"methyl+expr+protein f1_micro: \", f1_micro)\n", 712 | "print(\"methyl+expr+protein f1_weight: \", f1_weight)\n", 713 | "print(\"methyl+expr+protein acc: \", acc)" 714 | ] 715 | }, 716 | { 717 | "cell_type": "markdown", 718 | "metadata": {}, 719 | "source": [ 720 | "## Now extract the feature importance for the supervised classification; the extracted feature importance will be saved in the result dir" 721 | ] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": {}, 726 | "source": [ 727 | "### If you want the interpret the feature importance toward a specifit task, you can modify the custom_forward in the predictor.interpret_supervised" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "metadata": {}, 734 | "outputs": [], 735 | "source": [ 736 | "df_list = predictor.interpret_supervised(model_path=model_path, result_dir=result_dir, new_datasets=[methyl, expr, protein], modalities_names=[\"methyl\", \"expr\", \"protein\"])" 737 | ] 738 | } 739 | ], 740 | "metadata": { 741 | "kernelspec": { 742 | "display_name": "integrAO", 743 | "language": "python", 744 | "name": "python3" 745 | }, 746 | "language_info": { 747 | "codemirror_mode": { 748 | "name": "ipython", 749 | "version": 3 750 | }, 751 | "file_extension": ".py", 752 | "mimetype": "text/x-python", 753 | "name": "python", 754 | "nbconvert_exporter": "python", 755 | "pygments_lexer": "ipython3", 756 | "version": "3.10.16" 757 | } 758 | }, 759 | "nbformat": 4, 760 | "nbformat_minor": 2 761 | } 762 | -------------------------------------------------------------------------------- /tutorials/unsupervised_integration_feature_importance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Integrate simulated Cancer Omics dataset; load the saved model and perform feature importance extraction" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "/home/jma/anaconda3/envs/integrAO/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 20 | " from .autonotebook import tqdm as notebook_tqdm\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "# Import packages and IntegrAO code\n", 26 | "import numpy as np\n", 27 | "import pandas as pd\n", 28 | "import snf\n", 29 | "from sklearn.cluster import spectral_clustering\n", 30 | "from sklearn.metrics import v_measure_score\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "\n", 33 | "import sys\n", 34 | "import os\n", 35 | "import argparse\n", 36 | "import torch\n", 37 | "\n", 38 | "import umap\n", 39 | "from sklearn.model_selection import train_test_split\n", 40 | "\n", 41 | "# Add the parent directory of \"integrao\" to the Python path\n", 42 | "module_path = os.path.abspath(os.path.join('../'))\n", 43 | "if module_path not in sys.path:\n", 44 | " sys.path.append(module_path)\n", 45 | " \n", 46 | "from integrao.dataset import GraphDataset\n", 47 | "from integrao.main import dist2\n", 48 | "from integrao.integrater import integrao_integrater, integrao_predictor" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# Set Hyperparameters\n", 58 | "neighbor_size = 20\n", 59 | "embedding_dims = 64\n", 60 | "fusing_iteration = 30\n", 61 | "normalization_factor = 1.0\n", 62 | "alighment_epochs = 1000\n", 63 | "beta = 1.0\n", 64 | "mu = 0.5\n", 65 | "\n", 66 | "\n", 67 | "dataset_name = 'unsupervised_integration_feature_importance'\n", 68 | "cluster_number = 15" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# create result dir\n", 78 | "result_dir = os.path.join(\n", 79 | " module_path, \"results/{}\".format(dataset_name)\n", 80 | ")\n", 81 | "if not os.path.exists(result_dir):\n", 82 | " os.makedirs(result_dir)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## Read data" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "(500, 367)\n", 102 | "(500, 131)\n", 103 | "(500, 160)\n", 104 | "(500, 2)\n", 105 | "finish loading data!\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "testdata_dir = os.path.join(module_path, \"data/omics/\")\n", 111 | "\n", 112 | "methyl_ = os.path.join(testdata_dir, \"omics1.txt\")\n", 113 | "expr_ = os.path.join(testdata_dir, \"omics2.txt\")\n", 114 | "protein_ = os.path.join(testdata_dir, \"omics3.txt\")\n", 115 | "truelabel = os.path.join(testdata_dir, \"clusters.txt\")\n", 116 | "\n", 117 | "\n", 118 | "methyl = pd.read_csv(methyl_, index_col=0, delimiter=\"\\t\")\n", 119 | "expr = pd.read_csv(expr_, index_col=0, delimiter=\"\\t\")\n", 120 | "protein = pd.read_csv(protein_, index_col=0, delimiter=\"\\t\")\n", 121 | "truelabel = pd.read_csv(truelabel, index_col=0, delimiter=\"\\t\")\n", 122 | "\n", 123 | "methyl = np.transpose(methyl)\n", 124 | "expr = np.transpose(expr)\n", 125 | "protein = np.transpose(protein)\n", 126 | "print(methyl.shape)\n", 127 | "print(expr.shape)\n", 128 | "print(protein.shape)\n", 129 | "print(truelabel.shape)\n", 130 | "print(\"finish loading data!\")" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "## Random sub-sample the omics dataset to create an incomplete dataset" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 5, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "ratio = 0.7\n", 147 | "\n", 148 | "full_indices = range(len(methyl))\n", 149 | "unique_indices, common_indices = train_test_split(full_indices, test_size=ratio)\n", 150 | "\n", 151 | "w1w2_indices, w3_indices = train_test_split(unique_indices, test_size=0.33)\n", 152 | "w1_indices, w2_indices = train_test_split(w1w2_indices, test_size=0.5)\n", 153 | "\n", 154 | "w1_full_indices = common_indices + w1_indices\n", 155 | "w2_full_indices = common_indices + w2_indices\n", 156 | "w3_full_indices = common_indices + w3_indices\n", 157 | "\n", 158 | "methyl_temp = methyl.iloc[w1_full_indices]\n", 159 | "expr_temp = expr.iloc[w2_full_indices]\n", 160 | "protein_temp = protein.iloc[w3_full_indices]\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## IntegrAO integration" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 6, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "Start indexing input expression matrices!\n", 180 | "Common sample between view0 and view1: 350\n", 181 | "Common sample between view0 and view2: 350\n", 182 | "Common sample between view1 and view2: 350\n", 183 | "Neighbor size: 20\n", 184 | "Start applying diffusion!\n", 185 | "Diffusion ends! Times: 4.647532224655151s\n", 186 | "Starting unsupervised exmbedding extraction!\n", 187 | "Dataset 0: (400, 367)\n", 188 | "Dataset 1: (400, 131)\n", 189 | "Dataset 2: (400, 160)\n", 190 | "epoch 0: loss 30.287778854370117, align_loss:0.747223\n", 191 | "epoch 100: loss 20.88467025756836, align_loss:0.178473\n", 192 | "epoch 200: loss 1.1458323001861572, align_loss:0.092696\n", 193 | "epoch 300: loss 1.144501805305481, align_loss:0.091768\n", 194 | "epoch 400: loss 1.1429835557937622, align_loss:0.090755\n", 195 | "epoch 500: loss 1.1412932872772217, align_loss:0.089755\n", 196 | "epoch 600: loss 1.1394864320755005, align_loss:0.088606\n", 197 | "epoch 700: loss 1.1375226974487305, align_loss:0.087426\n", 198 | "epoch 800: loss 1.135468602180481, align_loss:0.086274\n", 199 | "epoch 900: loss 1.1333705186843872, align_loss:0.085182\n", 200 | "Manifold alignment ends! Times: 7.6544740200042725s\n", 201 | "IntegrAO for clustering union 500 samples NMI score: 0.9794302488169736\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "# Initialize integrater\n", 207 | "integrater = integrao_integrater(\n", 208 | " [methyl_temp, expr_temp, protein_temp],\n", 209 | " dataset_name,\n", 210 | " neighbor_size=neighbor_size,\n", 211 | " embedding_dims=embedding_dims,\n", 212 | " fusing_iteration=fusing_iteration,\n", 213 | " normalization_factor=normalization_factor,\n", 214 | " alighment_epochs=alighment_epochs,\n", 215 | " beta=beta,\n", 216 | " mu=mu,\n", 217 | ")\n", 218 | "# data indexing\n", 219 | "fused_networks = integrater.network_diffusion()\n", 220 | "embeds_final, S_final, model = integrater.unsupervised_alignment()\n", 221 | "\n", 222 | "labels = spectral_clustering(S_final, n_clusters=cluster_number)\n", 223 | "\n", 224 | "true_labels = truelabel.sort_values('subjects')['cluster.id'].tolist()\n", 225 | "\n", 226 | "score_all = v_measure_score(true_labels, labels)\n", 227 | "print(\"IntegrAO for clustering union 500 samples NMI score: \", score_all)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 7, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "# save model\n", 237 | "torch.save(model.state_dict(), os.path.join(result_dir, \"model_integrao_unsupervised.pth\"))" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "## Now load the saved model and perform embedding extraction using the trained model" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 8, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "name": "stdout", 254 | "output_type": "stream", 255 | "text": [ 256 | "Start indexing input expression matrices!\n", 257 | "Common sample between view0 and view1: 500\n", 258 | "Common sample between view0 and view2: 500\n", 259 | "Common sample between view1 and view2: 500\n", 260 | "Neighbor size: 20\n", 261 | "Start applying diffusion!\n", 262 | "Diffusion ends! Times: 5.997241497039795s\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "# Network fusion for the whole graph; make sure use the integrao_predictor with the same hyperparameters\n", 268 | "predictor = integrao_predictor(\n", 269 | " [methyl, expr, protein],\n", 270 | " dataset_name,\n", 271 | " modalities_name_list=[\"methyl\", \"expr\", \"protein\"], \n", 272 | " neighbor_size=neighbor_size,\n", 273 | " embedding_dims=embedding_dims,\n", 274 | " fusing_iteration=fusing_iteration,\n", 275 | " normalization_factor=normalization_factor,\n", 276 | " alighment_epochs=alighment_epochs,\n", 277 | " beta=beta,\n", 278 | " mu=mu,\n", 279 | ")\n", 280 | "# data indexing\n", 281 | "fused_networks = predictor.network_diffusion()" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 9, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "Loaded pre-trained model with success.\n", 294 | "IntegrAO for clustering union 500 samples NMI score: 1.0000000000000002\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "# load model and inference for obtaining the patient embeddings\n", 300 | "model_path = os.path.join(result_dir, \"model_integrao_unsupervised.pth\")\n", 301 | "final_embedding_df, S_final = predictor.inference_unsupervised(model_path, new_datasets=[methyl, expr, protein], modalities_names=[\"methyl\", \"expr\", \"protein\"])\n", 302 | "\n", 303 | "labels = spectral_clustering(S_final, n_clusters=cluster_number)\n", 304 | "\n", 305 | "true_labels = truelabel.sort_values('subjects')['cluster.id'].tolist()\n", 306 | "\n", 307 | "score_all = v_measure_score(true_labels, labels)\n", 308 | "print(\"IntegrAO for clustering union 500 samples NMI score: \", score_all)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "## Now extract the feature importance for the unsurvised integration; the extracted feature importance will be saved in the result dir" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "df_list = predictor.interpret_unsupervised(model_path=model_path, result_dir=result_dir, new_datasets=[methyl, expr, protein], modalities_names=[\"methyl\", \"expr\", \"protein\"])" 325 | ] 326 | } 327 | ], 328 | "metadata": { 329 | "kernelspec": { 330 | "display_name": "integrAO", 331 | "language": "python", 332 | "name": "python3" 333 | }, 334 | "language_info": { 335 | "codemirror_mode": { 336 | "name": "ipython", 337 | "version": 3 338 | }, 339 | "file_extension": ".py", 340 | "mimetype": "text/x-python", 341 | "name": "python", 342 | "nbconvert_exporter": "python", 343 | "pygments_lexer": "ipython3", 344 | "version": "3.10.16" 345 | } 346 | }, 347 | "nbformat": 4, 348 | "nbformat_minor": 2 349 | } 350 | --------------------------------------------------------------------------------