├── .gitignore ├── README.md ├── part1-get-data-and-non-graph-modeling-prep.ipynb ├── part2-simple-non-graph-model-and-pca.ipynb ├── part3-prepare-papers-for-import.ipynb ├── part4-prepare-authors-and-inst-for-import.ipynb ├── part5-admin-import.md ├── part6-analysis-in-neo4j-gds.ipynb ├── part7-graph-feature-engineering-in-gds.ipynb └── part8-graph-feature-model.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | scratch/ 3 | .idea/ 4 | .ipynb_checkpoints/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GDS Webinar Demo: Graph Data Science for Really Big Data 2 | 3 | This repo contains demo code from the 2022 GDS February Webinar - "Graph Data Science for Really Big Data". The exact pattern here may vary slightly from what you have seen in the webinar, most of the commands have been placed in notebooks for example, but the overall steps should be the same. 4 | 5 | The purpose of this demo is to explore engineering graph features using Neo4j and the [Graph Data Science (GDS) Library](https://neo4j.com/docs/graph-data-science/current/) on a larger dataset to see if we can improve accuracy for a classification problem. 6 | 7 | The graph used here is the [MAG240M OGB Large-Scale-Challenge Graph](https://ogb.stanford.edu/docs/lsc/mag240m/). It is a heterogeneous academic paper graph that contains around 240 Million Nodes and 1.7 Billion Relationships. 8 | 9 | ## Demo Outline and Notebooks Parts 10 | 11 | This demo walks through multiple steps including running a reference model before using graph, formatting and importing data into Neo4j, analyzing the graph and engineering graph features with GDS, and exporting data to re-run a model with those graph features. 12 | 13 | The demo here is ultimately split up into 8 parts, 7 of which are ipython notebooks. Hopefully the file names are descriptive as to what they cover 14 | 15 | - Parts 1 and 2 focus on understanding the data and running a classification model with available features before leveraging Neo4j/GDS/graph 16 | - `part1-get-data-and-non-graph-modeling-prep.ipynb` 17 | - `part2-simple-non-graph-model-and-pca.ipynb` 18 | 19 | - Parts 3-5 are focused on pre-formatting the data and importing into graph 20 | - `part3-prepare-papers-for-import.ipynb` 21 | - `part4-prepare-authors-and-inst-for-import.ipynb` 22 | - `part5-admin-import.md` 23 | 24 | - Part 6 and 7 focus on work in Neo4j and GDS. Part 6 is mostly inspecting the graph and demoing native projections and the WCC algorithm. Part 7 is focused on actually generating and exporting graph features (FastRP Node Embeddings) 25 | - `part6-analysis-in-neo4j-gds.ipynb` 26 | - `part7-graph-feature-engineering-in-gds.ipynb` 27 | 28 | - Finally Part 8 re-runs the classification model with the graph features (FastRP Node Embeddings). In this very rough exploratory first pass we get an ~9% point increase in classification accuracy. 29 | - `part8-graph-feature-model.ipynb` 30 | 31 | 32 | 33 | ## Prerequisites & Environment for Running the Demo 34 | 35 | ### Software Versions 36 | - Neo4j = Enterprise Edition 4.4.3 37 | - GDS = Enterprise Edition 1.8.3 38 | - APOC = 4.4.0.3 39 | - Python = 3.9.7 40 | 41 | Important Note: Enterprise (as opposed to Community) Editions were used for both the Neo4j Database and GDS library in this demo. The use of GDS Enterprise, in particular, provides high-concurrency and optimized in-memory compression which are not available in Community Edition and key to performance at these scales. 42 | 43 | ### Instance 44 | This demo was run on a single AWS ec2 x1.16xlarge instance (64 vCPUs, 976 GB Memory). 45 | 46 | ### Neo4j Configuration 47 | I tweaked a few things but the below are the most critical which you can update in the neo4j settings/configuration (a.k.a `neo4j.conf`) 48 | 49 | - `dbms.memory.heap.max_size=760G` 50 | - `gds.export.location=/data/neo-export` # or set to whatever directory you would like data exports from Neo4j to go 51 | 52 | Depending on your environment and specific needs you may need to tune this and other configuration like min heap size, pagecache, etc. For more details on optimizing Neo4j configuration for data science and analytics at scale I recommend looking into the [Graph Data Science Configuration Guide](https://neo4j.com/whitepapers/graph-data-science-configuration-guide/). 53 | 54 | 55 | ## Future Experimentation & Improvements 56 | 57 | This demo was just a rough first pass to explore what is possible. There are many ways to improve upon this analysis! Here are just a few areas to experiment: 58 | 59 | 1. Improved tuning of FastRP node Embeddings 60 | 2. Inclusion of more graph features 61 | 3. Streamlined data formatting and ETL 62 | 4. Better-tuned and/or more sophisticated classification models and frameworks 63 | 5. Exploration of Semi-supervised transductive approaches to label the rest of the papers, such as Label Propagation Algorithm (LPA) or K-Nearest Neighbor (KNN) -------------------------------------------------------------------------------- /part1-get-data-and-non-graph-modeling-prep.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "12fe1bee", 6 | "metadata": {}, 7 | "source": [ 8 | "# Part 1: Prepare Data For Non-Graph (\"Flat\") Modeling" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "50a66a9b", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from ogb.lsc import MAG240MDataset\n", 19 | "import numpy as np\n", 20 | "import os\n", 21 | "import pandas as pd" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "ec956515", 27 | "metadata": {}, 28 | "source": [ 29 | "## Notebook Setup" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "70aa26f6", 35 | "metadata": {}, 36 | "source": [ 37 | "Root Directory for data storage. Will be used in following parts as well." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "266ad13e", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "ROOT_DATA_DIR = \"/data\"" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "id": "cc481218", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Directory /data already exists\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "if not os.path.exists(ROOT_DATA_DIR):\n", 66 | " os.mkdir(ROOT_DATA_DIR)\n", 67 | " print(f'Created new directory: {ROOT_DATA_DIR}')\n", 68 | "else:\n", 69 | " print(f'Directory {ROOT_DATA_DIR} already exists')" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "74bdc7db", 75 | "metadata": {}, 76 | "source": [ 77 | "### Get the Dataset Object\n", 78 | "The dataset object handles downloading and easy access to the data and its features. The dataset object leverages [numpy memmap](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html) functionality to reference large pieces of the dataset on disk so it does not need to load all the features into memory at a time. For more information, please see the [OGB MAG240M Page](https://ogb.stanford.edu/kddcup2021/mag240m/).\n", 79 | "\n", 80 | "__Note: This command takes a while in the *first* run (several hours to a day)__ as the source data needs to be download from OGB. Sequential runs should be near instantaneous though.\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "id": "81db468a", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "dataset = MAG240MDataset(root = ROOT_DATA_DIR)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "id": "75a27300", 96 | "metadata": {}, 97 | "source": [ 98 | "## Examine Data Splitting and Labels\n", 99 | "\n", 100 | "Only a fraction of the papers (the arXiv papers) are labeled. An `idx_split` object is provided with indexes mapping the labeled papers to training, validate, and test sets. As we will see below, the training sets have their labels hidden for purposes of previous competition. More information on the data and labeling process can be found at the [OGB MAG240M Page](https://ogb.stanford.edu/kddcup2021/mag240m/)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 5, 106 | "id": "100a55a5", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "#get the indexes for arXiv paper data splits\n", 111 | "split_dict = dataset.get_idx_split()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 6, 117 | "id": "4080b496", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "------------------\n", 125 | "train index size = 1112392\n", 126 | "------------------\n", 127 | "valid index size = 138949\n", 128 | "------------------\n", 129 | "test-whole index size = 146818\n", 130 | "------------------\n", 131 | "test-dev index size = 88092\n", 132 | "------------------\n", 133 | "test-challenge index size = 58726\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "#get the relative sizes of each set\n", 139 | "for i in split_dict.keys():\n", 140 | " print('------------------')\n", 141 | " print(f'{i} index size = {len(split_dict[i])}')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 7, 147 | "id": "7466446c", 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "Paper labels for the \"train\" set:\n", 155 | "--------\n", 156 | "Sample values = [17. 29. 38. 5. 1.]\n", 157 | "Number non-missing = 1112392\n", 158 | "============================\n", 159 | "\n", 160 | "Paper labels for the \"valid\" set:\n", 161 | "--------\n", 162 | "Sample values = [140. 129. 33. 59. 24.]\n", 163 | "Number non-missing = 138949\n", 164 | "============================\n", 165 | "\n", 166 | "Paper labels for the \"test-whole\" set:\n", 167 | "--------\n", 168 | "Sample values = [-1. -1. -1. -1. -1.]\n", 169 | "Number non-missing = 0\n", 170 | "============================\n", 171 | "\n", 172 | "Paper labels for the \"test-dev\" set:\n", 173 | "--------\n", 174 | "Sample values = [-1. -1. -1. -1. -1.]\n", 175 | "Number non-missing = 0\n", 176 | "============================\n", 177 | "\n", 178 | "Paper labels for the \"test-challenge\" set:\n", 179 | "--------\n", 180 | "Sample values = [-1. -1. -1. -1. -1.]\n", 181 | "Number non-missing = 0\n", 182 | "============================\n", 183 | "\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "# Note that we only have known labels in the train and validate sets. \n", 189 | "# A value of -1 implies a hidden label\n", 190 | "for i in split_dict.keys():\n", 191 | " paper_labels = dataset.paper_label[split_dict[i]]\n", 192 | " print(f'Paper labels for the \"{i}\" set:')\n", 193 | " print('--------')\n", 194 | " print(f'Sample values = {paper_labels[:5]}')\n", 195 | " print(f'Number non-missing = {sum(dataset.paper_label[split_dict[i]] > -1)}')\n", 196 | " print('============================\\n')" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "id": "439a60ae", 202 | "metadata": {}, 203 | "source": [ 204 | "## Building a DataFrame for Supervised Model Testing\n", 205 | "\n", 206 | "We will use the 'train' and 'valid' set for pre-graph supervised model analysis\n", 207 | "since they are the only ones with labels" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 8, 213 | "id": "30a40151", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "#get the training set\n", 218 | "feat_cols = [f'paper_encoding_{i}' for i in range(768)]\n", 219 | "paper_df_train = pd.DataFrame(dataset.paper_feat[split_dict['train']], columns = feat_cols)\n", 220 | "paper_df_train['split_segment'] = 'TRAIN'\n", 221 | "paper_df_train['paper_subject'] = dataset.paper_label[split_dict['train']]\n", 222 | "paper_df_train['paper_year'] = dataset.paper_year[split_dict['train']]" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 9, 228 | "id": "9c0f485e", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "#get the validation set\n", 233 | "paper_df_validate = pd.DataFrame(dataset.paper_feat[split_dict['valid']], columns = feat_cols)\n", 234 | "paper_df_validate['split_segment'] = 'VALIDATE'\n", 235 | "paper_df_validate['paper_subject'] = dataset.paper_label[split_dict['valid']]\n", 236 | "paper_df_validate['paper_year'] = dataset.paper_year[split_dict['valid']]" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 10, 242 | "id": "49a7873b", 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "data": { 247 | "text/html": [ 248 | "
\n", 249 | "\n", 262 | "\n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | "
paper_encoding_0paper_encoding_1paper_encoding_2paper_encoding_3paper_encoding_4paper_encoding_5paper_encoding_6paper_encoding_7paper_encoding_8paper_encoding_9...paper_encoding_761paper_encoding_762paper_encoding_763paper_encoding_764paper_encoding_765paper_encoding_766paper_encoding_767split_segmentpaper_subjectpaper_year
00.4384770.2110600.3933110.055969-0.078003-0.0178070.553223-0.3198240.3940430.502930...-0.0524901.0927730.157227-1.467773-1.5908200.3286130.332275TRAIN17.02014
10.468994-0.2026370.0233310.5356450.4965820.0243680.2399900.5395510.4604490.078491...-0.1328121.1259770.368164-0.191406-0.3784180.031616-0.311523TRAIN29.02014
20.047485-0.398682-0.4204100.882324-0.1146850.6079100.1510010.124695-0.012108-0.005211...-0.130127-0.1211550.790527-0.147827-0.4519040.516602-0.135986TRAIN38.02015
3-0.395508-0.464355-0.336670-0.156616-0.396240-0.449951-0.0336300.3930660.552246-0.076782...0.1497801.1337890.3862300.0661620.7607420.355469-0.658691TRAIN5.02005
40.103210-0.1251220.0394900.6518550.2790530.0208280.325439-0.0045280.2644040.178101...0.0568240.4990230.0387880.906250-0.623047-0.1190800.394043TRAIN1.02013
..................................................................
138944-0.254883-0.069885-0.8212891.201172-0.639160-0.3681640.802246-0.0763550.3242190.030365...0.5532230.6977540.3696291.799805-0.534180-0.112244-0.230713VALIDATE51.02019
1389450.667480-0.0464480.1942140.2519530.0037840.4953610.756348-0.065125-0.0717770.123657...-0.4362791.1875000.360596-1.391602-0.752930-0.0689700.195923VALIDATE12.02019
1389460.660645-0.515137-0.7763670.222412-1.0732420.0496520.3352050.2819821.3857420.360840...-0.331787-0.0435490.6098630.0252230.2324220.2113040.060333VALIDATE18.02019
1389470.427246-0.276855-0.2038570.391113-0.368896-0.0910031.0302730.4150390.5068360.121399...-0.1987301.1513670.054382-0.266113-0.6000980.2580570.411377VALIDATE72.02019
1389480.123779-0.914551-0.0647580.073181-0.243286-0.6108400.522461-0.022171-0.6210940.040161...-0.1591800.350342-0.007122-0.407715-0.9765620.5454100.017975VALIDATE142.02019
\n", 556 | "

1251341 rows × 771 columns

\n", 557 | "
" 558 | ], 559 | "text/plain": [ 560 | " paper_encoding_0 paper_encoding_1 paper_encoding_2 \\\n", 561 | "0 0.438477 0.211060 0.393311 \n", 562 | "1 0.468994 -0.202637 0.023331 \n", 563 | "2 0.047485 -0.398682 -0.420410 \n", 564 | "3 -0.395508 -0.464355 -0.336670 \n", 565 | "4 0.103210 -0.125122 0.039490 \n", 566 | "... ... ... ... \n", 567 | "138944 -0.254883 -0.069885 -0.821289 \n", 568 | "138945 0.667480 -0.046448 0.194214 \n", 569 | "138946 0.660645 -0.515137 -0.776367 \n", 570 | "138947 0.427246 -0.276855 -0.203857 \n", 571 | "138948 0.123779 -0.914551 -0.064758 \n", 572 | "\n", 573 | " paper_encoding_3 paper_encoding_4 paper_encoding_5 \\\n", 574 | "0 0.055969 -0.078003 -0.017807 \n", 575 | "1 0.535645 0.496582 0.024368 \n", 576 | "2 0.882324 -0.114685 0.607910 \n", 577 | "3 -0.156616 -0.396240 -0.449951 \n", 578 | "4 0.651855 0.279053 0.020828 \n", 579 | "... ... ... ... \n", 580 | "138944 1.201172 -0.639160 -0.368164 \n", 581 | "138945 0.251953 0.003784 0.495361 \n", 582 | "138946 0.222412 -1.073242 0.049652 \n", 583 | "138947 0.391113 -0.368896 -0.091003 \n", 584 | "138948 0.073181 -0.243286 -0.610840 \n", 585 | "\n", 586 | " paper_encoding_6 paper_encoding_7 paper_encoding_8 \\\n", 587 | "0 0.553223 -0.319824 0.394043 \n", 588 | "1 0.239990 0.539551 0.460449 \n", 589 | "2 0.151001 0.124695 -0.012108 \n", 590 | "3 -0.033630 0.393066 0.552246 \n", 591 | "4 0.325439 -0.004528 0.264404 \n", 592 | "... ... ... ... \n", 593 | "138944 0.802246 -0.076355 0.324219 \n", 594 | "138945 0.756348 -0.065125 -0.071777 \n", 595 | "138946 0.335205 0.281982 1.385742 \n", 596 | "138947 1.030273 0.415039 0.506836 \n", 597 | "138948 0.522461 -0.022171 -0.621094 \n", 598 | "\n", 599 | " paper_encoding_9 ... paper_encoding_761 paper_encoding_762 \\\n", 600 | "0 0.502930 ... -0.052490 1.092773 \n", 601 | "1 0.078491 ... -0.132812 1.125977 \n", 602 | "2 -0.005211 ... -0.130127 -0.121155 \n", 603 | "3 -0.076782 ... 0.149780 1.133789 \n", 604 | "4 0.178101 ... 0.056824 0.499023 \n", 605 | "... ... ... ... ... \n", 606 | "138944 0.030365 ... 0.553223 0.697754 \n", 607 | "138945 0.123657 ... -0.436279 1.187500 \n", 608 | "138946 0.360840 ... -0.331787 -0.043549 \n", 609 | "138947 0.121399 ... -0.198730 1.151367 \n", 610 | "138948 0.040161 ... -0.159180 0.350342 \n", 611 | "\n", 612 | " paper_encoding_763 paper_encoding_764 paper_encoding_765 \\\n", 613 | "0 0.157227 -1.467773 -1.590820 \n", 614 | "1 0.368164 -0.191406 -0.378418 \n", 615 | "2 0.790527 -0.147827 -0.451904 \n", 616 | "3 0.386230 0.066162 0.760742 \n", 617 | "4 0.038788 0.906250 -0.623047 \n", 618 | "... ... ... ... \n", 619 | "138944 0.369629 1.799805 -0.534180 \n", 620 | "138945 0.360596 -1.391602 -0.752930 \n", 621 | "138946 0.609863 0.025223 0.232422 \n", 622 | "138947 0.054382 -0.266113 -0.600098 \n", 623 | "138948 -0.007122 -0.407715 -0.976562 \n", 624 | "\n", 625 | " paper_encoding_766 paper_encoding_767 split_segment paper_subject \\\n", 626 | "0 0.328613 0.332275 TRAIN 17.0 \n", 627 | "1 0.031616 -0.311523 TRAIN 29.0 \n", 628 | "2 0.516602 -0.135986 TRAIN 38.0 \n", 629 | "3 0.355469 -0.658691 TRAIN 5.0 \n", 630 | "4 -0.119080 0.394043 TRAIN 1.0 \n", 631 | "... ... ... ... ... \n", 632 | "138944 -0.112244 -0.230713 VALIDATE 51.0 \n", 633 | "138945 -0.068970 0.195923 VALIDATE 12.0 \n", 634 | "138946 0.211304 0.060333 VALIDATE 18.0 \n", 635 | "138947 0.258057 0.411377 VALIDATE 72.0 \n", 636 | "138948 0.545410 0.017975 VALIDATE 142.0 \n", 637 | "\n", 638 | " paper_year \n", 639 | "0 2014 \n", 640 | "1 2014 \n", 641 | "2 2015 \n", 642 | "3 2005 \n", 643 | "4 2013 \n", 644 | "... ... \n", 645 | "138944 2019 \n", 646 | "138945 2019 \n", 647 | "138946 2019 \n", 648 | "138947 2019 \n", 649 | "138948 2019 \n", 650 | "\n", 651 | "[1251341 rows x 771 columns]" 652 | ] 653 | }, 654 | "execution_count": 10, 655 | "metadata": {}, 656 | "output_type": "execute_result" 657 | } 658 | ], 659 | "source": [ 660 | "#join\n", 661 | "paper_df = pd.concat([paper_df_train, paper_df_validate])\n", 662 | "paper_df" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 11, 668 | "id": "f6b3f1af", 669 | "metadata": {}, 670 | "outputs": [], 671 | "source": [ 672 | "#write to Parquet so we do not need to repeat this process...keep the index\n", 673 | "paper_df.to_parquet(ROOT_DATA_DIR + \"/ogb-labeled-papers.parquet\", engine='fastparquet', index=True)" 674 | ] 675 | } 676 | ], 677 | "metadata": { 678 | "kernelspec": { 679 | "display_name": "Python 3 (ipykernel)", 680 | "language": "python", 681 | "name": "python3" 682 | }, 683 | "language_info": { 684 | "codemirror_mode": { 685 | "name": "ipython", 686 | "version": 3 687 | }, 688 | "file_extension": ".py", 689 | "mimetype": "text/x-python", 690 | "name": "python", 691 | "nbconvert_exporter": "python", 692 | "pygments_lexer": "ipython3", 693 | "version": "3.9.7" 694 | } 695 | }, 696 | "nbformat": 4, 697 | "nbformat_minor": 5 698 | } 699 | -------------------------------------------------------------------------------- /part2-simple-non-graph-model-and-pca.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "750e7f38", 6 | "metadata": {}, 7 | "source": [ 8 | "# Part 2: \"Flat\" Model - Logistic Regression with RoBERTa Encodings and PCA" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "6664d358", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "from sklearn.linear_model import LogisticRegression" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "70cbb3fa", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "ROOT_DATA_DIR = \"/data\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "id": "6ea8f478", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "data": { 40 | "text/html": [ 41 | "
\n", 42 | "\n", 55 | "\n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | "
paper_encoding_0paper_encoding_1paper_encoding_2paper_encoding_3paper_encoding_4paper_encoding_5paper_encoding_6paper_encoding_7paper_encoding_8paper_encoding_9...paper_encoding_761paper_encoding_762paper_encoding_763paper_encoding_764paper_encoding_765paper_encoding_766paper_encoding_767split_segmentpaper_subjectpaper_year
index
00.4384770.2110600.3933110.055969-0.078003-0.0178070.553223-0.3198240.3940430.502930...-0.0524901.0927730.157227-1.467773-1.5908200.3286130.332275TRAIN17.02014
10.468994-0.2026370.0233310.5356450.4965820.0243680.2399900.5395510.4604490.078491...-0.1328121.1259770.368164-0.191406-0.3784180.031616-0.311523TRAIN29.02014
20.047485-0.398682-0.4204100.882324-0.1146850.6079100.1510010.124695-0.012108-0.005211...-0.130127-0.1211550.790527-0.147827-0.4519040.516602-0.135986TRAIN38.02015
3-0.395508-0.464355-0.336670-0.156616-0.396240-0.449951-0.0336300.3930660.552246-0.076782...0.1497801.1337890.3862300.0661620.7607420.355469-0.658691TRAIN5.02005
40.103210-0.1251220.0394900.6518550.2790530.0208280.325439-0.0045280.2644040.178101...0.0568240.4990230.0387880.906250-0.623047-0.1190800.394043TRAIN1.02013
..................................................................
138944-0.254883-0.069885-0.8212891.201172-0.639160-0.3681640.802246-0.0763550.3242190.030365...0.5532230.6977540.3696291.799805-0.534180-0.112244-0.230713VALIDATE51.02019
1389450.667480-0.0464480.1942140.2519530.0037840.4953610.756348-0.065125-0.0717770.123657...-0.4362791.1875000.360596-1.391602-0.752930-0.0689700.195923VALIDATE12.02019
1389460.660645-0.515137-0.7763670.222412-1.0732420.0496520.3352050.2819821.3857420.360840...-0.331787-0.0435490.6098630.0252230.2324220.2113040.060333VALIDATE18.02019
1389470.427246-0.276855-0.2038570.391113-0.368896-0.0910031.0302730.4150390.5068360.121399...-0.1987301.1513670.054382-0.266113-0.6000980.2580570.411377VALIDATE72.02019
1389480.123779-0.914551-0.0647580.073181-0.243286-0.6108400.522461-0.022171-0.6210940.040161...-0.1591800.350342-0.007122-0.407715-0.9765620.5454100.017975VALIDATE142.02019
\n", 373 | "

1251341 rows × 771 columns

\n", 374 | "
" 375 | ], 376 | "text/plain": [ 377 | " paper_encoding_0 paper_encoding_1 paper_encoding_2 \\\n", 378 | "index \n", 379 | "0 0.438477 0.211060 0.393311 \n", 380 | "1 0.468994 -0.202637 0.023331 \n", 381 | "2 0.047485 -0.398682 -0.420410 \n", 382 | "3 -0.395508 -0.464355 -0.336670 \n", 383 | "4 0.103210 -0.125122 0.039490 \n", 384 | "... ... ... ... \n", 385 | "138944 -0.254883 -0.069885 -0.821289 \n", 386 | "138945 0.667480 -0.046448 0.194214 \n", 387 | "138946 0.660645 -0.515137 -0.776367 \n", 388 | "138947 0.427246 -0.276855 -0.203857 \n", 389 | "138948 0.123779 -0.914551 -0.064758 \n", 390 | "\n", 391 | " paper_encoding_3 paper_encoding_4 paper_encoding_5 \\\n", 392 | "index \n", 393 | "0 0.055969 -0.078003 -0.017807 \n", 394 | "1 0.535645 0.496582 0.024368 \n", 395 | "2 0.882324 -0.114685 0.607910 \n", 396 | "3 -0.156616 -0.396240 -0.449951 \n", 397 | "4 0.651855 0.279053 0.020828 \n", 398 | "... ... ... ... \n", 399 | "138944 1.201172 -0.639160 -0.368164 \n", 400 | "138945 0.251953 0.003784 0.495361 \n", 401 | "138946 0.222412 -1.073242 0.049652 \n", 402 | "138947 0.391113 -0.368896 -0.091003 \n", 403 | "138948 0.073181 -0.243286 -0.610840 \n", 404 | "\n", 405 | " paper_encoding_6 paper_encoding_7 paper_encoding_8 \\\n", 406 | "index \n", 407 | "0 0.553223 -0.319824 0.394043 \n", 408 | "1 0.239990 0.539551 0.460449 \n", 409 | "2 0.151001 0.124695 -0.012108 \n", 410 | "3 -0.033630 0.393066 0.552246 \n", 411 | "4 0.325439 -0.004528 0.264404 \n", 412 | "... ... ... ... \n", 413 | "138944 0.802246 -0.076355 0.324219 \n", 414 | "138945 0.756348 -0.065125 -0.071777 \n", 415 | "138946 0.335205 0.281982 1.385742 \n", 416 | "138947 1.030273 0.415039 0.506836 \n", 417 | "138948 0.522461 -0.022171 -0.621094 \n", 418 | "\n", 419 | " paper_encoding_9 ... paper_encoding_761 paper_encoding_762 \\\n", 420 | "index ... \n", 421 | "0 0.502930 ... -0.052490 1.092773 \n", 422 | "1 0.078491 ... -0.132812 1.125977 \n", 423 | "2 -0.005211 ... -0.130127 -0.121155 \n", 424 | "3 -0.076782 ... 0.149780 1.133789 \n", 425 | "4 0.178101 ... 0.056824 0.499023 \n", 426 | "... ... ... ... ... \n", 427 | "138944 0.030365 ... 0.553223 0.697754 \n", 428 | "138945 0.123657 ... -0.436279 1.187500 \n", 429 | "138946 0.360840 ... -0.331787 -0.043549 \n", 430 | "138947 0.121399 ... -0.198730 1.151367 \n", 431 | "138948 0.040161 ... -0.159180 0.350342 \n", 432 | "\n", 433 | " paper_encoding_763 paper_encoding_764 paper_encoding_765 \\\n", 434 | "index \n", 435 | "0 0.157227 -1.467773 -1.590820 \n", 436 | "1 0.368164 -0.191406 -0.378418 \n", 437 | "2 0.790527 -0.147827 -0.451904 \n", 438 | "3 0.386230 0.066162 0.760742 \n", 439 | "4 0.038788 0.906250 -0.623047 \n", 440 | "... ... ... ... \n", 441 | "138944 0.369629 1.799805 -0.534180 \n", 442 | "138945 0.360596 -1.391602 -0.752930 \n", 443 | "138946 0.609863 0.025223 0.232422 \n", 444 | "138947 0.054382 -0.266113 -0.600098 \n", 445 | "138948 -0.007122 -0.407715 -0.976562 \n", 446 | "\n", 447 | " paper_encoding_766 paper_encoding_767 split_segment paper_subject \\\n", 448 | "index \n", 449 | "0 0.328613 0.332275 TRAIN 17.0 \n", 450 | "1 0.031616 -0.311523 TRAIN 29.0 \n", 451 | "2 0.516602 -0.135986 TRAIN 38.0 \n", 452 | "3 0.355469 -0.658691 TRAIN 5.0 \n", 453 | "4 -0.119080 0.394043 TRAIN 1.0 \n", 454 | "... ... ... ... ... \n", 455 | "138944 -0.112244 -0.230713 VALIDATE 51.0 \n", 456 | "138945 -0.068970 0.195923 VALIDATE 12.0 \n", 457 | "138946 0.211304 0.060333 VALIDATE 18.0 \n", 458 | "138947 0.258057 0.411377 VALIDATE 72.0 \n", 459 | "138948 0.545410 0.017975 VALIDATE 142.0 \n", 460 | "\n", 461 | " paper_year \n", 462 | "index \n", 463 | "0 2014 \n", 464 | "1 2014 \n", 465 | "2 2015 \n", 466 | "3 2005 \n", 467 | "4 2013 \n", 468 | "... ... \n", 469 | "138944 2019 \n", 470 | "138945 2019 \n", 471 | "138946 2019 \n", 472 | "138947 2019 \n", 473 | "138948 2019 \n", 474 | "\n", 475 | "[1251341 rows x 771 columns]" 476 | ] 477 | }, 478 | "execution_count": 3, 479 | "metadata": {}, 480 | "output_type": "execute_result" 481 | } 482 | ], 483 | "source": [ 484 | "papers_df = pd.read_parquet(ROOT_DATA_DIR + \"/ogb-labeled-papers.parquet\", engine='fastparquet')\n", 485 | "papers_df" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "id": "9fdbbff1", 491 | "metadata": {}, 492 | "source": [ 493 | "## Data Split and Subject Label Stats" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 4, 499 | "id": "c7b27e10", 500 | "metadata": {}, 501 | "outputs": [ 502 | { 503 | "data": { 504 | "text/html": [ 505 | "
\n", 506 | "\n", 519 | "\n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | "
paper_encoding_0
split_segment
TRAIN1112392
VALIDATE138949
\n", 541 | "
" 542 | ], 543 | "text/plain": [ 544 | " paper_encoding_0\n", 545 | "split_segment \n", 546 | "TRAIN 1112392\n", 547 | "VALIDATE 138949" 548 | ] 549 | }, 550 | "execution_count": 4, 551 | "metadata": {}, 552 | "output_type": "execute_result" 553 | } 554 | ], 555 | "source": [ 556 | "papers_df[['split_segment', 'paper_encoding_0']].groupby('split_segment').count()" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 5, 562 | "id": "3e94ffa3", 563 | "metadata": {}, 564 | "outputs": [ 565 | { 566 | "data": { 567 | "text/html": [ 568 | "
\n", 569 | "\n", 582 | "\n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | "
paper_encoding_0
paper_subject
0.028041
1.02856
2.03907
3.01530
4.01910
......
148.0865
149.0815
150.0837
151.022696
152.01139
\n", 640 | "

153 rows × 1 columns

\n", 641 | "
" 642 | ], 643 | "text/plain": [ 644 | " paper_encoding_0\n", 645 | "paper_subject \n", 646 | "0.0 28041\n", 647 | "1.0 2856\n", 648 | "2.0 3907\n", 649 | "3.0 1530\n", 650 | "4.0 1910\n", 651 | "... ...\n", 652 | "148.0 865\n", 653 | "149.0 815\n", 654 | "150.0 837\n", 655 | "151.0 22696\n", 656 | "152.0 1139\n", 657 | "\n", 658 | "[153 rows x 1 columns]" 659 | ] 660 | }, 661 | "execution_count": 5, 662 | "metadata": {}, 663 | "output_type": "execute_result" 664 | } 665 | ], 666 | "source": [ 667 | "papers_df[['paper_subject', 'paper_encoding_0']].groupby('paper_subject').count()" 668 | ] 669 | }, 670 | { 671 | "cell_type": "markdown", 672 | "id": "88beb827", 673 | "metadata": {}, 674 | "source": [ 675 | "## Logistic Regression Using Entire 768 Dimensional Encoding\n", 676 | "\n", 677 | "\n", 678 | "As a first pass we will try to fit this model with simple logistic regression using just the 768 dimensional RoBERTa encoding vectors as features. \n", 679 | "\n", 680 | "__Note: this model fitting step can take a while (several hours) to complete__\n", 681 | "\n", 682 | "We will get convergence warnings when running the below model model. I tried various different parameters to try and avoid this in sklearn but could not seem to do so. In a more rigorous setting I would recommend looking deeper into tuning parameters, different model types, different machine learning libraries/frameworks, etc. But for purposes of this demo we are just trying to get an initial rough benchmark. In the following sections we will apply a very simple solution of dimensionality reduction with Principal Components Analysis (PCA) to see the effect on results. " 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 6, 688 | "id": "aa8a82dc", 689 | "metadata": {}, 690 | "outputs": [], 691 | "source": [ 692 | "papers_df = papers_df.astype({'paper_subject':'int32'})" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 7, 698 | "id": "4863519d", 699 | "metadata": {}, 700 | "outputs": [], 701 | "source": [ 702 | "X = papers_df[['paper_encoding_' + str(x) for x in range(768)]]\n", 703 | "y = papers_df.paper_subject" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": 8, 709 | "id": "486dfa78", 710 | "metadata": {}, 711 | "outputs": [], 712 | "source": [ 713 | "X_train = X[papers_df.split_segment == \"TRAIN\"]\n", 714 | "X_validate = X[papers_df.split_segment == \"VALIDATE\"]\n", 715 | "y_train = y[papers_df.split_segment == \"TRAIN\"]\n", 716 | "y_validate = y[papers_df.split_segment == \"VALIDATE\"]" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": 9, 722 | "id": "e3d95251", 723 | "metadata": {}, 724 | "outputs": [], 725 | "source": [ 726 | "model = LogisticRegression(multi_class='ovr', solver='saga', n_jobs=60, max_iter=200)" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 10, 732 | "id": "91cba657", 733 | "metadata": {}, 734 | "outputs": [ 735 | { 736 | "name": "stderr", 737 | "output_type": "stream", 738 | "text": [ 739 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 740 | " warnings.warn(\n", 741 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 742 | " warnings.warn(\n", 743 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 744 | " warnings.warn(\n", 745 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 746 | " warnings.warn(\n", 747 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 748 | " warnings.warn(\n", 749 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 750 | " warnings.warn(\n", 751 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 752 | " warnings.warn(\n", 753 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 754 | " warnings.warn(\n", 755 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 756 | " warnings.warn(\n", 757 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 758 | " warnings.warn(\n", 759 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 760 | " warnings.warn(\n", 761 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 762 | " warnings.warn(\n", 763 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 764 | " warnings.warn(\n", 765 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 766 | " warnings.warn(\n", 767 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 768 | " warnings.warn(\n", 769 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 770 | " warnings.warn(\n", 771 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 772 | " warnings.warn(\n", 773 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 774 | " warnings.warn(\n", 775 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 776 | " warnings.warn(\n", 777 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 778 | " warnings.warn(\n", 779 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 780 | " warnings.warn(\n", 781 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 782 | " warnings.warn(\n", 783 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 784 | " warnings.warn(\n", 785 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 786 | " warnings.warn(\n", 787 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 788 | " warnings.warn(\n", 789 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 790 | " warnings.warn(\n", 791 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 792 | " warnings.warn(\n", 793 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 794 | " warnings.warn(\n", 795 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 796 | " warnings.warn(\n", 797 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 798 | " warnings.warn(\n", 799 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 800 | " warnings.warn(\n", 801 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 802 | " warnings.warn(\n", 803 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 804 | " warnings.warn(\n", 805 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 806 | " warnings.warn(\n", 807 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/sklearn/linear_model/_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n", 808 | " warnings.warn(\n" 809 | ] 810 | }, 811 | { 812 | "data": { 813 | "text/plain": [ 814 | "LogisticRegression(max_iter=200, multi_class='ovr', n_jobs=60, solver='saga')" 815 | ] 816 | }, 817 | "execution_count": 10, 818 | "metadata": {}, 819 | "output_type": "execute_result" 820 | } 821 | ], 822 | "source": [ 823 | "#Note: This can take a while (several hours)\n", 824 | "model.fit(X_train, y_train)" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "execution_count": 11, 830 | "id": "a826e1b8", 831 | "metadata": {}, 832 | "outputs": [ 833 | { 834 | "name": "stdout", 835 | "output_type": "stream", 836 | "text": [ 837 | "Accuracy of logistic regression classifier on VALIDATE set: 0.49\n" 838 | ] 839 | } 840 | ], 841 | "source": [ 842 | "print('Accuracy of logistic regression classifier on VALIDATE set: {:.2f}'\\\n", 843 | " .format(model.score(X_validate, y_validate)))" 844 | ] 845 | }, 846 | { 847 | "cell_type": "markdown", 848 | "id": "67621ceb", 849 | "metadata": {}, 850 | "source": [ 851 | "## Reducing Dimensionality with Principal Components Analysis (PCA)" 852 | ] 853 | }, 854 | { 855 | "cell_type": "code", 856 | "execution_count": 12, 857 | "id": "1cea5185", 858 | "metadata": {}, 859 | "outputs": [ 860 | { 861 | "data": { 862 | "text/plain": [ 863 | "PCA()" 864 | ] 865 | }, 866 | "execution_count": 12, 867 | "metadata": {}, 868 | "output_type": "execute_result" 869 | } 870 | ], 871 | "source": [ 872 | "from sklearn.decomposition import PCA\n", 873 | "pca = PCA()\n", 874 | "pca.fit(X_train)" 875 | ] 876 | }, 877 | { 878 | "cell_type": "code", 879 | "execution_count": 13, 880 | "id": "0bcbf2b2", 881 | "metadata": {}, 882 | "outputs": [ 883 | { 884 | "data": { 885 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEWCAYAAACe8xtsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAl20lEQVR4nO3de5wcVZ338c83CUlIuCdZRIEMV9moqDAgNxVFEVDJwy4I7LCAopFVvCzeQFweYM3u4u6CogTMgoAwchFveRBECIqIChmQOwQCBAKyEgIkIQFy+z1/1Omk05npqZpMdffMfN+vV7266lRV16/n9ptT59Q5igjMzMyKGNbsAMzMbOBx8jAzs8KcPMzMrDAnDzMzK8zJw8zMCnPyMDOzwpw8zFqcpOMl/b7ZcZhVc/KwIUfSfpL+IGmhpBcl3S5pjybHdIak5ZJekfRyim/vPrzPbyV9sowYzao5ediQImkT4Drgu8AWwJuAM4HXC77PiP6PjqsjYiNgAvB74KeSVMJ1zNabk4cNNTsDRMSVEbEyIl6NiF9HxH2VAyR9StLDkhZLekjSbql8rqSvSboPWCJphKS9Ui3hZUn3Stq/6n02lXSxpOckPSvpm5KG9xZgRCwHLgPeAIyr3S9pH0mzUs1plqR9UvlU4N3A91IN5nvr84Uyq8fJw4aaR4GVki6TdLCkzat3SjoCOAM4FtgEOBRYUHXI0cCHgc2ALYFfAt8kq8V8GfiJpAnp2EuBFcCOwDuBA4FebylJGgUcD8yLiBdq9m2RrnkeWWI5B/ilpHERcRpwG3BSRGwUESf1/uUw6xsnDxtSImIRsB8QwP8A8yXNkLRlOuSTwLciYlZk5kTEU1VvcV5EzIuIV4FjgOsj4vqIWBURNwFdwCHp/Q4BvhgRSyLieeBc4Kg64X1M0svAPGB34LBujvkw8FhEXB4RKyLiSuAR4KN9+4qY9U0Z923NWlpEPEz2nz2SdgGuAL5NVqvYBni8zunzqtYnAkdIqv7DvQHwm7RvA+C5qmaLYTXn17omIo7pJfw3Ak/VlD1F1nZj1jBOHjakRcQjki4FPp2K5gE71Dulan0ecHlEfKr2IElbkTXCj4+IFf0ULsBfyBJTtW2BX3UTn1lpfNvKhhRJu0j6kqSt0/Y2ZDWOP6VDLgK+LGl3ZXaUVPvHuuIK4KOSPiRpuKTRkvaXtHVEPAf8GvhvSZtIGiZpB0nvXc+PcD2ws6R/SA32RwKTyHqQAfwV2H49r2HWKycPG2oWA+8C7pC0hCxpPAB8CSAifgxMBX6Ujv05WWP4OiJiHjAZ+Down6wm8hXW/F4dC4wEHgJeAq4Ftlqf4CNiAfCRFO8C4KvAR6oa1r8DHC7pJUnnrc+1zOqRJ4MyM7OiXPMwM7PCnDzMzKwwJw8zMyvMycPMzAobNM95jB8/Ptra2podhpnZgHLXXXe9EBETej9ybYMmebS1tdHV1dXsMMzMBhRJtSMW5OLbVmZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJNHZye0tcGwYdlrZ2ezIzIza3mDpqtun3R2wpQpsHRptv3UU9k2QEdH8+IyM2txQ7vmcdppaxJHxdKlWbmZmfVoaCePp58uVm5mZsBQTx7bblus3MzMgKGePKZOhQ02WLtszJis3MzMejS0k0dHBxx55JrtiRNh+nQ3lpuZ9WJoJw+APffMXk86CebOdeIwM8vByWNY+hKsWtXcOMzMBhAnDycPM7PCnDyk7NXJw8wsNycP1zzMzApz8qgkj4jmxmFmNoCUmjwkHSRptqQ5kk7pZv8oSVen/XdIakvlG0i6TNL9kh6WdGppQbrmYWZWWGnJQ9Jw4HzgYGAScLSkSTWHnQC8FBE7AucCZ6fyI4BREfE2YHfg05XE0u+cPMzMCiuz5rEnMCcinoiIZcBVwOSaYyYDl6X1a4EDJAkIYKykEcCGwDJgUSlROnmYmRVWZvJ4EzCvavuZVNbtMRGxAlgIjCNLJEuA54Cngf+KiBdrLyBpiqQuSV3z58/vW5ROHmZmhbVqg/mewErgjcB2wJckbV97UERMj4j2iGifMGFC367k5GFmVliZyeNZYJuq7a1TWbfHpFtUmwILgH8AfhURyyPieeB2oL2UKJ08zMwKKzN5zAJ2krSdpJHAUcCMmmNmAMel9cOBWyIiyG5VvR9A0lhgL+CRUqJ08jAzK6y05JHaME4CbgQeBq6JiAclnSXp0HTYxcA4SXOAk4FKd97zgY0kPUiWhC6JiPtKCdTJw8yssFLnMI+I64Hra8pOr1p/jaxbbu15r3RXXgonDzOzwlq1wbxxnDzMzArrseYh6X6y5y26FRG7lhJRozl5mJkVVu+21UfS62fT6+XpdXDNluTkYWZWWI/JIyKeApD0wYh4Z9WuUyTdzZrG7YHNycPMrLA8bR6StG/Vxj45zxsYnDzMzArL09vqBOAHkjZN2y8DnygtokZz8jAzK6zX5BERdwFvrySPiFhYelSN5JkEzcwK6/X2k6QtJV0MXBURCyVNknRCA2JrDNc8zMwKy9N2cSnZU+JvTNuPAl8sKZ7G80yCZmaF5Uke4yPiGmAVrB52ZGWpUTWSax5mZoXlSR5LJI0jPTAoaS+yeTcGBycPM7PC8vS2Opls9NsdJN0OTCAbAXdwcPIwMyssT2+ruyW9F3gzIGB2RCwvPbJGcfIwMyss76i6ewJt6fjdJBERPywtqkZy8jAzK6zX5CHpcmAH4B7WNJQH4ORhZjZE5al5tAOT0gx/g4+Th5lZYXl6Wz0AvKHsQJrGycPMrLA8NY/xwEOS7gRerxRGxKE9nzKAOHmYmRWWJ3mcUXYQTeXkYWZWWJ6uurc2IpCmcfIwMyus3jS0v4+I/SQtZu3paAVERGxSenSN4ORhZlZYvZkE90uvGzcunCZw8jAzKyzvQ4JI+htgdGU7Ip4uJaJGc/IwMyssz3weh0p6DHgSuBWYC9xQclyN48mgzMwKy/Ocx78CewGPRsR2wAHAn0qNqpFc8zAzKyxP8lgeEQuAYZKGRcRvyJ46Hxw8GZSZWWF52jxelrQR8DugU9LzwJJyw2og1zzMzArLU/OYDLwK/DPwK+Bx4KNlBtVQTh5mZoXleUiwupZxWYmxNIeTh5lZYfUeEuz24UD8kKCZ2ZBX7yHBwf1wYIWTh5lZYbkeEpS0G7AfWc3j9xHx51KjaiQnDzOzwvI8JHg6WVvHOLLh2S+V9I2yA2sYJw8zs8Ly1Dw6gLdHxGsAkv6DbErab5YYV+M4eZiZFZanq+5fqBrTChgFPFtOOE1w7bXZ68KF0NYGnZ1NDcfMbCDIU/NYCDwo6SayNo8PAndKOg8gIj5fYnzl6uyEz1eF/9RTMGVKtt7R0ZyYzMwGAEUvw3JIOq7e/ohoiWc/2tvbo6urq9hJbW1Zwqg1cSLMndsfYZmZtTRJd0VE4SGn8tQ8boiI52su9uaImJ0jqIOA7wDDgYsi4j9q9o8CfgjsDiwAjoyIuWnfrsD3gU2AVcAelXaXfvN0D6PK91RuZmZAvjaP2yR9rLIh6UvAz3o7SdJw4HzgYGAScLSkSTWHnQC8FBE7AucCZ6dzRwBXACdGxFuA/YHlOWItZttti5WbmRmQL3nsD/yjpB9L+h2wM7BnjvP2BOZExBMRsQy4imycrGqTWTPkybXAAZIEHAjcFxH3AkTEgohYmeOaxUydCmPGrF02ZkxWbmZmPeo1eUTEc2QDIu4NtAGXRcQrOd77TcC8qu1nUlm3x0TECrLG+XFkCSok3SjpbklfzXG94jo64MIL12xPnAjTp7ux3MysF722eUi6may77luBbYCLJf0uIr5cclz7AXsAS4GZqVFnZk1sU4ApANv29VbTMcfAscdm608+uWZmQTMz61Ge21bfi4hjI+LliLgf2IeshtCbZ8mSTcXWrPt8yOpjUjvHpmQN588Av4uIFyJiKXA9sFvtBSJiekS0R0T7hAkTcoTUjepk4QmhzMxy6TF5SNoFICJ+nnpFkbZXADfleO9ZwE6StpM0EjgKmFFzzAyg0hX4cOCWyPoO3wi8TdKYlFTeCzyU8zMV59kEzcwKqVfz+FHV+h9r9k3r7Y1TkjmJLBE8DFwTEQ9KOkvSoemwi4FxkuYAJwOnpHNfAs4hS0D3AHdHxC97/zh95CFKzMwKqdfmoR7Wu9vuVkRcT3bLqbrs9Kr114Ajejj3CrLuuuVz8jAzK6RezSN6WO9ue2Bz8jAzK6RezWPrNH6VqtZJ27Vdbgc2Jw8zs0LqJY+vVK3XDhpVcBCpFufkYWZWSL1paFtiwMOGcPIwMyskz3Meg5+Th5lZIU4e4ORhZlaQkwc4eZiZFdRr8pC0s6SZkh5I27tK+kb5oTWQk4eZWSF5ah7/A5xKmk8jIu4jG2pk8HDyMDMrJE/yGBMRd9aUrSgjmKZx8jAzKyRP8nhB0g6kp8olHQ48V2pUjebkYWZWSJ45zD8LTAd2kfQs8CRwTKlRNVplWHYnDzOzXHpNHhHxBPABSWOBYRGxuPywGsw1DzOzQvL0tvo3SZtFxJKIWCxpc0nfbERwDePkYWZWSJ42j4Mj4uXKRppr45DSImoGTwZlZlZInuQxvHomQUkbAqPqHD/wuOZhZlZIngbzTmCmpEvS9seBwTVoopOHmVkheRrMz5Z0H3BAKvrXiLix3LAazMnDzKyQPDUPIuIG4IaSY2keJw8zs0Ly9Lb6O0mPSVooaZGkxZIWNSK4hnHyMDMrJE/N41vARyPi4bKDaRonDzOzQvL0tvrroE4c4ORhZlZQnppHl6SrgZ8Dr1cKI+KnZQXVcE4eZmaF5EkemwBLgQOrygJw8jAzG6LydNX9eCMCaSonDzOzQnpNHpJGAycAbwFGV8oj4hMlxtVYTh5mZoXkaTC/HHgD8CHgVmBrYHCNrOvkYWZWSJ7ksWNE/AuwJCIuAz4MvKvcsBrMycPMrJA8yWN5en1Z0luBTYG/KS+kJvBkUGZmheTpbTVd0ubAvwAzgI2A00uNqtFc8zAzKyRPb6uL0uqtwPblhtMkTh5mZoX0mDwkHRMRV0g6ubv9EXFOeWE1mJOHmVkh9WoeY9Prxo0IpKk8k6CZWSE9Jo+I+L6k4cCiiDi3gTE1nmseZmaF1O1tFRErgaMbFEvzOHmYmRWSp7fV7ZK+B1wNLKkURsTdpUXVaE4eZmaF5Eke70ivZ1WVBfD+fo+mWZw8zMwKydNV932NCKRpOjvhppuy9U99Cl55BTo6mhuTmVmLy/OEOZI+LOmrkk6vLDnPO0jSbElzJJ3Szf5Rkq5O+++Q1Fazf1tJr0j6cq5PU1RnJ0yZAq++mm3Pn59td3aWcjkzs8EizxzmFwJHAp8DBBwBTMxx3nDgfOBgYBJwtKRJNYedALwUETsC5wJn1+w/B7iht2v12WmnwdKla5ctXZqVm5lZj/LUPPaJiGPJ/sifCewN7JzjvD2BORHxREQsA64CJtccMxm4LK1fCxwgZQNNSfo/wJPAgzmu1TdPP12s3MzMgHzJI93TYamkN5INlLhVjvPeBMyr2n4mlXV7TESsABYC4yRtBHwNOLPeBSRNkdQlqWv+/Pk5Qqqx7bbFys3MDMiXPK6TtBnwn8DdwFzgRyXGBHAGcG5EvFLvoIiYHhHtEdE+YcKE4leZOhXGjFm7bMyYrNzMzHpUb2yr68mSROWP+E8kXQeMjoiFOd77WWCbqu2tU1l3xzwjaQTZcO8LyOYLOVzSt4DNgFWSXouI7+X7WDlVelVNmbKm7WPDDfv1EmZmg1G9msf3ySZ+ekLSNZIOAyJn4gCYBewkaTtJI4GjyIZ0rzYDOC6tHw7cEpl3R0RbRLQB3wb+rd8TR7UVK9asL1jgHldmZr3oMXlExC8i4migDfgJcCzwtKRLJH2wtzdObRgnATcCDwPXRMSDks6SdGg67GKyNo45wMnAOt15S3faabBs2dpl7nFlZlaXosBIspJ2JesdtWtEDC8tqj5ob2+Prq6u4icOG9b9aLqSnzg3s0FP0l0R0V70vDzPeWwp6XOSbgd+TlaT2K14iC3KPa7MzArrMXlI+pSkW8h6WO0EfCUito+IUyLi3oZFWLapU2H06LXL3OPKzKyuemNb7Q38OzAzIgbv/ZuOjuyhwK9/PdueODFLHB7fysysR/UazD8RETcN6sRRcdhh2eub3wxz5zpxmJn1ItfAiIPeBhtkr7W9rszMrFv12jy2a2QgTTVyZPbq5GFmlku9mse1AJJmNiiW5nHyMDMrpF6D+TBJXwd2lnRy7c6IOKe8sBqscttq+fLmxmFmNkDUq3kcBawkSzAbd7MMHq55mJkV0mPNIyJmA2dLui8iypuQqRU4eZiZFZKnt9UfJJ1TmTdD0n9L2rT0yBqpcttqxYruhyoxM7O15EkePwAWAx9LyyLgkjKDajgJRqRKmNs9zMx6Va/BvGKHiPj7qu0zJd1TUjzNM3JkVvNYtmzNbSwzM+tWrmloJe1X2ZC0L2umph0cOjvh1fSR/vZvPZeHmVkv8tQ8TgR+WNXO8RJrJnAa+Do7s8mfKm0dzzyTbYOHKTEz60Hu+TwkbQIQEYtKjaiP+jyfR1sbPPXUuuUTJ2bjXJmZDWJ9nc8jT80DaN2ksd6efrpYuZmZeWDEHid92mKLxsZhZjaAOHlMnbrmOY9qixe74dzMrAe52jwk7QO0UXWbKyJ+WF5YxfW5zQNg/HhYsGDdcrd7mNkgV1qbh6TLgR2Ae8jGugIIoKWSx3p58cXuy93uYWbWrTwN5u3ApMjbLWsg2nbb7ntc9dQeYmY2xOVp83gAeEPZgTTV1KkwatTaZWPGZOVmZraOPMljPPCQpBslzagsZQfWUB0d8OEPr9kePhyOO84PCZqZ9SDPbaszyg6i6To74frr12yvXAmXXQb77usEYmbWjV5rHhFxK/AIayaBejiVDR6nnQavvbZ22dKlWbmZma2j1+Qh6WPAncARZEOy3yHp8LIDayg/ZW5mVkie21anAXtExPMAkiYANwPXlhlYQ7m3lZlZIXkazIdVEkeyIOd5A8fUqVnvqmrubWVm1qM8SeBXqafV8ZKOB34JXN/LOQNLR0fWu6pi2DD3tjIzqyNPg/lXgOnArmmZHhFfKzuwhurszHpXVaxalW17bCszs27lns+j1a3X2FY9zekxbhy88MJ6xWVm1sr6OrZVjzUPSb9Pr4slLapaFksaXHN79NSrasEC1z7MzLrRY/KIiP3S68YRsUnVsnFEbNK4EBugXq8qP+thZraOPM95XJ6nbECr16vKz3qYma0jT2+rt1RvSBoB7F5OOE3S0QFjx3a/r7YLr5mZ1W3zOFXSYmDX6vYO4K/ALxoWYaOMHt19+ZIlbvcwM6tRr83j34FNgR/WtHeMi4hT87y5pIMkzZY0R9Ip3ewfJenqtP8OSW2p/IOS7pJ0f3p9fx8/X349TQgFbvcwM6tR97ZVRKwC9ujLG0saDpwPHAxMAo6WNKnmsBOAlyJiR+Bc4OxU/gLw0Yh4G3AcUH4bS71Gc7d7mJmtJU+bx92S+pJA9gTmRMQTEbEMuAqYXHPMZKDydN61wAGSFBF/joi/pPIHgQ0l1czW1M+mTgWp+30e48rMbC15kse7gD9KelzSfelW0n05znsTMK9q+5lU1u0xEbECWAiMqznm74G7I+L12gtImiKpS1LX/Pnzc4RUR0cHvL+Hu2OHHLJ+721mNsjkGVX3Q6VH0QNJbyG7lXVgd/sjYjrZ0Cm0t7ev/6Pyc+Z0X3794BrKy8xsfeUZ2+opYDPgo2nZLJX15llgm6rtrVNZt8ekLsCbko3ai6StgZ8Bx0bE4zmut/66G6KkXrmZ2RCV5yHBLwCdwN+k5QpJn8vx3rOAnSRtJ2kkcBRQO/f5DLIGcYDDgVsiIiRtRjZ67ykRcXuuT9Ifhg/veZ+765qZrdbrwIipfWPviFiStscCf4yIXXt9c+kQ4NvAcOAHETFV0llAV0TMkDSarCfVO4EXgaMi4glJ3wBOBR6rersDa+YVWct6DYy4JuCe902cCHPnrt/7m5m1mL4OjJgnedxPNpPga2l7NDArdaNtGf2SPHoaXbdikIxAbGZW0e+j6la5hGze8jMknQn8Cbi46IUGhHpjXNW7pWVmNsTkaTA/B/g42W2lF4CPR8S3S46rOerNHLhyZePiMDNrcUXmIlfN6+DUUw3DNQ8zs9Xy9LY6newp8M2B8cAlqUF7cOqphuGah5nZanlqHh1kDeZnRMT/BfYC/rHcsJpo4sSe933mM42Lw8ysheVJHn8BqscrH8W6D/sNHvUazS+8sHFxmJm1sDzJYyHwoKRLJV0CPAC8LOk8SeeVG14T1Gs0j/DDgmZm5Bvb6mdpqfhtOaG0kOHDe27j+PSn6ycYM7MhoNfkERGXpeFFdk5FsyNieblhNdmUKXDBBd3vW7KksbGYmbWgPL2t9icbJuR8YBrwqKT3lBtWk02bVn+/b12Z2RCXp83jv8nGlXpvRLyHbIj2c8sNqwUMq/Ol+fSnGxeHmVkLypM8NoiI2ZWNiHgU2KC8kFpEvQSxZIlrH2Y2pOVJHndJukjS/mn5H2A9RyAcAHq7dXXaaY2Jw8ysBeVJHicCDwGfT8tDwD+VGVTLGFc7I24VTxBlZkNY3d5WkoYD90bELsA5jQmphXznO3DMMc2Owsys5dSteUTESmC2pG0bFE9r6e15Dg9XYmZDVJ6HBDcne8L8TmD1Qw4RcWhpUQ0UF1wA++7rhwbNbMjJkzz+pfQoWtm4cbBgQc/7v/AFJw8zG3J6TB5putkTgR2B+4GLI2JFowJrGb21e9RLLGZmg1S9No/LgHayxHEw2cOCQ09HB2y0Uf1j3PZhZkNMveQxKSKOiYjvA4cD725QTK2nt6HYL7jADw2a2ZBSL3msHvxwSN6uqpanTeMLXyg/DjOzFlEvebxd0qK0LAZ2raxLWtSoAFtGvRkGIWv7cO3DzIaIHpNHRAyPiE3SsnFEjKha36SRQbaEejMMVnziE+XHYWbWAvIMT2KQ3br6p15GZVm2DDbYwDUQMxv0nDyKmDat/lDtACtWZF173QPLzAYxJ4+i8s7lccEF8IEPlBuLmVmTOHkUNW0aHHBAvmNnznQCMbNBycmjL26+GUaPznfszJluBzGzQcfJo68uuij/sZV2ECcRMxsknDz6qqMDrrgChg/Pf04liUhuUDezAc3JY310dGQJIW8bSLULLnASMbMBy8mjP9x8c+/PgPSkkkQk2Hhj39YyswHByaO/TJvW9wRS8cora25rVZYNN3RCMbOW4+TRn6ZNK94O0pvXXls3oQwb5ttdZtZUTh79rdIO0t9JpFrE2re76i2uuZhZCZw8ylKdREaObF4c3dVcWn1xzcqs5ZWaPCQdJGm2pDmSTulm/yhJV6f9d0hqq9p3aiqfLelDZcZZqo4OeP31rLbQl15ZQ1GRmpUXL16ypcH/dJWWPCQNB84nm8J2EnC0pEk1h50AvBQROwLnAmencycBRwFvAQ4CpqX3G9huvtlJxMzKUfmnq0EJpMyax57AnIh4IiKWAVcBk2uOmUw2VzrAtcABkpTKr4qI1yPiSWBOer/BoZJEIta/h5aZWbXp0xtymTKTx5uAeVXbz6Sybo9JU90uBMblPBdJUyR1SeqaP39+P4beQNOmrUkklaXZ7SRmNnCtXNmQywzoBvOImB4R7RHRPmHChGaH03+q20mqE8rYsc2OzMxaXVm9PGuUmTyeBbap2t46lXV7jKQRwKbAgpznDi0dHdlDhLW1lO4W11zMhq4pUxpymTKTxyxgJ0nbSRpJ1gA+o+aYGcBxaf1w4JaIiFR+VOqNtR2wE3BnibEOLt3VXFp9cc3KbP1IWRvqtGkNudyIst44IlZIOgm4ERgO/CAiHpR0FtAVETOAi4HLJc0BXiRLMKTjrgEeAlYAn42IxtzIs+bo6MgWMxsQlP2jP/C1t7dHV1dXs8MwMxtQJN0VEe1FzxvQDeZmZtYcTh5mZlaYk4eZmRXm5GFmZoUNmgZzSfOBp/p4+njghX4Mp785vr5r5djA8a2PVo4NBk58EyOi8FPWgyZ5rA9JXX3pbdAojq/vWjk2cHzro5Vjg8Efn29bmZlZYU4eZmZWmJNHpjFjGPed4+u7Vo4NHN/6aOXYYJDH5zYPMzMrzDUPMzMrzMnDzMwKG/LJQ9JBkmZLmiPplCbF8ANJz0t6oKpsC0k3SXosvW6eyiXpvBTvfZJ2Kzm2bST9RtJDkh6U9IUWi2+0pDsl3ZviOzOVbyfpjhTH1WlaANIw/1en8jsktZUZX7rmcEl/lnRdC8Y2V9L9ku6R1JXKWuJ7m665maRrJT0i6WFJe7dCfJLenL5mlWWRpC+2QmxVMf5z+p14QNKV6Xel/372ImLILmRDxT8ObA+MBO4FJjUhjvcAuwEPVJV9CzglrZ8CnJ3WDwFuAATsBdxRcmxbAbul9Y2BR4FJLRSfgI3S+gbAHem61wBHpfILgX9K658BLkzrRwFXN+D7ezLwI+C6tN1Ksc0FxteUtcT3Nl3zMuCTaX0ksFkrxZeuOxz4X2Biq8RGNm33k8CGVT9zx/fnz17pX9hWXoC9gRurtk8FTm1SLG2snTxmA1ul9a2A2Wn9+8DR3R3XoDh/AXywFeMDxgB3A+8ie3J2RO33mWx+mb3T+oh0nEqMaWtgJvB+4Lr0x6MlYkvXmcu6yaMlvrdkM4s+Wfs1aJX4qq5zIHB7K8VGljzmAVukn6XrgA/158/eUL9tVfkCVzyTylrBlhHxXFr/X2DLtN60mFNV9p1k/923THzpttA9wPPATWS1yZcjYkU3MayOL+1fCIwrMbxvA18FVqXtcS0UG0AAv5Z0l6TK/KWt8r3dDpgPXJJu+10kaWwLxVdxFHBlWm+J2CLiWeC/gKeB58h+lu6iH3/2hnryGBAi+3egqX2qJW0E/AT4YkQsqt7X7PgiYmVEvIPsv/w9gV2aFUs1SR8Bno+Iu5odSx37RcRuwMHAZyW9p3pnk7+3I8hu514QEe8ElpDdClqt2T97qc3gUODHtfuaGVtqa5lMloDfCIwFDurPawz15PEssE3V9taprBX8VdJWAOn1+VTe8JglbUCWODoj4qetFl9FRLwM/IasOr6ZpMo0y9UxrI4v7d8UWFBSSPsCh0qaC1xFduvqOy0SG7D6P1Qi4nngZ2TJt1W+t88Az0TEHWn7WrJk0irxQZZ0746Iv6btVontA8CTETE/IpYDPyX7eey3n72hnjxmATulHggjyaqfM5ocU8UM4Li0fhxZW0Ol/NjUe2MvYGFVNbnfSRLZXPMPR8Q5LRjfBEmbpfUNydpjHiZLIof3EF8l7sOBW9J/iP0uIk6NiK0joo3sZ+uWiOhohdgAJI2VtHFlneze/QO0yPc2Iv4XmCfpzanoAOChVokvOZo1t6wqMbRCbE8De0kak36HK1+7/vvZK7sxqdUXsl4Qj5LdJz+tSTFcSXZfcjnZf1snkN1vnAk8BtwMbJGOFXB+ivd+oL3k2PYjq3rfB9yTlkNaKL5dgT+n+B4ATk/l2wN3AnPIbimMSuWj0/actH/7Bn2P92dNb6uWiC3FcW9aHqz8/LfK9zZd8x1AV/r+/hzYvFXiI7sVtADYtKqsJWJL1zwTeCT9XlwOjOrPnz0PT2JmZoUN9dtWZmbWB04eZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh7WsiStTCOWPiDpx5LG9HDcH/r4/u2SzluP+F7pofwNkq6S9Hga9uN6STv39TqtQNL+kvZpdhzWOpw8rJW9GhHviIi3AsuAE6t3Vp6UjYg+/VGLiK6I+Pz6h7lWTCJ7Uvu3EbFDROxONuDmlvXPbHn7A04etpqThw0UtwE7pv+Ab5M0g+yJ2dU1gLTvt1oz/0Nn+mOOpD0k/UHZvB93Sto4HV+ZY+MMSZdL+qOyuRg+lco3kjRT0t3K5r2Y3Euc7wOWR8SFlYKIuDcibktPF/9nqkndL+nIqrhvlfQLSU9I+g9JHSnO+yXtkI67VNKFkrokPaps7KzKnCaXpGP/LOl9qfx4ST+V9Kv0mb5ViUnSgemz3p1qdRul8rmSzqz6vLsoGxDzROCfU03w3ev5vbRBYETvh5g1V6phHAz8KhXtBrw1Ip7s5vB3Am8B/gLcDuwr6U7gauDIiJglaRPg1W7O3ZVsroWxwJ8l/ZJsbKLDImKRpPHAnyTNiJ6frn0r2eil3fk7siem3w6MB2ZJ+l3a93bgb4EXgSeAiyJiT2WTb30O+GI6ro1s/KkdgN9I2hH4LNk4fG+TtAvZKLmV22TvSF+T14HZkr6bPvs3gA9ExBJJXyObc+SsdM4LEbGbpM8AX46IT0q6EHglIv6rh89mQ4yTh7WyDZUNtQ5ZzeNislsnd/aQOEj7ngFI57aRDS/9XETMAog0KnCqlFT7RUS8Crwq6Tdkf6R/CfybstFmV5ENXb0l2XDbRe0HXBkRK8kG0LsV2ANYBMyKNNaRpMeBX6dz7ierzVRcExGrgMckPUE2gvB+wHfTZ3tE0lNAJXnMjIiF6X0fIpuwaDOyCb1uT1+DkcAfq65RGfzyLrKEZ7YOJw9rZa9GNtT6aumP3ZI657xetb6SYj/jtbWJADqACcDuEbFc2Qi5o+u8x4OsGXiuiOq4V1Vtr2Ltz9BdjHnft/L1EHBTRBzdyzlFv342hLjNw4aC2cBWkvYASO0d3f1RnJzaD8aRNRDPIhua+vmUON5H9p97PbcAo7RmYiUk7ZraCW4DjlQ2edUEsumH7yz4WY6QNCy1g2yfPtttZEmOdLtq21Tekz+R3c7bMZ0zNkdvsMVk0xCbAU4eNgRExDLgSOC7ku4lm22wu9rDfWRDVv8J+NeI+AvQCbRLuh84lmyU0nrXCuAw4APKuuo+CPw72W2un6Vr3EuWZL4a2bDjRTxNlnBuAE6MiNeAacCwFOPVwPER8XpPbxAR88nms75S0n1kt6x6m0Dr/wGHucHcKjyqrhlZbytavEFY0qVkw7pf2+xYzFzzMDOzwlzzMDOzwlzzMDOzwpw8zMysMCcPMzMrzMnDzMwKc/IwM7PC/j/3O5XOVIi+agAAAABJRU5ErkJggg==\n", 886 | "text/plain": [ 887 | "
" 888 | ] 889 | }, 890 | "metadata": { 891 | "needs_background": "light" 892 | }, 893 | "output_type": "display_data" 894 | } 895 | ], 896 | "source": [ 897 | "import numpy as np\n", 898 | "import matplotlib\n", 899 | "import matplotlib.pyplot as plt\n", 900 | "\n", 901 | "PC_values = np.arange(pca.n_components_) + 1\n", 902 | "plt.plot(PC_values, pca.explained_variance_ratio_, 'ro-', linewidth=2)\n", 903 | "plt.title('Scree Plot')\n", 904 | "plt.xlabel('Principal Component')\n", 905 | "plt.ylabel('Proportion of Variance Explained')\n", 906 | "plt.show()" 907 | ] 908 | }, 909 | { 910 | "cell_type": "markdown", 911 | "id": "51fb6a41", 912 | "metadata": {}, 913 | "source": [ 914 | "Note that almost 93% of the variance is explained by the first 128 principal components" 915 | ] 916 | }, 917 | { 918 | "cell_type": "code", 919 | "execution_count": 14, 920 | "id": "c7005d80", 921 | "metadata": {}, 922 | "outputs": [ 923 | { 924 | "data": { 925 | "text/plain": [ 926 | "0.9270989465518304" 927 | ] 928 | }, 929 | "execution_count": 14, 930 | "metadata": {}, 931 | "output_type": "execute_result" 932 | } 933 | ], 934 | "source": [ 935 | "sum(pca.explained_variance_ratio_[:128])" 936 | ] 937 | }, 938 | { 939 | "cell_type": "markdown", 940 | "id": "5ade6a55", 941 | "metadata": {}, 942 | "source": [ 943 | "## Logistic Regression with 128 Principal Components" 944 | ] 945 | }, 946 | { 947 | "cell_type": "code", 948 | "execution_count": 15, 949 | "id": "09d55936", 950 | "metadata": {}, 951 | "outputs": [ 952 | { 953 | "data": { 954 | "text/plain": [ 955 | "array([[-6.19892061e-01, 7.57744837e+00, 2.38994151e-01, ...,\n", 956 | " 1.45125136e-01, 4.76339161e-01, -3.43906790e-01],\n", 957 | " [ 1.02952302e+00, 4.56361294e+00, 1.81073594e+00, ...,\n", 958 | " -3.10825527e-01, -1.21716559e-01, -1.96754932e-03],\n", 959 | " [ 1.57442674e-01, 4.48937798e+00, -2.28051734e+00, ...,\n", 960 | " 1.97405860e-01, 4.04985249e-01, -6.72422767e-01],\n", 961 | " ...,\n", 962 | " [-7.67455101e+00, -1.38447428e+00, 3.48542595e+00, ...,\n", 963 | " -1.10648334e-01, -3.42456937e-01, -4.07753110e-01],\n", 964 | " [-7.29859233e-01, -1.26016960e-02, 1.33440959e+00, ...,\n", 965 | " -3.93246353e-01, -2.03127354e-01, -8.12449753e-02],\n", 966 | " [ 2.64888406e+00, 2.00736737e+00, -2.70441365e+00, ...,\n", 967 | " 1.83935657e-01, -3.53981048e-01, -3.53198677e-01]], dtype=float32)" 968 | ] 969 | }, 970 | "execution_count": 15, 971 | "metadata": {}, 972 | "output_type": "execute_result" 973 | } 974 | ], 975 | "source": [ 976 | "pca128 = PCA(n_components=128)\n", 977 | "pca128.fit(X_train)\n", 978 | "X_train_reduced = pca128.transform(X_train)\n", 979 | "X_train_reduced" 980 | ] 981 | }, 982 | { 983 | "cell_type": "code", 984 | "execution_count": 16, 985 | "id": "25a7041b", 986 | "metadata": {}, 987 | "outputs": [], 988 | "source": [ 989 | "model_128 = LogisticRegression(multi_class='ovr', solver='saga', n_jobs=60)" 990 | ] 991 | }, 992 | { 993 | "cell_type": "code", 994 | "execution_count": 17, 995 | "id": "6ffa7a40", 996 | "metadata": {}, 997 | "outputs": [ 998 | { 999 | "data": { 1000 | "text/plain": [ 1001 | "LogisticRegression(multi_class='ovr', n_jobs=60, solver='saga')" 1002 | ] 1003 | }, 1004 | "execution_count": 17, 1005 | "metadata": {}, 1006 | "output_type": "execute_result" 1007 | } 1008 | ], 1009 | "source": [ 1010 | "model_128.fit(X_train_reduced, y_train)" 1011 | ] 1012 | }, 1013 | { 1014 | "cell_type": "code", 1015 | "execution_count": 18, 1016 | "id": "6e3734fe", 1017 | "metadata": {}, 1018 | "outputs": [], 1019 | "source": [ 1020 | "X_validate_reduced = pca128.transform(X_validate)" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "execution_count": 21, 1026 | "id": "a6617d85", 1027 | "metadata": {}, 1028 | "outputs": [ 1029 | { 1030 | "name": "stdout", 1031 | "output_type": "stream", 1032 | "text": [ 1033 | "Accuracy of logistic regression classifier on VALIDATE set: 0.43\n" 1034 | ] 1035 | } 1036 | ], 1037 | "source": [ 1038 | "print('Accuracy of logistic regression classifier on VALIDATE set: {:.2f}'\\\n", 1039 | " .format(model_128.score(X_validate_reduced, y_validate)))" 1040 | ] 1041 | }, 1042 | { 1043 | "cell_type": "markdown", 1044 | "id": "11b2ca0d", 1045 | "metadata": {}, 1046 | "source": [ 1047 | "__Accuracy seems to go down, though we seem to have avoided convergence warnings__" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "markdown", 1052 | "id": "c406d1e1", 1053 | "metadata": {}, 1054 | "source": [ 1055 | "### Save PCA Object\n", 1056 | "We will use this in part 3 for Neo4j import." 1057 | ] 1058 | }, 1059 | { 1060 | "cell_type": "code", 1061 | "execution_count": 22, 1062 | "id": "346f767c", 1063 | "metadata": {}, 1064 | "outputs": [ 1065 | { 1066 | "data": { 1067 | "text/plain": [ 1068 | "['/data/paper-feat-pca128.joblib']" 1069 | ] 1070 | }, 1071 | "execution_count": 22, 1072 | "metadata": {}, 1073 | "output_type": "execute_result" 1074 | } 1075 | ], 1076 | "source": [ 1077 | "from joblib import dump\n", 1078 | "dump(pca128, ROOT_DATA_DIR + '/paper-feat-pca128.joblib')" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "markdown", 1083 | "id": "bcb06747", 1084 | "metadata": {}, 1085 | "source": [ 1086 | "## (Optional) Checking with 384 Principal Components\n", 1087 | "Just for sake of experimentation, let's up the dimensionality to 384 (50% the size of the original vector size) and see what we get. " 1088 | ] 1089 | }, 1090 | { 1091 | "cell_type": "code", 1092 | "execution_count": 23, 1093 | "id": "8209bcf2", 1094 | "metadata": {}, 1095 | "outputs": [ 1096 | { 1097 | "data": { 1098 | "text/plain": [ 1099 | "0.9850090609327055" 1100 | ] 1101 | }, 1102 | "execution_count": 23, 1103 | "metadata": {}, 1104 | "output_type": "execute_result" 1105 | } 1106 | ], 1107 | "source": [ 1108 | "sum(pca.explained_variance_ratio_[:384])" 1109 | ] 1110 | }, 1111 | { 1112 | "cell_type": "code", 1113 | "execution_count": 24, 1114 | "id": "3e67dcb3", 1115 | "metadata": {}, 1116 | "outputs": [ 1117 | { 1118 | "data": { 1119 | "text/plain": [ 1120 | "array([[-6.19884431e-01, 7.57744980e+00, 2.38999844e-01, ...,\n", 1121 | " -2.00414509e-02, 1.23605132e-02, 5.13292551e-02],\n", 1122 | " [ 1.02952719e+00, 4.56361008e+00, 1.81073332e+00, ...,\n", 1123 | " -1.46002211e-02, 1.11485720e-02, -6.50432706e-03],\n", 1124 | " [ 1.57447502e-01, 4.48937893e+00, -2.28051710e+00, ...,\n", 1125 | " -1.08808234e-01, 5.55179715e-02, 1.12585865e-01],\n", 1126 | " ...,\n", 1127 | " [-7.67455006e+00, -1.38446569e+00, 3.48541903e+00, ...,\n", 1128 | " 7.78694451e-02, 4.36903536e-02, -1.50039345e-01],\n", 1129 | " [-7.29859591e-01, -1.26056150e-02, 1.33440888e+00, ...,\n", 1130 | " -1.61436945e-02, -1.19892642e-01, -2.25727051e-01],\n", 1131 | " [ 2.64888358e+00, 2.00736547e+00, -2.70441127e+00, ...,\n", 1132 | " -6.14562258e-02, 7.74844512e-02, -3.82103026e-05]], dtype=float32)" 1133 | ] 1134 | }, 1135 | "execution_count": 24, 1136 | "metadata": {}, 1137 | "output_type": "execute_result" 1138 | } 1139 | ], 1140 | "source": [ 1141 | "pca384 = PCA(n_components=384)\n", 1142 | "pca384.fit(X_train)\n", 1143 | "X_train_reduced_384 = pca384.transform(X_train)\n", 1144 | "X_train_reduced_384" 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "code", 1149 | "execution_count": 25, 1150 | "id": "1489fee5", 1151 | "metadata": {}, 1152 | "outputs": [], 1153 | "source": [ 1154 | "model_384 = LogisticRegression(multi_class='ovr', solver='saga', n_jobs=60)" 1155 | ] 1156 | }, 1157 | { 1158 | "cell_type": "code", 1159 | "execution_count": 26, 1160 | "id": "67c70089", 1161 | "metadata": {}, 1162 | "outputs": [ 1163 | { 1164 | "data": { 1165 | "text/plain": [ 1166 | "LogisticRegression(multi_class='ovr', n_jobs=60, solver='saga')" 1167 | ] 1168 | }, 1169 | "execution_count": 26, 1170 | "metadata": {}, 1171 | "output_type": "execute_result" 1172 | } 1173 | ], 1174 | "source": [ 1175 | "model_384.fit(X_train_reduced_384, y_train)" 1176 | ] 1177 | }, 1178 | { 1179 | "cell_type": "code", 1180 | "execution_count": 27, 1181 | "id": "177b2e14", 1182 | "metadata": {}, 1183 | "outputs": [], 1184 | "source": [ 1185 | "X_validate_reduced_384 = pca384.transform(X_validate)" 1186 | ] 1187 | }, 1188 | { 1189 | "cell_type": "code", 1190 | "execution_count": 29, 1191 | "id": "c58f21af", 1192 | "metadata": {}, 1193 | "outputs": [ 1194 | { 1195 | "name": "stdout", 1196 | "output_type": "stream", 1197 | "text": [ 1198 | "Accuracy of logistic regression classifier on VALIDATE set: 0.47\n" 1199 | ] 1200 | } 1201 | ], 1202 | "source": [ 1203 | "print('Accuracy of logistic regression classifier on VALIDATE set: {:.2f}'\\\n", 1204 | " .format(model_384.score(X_validate_reduced_384, y_validate)))" 1205 | ] 1206 | }, 1207 | { 1208 | "cell_type": "markdown", 1209 | "id": "3ef1bbe8", 1210 | "metadata": {}, 1211 | "source": [ 1212 | "__Very close to the original accuracy but still lower__" 1213 | ] 1214 | }, 1215 | { 1216 | "cell_type": "code", 1217 | "execution_count": null, 1218 | "id": "f46c7c7a", 1219 | "metadata": {}, 1220 | "outputs": [], 1221 | "source": [] 1222 | } 1223 | ], 1224 | "metadata": { 1225 | "kernelspec": { 1226 | "display_name": "Python 3 (ipykernel)", 1227 | "language": "python", 1228 | "name": "python3" 1229 | }, 1230 | "language_info": { 1231 | "codemirror_mode": { 1232 | "name": "ipython", 1233 | "version": 3 1234 | }, 1235 | "file_extension": ".py", 1236 | "mimetype": "text/x-python", 1237 | "name": "python", 1238 | "nbconvert_exporter": "python", 1239 | "pygments_lexer": "ipython3", 1240 | "version": "3.9.7" 1241 | } 1242 | }, 1243 | "nbformat": 4, 1244 | "nbformat_minor": 5 1245 | } 1246 | -------------------------------------------------------------------------------- /part3-prepare-papers-for-import.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b6d6c4f1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Part 3: Formatting Papers for Neo4j Admin Import\n", 9 | "\n", 10 | "This notebook formats the papers nodes and cites relationships into csvs for admin import. \n", 11 | "\n", 12 | "__Note: The runtime for this notebook depends greatly on the environment within which it is run. It takes a few hours for me to complete on a 64-core 976 GB memory instance.__\n", 13 | "\n", 14 | "### Chunking Methodology\n", 15 | "This notebook splits the papers into chunks to avoid out of memory errors when formatting data. The `chunk_size` variable determines the number of papers brought into memory at once. `chunk_size` can be adjusted as needed, specifically, it can be turned down if encountering kernel shutdowns due to running out of memory. \n", 16 | "\n", 17 | "### Reducing Dimensionality with PCA\n", 18 | "the PCA (128 component) object saved from Part 2 is used here to reduce the encoding vector from 768 to 128 dimensions. This is a trade-off that gives up a small amount of variance in the original encodings for shorter vectors that will require less resource to work with in latter steps. It may be worth exploring higher dimensionality in the future." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "50a66a9b", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from ogb.lsc import MAG240MDataset\n", 29 | "import numpy as np\n", 30 | "import os.path as osp\n", 31 | "import pandas as pd\n", 32 | "import dask.dataframe as dd\n", 33 | "from joblib import load\n", 34 | "\n", 35 | "ROOT_DATA_DIR = '/data'\n", 36 | "pca_model_file = f'{ROOT_DATA_DIR}/paper-feat-pca128.joblib'\n", 37 | "chunk_size = 20_000_000" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "30235362", 43 | "metadata": {}, 44 | "source": [ 45 | "## Prepare Paper Nodes" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "id": "eff9814e", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "def load_and_pre_format_paper_data(dataset, from_ind, to_inx):\n", 56 | " feat_in_memory = feat_in_memory = dataset.paper_feat[from_ind:to_inx]\n", 57 | " feat_cols = [f'paper_encoding_{i}' for i in range(768)]\n", 58 | " paper_df = pd.DataFrame(feat_in_memory, columns = feat_cols)\n", 59 | " \n", 60 | " paper_df['ogb_index'] = paper_df.index + from_ind\n", 61 | " \n", 62 | " paper_df['paper_subject'] = dataset.all_paper_label[from_ind:to_inx] \n", 63 | " paper_df['paper_subject'] = paper_df['paper_subject'].fillna(-2)\n", 64 | " \n", 65 | " paper_df['paper_year'] = dataset.all_paper_year[from_ind:to_inx] \n", 66 | " \n", 67 | " split_dict = dataset.get_idx_split()\n", 68 | " paper_df[\"split_segment\"] = 'REMAINDER'\n", 69 | " paper_df.loc[paper_df.ogb_index.isin(split_dict['train']), 'split_segment'] = 'TRAIN'\n", 70 | " paper_df.loc[paper_df.ogb_index.isin(split_dict['valid']), 'split_segment'] = 'VALIDATE'\n", 71 | " paper_df.loc[paper_df.ogb_index.isin(split_dict['test-dev']), 'split_segment'] = 'TEST_DEV'\n", 72 | " paper_df.loc[paper_df.ogb_index.isin(split_dict['test-challenge']), 'split_segment'] = 'TEST_CHALLENGE'\n", 73 | " \n", 74 | " paper_df['subject_status'] = \"ERROR\"\n", 75 | " paper_df.loc[paper_df.paper_subject > -1,'subject_status'] = \"KNOWN\"\n", 76 | " paper_df.loc[paper_df.paper_subject == -1,'subject_status'] = \"HIDDEN\"\n", 77 | " paper_df.loc[paper_df.paper_subject == -2,'subject_status'] = \"UNKNOWN\"\n", 78 | " \n", 79 | " return paper_df" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 3, 85 | "id": "295562db", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "def reduce_paper_data(paper_df, pca_model_object = pca_model_file):\n", 90 | " feat_cols = [f'paper_encoding_{i}' for i in range(768)]\n", 91 | " feat_128_cols = ['paper_128_encoding_' + str(x) for x in range(128)]\n", 92 | " pca128 = load(pca_model_object)\n", 93 | " res_df = pd.DataFrame(pca128.transform(paper_df[feat_cols]), columns = feat_128_cols)\n", 94 | " res_df = pd.concat([paper_df[[\"ogb_index\", \"split_segment\", \"subject_status\", \"paper_year\", \n", 95 | " \"paper_subject\"]], res_df], axis=1)\n", 96 | " return res_df" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "id": "8836134e", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "#Future Note: Change split_segment -> splitSegment and subject_status -> subjectStatus \n", 107 | "## for consitent naming and to work with next parts, namely export. \n", 108 | "def post_format_paper_data(reduced_paper_df, npartitions=2000):\n", 109 | " feat_128_cols = ['paper_128_encoding_' + str(x) for x in range(128)]\n", 110 | " reduced_paper_df.rename(\n", 111 | " columns={\"ogb_index\":\"ogbIndex:ID\", \"split_segment\":\"split_segment:string\", \n", 112 | " \"subject_status\":\"subject_status:string\", \"paper_year\":\"year:int\", \n", 113 | " \"paper_subject\":\"subject:int\"}, inplace=True)\n", 114 | " paper_ddf = dd.from_pandas(reduced_paper_df, npartitions=npartitions)\n", 115 | " paper_ddf = paper_ddf.astype({'subject:int':'int32'})\n", 116 | " paper_ddf[\"encoding:float[]\"] = \\\n", 117 | " paper_ddf.apply(lambda x:\";\".join(['%0.5f' % i for i in x[feat_128_cols]]), axis=1,meta=(\"str\"))\n", 118 | " paper_ddf = paper_ddf.drop(columns=feat_128_cols)\n", 119 | " paper_ddf.compute(scheduler=\"processes\")\n", 120 | " return paper_ddf" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "id": "bf3760dd", 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "pre-formatting for chunk 0 to 20000000...\n", 134 | "pca reduction...\n", 135 | "post-formatting...\n", 136 | "writing chunk to files /data/demo-load/paper-c0-*.csv\n", 137 | "finished 16.43 of the data, iterated count to 1\n", 138 | "========================================\n", 139 | "========================================\n", 140 | "pre-formatting for chunk 20000000 to 40000000...\n", 141 | "pca reduction...\n", 142 | "post-formatting...\n", 143 | "writing chunk to files /data/demo-load/paper-c1-*.csv\n", 144 | "finished 32.85 of the data, iterated count to 2\n", 145 | "========================================\n", 146 | "========================================\n", 147 | "pre-formatting for chunk 40000000 to 60000000...\n", 148 | "pca reduction...\n", 149 | "post-formatting...\n", 150 | "writing chunk to files /data/demo-load/paper-c2-*.csv\n", 151 | "finished 49.28 of the data, iterated count to 3\n", 152 | "========================================\n", 153 | "========================================\n", 154 | "pre-formatting for chunk 60000000 to 80000000...\n", 155 | "pca reduction...\n", 156 | "post-formatting...\n", 157 | "writing chunk to files /data/demo-load/paper-c3-*.csv\n", 158 | "finished 65.71 of the data, iterated count to 4\n", 159 | "========================================\n", 160 | "========================================\n", 161 | "pre-formatting for chunk 80000000 to 100000000...\n", 162 | "pca reduction...\n", 163 | "post-formatting...\n", 164 | "writing chunk to files /data/demo-load/paper-c4-*.csv\n", 165 | "finished 82.13 of the data, iterated count to 5\n", 166 | "========================================\n", 167 | "========================================\n", 168 | "pre-formatting for chunk 100000000 to 120000000...\n", 169 | "pca reduction...\n", 170 | "post-formatting...\n", 171 | "writing chunk to files /data/demo-load/paper-c5-*.csv\n", 172 | "finished 98.56 of the data, iterated count to 6\n", 173 | "========================================\n", 174 | "========================================\n", 175 | "pre-formatting for chunk 120000000 to 140000000...\n", 176 | "pca reduction...\n", 177 | "post-formatting...\n", 178 | "writing chunk to files /data/demo-load/paper-c6-*.csv\n", 179 | "finished 114.99 of the data, iterated count to 7\n", 180 | "========================================\n", 181 | "========================================\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "dataset = MAG240MDataset(root = ROOT_DATA_DIR)\n", 187 | "total_n = dataset.num_papers\n", 188 | "done_n = 0\n", 189 | "count = 0\n", 190 | "\n", 191 | "while done_n < total_n:\n", 192 | " to_n = done_n + chunk_size\n", 193 | " print(\"pre-formatting for chunk \" + str(done_n) + \" to \" + str(to_n) + \"...\")\n", 194 | " paper_df = load_and_pre_format_paper_data(dataset, done_n, to_n)\n", 195 | " print(\"pca reduction...\")\n", 196 | " reduced_paper_df = reduce_paper_data(paper_df)\n", 197 | " print(\"post-formatting...\")\n", 198 | " paper_ddf = post_format_paper_data(reduced_paper_df)\n", 199 | " print(\"writing chunk to files \" + f'{ROOT_DATA_DIR}/demo-load/paper-c{count}-*.csv')\n", 200 | " paper_ddf = paper_ddf.repartition(npartitions=100)\n", 201 | " paper_ddf.to_csv(f'{ROOT_DATA_DIR}/demo-load/paper-c{count}-*.csv', \n", 202 | " header_first_partition_only=True, index=False, compute_kwargs={'scheduler': 'processes'})\n", 203 | " count += 1\n", 204 | " done_n = to_n\n", 205 | " print(f\"finished {round(100*done_n/total_n,2)} of the data, iterated count to {count}\")\n", 206 | " print(\"========================================\")\n", 207 | " print(\"========================================\")" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "id": "39c9ea6a", 213 | "metadata": {}, 214 | "source": [ 215 | "## Prepare Cite Relationships" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 6, 221 | "id": "0e5bd432", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "cites_edge_ddf = dd.from_pandas(pd.DataFrame(dataset.edge_index('paper', 'paper').T, \n", 226 | " columns = [\":START_ID\",\":END_ID\"]), npartitions=100)\n" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 7, 232 | "id": "2c159abf", 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stderr", 237 | "output_type": "stream", 238 | "text": [ 239 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/dask/dataframe/io/csv.py:916: FutureWarning: The 'scheduler' keyword argument for `to_csv()` is deprecated andwill be removed in a future version. Please use the `compute_kwargs` argument instead. For example, df.to_csv(..., compute_kwargs={scheduler: processes})\n", 240 | " warn(\n" 241 | ] 242 | }, 243 | { 244 | "data": { 245 | "text/plain": [ 246 | "['/data/demo-load/cited-00.csv',\n", 247 | " '/data/demo-load/cited-01.csv',\n", 248 | " '/data/demo-load/cited-02.csv',\n", 249 | " '/data/demo-load/cited-03.csv',\n", 250 | " '/data/demo-load/cited-04.csv',\n", 251 | " '/data/demo-load/cited-05.csv',\n", 252 | " '/data/demo-load/cited-06.csv',\n", 253 | " '/data/demo-load/cited-07.csv',\n", 254 | " '/data/demo-load/cited-08.csv',\n", 255 | " '/data/demo-load/cited-09.csv',\n", 256 | " '/data/demo-load/cited-10.csv',\n", 257 | " '/data/demo-load/cited-11.csv',\n", 258 | " '/data/demo-load/cited-12.csv',\n", 259 | " '/data/demo-load/cited-13.csv',\n", 260 | " '/data/demo-load/cited-14.csv',\n", 261 | " '/data/demo-load/cited-15.csv',\n", 262 | " '/data/demo-load/cited-16.csv',\n", 263 | " '/data/demo-load/cited-17.csv',\n", 264 | " '/data/demo-load/cited-18.csv',\n", 265 | " '/data/demo-load/cited-19.csv',\n", 266 | " '/data/demo-load/cited-20.csv',\n", 267 | " '/data/demo-load/cited-21.csv',\n", 268 | " '/data/demo-load/cited-22.csv',\n", 269 | " '/data/demo-load/cited-23.csv',\n", 270 | " '/data/demo-load/cited-24.csv',\n", 271 | " '/data/demo-load/cited-25.csv',\n", 272 | " '/data/demo-load/cited-26.csv',\n", 273 | " '/data/demo-load/cited-27.csv',\n", 274 | " '/data/demo-load/cited-28.csv',\n", 275 | " '/data/demo-load/cited-29.csv',\n", 276 | " '/data/demo-load/cited-30.csv',\n", 277 | " '/data/demo-load/cited-31.csv',\n", 278 | " '/data/demo-load/cited-32.csv',\n", 279 | " '/data/demo-load/cited-33.csv',\n", 280 | " '/data/demo-load/cited-34.csv',\n", 281 | " '/data/demo-load/cited-35.csv',\n", 282 | " '/data/demo-load/cited-36.csv',\n", 283 | " '/data/demo-load/cited-37.csv',\n", 284 | " '/data/demo-load/cited-38.csv',\n", 285 | " '/data/demo-load/cited-39.csv',\n", 286 | " '/data/demo-load/cited-40.csv',\n", 287 | " '/data/demo-load/cited-41.csv',\n", 288 | " '/data/demo-load/cited-42.csv',\n", 289 | " '/data/demo-load/cited-43.csv',\n", 290 | " '/data/demo-load/cited-44.csv',\n", 291 | " '/data/demo-load/cited-45.csv',\n", 292 | " '/data/demo-load/cited-46.csv',\n", 293 | " '/data/demo-load/cited-47.csv',\n", 294 | " '/data/demo-load/cited-48.csv',\n", 295 | " '/data/demo-load/cited-49.csv',\n", 296 | " '/data/demo-load/cited-50.csv',\n", 297 | " '/data/demo-load/cited-51.csv',\n", 298 | " '/data/demo-load/cited-52.csv',\n", 299 | " '/data/demo-load/cited-53.csv',\n", 300 | " '/data/demo-load/cited-54.csv',\n", 301 | " '/data/demo-load/cited-55.csv',\n", 302 | " '/data/demo-load/cited-56.csv',\n", 303 | " '/data/demo-load/cited-57.csv',\n", 304 | " '/data/demo-load/cited-58.csv',\n", 305 | " '/data/demo-load/cited-59.csv',\n", 306 | " '/data/demo-load/cited-60.csv',\n", 307 | " '/data/demo-load/cited-61.csv',\n", 308 | " '/data/demo-load/cited-62.csv',\n", 309 | " '/data/demo-load/cited-63.csv',\n", 310 | " '/data/demo-load/cited-64.csv',\n", 311 | " '/data/demo-load/cited-65.csv',\n", 312 | " '/data/demo-load/cited-66.csv',\n", 313 | " '/data/demo-load/cited-67.csv',\n", 314 | " '/data/demo-load/cited-68.csv',\n", 315 | " '/data/demo-load/cited-69.csv',\n", 316 | " '/data/demo-load/cited-70.csv',\n", 317 | " '/data/demo-load/cited-71.csv',\n", 318 | " '/data/demo-load/cited-72.csv',\n", 319 | " '/data/demo-load/cited-73.csv',\n", 320 | " '/data/demo-load/cited-74.csv',\n", 321 | " '/data/demo-load/cited-75.csv',\n", 322 | " '/data/demo-load/cited-76.csv',\n", 323 | " '/data/demo-load/cited-77.csv',\n", 324 | " '/data/demo-load/cited-78.csv',\n", 325 | " '/data/demo-load/cited-79.csv',\n", 326 | " '/data/demo-load/cited-80.csv',\n", 327 | " '/data/demo-load/cited-81.csv',\n", 328 | " '/data/demo-load/cited-82.csv',\n", 329 | " '/data/demo-load/cited-83.csv',\n", 330 | " '/data/demo-load/cited-84.csv',\n", 331 | " '/data/demo-load/cited-85.csv',\n", 332 | " '/data/demo-load/cited-86.csv',\n", 333 | " '/data/demo-load/cited-87.csv',\n", 334 | " '/data/demo-load/cited-88.csv',\n", 335 | " '/data/demo-load/cited-89.csv',\n", 336 | " '/data/demo-load/cited-90.csv',\n", 337 | " '/data/demo-load/cited-91.csv',\n", 338 | " '/data/demo-load/cited-92.csv',\n", 339 | " '/data/demo-load/cited-93.csv',\n", 340 | " '/data/demo-load/cited-94.csv',\n", 341 | " '/data/demo-load/cited-95.csv',\n", 342 | " '/data/demo-load/cited-96.csv',\n", 343 | " '/data/demo-load/cited-97.csv',\n", 344 | " '/data/demo-load/cited-98.csv',\n", 345 | " '/data/demo-load/cited-99.csv']" 346 | ] 347 | }, 348 | "execution_count": 7, 349 | "metadata": {}, 350 | "output_type": "execute_result" 351 | } 352 | ], 353 | "source": [ 354 | "cites_edge_ddf.to_csv(f'{ROOT_DATA_DIR}/demo-load/cited-*.csv', \n", 355 | " header_first_partition_only=True, index=False, scheduler=\"processes\")" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "id": "a2d46072", 361 | "metadata": {}, 362 | "source": [ 363 | "### Removing Extra Headers for Papers\n", 364 | "\n", 365 | "Unfortunately a small manual step here since I didn't get the headings quite right when writing the paper csvs. There should only be one paper csv with a header (the first one from the first chunk) for admin import. However, In the chunking logic above I write a header for the first file of each chunk. As a result, we must remove the headers from the first file of each chunk with exception to the initial chunk. \n", 366 | "\n", 367 | "It is easy to accomplish this in a terminal. Simply go to the directory with the csvs and execute the `sed` command like below for each first file with exception to `paper-c0-00.csv`. For the chunk size of 20 Million, it would look like the below. \n", 368 | "\n", 369 | "```bash\n", 370 | " sed -i '1d' paper-c1-00.csv \n", 371 | " sed -i '1d' paper-c2-00.csv \n", 372 | " sed -i '1d' paper-c3-00.csv \n", 373 | " sed -i '1d' paper-c4-00.csv \n", 374 | " sed -i '1d' paper-c5-00.csv \n", 375 | " sed -i '1d' paper-c6-00.csv \n", 376 | "```" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "id": "a2f3f9b9", 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [] 386 | } 387 | ], 388 | "metadata": { 389 | "kernelspec": { 390 | "display_name": "Python 3 (ipykernel)", 391 | "language": "python", 392 | "name": "python3" 393 | }, 394 | "language_info": { 395 | "codemirror_mode": { 396 | "name": "ipython", 397 | "version": 3 398 | }, 399 | "file_extension": ".py", 400 | "mimetype": "text/x-python", 401 | "name": "python", 402 | "nbconvert_exporter": "python", 403 | "pygments_lexer": "ipython3", 404 | "version": "3.9.7" 405 | } 406 | }, 407 | "nbformat": 4, 408 | "nbformat_minor": 5 409 | } 410 | -------------------------------------------------------------------------------- /part4-prepare-authors-and-inst-for-import.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6576eac0", 6 | "metadata": {}, 7 | "source": [ 8 | "# PART 4: Prepare Author and Institution Data For Admin Import\n", 9 | "\n", 10 | "This notebook formats the authors and institution and writes and affiliated_with relationships into csvs for admin import.\n", 11 | "\n", 12 | "Unlike Part 3 preparing paper nodes, these elements do not have any properties and as such this notebook runs much faster and does not require a chunking methodology .... at least on my machine. " 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "50a66a9b", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from ogb.lsc import MAG240MDataset\n", 23 | "import numpy as np\n", 24 | "import pandas as pd\n", 25 | "import dask.dataframe as dd" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "f172527c", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "ROOT_DATA_DIR = '/data'\n", 36 | "dataset = MAG240MDataset(root = ROOT_DATA_DIR)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "7da068c4", 42 | "metadata": {}, 43 | "source": [ 44 | "### ID Offsets\n", 45 | "\n", 46 | "To create a universal long id for nodes, we will offset the ogb index values for authors and institutions. This avoids id collisions across node labels. \n", 47 | "\n", 48 | "Note that Neo4j also allows for the partitioning of unique id by label: https://neo4j.com/docs/operations-manual/current/tools/neo4j-admin/neo4j-admin-import/#import-tool-id-spaces. My decision to offset ids this way to achieve a global node index is merely personal preference. " 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "6719c818", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "AUTHOR_ID_OFFSET = 10_000_000_000 #10B\n", 59 | "INSTITUTION_ID_OFFSET = AUTHOR_ID_OFFSET + 10_000_000_000 #10B" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "41435577", 65 | "metadata": {}, 66 | "source": [ 67 | "## Prepare Author Data" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "id": "d2f57601", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "#get edge index into dask df\n", 78 | "writes_ddf = dd.from_pandas(pd.DataFrame(dataset.edge_index('author', 'paper').T,\n", 79 | " columns = [\":START_ID\",\":END_ID\"]), npartitions=100)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "id": "ee044d60", 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/html": [ 91 | "
\n", 92 | "\n", 105 | "\n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | "
:START_ID:END_ID
01000000000017776550
11000000000022232787
21000000000022359844
31000000000034644458
41000000000059079951
.........
3860227151012238310783933134
3860227161012238310899252845
3860227171012238310999252845
3860227181012238311099252845
3860227191012238311199252845
\n", 171 | "

386022720 rows × 2 columns

\n", 172 | "
" 173 | ], 174 | "text/plain": [ 175 | " :START_ID :END_ID\n", 176 | "0 10000000000 17776550\n", 177 | "1 10000000000 22232787\n", 178 | "2 10000000000 22359844\n", 179 | "3 10000000000 34644458\n", 180 | "4 10000000000 59079951\n", 181 | "... ... ...\n", 182 | "386022715 10122383107 83933134\n", 183 | "386022716 10122383108 99252845\n", 184 | "386022717 10122383109 99252845\n", 185 | "386022718 10122383110 99252845\n", 186 | "386022719 10122383111 99252845\n", 187 | "\n", 188 | "[386022720 rows x 2 columns]" 189 | ] 190 | }, 191 | "execution_count": 5, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | } 195 | ], 196 | "source": [ 197 | "# offset author index (start id)\n", 198 | "writes_ddf[':START_ID'] = writes_ddf[':START_ID']+ AUTHOR_ID_OFFSET\n", 199 | "writes_ddf.compute()" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 6, 205 | "id": "82e17e36", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stderr", 210 | "output_type": "stream", 211 | "text": [ 212 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/dask/dataframe/io/csv.py:916: FutureWarning: The 'scheduler' keyword argument for `to_csv()` is deprecated andwill be removed in a future version. Please use the `compute_kwargs` argument instead. For example, df.to_csv(..., compute_kwargs={scheduler: processes})\n", 213 | " warn(\n" 214 | ] 215 | }, 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "['/data/demo-load/writes-00.csv',\n", 220 | " '/data/demo-load/writes-01.csv',\n", 221 | " '/data/demo-load/writes-02.csv',\n", 222 | " '/data/demo-load/writes-03.csv',\n", 223 | " '/data/demo-load/writes-04.csv',\n", 224 | " '/data/demo-load/writes-05.csv',\n", 225 | " '/data/demo-load/writes-06.csv',\n", 226 | " '/data/demo-load/writes-07.csv',\n", 227 | " '/data/demo-load/writes-08.csv',\n", 228 | " '/data/demo-load/writes-09.csv',\n", 229 | " '/data/demo-load/writes-10.csv',\n", 230 | " '/data/demo-load/writes-11.csv',\n", 231 | " '/data/demo-load/writes-12.csv',\n", 232 | " '/data/demo-load/writes-13.csv',\n", 233 | " '/data/demo-load/writes-14.csv',\n", 234 | " '/data/demo-load/writes-15.csv',\n", 235 | " '/data/demo-load/writes-16.csv',\n", 236 | " '/data/demo-load/writes-17.csv',\n", 237 | " '/data/demo-load/writes-18.csv',\n", 238 | " '/data/demo-load/writes-19.csv',\n", 239 | " '/data/demo-load/writes-20.csv',\n", 240 | " '/data/demo-load/writes-21.csv',\n", 241 | " '/data/demo-load/writes-22.csv',\n", 242 | " '/data/demo-load/writes-23.csv',\n", 243 | " '/data/demo-load/writes-24.csv',\n", 244 | " '/data/demo-load/writes-25.csv',\n", 245 | " '/data/demo-load/writes-26.csv',\n", 246 | " '/data/demo-load/writes-27.csv',\n", 247 | " '/data/demo-load/writes-28.csv',\n", 248 | " '/data/demo-load/writes-29.csv',\n", 249 | " '/data/demo-load/writes-30.csv',\n", 250 | " '/data/demo-load/writes-31.csv',\n", 251 | " '/data/demo-load/writes-32.csv',\n", 252 | " '/data/demo-load/writes-33.csv',\n", 253 | " '/data/demo-load/writes-34.csv',\n", 254 | " '/data/demo-load/writes-35.csv',\n", 255 | " '/data/demo-load/writes-36.csv',\n", 256 | " '/data/demo-load/writes-37.csv',\n", 257 | " '/data/demo-load/writes-38.csv',\n", 258 | " '/data/demo-load/writes-39.csv',\n", 259 | " '/data/demo-load/writes-40.csv',\n", 260 | " '/data/demo-load/writes-41.csv',\n", 261 | " '/data/demo-load/writes-42.csv',\n", 262 | " '/data/demo-load/writes-43.csv',\n", 263 | " '/data/demo-load/writes-44.csv',\n", 264 | " '/data/demo-load/writes-45.csv',\n", 265 | " '/data/demo-load/writes-46.csv',\n", 266 | " '/data/demo-load/writes-47.csv',\n", 267 | " '/data/demo-load/writes-48.csv',\n", 268 | " '/data/demo-load/writes-49.csv',\n", 269 | " '/data/demo-load/writes-50.csv',\n", 270 | " '/data/demo-load/writes-51.csv',\n", 271 | " '/data/demo-load/writes-52.csv',\n", 272 | " '/data/demo-load/writes-53.csv',\n", 273 | " '/data/demo-load/writes-54.csv',\n", 274 | " '/data/demo-load/writes-55.csv',\n", 275 | " '/data/demo-load/writes-56.csv',\n", 276 | " '/data/demo-load/writes-57.csv',\n", 277 | " '/data/demo-load/writes-58.csv',\n", 278 | " '/data/demo-load/writes-59.csv',\n", 279 | " '/data/demo-load/writes-60.csv',\n", 280 | " '/data/demo-load/writes-61.csv',\n", 281 | " '/data/demo-load/writes-62.csv',\n", 282 | " '/data/demo-load/writes-63.csv',\n", 283 | " '/data/demo-load/writes-64.csv',\n", 284 | " '/data/demo-load/writes-65.csv',\n", 285 | " '/data/demo-load/writes-66.csv',\n", 286 | " '/data/demo-load/writes-67.csv',\n", 287 | " '/data/demo-load/writes-68.csv',\n", 288 | " '/data/demo-load/writes-69.csv',\n", 289 | " '/data/demo-load/writes-70.csv',\n", 290 | " '/data/demo-load/writes-71.csv',\n", 291 | " '/data/demo-load/writes-72.csv',\n", 292 | " '/data/demo-load/writes-73.csv',\n", 293 | " '/data/demo-load/writes-74.csv',\n", 294 | " '/data/demo-load/writes-75.csv',\n", 295 | " '/data/demo-load/writes-76.csv',\n", 296 | " '/data/demo-load/writes-77.csv',\n", 297 | " '/data/demo-load/writes-78.csv',\n", 298 | " '/data/demo-load/writes-79.csv',\n", 299 | " '/data/demo-load/writes-80.csv',\n", 300 | " '/data/demo-load/writes-81.csv',\n", 301 | " '/data/demo-load/writes-82.csv',\n", 302 | " '/data/demo-load/writes-83.csv',\n", 303 | " '/data/demo-load/writes-84.csv',\n", 304 | " '/data/demo-load/writes-85.csv',\n", 305 | " '/data/demo-load/writes-86.csv',\n", 306 | " '/data/demo-load/writes-87.csv',\n", 307 | " '/data/demo-load/writes-88.csv',\n", 308 | " '/data/demo-load/writes-89.csv',\n", 309 | " '/data/demo-load/writes-90.csv',\n", 310 | " '/data/demo-load/writes-91.csv',\n", 311 | " '/data/demo-load/writes-92.csv',\n", 312 | " '/data/demo-load/writes-93.csv',\n", 313 | " '/data/demo-load/writes-94.csv',\n", 314 | " '/data/demo-load/writes-95.csv',\n", 315 | " '/data/demo-load/writes-96.csv',\n", 316 | " '/data/demo-load/writes-97.csv',\n", 317 | " '/data/demo-load/writes-98.csv',\n", 318 | " '/data/demo-load/writes-99.csv']" 319 | ] 320 | }, 321 | "execution_count": 6, 322 | "metadata": {}, 323 | "output_type": "execute_result" 324 | } 325 | ], 326 | "source": [ 327 | "# write edge indexes (to test area)\n", 328 | "writes_ddf.to_csv(f'{ROOT_DATA_DIR}/demo-load/writes-*.csv', \n", 329 | " header_first_partition_only=True, index=False, scheduler=\"processes\")" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 7, 335 | "id": "4787299d", 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "data": { 340 | "text/html": [ 341 | "
\n", 342 | "\n", 355 | "\n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | "
ogbIndex:ID
010000000000
1310000000001
2710000000002
19010000000003
21210000000004
......
38602271510122383107
38602271610122383108
38602271710122383109
38602271810122383110
38602271910122383111
\n", 409 | "

122383112 rows × 1 columns

\n", 410 | "
" 411 | ], 412 | "text/plain": [ 413 | " ogbIndex:ID\n", 414 | "0 10000000000\n", 415 | "13 10000000001\n", 416 | "27 10000000002\n", 417 | "190 10000000003\n", 418 | "212 10000000004\n", 419 | "... ...\n", 420 | "386022715 10122383107\n", 421 | "386022716 10122383108\n", 422 | "386022717 10122383109\n", 423 | "386022718 10122383110\n", 424 | "386022719 10122383111\n", 425 | "\n", 426 | "[122383112 rows x 1 columns]" 427 | ] 428 | }, 429 | "execution_count": 7, 430 | "metadata": {}, 431 | "output_type": "execute_result" 432 | } 433 | ], 434 | "source": [ 435 | "# get nodes out and deduped\n", 436 | "authors_ddf = writes_ddf.drop([\":END_ID\"], axis=1).drop_duplicates(subset=[\":START_ID\"]).repartition(npartitions=100)\\\n", 437 | " .rename(columns={\":START_ID\":\"ogbIndex:ID\"})\n", 438 | "authors_ddf.compute()" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 8, 444 | "id": "35189b24", 445 | "metadata": {}, 446 | "outputs": [ 447 | { 448 | "data": { 449 | "text/plain": [ 450 | "['/data/demo-load/authors-00.csv',\n", 451 | " '/data/demo-load/authors-01.csv',\n", 452 | " '/data/demo-load/authors-02.csv',\n", 453 | " '/data/demo-load/authors-03.csv',\n", 454 | " '/data/demo-load/authors-04.csv',\n", 455 | " '/data/demo-load/authors-05.csv',\n", 456 | " '/data/demo-load/authors-06.csv',\n", 457 | " '/data/demo-load/authors-07.csv',\n", 458 | " '/data/demo-load/authors-08.csv',\n", 459 | " '/data/demo-load/authors-09.csv',\n", 460 | " '/data/demo-load/authors-10.csv',\n", 461 | " '/data/demo-load/authors-11.csv',\n", 462 | " '/data/demo-load/authors-12.csv',\n", 463 | " '/data/demo-load/authors-13.csv',\n", 464 | " '/data/demo-load/authors-14.csv',\n", 465 | " '/data/demo-load/authors-15.csv',\n", 466 | " '/data/demo-load/authors-16.csv',\n", 467 | " '/data/demo-load/authors-17.csv',\n", 468 | " '/data/demo-load/authors-18.csv',\n", 469 | " '/data/demo-load/authors-19.csv',\n", 470 | " '/data/demo-load/authors-20.csv',\n", 471 | " '/data/demo-load/authors-21.csv',\n", 472 | " '/data/demo-load/authors-22.csv',\n", 473 | " '/data/demo-load/authors-23.csv',\n", 474 | " '/data/demo-load/authors-24.csv',\n", 475 | " '/data/demo-load/authors-25.csv',\n", 476 | " '/data/demo-load/authors-26.csv',\n", 477 | " '/data/demo-load/authors-27.csv',\n", 478 | " '/data/demo-load/authors-28.csv',\n", 479 | " '/data/demo-load/authors-29.csv',\n", 480 | " '/data/demo-load/authors-30.csv',\n", 481 | " '/data/demo-load/authors-31.csv',\n", 482 | " '/data/demo-load/authors-32.csv',\n", 483 | " '/data/demo-load/authors-33.csv',\n", 484 | " '/data/demo-load/authors-34.csv',\n", 485 | " '/data/demo-load/authors-35.csv',\n", 486 | " '/data/demo-load/authors-36.csv',\n", 487 | " '/data/demo-load/authors-37.csv',\n", 488 | " '/data/demo-load/authors-38.csv',\n", 489 | " '/data/demo-load/authors-39.csv',\n", 490 | " '/data/demo-load/authors-40.csv',\n", 491 | " '/data/demo-load/authors-41.csv',\n", 492 | " '/data/demo-load/authors-42.csv',\n", 493 | " '/data/demo-load/authors-43.csv',\n", 494 | " '/data/demo-load/authors-44.csv',\n", 495 | " '/data/demo-load/authors-45.csv',\n", 496 | " '/data/demo-load/authors-46.csv',\n", 497 | " '/data/demo-load/authors-47.csv',\n", 498 | " '/data/demo-load/authors-48.csv',\n", 499 | " '/data/demo-load/authors-49.csv',\n", 500 | " '/data/demo-load/authors-50.csv',\n", 501 | " '/data/demo-load/authors-51.csv',\n", 502 | " '/data/demo-load/authors-52.csv',\n", 503 | " '/data/demo-load/authors-53.csv',\n", 504 | " '/data/demo-load/authors-54.csv',\n", 505 | " '/data/demo-load/authors-55.csv',\n", 506 | " '/data/demo-load/authors-56.csv',\n", 507 | " '/data/demo-load/authors-57.csv',\n", 508 | " '/data/demo-load/authors-58.csv',\n", 509 | " '/data/demo-load/authors-59.csv',\n", 510 | " '/data/demo-load/authors-60.csv',\n", 511 | " '/data/demo-load/authors-61.csv',\n", 512 | " '/data/demo-load/authors-62.csv',\n", 513 | " '/data/demo-load/authors-63.csv',\n", 514 | " '/data/demo-load/authors-64.csv',\n", 515 | " '/data/demo-load/authors-65.csv',\n", 516 | " '/data/demo-load/authors-66.csv',\n", 517 | " '/data/demo-load/authors-67.csv',\n", 518 | " '/data/demo-load/authors-68.csv',\n", 519 | " '/data/demo-load/authors-69.csv',\n", 520 | " '/data/demo-load/authors-70.csv',\n", 521 | " '/data/demo-load/authors-71.csv',\n", 522 | " '/data/demo-load/authors-72.csv',\n", 523 | " '/data/demo-load/authors-73.csv',\n", 524 | " '/data/demo-load/authors-74.csv',\n", 525 | " '/data/demo-load/authors-75.csv',\n", 526 | " '/data/demo-load/authors-76.csv',\n", 527 | " '/data/demo-load/authors-77.csv',\n", 528 | " '/data/demo-load/authors-78.csv',\n", 529 | " '/data/demo-load/authors-79.csv',\n", 530 | " '/data/demo-load/authors-80.csv',\n", 531 | " '/data/demo-load/authors-81.csv',\n", 532 | " '/data/demo-load/authors-82.csv',\n", 533 | " '/data/demo-load/authors-83.csv',\n", 534 | " '/data/demo-load/authors-84.csv',\n", 535 | " '/data/demo-load/authors-85.csv',\n", 536 | " '/data/demo-load/authors-86.csv',\n", 537 | " '/data/demo-load/authors-87.csv',\n", 538 | " '/data/demo-load/authors-88.csv',\n", 539 | " '/data/demo-load/authors-89.csv',\n", 540 | " '/data/demo-load/authors-90.csv',\n", 541 | " '/data/demo-load/authors-91.csv',\n", 542 | " '/data/demo-load/authors-92.csv',\n", 543 | " '/data/demo-load/authors-93.csv',\n", 544 | " '/data/demo-load/authors-94.csv',\n", 545 | " '/data/demo-load/authors-95.csv',\n", 546 | " '/data/demo-load/authors-96.csv',\n", 547 | " '/data/demo-load/authors-97.csv',\n", 548 | " '/data/demo-load/authors-98.csv',\n", 549 | " '/data/demo-load/authors-99.csv']" 550 | ] 551 | }, 552 | "execution_count": 8, 553 | "metadata": {}, 554 | "output_type": "execute_result" 555 | } 556 | ], 557 | "source": [ 558 | "# write nodes\n", 559 | "authors_ddf.to_csv(f'{ROOT_DATA_DIR}/demo-load/authors-*.csv', \n", 560 | " header_first_partition_only=True, index=False, scheduler=\"processes\")" 561 | ] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "id": "f617c4ad", 566 | "metadata": {}, 567 | "source": [ 568 | "## Prepare Institution Data" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 9, 574 | "id": "a328081a", 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "#get edge index into dask df\n", 579 | "affiliated_with_ddf = dd.from_pandas(pd.DataFrame(dataset.edge_index('author', 'institution').T, \n", 580 | " columns = [\":START_ID\",\":END_ID\"]), npartitions=100)" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 10, 586 | "id": "6548f11c", 587 | "metadata": {}, 588 | "outputs": [ 589 | { 590 | "data": { 591 | "text/html": [ 592 | "
\n", 593 | "\n", 606 | "\n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | "
:START_ID:END_ID
01000000000220000000000
11000000000320000000000
21000000000420000000245
31000000000920000000001
41000000001020000000002
.........
445925811012238307520000000641
445925821012238308220000000720
445925831012238308620000000577
445925841012238309420000004054
445925851012238310420000008955
\n", 672 | "

44592586 rows × 2 columns

\n", 673 | "
" 674 | ], 675 | "text/plain": [ 676 | " :START_ID :END_ID\n", 677 | "0 10000000002 20000000000\n", 678 | "1 10000000003 20000000000\n", 679 | "2 10000000004 20000000245\n", 680 | "3 10000000009 20000000001\n", 681 | "4 10000000010 20000000002\n", 682 | "... ... ...\n", 683 | "44592581 10122383075 20000000641\n", 684 | "44592582 10122383082 20000000720\n", 685 | "44592583 10122383086 20000000577\n", 686 | "44592584 10122383094 20000004054\n", 687 | "44592585 10122383104 20000008955\n", 688 | "\n", 689 | "[44592586 rows x 2 columns]" 690 | ] 691 | }, 692 | "execution_count": 10, 693 | "metadata": {}, 694 | "output_type": "execute_result" 695 | } 696 | ], 697 | "source": [ 698 | "# offset author and institution index (start id)\n", 699 | "affiliated_with_ddf[':START_ID'] = affiliated_with_ddf[':START_ID'] + AUTHOR_ID_OFFSET\n", 700 | "affiliated_with_ddf[':END_ID'] = affiliated_with_ddf[':END_ID'] + INSTITUTION_ID_OFFSET\n", 701 | "affiliated_with_ddf.compute()" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 11, 707 | "id": "85d5078d", 708 | "metadata": {}, 709 | "outputs": [ 710 | { 711 | "data": { 712 | "text/plain": [ 713 | "['/data/demo-load/affiliated_with-00.csv',\n", 714 | " '/data/demo-load/affiliated_with-01.csv',\n", 715 | " '/data/demo-load/affiliated_with-02.csv',\n", 716 | " '/data/demo-load/affiliated_with-03.csv',\n", 717 | " '/data/demo-load/affiliated_with-04.csv',\n", 718 | " '/data/demo-load/affiliated_with-05.csv',\n", 719 | " '/data/demo-load/affiliated_with-06.csv',\n", 720 | " '/data/demo-load/affiliated_with-07.csv',\n", 721 | " '/data/demo-load/affiliated_with-08.csv',\n", 722 | " '/data/demo-load/affiliated_with-09.csv',\n", 723 | " '/data/demo-load/affiliated_with-10.csv',\n", 724 | " '/data/demo-load/affiliated_with-11.csv',\n", 725 | " '/data/demo-load/affiliated_with-12.csv',\n", 726 | " '/data/demo-load/affiliated_with-13.csv',\n", 727 | " '/data/demo-load/affiliated_with-14.csv',\n", 728 | " '/data/demo-load/affiliated_with-15.csv',\n", 729 | " '/data/demo-load/affiliated_with-16.csv',\n", 730 | " '/data/demo-load/affiliated_with-17.csv',\n", 731 | " '/data/demo-load/affiliated_with-18.csv',\n", 732 | " '/data/demo-load/affiliated_with-19.csv',\n", 733 | " '/data/demo-load/affiliated_with-20.csv',\n", 734 | " '/data/demo-load/affiliated_with-21.csv',\n", 735 | " '/data/demo-load/affiliated_with-22.csv',\n", 736 | " '/data/demo-load/affiliated_with-23.csv',\n", 737 | " '/data/demo-load/affiliated_with-24.csv',\n", 738 | " '/data/demo-load/affiliated_with-25.csv',\n", 739 | " '/data/demo-load/affiliated_with-26.csv',\n", 740 | " '/data/demo-load/affiliated_with-27.csv',\n", 741 | " '/data/demo-load/affiliated_with-28.csv',\n", 742 | " '/data/demo-load/affiliated_with-29.csv',\n", 743 | " '/data/demo-load/affiliated_with-30.csv',\n", 744 | " '/data/demo-load/affiliated_with-31.csv',\n", 745 | " '/data/demo-load/affiliated_with-32.csv',\n", 746 | " '/data/demo-load/affiliated_with-33.csv',\n", 747 | " '/data/demo-load/affiliated_with-34.csv',\n", 748 | " '/data/demo-load/affiliated_with-35.csv',\n", 749 | " '/data/demo-load/affiliated_with-36.csv',\n", 750 | " '/data/demo-load/affiliated_with-37.csv',\n", 751 | " '/data/demo-load/affiliated_with-38.csv',\n", 752 | " '/data/demo-load/affiliated_with-39.csv',\n", 753 | " '/data/demo-load/affiliated_with-40.csv',\n", 754 | " '/data/demo-load/affiliated_with-41.csv',\n", 755 | " '/data/demo-load/affiliated_with-42.csv',\n", 756 | " '/data/demo-load/affiliated_with-43.csv',\n", 757 | " '/data/demo-load/affiliated_with-44.csv',\n", 758 | " '/data/demo-load/affiliated_with-45.csv',\n", 759 | " '/data/demo-load/affiliated_with-46.csv',\n", 760 | " '/data/demo-load/affiliated_with-47.csv',\n", 761 | " '/data/demo-load/affiliated_with-48.csv',\n", 762 | " '/data/demo-load/affiliated_with-49.csv',\n", 763 | " '/data/demo-load/affiliated_with-50.csv',\n", 764 | " '/data/demo-load/affiliated_with-51.csv',\n", 765 | " '/data/demo-load/affiliated_with-52.csv',\n", 766 | " '/data/demo-load/affiliated_with-53.csv',\n", 767 | " '/data/demo-load/affiliated_with-54.csv',\n", 768 | " '/data/demo-load/affiliated_with-55.csv',\n", 769 | " '/data/demo-load/affiliated_with-56.csv',\n", 770 | " '/data/demo-load/affiliated_with-57.csv',\n", 771 | " '/data/demo-load/affiliated_with-58.csv',\n", 772 | " '/data/demo-load/affiliated_with-59.csv',\n", 773 | " '/data/demo-load/affiliated_with-60.csv',\n", 774 | " '/data/demo-load/affiliated_with-61.csv',\n", 775 | " '/data/demo-load/affiliated_with-62.csv',\n", 776 | " '/data/demo-load/affiliated_with-63.csv',\n", 777 | " '/data/demo-load/affiliated_with-64.csv',\n", 778 | " '/data/demo-load/affiliated_with-65.csv',\n", 779 | " '/data/demo-load/affiliated_with-66.csv',\n", 780 | " '/data/demo-load/affiliated_with-67.csv',\n", 781 | " '/data/demo-load/affiliated_with-68.csv',\n", 782 | " '/data/demo-load/affiliated_with-69.csv',\n", 783 | " '/data/demo-load/affiliated_with-70.csv',\n", 784 | " '/data/demo-load/affiliated_with-71.csv',\n", 785 | " '/data/demo-load/affiliated_with-72.csv',\n", 786 | " '/data/demo-load/affiliated_with-73.csv',\n", 787 | " '/data/demo-load/affiliated_with-74.csv',\n", 788 | " '/data/demo-load/affiliated_with-75.csv',\n", 789 | " '/data/demo-load/affiliated_with-76.csv',\n", 790 | " '/data/demo-load/affiliated_with-77.csv',\n", 791 | " '/data/demo-load/affiliated_with-78.csv',\n", 792 | " '/data/demo-load/affiliated_with-79.csv',\n", 793 | " '/data/demo-load/affiliated_with-80.csv',\n", 794 | " '/data/demo-load/affiliated_with-81.csv',\n", 795 | " '/data/demo-load/affiliated_with-82.csv',\n", 796 | " '/data/demo-load/affiliated_with-83.csv',\n", 797 | " '/data/demo-load/affiliated_with-84.csv',\n", 798 | " '/data/demo-load/affiliated_with-85.csv',\n", 799 | " '/data/demo-load/affiliated_with-86.csv',\n", 800 | " '/data/demo-load/affiliated_with-87.csv',\n", 801 | " '/data/demo-load/affiliated_with-88.csv',\n", 802 | " '/data/demo-load/affiliated_with-89.csv',\n", 803 | " '/data/demo-load/affiliated_with-90.csv',\n", 804 | " '/data/demo-load/affiliated_with-91.csv',\n", 805 | " '/data/demo-load/affiliated_with-92.csv',\n", 806 | " '/data/demo-load/affiliated_with-93.csv',\n", 807 | " '/data/demo-load/affiliated_with-94.csv',\n", 808 | " '/data/demo-load/affiliated_with-95.csv',\n", 809 | " '/data/demo-load/affiliated_with-96.csv',\n", 810 | " '/data/demo-load/affiliated_with-97.csv',\n", 811 | " '/data/demo-load/affiliated_with-98.csv',\n", 812 | " '/data/demo-load/affiliated_with-99.csv']" 813 | ] 814 | }, 815 | "execution_count": 11, 816 | "metadata": {}, 817 | "output_type": "execute_result" 818 | } 819 | ], 820 | "source": [ 821 | "# write edge indexes (to test area)\n", 822 | "affiliated_with_ddf.to_csv(f'{ROOT_DATA_DIR}/demo-load/affiliated_with-*.csv', \n", 823 | " header_first_partition_only=True, index=False, scheduler=\"processes\")" 824 | ] 825 | }, 826 | { 827 | "cell_type": "code", 828 | "execution_count": 12, 829 | "id": "a74c5974", 830 | "metadata": {}, 831 | "outputs": [ 832 | { 833 | "data": { 834 | "text/html": [ 835 | "
\n", 836 | "\n", 849 | "\n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | "
ogbIndex:ID
020000000000
220000000245
320000000001
420000000002
720000000557
......
4440576220000025715
4449349520000025716
4450220320000025717
4450807020000025718
4457646520000025720
\n", 903 | "

25721 rows × 1 columns

\n", 904 | "
" 905 | ], 906 | "text/plain": [ 907 | " ogbIndex:ID\n", 908 | "0 20000000000\n", 909 | "2 20000000245\n", 910 | "3 20000000001\n", 911 | "4 20000000002\n", 912 | "7 20000000557\n", 913 | "... ...\n", 914 | "44405762 20000025715\n", 915 | "44493495 20000025716\n", 916 | "44502203 20000025717\n", 917 | "44508070 20000025718\n", 918 | "44576465 20000025720\n", 919 | "\n", 920 | "[25721 rows x 1 columns]" 921 | ] 922 | }, 923 | "execution_count": 12, 924 | "metadata": {}, 925 | "output_type": "execute_result" 926 | } 927 | ], 928 | "source": [ 929 | "# get nodes out and deduped\n", 930 | "institutions_ddf = affiliated_with_ddf.drop([\":START_ID\"], axis=1)\\\n", 931 | " .drop_duplicates(subset=[\":END_ID\"]).rename(columns={\":END_ID\":\"ogbIndex:ID\"})\n", 932 | "institutions_ddf.compute()" 933 | ] 934 | }, 935 | { 936 | "cell_type": "code", 937 | "execution_count": 13, 938 | "id": "5ca26e43", 939 | "metadata": {}, 940 | "outputs": [ 941 | { 942 | "data": { 943 | "text/plain": [ 944 | "['/data/demo-load/institution-0.csv']" 945 | ] 946 | }, 947 | "execution_count": 13, 948 | "metadata": {}, 949 | "output_type": "execute_result" 950 | } 951 | ], 952 | "source": [ 953 | "# write nodes\n", 954 | "institutions_ddf.to_csv(f'{ROOT_DATA_DIR}/demo-load/institution-*.csv', \n", 955 | " header_first_partition_only=True, index=False, scheduler=\"processes\")" 956 | ] 957 | }, 958 | { 959 | "cell_type": "code", 960 | "execution_count": null, 961 | "id": "acf1cbb6", 962 | "metadata": {}, 963 | "outputs": [], 964 | "source": [] 965 | } 966 | ], 967 | "metadata": { 968 | "kernelspec": { 969 | "display_name": "Python 3 (ipykernel)", 970 | "language": "python", 971 | "name": "python3" 972 | }, 973 | "language_info": { 974 | "codemirror_mode": { 975 | "name": "ipython", 976 | "version": 3 977 | }, 978 | "file_extension": ".py", 979 | "mimetype": "text/x-python", 980 | "name": "python", 981 | "nbconvert_exporter": "python", 982 | "pygments_lexer": "ipython3", 983 | "version": "3.9.7" 984 | } 985 | }, 986 | "nbformat": 4, 987 | "nbformat_minor": 5 988 | } 989 | -------------------------------------------------------------------------------- /part5-admin-import.md: -------------------------------------------------------------------------------- 1 | # Part 5: Neo4j Admin Import 2 | 3 | 4 | This section covers actually running the admin import command. Admin import is a command line tool so no Python needed here, we can use the terminal. 5 | 6 | For our example it can be run like so from the shell on the local neo4j instance: 7 | 8 | ```bash 9 | bin/neo4j-admin import --database=ogblsc --id-type=INTEGER \ 10 | --nodes=Paper=/data/demo-load/paper-c\\d-\\d+.csv \ 11 | --nodes=Author=/data/demo-load/authors-\\d+.csv \ 12 | --nodes=Institution=/data/demo-load/institution-\\d+.csv \ 13 | --relationships=CITES=/data/demo-load/cited-\\d+.csv \ 14 | --relationships=WRITES=/data/demo-load/writes-\\d+.csv \ 15 | --relationships=AFFILIATED_WITH=/data/demo-load/affiliated_with-\\d+.csv 16 | ``` 17 | 18 | The above creates a new database called 'ogblsc' then loads and links all the nodes and relationships from the csv files. The regex in the csv names `\\d+`, this matches digit patterns so admin import picks up all the csvs we created in parts 3 and 4, i.e. `authors-00.csv`, `authors-01.csv`,...etc. 19 | 20 | Please ensure the root directory for these files aligns if you changed it in previous notebooks or moved to a different location. Also Note: If you decrease the chunk size from 20 Million in part 3, you may need to tweak the regex in the paper's csv line above to accommodate two digits after the ‘c’. 21 | 22 | Detailed Documentation for Admin Import can be found at https://neo4j.com/docs/operations-manual/current/tutorial/neo4j-admin-import/ -------------------------------------------------------------------------------- /part7-graph-feature-engineering-in-gds.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b921f6a6", 6 | "metadata": {}, 7 | "source": [ 8 | "# Part 7: Feature Engineering in Neo4j and GDS\n", 9 | "\n", 10 | "This notebook covers:\n", 11 | "\n", 12 | "1. Native Graph Projection with Properties\n", 13 | "2. Generating FastRP Features\n", 14 | "3. Subgraph Projection and Data Export" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "e0d3a4b2", 20 | "metadata": {}, 21 | "source": [ 22 | "## Connection Setup and Helper Functions" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "id": "a079f865", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from neo4j import GraphDatabase\n", 33 | "HOST = 'neo4j://localhost:7687'\n", 34 | "USERNAME = 'neo4j'\n", 35 | "DATABASE = 'ogblsc'\n", 36 | "PASSWORD = 'neo'\n", 37 | "\n", 38 | "def run(driver, query, params=None):\n", 39 | " with driver.session(database=DATABASE) as session:\n", 40 | " if params is not None:\n", 41 | " return [r for r in session.run(query, params)]\n", 42 | " else:\n", 43 | " return [r for r in session.run(query)]\n", 44 | "\n", 45 | "def clear_graph(driver, graph_name):\n", 46 | " if run(driver, f\"CALL gds.graph.exists('{graph_name}') YIELD exists RETURN exists\")[0].get(\"exists\"):\n", 47 | " run(driver, f\"CALL gds.graph.drop('{graph_name}')\")\n", 48 | "\n", 49 | "def clear_all_graphs(driver):\n", 50 | " graphs = run(driver, 'CALL gds.graph.list() YIELD graphName RETURN collect(graphName) as graphs')[0].get('graphs')\n", 51 | " for g in graphs:\n", 52 | " run(driver, f\"CALL gds.graph.drop('{g}')\")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "id": "68ec8667", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "driver = GraphDatabase.driver(HOST, auth=(USERNAME, PASSWORD))" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "69ca54e2", 68 | "metadata": {}, 69 | "source": [ 70 | "## Native Graph Projection with Properties\n", 71 | "\n", 72 | "We will project just the Paper nodes and CITES relationships for purposes of this demo." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "id": "69254d36", 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "[]" 85 | ] 86 | }, 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "run(driver, '''\n", 94 | " CALL gds.graph.create('proj-features',\n", 95 | " {Paper:{properties: ['subject', 'encoding']}},\n", 96 | " {CITES:{orientation:'UNDIRECTED'}},\n", 97 | " {readConcurrency: 60}\n", 98 | " ) YIELD nodeCount, relationshipCount, createMillis\n", 99 | "''')" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "c3e3832b", 105 | "metadata": {}, 106 | "source": [ 107 | "## Generating FastRP Features\n", 108 | "\n", 109 | "Fast Random Projection, or FastRP for short, is a node embedding algorithm. Node embedding algorithms compute low-dimensional vector representations of nodes in a graph. These vectors, also called embeddings, can be used as features for machine learning models among other tasks such as visualization and EDA.\n", 110 | "\n", 111 | "FastRP leverages the concept of sparse projections to significantly scale the computation of embeddings on larger graphs. More information can be found in [our documentation](https://neo4j.com/docs/graph-data-science/current/algorithms/fastrp/).\n", 112 | "\n", 113 | "In our example below we will choose to use a `propetyRatio`of 50% which basically initializes 50% of the embedding vectors with a linear combination of the RoBERTa components as weights. In layman's terms, we are basically using a combination of both the graph structure and the NLP encodings to generate (hopefully predictive) node features. " 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "id": "de1bc605", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "[]" 126 | ] 127 | }, 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "run(driver, '''\n", 135 | " CALL gds.fastRP.mutate('proj-features',\n", 136 | " {\n", 137 | " embeddingDimension: 256,\n", 138 | " randomSeed: 7474,\n", 139 | " propertyRatio: 0.5,\n", 140 | " featureProperties: ['encoding'],\n", 141 | " mutateProperty: 'embedding',\n", 142 | " concurrency: 60\n", 143 | " }\n", 144 | " ) YIELD nodePropertiesWritten, createMillis, computeMillis, mutateMillis\n", 145 | "''')" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "dfb0a1e6", 151 | "metadata": {}, 152 | "source": [ 153 | "## Subgraph Projection and Data Export\n", 154 | "\n", 155 | "To test predicting subject labels with the new (FastRP) graph features, we only need to export the fraction of papers with known labels. We can use a subgraph projection to filter down to these papers. We can then export the subgraph to csv." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 5, 161 | "id": "657e0cef", 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "[]" 168 | ] 169 | }, 170 | "execution_count": 5, 171 | "metadata": {}, 172 | "output_type": "execute_result" 173 | } 174 | ], 175 | "source": [ 176 | "# subgraph projection\n", 177 | "run(driver, '''\n", 178 | " CALL gds.beta.graph.create.subgraph(\n", 179 | " 'proj-features-labeled',\n", 180 | " 'proj-features',\n", 181 | " 'n.subject > -1',\n", 182 | " '*',\n", 183 | " {concurrency: 60}\n", 184 | " ) YIELD nodeCount, createMillis\n", 185 | "''')" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 7, 191 | "id": "c709cbd3", 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "text/plain": [ 197 | "[]" 198 | ] 199 | }, 200 | "execution_count": 7, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | } 204 | ], 205 | "source": [ 206 | "# csv export\n", 207 | "run(driver, '''\n", 208 | " CALL gds.beta.graph.export.csv('proj-features-labeled', {\n", 209 | " exportName: 'proj-features-labeled',\n", 210 | " additionalNodeProperties: ['ogbIndex', 'split_segment', 'subject_status', 'year'],\n", 211 | " writeConcurrency: 16\n", 212 | " }) YIELD writeMillis\n", 213 | "''')" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 8, 219 | "id": "e54b64f9", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "#remove graph projections to clean up\n", 224 | "clear_graph(driver, 'proj-features-labeled')\n", 225 | "clear_graph(driver, 'proj-features')" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "id": "4850496b", 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [] 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": "Python 3 (ipykernel)", 240 | "language": "python", 241 | "name": "python3" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 3 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython3", 253 | "version": "3.9.7" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 5 258 | } 259 | -------------------------------------------------------------------------------- /part8-graph-feature-model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ffb4e37f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Part 8: \"Graph\" Model - Logistic Regression with FastRP Embeddings" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "6664d358", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "from sklearn.linear_model import LogisticRegression\n", 21 | "import glob" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "6c2b5937", 27 | "metadata": {}, 28 | "source": [ 29 | "## Data Load and Formatting" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "2104a91b", 36 | "metadata": { 37 | "scrolled": false 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "path = '/data/neo-export/export/proj-features-labeled/' \n", 42 | "all_files = glob.glob(path + \"nodes_Paper_[0-9]*.csv\")\n", 43 | "\n", 44 | "li = []\n", 45 | "\n", 46 | "for filename in all_files:\n", 47 | " df = pd.read_csv(filename, header=None, \n", 48 | " names = [\"nodeId\", \"embedding\",\"encoding\",\"ogbIndex\", \"split_segment\",\"subject\",\n", 49 | " \"subject_status\",\"year\"])\n", 50 | " li.append(df)\n", 51 | "\n", 52 | "\n", 53 | "papers_df = pd.concat(li, axis=0, ignore_index=True)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "370acd2c", 59 | "metadata": {}, 60 | "source": [ 61 | "#### Below breakouts of split segment and subject labels should match those from part 2" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "id": "5a6b05c5", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/html": [ 73 | "
\n", 74 | "\n", 87 | "\n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | "
ogbIndex
split_segment
TRAIN1112392
VALIDATE138949
\n", 109 | "
" 110 | ], 111 | "text/plain": [ 112 | " ogbIndex\n", 113 | "split_segment \n", 114 | "TRAIN 1112392\n", 115 | "VALIDATE 138949" 116 | ] 117 | }, 118 | "execution_count": 3, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "papers_df[['split_segment', 'ogbIndex']].groupby('split_segment').count()" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 4, 130 | "id": "ce172791", 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "data": { 135 | "text/html": [ 136 | "
\n", 137 | "\n", 150 | "\n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | "
ogbIndex
subject
028041
12856
23907
31530
41910
......
148865
149815
150837
15122696
1521139
\n", 208 | "

153 rows × 1 columns

\n", 209 | "
" 210 | ], 211 | "text/plain": [ 212 | " ogbIndex\n", 213 | "subject \n", 214 | "0 28041\n", 215 | "1 2856\n", 216 | "2 3907\n", 217 | "3 1530\n", 218 | "4 1910\n", 219 | "... ...\n", 220 | "148 865\n", 221 | "149 815\n", 222 | "150 837\n", 223 | "151 22696\n", 224 | "152 1139\n", 225 | "\n", 226 | "[153 rows x 1 columns]" 227 | ] 228 | }, 229 | "execution_count": 4, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | } 233 | ], 234 | "source": [ 235 | "papers_df[['subject', 'ogbIndex']].groupby('subject').count()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 5, 241 | "id": "5334c395", 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "name": "stderr", 246 | "output_type": "stream", 247 | "text": [ 248 | "/home/ubuntu/.conda/envs/graph2/lib/python3.9/site-packages/pandas/core/frame.py:3641: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", 249 | " self[k1] = value[k2]\n" 250 | ] 251 | }, 252 | { 253 | "data": { 254 | "text/html": [ 255 | "
\n", 256 | "\n", 269 | "\n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | "
nodeIdembeddingencodingogbIndexsplit_segmentsubjectsubject_statusyearembedding_0embedding_1...embedding_246embedding_247embedding_248embedding_249embedding_250embedding_251embedding_252embedding_253embedding_254embedding_255
02055803380.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0....-4.32037;-0.81883;2.67975;2.31663;-3.12459;-2....83179065TRAIN98KNOWN20180.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
12055803730.0013625429;-9.808112E-4;4.393044E-4;-0.00290...-7.31162;1.27658;1.8888;-0.87184;2.52357;-0.53...83179100TRAIN106KNOWN20110.001363-0.000981...-0.0756290.0623450.263585-0.277317-0.305459-0.1299120.197675-0.0902850.2043990.027731
2205580582-0.0027561677;0.0019617418;-0.0030122166;0.001...1.62784;-2.5646;-0.07218;-1.67851;3.02409;-1.0...83181475TRAIN56KNOWN2012-0.0027560.001962...-0.177606-0.167020-0.1381670.0912430.195253-0.1062060.181296-0.052775-0.119006-0.206935
32055806415.627003E-4;-6.950726E-4;-0.0023442905;5.60485...2.75564;-3.55;4.72421;-0.70463;-0.09119;-2.252...83181534TRAIN128KNOWN20100.000563-0.000695...-0.0044620.146282-0.412124-0.0869600.144308-0.066715-0.148467-0.114543-0.368114-0.107802
4205581193-0.001374663;0.003907469;0.0010971766;-0.00113...-2.8812;1.11153;3.17897;0.28759;1.80633;0.6327...83179374TRAIN18KNOWN2012-0.0013750.003907...-0.271622-0.0200340.2484320.019380-0.187105-0.0476860.106186-0.0704200.2941780.014652
..................................................................
12513361892972760.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0....4.24787;-5.7E-4;-0.42477;-4.18265;2.68804;2.38...66895631TRAIN90KNOWN20140.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
1251337189297335-2.1663477E-4;0.0010571239;0.001026029;-0.0010...-0.26447;3.84805;-0.54991;-3.11971;-2.58163;-0...66895690TRAIN75KNOWN2018-0.0002170.001057...0.0028890.013994-0.1709450.3273280.0415090.0347200.121717-0.0188260.2309520.070828
12513381892976810.002675809;0.0013950182;2.3017207E-4;-0.00165...1.02664;-1.2927;-2.50304;-3.02091;0.26892;-1.0...66897904TRAIN34KNOWN20110.0026760.001395...-0.238014-0.050306-0.0252720.0533120.1070860.0437620.017888-0.369173-0.216644-0.192462
12513391892978290.0020234333;6.1786047E-4;-7.3964836E-4;-9.056...1.27171;7.01788;0.80883;5.77526;-1.82221;2.676...66898052TRAIN4KNOWN20120.0020230.000618...0.007818-0.085743-0.0085750.0365220.0257260.1654960.146139-0.049635-0.050294-0.053110
1251340189297865-5.5444025E-4;-0.0016479483;-0.0016374378;8.50...2.84921;-2.64254;1.72713;2.03434;2.50258;5.435...66898088TRAIN14KNOWN2005-0.000554-0.001648...0.1760770.2766220.095574-0.223936-0.172824-0.044063-0.1427270.0139340.0373600.039559
\n", 563 | "

1251341 rows × 264 columns

\n", 564 | "
" 565 | ], 566 | "text/plain": [ 567 | " nodeId embedding \\\n", 568 | "0 205580338 0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.... \n", 569 | "1 205580373 0.0013625429;-9.808112E-4;4.393044E-4;-0.00290... \n", 570 | "2 205580582 -0.0027561677;0.0019617418;-0.0030122166;0.001... \n", 571 | "3 205580641 5.627003E-4;-6.950726E-4;-0.0023442905;5.60485... \n", 572 | "4 205581193 -0.001374663;0.003907469;0.0010971766;-0.00113... \n", 573 | "... ... ... \n", 574 | "1251336 189297276 0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.... \n", 575 | "1251337 189297335 -2.1663477E-4;0.0010571239;0.001026029;-0.0010... \n", 576 | "1251338 189297681 0.002675809;0.0013950182;2.3017207E-4;-0.00165... \n", 577 | "1251339 189297829 0.0020234333;6.1786047E-4;-7.3964836E-4;-9.056... \n", 578 | "1251340 189297865 -5.5444025E-4;-0.0016479483;-0.0016374378;8.50... \n", 579 | "\n", 580 | " encoding ogbIndex \\\n", 581 | "0 -4.32037;-0.81883;2.67975;2.31663;-3.12459;-2.... 83179065 \n", 582 | "1 -7.31162;1.27658;1.8888;-0.87184;2.52357;-0.53... 83179100 \n", 583 | "2 1.62784;-2.5646;-0.07218;-1.67851;3.02409;-1.0... 83181475 \n", 584 | "3 2.75564;-3.55;4.72421;-0.70463;-0.09119;-2.252... 83181534 \n", 585 | "4 -2.8812;1.11153;3.17897;0.28759;1.80633;0.6327... 83179374 \n", 586 | "... ... ... \n", 587 | "1251336 4.24787;-5.7E-4;-0.42477;-4.18265;2.68804;2.38... 66895631 \n", 588 | "1251337 -0.26447;3.84805;-0.54991;-3.11971;-2.58163;-0... 66895690 \n", 589 | "1251338 1.02664;-1.2927;-2.50304;-3.02091;0.26892;-1.0... 66897904 \n", 590 | "1251339 1.27171;7.01788;0.80883;5.77526;-1.82221;2.676... 66898052 \n", 591 | "1251340 2.84921;-2.64254;1.72713;2.03434;2.50258;5.435... 66898088 \n", 592 | "\n", 593 | " split_segment subject subject_status year embedding_0 embedding_1 \\\n", 594 | "0 TRAIN 98 KNOWN 2018 0.000000 0.000000 \n", 595 | "1 TRAIN 106 KNOWN 2011 0.001363 -0.000981 \n", 596 | "2 TRAIN 56 KNOWN 2012 -0.002756 0.001962 \n", 597 | "3 TRAIN 128 KNOWN 2010 0.000563 -0.000695 \n", 598 | "4 TRAIN 18 KNOWN 2012 -0.001375 0.003907 \n", 599 | "... ... ... ... ... ... ... \n", 600 | "1251336 TRAIN 90 KNOWN 2014 0.000000 0.000000 \n", 601 | "1251337 TRAIN 75 KNOWN 2018 -0.000217 0.001057 \n", 602 | "1251338 TRAIN 34 KNOWN 2011 0.002676 0.001395 \n", 603 | "1251339 TRAIN 4 KNOWN 2012 0.002023 0.000618 \n", 604 | "1251340 TRAIN 14 KNOWN 2005 -0.000554 -0.001648 \n", 605 | "\n", 606 | " ... embedding_246 embedding_247 embedding_248 embedding_249 \\\n", 607 | "0 ... 0.000000 0.000000 0.000000 0.000000 \n", 608 | "1 ... -0.075629 0.062345 0.263585 -0.277317 \n", 609 | "2 ... -0.177606 -0.167020 -0.138167 0.091243 \n", 610 | "3 ... -0.004462 0.146282 -0.412124 -0.086960 \n", 611 | "4 ... -0.271622 -0.020034 0.248432 0.019380 \n", 612 | "... ... ... ... ... ... \n", 613 | "1251336 ... 0.000000 0.000000 0.000000 0.000000 \n", 614 | "1251337 ... 0.002889 0.013994 -0.170945 0.327328 \n", 615 | "1251338 ... -0.238014 -0.050306 -0.025272 0.053312 \n", 616 | "1251339 ... 0.007818 -0.085743 -0.008575 0.036522 \n", 617 | "1251340 ... 0.176077 0.276622 0.095574 -0.223936 \n", 618 | "\n", 619 | " embedding_250 embedding_251 embedding_252 embedding_253 \\\n", 620 | "0 0.000000 0.000000 0.000000 0.000000 \n", 621 | "1 -0.305459 -0.129912 0.197675 -0.090285 \n", 622 | "2 0.195253 -0.106206 0.181296 -0.052775 \n", 623 | "3 0.144308 -0.066715 -0.148467 -0.114543 \n", 624 | "4 -0.187105 -0.047686 0.106186 -0.070420 \n", 625 | "... ... ... ... ... \n", 626 | "1251336 0.000000 0.000000 0.000000 0.000000 \n", 627 | "1251337 0.041509 0.034720 0.121717 -0.018826 \n", 628 | "1251338 0.107086 0.043762 0.017888 -0.369173 \n", 629 | "1251339 0.025726 0.165496 0.146139 -0.049635 \n", 630 | "1251340 -0.172824 -0.044063 -0.142727 0.013934 \n", 631 | "\n", 632 | " embedding_254 embedding_255 \n", 633 | "0 0.000000 0.000000 \n", 634 | "1 0.204399 0.027731 \n", 635 | "2 -0.119006 -0.206935 \n", 636 | "3 -0.368114 -0.107802 \n", 637 | "4 0.294178 0.014652 \n", 638 | "... ... ... \n", 639 | "1251336 0.000000 0.000000 \n", 640 | "1251337 0.230952 0.070828 \n", 641 | "1251338 -0.216644 -0.192462 \n", 642 | "1251339 -0.050294 -0.053110 \n", 643 | "1251340 0.037360 0.039559 \n", 644 | "\n", 645 | "[1251341 rows x 264 columns]" 646 | ] 647 | }, 648 | "execution_count": 5, 649 | "metadata": {}, 650 | "output_type": "execute_result" 651 | } 652 | ], 653 | "source": [ 654 | "# Expand embedding vectors...sorry for the ugly performance warning :( - if you see a better way please recommend!\n", 655 | "def string_to_float(x):\n", 656 | " return np.array(x.split(';')).astype(float)\n", 657 | "papers_df[[f'embedding_{x}' for x in range(256)]] = papers_df.apply(\n", 658 | " lambda row: string_to_float(row.embedding), axis = 1, result_type ='expand')\n", 659 | "papers_df" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 6, 665 | "id": "056cfd87", 666 | "metadata": {}, 667 | "outputs": [], 668 | "source": [ 669 | "X = papers_df[[f'embedding_{x}' for x in range(256)]]\n", 670 | "y = papers_df.subject" 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "id": "201956ab", 676 | "metadata": {}, 677 | "source": [ 678 | "## Logistic Regression with FastRP Features" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 7, 684 | "id": "486dfa78", 685 | "metadata": {}, 686 | "outputs": [], 687 | "source": [ 688 | "X_train = X[papers_df.split_segment == \"TRAIN\"]\n", 689 | "X_validate = X[papers_df.split_segment == \"VALIDATE\"]\n", 690 | "y_train = y[papers_df.split_segment == \"TRAIN\"]\n", 691 | "y_validate = y[papers_df.split_segment == \"VALIDATE\"]" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": 8, 697 | "id": "25a7041b", 698 | "metadata": {}, 699 | "outputs": [], 700 | "source": [ 701 | "model = LogisticRegression(multi_class='ovr', solver='saga', n_jobs=60)" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 9, 707 | "id": "6ffa7a40", 708 | "metadata": {}, 709 | "outputs": [ 710 | { 711 | "data": { 712 | "text/plain": [ 713 | "LogisticRegression(multi_class='ovr', n_jobs=60, solver='saga')" 714 | ] 715 | }, 716 | "execution_count": 9, 717 | "metadata": {}, 718 | "output_type": "execute_result" 719 | } 720 | ], 721 | "source": [ 722 | "model.fit(X_train, y_train)" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 10, 728 | "id": "a6617d85", 729 | "metadata": {}, 730 | "outputs": [ 731 | { 732 | "name": "stdout", 733 | "output_type": "stream", 734 | "text": [ 735 | "Accuracy of logistic regression classifier on VALIDATE set: 0.58\n" 736 | ] 737 | } 738 | ], 739 | "source": [ 740 | "print('Accuracy of logistic regression classifier on VALIDATE set: {:.2f}'\\\n", 741 | " .format(model.score(X_validate, y_validate)))" 742 | ] 743 | }, 744 | { 745 | "cell_type": "markdown", 746 | "id": "d82f4bc2", 747 | "metadata": {}, 748 | "source": [ 749 | "#### Note: We were able to increase classification accuracy by about 9% points from part 2 by substituting the FastRP graph features" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": null, 755 | "id": "2835554d", 756 | "metadata": {}, 757 | "outputs": [], 758 | "source": [] 759 | } 760 | ], 761 | "metadata": { 762 | "kernelspec": { 763 | "display_name": "Python 3 (ipykernel)", 764 | "language": "python", 765 | "name": "python3" 766 | }, 767 | "language_info": { 768 | "codemirror_mode": { 769 | "name": "ipython", 770 | "version": 3 771 | }, 772 | "file_extension": ".py", 773 | "mimetype": "text/x-python", 774 | "name": "python", 775 | "nbconvert_exporter": "python", 776 | "pygments_lexer": "ipython3", 777 | "version": "3.9.7" 778 | } 779 | }, 780 | "nbformat": 4, 781 | "nbformat_minor": 5 782 | } 783 | --------------------------------------------------------------------------------