├── scripts ├── src │ ├── __init__.py │ ├── ICOT │ │ ├── data │ │ │ ├── X0.npy │ │ │ ├── X1.npy │ │ │ ├── X2.npy │ │ │ ├── X3.npy │ │ │ ├── X4.npy │ │ │ ├── Y0.npy │ │ │ ├── Y1.npy │ │ │ ├── Y2.npy │ │ │ ├── Y3.npy │ │ │ ├── Y4.npy │ │ │ └── ruspini.csv │ │ ├── result │ │ │ └── icot.png │ │ ├── src │ │ │ ├── runningICOT_example0.jl │ │ │ ├── runningICOT_example1.jl │ │ │ ├── runningICOT_example2.jl │ │ │ ├── runningICOT_example3.jl │ │ │ └── runningICOT_example4.jl │ │ └── README.md │ ├── leaf_model_cart.py │ └── cart_customerfeats_in_leafmod.py ├── data │ ├── P0.npy │ ├── P1.npy │ ├── P2.npy │ ├── P3.npy │ ├── P4.npy │ ├── P5.npy │ ├── P6.npy │ ├── P7.npy │ ├── P8.npy │ ├── P9.npy │ ├── PT0.npy │ ├── PT1.npy │ ├── PT2.npy │ ├── PT3.npy │ ├── PT4.npy │ ├── PT5.npy │ ├── PT6.npy │ ├── PT7.npy │ ├── PT8.npy │ ├── PT9.npy │ ├── PV0.npy │ ├── PV1.npy │ ├── PV2.npy │ ├── PV3.npy │ ├── PV4.npy │ ├── PV5.npy │ ├── PV6.npy │ ├── PV7.npy │ ├── PV8.npy │ ├── PV9.npy │ ├── X0.npy │ ├── X1.npy │ ├── X2.npy │ ├── X3.npy │ ├── X4.npy │ ├── X5.npy │ ├── X6.npy │ ├── X7.npy │ ├── X8.npy │ ├── X9.npy │ ├── XT0.npy │ ├── XT1.npy │ ├── XT2.npy │ ├── XT3.npy │ ├── XT4.npy │ ├── XT5.npy │ ├── XT6.npy │ ├── XT7.npy │ ├── XT8.npy │ ├── XT9.npy │ ├── XV0.npy │ ├── XV1.npy │ ├── XV2.npy │ ├── XV3.npy │ ├── XV4.npy │ ├── XV5.npy │ ├── XV6.npy │ ├── XV7.npy │ ├── XV8.npy │ ├── XV9.npy │ ├── X_0.npy │ ├── X_1.npy │ ├── X_2.npy │ ├── X_3.npy │ ├── X_4.npy │ ├── X_5.npy │ ├── X_6.npy │ ├── X_7.npy │ ├── X_8.npy │ ├── X_9.npy │ ├── Y0.npy │ ├── Y1.npy │ ├── Y2.npy │ ├── Y3.npy │ ├── Y4.npy │ ├── Y5.npy │ ├── Y6.npy │ ├── Y7.npy │ ├── Y8.npy │ ├── Y9.npy │ ├── YT0.npy │ ├── YT1.npy │ ├── YT2.npy │ ├── YT3.npy │ ├── YT4.npy │ ├── YT5.npy │ ├── YT6.npy │ ├── YT7.npy │ ├── YT8.npy │ ├── YT9.npy │ ├── YV0.npy │ ├── YV1.npy │ ├── YV2.npy │ ├── YV3.npy │ ├── YV4.npy │ ├── YV5.npy │ ├── YV6.npy │ ├── YV7.npy │ ├── YV8.npy │ ├── YV9.npy │ ├── P_long_0.npy │ ├── P_long_1.npy │ ├── P_long_2.npy │ ├── P_long_3.npy │ ├── P_long_4.npy │ ├── P_long_5.npy │ ├── P_long_6.npy │ ├── P_long_7.npy │ ├── P_long_8.npy │ ├── P_long_9.npy │ ├── X_long_0.npy │ ├── X_long_1.npy │ ├── X_long_2.npy │ ├── X_long_3.npy │ ├── X_long_4.npy │ ├── X_long_5.npy │ ├── X_long_6.npy │ ├── X_long_7.npy │ ├── X_long_8.npy │ ├── X_long_9.npy │ ├── Y_long_0.npy │ ├── Y_long_1.npy │ ├── Y_long_2.npy │ ├── Y_long_3.npy │ ├── Y_long_4.npy │ ├── Y_long_5.npy │ ├── Y_long_6.npy │ ├── Y_long_7.npy │ ├── Y_long_8.npy │ ├── Y_long_9.npy │ ├── PT_long_0.npy │ ├── PT_long_1.npy │ ├── PT_long_2.npy │ ├── PT_long_3.npy │ ├── PT_long_4.npy │ ├── PT_long_5.npy │ ├── PT_long_6.npy │ ├── PT_long_7.npy │ ├── PT_long_8.npy │ ├── PT_long_9.npy │ ├── PV_long_0.npy │ ├── PV_long_1.npy │ ├── PV_long_2.npy │ ├── PV_long_3.npy │ ├── PV_long_4.npy │ ├── PV_long_5.npy │ ├── PV_long_6.npy │ ├── PV_long_7.npy │ ├── PV_long_8.npy │ ├── PV_long_9.npy │ ├── XT_long_0.npy │ ├── XT_long_1.npy │ ├── XT_long_2.npy │ ├── XT_long_3.npy │ ├── XT_long_4.npy │ ├── XT_long_5.npy │ ├── XT_long_6.npy │ ├── XT_long_7.npy │ ├── XT_long_8.npy │ ├── XT_long_9.npy │ ├── XV_long_0.npy │ ├── XV_long_1.npy │ ├── XV_long_2.npy │ ├── XV_long_3.npy │ ├── XV_long_4.npy │ ├── XV_long_5.npy │ ├── XV_long_6.npy │ ├── XV_long_7.npy │ ├── XV_long_8.npy │ ├── XV_long_9.npy │ ├── YT_long_0.npy │ ├── YT_long_1.npy │ ├── YT_long_2.npy │ ├── YT_long_3.npy │ ├── YT_long_4.npy │ ├── YT_long_5.npy │ ├── YT_long_6.npy │ ├── YT_long_7.npy │ ├── YT_long_8.npy │ ├── YT_long_9.npy │ ├── YV_long_0.npy │ ├── YV_long_1.npy │ ├── YV_long_2.npy │ ├── YV_long_3.npy │ ├── YV_long_4.npy │ ├── YV_long_5.npy │ ├── YV_long_6.npy │ ├── YV_long_7.npy │ ├── YV_long_8.npy │ └── YV_long_9.npy ├── outputs │ ├── cmt.npy │ ├── cmt+.npy │ ├── mnldt.npy │ ├── mnlkm.npy │ ├── mnldt+.npy │ ├── mnlicot+.npy │ ├── mnlicot.npy │ ├── mnlint.npy │ ├── mnlkm+.npy │ ├── tastenet.npy │ ├── plot_is_mse.png │ ├── plot_os_mse.png │ ├── plot_os_nll.png │ ├── cmt_refactor.py.npy │ ├── plot_is_nll.nll.png │ ├── cmt+_refactor.py.npy │ ├── varying_cmt_d15_m50_mod0.npy │ ├── varying_mnlkm_k30_5_10_mod0.npy │ └── fitted cmts │ │ └── CMTtree3.txt ├── plots.py ├── README.md ├── prepare_data.py ├── mnlkm+.py ├── mnlkm.py ├── cmt+.py ├── cmt.py ├── mnldt+.py ├── mnldt.py ├── mnlicot+.py ├── mnlicot.py └── mnlint.py ├── LICENSE ├── GenMNL.py ├── irt_example.py ├── README.md ├── cmt_example.py ├── leaf_model_template.py ├── leaf_model_mnl.py └── leaf_model_isoreg.py /scripts/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/data/P0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P0.npy -------------------------------------------------------------------------------- /scripts/data/P1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P1.npy -------------------------------------------------------------------------------- /scripts/data/P2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P2.npy -------------------------------------------------------------------------------- /scripts/data/P3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P3.npy -------------------------------------------------------------------------------- /scripts/data/P4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P4.npy -------------------------------------------------------------------------------- /scripts/data/P5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P5.npy -------------------------------------------------------------------------------- /scripts/data/P6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P6.npy -------------------------------------------------------------------------------- /scripts/data/P7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P7.npy -------------------------------------------------------------------------------- /scripts/data/P8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P8.npy -------------------------------------------------------------------------------- /scripts/data/P9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P9.npy -------------------------------------------------------------------------------- /scripts/data/PT0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT0.npy -------------------------------------------------------------------------------- /scripts/data/PT1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT1.npy -------------------------------------------------------------------------------- /scripts/data/PT2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT2.npy -------------------------------------------------------------------------------- /scripts/data/PT3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT3.npy -------------------------------------------------------------------------------- /scripts/data/PT4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT4.npy -------------------------------------------------------------------------------- /scripts/data/PT5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT5.npy -------------------------------------------------------------------------------- /scripts/data/PT6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT6.npy -------------------------------------------------------------------------------- /scripts/data/PT7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT7.npy -------------------------------------------------------------------------------- /scripts/data/PT8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT8.npy -------------------------------------------------------------------------------- /scripts/data/PT9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT9.npy -------------------------------------------------------------------------------- /scripts/data/PV0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV0.npy -------------------------------------------------------------------------------- /scripts/data/PV1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV1.npy -------------------------------------------------------------------------------- /scripts/data/PV2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV2.npy -------------------------------------------------------------------------------- /scripts/data/PV3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV3.npy -------------------------------------------------------------------------------- /scripts/data/PV4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV4.npy -------------------------------------------------------------------------------- /scripts/data/PV5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV5.npy -------------------------------------------------------------------------------- /scripts/data/PV6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV6.npy -------------------------------------------------------------------------------- /scripts/data/PV7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV7.npy -------------------------------------------------------------------------------- /scripts/data/PV8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV8.npy -------------------------------------------------------------------------------- /scripts/data/PV9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV9.npy -------------------------------------------------------------------------------- /scripts/data/X0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X0.npy -------------------------------------------------------------------------------- /scripts/data/X1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X1.npy -------------------------------------------------------------------------------- /scripts/data/X2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X2.npy -------------------------------------------------------------------------------- /scripts/data/X3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X3.npy -------------------------------------------------------------------------------- /scripts/data/X4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X4.npy -------------------------------------------------------------------------------- /scripts/data/X5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X5.npy -------------------------------------------------------------------------------- /scripts/data/X6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X6.npy -------------------------------------------------------------------------------- /scripts/data/X7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X7.npy -------------------------------------------------------------------------------- /scripts/data/X8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X8.npy -------------------------------------------------------------------------------- /scripts/data/X9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X9.npy -------------------------------------------------------------------------------- /scripts/data/XT0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT0.npy -------------------------------------------------------------------------------- /scripts/data/XT1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT1.npy -------------------------------------------------------------------------------- /scripts/data/XT2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT2.npy -------------------------------------------------------------------------------- /scripts/data/XT3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT3.npy -------------------------------------------------------------------------------- /scripts/data/XT4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT4.npy -------------------------------------------------------------------------------- /scripts/data/XT5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT5.npy -------------------------------------------------------------------------------- /scripts/data/XT6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT6.npy -------------------------------------------------------------------------------- /scripts/data/XT7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT7.npy -------------------------------------------------------------------------------- /scripts/data/XT8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT8.npy -------------------------------------------------------------------------------- /scripts/data/XT9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT9.npy -------------------------------------------------------------------------------- /scripts/data/XV0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV0.npy -------------------------------------------------------------------------------- /scripts/data/XV1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV1.npy -------------------------------------------------------------------------------- /scripts/data/XV2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV2.npy -------------------------------------------------------------------------------- /scripts/data/XV3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV3.npy -------------------------------------------------------------------------------- /scripts/data/XV4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV4.npy -------------------------------------------------------------------------------- /scripts/data/XV5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV5.npy -------------------------------------------------------------------------------- /scripts/data/XV6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV6.npy -------------------------------------------------------------------------------- /scripts/data/XV7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV7.npy -------------------------------------------------------------------------------- /scripts/data/XV8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV8.npy -------------------------------------------------------------------------------- /scripts/data/XV9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV9.npy -------------------------------------------------------------------------------- /scripts/data/X_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_0.npy -------------------------------------------------------------------------------- /scripts/data/X_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_1.npy -------------------------------------------------------------------------------- /scripts/data/X_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_2.npy -------------------------------------------------------------------------------- /scripts/data/X_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_3.npy -------------------------------------------------------------------------------- /scripts/data/X_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_4.npy -------------------------------------------------------------------------------- /scripts/data/X_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_5.npy -------------------------------------------------------------------------------- /scripts/data/X_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_6.npy -------------------------------------------------------------------------------- /scripts/data/X_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_7.npy -------------------------------------------------------------------------------- /scripts/data/X_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_8.npy -------------------------------------------------------------------------------- /scripts/data/X_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_9.npy -------------------------------------------------------------------------------- /scripts/data/Y0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y0.npy -------------------------------------------------------------------------------- /scripts/data/Y1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y1.npy -------------------------------------------------------------------------------- /scripts/data/Y2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y2.npy -------------------------------------------------------------------------------- /scripts/data/Y3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y3.npy -------------------------------------------------------------------------------- /scripts/data/Y4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y4.npy -------------------------------------------------------------------------------- /scripts/data/Y5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y5.npy -------------------------------------------------------------------------------- /scripts/data/Y6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y6.npy -------------------------------------------------------------------------------- /scripts/data/Y7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y7.npy -------------------------------------------------------------------------------- /scripts/data/Y8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y8.npy -------------------------------------------------------------------------------- /scripts/data/Y9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y9.npy -------------------------------------------------------------------------------- /scripts/data/YT0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT0.npy -------------------------------------------------------------------------------- /scripts/data/YT1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT1.npy -------------------------------------------------------------------------------- /scripts/data/YT2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT2.npy -------------------------------------------------------------------------------- /scripts/data/YT3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT3.npy -------------------------------------------------------------------------------- /scripts/data/YT4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT4.npy -------------------------------------------------------------------------------- /scripts/data/YT5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT5.npy -------------------------------------------------------------------------------- /scripts/data/YT6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT6.npy -------------------------------------------------------------------------------- /scripts/data/YT7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT7.npy -------------------------------------------------------------------------------- /scripts/data/YT8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT8.npy -------------------------------------------------------------------------------- /scripts/data/YT9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT9.npy -------------------------------------------------------------------------------- /scripts/data/YV0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV0.npy -------------------------------------------------------------------------------- /scripts/data/YV1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV1.npy -------------------------------------------------------------------------------- /scripts/data/YV2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV2.npy -------------------------------------------------------------------------------- /scripts/data/YV3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV3.npy -------------------------------------------------------------------------------- /scripts/data/YV4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV4.npy -------------------------------------------------------------------------------- /scripts/data/YV5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV5.npy -------------------------------------------------------------------------------- /scripts/data/YV6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV6.npy -------------------------------------------------------------------------------- /scripts/data/YV7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV7.npy -------------------------------------------------------------------------------- /scripts/data/YV8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV8.npy -------------------------------------------------------------------------------- /scripts/data/YV9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV9.npy -------------------------------------------------------------------------------- /scripts/outputs/cmt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/cmt.npy -------------------------------------------------------------------------------- /scripts/data/P_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_0.npy -------------------------------------------------------------------------------- /scripts/data/P_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_1.npy -------------------------------------------------------------------------------- /scripts/data/P_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_2.npy -------------------------------------------------------------------------------- /scripts/data/P_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_3.npy -------------------------------------------------------------------------------- /scripts/data/P_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_4.npy -------------------------------------------------------------------------------- /scripts/data/P_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_5.npy -------------------------------------------------------------------------------- /scripts/data/P_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_6.npy -------------------------------------------------------------------------------- /scripts/data/P_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_7.npy -------------------------------------------------------------------------------- /scripts/data/P_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_8.npy -------------------------------------------------------------------------------- /scripts/data/P_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/P_long_9.npy -------------------------------------------------------------------------------- /scripts/data/X_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_0.npy -------------------------------------------------------------------------------- /scripts/data/X_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_1.npy -------------------------------------------------------------------------------- /scripts/data/X_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_2.npy -------------------------------------------------------------------------------- /scripts/data/X_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_3.npy -------------------------------------------------------------------------------- /scripts/data/X_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_4.npy -------------------------------------------------------------------------------- /scripts/data/X_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_5.npy -------------------------------------------------------------------------------- /scripts/data/X_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_6.npy -------------------------------------------------------------------------------- /scripts/data/X_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_7.npy -------------------------------------------------------------------------------- /scripts/data/X_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_8.npy -------------------------------------------------------------------------------- /scripts/data/X_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/X_long_9.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_0.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_1.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_2.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_3.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_4.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_5.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_6.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_7.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_8.npy -------------------------------------------------------------------------------- /scripts/data/Y_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/Y_long_9.npy -------------------------------------------------------------------------------- /scripts/outputs/cmt+.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/cmt+.npy -------------------------------------------------------------------------------- /scripts/outputs/mnldt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/mnldt.npy -------------------------------------------------------------------------------- /scripts/outputs/mnlkm.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/mnlkm.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_0.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_1.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_2.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_3.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_4.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_5.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_6.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_7.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_8.npy -------------------------------------------------------------------------------- /scripts/data/PT_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PT_long_9.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_0.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_1.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_2.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_3.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_4.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_5.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_6.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_7.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_8.npy -------------------------------------------------------------------------------- /scripts/data/PV_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/PV_long_9.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_0.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_1.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_2.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_3.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_4.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_5.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_6.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_7.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_8.npy -------------------------------------------------------------------------------- /scripts/data/XT_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XT_long_9.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_0.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_1.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_2.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_3.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_4.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_5.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_6.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_7.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_8.npy -------------------------------------------------------------------------------- /scripts/data/XV_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/XV_long_9.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_0.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_1.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_2.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_3.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_4.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_5.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_6.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_7.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_8.npy -------------------------------------------------------------------------------- /scripts/data/YT_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YT_long_9.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_0.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_1.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_2.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_3.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_4.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_5.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_6.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_7.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_8.npy -------------------------------------------------------------------------------- /scripts/data/YV_long_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/data/YV_long_9.npy -------------------------------------------------------------------------------- /scripts/outputs/mnldt+.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/mnldt+.npy -------------------------------------------------------------------------------- /scripts/outputs/mnlicot+.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/mnlicot+.npy -------------------------------------------------------------------------------- /scripts/outputs/mnlicot.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/mnlicot.npy -------------------------------------------------------------------------------- /scripts/outputs/mnlint.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/mnlint.npy -------------------------------------------------------------------------------- /scripts/outputs/mnlkm+.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/mnlkm+.npy -------------------------------------------------------------------------------- /scripts/outputs/tastenet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/tastenet.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/X0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/X0.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/X1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/X1.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/X2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/X2.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/X3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/X3.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/X4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/X4.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/Y0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/Y0.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/Y1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/Y1.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/Y2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/Y2.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/Y3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/Y3.npy -------------------------------------------------------------------------------- /scripts/src/ICOT/data/Y4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/data/Y4.npy -------------------------------------------------------------------------------- /scripts/outputs/plot_is_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/plot_is_mse.png -------------------------------------------------------------------------------- /scripts/outputs/plot_os_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/plot_os_mse.png -------------------------------------------------------------------------------- /scripts/outputs/plot_os_nll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/plot_os_nll.png -------------------------------------------------------------------------------- /scripts/src/ICOT/result/icot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/src/ICOT/result/icot.png -------------------------------------------------------------------------------- /scripts/outputs/cmt_refactor.py.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/cmt_refactor.py.npy -------------------------------------------------------------------------------- /scripts/outputs/plot_is_nll.nll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/plot_is_nll.nll.png -------------------------------------------------------------------------------- /scripts/outputs/cmt+_refactor.py.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/cmt+_refactor.py.npy -------------------------------------------------------------------------------- /scripts/outputs/varying_cmt_d15_m50_mod0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/varying_cmt_d15_m50_mod0.npy -------------------------------------------------------------------------------- /scripts/outputs/varying_mnlkm_k30_5_10_mod0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtm2130/MST/HEAD/scripts/outputs/varying_mnlkm_k30_5_10_mod0.npy -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ryan McNellis Ali Aouad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/src/ICOT/data/ruspini.csv: -------------------------------------------------------------------------------- 1 | "x1","x2","x3" 2 | 4.0,53.0,4.0 3 | 5.0,63.0,4.0 4 | 10.0,59.0,4.0 5 | 9.0,77.0,4.0 6 | 13.0,49.0,4.0 7 | 13.0,69.0,4.0 8 | 12.0,88.0,4.0 9 | 15.0,75.0,4.0 10 | 18.0,61.0,4.0 11 | 19.0,65.0,4.0 12 | 22.0,74.0,4.0 13 | 27.0,72.0,4.0 14 | 28.0,76.0,4.0 15 | 24.0,58.0,4.0 16 | 27.0,55.0,4.0 17 | 28.0,60.0,4.0 18 | 30.0,52.0,4.0 19 | 31.0,60.0,4.0 20 | 32.0,61.0,4.0 21 | 36.0,72.0,4.0 22 | 28.0,147.0,1.0 23 | 32.0,149.0,1.0 24 | 35.0,153.0,1.0 25 | 33.0,154.0,1.0 26 | 38.0,151.0,1.0 27 | 41.0,150.0,1.0 28 | 38.0,145.0,1.0 29 | 38.0,143.0,1.0 30 | 32.0,143.0,1.0 31 | 34.0,141.0,1.0 32 | 44.0,156.0,1.0 33 | 44.0,149.0,1.0 34 | 44.0,143.0,1.0 35 | 46.0,142.0,1.0 36 | 47.0,149.0,1.0 37 | 49.0,152.0,1.0 38 | 50.0,142.0,1.0 39 | 53.0,144.0,1.0 40 | 52.0,152.0,1.0 41 | 55.0,155.0,1.0 42 | 54.0,124.0,1.0 43 | 60.0,136.0,1.0 44 | 63.0,139.0,1.0 45 | 86.0,132.0,3.0 46 | 85.0,115.0,3.0 47 | 85.0,96.0,3.0 48 | 78.0,94.0,3.0 49 | 74.0,96.0,3.0 50 | 97.0,122.0,3.0 51 | 98.0,116.0,3.0 52 | 98.0,124.0,3.0 53 | 99.0,119.0,3.0 54 | 99.0,128.0,3.0 55 | 101.0,115.0,3.0 56 | 108.0,111.0,3.0 57 | 110.0,111.0,3.0 58 | 108.0,116.0,3.0 59 | 111.0,126.0,3.0 60 | 115.0,117.0,3.0 61 | 117.0,115.0,3.0 62 | 70.0,4.0,2.0 63 | 77.0,12.0,2.0 64 | 83.0,21.0,2.0 65 | 61.0,15.0,2.0 66 | 69.0,15.0,2.0 67 | 78.0,16.0,2.0 68 | 66.0,18.0,2.0 69 | 58.0,13.0,2.0 70 | 64.0,20.0,2.0 71 | 69.0,21.0,2.0 72 | 66.0,23.0,2.0 73 | 61.0,25.0,2.0 74 | 76.0,27.0,2.0 75 | 72.0,31.0,2.0 76 | 64.0,30.0,2.0 77 | -------------------------------------------------------------------------------- /scripts/plots.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Jan 24 16:50:46 2021 5 | 6 | Script to plot the pareto curves 7 | 8 | """ 9 | 10 | import pandas as pd 11 | import numpy as np 12 | import seaborn as sns 13 | 14 | 15 | MST = np.load("outputs/varying_cmt_d15_m50_mod0.npy") 16 | cMNL = np.load("outputs/varying_mnlkm_k30_5_10_mod0.npy") 17 | 18 | dfMST = pd.DataFrame({'# Segments':MST.mean(axis = 0)[0,4,:], 19 | 'In-sample NLL':-MST.mean(axis = 0)[0,2,:], 20 | 'Out-of-sample NLL':-MST.mean(axis = 0)[0,0,:], 21 | 'In-sample MSE':MST.mean(axis = 0)[0,3,:], 22 | 'Out-of-sample MSE': MST.mean(axis = 0)[0,1,:]}) 23 | 24 | dfcMNL = pd.DataFrame({'# Segments':cMNL.mean(axis = 0)[0,4,:], 25 | 'In-sample NLL':cMNL.mean(axis = 0)[0,2,:], 26 | 'Out-of-sample NLL':-cMNL.mean(axis = 0)[0,0,:], 27 | 'In-sample MSE':cMNL.mean(axis = 0)[0,3,:], 28 | 'Out-of-sample MSE': cMNL.mean(axis = 0)[0,1,:]}) 29 | dfcMNL['Model'] = "MNLKM" 30 | dfMST['Model'] = "CMT" 31 | 32 | dfcMNL = dfcMNL[dfcMNL["# Segments"]<=110] 33 | 34 | df = pd.concat([dfMST,dfcMNL],axis = 0) 35 | 36 | sns.relplot(x="# Segments", y="Out-of-sample MSE", hue="Model", data=df) 37 | sns.relplot(x="# Segments", y="In-sample MSE", hue="Model", data=df) 38 | sns.relplot(x="# Segments", y="Out-of-sample NLL", hue="Model", data=df) 39 | sns.relplot(x="# Segments", y="In-sample NLL", hue="Model", data=df) 40 | 41 | -------------------------------------------------------------------------------- /GenMNL.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #Class for generating random MNL models. Helper class for cmt_example.py. 3 | class GenMNL(object): 4 | 5 | ''' 6 | Arguments specifying MNL model type 7 | n_items: number of products 8 | num_features: number of product features (integer), INCLUDING the binary availability feature 9 | model_type : whether the model has alternative varying coefficients (0) or not (1) 10 | (default is 0 meaning each alternative has a separate coeff) 11 | is_bias : whether the utility function has an intercept (default is True) 12 | ''' 13 | def __init__(self, n_items, num_features, model_type, is_bias): 14 | 15 | self.n_items = n_items 16 | self.num_features = num_features 17 | self.model_type = model_type 18 | self.is_bias = is_bias 19 | #generate MNL's coefficients randomly 20 | self.Beta = np.random.uniform(low=-1, high=1, size=(n_items,num_features)) 21 | 22 | ''' 23 | Get choice probabilities from product features P 24 | ''' 25 | def get_choice_probs(self, P): 26 | n = P.shape[0] 27 | n_items = self.n_items 28 | num_features = self.num_features 29 | model_type = self.model_type 30 | is_bias = self.is_bias 31 | Beta = self.Beta 32 | 33 | U_exp = np.zeros((n,n_items)) 34 | 35 | for k in range(n_items): 36 | if is_bias == True: 37 | U_exp[:,k] = Beta[k,0] 38 | else: 39 | U_exp[:,k] = 0.0 40 | for l in range(num_features-1): 41 | if model_type == 0: 42 | U_exp[:,k] = U_exp[:,k] + Beta[k,l+1]*P[:,(n_items*(l+1)+k)] 43 | else: 44 | U_exp[:,k] = U_exp[:,k] + Beta[0,l+1]*P[:,(n_items*(l+1)+k)] 45 | 46 | scale = 5 #dictates the level of noise in the choice probabilities. epsilon~Gumbel(0,1/scale) 47 | Y_prob = np.zeros((n,n_items)) 48 | denom = sum([np.exp(scale*U_exp[:,k])*P[:,k] for k in range(n_items)]) 49 | for k in range(n_items): 50 | Y_prob[:,k] = np.where(P[:,k] == 1, np.exp(scale*U_exp[:,k])/denom, 0) 51 | 52 | return Y_prob 53 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | This directory provides the scripts to run CMTs and benchmark methods on the Swissmetro dataset. 2 | 3 | ## Dependencies 4 | 5 | All scripts except `tastenet.py` and `mnlint.py` were run with the following dependencies: 6 | 7 | python 2.7.17, joblib 0.13.2, numpy 1.16.6, pandas 0.22.0, rpy2 2.8.6, seaborn 0.9.0, scikit-learn 0.20.3, tensorflow 2.1.0 8 | 9 | Note that other R dependencies are specified along with the virtual environment of the source code. The scripts `tastenet.py` and `mnlint.py` were run with the following dependencies: 10 | 11 | python 3.8.13, pandas 1.4.2, numpy 1.19.2, tensorflow 2.4.0, tensorboard 2.4.1, tf-nightly 2.5.0, tf-estimator-nightly 2.4.0, keras-preprocessing 1.1.2, keras-tuner 1.1.0 12 | 13 | ## Description of scripts 14 | 15 | `cmt.py` and `cmt+.py`: estimation of the CMT on the 10 data splits for a maximum depth of 14. 16 | 17 | `mnlkm.py` and `mnlkm+.py`: estimation of the MNLKM benchmark on the 10 data splits with a search over the number of clusters 5, 10, ... 295. 18 | 19 | `mnldt.py` and `mnldt+.py`: estimation of the MNLDT benchmark on the 10 data splits for a maximum depth of 14. 20 | 21 | `mnlicot.py` and `mnlicot+.py`: estimation of the MNLICOT benchmark on the 10 data splits. The ICOT tree is hardcoded. Code that generates the ICOT tree is provided in the folder `src/ICOT`. The same tree is produced across 5 distinct splits of the data. 22 | 23 | `tastenet.py`: estimation of the TasteNet benchmark on the 10 data splits. Code is implemented using the tensorflow library. Input data is used in long format. Note: random seed was not saved. 24 | 25 | `mnlint.py`: estimation of the MNLINT benchmark on the 10 data splits. Code is implemented using the tensorflow library. Input data is in used in long format. Note: random seed was not saved. 26 | 27 | `plots.py`: visualisation of the Pareto curve showing the predictive performance as a function of the number of segments for MNLKM and CMT. 28 | 29 | `prepare_data.py`: code to generate the 10 random splits 75%-12.5%-12.5%, both in long/wide formats. Note: random seed was not saved. 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/prepare_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Dec 16 19:19:32 2020 5 | 6 | """ 7 | import pandas as pd 8 | import numpy as np 9 | 10 | path = 'data/' 11 | df = pd.read_csv(path+'swissmetro.dat',sep='\t') 12 | 13 | df["CAR_HE"] = 0 14 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 15 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 16 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 17 | "TRAIN_TT", "SM_TT", "CAR_TT", 18 | "TRAIN_CO", "SM_CO", "CAR_CO", 19 | "TRAIN_HE", "SM_HE", "CAR_HE"] 20 | target = "CHOICE" 21 | df = df[df[target] > 0] 22 | df.loc[:,target] = df.loc[:,target]-1 23 | df = df.reset_index() 24 | df = df.sample(frac=1).reset_index(drop=True) 25 | already_dummies = ["FIRST", "MALE", "GA"] 26 | 27 | df2 = df.copy() 28 | 29 | c_features2 = c_features[:] 30 | 31 | for c in c_features2: 32 | if c not in already_dummies: 33 | dummies = pd.get_dummies(df[c], prefix = c) 34 | df2 = pd.concat((df2,dummies), axis = 1) 35 | df2 = df2.drop(c, axis = 1) 36 | c_features2.remove(c) 37 | c_features2 = c_features2 + dummies.columns.tolist() 38 | 39 | def prepare_data(df,train_indices,test_indices,validation_indices,c_features): 40 | ''' 41 | Prepares a partition of the data into train, validation and test sets 42 | 43 | ''' 44 | 45 | Y = df.loc[train_indices,target].values 46 | X = df.loc[train_indices,c_features].values 47 | P = df.loc[train_indices,p_features].values 48 | YT = df.loc[test_indices,target].values 49 | XT = df.loc[test_indices,c_features].values 50 | PT = df.loc[test_indices,p_features].values 51 | YV = df.loc[validation_indices,target].values 52 | XV = df.loc[validation_indices,c_features].values 53 | PV = df.loc[validation_indices,p_features].values 54 | 55 | return X,P,Y,XV,PV,YV,XT,PT,YT 56 | 57 | 58 | n_valid = int(df.shape[0]*0.125); 59 | n_test = int(df.shape[0]*0.125); 60 | 61 | for i in range(10): 62 | n = df.shape[0] 63 | selected_indices = np.random.choice(range(n), size= n_valid + n_test, replace=False) 64 | test_indices = np.random.choice(selected_indices, size = n_test, replace=False) 65 | validation_indices = np.setdiff1d(selected_indices,test_indices) 66 | train_indices = np.setdiff1d(np.arange(n),selected_indices) 67 | 68 | X,P,Y,XV,PV,YV,XT,PT,YT = \ 69 | prepare_data(df2, train_indices, test_indices,validation_indices,c_features2) 70 | 71 | np.save(path+'X_long_{}.npy'.format(i),X) 72 | np.save(path+'P_long_{}.npy'.format(i),P) 73 | np.save(path+'Y_long_{}.npy'.format(i),Y) 74 | 75 | np.save(path+'XV_long_{}.npy'.format(i),XV) 76 | np.save(path+'PV_long_{}.npy'.format(i),PV) 77 | np.save(path+'YV_long_{}.npy'.format(i),YV) 78 | 79 | np.save(path+'XT_long_{}.npy'.format(i),XT) 80 | np.save(path+'PT_long_{}.npy'.format(i),PT) 81 | np.save(path+'YT_long_{}.npy'.format(i),YT) 82 | 83 | X,P,Y,XV,PV,YV,XT,PT,YT = \ 84 | prepare_data(df, train_indices, test_indices,validation_indices,c_features) 85 | 86 | np.save(path+'X{}.npy'.format(i),X) 87 | np.save(path+'P{}.npy'.format(i),P) 88 | np.save(path+'Y{}.npy'.format(i),Y) 89 | 90 | np.save(path+'XV{}.npy'.format(i),XV) 91 | np.save(path+'PV{}.npy'.format(i),PV) 92 | np.save(path+'YV{}.npy'.format(i),YV) 93 | 94 | np.save(path+'XT{}.npy'.format(i),XT) 95 | np.save(path+'PT{}.npy'.format(i),PT) 96 | np.save(path+'YT{}.npy'.format(i),YT) 97 | -------------------------------------------------------------------------------- /scripts/mnlkm+.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | from src.mst import MST 5 | from src.responsemodel_kmeans import response_model_kmeans_fit_and_predict 6 | from src.cart_customerfeats_in_leafmod import append_customer_features_to_product_features 7 | 8 | 9 | 10 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 11 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 12 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 13 | "TRAIN_TT", "SM_TT", "CAR_TT", 14 | "TRAIN_CO", "SM_CO", "CAR_CO", 15 | "TRAIN_HE", "SM_HE", "CAR_HE"] 16 | 17 | num_features = 4 18 | model_type = 0 19 | is_bias = True 20 | is_continuous = [False for k in range(len(c_features))] 21 | is_continuous[6] = True 22 | is_continuous[8] = False 23 | 24 | scores_np = np.zeros((10,2,5,30)) 25 | 26 | for i in range(10): 27 | num_features = 4 28 | X = np.load('data/X'+str(i)+'.npy') 29 | P = np.load('data/P'+str(i)+'.npy') 30 | Y = np.load('data/Y'+str(i)+'.npy') 31 | 32 | XV = np.load('data/XV'+str(i)+'.npy') 33 | PV = np.load('data/PV'+str(i)+'.npy') 34 | YV = np.load('data/YV'+str(i)+'.npy') 35 | 36 | XT = np.load('data/XT'+str(i)+'.npy') 37 | PT = np.load('data/PT'+str(i)+'.npy') 38 | YT = np.load('data/YT'+str(i)+'.npy') 39 | 40 | P = P.astype(float) 41 | PV = PV.astype(float) 42 | PT = PT.astype(float) 43 | 44 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 45 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 46 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 47 | 48 | P, PV, PT, num_features = append_customer_features_to_product_features(X, XV, XT, 49 | P, PV, PT, 50 | feats_continuous=is_continuous, 51 | model_type=model_type, num_features=num_features) 52 | print(P.shape,PV.shape,PT.shape,X.shape) 53 | ############################################################################# 54 | #(2) Run MNLKM (k-means with MNL response model) 55 | for n_clusters in range(1): 56 | k_seq = range(5,300,10) 57 | # k_seq = [11] 58 | # for n_clusters in range(30): 59 | # k_seq = [5+10*n_clusters] 60 | #fit MNLKM and output test set predictions. See code responsemodel_kmeans.py for more details 61 | loss,mse,Y_predT,best_k = response_model_kmeans_fit_and_predict(k_seq, 62 | X,P,Y, 63 | XV,PV,YV, 64 | XT,PT, 65 | feats_continuous=is_continuous, normalize_feats=True, 66 | n_init=10, verbose=True, tuning_loss_function="mse", 67 | num_features = num_features, is_bias = is_bias, model_type = model_type, 68 | mode = "mnl", batch_size = 100, epochs = 100, steps = 5000, 69 | method = 'kmeans', 70 | leaf_mod_thresh=1000000000000) 71 | 72 | YT_flat = np.zeros((YT.shape[0],3)) 73 | YT_flat[np.arange(YT.shape[0]),YT] = 1 74 | 75 | s_Y = Y_predT.shape[0] 76 | scores_np[i,0,0,n_clusters] = np.mean(np.log(np.maximum(0.01,Y_predT[np.arange(s_Y),YT]))) 77 | scores_np[i,0,1,n_clusters] = np.mean(np.sum(np.power(Y_predT-YT_flat,2),axis = 1)) 78 | scores_np[i,0,2,n_clusters] = loss 79 | scores_np[i,0,3,n_clusters] = mse 80 | scores_np[i,0,4,n_clusters] = best_k 81 | 82 | -------------------------------------------------------------------------------- /scripts/mnlkm.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | from src.mst import MST 5 | from src.responsemodel_kmeans import response_model_kmeans_fit_and_predict 6 | from src.cart_customerfeats_in_leafmod import append_customer_features_to_product_features 7 | 8 | 9 | 10 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 11 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 12 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 13 | "TRAIN_TT", "SM_TT", "CAR_TT", 14 | "TRAIN_CO", "SM_CO", "CAR_CO", 15 | "TRAIN_HE", "SM_HE", "CAR_HE"] 16 | 17 | num_features = 4 18 | model_type = 0 19 | is_bias = True 20 | is_continuous = [False for k in range(len(c_features))] 21 | is_continuous[6] = True 22 | is_continuous[8] = False 23 | 24 | scores_np = np.zeros((10,2,5,30)) 25 | 26 | for i in range(10): 27 | num_features = 4 28 | X = np.load('data/X'+str(i)+'.npy') 29 | P = np.load('data/P'+str(i)+'.npy') 30 | Y = np.load('data/Y'+str(i)+'.npy') 31 | 32 | XV = np.load('data/XV'+str(i)+'.npy') 33 | PV = np.load('data/PV'+str(i)+'.npy') 34 | YV = np.load('data/YV'+str(i)+'.npy') 35 | 36 | XT = np.load('data/XT'+str(i)+'.npy') 37 | PT = np.load('data/PT'+str(i)+'.npy') 38 | YT = np.load('data/YT'+str(i)+'.npy') 39 | 40 | P = P.astype(float) 41 | PV = PV.astype(float) 42 | PT = PT.astype(float) 43 | 44 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 45 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 46 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 47 | 48 | # P, PV, PT, num_features = append_customer_features_to_product_features(X, XV, XT, 49 | # P, PV, PT, 50 | # feats_continuous=is_continuous, 51 | # model_type=model_type, num_features=num_features) 52 | # print(P.shape,PV.shape,PT.shape,X.shape) 53 | ############################################################################# 54 | #(2) Run MNLKM (k-means with MNL response model) 55 | for n_clusters in range(1): 56 | k_seq = range(5,300,10) 57 | # k_seq = [11] 58 | # for n_clusters in range(30): 59 | # k_seq = [5+10*n_clusters] 60 | #fit MNLKM and output test set predictions. See code responsemodel_kmeans.py for more details 61 | loss,mse,Y_predT,best_k = response_model_kmeans_fit_and_predict(k_seq, 62 | X,P,Y, 63 | XV,PV,YV, 64 | XT,PT, 65 | feats_continuous=is_continuous, normalize_feats=True, 66 | n_init=10, verbose=True, tuning_loss_function="mse", 67 | num_features = num_features, is_bias = is_bias, model_type = model_type, 68 | mode = "mnl", batch_size = 100, epochs = 100, steps = 5000, 69 | method = 'kmeans', 70 | leaf_mod_thresh=1000000000000) 71 | 72 | YT_flat = np.zeros((YT.shape[0],3)) 73 | YT_flat[np.arange(YT.shape[0]),YT] = 1 74 | 75 | s_Y = Y_predT.shape[0] 76 | scores_np[i,0,0,n_clusters] = np.mean(np.log(np.maximum(0.01,Y_predT[np.arange(s_Y),YT]))) 77 | scores_np[i,0,1,n_clusters] = np.mean(np.sum(np.power(Y_predT-YT_flat,2),axis = 1)) 78 | scores_np[i,0,2,n_clusters] = loss 79 | scores_np[i,0,3,n_clusters] = mse 80 | scores_np[i,0,4,n_clusters] = best_k 81 | 82 | -------------------------------------------------------------------------------- /irt_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | IRT EXAMPLE: SIMULATED LOGISTIC MODEL TREE DATASET 5 | 6 | Dataset is generated using a logistic regression model tree of depth 2 (see the Depth2Tree function). 7 | This example runs our IRT tree-building algorithm on the dataset to test whether it accurately models the data. 8 | Note that as an additional challenge to the IRT, the leaf nodes of the data-generating tree use logistic regression models 9 | (not isotonic regression models). 10 | 11 | NOTE: for this to run properly, include the following import statement in mst.py: "from leaf_model_isoreg import *" 12 | """ 13 | 14 | import numpy as np 15 | import pandas as pd 16 | 17 | from mst import MST 18 | 19 | np.set_printoptions(suppress=True) #suppress scientific notation 20 | np.random.seed(0) 21 | 22 | """ 23 | Given auction features X and bids P, returns vector of probabilities 24 | corresponding to whether or not the bid will win the auction. 25 | Dataset X: 3 covariates: 26 | X1: Binary in {0,1} 27 | X2: Continuous, takes values in [0,1] 28 | X3: Ordinal, takes values {0.0,0.2,0.4,0.6,0.8,1.0} 29 | 30 | This depth-2 logistic model tree is used to simulate the auction data 31 | """ 32 | def Depth2Tree(X,P): 33 | 34 | num_obs = X.shape[0] 35 | 36 | probs = np.ones([num_obs]) 37 | for i in range(0,num_obs): 38 | x = X.iloc[i,:] 39 | p = P[i] 40 | 41 | if (x['X3'] <= 0.6): 42 | if x['X1'] == 0: 43 | a = 20.0; #steepness of logistic curve 44 | p_thresh = 35.0; #price at which logistic curve is centered (i.e., L(p) = 0.5) 45 | b = -a/p_thresh; 46 | else: 47 | a = 20.0; #steepness of logistic curve 48 | p_thresh = 55.0; #price at which logistic curve is centered (i.e., L(p) = 0.5) 49 | b = -a/p_thresh; 50 | 51 | else: 52 | if x['X2'] <= 0.6: 53 | a = 20.0; #steepness of logistic curve 54 | p_thresh = 65.0; #price at which logistic curve is centered (i.e., L(p) = 0.5) 55 | b = -a/p_thresh; 56 | else: 57 | a = 20.0; #steepness of logistic curve 58 | p_thresh = 95.0; #price at which logistic curve is centered (i.e., L(p) = 0.5) 59 | b = -a/p_thresh; 60 | 61 | probs[i] = 1.0/(1.0+np.exp(-a-b*p)); 62 | 63 | return(probs); 64 | 65 | #SIMULATED DATA PARAMETERS 66 | n_train = 10000; 67 | n_valid = 2000; 68 | n_test = 5000; 69 | p_min = 10; 70 | p_max = 90; 71 | X1range = [0,1]; 72 | X3range = [0.0,0.2,0.4,0.6,0.8,1.0]; 73 | 74 | #generates data from logistic regression model tree of depth 2 75 | def generate_data(n): 76 | #auction features 77 | X1 = np.random.choice(X1range, size=n_train, replace=True) 78 | X2 = np.random.uniform(low=0.0,high=1.0,size=n_train) 79 | X3 = np.random.choice(X3range, size=n_train, replace=True) 80 | X = pd.DataFrame({'X1': X1,'X2': X2,'X3': X3}) 81 | #bids 82 | P = np.random.uniform(low = p_min, high = p_max, size=n_train); 83 | #outcomes: auction win indicators 84 | Y_prob = Depth2Tree(X,P); 85 | Y = np.random.binomial(1,Y_prob, size=n_train); 86 | 87 | return X,P,Y,Y_prob 88 | 89 | #GENERATE TRAINING DATA 90 | X,P,Y,Y_prob = generate_data(n_train) 91 | 92 | #FIT IRT ALGORITHM 93 | my_tree = MST(max_depth = 5, min_weights_per_node = 20) 94 | my_tree.fit(X,P,Y,verbose=False,feats_continuous=[False,True,True],increasing=False); #verbose specifies whether fitting procedure should print progress 95 | #ABOVE: increasing specifies whether fit isotonic regression models should be monotonically increasing or decreasing. 96 | #note ground truth has decreasing logistic curves. 97 | #my_tree.traverse() #prints out the unpruned tree 98 | 99 | #GENERATE VALIDATION SET DATA 100 | X,P,Y,Y_prob = generate_data(n_valid) 101 | 102 | #PRUNE DECISION TREE USING VALIDATION SET 103 | my_tree.prune(X, P, Y, verbose=False) #verbose specifies whether pruning procedure should print progress 104 | my_tree.traverse() #prints out the pruned tree, compare it against depth-2 tree used to generate the data 105 | 106 | #GENERATE TESTING DATA 107 | X,P,Y,Y_prob = generate_data(n_test) 108 | 109 | #USE TREE TO PREDICT TEST-SET PROBABILITIES AND MEASURE ERROR 110 | Ypred = my_tree.predict(X,P) 111 | print(np.mean(abs(Y_prob-Ypred))) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /scripts/src/ICOT/src/runningICOT_example0.jl: -------------------------------------------------------------------------------- 1 | using DataFrames, MLDataUtils 2 | using Clustering, Distances 3 | using CSV 4 | using Random 5 | using Logging 6 | using NPZ 7 | 8 | # Set up Logging - we recommend to use this command to avoid package warnings during the model training process. 9 | logger = Logging.SimpleLogger(stderr,Logging.Warn); 10 | global_logger(logger); 11 | 12 | #### Set parameters for the learners 13 | cr = :dunnindex 14 | method = "ICOT_local" 15 | warm_start = :none; 16 | geom_search = true 17 | threshold = 0.99 18 | seed = 1 19 | gridsearch = false 20 | num_tree_restarts = 100 21 | complexity_c = 0.0 22 | min_bucket = 10 23 | maxdepth = 10 24 | 25 | ###### Step 1: Prepare the data 26 | # Read the data - recommend the use of the (deprecated) readtable() command to avoid potential version conflicts with the CSV package. 27 | # data = readtable("../data/ruspini.csv"); 28 | 29 | data = npzread("../data/X0.npy") 30 | # data = data[1:200,:] 31 | true_labels = npzread("../data/Y0.npy") 32 | # true_labels = true_labels[1:200,:] 33 | 34 | # Convert the dataset to a matrix 35 | data = convert(Matrix{Float64}, data); 36 | true_labels = convert(Matrix{Float64},reshape(true_labels,:,1)); 37 | 38 | data = DataFrame(hcat(data,true_labels)); 39 | data_array = convert(Matrix{Float64}, data); 40 | 41 | # Get the number of observations and features 42 | n, p = size(data_array) 43 | data_t = data_array'; 44 | 45 | ##### Step 2: Fit K-means clustering on the dataset to generate a warm-start for ICOT 46 | #Fix the seed 47 | Random.seed!(seed); 48 | 49 | # The ruspini dataset has pre-defined clusters, which we will use to select the cluster count (K) for the K-means algorithm. 50 | # In an unsupervised setting (with no prior-known K), the number of clusters for K means can be selected using the elbow method. 51 | K = length(unique(data_array[:,end])) 52 | 53 | # Run k-means and save the assignments 54 | kmeans_result = kmeans(data_t, K); 55 | assignment = kmeans_result.assignments; 56 | 57 | data_full = DataFrame(hcat(data, assignment, makeunique=true)); 58 | names!(data_full, vcat([Symbol(string("x",k)) for k in range(1,p-1)],[:true_labels, :kmean_assign])); 59 | 60 | # Prepare data for ICOT: features are stored in the matrix X, and the warm-start labels are stored in y 61 | X = data_full[:,1:(p-1)]; y = data_full[:,:true_labels]; 62 | 63 | # ##### Step 3a. Before running ICOT, start by testing the IAI license 64 | # lnr_oct = ICOT.IAI.OptimalTreeClassifier(localsearch = false, max_depth = maxdepth, 65 | # minbucket = min_bucket, 66 | # criterion = :misclassification 67 | # ) 68 | # grid = ICOT.IAI.GridSearch(lnr_oct) 69 | # ICOT.IAI.fit!(grid, X, y) 70 | # ICOT.IAI.showinbrowser(grid.lnr) 71 | 72 | ##### Step 3b. Run ICOT 73 | 74 | # Run ICOT with no warm-start: 75 | warm_start= :none 76 | lnr_ws_none = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 77 | minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 78 | geom_search = geom_search, geom_threshold = threshold); 79 | run_time_icot_ls_none = @elapsed ICOT.fit!(lnr_ws_none, X, y); 80 | 81 | ICOT.showinbrowser(lnr_ws_none) 82 | 83 | score_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:dunnindex); 84 | score_al_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:silhouette); 85 | 86 | # # Run ICOT with an OCT warm-start: fit an OCT as a supervised learning problem with labels "y" and use this as the warm-start 87 | # warm_start= :oct 88 | # lnr_ws_oct = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 89 | # minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 90 | # geom_search = geom_search, geom_threshold = threshold); 91 | # run_time_icot_ls_oct = @elapsed ICOT.fit!(lnr_ws_oct, X, y); 92 | # 93 | # score_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:dunnindex); 94 | # score_al_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:silhouette); 95 | # ICOT.showinbrowser(lnr_ws_oct) 96 | 97 | ##### Comments 98 | # Note that in this example, the OCT tree and ICOT (OCT warm-start, or no warm-start) result in the same solution. 99 | # This is not generally the case, but can occur when the data is easily separated (as in the ruspini dataset) 100 | 101 | # The score printed in the browser view is negative to reflect the formulation as a minimization problem. 102 | # The positive score returned by the ICOT.score() functions reflect the "correct" interpretation of the score, 103 | # in which we seek to maximize the criterion (with a maximum value of 1). 104 | 105 | # For larger datasets, we recommend setting warm_start = :oct and threshold = 0.99 to improve the solve time. 106 | -------------------------------------------------------------------------------- /scripts/src/ICOT/src/runningICOT_example1.jl: -------------------------------------------------------------------------------- 1 | using DataFrames, MLDataUtils 2 | using Clustering, Distances 3 | using CSV 4 | using Random 5 | using Logging 6 | using NPZ 7 | 8 | # Set up Logging - we recommend to use this command to avoid package warnings during the model training process. 9 | logger = Logging.SimpleLogger(stderr,Logging.Warn); 10 | global_logger(logger); 11 | 12 | #### Set parameters for the learners 13 | cr = :dunnindex 14 | method = "ICOT_local" 15 | warm_start = :none; 16 | geom_search = true 17 | threshold = 0.99 18 | seed = 1 19 | gridsearch = false 20 | num_tree_restarts = 100 21 | complexity_c = 0.0 22 | min_bucket = 10 23 | maxdepth = 10 24 | 25 | ###### Step 1: Prepare the data 26 | # Read the data - recommend the use of the (deprecated) readtable() command to avoid potential version conflicts with the CSV package. 27 | # data = readtable("../data/ruspini.csv"); 28 | 29 | data = npzread("../data/X1.npy") 30 | # data = data[1:200,:] 31 | true_labels = npzread("../data/Y1.npy") 32 | # true_labels = true_labels[1:200,:] 33 | 34 | # Convert the dataset to a matrix 35 | data = convert(Matrix{Float64}, data); 36 | true_labels = convert(Matrix{Float64},reshape(true_labels,:,1)); 37 | 38 | data = DataFrame(hcat(data,true_labels)); 39 | data_array = convert(Matrix{Float64}, data); 40 | 41 | # Get the number of observations and features 42 | n, p = size(data_array) 43 | data_t = data_array'; 44 | 45 | ##### Step 2: Fit K-means clustering on the dataset to generate a warm-start for ICOT 46 | #Fix the seed 47 | Random.seed!(seed); 48 | 49 | # The ruspini dataset has pre-defined clusters, which we will use to select the cluster count (K) for the K-means algorithm. 50 | # In an unsupervised setting (with no prior-known K), the number of clusters for K means can be selected using the elbow method. 51 | K = length(unique(data_array[:,end])) 52 | 53 | # Run k-means and save the assignments 54 | kmeans_result = kmeans(data_t, K); 55 | assignment = kmeans_result.assignments; 56 | 57 | data_full = DataFrame(hcat(data, assignment, makeunique=true)); 58 | names!(data_full, vcat([Symbol(string("x",k)) for k in range(1,p-1)],[:true_labels, :kmean_assign])); 59 | 60 | # Prepare data for ICOT: features are stored in the matrix X, and the warm-start labels are stored in y 61 | X = data_full[:,1:(p-1)]; y = data_full[:,:true_labels]; 62 | 63 | # ##### Step 3a. Before running ICOT, start by testing the IAI license 64 | # lnr_oct = ICOT.IAI.OptimalTreeClassifier(localsearch = false, max_depth = maxdepth, 65 | # minbucket = min_bucket, 66 | # criterion = :misclassification 67 | # ) 68 | # grid = ICOT.IAI.GridSearch(lnr_oct) 69 | # ICOT.IAI.fit!(grid, X, y) 70 | # ICOT.IAI.showinbrowser(grid.lnr) 71 | 72 | ##### Step 3b. Run ICOT 73 | 74 | # Run ICOT with no warm-start: 75 | warm_start= :none 76 | lnr_ws_none = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 77 | minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 78 | geom_search = geom_search, geom_threshold = threshold); 79 | run_time_icot_ls_none = @elapsed ICOT.fit!(lnr_ws_none, X, y); 80 | 81 | ICOT.showinbrowser(lnr_ws_none) 82 | 83 | score_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:dunnindex); 84 | score_al_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:silhouette); 85 | 86 | # # Run ICOT with an OCT warm-start: fit an OCT as a supervised learning problem with labels "y" and use this as the warm-start 87 | # warm_start= :oct 88 | # lnr_ws_oct = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 89 | # minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 90 | # geom_search = geom_search, geom_threshold = threshold); 91 | # run_time_icot_ls_oct = @elapsed ICOT.fit!(lnr_ws_oct, X, y); 92 | # 93 | # score_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:dunnindex); 94 | # score_al_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:silhouette); 95 | # ICOT.showinbrowser(lnr_ws_oct) 96 | 97 | ##### Comments 98 | # Note that in this example, the OCT tree and ICOT (OCT warm-start, or no warm-start) result in the same solution. 99 | # This is not generally the case, but can occur when the data is easily separated (as in the ruspini dataset) 100 | 101 | # The score printed in the browser view is negative to reflect the formulation as a minimization problem. 102 | # The positive score returned by the ICOT.score() functions reflect the "correct" interpretation of the score, 103 | # in which we seek to maximize the criterion (with a maximum value of 1). 104 | 105 | # For larger datasets, we recommend setting warm_start = :oct and threshold = 0.99 to improve the solve time. 106 | -------------------------------------------------------------------------------- /scripts/src/ICOT/src/runningICOT_example2.jl: -------------------------------------------------------------------------------- 1 | using DataFrames, MLDataUtils 2 | using Clustering, Distances 3 | using CSV 4 | using Random 5 | using Logging 6 | using NPZ 7 | 8 | # Set up Logging - we recommend to use this command to avoid package warnings during the model training process. 9 | logger = Logging.SimpleLogger(stderr,Logging.Warn); 10 | global_logger(logger); 11 | 12 | #### Set parameters for the learners 13 | cr = :dunnindex 14 | method = "ICOT_local" 15 | warm_start = :none; 16 | geom_search = true 17 | threshold = 0.99 18 | seed = 1 19 | gridsearch = false 20 | num_tree_restarts = 100 21 | complexity_c = 0.0 22 | min_bucket = 10 23 | maxdepth = 10 24 | 25 | ###### Step 1: Prepare the data 26 | # Read the data - recommend the use of the (deprecated) readtable() command to avoid potential version conflicts with the CSV package. 27 | # data = readtable("../data/ruspini.csv"); 28 | 29 | data = npzread("../data/X2.npy") 30 | # data = data[1:200,:] 31 | true_labels = npzread("../data/Y2.npy") 32 | # true_labels = true_labels[1:200,:] 33 | 34 | # Convert the dataset to a matrix 35 | data = convert(Matrix{Float64}, data); 36 | true_labels = convert(Matrix{Float64},reshape(true_labels,:,1)); 37 | 38 | data = DataFrame(hcat(data,true_labels)); 39 | data_array = convert(Matrix{Float64}, data); 40 | 41 | # Get the number of observations and features 42 | n, p = size(data_array) 43 | data_t = data_array'; 44 | 45 | ##### Step 2: Fit K-means clustering on the dataset to generate a warm-start for ICOT 46 | #Fix the seed 47 | Random.seed!(seed); 48 | 49 | # The ruspini dataset has pre-defined clusters, which we will use to select the cluster count (K) for the K-means algorithm. 50 | # In an unsupervised setting (with no prior-known K), the number of clusters for K means can be selected using the elbow method. 51 | K = length(unique(data_array[:,end])) 52 | 53 | # Run k-means and save the assignments 54 | kmeans_result = kmeans(data_t, K); 55 | assignment = kmeans_result.assignments; 56 | 57 | data_full = DataFrame(hcat(data, assignment, makeunique=true)); 58 | names!(data_full, vcat([Symbol(string("x",k)) for k in range(1,p-1)],[:true_labels, :kmean_assign])); 59 | 60 | # Prepare data for ICOT: features are stored in the matrix X, and the warm-start labels are stored in y 61 | X = data_full[:,1:(p-1)]; y = data_full[:,:true_labels]; 62 | 63 | # ##### Step 3a. Before running ICOT, start by testing the IAI license 64 | # lnr_oct = ICOT.IAI.OptimalTreeClassifier(localsearch = false, max_depth = maxdepth, 65 | # minbucket = min_bucket, 66 | # criterion = :misclassification 67 | # ) 68 | # grid = ICOT.IAI.GridSearch(lnr_oct) 69 | # ICOT.IAI.fit!(grid, X, y) 70 | # ICOT.IAI.showinbrowser(grid.lnr) 71 | 72 | ##### Step 3b. Run ICOT 73 | 74 | # Run ICOT with no warm-start: 75 | warm_start= :none 76 | lnr_ws_none = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 77 | minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 78 | geom_search = geom_search, geom_threshold = threshold); 79 | run_time_icot_ls_none = @elapsed ICOT.fit!(lnr_ws_none, X, y); 80 | 81 | ICOT.showinbrowser(lnr_ws_none) 82 | 83 | score_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:dunnindex); 84 | score_al_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:silhouette); 85 | 86 | # # Run ICOT with an OCT warm-start: fit an OCT as a supervised learning problem with labels "y" and use this as the warm-start 87 | # warm_start= :oct 88 | # lnr_ws_oct = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 89 | # minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 90 | # geom_search = geom_search, geom_threshold = threshold); 91 | # run_time_icot_ls_oct = @elapsed ICOT.fit!(lnr_ws_oct, X, y); 92 | # 93 | # score_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:dunnindex); 94 | # score_al_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:silhouette); 95 | # ICOT.showinbrowser(lnr_ws_oct) 96 | 97 | ##### Comments 98 | # Note that in this example, the OCT tree and ICOT (OCT warm-start, or no warm-start) result in the same solution. 99 | # This is not generally the case, but can occur when the data is easily separated (as in the ruspini dataset) 100 | 101 | # The score printed in the browser view is negative to reflect the formulation as a minimization problem. 102 | # The positive score returned by the ICOT.score() functions reflect the "correct" interpretation of the score, 103 | # in which we seek to maximize the criterion (with a maximum value of 1). 104 | 105 | # For larger datasets, we recommend setting warm_start = :oct and threshold = 0.99 to improve the solve time. 106 | -------------------------------------------------------------------------------- /scripts/src/ICOT/src/runningICOT_example3.jl: -------------------------------------------------------------------------------- 1 | using DataFrames, MLDataUtils 2 | using Clustering, Distances 3 | using CSV 4 | using Random 5 | using Logging 6 | using NPZ 7 | 8 | # Set up Logging - we recommend to use this command to avoid package warnings during the model training process. 9 | logger = Logging.SimpleLogger(stderr,Logging.Warn); 10 | global_logger(logger); 11 | 12 | #### Set parameters for the learners 13 | cr = :dunnindex 14 | method = "ICOT_local" 15 | warm_start = :none; 16 | geom_search = true 17 | threshold = 0.99 18 | seed = 1 19 | gridsearch = false 20 | num_tree_restarts = 100 21 | complexity_c = 0.0 22 | min_bucket = 10 23 | maxdepth = 10 24 | 25 | ###### Step 1: Prepare the data 26 | # Read the data - recommend the use of the (deprecated) readtable() command to avoid potential version conflicts with the CSV package. 27 | # data = readtable("../data/ruspini.csv"); 28 | 29 | data = npzread("../data/X3.npy") 30 | # data = data[1:200,:] 31 | true_labels = npzread("../data/Y3.npy") 32 | # true_labels = true_labels[1:200,:] 33 | 34 | # Convert the dataset to a matrix 35 | data = convert(Matrix{Float64}, data); 36 | true_labels = convert(Matrix{Float64},reshape(true_labels,:,1)); 37 | 38 | data = DataFrame(hcat(data,true_labels)); 39 | data_array = convert(Matrix{Float64}, data); 40 | 41 | # Get the number of observations and features 42 | n, p = size(data_array) 43 | data_t = data_array'; 44 | 45 | ##### Step 2: Fit K-means clustering on the dataset to generate a warm-start for ICOT 46 | #Fix the seed 47 | Random.seed!(seed); 48 | 49 | # The ruspini dataset has pre-defined clusters, which we will use to select the cluster count (K) for the K-means algorithm. 50 | # In an unsupervised setting (with no prior-known K), the number of clusters for K means can be selected using the elbow method. 51 | K = length(unique(data_array[:,end])) 52 | 53 | # Run k-means and save the assignments 54 | kmeans_result = kmeans(data_t, K); 55 | assignment = kmeans_result.assignments; 56 | 57 | data_full = DataFrame(hcat(data, assignment, makeunique=true)); 58 | names!(data_full, vcat([Symbol(string("x",k)) for k in range(1,p-1)],[:true_labels, :kmean_assign])); 59 | 60 | # Prepare data for ICOT: features are stored in the matrix X, and the warm-start labels are stored in y 61 | X = data_full[:,1:(p-1)]; y = data_full[:,:true_labels]; 62 | 63 | # ##### Step 3a. Before running ICOT, start by testing the IAI license 64 | # lnr_oct = ICOT.IAI.OptimalTreeClassifier(localsearch = false, max_depth = maxdepth, 65 | # minbucket = min_bucket, 66 | # criterion = :misclassification 67 | # ) 68 | # grid = ICOT.IAI.GridSearch(lnr_oct) 69 | # ICOT.IAI.fit!(grid, X, y) 70 | # ICOT.IAI.showinbrowser(grid.lnr) 71 | 72 | ##### Step 3b. Run ICOT 73 | 74 | # Run ICOT with no warm-start: 75 | warm_start= :none 76 | lnr_ws_none = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 77 | minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 78 | geom_search = geom_search, geom_threshold = threshold); 79 | run_time_icot_ls_none = @elapsed ICOT.fit!(lnr_ws_none, X, y); 80 | 81 | ICOT.showinbrowser(lnr_ws_none) 82 | 83 | score_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:dunnindex); 84 | score_al_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:silhouette); 85 | 86 | # # Run ICOT with an OCT warm-start: fit an OCT as a supervised learning problem with labels "y" and use this as the warm-start 87 | # warm_start= :oct 88 | # lnr_ws_oct = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 89 | # minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 90 | # geom_search = geom_search, geom_threshold = threshold); 91 | # run_time_icot_ls_oct = @elapsed ICOT.fit!(lnr_ws_oct, X, y); 92 | # 93 | # score_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:dunnindex); 94 | # score_al_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:silhouette); 95 | # ICOT.showinbrowser(lnr_ws_oct) 96 | 97 | ##### Comments 98 | # Note that in this example, the OCT tree and ICOT (OCT warm-start, or no warm-start) result in the same solution. 99 | # This is not generally the case, but can occur when the data is easily separated (as in the ruspini dataset) 100 | 101 | # The score printed in the browser view is negative to reflect the formulation as a minimization problem. 102 | # The positive score returned by the ICOT.score() functions reflect the "correct" interpretation of the score, 103 | # in which we seek to maximize the criterion (with a maximum value of 1). 104 | 105 | # For larger datasets, we recommend setting warm_start = :oct and threshold = 0.99 to improve the solve time. 106 | -------------------------------------------------------------------------------- /scripts/src/ICOT/src/runningICOT_example4.jl: -------------------------------------------------------------------------------- 1 | using DataFrames, MLDataUtils 2 | using Clustering, Distances 3 | using CSV 4 | using Random 5 | using Logging 6 | using NPZ 7 | 8 | # Set up Logging - we recommend to use this command to avoid package warnings during the model training process. 9 | logger = Logging.SimpleLogger(stderr,Logging.Warn); 10 | global_logger(logger); 11 | 12 | #### Set parameters for the learners 13 | cr = :dunnindex 14 | method = "ICOT_local" 15 | warm_start = :none; 16 | geom_search = true 17 | threshold = 0.99 18 | seed = 1 19 | gridsearch = false 20 | num_tree_restarts = 100 21 | complexity_c = 0.0 22 | min_bucket = 10 23 | maxdepth = 10 24 | 25 | ###### Step 1: Prepare the data 26 | # Read the data - recommend the use of the (deprecated) readtable() command to avoid potential version conflicts with the CSV package. 27 | # data = readtable("../data/ruspini.csv"); 28 | 29 | data = npzread("../data/X4.npy") 30 | # data = data[1:200,:] 31 | true_labels = npzread("../data/Y4.npy") 32 | # true_labels = true_labels[1:200,:] 33 | 34 | # Convert the dataset to a matrix 35 | data = convert(Matrix{Float64}, data); 36 | true_labels = convert(Matrix{Float64},reshape(true_labels,:,1)); 37 | 38 | data = DataFrame(hcat(data,true_labels)); 39 | data_array = convert(Matrix{Float64}, data); 40 | 41 | # Get the number of observations and features 42 | n, p = size(data_array) 43 | data_t = data_array'; 44 | 45 | ##### Step 2: Fit K-means clustering on the dataset to generate a warm-start for ICOT 46 | #Fix the seed 47 | Random.seed!(seed); 48 | 49 | # The ruspini dataset has pre-defined clusters, which we will use to select the cluster count (K) for the K-means algorithm. 50 | # In an unsupervised setting (with no prior-known K), the number of clusters for K means can be selected using the elbow method. 51 | K = length(unique(data_array[:,end])) 52 | 53 | # Run k-means and save the assignments 54 | kmeans_result = kmeans(data_t, K); 55 | assignment = kmeans_result.assignments; 56 | 57 | data_full = DataFrame(hcat(data, assignment, makeunique=true)); 58 | names!(data_full, vcat([Symbol(string("x",k)) for k in range(1,p-1)],[:true_labels, :kmean_assign])); 59 | 60 | # Prepare data for ICOT: features are stored in the matrix X, and the warm-start labels are stored in y 61 | X = data_full[:,1:(p-1)]; y = data_full[:,:true_labels]; 62 | 63 | # ##### Step 3a. Before running ICOT, start by testing the IAI license 64 | # lnr_oct = ICOT.IAI.OptimalTreeClassifier(localsearch = false, max_depth = maxdepth, 65 | # minbucket = min_bucket, 66 | # criterion = :misclassification 67 | # ) 68 | # grid = ICOT.IAI.GridSearch(lnr_oct) 69 | # ICOT.IAI.fit!(grid, X, y) 70 | # ICOT.IAI.showinbrowser(grid.lnr) 71 | 72 | ##### Step 3b. Run ICOT 73 | 74 | # Run ICOT with no warm-start: 75 | warm_start= :none 76 | lnr_ws_none = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 77 | minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 78 | geom_search = geom_search, geom_threshold = threshold); 79 | run_time_icot_ls_none = @elapsed ICOT.fit!(lnr_ws_none, X, y); 80 | 81 | ICOT.showinbrowser(lnr_ws_none) 82 | 83 | score_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:dunnindex); 84 | score_al_ws_none = ICOT.score(lnr_ws_none, X, y, criterion=:silhouette); 85 | 86 | # # Run ICOT with an OCT warm-start: fit an OCT as a supervised learning problem with labels "y" and use this as the warm-start 87 | # warm_start= :oct 88 | # lnr_ws_oct = ICOT.InterpretableCluster(ls_num_tree_restarts = num_tree_restarts, ls_random_seed = seed, cp = complexity_c, max_depth = maxdepth, 89 | # minbucket = min_bucket, criterion = cr, ls_warmstart_criterion = cr, kmeans_warmstart = warm_start, 90 | # geom_search = geom_search, geom_threshold = threshold); 91 | # run_time_icot_ls_oct = @elapsed ICOT.fit!(lnr_ws_oct, X, y); 92 | # 93 | # score_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:dunnindex); 94 | # score_al_ws_oct = ICOT.score(lnr_ws_oct, X, y, criterion=:silhouette); 95 | # ICOT.showinbrowser(lnr_ws_oct) 96 | 97 | ##### Comments 98 | # Note that in this example, the OCT tree and ICOT (OCT warm-start, or no warm-start) result in the same solution. 99 | # This is not generally the case, but can occur when the data is easily separated (as in the ruspini dataset) 100 | 101 | # The score printed in the browser view is negative to reflect the formulation as a minimization problem. 102 | # The positive score returned by the ICOT.score() functions reflect the "correct" interpretation of the score, 103 | # in which we seek to maximize the criterion (with a maximum value of 1). 104 | 105 | # For larger datasets, we recommend setting warm_start = :oct and threshold = 0.99 to improve the solve time. 106 | -------------------------------------------------------------------------------- /scripts/src/ICOT/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | This is the documentation repository for the clustering algorithm of the paper "Interpretable Clustering: An Optimization Approach" by Dimitris Bertsimas, Agni Orfanoudaki, and Holly Wiberg. The purpose of this method, ICOT, is to generate interpretable tree-based clustering models. 3 | 4 | # Academic License and Installation 5 | 6 | This code runs in Julia version 1.1.0, which can be downloaded at the following links: 7 | * Linux: https://julialang-s3.julialang.org/bin/linux/x64/1.1/julia-1.1.0-linux-x86_64.tar.gz 8 | * Mac: https://julialang-s3.julialang.org/bin/mac/x64/1.1/julia-1.1.0-mac64.dmg 9 | 10 | *Note: version 1.1.0 is required for compatibility with the package.* 11 | 12 | The ICOT software package uses tools from the [Interpretable AI](https://www.interpretable.ai/) suite and thus it requires an academic license. 13 | 14 | You can download the system image the following links: 15 | * [Linux](https://iai-system-images.s3.amazonaws.com/icot/linux/julia1.1.0/v1.0/sys-linux-julia1.1.0-iai0.1.0-878.zip) 16 | * [Mac](https://iai-system-images.s3.amazonaws.com/icot/macos/julia1.1.0/v1.0/sys-macos-julia1.1.0-iai0.1.0-878.zip) 17 | 18 | You can find detailed installation guidelines for the system image [here](https://docs.interpretable.ai/stable/installation/). 19 | 20 | Once you have completed the installation you will be presented with a machine ID. You can request an academic license by emailing  with your academic institution address and the subject line "Request for ICOT License". Please include the machine ID in the email, and Interpretable AI will generate a license file for your machine. 21 | 22 | # Algorithm Guidelines 23 | 24 | The main command to run the algorithm on a dataset `X` is `ICOT.fit!(learner, X, y);` where the `y` can refer to some data partition that is associated with the dataset. The `learner` is defined as an `ICOT.InterpretableCluster()` object with the following parameters: 25 | * `criterion`: defines the internal validation criterion used to train the ICOT algorithm. The algorithm accepts to options `:dunnindex` ([Dunn 1974](https://www.tandfonline.com/doi/abs/10.1080/01969727408546059)) and `:silhouette` ([Rousseeuw 1987](https://www.sciencedirect.com/science/article/pii/0377042787901257)). 26 | * `ls_warmstart_criterion`: defines the internal validation criterion used to create the initial solution of the warmstart. The same options are offered with the `criterion` parameter. 27 | * `kmeans_warmstart`: provides a warmstart solution to initialize the algorithm. Details are provided in Section 3.3.2 of the [paper](https://link.springer.com/article/10.1007/s10994-020-05896-2). It can take as input `:none`, `:greedy`, and `:oct`. The OCT option uses user-selected labels (i.e. from K-means) to fit an Optimal Classification Tree as a supervised learning problem to provide a warm-start to the algorithm. The greedy option fits a CART tree to these labels. 28 | * `geom_search`: is a boolean parameter that controls where the algorithm will enable the geometric component of the feature space search. See details in Section 3.3.1 of the [paper](https://link.springer.com/article/10.1007/s10994-020-05896-2). 29 | * `geom_threshold`: refers to the percentile of gaps that will be considered by the geometric search for each feature. For example: 0.99. 30 | * `minbucket`: controls the minimum number of points that must be present in every leaf node of the fitted tree. 31 | * `max_depth`: accepts a non-negative Integer to control the maximum depth of the fitted tree. This parameter must always be explicitly set or tuned. We recommend tuning this parameter using the grid search process described in the guide to parameter tuning. 32 | * `ls_random_seed`: is an integer controlling the randomized state of the algorithm. We recommend to set the seed to ensure reproducability of results. 33 | * `ls_num_tree_restarts`: is an integer specifying the number of random restarts to use in the local search algorithm. Must be positive and defaults to 100. The performance of the tree typically increases as this value is increased, but with quickly diminishing returns. The computational cost of training increases linearly with this value. 34 | * `cp`: the complexity parameter that determines the tradeoff between the accuracy and complexity of the tree to control overfitting, as commonly seen in supervised learning problems. The internal validation criteria used for this unsupervised algorithm naturally limit the tree complexity, we recommend to set the value to 0.0. 35 | 36 | You can visualize your model on a browser using the `ICOT.showinbrowser()` command. 37 | 38 | You can evaluate the score on a trained ICOT model using the `score_al_ws_oct = ICOT.score(learner, X, y, criterion);` command. 39 | 40 | We have added an example for the ruspini dataset in the `src` folder called `runningICOT_example.jl`. 41 | 42 | # Citing ICOT 43 | If you use ICOT in your research, we kindly ask that you reference the original [paper](https://link.springer.com/article/10.1007/s10994-020-05896-2) that first introduced the algorithm: 44 | 45 | ``` 46 | @article{bertsimas2020interpretable, 47 | title={Interpretable clustering: an optimization approach}, 48 | author={Bertsimas, Dimitris and Orfanoudaki, Agni and Wiberg, Holly}, 49 | journal={Machine Learning}, 50 | pages={1--50}, 51 | year={2020}, 52 | publisher={Springer} 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /scripts/cmt+.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from src.mst import MST 4 | from src.cart_customerfeats_in_leafmod import append_customer_features_to_product_features 5 | import pickle 6 | import sys 7 | 8 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 9 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 10 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 11 | "TRAIN_TT", "SM_TT", "CAR_TT", 12 | "TRAIN_CO", "SM_CO", "CAR_CO", 13 | "TRAIN_HE", "SM_HE", "CAR_HE"] 14 | 15 | num_features = 4 16 | model_type = 0 17 | is_bias = True 18 | 19 | #scores per data set, pre prune post prune, MSE + LL (test,train) 20 | scores_np = np.zeros((10,2,4)) 21 | 22 | scores_np = np.zeros((10,2,6,15)) 23 | 24 | for i in range(10): 25 | num_features = 4 26 | X = np.load('data/X'+str(i)+'.npy') 27 | P = np.load('data/P'+str(i)+'.npy') 28 | Y = np.load('data/Y'+str(i)+'.npy') 29 | 30 | XV = np.load('data/XV'+str(i)+'.npy') 31 | PV = np.load('data/PV'+str(i)+'.npy') 32 | YV = np.load('data/YV'+str(i)+'.npy') 33 | 34 | XT = np.load('data/XT'+str(i)+'.npy') 35 | PT = np.load('data/PT'+str(i)+'.npy') 36 | YT = np.load('data/YT'+str(i)+'.npy') 37 | 38 | P = P.astype(float) 39 | PV = PV.astype(float) 40 | PT = PT.astype(float) 41 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 42 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 43 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 44 | 45 | 46 | is_continuous = [False for k in range(len(c_features))] 47 | is_continuous[6] = True 48 | is_continuous[8] = False 49 | 50 | P, PV, PT, num_features = append_customer_features_to_product_features(X, XV, XT, 51 | P, PV, PT, 52 | feats_continuous=is_continuous, 53 | model_type=model_type, num_features=num_features) 54 | for depth in [14]: 55 | # for depth in range(15): 56 | #APPLY TREE ALGORITHM. TRAIN TO DEPTH 1 57 | my_tree = MST(max_depth = depth, min_weights_per_node = 50, only_singleton_splits = True, quant_discret = 0.05) 58 | my_tree.fit(X,P,Y,verbose=True, 59 | feats_continuous= is_continuous, 60 | refit_leaves=True,only_singleton_splits = True, 61 | num_features = num_features, is_bias = is_bias, model_type = model_type, 62 | mode = "mnl", batch_size = 100, epochs = 50, steps = 6000, 63 | leaf_mod_thresh=10000000,loglik_proba_cap=0.01); 64 | #ABOVE: leaf_mod_thresh controls whether when fitting a leaf node we apply Newton's method or stochastic gradient descent. 65 | # If the number of training observations in a leaf node <= leaf_mod_thresh, then newton's method 66 | # is applied; otherwise, stochastic gradient descent is applied. 67 | 68 | 69 | YT_flat = np.zeros((YT.shape[0],3)) 70 | YT_flat[np.arange(YT.shape[0]),YT] = 1 71 | Y_flat = np.zeros((Y.shape[0],3)) 72 | Y_flat[np.arange(Y.shape[0]),Y] = 1 73 | 74 | Y_pred = my_tree.predict(XT,PT) 75 | s_Y = Y_pred.shape[0] 76 | scores_np[i,0,0,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),YT]))) 77 | scores_np[i,0,1,depth] = np.mean(np.sum(np.power(Y_pred-YT_flat,2),axis = 1)) 78 | 79 | Y_pred = my_tree.predict(X,P) 80 | s_Y = Y_pred.shape[0] 81 | scores_np[i,0,2,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),Y]))) 82 | scores_np[i,0,3,depth] = np.mean(np.sum(np.power(Y_pred-Y_flat,2),axis = 1)) 83 | 84 | scores_np[i,0,4,depth] = sum(map(lambda x: x.is_leaf,my_tree.tree)) 85 | scores_np[i,0,5,depth] = sum(map(lambda x: x.alpha_thresh < np.inf,my_tree.tree)) 86 | 87 | # my_tree.traverse(verbose=True) 88 | 89 | # Post-pruning metrics 90 | my_tree.prune(XV, PV, YV, verbose=False) 91 | 92 | # my_tree.traverse(verbose=True) 93 | 94 | Y_pred = my_tree.predict(XT,PT) 95 | s_Y = Y_pred.shape[0] 96 | scores_np[i,1,0,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),YT]))) 97 | scores_np[i,1,1,depth] = np.mean(np.sum(np.power(Y_pred-YT_flat,2),axis = 1)) 98 | 99 | Y_pred = my_tree.predict(X,P) 100 | s_Y = Y_pred.shape[0] 101 | scores_np[i,1,2,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),Y]))) 102 | scores_np[i,1,3,depth] = np.mean(np.sum(np.power(Y_pred-Y_flat,2),axis = 1)) 103 | 104 | scores_np[i,1,4,depth] = sum(map(lambda x: x.is_leaf,my_tree.tree)) 105 | scores_np[i,1,5,depth] = sum(map(lambda x: x.alpha_thresh < np.inf,my_tree.tree)) 106 | 107 | 108 | # f = open('CMTtree+'+str(i)+'.p', 'w') 109 | # pickle.dump(my_tree, f) 110 | # f.close() 111 | # 112 | # original_stdout = sys.stdout 113 | # with open('CMTtree+'+str(i)+'.txt', 'w') as f: 114 | # sys.stdout = f; 115 | # my_tree.traverse(verbose=True); 116 | # sys.stdout = original_stdout; -------------------------------------------------------------------------------- /scripts/cmt.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from src.mst import MST 4 | from src.cart_customerfeats_in_leafmod import append_customer_features_to_product_features 5 | 6 | import sys 7 | import pickle 8 | 9 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 10 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 11 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 12 | "TRAIN_TT", "SM_TT", "CAR_TT", 13 | "TRAIN_CO", "SM_CO", "CAR_CO", 14 | "TRAIN_HE", "SM_HE", "CAR_HE"] 15 | 16 | num_features = 4 17 | model_type = 0 18 | is_bias = True 19 | 20 | #scores per data set, pre prune post prune, MSE + LL (test,train) 21 | scores_np = np.zeros((10,2,4)) 22 | 23 | scores_np = np.zeros((10,2,6,15)) 24 | 25 | for i in range(10): 26 | #for i in [9]: 27 | num_features = 4 28 | X = np.load('data/X'+str(i)+'.npy') 29 | P = np.load('data/P'+str(i)+'.npy') 30 | Y = np.load('data/Y'+str(i)+'.npy') 31 | 32 | XV = np.load('data/XV'+str(i)+'.npy') 33 | PV = np.load('data/PV'+str(i)+'.npy') 34 | YV = np.load('data/YV'+str(i)+'.npy') 35 | 36 | XT = np.load('data/XT'+str(i)+'.npy') 37 | PT = np.load('data/PT'+str(i)+'.npy') 38 | YT = np.load('data/YT'+str(i)+'.npy') 39 | 40 | P = P.astype(float) 41 | PV = PV.astype(float) 42 | PT = PT.astype(float) 43 | #very small perturbation because car does not have variation in headway 44 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 45 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 46 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 47 | 48 | 49 | is_continuous = [False for k in range(len(c_features))] 50 | is_continuous[6] = True 51 | is_continuous[8] = False 52 | # 53 | # P, PV, PT, num_features = append_customer_features_to_product_features(X, XV, XT, 54 | # P, PV, PT, 55 | # feats_continuous=is_continuous, 56 | # model_type=model_type, num_features=num_features) 57 | for depth in [14]: 58 | # for depth in [3,4]: 59 | # for depth in range(15): 60 | #APPLY TREE ALGORITHM. TRAIN TO DEPTH 1 61 | my_tree = MST(max_depth = depth, min_weights_per_node = 50, only_singleton_splits = True, quant_discret = 0.05) 62 | my_tree.fit(X,P,Y,verbose=True, 63 | feats_continuous= is_continuous, 64 | refit_leaves=True,only_singleton_splits = True, 65 | num_features = num_features, is_bias = is_bias, model_type = model_type, 66 | mode = "mnl", batch_size = 100, epochs = 50, steps = 6000, 67 | leaf_mod_thresh=10000000,loglik_proba_cap=0.01); 68 | #ABOVE: leaf_mod_thresh controls whether when fitting a leaf node we apply Newton's method or stochastic gradient descent. 69 | # If the number of training observations in a leaf node <= leaf_mod_thresh, then newton's method 70 | # is applied; otherwise, stochastic gradient descent is applied. 71 | 72 | 73 | YT_flat = np.zeros((YT.shape[0],3)) 74 | YT_flat[np.arange(YT.shape[0]),YT] = 1 75 | YV_flat = np.zeros((YV.shape[0],3)) 76 | YV_flat[np.arange(YV.shape[0]),YV] = 1 77 | Y_flat = np.zeros((Y.shape[0],3)) 78 | Y_flat[np.arange(Y.shape[0]),Y] = 1 79 | 80 | Y_pred = my_tree.predict(XT,PT) 81 | s_Y = Y_pred.shape[0] 82 | scores_np[i,0,0,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),YT]))) 83 | scores_np[i,0,1,depth] = np.mean(np.sum(np.power(Y_pred-YT_flat,2),axis = 1)) 84 | 85 | Y_pred = my_tree.predict(X,P) 86 | s_Y = Y_pred.shape[0] 87 | scores_np[i,0,2,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),Y]))) 88 | scores_np[i,0,3,depth] = np.mean(np.sum(np.power(Y_pred-Y_flat,2),axis = 1)) 89 | 90 | scores_np[i,0,4,depth] = sum(map(lambda x: x.is_leaf,my_tree.tree)) 91 | scores_np[i,0,5,depth] = sum(map(lambda x: x.alpha_thresh < np.inf,my_tree.tree)) 92 | 93 | # my_tree.traverse(verbose=True) 94 | 95 | # Post-pruning metrics 96 | my_tree.prune(XV, PV, YV, verbose=False, one_SE_rule = False) 97 | 98 | # my_tree.traverse(verbose=True) 99 | 100 | Y_pred = my_tree.predict(XT,PT) 101 | s_Y = Y_pred.shape[0] 102 | scores_np[i,1,0,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),YT]))) 103 | scores_np[i,1,1,depth] = np.mean(np.sum(np.power(Y_pred-YT_flat,2),axis = 1)) 104 | 105 | Y_pred = my_tree.predict(X,P) 106 | s_Y = Y_pred.shape[0] 107 | scores_np[i,1,2,depth] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),Y]))) 108 | scores_np[i,1,3,depth] = np.mean(np.sum(np.power(Y_pred-Y_flat,2),axis = 1)) 109 | 110 | scores_np[i,1,4,depth] = sum(map(lambda x: x.is_leaf,my_tree.tree)) 111 | scores_np[i,1,5,depth] = sum(map(lambda x: x.alpha_thresh < np.inf,my_tree.tree)) 112 | 113 | # f = open('CMTtree'+str(i)+'.p', 'w') 114 | # pickle.dump(my_tree, f) 115 | # f.close() 116 | # 117 | # original_stdout = sys.stdout 118 | # with open('CMTtree'+str(i)+'.txt', 'w') as f: 119 | # sys.stdout = f; 120 | # my_tree.traverse(verbose=True); 121 | # sys.stdout = original_stdout; 122 | 123 | 124 | -------------------------------------------------------------------------------- /scripts/mnldt+.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This example shows how to apply the MST CMT tree algorithm to a simple synthetic dataset 3 | The ground truth of the synethetic dataset is a CMT of depth 1 4 | 5 | We will train the CMT on this dataset and observe whether it recovers the true CMT used to generate the data 6 | 7 | n = number of historical customers (i.e., training observations) 8 | X is a n x 3 matrix containing customers' contextual information 9 | 3 contexts: 10 | X0: binary {0,1} 11 | X1: multi-categorical {0, 1, 2, 3} 12 | X2: ordinal in {0,0.2,0.4,0.6,0.8,1} 13 | 14 | There are 5 products total. Each customer sees a random assortment of 3 of these 5 products and chooses his favorite product in the assortment. 15 | P[:,:5] encodes the offered assortment: P[i,j] = 1 iff item j was offered to customer i 16 | There are two other product features (besides the assortment indicators) which can be interpreted as price and quality rating. 17 | Product prices are stored in P[:,5:10]. P[i,(j+5)] = price of item j offered to customer i 18 | Quality ratings are stored in P[:,10:15]. P[i,(j+10)] = quality rating of item j offered to customer i 19 | 20 | Y is an n-dim vector, where Y[i] in {0,1,2,3,4} encodes customer i's choice among the products 21 | 22 | The CMT used to generate the data consists of a single split (x2 <= 0.6), with MNLs in each leaf with randomly-generated coefs 23 | 24 | NOTE: for this to run properly, include the following import statement in mst.py: "from leaf_model_mnl import *" 25 | ''' 26 | 27 | #from mst import MST 28 | #from GenMNL import GenMNL 29 | #from responsemodel_kmeans import response_model_kmeans_fit_and_predict 30 | from src.cart_customerfeats_in_leafmod import mnl_cart_fit_prune_and_predict, append_customer_features_to_product_features 31 | import numpy as np 32 | import pandas as pd 33 | 34 | np.set_printoptions(suppress=True) #suppress scientific notation 35 | np.random.seed(0) 36 | 37 | 38 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 39 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 40 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 41 | "TRAIN_TT", "SM_TT", "CAR_TT", 42 | "TRAIN_CO", "SM_CO", "CAR_CO", 43 | "TRAIN_HE", "SM_HE", "CAR_HE"] 44 | 45 | num_features = 4 46 | model_type = 0 47 | is_bias = True 48 | is_continuous = [False for k in range(len(c_features))] 49 | is_continuous[6] = True 50 | is_continuous[8] = False 51 | 52 | scores_np = np.zeros((10,2,5,15)) 53 | 54 | for i in range(10): 55 | 56 | X = np.load('data/X'+str(i)+'.npy') 57 | P = np.load('data/P'+str(i)+'.npy') 58 | Y = np.load('data/Y'+str(i)+'.npy') 59 | 60 | XV = np.load('data/XV'+str(i)+'.npy') 61 | PV = np.load('data/PV'+str(i)+'.npy') 62 | YV = np.load('data/YV'+str(i)+'.npy') 63 | 64 | XT = np.load('data/XT'+str(i)+'.npy') 65 | PT = np.load('data/PT'+str(i)+'.npy') 66 | YT = np.load('data/YT'+str(i)+'.npy') 67 | 68 | P = P.astype(float) 69 | PV = PV.astype(float) 70 | PT = PT.astype(float) 71 | 72 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 73 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 74 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 75 | 76 | ############################################################################# 77 | #(3) Run MNL-CART (CART with MNL response model refit in each leaf) 78 | print("Running MNL-CART") 79 | 80 | for d in [14]: 81 | # for d in range(3,5): 82 | #In this benchmark, we fit MNLs in each leaf using *both* the customer features and product features 83 | #Therefore, we use this function to add the customer features (X) to the MNL features matrix (P) 84 | #NOTE: this function handles binarization of customer features (X) internally prior to appending to product feature matrix P 85 | #----Specifically, the function will binarize all customer features in X satisfying feats_continuous = False 86 | #NOTE: if model_type = 1 (alternative-general coefs), then this function still encodes Pnew in such a way that the customer features have alt-specific coefs 87 | # Pnew, PVnew, PTnew, num_features_new = P,PV,PT,num_features 88 | Pnew, PVnew, PTnew, num_features_new = append_customer_features_to_product_features(X, XV, XT, 89 | P, PV, PT, 90 | feats_continuous=is_continuous, 91 | model_type=model_type, num_features=num_features) 92 | #fit MNL-CART and output test set predictions. See code cart_customerfeats_in_leafmod.py for more details 93 | Y_predT,my_tree = mnl_cart_fit_prune_and_predict(X,Pnew,Y, 94 | XV,PVnew,YV, 95 | XT,PTnew, 96 | feats_continuous=is_continuous, 97 | verbose=True, 98 | one_SE_rule=True, 99 | max_depth=d, min_weights_per_node=50, quant_discret=0.05, 100 | run_in_parallel=False,num_workers=None, 101 | num_features = num_features_new, is_bias = is_bias, model_type = model_type, 102 | mode = "mnl", batch_size = 100, epochs = 100, steps = 5000, 103 | leaf_mod_thresh=1000000000000) 104 | 105 | 106 | YT_flat = np.zeros((YT.shape[0],3)) 107 | YT_flat[np.arange(YT.shape[0]),YT] = 1 108 | 109 | s_Y = Y_predT.shape[0] 110 | scores_np[i,0,0,d] = np.mean(np.log(np.maximum(0.01,Y_predT[np.arange(s_Y),YT]))) 111 | scores_np[i,0,1,d] = np.mean(np.sum(np.power(Y_predT-YT_flat,2),axis = 1)) 112 | 113 | -------------------------------------------------------------------------------- /scripts/mnldt.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This example shows how to apply the MST CMT tree algorithm to a simple synthetic dataset 3 | The ground truth of the synethetic dataset is a CMT of depth 1 4 | 5 | We will train the CMT on this dataset and observe whether it recovers the true CMT used to generate the data 6 | 7 | n = number of historical customers (i.e., training observations) 8 | X is a n x 3 matrix containing customers' contextual information 9 | 3 contexts: 10 | X0: binary {0,1} 11 | X1: multi-categorical {0, 1, 2, 3} 12 | X2: ordinal in {0,0.2,0.4,0.6,0.8,1} 13 | 14 | There are 5 products total. Each customer sees a random assortment of 3 of these 5 products and chooses his favorite product in the assortment. 15 | P[:,:5] encodes the offered assortment: P[i,j] = 1 iff item j was offered to customer i 16 | There are two other product features (besides the assortment indicators) which can be interpreted as price and quality rating. 17 | Product prices are stored in P[:,5:10]. P[i,(j+5)] = price of item j offered to customer i 18 | Quality ratings are stored in P[:,10:15]. P[i,(j+10)] = quality rating of item j offered to customer i 19 | 20 | Y is an n-dim vector, where Y[i] in {0,1,2,3,4} encodes customer i's choice among the products 21 | 22 | The CMT used to generate the data consists of a single split (x2 <= 0.6), with MNLs in each leaf with randomly-generated coefs 23 | 24 | NOTE: for this to run properly, include the following import statement in mst.py: "from leaf_model_mnl import *" 25 | ''' 26 | 27 | #from mst import MST 28 | #from GenMNL import GenMNL 29 | #from responsemodel_kmeans import response_model_kmeans_fit_and_predict 30 | from src.cart_customerfeats_in_leafmod import mnl_cart_fit_prune_and_predict, append_customer_features_to_product_features 31 | import numpy as np 32 | import pandas as pd 33 | 34 | np.set_printoptions(suppress=True) #suppress scientific notation 35 | np.random.seed(0) 36 | 37 | 38 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 39 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 40 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 41 | "TRAIN_TT", "SM_TT", "CAR_TT", 42 | "TRAIN_CO", "SM_CO", "CAR_CO", 43 | "TRAIN_HE", "SM_HE", "CAR_HE"] 44 | 45 | num_features = 4 46 | model_type = 0 47 | is_bias = True 48 | is_continuous = [False for k in range(len(c_features))] 49 | is_continuous[6] = True 50 | is_continuous[8] = False 51 | 52 | scores_np = np.zeros((10,2,5,15)) 53 | 54 | for i in range(10): 55 | 56 | X = np.load('data/X'+str(i)+'.npy') 57 | P = np.load('data/P'+str(i)+'.npy') 58 | Y = np.load('data/Y'+str(i)+'.npy') 59 | 60 | XV = np.load('data/XV'+str(i)+'.npy') 61 | PV = np.load('data/PV'+str(i)+'.npy') 62 | YV = np.load('data/YV'+str(i)+'.npy') 63 | 64 | XT = np.load('data/XT'+str(i)+'.npy') 65 | PT = np.load('data/PT'+str(i)+'.npy') 66 | YT = np.load('data/YT'+str(i)+'.npy') 67 | 68 | P = P.astype(float) 69 | PV = PV.astype(float) 70 | PT = PT.astype(float) 71 | 72 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 73 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 74 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 75 | 76 | ############################################################################# 77 | #(3) Run MNL-CART (CART with MNL response model refit in each leaf) 78 | print("Running MNL-CART") 79 | 80 | for d in [14]: 81 | # for d in range(3,5): 82 | #In this benchmark, we fit MNLs in each leaf using *both* the customer features and product features 83 | #Therefore, we use this function to add the customer features (X) to the MNL features matrix (P) 84 | #NOTE: this function handles binarization of customer features (X) internally prior to appending to product feature matrix P 85 | #----Specifically, the function will binarize all customer features in X satisfying feats_continuous = False 86 | #NOTE: if model_type = 1 (alternative-general coefs), then this function still encodes Pnew in such a way that the customer features have alt-specific coefs 87 | Pnew, PVnew, PTnew, num_features_new = P,PV,PT,num_features 88 | # Pnew, PVnew, PTnew, num_features_new = append_customer_features_to_product_features(X, XV, XT, 89 | # P, PV, PT, 90 | # feats_continuous=is_continuous, 91 | # model_type=model_type, num_features=num_features) 92 | #fit MNL-CART and output test set predictions. See code cart_customerfeats_in_leafmod.py for more details 93 | Y_predT,my_tree = mnl_cart_fit_prune_and_predict(X,Pnew,Y, 94 | XV,PVnew,YV, 95 | XT,PTnew, 96 | feats_continuous=is_continuous, 97 | verbose=True, 98 | one_SE_rule=True, 99 | max_depth=d, min_weights_per_node=50, quant_discret=0.05, 100 | run_in_parallel=False,num_workers=None, 101 | num_features = num_features_new, is_bias = is_bias, model_type = model_type, 102 | mode = "mnl", batch_size = 100, epochs = 100, steps = 5000, 103 | leaf_mod_thresh=1000000000000) 104 | 105 | 106 | YT_flat = np.zeros((YT.shape[0],3)) 107 | YT_flat[np.arange(YT.shape[0]),YT] = 1 108 | 109 | s_Y = Y_predT.shape[0] 110 | scores_np[i,0,0,d] = np.mean(np.log(np.maximum(0.01,Y_predT[np.arange(s_Y),YT]))) 111 | scores_np[i,0,1,d] = np.mean(np.sum(np.power(Y_predT-YT_flat,2),axis = 1)) 112 | 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MST 2 | 3 | This is a Python code base for training "Market Segmentation Trees" (MSTs) (formerly known as "Model Trees for Personalization" (MTPs)). MSTs provide a general framework for jointly performing market segmentation and response modeling. The folder `/scripts` contains the code relative to the case study on Swiss Metro data. Currently, this repo only supports Python 2.7 and does not support Python 3. 4 | 5 | Link to paper: [https://arxiv.org/abs/1906.01174](https://arxiv.org/abs/1906.01174) 6 | 7 | "mst.py" contains the Python class "MST" used for training MSTs. This class supports the following methods: 8 | * `__init__()`: initializes the MST 9 | * `fit()`: trains the MST on data: contexts X, decisions P (labeled as A in this code base), responses Y 10 | * `traverse()`: prints out the learned MST 11 | * `prune()`: prunes the tree on a held-out validation set to prevent overfitting 12 | * `predict()`: predict response distribution given new contexts X and decisions P 13 | 14 | Users can build their own response models for the MST class by filling in the template "leaf_model_template.py". 15 | 16 | Two examples of MSTs are provided here: 17 | (1) "Isotonic Regression Model Trees" (IRTs): Here, the response models are isotonic regression models. The "irt_exmaple.py" file provides an example of running IRTs on a synthetic dataset. 18 | (2) "Choice Model Trees" (CMTs): Here, the response models are MNL choice models. The "cmt_exmaple.py" file provides an example of running CMTs on a synthetic dataset. 19 | 20 | For significantly faster CMT fitting, users can combine this repo with the MNL fitting code found at https://github.com/rtm2130/CMT-R. Please see the CMT-R repo for further instructions. Note that the CMT-R repo is under a GPLv2 license, which is more restrictive on terms of use than this repo's MIT license. The case study on the Swiss Metro data from our paper was run using the code from the CMT-R repo. 21 | 22 | ## Package Installation 23 | 24 | Here we provide guidance on installing the MST package dependencies excluding the files located in the CMT-R repo (https://github.com/rtm2130/CMT-R). For the complete installation instructions including the CMT-R files, see the README.md file of the CMT-R repo. 25 | 26 | ### Prerequisites 27 | 28 | First, clone the MST repo. This can be done through opening a command prompt / terminal and typing: `git clone https://github.com/rtm2130/MST.git` (if this command does not work then install git). 29 | 30 | Install the conda command-line tool. This can be accomplished through installing miniforge, miniconda, or anaconda. We advise users to consult the license terms of use for these tools because as of 2023-02-26 miniconda and anaconda are not free for commercial use. 31 | 32 | Open a command prompt / terminal and execute the following steps: 33 | 1. Update conda: `conda update -n base -c defaults conda` 34 | 2. Install the conda-forge channel into conda: `conda config --add channels conda-forge` 35 | 36 | ### Installing Package Dependencies 37 | 38 | In this step, we will be creating a new conda virtual environment called `mstenv` which will contain Python 2.7.15 and the package dependencies. Open a command prompt / terminal and execute the steps below. 39 | 40 | 1. Build a new MST virtual environment which will be named mstenv with the recommended Python version: `conda create --name mstenv python=2.7.15` 41 | 2. Activate the newly-created MST virtual environment: `conda activate mstenv`. All subsequent steps should be followed within the activated virtual environment. 42 | 3. Install the pandas, scikit-learn, and joblib packages. Execute the following: `conda install pandas`, `conda install scikit-learn`, `conda install -c anaconda joblib` 43 | 4. Install tensorflow ensuring compatibility with python 2.7. The following worked for us: `pip install --upgrade tensorflow` 44 | 5. Deactivate the environment: `conda deactivate`. Going forward, users should activate their MST virtual environment prior to working with the code in this repo via `conda activate mstenv`. 45 | 46 | ## Running the Package Demos / Testing Installation 47 | 48 | To test the package installation or demo the package, users can take the following steps: 49 | 1. Open command prompt / terminal and navigate into the MST directory 50 | 2. Activate the MST virtual environment: `conda activate mstenv` 51 | 3. We will first demo MST's implementation of Choice Model Trees. Open mst.py. At the top of the file under "Import proper leaf model here:" , ensure that only one leaf model is being imported which should read `from leaf_model_mnl import *`. In command prompt / terminal, execute command `python cmt_example.py` which will run the MST on a synthetic choice modeling dataset. At the end of execution, the test set error will be outputted which should be under 0.05. 52 | 5. We will next demo MST's implementation of Isotonic Regression Trees. Open mst.py. At the top of the file under "Import proper leaf model here:" , ensure that only one leaf model is being imported which should read `from leaf_model_isoreg import *`. In command prompt / terminal, execute command `python irt_example.py` which will run the MST on a synthetic ad auction dataset. At the end of execution, the test set error will be outputted which should be under 0.05. 53 | 54 | ## Running MSTs on the Swiss Metro dataset 55 | To run MSTs on the Swiss Metro dataset used by our paper, please take the following steps: 56 | 1. Copy the files leaf_model_mnl_tensorflow.py and mst.py from this repo to the /scripts/src directory 57 | 2. Copy the files newmnlogit.R, leaf_model_mnl.py, and leaf_model_mnl_rmnlogit.py from the https://github.com/rtm2130/CMT-R repo to the /scripts/src directory 58 | 3. Create and activate a virtual environment following the steps from the https://github.com/rtm2130/CMT-R repo 59 | 4. Open the /scripts/src/newmnlogit.R file and at the top of the file, change `ro.r.source("newmnlogit.R")` to `ro.r.source("src/newmnlogit.R")` 60 | 5. In /scripts/src/leaf_model_mnl_rmnlogit.py, within the implementation of the `error(self,A,Y)` function, change `log_probas = -np.log(Ypred[(np.arange(Y.shape[0]),Y)])` to `log_probas = -np.log(np.maximum(Ypred[(np.arange(Y.shape[0]),Y)],0.01))` 61 | 6. Open /scripts/src/mst.py and ensure that at the top of the file the correct leaf model is being imported (`from leaf_model_mnl import *`) 62 | 7. Within the activated virtual environment, execute the scripts for running the Swiss Metro dataset located within the scripts/ directory. 63 | -------------------------------------------------------------------------------- /cmt_example.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This example shows how to apply the MST CMT tree algorithm to a simple synthetic dataset 3 | The ground truth of the synethetic dataset is a CMT of depth 1 4 | 5 | We will train the CMT on this dataset and observe whether it recovers the true CMT used to generate the data 6 | 7 | n = number of historical customers (i.e., training observations) 8 | X is a n x 3 matrix containing customers' contextual information 9 | 3 contexts: 10 | X0: binary {0,1} 11 | X1: binary {0,1} 12 | X2: ordinal in {0,0.2,0.4,0.6,0.8,1} 13 | 14 | There are 5 products total. Each customer sees a random assortment of 3 of these 5 products and chooses his favorite product in the assortment. 15 | P[:,:5] encodes the offered assortment: P[i,j] = 1 iff item j was offered to customer i 16 | There are two other product features (besides the assortment indicators) which can be interpreted as price and quality rating. 17 | Product prices are stored in P[:,5:10]. P[i,(j+5)] = price of item j offered to customer i 18 | Quality ratings are stored in P[:,10:15]. P[i,(j+10)] = quality rating of item j offered to customer i 19 | 20 | Y is an n-dim vector, where Y[i] in {0,1,2,3,4} encodes customer i's choice among the products 21 | 22 | The CMT used to generate the data consists of a single split (x2 <= 0.6), with MNLs in each leaf with randomly-generated coefs 23 | 24 | NOTE: for this to run properly, include the following import statement in mst.py: "from leaf_model_mnl import *" 25 | ''' 26 | 27 | from mst import MST 28 | from GenMNL import GenMNL 29 | import numpy as np 30 | import pandas as pd 31 | 32 | np.set_printoptions(suppress=True) #suppress scientific notation 33 | np.random.seed(0) 34 | 35 | ''' 36 | Generates responses Y, probability distribution of responses Y_prob given contexts X and assortments P 37 | 38 | The CMT used to generate the data consists of a single split (x2 <= 0.6), with MNLs in each leaf with randomly-generated coefs 39 | 40 | Arguments specifying MNL model type 41 | n_items: number of products (integer) 42 | num_features: number of product features (integer), INCLUDING the binary availability feature 43 | model_type: whether the model has alternative varying coefficients (0) or not (1). Type integer (0/1) 44 | (default is 0 meaning each alternative has a separate coeff) 45 | is_bias: whether the utility function has an intercept (default is True). Type boolean (True/False). 46 | ''' 47 | def get_choices(X, P, n_items, num_features, model_type, is_bias): 48 | n = X.shape[0] 49 | 50 | left_inds = np.where(X[:,2] <= 0.6)[0] 51 | right_inds = np.where(X[:,2] > 0.6)[0] 52 | 53 | left_mnl = GenMNL(n_items, num_features, model_type, is_bias) 54 | right_mnl = GenMNL(n_items, num_features, model_type, is_bias) 55 | 56 | Y_prob = np.zeros((n,n_items)) 57 | Y_prob[left_inds] = left_mnl.get_choice_probs(P[left_inds]) 58 | Y_prob[right_inds] = right_mnl.get_choice_probs(P[right_inds]) 59 | 60 | Y = np.zeros(n,dtype=int) 61 | for i in range(n): 62 | Y[i] = np.where(np.random.multinomial(1, Y_prob[i,:]))[0][0] 63 | 64 | return Y, Y_prob 65 | 66 | ''' 67 | Generates contexts X, assortments P, responses Y, response probability distributions Y_prob 68 | 69 | Arguments: 70 | n: number of training observations (integer) 71 | n_items: number of products (integer) 72 | num_features: number of product features (integer), INCLUDING the binary availability feature 73 | model_type: whether the model has alternative varying coefficients (0) or not (1). Type integer (0/1) 74 | (default is 0 meaning each alternative has a separate coeff) 75 | is_bias: whether the utility function has an intercept (default is True). Type boolean (True/False). 76 | ''' 77 | def generate_data(n, n_items, assortment_size, num_features, model_type, is_bias): 78 | 79 | #GENERATE CUSTOMER FEATURES 80 | #X1 is in {0,1} (binary) 81 | X1 = np.random.choice([0,1], size=n, replace=True).reshape((n,1)) 82 | #X2 is in {0,1} (binary) 83 | X2 = np.random.choice([0,1], size=n, replace=True).reshape((n,1)) 84 | #X3 is in {0,0.2,0.4,0.6,0.8,1} (ordinal) 85 | X3 = np.random.choice([0,0.2, 0.4, 0.6, 0.8, 1], size=n, replace=True).reshape((n,1)) 86 | X = np.concatenate((X1,X2,X3),axis = 1) 87 | 88 | #GENERATE ASSORTMENT AND PRODUCT FEATURES 89 | P = np.zeros((n,num_features*n_items)) 90 | for i in range(n): 91 | #generate assortment features (5 products, choose 3 to offer to each customer) 92 | assortment_items = np.random.choice(range(n_items), size=assortment_size, replace=False) 93 | P[i,assortment_items] = 1 94 | #generate price and quality rating features 95 | P[i,n_items:] = np.random.uniform(size=(num_features-1)*n_items) 96 | 97 | #GENERATE OUTCOMES (response probability distributions Y_prob, observed responses Y) 98 | Y, Y_prob = get_choices(X, P, n_items, num_features, model_type, is_bias) 99 | 100 | return X,P,Y,Y_prob 101 | 102 | 103 | #SIMULATED DATA PARAMETERS 104 | n_train = 5000; 105 | n_valid = 2500; 106 | n_test = 2000; 107 | 108 | n_items = 5 109 | assortment_size = 3 110 | num_features = 3 111 | model_type = 0 112 | is_bias = True 113 | 114 | #GENERATE DATA 115 | X,P,Y,Y_prob = generate_data(n_train+n_valid+n_test, n_items, assortment_size, num_features, model_type, is_bias) 116 | XV,PV,YV,Y_probV = X[n_train:(n_train+n_valid)],P[n_train:(n_train+n_valid)],Y[n_train:(n_train+n_valid)],Y_prob[n_train:(n_train+n_valid)] #valid set 117 | XT,PT,YT,Y_probT = X[(n_train+n_valid):],P[(n_train+n_valid):],Y[(n_train+n_valid):],Y_prob[(n_train+n_valid):] #test set 118 | X,P,Y,Y_prob = X[:n_train],P[:n_train],Y[:n_train],Y_prob[:n_train] #training set 119 | 120 | #APPLY TREE ALGORITHM. TRAIN TO DEPTH 1 121 | my_tree = MST(max_depth = 2, min_weights_per_node = 100, quant_discret = 0.05) 122 | my_tree.fit(X,P,Y,verbose=False, 123 | feats_continuous=[False, False, True], 124 | refit_leaves=True, 125 | num_features = num_features, is_bias = is_bias, model_type = model_type, 126 | mode = "mnl", batch_size = 100, epochs = 100, steps = 5000, 127 | leaf_mod_thresh=10000000); 128 | #ABOVE: leaf_mod_thresh controls whether when fitting a leaf node we apply Newton's method or stochastic gradient descent. 129 | # If the number of training observations in a leaf node <= leaf_mod_thresh, then newton's method 130 | # is applied; otherwise, stochastic gradient descent is applied. 131 | 132 | #PRINT OUT THE UNPRUNED TREE. OBSERVE THAT THE FIRST SPLIT IS CORRECT, BUT THERE ARE UNNECESSARY SPLITS AFTER THAT 133 | my_tree.traverse(verbose=True) 134 | #PRUNE THE TREE 135 | my_tree.prune(XV, PV, YV, verbose=False) 136 | #PRINT OUT THE PRUNED TREE. OBSERVE THAT THE UNNECESSARY SPLITS HAVE BEEN PRUNED FROM THE TREE 137 | my_tree.traverse(verbose=True) 138 | #OBSERVE TEST SET MEAN-SQUARED-ERROR 139 | Y_pred_pruned = my_tree.predict(XT,PT) 140 | score_pruned = np.sqrt(np.mean(np.power(Y_pred_pruned-Y_probT,2))) 141 | print(score_pruned) 142 | 143 | 144 | -------------------------------------------------------------------------------- /leaf_model_template.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ''' 4 | MST depends on the classes and functions below. 5 | These classes/methods are used to define the leaf model object in each leaf node, 6 | as well as helper functions for certain operations in the tree fitting procedure. 7 | 8 | One can feel free to edit the code below to accommodate any leaf node model. 9 | The leaf node model is fit on data (A,Y). (A is are the decisions "P" in the paper). 10 | Make sure to add an import statement to mst.py importing this leaf model class. 11 | 12 | Summary of methods and functions to specify: 13 | Methods as a part of class LeafModel: fit(), predict(), to_string(), error(), error_pruning() 14 | Other helper functions: get_sub(), are_Ys_diverse() 15 | 16 | ''' 17 | 18 | ''' 19 | LeafModel: the model used in each leaf. 20 | Has five methods: fit, predict, to_string, error, error_pruning 21 | ''' 22 | class LeafModel(object): 23 | 24 | #Any additional args passed to MST's init() function are directly passed here 25 | def __init__(self,*args,**kwargs): 26 | return 27 | 28 | ''' 29 | This function trains the leaf model on the data (A,Y,weights). 30 | 31 | A and Y can take any form (lists, matrices, vectors, etc.). For our applications, I recommend making Y 32 | the response data (e.g., choices) and A alternative-specific data (e.g., prices, choice sets) 33 | 34 | weights: a numpy array of case weights. Is 1-dimensional, with weights[i] yielding 35 | weight of observation/customer i. If you know you will not be using case weights 36 | in your particular application, you can ignore this input entirely. 37 | 38 | refit: boolean which equals True iff leaf model is being refit after tree splits have been decided. Since 39 | the tree split evaluation process requires fitting a large number of leaf models, one might wish to 40 | fit the leaf models on only a subset of data or for less training iterations when refit=False. Practitioners 41 | can feel free to ignore this parameter in their leaf model design. 42 | 43 | Returns 0 or 1. 44 | 0: No errors occurred when fitting leaf node model 45 | 1: An error occurred when fitting the leaf node model (probably due to insufficient data) 46 | If fit returns 1, then the tree will not consider the split that led to this leaf node model 47 | 48 | fit_init is a LeafModel object which represents a previously-trained leaf node model. 49 | If specified, fit_init is used for initialization when training this current LeafModel object. 50 | Useful for faster computation when fit_init's coefficients are close to the optimal solution of the new data. 51 | 52 | For those interested in defining their own leaf node functions: 53 | (1) It is not required to use the fit_init argument in your code 54 | (2) All edge cases must be handled in code below (ex: arguments 55 | consist of a single entry, weights are all zero, Y has one unique choice, etc.). 56 | In these cases, either hard-code a model that works with these edge-cases (e.g., 57 | if all Ys = 1, predict 1 with probability one), or have the fit function return 1 (error) 58 | (3) Store the fitted model (or its coefficients) as an attribute to the self object. You can name the attribute 59 | anything you want (i.e., it does not have to be self.model_obj below), 60 | as long as its consistent with your predict_prob() and to_string() methods 61 | 62 | Any additional args passed to MST's fit() function are directly passed here 63 | ''' 64 | def fit(self, A, Y, weights, fit_init=None, refit=False, *args,**kwargs): 65 | return 66 | 67 | 68 | ''' 69 | This function applies model from fit() to predict response data given new data A. 70 | Returns a numpy vector/matrix of response probabilities (one list entry per observation, i.e. l[i] yields prediction for ith obs.). 71 | Note: make sure to call fit() first before this method. 72 | 73 | Any additional args passed to MST's predict() function are directly passed here 74 | ''' 75 | def predict(self, A, *args,**kwargs): 76 | return 77 | 78 | ''' 79 | This function outputs the errors for each observation in pair (A,Y). 80 | Used in training when comparing different tree splits. 81 | Ex: log-likelihood between observed data Y and predict(A) 82 | Any error metric can be used, so long as: 83 | (1) lower error = "better" fit 84 | (2) error >= 0, where error = 0 means no error 85 | (3) error of fit on a group of data points = sum(errors of each data point) 86 | 87 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 88 | in the self object. 89 | ''' 90 | def error(self,A,Y): 91 | return 92 | 93 | ''' 94 | This function outputs the errors for each observation in pair (A,Y). 95 | Used in pruning to determine the best tree subset. 96 | Ex: mean-squared-error between observed data Y and predict(A) 97 | Any error metric can be used, so long as: 98 | (1) lower error = "better" fit 99 | (2) error >= 0, where error = 0 means no error 100 | (3) error of fit on a group of data points = sum(errors of each data point) 101 | 102 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 103 | in the self object. 104 | ''' 105 | def error_pruning(self,A,Y): 106 | return 107 | 108 | ''' 109 | This function returns the string representation of the fitted model 110 | Used in traverse() method, which traverses the tree and prints out all terminal node models 111 | 112 | Any additional args passed to MST's traverse() function are directly passed here 113 | ''' 114 | def to_string(self,*leafargs,**leafkwargs): 115 | return 116 | 117 | 118 | ''' 119 | Given decision vars A, response data Y, and observation indices data_inds, 120 | extract those observations of A and Y corresponding to data_inds 121 | 122 | If only decision vars A is given, returns A. 123 | If only response data Y is given, returns Y. 124 | 125 | If is_boolvec=True, data_inds takes the form of a boolean vector which indicates 126 | the elements we wish to extract. Otherwise, data_inds takes the form of the indices 127 | themselves (i.e., ints). 128 | 129 | Used to partition the data in the tree-fitting procedure 130 | ''' 131 | def get_sub(data_inds,A=None,Y=None,is_boolvec=False): 132 | if A is None: 133 | return Y[data_inds] 134 | if Y is None: 135 | return A[data_inds] 136 | else: 137 | return A[data_inds],Y[data_inds] 138 | 139 | ''' 140 | This function takes as input response data Y and outputs a boolean corresponding 141 | to whether all of the responses in Y are the same. 142 | 143 | It is used as a test for whether we should make a node a leaf. If are_Ys_diverse(Y)=False, 144 | then the node will become a leaf. Otherwise, if the node passes the other tests (doesn't exceed 145 | max depth, etc), we will consider splitting on the node. If you do not want to specify any 146 | termination criterion, simply "return True" 147 | ''' 148 | def are_Ys_diverse(Y): 149 | return (len(np.unique(Y)) > 1) 150 | -------------------------------------------------------------------------------- /leaf_model_mnl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from leaf_model_mnl_tensorflow import * 3 | 4 | ''' 5 | MST depends on the classes and functions below. 6 | These classes/methods are used to define the leaf model object in each leaf node, 7 | as well as helper functions for certain operations in the tree fitting procedure. 8 | 9 | One can feel free to edit the code below to accommodate any leaf node model. 10 | The leaf node model is fit on data (A,Y). (A is are the decisions "P" in the paper). 11 | Make sure to add an import statement to mst.py importing this leaf model class. 12 | 13 | Summary of methods and functions to specify: 14 | Methods as a part of class LeafModel: fit(), predict(), to_string(), error(), error_pruning() 15 | Other helper functions: get_sub(), are_Ys_diverse() 16 | 17 | ''' 18 | 19 | ''' 20 | LeafModel: the model used in each leaf. 21 | Has five methods: fit, predict, to_string, error, error_pruning 22 | ''' 23 | class LeafModel(object): 24 | 25 | #Any additional args passed to MST's init() function are directly passed here 26 | def __init__(self,*args,**kwargs): 27 | self.mnl = None 28 | return 29 | 30 | ''' 31 | This function trains the leaf model on the data (A,Y,weights). 32 | 33 | A and Y can take any form (lists, matrices, vectors, etc.). For our applications, I recommend making Y 34 | the response data (e.g., choices) and A alternative-specific data (e.g., prices, choice sets) 35 | 36 | weights: a numpy array of case weights. Is 1-dimensional, with weights[i] yielding 37 | weight of observation/customer i. If you know you will not be using case weights 38 | in your particular application, you can ignore this input entirely. 39 | 40 | refit: boolean which equals True iff leaf model is being refit after tree splits have been decided. Since 41 | the tree split evaluation process requires fitting a large number of leaf models, one might wish to 42 | fit the leaf models on only a subset of data or for less training iterations when refit=False. Practitioners 43 | can feel free to ignore this parameter in their leaf model design. 44 | 45 | Returns 0 or 1. 46 | 0: No errors occurred when fitting leaf node model 47 | 1: An error occurred when fitting the leaf node model (probably due to insufficient data) 48 | If fit returns 1, then the tree will not consider the split that led to this leaf node model 49 | 50 | fit_init is a LeafModel object which represents a previously-trained leaf node model. 51 | If specified, fit_init is used for initialization when training this current LeafModel object. 52 | Useful for faster computation when fit_init's coefficients are close to the optimal solution of the new data. 53 | 54 | For those interested in defining their own leaf node functions: 55 | (1) It is not required to use the fit_init argument in your code 56 | (2) All edge cases must be handled in code below (ex: arguments 57 | consist of a single entry, weights are all zero, Y has one unique choice, etc.). 58 | In these cases, either hard-code a model that works with these edge-cases (e.g., 59 | if all Ys = 1, predict 1 with probability one), or have the fit function return 1 (error) 60 | (3) Store the fitted model (or its coefficients) as an attribute to the self object. You can name the attribute 61 | anything you want (i.e., it does not have to be self.model_obj below), 62 | as long as its consistent with your predict_prob() and to_string() methods 63 | 64 | Any additional args passed to MST's fit() function are directly passed here 65 | ''' 66 | def fit(self, A, Y, weights, fit_init=None, leaf_mod_thresh=1000000, **kwargs): 67 | #note: the leaf_mod_thresh argument is only used by the alternative implementation 68 | #of this function found here: https://github.com/rtm2130/CMT-R/blob/main/leaf_model_mnl.py 69 | self.leaf_mod_type = "tensorflow" 70 | if fit_init is not None and fit_init.leaf_mod_type == "tensorflow": 71 | fit_init = fit_init.mnl 72 | else: 73 | fit_init = None 74 | 75 | if self.mnl is None: 76 | self.mnl = LeafModelTensorflow() 77 | 78 | error = self.mnl.fit(A, Y, weights, fit_init=fit_init, **kwargs) 79 | return(error) 80 | 81 | ''' 82 | This function applies model from fit() to predict response data given new data A. 83 | Returns a numpy vector/matrix of response probabilities (one list entry per observation, i.e. l[i] yields prediction for ith obs.). 84 | Note: make sure to call fit() first before this method. 85 | 86 | Any additional args passed to MST's predict() function are directly passed here 87 | ''' 88 | def predict(self, *args,**kwargs): 89 | return(self.mnl.predict(*args,**kwargs)) 90 | 91 | 92 | ''' 93 | This function outputs the errors for each observation in pair (A,Y). 94 | Used in training when comparing different tree splits. 95 | Ex: log-likelihood between observed data Y and predict(A) 96 | Any error metric can be used, so long as: 97 | (1) lower error = "better" fit 98 | (2) error >= 0, where error = 0 means no error 99 | (3) error of fit on a group of data points = sum(errors of each data point) 100 | 101 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 102 | in the self object. 103 | ''' 104 | def error(self,A,Y): 105 | return(self.mnl.error(A,Y)) 106 | 107 | ''' 108 | This function outputs the errors for each observation in pair (A,Y). 109 | Used in pruning to determine the best tree subset. 110 | Ex: mean-squared-error between observed data Y and predict(A) 111 | Any error metric can be used, so long as: 112 | (1) lower error = "better" fit 113 | (2) error >= 0, where error = 0 means no error 114 | (3) error of fit on a group of data points = sum(errors of each data point) 115 | 116 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 117 | in the self object. 118 | ''' 119 | def error_pruning(self,A,Y): 120 | return(self.mnl.error_pruning(A,Y)) 121 | 122 | ''' 123 | This function returns the string representation of the fitted model 124 | Used in traverse() method, which traverses the tree and prints out all terminal node models 125 | 126 | Any additional args passed to MST's traverse() function are directly passed here 127 | ''' 128 | def to_string(self,*leafargs,**leafkwargs): 129 | return(self.mnl.to_string(*leafargs,**leafkwargs)) 130 | 131 | 132 | ''' 133 | Given decision vars A, response data Y, and observation indices data_inds, 134 | extract those observations of A and Y corresponding to data_inds 135 | 136 | If only decision vars A is given, returns A. 137 | If only response data Y is given, returns Y. 138 | 139 | If is_boolvec=True, data_inds takes the form of a boolean vector which indicates 140 | the elements we wish to extract. Otherwise, data_inds takes the form of the indices 141 | themselves (i.e., ints). 142 | 143 | Used to partition the data in the tree-fitting procedure 144 | ''' 145 | def get_sub(data_inds,A=None,Y=None,is_boolvec=False): 146 | if A is None: 147 | return Y[data_inds] 148 | if Y is None: 149 | return A[data_inds] 150 | else: 151 | return A[data_inds],Y[data_inds] 152 | 153 | ''' 154 | This function takes as input response data Y and outputs a boolean corresponding 155 | to whether all of the responses in Y are the same. 156 | 157 | It is used as a test for whether we should make a node a leaf. If are_Ys_diverse(Y)=False, 158 | then the node will become a leaf. Otherwise, if the node passes the other tests (doesn't exceed 159 | max depth, etc), we will consider splitting on the node. If you do not want to specify any 160 | termination criterion, simply "return True" 161 | ''' 162 | def are_Ys_diverse(Y): 163 | return (len(np.unique(Y)) > 1) 164 | 165 | -------------------------------------------------------------------------------- /leaf_model_isoreg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.isotonic import IsotonicRegression 4 | 5 | ''' 6 | MST depends on the classes and functions below. 7 | These classes/methods are used to define the leaf model object in each leaf node, 8 | as well as helper functions for certain operations in the tree fitting procedure. 9 | 10 | One can feel free to edit the code below to accommodate any leaf node model. 11 | The leaf node model used here is an isotonic regression model fit 12 | on data (A,Y). A is a vector of bids, and Y is a vector of 0s/1s corresponding 13 | to auction win/losses (A is are the decisions "P" in the paper). 14 | Make sure to add an import statement to mst.py importing this leaf model class. 15 | 16 | Summary of methods and functions to specify: 17 | Methods as a part of class LeafModel: fit(), predict(), to_string(), error(), error_pruning() 18 | Other helper functions: get_sub(), are_Ys_diverse() 19 | ''' 20 | 21 | ''' 22 | LeafModel: the model used in each leaf. 23 | Has five methods: fit, predict, to_string, error, error_pruning 24 | ''' 25 | class LeafModel(object): 26 | 27 | #Any additional args passed to MST's init() function are directly passed here 28 | def __init__(self,*args,**kwargs): 29 | return 30 | 31 | ''' 32 | This function trains the leaf model on the data (A,Y,weights). 33 | 34 | A and Y can take any form (lists, matrices, vectors, etc.). For our applications, I recommend making Y 35 | the response data (e.g., choices) and A alternative-specific data (e.g., prices, choice sets) 36 | 37 | weights: a numpy array of case weights. Is 1-dimensional, with weights[i] yielding 38 | weight of observation/customer i. If you know you will not be using case weights 39 | in your particular application, you can ignore this input entirely. 40 | 41 | refit: boolean which equals True iff leaf model is being refit after tree splits have been decided. Since 42 | the tree split evaluation process requires fitting a large number of leaf models, one might wish to 43 | fit the leaf models on only a subset of data or for less training iterations when refit=False. Practitioners 44 | can feel free to ignore this parameter in their leaf model design. 45 | 46 | Returns 0 or 1. 47 | 0: No errors occurred when fitting leaf node model 48 | 1: An error occurred when fitting the leaf node model (probably due to insufficient data) 49 | If fit returns 1, then the tree will not consider the split that led to this leaf node model 50 | 51 | fit_init is a LeafModel object which represents a previously-trained leaf node model. 52 | If specified, fit_init is used for initialization when training this current LeafModel object. 53 | Useful for faster computation when fit_init's coefficients are close to the optimal solution of the new data. 54 | 55 | For those interested in defining their own leaf node functions: 56 | (1) It is not required to use the fit_init argument in your code 57 | (2) All edge cases must be handled in code below (ex: arguments 58 | consist of a single entry, weights are all zero, Y has one unique choice, etc.). 59 | In these cases, either hard-code a model that works with these edge-cases (e.g., 60 | if all Ys = 1, predict 1 with probability one), or have the fit function return 1 (error) 61 | (3) Store the fitted model (or its coefficients) as an attribute to the self object. You can name the attribute 62 | anything you want (i.e., it does not have to be self.model_obj below), 63 | as long as its consistent with your predict_prob() and to_string() methods 64 | 65 | Any additional args passed to MST's fit() function are directly passed here 66 | ''' 67 | def fit(self, A, Y, weights, fit_init=None, refit=False, increasing=True): 68 | #fit isotonic regression model. 69 | model = IsotonicRegression(increasing=increasing,out_of_bounds="clip",y_min=0.0,y_max=1.0) 70 | model.fit(X=A,y=Y,sample_weight=weights) 71 | self.model_obj = model 72 | return(0) 73 | 74 | 75 | ''' 76 | This function applies model from fit() to predict response data given new data A. 77 | Returns a numpy vector/matrix of response probabilities (one list entry per observation, i.e. l[i] yields prediction for ith obs.). 78 | Note: make sure to call fit() first before this method. 79 | 80 | Any additional args passed to MST's predict() function are directly passed here 81 | ''' 82 | def predict(self, A, *args,**kwargs): 83 | return self.model_obj.predict(A) 84 | 85 | ''' 86 | This function outputs the errors for each observation in pair (A,Y). 87 | Used in training when comparing different tree splits. 88 | Ex: log-likelihood between observed data Y and predict(A) 89 | Any error metric can be used, so long as: 90 | (1) lower error = "better" fit 91 | (2) error >= 0, where error = 0 means no error 92 | (3) error of fit on a group of data points = sum(errors of each data point) 93 | 94 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 95 | in the self object. 96 | ''' 97 | def error(self,A,Y): 98 | #Here I define the error metric to be weighted mean-square error (brier score) 99 | Ypred = self.predict(A) 100 | errors = (Y-Ypred)**2.0 101 | return errors 102 | 103 | ''' 104 | This function outputs the errors for each observation in pair (A,Y). 105 | Used in pruning to determine the best tree subset. 106 | Ex: mean-squared-error between observed data Y and predict(A) 107 | Any error metric can be used, so long as: 108 | (1) lower error = "better" fit 109 | (2) error >= 0, where error = 0 means no error 110 | (3) error of fit on a group of data points = sum(errors of each data point) 111 | 112 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 113 | in the self object. 114 | ''' 115 | def error_pruning(self,A,Y): 116 | #Here I define the error metric to be weighted mean-square error (brier score) 117 | Ypred = self.predict(A) 118 | errors = (Y-Ypred)**2.0 119 | return errors 120 | 121 | ''' 122 | This function returns the string representation of the fitted model 123 | Used in traverse() method, which traverses the tree and prints out all terminal node models 124 | 125 | Any additional args passed to MST's traverse() function are directly passed here 126 | ''' 127 | def to_string(self,*leafargs,**leafkwargs): 128 | return "Isotonic Regression Model" 129 | 130 | 131 | ''' 132 | Given decision vars A, response data Y, and observation indices data_inds, 133 | extract those observations of A and Y corresponding to data_inds 134 | 135 | If only decision vars A is given, returns A. 136 | If only response data Y is given, returns Y. 137 | 138 | If is_boolvec=True, data_inds takes the form of a boolean vector which indicates 139 | the elements we wish to extract. Otherwise, data_inds takes the form of the indices 140 | themselves (i.e., ints). 141 | 142 | Used to partition the data in the tree-fitting procedure 143 | ''' 144 | def get_sub(data_inds,A=None,Y=None,is_boolvec=False): 145 | if A is None: 146 | return Y[data_inds] 147 | if Y is None: 148 | return A[data_inds] 149 | else: 150 | return A[data_inds],Y[data_inds] 151 | 152 | ''' 153 | This function takes as input response data Y and outputs a boolean corresponding 154 | to whether all of the responses in Y are the same. 155 | 156 | It is used as a test for whether we should make a node a leaf. If are_Ys_diverse(Y)=False, 157 | then the node will become a leaf. Otherwise, if the node passes the other tests (doesn't exceed 158 | max depth, etc), we will consider splitting on the node. If you do not want to specify any 159 | termination criterion, simply "return True" 160 | ''' 161 | def are_Ys_diverse(Y): 162 | return (len(np.unique(Y)) > 1) 163 | -------------------------------------------------------------------------------- /scripts/src/leaf_model_cart.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ''' 4 | MST depends on the classes and functions below. 5 | These classes/methods are used to define the leaf model object in each leaf node, 6 | as well as helper functions for certain operations in the tree fitting procedure. 7 | 8 | One can feel free to edit the code below to accommodate any leaf node model. 9 | The leaf node model used here is an isotonic regression model fit 10 | on data (A,Y). A is a vector of bids, and Y is a vector of 0s/1s corresponding 11 | to auction win/losses (A is are the decisions "P" in the paper). 12 | Make sure to add an import statement to mst.py importing this leaf model class. 13 | 14 | Summary of methods and functions to specify: 15 | Methods as a part of class LeafModel: fit(), predict(), to_string(), error(), error_pruning() 16 | Other helper functions: get_sub(), are_Ys_diverse() 17 | ''' 18 | 19 | ''' 20 | LeafModel: the model used in each leaf. 21 | Has five methods: fit, predict, to_string, error, error_pruning 22 | ''' 23 | class LeafModel(object): 24 | 25 | #Any additional args passed to MST's init() function are directly passed here 26 | def __init__(self,*args,**kwargs): 27 | return 28 | 29 | ''' 30 | This function trains the leaf model on the data (A,Y,weights). 31 | 32 | A and Y can take any form (lists, matrices, vectors, etc.). For our applications, I recommend making Y 33 | the response data (e.g., choices) and A alternative-specific data (e.g., prices, choice sets) 34 | 35 | weights: a numpy array of case weights. Is 1-dimensional, with weights[i] yielding 36 | weight of observation/customer i. If you know you will not be using case weights 37 | in your particular application, you can ignore this input entirely. 38 | 39 | refit: boolean which equals True iff leaf model is being refit after tree splits have been decided. Since 40 | the tree split evaluation process requires fitting a large number of leaf models, one might wish to 41 | fit the leaf models on only a subset of data or for less training iterations when refit=False. Practitioners 42 | can feel free to ignore this parameter in their leaf model design. 43 | 44 | Returns 0 or 1. 45 | 0: No errors occurred when fitting leaf node model 46 | 1: An error occurred when fitting the leaf node model (probably due to insufficient data) 47 | If fit returns 1, then the tree will not consider the split that led to this leaf node model 48 | 49 | fit_init is a LeafModel object which represents a previously-trained leaf node model. 50 | If specified, fit_init is used for initialization when training this current LeafModel object. 51 | Useful for faster computation when fit_init's coefficients are close to the optimal solution of the new data. 52 | 53 | For those interested in defining their own leaf node functions: 54 | (1) It is not required to use the fit_init argument in your code 55 | (2) All edge cases must be handled in code below (ex: arguments 56 | consist of a single entry, weights are all zero, Y has one unique choice, etc.). 57 | In these cases, either hard-code a model that works with these edge-cases (e.g., 58 | if all Ys = 1, predict 1 with probability one), or have the fit function return 1 (error) 59 | (3) Store the fitted model (or its coefficients) as an attribute to the self object. You can name the attribute 60 | anything you want (i.e., it does not have to be self.model_obj below), 61 | as long as its consistent with your predict_prob() and to_string() methods 62 | 63 | Any additional args passed to MST's fit() function are directly passed here 64 | ''' 65 | def fit(self, A, Y, weights, fit_init=None, num_features=2, **kwargs): 66 | n_obs = len(Y) 67 | n_items = int(A.shape[1]/num_features) 68 | self.n_items = n_items 69 | Y_bool = np.zeros((n_obs, n_items)) 70 | Y_bool[range(n_obs),Y] = 1 71 | self.Ypred = np.mean(Y_bool, axis=0) 72 | return(0) 73 | 74 | 75 | ''' 76 | This function applies model from fit() to predict response data given new data A. 77 | Returns a numpy vector/matrix of response probabilities (one list entry per observation, i.e. l[i] yields prediction for ith obs.). 78 | Note: make sure to call fit() first before this method. 79 | 80 | Any additional args passed to MST's predict() function are directly passed here 81 | ''' 82 | def predict(self, A, *args,**kwargs): 83 | return np.tile(self.Ypred, (A.shape[0], 1)) 84 | 85 | ''' 86 | This function outputs the errors for each observation in pair (A,Y). 87 | Used in training when comparing different tree splits. 88 | Ex: log-likelihood between observed data Y and predict(A) 89 | Any error metric can be used, so long as: 90 | (1) lower error = "better" fit 91 | (2) error >= 0, where error = 0 means no error 92 | (3) error of fit on a group of data points = sum(errors of each data point) 93 | 94 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 95 | in the self object. 96 | ''' 97 | def error(self,A,Y): 98 | Ypred = self.predict(A) 99 | # log_probas = -np.log(np.maximum(Ypred[(np.arange(Y.shape[0]),Y)],0.001)) 100 | log_probas = -np.log(Ypred[(np.arange(Y.shape[0]),Y)]) 101 | return(log_probas) 102 | 103 | ''' 104 | This function outputs the errors for each observation in pair (A,Y). 105 | Used in pruning to determine the best tree subset. 106 | Ex: mean-squared-error between observed data Y and predict(A) 107 | Any error metric can be used, so long as: 108 | (1) lower error = "better" fit 109 | (2) error >= 0, where error = 0 means no error 110 | (3) error of fit on a group of data points = sum(errors of each data point) 111 | 112 | How to pass additional arguments to this function: simply pass these arguments to the init()/fit() functions and store them 113 | in the self object. 114 | ''' 115 | def error_pruning(self,A,Y): 116 | #Here I define the error metric to be weighted mean-square error (brier score) 117 | Ypred = self.predict(A) 118 | Z = np.zeros(Ypred.shape) 119 | # errors = (1-Ypred[(np.arange(Y.shape[0]),Y)])**2.0 120 | Z[(np.arange(Y.shape[0]),Y)] = 1.0 121 | errors = np.sum((Z-Ypred)**2.0,axis = 1) 122 | return errors 123 | 124 | ''' 125 | This function returns the string representation of the fitted model 126 | Used in traverse() method, which traverses the tree and prints out all terminal node models 127 | 128 | Any additional args passed to MST's traverse() function are directly passed here 129 | ''' 130 | def to_string(self,*leafargs,**leafkwargs): 131 | return "Y reg pred: " + str(self.Ypred) 132 | 133 | 134 | ''' 135 | Given decision vars A, response data Y, and observation indices data_inds, 136 | extract those observations of A and Y corresponding to data_inds 137 | 138 | If only decision vars A is given, returns A. 139 | If only response data Y is given, returns Y. 140 | 141 | If is_boolvec=True, data_inds takes the form of a boolean vector which indicates 142 | the elements we wish to extract. Otherwise, data_inds takes the form of the indices 143 | themselves (i.e., ints). 144 | 145 | Used to partition the data in the tree-fitting procedure 146 | ''' 147 | def get_sub(data_inds,A=None,Y=None,is_boolvec=False): 148 | if A is None: 149 | return Y[data_inds] 150 | if Y is None: 151 | return A[data_inds] 152 | else: 153 | return A[data_inds],Y[data_inds] 154 | 155 | ''' 156 | This function takes as input response data Y and outputs a boolean corresponding 157 | to whether all of the responses in Y are the same. 158 | 159 | It is used as a test for whether we should make a node a leaf. If are_Ys_diverse(Y)=False, 160 | then the node will become a leaf. Otherwise, if the node passes the other tests (doesn't exceed 161 | max depth, etc), we will consider splitting on the node. If you do not want to specify any 162 | termination criterion, simply "return True" 163 | ''' 164 | def are_Ys_diverse(Y): 165 | return (len(np.unique(Y)) > 1) 166 | -------------------------------------------------------------------------------- /scripts/mnlicot+.py: -------------------------------------------------------------------------------- 1 | 2 | #from GenMNL import GenMNL 3 | from mst import MST 4 | import numpy as np 5 | from cart_customerfeats_in_leafmod import append_customer_features_to_product_features 6 | # import pandas as pd 7 | 8 | 9 | #np.set_printoptions(suppress=True) #suppress scientific notation 10 | #np.random.seed(0) 11 | # df = pd.read_csv('swissmetro.dat',sep='\t') 12 | # df["CAR_HE"] = 0 13 | 14 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 15 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 16 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 17 | "TRAIN_TT", "SM_TT", "CAR_TT", 18 | "TRAIN_CO", "SM_CO", "CAR_CO", 19 | "TRAIN_HE", "SM_HE", "CAR_HE"] 20 | # target = "CHOICE" 21 | # df = df[df[target] > 0] 22 | # df.loc[:,target] = df.loc[:,target]-1 23 | # df = df.reset_index() 24 | 25 | # df = df.sample(frac=1).reset_index(drop=True) 26 | 27 | 28 | # def prepare_data(df, k1, k2): 29 | # ''' 30 | # Prepares a partition of the data into train, validation and test sets 31 | 32 | # Args: 33 | # k1 -> number of observations in the test set 34 | # k2 -> number of observations in the validation set 35 | # ''' 36 | # n = df.shape[0] 37 | # selected_indices = np.random.choice(range(n), size=k1 + k2, replace=False) 38 | # test_indices = np.random.choice(selected_indices, size=k1, replace=False) 39 | # validation_indices = np.setdiff1d(selected_indices,test_indices) 40 | # train_indices = np.setdiff1d(np.arange(n),selected_indices) 41 | # Y = df.loc[train_indices,target].values 42 | # X = df.loc[train_indices,c_features].values 43 | # P = df.loc[train_indices,p_features].values 44 | # YT = df.loc[test_indices,target].values 45 | # XT = df.loc[test_indices,c_features].values 46 | # PT = df.loc[test_indices,p_features].values 47 | # YV = df.loc[validation_indices,target].values 48 | # XV = df.loc[validation_indices,c_features].values 49 | # PV = df.loc[validation_indices,p_features].values 50 | 51 | # return X,P,Y,XV,PV,YV,XT,PT,YT,train_indices,validation_indices,test_indices 52 | 53 | 54 | 55 | 56 | # #SIMULATED DATA PARAMETERS 57 | # #n_train = 5000; 58 | # n_valid = int(df.shape[0]/10); 59 | # n_test = int(df.shape[0]/10); 60 | 61 | # X,P,Y,XV,PV,YV,XT,PT,YT,train_indices,validation_indices,test_indices = prepare_data(df, n_test, n_valid) 62 | 63 | # #np.save('prepared_data/X.npy',X) 64 | # #np.save('prepared_data/P.npy',P) 65 | # #np.save('prepared_data/Y.npy',Y) 66 | # # 67 | # #np.save('prepared_data/XV.npy',XV) 68 | # #np.save('prepared_data/PV.npy',PV) 69 | # #np.save('prepared_data/YV.npy',YV) 70 | # # 71 | # #np.save('prepared_data/XT.npy',XT) 72 | # #np.save('prepared_data/PT.npy',PT) 73 | # #np.save('prepared_data/YT.npy',YT) 74 | 75 | 76 | # X = np.load('prepared_data/X.npy') 77 | # P = np.load('prepared_data/P.npy') 78 | # Y = np.load('prepared_data/Y.npy') 79 | 80 | # XV = np.load('prepared_data/XV.npy') 81 | # PV = np.load('prepared_data/PV.npy') 82 | # YV = np.load('prepared_data/YV.npy') 83 | 84 | # XT = np.load('prepared_data/XT.npy') 85 | # PT = np.load('prepared_data/PT.npy') 86 | # YT = np.load('prepared_data/YT.npy') 87 | 88 | num_features = 4 89 | model_type = 0 90 | is_bias = True 91 | 92 | #scores per data set, pre prune post prune, MSE + LL (test,train) 93 | scores_np = np.zeros((10,2,4)) 94 | 95 | scores_np = np.zeros((10,2,6,15)) 96 | 97 | def return_leaf(X): 98 | labels = np.zeros(X.shape[0]) 99 | labels = 1*((X[:,9]>0.5) &(X[:,7]>0.5)& (X[:,2]>0.5)) 100 | labels += 2*((X[:,9]>0.5) &(X[:,7]>0.5)&(X[:,2]<0.5)) 101 | labels += 3*((X[:,9]>0.5) &(X[:,7]<0.5)) 102 | labels += 4*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]>0.5)&(X[:,0]>2.5)) 103 | labels += 5*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]>0.5)&(X[:,0]<2.5)) 104 | labels += 6*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]<0.5)&(X[:,0]>2.5)) 105 | labels += 7*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]<0.5)&(X[:,0]<2.5)) 106 | labels += 8*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]>0.5)&(X[:,0]>2.5)) 107 | labels += 9*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]>0.5)&(X[:,0]<2.5)) 108 | labels += 10*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]<0.5)&(X[:,0]>2.5)) 109 | labels += 11*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]<0.5)&(X[:,0]<2.5)) 110 | return(labels) 111 | 112 | for i in range(10): 113 | #for i in [1]: 114 | num_features = 4 115 | X = np.load('prepared_data2/X'+str(i)+'.npy') 116 | P = np.load('prepared_data2/P'+str(i)+'.npy') 117 | Y = np.load('prepared_data2/Y'+str(i)+'.npy') 118 | 119 | XV = np.load('prepared_data2/XV'+str(i)+'.npy') 120 | PV = np.load('prepared_data2/PV'+str(i)+'.npy') 121 | YV = np.load('prepared_data2/YV'+str(i)+'.npy') 122 | 123 | XT = np.load('prepared_data2/XT'+str(i)+'.npy') 124 | PT = np.load('prepared_data2/PT'+str(i)+'.npy') 125 | YT = np.load('prepared_data2/YT'+str(i)+'.npy') 126 | 127 | 128 | P = P.astype(float) 129 | PV = PV.astype(float) 130 | PT = PT.astype(float) 131 | 132 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 133 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 134 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 135 | 136 | is_continuous = [False for k in range(len(c_features))] 137 | is_continuous[6] = True 138 | is_continuous[8] = False 139 | P, PV, PT, num_features = append_customer_features_to_product_features(X, XV, XT, 140 | P, PV, PT, 141 | feats_continuous=is_continuous, 142 | model_type=model_type, num_features=num_features) 143 | # for depth in [4]: 144 | # for depth in range(15): 145 | labels = return_leaf(X) 146 | labelsT = return_leaf(XT) 147 | Y_pred = np.zeros((YT.shape[0],3)) 148 | Y_pred2 = np.zeros((Y.shape[0],3)) 149 | for l in range(1,12): 150 | #APPLY TREE ALGORITHM. TRAIN TO DEPTH 1 151 | my_tree = MST(max_depth = 0, min_weights_per_node = 20, only_singleton_splits = True, quant_discret = 0.05) 152 | # my_tree = MST(max_depth = 12 min_weights_per_node = 20, quant_discret = 0.05) 153 | my_tree.fit(X[labels == l],P[labels == l],Y[labels == l],verbose=True, 154 | feats_continuous= is_continuous, 155 | refit_leaves=True,only_singleton_splits = True, 156 | num_features = num_features, is_bias = is_bias, model_type = model_type, 157 | mode = "mnl", batch_size = 100, epochs = 50, steps = 6000, 158 | leaf_mod_thresh=10000000); 159 | Y_pred[labelsT == l] = my_tree.predict(XT[labelsT == l],PT[labelsT == l]) 160 | Y_pred2[labels == l] = my_tree.predict(X[labels == l],P[labels == l]) 161 | #ABOVE: leaf_mod_thresh controls whether when fitting a leaf node we apply Newton's method or stochastic gradient descent. 162 | # If the number of training observations in a leaf node <= leaf_mod_thresh, then newton's method 163 | # is applied; otherwise, stochastic gradient descent is applied. 164 | 165 | ## PRINT OUT THE UNPRUNED TREE. OBSERVE THAT THE FIRST SPLIT IS CORRECT, BUT THERE ARE UNNECESSARY SPLITS AFTER THAT 166 | #my_tree.traverse(verbose=True) 167 | ##PRUNE THE TREE 168 | #my_tree.prune(XV, PV, YV, verbose=False) 169 | # #PRINT OUT THE PRUNED TREE. OBSERVE THAT THE UNNECESSARY SPLITS HAVE BEEN PRUNED FROM THE TREE 170 | #my_tree.traverse(verbose=True) 171 | #print(my_tree._error(XT,PT,YT, use_pruning_error = True)) 172 | #print(my_tree._error(XT,PT,YT, use_pruning_error = False)) 173 | 174 | YT_flat = np.zeros((YT.shape[0],3)) 175 | YT_flat[np.arange(YT.shape[0]),YT] = 1 176 | Y_flat = np.zeros((Y.shape[0],3)) 177 | Y_flat[np.arange(Y.shape[0]),Y] = 1 178 | 179 | 180 | s_Y = Y_pred.shape[0] 181 | scores_np[i,0,0,0] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),YT]))) 182 | scores_np[i,0,1,0] = np.mean(np.sum(np.power(Y_pred-YT_flat,2),axis = 1)) 183 | 184 | s_Y = Y_pred2.shape[0] 185 | scores_np[i,0,2,0] = np.mean(np.log(np.maximum(0.01,Y_pred2[np.arange(s_Y),Y]))) 186 | scores_np[i,0,3,0] = np.mean(np.sum(np.power(Y_pred2-Y_flat,2),axis = 1)) 187 | 188 | 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /scripts/mnlicot.py: -------------------------------------------------------------------------------- 1 | 2 | #from GenMNL import GenMNL 3 | from mst import MST 4 | import numpy as np 5 | from cart_customerfeats_in_leafmod import append_customer_features_to_product_features 6 | # import pandas as pd 7 | 8 | 9 | #np.set_printoptions(suppress=True) #suppress scientific notation 10 | #np.random.seed(0) 11 | # df = pd.read_csv('swissmetro.dat',sep='\t') 12 | # df["CAR_HE"] = 0 13 | 14 | c_features = ["GROUP", "PURPOSE", "FIRST", "TICKET", "WHO", "LUGGAGE", "AGE", 15 | "MALE", "INCOME", "GA", "ORIGIN", "DEST"] 16 | p_features = ["TRAIN_AV", "SM_AV", "CAR_AV", 17 | "TRAIN_TT", "SM_TT", "CAR_TT", 18 | "TRAIN_CO", "SM_CO", "CAR_CO", 19 | "TRAIN_HE", "SM_HE", "CAR_HE"] 20 | # target = "CHOICE" 21 | # df = df[df[target] > 0] 22 | # df.loc[:,target] = df.loc[:,target]-1 23 | # df = df.reset_index() 24 | 25 | # df = df.sample(frac=1).reset_index(drop=True) 26 | 27 | 28 | # def prepare_data(df, k1, k2): 29 | # ''' 30 | # Prepares a partition of the data into train, validation and test sets 31 | 32 | # Args: 33 | # k1 -> number of observations in the test set 34 | # k2 -> number of observations in the validation set 35 | # ''' 36 | # n = df.shape[0] 37 | # selected_indices = np.random.choice(range(n), size=k1 + k2, replace=False) 38 | # test_indices = np.random.choice(selected_indices, size=k1, replace=False) 39 | # validation_indices = np.setdiff1d(selected_indices,test_indices) 40 | # train_indices = np.setdiff1d(np.arange(n),selected_indices) 41 | # Y = df.loc[train_indices,target].values 42 | # X = df.loc[train_indices,c_features].values 43 | # P = df.loc[train_indices,p_features].values 44 | # YT = df.loc[test_indices,target].values 45 | # XT = df.loc[test_indices,c_features].values 46 | # PT = df.loc[test_indices,p_features].values 47 | # YV = df.loc[validation_indices,target].values 48 | # XV = df.loc[validation_indices,c_features].values 49 | # PV = df.loc[validation_indices,p_features].values 50 | 51 | # return X,P,Y,XV,PV,YV,XT,PT,YT,train_indices,validation_indices,test_indices 52 | 53 | 54 | 55 | 56 | # #SIMULATED DATA PARAMETERS 57 | # #n_train = 5000; 58 | # n_valid = int(df.shape[0]/10); 59 | # n_test = int(df.shape[0]/10); 60 | 61 | # X,P,Y,XV,PV,YV,XT,PT,YT,train_indices,validation_indices,test_indices = prepare_data(df, n_test, n_valid) 62 | 63 | # #np.save('prepared_data/X.npy',X) 64 | # #np.save('prepared_data/P.npy',P) 65 | # #np.save('prepared_data/Y.npy',Y) 66 | # # 67 | # #np.save('prepared_data/XV.npy',XV) 68 | # #np.save('prepared_data/PV.npy',PV) 69 | # #np.save('prepared_data/YV.npy',YV) 70 | # # 71 | # #np.save('prepared_data/XT.npy',XT) 72 | # #np.save('prepared_data/PT.npy',PT) 73 | # #np.save('prepared_data/YT.npy',YT) 74 | 75 | 76 | # X = np.load('prepared_data/X.npy') 77 | # P = np.load('prepared_data/P.npy') 78 | # Y = np.load('prepared_data/Y.npy') 79 | 80 | # XV = np.load('prepared_data/XV.npy') 81 | # PV = np.load('prepared_data/PV.npy') 82 | # YV = np.load('prepared_data/YV.npy') 83 | 84 | # XT = np.load('prepared_data/XT.npy') 85 | # PT = np.load('prepared_data/PT.npy') 86 | # YT = np.load('prepared_data/YT.npy') 87 | 88 | num_features = 4 89 | model_type = 0 90 | is_bias = True 91 | 92 | #scores per data set, pre prune post prune, MSE + LL (test,train) 93 | scores_np = np.zeros((10,2,4)) 94 | 95 | scores_np = np.zeros((10,2,6,15)) 96 | 97 | def return_leaf(X): 98 | labels = np.zeros(X.shape[0]) 99 | labels = 1*((X[:,9]>0.5) &(X[:,7]>0.5)& (X[:,2]>0.5)) 100 | labels += 2*((X[:,9]>0.5) &(X[:,7]>0.5)&(X[:,2]<0.5)) 101 | labels += 3*((X[:,9]>0.5) &(X[:,7]<0.5)) 102 | labels += 4*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]>0.5)&(X[:,0]>2.5)) 103 | labels += 5*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]>0.5)&(X[:,0]<2.5)) 104 | labels += 6*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]<0.5)&(X[:,0]>2.5)) 105 | labels += 7*((X[:,9]<0.5) &(X[:,2]>0.5)&(X[:,7]<0.5)&(X[:,0]<2.5)) 106 | labels += 8*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]>0.5)&(X[:,0]>2.5)) 107 | labels += 9*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]>0.5)&(X[:,0]<2.5)) 108 | labels += 10*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]<0.5)&(X[:,0]>2.5)) 109 | labels += 11*((X[:,9]<0.5) &(X[:,2]<0.5)&(X[:,7]<0.5)&(X[:,0]<2.5)) 110 | return(labels) 111 | 112 | for i in range(10): 113 | #for i in [1]: 114 | num_features = 4 115 | X = np.load('prepared_data2/X'+str(i)+'.npy') 116 | P = np.load('prepared_data2/P'+str(i)+'.npy') 117 | Y = np.load('prepared_data2/Y'+str(i)+'.npy') 118 | 119 | XV = np.load('prepared_data2/XV'+str(i)+'.npy') 120 | PV = np.load('prepared_data2/PV'+str(i)+'.npy') 121 | YV = np.load('prepared_data2/YV'+str(i)+'.npy') 122 | 123 | XT = np.load('prepared_data2/XT'+str(i)+'.npy') 124 | PT = np.load('prepared_data2/PT'+str(i)+'.npy') 125 | YT = np.load('prepared_data2/YT'+str(i)+'.npy') 126 | 127 | 128 | P = P.astype(float) 129 | PV = PV.astype(float) 130 | PT = PT.astype(float) 131 | 132 | P[:,-1] = 0.001*np.random.rand(P.shape[0]) 133 | PV[:,-1] = 0.001*np.random.rand(PV.shape[0]) 134 | PT[:,-1] = 0.001*np.random.rand(PT.shape[0]) 135 | 136 | is_continuous = [False for k in range(len(c_features))] 137 | is_continuous[6] = True 138 | is_continuous[8] = False 139 | P, PV, PT, num_features = append_customer_features_to_product_features(X, XV, XT, 140 | P, PV, PT, 141 | feats_continuous=is_continuous, 142 | model_type=model_type, num_features=num_features) 143 | # for depth in [4]: 144 | # for depth in range(15): 145 | labels = return_leaf(X) 146 | labelsT = return_leaf(XT) 147 | Y_pred = np.zeros((YT.shape[0],3)) 148 | Y_pred2 = np.zeros((Y.shape[0],3)) 149 | for l in range(1,12): 150 | #APPLY TREE ALGORITHM. TRAIN TO DEPTH 1 151 | my_tree = MST(max_depth = 0, min_weights_per_node = 20, only_singleton_splits = True, quant_discret = 0.05) 152 | # my_tree = MST(max_depth = 12 min_weights_per_node = 20, quant_discret = 0.05) 153 | my_tree.fit(X[labels == l],P[labels == l],Y[labels == l],verbose=True, 154 | feats_continuous= is_continuous, 155 | refit_leaves=True,only_singleton_splits = True, 156 | num_features = num_features, is_bias = is_bias, model_type = model_type, 157 | mode = "mnl", batch_size = 100, epochs = 50, steps = 6000, 158 | leaf_mod_thresh=10000000); 159 | Y_pred[labelsT == l] = my_tree.predict(XT[labelsT == l],PT[labelsT == l]) 160 | Y_pred2[labels == l] = my_tree.predict(X[labels == l],P[labels == l]) 161 | #ABOVE: leaf_mod_thresh controls whether when fitting a leaf node we apply Newton's method or stochastic gradient descent. 162 | # If the number of training observations in a leaf node <= leaf_mod_thresh, then newton's method 163 | # is applied; otherwise, stochastic gradient descent is applied. 164 | 165 | ## PRINT OUT THE UNPRUNED TREE. OBSERVE THAT THE FIRST SPLIT IS CORRECT, BUT THERE ARE UNNECESSARY SPLITS AFTER THAT 166 | #my_tree.traverse(verbose=True) 167 | ##PRUNE THE TREE 168 | #my_tree.prune(XV, PV, YV, verbose=False) 169 | # #PRINT OUT THE PRUNED TREE. OBSERVE THAT THE UNNECESSARY SPLITS HAVE BEEN PRUNED FROM THE TREE 170 | #my_tree.traverse(verbose=True) 171 | #print(my_tree._error(XT,PT,YT, use_pruning_error = True)) 172 | #print(my_tree._error(XT,PT,YT, use_pruning_error = False)) 173 | 174 | YT_flat = np.zeros((YT.shape[0],3)) 175 | YT_flat[np.arange(YT.shape[0]),YT] = 1 176 | Y_flat = np.zeros((Y.shape[0],3)) 177 | Y_flat[np.arange(Y.shape[0]),Y] = 1 178 | 179 | 180 | s_Y = Y_pred.shape[0] 181 | scores_np[i,0,0,0] = np.mean(np.log(np.maximum(0.01,Y_pred[np.arange(s_Y),YT]))) 182 | scores_np[i,0,1,0] = np.mean(np.sum(np.power(Y_pred-YT_flat,2),axis = 1)) 183 | 184 | s_Y = Y_pred2.shape[0] 185 | scores_np[i,0,2,0] = np.mean(np.log(np.maximum(0.01,Y_pred2[np.arange(s_Y),Y]))) 186 | scores_np[i,0,3,0] = np.mean(np.sum(np.power(Y_pred2-Y_flat,2),axis = 1)) 187 | 188 | 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /scripts/src/cart_customerfeats_in_leafmod.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | MNL + CART IMPLEMENTATION 4 | """ 5 | #import math 6 | import numpy as np 7 | import pandas as pd 8 | import copy 9 | from cart_with_mnl_leaf_refitting import CARTWithMNLLeafRefitting 10 | 11 | 12 | ''' 13 | This function fits and prunes a CART tree, and then as a postprocessing step fits an MNL response model in each segment. 14 | This function outputs predicted choice probabilities on the test set. 15 | Input arguments: 16 | Xtrain, Xval, Xtest: The individual-specific feature data for the training, validation, and test sets. Can either be a pandas data frame or numpy array, with: 17 | (a) rows of X = observations/customers 18 | (b) columns of X = features about the observation/customer 19 | Atrain, Aval, Atest: the decision variables/alternative-specific features used in the response models for the training, validation, and test sets. 20 | A can take any form -- it is directly passed to the functions in leaf_model.py 21 | Ytrain, Yval: the responses/outcomes/choices used in the response models. 22 | Y must be a 1-D array with length = number of observations 23 | weights_train (optional), weights_val (optional): an optional numpy array of case weights for the training and validation sets. Is 1-dimensional, with weights[i] yielding weight of observation/customer i 24 | feats_continuous (optional): If False, all X features are treated as categorical. If True, all features are treated as continuous. 25 | feats_continuous can also be a boolean vector of dimension = num_features specifying how to treat each feature. 26 | All features satisfying feats_continuous == False will be binarized when fitting the leaf models. 27 | verbose (optional): if verbose=True, prints out progress in training MST 28 | one_SE_rule (default True): do we use 1SE rule when pruning the tree? 29 | MST standard input arguments: 30 | max_depth: the maximum depth of the pre-pruned tree (default = Inf: no depth limit) 31 | min_weight_per_node: the mininum number of observations (with respect to cumulative weight) per node 32 | min_depth: the minimum depth of the pre-pruned tree (default: set equal to max_depth) 33 | min_diff: if depth > min_depth, stop splitting if improvement in fit does not exceed min_diff 34 | quant_discret: continuous variable split points are chosen from quantiles of the variable corresponding to quant_discret,2*quant_discret,3*quant_discret, etc.. 35 | run_in_parallel: if set to True, enables parallel computing among num_workers threads. If num_workers is not 36 | specified, uses the number of cpu cores available. 37 | 38 | Any additional keyword arguments are passed to the leaf_model fit() function 39 | For leaf_model_mnl, you should pass the following: 40 | n_features: integer (default is 2) 41 | mode : "mnl" or "exponomial" (default is "mnl") 42 | batch_size : size of the stochastic batch (default is 50,) 43 | model_type : whether the model has alternative varying coefficients or not (default is 0 meaning each alternative has a separate coeff) 44 | num_features : number of features under consideration (default is 2) 45 | epochs : number of epochs for the estimation (default is 10) 46 | is_bias : whether the utility function has an intercept (default is True) 47 | ''' 48 | def mnl_cart_fit_prune_and_predict(Xtrain,Atrain,Ytrain, 49 | Xval,Aval,Yval, 50 | Xtest,Atest, 51 | weights_train=None, weights_val=None, 52 | feats_continuous=True, 53 | verbose=True, 54 | one_SE_rule=True, 55 | max_depth=float("inf"),min_weights_per_node=15, 56 | min_depth=None,min_diff=0, 57 | quant_discret=0.01, 58 | run_in_parallel=False,num_workers=None, 59 | only_singleton_splits=True, 60 | *leafargs_fit,**leafkwargs_fit): 61 | 62 | my_tree = CARTWithMNLLeafRefitting(max_depth=max_depth,min_weights_per_node=min_weights_per_node, 63 | min_depth=min_depth,min_diff=min_diff, 64 | quant_discret = quant_discret, 65 | run_in_parallel=run_in_parallel,num_workers=num_workers, 66 | only_singleton_splits=only_singleton_splits) 67 | 68 | my_tree.fit(Xtrain, Atrain, Ytrain, weights=weights_train, feats_continuous=feats_continuous, verbose=verbose, *leafargs_fit,**leafkwargs_fit) 69 | my_tree.prune(Xval, Aval, Yval, weights_val=weights_val, one_SE_rule=one_SE_rule, verbose=verbose) 70 | #my_tree.traverse(verbose=True) 71 | my_tree.refit_leafmods_with_mnl(Xtrain, Atrain, Ytrain, weights_new=weights_train, verbose=verbose, *leafargs_fit,**leafkwargs_fit) 72 | #my_tree.traverse(verbose=True) 73 | 74 | Ypred_test = my_tree.predict(Xtest, Atest) 75 | 76 | return(Ypred_test,my_tree) 77 | 78 | ''' 79 | This function takes customer feature matrix X and product feature matrix A as input. 80 | Outputs a new product feature matrix A_new with customer features included 81 | NOTE: A_new is encoded such that each customer feature will have an alternative-specific coefficient regardless of model_type specification 82 | ''' 83 | def append_customer_features_to_product_features(X_train, X_valid, X_test, 84 | A_train, A_valid, A_test, 85 | feats_continuous=True, 86 | model_type=0, num_features=2): 87 | 88 | X_train_bin, X_valid_bin, X_test_bin = binarize_and_normalize_contexts(X_train, X_valid, X_test, 89 | feats_continuous=feats_continuous, 90 | normalize=True, 91 | drop_first=True) 92 | n_items = int(A_train.shape[1]/num_features) 93 | A_train_with_custfeats = _append_customer_features_to_product_features(A_train, X_train_bin, n_items, model_type) 94 | A_valid_with_custfeats = _append_customer_features_to_product_features(A_valid, X_valid_bin, n_items, model_type) 95 | A_test_with_custfeats = _append_customer_features_to_product_features(A_test, X_test_bin, n_items, model_type) 96 | num_features_with_custfeats = int(A_train_with_custfeats.shape[1]/n_items) 97 | return(A_train_with_custfeats, A_valid_with_custfeats, A_test_with_custfeats, num_features_with_custfeats) 98 | 99 | ''' 100 | This function takes customer feature matrix X and product feature matrix A as input. 101 | Outputs a new product feature matrix A_new with customer features included 102 | NOTE: A_new is encoded such that each customer feature will have an alternative-specific coefficient regardless of model_type specification 103 | ''' 104 | def _append_customer_features_to_product_features(A, X_bin, n_items, model_type): 105 | if model_type == 0: 106 | #MNL model is encoded as having alternative-specific coefs 107 | A_cust = np.repeat(X_bin, n_items, axis=1) 108 | else: 109 | #MNL model is not encoded as having alternative-specific coefs, so need to encode data in a special way to ensure 110 | #customer feats each have alternative-specific coefs 111 | A_cust = np.repeat(X_bin, n_items*(n_items-1), axis=1) 112 | num_obs = X_bin.shape[0] 113 | n_cust_feats = X_bin.shape[1] 114 | zeroing_mat = np.zeros((num_obs, n_items*(n_items-1))) 115 | zeroing_mat[:,np.arange(n_items-1) + n_items*np.arange(n_items-1)] = 1 116 | zeroing_mat = np.tile(zeroing_mat, n_cust_feats) 117 | A_cust = A_cust * zeroing_mat 118 | return np.concatenate((A, A_cust), axis=1) 119 | 120 | ''' 121 | This function performs the following operations on contextual feature matrices X_train, X_valid, X_test: 122 | (1) Binarizes any features which are categorical, i.e. where feats_continuous==False 123 | (2) If normalize==True, normalizes any features which are numerical, i.e. where feats_continuous==True. 124 | (Normalize a feature means to transform it to have mean 0 and variance 1) 125 | This function outputs normalized feature matrices X_train_new, X_valid_new, X_test_new w/ binary categorical features FIRST, then continuous 126 | 127 | Input parameters: 128 | X_train, X_valid, X_test: contextual feature matrices for the training, validation, and test sets which have the following dimensions: 129 | (a) rows of X = observations/customers 130 | (b) columns of X = features about the observation/customer 131 | feats_continuous: If False, all feature are treated as categorical. If True, all feature are treated as continuous. 132 | feats_continuous can also be a boolean vector of dimension = num_features specifying how to treat each feature 133 | normalize: Whether to normalize the numerical features 134 | drop_first: Whether to get k-1 dummies out of k categorical levels by removing the first level 135 | ''' 136 | 137 | def binarize_and_normalize_contexts(X_train, X_valid, X_test, feats_continuous=True, normalize=True, drop_first=False): 138 | X_train = copy.deepcopy(X_train) 139 | X_valid = copy.deepcopy(X_valid) 140 | X_test = copy.deepcopy(X_test) 141 | 142 | n_train = X_train.shape[0] 143 | n_valid = X_valid.shape[0] 144 | 145 | all_continuous = np.all(feats_continuous) 146 | all_categorical = np.all(np.logical_not(feats_continuous)) 147 | 148 | X = np.concatenate([X_train, X_valid, X_test], axis=0) 149 | 150 | if not all_categorical: 151 | X_continuous = X[:,np.where(feats_continuous)[0]].astype("float") 152 | if normalize==True: 153 | X_continuous_std = np.std(X_continuous,axis=0) 154 | X_continuous_std[X_continuous_std == 0.0] = 1.0 155 | X_continuous = (X_continuous - np.mean(X_continuous, axis=0)) / X_continuous_std 156 | 157 | if not all_continuous: 158 | X_categorical = X[:,np.where(np.logical_not(feats_continuous))[0]] 159 | X_categorical = pd.DataFrame(X_categorical) 160 | X_categorical_bin = pd.get_dummies(X_categorical,columns=X_categorical.columns,drop_first=drop_first).values 161 | 162 | if all_continuous: 163 | X_new = X_continuous 164 | #feats_continuous_new = [True]*X_continuous.shape[1] 165 | elif all_categorical: 166 | X_new = X_categorical_bin 167 | #feats_continuous_new = [False]*X_categorical_bin.shape[1] 168 | else: 169 | X_new = np.concatenate([X_categorical_bin, X_continuous], axis=1) 170 | #feats_continuous_new = [False]*X_categorical_bin.shape[1] + [True]*X_continuous.shape[1] 171 | 172 | X_train_new = X_new[:n_train,:] 173 | X_valid_new = X_new[n_train:(n_train+n_valid),:] 174 | X_test_new = X_new[(n_train+n_valid):,:] 175 | 176 | return (X_train_new, X_valid_new, X_test_new) 177 | #return (X_train_new, X_valid_new, X_test_new, feats_continuous_new) -------------------------------------------------------------------------------- /scripts/outputs/fitted cmts/CMTtree3.txt: -------------------------------------------------------------------------------- 1 | Node 0: Depth 0 2 | Parent Node: NA 3 | Non-terminal Node splitting on (categorical) var V0 4 | Splitting Question: Is V0 in: 5 | [2.] 6 | Child Nodes: 1 (True), 2 (False) 7 | 8 | 9 | Node 1: Depth 1 10 | Parent Node: 0 11 | Non-terminal node splitting on (numeric) var V6 12 | Splitting Question: Is V6 <= 4? 13 | Child Nodes: 3 (True), 4 (False) 14 | 15 | 16 | Node 2: Depth 1 17 | Parent Node: 0 18 | Non-terminal Node splitting on (categorical) var V2 19 | Splitting Question: Is V2 in: 20 | [0.] 21 | Child Nodes: 5 (True), 6 (False) 22 | 23 | 24 | Node 3: Depth 2 25 | Parent Node: 1 26 | Non-terminal Node splitting on (categorical) var V2 27 | Splitting Question: Is V2 in: 28 | [0.] 29 | Child Nodes: 7 (True), 8 (False) 30 | 31 | 32 | Node 4: Depth 2 33 | Parent Node: 1 34 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 35 | -1.519036e-04 -1.231850e-04 -2.714268e-03 4.359725e-04 -1.336364e-03 36 | F1:2 F2:2 F3:2 F1:3 F2:3 37 | -8.974949e-03 1.096445e-05 -9.648585e-03 -3.266413e-02 6.889427e-03 38 | F3:3 Avl 39 | -5.804593e-08 1.000000e+04 40 | 41 | 42 | 43 | Node 5: Depth 2 44 | Parent Node: 2 45 | Non-terminal Node splitting on (categorical) var V4 46 | Splitting Question: Is V4 in: 47 | [3.] 48 | Child Nodes: 11 (True), 12 (False) 49 | 50 | 51 | Node 6: Depth 2 52 | Parent Node: 2 53 | Non-terminal Node splitting on (categorical) var V5 54 | Splitting Question: Is V5 in: 55 | [0.] 56 | Child Nodes: 13 (True), 14 (False) 57 | 58 | 59 | Node 7: Depth 3 60 | Parent Node: 3 61 | Non-terminal Node splitting on (categorical) var V3 62 | Splitting Question: Is V3 in: 63 | [7.] 64 | Child Nodes: 15 (True), 16 (False) 65 | 66 | 67 | Node 8: Depth 3 68 | Parent Node: 3 69 | Non-terminal Node splitting on (categorical) var V3 70 | Splitting Question: Is V3 in: 71 | [3.] 72 | Child Nodes: 17 (True), 18 (False) 73 | 74 | 75 | Node 11: Depth 3 76 | Parent Node: 5 77 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 78 | 6.301111e-01 5.108592e-01 -1.241984e-03 -3.330738e-02 -9.688157e-03 79 | F1:2 F2:2 F3:2 F1:3 F2:3 80 | 3.057576e-03 -2.487229e-02 1.786548e-03 -1.159451e-02 -7.571369e-04 81 | F3:3 Avl 82 | 2.260231e-04 1.000000e+04 83 | 84 | 85 | 86 | Node 12: Depth 3 87 | Parent Node: 5 88 | Non-terminal Node splitting on (categorical) var V5 89 | Splitting Question: Is V5 in: 90 | [0.] 91 | Child Nodes: 25 (True), 26 (False) 92 | 93 | 94 | Node 13: Depth 3 95 | Parent Node: 6 96 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 97 | 1.453189e+00 1.820922e+00 -1.522427e-02 -3.034687e-02 -5.275996e-03 98 | F1:2 F2:2 F3:2 F1:3 F2:3 99 | -2.359661e-02 -1.855512e-02 6.095158e-03 -2.747502e-02 -1.482392e-02 100 | F3:3 Avl 101 | 9.440796e-04 1.000000e+04 102 | 103 | 104 | 105 | Node 14: Depth 3 106 | Parent Node: 6 107 | Non-terminal Node splitting on (categorical) var V3 108 | Splitting Question: Is V3 in: 109 | [1.] 110 | Child Nodes: 29 (True), 30 (False) 111 | 112 | 113 | Node 15: Depth 4 114 | Parent Node: 7 115 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 116 | 2.976704e-03 -5.653588e-03 -2.252575e-02 -4.180009e-03 -2.004230e-02 117 | F1:2 F2:2 F3:2 F1:3 F2:3 118 | -5.898804e-02 -3.288003e-03 -1.780213e-02 -1.078849e-01 -3.012621e-01 119 | F3:3 Avl 120 | 1.911478e-06 1.000000e+04 121 | 122 | 123 | 124 | Node 16: Depth 4 125 | Parent Node: 7 126 | Non-terminal Node splitting on (categorical) var V1 127 | Splitting Question: Is V1 in: 128 | [3.] 129 | Child Nodes: 33 (True), 34 (False) 130 | 131 | 132 | Node 17: Depth 4 133 | Parent Node: 8 134 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 135 | 4.234020e-02 4.448950e-01 3.177469e-02 -6.524086e-02 -2.426384e-02 136 | F1:2 F2:2 F3:2 F1:3 F2:3 137 | 5.233479e-02 -4.725316e-02 -4.700310e-02 -4.367803e-02 -1.979192e-02 138 | F3:3 Avl 139 | 2.009180e-04 1.000000e+04 140 | 141 | 142 | 143 | Node 18: Depth 4 144 | Parent Node: 8 145 | Non-terminal Node splitting on (categorical) var V9 146 | Splitting Question: Is V9 in: 147 | [0.] 148 | Child Nodes: 35 (True), 36 (False) 149 | 150 | 151 | Node 25: Depth 4 152 | Parent Node: 12 153 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 154 | -5.569334e-01 -1.309272e+00 -2.620728e-02 -5.842753e-02 -1.450680e-02 155 | F1:2 F2:2 F3:2 F1:3 F2:3 156 | -2.678389e-02 -4.121851e-02 -2.158244e-02 -2.963383e-02 -2.416643e-02 157 | F3:3 Avl 158 | -6.282445e-04 1.000000e+04 159 | 160 | 161 | 162 | Node 26: Depth 4 163 | Parent Node: 12 164 | Non-terminal Node splitting on (categorical) var V1 165 | Splitting Question: Is V1 in: 166 | [4.] 167 | Child Nodes: 39 (True), 40 (False) 168 | 169 | 170 | Node 29: Depth 4 171 | Parent Node: 14 172 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 173 | 1.241475e+00 1.660932e-01 -2.783546e-02 -6.744886e-03 -2.055840e-02 174 | F1:2 F2:2 F3:2 F1:3 F2:3 175 | -1.111469e-02 -3.582079e-02 -8.859119e-03 -2.214270e-02 -1.053350e-02 176 | F3:3 Avl 177 | 1.763387e-04 1.000000e+04 178 | 179 | 180 | 181 | Node 30: Depth 4 182 | Parent Node: 14 183 | Non-terminal node splitting on (numeric) var V6 184 | Splitting Question: Is V6 <= 3? 185 | Child Nodes: 45 (True), 46 (False) 186 | 187 | 188 | Node 33: Depth 5 189 | Parent Node: 16 190 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 191 | 9.559482e-01 -6.941771e-01 -4.768309e-03 -6.818659e-04 -2.853700e-03 192 | F1:2 F2:2 F3:2 F1:3 F2:3 193 | -1.024020e-02 -7.624882e-04 -7.506058e-03 -3.132388e-03 -1.421191e-02 194 | F3:3 Avl 195 | -4.046611e-04 1.000000e+04 196 | 197 | 198 | 199 | Node 34: Depth 5 200 | Parent Node: 16 201 | Non-terminal Node splitting on (categorical) var V9 202 | Splitting Question: Is V9 in: 203 | [0.] 204 | Child Nodes: 49 (True), 50 (False) 205 | 206 | 207 | Node 35: Depth 5 208 | Parent Node: 18 209 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 210 | -6.677119e-01 -2.141141e+00 -3.054340e-02 -6.097421e-03 -1.369695e-02 211 | F1:2 F2:2 F3:2 F1:3 F2:3 212 | -2.807416e-02 -3.409508e-03 -1.839882e-02 -1.676398e-02 -2.883941e-02 213 | F3:3 Avl 214 | -1.101091e-03 1.000000e+04 215 | 216 | 217 | 218 | Node 36: Depth 5 219 | Parent Node: 18 220 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 221 | 1.207664e-04 -2.463328e-04 -3.687042e-03 -7.686849e-04 -5.189160e-03 222 | F1:2 F2:2 F3:2 F1:3 F2:3 223 | -1.173075e-02 -4.016654e-04 -1.053529e-02 -2.802661e-02 -7.338101e-02 224 | F3:3 Avl 225 | -7.670853e-08 1.000000e+04 226 | 227 | 228 | 229 | Node 39: Depth 5 230 | Parent Node: 26 231 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 232 | -3.699108e-01 9.816259e-01 -3.411850e-03 -8.114622e-02 -1.859799e-03 233 | F1:2 F2:2 F3:2 F1:3 F2:3 234 | -2.981370e-03 -2.781521e-02 -9.834112e-03 -1.667581e-02 -1.531339e-02 235 | F3:3 Avl 236 | 4.543011e-04 1.000000e+04 237 | 238 | 239 | 240 | Node 40: Depth 5 241 | Parent Node: 26 242 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 243 | -1.328243e+00 -8.729430e-01 -2.959597e-02 -2.968696e-02 -1.516995e-02 244 | F1:2 F2:2 F3:2 F1:3 F2:3 245 | -2.220844e-02 -2.426447e-02 7.730334e-03 -2.822578e-02 -8.394870e-03 246 | F3:3 Avl 247 | -3.691845e-04 1.000000e+04 248 | 249 | 250 | 251 | Node 45: Depth 5 252 | Parent Node: 30 253 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 254 | 3.171062e+00 3.611031e+00 -2.844232e-02 -8.274448e-03 2.059831e-02 255 | F1:2 F2:2 F3:2 F1:3 F2:3 256 | -1.549930e-02 -1.710031e-02 -2.053694e-02 -2.265076e-02 -1.457003e-02 257 | F3:3 Avl 258 | 1.841704e-03 1.000000e+04 259 | 260 | 261 | 262 | Node 46: Depth 5 263 | Parent Node: 30 264 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 265 | 5.157320e-01 4.953947e+00 6.172003e-04 -7.773274e-02 4.833166e-02 266 | F1:2 F2:2 F3:2 F1:3 F2:3 267 | -2.983807e-02 -1.731310e-02 -3.386956e-04 -3.237565e-02 -3.419757e-02 268 | F3:3 Avl 269 | 1.385805e-03 1.000000e+04 270 | 271 | 272 | 273 | Node 49: Depth 6 274 | Parent Node: 34 275 | Non-terminal Node splitting on (categorical) var V3 276 | Splitting Question: Is V3 in: 277 | [3.] 278 | Child Nodes: 77 (True), 78 (False) 279 | 280 | 281 | Node 50: Depth 6 282 | Parent Node: 34 283 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 284 | 5.658087e-04 -1.033379e-03 -1.051035e-03 -1.593862e-03 -7.809525e-03 285 | F1:2 F2:2 F3:2 F1:3 F2:3 286 | -7.315117e-03 -1.138563e-03 1.792917e-02 -1.153893e-01 -1.571843e-02 287 | F3:3 Avl 288 | -4.900901e-07 1.000000e+04 289 | 290 | 291 | 292 | Node 77: Depth 7 293 | Parent Node: 49 294 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 295 | 3.318676e+00 -2.427520e+01 -3.998864e-02 -2.169391e-01 -1.975203e-04 296 | F1:2 F2:2 F3:2 F1:3 F2:3 297 | -4.047354e-02 -2.040290e-01 -1.054064e-01 -7.124155e-02 8.225551e-02 298 | F3:3 Avl 299 | -1.409556e-02 1.000000e+04 300 | 301 | 302 | 303 | Node 78: Depth 7 304 | Parent Node: 49 305 | Non-terminal node splitting on (numeric) var V6 306 | Splitting Question: Is V6 <= 2? 307 | Child Nodes: 109 (True), 110 (False) 308 | 309 | 310 | Node 109: Depth 8 311 | Parent Node: 78 312 | Non-terminal Node splitting on (categorical) var V3 313 | Splitting Question: Is V3 in: 314 | [8.] 315 | Child Nodes: 139 (True), 140 (False) 316 | 317 | 318 | Node 110: Depth 8 319 | Parent Node: 78 320 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 321 | -9.834426e-01 -4.449200e-01 -6.255845e-02 -3.485669e-02 -1.632535e-02 322 | F1:2 F2:2 F3:2 F1:3 F2:3 323 | -5.811344e-02 -6.283422e-02 -1.132461e-02 -5.787489e-02 -1.124851e-01 324 | F3:3 Avl 325 | -3.470079e-05 1.000000e+04 326 | 327 | 328 | 329 | Node 139: Depth 9 330 | Parent Node: 109 331 | Leaf: Model params: (Intercept):2 F1:1 F2:1 F3:1 F1:2 332 | 2.115995e+00 2.726341e-02 -2.331976e-02 -4.190929e-03 2.671481e-02 333 | F2:2 F3:2 Avl 334 | -1.744978e-02 -3.489205e-02 1.000000e+04 335 | 336 | 337 | 338 | Node 140: Depth 9 339 | Parent Node: 109 340 | Leaf: Model params: (Intercept):2 (Intercept):3 F1:1 F2:1 F3:1 341 | 3.058408e+00 -1.188781e+00 -9.693176e-03 -3.618815e-02 -1.129509e-02 342 | F1:2 F2:2 F3:2 F1:3 F2:3 343 | -2.658443e-02 -2.455232e-02 -3.647763e-02 -3.052819e-04 -5.811319e-02 344 | F3:3 Avl 345 | -7.634294e-04 1.000000e+04 346 | 347 | 348 | 349 | Max depth:9 350 | Num. Nodes: 37.0 351 | Num. Terminal Nodes: 19.0 352 | -------------------------------------------------------------------------------- /scripts/mnlint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras.layers import Dense 7 | from tensorflow.keras import Model,backend 8 | from functools import reduce 9 | import pandas as pd 10 | import numpy as np 11 | import tensorboard 12 | import datetime 13 | 14 | 15 | # specifies the model class 16 | class MyModel(Model): 17 | def __init__(self,params ): 18 | super(MyModel, self).__init__() 19 | self.params = params 20 | regularizer = tf.keras.regularizers.L2(0.) 21 | initializer = tf.keras.initializers.TruncatedNormal(mean=0.0001, stddev=0.0002) 22 | #coefficients of NN for utility function 23 | self.u1 = [Dense(params['hidden_utility'], activation='elu',kernel_initializer=initializer, 24 | kernel_regularizer=regularizer) 25 | for l in range(params['depth_utility'])] 26 | #last layer for utility function NN (no activation) 27 | self.u2 = Dense(1,kernel_regularizer=regularizer, kernel_initializer=initializer) 28 | 29 | def call(self, x): 30 | # computes the utility 31 | list_x1 = [tf.reshape(self.u2(reduce(lambda z,y: y(z),self.u1,x[:,:,i])), 32 | [-1,1,1]) 33 | for i in range(total)] 34 | x1 = tf.stack(list_x1, axis = 1)[:,:,0,0] 35 | 36 | 37 | x = (1.0/1.03)*(0.01+tf.nn.softmax(x1)) 38 | return(x) 39 | 40 | 41 | # # Selects the parameter configuration 42 | params_list= [ 43 | # simple_mnl_regression 44 | { 45 | 'depth_utility': 0, # number of hidden layers in generating the utility function 46 | 'hidden_utility':0, # width of hidden layers generating the utility function 47 | 'name': 'regression' 48 | } 49 | ] 50 | 51 | # size of mini-batch in gradient computation 52 | batch_size = 32 53 | total = 3 54 | # m = 4 55 | p_features = 4 56 | c_features = 72 57 | scores_np = np.zeros((10,len(params_list),5)) 58 | for iD in range(10): 59 | x_train = np.load(f'data/X_long_{iD}.npy') 60 | p_train = np.load(f'data/P_long_{iD}.npy') 61 | p_train = p_train.reshape((p_train.shape[0],p_features,total),order = 'C') 62 | y_train= np.load(f'data/Y_long_{iD}.npy') 63 | 64 | x_val = np.load(f'data/XV_long_{iD}.npy') 65 | p_val = np.load(f'data/PV_long_{iD}.npy') 66 | p_val = p_val.reshape((p_val.shape[0],p_features,total),order = 'C') 67 | y_val = np.load(f'data/YV_long_{iD}.npy') 68 | 69 | x_test = np.load(f'data/XT_long_{iD}.npy') 70 | p_test = np.load(f'data/PT_long_{iD}.npy') 71 | p_test = p_test.reshape((p_test.shape[0],p_features,total),order = 'C') 72 | y_test = np.load(f'data/YT_long_{iD}.npy') 73 | 74 | X1 = np.concatenate( [x_train.reshape((x_train.shape[0],c_features,1)) for u in range(total)],axis = 2) 75 | x_train= np.concatenate([p_train,X1], axis = 1) 76 | 77 | X1 = np.concatenate( [x_val.reshape((x_val.shape[0],c_features,1)) for u in range(total)],axis = 2) 78 | x_val= np.concatenate([p_val,X1], axis = 1) 79 | 80 | X1 = np.concatenate( [x_test.reshape((x_test.shape[0],c_features,1)) for u in range(total)],axis = 2) 81 | x_test= np.concatenate([p_test,X1], axis = 1) 82 | 83 | x_train = x_train.astype(float) 84 | x_val = x_val.astype(float) 85 | x_test = x_test.astype(float) 86 | 87 | # create the expanded version of the data with all interactions of product and customer features 88 | X1 = np.zeros((x_train.shape[0],p_features*(1+c_features),total)) 89 | X2 = np.zeros((x_test.shape[0],p_features*(1+c_features),total)) 90 | X3 = np.zeros((x_val.shape[0],p_features*(1+c_features),total)) 91 | X1[:,:p_features,:] = x_train[:,:p_features,:] 92 | X2[:,:p_features,:] = x_test[:,:p_features,:] 93 | X3[:,:p_features,:] = x_val[:,:p_features,:] 94 | 95 | for zeta in range(p_features): 96 | for j in range(c_features): 97 | X1[:,p_features + zeta*j,:] = np.multiply(x_train[:,zeta,:],x_train[:,p_features + j,:]) 98 | X2[:,p_features + zeta*j,:] = np.multiply(x_test[:,zeta,:],x_test[:,p_features + j,:]) 99 | X3[:,p_features + zeta*j,:] = np.multiply(x_val[:,zeta,:],x_val[:,p_features + j,:]) 100 | x_train = X1 101 | x_test = X2 102 | x_val = X3 103 | 104 | print("Shape",x_train.shape,y_train.shape,x_test.shape,y_test.shape) 105 | 106 | # creates the data batch generation object 107 | train_ds = tf.data.Dataset.from_tensor_slices( 108 | (x_train, y_train)).shuffle(10000).batch(batch_size) 109 | test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size) 110 | val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size) 111 | 112 | i_p = -1 113 | for p in params_list: 114 | i_p+=1 115 | loss_val = 10000 116 | test_loss_f = [] 117 | loss_train = 0 118 | for r in [0.001,0.0001]: 119 | print("\n",p['name'],r) 120 | # Create an instance of the model 121 | model = MyModel(p) 122 | 123 | # designs the estimation process 124 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) 125 | 126 | optimizer = tf.keras.optimizers.Adam(learning_rate=r) 127 | 128 | # picks the metrics 129 | train_loss = tf.keras.metrics.Mean(name='train_loss') 130 | train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') 131 | train_mse =tf.keras.metrics.MeanSquaredError('train_mse') 132 | 133 | test_loss = tf.keras.metrics.Mean(name='test_loss') 134 | test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') 135 | test_mse =tf.keras.metrics.MeanSquaredError('test_mse') 136 | 137 | val_loss = tf.keras.metrics.Mean(name='val_loss') 138 | val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy') 139 | val_mse =tf.keras.metrics.MeanSquaredError('val_mse') 140 | 141 | # computes the gradient function 142 | @tf.function 143 | def train_step(images, labels): 144 | with tf.GradientTape() as tape: 145 | predictions = model(images, training=True) 146 | loss = loss_object(labels, predictions) 147 | regularization_loss=tf.add_n(model.losses) 148 | loss2 = tf.math.add(loss,regularization_loss) 149 | 150 | gradients = tape.gradient(loss2, model.trainable_variables) 151 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 152 | 153 | train_loss(loss) 154 | train_accuracy(labels, predictions) 155 | print() 156 | train_mse(tf.one_hot(labels,3), predictions,sample_weight = 1.0/batch_size) 157 | 158 | 159 | # computes the test function 160 | @tf.function 161 | def test_step(images, labels): 162 | predictions = model(images, training=False) 163 | t_loss = loss_object(labels, predictions) 164 | 165 | test_loss(t_loss) 166 | test_accuracy(labels, predictions) 167 | test_mse(tf.one_hot(labels,3), predictions,sample_weight = 1.0/batch_size) 168 | 169 | # computes the validation function 170 | @tf.function 171 | def val_step(images, labels): 172 | predictions = model(images, training=False) 173 | v_loss = loss_object(labels, predictions) 174 | 175 | val_loss(v_loss) 176 | val_accuracy(labels, predictions) 177 | val_mse(tf.one_hot(labels,3), predictions,sample_weight = 1.0/batch_size) 178 | 179 | # Training loop 180 | EPOCHS = 1001 181 | current_time = p['name'] +str(r)+datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 182 | train_log_dir = 'logs/gradient_tape/' + current_time + '/train' 183 | test_log_dir = 'logs/gradient_tape/' + current_time + '/test' 184 | train_summary_writer = tf.summary.create_file_writer(train_log_dir) 185 | test_summary_writer = tf.summary.create_file_writer(test_log_dir) 186 | 187 | for epoch in range(EPOCHS): 188 | try: 189 | # Reset the metrics at the start of the next epoch 190 | train_loss.reset_states() 191 | train_accuracy.reset_states() 192 | 193 | del train_ds 194 | del test_ds 195 | train_ds = tf.data.Dataset.from_tensor_slices( 196 | (x_train, y_train)).shuffle(1000+epoch).batch(batch_size) 197 | test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(1000+epoch).batch(batch_size) 198 | 199 | for images, labels in train_ds: 200 | train_step(images, labels) 201 | with train_summary_writer.as_default(): 202 | tf.summary.scalar('loss', train_loss.result(), step=epoch) 203 | tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch) 204 | tf.summary.scalar('mse', train_mse.result(), step=epoch) 205 | 206 | if (epoch % 10) == 0: 207 | test_loss.reset_states() 208 | test_accuracy.reset_states() 209 | test_mse.reset_states() 210 | val_loss.reset_states() 211 | val_accuracy.reset_states() 212 | val_mse.reset_states() 213 | 214 | for val_images, val_labels in val_ds: 215 | val_step(val_images, val_labels) 216 | with test_summary_writer.as_default(): 217 | tf.summary.scalar('loss', val_loss.result(), step=epoch) 218 | tf.summary.scalar('accuracy', val_accuracy.result(), step=epoch) 219 | tf.summary.scalar('mse', val_mse.result(), step=epoch) 220 | 221 | for test_images, test_labels in test_ds: 222 | test_step(test_images, test_labels) 223 | 224 | except: 225 | print("Error") 226 | raise 227 | if (epoch % 200) == 0: 228 | print( 229 | f'Epoch {epoch }, ' 230 | f'Loss: {train_loss.result()}, ' 231 | f'Accuracy: {train_accuracy.result() * 100}, ' 232 | f'Validation Loss: {val_loss.result()}, ' 233 | f'Test Loss: {test_loss.result()}, ' 234 | f'Test Accuracy: {test_accuracy.result() * 100},' 235 | f'Test Mse: {test_mse.result() * 100},' 236 | f'Number of parameters: {np.sum([np.prod(v.get_shape().as_list()) for v in model.trainable_variables])}' 237 | ) 238 | 239 | if (epoch % 10) == 0 and (float(val_loss.result()) < loss_val): 240 | test_loss_f = [float(test_loss.result()),float(test_accuracy.result()),float(test_mse.result())*3] 241 | loss_val = float(val_loss.result()) 242 | loss_train = float(train_loss.result()) 243 | print("Test loss final:",p['name'],r, test_loss_f) 244 | print("Train/validation loss final:",p['name'],r,loss_train,loss_val) 245 | 246 | scores_np[iD,i_p,0] = test_loss_f[0] 247 | scores_np[iD,i_p,1] = test_loss_f[1] 248 | scores_np[iD,i_p,2] = test_loss_f[2] 249 | scores_np[iD,i_p,3] = loss_val 250 | scores_np[iD,i_p,4] = loss_train 251 | 252 | 253 | 254 | 255 | --------------------------------------------------------------------------------