├── README.md ├── data ├── gas │ └── insert-data-here ├── rdj │ ├── GE.data │ ├── INTEL.data │ ├── MS.data │ └── raw.data └── stocks │ ├── fb │ └── close.vals │ └── goog │ └── close.vals ├── dirac_phi.py ├── gen_data.py ├── gen_scripts ├── clayton.py ├── frank.py └── joe.py ├── main.py ├── phi_listing.py ├── train.py └── train_scripts ├── boston └── train.py ├── clayton.py ├── clayton_noisy.py ├── frank.py ├── gas ├── train.py ├── train_with_clayton.py ├── train_with_frank.py └── train_with_gumbel.py ├── joe.py ├── rdj ├── train.py ├── train_with_clayton.py ├── train_with_frank.py └── train_with_gumbel.py └── stocks ├── train.py ├── train_with_clayton.py ├── train_with_frank.py └── train_with_gumbel.py /README.md: -------------------------------------------------------------------------------- 1 | This folder contains a minimal working example (MWE) for ACNet. 2 | 3 | We first describe the core code for ACNet. Then, we describe how to run experiments 4 | for one (sub)-experiment for each of the 3 subsections in Section 5 and should be 5 | sufficient for most purposes. 6 | 7 | Code for other datasets are described in the final section. 8 | 9 | Plotting of graphs was done by separately sampling and evaluating pdfs in ACNet, 10 | followed by calling pgfgen. These components are not include in this MWE. 11 | 12 | Files 13 | ===== 14 | The main files are `main.py` and `dirac_phi.py`. 15 | 16 | `main.py` 17 | This contains the code performing the generation of the computational 18 | graph, Newton's root finding method, bisection method (for root finding in the 19 | case of conditional sampling), as well as the implementation for the 20 | gradients of inverses. 21 | 22 | `dirac_phi.py` 23 | This contains the code defining the specific structure of phi_NN. 24 | 25 | Dependencies: 26 | pytorch 27 | numpy 28 | matplotlib 29 | scipy 30 | optional: sacred, used for monitoring and keeping track of experiments 31 | optional: scikit-learn, used only to download the boston housing dataset. 32 | 33 | Other files are: 34 | `phi_listing.py` 35 | Contains definitions for other phi's for the commonly used copula. Used to generate 36 | synthetic data. 37 | 38 | The rest of the files contain boilerplate code and helper functions. 39 | 40 | Instructions 41 | ============ 42 | We asume that dependencies are installed correctly. 43 | 44 | Experiment 1: Synthetic Data 45 | ---------------------------- 46 | To generate synthetic data drawn from the Clayton distribution, first run: 47 | 48 | >> python -m gen_scripts.clayton -F gen_clayton 49 | 50 | This will generate synthetic data in the pickle file `claytonX.p`. 51 | If this is your first run, X=1. 52 | 53 | Next, navigate to train_scripts/clayton.py. In the cfg() function, 54 | modify the data_name accordingly. If this is your first run, the 55 | default should be alright. Then, run: 56 | 57 | >> python -m train_scripts.clayton -F learn_clayton 58 | 59 | The experiment should run, albeit with slightly different values as us due to slight differents 60 | in randomness. The network should converge after around 10k (training) iterations. 61 | 62 | After every fixed interval, samples will be drawn from the learned network and plotted 63 | in the `/sample_figs` folder. 64 | 65 | Experiment 2: Real World Data 66 | ----------------------------- 67 | The boston housing dataset will automatically be downloaded using scikit learn. Simply run 68 | 69 | >> python -m train_scripts.boston.train -F boston_housing 70 | 71 | to run the experiment. Sampled points will be found in `/sample_figs` in the appropriate 72 | folder. 73 | 74 | Note that since the dataset is fairly small, results may vary significantly between runs. You may change 75 | the train/test split in `train_scripts/boston/train.py` directly to get varying results. In our experiments 76 | we used 5 different seeds and took averages. 77 | 78 | Typically, convergence occurs at around 10k epochs. Note that because the dataset is so small, in 79 | some settings, test loss will be *better* than training loss. This is *not* a bug. 80 | 81 | Experiment 3: Noisy Data: Synthetic Data 82 | ---------------------------------------- 83 | Generate the data from the Clayton copula as shown in the first section. Simply run 84 | 85 | >> python -m train_scripts.clayton_noisy -F learn_clayton_noisy 86 | 87 | As before, samples from the learned distribution will be periodically saved in `/sample_figs`. 88 | 89 | Note that the training loss being reported are the log *probabilities* and not the log *likelihoods*. 90 | However, the test loss is based on log *likelihoods*, in order to facilitate comparison with the non-noisy 91 | case (Experiment 1). 92 | 93 | In order to change the magnitude of noise, modify the variable `width_noise` in `train_scripts/clayton_noisy.py`. 94 | This coressponds to \lambda in the paper. 95 | 96 | Other Experiments 97 | ================= 98 | 99 | Synthetic Data 100 | -------------- 101 | 102 | The same steps may be followed (replace `clayton` by `frank` and `joe` where appropriate). 103 | 104 | INTC_MSFT dataset 105 | ----------------- 106 | 107 | The INTC-MSFT dataset may be obtained [here](https://rdrr.io/cran/copula/man/rdj.html). The data 108 | was analyzed in Chapter 5 of McNeil, Frey and Embrechts (2005). 109 | 110 | For convenience, processed data is included in the folder `/data/rdj`. Training using ACNet 111 | may be done by running: 112 | 113 | >> python -m train_scripts.rdj.train -F learn_rdj 114 | 115 | The script `train_scrips/rdj/train.py` contains the same network structure used as before. The 116 | seeds used are also included there to aid with reproduction. You can adjust the proportion of 117 | randomly generated `noise` datapoints in the configuration. 118 | 119 | To train other parametric copula using the dataset, run 120 | 121 | >> python -m train_scripts.rdj.train_with_frank -F learn_rdj_with_frank 122 | 123 | Replace `frank` with `gumbel` and `clayton` as appropriate. 124 | 125 | GOOG-FB dataset 126 | --------------- 127 | 128 | The data was obtained from Yahoo Finance and is included in the `/data` folder. 129 | 130 | Training is done in the same manner as before: 131 | 132 | >> python -m train_scripts.stocks.train -F learn_stocks 133 | 134 | and training using other copula is done by 135 | 136 | >> python -m train_scripts.stocks.train_with_frank -F learn_stocks_with_frank 137 | 138 | where `frank` may be replaced by `clayton` or `gumbel`. 139 | 140 | GAS dataset 141 | ----------- 142 | 143 | The gas dataset may be downloaded [here](https://archive.ics.uci.edu/ml/datasets/gas+sensor+array+drift+dataset). 144 | 145 | You should download the data and organize it such that the files `batchX.dat` lies in `/data/gas`. 146 | 147 | Training is done in the same way; note that here we learn when d=3:: 148 | 149 | >> python -m train_scripts.gas.train -F learn_gas 150 | 151 | and 152 | 153 | >> python -m train_scripts.gas.train_with_frank -F learn_gas_with_frank 154 | 155 | -------------------------------------------------------------------------------- /data/gas/insert-data-here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingchunkai/ACNet/bc4e8c87728574ce549676aa91ff2c734751345b/data/gas/insert-data-here -------------------------------------------------------------------------------- /data/rdj/GE.data: -------------------------------------------------------------------------------- 1 | 0.001692413 2 | -0.011949427 3 | 0.001718221 4 | 0.006835194 5 | -0.009627754 6 | -0.032119624 7 | 0.001776161 8 | -0.003555482 9 | 0 10 | 0.008856504 11 | 0.007035998 12 | 0.015635459 13 | 0.023855212 14 | -0.0033756 15 | -0.013609154 16 | 0.010222121 17 | 0.013471677 18 | 0.009983958 19 | -0.008311669 20 | 0.008311669 21 | 0.016418914 22 | 0.003254849 23 | 0.001615548 24 | 0.012885956 25 | 0.012722018 26 | -0.009524361 27 | 0.014248901 28 | -0.001574931 29 | 0.003147385 30 | 0.001569985 31 | -0.012621695 32 | -0.009569973 33 | -0.006428619 34 | -0.006470214 35 | 0.011293589 36 | 0.022226926 37 | 0.001569985 38 | -0.015796338 39 | -0.009600549 40 | -0.021124222 41 | -0.008243154 42 | 0.018047662 43 | 0.012927526 44 | 0.004807984 45 | -0.004807984 46 | 0 47 | -0.040950982 48 | 0.023137206 49 | -0.019803585 50 | -0.005008698 51 | 0.006675413 52 | 0.0082841 53 | 0.027666546 54 | -0.009677965 55 | 0.001620829 56 | 0.001618206 57 | 0.012836664 58 | 0.007949692 59 | 0 60 | -0.001584891 61 | -0.007954552 62 | -0.004807984 63 | 0.019079393 64 | 0.001574918 65 | 0.009390171 66 | 0.001557757 67 | -0.009375486 68 | -0.015829019 69 | -0.022577451 70 | -0.013139782 71 | 0.016399911 72 | 0.017727571 73 | -0.006415795 74 | 0.006415795 75 | 0.012694007 76 | -0.001579895 77 | 0.004732216 78 | -0.001574918 79 | -0.004731967 80 | -0.012742195 81 | -0.001605231 82 | -0.003210351 83 | -0.004846784 84 | 0.008057136 85 | -0.004823337 86 | -0.006478088 87 | -0.0097885 88 | -0.011535602 89 | 0.008256713 90 | 0.013067389 91 | 0.009693613 92 | 0.009600549 93 | 0.004769799 94 | -0.001587406 95 | 0.01418147 96 | 0.017067497 97 | 0.037743388 98 | -0.01341999 99 | 0.019326307 100 | -0.002945192 101 | 0.004418132 102 | -0.026828544 103 | -0.00605845 104 | 0.012087747 105 | -0.006029297 106 | 0.001510736 107 | 0.011997143 108 | 0.005941409 109 | 0.021976391 110 | -0.004353968 111 | 0.001455796 112 | -0.016117093 113 | 0.007357518 114 | 0.005845862 115 | 0 116 | 0.005811886 117 | -0.007271933 118 | 0 119 | 0 120 | 0.015930423 121 | 0.004304785 122 | 0.004279378 123 | -0.010016195 124 | 0.007167637 125 | -0.008608756 126 | 0.005752061 127 | 0.002856696 128 | -0.002856696 129 | -0.018802205 130 | -0.020652994 131 | 0.008902534 132 | 0.008823977 133 | -0.017726511 134 | 0.001490481 135 | -0.031746355 136 | -0.006161087 137 | 0.022916457 138 | 0.013500504 139 | -0.021081121 140 | -0.01533584 141 | -0.018722032 142 | 0.006277157 143 | 0.009346253 144 | 0.009259709 145 | -0.017040903 146 | 0.001562614 147 | 0.026175255 148 | 0.013581989 149 | 0.032450227 150 | -0.002902368 151 | 0.00725077 152 | -0.008708678 153 | -0.010258114 154 | -0.013340335 155 | 0.014813264 156 | 0 157 | 0.004398666 158 | -0.007346697 159 | 0.001475102 160 | 0.001472929 161 | 0 162 | -0.011837364 163 | 0.007412052 164 | 0.013210497 165 | -0.001460036 166 | 0.010168714 167 | -0.001447367 168 | -0.018994526 169 | -0.019354974 170 | 0.011953904 171 | 0.001486051 172 | -0.011943377 173 | 0.017858396 174 | 0.01027318 175 | 0.004366633 176 | 0 177 | 0.011563677 178 | 0.01568051 179 | 0 180 | -0.01568051 181 | 0.002871895 182 | 0.007136914 183 | 0.014122814 184 | -0.007036476 185 | 0.015416765 186 | 0.017921012 187 | -0.012368414 188 | 0.012375083 189 | -0.00548412 190 | -0.004126184 191 | 0.013697104 192 | 0 193 | 0.017534202 194 | 0.0066614 195 | -0.008000329 196 | -0.005367144 197 | -0.009467354 198 | 0.025489599 199 | 0.006599833 200 | 0.002623941 201 | -0.002623941 202 | 0.001309623 203 | 0 204 | 0.006554382 205 | -0.001307444 206 | 0.009110089 207 | -0.003896887 208 | -0.002606387 209 | -0.009163376 210 | 0.010467419 211 | 0 212 | 0.007782436 213 | -0.001293927 214 | 0.007732297 215 | 0.014021408 216 | 0.028698271 217 | 0.011012683 218 | -0.006100522 219 | -0.00122578 220 | -0.001227284 221 | 0.017034629 222 | 0.013181218 223 | -0.00716874 224 | -0.008429989 225 | -0.003631335 226 | -0.002432196 227 | -0.011012683 228 | -0.006168286 229 | 0.026862696 230 | -0.01700803 231 | 0.004888176 232 | 0.014528644 233 | -0.007237937 234 | -0.039508722 235 | -0.00347398 236 | -0.013035954 237 | 0.002559718 238 | 0.006364579 239 | -0.003817616 240 | -0.011528497 241 | -0.002582841 242 | 0 243 | -0.016940232 244 | 0.025944889 245 | 0.016509933 246 | 0.034656141 247 | -0.004876243 248 | -0.012304725 249 | 0.009854666 250 | 0.012178883 251 | -0.007284753 252 | -0.008574505 253 | -0.027432003 254 | -0.012721143 255 | 0.015248529 256 | -0.001262895 257 | 0.019999263 258 | -0.013706804 259 | 0.014944708 260 | 0.002471221 261 | 0.018324293 262 | 0.013228739 263 | -0.016864461 264 | -0.013461259 265 | 0.020719532 266 | 0.01318703 267 | 0.004748812 268 | 0.017614124 269 | -0.030732407 270 | -0.021846198 271 | -0.014833862 272 | 0.002489663 273 | 0.029372837 274 | 0.019120975 275 | -0.020323103 276 | 0.016763485 277 | -0.013149675 278 | -0.008460527 279 | 0.00966419 280 | 0.009571686 281 | -0.010775349 282 | 0.007194507 283 | 0.011875082 284 | 0.012902888 285 | -0.008187471 286 | 0.010517732 287 | 0 288 | -0.016409715 289 | -0.005932709 290 | 0.014168364 291 | -0.005878149 292 | 0.002357705 293 | -0.01064792 294 | -0.021630268 295 | 0.004846665 296 | -0.019535236 297 | 0.021952733 298 | 0.009606132 299 | 0.007143074 300 | 0.004731938 301 | -0.007109233 302 | -0.003570747 303 | -0.010801164 304 | -0.006048811 305 | 0.013260585 306 | -0.013260585 307 | 0.004846689 308 | -0.008491237 309 | 0.001218306 310 | 0.029956197 311 | -0.008300004 312 | -0.010769474 313 | -0.025593219 314 | -0.019955372 315 | -0.006316291 316 | -0.00891287 317 | 0.006378863 318 | 0.025091976 319 | -0.003721465 320 | 0.007435158 321 | 0.006153015 322 | -0.003690872 323 | -0.036367644 324 | 0.021478611 325 | 0.02834428 326 | 0.024013368 327 | -0.01433754 328 | 0.008389498 329 | -0.00358939 330 | 0.038763877 331 | -0.012753159 332 | -0.004682298 333 | -0.001174008 334 | 0.007017773 335 | 0.030987172 336 | 0.003381433 337 | -0.011323609 338 | 0.030286427 339 | 0.029401298 340 | 0.002140191 341 | -0.00859893 342 | 0.003231956 343 | 0.003226782 344 | 0.027512229 345 | 0.014510092 346 | 0.002052209 347 | 0.018312319 348 | -0.024491712 349 | 0.008232401 350 | 0.024293333 351 | -0.012072898 352 | -0.012220435 353 | 0.006128885 354 | -0.008181893 355 | 0.002053008 356 | 0.002048802 357 | -0.012346146 358 | 0.004130664 359 | 0 360 | 0.004113672 361 | -0.010318616 362 | 0.032658465 363 | 0.009988719 364 | 0.005946653 365 | 0.005911499 366 | 0.023302614 367 | 0.01523723 368 | 0.005655205 369 | 0 370 | -0.003765046 371 | 0.013118378 372 | 0.00742289 373 | -0.031926871 374 | 0.012324744 375 | -0.014236567 376 | -0.003829962 377 | 0.002872678 378 | -0.004795521 379 | 0.022815278 380 | 0.01585532 381 | 0.026473314 382 | 0.007182972 383 | -0.005381289 384 | -0.015409866 385 | 0.015409866 386 | 0.016053443 387 | 0.005295817 388 | 0.016588129 389 | 0.024798075 390 | -0.011041406 391 | -0.028584309 392 | -0.007056305 393 | 0.017546151 394 | 0.004337795 395 | -0.004337795 396 | 0.006069346 397 | -0.01042604 398 | -0.008770746 399 | -0.002648883 400 | -0.008871927 401 | -0.014361747 402 | -0.016409096 403 | -0.003686325 404 | 0.022802081 405 | -0.01544744 406 | -0.017553568 407 | 0.008355168 408 | -0.021485803 409 | 0.008460563 410 | 0 411 | -0.051886053 412 | 0.038688915 413 | 0.007563191 414 | 0.022347807 415 | -0.028961886 416 | -0.01624459 417 | -0.011628313 418 | -0.013736289 419 | -0.007938337 420 | -0.003990485 421 | 0.001000334 422 | 0.056325143 423 | 0.012197452 424 | -0.006547603 425 | -0.004709827 426 | 0.010327934 427 | -0.002807394 428 | -0.012248027 429 | -0.014318051 430 | 0.019045538 431 | 0.001886593 432 | 0.03242525 433 | 0.016275595 434 | 0.006259269 435 | 0.000891613 436 | 0.001776504 437 | -0.017029864 438 | -0.007261192 439 | -0.025830756 440 | 0.021271731 441 | 0.010919312 442 | -0.014588729 443 | 0.007320968 444 | 0.01357907 445 | -0.006311309 446 | 0.011692598 447 | 0.029961614 448 | -0.02105377 449 | 0.00176863 450 | -0.010676474 451 | -0.007178574 452 | 0 453 | -0.009052954 454 | 0.008150898 455 | 0.000902057 456 | 0.01075298 457 | 0.007102069 458 | -0.016953805 459 | -0.030155818 460 | -0.021566362 461 | -0.062581115 462 | 0.071078196 463 | -0.024741595 464 | -0.025369305 465 | 0.021509798 466 | 0.03795435 467 | 0.019362487 468 | 0.009995004 469 | -0.005439836 470 | -0.014651081 471 | -0.011132059 472 | 0.00558152 473 | -0.024411749 474 | 0.03914378 475 | 0.013620487 476 | 0.026692459 477 | -0.023093492 478 | 0.016043439 479 | 0.017530684 480 | 0.011232116 481 | -0.015584995 482 | 0.015584995 483 | 0.009405263 484 | 0.005936432 485 | 0.001692027 486 | -0.005080549 487 | -0.008523973 488 | -0.006874407 489 | 0.012850471 490 | 0.001697945 491 | 0.014339791 492 | -0.011792061 493 | -0.017096145 494 | -0.003452895 495 | 0.032341101 496 | 0.010000287 497 | -0.004987643 498 | -0.013421921 499 | -0.013604522 500 | 0.011912496 501 | -0.038813375 502 | 0 503 | -0.003522865 504 | 0.022690389 505 | 0.026396847 506 | -0.013535405 507 | 0.008480558 508 | 0.01758332 509 | -0.013369859 510 | 0.008378104 511 | -0.00921609 512 | -0.024717475 513 | 0.027237526 514 | 0.003351581 515 | -0.00419226 516 | -0.022948684 517 | 0.009413279 518 | 0.029376579 519 | 0 520 | -0.004975205 521 | -0.015075878 522 | 0.00840217 523 | 0.020705661 524 | 0 525 | 0.011411552 526 | 0.004850518 527 | 0.01361608 528 | -0.012809287 529 | -0.008906756 530 | -0.002440053 531 | 0.012954517 532 | -0.008076391 533 | 0.00566189 534 | 0.012022405 535 | 0.000797155 536 | -0.008796011 537 | 0.003206442 538 | 0.004792413 539 | -0.007998855 540 | 0.004807767 541 | -0.006416816 542 | -0.012139179 543 | 0.00487416 544 | -0.001623403 545 | 0.009693269 546 | -0.015392038 547 | 0.008938993 548 | -0.012208359 549 | 0.000823338 550 | 0.024253635 551 | 0.003188529 552 | 0.003178394 553 | 0.003172187 554 | 0.0023687 555 | -0.011112646 556 | 0.013479593 557 | 0.007841976 558 | -0.003913301 559 | 0.004694889 560 | 0.019328813 561 | -0.014655863 562 | 0.026831448 563 | -0.007591886 564 | 0.005318145 565 | 0.012805742 566 | 0.021477593 567 | 0.010201901 568 | 0.00289884 569 | 0.015068646 570 | -0.004282707 571 | -0.006456891 572 | 0.003592696 573 | -0.01154661 574 | 0.004344736 575 | 0.003609178 576 | -0.002163943 577 | -0.000722356 578 | -0.01088774 579 | 0.01879811 580 | -0.012977984 581 | -0.000723007 582 | -0.010220352 583 | -0.008843184 584 | -0.006685908 585 | -0.017286759 586 | -0.016047565 587 | 0.014532567 588 | 0.034330849 589 | -0.00220527 590 | -0.007378991 591 | -0.012672079 592 | -0.003004068 593 | -0.014400397 594 | 0.015901725 595 | 0 596 | 0.003003225 597 | 0.014129006 598 | -0.005923597 599 | -0.017991665 600 | -0.001515253 601 | 0.004535205 602 | 0.033375652 603 | 0.002912173 604 | -0.010970937 605 | -0.026071942 606 | 0.011259107 607 | 0.007433845 608 | -0.011921555 609 | -0.006016342 610 | -0.001506993 611 | -0.017527537 612 | 0.025800833 613 | 0.015606239 614 | -0.003693568 615 | 0.004431367 616 | 0.005145827 617 | -0.012541178 618 | 0.012541178 619 | -0.016263103 620 | 0.016263103 621 | 0.026765541 622 | 0.004270078 623 | -0.010716594 624 | -0.010109677 625 | 0.022957954 626 | 0.018966399 627 | -0.000693312 628 | 0.003475313 629 | 0.008981708 630 | 0 631 | 0.003432363 632 | -0.002747632 633 | 0.010261602 634 | 0.00813018 635 | 0.007394495 636 | 0.008670212 637 | -0.003325056 638 | -0.002670637 639 | 0.015905037 640 | -0.010574119 641 | 0.017779704 642 | 0.009745161 643 | -0.012358498 644 | -0.027198682 645 | -0.016955984 646 | -0.015161237 647 | 0.018574846 648 | 0.003405313 649 | -0.012992268 650 | -0.014560111 651 | 0.020048665 652 | -0.020747807 653 | 0.000699142 654 | -0.046453294 655 | 0.01524718 656 | 0.023495405 657 | -0.007770626 658 | -0.003551847 659 | -0.005711424 660 | 0.023349246 661 | -0.016925748 662 | -0.004991708 663 | 0.021217338 664 | 0.015966556 665 | 0.005495671 666 | -0.008944601 667 | -0.009025329 668 | 0.001394871 669 | 0.006246197 670 | 0.001384261 671 | -0.04379376 672 | -0.007973415 673 | -0.070865366 674 | 0.033042629 675 | -0.028366051 676 | -0.033204845 677 | -0.024411463 678 | 0.076106936 679 | -0.020828085 680 | -0.039748556 681 | 0.027990154 682 | -0.001579828 683 | -0.013518546 684 | 0.022952651 685 | -0.020551642 686 | -0.004803877 687 | 0.017501243 688 | 0.007854277 689 | 0.051098643 690 | -0.02789557 691 | 0.015929481 692 | 0.006002769 693 | -0.008260828 694 | -0.040790086 695 | -0.051581571 696 | -0.002486502 697 | -0.032017739 698 | 0.010221611 699 | 0.016808566 700 | -0.046917826 701 | 0.013017537 702 | 0.021323515 703 | 0.027465446 704 | 0.019512892 705 | 0.054828676 706 | 0.021866055 707 | -0.005234276 708 | -0.013585148 709 | 0.019561828 710 | 0.030816966 711 | 0.000722873 712 | -0.007973387 713 | -0.00437639 714 | -0.006598912 715 | 0.021831453 716 | 0.00789112 717 | -0.00143076 718 | 0.012793396 719 | 0.00281974 720 | 0.022975297 721 | -0.004829837 722 | -0.01814546 723 | 0.003514387 724 | -0.021277509 725 | -0.006470787 726 | 0.017877411 727 | 0.025193847 728 | -0.004849913 729 | 0.019256773 730 | 0.008817193 731 | 0.00135063 732 | 0.020688686 733 | -0.003970894 734 | -0.011339987 735 | -0.009431875 736 | -0.021212502 737 | 0.008264538 738 | -0.002061418 739 | -0.032124677 740 | 0.023152591 741 | 0.007598803 742 | -0.004137949 743 | -0.004849913 744 | -0.017514217 745 | 0.007041169 746 | -0.025587883 747 | 0.070195144 748 | -0.012154126 749 | 0.024822518 750 | 0.026810504 751 | 0.015365251 752 | 0.006963758 753 | 0.023073099 754 | -0.004322502 755 | 0.00740289 756 | 0.015252333 757 | -0.009122859 758 | -0.003058545 759 | -0.014192883 760 | 0.02090997 761 | 0.019290307 762 | -0.016859682 763 | -0.005477984 764 | -0.026602759 765 | -0.019627625 766 | -0.012223271 767 | 0.00516569 768 | 0.037286614 769 | 0.00680015 770 | -0.010529182 771 | -0.010641227 772 | -0.014579778 773 | 0.026466573 774 | 0.014815086 775 | -0.004915017 776 | 0.018909278 777 | 0.013801121 778 | -0.025960327 779 | -0.001835055 780 | 0.008542834 781 | -0.027096242 782 | -0.021450889 783 | -0.004474797 784 | -0.016145528 785 | 0.019981848 786 | 0.02083913 787 | -0.024034519 788 | 0.0139852 789 | 0.003151082 790 | 0.013130814 791 | -0.002489588 792 | 0.035479718 793 | -0.009663876 794 | -0.019606477 795 | -0.006209365 796 | -0.000623367 797 | 0.006213225 798 | -0.016231085 799 | -0.005678499 800 | 0.01756833 801 | 0.031225219 802 | 0.010791424 803 | 0.011266485 804 | -0.002363188 805 | 0.012336348 806 | 0.002914118 807 | 0.025289533 808 | -0.001136638 809 | -0.014306811 810 | 0.014306811 811 | 0.018019348 812 | -0.014616067 813 | -0.035152337 814 | 0.001759408 815 | 0.016259507 816 | -0.00635607 817 | 0.038105559 818 | 0.011649218 819 | -0.024003769 820 | 0.009000734 821 | 0.026520772 822 | -0.011517737 823 | 0.010425746 824 | -0.005472555 825 | -0.01493107 826 | 0.037720886 827 | 0.00374614 828 | -0.024892114 829 | -0.013793731 830 | -0.011171442 831 | -0.06020405 832 | 0.042634497 833 | 0.038140646 834 | -0.002755127 835 | -0.003316744 836 | 0.005523603 837 | -0.008851384 838 | -0.008927694 839 | -0.035372442 840 | -0.021127486 841 | 0.002371588 842 | -0.006532542 843 | 0.03798519 844 | -0.006904483 845 | 0.015467196 846 | -0.010860785 847 | 0.019914796 848 | -0.006783504 849 | -0.007970828 850 | -0.031361322 851 | 0.002945019 852 | -0.019003849 853 | 0.035339265 854 | -0.014573705 855 | -0.021963363 856 | -0.006022424 857 | -0.010929038 858 | 0.007299281 859 | -0.023918597 860 | 0.009881679 861 | 0.008568946 862 | -0.004887973 863 | 0.007929755 864 | 0.0192554 865 | -0.004180825 866 | -0.019947666 867 | 0.003047349 868 | -0.016568327 869 | 0.003395917 870 | 0.023468599 871 | 0.00420554 872 | 0.02488274 873 | 0.015099912 874 | -0.008102815 875 | -0.018766908 876 | 0.002955465 877 | -0.000590971 878 | 0.008822392 879 | -0.021901227 880 | 0.016616563 881 | 0.032429835 882 | 0.029751833 883 | -0.013363179 884 | 0.010594323 885 | 0.011035078 886 | 0.034503651 887 | -0.007979533 888 | 0.001066061 889 | -0.015053074 890 | 0.001622485 891 | 0.002162739 892 | 0.014996627 893 | 0.011103626 894 | 0.00941915 895 | -0.020522775 896 | -0.005329599 897 | -0.016707793 898 | 0.001627771 899 | -0.010363173 900 | 0.019544493 901 | -0.019544493 902 | -0.017697665 903 | -0.027150789 904 | -0.010374655 905 | 0 906 | -0.010483418 907 | 0.026003922 908 | -0.023665708 909 | -0.018273785 910 | -0.00178478 911 | 0.011256894 912 | -0.009472114 913 | 0.025258633 914 | 0.018961902 915 | 0.01468759 916 | 0.002242705 917 | -0.006176834 918 | 0.016199113 919 | 0.032700215 920 | 0.000536481 921 | 0.022780134 922 | -0.003147939 923 | -0.020168677 924 | -0.020038875 925 | -0.017104195 926 | 0.012167315 927 | -0.006618849 928 | 0.030253451 929 | 0.028840514 930 | 0.008828456 931 | -0.009350662 932 | -0.004707045 933 | -0.005259313 934 | -0.015404829 935 | -0.006444666 936 | 0.010719018 937 | 0.023184024 938 | 0.01652751 939 | -0.022797075 940 | -0.00210033 941 | -0.019088481 942 | 0.010650525 943 | 0.016807851 944 | -0.013632368 945 | -0.016503183 946 | 0.018087151 947 | -0.010066635 948 | 0.022638694 949 | 0.007259684 950 | 0.022484966 951 | -0.014761736 952 | 0.023314759 953 | -0.013112343 954 | -0.014314216 955 | -0.015045114 956 | 0.004692304 957 | -0.037098976 958 | 0.016068751 959 | 0.003712943 960 | 0.029726851 961 | 0.013272987 962 | 0.019087104 963 | -0.00298951 964 | 0.00298951 965 | 0.022139725 966 | 0.034906821 967 | 0.018623858 968 | -0.046256801 969 | -0.002902732 970 | 0.018244028 971 | 0.003797094 972 | 0.014117763 973 | 0.001868917 974 | -0.004676569 975 | 0.004676569 976 | 0 977 | 0.010208054 978 | 0.000923677 979 | 0.029988329 980 | 0.011128792 981 | -0.012920161 982 | -0.012182501 983 | 0.017995474 984 | -0.018449869 985 | -0.005464446 986 | -0.009172794 987 | -0.018606519 988 | -0.022792798 989 | 0.033069879 990 | 0.001391934 991 | 0.011534418 992 | 0.00719845 993 | 0.01995228 994 | 0.008999746 995 | 0.01612914 996 | 0.026632227 997 | 0.009284054 998 | 0.00711408 999 | -0.037819808 1000 | 0.022693614 1001 | 0.028382939 1002 | 0.009113544 1003 | 0.029235562 1004 | 0.000398212 1005 | -0.002378095 1006 | 0.012618956 1007 | -0.014206853 1008 | -0.004381226 1009 | -0.010030375 1010 | -0.0016155 1011 | -0.031174586 1012 | -0.040821486 1013 | -0.001739013 1014 | 0.013281617 1015 | 0.037989708 1016 | -0.000411452 1017 | 0.001650836 1018 | 0.003295525 1019 | 0.01144658 1020 | -0.018048527 1021 | -0.020067331 1022 | 0.004845463 1023 | -0.018878078 1024 | -0.012499071 1025 | -0.042521416 1026 | 0.002711222 1027 | 0.020986994 1028 | 0.002208788 1029 | -0.056225914 1030 | 0 1031 | 0.014815655 1032 | -0.014348972 1033 | 0.037963938 1034 | 0.016470605 1035 | -0.036417379 1036 | 0.004113479 1037 | -0.022130643 1038 | 0.010203292 1039 | -0.012536606 1040 | 0.004661196 1041 | 0.022989983 1042 | -0.013500571 1043 | -0.034926176 1044 | -0.045885128 1045 | 0.035332034 1046 | 0.006728266 1047 | 0.003824828 1048 | -0.037923364 1049 | 0.025923937 1050 | 0.02244001 1051 | -0.007582113 1052 | 0.029069486 1053 | 0.030042846 1054 | -0.013996849 1055 | -0.05611487 1056 | 0.002881803 1057 | -0.010124035 1058 | 0.020619384 1059 | -0.017233304 1060 | -0.019009572 1061 | 0.050380517 1062 | 0.039904947 1063 | 0.006275869 1064 | 0.007566887 1065 | 0.065646628 1066 | 0.002488558 1067 | 0.058721926 1068 | -0.005875563 1069 | -0.00709771 1070 | -0.012343508 1071 | 0.043892898 1072 | -0.02641982 1073 | -0.01987978 1074 | 0.033955206 1075 | -0.044452366 1076 | -0.004473881 1077 | 0.022969661 1078 | 0.012275097 1079 | 0.003928954 1080 | 0.013626315 1081 | -0.030626188 1082 | -0.0406898 1083 | -0.032069077 1084 | 0.041986367 1085 | 0.029175066 1086 | -0.006409532 1087 | 0.019108577 1088 | 0.022228087 1089 | 0.024006279 1090 | -0.016705375 1091 | -0.010778716 1092 | -0.026668502 1093 | 0.013424107 1094 | 0.010531465 1095 | -0.031534162 1096 | -0.013304742 1097 | 0.025641402 1098 | -0.004360375 1099 | -0.005976241 1100 | -0.029200701 1101 | 0.006154778 1102 | 0.025441097 1103 | 0.032943651 1104 | 0.004618876 1105 | -0.01158744 1106 | -0.009367607 1107 | -0.02381028 1108 | -0.036815419 1109 | 0 1110 | 0.014888402 1111 | 0 1112 | -0.023675756 1113 | 0.033481604 1114 | 0.026475259 1115 | -0.00476184 1116 | 0.007134284 1117 | -0.022769739 1118 | -0.007299194 1119 | 0.003657247 1120 | -0.009779873 1121 | -0.019853487 1122 | 0 1123 | 0.025976114 1124 | -0.003668677 1125 | 0.017011077 1126 | -0.014563134 1127 | -0.017264814 1128 | 0.014816872 1129 | -0.031116912 1130 | -0.014004874 1131 | 0.022814348 1132 | 0.001253349 1133 | -0.013863675 1134 | 0.025991637 1135 | -0.015890686 1136 | 0.063282491 1137 | -0.0190479 1138 | -0.040469831 1139 | 0.004993682 1140 | 0.022168018 1141 | 0.021689175 1142 | -0.003583001 1143 | 0.028303287 1144 | -0.023530092 1145 | -0.019231023 1146 | 0.041598615 1147 | -0.027140787 1148 | 0.009523716 1149 | 0.029191068 1150 | -0.003459095 1151 | -0.002312101 1152 | -0.008133762 1153 | -0.027204994 1154 | 0.007168364 1155 | -0.030214242 1156 | 0.014616336 1157 | 0.020348427 1158 | -0.014319804 1159 | 0.016686628 1160 | 0.009411672 1161 | -0.014150935 1162 | 0.021152011 1163 | 0.039891572 1164 | 0.018816842 1165 | -0.007705807 1166 | 0.007705807 1167 | 0.004376297 1168 | -0.007672029 1169 | -0.002202607 1170 | -0.008859264 1171 | 0.006651795 1172 | -0.005539181 1173 | 0.029559153 1174 | 0.017112404 1175 | 0.005287266 1176 | 0.012578568 1177 | -0.002085471 1178 | -0.040473406 1179 | 0.01937594 1180 | -0.002134436 1181 | -0.011820653 1182 | 0.020331195 1183 | 0 1184 | 0.01472136 1185 | -0.003135532 1186 | -0.010526224 1187 | 0 1188 | -0.001059604 1189 | -0.038881104 1190 | 0.013129058 1191 | -0.008733525 1192 | -0.006600566 1193 | -0.006644423 1194 | 0.018711486 1195 | 0.013002801 1196 | -0.001077863 1197 | 0.024482761 1198 | -0.007388625 1199 | -0.018008163 1200 | 0.011822874 1201 | 0.009568642 1202 | -0.003179959 1203 | 0.014754052 1204 | -0.005244351 1205 | -0.015898384 1206 | -0.007507457 1207 | -0.025068289 1208 | -0.038249862 1209 | 0.04484939 1210 | 0.01307333 1211 | -0.037490585 1212 | -0.002249933 1213 | 0.002249933 1214 | -0.06498859 1215 | -0.046634691 1216 | 0.070332815 1217 | -0.008231285 1218 | -0.01546684 1219 | 0.002395439 1220 | 0.032943481 1221 | 0.014933793 1222 | -0.006863815 1223 | -0.013874383 1224 | -0.007008149 1225 | 0.022030088 1226 | 0.007994383 1227 | -0.006848142 1228 | 0 1229 | -0.012681496 1230 | -0.046297704 1231 | 0.025194803 1232 | -0.004749128 1233 | 0.002376419 1234 | -0.014353751 1235 | -0.03556417 1236 | 0.013638684 1237 | -0.044060134 1238 | 0.016592122 1239 | -0.005074579 1240 | 0.013897161 1241 | -0.002512805 1242 | -0.002519135 1243 | 0.028591856 1244 | 0.012179609 1245 | 0.047290834 1246 | -0.00347055 1247 | -0.008143203 1248 | 0.031054047 1249 | 0.002262661 1250 | -0.046251691 1251 | 0.003544348 1252 | -0.029923398 1253 | -0.032102574 1254 | 0.023559916 1255 | -0.01730539 1256 | -0.055107943 1257 | 0.009181272 1258 | 0.020672446 1259 | 0.008912428 1260 | -0.019829954 1261 | 0.005175111 1262 | -0.010375033 1263 | -------------------------------------------------------------------------------- /data/stocks/fb/close.vals: -------------------------------------------------------------------------------- 1 | 80.150002 2 | 79.190002 3 | 80.290001 4 | 80.440002 5 | 82.440002 6 | 82.050003 7 | 82.139999 8 | 80.669998 9 | 80.669998 10 | 82.160004 11 | 81.830002 12 | 81.529999 13 | 80.709999 14 | 81.059998 15 | 81.790001 16 | 82.910004 17 | 82.510002 18 | 84.739998 19 | 87.879997 20 | 88.860001 21 | 87.980003 22 | 88.010002 23 | 85.800003 24 | 85.769997 25 | 86.910004 26 | 87.290001 27 | 87.550003 28 | 87.220001 29 | 85.650002 30 | 85.879997 31 | 87.949997 32 | 90.099998 33 | 89.68 34 | 89.760002 35 | 90.849998 36 | 94.970001 37 | 97.910004 38 | 98.389999 39 | 97.040001 40 | 95.440002 41 | 96.949997 42 | 94.169998 43 | 95.290001 44 | 96.989998 45 | 95.209999 46 | 94.010002 47 | 94.139999 48 | 94.059998 49 | 96.440002 50 | 95.120003 51 | 94.300003 52 | 94.150002 53 | 93.620003 54 | 94.190002 55 | 93.43 56 | 94.419998 57 | 93.93 58 | 95.169998 59 | 95.309998 60 | 90.559998 61 | 86.059998 62 | 82.089996 63 | 83 64 | 87.190002 65 | 89.730003 66 | 91.010002 67 | 89.43 68 | 87.230003 69 | 89.889999 70 | 88.150002 71 | 88.260002 72 | 89.529999 73 | 90.440002 74 | 91.980003 75 | 92.050003 76 | 92.309998 77 | 92.900002 78 | 93.449997 79 | 94.339996 80 | 94.400002 81 | 95.550003 82 | 92.959999 83 | 93.970001 84 | 94.410004 85 | 92.769997 86 | 89.209999 87 | 86.669998 88 | 89.900002 89 | 90.949997 90 | 92.07 91 | 94.010002 92 | 92.800003 93 | 92.400002 94 | 92.470001 95 | 93.239998 96 | 94.260002 97 | 94.120003 98 | 94.07 99 | 95.959999 100 | 97.540001 101 | 98.470001 102 | 97 103 | 97.110001 104 | 99.669998 105 | 102.190002 106 | 103.769997 107 | 103.699997 108 | 104.199997 109 | 104.879997 110 | 101.970001 111 | 103.309998 112 | 102.580002 113 | 103.940002 114 | 108.760002 115 | 107.099998 116 | 106.489998 117 | 107.910004 118 | 109.010002 119 | 108.019997 120 | 103.949997 121 | 104.040001 122 | 105.129997 123 | 107.769997 124 | 106.260002 125 | 107.32 126 | 106.949997 127 | 105.739998 128 | 105.410004 129 | 105.449997 130 | 104.239998 131 | 107.120003 132 | 106.07 133 | 104.379997 134 | 106.18 135 | 105.610001 136 | 106.489998 137 | 104.599998 138 | 105.419998 139 | 102.120003 140 | 104.660004 141 | 104.550003 142 | 106.790001 143 | 106.220001 144 | 104.040001 145 | 104.769997 146 | 105.510002 147 | 104.629997 148 | 105.019997 149 | 105.93 150 | 107.260002 151 | 106.220001 152 | 104.660004 153 | 102.220001 154 | 102.730003 155 | 102.970001 156 | 97.919998 157 | 97.330002 158 | 97.510002 159 | 99.370003 160 | 95.440002 161 | 98.370003 162 | 94.970001 163 | 95.260002 164 | 94.349998 165 | 94.160004 166 | 97.940002 167 | 97.010002 168 | 97.339996 169 | 94.449997 170 | 109.110001 171 | 112.209999 172 | 115.089996 173 | 114.610001 174 | 112.690002 175 | 110.489998 176 | 104.07 177 | 99.75 178 | 99.540001 179 | 101 180 | 101.910004 181 | 102.010002 182 | 101.610001 183 | 105.199997 184 | 103.470001 185 | 104.57 186 | 107.160004 187 | 105.459999 188 | 106.879997 189 | 108.07 190 | 107.919998 191 | 106.919998 192 | 109.82 193 | 109.949997 194 | 109.580002 195 | 108.389999 196 | 105.730003 197 | 105.93 198 | 107.510002 199 | 107.32 200 | 109.410004 201 | 109.889999 202 | 110.669998 203 | 112.18 204 | 111.019997 205 | 111.449997 206 | 111.849998 207 | 112.25 208 | 112.540001 209 | 113.050003 210 | 113.690002 211 | 116.139999 212 | 114.699997 213 | 114.099998 214 | 116.059998 215 | 112.550003 216 | 112.220001 217 | 113.709999 218 | 113.639999 219 | 110.629997 220 | 108.989998 221 | 110.610001 222 | 110.510002 223 | 110.839996 224 | 109.639999 225 | 110.449997 226 | 112.290001 227 | 112.419998 228 | 113.440002 229 | 110.559998 230 | 110.099998 231 | 108.760002 232 | 108.889999 233 | 116.730003 234 | 117.580002 235 | 118.57 236 | 117.43 237 | 118.059998 238 | 117.809998 239 | 119.489998 240 | 119.239998 241 | 120.5 242 | 119.519997 243 | 120.279999 244 | 119.809998 245 | 118.669998 246 | 117.349998 247 | 117.650002 248 | 116.809998 249 | 117.349998 250 | 115.970001 251 | 117.699997 252 | 117.889999 253 | 119.470001 254 | 119.379997 255 | 118.809998 256 | 118.779999 257 | 118.93 258 | 118.470001 259 | 118.790001 260 | 117.760002 261 | 118.389999 262 | 118.559998 263 | 116.620003 264 | 113.949997 265 | 114.940002 266 | 114.599998 267 | 114.389999 268 | 113.019997 269 | 113.370003 270 | 114.379997 271 | 113.910004 272 | 115.080002 273 | 112.080002 274 | 108.970001 275 | 112.699997 276 | 114.160004 277 | 114.279999 278 | 114.190002 279 | 114.199997 280 | 116.699997 281 | 115.849998 282 | 117.239998 283 | 117.870003 284 | 117.93 285 | 116.779999 286 | 117.290001 287 | 116.860001 288 | 119.370003 289 | 120.610001 290 | 121.919998 291 | 120.610001 292 | 121 293 | 121.629997 294 | 121.220001 295 | 123.339996 296 | 125 297 | 123.940002 298 | 124.309998 299 | 123.089996 300 | 122.510002 301 | 124.360001 302 | 125.150002 303 | 125.260002 304 | 125.059998 305 | 124.879997 306 | 124.900002 307 | 124.879997 308 | 123.900002 309 | 123.300003 310 | 124.370003 311 | 123.910004 312 | 123.559998 313 | 124.150002 314 | 124.370003 315 | 123.480003 316 | 123.889999 317 | 124.959999 318 | 126.540001 319 | 125.839996 320 | 126.120003 321 | 126.169998 322 | 126.510002 323 | 129.729996 324 | 131.050003 325 | 130.270004 326 | 127.099998 327 | 128.690002 328 | 127.209999 329 | 127.769997 330 | 128.350006 331 | 129.070007 332 | 128.649994 333 | 128.639999 334 | 129.940002 335 | 130.080002 336 | 127.959999 337 | 127.309998 338 | 128.690002 339 | 129.229996 340 | 128.089996 341 | 128.270004 342 | 128.770004 343 | 128.190002 344 | 128.470001 345 | 128.740005 346 | 128.990005 347 | 130.240005 348 | 128.880005 349 | 129.050003 350 | 127.82 351 | 127.879997 352 | 127.540001 353 | 128.570007 354 | 130.110001 355 | 130 356 | 132.070007 357 | 133.279999 358 | 132.289993 359 | 131.039993 360 | 129.690002 361 | 131.289993 362 | 130.990005 363 | 129.5 364 | 127.169998 365 | 120 366 | 120.75 367 | 122.150002 368 | 124.220001 369 | 123.18 370 | 120.800003 371 | 119.019997 372 | 115.080002 373 | 117.199997 374 | 116.339996 375 | 117.790001 376 | 117.019997 377 | 121.769997 378 | 121.470001 379 | 120.839996 380 | 120.379997 381 | 120.410004 382 | 120.870003 383 | 118.419998 384 | 115.099998 385 | 115.400002 386 | 117.43 387 | 117.309998 388 | 117.949997 389 | 118.910004 390 | 119.68 391 | 117.769997 392 | 120.309998 393 | 120.209999 394 | 120.57 395 | 119.870003 396 | 119.239998 397 | 119.089996 398 | 119.040001 399 | 117.400002 400 | 117.269997 401 | 118.010002 402 | 116.919998 403 | 116.349998 404 | 115.050003 405 | 116.860001 406 | 118.690002 407 | 120.669998 408 | 123.410004 409 | 124.900002 410 | 124.349998 411 | 126.089996 412 | 126.620003 413 | 128.339996 414 | 127.870003 415 | 127.919998 416 | 127.550003 417 | 127.040001 418 | 128.929993 419 | 129.369995 420 | 131.479996 421 | 132.779999 422 | 132.179993 423 | 130.979996 424 | 130.320007 425 | 133.229996 426 | 130.839996 427 | 130.979996 428 | 132.059998 429 | 131.839996 430 | 134.199997 431 | 134.139999 432 | 134.190002 433 | 134.050003 434 | 133.850006 435 | 133.440002 436 | 133.839996 437 | 133.529999 438 | 133.720001 439 | 136.119995 440 | 135.360001 441 | 135.440002 442 | 136.410004 443 | 135.539993 444 | 137.419998 445 | 136.759995 446 | 137.169998 447 | 137.419998 448 | 137.300003 449 | 137.720001 450 | 138.240005 451 | 138.789993 452 | 139.600006 453 | 139.320007 454 | 139.720001 455 | 139.990005 456 | 139.839996 457 | 139.940002 458 | 138.509995 459 | 139.589996 460 | 139.529999 461 | 140.339996 462 | 140.320007 463 | 141.759995 464 | 142.649994 465 | 142.410004 466 | 142.050003 467 | 142.279999 468 | 141.729996 469 | 141.850006 470 | 141.169998 471 | 140.779999 472 | 141.039993 473 | 139.919998 474 | 139.580002 475 | 139.389999 476 | 141.419998 477 | 140.960007 478 | 142.270004 479 | 143.800003 480 | 143.679993 481 | 145.470001 482 | 146.490005 483 | 146.559998 484 | 147.699997 485 | 150.25 486 | 152.460007 487 | 152.779999 488 | 151.800003 489 | 150.850006 490 | 150.240005 491 | 151.059998 492 | 150.479996 493 | 150.289993 494 | 150.039993 495 | 150.330002 496 | 150.190002 497 | 149.779999 498 | 144.850006 499 | 147.660004 500 | 148.059998 501 | 148.240005 502 | 148.070007 503 | 150.039993 504 | 151.960007 505 | 152.130005 506 | 152.380005 507 | 151.460007 508 | 151.529999 509 | 153.610001 510 | 153.630005 511 | 152.809998 512 | 153.119995 513 | 154.710007 514 | 149.600006 515 | 148.440002 516 | 150.679993 517 | 150.25 518 | 149.800003 519 | 150.639999 520 | 152.869995 521 | 152.25 522 | 153.910004 523 | 153.399994 524 | 155.070007 525 | 153.589996 526 | 150.580002 527 | 153.240005 528 | 151.039993 529 | 150.979996 530 | 148.429993 531 | 150.339996 532 | 148.820007 533 | 151.440002 534 | 153.5 535 | 155.270004 536 | 158.899994 537 | 159.259995 538 | 159.970001 539 | 159.729996 540 | 162.860001 541 | 164.139999 542 | 164.529999 543 | 164.429993 544 | 166 545 | 165.279999 546 | 165.610001 547 | 170.440002 548 | 172.449997 549 | 169.25 550 | 169.860001 551 | 169.300003 552 | 168.589996 553 | 169.619995 554 | 171.979996 555 | 171.229996 556 | 171.179993 557 | 167.399994 558 | 168.080002 559 | 170.75 560 | 171 561 | 170 562 | 166.910004 563 | 167.410004 564 | 167.779999 565 | 169.639999 566 | 168.710007 567 | 167.740005 568 | 166.320007 569 | 167.240005 570 | 168.050003 571 | 169.919998 572 | 171.970001 573 | 172.020004 574 | 170.720001 575 | 172.089996 576 | 173.210007 577 | 170.949997 578 | 173.509995 579 | 172.960007 580 | 173.050003 581 | 170.960007 582 | 171.639999 583 | 170.009995 584 | 172.520004 585 | 172.169998 586 | 171.110001 587 | 170.539993 588 | 162.869995 589 | 164.210007 590 | 167.679993 591 | 168.729996 592 | 170.869995 593 | 169.470001 594 | 169.960007 595 | 168.419998 596 | 171.240005 597 | 172.229996 598 | 172.5 599 | 171.589996 600 | 172.740005 601 | 172.550003 602 | 173.740005 603 | 174.520004 604 | 176.110001 605 | 176.029999 606 | 174.559998 607 | 174.979996 608 | 171.270004 609 | 171.800003 610 | 170.600006 611 | 170.630005 612 | 177.880005 613 | 179.869995 614 | 180.059998 615 | 182.660004 616 | 178.919998 617 | 178.919998 618 | 180.169998 619 | 180.25 620 | 179.559998 621 | 179.300003 622 | 178.460007 623 | 178.770004 624 | 178.070007 625 | 177.949997 626 | 179.589996 627 | 179 628 | 178.740005 629 | 181.860001 630 | 180.869995 631 | 182.779999 632 | 183.029999 633 | 182.419998 634 | 175.130005 635 | 177.179993 636 | 175.100006 637 | 171.470001 638 | 172.830002 639 | 176.059998 640 | 180.139999 641 | 179 642 | 179.039993 643 | 176.960007 644 | 178.300003 645 | 178.389999 646 | 180.179993 647 | 180.820007 648 | 179.509995 649 | 177.889999 650 | 177.449997 651 | 177.199997 652 | 175.990005 653 | 177.619995 654 | 177.919998 655 | 176.460007 656 | 181.419998 657 | 184.669998 658 | 184.330002 659 | 186.850006 660 | 188.279999 661 | 187.869995 662 | 187.839996 663 | 187.770004 664 | 179.369995 665 | 178.389999 666 | 177.600006 667 | 179.800003 668 | 181.289993 669 | 185.369995 670 | 189.350006 671 | 186.550003 672 | 187.479996 673 | 190 674 | 185.979996 675 | 187.119995 676 | 186.889999 677 | 193.089996 678 | 190.279999 679 | 181.259995 680 | 185.309998 681 | 180.179993 682 | 171.580002 683 | 176.110001 684 | 176.410004 685 | 173.149994 686 | 179.520004 687 | 179.960007 688 | 177.360001 689 | 176.009995 690 | 177.910004 691 | 178.990005 692 | 183.289993 693 | 184.929993 694 | 181.460007 695 | 178.320007 696 | 175.940002 697 | 176.619995 698 | 180.399994 699 | 179.779999 700 | 183.710007 701 | 182.339996 702 | 185.229996 703 | 184.759995 704 | 181.880005 705 | 184.190002 706 | 183.860001 707 | 185.089996 708 | 172.559998 709 | 168.149994 710 | 169.389999 711 | 164.889999 712 | 159.389999 713 | 160.059998 714 | 152.220001 715 | 153.029999 716 | 159.789993 717 | 155.389999 718 | 156.110001 719 | 155.100006 720 | 159.339996 721 | 157.199997 722 | 157.929993 723 | 165.039993 724 | 166.320007 725 | 163.869995 726 | 164.520004 727 | 164.830002 728 | 168.660004 729 | 166.360001 730 | 168.100006 731 | 166.279999 732 | 165.839996 733 | 159.690002 734 | 159.690002 735 | 174.160004 736 | 173.589996 737 | 172 738 | 173.860001 739 | 176.070007 740 | 174.020004 741 | 176.610001 742 | 177.970001 743 | 178.919998 744 | 182.660004 745 | 185.529999 746 | 186.990005 747 | 186.639999 748 | 184.320007 749 | 183.199997 750 | 183.759995 751 | 182.679993 752 | 184.490005 753 | 183.800003 754 | 186.899994 755 | 185.929993 756 | 184.919998 757 | 185.740005 758 | 187.669998 759 | 191.779999 760 | 193.990005 761 | 193.279999 762 | 192.940002 763 | 191.339996 764 | 188.179993 765 | 189.100006 766 | 191.539993 767 | 192.399994 768 | 192.410004 769 | 196.809998 770 | 195.850006 771 | 198.309998 772 | 197.490005 773 | 202 774 | 201.5 775 | 201.740005 776 | 196.350006 777 | 199 778 | 195.839996 779 | 196.229996 780 | 194.320007 781 | 197.360001 782 | 192.729996 783 | 198.449997 784 | 203.229996 785 | 204.740005 786 | 203.539993 787 | 202.539993 788 | 206.919998 789 | 207.320007 790 | 207.229996 791 | 209.990005 792 | 209.360001 793 | 208.089996 794 | 209.940002 795 | 210.910004 796 | 214.669998 797 | 217.5 798 | 176.259995 799 | 174.889999 800 | 171.059998 801 | 172.580002 802 | 171.649994 803 | 176.369995 804 | 177.779999 805 | 185.690002 806 | 183.809998 807 | 185.179993 808 | 183.089996 809 | 180.259995 810 | 180.050003 811 | 181.110001 812 | 179.529999 813 | 174.699997 814 | 173.800003 815 | 172.5 816 | 172.619995 817 | 173.639999 818 | 172.899994 819 | 174.649994 820 | 177.460007 821 | 176.259995 822 | 175.899994 823 | 177.639999 824 | 175.729996 825 | 171.160004 826 | 167.179993 827 | 162.529999 828 | 163.039993 829 | 164.179993 830 | 165.940002 831 | 162 832 | 161.360001 833 | 162.320007 834 | 160.580002 835 | 160.300003 836 | 163.059998 837 | 166.020004 838 | 162.929993 839 | 165.410004 840 | 164.910004 841 | 166.949997 842 | 168.839996 843 | 164.460007 844 | 162.440002 845 | 159.330002 846 | 162.429993 847 | 158.850006 848 | 157.330002 849 | 157.25 850 | 157.899994 851 | 151.380005 852 | 153.350006 853 | 153.740005 854 | 153.520004 855 | 158.779999 856 | 159.419998 857 | 154.919998 858 | 154.050003 859 | 154.779999 860 | 154.389999 861 | 146.039993 862 | 150.949997 863 | 145.369995 864 | 142.089996 865 | 146.220001 866 | 151.789993 867 | 151.75 868 | 150.350006 869 | 148.679993 870 | 149.940002 871 | 151.529999 872 | 147.869995 873 | 144.960007 874 | 141.550003 875 | 142.160004 876 | 144.220001 877 | 143.850006 878 | 139.529999 879 | 131.550003 880 | 132.429993 881 | 134.820007 882 | 131.729996 883 | 136.380005 884 | 135 885 | 136.759995 886 | 138.679993 887 | 140.610001 888 | 141.089996 889 | 137.929993 890 | 139.630005 891 | 137.419998 892 | 141.850006 893 | 142.080002 894 | 144.5 895 | 145.009995 896 | 144.059998 897 | 140.190002 898 | 143.660004 899 | 133.240005 900 | 133.399994 901 | 124.949997 902 | 124.059998 903 | 134.179993 904 | 134.520004 905 | 133.199997 906 | 131.089996 907 | 135.679993 908 | 131.740005 909 | 137.949997 910 | 138.050003 911 | 142.529999 912 | 144.229996 913 | 144.199997 914 | 143.800003 915 | 145.389999 916 | 148.949997 917 | 147.539993 918 | 148.300003 919 | 150.039993 920 | 147.570007 921 | 144.300003 922 | 145.830002 923 | 149.009995 924 | 147.470001 925 | 144.190002 926 | 150.419998 927 | 166.690002 928 | 165.710007 929 | 169.25 930 | 171.160004 931 | 170.490005 932 | 166.380005 933 | 167.330002 934 | 165.789993 935 | 165.039993 936 | 164.070007 937 | 163.949997 938 | 162.5 939 | 162.289993 940 | 162.559998 941 | 160.039993 942 | 161.889999 943 | 164.619995 944 | 164.130005 945 | 162.809998 946 | 161.449997 947 | 162.279999 948 | 167.369995 949 | 171.259995 950 | 172.509995 951 | 169.130005 952 | 169.600006 953 | 172.070007 954 | 171.919998 955 | 173.369995 956 | 170.169998 957 | 165.979996 958 | 160.470001 959 | 161.570007 960 | 165.440002 961 | 166.080002 962 | 164.339996 963 | 166.289993 964 | 167.679993 965 | 165.869995 966 | 165.550003 967 | 166.690002 968 | 168.699997 969 | 174.199997 970 | 173.539993 971 | 176.020004 972 | 175.720001 973 | 174.929993 974 | 177.580002 975 | 177.820007 976 | 177.509995 977 | 179.100006 978 | 179.649994 979 | 178.869995 980 | 178.779999 981 | 178.279999 982 | 181.440002 983 | 183.779999 984 | 182.580002 985 | 193.259995 986 | 191.490005 987 | 194.779999 988 | 193.399994 989 | 193.029999 990 | 192.529999 991 | 195.470001 992 | 193.880005 993 | 189.770004 994 | 189.539993 995 | 188.649994 996 | 188.339996 997 | 181.539993 998 | 180.729996 999 | 186.270004 1000 | 186.990005 1001 | 185.300003 1002 | 182.720001 1003 | 184.820007 1004 | 185.320007 1005 | 180.869995 1006 | 181.059998 1007 | 184.309998 1008 | 182.190002 1009 | 183.009995 1010 | 177.470001 1011 | 164.149994 1012 | 167.5 1013 | 168.169998 1014 | 168.330002 1015 | 173.350006 1016 | 174.820007 1017 | 178.100006 1018 | 175.039993 1019 | 177.470001 1020 | 181.330002 1021 | 189.009995 1022 | 188.470001 1023 | 187.479996 1024 | 189.529999 1025 | 191.139999 1026 | 192.600006 1027 | 188.839996 1028 | 187.660004 1029 | 189.5 1030 | 193 1031 | 193 1032 | 195 1033 | 197.199997 1034 | 196.399994 1035 | 195.759995 1036 | 199.210007 1037 | 202.729996 1038 | 201.229996 1039 | 204.869995 1040 | 203.910004 1041 | 203.839996 1042 | 201.800003 1043 | 200.779999 1044 | 198.360001 1045 | 202.320007 1046 | 202.360001 1047 | 204.660004 1048 | 200.710007 1049 | 199.75 1050 | 195.940002 1051 | 197.039993 1052 | 194.229996 1053 | 192.729996 1054 | 189.020004 1055 | 181.729996 1056 | 184.509995 1057 | 185.149994 1058 | 190.160004 1059 | 187.850006 1060 | 185.369995 1061 | 188.449997 1062 | 179.710007 1063 | 182.589996 1064 | 183.699997 1065 | 186.169998 1066 | 183.809998 1067 | 183.550003 1068 | 182.039993 1069 | 177.75 1070 | 180.360001 1071 | 181.300003 1072 | 181.759995 1073 | 185.570007 1074 | 185.669998 1075 | 182.389999 1076 | 187.139999 1077 | 190.899994 1078 | 187.490005 1079 | 188.759995 1080 | 186.169998 1081 | 188.490005 1082 | 187.470001 1083 | 187.190002 1084 | 186.220001 1085 | 188.080002 1086 | 188.139999 1087 | 190.139999 1088 | 189.929993 1089 | 186.820007 1090 | 181.279999 1091 | 182.800003 1092 | 180.110001 1093 | 177.100006 1094 | 178.080002 1095 | 175.809998 1096 | 174.600006 1097 | 179.380005 1098 | 180.449997 1099 | 179.679993 1100 | 177.75 1101 | 179.850006 1102 | 180.029999 1103 | 184.190002 1104 | 183.279999 1105 | 188.889999 1106 | 189.550003 1107 | 190.389999 1108 | 185.850006 1109 | 189.759995 1110 | 182.339996 1111 | 186.149994 1112 | 186.380005 1113 | 187.889999 1114 | 189.399994 1115 | 189.309998 1116 | 188.25 1117 | 191.649994 1118 | 193.619995 1119 | 194.720001 1120 | 194.320007 1121 | 191.550003 1122 | 190.419998 1123 | 190.839996 1124 | 189.610001 1125 | 194.470001 1126 | 193.190002 1127 | 193.149994 1128 | 195.100006 1129 | 197.399994 1130 | 199.320007 1131 | 197.509995 1132 | 197.929993 1133 | 198.820007 1134 | 199.789993 1135 | 198.970001 1136 | 202 1137 | 201.639999 1138 | 199.699997 1139 | 198.820007 1140 | 198.710007 1141 | 199.360001 1142 | 201.050003 1143 | 201.339996 1144 | 200.869995 1145 | 202.259995 1146 | 196.75 1147 | 194.110001 1148 | 197.919998 1149 | 198.389999 1150 | 202.5 1151 | 206.059998 1152 | 206.300003 1153 | 206.179993 1154 | 205.119995 1155 | 207.789993 1156 | 208.100006 1157 | 204.410004 1158 | 205.25 1159 | 209.779999 1160 | 208.669998 1161 | 212.600006 1162 | 213.059998 1163 | 215.220001 1164 | 218.300003 1165 | 218.059998 1166 | 221.910004 1167 | 219.059998 1168 | 221.149994 1169 | 221.770004 1170 | 222.139999 1171 | 221.440002 1172 | 221.320007 1173 | 219.759995 1174 | 217.940002 1175 | 214.869995 1176 | 217.789993 1177 | 223.229996 1178 | 209.529999 1179 | 201.910004 1180 | 204.190002 1181 | 209.830002 1182 | 210.110001 1183 | 210.850006 1184 | 212.330002 1185 | 213.059998 1186 | 207.190002 1187 | 210.759995 1188 | 213.139999 1189 | 214.179993 1190 | 217.800003 1191 | 217.490005 1192 | 214.580002 1193 | 210.179993 1194 | 200.720001 1195 | 196.770004 1196 | 197.199997 1197 | 189.75 1198 | 192.470001 1199 | 196.440002 1200 | 185.889999 1201 | 191.759995 1202 | 185.169998 1203 | 181.089996 1204 | 169.5 1205 | 178.190002 1206 | 170.240005 1207 | 154.470001 1208 | 170.279999 1209 | 146.009995 1210 | 149.419998 1211 | 146.960007 1212 | 153.130005 1213 | 149.729996 1214 | 148.100006 1215 | 160.979996 1216 | 156.210007 1217 | 163.339996 1218 | 156.789993 1219 | 165.949997 1220 | 166.800003 1221 | 159.600006 1222 | 158.190002 1223 | 154.179993 1224 | 165.550003 1225 | 168.830002 1226 | 174.279999 1227 | 175.190002 1228 | 174.789993 1229 | 178.169998 1230 | 176.970001 1231 | 176.25 1232 | 179.240005 1233 | 178.240005 1234 | 170.800003 1235 | 182.279999 1236 | 185.130005 1237 | 190.070007 1238 | 187.5 1239 | 182.910004 1240 | 194.190002 1241 | 204.710007 1242 | 202.270004 1243 | 205.259995 1244 | 207.070007 1245 | 208.470001 1246 | 211.259995 1247 | 212.350006 1248 | 213.179993 1249 | 210.100006 1250 | 205.100006 1251 | 206.809998 1252 | 210.880005 1253 | 213.190002 1254 | 216.880005 1255 | 229.970001 1256 | 231.389999 1257 | 234.910004 1258 | 232.199997 1259 | 229.139999 1260 | -------------------------------------------------------------------------------- /data/stocks/goog/close.vals: -------------------------------------------------------------------------------- 1 | 539.780029 2 | 532.109985 3 | 533.98999 4 | 539.179993 5 | 540.309998 6 | 536.700012 7 | 533.330017 8 | 526.830017 9 | 526.690002 10 | 536.690002 11 | 534.609985 12 | 532.330017 13 | 527.200012 14 | 528.150024 15 | 529.26001 16 | 536.72998 17 | 536.690002 18 | 538.190002 19 | 540.47998 20 | 537.840027 21 | 535.22998 22 | 531.690002 23 | 521.52002 24 | 520.51001 25 | 521.840027 26 | 523.400024 27 | 522.859985 28 | 525.02002 29 | 516.830017 30 | 520.679993 31 | 530.130005 32 | 546.549988 33 | 561.099976 34 | 560.219971 35 | 579.849976 36 | 672.929993 37 | 663.02002 38 | 662.299988 39 | 662.099976 40 | 644.280029 41 | 623.559998 42 | 627.26001 43 | 628 44 | 631.929993 45 | 632.590027 46 | 625.609985 47 | 631.210022 48 | 629.25 49 | 643.780029 50 | 642.679993 51 | 635.299988 52 | 633.72998 53 | 660.780029 54 | 659.559998 55 | 656.450012 56 | 657.119995 57 | 660.869995 58 | 656.130005 59 | 660.900024 60 | 646.830017 61 | 612.47998 62 | 589.609985 63 | 582.059998 64 | 628.619995 65 | 637.609985 66 | 630.380005 67 | 618.25 68 | 597.789978 69 | 614.340027 70 | 606.25 71 | 600.700012 72 | 614.659973 73 | 612.719971 74 | 621.349976 75 | 625.77002 76 | 623.23999 77 | 635.140015 78 | 635.97998 79 | 642.900024 80 | 629.25 81 | 635.440002 82 | 622.690002 83 | 622.359985 84 | 625.799988 85 | 611.969971 86 | 594.890015 87 | 594.969971 88 | 608.419983 89 | 611.289978 90 | 626.909973 91 | 641.469971 92 | 645.440002 93 | 642.359985 94 | 639.159973 95 | 643.609985 96 | 646.669983 97 | 652.299988 98 | 651.159973 99 | 661.73999 100 | 662.200012 101 | 666.099976 102 | 650.280029 103 | 642.609985 104 | 651.789978 105 | 702 106 | 712.780029 107 | 708.48999 108 | 712.950012 109 | 716.919983 110 | 710.809998 111 | 721.109985 112 | 722.159973 113 | 728.109985 114 | 731.25 115 | 733.76001 116 | 724.890015 117 | 728.320007 118 | 735.400024 119 | 731.22998 120 | 717 121 | 728.960022 122 | 725.299988 123 | 740 124 | 738.409973 125 | 756.599976 126 | 755.97998 127 | 748.280029 128 | 748.150024 129 | 750.26001 130 | 742.599976 131 | 767.039978 132 | 762.380005 133 | 752.539978 134 | 766.809998 135 | 763.25 136 | 762.369995 137 | 751.609985 138 | 749.460022 139 | 738.869995 140 | 747.77002 141 | 743.400024 142 | 758.090027 143 | 749.429993 144 | 739.309998 145 | 747.77002 146 | 750 147 | 750.309998 148 | 748.400024 149 | 762.51001 150 | 776.599976 151 | 771 152 | 758.880005 153 | 741.840027 154 | 742.580017 155 | 743.619995 156 | 726.390015 157 | 714.469971 158 | 716.030029 159 | 726.070007 160 | 700.559998 161 | 714.719971 162 | 694.450012 163 | 701.789978 164 | 698.450012 165 | 706.590027 166 | 725.25 167 | 711.669983 168 | 713.039978 169 | 699.98999 170 | 730.960022 171 | 742.950012 172 | 752 173 | 764.650024 174 | 726.950012 175 | 708.01001 176 | 683.570007 177 | 682.73999 178 | 678.109985 179 | 684.119995 180 | 683.109985 181 | 682.400024 182 | 691 183 | 708.400024 184 | 697.349976 185 | 700.909973 186 | 706.460022 187 | 695.849976 188 | 699.559998 189 | 705.75 190 | 705.070007 191 | 697.77002 192 | 718.809998 193 | 718.849976 194 | 712.419983 195 | 710.890015 196 | 695.159973 197 | 693.969971 198 | 705.23999 199 | 712.820007 200 | 726.820007 201 | 730.48999 202 | 728.330017 203 | 736.090027 204 | 737.780029 205 | 737.599976 206 | 742.090027 207 | 740.75 208 | 738.059998 209 | 735.299988 210 | 733.530029 211 | 744.77002 212 | 750.530029 213 | 744.950012 214 | 749.909973 215 | 745.289978 216 | 737.799988 217 | 745.690002 218 | 740.280029 219 | 739.150024 220 | 736.099976 221 | 743.090027 222 | 751.719971 223 | 753.200012 224 | 759 225 | 766.609985 226 | 753.929993 227 | 752.669983 228 | 759.140015 229 | 718.77002 230 | 723.150024 231 | 708.140015 232 | 705.840027 233 | 691.02002 234 | 693.01001 235 | 698.210022 236 | 692.359985 237 | 695.700012 238 | 701.429993 239 | 711.119995 240 | 712.900024 241 | 723.179993 242 | 715.289978 243 | 713.309998 244 | 710.830017 245 | 716.48999 246 | 706.22998 247 | 706.630005 248 | 700.320007 249 | 709.73999 250 | 704.23999 251 | 720.090027 252 | 725.27002 253 | 724.119995 254 | 732.659973 255 | 735.719971 256 | 734.150024 257 | 730.400024 258 | 722.340027 259 | 716.549988 260 | 716.650024 261 | 728.280029 262 | 728.580017 263 | 719.409973 264 | 718.359985 265 | 718.27002 266 | 718.919983 267 | 710.359985 268 | 691.719971 269 | 693.710022 270 | 695.940002 271 | 697.460022 272 | 701.869995 273 | 675.219971 274 | 668.26001 275 | 680.039978 276 | 684.109985 277 | 692.099976 278 | 699.210022 279 | 694.950012 280 | 697.77002 281 | 695.359985 282 | 705.630005 283 | 715.090027 284 | 720.640015 285 | 716.97998 286 | 720.950012 287 | 719.849976 288 | 733.780029 289 | 736.960022 290 | 741.190002 291 | 738.630005 292 | 742.73999 293 | 739.77002 294 | 738.419983 295 | 741.77002 296 | 745.909973 297 | 768.789978 298 | 772.880005 299 | 771.070007 300 | 773.179993 301 | 771.609985 302 | 782.219971 303 | 781.76001 304 | 784.26001 305 | 784.679993 306 | 784.849976 307 | 783.219971 308 | 782.440002 309 | 777.140015 310 | 779.909973 311 | 777.5 312 | 775.419983 313 | 772.150024 314 | 772.080017 315 | 769.640015 316 | 769.409973 317 | 769.539978 318 | 772.150024 319 | 769.090027 320 | 767.049988 321 | 768.780029 322 | 771.460022 323 | 780.080017 324 | 780.349976 325 | 775.320007 326 | 759.659973 327 | 769.02002 328 | 759.690002 329 | 762.48999 330 | 771.76001 331 | 768.880005 332 | 765.700012 333 | 771.409973 334 | 776.219971 335 | 787.210022 336 | 786.900024 337 | 774.210022 338 | 783.01001 339 | 781.559998 340 | 775.01001 341 | 777.289978 342 | 772.559998 343 | 776.429993 344 | 776.469971 345 | 776.859985 346 | 775.080017 347 | 785.940002 348 | 783.070007 349 | 786.140015 350 | 778.190002 351 | 778.530029 352 | 779.960022 353 | 795.26001 354 | 801.5 355 | 796.969971 356 | 799.369995 357 | 813.109985 358 | 807.669983 359 | 799.070007 360 | 795.349976 361 | 795.369995 362 | 784.539978 363 | 783.609985 364 | 768.700012 365 | 762.130005 366 | 762.02002 367 | 782.52002 368 | 790.51001 369 | 785.309998 370 | 762.559998 371 | 754.02002 372 | 736.080017 373 | 758.48999 374 | 764.47998 375 | 771.22998 376 | 760.539978 377 | 769.200012 378 | 768.27002 379 | 760.98999 380 | 761.679993 381 | 768.23999 382 | 770.840027 383 | 758.039978 384 | 747.919983 385 | 750.5 386 | 762.52002 387 | 759.109985 388 | 771.190002 389 | 776.419983 390 | 789.289978 391 | 789.27002 392 | 796.099976 393 | 797.070007 394 | 797.849976 395 | 790.799988 396 | 794.200012 397 | 796.419983 398 | 794.559998 399 | 791.26001 400 | 789.909973 401 | 791.549988 402 | 785.049988 403 | 782.789978 404 | 771.820007 405 | 786.140015 406 | 786.900024 407 | 794.02002 408 | 806.150024 409 | 806.650024 410 | 804.789978 411 | 807.909973 412 | 806.359985 413 | 807.880005 414 | 804.609985 415 | 806.070007 416 | 802.174988 417 | 805.02002 418 | 819.309998 419 | 823.869995 420 | 835.669983 421 | 832.150024 422 | 823.309998 423 | 802.320007 424 | 796.789978 425 | 795.695007 426 | 798.530029 427 | 801.48999 428 | 801.340027 429 | 806.969971 430 | 808.380005 431 | 809.559998 432 | 813.669983 433 | 819.23999 434 | 820.450012 435 | 818.97998 436 | 824.159973 437 | 828.070007 438 | 831.659973 439 | 830.76001 440 | 831.330017 441 | 828.640015 442 | 829.280029 443 | 823.210022 444 | 835.23999 445 | 830.630005 446 | 829.080017 447 | 827.780029 448 | 831.909973 449 | 835.369995 450 | 838.679993 451 | 843.25 452 | 845.539978 453 | 845.619995 454 | 847.200012 455 | 848.780029 456 | 852.119995 457 | 848.400024 458 | 830.460022 459 | 829.590027 460 | 817.580017 461 | 814.429993 462 | 819.51001 463 | 820.919983 464 | 831.409973 465 | 831.5 466 | 829.559998 467 | 838.549988 468 | 834.570007 469 | 831.409973 470 | 827.880005 471 | 824.669983 472 | 824.72998 473 | 823.349976 474 | 824.320007 475 | 823.559998 476 | 837.169983 477 | 836.820007 478 | 838.210022 479 | 841.650024 480 | 843.190002 481 | 862.76001 482 | 872.299988 483 | 871.72998 484 | 874.25 485 | 905.960022 486 | 912.570007 487 | 916.440002 488 | 927.039978 489 | 931.659973 490 | 927.130005 491 | 934.299988 492 | 932.169983 493 | 928.780029 494 | 930.599976 495 | 932.219971 496 | 937.080017 497 | 943 498 | 919.619995 499 | 930.23999 500 | 934.01001 501 | 941.859985 502 | 948.820007 503 | 954.960022 504 | 969.539978 505 | 971.469971 506 | 975.880005 507 | 964.859985 508 | 966.950012 509 | 975.599976 510 | 983.679993 511 | 976.570007 512 | 980.940002 513 | 983.409973 514 | 949.830017 515 | 942.900024 516 | 953.400024 517 | 950.76001 518 | 942.309998 519 | 939.780029 520 | 957.369995 521 | 950.630005 522 | 959.450012 523 | 957.090027 524 | 965.590027 525 | 952.27002 526 | 927.330017 527 | 940.48999 528 | 917.789978 529 | 908.72998 530 | 898.700012 531 | 911.710022 532 | 906.690002 533 | 918.590027 534 | 928.799988 535 | 930.090027 536 | 943.830017 537 | 947.159973 538 | 955.98999 539 | 953.419983 540 | 965.400024 541 | 970.890015 542 | 968.150024 543 | 972.919983 544 | 980.340027 545 | 950.700012 546 | 947.799988 547 | 934.090027 548 | 941.530029 549 | 930.5 550 | 930.830017 551 | 930.390015 552 | 923.650024 553 | 927.960022 554 | 929.359985 555 | 926.789978 556 | 922.900024 557 | 907.23999 558 | 914.390015 559 | 922.669983 560 | 922.219971 561 | 926.960022 562 | 910.97998 563 | 910.669983 564 | 906.659973 565 | 924.690002 566 | 927 567 | 921.280029 568 | 915.890015 569 | 913.809998 570 | 921.289978 571 | 929.570007 572 | 939.330017 573 | 937.340027 574 | 928.450012 575 | 927.809998 576 | 935.950012 577 | 926.5 578 | 929.080017 579 | 932.070007 580 | 935.090027 581 | 925.109985 582 | 920.289978 583 | 915 584 | 921.809998 585 | 931.580017 586 | 932.450012 587 | 928.530029 588 | 920.969971 589 | 924.859985 590 | 944.48999 591 | 949.5 592 | 959.109985 593 | 953.27002 594 | 957.789978 595 | 951.679993 596 | 969.960022 597 | 978.890015 598 | 977 599 | 972.599976 600 | 989.25 601 | 987.830017 602 | 989.679993 603 | 992 604 | 992.179993 605 | 992.809998 606 | 984.450012 607 | 988.200012 608 | 968.450012 609 | 970.539978 610 | 973.330017 611 | 972.559998 612 | 1019.27002 613 | 1017.109985 614 | 1016.640015 615 | 1025.5 616 | 1025.579956 617 | 1032.47998 618 | 1025.900024 619 | 1033.329956 620 | 1039.849976 621 | 1031.26001 622 | 1028.069946 623 | 1025.75 624 | 1026 625 | 1020.909973 626 | 1032.5 627 | 1019.090027 628 | 1018.380005 629 | 1034.48999 630 | 1035.959961 631 | 1040.609985 632 | 1054.209961 633 | 1047.410034 634 | 1021.659973 635 | 1021.409973 636 | 1010.169983 637 | 998.679993 638 | 1005.150024 639 | 1018.380005 640 | 1030.930054 641 | 1037.050049 642 | 1041.099976 643 | 1040.47998 644 | 1040.609985 645 | 1049.150024 646 | 1064.189941 647 | 1077.140015 648 | 1070.680054 649 | 1064.949951 650 | 1063.630005 651 | 1060.119995 652 | 1056.73999 653 | 1049.369995 654 | 1048.140015 655 | 1046.400024 656 | 1065 657 | 1082.47998 658 | 1086.400024 659 | 1102.22998 660 | 1106.939941 661 | 1106.26001 662 | 1102.609985 663 | 1105.52002 664 | 1122.26001 665 | 1121.76001 666 | 1131.97998 667 | 1129.790039 668 | 1137.51001 669 | 1155.810059 670 | 1169.969971 671 | 1164.23999 672 | 1170.369995 673 | 1175.839966 674 | 1175.579956 675 | 1163.689941 676 | 1169.939941 677 | 1167.699951 678 | 1111.900024 679 | 1055.800049 680 | 1080.599976 681 | 1048.579956 682 | 1001.52002 683 | 1037.780029 684 | 1051.939941 685 | 1052.099976 686 | 1069.699951 687 | 1089.52002 688 | 1094.800049 689 | 1102.459961 690 | 1111.339966 691 | 1106.630005 692 | 1126.790039 693 | 1143.75 694 | 1118.290039 695 | 1104.72998 696 | 1069.52002 697 | 1078.920044 698 | 1090.930054 699 | 1095.060059 700 | 1109.640015 701 | 1126 702 | 1160.040039 703 | 1164.5 704 | 1138.170044 705 | 1149.48999 706 | 1149.579956 707 | 1135.72998 708 | 1099.819946 709 | 1097.709961 710 | 1090.880005 711 | 1049.079956 712 | 1021.570007 713 | 1053.209961 714 | 1005.099976 715 | 1004.559998 716 | 1031.790039 717 | 1006.469971 718 | 1013.409973 719 | 1025.140015 720 | 1027.810059 721 | 1007.039978 722 | 1015.450012 723 | 1031.640015 724 | 1019.969971 725 | 1032.51001 726 | 1029.27002 727 | 1037.97998 728 | 1074.160034 729 | 1072.079956 730 | 1087.699951 731 | 1072.959961 732 | 1067.449951 733 | 1019.97998 734 | 1021.179993 735 | 1040.040039 736 | 1030.050049 737 | 1017.330017 738 | 1037.310059 739 | 1024.380005 740 | 1023.719971 741 | 1048.209961 742 | 1054.790039 743 | 1053.910034 744 | 1082.76001 745 | 1097.569946 746 | 1098.26001 747 | 1100.199951 748 | 1079.22998 749 | 1081.77002 750 | 1078.589966 751 | 1066.359985 752 | 1079.579956 753 | 1069.72998 754 | 1079.689941 755 | 1079.23999 756 | 1075.660034 757 | 1060.319946 758 | 1067.800049 759 | 1084.98999 760 | 1119.5 761 | 1139.290039 762 | 1139.660034 763 | 1136.880005 764 | 1123.859985 765 | 1120.869995 766 | 1129.98999 767 | 1139.319946 768 | 1134.790039 769 | 1152.119995 770 | 1152.26001 771 | 1173.459961 772 | 1168.060059 773 | 1169.839966 774 | 1157.660034 775 | 1155.47998 776 | 1124.810059 777 | 1118.459961 778 | 1103.97998 779 | 1114.219971 780 | 1115.650024 781 | 1127.459961 782 | 1102.890015 783 | 1124.27002 784 | 1140.170044 785 | 1154.050049 786 | 1152.839966 787 | 1153.900024 788 | 1183.47998 789 | 1188.819946 790 | 1183.859985 791 | 1198.800049 792 | 1195.880005 793 | 1186.959961 794 | 1184.910034 795 | 1205.5 796 | 1248.079956 797 | 1263.699951 798 | 1268.329956 799 | 1238.5 800 | 1219.73999 801 | 1217.26001 802 | 1220.01001 803 | 1226.150024 804 | 1223.709961 805 | 1224.77002 806 | 1242.219971 807 | 1245.609985 808 | 1249.099976 809 | 1237.609985 810 | 1235.01001 811 | 1242.099976 812 | 1214.380005 813 | 1206.48999 814 | 1200.959961 815 | 1207.77002 816 | 1201.619995 817 | 1207.329956 818 | 1205.380005 819 | 1220.650024 820 | 1241.819946 821 | 1231.150024 822 | 1249.300049 823 | 1239.119995 824 | 1218.189941 825 | 1197 826 | 1186.47998 827 | 1171.439941 828 | 1164.829956 829 | 1164.640015 830 | 1177.359985 831 | 1162.819946 832 | 1175.329956 833 | 1172.530029 834 | 1156.050049 835 | 1161.219971 836 | 1171.089966 837 | 1186.869995 838 | 1166.089966 839 | 1173.369995 840 | 1184.650024 841 | 1180.48999 842 | 1194.640015 843 | 1193.469971 844 | 1195.310059 845 | 1200.109985 846 | 1202.949951 847 | 1168.189941 848 | 1157.349976 849 | 1148.969971 850 | 1138.819946 851 | 1081.219971 852 | 1079.319946 853 | 1110.079956 854 | 1092.25 855 | 1121.280029 856 | 1115.689941 857 | 1087.969971 858 | 1096.459961 859 | 1101.160034 860 | 1103.689941 861 | 1050.709961 862 | 1095.569946 863 | 1071.469971 864 | 1020.080017 865 | 1036.209961 866 | 1076.77002 867 | 1070 868 | 1057.790039 869 | 1040.089966 870 | 1055.810059 871 | 1093.390015 872 | 1082.400024 873 | 1066.150024 874 | 1038.630005 875 | 1036.050049 876 | 1043.660034 877 | 1064.709961 878 | 1061.48999 879 | 1020 880 | 1025.76001 881 | 1037.609985 882 | 1023.880005 883 | 1048.619995 884 | 1044.410034 885 | 1086.22998 886 | 1088.300049 887 | 1094.430054 888 | 1106.430054 889 | 1050.819946 890 | 1068.72998 891 | 1036.579956 892 | 1039.550049 893 | 1051.75 894 | 1063.680054 895 | 1061.900024 896 | 1042.099976 897 | 1016.530029 898 | 1028.709961 899 | 1023.01001 900 | 1009.409973 901 | 979.539978 902 | 976.219971 903 | 1039.459961 904 | 1043.880005 905 | 1037.079956 906 | 1035.609985 907 | 1045.849976 908 | 1016.059998 909 | 1070.709961 910 | 1068.390015 911 | 1076.280029 912 | 1074.660034 913 | 1070.329956 914 | 1057.189941 915 | 1044.689941 916 | 1077.150024 917 | 1080.969971 918 | 1089.900024 919 | 1098.26001 920 | 1070.52002 921 | 1075.569946 922 | 1073.900024 923 | 1090.98999 924 | 1070.079956 925 | 1060.619995 926 | 1089.060059 927 | 1116.369995 928 | 1110.75 929 | 1132.800049 930 | 1145.98999 931 | 1115.22998 932 | 1098.709961 933 | 1095.060059 934 | 1095.01001 935 | 1121.369995 936 | 1120.160034 937 | 1121.670044 938 | 1113.650024 939 | 1118.560059 940 | 1113.800049 941 | 1096.969971 942 | 1110.369995 943 | 1109.400024 944 | 1115.130005 945 | 1116.050049 946 | 1119.920044 947 | 1140.98999 948 | 1147.800049 949 | 1162.030029 950 | 1157.859985 951 | 1143.300049 952 | 1142.319946 953 | 1175.76001 954 | 1193.199951 955 | 1193.319946 956 | 1185.550049 957 | 1184.459961 958 | 1184.26001 959 | 1198.849976 960 | 1223.969971 961 | 1231.540039 962 | 1205.5 963 | 1193 964 | 1184.619995 965 | 1173.02002 966 | 1168.48999 967 | 1173.310059 968 | 1194.430054 969 | 1200.48999 970 | 1205.920044 971 | 1215 972 | 1207.150024 973 | 1203.839966 974 | 1197.25 975 | 1202.160034 976 | 1204.619995 977 | 1217.869995 978 | 1221.099976 979 | 1227.130005 980 | 1236.339966 981 | 1236.369995 982 | 1248.839966 983 | 1264.550049 984 | 1256 985 | 1263.449951 986 | 1272.180054 987 | 1287.579956 988 | 1188.47998 989 | 1168.079956 990 | 1162.609985 991 | 1185.400024 992 | 1189.390015 993 | 1174.099976 994 | 1166.27002 995 | 1162.380005 996 | 1164.27002 997 | 1132.030029 998 | 1120.439941 999 | 1164.209961 1000 | 1178.97998 1001 | 1162.300049 1002 | 1138.849976 1003 | 1149.630005 1004 | 1151.420044 1005 | 1140.77002 1006 | 1133.469971 1007 | 1134.150024 1008 | 1116.459961 1009 | 1117.949951 1010 | 1103.630005 1011 | 1036.22998 1012 | 1053.050049 1013 | 1042.219971 1014 | 1044.339966 1015 | 1066.040039 1016 | 1080.380005 1017 | 1078.719971 1018 | 1077.030029 1019 | 1088.77002 1020 | 1085.349976 1021 | 1092.5 1022 | 1103.599976 1023 | 1102.329956 1024 | 1111.420044 1025 | 1121.880005 1026 | 1115.52002 1027 | 1086.349976 1028 | 1079.800049 1029 | 1076.01001 1030 | 1080.910034 1031 | 1097.949951 1032 | 1111.25 1033 | 1121.579956 1034 | 1131.589966 1035 | 1116.349976 1036 | 1124.829956 1037 | 1140.47998 1038 | 1144.209961 1039 | 1144.900024 1040 | 1150.339966 1041 | 1153.579956 1042 | 1146.349976 1043 | 1146.329956 1044 | 1130.099976 1045 | 1138.069946 1046 | 1146.209961 1047 | 1137.810059 1048 | 1132.119995 1049 | 1250.410034 1050 | 1239.410034 1051 | 1225.140015 1052 | 1216.680054 1053 | 1209.01001 1054 | 1193.98999 1055 | 1152.319946 1056 | 1169.949951 1057 | 1173.98999 1058 | 1204.800049 1059 | 1188.01001 1060 | 1174.709961 1061 | 1197.27002 1062 | 1164.290039 1063 | 1167.26001 1064 | 1177.599976 1065 | 1198.449951 1066 | 1182.689941 1067 | 1191.25 1068 | 1189.530029 1069 | 1151.290039 1070 | 1168.890015 1071 | 1167.839966 1072 | 1171.02002 1073 | 1192.849976 1074 | 1188.099976 1075 | 1168.390015 1076 | 1181.410034 1077 | 1211.380005 1078 | 1204.930054 1079 | 1204.410034 1080 | 1206 1081 | 1220.170044 1082 | 1234.25 1083 | 1239.560059 1084 | 1231.300049 1085 | 1229.150024 1086 | 1232.410034 1087 | 1238.709961 1088 | 1229.930054 1089 | 1234.030029 1090 | 1218.76001 1091 | 1246.52002 1092 | 1241.390015 1093 | 1225.089966 1094 | 1219 1095 | 1205.099976 1096 | 1176.630005 1097 | 1187.829956 1098 | 1209 1099 | 1207.680054 1100 | 1189.130005 1101 | 1202.310059 1102 | 1208.670044 1103 | 1215.449951 1104 | 1217.140015 1105 | 1243.01001 1106 | 1243.640015 1107 | 1253.069946 1108 | 1245.48999 1109 | 1246.150024 1110 | 1242.800049 1111 | 1259.130005 1112 | 1260.98999 1113 | 1265.130005 1114 | 1290 1115 | 1262.619995 1116 | 1261.290039 1117 | 1260.109985 1118 | 1273.73999 1119 | 1291.369995 1120 | 1292.030029 1121 | 1291.800049 1122 | 1308.859985 1123 | 1311.369995 1124 | 1299.189941 1125 | 1298.800049 1126 | 1298 1127 | 1311.459961 1128 | 1334.869995 1129 | 1320.699951 1130 | 1315.459961 1131 | 1303.050049 1132 | 1301.349976 1133 | 1295.339966 1134 | 1306.689941 1135 | 1313.550049 1136 | 1312.98999 1137 | 1304.959961 1138 | 1289.920044 1139 | 1295.280029 1140 | 1320.540039 1141 | 1328.130005 1142 | 1340.619995 1143 | 1343.560059 1144 | 1344.660034 1145 | 1345.02002 1146 | 1350.27002 1147 | 1347.829956 1148 | 1361.170044 1149 | 1355.119995 1150 | 1352.619995 1151 | 1356.040039 1152 | 1349.589966 1153 | 1348.839966 1154 | 1343.560059 1155 | 1360.400024 1156 | 1351.890015 1157 | 1336.140015 1158 | 1337.02002 1159 | 1367.369995 1160 | 1360.660034 1161 | 1394.209961 1162 | 1393.339966 1163 | 1404.319946 1164 | 1419.829956 1165 | 1429.72998 1166 | 1439.22998 1167 | 1430.880005 1168 | 1439.199951 1169 | 1451.699951 1170 | 1480.390015 1171 | 1484.400024 1172 | 1485.949951 1173 | 1486.650024 1174 | 1466.709961 1175 | 1433.900024 1176 | 1452.560059 1177 | 1458.630005 1178 | 1455.839966 1179 | 1434.22998 1180 | 1485.939941 1181 | 1447.069946 1182 | 1448.22998 1183 | 1476.22998 1184 | 1479.22998 1185 | 1508.680054 1186 | 1508.790039 1187 | 1518.27002 1188 | 1514.660034 1189 | 1520.73999 1190 | 1519.670044 1191 | 1526.689941 1192 | 1518.150024 1193 | 1485.109985 1194 | 1421.589966 1195 | 1388.449951 1196 | 1393.180054 1197 | 1318.089966 1198 | 1339.329956 1199 | 1389.109985 1200 | 1341.390015 1201 | 1386.52002 1202 | 1319.040039 1203 | 1298.410034 1204 | 1215.560059 1205 | 1280.390015 1206 | 1215.410034 1207 | 1114.910034 1208 | 1219.72998 1209 | 1084.329956 1210 | 1119.800049 1211 | 1096.800049 1212 | 1115.290039 1213 | 1072.319946 1214 | 1056.619995 1215 | 1134.459961 1216 | 1102.48999 1217 | 1161.75 1218 | 1110.709961 1219 | 1146.819946 1220 | 1162.810059 1221 | 1105.619995 1222 | 1120.839966 1223 | 1097.880005 1224 | 1186.920044 1225 | 1186.51001 1226 | 1210.280029 1227 | 1211.449951 1228 | 1217.560059 1229 | 1269.22998 1230 | 1262.469971 1231 | 1263.469971 1232 | 1283.25 1233 | 1266.609985 1234 | 1216.339966 1235 | 1263.209961 1236 | 1276.310059 1237 | 1279.310059 1238 | 1275.880005 1239 | 1233.670044 1240 | 1341.47998 1241 | 1348.660034 1242 | 1320.609985 1243 | 1326.800049 1244 | 1351.109985 1245 | 1347.300049 1246 | 1372.560059 1247 | 1388.369995 1248 | 1403.26001 1249 | 1375.73999 1250 | 1349.329956 1251 | 1356.130005 1252 | 1373.189941 1253 | 1383.939941 1254 | 1373.484985 1255 | 1406.719971 1256 | 1402.800049 1257 | 1410.420044 1258 | 1417.02002 1259 | 1417.839966 1260 | -------------------------------------------------------------------------------- /dirac_phi.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the generator for ACNet (in the class DiracPhi). 3 | This is named as such since the mixing variable is a convex combination of dirac delta functions. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class DiracPhi(nn.Module): 11 | ''' 12 | TODO: streamline 3 cases in forward pass. 13 | ''' 14 | 15 | def __init__(self, depth, widths, lc_w_range, shift_w_range): 16 | super(DiracPhi, self).__init__() 17 | 18 | # Depth is the number of hidden layers. 19 | self.depth = depth 20 | self.widths = widths 21 | self.lc_w_range = lc_w_range 22 | self.shift_w_range = shift_w_range 23 | 24 | assert self.depth == len(self.widths) 25 | 26 | self.shift_raw_, self.lc_raw_ = self.init_w() 27 | print(self.shift_raw_, self.lc_raw_) 28 | self.shift_raw = nn.ParameterList( 29 | [nn.Parameter(x) for x in self.shift_raw_]) 30 | self.lc_raw = nn.ParameterList([nn.Parameter(x) for x in self.lc_raw_]) 31 | 32 | def init_w(self): 33 | sizes = self.get_sizes_w_() 34 | shift_sizes, lc_sizes = sizes[:self.depth], sizes[self.depth:] 35 | shift_tensors, lc_tensors = [], [] 36 | 37 | for shift_size in shift_sizes: 38 | w = torch.zeros(shift_size) 39 | torch.nn.init.uniform_(w, *self.shift_w_range) 40 | shift_tensors.append(w) 41 | 42 | for lc_size in lc_sizes: 43 | w = torch.zeros(lc_size) 44 | torch.nn.init.uniform_(w, *self.lc_w_range) 45 | lc_tensors.append(w) 46 | 47 | return shift_tensors, lc_tensors 48 | 49 | def get_sizes_w_(self): 50 | depth, widths = self.depth, self.widths 51 | lc_sizes, shift_sizes = [], [] 52 | 53 | # Shift weights 54 | prev_width = 1 55 | for pos in range(depth): 56 | width = widths[pos] 57 | shift_sizes.append((width,)) 58 | prev_width = width 59 | 60 | # Linear combination weights 61 | for pos in range(depth): 62 | width = widths[pos] 63 | if pos < depth-1: 64 | next_width = widths[pos+1] 65 | else: 66 | next_width = 1 67 | lc_sizes.append((next_width, width)) 68 | 69 | return shift_sizes + lc_sizes 70 | 71 | def forward(self, t_raw): 72 | s_raw, lc_raw = self.shift_raw, self.lc_raw 73 | depth = self.depth 74 | num_queries = t_raw.numel() 75 | t = t_raw.flatten() 76 | 77 | # State[i] has a dimension of N x num_inputs (to current layer) 78 | # Initial state has a dimension of N x 1. 79 | initial_state = torch.ones((num_queries, 1)) 80 | states = [initial_state] 81 | 82 | # Positive function. 83 | def pf(x): return torch.exp(x) 84 | 85 | for ell in range(depth+1): 86 | F_prev = states[-1] 87 | if ell == 0: 88 | # In the first layer, there is only a shift, since convex combinations 89 | # are meaningless. 90 | n_outputs, n_inputs = s_raw[ell].size()[0], 1 91 | s = pf(s_raw[ell]) 92 | s = s[None, :].expand(num_queries, -1) 93 | 94 | Fp = F_prev[:, None].expand(-1, 1, n_outputs).squeeze(dim=1) 95 | t_2d = t[:, None].expand(-1, n_outputs) 96 | 97 | next_state = Fp * torch.exp(-t_2d * s) 98 | states.append(next_state) 99 | 100 | elif ell == depth: 101 | # In the last layer, we only perform convex combinations. 102 | n_outputs, n_inputs = lc_raw[ell-1].size() 103 | lc = torch.softmax(lc_raw[ell-1], dim=1) 104 | lc = lc[None, :, :].expand(num_queries, -1, -1) 105 | Fp = F_prev[:, None, :].expand(-1, n_outputs, -1) 106 | next_state = (Fp * lc).sum(dim=2) 107 | states.append(next_state) 108 | 109 | else: 110 | # Main case. 111 | n_outputs, n_inputs = lc_raw[ell-1].size() 112 | s = pf(s_raw[ell]) 113 | s = s[None, :].expand(num_queries, -1) 114 | lc = torch.softmax(lc_raw[ell-1], dim=1) 115 | lc = lc[None, :, :].expand(num_queries, -1, -1) 116 | Fp = F_prev[:, None, :].expand(-1, n_outputs, -1) 117 | t_2d = t[:, None].expand(-1, n_outputs) 118 | 119 | next_state = (Fp * lc).sum(dim=2) * torch.exp(-t_2d * s) 120 | states.append(next_state) 121 | 122 | output = states[-1] 123 | assert (output >= 0.).all() and ( 124 | output <= 1.+1e-10).all(), "t %s, output %s" % (t, output, ) 125 | 126 | return output.reshape(t_raw.size()) 127 | -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from main import Copula, sample 5 | 6 | 7 | def gen_data(phi, 8 | ndims, 9 | N, 10 | seed): 11 | torch.set_default_tensor_type(torch.DoubleTensor) 12 | net = Copula(phi) 13 | 14 | s = sample(net, ndims, N, seed=seed) 15 | log_ll = -torch.log(net(s, 'pdf')) 16 | 17 | print('mean log_ll:', torch.mean(log_ll)) 18 | 19 | plot_samples(s) 20 | return s, log_ll 21 | 22 | 23 | def plot_samples(s): 24 | s_np = s.detach().numpy() 25 | assert s_np.ndim == 2, 'Can only plot 2d array of samples.' 26 | 27 | import matplotlib.pyplot as plt 28 | 29 | s_np = s.detach().numpy() 30 | plt.scatter(s_np[:, 0], s_np[:, 1]) 31 | plt.show() 32 | -------------------------------------------------------------------------------- /gen_scripts/clayton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from phi_listing import ClaytonPhi 4 | from sacred import Experiment 5 | from gen_data import gen_data, plot_samples 6 | from sacred.observers import FileStorageObserver 7 | ex = Experiment('LOG_DATA_clayton') 8 | 9 | torch.set_default_tensor_type(torch.DoubleTensor) 10 | 11 | 12 | @ex.config 13 | def cfg(): 14 | Phi = ClaytonPhi 15 | phi_name = 'ClaytonPhi' 16 | theta = torch.tensor(5.) 17 | N = 10000 18 | ndims = 2 19 | seed = 142857 20 | 21 | 22 | @ex.capture 23 | def get_info(_run): 24 | return _run._id 25 | 26 | 27 | @ex.automain 28 | def run(Phi, ndims, theta, N, seed): 29 | 30 | phi = Phi(theta) 31 | id = get_info() 32 | s = phi.sample(ndims, N) 33 | log_ll = torch.log(phi.pdf(s)) 34 | print('avg_log_likelihood', torch.mean(log_ll)) 35 | 36 | import matplotlib.pyplot as plt 37 | plt.scatter(s.detach().numpy()[:, 0], s.detach().numpy()[:, 1]) 38 | plt.show() 39 | 40 | d = {'samples': s, 'log_ll': log_ll} 41 | pickle.dump(d, open('./data/clayton%s.p' % id, 'wb')) 42 | ex.add_artifact('./data/clayton%s.p' % id) 43 | 44 | 45 | if __name__ == '__main__': 46 | print('Sample usage: python -m gen_scripts.clayton -F gen_clayton') 47 | -------------------------------------------------------------------------------- /gen_scripts/frank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from phi_listing import FrankPhi 4 | from sacred import Experiment 5 | from gen_data import gen_data, plot_samples 6 | from sacred.observers import FileStorageObserver 7 | ex = Experiment('LOG_DATA_frank') 8 | 9 | torch.set_default_tensor_type(torch.DoubleTensor) 10 | 11 | 12 | @ex.config 13 | def cfg(): 14 | Phi = FrankPhi 15 | phi_name = 'FrankPhi' 16 | theta = 15. 17 | N = 10000 18 | ndims = 2 19 | seed = 142857 20 | 21 | 22 | @ex.capture 23 | def get_info(_run): 24 | return _run._id 25 | 26 | 27 | @ex.automain 28 | def run(Phi, ndims, theta, N, seed): 29 | 30 | phi = Phi(torch.tensor(theta)) 31 | id = get_info() 32 | s, log_ll = gen_data(phi, ndims, N, seed) 33 | print('avg_log_likelihood', torch.mean(log_ll)) 34 | 35 | import matplotlib.pyplot as plt 36 | plt.scatter(s.detach().numpy()[:, 0], s.detach().numpy()[:, 1]) 37 | plt.show() 38 | 39 | d = {'samples': s, 'log_ll': log_ll} 40 | pickle.dump(d, open('./data/frank%s.p' % id, 'wb')) 41 | ex.add_artifact('./data/frank%s.p' % id) 42 | 43 | 44 | if __name__ == '__main__': 45 | print('Sample usage: python -m gen_scripts.frank -F gen_frank') 46 | -------------------------------------------------------------------------------- /gen_scripts/joe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from phi_listing import JoePhi 4 | from sacred import Experiment 5 | from gen_data import gen_data, plot_samples 6 | from sacred.observers import FileStorageObserver 7 | ex = Experiment('LOG_DATA_joe') 8 | 9 | torch.set_default_tensor_type(torch.DoubleTensor) 10 | 11 | 12 | @ex.config 13 | def cfg(): 14 | Phi = JoePhi 15 | phi_name = 'JoePhi' 16 | theta = torch.tensor(3.) 17 | N = 10000 18 | ndims = 2 19 | seed = 142857 20 | 21 | 22 | @ex.capture 23 | def get_info(_run): 24 | return _run._id 25 | 26 | 27 | @ex.automain 28 | def run(Phi, ndims, theta, N, seed): 29 | 30 | phi = Phi(theta) 31 | id = get_info() 32 | s = phi.sample(ndims, N) 33 | log_ll = torch.log(phi.pdf(s)) 34 | print('avg_log_likelihood', torch.mean(log_ll)) 35 | 36 | import matplotlib.pyplot as plt 37 | plt.scatter(s.detach().numpy()[:, 0], s.detach().numpy()[:, 1]) 38 | plt.show() 39 | 40 | d = {'samples': s, 'log_ll': log_ll} 41 | pickle.dump(d, open('./data/joe%s.p' % id, 'wb')) 42 | ex.add_artifact('./data/joe%s.p' % id) 43 | 44 | 45 | if __name__ == '__main__': 46 | print('Sample usage: python -m gen_scripts.joe -F gen_joe') 47 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from torch.autograd import Function 5 | 6 | 7 | class PhiInv(nn.Module): 8 | def __init__(self, phi): 9 | super(PhiInv, self).__init__() 10 | self.phi = phi 11 | 12 | def forward(self, y, t0=None, max_iter=400, tol=1e-10): 13 | with torch.no_grad(): 14 | """ 15 | # We will only run newton's method on entries which do not have 16 | # a manual inverse defined (via .inverse) 17 | inverse = self.phi.inverse(y) 18 | assert inverse.shape == y.shape 19 | no_inverse_indices = torch.isnan(inverse) 20 | # print(no_inverse_indices) 21 | # print(y[no_inverse_indices].shape) 22 | t_ = newton_root( 23 | self.phi, y[no_inverse_indices], max_iter=max_iter, tol=tol, 24 | t0=torch.ones_like(y[no_inverse_indices])*1e-10) 25 | 26 | inverse[no_inverse_indices] = t_ 27 | t = inverse 28 | """ 29 | t = newton_root(self.phi, y, max_iter=max_iter, tol=tol) 30 | 31 | topt = t.clone().detach().requires_grad_(True) 32 | f_topt = self.phi(topt) 33 | return self.FastInverse.apply(y, topt, f_topt, self.phi) 34 | 35 | class FastInverse(Function): 36 | ''' 37 | Fast inverse function. To avoid running the optimization 38 | procedure (e.g., Newton's) repeatedly, we pass in the value 39 | of the inverse (obtained from the forward pass) manually. 40 | 41 | In the backward pass, we provide gradients w.r.t (i) `y`, and 42 | (ii) `w`, which are any parameters contained in PhiInv.phi. The 43 | latter is implicitly given by furnishing derivatives w.r.t. f_topt, 44 | i.e., the function evaluated (with the current `w`) on topt. Note 45 | that this should contain *values* approximately equal to y, but 46 | will have the necessary computational graph built up, but detached 47 | from y. 48 | ''' 49 | @staticmethod 50 | def forward(ctx, y, topt, f_topt, phi): 51 | ctx.save_for_backward(y, topt, f_topt) 52 | ctx.phi = phi 53 | return topt 54 | 55 | @staticmethod 56 | def backward(ctx, grad): 57 | y, topt, f_topt = ctx.saved_tensors 58 | phi = ctx.phi 59 | 60 | with torch.enable_grad(): 61 | # Call FakeInverse once again, in order to allow for higher 62 | # order derivatives to be taken. 63 | z = PhiInv.FastInverse.apply(y, topt, f_topt, phi) 64 | 65 | # Find phi'(z), i.e., take derivatives of phi(z) w.r.t z. 66 | f = phi(z) 67 | dev_z = torch.autograd.grad(f.sum(), z, create_graph=True)[0] 68 | 69 | # To understand why this works, refer to the derivations for 70 | # inverses. Note that when taking derivatives w.r.t. `w`, we 71 | # make use of autodiffs automatic application of the chain rule. 72 | # This automatically finds the derivative d/dw[phi(z)] which 73 | # when multiplied by the 3rd returned value gives the derivative 74 | # w.r.t. `w` contained by phi. 75 | return grad/dev_z, None, -grad/dev_z, None 76 | 77 | 78 | def newton_root(phi, y, t0=None, max_iter=200, tol=1e-10, guarded=False): 79 | ''' 80 | Solve 81 | f(t) = y 82 | using the Newton's root finding method. 83 | 84 | Parameters 85 | ---------- 86 | f: Function which takes in a Tensor t of shape `s` and outputs 87 | the pointwise evaluation f(t). 88 | y: Tensor of shape `s`. 89 | t0: Tensor of shape `s` indicating the initial guess for the root. 90 | max_iter: Positive integer containing the max. number of iterations. 91 | tol: Termination criterion for the absolute difference |f(t) - y|. 92 | By default, this is set to 1e-14, 93 | beyond which instability could occur when using pytorch `DoubleTensor`. 94 | guarded: Whether we use guarded Newton's root finding method. 95 | By default False: too slow and is not necessary most of the time. 96 | 97 | Returns: 98 | Tensor `t*` of size `s` such that f(t*) ~= y 99 | ''' 100 | if t0 is None: 101 | t = torch.zeros_like(y) 102 | else: 103 | t = t0.clone().detach() 104 | 105 | s = y.size() 106 | for it in range(max_iter): 107 | 108 | with torch.enable_grad(): 109 | f_t = phi(t.requires_grad_(True)) 110 | fp_t = torch.autograd.grad(f_t.sum(), t)[0] 111 | assert not torch.any(torch.isnan(fp_t)) 112 | 113 | assert f_t.size() == s 114 | assert fp_t.size() == s 115 | 116 | g_t = f_t - y 117 | 118 | # Terminate algorithm when all errors are sufficiently small. 119 | if (torch.abs(g_t) < tol).all(): 120 | break 121 | 122 | if not guarded: 123 | t = t - g_t / fp_t 124 | else: 125 | step_size = torch.ones_like(t) 126 | for num_guarded_steps in range(1000): 127 | t_candidate = t - step_size * g_t / fp_t 128 | f_t_candidate = phi(t_candidate.requires_grad_(True)) 129 | g_candidate = f_t_candidate - y 130 | overstepped_indices = torch.abs(g_candidate) > torch.abs(g_t) 131 | if not overstepped_indices.any(): 132 | t = t_candidate 133 | print(num_guarded_steps) 134 | break 135 | else: 136 | step_size[overstepped_indices] /= 2. 137 | 138 | assert torch.abs(g_t).max( 139 | ) < tol, "t=%s, f(t)-y=%s, y=%s, iter=%s, max dev:%s" % (t, g_t, y, it, g_t.max()) 140 | assert t.size() == s 141 | return t 142 | 143 | 144 | def bisection_root(phi, y, lb=None, ub=None, increasing=True, max_iter=100, tol=1e-10): 145 | ''' 146 | Solve 147 | f(t) = y 148 | using the bisection method. 149 | 150 | Parameters 151 | ---------- 152 | f: Function which takes in a Tensor t of shape `s` and outputs 153 | the pointwise evaluation f(t). 154 | y: Tensor of shape `s`. 155 | lb, ub: lower and upper bounds for t. 156 | increasing: True if f is increasing, False if decreasing. 157 | max_iter: Positive integer containing the max. number of iterations. 158 | tol: Termination criterion for the difference in upper and lower bounds. 159 | By default, this is set to 1e-10, 160 | beyond which instability could occur when using pytorch `DoubleTensor`. 161 | 162 | Returns: 163 | Tensor `t*` of size `s` such that f(t*) ~= y 164 | ''' 165 | if lb is None: 166 | lb = torch.zeros_like(y) 167 | if ub is None: 168 | ub = torch.ones_like(y) 169 | 170 | assert lb.size() == y.size() 171 | assert ub.size() == y.size() 172 | assert torch.all(lb < ub) 173 | 174 | f_ub = phi(ub) 175 | f_lb = phi(lb) 176 | assert torch.all( 177 | f_ub >= f_lb) or not increasing, 'Need f to be monotonically non-decreasing.' 178 | assert torch.all( 179 | f_lb >= f_ub) or increasing, 'Need f to be monotonically non-increasing.' 180 | 181 | assert (torch.all( 182 | f_ub >= y) and torch.all(f_lb <= y)) or not increasing, 'y must lie within lower and upper bound. max min y=%s, %s. ub, lb=%s %s' % (y.max(), y.min(), ub, lb) 183 | assert (torch.all( 184 | f_ub <= y) and torch.all(f_lb >= y)) or increasing, 'y must lie within lower and upper bound. y=%s, %s. ub, lb=%s %s' % (y.max(), y.min(), ub, lb) 185 | 186 | for it in range(max_iter): 187 | t = (lb + ub)/2 188 | f_t = phi(t) 189 | 190 | if increasing: 191 | too_low, too_high = f_t < y, f_t >= y 192 | lb[too_low] = t[too_low] 193 | ub[too_high] = t[too_high] 194 | else: 195 | too_low, too_high = f_t > y, f_t <= y 196 | lb[too_low] = t[too_low] 197 | ub[too_high] = t[too_high] 198 | 199 | assert torch.all(ub - lb > 0.), "lb: %s, ub: %s" % (lb, ub) 200 | 201 | assert torch.all(ub - lb <= tol) 202 | return t 203 | 204 | 205 | def bisection_default_increasing(phi, y): 206 | ''' 207 | Wrapper for performing bisection method when f is increasing. 208 | ''' 209 | return bisection_root(phi, y, increasing=True) 210 | 211 | 212 | def bisection_default_decreasing(phi, y): 213 | ''' 214 | Wrapper for performing bisection method when f is decreasing. 215 | ''' 216 | return bisection_root(phi, y, increasing=False) 217 | 218 | 219 | class MixExpPhi(nn.Module): 220 | ''' 221 | Sample net for phi involving the sum of 2 negative exponentials. 222 | phi(t) = m1 * exp(-w1 * t) + m2 * exp(-w2 * t) 223 | 224 | Network Parameters 225 | ================== 226 | mix: Tensor of size 2 such that such that (m1, m2) = softmax(mix) 227 | slope: Tensor of size 2 such that exp(m1) = w1, exp(m2) = w2 228 | 229 | Note that this implies 230 | i) m1, m2 > 0 and m1 + m2 = 1.0 231 | ii) w1, w2 > 0 232 | ''' 233 | 234 | def __init__(self, init_w=None): 235 | import numpy as np 236 | super(MixExpPhi, self).__init__() 237 | 238 | if init_w is None: 239 | self.mix = nn.Parameter(torch.tensor( 240 | [np.log(0.2), np.log(0.8)], requires_grad=True)) 241 | self.slope = nn.Parameter( 242 | torch.log(torch.tensor([1e1, 1e6], requires_grad=True))) 243 | else: 244 | assert len(init_w) == 2 245 | assert init_w[0].numel() == init_w[1].numel() 246 | self.mix = nn.Parameter(init_w[0]) 247 | self.slope = nn.Parameter(init_w[1]) 248 | 249 | def forward(self, t): 250 | s = t.size() 251 | t_ = t.flatten() 252 | nquery, nmix = t.numel(), self.mix.numel() 253 | 254 | mix_ = torch.nn.functional.softmax(self.mix) 255 | exps = torch.exp(-t_[:, None].expand(nquery, nmix) * 256 | torch.exp(self.slope)[None, :].expand(nquery, nmix)) 257 | 258 | ret = torch.sum(mix_ * exps, dim=1) 259 | return ret.reshape(s) 260 | 261 | 262 | class MixExpPhi2FixedSlope(nn.Module): 263 | def __init__(self, init_w=None): 264 | super(MixExpPhi2FixedSlope, self).__init__() 265 | 266 | self.mix = nn.Parameter(torch.tensor( 267 | [np.log(0.25)], requires_grad=True)) 268 | self.slope = torch.tensor([1e1, 1e6], requires_grad=True) 269 | 270 | def forward(self, t): 271 | z = 1./(1+torch.exp(-self.mix[0])) 272 | return z * torch.exp(-t * self.slope[0]) + (1-z) * torch.exp(-t * self.slope[1]) 273 | 274 | 275 | class Copula(nn.Module): 276 | def __init__(self, phi): 277 | super(Copula, self).__init__() 278 | self.phi = phi 279 | self.phi_inv = PhiInv(phi) 280 | 281 | def forward(self, y, mode='cdf', others=None, tol=1e-10): 282 | if not y.requires_grad: 283 | y = y.requires_grad_(True) 284 | ndims = y.size()[1] 285 | inverses = self.phi_inv(y, tol=tol) 286 | cdf = self.phi(inverses.sum(dim=1)) 287 | 288 | if mode == 'cdf': 289 | return cdf 290 | 291 | if mode == 'pdf': 292 | cur = cdf 293 | for dim in range(ndims): 294 | # TODO: Only take gradients with respect to one dimension of y at at time 295 | cur = torch.autograd.grad( 296 | cur.sum(), y, create_graph=True)[0][:, dim] 297 | return cur 298 | elif mode == 'cond_cdf': 299 | target_dims = others['cond_dims'] 300 | 301 | # Numerator 302 | cur = cdf 303 | for dim in target_dims: 304 | # TODO: Only take gradients with respect to one dimension of y at a time 305 | cur = torch.autograd.grad( 306 | cur.sum(), y, create_graph=True, retain_graph=True)[0][:, dim] 307 | numerator = cur 308 | 309 | # Denominator 310 | trunc_cdf = self.phi(inverses[:, target_dims]) 311 | cur = trunc_cdf 312 | for dim in range(len(target_dims)): 313 | cur = torch.autograd.grad( 314 | cur.sum(), y, create_graph=True)[0][:, dim] 315 | 316 | denominator = cur 317 | return numerator/denominator 318 | 319 | 320 | def sample(net, ndims, N, seed=142857): 321 | """ 322 | Note: this does *not* use the efficient method described in the paper. 323 | Instead, we will use the naive method, i.e., conditioning on each 324 | variable in turn and then applying the inverse CDF method on the resultant conditional 325 | CDF. 326 | 327 | This method will work on all generators (even those defined by ACNet), and is 328 | the simplest method assuming no knowledge of the mixing variable M is known. 329 | """ 330 | # Store old seed and set new seed 331 | old_rng_state = torch.random.get_rng_state() 332 | torch.manual_seed(seed) 333 | 334 | U = torch.rand(N, ndims) 335 | 336 | for dim in range(1, ndims): 337 | print('Sampling from dim: %s' % dim) 338 | y = U[:, dim].detach().clone() 339 | 340 | def cond_cdf_func(u): 341 | U_ = U.clone().detach() 342 | U_[:, dim] = u 343 | cond_cdf = net(U_[:, :(dim+1)], "cond_cdf", 344 | others={'cond_dims': list(range(dim))}) 345 | return cond_cdf 346 | 347 | # Call inverse using the conditional cdf `M` as the function. 348 | # Note that the weight parameter is set to None since `M` is not parameterized, 349 | # i.e., hardcoded as the conditional cdf itself. 350 | U[:, dim] = bisection_default_increasing(cond_cdf_func, y).detach() 351 | 352 | # Revert to old random state. 353 | torch.random.set_rng_state(old_rng_state) 354 | 355 | return U 356 | 357 | #################################################################################### 358 | # Tests 359 | #################################################################################### 360 | 361 | 362 | def test_grad_of_phi(): 363 | phi_net = MixExpPhi() 364 | phi_inv = PhiInv(phi_net) 365 | query = torch.tensor( 366 | [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [1., 1., 1.]]).requires_grad_(True) 367 | 368 | gradcheck(phi_net, (query), eps=1e-9) 369 | gradgradcheck(phi_net, (query,), eps=1e-9) 370 | 371 | 372 | def test_grad_y_of_inverse(): 373 | phi_net = MixExpPhi() 374 | phi_inv = PhiInv(phi_net) 375 | query = torch.tensor( 376 | [[0.1, 0.2], [0.2, 0.3], [0.25, 0.7]]).requires_grad_(True) 377 | 378 | gradcheck(phi_inv, (query, ), eps=1e-10) 379 | gradgradcheck(phi_inv, (query, ), eps=1e-10) 380 | 381 | 382 | def test_grad_w_of_inverse(): 383 | phi_net = MixExpPhi2FixedSlope() 384 | phi_inv = PhiInv(phi_net) 385 | 386 | eps = 1e-8 387 | new_phi_inv = copy.deepcopy(phi_inv) 388 | 389 | # Jitter weights in new_phi. 390 | new_phi_inv.phi.mix.data = phi_inv.phi.mix.data + eps 391 | 392 | query = torch.tensor( 393 | [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.99, 0.99, 0.99]]).requires_grad_(True) 394 | old_value = phi_inv(query).sum() 395 | old_value.backward() 396 | anal_grad = phi_inv.phi.mix.grad 397 | new_value = new_phi_inv(query).sum() 398 | num_grad = (new_value-old_value)/eps 399 | 400 | print('gradient of weights (anal)', anal_grad) 401 | print('gradient of weights (num)', num_grad) 402 | 403 | 404 | def test_grad_y_of_pdf(): 405 | phi_net = MixExpPhi() 406 | query = torch.tensor( 407 | [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.99, 0.99, 0.99]]).requires_grad_(True) 408 | cop = Copula(phi_net) 409 | def f(y): return cop(y, mode='pdf') 410 | gradcheck(f, (query, ), eps=1e-8) 411 | # This fails sometimes if rtol is too low..? 412 | gradgradcheck(f, (query, ), eps=1e-8, atol=1e-6, rtol=1e-2) 413 | 414 | 415 | def plot_pdf_and_cdf_over_grid(): 416 | phi_net = MixExpPhi() 417 | cop = Copula(phi_net) 418 | 419 | n = 500 420 | x1 = np.linspace(0.001, 1, n) 421 | x2 = np.linspace(0.001, 1, n) 422 | xv1, xv2 = np.meshgrid(x1, x2) 423 | xv1_tensor = torch.tensor(xv1.flatten()) 424 | xv2_tensor = torch.tensor(xv2.flatten()) 425 | query = torch.stack((xv1_tensor, xv2_tensor) 426 | ).double().t().requires_grad_(True) 427 | cdf = cop(query, mode='cdf') 428 | pdf = cop(query, mode='pdf') 429 | 430 | assert abs(pdf.mean().detach().numpy().sum() - 431 | 1) < 1e-6, 'Mean of pdf over grid should be 1' 432 | assert abs(cdf[-1].detach().numpy().sum() - 433 | 1) < 1e-6, 'CDF at (1..1) should be should be 1' 434 | 435 | 436 | def plot_cond_cdf(): 437 | phi_net = MixExpPhi() 438 | cop = Copula(phi_net) 439 | 440 | n = 500 441 | xv2 = np.linspace(0.001, 1, n) 442 | xv2_tensor = torch.tensor(xv2.flatten()) 443 | xv1_tensor = 0.9 * torch.ones_like(xv2_tensor) 444 | x = torch.stack([xv1_tensor, xv2_tensor], dim=1).requires_grad_(True) 445 | cond_cdf = cop(x, mode="cond_cdf", others={'cond_dims': [0]}) 446 | 447 | plt.figure() 448 | plt.plot(cond_cdf.detach().numpy()) 449 | plt.title('Conditional CDF') 450 | plt.draw() 451 | plt.pause(0.01) 452 | 453 | 454 | def plot_samples(): 455 | phi_net = MixExpPhi() 456 | cop = Copula(phi_net) 457 | 458 | s = sample(cop, 2, 2000, seed=142857) 459 | s_np = s.detach().numpy() 460 | 461 | plt.figure() 462 | plt.scatter(s_np[:, 0], s_np[:, 1]) 463 | plt.title('Sampled points from Copula') 464 | plt.draw() 465 | plt.pause(0.01) 466 | 467 | 468 | def plot_loss_surface(): 469 | phi_net = MixExpPhi2FixedSlope() 470 | cop = Copula(phi_net) 471 | 472 | s = sample(cop, 2, 2000, seed=142857) 473 | s_np = s.detach().numpy() 474 | 475 | l = [] 476 | x = np.linspace(-1e-2, 1e-2, 1000) 477 | for SS in x: 478 | new_cop = copy.deepcopy(cop) 479 | new_cop.phi.mix.data = cop.phi.mix.data + SS 480 | 481 | loss = -torch.log(new_cop(s, mode='pdf')).sum() 482 | l.append(loss.detach().numpy().sum()) 483 | 484 | plt.figure() 485 | plt.plot(x, l) 486 | plt.title('Loss surface') 487 | plt.draw() 488 | plt.pause(0.01) 489 | 490 | 491 | def test_training(test_grad_w=False): 492 | gen_phi_net = MixExpPhi() 493 | gen_phi_inv = PhiInv(gen_phi_net) 494 | gen_cop = Copula(gen_phi_net) 495 | 496 | s = sample(gen_cop, 2, 2000, seed=142857) 497 | s_np = s.detach().numpy() 498 | 499 | ideal_loss = -torch.log(gen_cop(s, mode='pdf')).sum() 500 | 501 | train_cop = copy.deepcopy(gen_cop) 502 | train_cop.phi.mix.data *= 1.5 503 | train_cop.phi.slope.data *= 1.5 504 | print('Initial loss', ideal_loss) 505 | optimizer = optim.Adam(train_cop.parameters(), lr=1e-3) 506 | 507 | def numerical_grad(cop): 508 | # Take gradients w.r.t to the first mixing parameter 509 | print('Analytic gradients:', cop.phi.mix.grad[0]) 510 | 511 | old_cop, new_cop = copy.deepcopy(cop), copy.deepcopy(cop) 512 | # First order approximation of gradient of weights 513 | eps = 1e-6 514 | new_cop.phi.mix.data[0] = cop.phi.mix.data[0] + eps 515 | x2 = -torch.log(new_cop(s, mode='pdf')).sum() 516 | x1 = -torch.log(cop(s, mode='pdf')).sum() 517 | 518 | first_order_approximate = (x2-x1)/eps 519 | print('First order approx.:', first_order_approximate) 520 | 521 | for iter in range(100000): 522 | optimizer.zero_grad() 523 | loss = -torch.log(train_cop(s, mode='pdf')).sum() 524 | loss.backward() 525 | print('iter', iter, ':', loss, 'ideal loss:', ideal_loss) 526 | if test_grad_w: 527 | numerical_grad(train_cop) 528 | optimizer.step() 529 | 530 | 531 | if __name__ == '__main__': 532 | import torch.optim as optim 533 | from torch.autograd import gradgradcheck, gradcheck 534 | import numpy as np 535 | import logging as log 536 | import matplotlib.pyplot as plt 537 | import copy 538 | 539 | torch.set_default_tensor_type(torch.DoubleTensor) 540 | 541 | test_grad_of_phi() 542 | test_grad_y_of_inverse() 543 | test_grad_w_of_inverse() 544 | test_grad_y_of_pdf() 545 | 546 | plot_pdf_and_cdf_over_grid() 547 | plot_cond_cdf() 548 | plot_samples() 549 | """ Uncomment for rudimentary training. 550 | Note: very slow and unrealistic. 551 | plot_loss_surface() 552 | test_training() 553 | """ 554 | -------------------------------------------------------------------------------- /phi_listing.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains the standard copula. 3 | For some copula, we use more efficient methods based on 4 | Chapter 2 of 5 | Matthias, Scherer, and Mai Jan-frederik. Simulating copulas: stochastic models, sampling algorithms, and applications. Vol. 4. World Scientific, 2012. 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | import logging 12 | 13 | 14 | class PhiListing(nn.Module): 15 | def __init__(self): 16 | super(PhiListing, self).__init__() 17 | 18 | def inverse(self, t): 19 | ''' 20 | Return a tensor of nan's by default if no manual inverse is defined. 21 | ''' 22 | return torch.zeros_like(t) / torch.zeros_like(t) 23 | 24 | def sample(self, ndims, n): 25 | shape = (n, ndims) 26 | ms = self.sample_M(n)[:, None].expand(-1, ndims) 27 | e = torch.distributions.exponential.Exponential(torch.ones(shape)) 28 | E = e.sample() 29 | return self.forward(E/ms) 30 | 31 | 32 | class FrankPhi(nn.Module): 33 | def __init__(self, theta): 34 | super(FrankPhi, self).__init__() 35 | 36 | self.theta = nn.Parameter(theta) 37 | 38 | def forward(self, t): 39 | theta = self.theta 40 | ret = -1/theta * torch.log(torch.exp(-t)*(torch.exp(-theta)-1)+1) 41 | return ret 42 | 43 | def pdf(self, X): 44 | return None 45 | 46 | def cdf(self, X): 47 | return -1./self.theta * \ 48 | torch.log( 49 | 1 + (torch.exp(-self.theta * X[:, 0]) - 1) * ( 50 | torch.exp(-self.theta * X[:, 1]) - 1) / (torch.exp(-self.theta) - 1) 51 | ) 52 | 53 | 54 | class ClaytonPhi(PhiListing): 55 | def __init__(self, theta): 56 | super(ClaytonPhi, self).__init__() 57 | 58 | self.theta = nn.Parameter(theta) 59 | 60 | def forward(self, t): 61 | theta = self.theta 62 | ret = (1+t)**(-1/theta) 63 | return ret 64 | 65 | def inverse(self, t): 66 | ret = torch.zeros_like(t) / torch.zeros_like(t) 67 | ret[torch.abs(t - 1.) < self.eps_snap_zero] = 1.0 68 | 69 | return ret 70 | 71 | def sample_M(self, n): 72 | alpha = 1./self.theta 73 | 74 | m = torch.distributions.gamma.Gamma(1./self.theta, 1.0) 75 | 76 | return m.sample((n,)) 77 | 78 | def pdf(self, X): 79 | """ 80 | [From Wolfram] 81 | d/dx((d(x^(-z) + y^(-z) - 1)^(-1/z))/(dy)) = (-1/z - 1) z (-x^(-z - 1)) y^(-z - 1) (x^(-z) + y^(-z) - 1)^(-1/z - 2) 82 | """ 83 | assert X.shape[1] == 2 84 | 85 | Z = X[:, 0]**(-self.theta) + X[:, 1]**(-self.theta) - 1. 86 | ret = torch.zeros_like(Z) 87 | ret[Z > 0] = (-1/self.theta-1.) * self.theta * -X[Z > 0, 0] ** (-self.theta-1) * X[Z > 0, 1] ** ( 88 | -self.theta-1) * (X[Z > 0, 0] ** (-self.theta) + X[Z > 0, 1] ** (-self.theta) - 1) ** (-1./self.theta-2) 89 | 90 | return ret 91 | 92 | def cdf(self, X): 93 | assert X.shape[1] == 2 94 | 95 | return (torch.max(X[:, 0]**(-self.theta) + X[:, 1] 96 | ** (-self.theta)-1, torch.zeros(X.shape[0])))**(-1./self.theta) 97 | 98 | 99 | class JoePhi(PhiListing): 100 | """ 101 | The Joe Generator has a derivative that goes to infinity at t = 0. Hence we need 102 | to be careful when t is close to 0! 103 | """ 104 | 105 | def __init__(self, theta): 106 | super(JoePhi, self).__init__() 107 | 108 | self.eps = 0 109 | self.eps_snap_zero = 1e-15 110 | self.theta = nn.Parameter(theta) 111 | 112 | def forward(self, t): 113 | eps = self.eps 114 | if torch.any(t < eps): 115 | """ 116 | logging.warning('''some entry in t is too small, < %s. May encounter numerical errors if taking gradients. 117 | Smallest t= % s. Will be adding eps= % s to inputs for stability.''' % (eps, torch.min(t), eps)) 118 | """ 119 | t_ = t + eps 120 | else: 121 | t_ = t + eps 122 | theta = self.theta 123 | ret = 1-(1-torch.exp(-t_))**(1/theta) + 1e-7 124 | return ret 125 | 126 | def inverse(self, t): 127 | ret = torch.zeros_like(t) / torch.zeros_like(t) 128 | ret[torch.abs(t - 1.) < self.eps_snap_zero] = 1.0 129 | 130 | return ret 131 | 132 | def sample_M(self, n): 133 | alpha = 1./self.theta 134 | U = torch.rand(n) 135 | 136 | ret = torch.ones_like(U) 137 | 138 | ginv_u = self.Ginv(U) 139 | cond = self.F(torch.floor(ginv_u)) 140 | 141 | cut_indices = U <= alpha 142 | z = cond < U 143 | j = cond >= U 144 | 145 | ret[z] = torch.ceil(ginv_u[z]) 146 | ret[j] = torch.floor(ginv_u[j]) 147 | ret[cut_indices] = 1. 148 | 149 | return ret 150 | 151 | def Ginv(self, y): 152 | alpha = 1/self.theta 153 | 154 | return torch.exp(-self.theta * (torch.log(1.-y) + torch.lgamma(1.-alpha))) 155 | 156 | def gamma(self, x): 157 | return torch.exp(torch.lgamma(x)) 158 | 159 | def lbeta(self, x, y): 160 | return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x+y) 161 | 162 | def F(self, n): 163 | alpha = 1/self.theta 164 | return 1. - 1. / (n * torch.exp(self.lbeta(n, 1.-alpha))) 165 | 166 | def pdf(self, X): 167 | assert X.shape[1] == 2 168 | 169 | X_ = -X+1.0 170 | X_1 = X_[:, 0] 171 | X_2 = X_[:, 1] 172 | 173 | bleh = -X_1 ** (self.theta-1) * X_2 ** (self.theta-1) * \ 174 | ((X_1**self.theta) - (X_1**self.theta - 1) * X_2**self.theta)**(1./self.theta-2) * \ 175 | ((X_1**self.theta-1) * (X_2**self.theta-1) - self.theta) 176 | 177 | return bleh 178 | 179 | def cdf(self, X): 180 | assert X.shape[1] == 2 181 | 182 | X_ = -X+1.0 183 | X_1 = X_[:, 0] 184 | X_2 = X_[:, 1] 185 | 186 | return 1.0 - (X_1**self.theta + X_2**self.theta - (X_1**self.theta)*(X_2**self.theta))**(1./self.theta) 187 | 188 | 189 | class GumbelPhi(PhiListing): 190 | def __init__(self, theta): 191 | super(GumbelPhi, self).__init__() 192 | 193 | self.theta = nn.Parameter(theta) 194 | 195 | def forward(self, t): 196 | offsetx = 1e-15 197 | offsety = 1e-15 198 | theta = self.theta 199 | ret = torch.exp(-((t+offsetx) ** (1/theta))) + offsety 200 | return ret 201 | 202 | def pdf(self, X): 203 | assert X.shape[1] == 2 204 | 205 | u_ = (-torch.log(X[:, 0]))**(self.theta) 206 | v_ = (-torch.log(X[:, 1]))**(self.theta) 207 | 208 | return torch.exp(-(u_+v_)) ** (1/self.theta) 209 | 210 | 211 | class IGPhi(nn.Module): 212 | def __init__(self, theta): 213 | super(IGPhi, self).__init__() 214 | 215 | self.theta = nn.Parameter(theta) 216 | 217 | def forward(self, t): 218 | theta = self.theta 219 | ret = torch.exp((1-torch.sqrt(1+2*theta**2*t))/theta) 220 | return ret 221 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Helper files for training. 3 | ''' 4 | 5 | from torch.autograd import Function, gradcheck 6 | from torch.utils.data import DataLoader, Dataset 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import numpy as np 12 | import pickle 13 | import os 14 | from main import sample 15 | 16 | 17 | def load_data(path, num_train, num_test): 18 | ''' 19 | Loads dataset from `path` split into Pytorch train and test of 20 | given sizes. Train set is taken from the front while 21 | test set is taken from behind. 22 | 23 | :param path: path to .p file containing data. 24 | ''' 25 | f = open(path, 'rb') 26 | all_data = pickle.load(f)['samples'] 27 | 28 | ndata_all = all_data.size()[0] 29 | assert num_train+num_test <= ndata_all 30 | 31 | train_data = all_data[:num_train] 32 | test_data = all_data[(ndata_all-num_test):] 33 | 34 | return train_data, test_data 35 | 36 | 37 | def load_log_ll(path, num_train, num_test): 38 | f = open(path, 'rb') 39 | all_log_ll = pickle.load(f)['log_ll'] 40 | 41 | ndata_all = all_log_ll.numel() 42 | assert num_train+num_test <= ndata_all 43 | 44 | train_log_ll = all_log_ll[:num_train] 45 | test_log_ll = all_log_ll[(ndata_all-num_test):] 46 | 47 | return train_log_ll, test_log_ll 48 | 49 | 50 | def get_optim(name, net, args): 51 | if name == 'SGD': 52 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 53 | elif name == 'Adam': 54 | # TODO: add in more. Note: we do not use this in the paper. 55 | optimizer = optim.Adam(net.parameters(), args['lr']) 56 | elif name == 'RMSprop': 57 | # TODO: add in more. Note: we do not use this in the paper. 58 | optimizer = optim.RMSprop(net.parameters(), args['lr']) 59 | 60 | return optimizer 61 | 62 | 63 | def expt(train_data, val_data, 64 | net, 65 | optim_name, 66 | optim_args, 67 | identifier, 68 | num_epochs=1000, 69 | batch_size=100, 70 | chkpt_freq=50, 71 | ): 72 | 73 | os.mkdir('./checkpoints/%s' % identifier) 74 | os.mkdir('./sample_figs/%s' % identifier) 75 | 76 | train_loader = DataLoader( 77 | train_data, batch_size=batch_size, shuffle=True) 78 | 79 | # IMPORTANT: for this experiment, we did *not* perform hyperparameter tuning. 80 | # Hence, the `validation loss' here is essentially `test` loss. 81 | val_loader = DataLoader( 82 | val_data, batch_size=1000000, shuffle=True) 83 | 84 | optimizer = get_optim(optim_name, net, optim_args) 85 | 86 | train_loss_per_epoch = [] 87 | 88 | for epoch in range(num_epochs): 89 | loss_per_minibatch = [] 90 | for i, data in enumerate(train_loader, 0): 91 | optimizer.zero_grad() 92 | 93 | d = torch.tensor(data, requires_grad=True) 94 | p = net(d, mode='pdf') 95 | 96 | logloss = -torch.sum(torch.log(p)) 97 | reg_loss = logloss 98 | reg_loss.backward() 99 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 100 | 101 | loss_per_minibatch.append(scalar_loss) 102 | optimizer.step() 103 | 104 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 105 | print('Training loss at epoch %s: %s' % 106 | (epoch, train_loss_per_epoch[-1])) 107 | 108 | if epoch % chkpt_freq == 0: 109 | print('Checkpointing') 110 | torch.save({ 111 | 'epoch': epoch, 112 | 'model_state_dict': net.state_dict(), 113 | 'optimizer_state_dict': optimizer.state_dict(), 114 | 'loss': logloss, 115 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 116 | 117 | """ 118 | if args.dims == 2: 119 | print('Scatter sampling') 120 | samples = sample(net, 2, 1000) 121 | plt.scatter(samples[:, 0], samples[:, 1]) 122 | plt.savefig('./sample_figs/%s/epoch%s.png' % 123 | (identifier, epoch)) 124 | plt.clf() 125 | else: 126 | print('Not doign scatter plot, dims > 2') 127 | """ 128 | 129 | print('Evaluating validation/test loss.') 130 | for j, val_data in enumerate(val_loader, 0): 131 | net.zero_grad() 132 | val_p = net(val_data, mode='pdf') 133 | val_loss = -torch.mean(torch.log(val_p)) 134 | print('Average validation/test loss %s' % val_loss) 135 | 136 | 137 | def make_ranged_data(data, width, seed): 138 | ''' 139 | Transforms each coordinate (x, y) to 140 | the range ([x-e1, x+e2], [y-e3, y+e4]) 141 | where e1, e2, e3, e4 are drawn uniformly from [0, width]. 142 | The final ranges are snapped to [0, 1]. 143 | ''' 144 | 145 | assert data.shape[1] == 2 146 | 147 | np.random.seed(seed) 148 | 149 | epsilon_lower = np.random.random_sample((data.shape)) * width 150 | epsilon_upper = np.random.random_sample((data.shape)) * width 151 | 152 | epsilon_lower = torch.from_numpy(epsilon_lower) 153 | epsilon_upper = torch.from_numpy(epsilon_upper) 154 | 155 | bounds_lower = torch.max(torch.zeros_like(data), data - epsilon_lower) 156 | bounds_upper = torch.min(torch.ones_like(data), data + epsilon_upper) 157 | 158 | return bounds_lower, bounds_upper 159 | 160 | 161 | def expt_cdf_noisy(train_data, val_data, 162 | net, 163 | optim_name, 164 | optim_args, 165 | identifier, 166 | width, 167 | seed, 168 | num_epochs=1000, 169 | batch_size=100, 170 | chkpt_freq=50, 171 | ): 172 | ''' 173 | Add in uncertainty in all points 174 | ''' 175 | 176 | os.mkdir('./checkpoints/%s' % identifier) 177 | os.mkdir('./sample_figs/%s' % identifier) 178 | 179 | train_bounds_lower, train_bounds_upper = make_ranged_data( 180 | train_data, width, seed) 181 | val_bounds_lower, val_bounds_upper = make_ranged_data( 182 | val_data, width, seed) 183 | 184 | train_bounds = torch.cat( 185 | [train_bounds_lower, train_bounds_upper], dim=1) 186 | val_bounds = torch.cat( 187 | [val_bounds_lower, val_bounds_upper], dim=1) 188 | 189 | train_loader = DataLoader( 190 | train_bounds, batch_size=batch_size, shuffle=True) 191 | val_loader = DataLoader( 192 | val_data, batch_size=1000000, shuffle=True) 193 | 194 | optimizer = get_optim(optim_name, net, optim_args) 195 | 196 | train_loss_per_epoch = [] 197 | 198 | for epoch in range(num_epochs): 199 | loss_per_minibatch = [] 200 | for i, data in enumerate(train_loader, 0): 201 | optimizer.zero_grad() 202 | 203 | d = torch.tensor(data, requires_grad=True) 204 | dsize = d.shape[0] 205 | 206 | big = data[:, 2:] 207 | small = data[:, 0:2] 208 | cross1 = torch.cat( 209 | [data[:, 0:1], data[:, 3:4]], dim=1) 210 | cross2 = torch.cat( 211 | [data[:, 2:3], data[:, 1:2]], dim=1) 212 | 213 | joint = torch.cat([big, small, cross1, cross2], dim=0) 214 | P_raw = net(torch.tensor(joint, requires_grad=True), mode='cdf') 215 | P_big = P_raw[:dsize] 216 | P_small = P_raw[dsize:(2*dsize)] 217 | P_cross1 = P_raw[(2*dsize):(3*dsize)] 218 | P_cross2 = P_raw[(3*dsize):(4*dsize)] 219 | P = P_big + P_small - P_cross1 - P_cross2 220 | 221 | logloss = -torch.sum(torch.log(P)) 222 | reg_loss = logloss 223 | reg_loss.backward() 224 | scalar_loss = (reg_loss/P.numel()).detach().numpy().item() 225 | 226 | loss_per_minibatch.append(scalar_loss) 227 | optimizer.step() 228 | 229 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 230 | print('Training loss at epoch %s: %s' % 231 | (epoch, train_loss_per_epoch[-1])) 232 | 233 | if epoch % chkpt_freq == 0: 234 | print('Checkpointing') 235 | torch.save({ 236 | 'epoch': epoch, 237 | 'model_state_dict': net.state_dict(), 238 | 'optimizer_state_dict': optimizer.state_dict(), 239 | 'loss': logloss, 240 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 241 | 242 | """ 243 | if args.dims == 2: 244 | print('Scatter sampling') 245 | samples = sample(net, 2, 1000) 246 | plt.scatter(samples[:, 0], samples[:, 1]) 247 | plt.savefig('./sample_figs/%s/epoch%s.png' % 248 | (identifier, epoch)) 249 | plt.clf() 250 | else: 251 | print('Not doign scatter plot, dims > 2') 252 | """ 253 | 254 | print('Evaluating validation loss') 255 | for j, val_data in enumerate(val_loader, 0): 256 | net.zero_grad() 257 | val_p = net(val_data, mode='pdf') 258 | val_loss = -torch.mean(torch.log(val_p)) 259 | print('Average validation loss %s' % val_loss) 260 | -------------------------------------------------------------------------------- /train_scripts/boston/train.py: -------------------------------------------------------------------------------- 1 | from main import Copula 2 | from torch.autograd import Function, gradcheck 3 | from torch.utils.data import DataLoader, Dataset 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import numpy as np 9 | import pickle 10 | import os 11 | from main import sample 12 | from dirac_phi import DiracPhi 13 | 14 | from sklearn.datasets import load_boston 15 | from sklearn.model_selection import train_test_split 16 | import scipy 17 | 18 | from sacred import Experiment 19 | 20 | identifier = 'boston_housing' 21 | ex = Experiment('boston_housing') 22 | 23 | torch.set_default_tensor_type(torch.DoubleTensor) 24 | 25 | 26 | def add_train_random_noise(data, num_adds): 27 | new_data = np.random.rand(num_adds, data.shape[1]) 28 | return np.concatenate((data, new_data), axis=0) 29 | 30 | 31 | X, y = load_boston(return_X_y=True) 32 | X_train, X_test, y_train, y_test = train_test_split( 33 | X, y, shuffle=True, random_state=142857) 34 | X_train = np.concatenate((X_train, y_train[:, None]), axis=1) 35 | X_test = np.concatenate((X_test, y_test[:, None]), axis=1) 36 | 37 | nfeats = X_test.shape[1] 38 | 39 | # Normalize data based on ordinal rankings. 40 | for z in [X_train, X_test]: 41 | ndata = z.shape[0] 42 | gap = 1./(ndata+1) 43 | for i in range(nfeats): 44 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 45 | 46 | # Potentially inject noise into data: comment if you do not want noise. 47 | X_train = add_train_random_noise(X_train, int(X_train.shape[0]*0.01)) 48 | 49 | """ 50 | FEATURE DESCRIPTIONS. We are interested in features 0 and -1 (price) 51 | 52 | 0. crim 53 | per capita crime rate by town. 54 | 55 | 1. zn 56 | proportion of residential land zoned for lots over 25,000 sq.ft. 57 | 58 | 2. indus 59 | proportion of non-retail business acres per town. 60 | 61 | 3. chas 62 | Charles River dummy variable (= 1 if tract bounds river; 0 otherwise). 63 | 64 | 4. nox 65 | nitrogen oxides concentration (parts per 10 million). 66 | 67 | 5. rm 68 | average number of rooms per dwelling. 69 | 70 | 6. age 71 | proportion of owner-occupied units built prior to 1940. 72 | 73 | 7. dis 74 | weighted mean of distances to five Boston employment centres. 75 | 76 | 8. rad 77 | index of accessibility to radial highways. 78 | 79 | 9. tax 80 | full-value property-tax rate per \$10,000. 81 | 82 | 10. ptratio 83 | pupil-teacher ratio by town. 84 | 85 | 11. black 86 | 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town. 87 | 88 | 12. lstat 89 | lower status of the population (percent). 90 | 91 | 13 [y] . medv 92 | median value of owner-occupied homes in \$1000s. 93 | 94 | """ 95 | 96 | 97 | def get_optim(name, net, args): 98 | if name == 'SGD': 99 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 100 | elif name == 'Adam': 101 | # TODO: add in more. 102 | optimizer = optim.Adam(net.parameters(), args['lr']) 103 | elif name == 'RMSprop': 104 | # TODO: add in more. 105 | optimizer = optim.RMSprop(net.parameters(), args['lr']) 106 | 107 | return optimizer 108 | 109 | 110 | @ex.config 111 | def cfg(): 112 | x_index = 0 113 | y_index = 13 114 | 115 | # Flip around the line y = 0.5 to make negative correlations positive. 116 | x_flip, y_flip = True, False 117 | 118 | optim_name = 'SGD' 119 | optim_args = \ 120 | { 121 | 'lr': 1e-5, 122 | 'momentum': 0.9 123 | } 124 | num_epochs = 10000000 125 | batch_size = 200 126 | chkpt_freq = 500 127 | 128 | Phi = DiracPhi 129 | phi_name = 'DiracPhi' 130 | 131 | # Initial parameters. 132 | depth = 2 133 | widths = [10, 10] 134 | lc_w_range = (0, 1.0) 135 | shift_w_range = (0., 2.0) 136 | 137 | 138 | @ex.capture 139 | def get_info(_run): 140 | return _run._id 141 | 142 | 143 | def expt(train_data, val_data, 144 | net, 145 | optim_name, 146 | optim_args, 147 | identifier, 148 | num_epochs=1000, 149 | batch_size=100, 150 | chkpt_freq=50, 151 | ): 152 | 153 | os.mkdir('./checkpoints/%s' % identifier) 154 | os.mkdir('./sample_figs/%s' % identifier) 155 | 156 | train_loader = DataLoader( 157 | train_data, batch_size=batch_size, shuffle=True) 158 | val_loader = DataLoader( 159 | val_data, batch_size=1000000, shuffle=True) 160 | 161 | # IMPORTANT: for this experiment, we did *not* perform hyperparameter tuning. 162 | # Hence, the `validation loss' here is essentially `test` loss. 163 | optimizer = get_optim(optim_name, net, optim_args) 164 | 165 | train_loss_per_epoch = [] 166 | 167 | for epoch in range(num_epochs): 168 | loss_per_minibatch = [] 169 | for i, data in enumerate(train_loader, 0): 170 | optimizer.zero_grad() 171 | 172 | d = torch.tensor(data, requires_grad=True) 173 | p = net(d, mode='pdf') 174 | 175 | logloss = -torch.sum(torch.log(p)) 176 | reg_loss = logloss 177 | reg_loss.backward() 178 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 179 | 180 | loss_per_minibatch.append(scalar_loss) 181 | optimizer.step() 182 | 183 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 184 | print('Training loss at epoch %s: %s' % 185 | (epoch, train_loss_per_epoch[-1])) 186 | 187 | if epoch % chkpt_freq == 0: 188 | print('Checkpointing') 189 | torch.save({ 190 | 'epoch': epoch, 191 | 'model_state_dict': net.state_dict(), 192 | 'optimizer_state_dict': optimizer.state_dict(), 193 | 'loss': logloss, 194 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 195 | 196 | print('Evaluating validation loss') 197 | for j, val_data in enumerate(val_loader, 0): 198 | net.zero_grad() 199 | val_p = net(val_data, mode='pdf') 200 | val_loss = -torch.mean(torch.log(val_p)) 201 | print('Average validation loss %s' % val_loss) 202 | 203 | 204 | @ex.automain 205 | def run(x_index, y_index, 206 | x_flip, y_flip, 207 | Phi, 208 | depth, widths, lc_w_range, shift_w_range, 209 | optim_name, optim_args, 210 | num_epochs, batch_size, chkpt_freq): 211 | id = get_info() 212 | identifier_id = '%s%s' % (identifier, id) 213 | 214 | train_data = X_train[:, [x_index, y_index]] 215 | test_data = X_test[:, [x_index, y_index]] 216 | 217 | if x_flip: 218 | train_data[:, 0] = 1-train_data[:, 0] 219 | test_data[:, 0] = 1-test_data[:, 0] 220 | 221 | if y_flip: 222 | train_data[:, 1] = 1-train_data[:, 1] 223 | test_data[:, 1] = 1-test_data[:, 1] 224 | 225 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 226 | net = Copula(phi) 227 | expt(train_data, test_data, net, optim_name, 228 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 229 | 230 | 231 | if __name__ == '__main__': 232 | print('Sample usage: python -m train_scripts.boston.train -F boston_housing') 233 | -------------------------------------------------------------------------------- /train_scripts/clayton.py: -------------------------------------------------------------------------------- 1 | from main import Copula 2 | from dirac_phi import DiracPhi 3 | import pickle 4 | from sacred import Experiment 5 | from train import load_data, load_log_ll, expt 6 | import torch 7 | from sacred.observers import FileStorageObserver 8 | 9 | identifier = 'learn_clayton' 10 | ex = Experiment('LOG_learn_clayton') 11 | 12 | torch.set_default_tensor_type(torch.DoubleTensor) 13 | 14 | 15 | @ex.config 16 | def cfg(): 17 | data_name = './data/clayton1.p' 18 | num_train, num_test = 2000, 1000 19 | optim_name = 'SGD' 20 | optim_args = \ 21 | { 22 | 'lr': 1e-5, 23 | 'momentum': 0.9 24 | } 25 | num_epochs = 10000000 26 | batch_size = 200 27 | chkpt_freq = 50 28 | 29 | Phi = DiracPhi 30 | phi_name = 'DiracPhi' 31 | 32 | # Initial parameters. 33 | depth = 2 34 | widths = [10, 10] 35 | lc_w_range = (0, 1.0) 36 | shift_w_range = (0., 2.0) 37 | 38 | 39 | @ex.capture 40 | def get_info(_run): 41 | return _run._id 42 | 43 | 44 | @ex.automain 45 | def run(data_name, num_train, num_test, Phi, 46 | depth, widths, lc_w_range, shift_w_range, 47 | optim_name, optim_args, 48 | num_epochs, batch_size, chkpt_freq): 49 | id = get_info() 50 | identifier_id = '%s%s' % (identifier, id) 51 | train_data, test_data = load_data(data_name, num_train, num_test) 52 | train_ll, test_ll = load_log_ll(data_name, num_train, num_test) 53 | 54 | print('Computing ground truth manually because tagged log likelihood is wrong') 55 | from phi_listing import ClaytonPhi 56 | gt_phi = ClaytonPhi(torch.tensor(5.)) 57 | cop = Copula(gt_phi) 58 | train_ll = -torch.log(cop(train_data, mode='pdf')) 59 | test_ll = -torch.log(cop(test_data, mode='pdf')) 60 | 61 | print('train_ll', torch.mean(train_ll)) 62 | print('test_ll', torch.mean(test_ll)) 63 | 64 | print('Train ideal ll:', torch.mean(train_ll)) 65 | print('Test ideal ll:', torch.mean(test_ll)) 66 | 67 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 68 | net = Copula(phi) 69 | expt(train_data, test_data, net, optim_name, 70 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 71 | 72 | 73 | if __name__ == '__main__': 74 | print('Sample usage: python -m train_scripts.clayton -F learn_clayton') 75 | -------------------------------------------------------------------------------- /train_scripts/clayton_noisy.py: -------------------------------------------------------------------------------- 1 | from main import Copula 2 | from dirac_phi import DiracPhi 3 | import pickle 4 | from sacred import Experiment 5 | from train import load_data, load_log_ll, expt_cdf_noisy 6 | import torch 7 | from sacred.observers import FileStorageObserver 8 | 9 | identifier = 'learn_clayton_noisy' 10 | ex = Experiment('LOG_learn_clayton_noisy') 11 | 12 | torch.set_default_tensor_type(torch.DoubleTensor) 13 | 14 | 15 | @ex.config 16 | def cfg(): 17 | data_name = './data/clayton1.p' 18 | 19 | num_train, num_test = 2000, 1000 20 | optim_name = 'SGD' 21 | optim_args = \ 22 | { 23 | 'lr': 1e-5, 24 | 'momentum': 0.9 25 | } 26 | num_epochs = 10000000 27 | batch_size = 200 28 | chkpt_freq = 50 29 | 30 | seed_noise = 142857 31 | width_noise = 0.5 32 | 33 | Phi = DiracPhi 34 | phi_name = 'DiracPhi' 35 | 36 | # Initial parameters. 37 | depth = 2 38 | widths = [10, 10] 39 | lc_w_range = (0, 1.0) 40 | shift_w_range = (0., 2.0) 41 | 42 | 43 | @ex.capture 44 | def get_info(_run): 45 | return _run._id 46 | 47 | 48 | @ex.automain 49 | def run(data_name, num_train, num_test, Phi, 50 | depth, widths, lc_w_range, shift_w_range, 51 | optim_name, optim_args, 52 | width_noise, seed_noise, 53 | num_epochs, batch_size, chkpt_freq): 54 | id = get_info() 55 | identifier_id = '%s%s' % (identifier, id) 56 | train_data, test_data = load_data(data_name, num_train, num_test) 57 | train_ll, test_ll = load_log_ll(data_name, num_train, num_test) 58 | 59 | print('Train ideal ll:', torch.mean(train_ll)) 60 | print('Test ideal ll:', torch.mean(test_ll)) 61 | 62 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 63 | net = Copula(phi) 64 | expt_cdf_noisy(train_data, test_data, net, optim_name, 65 | optim_args, identifier_id, 66 | width_noise, seed_noise, 67 | num_epochs, batch_size, chkpt_freq) 68 | 69 | 70 | if __name__ == '__main__': 71 | print('Sample usage: python -m train_scripts.clayton_noisy -F learn_clayton_noisy') 72 | -------------------------------------------------------------------------------- /train_scripts/frank.py: -------------------------------------------------------------------------------- 1 | from main import Copula 2 | from phi_listing import FrankPhi 3 | from dirac_phi import DiracPhi 4 | import pickle 5 | from sacred import Experiment 6 | from train import load_data, load_log_ll, expt 7 | import torch 8 | from sacred.observers import FileStorageObserver 9 | 10 | identifier = 'learn_frank' 11 | ex = Experiment('LOG_learn_frank') 12 | 13 | torch.set_default_tensor_type(torch.DoubleTensor) 14 | 15 | 16 | @ex.config 17 | def cfg(): 18 | data_name = './data/frank1.p' 19 | num_train, num_test = 2000, 1000 20 | optim_name = 'SGD' 21 | optim_args = \ 22 | { 23 | 'lr': 1e-5, 24 | 'momentum': 0.9 25 | } 26 | num_epochs = 10000000 27 | batch_size = 200 28 | chkpt_freq = 50 29 | 30 | Phi = DiracPhi 31 | phi_name = 'DiracPhi' 32 | 33 | # Initial parameters. 34 | depth = 2 35 | widths = [10, 10] 36 | lc_w_range = (0, 1.0) 37 | shift_w_range = (0., 2.0) 38 | 39 | 40 | @ex.capture 41 | def get_info(_run): 42 | return _run._id 43 | 44 | 45 | @ex.automain 46 | def run(data_name, num_train, num_test, Phi, 47 | depth, widths, lc_w_range, shift_w_range, 48 | optim_name, optim_args, 49 | num_epochs, batch_size, chkpt_freq): 50 | id = get_info() 51 | identifier_id = '%s%s' % (identifier, id) 52 | train_data, test_data = load_data(data_name, num_train, num_test) 53 | train_ll, test_ll = load_log_ll(data_name, num_train, num_test) 54 | 55 | print('Train ideal ll:', torch.mean(train_ll)) 56 | print('Test ideal ll:', torch.mean(test_ll)) 57 | 58 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 59 | net = Copula(phi) 60 | expt(train_data, test_data, net, optim_name, 61 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 62 | 63 | 64 | if __name__ == '__main__': 65 | print('Sample usage: python -m train_scripts.frank -F learn_frank') 66 | -------------------------------------------------------------------------------- /train_scripts/gas/train.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | from dirac_phi import DiracPhi 5 | from main import sample 6 | import os 7 | import numpy as np 8 | import torch.optim as optim 9 | import torch 10 | import matplotlib.pyplot as plt 11 | from torch.utils.data import DataLoader 12 | from main import Copula 13 | 14 | month1_path = "data/gas/batch1.dat" 15 | # Used if we want to compare between months. 16 | month2_path = "data/gas/batch2.dat" 17 | 18 | 19 | # Extract data which compares between features. 20 | def data_test_between_features(feature_ids, sensor_id, month_data): 21 | features = [] 22 | for feature_id in feature_ids: 23 | 24 | full_feature_id = feature_id + sensor_id * 8 25 | if feature_id in [5, 6, 7]: 26 | features.append(-month_data[:, full_feature_id]) 27 | else: 28 | features.append(month_data[:, full_feature_id]) 29 | 30 | return features 31 | 32 | 33 | # Extract data which compares between sensor ids. 34 | def data_test_between_sensors(feature_id, sensor_ids, month_data): 35 | assert len(sensor_ids) == 2 36 | 37 | id1 = feature_id[0] + sensor_ids[0] * 8 38 | id2 = feature_id[1] + sensor_ids[1] * 8 39 | return month_data[:, id1], month_data[:, id2] 40 | 41 | 42 | # Extract data which compares between different months. 43 | def data_test_between_months(feature_id, sensor_id, months_data): 44 | assert len(months_data) == 2 45 | 46 | id = feature_id + sensor_id * 8 47 | return months_data[0][:, id], months_data[1][:, id] 48 | 49 | 50 | def read_batch(filepath): 51 | def format_feature(x): 52 | return float(x.decode('UTF-8').split(':')[1]) 53 | 54 | d = [(i, format_feature) for i in range(1, 129)] 55 | z = np.genfromtxt(filepath, 56 | delimiter=" ", 57 | usecols=list(range(1, 129)), 58 | converters=dict(d), 59 | ) 60 | return z 61 | 62 | 63 | month1 = read_batch(month1_path) 64 | month2 = read_batch(month2_path) 65 | 66 | identifier = 'gas_2012' 67 | ex = Experiment('gas_2012') 68 | 69 | torch.set_default_tensor_type(torch.DoubleTensor) 70 | 71 | data = data_test_between_features((0, 4, 7), 0, month1) 72 | data = data_test_between_features((0, 4, 7), 2, month1) 73 | d1 = data[0] 74 | d2 = data[1] 75 | d3 = data[2] 76 | 77 | X = np.concatenate([d1[:, None], d2[:, None], d3[:, None]], axis=1) 78 | 79 | # plt.scatter(X[:, 0], X[:, 1]) 80 | # plt.show() 81 | 82 | 83 | def add_train_random_noise(data, num_adds): 84 | new_data = np.random.rand(num_adds, data.shape[1]) 85 | print(data.shape) 86 | print(new_data.shape) 87 | return np.concatenate((data, new_data), axis=0) 88 | 89 | 90 | X_train, X_test, _, _ = train_test_split( 91 | X, X, shuffle=True, random_state=142857) 92 | # X_train, X_test, _, _ = train_test_split( 93 | # X, X, shuffle=True, random_state=714285) 94 | # X_train, X_test, _, _ = train_test_split( 95 | # X, X, shuffle=True, random_state=571428) 96 | # X_train, X_test, _, _ = train_test_split( 97 | # X, X, shuffle=True, random_state=857142) 98 | # X_train, X_test, _, _ = train_test_split( 99 | # X, X, shuffle=True, random_state=285714) 100 | 101 | nfeats = X_test.shape[1] 102 | 103 | # Normalize data. 104 | for z in [X_train, X_test]: 105 | ndata = z.shape[0] 106 | gap = 1./(ndata+1) 107 | for i in range(nfeats): 108 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 109 | 110 | 111 | def get_optim(name, net, args): 112 | if name == 'SGD': 113 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 114 | else: 115 | assert False 116 | 117 | return optimizer 118 | 119 | 120 | @ex.config 121 | def cfg(): 122 | x_index = 0 123 | y_index = 1 124 | 125 | x_flip, y_flip = False, False 126 | 127 | optim_name = 'SGD' 128 | optim_args = \ 129 | { 130 | 'lr': 1e-5, 131 | 'momentum': 0.9 132 | } 133 | num_epochs = 10000000 134 | batch_size = 200 135 | chkpt_freq = 500 136 | 137 | Phi = DiracPhi 138 | phi_name = 'DiracPhi' 139 | 140 | # Initial parameters. 141 | depth = 2 142 | widths = [10, 10] 143 | lc_w_range = (0, 1.0) 144 | shift_w_range = (0., 2.0) 145 | 146 | frac_rand = 0.01 147 | 148 | 149 | @ex.capture 150 | def get_info(_run): 151 | return _run._id 152 | 153 | 154 | def expt(train_data, val_data, 155 | net, 156 | optim_name, 157 | optim_args, 158 | identifier, 159 | num_epochs=1000, 160 | batch_size=100, 161 | chkpt_freq=50, 162 | ): 163 | 164 | os.mkdir('./checkpoints/%s' % identifier) 165 | os.mkdir('./sample_figs/%s' % identifier) 166 | 167 | train_loader = DataLoader( 168 | train_data, batch_size=batch_size, shuffle=True) 169 | val_loader = DataLoader( 170 | val_data, batch_size=1000000, shuffle=True) 171 | 172 | optimizer = get_optim(optim_name, net, optim_args) 173 | 174 | train_loss_per_epoch = [] 175 | 176 | for epoch in range(num_epochs): 177 | loss_per_minibatch = [] 178 | for i, data in enumerate(train_loader, 0): 179 | optimizer.zero_grad() 180 | 181 | d = torch.tensor(data, requires_grad=True) 182 | p = net(d, mode='pdf') 183 | 184 | logloss = -torch.sum(torch.log(p)) 185 | reg_loss = logloss 186 | reg_loss.backward() 187 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 188 | 189 | loss_per_minibatch.append(scalar_loss) 190 | optimizer.step() 191 | 192 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 193 | print('Training loss at epoch %s: %s' % 194 | (epoch, train_loss_per_epoch[-1])) 195 | 196 | if epoch % chkpt_freq == 0: 197 | print('Checkpointing') 198 | torch.save({ 199 | 'epoch': epoch, 200 | 'model_state_dict': net.state_dict(), 201 | 'optimizer_state_dict': optimizer.state_dict(), 202 | 'loss': logloss, 203 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 204 | 205 | # if args.dims == 2: 206 | if True: 207 | # if False: 208 | print('Scatter sampling') 209 | samples = sample(net, 2, 1000) 210 | plt.scatter(samples[:, 0], samples[:, 1]) 211 | plt.savefig('./sample_figs/%s/epoch%s.png' % 212 | (identifier, epoch)) 213 | plt.clf() 214 | else: 215 | print('Not doign scatter plot, dims > 2') 216 | 217 | print('Evaluating validation loss') 218 | for j, val_data in enumerate(val_loader, 0): 219 | net.zero_grad() 220 | val_p = net(val_data, mode='pdf') 221 | val_loss = -torch.mean(torch.log(val_p)) 222 | print('Average validation loss %s' % val_loss) 223 | 224 | 225 | @ex.automain 226 | def run(Phi, 227 | depth, widths, lc_w_range, shift_w_range, 228 | optim_name, optim_args, 229 | num_epochs, batch_size, chkpt_freq, 230 | frac_rand): 231 | id = get_info() 232 | identifier_id = '%s%s' % (identifier, id) 233 | 234 | train_data = X_train 235 | train_data = add_train_random_noise(train_data, int(X_train.shape[0]*frac_rand)) 236 | test_data = X_test 237 | 238 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 239 | net = Copula(phi) 240 | expt(train_data, test_data, net, optim_name, 241 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 242 | 243 | 244 | if __name__ == '__main__': 245 | print('Sample usage: python -m train_scripts.gas.train -F learn_gas') 246 | -------------------------------------------------------------------------------- /train_scripts/gas/train_with_clayton.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | import os 5 | import numpy as np 6 | import torch.optim as optim 7 | import torch 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | from main import Copula 11 | 12 | month1_path = "data/gas/batch1.dat" 13 | # Used if we want to compare between months. 14 | month2_path = "data/gas/batch2.dat" 15 | 16 | 17 | # Extract data which compares between features. 18 | def data_test_between_features(feature_ids, sensor_id, month_data): 19 | # assert len(feature_ids) == 2 20 | 21 | num_features = len(feature_ids) 22 | 23 | features = [] 24 | for feature_id in feature_ids: 25 | 26 | full_feature_id = feature_id + sensor_id * 8 27 | if feature_id in [5, 6, 7]: 28 | features.append(-month_data[:, full_feature_id]) 29 | else: 30 | features.append(month_data[:, full_feature_id]) 31 | 32 | return features 33 | 34 | 35 | # Extract data which compares between sensor ids. 36 | def data_test_between_sensors(feature_id, sensor_ids, month_data): 37 | assert len(sensor_ids) == 2 38 | 39 | id1 = feature_id[0] + sensor_ids[0] * 8 40 | id2 = feature_id[1] + sensor_ids[1] * 8 41 | return month_data[:, id1], month_data[:, id2] 42 | 43 | 44 | # Extract data which compares between different months. 45 | def data_test_between_months(feature_id, sensor_id, months_data): 46 | assert len(months_data) == 2 47 | 48 | id = feature_id + sensor_id * 8 49 | return months_data[0][:, id], months_data[1][:, id] 50 | 51 | 52 | def read_batch(filepath): 53 | def format_feature(x): 54 | return float(x.decode('UTF-8').split(':')[1]) 55 | 56 | d = [(i, format_feature) for i in range(1, 129)] 57 | z = np.genfromtxt(filepath, 58 | delimiter=" ", 59 | usecols=list(range(1, 129)), 60 | converters=dict(d), 61 | ) 62 | return z 63 | 64 | 65 | month1 = read_batch(month1_path) 66 | month2 = read_batch(month2_path) 67 | 68 | identifier = 'gas_2012_clayton' 69 | ex = Experiment('gas_2012_clayton') 70 | 71 | torch.set_default_tensor_type(torch.DoubleTensor) 72 | 73 | data = data_test_between_features((0, 4, 7), 0, month1) 74 | data = data_test_between_features((0, 4, 7), 2, month1) 75 | d1 = data[0] 76 | d2 = data[1] 77 | d3 = data[2] 78 | 79 | X = np.concatenate([d1[:, None], d2[:, None], d3[:, None]], axis=1) 80 | 81 | # plt.scatter(X[:, 0], X[:, 1]) 82 | # plt.show() 83 | 84 | 85 | def add_train_random_noise(data, num_adds): 86 | new_data = np.random.rand(num_adds, data.shape[1]) 87 | print(data.shape) 88 | print(new_data.shape) 89 | return np.concatenate((data, new_data), axis=0) 90 | 91 | 92 | X_train, X_test, _, _ = train_test_split( 93 | X, X, shuffle=True, random_state=142857) 94 | # X_train, X_test, _, _ = train_test_split( 95 | # X, X, shuffle=True, random_state=714285) 96 | # X_train, X_test, _, _ = train_test_split( 97 | # X, X, shuffle=True, random_state=571428) 98 | # X_train, X_test, _, _ = train_test_split( 99 | # X, X, shuffle=True, random_state=857142) 100 | # X_train, X_test, _, _ = train_test_split( 101 | # X, X, shuffle=True, random_state=285714) 102 | 103 | nfeats = X_test.shape[1] 104 | 105 | # Normalize data. 106 | for z in [X_train, X_test]: 107 | ndata = z.shape[0] 108 | gap = 1./(ndata+1) 109 | for i in range(nfeats): 110 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 111 | 112 | 113 | # plt.scatter(X_train[:, 0], X_train[:, 1]) 114 | # plt.show() 115 | 116 | 117 | def get_optim(name, net, args): 118 | if name == 'SGD': 119 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 120 | else: 121 | assert False 122 | 123 | return optimizer 124 | 125 | 126 | @ex.config 127 | def cfg(): 128 | x_index = 0 129 | y_index = 1 130 | 131 | x_flip, y_flip = False, False 132 | 133 | optim_name = 'SGD' 134 | optim_args = \ 135 | { 136 | 'lr': 1e-5, 137 | 'momentum': 0.9 138 | } 139 | num_epochs = 10000000 140 | batch_size = 200 141 | chkpt_freq = 500 142 | 143 | from phi_listing import ClaytonPhi 144 | Phi = ClaytonPhi 145 | phi_name = 'ClaytonPhi' 146 | 147 | # Initial parameters. 148 | initial_theta = 5. 149 | 150 | frac_rand = 0.01 151 | 152 | 153 | @ex.capture 154 | def get_info(_run): 155 | return _run._id 156 | 157 | 158 | def expt(train_data, val_data, 159 | net, 160 | optim_name, 161 | optim_args, 162 | identifier, 163 | num_epochs=1000, 164 | batch_size=100, 165 | chkpt_freq=50, 166 | ): 167 | 168 | os.mkdir('./checkpoints/%s' % identifier) 169 | os.mkdir('./sample_figs/%s' % identifier) 170 | 171 | train_loader = DataLoader( 172 | train_data, batch_size=batch_size, shuffle=True) 173 | val_loader = DataLoader( 174 | val_data, batch_size=1000000, shuffle=True) 175 | 176 | optimizer = get_optim(optim_name, net, optim_args) 177 | 178 | train_loss_per_epoch = [] 179 | 180 | for epoch in range(num_epochs): 181 | print(net.phi.theta) 182 | loss_per_minibatch = [] 183 | for i, data in enumerate(train_loader, 0): 184 | optimizer.zero_grad() 185 | 186 | d = torch.tensor(data, requires_grad=True) 187 | p = net(d, mode='pdf') 188 | 189 | logloss = -torch.sum(torch.log(p)) 190 | reg_loss = logloss 191 | reg_loss.backward() 192 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 193 | 194 | loss_per_minibatch.append(scalar_loss) 195 | optimizer.step() 196 | 197 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 198 | print('Training loss at epoch %s: %s' % 199 | (epoch, train_loss_per_epoch[-1])) 200 | 201 | if epoch % chkpt_freq == 0: 202 | print('Checkpointing') 203 | torch.save({ 204 | 'epoch': epoch, 205 | 'model_state_dict': net.state_dict(), 206 | 'optimizer_state_dict': optimizer.state_dict(), 207 | 'loss': logloss, 208 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 209 | 210 | print('Evaluating validation loss') 211 | for j, val_data in enumerate(val_loader, 0): 212 | net.zero_grad() 213 | val_p = net(val_data, mode='pdf') 214 | val_loss = -torch.mean(torch.log(val_p)) 215 | print('Average validation loss %s' % val_loss) 216 | 217 | 218 | @ex.automain 219 | def run(Phi, 220 | initial_theta, 221 | optim_name, optim_args, 222 | num_epochs, batch_size, chkpt_freq, 223 | frac_rand): 224 | id = get_info() 225 | identifier_id = '%s%s' % (identifier, id) 226 | 227 | train_data = X_train 228 | train_data = add_train_random_noise(train_data, int(X_train.shape[0]*frac_rand)) 229 | test_data = X_test 230 | 231 | phi = Phi(torch.tensor(initial_theta)) 232 | net = Copula(phi) 233 | expt(train_data, test_data, net, optim_name, 234 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 235 | 236 | 237 | if __name__ == '__main__': 238 | print('Sample usage: python -m train_scripts.gas.train_with_clayton -F learn_gas_with_clayton') 239 | -------------------------------------------------------------------------------- /train_scripts/gas/train_with_frank.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | import os 5 | import numpy as np 6 | import torch.optim as optim 7 | import torch 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | from main import Copula 11 | 12 | month1_path = "data/gas/batch1.dat" 13 | # Used if we want to compare between months. 14 | month2_path = "data/gas/batch2.dat" 15 | 16 | 17 | # Extract data which compares between features. 18 | def data_test_between_features(feature_ids, sensor_id, month_data): 19 | # assert len(feature_ids) == 2 20 | 21 | num_features = len(feature_ids) 22 | 23 | features = [] 24 | for feature_id in feature_ids: 25 | 26 | full_feature_id = feature_id + sensor_id * 8 27 | if feature_id in [5, 6, 7]: 28 | features.append(-month_data[:, full_feature_id]) 29 | else: 30 | features.append(month_data[:, full_feature_id]) 31 | 32 | return features 33 | 34 | 35 | # Extract data which compares between sensor ids. 36 | def data_test_between_sensors(feature_id, sensor_ids, month_data): 37 | assert len(sensor_ids) == 2 38 | 39 | id1 = feature_id[0] + sensor_ids[0] * 8 40 | id2 = feature_id[1] + sensor_ids[1] * 8 41 | return month_data[:, id1], month_data[:, id2] 42 | 43 | 44 | # Extract data which compares between different months. 45 | def data_test_between_months(feature_id, sensor_id, months_data): 46 | assert len(months_data) == 2 47 | 48 | id = feature_id + sensor_id * 8 49 | return months_data[0][:, id], months_data[1][:, id] 50 | 51 | 52 | def read_batch(filepath): 53 | def format_feature(x): 54 | return float(x.decode('UTF-8').split(':')[1]) 55 | 56 | d = [(i, format_feature) for i in range(1, 129)] 57 | z = np.genfromtxt(filepath, 58 | delimiter=" ", 59 | usecols=list(range(1, 129)), 60 | converters=dict(d), 61 | ) 62 | return z 63 | 64 | 65 | month1 = read_batch(month1_path) 66 | month2 = read_batch(month2_path) 67 | 68 | identifier = 'gas_2012_frank' 69 | ex = Experiment('gas_2012_frank') 70 | 71 | torch.set_default_tensor_type(torch.DoubleTensor) 72 | 73 | data = data_test_between_features((0, 4, 7), 0, month1) 74 | data = data_test_between_features((0, 4, 7), 2, month1) 75 | d1 = data[0] 76 | d2 = data[1] 77 | d3 = data[2] 78 | 79 | X = np.concatenate([d1[:, None], d2[:, None], d3[:, None]], axis=1) 80 | 81 | # plt.scatter(X[:, 0], X[:, 1]) 82 | # plt.show() 83 | 84 | 85 | def add_train_random_noise(data, num_adds): 86 | new_data = np.random.rand(num_adds, data.shape[1]) 87 | print(data.shape) 88 | print(new_data.shape) 89 | return np.concatenate((data, new_data), axis=0) 90 | 91 | 92 | X_train, X_test, _, _ = train_test_split( 93 | X, X, shuffle=True, random_state=142857) 94 | # X_train, X_test, _, _ = train_test_split( 95 | # X, X, shuffle=True, random_state=714285) 96 | # X_train, X_test, _, _ = train_test_split( 97 | # X, X, shuffle=True, random_state=571428) 98 | # X_train, X_test, _, _ = train_test_split( 99 | # X, X, shuffle=True, random_state=857142) 100 | # X_train, X_test, _, _ = train_test_split( 101 | # X, X, shuffle=True, random_state=285714) 102 | 103 | nfeats = X_test.shape[1] 104 | 105 | # Normalize data. 106 | for z in [X_train, X_test]: 107 | ndata = z.shape[0] 108 | gap = 1./(ndata+1) 109 | for i in range(nfeats): 110 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 111 | 112 | 113 | # plt.scatter(X_train[:, 0], X_train[:, 1]) 114 | # plt.show() 115 | 116 | 117 | def get_optim(name, net, args): 118 | if name == 'SGD': 119 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 120 | else: 121 | assert False 122 | 123 | return optimizer 124 | 125 | 126 | @ex.config 127 | def cfg(): 128 | x_index = 0 129 | y_index = 1 130 | 131 | x_flip, y_flip = False, False 132 | 133 | optim_name = 'SGD' 134 | optim_args = \ 135 | { 136 | 'lr': 1e-5, 137 | 'momentum': 0.9 138 | } 139 | num_epochs = 10000000 140 | batch_size = 200 141 | chkpt_freq = 500 142 | 143 | from phi_listing import FrankPhi 144 | Phi = FrankPhi 145 | phi_name = 'FrankPhi' 146 | 147 | # Initial parameters. 148 | initial_theta = 5. 149 | 150 | frac_rand = 0.01 151 | 152 | 153 | @ex.capture 154 | def get_info(_run): 155 | return _run._id 156 | 157 | 158 | def expt(train_data, val_data, 159 | net, 160 | optim_name, 161 | optim_args, 162 | identifier, 163 | num_epochs=1000, 164 | batch_size=100, 165 | chkpt_freq=50, 166 | ): 167 | 168 | os.mkdir('./checkpoints/%s' % identifier) 169 | os.mkdir('./sample_figs/%s' % identifier) 170 | 171 | train_loader = DataLoader( 172 | train_data, batch_size=batch_size, shuffle=True) 173 | val_loader = DataLoader( 174 | val_data, batch_size=1000000, shuffle=True) 175 | 176 | optimizer = get_optim(optim_name, net, optim_args) 177 | 178 | train_loss_per_epoch = [] 179 | 180 | for epoch in range(num_epochs): 181 | print(net.phi.theta) 182 | loss_per_minibatch = [] 183 | for i, data in enumerate(train_loader, 0): 184 | optimizer.zero_grad() 185 | 186 | d = torch.tensor(data, requires_grad=True) 187 | p = net(d, mode='pdf') 188 | 189 | logloss = -torch.sum(torch.log(p)) 190 | reg_loss = logloss 191 | reg_loss.backward() 192 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 193 | 194 | loss_per_minibatch.append(scalar_loss) 195 | optimizer.step() 196 | 197 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 198 | print('Training loss at epoch %s: %s' % 199 | (epoch, train_loss_per_epoch[-1])) 200 | 201 | if epoch % chkpt_freq == 0: 202 | print('Checkpointing') 203 | torch.save({ 204 | 'epoch': epoch, 205 | 'model_state_dict': net.state_dict(), 206 | 'optimizer_state_dict': optimizer.state_dict(), 207 | 'loss': logloss, 208 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 209 | 210 | print('Evaluating validation loss') 211 | for j, val_data in enumerate(val_loader, 0): 212 | net.zero_grad() 213 | val_p = net(val_data, mode='pdf') 214 | val_loss = -torch.mean(torch.log(val_p)) 215 | print('Average validation loss %s' % val_loss) 216 | 217 | 218 | @ex.automain 219 | def run(Phi, 220 | initial_theta, 221 | optim_name, optim_args, 222 | num_epochs, batch_size, chkpt_freq, 223 | frac_rand): 224 | id = get_info() 225 | identifier_id = '%s%s' % (identifier, id) 226 | 227 | train_data = X_train 228 | train_data = add_train_random_noise(train_data, int(X_train.shape[0]*frac_rand)) 229 | test_data = X_test 230 | 231 | phi = Phi(torch.tensor(initial_theta)) 232 | net = Copula(phi) 233 | expt(train_data, test_data, net, optim_name, 234 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 235 | 236 | 237 | if __name__ == '__main__': 238 | print('Sample usage: python -m train_scripts.gas.train_with_frank -F learn_gas_with_frank') 239 | -------------------------------------------------------------------------------- /train_scripts/gas/train_with_gumbel.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | import os 5 | import numpy as np 6 | import torch.optim as optim 7 | import torch 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | from main import Copula 11 | 12 | month1_path = "data/gas/batch1.dat" 13 | # Used if we want to compare between months. 14 | month2_path = "data/gas/batch2.dat" 15 | 16 | 17 | # Extract data which compares between features. 18 | def data_test_between_features(feature_ids, sensor_id, month_data): 19 | # assert len(feature_ids) == 2 20 | 21 | num_features = len(feature_ids) 22 | 23 | features = [] 24 | for feature_id in feature_ids: 25 | 26 | full_feature_id = feature_id + sensor_id * 8 27 | if feature_id in [5, 6, 7]: 28 | features.append(-month_data[:, full_feature_id]) 29 | else: 30 | features.append(month_data[:, full_feature_id]) 31 | 32 | return features 33 | 34 | 35 | # Extract data which compares between sensor ids. 36 | def data_test_between_sensors(feature_id, sensor_ids, month_data): 37 | assert len(sensor_ids) == 2 38 | 39 | id1 = feature_id[0] + sensor_ids[0] * 8 40 | id2 = feature_id[1] + sensor_ids[1] * 8 41 | return month_data[:, id1], month_data[:, id2] 42 | 43 | 44 | # Extract data which compares between different months. 45 | def data_test_between_months(feature_id, sensor_id, months_data): 46 | assert len(months_data) == 2 47 | 48 | id = feature_id + sensor_id * 8 49 | return months_data[0][:, id], months_data[1][:, id] 50 | 51 | 52 | def read_batch(filepath): 53 | def format_feature(x): 54 | return float(x.decode('UTF-8').split(':')[1]) 55 | 56 | d = [(i, format_feature) for i in range(1, 129)] 57 | z = np.genfromtxt(filepath, 58 | delimiter=" ", 59 | usecols=list(range(1, 129)), 60 | converters=dict(d), 61 | ) 62 | return z 63 | 64 | 65 | month1 = read_batch(month1_path) 66 | month2 = read_batch(month2_path) 67 | 68 | identifier = 'gas_2012_gumbel' 69 | ex = Experiment('gas_2012_gumbel') 70 | 71 | torch.set_default_tensor_type(torch.DoubleTensor) 72 | 73 | data = data_test_between_features((0, 4, 7), 0, month1) 74 | data = data_test_between_features((0, 4, 7), 2, month1) 75 | d1 = data[0] 76 | d2 = data[1] 77 | d3 = data[2] 78 | 79 | X = np.concatenate([d1[:, None], d2[:, None], d3[:, None]], axis=1) 80 | 81 | # plt.scatter(X[:, 0], X[:, 1]) 82 | # plt.show() 83 | 84 | 85 | def add_train_random_noise(data, num_adds): 86 | new_data = np.random.rand(num_adds, data.shape[1]) 87 | print(data.shape) 88 | print(new_data.shape) 89 | return np.concatenate((data, new_data), axis=0) 90 | 91 | 92 | X_train, X_test, _, _ = train_test_split( 93 | X, X, shuffle=True, random_state=142857) 94 | # X_train, X_test, _, _ = train_test_split( 95 | # X, X, shuffle=True, random_state=714285) 96 | # X_train, X_test, _, _ = train_test_split( 97 | # X, X, shuffle=True, random_state=571428) 98 | # X_train, X_test, _, _ = train_test_split( 99 | # X, X, shuffle=True, random_state=857142) 100 | # X_train, X_test, _, _ = train_test_split( 101 | # X, X, shuffle=True, random_state=285714) 102 | 103 | nfeats = X_test.shape[1] 104 | 105 | # Normalize data. 106 | for z in [X_train, X_test]: 107 | ndata = z.shape[0] 108 | gap = 1./(ndata+1) 109 | for i in range(nfeats): 110 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 111 | 112 | 113 | # plt.scatter(X_train[:, 0], X_train[:, 1]) 114 | # plt.show() 115 | 116 | 117 | def get_optim(name, net, args): 118 | if name == 'SGD': 119 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 120 | else: 121 | assert False 122 | 123 | return optimizer 124 | 125 | 126 | @ex.config 127 | def cfg(): 128 | x_index = 0 129 | y_index = 1 130 | 131 | x_flip, y_flip = False, False 132 | 133 | optim_name = 'SGD' 134 | optim_args = \ 135 | { 136 | 'lr': 1e-5, 137 | 'momentum': 0.9 138 | } 139 | num_epochs = 10000000 140 | batch_size = 200 141 | chkpt_freq = 500 142 | 143 | from phi_listing import GumbelPhi 144 | Phi = GumbelPhi 145 | phi_name = 'GumbelPhi' 146 | 147 | # Initial parameters. 148 | initial_theta = 5. 149 | 150 | frac_rand = 0.01 151 | 152 | 153 | @ex.capture 154 | def get_info(_run): 155 | return _run._id 156 | 157 | 158 | def expt(train_data, val_data, 159 | net, 160 | optim_name, 161 | optim_args, 162 | identifier, 163 | num_epochs=1000, 164 | batch_size=100, 165 | chkpt_freq=50, 166 | ): 167 | 168 | os.mkdir('./checkpoints/%s' % identifier) 169 | os.mkdir('./sample_figs/%s' % identifier) 170 | 171 | train_loader = DataLoader( 172 | train_data, batch_size=batch_size, shuffle=True) 173 | val_loader = DataLoader( 174 | val_data, batch_size=1000000, shuffle=True) 175 | 176 | optimizer = get_optim(optim_name, net, optim_args) 177 | 178 | train_loss_per_epoch = [] 179 | 180 | for epoch in range(num_epochs): 181 | print(net.phi.theta) 182 | loss_per_minibatch = [] 183 | for i, data in enumerate(train_loader, 0): 184 | optimizer.zero_grad() 185 | 186 | d = torch.tensor(data, requires_grad=True) 187 | p = net(d, mode='pdf') 188 | 189 | logloss = -torch.sum(torch.log(p)) 190 | reg_loss = logloss 191 | reg_loss.backward() 192 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 193 | 194 | loss_per_minibatch.append(scalar_loss) 195 | optimizer.step() 196 | 197 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 198 | print('Training loss at epoch %s: %s' % 199 | (epoch, train_loss_per_epoch[-1])) 200 | 201 | if epoch % chkpt_freq == 0: 202 | print('Checkpointing') 203 | torch.save({ 204 | 'epoch': epoch, 205 | 'model_state_dict': net.state_dict(), 206 | 'optimizer_state_dict': optimizer.state_dict(), 207 | 'loss': logloss, 208 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 209 | 210 | print('Evaluating validation loss') 211 | for j, val_data in enumerate(val_loader, 0): 212 | net.zero_grad() 213 | val_p = net(val_data, mode='pdf') 214 | val_loss = -torch.mean(torch.log(val_p)) 215 | print('Average validation loss %s' % val_loss) 216 | 217 | 218 | @ex.automain 219 | def run(Phi, 220 | initial_theta, 221 | optim_name, optim_args, 222 | num_epochs, batch_size, chkpt_freq, 223 | frac_rand): 224 | id = get_info() 225 | identifier_id = '%s%s' % (identifier, id) 226 | 227 | train_data = X_train 228 | train_data = add_train_random_noise(train_data, int(X_train.shape[0]*frac_rand)) 229 | test_data = X_test 230 | 231 | phi = Phi(torch.tensor(initial_theta)) 232 | net = Copula(phi) 233 | expt(train_data, test_data, net, optim_name, 234 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 235 | 236 | 237 | if __name__ == '__main__': 238 | print('Sample usage: python -m train_scripts.gas.train_with_gumbel -F learn_gas_with_gumbel') 239 | -------------------------------------------------------------------------------- /train_scripts/joe.py: -------------------------------------------------------------------------------- 1 | from main import Copula 2 | from dirac_phi import DiracPhi 3 | import pickle 4 | from sacred import Experiment 5 | from train import load_data, load_log_ll, expt 6 | import torch 7 | from sacred.observers import FileStorageObserver 8 | 9 | identifier = 'data_joe_net_dirac' 10 | ex = Experiment('LOG_data_joe_net_dirac') 11 | 12 | torch.set_default_tensor_type(torch.DoubleTensor) 13 | 14 | 15 | @ex.config 16 | def cfg(): 17 | data_name = './data/joe1.p' 18 | num_train, num_test = 2000, 1000 19 | optim_name = 'SGD' 20 | optim_args = \ 21 | { 22 | 'lr': 1e-5, 23 | 'momentum': 0.9 24 | } 25 | num_epochs = 10000000 26 | batch_size = 200 27 | chkpt_freq = 50 28 | 29 | Phi = DiracPhi 30 | phi_name = 'DiracPhi' 31 | 32 | # Initial parameters. 33 | depth = 2 34 | widths = [10, 10] 35 | lc_w_range = (0, 1.0) 36 | shift_w_range = (0., 2.0) 37 | 38 | 39 | @ex.capture 40 | def get_info(_run): 41 | return _run._id 42 | 43 | 44 | @ex.automain 45 | def run(data_name, num_train, num_test, Phi, 46 | depth, widths, lc_w_range, shift_w_range, 47 | optim_name, optim_args, 48 | num_epochs, batch_size, chkpt_freq): 49 | id = get_info() 50 | identifier_id = '%s%s' % (identifier, id) 51 | train_data, test_data = load_data(data_name, num_train, num_test) 52 | train_ll, test_ll = load_log_ll(data_name, num_train, num_test) 53 | 54 | print('Train ideal ll:', torch.mean(train_ll)) 55 | print('Test ideal ll:', torch.mean(test_ll)) 56 | 57 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 58 | net = Copula(phi) 59 | expt(train_data, test_data, net, optim_name, 60 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 61 | 62 | 63 | if __name__ == '__main__': 64 | print('Sample usage: python -m train_scripts.frank -F learn_joe') 65 | -------------------------------------------------------------------------------- /train_scripts/rdj/train.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | from dirac_phi import DiracPhi 5 | from main import sample 6 | import os 7 | import numpy as np 8 | import torch.optim as optim 9 | import torch 10 | import matplotlib.pyplot as plt 11 | from torch.utils.data import DataLoader 12 | from main import Copula 13 | 14 | 15 | intel_f = open('data/rdj/INTEL.data', 'r') 16 | intel = np.array(list(map(float, intel_f.readlines()))) 17 | 18 | ms_f = open('data/rdj/MS.data', 'r') 19 | ms = np.array(list(map(float, ms_f.readlines()))) 20 | 21 | ge_f = open('data/rdj/GE.data', 'r') 22 | ge = np.array(list(map(float, ge_f.readlines()))) 23 | 24 | identifier = 'rdj_stocks' 25 | ex = Experiment('rdj_stocks') 26 | 27 | torch.set_default_tensor_type(torch.DoubleTensor) 28 | 29 | X = np.concatenate((intel[:, None], ms[:, None]), axis=1) 30 | 31 | 32 | def add_train_random_noise(data, num_adds): 33 | new_data = np.random.rand(num_adds, data.shape[1]) 34 | print(data.shape) 35 | print(new_data.shape) 36 | return np.concatenate((data, new_data), axis=0) 37 | 38 | 39 | X_train, X_test, _, _ = train_test_split( 40 | X, X, shuffle=True, random_state=142857) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=714285) 43 | # X_train, X_test, _, _ = train_test_split( 44 | # X, X, shuffle=True, random_state=571428) 45 | # X_train, X_test, _, _ = train_test_split( 46 | # X, X, shuffle=True, random_state=857142) 47 | # X_train, X_test, _, _ = train_test_split( 48 | # X, X, shuffle=True, random_state=285714) 49 | 50 | nfeats = X_test.shape[1] 51 | 52 | # Normalize data. 53 | for z in [X_train, X_test]: 54 | ndata = z.shape[0] 55 | gap = 1./(ndata+1) 56 | for i in range(nfeats): 57 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 58 | 59 | 60 | def get_optim(name, net, args): 61 | if name == 'SGD': 62 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 63 | else: 64 | assert False 65 | 66 | return optimizer 67 | 68 | 69 | @ex.config 70 | def cfg(): 71 | x_index = 0 72 | y_index = 1 73 | 74 | x_flip, y_flip = False, False 75 | 76 | optim_name = 'SGD' 77 | optim_args = \ 78 | { 79 | 'lr': 1e-5, 80 | 'momentum': 0.9 81 | } 82 | num_epochs = 10000000 83 | batch_size = 200 84 | chkpt_freq = 500 85 | 86 | Phi = DiracPhi 87 | phi_name = 'DiracPhi' 88 | 89 | # Initial parameters. 90 | depth = 2 91 | widths = [10, 10] 92 | lc_w_range = (0, 1.0) 93 | shift_w_range = (0., 2.0) 94 | 95 | # Fraction of random data added 96 | frac_rand = 0.01 97 | 98 | 99 | @ex.capture 100 | def get_info(_run): 101 | return _run._id 102 | 103 | 104 | def expt(train_data, val_data, 105 | net, 106 | optim_name, 107 | optim_args, 108 | identifier, 109 | num_epochs=1000, 110 | batch_size=100, 111 | chkpt_freq=50, 112 | ): 113 | 114 | os.mkdir('./checkpoints/%s' % identifier) 115 | os.mkdir('./sample_figs/%s' % identifier) 116 | 117 | train_loader = DataLoader( 118 | train_data, batch_size=batch_size, shuffle=True) 119 | val_loader = DataLoader( 120 | val_data, batch_size=1000000, shuffle=True) 121 | 122 | optimizer = get_optim(optim_name, net, optim_args) 123 | 124 | train_loss_per_epoch = [] 125 | 126 | for epoch in range(num_epochs): 127 | loss_per_minibatch = [] 128 | for i, data in enumerate(train_loader, 0): 129 | optimizer.zero_grad() 130 | 131 | d = torch.tensor(data, requires_grad=True) 132 | p = net(d, mode='pdf') 133 | 134 | logloss = -torch.sum(torch.log(p)) 135 | reg_loss = logloss 136 | reg_loss.backward() 137 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 138 | 139 | loss_per_minibatch.append(scalar_loss) 140 | optimizer.step() 141 | 142 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 143 | print('Training loss at epoch %s: %s' % 144 | (epoch, train_loss_per_epoch[-1])) 145 | 146 | if epoch % chkpt_freq == 0: 147 | print('Checkpointing') 148 | torch.save({ 149 | 'epoch': epoch, 150 | 'model_state_dict': net.state_dict(), 151 | 'optimizer_state_dict': optimizer.state_dict(), 152 | 'loss': logloss, 153 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 154 | 155 | if True: 156 | print('Scatter sampling') 157 | samples = sample(net, 2, 1000) 158 | plt.scatter(samples[:, 0], samples[:, 1]) 159 | plt.savefig('./sample_figs/%s/epoch%s.png' % 160 | (identifier, epoch)) 161 | plt.clf() 162 | 163 | print('Evaluating validation loss') 164 | for j, val_data in enumerate(val_loader, 0): 165 | net.zero_grad() 166 | val_p = net(val_data, mode='pdf') 167 | val_loss = -torch.mean(torch.log(val_p)) 168 | print('Average validation loss %s' % val_loss) 169 | 170 | 171 | @ex.automain 172 | def run(x_index, y_index, 173 | x_flip, y_flip, 174 | Phi, 175 | depth, widths, lc_w_range, shift_w_range, 176 | optim_name, optim_args, 177 | num_epochs, batch_size, chkpt_freq, 178 | frac_rand): 179 | id = get_info() 180 | identifier_id = '%s%s' % (identifier, id) 181 | 182 | train_data = X_train[:, [x_index, y_index]] 183 | train_data = add_train_random_noise(train_data, 184 | int(train_data.shape[0]*frac_rand)) 185 | test_data = X_test[:, [x_index, y_index]] 186 | 187 | if x_flip: 188 | train_data[:, 0] = 1-train_data[:, 0] 189 | test_data[:, 0] = 1-test_data[:, 0] 190 | 191 | if y_flip: 192 | train_data[:, 1] = 1-train_data[:, 1] 193 | test_data[:, 1] = 1-test_data[:, 1] 194 | 195 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 196 | net = Copula(phi) 197 | expt(train_data, test_data, net, optim_name, 198 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 199 | 200 | 201 | if __name__ == '__main__': 202 | print('Sample usage: python -m train_scripts.rdj.train -F learn_rdj') 203 | -------------------------------------------------------------------------------- /train_scripts/rdj/train_with_clayton.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | from main import Copula 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import torch.optim as optim 6 | import numpy as np 7 | import os 8 | from main import sample 9 | from sklearn.model_selection import train_test_split 10 | import scipy 11 | 12 | 13 | intel_f = open('data/rdj/INTEL.data', 'r') 14 | intel = np.array(list(map(float, intel_f.readlines()))) 15 | 16 | ms_f = open('data/rdj/MS.data', 'r') 17 | ms = np.array(list(map(float, ms_f.readlines()))) 18 | 19 | ge_f = open('data/rdj/GE.data', 'r') 20 | ge = np.array(list(map(float, ge_f.readlines()))) 21 | 22 | identifier = 'rdj_stocks_clayton' 23 | ex = Experiment('rdj_stocks_clayton') 24 | 25 | torch.set_default_tensor_type(torch.DoubleTensor) 26 | 27 | X = np.concatenate((intel[:, None], ms[:, None]), axis=1) 28 | 29 | 30 | def add_train_random_noise(data, num_adds): 31 | new_data = np.random.rand(num_adds, data.shape[1]) 32 | print(data.shape) 33 | print(new_data.shape) 34 | return np.concatenate((data, new_data), axis=0) 35 | 36 | 37 | X_train, X_test, _, _ = train_test_split( 38 | X, X, shuffle=True, random_state=142857) 39 | # X_train, X_test, _, _ = train_test_split( 40 | # X, X, shuffle=True, random_state=714285) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=571428) 43 | # X_train, X_test, _, _ = train_test_split( 44 | # X, X, shuffle=True, random_state=857142) 45 | # X_train, X_test, _, _ = train_test_split( 46 | # X, X, shuffle=True, random_state=285714) 47 | 48 | nfeats = X_test.shape[1] 49 | 50 | # Normalize data. 51 | for z in [X_train, X_test]: 52 | ndata = z.shape[0] 53 | gap = 1./(ndata+1) 54 | for i in range(nfeats): 55 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 56 | 57 | def get_optim(name, net, args): 58 | if name == 'SGD': 59 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 60 | else: 61 | assert False 62 | 63 | return optimizer 64 | 65 | 66 | @ex.config 67 | def cfg(): 68 | x_index = 0 69 | y_index = 1 70 | 71 | x_flip, y_flip = False, False 72 | 73 | optim_name = 'SGD' 74 | optim_args = \ 75 | { 76 | 'lr': 1e-5, 77 | 'momentum': 0.9 78 | } 79 | num_epochs = 10000000 80 | batch_size = 200 81 | chkpt_freq = 50 82 | 83 | from phi_listing import ClaytonPhi 84 | Phi = ClaytonPhi 85 | phi_name = 'ClaytonPhi' 86 | 87 | # Initial parameters. 88 | initial_theta = 1.1 89 | 90 | # Fraction of random data added 91 | frac_rand = 0.01 92 | 93 | @ex.capture 94 | def get_info(_run): 95 | return _run._id 96 | 97 | 98 | def expt(train_data, val_data, 99 | net, 100 | optim_name, 101 | optim_args, 102 | identifier, 103 | num_epochs=1000, 104 | batch_size=100, 105 | chkpt_freq=50, 106 | ): 107 | 108 | os.mkdir('./checkpoints/%s' % identifier) 109 | os.mkdir('./sample_figs/%s' % identifier) 110 | 111 | train_loader = DataLoader( 112 | train_data, batch_size=batch_size, shuffle=True) 113 | val_loader = DataLoader( 114 | val_data, batch_size=1000000, shuffle=True) 115 | 116 | optimizer = get_optim(optim_name, net, optim_args) 117 | 118 | train_loss_per_epoch = [] 119 | 120 | for epoch in range(num_epochs): 121 | loss_per_minibatch = [] 122 | for i, data in enumerate(train_loader, 0): 123 | optimizer.zero_grad() 124 | 125 | d = torch.tensor(data, requires_grad=True) 126 | p = net(d, mode='pdf') 127 | 128 | logloss = -torch.sum(torch.log(p)) 129 | reg_loss = logloss 130 | reg_loss.backward() 131 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 132 | 133 | loss_per_minibatch.append(scalar_loss) 134 | optimizer.step() 135 | 136 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 137 | print('Training loss at epoch %s: %s' % 138 | (epoch, train_loss_per_epoch[-1])) 139 | 140 | if epoch % chkpt_freq == 0: 141 | print('Checkpointing') 142 | torch.save({ 143 | 'epoch': epoch, 144 | 'model_state_dict': net.state_dict(), 145 | 'optimizer_state_dict': optimizer.state_dict(), 146 | 'loss': logloss, 147 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 148 | 149 | print('Evaluating validation loss') 150 | for j, val_data in enumerate(val_loader, 0): 151 | net.zero_grad() 152 | val_p = net(val_data, mode='pdf') 153 | val_loss = -torch.mean(torch.log(val_p)) 154 | print('Average validation loss %s' % val_loss) 155 | 156 | 157 | @ex.automain 158 | def run(x_index, y_index, 159 | x_flip, y_flip, 160 | Phi, 161 | initial_theta, 162 | optim_name, optim_args, 163 | num_epochs, batch_size, chkpt_freq, 164 | frac_rand): 165 | id = get_info() 166 | identifier_id = '%s%s' % (identifier, id) 167 | 168 | train_data = X_train[:, [x_index, y_index]] 169 | train_data = add_train_random_noise(train_data, 170 | int(train_data.shape[0]*frac_rand)) 171 | test_data = X_test[:, [x_index, y_index]] 172 | 173 | if x_flip: 174 | train_data[:, 0] = 1-train_data[:, 0] 175 | test_data[:, 0] = 1-test_data[:, 0] 176 | 177 | if y_flip: 178 | train_data[:, 1] = 1-train_data[:, 1] 179 | test_data[:, 1] = 1-test_data[:, 1] 180 | 181 | phi = Phi(torch.tensor(initial_theta)) 182 | net = Copula(phi) 183 | expt(train_data, test_data, net, optim_name, 184 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 185 | 186 | 187 | if __name__ == '__main__': 188 | print('Sample usage: python -m train_scripts.rdj.train_with_clayton -F learn_rdj_with_clayton') 189 | -------------------------------------------------------------------------------- /train_scripts/rdj/train_with_frank.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | from main import Copula 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import torch.optim as optim 6 | import numpy as np 7 | import os 8 | from main import sample 9 | from sklearn.model_selection import train_test_split 10 | import scipy 11 | 12 | 13 | intel_f = open('data/rdj/INTEL.data', 'r') 14 | intel = np.array(list(map(float, intel_f.readlines()))) 15 | 16 | ms_f = open('data/rdj/MS.data', 'r') 17 | ms = np.array(list(map(float, ms_f.readlines()))) 18 | 19 | ge_f = open('data/rdj/GE.data', 'r') 20 | ge = np.array(list(map(float, ge_f.readlines()))) 21 | 22 | identifier = 'rdj_stocks_frank' 23 | ex = Experiment('rdj_stocks_frank') 24 | 25 | torch.set_default_tensor_type(torch.DoubleTensor) 26 | 27 | X = np.concatenate((intel[:, None], ms[:, None]), axis=1) 28 | 29 | 30 | def add_train_random_noise(data, num_adds): 31 | new_data = np.random.rand(num_adds, data.shape[1]) 32 | print(data.shape) 33 | print(new_data.shape) 34 | return np.concatenate((data, new_data), axis=0) 35 | 36 | 37 | X_train, X_test, _, _ = train_test_split( 38 | X, X, shuffle=True, random_state=142857) 39 | # X_train, X_test, _, _ = train_test_split( 40 | # X, X, shuffle=True, random_state=714285) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=571428) 43 | # X_train, X_test, _, _ = train_test_split( 44 | # X, X, shuffle=True, random_state=857142) 45 | # X_train, X_test, _, _ = train_test_split( 46 | # X, X, shuffle=True, random_state=285714) 47 | 48 | nfeats = X_test.shape[1] 49 | 50 | # Normalize data. 51 | for z in [X_train, X_test]: 52 | ndata = z.shape[0] 53 | gap = 1./(ndata+1) 54 | for i in range(nfeats): 55 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 56 | 57 | def get_optim(name, net, args): 58 | if name == 'SGD': 59 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 60 | else: 61 | assert False 62 | 63 | return optimizer 64 | 65 | 66 | @ex.config 67 | def cfg(): 68 | x_index = 0 69 | y_index = 1 70 | 71 | x_flip, y_flip = False, False 72 | 73 | optim_name = 'SGD' 74 | optim_args = \ 75 | { 76 | 'lr': 1e-5, 77 | 'momentum': 0.9 78 | } 79 | num_epochs = 10000000 80 | batch_size = 200 81 | chkpt_freq = 50 82 | 83 | from phi_listing import FrankPhi 84 | Phi = FrankPhi 85 | phi_name = 'FrankPhi' 86 | 87 | # Initial parameters. 88 | initial_theta = 1.1 89 | 90 | # Fraction of random data added 91 | frac_rand = 0.01 92 | 93 | @ex.capture 94 | def get_info(_run): 95 | return _run._id 96 | 97 | 98 | def expt(train_data, val_data, 99 | net, 100 | optim_name, 101 | optim_args, 102 | identifier, 103 | num_epochs=1000, 104 | batch_size=100, 105 | chkpt_freq=50, 106 | ): 107 | 108 | os.mkdir('./checkpoints/%s' % identifier) 109 | os.mkdir('./sample_figs/%s' % identifier) 110 | 111 | train_loader = DataLoader( 112 | train_data, batch_size=batch_size, shuffle=True) 113 | val_loader = DataLoader( 114 | val_data, batch_size=1000000, shuffle=True) 115 | 116 | optimizer = get_optim(optim_name, net, optim_args) 117 | 118 | train_loss_per_epoch = [] 119 | 120 | for epoch in range(num_epochs): 121 | loss_per_minibatch = [] 122 | for i, data in enumerate(train_loader, 0): 123 | optimizer.zero_grad() 124 | 125 | d = torch.tensor(data, requires_grad=True) 126 | p = net(d, mode='pdf') 127 | 128 | logloss = -torch.sum(torch.log(p)) 129 | reg_loss = logloss 130 | reg_loss.backward() 131 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 132 | 133 | loss_per_minibatch.append(scalar_loss) 134 | optimizer.step() 135 | 136 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 137 | print('Training loss at epoch %s: %s' % 138 | (epoch, train_loss_per_epoch[-1])) 139 | 140 | if epoch % chkpt_freq == 0: 141 | print('Checkpointing') 142 | torch.save({ 143 | 'epoch': epoch, 144 | 'model_state_dict': net.state_dict(), 145 | 'optimizer_state_dict': optimizer.state_dict(), 146 | 'loss': logloss, 147 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 148 | 149 | print('Evaluating validation loss') 150 | for j, val_data in enumerate(val_loader, 0): 151 | net.zero_grad() 152 | val_p = net(val_data, mode='pdf') 153 | val_loss = -torch.mean(torch.log(val_p)) 154 | print('Average validation loss %s' % val_loss) 155 | 156 | 157 | @ex.automain 158 | def run(x_index, y_index, 159 | x_flip, y_flip, 160 | Phi, 161 | initial_theta, 162 | optim_name, optim_args, 163 | num_epochs, batch_size, chkpt_freq, 164 | frac_rand): 165 | id = get_info() 166 | identifier_id = '%s%s' % (identifier, id) 167 | 168 | train_data = X_train[:, [x_index, y_index]] 169 | train_data = add_train_random_noise(train_data, 170 | int(train_data.shape[0]*frac_rand)) 171 | test_data = X_test[:, [x_index, y_index]] 172 | 173 | if x_flip: 174 | train_data[:, 0] = 1-train_data[:, 0] 175 | test_data[:, 0] = 1-test_data[:, 0] 176 | 177 | if y_flip: 178 | train_data[:, 1] = 1-train_data[:, 1] 179 | test_data[:, 1] = 1-test_data[:, 1] 180 | 181 | phi = Phi(torch.tensor(initial_theta)) 182 | net = Copula(phi) 183 | expt(train_data, test_data, net, optim_name, 184 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 185 | 186 | 187 | if __name__ == '__main__': 188 | print('Sample usage: python -m train_scripts.rdj.train_with_frank -F learn_rdj_with_frank') 189 | -------------------------------------------------------------------------------- /train_scripts/rdj/train_with_gumbel.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | from main import Copula 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import torch.optim as optim 6 | import numpy as np 7 | import os 8 | from main import sample 9 | from sklearn.model_selection import train_test_split 10 | import scipy 11 | 12 | 13 | intel_f = open('data/rdj/INTEL.data', 'r') 14 | intel = np.array(list(map(float, intel_f.readlines()))) 15 | 16 | ms_f = open('data/rdj/MS.data', 'r') 17 | ms = np.array(list(map(float, ms_f.readlines()))) 18 | 19 | ge_f = open('data/rdj/GE.data', 'r') 20 | ge = np.array(list(map(float, ge_f.readlines()))) 21 | 22 | identifier = 'rdj_stocks_gumbel' 23 | ex = Experiment('rdj_stocks_gumbel') 24 | 25 | torch.set_default_tensor_type(torch.DoubleTensor) 26 | 27 | X = np.concatenate((intel[:, None], ms[:, None]), axis=1) 28 | 29 | 30 | def add_train_random_noise(data, num_adds): 31 | new_data = np.random.rand(num_adds, data.shape[1]) 32 | print(data.shape) 33 | print(new_data.shape) 34 | return np.concatenate((data, new_data), axis=0) 35 | 36 | 37 | X_train, X_test, _, _ = train_test_split( 38 | X, X, shuffle=True, random_state=142857) 39 | # X_train, X_test, _, _ = train_test_split( 40 | # X, X, shuffle=True, random_state=714285) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=571428) 43 | # X_train, X_test, _, _ = train_test_split( 44 | # X, X, shuffle=True, random_state=857142) 45 | # X_train, X_test, _, _ = train_test_split( 46 | # X, X, shuffle=True, random_state=285714) 47 | 48 | nfeats = X_test.shape[1] 49 | 50 | # Normalize data. 51 | for z in [X_train, X_test]: 52 | ndata = z.shape[0] 53 | gap = 1./(ndata+1) 54 | for i in range(nfeats): 55 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 56 | 57 | def get_optim(name, net, args): 58 | if name == 'SGD': 59 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 60 | else: 61 | assert False 62 | 63 | return optimizer 64 | 65 | 66 | @ex.config 67 | def cfg(): 68 | x_index = 0 69 | y_index = 1 70 | 71 | x_flip, y_flip = False, False 72 | 73 | optim_name = 'SGD' 74 | optim_args = \ 75 | { 76 | 'lr': 1e-5, 77 | 'momentum': 0.9 78 | } 79 | num_epochs = 10000000 80 | batch_size = 200 81 | chkpt_freq = 50 82 | 83 | from phi_listing import GumbelPhi 84 | Phi = GumbelPhi 85 | phi_name = 'GumbelPhi' 86 | 87 | # Initial parameters. 88 | initial_theta = 1.1 89 | 90 | # Fraction of random data added 91 | frac_rand = 0.01 92 | 93 | @ex.capture 94 | def get_info(_run): 95 | return _run._id 96 | 97 | 98 | def expt(train_data, val_data, 99 | net, 100 | optim_name, 101 | optim_args, 102 | identifier, 103 | num_epochs=1000, 104 | batch_size=100, 105 | chkpt_freq=50, 106 | ): 107 | 108 | os.mkdir('./checkpoints/%s' % identifier) 109 | os.mkdir('./sample_figs/%s' % identifier) 110 | 111 | train_loader = DataLoader( 112 | train_data, batch_size=batch_size, shuffle=True) 113 | val_loader = DataLoader( 114 | val_data, batch_size=1000000, shuffle=True) 115 | 116 | optimizer = get_optim(optim_name, net, optim_args) 117 | 118 | train_loss_per_epoch = [] 119 | 120 | for epoch in range(num_epochs): 121 | loss_per_minibatch = [] 122 | for i, data in enumerate(train_loader, 0): 123 | optimizer.zero_grad() 124 | 125 | d = torch.tensor(data, requires_grad=True) 126 | p = net(d, mode='pdf') 127 | 128 | logloss = -torch.sum(torch.log(p)) 129 | reg_loss = logloss 130 | reg_loss.backward() 131 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 132 | 133 | loss_per_minibatch.append(scalar_loss) 134 | optimizer.step() 135 | 136 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 137 | print('Training loss at epoch %s: %s' % 138 | (epoch, train_loss_per_epoch[-1])) 139 | 140 | if epoch % chkpt_freq == 0: 141 | print('Checkpointing') 142 | torch.save({ 143 | 'epoch': epoch, 144 | 'model_state_dict': net.state_dict(), 145 | 'optimizer_state_dict': optimizer.state_dict(), 146 | 'loss': logloss, 147 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 148 | 149 | print('Evaluating validation loss') 150 | for j, val_data in enumerate(val_loader, 0): 151 | net.zero_grad() 152 | val_p = net(val_data, mode='pdf') 153 | val_loss = -torch.mean(torch.log(val_p)) 154 | print('Average validation loss %s' % val_loss) 155 | 156 | 157 | @ex.automain 158 | def run(x_index, y_index, 159 | x_flip, y_flip, 160 | Phi, 161 | initial_theta, 162 | optim_name, optim_args, 163 | num_epochs, batch_size, chkpt_freq, 164 | frac_rand): 165 | id = get_info() 166 | identifier_id = '%s%s' % (identifier, id) 167 | 168 | train_data = X_train[:, [x_index, y_index]] 169 | train_data = add_train_random_noise(train_data, 170 | int(train_data.shape[0]*frac_rand)) 171 | test_data = X_test[:, [x_index, y_index]] 172 | 173 | if x_flip: 174 | train_data[:, 0] = 1-train_data[:, 0] 175 | test_data[:, 0] = 1-test_data[:, 0] 176 | 177 | if y_flip: 178 | train_data[:, 1] = 1-train_data[:, 1] 179 | test_data[:, 1] = 1-test_data[:, 1] 180 | 181 | phi = Phi(torch.tensor(initial_theta)) 182 | net = Copula(phi) 183 | expt(train_data, test_data, net, optim_name, 184 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 185 | 186 | 187 | if __name__ == '__main__': 188 | print('Sample usage: python -m train_scripts.rdj.train_with_gumbel -F learn_rdj_with_gumbel') 189 | -------------------------------------------------------------------------------- /train_scripts/stocks/train.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | from dirac_phi import DiracPhi 5 | from main import sample 6 | import os 7 | import numpy as np 8 | import torch.optim as optim 9 | import torch 10 | import matplotlib.pyplot as plt 11 | from torch.utils.data import DataLoader 12 | from main import Copula 13 | 14 | goog_f = open('data/stocks/goog/close.vals', 'r') 15 | goog = np.array(list(map(float, goog_f.readlines()))) 16 | 17 | fb_f = open('data/stocks/fb/close.vals', 'r') 18 | fb = np.array(list(map(float, fb_f.readlines()))) 19 | 20 | 21 | identifier = 'goog_fb_stocks' 22 | ex = Experiment('goog_fb_stocks') 23 | 24 | torch.set_default_tensor_type(torch.DoubleTensor) 25 | 26 | 27 | def add_train_random_noise(data, num_adds): 28 | new_data = np.random.rand(num_adds, data.shape[1]) 29 | print(data.shape) 30 | print(new_data.shape) 31 | return np.concatenate((data, new_data), axis=0) 32 | 33 | 34 | X = np.concatenate((goog[:, None], fb[:, None]), axis=1) 35 | X_train, X_test, _, _ = train_test_split( 36 | X, X, shuffle=True, random_state=142857) 37 | # X_train, X_test, _, _ = train_test_split( 38 | # X, X, shuffle=True, random_state=714285) 39 | # X_train, X_test, _, _ = train_test_split( 40 | # X, X, shuffle=True, random_state=571428) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=857142) 43 | # X_train, X_test, _, _ = train_test_split( 44 | # X, X, shuffle=True, random_state=285714) 45 | 46 | 47 | nfeats = X_test.shape[1] 48 | 49 | # Normalize data. 50 | for z in [X_train, X_test]: 51 | ndata = z.shape[0] 52 | gap = 1./(ndata+1) 53 | for i in range(nfeats): 54 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 55 | 56 | w = X.copy() 57 | ndata = w.shape[0] 58 | gap = 1./(ndata+1) 59 | for i in range(nfeats): 60 | w[:, i] = scipy.stats.rankdata(w[:, i], 'ordinal')*gap 61 | # np.savetxt('stocks_transformed.points', w) 62 | 63 | 64 | def get_optim(name, net, args): 65 | if name == 'SGD': 66 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 67 | else: 68 | assert False 69 | 70 | return optimizer 71 | 72 | 73 | @ex.config 74 | def cfg(): 75 | x_index = 0 76 | y_index = 1 77 | 78 | x_flip, y_flip = False, False 79 | 80 | optim_name = 'SGD' 81 | optim_args = \ 82 | { 83 | 'lr': 1e-5, 84 | 'momentum': 0.9 85 | } 86 | num_epochs = 10000000 87 | batch_size = 200 88 | chkpt_freq = 500 89 | 90 | Phi = DiracPhi 91 | phi_name = 'DiracPhi' 92 | 93 | # Initial parameters. 94 | depth = 2 95 | widths = [10, 10] 96 | lc_w_range = (0, 1.0) 97 | shift_w_range = (0., 2.0) 98 | 99 | # Fraction of random noise 100 | frac_rand = 0.01 101 | 102 | 103 | @ex.capture 104 | def get_info(_run): 105 | return _run._id 106 | 107 | 108 | def expt(train_data, val_data, 109 | net, 110 | optim_name, 111 | optim_args, 112 | identifier, 113 | num_epochs=1000, 114 | batch_size=100, 115 | chkpt_freq=50, 116 | ): 117 | 118 | os.mkdir('./checkpoints/%s' % identifier) 119 | os.mkdir('./sample_figs/%s' % identifier) 120 | 121 | train_loader = DataLoader( 122 | train_data, batch_size=batch_size, shuffle=True) 123 | val_loader = DataLoader( 124 | val_data, batch_size=1000000, shuffle=True) 125 | 126 | optimizer = get_optim(optim_name, net, optim_args) 127 | 128 | train_loss_per_epoch = [] 129 | 130 | for epoch in range(num_epochs): 131 | loss_per_minibatch = [] 132 | for i, data in enumerate(train_loader, 0): 133 | optimizer.zero_grad() 134 | 135 | d = torch.tensor(data, requires_grad=True) 136 | p = net(d, mode='pdf') 137 | 138 | logloss = -torch.sum(torch.log(p)) 139 | reg_loss = logloss 140 | reg_loss.backward() 141 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 142 | 143 | loss_per_minibatch.append(scalar_loss) 144 | optimizer.step() 145 | 146 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 147 | print('Training loss at epoch %s: %s' % 148 | (epoch, train_loss_per_epoch[-1])) 149 | 150 | if epoch % chkpt_freq == 0: 151 | print('Checkpointing') 152 | torch.save({ 153 | 'epoch': epoch, 154 | 'model_state_dict': net.state_dict(), 155 | 'optimizer_state_dict': optimizer.state_dict(), 156 | 'loss': logloss, 157 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 158 | 159 | if True: 160 | print('Scatter sampling') 161 | samples = sample(net, 2, 1000) 162 | plt.scatter(samples[:, 0], samples[:, 1]) 163 | plt.savefig('./sample_figs/%s/epoch%s.png' % 164 | (identifier, epoch)) 165 | plt.clf() 166 | 167 | 168 | print('Evaluating validation loss') 169 | for j, val_data in enumerate(val_loader, 0): 170 | net.zero_grad() 171 | val_p = net(val_data, mode='pdf') 172 | val_loss = -torch.mean(torch.log(val_p)) 173 | print('Average validation loss %s' % val_loss) 174 | 175 | 176 | @ex.automain 177 | def run(x_index, y_index, 178 | x_flip, y_flip, 179 | Phi, 180 | depth, widths, lc_w_range, shift_w_range, 181 | optim_name, optim_args, 182 | num_epochs, batch_size, chkpt_freq, 183 | frac_rand): 184 | id = get_info() 185 | identifier_id = '%s%s' % (identifier, id) 186 | 187 | train_data = X_train[:, [x_index, y_index]] 188 | train_data = add_train_random_noise(train_data, int(train_data.shape[0]*frac_rand)) 189 | test_data = X_test[:, [x_index, y_index]] 190 | 191 | if x_flip: 192 | train_data[:, 0] = 1-train_data[:, 0] 193 | test_data[:, 0] = 1-test_data[:, 0] 194 | 195 | if y_flip: 196 | train_data[:, 1] = 1-train_data[:, 1] 197 | test_data[:, 1] = 1-test_data[:, 1] 198 | 199 | phi = Phi(depth, widths, lc_w_range, shift_w_range) 200 | net = Copula(phi) 201 | expt(train_data, test_data, net, optim_name, 202 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 203 | 204 | 205 | if __name__ == '__main__': 206 | print('Sample usage: python -m train_scripts.stocks.train -F learn_stocks') 207 | -------------------------------------------------------------------------------- /train_scripts/stocks/train_with_clayton.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | import os 5 | import numpy as np 6 | import torch.optim as optim 7 | import torch 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | from main import Copula 11 | 12 | goog_f = open('data/stocks/goog/close.vals', 'r') 13 | goog = np.array(list(map(float, goog_f.readlines()))) 14 | 15 | fb_f = open('data/stocks/fb/close.vals', 'r') 16 | fb = np.array(list(map(float, fb_f.readlines()))) 17 | 18 | 19 | identifier = 'goog_fb_stocks_clayton' 20 | ex = Experiment('goog_fb_stocks_clayton') 21 | 22 | torch.set_default_tensor_type(torch.DoubleTensor) 23 | 24 | 25 | def add_train_random_noise(data, num_adds): 26 | new_data = np.random.rand(num_adds, data.shape[1]) 27 | print(data.shape) 28 | print(new_data.shape) 29 | return np.concatenate((data, new_data), axis=0) 30 | 31 | 32 | X = np.concatenate((goog[:, None], fb[:, None]), axis=1) 33 | X_train, X_test, _, _ = train_test_split( 34 | X, X, shuffle=True, random_state=142857) 35 | # X_train, X_test, _, _ = train_test_split( 36 | # X, X, shuffle=True, random_state=714285) 37 | # X_train, X_test, _, _ = train_test_split( 38 | # X, X, shuffle=True, random_state=571428) 39 | # X_train, X_test, _, _ = train_test_split( 40 | # X, X, shuffle=True, random_state=857142) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=285714) 43 | 44 | nfeats = X_test.shape[1] 45 | 46 | # Normalize data. 47 | for z in [X_train, X_test]: 48 | ndata = z.shape[0] 49 | gap = 1./(ndata+1) 50 | for i in range(nfeats): 51 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 52 | 53 | w = X.copy() 54 | ndata = w.shape[0] 55 | gap = 1./(ndata+1) 56 | for i in range(nfeats): 57 | w[:, i] = scipy.stats.rankdata(w[:, i], 'ordinal')*gap 58 | # np.savetxt('stocks_transformed.points', w) 59 | 60 | 61 | def get_optim(name, net, args): 62 | if name == 'SGD': 63 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 64 | else: 65 | assert False 66 | 67 | return optimizer 68 | 69 | 70 | @ex.config 71 | def cfg(): 72 | x_index = 0 73 | y_index = 1 74 | 75 | x_flip, y_flip = False, False 76 | 77 | optim_name = 'SGD' 78 | optim_args = \ 79 | { 80 | 'lr': 1e-5, 81 | 'momentum': 0.9 82 | } 83 | num_epochs = 10000000 84 | # batch_size = 1000 85 | batch_size = 200 86 | chkpt_freq = 500 87 | 88 | from phi_listing import ClaytonPhi 89 | Phi = ClaytonPhi 90 | phi_name = 'ClaytonPhi' 91 | 92 | # Initial parameters. 93 | initial_theta = 5.0 94 | 95 | frac_rand = 0.01 96 | 97 | @ex.capture 98 | def get_info(_run): 99 | return _run._id 100 | 101 | 102 | def expt(train_data, val_data, 103 | net, 104 | optim_name, 105 | optim_args, 106 | identifier, 107 | num_epochs=1000, 108 | batch_size=100, 109 | chkpt_freq=50, 110 | ): 111 | 112 | os.mkdir('./checkpoints/%s' % identifier) 113 | os.mkdir('./sample_figs/%s' % identifier) 114 | # os.mkdir('./psi_figs/%s' % identifier) 115 | 116 | train_loader = DataLoader( 117 | train_data, batch_size=batch_size, shuffle=True) 118 | val_loader = DataLoader( 119 | val_data, batch_size=1000000, shuffle=True) 120 | 121 | optimizer = get_optim(optim_name, net, optim_args) 122 | 123 | train_loss_per_epoch = [] 124 | 125 | for epoch in range(num_epochs): 126 | loss_per_minibatch = [] 127 | for i, data in enumerate(train_loader, 0): 128 | optimizer.zero_grad() 129 | 130 | d = torch.tensor(data, requires_grad=True) 131 | p = net(d, mode='pdf') 132 | 133 | logloss = -torch.sum(torch.log(p)) 134 | reg_loss = logloss 135 | reg_loss.backward() 136 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 137 | 138 | loss_per_minibatch.append(scalar_loss) 139 | optimizer.step() 140 | 141 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 142 | print('Training loss at epoch %s: %s' % 143 | (epoch, train_loss_per_epoch[-1])) 144 | 145 | if epoch % chkpt_freq == 0: 146 | print('Checkpointing') 147 | torch.save({ 148 | 'epoch': epoch, 149 | 'model_state_dict': net.state_dict(), 150 | 'optimizer_state_dict': optimizer.state_dict(), 151 | 'loss': logloss, 152 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 153 | 154 | print('Evaluating validation loss') 155 | for j, val_data in enumerate(val_loader, 0): 156 | net.zero_grad() 157 | val_p = net(val_data, mode='pdf') 158 | val_loss = -torch.mean(torch.log(val_p)) 159 | print('Average validation loss %s' % val_loss) 160 | 161 | 162 | @ex.automain 163 | def run(x_index, y_index, 164 | x_flip, y_flip, 165 | Phi, 166 | initial_theta, 167 | optim_name, optim_args, 168 | num_epochs, batch_size, chkpt_freq, 169 | frac_rand): 170 | id = get_info() 171 | identifier_id = '%s%s' % (identifier, id) 172 | 173 | train_data = X_train[:, [x_index, y_index]] 174 | train_data = add_train_random_noise(train_data, int(train_data.shape[0]*frac_rand)) 175 | test_data = X_test[:, [x_index, y_index]] 176 | 177 | if x_flip: 178 | train_data[:, 0] = 1-train_data[:, 0] 179 | test_data[:, 0] = 1-test_data[:, 0] 180 | 181 | if y_flip: 182 | train_data[:, 1] = 1-train_data[:, 1] 183 | test_data[:, 1] = 1-test_data[:, 1] 184 | 185 | phi = Phi(torch.tensor(initial_theta)) 186 | net = Copula(phi) 187 | expt(train_data, test_data, net, optim_name, 188 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 189 | 190 | 191 | if __name__ == '__main__': 192 | print('Sample usage: python -m train_scripts.stocks.train_with_clayton -F learn_stocks_with_clayton') 193 | -------------------------------------------------------------------------------- /train_scripts/stocks/train_with_frank.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | import os 5 | import numpy as np 6 | import torch.optim as optim 7 | import torch 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | from main import Copula 11 | 12 | goog_f = open('data/stocks/goog/close.vals', 'r') 13 | goog = np.array(list(map(float, goog_f.readlines()))) 14 | 15 | fb_f = open('data/stocks/fb/close.vals', 'r') 16 | fb = np.array(list(map(float, fb_f.readlines()))) 17 | 18 | 19 | identifier = 'goog_fb_stocks_frank' 20 | ex = Experiment('goog_fb_stocks_frank') 21 | 22 | torch.set_default_tensor_type(torch.DoubleTensor) 23 | 24 | 25 | def add_train_random_noise(data, num_adds): 26 | new_data = np.random.rand(num_adds, data.shape[1]) 27 | print(data.shape) 28 | print(new_data.shape) 29 | return np.concatenate((data, new_data), axis=0) 30 | 31 | 32 | X = np.concatenate((goog[:, None], fb[:, None]), axis=1) 33 | X_train, X_test, _, _ = train_test_split( 34 | X, X, shuffle=True, random_state=142857) 35 | # X_train, X_test, _, _ = train_test_split( 36 | # X, X, shuffle=True, random_state=714285) 37 | # X_train, X_test, _, _ = train_test_split( 38 | # X, X, shuffle=True, random_state=571428) 39 | # X_train, X_test, _, _ = train_test_split( 40 | # X, X, shuffle=True, random_state=857142) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=285714) 43 | 44 | nfeats = X_test.shape[1] 45 | 46 | # Normalize data. 47 | for z in [X_train, X_test]: 48 | ndata = z.shape[0] 49 | gap = 1./(ndata+1) 50 | for i in range(nfeats): 51 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 52 | 53 | w = X.copy() 54 | ndata = w.shape[0] 55 | gap = 1./(ndata+1) 56 | for i in range(nfeats): 57 | w[:, i] = scipy.stats.rankdata(w[:, i], 'ordinal')*gap 58 | # np.savetxt('stocks_transformed.points', w) 59 | 60 | 61 | def get_optim(name, net, args): 62 | if name == 'SGD': 63 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 64 | else: 65 | assert False 66 | 67 | return optimizer 68 | 69 | 70 | @ex.config 71 | def cfg(): 72 | x_index = 0 73 | y_index = 1 74 | 75 | x_flip, y_flip = False, False 76 | 77 | optim_name = 'SGD' 78 | optim_args = \ 79 | { 80 | 'lr': 1e-5, 81 | 'momentum': 0.9 82 | } 83 | num_epochs = 10000000 84 | # batch_size = 1000 85 | batch_size = 200 86 | chkpt_freq = 500 87 | 88 | from phi_listing import FrankPhi 89 | Phi = FrankPhi 90 | phi_name = 'FrankPhi' 91 | 92 | # Initial parameters. 93 | initial_theta = 5.0 94 | 95 | frac_rand = 0.01 96 | 97 | @ex.capture 98 | def get_info(_run): 99 | return _run._id 100 | 101 | 102 | def expt(train_data, val_data, 103 | net, 104 | optim_name, 105 | optim_args, 106 | identifier, 107 | num_epochs=1000, 108 | batch_size=100, 109 | chkpt_freq=50, 110 | ): 111 | 112 | os.mkdir('./checkpoints/%s' % identifier) 113 | os.mkdir('./sample_figs/%s' % identifier) 114 | # os.mkdir('./psi_figs/%s' % identifier) 115 | 116 | train_loader = DataLoader( 117 | train_data, batch_size=batch_size, shuffle=True) 118 | val_loader = DataLoader( 119 | val_data, batch_size=1000000, shuffle=True) 120 | 121 | optimizer = get_optim(optim_name, net, optim_args) 122 | 123 | train_loss_per_epoch = [] 124 | 125 | for epoch in range(num_epochs): 126 | loss_per_minibatch = [] 127 | for i, data in enumerate(train_loader, 0): 128 | optimizer.zero_grad() 129 | 130 | d = torch.tensor(data, requires_grad=True) 131 | p = net(d, mode='pdf') 132 | 133 | logloss = -torch.sum(torch.log(p)) 134 | reg_loss = logloss 135 | reg_loss.backward() 136 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 137 | 138 | loss_per_minibatch.append(scalar_loss) 139 | optimizer.step() 140 | 141 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 142 | print('Training loss at epoch %s: %s' % 143 | (epoch, train_loss_per_epoch[-1])) 144 | 145 | if epoch % chkpt_freq == 0: 146 | print('Checkpointing') 147 | torch.save({ 148 | 'epoch': epoch, 149 | 'model_state_dict': net.state_dict(), 150 | 'optimizer_state_dict': optimizer.state_dict(), 151 | 'loss': logloss, 152 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 153 | 154 | print('Evaluating validation loss') 155 | for j, val_data in enumerate(val_loader, 0): 156 | net.zero_grad() 157 | val_p = net(val_data, mode='pdf') 158 | val_loss = -torch.mean(torch.log(val_p)) 159 | print('Average validation loss %s' % val_loss) 160 | 161 | 162 | @ex.automain 163 | def run(x_index, y_index, 164 | x_flip, y_flip, 165 | Phi, 166 | initial_theta, 167 | optim_name, optim_args, 168 | num_epochs, batch_size, chkpt_freq, 169 | frac_rand): 170 | id = get_info() 171 | identifier_id = '%s%s' % (identifier, id) 172 | 173 | train_data = X_train[:, [x_index, y_index]] 174 | train_data = add_train_random_noise(train_data, int(train_data.shape[0]*frac_rand)) 175 | test_data = X_test[:, [x_index, y_index]] 176 | 177 | if x_flip: 178 | train_data[:, 0] = 1-train_data[:, 0] 179 | test_data[:, 0] = 1-test_data[:, 0] 180 | 181 | if y_flip: 182 | train_data[:, 1] = 1-train_data[:, 1] 183 | test_data[:, 1] = 1-test_data[:, 1] 184 | 185 | phi = Phi(torch.tensor(initial_theta)) 186 | net = Copula(phi) 187 | expt(train_data, test_data, net, optim_name, 188 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 189 | 190 | 191 | if __name__ == '__main__': 192 | print('Sample usage: python -m train_scripts.stocks.train_with_frank -F learn_stocks_with_frank') 193 | -------------------------------------------------------------------------------- /train_scripts/stocks/train_with_gumbel.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | import scipy 3 | from sklearn.model_selection import train_test_split 4 | import os 5 | import numpy as np 6 | import torch.optim as optim 7 | import torch 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | from main import Copula 11 | 12 | goog_f = open('data/stocks/goog/close.vals', 'r') 13 | goog = np.array(list(map(float, goog_f.readlines()))) 14 | 15 | fb_f = open('data/stocks/fb/close.vals', 'r') 16 | fb = np.array(list(map(float, fb_f.readlines()))) 17 | 18 | 19 | identifier = 'goog_fb_stocks_gumbel' 20 | ex = Experiment('goog_fb_stocks_gumbel') 21 | 22 | torch.set_default_tensor_type(torch.DoubleTensor) 23 | 24 | 25 | def add_train_random_noise(data, num_adds): 26 | new_data = np.random.rand(num_adds, data.shape[1]) 27 | print(data.shape) 28 | print(new_data.shape) 29 | return np.concatenate((data, new_data), axis=0) 30 | 31 | 32 | X = np.concatenate((goog[:, None], fb[:, None]), axis=1) 33 | X_train, X_test, _, _ = train_test_split( 34 | X, X, shuffle=True, random_state=142857) 35 | # X_train, X_test, _, _ = train_test_split( 36 | # X, X, shuffle=True, random_state=714285) 37 | # X_train, X_test, _, _ = train_test_split( 38 | # X, X, shuffle=True, random_state=571428) 39 | # X_train, X_test, _, _ = train_test_split( 40 | # X, X, shuffle=True, random_state=857142) 41 | # X_train, X_test, _, _ = train_test_split( 42 | # X, X, shuffle=True, random_state=285714) 43 | 44 | nfeats = X_test.shape[1] 45 | 46 | # Normalize data. 47 | for z in [X_train, X_test]: 48 | ndata = z.shape[0] 49 | gap = 1./(ndata+1) 50 | for i in range(nfeats): 51 | z[:, i] = scipy.stats.rankdata(z[:, i], 'ordinal')*gap 52 | 53 | w = X.copy() 54 | ndata = w.shape[0] 55 | gap = 1./(ndata+1) 56 | for i in range(nfeats): 57 | w[:, i] = scipy.stats.rankdata(w[:, i], 'ordinal')*gap 58 | # np.savetxt('stocks_transformed.points', w) 59 | 60 | 61 | def get_optim(name, net, args): 62 | if name == 'SGD': 63 | optimizer = optim.SGD(net.parameters(), args['lr'], args['momentum']) 64 | else: 65 | assert False 66 | 67 | return optimizer 68 | 69 | 70 | @ex.config 71 | def cfg(): 72 | x_index = 0 73 | y_index = 1 74 | 75 | x_flip, y_flip = False, False 76 | 77 | optim_name = 'SGD' 78 | optim_args = \ 79 | { 80 | 'lr': 1e-5, 81 | 'momentum': 0.9 82 | } 83 | num_epochs = 10000000 84 | # batch_size = 1000 85 | batch_size = 200 86 | chkpt_freq = 500 87 | 88 | from phi_listing import GumbelPhi 89 | Phi = GumbelPhi 90 | phi_name = 'GumbelPhi' 91 | 92 | # Initial parameters. 93 | initial_theta = 5.0 94 | 95 | frac_rand = 0.01 96 | 97 | @ex.capture 98 | def get_info(_run): 99 | return _run._id 100 | 101 | 102 | def expt(train_data, val_data, 103 | net, 104 | optim_name, 105 | optim_args, 106 | identifier, 107 | num_epochs=1000, 108 | batch_size=100, 109 | chkpt_freq=50, 110 | ): 111 | 112 | os.mkdir('./checkpoints/%s' % identifier) 113 | os.mkdir('./sample_figs/%s' % identifier) 114 | # os.mkdir('./psi_figs/%s' % identifier) 115 | 116 | train_loader = DataLoader( 117 | train_data, batch_size=batch_size, shuffle=True) 118 | val_loader = DataLoader( 119 | val_data, batch_size=1000000, shuffle=True) 120 | 121 | optimizer = get_optim(optim_name, net, optim_args) 122 | 123 | train_loss_per_epoch = [] 124 | 125 | for epoch in range(num_epochs): 126 | loss_per_minibatch = [] 127 | for i, data in enumerate(train_loader, 0): 128 | optimizer.zero_grad() 129 | 130 | d = torch.tensor(data, requires_grad=True) 131 | p = net(d, mode='pdf') 132 | 133 | logloss = -torch.sum(torch.log(p)) 134 | reg_loss = logloss 135 | reg_loss.backward() 136 | scalar_loss = (reg_loss/p.numel()).detach().numpy().item() 137 | 138 | loss_per_minibatch.append(scalar_loss) 139 | optimizer.step() 140 | 141 | train_loss_per_epoch.append(np.mean(loss_per_minibatch)) 142 | print('Training loss at epoch %s: %s' % 143 | (epoch, train_loss_per_epoch[-1])) 144 | 145 | if epoch % chkpt_freq == 0: 146 | print('Checkpointing') 147 | torch.save({ 148 | 'epoch': epoch, 149 | 'model_state_dict': net.state_dict(), 150 | 'optimizer_state_dict': optimizer.state_dict(), 151 | 'loss': logloss, 152 | }, './checkpoints/%s/epoch%s' % (identifier, epoch)) 153 | 154 | print('Evaluating validation loss') 155 | for j, val_data in enumerate(val_loader, 0): 156 | net.zero_grad() 157 | val_p = net(val_data, mode='pdf') 158 | val_loss = -torch.mean(torch.log(val_p)) 159 | print('Average validation loss %s' % val_loss) 160 | 161 | 162 | @ex.automain 163 | def run(x_index, y_index, 164 | x_flip, y_flip, 165 | Phi, 166 | initial_theta, 167 | optim_name, optim_args, 168 | num_epochs, batch_size, chkpt_freq, 169 | frac_rand): 170 | id = get_info() 171 | identifier_id = '%s%s' % (identifier, id) 172 | 173 | train_data = X_train[:, [x_index, y_index]] 174 | train_data = add_train_random_noise(train_data, int(train_data.shape[0]*frac_rand)) 175 | test_data = X_test[:, [x_index, y_index]] 176 | 177 | if x_flip: 178 | train_data[:, 0] = 1-train_data[:, 0] 179 | test_data[:, 0] = 1-test_data[:, 0] 180 | 181 | if y_flip: 182 | train_data[:, 1] = 1-train_data[:, 1] 183 | test_data[:, 1] = 1-test_data[:, 1] 184 | 185 | phi = Phi(torch.tensor(initial_theta)) 186 | net = Copula(phi) 187 | expt(train_data, test_data, net, optim_name, 188 | optim_args, identifier_id, num_epochs, batch_size, chkpt_freq) 189 | 190 | 191 | if __name__ == '__main__': 192 | print('Sample usage: python -m train_scripts.stocks.train_with_gumbel -F learn_stocks_with_gumbel') 193 | --------------------------------------------------------------------------------