├── regression ├── requirements.txt ├── configs │ ├── gp │ │ ├── bnp.yaml │ │ ├── cnp.yaml │ │ ├── np.yaml │ │ ├── banp.yaml │ │ ├── canp.yaml │ │ └── anp.yaml │ ├── celeba │ │ ├── bnp.yaml │ │ ├── cnp.yaml │ │ ├── np.yaml │ │ ├── banp.yaml │ │ ├── canp.yaml │ │ └── anp.yaml │ ├── emnist │ │ ├── bnp.yaml │ │ ├── cnp.yaml │ │ ├── np.yaml │ │ ├── banp.yaml │ │ ├── canp.yaml │ │ └── anp.yaml │ └── lotka_volterra │ │ ├── bnp.yaml │ │ ├── cnp.yaml │ │ ├── np.yaml │ │ ├── banp.yaml │ │ ├── canp.yaml │ │ └── anp.yaml ├── utils │ ├── paths.py │ ├── misc.py │ ├── sampling.py │ └── log.py ├── data │ ├── emnist.py │ ├── celeba.py │ ├── image.py │ ├── gp.py │ └── lotka_volterra.py ├── models │ ├── cnp.py │ ├── attention.py │ ├── canp.py │ ├── bnp.py │ ├── banp.py │ ├── np.py │ ├── anp.py │ └── modules.py ├── celeba.py ├── emnist.py ├── lotka_volterra.py └── gp.py ├── bnp_new.png ├── bayesian_optimization ├── requirements.txt ├── results │ ├── oracle_1.npy │ ├── oracle_2.npy │ ├── oracle_3.npy │ ├── oracle_4.npy │ ├── oracle_5.npy │ ├── oracle_6.npy │ ├── oracle_7.npy │ ├── oracle_8.npy │ ├── oracle_9.npy │ ├── oracle_10.npy │ ├── oracle_100.npy │ ├── oracle_11.npy │ ├── oracle_12.npy │ ├── oracle_13.npy │ ├── oracle_14.npy │ ├── oracle_15.npy │ ├── oracle_16.npy │ ├── oracle_17.npy │ ├── oracle_18.npy │ ├── oracle_19.npy │ ├── oracle_20.npy │ ├── oracle_21.npy │ ├── oracle_22.npy │ ├── oracle_23.npy │ ├── oracle_24.npy │ ├── oracle_25.npy │ ├── oracle_26.npy │ ├── oracle_27.npy │ ├── oracle_28.npy │ ├── oracle_29.npy │ ├── oracle_30.npy │ ├── oracle_31.npy │ ├── oracle_32.npy │ ├── oracle_33.npy │ ├── oracle_34.npy │ ├── oracle_35.npy │ ├── oracle_36.npy │ ├── oracle_37.npy │ ├── oracle_38.npy │ ├── oracle_39.npy │ ├── oracle_40.npy │ ├── oracle_41.npy │ ├── oracle_42.npy │ ├── oracle_43.npy │ ├── oracle_44.npy │ ├── oracle_45.npy │ ├── oracle_46.npy │ ├── oracle_47.npy │ ├── oracle_48.npy │ ├── oracle_49.npy │ ├── oracle_50.npy │ ├── oracle_51.npy │ ├── oracle_52.npy │ ├── oracle_53.npy │ ├── oracle_54.npy │ ├── oracle_55.npy │ ├── oracle_56.npy │ ├── oracle_57.npy │ ├── oracle_58.npy │ ├── oracle_59.npy │ ├── oracle_60.npy │ ├── oracle_61.npy │ ├── oracle_62.npy │ ├── oracle_63.npy │ ├── oracle_64.npy │ ├── oracle_65.npy │ ├── oracle_66.npy │ ├── oracle_67.npy │ ├── oracle_68.npy │ ├── oracle_69.npy │ ├── oracle_70.npy │ ├── oracle_71.npy │ ├── oracle_72.npy │ ├── oracle_73.npy │ ├── oracle_74.npy │ ├── oracle_75.npy │ ├── oracle_76.npy │ ├── oracle_77.npy │ ├── oracle_78.npy │ ├── oracle_79.npy │ ├── oracle_80.npy │ ├── oracle_81.npy │ ├── oracle_82.npy │ ├── oracle_83.npy │ ├── oracle_84.npy │ ├── oracle_85.npy │ ├── oracle_86.npy │ ├── oracle_87.npy │ ├── oracle_88.npy │ ├── oracle_89.npy │ ├── oracle_90.npy │ ├── oracle_91.npy │ ├── oracle_92.npy │ ├── oracle_93.npy │ ├── oracle_94.npy │ ├── oracle_95.npy │ ├── oracle_96.npy │ ├── oracle_97.npy │ ├── oracle_98.npy │ └── oracle_99.npy ├── configs │ └── gp │ │ ├── bnp.yaml │ │ ├── cnp.yaml │ │ ├── np.yaml │ │ ├── banp.yaml │ │ ├── canp.yaml │ │ └── anp.yaml ├── utils │ ├── paths.py │ ├── misc.py │ ├── sampling.py │ └── log.py ├── .gitignore ├── test_all.sh ├── models │ ├── attention.py │ ├── cnp.py │ ├── canp.py │ ├── bnp.py │ ├── banp.py │ ├── np.py │ ├── anp.py │ └── modules.py ├── data │ └── gp.py └── run_bo.py ├── LICENSE ├── README.md └── .gitignore /regression/requirements.txt: -------------------------------------------------------------------------------- 1 | attrdict 2 | pyyaml 3 | tqdm 4 | -------------------------------------------------------------------------------- /bnp_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bnp_new.png -------------------------------------------------------------------------------- /bayesian_optimization/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | attrdict 3 | torch==1.4.0 4 | torchvision==0.5.0 5 | bayeso 6 | -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_1.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_2.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_3.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_4.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_5.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_6.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_7.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_8.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_9.npy -------------------------------------------------------------------------------- /regression/configs/gp/bnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_pre_depth: 4 5 | enc_post_depth: 2 6 | dec_depth: 3 7 | -------------------------------------------------------------------------------- /regression/configs/gp/cnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_pre_depth: 4 5 | enc_post_depth: 2 6 | dec_depth: 3 7 | -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_10.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_100.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_11.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_12.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_13.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_14.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_15.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_16.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_17.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_18.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_19.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_20.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_21.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_21.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_22.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_22.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_23.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_23.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_24.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_24.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_25.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_25.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_26.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_26.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_27.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_27.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_28.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_28.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_29.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_29.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_30.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_30.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_31.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_31.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_32.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_32.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_33.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_33.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_34.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_34.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_35.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_35.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_36.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_36.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_37.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_37.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_38.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_38.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_39.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_39.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_40.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_40.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_41.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_41.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_42.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_42.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_43.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_43.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_44.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_44.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_45.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_45.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_46.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_46.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_47.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_47.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_48.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_48.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_49.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_49.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_50.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_50.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_51.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_51.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_52.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_52.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_53.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_53.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_54.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_54.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_55.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_55.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_56.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_56.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_57.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_57.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_58.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_58.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_59.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_59.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_60.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_60.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_61.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_61.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_62.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_62.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_63.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_63.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_64.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_64.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_65.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_65.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_66.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_66.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_67.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_67.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_68.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_68.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_69.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_69.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_70.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_70.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_71.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_71.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_72.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_72.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_73.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_73.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_74.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_74.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_75.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_75.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_76.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_76.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_77.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_77.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_78.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_78.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_79.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_79.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_80.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_80.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_81.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_81.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_82.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_82.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_83.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_83.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_84.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_84.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_85.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_85.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_86.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_86.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_87.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_87.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_88.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_88.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_89.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_89.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_90.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_90.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_91.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_91.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_92.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_92.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_93.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_93.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_94.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_94.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_95.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_95.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_96.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_96.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_97.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_97.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_98.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_98.npy -------------------------------------------------------------------------------- /bayesian_optimization/results/oracle_99.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_99.npy -------------------------------------------------------------------------------- /regression/configs/celeba/bnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 3 3 | dim_hid: 128 4 | enc_pre_depth: 6 5 | enc_post_depth: 3 6 | dec_depth: 5 7 | -------------------------------------------------------------------------------- /regression/configs/celeba/cnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 3 3 | dim_hid: 128 4 | enc_pre_depth: 6 5 | enc_post_depth: 3 6 | dec_depth: 5 7 | -------------------------------------------------------------------------------- /regression/configs/emnist/bnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_pre_depth: 5 5 | enc_post_depth: 3 6 | dec_depth: 4 7 | -------------------------------------------------------------------------------- /regression/configs/emnist/cnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_pre_depth: 5 5 | enc_post_depth: 3 6 | dec_depth: 4 7 | -------------------------------------------------------------------------------- /bayesian_optimization/configs/gp/bnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_pre_depth: 4 5 | enc_post_depth: 2 6 | dec_depth: 3 7 | -------------------------------------------------------------------------------- /bayesian_optimization/configs/gp/cnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_pre_depth: 4 5 | enc_post_depth: 2 6 | dec_depth: 3 7 | -------------------------------------------------------------------------------- /regression/configs/lotka_volterra/bnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 2 3 | dim_hid: 128 4 | enc_pre_depth: 4 5 | enc_post_depth: 2 6 | dec_depth: 3 7 | -------------------------------------------------------------------------------- /regression/configs/lotka_volterra/cnp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 2 3 | dim_hid: 128 4 | enc_pre_depth: 4 5 | enc_post_depth: 2 6 | dec_depth: 3 7 | -------------------------------------------------------------------------------- /regression/configs/celeba/np.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 3 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_pre_depth: 6 6 | enc_post_depth: 3 7 | dec_depth: 5 8 | -------------------------------------------------------------------------------- /regression/configs/emnist/np.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 1 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_pre_depth: 5 6 | enc_post_depth: 3 7 | dec_depth: 4 8 | -------------------------------------------------------------------------------- /regression/configs/gp/np.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_pre_depth: 4 6 | enc_post_depth: 2 7 | dec_depth: 3 8 | -------------------------------------------------------------------------------- /bayesian_optimization/configs/gp/np.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_pre_depth: 4 6 | enc_post_depth: 2 7 | dec_depth: 3 8 | -------------------------------------------------------------------------------- /regression/configs/lotka_volterra/np.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 2 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_pre_depth: 4 6 | enc_post_depth: 2 7 | dec_depth: 3 8 | -------------------------------------------------------------------------------- /regression/configs/gp/banp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_v_depth: 4 5 | enc_qk_depth: 2 6 | enc_pre_depth: 4 7 | enc_post_depth: 2 8 | dec_depth: 3 9 | -------------------------------------------------------------------------------- /regression/configs/gp/canp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_v_depth: 4 5 | enc_qk_depth: 2 6 | enc_pre_depth: 4 7 | enc_post_depth: 2 8 | dec_depth: 3 9 | -------------------------------------------------------------------------------- /regression/configs/celeba/banp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 3 3 | dim_hid: 128 4 | enc_v_depth: 6 5 | enc_qk_depth: 3 6 | enc_pre_depth: 6 7 | enc_post_depth: 3 8 | dec_depth: 5 9 | -------------------------------------------------------------------------------- /regression/configs/celeba/canp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 3 3 | dim_hid: 128 4 | enc_v_depth: 6 5 | enc_qk_depth: 3 6 | enc_pre_depth: 6 7 | enc_post_depth: 3 8 | dec_depth: 5 9 | -------------------------------------------------------------------------------- /regression/configs/emnist/banp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_v_depth: 5 5 | enc_qk_depth: 3 6 | enc_pre_depth: 5 7 | enc_post_depth: 3 8 | dec_depth: 4 9 | -------------------------------------------------------------------------------- /regression/configs/emnist/canp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_v_depth: 5 5 | enc_qk_depth: 3 6 | enc_pre_depth: 5 7 | enc_post_depth: 3 8 | dec_depth: 4 9 | -------------------------------------------------------------------------------- /bayesian_optimization/utils/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT = '../../' 4 | 5 | datasets_path = os.path.join(ROOT, 'datasets') 6 | results_path = os.path.join(ROOT, 'ckpt_np') 7 | -------------------------------------------------------------------------------- /bayesian_optimization/configs/gp/banp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_v_depth: 4 5 | enc_qk_depth: 2 6 | enc_pre_depth: 4 7 | enc_post_depth: 2 8 | dec_depth: 3 9 | -------------------------------------------------------------------------------- /bayesian_optimization/configs/gp/canp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | enc_v_depth: 4 5 | enc_qk_depth: 2 6 | enc_pre_depth: 4 7 | enc_post_depth: 2 8 | dec_depth: 3 9 | -------------------------------------------------------------------------------- /regression/configs/lotka_volterra/banp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 2 3 | dim_hid: 128 4 | enc_v_depth: 4 5 | enc_qk_depth: 2 6 | enc_pre_depth: 4 7 | enc_post_depth: 2 8 | dec_depth: 3 9 | -------------------------------------------------------------------------------- /regression/configs/lotka_volterra/canp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 2 3 | dim_hid: 128 4 | enc_v_depth: 4 5 | enc_qk_depth: 2 6 | enc_pre_depth: 4 7 | enc_post_depth: 2 8 | dec_depth: 3 9 | -------------------------------------------------------------------------------- /regression/configs/celeba/anp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 3 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_v_depth: 6 6 | enc_qk_depth: 3 7 | enc_pre_depth: 6 8 | enc_post_depth: 3 9 | dec_depth: 5 10 | -------------------------------------------------------------------------------- /regression/configs/emnist/anp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 2 2 | dim_y: 1 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_v_depth: 5 6 | enc_qk_depth: 3 7 | enc_pre_depth: 5 8 | enc_post_depth: 3 9 | dec_depth: 4 10 | -------------------------------------------------------------------------------- /regression/configs/gp/anp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_v_depth: 4 6 | enc_qk_depth: 2 7 | enc_pre_depth: 4 8 | enc_post_depth: 2 9 | dec_depth: 3 10 | -------------------------------------------------------------------------------- /bayesian_optimization/configs/gp/anp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 1 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_v_depth: 4 6 | enc_qk_depth: 2 7 | enc_pre_depth: 4 8 | enc_post_depth: 2 9 | dec_depth: 3 10 | -------------------------------------------------------------------------------- /regression/configs/lotka_volterra/anp.yaml: -------------------------------------------------------------------------------- 1 | dim_x: 1 2 | dim_y: 2 3 | dim_hid: 128 4 | dim_lat: 128 5 | enc_v_depth: 4 6 | enc_qk_depth: 2 7 | enc_pre_depth: 4 8 | enc_post_depth: 2 9 | dec_depth: 3 10 | -------------------------------------------------------------------------------- /bayesian_optimization/.gitignore: -------------------------------------------------------------------------------- 1 | results/bo_rbf_oracle_* 2 | results/bo_rbf_noisy_oracle_* 3 | results/bo_matern_oracle_* 4 | results/bo_matern_noisy_oracle_* 5 | results/bo_periodic_oracle_* 6 | results/bo_periodic_noisy_oracle_* 7 | 8 | -------------------------------------------------------------------------------- /regression/utils/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT = '/nfs/lang/ext01/john' 4 | 5 | evalsets_path = os.path.join(ROOT, 'bnp', 'evalsets') 6 | datasets_path = os.path.join(ROOT, 'datasets') 7 | results_path = os.path.join(ROOT, 'bnp', 'results') 8 | -------------------------------------------------------------------------------- /bayesian_optimization/test_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python run_bo.py --mode bo --model anp 4 | python run_bo.py --mode bo --model banp 5 | python run_bo.py --mode bo --model bnp 6 | python run_bo.py --mode bo --model canp 7 | python run_bo.py --mode bo --model cnp 8 | python run_bo.py --mode bo --model np 9 | 10 | -------------------------------------------------------------------------------- /regression/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib.machinery import SourceFileLoader 3 | import math 4 | import torch 5 | 6 | def gen_load_func(parser, func): 7 | def load(args, cmdline): 8 | sub_args, cmdline = parser.parse_known_args(cmdline) 9 | for k, v in sub_args.__dict__.items(): 10 | args.__dict__[k] = v 11 | return func(**sub_args.__dict__), cmdline 12 | return load 13 | 14 | def load_module(filename): 15 | module_name = os.path.splitext(os.path.basename(filename))[0] 16 | return SourceFileLoader(module_name, filename).load_module() 17 | 18 | def logmeanexp(x, dim=0): 19 | return x.logsumexp(dim) - math.log(x.shape[dim]) 20 | 21 | def stack(x, num_samples=None, dim=0): 22 | return x if num_samples is None \ 23 | else torch.stack([x]*num_samples, dim=dim) 24 | -------------------------------------------------------------------------------- /bayesian_optimization/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib.machinery import SourceFileLoader 3 | import math 4 | import torch 5 | 6 | def gen_load_func(parser, func): 7 | def load(args, cmdline): 8 | sub_args, cmdline = parser.parse_known_args(cmdline) 9 | for k, v in sub_args.__dict__.items(): 10 | args.__dict__[k] = v 11 | return func(**sub_args.__dict__), cmdline 12 | return load 13 | 14 | def load_module(filename): 15 | module_name = os.path.splitext(os.path.basename(filename))[0] 16 | return SourceFileLoader(module_name, filename).load_module() 17 | 18 | def logmeanexp(x, dim=0): 19 | return x.logsumexp(dim) - math.log(x.shape[dim]) 20 | 21 | def stack(x, num_samples=None, dim=0): 22 | return x if num_samples is None \ 23 | else torch.stack([x]*num_samples, dim=dim) 24 | -------------------------------------------------------------------------------- /regression/data/emnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torchvision.datasets as tvds 5 | 6 | from utils.paths import datasets_path 7 | from utils.misc import gen_load_func 8 | 9 | class EMNIST(tvds.EMNIST): 10 | def __init__(self, train=True, class_range=[0, 47], device='cpu', download=True): 11 | super().__init__(datasets_path, train=train, split='balanced', download=download) 12 | 13 | self.data = self.data.unsqueeze(1).float().div(255).transpose(-1, -2).to(device) 14 | self.targets = self.targets.to(device) 15 | 16 | idxs = [] 17 | for c in range(class_range[0], class_range[1]): 18 | idxs.append(torch.where(self.targets==c)[0]) 19 | idxs = torch.cat(idxs) 20 | 21 | self.data = self.data[idxs] 22 | self.targets = self.targets[idxs] 23 | 24 | def __getitem__(self, idx): 25 | return self.data[idx], self.targets[idx] 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Juho Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /regression/utils/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def gather(items, idxs): 4 | K = idxs.shape[0] 5 | idxs = idxs.to(items[0].device) 6 | gathered = [] 7 | for item in items: 8 | gathered.append(torch.gather( 9 | torch.stack([item]*K), -2, 10 | torch.stack([idxs]*item.shape[-1], -1)).squeeze(0)) 11 | return gathered[0] if len(gathered) == 1 else gathered 12 | 13 | def sample_subset(*items, r_N=None, num_samples=None): 14 | r_N = r_N or torch.rand(1).item() 15 | K = num_samples or 1 16 | N = items[0].shape[-2] 17 | Ns = min(max(1, int(r_N * N)), N-1) 18 | batch_shape = items[0].shape[:-2] 19 | idxs = torch.rand((K,)+batch_shape+(N,)).argsort(-1) 20 | return gather(items, idxs[...,:Ns]), gather(items, idxs[...,Ns:]) 21 | 22 | def sample_with_replacement(*items, num_samples=None, r_N=1.0, N_s=None): 23 | K = num_samples or 1 24 | N = items[0].shape[-2] 25 | N_s = N_s or max(1, int(r_N * N)) 26 | batch_shape = items[0].shape[:-2] 27 | idxs = torch.randint(N, size=(K,)+batch_shape+(N_s,)) 28 | return gather(items, idxs) 29 | 30 | def sample_mask(B, N, num_samples=None, min_num=3, prob=0.5): 31 | min_num = min(min_num, N) 32 | K = num_samples or 1 33 | fixed = torch.ones(K, B, min_num) 34 | if N - min_num > 0: 35 | rand = torch.bernoulli(prob*torch.ones(K, B, N-min_num)) 36 | mask = torch.cat([fixed, rand], -1) 37 | return mask.squeeze(0) 38 | else: 39 | return fixed.squeeze(0) 40 | -------------------------------------------------------------------------------- /bayesian_optimization/utils/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def gather(items, idxs): 4 | K = idxs.shape[0] 5 | idxs = idxs.to(items[0].device) 6 | gathered = [] 7 | for item in items: 8 | gathered.append(torch.gather( 9 | torch.stack([item]*K), -2, 10 | torch.stack([idxs]*item.shape[-1], -1)).squeeze(0)) 11 | return gathered[0] if len(gathered) == 1 else gathered 12 | 13 | def sample_subset(*items, r_N=None, num_samples=None): 14 | r_N = r_N or torch.rand(1).item() 15 | K = num_samples or 1 16 | N = items[0].shape[-2] 17 | Ns = max(1, int(r_N * N)) 18 | batch_shape = items[0].shape[:-2] 19 | idxs = torch.rand((K,)+batch_shape+(Ns,)).argsort(-1) 20 | return gather(items, idxs[...,:Ns]), gather(items, idxs[...,Ns:]) 21 | 22 | def sample_with_replacement(*items, num_samples=None, r_N=1.0, N_s=None): 23 | K = num_samples or 1 24 | N = items[0].shape[-2] 25 | N_s = N_s or max(1, int(r_N * N)) 26 | batch_shape = items[0].shape[:-2] 27 | idxs = torch.randint(N, size=(K,)+batch_shape+(N_s,)) 28 | return gather(items, idxs) 29 | 30 | def sample_mask(B, N, num_samples=None, min_num=3, prob=0.5): 31 | min_num = min(min_num, N) 32 | K = num_samples or 1 33 | fixed = torch.ones(K, B, min_num) 34 | if N - min_num > 0: 35 | rand = torch.bernoulli(prob*torch.ones(K, B, N-min_num)) 36 | mask = torch.cat([fixed, rand], -1) 37 | return mask.squeeze(0) 38 | else: 39 | return fixed.squeeze(0) 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bootstrapping Neural Processes 2 | The official repository for the paper [Bootstrapping Neural Processes](https://arxiv.org/abs/2008.02956) (NeurIPS 2020) by Juho Lee et al. 3 | 4 |

5 | 6 |

7 | 8 | ## Abstract 9 | Unlike in the traditional statistical modeling for which a user typically hand-specify a prior, Neural Processes (NPs) implicitly define a broad class of stochastic processes with neural networks. Given a data stream, NP learns a stochastic process that best describes the data. While this ``data-driven'' way of learning stochastic processes has proven to handle various types of data, NPs still rely on an assumption that uncertainty in stochastic processes is modeled by a single latent variable, which potentially limits the flexibility. To this end, we propose the Boostrapping Neural Process (BNP), a novel extension of the NP family using the bootstrap. The bootstrap is a classical data-driven technique for estimating uncertainty, which allows BNP to learn the stochasticity in NPs without assuming a particular form. We demonstrate the efficacy of BNP on various types of data and its robustness in the presence of model-data mismatch. 10 | 11 | ## Citation 12 | If you find this useful in your research, please consider citing our paper: 13 | ``` 14 | @misc{lee2020bootstrapping, 15 | title={Bootstrapping Neural Processes}, 16 | author={Juho Lee and Yoonho Lee and Jungtaek Kim and Eunho Yang and Sung Ju Hwang and Yee Whye Teh}, 17 | year={2020}, 18 | journal={arXiv preprint arXiv:2008.02956}, 19 | } 20 | ``` -------------------------------------------------------------------------------- /regression/utils/log.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import logging 4 | from collections import OrderedDict 5 | 6 | def get_logger(filename, mode='a'): 7 | logging.basicConfig(level=logging.INFO, format='%(message)s') 8 | logger = logging.getLogger() 9 | logger.addHandler(logging.FileHandler(filename, mode=mode)) 10 | return logger 11 | 12 | class RunningAverage(object): 13 | def __init__(self, *keys): 14 | self.sum = OrderedDict() 15 | self.cnt = OrderedDict() 16 | self.clock = time.time() 17 | for key in keys: 18 | self.sum[key] = 0 19 | self.cnt[key] = 0 20 | 21 | def update(self, key, val): 22 | if isinstance(val, torch.Tensor): 23 | val = val.item() 24 | if self.sum.get(key, None) is None: 25 | self.sum[key] = val 26 | self.cnt[key] = 1 27 | else: 28 | self.sum[key] = self.sum[key] + val 29 | self.cnt[key] += 1 30 | 31 | def reset(self): 32 | for key in self.sum.keys(): 33 | self.sum[key] = 0 34 | self.cnt[key] = 0 35 | self.clock = time.time() 36 | 37 | def clear(self): 38 | self.sum = OrderedDict() 39 | self.cnt = OrderedDict() 40 | self.clock = time.time() 41 | 42 | def keys(self): 43 | return self.sum.keys() 44 | 45 | def get(self, key): 46 | assert(self.sum.get(key, None) is not None) 47 | return self.sum[key] / self.cnt[key] 48 | 49 | def info(self, show_et=True): 50 | line = '' 51 | for key in self.sum.keys(): 52 | val = self.sum[key] / self.cnt[key] 53 | if type(val) == float: 54 | line += f'{key} {val:.4f} ' 55 | else: 56 | line += f'{key} {val} '.format(key, val) 57 | if show_et: 58 | line += f'({time.time()-self.clock:.3f} secs)' 59 | return line 60 | -------------------------------------------------------------------------------- /bayesian_optimization/utils/log.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import logging 4 | from collections import OrderedDict 5 | 6 | def get_logger(filename, mode='a'): 7 | logging.basicConfig(level=logging.INFO, format='%(message)s') 8 | logger = logging.getLogger() 9 | logger.addHandler(logging.FileHandler(filename, mode=mode)) 10 | return logger 11 | 12 | class RunningAverage(object): 13 | def __init__(self, *keys): 14 | self.sum = OrderedDict() 15 | self.cnt = OrderedDict() 16 | self.clock = time.time() 17 | for key in keys: 18 | self.sum[key] = 0 19 | self.cnt[key] = 0 20 | 21 | def update(self, key, val): 22 | if isinstance(val, torch.Tensor): 23 | val = val.item() 24 | if self.sum.get(key, None) is None: 25 | self.sum[key] = val 26 | self.cnt[key] = 1 27 | else: 28 | self.sum[key] = self.sum[key] + val 29 | self.cnt[key] += 1 30 | 31 | def reset(self): 32 | for key in self.sum.keys(): 33 | self.sum[key] = 0 34 | self.cnt[key] = 0 35 | self.clock = time.time() 36 | 37 | def clear(self): 38 | self.sum = OrderedDict() 39 | self.cnt = OrderedDict() 40 | self.clock = time.time() 41 | 42 | def keys(self): 43 | return self.sum.keys() 44 | 45 | def get(self, key): 46 | assert(self.sum.get(key, None) is not None) 47 | return self.sum[key] / self.cnt[key] 48 | 49 | def info(self, show_et=True): 50 | line = '' 51 | for key in self.sum.keys(): 52 | val = self.sum[key] / self.cnt[key] 53 | if type(val) == float: 54 | line += f'{key} {val:.4f} ' 55 | else: 56 | line += f'{key} {val} '.format(key, val) 57 | if show_et: 58 | line += f'({time.time()-self.clock:.3f} secs)' 59 | return line 60 | -------------------------------------------------------------------------------- /regression/models/cnp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.modules import PoolingEncoder, Decoder 6 | 7 | class CNP(nn.Module): 8 | def __init__(self, 9 | dim_x=1, 10 | dim_y=1, 11 | dim_hid=128, 12 | enc_pre_depth=4, 13 | enc_post_depth=2, 14 | dec_depth=3): 15 | 16 | super().__init__() 17 | 18 | self.enc1 = PoolingEncoder( 19 | dim_x=dim_x, 20 | dim_y=dim_y, 21 | dim_hid=dim_hid, 22 | pre_depth=enc_pre_depth, 23 | post_depth=enc_post_depth) 24 | 25 | self.enc2 = PoolingEncoder( 26 | dim_x=dim_x, 27 | dim_y=dim_y, 28 | dim_hid=dim_hid, 29 | pre_depth=enc_pre_depth, 30 | post_depth=enc_post_depth) 31 | 32 | self.dec = Decoder( 33 | dim_x=dim_x, 34 | dim_y=dim_y, 35 | dim_enc=2*dim_hid, 36 | dim_hid=dim_hid, 37 | depth=dec_depth) 38 | 39 | def predict(self, xc, yc, xt, num_samples=None): 40 | encoded = torch.cat([self.enc1(xc, yc), self.enc2(xc, yc)], -1) 41 | encoded = torch.stack([encoded]*xt.shape[-2], -2) 42 | return self.dec(encoded, xt) 43 | 44 | def forward(self, batch, num_samples=None, reduce_ll=True): 45 | outs = AttrDict() 46 | py = self.predict(batch.xc, batch.yc, batch.x) 47 | ll = py.log_prob(batch.y).sum(-1) 48 | 49 | if self.training: 50 | outs.loss = -ll.mean() 51 | else: 52 | num_ctx = batch.xc.shape[-2] 53 | if reduce_ll: 54 | outs.ctx_ll = ll[...,:num_ctx].mean() 55 | outs.tar_ll = ll[...,num_ctx:].mean() 56 | else: 57 | outs.ctx_ll = ll[...,:num_ctx] 58 | outs.tar_ll = ll[...,num_ctx:] 59 | 60 | return outs 61 | -------------------------------------------------------------------------------- /regression/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MultiHeadAttn(nn.Module): 7 | def __init__(self, dim_q, dim_k, dim_v, dim_out, num_heads=8): 8 | super().__init__() 9 | self.num_heads = num_heads 10 | self.dim_out = dim_out 11 | self.fc_q = nn.Linear(dim_q, dim_out, bias=False) 12 | self.fc_k = nn.Linear(dim_k, dim_out, bias=False) 13 | self.fc_v = nn.Linear(dim_v, dim_out, bias=False) 14 | self.fc_out = nn.Linear(dim_out, dim_out) 15 | self.ln1 = nn.LayerNorm(dim_out) 16 | self.ln2 = nn.LayerNorm(dim_out) 17 | 18 | def scatter(self, x): 19 | return torch.cat(x.chunk(self.num_heads, -1), -3) 20 | 21 | def gather(self, x): 22 | return torch.cat(x.chunk(self.num_heads, -3), -1) 23 | 24 | def attend(self, q, k, v, mask=None): 25 | q_, k_, v_ = [self.scatter(x) for x in [q, k, v]] 26 | A_logits = q_ @ k_.transpose(-2, -1) / math.sqrt(self.dim_out) 27 | if mask is not None: 28 | mask = mask.bool().to(q.device) 29 | mask = torch.stack([mask]*q.shape[-2], -2) 30 | mask = torch.cat([mask]*self.num_heads, -3) 31 | A = torch.softmax(A_logits.masked_fill(mask, -float('inf')), -1) 32 | A = A.masked_fill(torch.isnan(A), 0.0) 33 | else: 34 | A = torch.softmax(A_logits, -1) 35 | return self.gather(A @ v_) 36 | 37 | def forward(self, q, k, v, mask=None): 38 | q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v) 39 | out = self.ln1(q + self.attend(q, k, v, mask=mask)) 40 | out = self.ln2(out + F.relu(self.fc_out(out))) 41 | return out 42 | 43 | class SelfAttn(MultiHeadAttn): 44 | def __init__(self, dim_in, dim_out, num_heads=8): 45 | super().__init__(dim_in, dim_in, dim_in, dim_out, num_heads) 46 | 47 | def forward(self, x, mask=None): 48 | return super().forward(x, x, x, mask=mask) 49 | -------------------------------------------------------------------------------- /bayesian_optimization/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MultiHeadAttn(nn.Module): 7 | def __init__(self, dim_q, dim_k, dim_v, dim_out, num_heads=8): 8 | super().__init__() 9 | self.num_heads = num_heads 10 | self.dim_out = dim_out 11 | self.fc_q = nn.Linear(dim_q, dim_out, bias=False) 12 | self.fc_k = nn.Linear(dim_k, dim_out, bias=False) 13 | self.fc_v = nn.Linear(dim_v, dim_out, bias=False) 14 | self.fc_out = nn.Linear(dim_out, dim_out) 15 | self.ln1 = nn.LayerNorm(dim_out) 16 | self.ln2 = nn.LayerNorm(dim_out) 17 | 18 | def scatter(self, x): 19 | return torch.cat(x.chunk(self.num_heads, -1), -3) 20 | 21 | def gather(self, x): 22 | return torch.cat(x.chunk(self.num_heads, -3), -1) 23 | 24 | def attend(self, q, k, v, mask=None): 25 | q_, k_, v_ = [self.scatter(x) for x in [q, k, v]] 26 | A_logits = q_ @ k_.transpose(-2, -1) / math.sqrt(self.dim_out) 27 | if mask is not None: 28 | mask = mask.bool().to(q.device) 29 | mask = torch.stack([mask]*q.shape[-2], -2) 30 | mask = torch.cat([mask]*self.num_heads, -3) 31 | A = torch.softmax(A_logits.masked_fill(mask, -float('inf')), -1) 32 | A = A.masked_fill(torch.isnan(A), 0.0) 33 | else: 34 | A = torch.softmax(A_logits, -1) 35 | return self.gather(A @ v_) 36 | 37 | def forward(self, q, k, v, mask=None): 38 | q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v) 39 | out = self.ln1(q + self.attend(q, k, v, mask=mask)) 40 | out = self.ln2(out + F.relu(self.fc_out(out))) 41 | return out 42 | 43 | class SelfAttn(MultiHeadAttn): 44 | def __init__(self, dim_in, dim_out, num_heads=8): 45 | super().__init__(dim_in, dim_in, dim_in, dim_out, num_heads) 46 | 47 | def forward(self, x, mask=None): 48 | return super().forward(x, x, x, mask=mask) 49 | -------------------------------------------------------------------------------- /bayesian_optimization/models/cnp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.modules import PoolingEncoder, Decoder 6 | 7 | class CNP(nn.Module): 8 | def __init__(self, 9 | dim_x=1, 10 | dim_y=1, 11 | dim_hid=128, 12 | enc_pre_depth=4, 13 | enc_post_depth=2, 14 | dec_depth=3): 15 | 16 | super().__init__() 17 | 18 | self.enc1 = PoolingEncoder( 19 | dim_x=dim_x, 20 | dim_y=dim_y, 21 | dim_hid=dim_hid, 22 | pre_depth=enc_pre_depth, 23 | post_depth=enc_post_depth) 24 | 25 | self.enc2 = PoolingEncoder( 26 | dim_x=dim_x, 27 | dim_y=dim_y, 28 | dim_hid=dim_hid, 29 | pre_depth=enc_pre_depth, 30 | post_depth=enc_post_depth) 31 | 32 | self.dec = Decoder( 33 | dim_x=dim_x, 34 | dim_y=dim_y, 35 | dim_enc=2*dim_hid, 36 | dim_hid=dim_hid, 37 | depth=dec_depth) 38 | 39 | def predict(self, xc, yc, xt, num_samples=None): 40 | encoded = torch.cat([self.enc1(xc, yc), self.enc2(xc, yc)], -1) 41 | encoded = torch.stack([encoded]*xt.shape[-2], -2) 42 | return self.dec(encoded, xt) 43 | 44 | def forward(self, batch, num_samples=None, reduce_ll=True): 45 | outs = AttrDict() 46 | py = self.predict(batch.xc, batch.yc, batch.x) 47 | ll = py.log_prob(batch.y).sum(-1) 48 | 49 | if self.training: 50 | outs.loss = -ll.mean() 51 | else: 52 | num_ctx = batch.xc.shape[-2] 53 | if reduce_ll: 54 | outs.ctx_ll = ll[...,:num_ctx].mean() 55 | outs.tar_ll = ll[...,num_ctx:].mean() 56 | else: 57 | outs.ctx_ll = ll[...,:num_ctx] 58 | outs.tar_ll = ll[...,num_ctx:] 59 | 60 | return outs 61 | -------------------------------------------------------------------------------- /regression/models/canp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.modules import CrossAttnEncoder, Decoder, PoolingEncoder 6 | 7 | class CANP(nn.Module): 8 | def __init__(self, 9 | dim_x=1, 10 | dim_y=1, 11 | dim_hid=128, 12 | enc_v_depth=4, 13 | enc_qk_depth=2, 14 | enc_pre_depth=4, 15 | enc_post_depth=2, 16 | dec_depth=3): 17 | 18 | super().__init__() 19 | 20 | self.enc1 = CrossAttnEncoder( 21 | dim_x=dim_x, 22 | dim_y=dim_y, 23 | dim_hid=dim_hid, 24 | v_depth=enc_v_depth, 25 | qk_depth=enc_qk_depth) 26 | 27 | self.enc2 = PoolingEncoder( 28 | dim_x=dim_x, 29 | dim_y=dim_y, 30 | dim_hid=dim_hid, 31 | self_attn=True, 32 | pre_depth=enc_pre_depth, 33 | post_depth=enc_post_depth) 34 | 35 | self.dec = Decoder( 36 | dim_x=dim_x, 37 | dim_y=dim_y, 38 | dim_enc=2*dim_hid, 39 | dim_hid=dim_hid, 40 | depth=dec_depth) 41 | 42 | def predict(self, xc, yc, xt, num_samples=None): 43 | theta1 = self.enc1(xc, yc, xt) 44 | theta2 = self.enc2(xc, yc) 45 | encoded = torch.cat([theta1, 46 | torch.stack([theta2]*xt.shape[-2], -2)], -1) 47 | return self.dec(encoded, xt) 48 | 49 | def forward(self, batch, num_samples=None, reduce_ll=True): 50 | outs = AttrDict() 51 | py = self.predict(batch.xc, batch.yc, batch.x) 52 | ll = py.log_prob(batch.y).sum(-1) 53 | 54 | if self.training: 55 | outs.loss = -ll.mean() 56 | else: 57 | num_ctx = batch.xc.shape[-2] 58 | if reduce_ll: 59 | outs.ctx_ll = ll[...,:num_ctx].mean() 60 | outs.tar_ll = ll[...,num_ctx:].mean() 61 | else: 62 | outs.ctx_ll = ll[...,:num_ctx] 63 | outs.tar_ll = ll[...,num_ctx:] 64 | 65 | return outs 66 | -------------------------------------------------------------------------------- /bayesian_optimization/models/canp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.modules import CrossAttnEncoder, Decoder, PoolingEncoder 6 | 7 | class CANP(nn.Module): 8 | def __init__(self, 9 | dim_x=1, 10 | dim_y=1, 11 | dim_hid=128, 12 | enc_v_depth=4, 13 | enc_qk_depth=2, 14 | enc_pre_depth=4, 15 | enc_post_depth=2, 16 | dec_depth=3): 17 | 18 | super().__init__() 19 | 20 | self.enc1 = CrossAttnEncoder( 21 | dim_x=dim_x, 22 | dim_y=dim_y, 23 | dim_hid=dim_hid, 24 | v_depth=enc_v_depth, 25 | qk_depth=enc_qk_depth) 26 | 27 | self.enc2 = PoolingEncoder( 28 | dim_x=dim_x, 29 | dim_y=dim_y, 30 | dim_hid=dim_hid, 31 | self_attn=True, 32 | pre_depth=enc_pre_depth, 33 | post_depth=enc_post_depth) 34 | 35 | self.dec = Decoder( 36 | dim_x=dim_x, 37 | dim_y=dim_y, 38 | dim_enc=2*dim_hid, 39 | dim_hid=dim_hid, 40 | depth=dec_depth) 41 | 42 | def predict(self, xc, yc, xt, num_samples=None): 43 | theta1 = self.enc1(xc, yc, xt) 44 | theta2 = self.enc2(xc, yc) 45 | encoded = torch.cat([theta1, 46 | torch.stack([theta2]*xt.shape[-2], -2)], -1) 47 | return self.dec(encoded, xt) 48 | 49 | def forward(self, batch, num_samples=None, reduce_ll=True): 50 | outs = AttrDict() 51 | py = self.predict(batch.xc, batch.yc, batch.x) 52 | ll = py.log_prob(batch.y).sum(-1) 53 | 54 | if self.training: 55 | outs.loss = -ll.mean() 56 | else: 57 | num_ctx = batch.xc.shape[-2] 58 | if reduce_ll: 59 | outs.ctx_ll = ll[...,:num_ctx].mean() 60 | outs.tar_ll = ll[...,num_ctx:].mean() 61 | else: 62 | outs.ctx_ll = ll[...,:num_ctx] 63 | outs.tar_ll = ll[...,num_ctx:] 64 | 65 | return outs 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /regression/models/bnp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.cnp import CNP 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_with_replacement as SWR, sample_subset 8 | 9 | class BNP(CNP): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.dec.add_ctx(2*kwargs['dim_hid']) 13 | 14 | def encode(self, xc, yc, xt, mask=None): 15 | encoded = torch.cat([ 16 | self.enc1(xc, yc, mask=mask), 17 | self.enc2(xc, yc, mask=mask)], -1) 18 | return stack(encoded, xt.shape[-2], -2) 19 | 20 | def predict(self, xc, yc, xt, num_samples=None, return_base=False): 21 | with torch.no_grad(): 22 | bxc, byc = SWR(xc, yc, num_samples=num_samples) 23 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples) 24 | 25 | encoded = self.encode(bxc, byc, sxc) 26 | py_res = self.dec(encoded, sxc) 27 | 28 | mu, sigma = py_res.mean, py_res.scale 29 | res = SWR((syc - mu)/sigma).detach() 30 | res = (res - res.mean(-2, keepdim=True)) 31 | 32 | bxc = sxc 33 | byc = mu + sigma * res 34 | 35 | encoded_base = self.encode(xc, yc, xt) 36 | 37 | sxt = stack(xt, num_samples) 38 | encoded_bs = self.encode(bxc, byc, sxt) 39 | 40 | py = self.dec(stack(encoded_base, num_samples), 41 | sxt, ctx=encoded_bs) 42 | 43 | if self.training or return_base: 44 | py_base = self.dec(encoded_base, xt) 45 | return py_base, py 46 | else: 47 | return py 48 | 49 | def forward(self, batch, num_samples=None, reduce_ll=True): 50 | outs = AttrDict() 51 | 52 | def compute_ll(py, y): 53 | ll = py.log_prob(y).sum(-1) 54 | if ll.dim() == 3 and reduce_ll: 55 | ll = logmeanexp(ll) 56 | return ll 57 | 58 | if self.training: 59 | py_base, py = self.predict(batch.xc, batch.yc, batch.x, 60 | num_samples=num_samples) 61 | 62 | outs.ll_base = compute_ll(py_base, batch.y).mean() 63 | outs.ll = compute_ll(py, batch.y).mean() 64 | outs.loss = -outs.ll_base - outs.ll 65 | else: 66 | py = self.predict(batch.xc, batch.yc, batch.x, 67 | num_samples=num_samples) 68 | ll = compute_ll(py, batch.y) 69 | num_ctx = batch.xc.shape[-2] 70 | if reduce_ll: 71 | outs.ctx_ll = ll[...,:num_ctx].mean() 72 | outs.tar_ll = ll[...,num_ctx:].mean() 73 | else: 74 | outs.ctx_ll = ll[...,:num_ctx] 75 | outs.tar_ll = ll[...,num_ctx:] 76 | 77 | return outs 78 | -------------------------------------------------------------------------------- /bayesian_optimization/models/bnp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.cnp import CNP 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_with_replacement as SWR, sample_subset 8 | 9 | class BNP(CNP): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.dec.add_ctx(2*kwargs['dim_hid']) 13 | 14 | def encode(self, xc, yc, xt, mask=None): 15 | encoded = torch.cat([ 16 | self.enc1(xc, yc, mask=mask), 17 | self.enc2(xc, yc, mask=mask)], -1) 18 | return stack(encoded, xt.shape[-2], -2) 19 | 20 | def predict(self, xc, yc, xt, num_samples=None, return_base=False): 21 | with torch.no_grad(): 22 | bxc, byc = SWR(xc, yc, num_samples=num_samples) 23 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples) 24 | 25 | encoded = self.encode(bxc, byc, sxc) 26 | py_res = self.dec(encoded, sxc) 27 | 28 | mu, sigma = py_res.mean, py_res.scale 29 | res = SWR((syc - mu)/sigma).detach() 30 | res = (res - res.mean(-2, keepdim=True)) 31 | 32 | bxc = sxc 33 | byc = mu + sigma * res 34 | 35 | encoded_base = self.encode(xc, yc, xt) 36 | 37 | sxt = stack(xt, num_samples) 38 | encoded_bs = self.encode(bxc, byc, sxt) 39 | 40 | py = self.dec(stack(encoded_base, num_samples), 41 | sxt, ctx=encoded_bs) 42 | 43 | if self.training or return_base: 44 | py_base = self.dec(encoded_base, xt) 45 | return py_base, py 46 | else: 47 | return py 48 | 49 | def forward(self, batch, num_samples=None, reduce_ll=True): 50 | outs = AttrDict() 51 | 52 | def compute_ll(py, y): 53 | ll = py.log_prob(y).sum(-1) 54 | if ll.dim() == 3 and reduce_ll: 55 | ll = logmeanexp(ll) 56 | return ll 57 | 58 | if self.training: 59 | py_base, py = self.predict(batch.xc, batch.yc, batch.x, 60 | num_samples=num_samples) 61 | 62 | outs.ll_base = compute_ll(py_base, batch.y).mean() 63 | outs.ll = compute_ll(py, batch.y).mean() 64 | outs.loss = -outs.ll_base - outs.ll 65 | else: 66 | py = self.predict(batch.xc, batch.yc, batch.x, 67 | num_samples=num_samples) 68 | ll = compute_ll(py, batch.y) 69 | num_ctx = batch.xc.shape[-2] 70 | if reduce_ll: 71 | outs.ctx_ll = ll[...,:num_ctx].mean() 72 | outs.tar_ll = ll[...,num_ctx:].mean() 73 | else: 74 | outs.ctx_ll = ll[...,:num_ctx] 75 | outs.tar_ll = ll[...,num_ctx:] 76 | 77 | return outs 78 | -------------------------------------------------------------------------------- /regression/models/banp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.canp import CANP 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_with_replacement as SWR, sample_subset 8 | 9 | class BANP(CANP): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.dec.add_ctx(2*kwargs['dim_hid']) 13 | 14 | def encode(self, xc, yc, xt, mask=None): 15 | theta1 = self.enc1(xc, yc, xt) 16 | theta2 = self.enc2(xc, yc) 17 | encoded = torch.cat([theta1, 18 | torch.stack([theta2]*xt.shape[-2], -2)], -1) 19 | return encoded 20 | 21 | def predict(self, xc, yc, xt, num_samples=None, return_base=False): 22 | with torch.no_grad(): 23 | bxc, byc = SWR(xc, yc, num_samples=num_samples) 24 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples) 25 | 26 | encoded = self.encode(bxc, byc, sxc) 27 | py_res = self.dec(encoded, sxc) 28 | 29 | mu, sigma = py_res.mean, py_res.scale 30 | res = SWR((syc - mu)/sigma).detach() 31 | res = (res - res.mean(-2, keepdim=True)) 32 | 33 | bxc = sxc 34 | byc = mu + sigma * res 35 | 36 | encoded_base = self.encode(xc, yc, xt) 37 | 38 | sxt = stack(xt, num_samples) 39 | encoded_bs = self.encode(bxc, byc, sxt) 40 | 41 | py = self.dec(stack(encoded_base, num_samples), 42 | sxt, ctx=encoded_bs) 43 | 44 | if self.training or return_base: 45 | py_base = self.dec(encoded_base, xt) 46 | return py_base, py 47 | else: 48 | return py 49 | 50 | def forward(self, batch, num_samples=None, reduce_ll=True): 51 | outs = AttrDict() 52 | 53 | def compute_ll(py, y): 54 | ll = py.log_prob(y).sum(-1) 55 | if ll.dim() == 3 and reduce_ll: 56 | ll = logmeanexp(ll) 57 | return ll 58 | 59 | if self.training: 60 | py_base, py = self.predict(batch.xc, batch.yc, batch.x, 61 | num_samples=num_samples) 62 | 63 | outs.ll_base = compute_ll(py_base, batch.y).mean() 64 | outs.ll = compute_ll(py, batch.y).mean() 65 | outs.loss = -outs.ll_base - outs.ll 66 | else: 67 | py = self.predict(batch.xc, batch.yc, batch.x, 68 | num_samples=num_samples) 69 | ll = compute_ll(py, batch.y) 70 | num_ctx = batch.xc.shape[-2] 71 | if reduce_ll: 72 | outs.ctx_ll = ll[...,:num_ctx].mean() 73 | outs.tar_ll = ll[...,num_ctx:].mean() 74 | else: 75 | outs.ctx_ll = ll[...,:num_ctx] 76 | outs.tar_ll = ll[...,num_ctx:] 77 | 78 | return outs 79 | -------------------------------------------------------------------------------- /bayesian_optimization/models/banp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attrdict import AttrDict 4 | 5 | from models.canp import CANP 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_with_replacement as SWR, sample_subset 8 | 9 | class BANP(CANP): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.dec.add_ctx(2*kwargs['dim_hid']) 13 | 14 | def encode(self, xc, yc, xt, mask=None): 15 | theta1 = self.enc1(xc, yc, xt) 16 | theta2 = self.enc2(xc, yc) 17 | encoded = torch.cat([theta1, 18 | torch.stack([theta2]*xt.shape[-2], -2)], -1) 19 | return encoded 20 | 21 | def predict(self, xc, yc, xt, num_samples=None, return_base=False): 22 | with torch.no_grad(): 23 | bxc, byc = SWR(xc, yc, num_samples=num_samples) 24 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples) 25 | 26 | encoded = self.encode(bxc, byc, sxc) 27 | py_res = self.dec(encoded, sxc) 28 | 29 | mu, sigma = py_res.mean, py_res.scale 30 | res = SWR((syc - mu)/sigma).detach() 31 | res = (res - res.mean(-2, keepdim=True)) 32 | 33 | bxc = sxc 34 | byc = mu + sigma * res 35 | 36 | encoded_base = self.encode(xc, yc, xt) 37 | 38 | sxt = stack(xt, num_samples) 39 | encoded_bs = self.encode(bxc, byc, sxt) 40 | 41 | py = self.dec(stack(encoded_base, num_samples), 42 | sxt, ctx=encoded_bs) 43 | 44 | if self.training or return_base: 45 | py_base = self.dec(encoded_base, xt) 46 | return py_base, py 47 | else: 48 | return py 49 | 50 | def forward(self, batch, num_samples=None, reduce_ll=True): 51 | outs = AttrDict() 52 | 53 | def compute_ll(py, y): 54 | ll = py.log_prob(y).sum(-1) 55 | if ll.dim() == 3 and reduce_ll: 56 | ll = logmeanexp(ll) 57 | return ll 58 | 59 | if self.training: 60 | py_base, py = self.predict(batch.xc, batch.yc, batch.x, 61 | num_samples=num_samples) 62 | 63 | outs.ll_base = compute_ll(py_base, batch.y).mean() 64 | outs.ll = compute_ll(py, batch.y).mean() 65 | outs.loss = -outs.ll_base - outs.ll 66 | else: 67 | py = self.predict(batch.xc, batch.yc, batch.x, 68 | num_samples=num_samples) 69 | ll = compute_ll(py, batch.y) 70 | num_ctx = batch.xc.shape[-2] 71 | if reduce_ll: 72 | outs.ctx_ll = ll[...,:num_ctx].mean() 73 | outs.tar_ll = ll[...,num_ctx:].mean() 74 | else: 75 | outs.ctx_ll = ll[...,:num_ctx] 76 | outs.tar_ll = ll[...,num_ctx:] 77 | 78 | return outs 79 | -------------------------------------------------------------------------------- /regression/data/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | import argparse 4 | 5 | from utils.paths import datasets_path 6 | from utils.misc import gen_load_func 7 | 8 | class CelebA(object): 9 | def __init__(self, train=True): 10 | self.data, self.targets = torch.load( 11 | osp.join(datasets_path, 'celeba', 12 | 'train.pt' if train else 'eval.pt')) 13 | self.data = self.data.float() / 255.0 14 | 15 | if train: 16 | self.data, self.targets = self.data, self.targets 17 | else: 18 | self.data, self.targets = self.data, self.targets 19 | 20 | def __len__(self): 21 | return len(self.data) 22 | 23 | def __getitem__(self, index): 24 | return self.data[index], self.targets[index] 25 | 26 | if __name__ == '__main__': 27 | 28 | # preprocess 29 | # before proceeding, download img_celeba.7z from 30 | # https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg 31 | # ,download list_eval_partitions.txt from 32 | # https://drive.google.com/drive/folders/0B7EVK8r0v71pdjI3dmwtNm5jRkE 33 | # and download identity_CelebA.txt from 34 | # https://drive.google.com/drive/folders/0B7EVK8r0v71pOC0wOVZlQnFfaGs 35 | # and place them in ${datasets_path}/celeba folder. 36 | 37 | import os 38 | import os.path as osp 39 | from PIL import Image 40 | from tqdm import tqdm 41 | import numpy as np 42 | import torch 43 | 44 | # load train/val/test split 45 | splitdict = {} 46 | with open(osp.join(datasets_path, 'celeba', 'list_eval_partition.txt'), 'r') as f: 47 | for line in f: 48 | fn, split = line.split() 49 | splitdict[fn] = int(split) 50 | 51 | # load identities 52 | iddict = {} 53 | with open(osp.join(datasets_path, 'celeba', 'identity_CelebA.txt'), 'r') as f: 54 | for line in f: 55 | fn, label = line.split() 56 | iddict[fn] = int(label) 57 | 58 | train_imgs = [] 59 | train_labels = [] 60 | eval_imgs = [] 61 | eval_labels = [] 62 | path = osp.join(datasets_path, 'celeba', 'img_align_celeba') 63 | imgfilenames = os.listdir(path) 64 | for fn in tqdm(imgfilenames): 65 | 66 | img = Image.open(osp.join(path, fn)).resize((32, 32)) 67 | if splitdict[fn] == 2: 68 | eval_imgs.append(torch.LongTensor(np.array(img).transpose(2, 0, 1))) 69 | eval_labels.append(iddict[fn]) 70 | else: 71 | train_imgs.append(torch.LongTensor(np.array(img).transpose(2, 0, 1))) 72 | train_labels.append(iddict[fn]) 73 | 74 | print(f'{len(train_imgs)} train, {len(eval_imgs)} eval') 75 | 76 | train_imgs = torch.stack(train_imgs) 77 | train_labels = torch.LongTensor(train_labels) 78 | torch.save([train_imgs, train_labels], osp.join(datasets_path, 'celeba', 'train.pt')) 79 | 80 | eval_imgs = torch.stack(eval_imgs) 81 | eval_labels = torch.LongTensor(eval_labels) 82 | torch.save([eval_imgs, eval_labels], osp.join(datasets_path, 'celeba', 'eval.pt')) 83 | -------------------------------------------------------------------------------- /regression/data/image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from attrdict import AttrDict 3 | from torch.utils.data import DataLoader 4 | from torch.distributions import StudentT, Normal 5 | 6 | def img_to_task(img, num_ctx=None, 7 | max_num_points=None, target_all=False, t_noise=None, device=None): 8 | 9 | B, C, H, W = img.shape 10 | num_pixels = H*W 11 | img = img.view(B, C, -1) 12 | 13 | if t_noise is not None: 14 | if t_noise == -1: 15 | t_noise = 0.09 * torch.rand(img.shape) 16 | img += t_noise * StudentT(2.1).rsample(img.shape) 17 | 18 | device = img.device if device is None else device 19 | 20 | batch = AttrDict() 21 | max_num_points = max_num_points or num_pixels 22 | num_ctx = num_ctx or \ 23 | torch.randint(low=3, high=max_num_points-3, size=[1]).item() 24 | num_tar = max_num_points - num_ctx if target_all else \ 25 | torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item() 26 | num_points = num_ctx + num_tar 27 | idxs = torch.rand(B, num_pixels).argsort(-1)[...,:num_points].to(img.device) 28 | x1, x2 = idxs//W, idxs%W 29 | batch.x = torch.stack([ 30 | 2*x1.float()/(H-1) - 1, 31 | 2*x2.float()/(W-1) - 1], -1).to(device) 32 | batch.y = (torch.gather(img, -1, idxs.unsqueeze(-2).repeat(1, C, 1))\ 33 | .transpose(-2, -1) - 0.5).to(device) 34 | 35 | batch.xc = batch.x[:,:num_ctx] 36 | batch.xt = batch.x[:,num_ctx:] 37 | batch.yc = batch.y[:,:num_ctx] 38 | batch.yt = batch.y[:,num_ctx:] 39 | 40 | return batch 41 | 42 | def coord_to_img(x, y, shape): 43 | x = x.cpu() 44 | y = y.cpu() 45 | B = x.shape[0] 46 | C, H, W = shape 47 | 48 | I = torch.zeros(B, 3, H, W) 49 | I[:,0,:,:] = 0.61 50 | I[:,1,:,:] = 0.55 51 | I[:,2,:,:] = 0.71 52 | 53 | x1, x2 = x[...,0], x[...,1] 54 | x1 = ((x1+1)*(H-1)/2).round().long() 55 | x2 = ((x2+1)*(W-1)/2).round().long() 56 | for b in range(B): 57 | for c in range(3): 58 | I[b,c,x1[b],x2[b]] = y[b,:,min(c,C-1)] 59 | 60 | return I 61 | 62 | def task_to_img(xc, yc, xt, yt, shape): 63 | xc = xc.cpu() 64 | yc = yc.cpu() 65 | xt = xt.cpu() 66 | yt = yt.cpu() 67 | 68 | B = xc.shape[0] 69 | C, H, W = shape 70 | 71 | xc1, xc2 = xc[...,0], xc[...,1] 72 | xc1 = ((xc1+1)*(H-1)/2).round().long() 73 | xc2 = ((xc2+1)*(W-1)/2).round().long() 74 | 75 | xt1, xt2 = xt[...,0], xt[...,1] 76 | xt1 = ((xt1+1)*(H-1)/2).round().long() 77 | xt2 = ((xt2+1)*(W-1)/2).round().long() 78 | 79 | task_img = torch.zeros(B, 3, H, W).to(xc.device) 80 | task_img[:,2,:,:] = 1.0 81 | task_img[:,1,:,:] = 0.4 82 | for b in range(B): 83 | for c in range(3): 84 | task_img[b,c,xc1[b],xc2[b]] = yc[b,:,min(c,C-1)] + 0.5 85 | task_img = task_img.clamp(0, 1) 86 | 87 | completed_img = task_img.clone() 88 | for b in range(B): 89 | for c in range(3): 90 | completed_img[b,c,xt1[b],xt2[b]] = yt[b,:,min(c,C-1)] + 0.5 91 | completed_img = completed_img.clamp(0, 1) 92 | 93 | return task_img, completed_img 94 | -------------------------------------------------------------------------------- /regression/models/np.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import kl_divergence 4 | from attrdict import AttrDict 5 | 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_subset 8 | from models.modules import PoolingEncoder, Decoder 9 | 10 | class NP(nn.Module): 11 | def __init__(self, 12 | dim_x=1, 13 | dim_y=1, 14 | dim_hid=128, 15 | dim_lat=128, 16 | enc_pre_depth=4, 17 | enc_post_depth=2, 18 | dec_depth=3): 19 | 20 | super().__init__() 21 | 22 | self.denc = PoolingEncoder( 23 | dim_x=dim_x, 24 | dim_y=dim_y, 25 | dim_hid=dim_hid, 26 | pre_depth=enc_pre_depth, 27 | post_depth=enc_post_depth) 28 | 29 | self.lenc = PoolingEncoder( 30 | dim_x=dim_x, 31 | dim_y=dim_y, 32 | dim_hid=dim_hid, 33 | dim_lat=dim_lat, 34 | pre_depth=enc_pre_depth, 35 | post_depth=enc_post_depth) 36 | 37 | self.dec = Decoder( 38 | dim_x=dim_x, 39 | dim_y=dim_y, 40 | dim_enc=dim_hid+dim_lat, 41 | dim_hid=dim_hid, 42 | depth=dec_depth) 43 | 44 | def predict(self, xc, yc, xt, z=None, num_samples=None): 45 | theta = stack(self.denc(xc, yc), num_samples) 46 | if z is None: 47 | pz = self.lenc(xc, yc) 48 | z = pz.rsample() if num_samples is None \ 49 | else pz.rsample([num_samples]) 50 | encoded = torch.cat([theta, z], -1) 51 | encoded = stack(encoded, xt.shape[-2], -2) 52 | return self.dec(encoded, stack(xt, num_samples)) 53 | 54 | def forward(self, batch, num_samples=None, reduce_ll=True): 55 | outs = AttrDict() 56 | if self.training: 57 | pz = self.lenc(batch.xc, batch.yc) 58 | qz = self.lenc(batch.x, batch.y) 59 | z = qz.rsample() if num_samples is None else \ 60 | qz.rsample([num_samples]) 61 | py = self.predict(batch.xc, batch.yc, batch.x, 62 | z=z, num_samples=num_samples) 63 | 64 | if num_samples > 1: 65 | # K * B * N 66 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1) 67 | # K * B 68 | log_qz = qz.log_prob(z).sum(-1) 69 | log_pz = pz.log_prob(z).sum(-1) 70 | 71 | # K * B 72 | log_w = recon.sum(-1) + log_pz - log_qz 73 | 74 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2] 75 | else: 76 | outs.recon = py.log_prob(batch.y).sum(-1).mean() 77 | outs.kld = kl_divergence(qz, pz).sum(-1).mean() 78 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2] 79 | 80 | else: 81 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples) 82 | if num_samples is None: 83 | ll = py.log_prob(batch.y).sum(-1) 84 | else: 85 | y = torch.stack([batch.y]*num_samples) 86 | if reduce_ll: 87 | ll = logmeanexp(py.log_prob(y).sum(-1)) 88 | else: 89 | ll = py.log_prob(y).sum(-1) 90 | num_ctx = batch.xc.shape[-2] 91 | if reduce_ll: 92 | outs.ctx_ll = ll[...,:num_ctx].mean() 93 | outs.tar_ll = ll[...,num_ctx:].mean() 94 | else: 95 | outs.ctx_ll = ll[...,:num_ctx] 96 | outs.tar_ll = ll[...,num_ctx:] 97 | return outs 98 | -------------------------------------------------------------------------------- /bayesian_optimization/models/np.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import kl_divergence 4 | from attrdict import AttrDict 5 | 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_subset 8 | from models.modules import PoolingEncoder, Decoder 9 | 10 | class NP(nn.Module): 11 | def __init__(self, 12 | dim_x=1, 13 | dim_y=1, 14 | dim_hid=128, 15 | dim_lat=128, 16 | enc_pre_depth=4, 17 | enc_post_depth=2, 18 | dec_depth=3): 19 | 20 | super().__init__() 21 | 22 | self.denc = PoolingEncoder( 23 | dim_x=dim_x, 24 | dim_y=dim_y, 25 | dim_hid=dim_hid, 26 | pre_depth=enc_pre_depth, 27 | post_depth=enc_post_depth) 28 | 29 | self.lenc = PoolingEncoder( 30 | dim_x=dim_x, 31 | dim_y=dim_y, 32 | dim_hid=dim_hid, 33 | dim_lat=dim_lat, 34 | pre_depth=enc_pre_depth, 35 | post_depth=enc_post_depth) 36 | 37 | self.dec = Decoder( 38 | dim_x=dim_x, 39 | dim_y=dim_y, 40 | dim_enc=dim_hid+dim_lat, 41 | dim_hid=dim_hid, 42 | depth=dec_depth) 43 | 44 | def predict(self, xc, yc, xt, z=None, num_samples=None): 45 | theta = stack(self.denc(xc, yc), num_samples) 46 | if z is None: 47 | pz = self.lenc(xc, yc) 48 | z = pz.rsample() if num_samples is None \ 49 | else pz.rsample([num_samples]) 50 | encoded = torch.cat([theta, z], -1) 51 | encoded = stack(encoded, xt.shape[-2], -2) 52 | return self.dec(encoded, stack(xt, num_samples)) 53 | 54 | def forward(self, batch, num_samples=None, reduce_ll=True): 55 | outs = AttrDict() 56 | if self.training: 57 | pz = self.lenc(batch.xc, batch.yc) 58 | qz = self.lenc(batch.x, batch.y) 59 | z = qz.rsample() if num_samples is None else \ 60 | qz.rsample([num_samples]) 61 | py = self.predict(batch.xc, batch.yc, batch.x, 62 | z=z, num_samples=num_samples) 63 | 64 | if num_samples > 1: 65 | # K * B * N 66 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1) 67 | # K * B 68 | log_qz = qz.log_prob(z).sum(-1) 69 | log_pz = pz.log_prob(z).sum(-1) 70 | 71 | # K * B 72 | log_w = recon.sum(-1) + log_pz - log_qz 73 | 74 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2] 75 | else: 76 | outs.recon = py.log_prob(batch.y).sum(-1).mean() 77 | outs.kld = kl_divergence(qz, pz).sum(-1).mean() 78 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2] 79 | 80 | else: 81 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples) 82 | if num_samples is None: 83 | ll = py.log_prob(batch.y).sum(-1) 84 | else: 85 | y = torch.stack([batch.y]*num_samples) 86 | if reduce_ll: 87 | ll = logmeanexp(py.log_prob(y).sum(-1)) 88 | else: 89 | ll = py.log_prob(y).sum(-1) 90 | num_ctx = batch.xc.shape[-2] 91 | if reduce_ll: 92 | outs.ctx_ll = ll[...,:num_ctx].mean() 93 | outs.tar_ll = ll[...,num_ctx:].mean() 94 | else: 95 | outs.ctx_ll = ll[...,:num_ctx] 96 | outs.tar_ll = ll[...,num_ctx:] 97 | return outs 98 | -------------------------------------------------------------------------------- /regression/models/anp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import kl_divergence 4 | from attrdict import AttrDict 5 | 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_subset 8 | 9 | from models.modules import CrossAttnEncoder, PoolingEncoder, Decoder 10 | 11 | class ANP(nn.Module): 12 | def __init__(self, 13 | dim_x=1, 14 | dim_y=1, 15 | dim_hid=128, 16 | dim_lat=128, 17 | enc_v_depth=4, 18 | enc_qk_depth=2, 19 | enc_pre_depth=4, 20 | enc_post_depth=2, 21 | dec_depth=3): 22 | 23 | super().__init__() 24 | 25 | self.denc = CrossAttnEncoder( 26 | dim_x=dim_x, 27 | dim_y=dim_y, 28 | dim_hid=dim_hid, 29 | v_depth=enc_v_depth, 30 | qk_depth=enc_qk_depth) 31 | 32 | self.lenc = PoolingEncoder( 33 | dim_x=dim_x, 34 | dim_y=dim_y, 35 | dim_hid=dim_hid, 36 | dim_lat=dim_lat, 37 | self_attn=True, 38 | pre_depth=enc_pre_depth, 39 | post_depth=enc_post_depth) 40 | 41 | self.dec = Decoder( 42 | dim_x=dim_x, 43 | dim_y=dim_y, 44 | dim_enc=dim_hid+dim_lat, 45 | dim_hid=dim_hid, 46 | depth=dec_depth) 47 | 48 | def predict(self, xc, yc, xt, z=None, num_samples=None): 49 | theta = stack(self.denc(xc, yc, xt), num_samples) 50 | if z is None: 51 | pz = self.lenc(xc, yc) 52 | z = pz.rsample() if num_samples is None \ 53 | else pz.rsample([num_samples]) 54 | z = stack(z, xt.shape[-2], -2) 55 | encoded = torch.cat([theta, z], -1) 56 | return self.dec(encoded, stack(xt, num_samples)) 57 | 58 | def forward(self, batch, num_samples=None, reduce_ll=True): 59 | outs = AttrDict() 60 | if self.training: 61 | pz = self.lenc(batch.xc, batch.yc) 62 | qz = self.lenc(batch.x, batch.y) 63 | z = qz.rsample() if num_samples is None else \ 64 | qz.rsample([num_samples]) 65 | py = self.predict(batch.xc, batch.yc, batch.x, 66 | z=z, num_samples=num_samples) 67 | 68 | if num_samples > 1: 69 | # K * B * N 70 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1) 71 | # K * B 72 | log_qz = qz.log_prob(z).sum(-1) 73 | log_pz = pz.log_prob(z).sum(-1) 74 | 75 | # K * B 76 | log_w = recon.sum(-1) + log_pz - log_qz 77 | 78 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2] 79 | else: 80 | outs.recon = py.log_prob(batch.y).sum(-1).mean() 81 | outs.kld = kl_divergence(qz, pz).sum(-1).mean() 82 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2] 83 | 84 | else: 85 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples) 86 | if num_samples is None: 87 | ll = py.log_prob(batch.y).sum(-1) 88 | else: 89 | y = torch.stack([batch.y]*num_samples) 90 | if reduce_ll: 91 | ll = logmeanexp(py.log_prob(y).sum(-1)) 92 | else: 93 | ll = py.log_prob(y).sum(-1) 94 | num_ctx = batch.xc.shape[-2] 95 | 96 | if reduce_ll: 97 | outs.ctx_ll = ll[...,:num_ctx].mean() 98 | outs.tar_ll = ll[...,num_ctx:].mean() 99 | else: 100 | outs.ctx_ll = ll[...,:num_ctx] 101 | outs.tar_ll = ll[...,num_ctx:] 102 | 103 | return outs 104 | -------------------------------------------------------------------------------- /bayesian_optimization/models/anp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import kl_divergence 4 | from attrdict import AttrDict 5 | 6 | from utils.misc import stack, logmeanexp 7 | from utils.sampling import sample_subset 8 | 9 | from models.modules import CrossAttnEncoder, PoolingEncoder, Decoder 10 | 11 | class ANP(nn.Module): 12 | def __init__(self, 13 | dim_x=1, 14 | dim_y=1, 15 | dim_hid=128, 16 | dim_lat=128, 17 | enc_v_depth=4, 18 | enc_qk_depth=2, 19 | enc_pre_depth=4, 20 | enc_post_depth=2, 21 | dec_depth=3): 22 | 23 | super().__init__() 24 | 25 | self.denc = CrossAttnEncoder( 26 | dim_x=dim_x, 27 | dim_y=dim_y, 28 | dim_hid=dim_hid, 29 | v_depth=enc_v_depth, 30 | qk_depth=enc_qk_depth) 31 | 32 | self.lenc = PoolingEncoder( 33 | dim_x=dim_x, 34 | dim_y=dim_y, 35 | dim_hid=dim_hid, 36 | dim_lat=dim_lat, 37 | self_attn=True, 38 | pre_depth=enc_pre_depth, 39 | post_depth=enc_post_depth) 40 | 41 | self.dec = Decoder( 42 | dim_x=dim_x, 43 | dim_y=dim_y, 44 | dim_enc=dim_hid+dim_lat, 45 | dim_hid=dim_hid, 46 | depth=dec_depth) 47 | 48 | def predict(self, xc, yc, xt, z=None, num_samples=None): 49 | theta = stack(self.denc(xc, yc, xt), num_samples) 50 | if z is None: 51 | pz = self.lenc(xc, yc) 52 | z = pz.rsample() if num_samples is None \ 53 | else pz.rsample([num_samples]) 54 | z = stack(z, xt.shape[-2], -2) 55 | encoded = torch.cat([theta, z], -1) 56 | return self.dec(encoded, stack(xt, num_samples)) 57 | 58 | def forward(self, batch, num_samples=None, reduce_ll=True): 59 | outs = AttrDict() 60 | if self.training: 61 | pz = self.lenc(batch.xc, batch.yc) 62 | qz = self.lenc(batch.x, batch.y) 63 | z = qz.rsample() if num_samples is None else \ 64 | qz.rsample([num_samples]) 65 | py = self.predict(batch.xc, batch.yc, batch.x, 66 | z=z, num_samples=num_samples) 67 | 68 | if num_samples > 1: 69 | # K * B * N 70 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1) 71 | # K * B 72 | log_qz = qz.log_prob(z).sum(-1) 73 | log_pz = pz.log_prob(z).sum(-1) 74 | 75 | # K * B 76 | log_w = recon.sum(-1) + log_pz - log_qz 77 | 78 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2] 79 | else: 80 | outs.recon = py.log_prob(batch.y).sum(-1).mean() 81 | outs.kld = kl_divergence(qz, pz).sum(-1).mean() 82 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2] 83 | 84 | else: 85 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples) 86 | if num_samples is None: 87 | ll = py.log_prob(batch.y).sum(-1) 88 | else: 89 | y = torch.stack([batch.y]*num_samples) 90 | if reduce_ll: 91 | ll = logmeanexp(py.log_prob(y).sum(-1)) 92 | else: 93 | ll = py.log_prob(y).sum(-1) 94 | num_ctx = batch.xc.shape[-2] 95 | 96 | if reduce_ll: 97 | outs.ctx_ll = ll[...,:num_ctx].mean() 98 | outs.tar_ll = ll[...,num_ctx:].mean() 99 | else: 100 | outs.ctx_ll = ll[...,:num_ctx] 101 | outs.tar_ll = ll[...,num_ctx:] 102 | 103 | return outs 104 | -------------------------------------------------------------------------------- /regression/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | 6 | from models.attention import MultiHeadAttn, SelfAttn 7 | 8 | __all__ = ['PoolingEncoder', 'CrossAttnEncoder', 'Decoder'] 9 | 10 | def build_mlp(dim_in, dim_hid, dim_out, depth): 11 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)] 12 | for _ in range(depth-2): 13 | modules.append(nn.Linear(dim_hid, dim_hid)) 14 | modules.append(nn.ReLU(True)) 15 | modules.append(nn.Linear(dim_hid, dim_out)) 16 | return nn.Sequential(*modules) 17 | 18 | class PoolingEncoder(nn.Module): 19 | def __init__(self, dim_x=1, dim_y=1, 20 | dim_hid=128, dim_lat=None, self_attn=False, 21 | pre_depth=4, post_depth=2): 22 | super().__init__() 23 | 24 | self.use_lat = dim_lat is not None 25 | 26 | self.net_pre = build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth) \ 27 | if not self_attn else \ 28 | nn.Sequential( 29 | build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth-2), 30 | nn.ReLU(True), 31 | SelfAttn(dim_hid, dim_hid)) 32 | 33 | self.net_post = build_mlp(dim_hid, dim_hid, 34 | 2*dim_lat if self.use_lat else dim_hid, 35 | post_depth) 36 | 37 | def forward(self, xc, yc, mask=None): 38 | out = self.net_pre(torch.cat([xc, yc], -1)) 39 | if mask is None: 40 | out = out.mean(-2) 41 | else: 42 | mask = mask.to(xc.device) 43 | out = (out * mask.unsqueeze(-1)).sum(-2) / \ 44 | (mask.sum(-1, keepdim=True).detach() + 1e-5) 45 | if self.use_lat: 46 | mu, sigma = self.net_post(out).chunk(2, -1) 47 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma) 48 | return Normal(mu, sigma) 49 | else: 50 | return self.net_post(out) 51 | 52 | class CrossAttnEncoder(nn.Module): 53 | def __init__(self, dim_x=1, dim_y=1, dim_hid=128, 54 | dim_lat=None, self_attn=True, 55 | v_depth=4, qk_depth=2): 56 | super().__init__() 57 | self.use_lat = dim_lat is not None 58 | 59 | if not self_attn: 60 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth) 61 | else: 62 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth-2) 63 | self.self_attn = SelfAttn(dim_hid, dim_hid) 64 | 65 | self.net_qk = build_mlp(dim_x, dim_hid, dim_hid, qk_depth) 66 | 67 | self.attn = MultiHeadAttn(dim_hid, dim_hid, dim_hid, 68 | 2*dim_lat if self.use_lat else dim_hid) 69 | 70 | def forward(self, xc, yc, xt, mask=None): 71 | q, k = self.net_qk(xt), self.net_qk(xc) 72 | v = self.net_v(torch.cat([xc, yc], -1)) 73 | 74 | if hasattr(self, 'self_attn'): 75 | v = self.self_attn(v, mask=mask) 76 | 77 | out = self.attn(q, k, v, mask=mask) 78 | if self.use_lat: 79 | mu, sigma = out.chunk(2, -1) 80 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma) 81 | return Normal(mu, sigma) 82 | else: 83 | return out 84 | 85 | class Decoder(nn.Module): 86 | def __init__(self, dim_x=1, dim_y=1, 87 | dim_enc=128, dim_hid=128, depth=3): 88 | super().__init__() 89 | self.fc = nn.Linear(dim_x+dim_enc, dim_hid) 90 | self.dim_hid = dim_hid 91 | 92 | modules = [nn.ReLU(True)] 93 | for _ in range(depth-2): 94 | modules.append(nn.Linear(dim_hid, dim_hid)) 95 | modules.append(nn.ReLU(True)) 96 | modules.append(nn.Linear(dim_hid, 2*dim_y)) 97 | self.mlp = nn.Sequential(*modules) 98 | 99 | def add_ctx(self, dim_ctx): 100 | self.dim_ctx = dim_ctx 101 | self.fc_ctx = nn.Linear(dim_ctx, self.dim_hid, bias=False) 102 | 103 | def forward(self, encoded, x, ctx=None): 104 | packed = torch.cat([encoded, x], -1) 105 | hid = self.fc(packed) 106 | if ctx is not None: 107 | hid = hid + self.fc_ctx(ctx) 108 | out = self.mlp(hid) 109 | mu, sigma = out.chunk(2, -1) 110 | sigma = 0.1 + 0.9 * F.softplus(sigma) 111 | return Normal(mu, sigma) 112 | -------------------------------------------------------------------------------- /bayesian_optimization/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | 6 | from models.attention import MultiHeadAttn, SelfAttn 7 | 8 | __all__ = ['PoolingEncoder', 'CrossAttnEncoder', 'Decoder'] 9 | 10 | def build_mlp(dim_in, dim_hid, dim_out, depth): 11 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)] 12 | for _ in range(depth-2): 13 | modules.append(nn.Linear(dim_hid, dim_hid)) 14 | modules.append(nn.ReLU(True)) 15 | modules.append(nn.Linear(dim_hid, dim_out)) 16 | return nn.Sequential(*modules) 17 | 18 | class PoolingEncoder(nn.Module): 19 | def __init__(self, dim_x=1, dim_y=1, 20 | dim_hid=128, dim_lat=None, self_attn=False, 21 | pre_depth=4, post_depth=2): 22 | super().__init__() 23 | 24 | self.use_lat = dim_lat is not None 25 | 26 | self.net_pre = build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth) \ 27 | if not self_attn else \ 28 | nn.Sequential( 29 | build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth-2), 30 | nn.ReLU(True), 31 | SelfAttn(dim_hid, dim_hid)) 32 | 33 | self.net_post = build_mlp(dim_hid, dim_hid, 34 | 2*dim_lat if self.use_lat else dim_hid, 35 | post_depth) 36 | 37 | def forward(self, xc, yc, mask=None): 38 | out = self.net_pre(torch.cat([xc, yc], -1)) 39 | if mask is None: 40 | out = out.mean(-2) 41 | else: 42 | mask = mask.to(xc.device) 43 | out = (out * mask.unsqueeze(-1)).sum(-2) / \ 44 | (mask.sum(-1, keepdim=True).detach() + 1e-5) 45 | if self.use_lat: 46 | mu, sigma = self.net_post(out).chunk(2, -1) 47 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma) 48 | return Normal(mu, sigma) 49 | else: 50 | return self.net_post(out) 51 | 52 | class CrossAttnEncoder(nn.Module): 53 | def __init__(self, dim_x=1, dim_y=1, dim_hid=128, 54 | dim_lat=None, self_attn=True, 55 | v_depth=4, qk_depth=2): 56 | super().__init__() 57 | self.use_lat = dim_lat is not None 58 | 59 | if not self_attn: 60 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth) 61 | else: 62 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth-2) 63 | self.self_attn = SelfAttn(dim_hid, dim_hid) 64 | 65 | self.net_qk = build_mlp(dim_x, dim_hid, dim_hid, qk_depth) 66 | 67 | self.attn = MultiHeadAttn(dim_hid, dim_hid, dim_hid, 68 | 2*dim_lat if self.use_lat else dim_hid) 69 | 70 | def forward(self, xc, yc, xt, mask=None): 71 | q, k = self.net_qk(xt), self.net_qk(xc) 72 | v = self.net_v(torch.cat([xc, yc], -1)) 73 | 74 | if hasattr(self, 'self_attn'): 75 | v = self.self_attn(v, mask=mask) 76 | 77 | out = self.attn(q, k, v, mask=mask) 78 | if self.use_lat: 79 | mu, sigma = out.chunk(2, -1) 80 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma) 81 | return Normal(mu, sigma) 82 | else: 83 | return out 84 | 85 | class Decoder(nn.Module): 86 | def __init__(self, dim_x=1, dim_y=1, 87 | dim_enc=128, dim_hid=128, depth=3): 88 | super().__init__() 89 | self.fc = nn.Linear(dim_x+dim_enc, dim_hid) 90 | self.dim_hid = dim_hid 91 | 92 | modules = [nn.ReLU(True)] 93 | for _ in range(depth-2): 94 | modules.append(nn.Linear(dim_hid, dim_hid)) 95 | modules.append(nn.ReLU(True)) 96 | modules.append(nn.Linear(dim_hid, 2*dim_y)) 97 | self.mlp = nn.Sequential(*modules) 98 | 99 | def add_ctx(self, dim_ctx): 100 | self.dim_ctx = dim_ctx 101 | self.fc_ctx = nn.Linear(dim_ctx, self.dim_hid, bias=False) 102 | 103 | def forward(self, encoded, x, ctx=None): 104 | packed = torch.cat([encoded, x], -1) 105 | hid = self.fc(packed) 106 | if ctx is not None: 107 | hid = hid + self.fc_ctx(ctx) 108 | out = self.mlp(hid) 109 | mu, sigma = out.chunk(2, -1) 110 | sigma = 0.1 + 0.9 * F.softplus(sigma) 111 | return Normal(mu, sigma) 112 | -------------------------------------------------------------------------------- /regression/data/gp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import MultivariateNormal, StudentT 5 | from attrdict import AttrDict 6 | import math 7 | 8 | __all__ = ['GPSampler', 'RBFKernel', 'PeriodicKernel', 'Matern52Kernel'] 9 | 10 | class GPSampler(object): 11 | def __init__(self, kernel, t_noise=None): 12 | self.kernel = kernel 13 | self.t_noise = t_noise 14 | 15 | def sample(self, 16 | batch_size=16, 17 | num_ctx=None, 18 | max_num_points=50, 19 | x_range=(-2, 2), 20 | device='cpu'): 21 | 22 | batch = AttrDict() 23 | num_ctx = num_ctx or torch.randint(low=3, high=max_num_points-3, size=[1]).item() 24 | num_tar = torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item() 25 | 26 | num_points = num_ctx + num_tar 27 | batch.x = x_range[0] + (x_range[1] - x_range[0]) \ 28 | * torch.rand([batch_size, num_points, 1], device=device) 29 | batch.xc = batch.x[:,:num_ctx] 30 | batch.xt = batch.x[:,num_ctx:] 31 | 32 | # batch_size * num_points * num_points 33 | cov = self.kernel(batch.x) 34 | mean = torch.zeros(batch_size, num_points, device=device) 35 | batch.y = MultivariateNormal(mean, cov).rsample().unsqueeze(-1) 36 | batch.yc = batch.y[:,:num_ctx] 37 | batch.yt = batch.y[:,num_ctx:] 38 | 39 | if self.t_noise is not None: 40 | if self.t_noise == -1: 41 | t_noise = 0.15 * torch.rand(batch.y.shape).to(device) 42 | else: 43 | t_noise = self.t_noise 44 | batch.y += t_noise * StudentT(2.1).rsample(batch.y.shape).to(device) 45 | return batch 46 | 47 | class RBFKernel(object): 48 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0): 49 | self.sigma_eps = sigma_eps 50 | self.max_length = max_length 51 | self.max_scale = max_scale 52 | 53 | # x: batch_size * num_points * dim 54 | def __call__(self, x): 55 | length = 0.1 + (self.max_length-0.1) \ 56 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device) 57 | scale = 0.1 + (self.max_scale-0.1) \ 58 | * torch.rand([x.shape[0], 1, 1], device=x.device) 59 | 60 | # batch_size * num_points * num_points * dim 61 | dist = (x.unsqueeze(-2) - x.unsqueeze(-3))/length 62 | 63 | # batch_size * num_points * num_points 64 | cov = scale.pow(2) * torch.exp(-0.5 * dist.pow(2).sum(-1)) \ 65 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device) 66 | 67 | return cov 68 | 69 | class Matern52Kernel(object): 70 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0): 71 | self.sigma_eps = sigma_eps 72 | self.max_length = max_length 73 | self.max_scale = max_scale 74 | 75 | # x: batch_size * num_points * dim 76 | def __call__(self, x): 77 | length = 0.1 + (self.max_length-0.1) \ 78 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device) 79 | scale = 0.1 + (self.max_scale-0.1) \ 80 | * torch.rand([x.shape[0], 1, 1], device=x.device) 81 | 82 | # batch_size * num_points * num_points 83 | dist = torch.norm((x.unsqueeze(-2) - x.unsqueeze(-3))/length, dim=-1) 84 | 85 | cov = scale.pow(2)*(1 + math.sqrt(5.0)*dist + 5.0*dist.pow(2)/3.0) \ 86 | * torch.exp(-math.sqrt(5.0) * dist) \ 87 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device) 88 | 89 | return cov 90 | 91 | class PeriodicKernel(object): 92 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0): 93 | #self.p = p 94 | self.sigma_eps = sigma_eps 95 | self.max_length = max_length 96 | self.max_scale = max_scale 97 | 98 | # x: batch_size * num_points * dim 99 | def __call__(self, x): 100 | p = 0.1 + 0.4*torch.rand([x.shape[0], 1, 1], device=x.device) 101 | length = 0.1 + (self.max_length-0.1) \ 102 | * torch.rand([x.shape[0], 1, 1], device=x.device) 103 | scale = 0.1 + (self.max_scale-0.1) \ 104 | * torch.rand([x.shape[0], 1, 1], device=x.device) 105 | 106 | dist = x.unsqueeze(-2) - x.unsqueeze(-3) 107 | cov = scale.pow(2) * torch.exp(\ 108 | - 2*(torch.sin(math.pi*dist.abs().sum(-1)/p)/length).pow(2)) \ 109 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device) 110 | 111 | return cov 112 | -------------------------------------------------------------------------------- /bayesian_optimization/data/gp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import MultivariateNormal, StudentT 5 | from attrdict import AttrDict 6 | import math 7 | 8 | __all__ = ['GPPriorSampler', 'GPSampler', 'RBFKernel', 'PeriodicKernel', 'Matern52Kernel'] 9 | 10 | class GPPriorSampler(object): 11 | def __init__(self, kernel, t_noise=None): 12 | self.kernel = kernel 13 | self.t_noise = t_noise 14 | 15 | def sample(self, 16 | bx, 17 | device='cuda:0'): 18 | # bx: 1 * num_points * 1 19 | 20 | # 1 * num_points * num_points 21 | cov = self.kernel(bx) 22 | mean = torch.zeros(1, bx.shape[1], device=device) 23 | mean = mean.cuda() 24 | 25 | by = MultivariateNormal(mean, cov).rsample().unsqueeze(-1) 26 | 27 | if self.t_noise is not None: 28 | by += self.t_noise * StudentT(2.1).rsample(by.shape).to(device) 29 | 30 | return by 31 | 32 | class GPSampler(object): 33 | def __init__(self, kernel, t_noise=None): 34 | self.kernel = kernel 35 | self.t_noise = t_noise 36 | 37 | def sample(self, 38 | batch_size=16, 39 | num_ctx=None, 40 | max_num_points=50, 41 | x_range=(-2, 2), 42 | device='cpu'): 43 | 44 | batch = AttrDict() 45 | num_ctx = num_ctx or torch.randint(low=3, high=max_num_points-3, size=[1]).item() 46 | num_tar = torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item() 47 | 48 | num_points = num_ctx + num_tar 49 | batch.x = x_range[0] + (x_range[1] - x_range[0]) \ 50 | * torch.rand([batch_size, num_points, 1], device=device) 51 | batch.xc = batch.x[:,:num_ctx] 52 | batch.xt = batch.x[:,num_ctx:] 53 | 54 | # batch_size * num_points * num_points 55 | cov = self.kernel(batch.x) 56 | mean = torch.zeros(batch_size, num_points, device=device) 57 | batch.y = MultivariateNormal(mean, cov).rsample().unsqueeze(-1) 58 | batch.yc = batch.y[:,:num_ctx] 59 | batch.yt = batch.y[:,num_ctx:] 60 | 61 | if self.t_noise is not None: 62 | batch.y += self.t_noise * StudentT(2.1).rsample(batch.y.shape).to(device) 63 | return batch 64 | 65 | class RBFKernel(object): 66 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0): 67 | self.sigma_eps = sigma_eps 68 | self.max_length = max_length 69 | self.max_scale = max_scale 70 | 71 | # x: batch_size * num_points * dim 72 | def __call__(self, x): 73 | length = 0.1 + (self.max_length-0.1) \ 74 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device) 75 | scale = 0.1 + (self.max_scale-0.1) \ 76 | * torch.rand([x.shape[0], 1, 1], device=x.device) 77 | 78 | # batch_size * num_points * num_points * dim 79 | dist = (x.unsqueeze(-2) - x.unsqueeze(-3))/length 80 | 81 | # batch_size * num_points * num_points 82 | cov = scale.pow(2) * torch.exp(-0.5 * dist.pow(2).sum(-1)) \ 83 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device) 84 | 85 | return cov 86 | 87 | class Matern52Kernel(object): 88 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0): 89 | self.sigma_eps = sigma_eps 90 | self.max_length = max_length 91 | self.max_scale = max_scale 92 | 93 | # x: batch_size * num_points * dim 94 | def __call__(self, x): 95 | length = 0.1 + (self.max_length-0.1) \ 96 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device) 97 | scale = 0.1 + (self.max_scale-0.1) \ 98 | * torch.rand([x.shape[0], 1, 1], device=x.device) 99 | 100 | # batch_size * num_points * num_points 101 | dist = torch.norm((x.unsqueeze(-2) - x.unsqueeze(-3))/length, dim=-1) 102 | 103 | cov = scale.pow(2)*(1 + math.sqrt(5.0)*dist + 5.0*dist.pow(2)/3.0) \ 104 | * torch.exp(-math.sqrt(5.0) * dist) \ 105 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device) 106 | 107 | return cov 108 | 109 | class PeriodicKernel(object): 110 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0): 111 | #self.p = p 112 | self.sigma_eps = sigma_eps 113 | self.max_length = max_length 114 | self.max_scale = max_scale 115 | 116 | # x: batch_size * num_points * dim 117 | def __call__(self, x): 118 | p = 0.1 + 0.4*torch.rand([x.shape[0], 1, 1], device=x.device) 119 | length = 0.1 + (self.max_length-0.1) \ 120 | * torch.rand([x.shape[0], 1, 1], device=x.device) 121 | scale = 0.1 + (self.max_scale-0.1) \ 122 | * torch.rand([x.shape[0], 1, 1], device=x.device) 123 | 124 | dist = x.unsqueeze(-2) - x.unsqueeze(-3) 125 | cov = scale.pow(2) * torch.exp(\ 126 | - 2*(torch.sin(math.pi*dist.abs().sum(-1)/p)/length).pow(2)) \ 127 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device) 128 | 129 | return cov 130 | -------------------------------------------------------------------------------- /regression/data/lotka_volterra.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import numpy.random as npr 4 | import numba as nb 5 | from tqdm import tqdm 6 | from attrdict import AttrDict 7 | #import pandas as pd 8 | import wget 9 | 10 | import os.path as osp 11 | from utils.paths import datasets_path 12 | 13 | @nb.njit(nb.i4(nb.f8[:])) 14 | def catrnd(prob): 15 | cprob = prob.cumsum() 16 | u = npr.rand() 17 | for i in range(len(cprob)): 18 | if u < cprob[i]: 19 | return i 20 | return i 21 | 22 | @nb.njit(nb.types.Tuple((nb.f8[:,:,:], nb.f8[:,:,:], nb.i4)) \ 23 | (nb.i4, nb.i4, nb.i4, \ 24 | nb.f8, nb.f8, nb.f8, nb.f8, nb.f8, nb.f8)) 25 | def _simulate_task(batch_size, num_steps, max_num_points, 26 | X0, Y0, theta0, theta1, theta2, theta3): 27 | 28 | time = np.zeros((batch_size, num_steps, 1)) 29 | pop = np.zeros((batch_size, num_steps, 2)) 30 | length = num_steps*np.ones((batch_size)) 31 | 32 | for b in range(batch_size): 33 | pop[b,0,0] = max(int(X0 + npr.randn()), 1) 34 | pop[b,0,1] = max(int(Y0 + npr.randn()), 1) 35 | for i in range(1, num_steps): 36 | X, Y = pop[b,i-1,0], pop[b,i-1,1] 37 | rates = np.array([ 38 | theta0*X*Y, 39 | theta1*X, 40 | theta2*Y, 41 | theta3*X*Y]) 42 | total_rate = rates.sum() 43 | 44 | time[b,i,0] = time[b,i-1,0] + npr.exponential(scale=1./total_rate) 45 | 46 | pop[b,i,0] = pop[b,i-1,0] 47 | pop[b,i,1] = pop[b,i-1,1] 48 | a = catrnd(rates/total_rate) 49 | if a == 0: 50 | pop[b,i,0] += 1 51 | elif a == 1: 52 | pop[b,i,0] -= 1 53 | elif a == 2: 54 | pop[b,i,1] += 1 55 | else: 56 | pop[b,i,1] -= 1 57 | 58 | if pop[b,i,0] == 0 or pop[b,i,1] == 0: 59 | length[b] = i+1 60 | break 61 | 62 | num_ctx = npr.randint(15, max_num_points-15) 63 | num_tar = npr.randint(15, max_num_points-num_ctx) 64 | num_points = num_ctx + num_tar 65 | min_length = length.min() 66 | while num_points > min_length: 67 | num_ctx = npr.randint(15, max_num_points-15) 68 | num_tar = npr.randint(15, max_num_points-num_ctx) 69 | num_points = num_ctx + num_tar 70 | 71 | x = np.zeros((batch_size, num_points, 1)) 72 | y = np.zeros((batch_size, num_points, 2)) 73 | for b in range(batch_size): 74 | idxs = np.arange(int(length[b])) 75 | npr.shuffle(idxs) 76 | for j in range(num_points): 77 | x[b,j,0] = time[b,idxs[j],0] 78 | y[b,j,0] = pop[b,idxs[j],0] 79 | y[b,j,1] = pop[b,idxs[j],1] 80 | 81 | return x, y, num_ctx 82 | 83 | class LotkaVolterraSimulator(object): 84 | def __init__(self, 85 | X0=50, 86 | Y0=100, 87 | theta0=0.01, 88 | theta1=0.5, 89 | theta2=1.0, 90 | theta3=0.01): 91 | 92 | self.X0 = X0 93 | self.Y0 = Y0 94 | self.theta0 = theta0 95 | self.theta1 = theta1 96 | self.theta2 = theta2 97 | self.theta3 = theta3 98 | 99 | def simulate_tasks(self, 100 | num_batches, 101 | batch_size, 102 | num_steps=20000, 103 | max_num_points=100): 104 | 105 | batches = [] 106 | for _ in tqdm(range(num_batches)): 107 | batch = AttrDict() 108 | x, y, num_ctx = _simulate_task( 109 | batch_size, num_steps, max_num_points, 110 | self.X0, self.Y0, self.theta0, self.theta1, self.theta2, self.theta3) 111 | batch.x = torch.Tensor(x) 112 | batch.y = torch.Tensor(y) 113 | batch.xc = batch.x[:,:num_ctx] 114 | batch.xt = batch.x[:,num_ctx:] 115 | batch.yc = batch.y[:,:num_ctx] 116 | batch.yt = batch.y[:,num_ctx:] 117 | 118 | batches.append(batch) 119 | 120 | return batches 121 | 122 | def load_hare_lynx(num_batches, batch_size): 123 | 124 | filename = osp.join(datasets_path, 'lotka_volterra', 'LynxHare.txt') 125 | if not osp.isfile(filename): 126 | wget.download('http://people.whitman.edu/~hundledr/courses/M250F03/LynxHare.txt', 127 | out=osp.join(datsets_path, 'lotka_volterra')) 128 | 129 | tb = np.loadtxt(filename) 130 | times = torch.Tensor(tb[:,0]).unsqueeze(-1) 131 | pops = torch.stack([torch.Tensor(tb[:,2]), torch.Tensor(tb[:,1])], -1) 132 | 133 | #tb = pd.read_csv(osp.join(datasets_path, 'lotka_volterra', 'hare-lynx.csv')) 134 | #times = torch.Tensor(np.array(tb['time'])).unsqueeze(-1) 135 | #pops = torch.stack([torch.Tensor(np.array(tb['lynx'])), 136 | # torch.Tensor(np.array(tb['hare']))], -1) 137 | 138 | batches = [] 139 | N = pops.shape[-2] 140 | for _ in range(num_batches): 141 | batch = AttrDict() 142 | 143 | num_ctx = torch.randint(low=15, high=N-15, size=[1]).item() 144 | num_tar = N - num_ctx 145 | 146 | idxs = torch.rand(batch_size, N).argsort(-1) 147 | 148 | batch.x = torch.gather( 149 | torch.stack([times]*batch_size), 150 | -2, idxs.unsqueeze(-1)) 151 | batch.y = torch.gather(torch.stack([pops]*batch_size), 152 | -2, torch.stack([idxs]*2, -1)) 153 | batch.xc = batch.x[:,:num_ctx] 154 | batch.xt = batch.x[:,num_ctx:] 155 | batch.yc = batch.y[:,:num_ctx] 156 | batch.yt = batch.y[:,num_ctx:] 157 | 158 | batches.append(batch) 159 | 160 | return batches 161 | 162 | if __name__ == '__main__': 163 | import argparse 164 | import os 165 | from utils.paths import datasets_path 166 | import matplotlib.pyplot as plt 167 | 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument('--num_batches', type=int, default=10000) 170 | parser.add_argument('--batch_size', type=int, default=50) 171 | parser.add_argument('--filename', type=str, default='batch') 172 | parser.add_argument('--X0', type=float, default=50) 173 | parser.add_argument('--Y0', type=float, default=100) 174 | parser.add_argument('--theta0', type=float, default=0.01) 175 | parser.add_argument('--theta1', type=float, default=0.5) 176 | parser.add_argument('--theta2', type=float, default=1.0) 177 | parser.add_argument('--theta3', type=float, default=0.01) 178 | parser.add_argument('--num_steps', type=int, default=20000) 179 | args = parser.parse_args() 180 | 181 | sim = LotkaVolterraSimulator(X0=args.X0, Y0=args.Y0, 182 | theta0=args.theta0, theta1=args.theta1, 183 | theta2=args.theta2, theta3=args.theta3) 184 | 185 | batches = sim.simulate_tasks(args.num_batches, args.batch_size, 186 | num_steps=args.num_steps) 187 | 188 | root = os.path.join(datasets_path, 'lotka_volterra') 189 | if not os.path.isdir(root): 190 | os.makedirs(root) 191 | 192 | torch.save(batches, os.path.join(root, f'{args.filename}.tar')) 193 | 194 | fig, axes = plt.subplots(1, 4, figsize=(16,4)) 195 | for i, ax in enumerate(axes.flatten()): 196 | ax.scatter(batches[0].x[i,:,0], batches[0].y[i,:,0]) 197 | ax.scatter(batches[0].x[i,:,0], batches[0].y[i,:,1]) 198 | plt.show() 199 | -------------------------------------------------------------------------------- /regression/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import argparse 5 | import yaml 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import math 11 | import time 12 | import matplotlib.pyplot as plt 13 | from attrdict import AttrDict 14 | from tqdm import tqdm 15 | from copy import deepcopy 16 | 17 | from data.image import img_to_task, task_to_img 18 | from data.celeba import CelebA 19 | 20 | from utils.misc import load_module, logmeanexp 21 | from utils.paths import results_path, evalsets_path 22 | from utils.log import get_logger, RunningAverage 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--mode', 28 | choices=['train', 'eval', 'plot', 'ensemble'], 29 | default='train') 30 | parser.add_argument('--expid', type=str, default='trial') 31 | parser.add_argument('--resume', action='store_true', default=False) 32 | parser.add_argument('--gpu', type=str, default='0') 33 | 34 | parser.add_argument('--max_num_points', type=int, default=200) 35 | 36 | parser.add_argument('--model', type=str, default='cnp') 37 | parser.add_argument('--train_batch_size', type=int, default=100) 38 | parser.add_argument('--train_num_samples', type=int, default=4) 39 | 40 | parser.add_argument('--lr', type=float, default=5e-4) 41 | parser.add_argument('--num_epochs', type=int, default=200) 42 | parser.add_argument('--eval_freq', type=int, default=10) 43 | parser.add_argument('--save_freq', type=int, default=10) 44 | 45 | parser.add_argument('--eval_seed', type=int, default=42) 46 | parser.add_argument('--eval_batch_size', type=int, default=16) 47 | parser.add_argument('--eval_num_samples', type=int, default=50) 48 | parser.add_argument('--eval_logfile', type=str, default=None) 49 | 50 | parser.add_argument('--plot_seed', type=int, default=None) 51 | parser.add_argument('--plot_batch_size', type=int, default=16) 52 | parser.add_argument('--plot_num_samples', type=int, default=30) 53 | parser.add_argument('--plot_num_ctx', type=int, default=100) 54 | 55 | # OOD settings 56 | parser.add_argument('--t_noise', type=float, default=None) 57 | 58 | args = parser.parse_args() 59 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 60 | 61 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper()) 62 | with open(f'configs/celeba/{args.model}.yaml', 'r') as f: 63 | config = yaml.safe_load(f) 64 | model = model_cls(**config).cuda() 65 | 66 | args.root = osp.join(results_path, 'celeba', args.model, args.expid) 67 | 68 | if args.mode == 'train': 69 | train(args, model) 70 | elif args.mode == 'eval': 71 | eval(args, model) 72 | elif args.mode == 'plot': 73 | plot(args, model) 74 | elif args.mode == 'ensemble': 75 | ensemble(args, model) 76 | 77 | def train(args, model): 78 | if not osp.isdir(args.root): 79 | os.makedirs(args.root) 80 | 81 | with open(osp.join(args.root, 'args.yaml'), 'w') as f: 82 | yaml.dump(args.__dict__, f) 83 | 84 | train_ds = CelebA(train=True) 85 | eval_ds = CelebA(train=False) 86 | train_loader = torch.utils.data.DataLoader(train_ds, 87 | batch_size=args.train_batch_size, 88 | shuffle=True, num_workers=4) 89 | #eval_loader = torch.utils.data.DataLoader(eval_ds, 90 | # batch_size=args.eval_batch_size, 91 | # shuffle=False, num_workers=4) 92 | 93 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 94 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 95 | optimizer, T_max=len(train_loader)*args.num_epochs) 96 | 97 | if args.resume: 98 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar')) 99 | model.load_state_dict(ckpt.model) 100 | optimizer.load_state_dict(ckpt.optimizer) 101 | scheduler.load_state_dict(ckpt.scheduler) 102 | logfilename = ckpt.logfilename 103 | start_epoch = ckpt.epoch 104 | else: 105 | logfilename = osp.join(args.root, 'train_{}.log'.format( 106 | time.strftime('%Y%m%d-%H%M'))) 107 | start_epoch = 1 108 | 109 | logger = get_logger(logfilename) 110 | ravg = RunningAverage() 111 | 112 | if not args.resume: 113 | logger.info('Total number of parameters: {}\n'.format( 114 | sum(p.numel() for p in model.parameters()))) 115 | 116 | for epoch in range(start_epoch, args.num_epochs+1): 117 | model.train() 118 | for (x, _) in tqdm(train_loader): 119 | batch = img_to_task(x, 120 | max_num_points=args.max_num_points, 121 | device='cuda') 122 | optimizer.zero_grad() 123 | outs = model(batch, num_samples=args.train_num_samples) 124 | outs.loss.backward() 125 | optimizer.step() 126 | scheduler.step() 127 | 128 | for key, val in outs.items(): 129 | ravg.update(key, val) 130 | 131 | line = f'{args.model}:{args.expid} epoch {epoch} ' 132 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} ' 133 | line += ravg.info() 134 | logger.info(line) 135 | 136 | if epoch % args.eval_freq == 0: 137 | logger.info(eval(args, model) + '\n') 138 | 139 | ravg.reset() 140 | 141 | if epoch % args.save_freq == 0 or epoch == args.num_epochs: 142 | ckpt = AttrDict() 143 | ckpt.model = model.state_dict() 144 | ckpt.optimizer = optimizer.state_dict() 145 | ckpt.scheduler = scheduler.state_dict() 146 | ckpt.logfilename = logfilename 147 | ckpt.epoch = epoch + 1 148 | torch.save(ckpt, osp.join(args.root, 'ckpt.tar')) 149 | 150 | args.mode = 'eval' 151 | eval(args, model) 152 | 153 | def gen_evalset(args): 154 | 155 | torch.manual_seed(args.eval_seed) 156 | torch.cuda.manual_seed(args.eval_seed) 157 | 158 | eval_ds = CelebA(train=False) 159 | eval_loader = torch.utils.data.DataLoader(eval_ds, 160 | batch_size=args.eval_batch_size, 161 | shuffle=False, num_workers=4) 162 | 163 | batches = [] 164 | for x, _ in tqdm(eval_loader): 165 | batches.append(img_to_task(x, 166 | t_noise=args.t_noise, 167 | max_num_points=args.max_num_points)) 168 | 169 | torch.manual_seed(time.time()) 170 | torch.cuda.manual_seed(time.time()) 171 | 172 | path = osp.join(evalsets_path, 'celeba') 173 | if not osp.isdir(path): 174 | os.makedirs(path) 175 | 176 | filename = 'no_noise.tar' if args.t_noise is None else \ 177 | f'{args.t_noise}.tar' 178 | torch.save(batches, osp.join(path, filename)) 179 | 180 | def eval(args, model): 181 | if args.mode == 'eval': 182 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar')) 183 | model.load_state_dict(ckpt.model) 184 | if args.eval_logfile is None: 185 | eval_logfile = f'eval' 186 | if args.t_noise is not None: 187 | eval_logfile += f'_{args.t_noise}' 188 | eval_logfile += '.log' 189 | else: 190 | eval_logfile = args.eval_logfile 191 | filename = osp.join(args.root, eval_logfile) 192 | logger = get_logger(filename, mode='w') 193 | else: 194 | logger = None 195 | 196 | path = osp.join(evalsets_path, 'celeba') 197 | if not osp.isdir(path): 198 | os.makedirs(path) 199 | filename = 'no_noise.tar' if args.t_noise is None else \ 200 | f'{args.t_noise}.tar' 201 | if not osp.isfile(osp.join(path, filename)): 202 | print('generating evaluation sets...') 203 | gen_evalset(args) 204 | 205 | eval_batches = torch.load(osp.join(path, filename)) 206 | 207 | torch.manual_seed(args.eval_seed) 208 | torch.cuda.manual_seed(args.eval_seed) 209 | 210 | ravg = RunningAverage() 211 | model.eval() 212 | with torch.no_grad(): 213 | for batch in tqdm(eval_batches): 214 | for key, val in batch.items(): 215 | batch[key] = val.cuda() 216 | outs = model(batch, num_samples=args.eval_num_samples) 217 | for key, val in outs.items(): 218 | ravg.update(key, val) 219 | 220 | torch.manual_seed(time.time()) 221 | torch.cuda.manual_seed(time.time()) 222 | 223 | line = f'{args.model}:{args.expid} ' 224 | if args.t_noise is not None: 225 | line += f'tn {args.t_noise} ' 226 | line += ravg.info() 227 | 228 | if logger is not None: 229 | logger.info(line) 230 | 231 | return line 232 | 233 | def ensemble(args, model): 234 | num_runs = 5 235 | models = [] 236 | for i in range(num_runs): 237 | model_ = deepcopy(model) 238 | ckpt = torch.load(osp.join(results_path, 'celeba', args.model, f'run{i+1}', 'ckpt.tar')) 239 | model_.load_state_dict(ckpt['model']) 240 | model_.cuda() 241 | model_.eval() 242 | models.append(model_) 243 | 244 | path = osp.join(evalsets_path, 'celeba') 245 | if not osp.isdir(path): 246 | os.makedirs(path) 247 | filename = 'no_noise.tar' if args.t_noise is None else \ 248 | f'{args.t_noise}.tar' 249 | if not osp.isfile(osp.join(path, filename)): 250 | print('generating evaluation sets...') 251 | gen_evalset(args) 252 | 253 | eval_batches = torch.load(osp.join(path, filename)) 254 | 255 | ravg = RunningAverage() 256 | with torch.no_grad(): 257 | for batch in tqdm(eval_batches): 258 | for key, val in batch.items(): 259 | batch[key] = val.cuda() 260 | 261 | ctx_ll = [] 262 | tar_ll = [] 263 | for model in models: 264 | outs = model(batch, 265 | num_samples=args.eval_num_samples, 266 | reduce_ll=False) 267 | ctx_ll.append(outs.ctx_ll) 268 | tar_ll.append(outs.tar_ll) 269 | 270 | if ctx_ll[0].dim() == 2: 271 | ctx_ll = torch.stack(ctx_ll) 272 | tar_ll = torch.stack(tar_ll) 273 | else: 274 | ctx_ll = torch.cat(ctx_ll) 275 | tar_ll = torch.cat(tar_ll) 276 | 277 | ctx_ll = logmeanexp(ctx_ll).mean() 278 | tar_ll = logmeanexp(tar_ll).mean() 279 | 280 | ravg.update('ctx_ll', ctx_ll) 281 | ravg.update('tar_ll', tar_ll) 282 | 283 | torch.manual_seed(time.time()) 284 | torch.cuda.manual_seed(time.time()) 285 | 286 | filename = f'ensemble' 287 | if args.t_noise is not None: 288 | filename += f'_{args.t_noise}' 289 | filename += '.log' 290 | logger = get_logger(osp.join(results_path, 'celeba', args.model, filename), mode='w') 291 | logger.info(ravg.info()) 292 | 293 | if __name__ == '__main__': 294 | main() 295 | -------------------------------------------------------------------------------- /regression/emnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import argparse 5 | import yaml 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import math 11 | import time 12 | import matplotlib.pyplot as plt 13 | from attrdict import AttrDict 14 | from tqdm import tqdm 15 | from copy import deepcopy 16 | 17 | from data.image import img_to_task, task_to_img 18 | from data.emnist import EMNIST 19 | 20 | from utils.misc import load_module, logmeanexp 21 | from utils.paths import results_path, evalsets_path 22 | from utils.log import get_logger, RunningAverage 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--mode', 28 | choices=['train', 'eval', 'plot', 'ensemble'], 29 | default='train') 30 | parser.add_argument('--expid', type=str, default='trial') 31 | parser.add_argument('--resume', action='store_true', default=False) 32 | parser.add_argument('--gpu', type=str, default='0') 33 | 34 | parser.add_argument('--max_num_points', type=int, default=200) 35 | parser.add_argument('--class_range', type=int, nargs='*', default=[0,10]) 36 | 37 | parser.add_argument('--model', type=str, default='cnp') 38 | parser.add_argument('--train_batch_size', type=int, default=100) 39 | parser.add_argument('--train_num_samples', type=int, default=4) 40 | 41 | parser.add_argument('--lr', type=float, default=5e-4) 42 | parser.add_argument('--num_epochs', type=int, default=200) 43 | parser.add_argument('--eval_freq', type=int, default=10) 44 | parser.add_argument('--save_freq', type=int, default=10) 45 | 46 | parser.add_argument('--eval_seed', type=int, default=42) 47 | parser.add_argument('--eval_batch_size', type=int, default=16) 48 | parser.add_argument('--eval_num_samples', type=int, default=50) 49 | parser.add_argument('--eval_logfile', type=str, default=None) 50 | 51 | parser.add_argument('--plot_seed', type=int, default=None) 52 | parser.add_argument('--plot_batch_size', type=int, default=16) 53 | parser.add_argument('--plot_num_samples', type=int, default=30) 54 | parser.add_argument('--plot_num_ctx', type=int, default=100) 55 | 56 | # OOD settings 57 | parser.add_argument('--t_noise', type=float, default=None) 58 | 59 | args = parser.parse_args() 60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 61 | 62 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper()) 63 | with open(f'configs/emnist/{args.model}.yaml', 'r') as f: 64 | config = yaml.safe_load(f) 65 | model = model_cls(**config).cuda() 66 | 67 | args.root = osp.join(results_path, 'emnist', args.model, args.expid) 68 | 69 | if args.mode == 'train': 70 | train(args, model) 71 | elif args.mode == 'eval': 72 | eval(args, model) 73 | elif args.mode == 'plot': 74 | plot(args, model) 75 | elif args.mode == 'ensemble': 76 | ensemble(args, model) 77 | 78 | def train(args, model): 79 | if not osp.isdir(args.root): 80 | os.makedirs(args.root) 81 | 82 | with open(osp.join(args.root, 'args.yaml'), 'w') as f: 83 | yaml.dump(args.__dict__, f) 84 | 85 | train_ds = EMNIST(train=True, class_range=args.class_range) 86 | eval_ds = EMNIST(train=False, class_range=args.class_range) 87 | train_loader = torch.utils.data.DataLoader(train_ds, 88 | batch_size=args.train_batch_size, 89 | shuffle=True, num_workers=4) 90 | 91 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 92 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 93 | optimizer, T_max=len(train_loader)*args.num_epochs) 94 | 95 | if args.resume: 96 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar')) 97 | model.load_state_dict(ckpt.model) 98 | optimizer.load_state_dict(ckpt.optimizer) 99 | scheduler.load_state_dict(ckpt.scheduler) 100 | logfilename = ckpt.logfilename 101 | start_epoch = ckpt.epoch 102 | else: 103 | logfilename = osp.join(args.root, 'train_{}.log'.format( 104 | time.strftime('%Y%m%d-%H%M'))) 105 | start_epoch = 1 106 | 107 | logger = get_logger(logfilename) 108 | ravg = RunningAverage() 109 | 110 | if not args.resume: 111 | logger.info('Total number of parameters: {}\n'.format( 112 | sum(p.numel() for p in model.parameters()))) 113 | 114 | for epoch in range(start_epoch, args.num_epochs+1): 115 | model.train() 116 | for (x, _) in tqdm(train_loader): 117 | batch = img_to_task(x, 118 | max_num_points=args.max_num_points, 119 | device='cuda') 120 | optimizer.zero_grad() 121 | outs = model(batch, num_samples=args.train_num_samples) 122 | outs.loss.backward() 123 | optimizer.step() 124 | scheduler.step() 125 | 126 | for key, val in outs.items(): 127 | ravg.update(key, val) 128 | 129 | line = f'{args.model}:{args.expid} epoch {epoch} ' 130 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} ' 131 | line += ravg.info() 132 | logger.info(line) 133 | 134 | if epoch % args.eval_freq == 0: 135 | logger.info(eval(args, model) + '\n') 136 | 137 | ravg.reset() 138 | 139 | if epoch % args.save_freq == 0 or epoch == args.num_epochs: 140 | ckpt = AttrDict() 141 | ckpt.model = model.state_dict() 142 | ckpt.optimizer = optimizer.state_dict() 143 | ckpt.scheduler = scheduler.state_dict() 144 | ckpt.logfilename = logfilename 145 | ckpt.epoch = epoch + 1 146 | torch.save(ckpt, osp.join(args.root, 'ckpt.tar')) 147 | 148 | args.mode = 'eval' 149 | eval(args, model) 150 | 151 | def gen_evalset(args): 152 | 153 | torch.manual_seed(args.eval_seed) 154 | torch.cuda.manual_seed(args.eval_seed) 155 | 156 | eval_ds = EMNIST(train=False, class_range=args.class_range) 157 | eval_loader = torch.utils.data.DataLoader(eval_ds, 158 | batch_size=args.eval_batch_size, 159 | shuffle=False, num_workers=4) 160 | 161 | batches = [] 162 | for x, _ in tqdm(eval_loader): 163 | batches.append(img_to_task(x, 164 | t_noise=args.t_noise, 165 | max_num_points=args.max_num_points)) 166 | 167 | torch.manual_seed(time.time()) 168 | torch.cuda.manual_seed(time.time()) 169 | 170 | path = osp.join(evalsets_path, 'emnist') 171 | if not osp.isdir(path): 172 | os.makedirs(path) 173 | 174 | c1, c2 = args.class_range 175 | filename = f'{c1}-{c2}' 176 | if args.t_noise is not None: 177 | filename += f'_{args.t_noise}' 178 | filename += '.tar' 179 | 180 | torch.save(batches, osp.join(path, filename)) 181 | 182 | def eval(args, model): 183 | if args.mode == 'eval': 184 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar')) 185 | model.load_state_dict(ckpt.model) 186 | if args.eval_logfile is None: 187 | c1, c2 = args.class_range 188 | eval_logfile = f'eval_{c1}-{c2}' 189 | if args.t_noise is not None: 190 | eval_logfile += f'_{args.t_noise}' 191 | eval_logfile += '.log' 192 | else: 193 | eval_logfile = args.eval_logfile 194 | filename = osp.join(args.root, eval_logfile) 195 | logger = get_logger(filename, mode='w') 196 | else: 197 | logger = None 198 | 199 | path = osp.join(evalsets_path, 'emnist') 200 | c1, c2 = args.class_range 201 | filename = f'{c1}-{c2}' 202 | if args.t_noise is not None: 203 | filename += f'_{args.t_noise}' 204 | filename += '.tar' 205 | if not osp.isfile(osp.join(path, filename)): 206 | print('generating evaluation sets...') 207 | gen_evalset(args) 208 | 209 | eval_batches = torch.load(osp.join(path, filename)) 210 | 211 | torch.manual_seed(args.eval_seed) 212 | torch.cuda.manual_seed(args.eval_seed) 213 | 214 | ravg = RunningAverage() 215 | model.eval() 216 | with torch.no_grad(): 217 | for batch in tqdm(eval_batches): 218 | for key, val in batch.items(): 219 | batch[key] = val.cuda() 220 | outs = model(batch, num_samples=args.eval_num_samples) 221 | for key, val in outs.items(): 222 | ravg.update(key, val) 223 | 224 | torch.manual_seed(time.time()) 225 | torch.cuda.manual_seed(time.time()) 226 | 227 | c1, c2 = args.class_range 228 | line = f'{args.model}:{args.expid} {c1}-{c2} ' 229 | if args.t_noise is not None: 230 | line += f'tn {args.t_noise} ' 231 | line += ravg.info() 232 | 233 | if logger is not None: 234 | logger.info(line) 235 | 236 | return line 237 | 238 | def ensemble(args, model): 239 | num_runs = 5 240 | models = [] 241 | for i in range(num_runs): 242 | model_ = deepcopy(model) 243 | ckpt = torch.load(osp.join(results_path, 'emnist', args.model, f'run{i+1}', 'ckpt.tar')) 244 | model_.load_state_dict(ckpt['model']) 245 | model_.cuda() 246 | model_.eval() 247 | models.append(model_) 248 | 249 | path = osp.join(evalsets_path, 'emnist') 250 | c1, c2 = args.class_range 251 | filename = f'{c1}-{c2}' 252 | if args.t_noise is not None: 253 | filename += f'_{args.t_noise}' 254 | filename += '.tar' 255 | if not osp.isfile(osp.join(path, filename)): 256 | print('generating evaluation sets...') 257 | gen_evalset(args) 258 | 259 | eval_batches = torch.load(osp.join(path, filename)) 260 | 261 | ravg = RunningAverage() 262 | with torch.no_grad(): 263 | for batch in tqdm(eval_batches): 264 | for key, val in batch.items(): 265 | batch[key] = val.cuda() 266 | 267 | ctx_ll = [] 268 | tar_ll = [] 269 | for model in models: 270 | outs = model(batch, 271 | num_samples=args.eval_num_samples, 272 | reduce_ll=False) 273 | ctx_ll.append(outs.ctx_ll) 274 | tar_ll.append(outs.tar_ll) 275 | 276 | if ctx_ll[0].dim() == 2: 277 | ctx_ll = torch.stack(ctx_ll) 278 | tar_ll = torch.stack(tar_ll) 279 | else: 280 | ctx_ll = torch.cat(ctx_ll) 281 | tar_ll = torch.cat(tar_ll) 282 | 283 | ctx_ll = logmeanexp(ctx_ll).mean() 284 | tar_ll = logmeanexp(tar_ll).mean() 285 | 286 | ravg.update('ctx_ll', ctx_ll) 287 | ravg.update('tar_ll', tar_ll) 288 | 289 | torch.manual_seed(time.time()) 290 | torch.cuda.manual_seed(time.time()) 291 | 292 | filename = f'ensemble_{c1}-{c2}' 293 | if args.t_noise is not None: 294 | filename += f'_{args.t_noise}' 295 | filename += '.log' 296 | logger = get_logger(osp.join(results_path, 'emnist', args.model, filename), mode='w') 297 | logger.info(ravg.info()) 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /bayesian_optimization/run_bo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from attrdict import AttrDict 4 | 5 | import numpy as np 6 | import os.path as osp 7 | import yaml 8 | 9 | import torch 10 | from data.gp import * 11 | 12 | import bayeso 13 | import bayeso.gp as bayesogp 14 | from bayeso import covariance 15 | from bayeso import acquisition 16 | 17 | from utils.paths import results_path 18 | from utils.misc import load_module 19 | 20 | def get_str_file(path_, str_kernel, str_model, noise, seed=None): 21 | if noise is not None: 22 | str_all = 'bo_{}_{}_{}'.format(str_kernel, 'noisy', str_model) 23 | else: 24 | str_all = 'bo_{}_{}'.format(str_kernel, str_model) 25 | 26 | if seed is not None: 27 | str_all += '_' + str(seed) + '.npy' 28 | else: 29 | str_all += '.npy' 30 | 31 | return osp.join(path_, str_all) 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument('--mode', 37 | choices=['oracle', 'bo'], 38 | default='bo') 39 | parser.add_argument('--expid', type=str, default='run1') 40 | parser.add_argument('--gpu', type=str, default='0') 41 | 42 | parser.add_argument('--model', type=str, default='cnp') 43 | 44 | parser.add_argument('--bo_num_samples', type=int, default=200) 45 | parser.add_argument('--bo_num_init', type=int, default=1) 46 | parser.add_argument('--bo_kernel', type=str, default='periodic') 47 | parser.add_argument('--t_noise', type=float, default=None) 48 | 49 | args = parser.parse_args() 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 51 | 52 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper()) 53 | with open(f'configs/gp/{args.model}.yaml', 'r') as f: 54 | config = yaml.safe_load(f) 55 | model = model_cls(**config).cuda() 56 | 57 | args.root = osp.join(results_path, 'gp', args.model, args.expid) 58 | 59 | if args.mode == 'oracle': 60 | oracle(args, model) 61 | elif args.mode == 'bo': 62 | bo(args, model) 63 | 64 | def oracle(args, model): 65 | seed = 42 66 | num_all = 100 67 | num_iter = 50 68 | num_init = args.bo_num_init 69 | str_cov = 'se' 70 | 71 | list_dict = [] 72 | 73 | if args.bo_kernel == 'rbf': 74 | kernel = RBFKernel() 75 | elif args.bo_kernel == 'matern': 76 | kernel = Matern52Kernel() 77 | elif args.bo_kernel == 'periodic': 78 | kernel = PeriodicKernel() 79 | else: 80 | raise ValueError(f'Invalid kernel {args.bo_kernel}') 81 | 82 | for ind_seed in range(1, num_all + 1): 83 | seed_ = seed * ind_seed 84 | 85 | if seed_ is not None: 86 | torch.manual_seed(seed_) 87 | torch.cuda.manual_seed(seed_) 88 | 89 | if os.path.exists(get_str_file('./results', args.bo_kernel, 'oracle', args.t_noise, seed=ind_seed)): 90 | dict_exp = np.load(get_str_file('./results', args.bo_kernel, 'oracle', args.t_noise, seed=ind_seed), allow_pickle=True) 91 | dict_exp = dict_exp[()] 92 | list_dict.append(dict_exp) 93 | 94 | print(dict_exp) 95 | print(dict_exp['global']) 96 | print(np.array2string(dict_exp['minima'], separator=',')) 97 | print(np.array2string(dict_exp['regrets'], separator=',')) 98 | 99 | continue 100 | 101 | sampler = GPPriorSampler(kernel, t_noise=args.t_noise) 102 | 103 | xp = torch.linspace(-2, 2, 1000).cuda() 104 | xp_ = xp.unsqueeze(0).unsqueeze(2) 105 | 106 | yp = sampler.sample(xp_) 107 | min_yp = yp.min() 108 | print(min_yp.cpu().numpy()) 109 | 110 | model.eval() 111 | 112 | batch = AttrDict() 113 | indices_permuted = torch.randperm(yp.shape[1]) 114 | 115 | batch.x = xp_[:, indices_permuted[:2*num_init], :] 116 | batch.y = yp[:, indices_permuted[:2*num_init], :] 117 | 118 | batch.xc = xp_[:, indices_permuted[:num_init], :] 119 | batch.yc = yp[:, indices_permuted[:num_init], :] 120 | 121 | batch.xt = xp_[:, indices_permuted[num_init:2*num_init], :] 122 | batch.yt = yp[:, indices_permuted[num_init:2*num_init], :] 123 | 124 | X_train = batch.xc.squeeze(0).cpu().numpy() 125 | Y_train = batch.yc.squeeze(0).cpu().numpy() 126 | X_test = xp_.squeeze(0).cpu().numpy() 127 | 128 | list_min = [] 129 | list_min.append(batch.yc.min().cpu().numpy()) 130 | 131 | for ind_iter in range(0, num_iter): 132 | print('ind_seed {} seed {} iter {}'.format(ind_seed, seed_, ind_iter + 1)) 133 | 134 | cov_X_X, inv_cov_X_X, hyps = bayesogp.get_optimized_kernel(X_train, Y_train, None, str_cov, is_fixed_noise=False, debug=False) 135 | 136 | prior_mu_train = bayesogp.get_prior_mu(None, X_train) 137 | prior_mu_test = bayesogp.get_prior_mu(None, X_test) 138 | cov_X_Xs = covariance.cov_main(str_cov, X_train, X_test, hyps, False) 139 | cov_Xs_Xs = covariance.cov_main(str_cov, X_test, X_test, hyps, True) 140 | cov_Xs_Xs = (cov_Xs_Xs + cov_Xs_Xs.T) / 2.0 141 | 142 | mu_ = np.dot(np.dot(cov_X_Xs.T, inv_cov_X_X), Y_train - prior_mu_train) + prior_mu_test 143 | Sigma_ = cov_Xs_Xs - np.dot(np.dot(cov_X_Xs.T, inv_cov_X_X), cov_X_Xs) 144 | sigma_ = np.expand_dims(np.sqrt(np.maximum(np.diag(Sigma_), 0.0)), axis=1) 145 | 146 | acq_vals = -1.0 * acquisition.ei(np.ravel(mu_), np.ravel(sigma_), Y_train) 147 | ind_ = np.argmin(acq_vals) 148 | 149 | x_new = xp[ind_, None, None, None] 150 | y_new = yp[:, ind_, None, :] 151 | 152 | batch.x = torch.cat([batch.x, x_new], axis=1) 153 | batch.y = torch.cat([batch.y, y_new], axis=1) 154 | 155 | batch.xc = torch.cat([batch.xc, x_new], axis=1) 156 | batch.yc = torch.cat([batch.yc, y_new], axis=1) 157 | 158 | X_train = batch.xc.squeeze(0).cpu().numpy() 159 | Y_train = batch.yc.squeeze(0).cpu().numpy() 160 | 161 | min_cur = batch.yc.min() 162 | list_min.append(min_cur.cpu().numpy()) 163 | 164 | print(min_yp.cpu().numpy()) 165 | print(np.array2string(np.array(list_min), separator=',')) 166 | print(np.array2string(np.array(list_min) - min_yp.cpu().numpy(), separator=',')) 167 | 168 | dict_exp = { 169 | 'seed': seed_, 170 | 'str_cov': str_cov, 171 | 'global': min_yp.cpu().numpy(), 172 | 'minima': np.array(list_min), 173 | 'regrets': np.array(list_min) - min_yp.cpu().numpy(), 174 | 'xc': X_train, 175 | 'yc': Y_train, 176 | 'model': 'oracle', 177 | } 178 | 179 | np.save(get_str_file('./results', args.bo_kernel, 'oracle', args.t_noise, seed=ind_seed), dict_exp) 180 | list_dict.append(dict_exp) 181 | 182 | np.save(get_str_file('./figures/results', args.bo_kernel, 'oracle', args.t_noise), list_dict) 183 | 184 | def bo(args, model): 185 | if args.mode == 'bo': 186 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar')) 187 | model.load_state_dict(ckpt.model) 188 | 189 | if args.bo_kernel == 'rbf': 190 | kernel = RBFKernel() 191 | elif args.bo_kernel == 'matern': 192 | kernel = Matern52Kernel() 193 | elif args.bo_kernel == 'periodic': 194 | kernel = PeriodicKernel() 195 | else: 196 | raise ValueError(f'Invalid kernel {args.bo_kernel}') 197 | 198 | seed = 42 199 | str_cov = 'se' 200 | num_all = 100 201 | num_iter = 50 202 | num_init = args.bo_num_init 203 | 204 | list_dict = [] 205 | 206 | for ind_seed in range(1, num_all + 1): 207 | seed_ = seed * ind_seed 208 | 209 | if seed_ is not None: 210 | torch.manual_seed(seed_) 211 | torch.cuda.manual_seed(seed_) 212 | 213 | obj_prior = GPPriorSampler(kernel, t_noise=args.t_noise) 214 | 215 | xp = torch.linspace(-2, 2, 1000).cuda() 216 | xp_ = xp.unsqueeze(0).unsqueeze(2) 217 | 218 | yp = obj_prior.sample(xp_) 219 | min_yp = yp.min() 220 | print(min_yp.cpu().numpy()) 221 | 222 | model.eval() 223 | 224 | batch = AttrDict() 225 | 226 | indices_permuted = torch.randperm(yp.shape[1]) 227 | 228 | batch.x = xp_[:, indices_permuted[:2*num_init], :] 229 | batch.y = yp[:, indices_permuted[:2*num_init], :] 230 | 231 | batch.xc = xp_[:, indices_permuted[:num_init], :] 232 | batch.yc = yp[:, indices_permuted[:num_init], :] 233 | 234 | batch.xt = xp_[:, indices_permuted[num_init:2*num_init], :] 235 | batch.yt = yp[:, indices_permuted[num_init:2*num_init], :] 236 | 237 | X_train = batch.xc.squeeze(0).cpu().numpy() 238 | Y_train = batch.yc.squeeze(0).cpu().numpy() 239 | X_test = xp_.squeeze(0).cpu().numpy() 240 | 241 | list_min = [] 242 | list_min.append(batch.yc.min().cpu().numpy()) 243 | 244 | for ind_iter in range(0, num_iter): 245 | print('ind_seed {} seed {} iter {}'.format(ind_seed, seed_, ind_iter + 1)) 246 | 247 | with torch.no_grad(): 248 | outs = model(batch, num_samples=args.bo_num_samples) 249 | print('ctx_ll {:.4f} tar ll {:.4f}'.format( 250 | outs.ctx_ll.item(), outs.tar_ll.item())) 251 | py = model.predict(batch.xc, batch.yc, 252 | xp[None,:,None].repeat(1, 1, 1), 253 | num_samples=args.bo_num_samples) 254 | mu, sigma = py.mean.squeeze(0), py.scale.squeeze(0) 255 | 256 | if mu.dim() == 4: 257 | print(mu.shape, sigma.shape) 258 | var = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2) 259 | sigma = var.sqrt().squeeze(0) 260 | mu = mu.mean(0).squeeze(0) 261 | mu_ = mu.cpu().numpy() 262 | sigma_ = sigma.cpu().numpy() 263 | 264 | acq_vals = -1.0 * acquisition.ei(np.ravel(mu_), np.ravel(sigma_), Y_train) 265 | 266 | # acq_vals = [] 267 | 268 | # for ind_mu in range(0, mu.shape[0]): 269 | # acq_vals_ = -1.0 * acquisition.ei(np.ravel(mu[ind_mu].cpu().numpy()), np.ravel(sigma[ind_mu].cpu().numpy()), Y_train) 270 | # acq_vals.append(acq_vals_) 271 | 272 | # acq_vals = np.mean(acq_vals, axis=0) 273 | else: 274 | mu_ = mu.cpu().numpy() 275 | sigma_ = sigma.cpu().numpy() 276 | 277 | acq_vals = -1.0 * acquisition.ei(np.ravel(mu_), np.ravel(sigma_), Y_train) 278 | 279 | # var = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2) 280 | # sigma = var.sqrt().squeeze(0) 281 | # mu = mu.mean(0).squeeze(0) 282 | 283 | ind_ = np.argmin(acq_vals) 284 | 285 | x_new = xp[ind_, None, None, None] 286 | y_new = yp[:, ind_, None, :] 287 | 288 | batch.x = torch.cat([batch.x, x_new], axis=1) 289 | batch.y = torch.cat([batch.y, y_new], axis=1) 290 | 291 | batch.xc = torch.cat([batch.xc, x_new], axis=1) 292 | batch.yc = torch.cat([batch.yc, y_new], axis=1) 293 | 294 | X_train = batch.xc.squeeze(0).cpu().numpy() 295 | Y_train = batch.yc.squeeze(0).cpu().numpy() 296 | 297 | min_cur = batch.yc.min() 298 | list_min.append(min_cur.cpu().numpy()) 299 | 300 | print(min_yp.cpu().numpy()) 301 | print(np.array2string(np.array(list_min), separator=',')) 302 | print(np.array2string(np.array(list_min) - min_yp.cpu().numpy(), separator=',')) 303 | 304 | dict_exp = { 305 | 'seed': seed_, 306 | 'global': min_yp.cpu().numpy(), 307 | 'minima': np.array(list_min), 308 | 'regrets': np.array(list_min) - min_yp.cpu().numpy(), 309 | 'xc': X_train, 310 | 'yc': Y_train, 311 | 'model': args.model, 312 | 'cov': str_cov, 313 | } 314 | 315 | list_dict.append(dict_exp) 316 | 317 | np.save(get_str_file('./figures/results', args.bo_kernel, args.model, args.t_noise), list_dict) 318 | 319 | if __name__ == '__main__': 320 | main() 321 | -------------------------------------------------------------------------------- /regression/lotka_volterra.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import argparse 5 | import yaml 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import math 11 | import time 12 | import matplotlib.pyplot as plt 13 | from attrdict import AttrDict 14 | from tqdm import tqdm 15 | from copy import deepcopy 16 | 17 | from utils.misc import load_module, logmeanexp 18 | from utils.paths import results_path, datasets_path, evalsets_path 19 | from utils.log import get_logger, RunningAverage 20 | from data.lotka_volterra import load_hare_lynx 21 | 22 | def standardize(batch): 23 | with torch.no_grad(): 24 | mu, sigma = batch.xc.mean(-2, keepdim=True), batch.xc.std(-2, keepdim=True) 25 | sigma[sigma==0] = 1.0 26 | batch.x = (batch.x - mu) / (sigma + 1e-5) 27 | batch.xc = (batch.xc - mu) / (sigma + 1e-5) 28 | batch.xt = (batch.xt - mu) / (sigma + 1e-5) 29 | 30 | mu, sigma = batch.yc.mean(-2, keepdim=True), batch.yc.std(-2, keepdim=True) 31 | batch.y = (batch.y - mu) / (sigma + 1e-5) 32 | batch.yc = (batch.yc - mu) / (sigma + 1e-5) 33 | batch.yt = (batch.yt - mu) / (sigma + 1e-5) 34 | return batch 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--mode', 39 | choices=['train', 'eval', 'plot', 'ensemble'], 40 | default='train') 41 | parser.add_argument('--expid', type=str, default='trial') 42 | parser.add_argument('--resume', action='store_true', default=False) 43 | parser.add_argument('--gpu', type=str, default='0') 44 | 45 | parser.add_argument('--max_num_points', type=int, default=50) 46 | 47 | parser.add_argument('--model', type=str, default='cnp') 48 | parser.add_argument('--train_batch_size', type=int, default=100) 49 | parser.add_argument('--train_num_samples', type=int, default=4) 50 | 51 | parser.add_argument('--lr', type=float, default=5e-4) 52 | parser.add_argument('--print_freq', type=int, default=200) 53 | parser.add_argument('--eval_freq', type=int, default=5000) 54 | parser.add_argument('--save_freq', type=int, default=1000) 55 | 56 | parser.add_argument('--eval_seed', type=int, default=42) 57 | parser.add_argument('--hare_lynx', action='store_true') 58 | parser.add_argument('--eval_num_samples', type=int, default=50) 59 | parser.add_argument('--eval_logfile', type=str, default=None) 60 | 61 | parser.add_argument('--plot_seed', type=int, default=None) 62 | parser.add_argument('--plot_batch_size', type=int, default=16) 63 | parser.add_argument('--plot_num_samples', type=int, default=30) 64 | 65 | args = parser.parse_args() 66 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 67 | 68 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper()) 69 | with open(f'configs/lotka_volterra/{args.model}.yaml', 'r') as f: 70 | config = yaml.safe_load(f) 71 | model = model_cls(**config).cuda() 72 | 73 | args.root = osp.join(results_path, 'lotka_volterra', args.model, args.expid) 74 | 75 | if args.mode == 'train': 76 | train(args, model) 77 | elif args.mode == 'eval': 78 | eval(args, model) 79 | elif args.mode == 'plot': 80 | plot(args, model) 81 | elif args.mode == 'ensemble': 82 | ensemble(args, model) 83 | 84 | def train(args, model): 85 | if not osp.isdir(args.root): 86 | os.makedirs(args.root) 87 | 88 | with open(osp.join(args.root, 'args.yaml'), 'w') as f: 89 | yaml.dump(args.__dict__, f) 90 | 91 | train_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'train.tar')) 92 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar')) 93 | num_steps = len(train_data) 94 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 95 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 96 | optimizer, T_max=num_steps) 97 | 98 | if args.resume: 99 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar')) 100 | model.load_state_dict(ckpt.model) 101 | optimizer.load_state_dict(ckpt.optimizer) 102 | scheduler.load_state_dict(ckpt.scheduler) 103 | logfilename = ckpt.logfilename 104 | start_step = ckpt.step 105 | else: 106 | logfilename = osp.join(args.root, 107 | f'train_{time.strftime("%Y%m%d-%H%M")}.log') 108 | start_step = 1 109 | 110 | logger = get_logger(logfilename) 111 | ravg = RunningAverage() 112 | 113 | if not args.resume: 114 | logger.info('Total number of parameters: {}\n'.format( 115 | sum(p.numel() for p in model.parameters()))) 116 | 117 | for step in range(start_step, num_steps+1): 118 | model.train() 119 | optimizer.zero_grad() 120 | 121 | batch = standardize(train_data[step-1]) 122 | for key, val in batch.items(): 123 | batch[key] = val.cuda() 124 | 125 | outs = model(batch, num_samples=args.train_num_samples) 126 | outs.loss.backward() 127 | optimizer.step() 128 | scheduler.step() 129 | 130 | for key, val in outs.items(): 131 | ravg.update(key, val) 132 | 133 | if step % args.print_freq == 0: 134 | line = f'{args.model}:{args.expid} step {step} ' 135 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} ' 136 | line += ravg.info() 137 | logger.info(line) 138 | 139 | if step % args.eval_freq == 0: 140 | line = eval(args, model, eval_data=eval_data) 141 | logger.info(line + '\n') 142 | 143 | ravg.reset() 144 | 145 | if step % args.save_freq == 0 or step == num_steps: 146 | ckpt = AttrDict() 147 | ckpt.model = model.state_dict() 148 | ckpt.optimizer = optimizer.state_dict() 149 | ckpt.scheduler = scheduler.state_dict() 150 | ckpt.logfilename = logfilename 151 | ckpt.step = step + 1 152 | torch.save(ckpt, osp.join(args.root, 'ckpt.tar')) 153 | 154 | args.mode = 'eval' 155 | eval(args, model, eval_data=eval_data) 156 | 157 | def eval(args, model, eval_data=None): 158 | if args.mode == 'eval': 159 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar')) 160 | model.load_state_dict(ckpt.model) 161 | if args.eval_logfile is None: 162 | if args.hare_lynx: 163 | eval_logfile = 'hare_lynx.log' 164 | else: 165 | eval_logfile = 'eval.log' 166 | else: 167 | eval_logfile = args.eval_logfile 168 | filename = osp.join(args.root, eval_logfile) 169 | logger = get_logger(filename, mode='w') 170 | else: 171 | logger = None 172 | 173 | torch.manual_seed(args.eval_seed) 174 | torch.cuda.manual_seed(args.eval_seed) 175 | 176 | if eval_data is None: 177 | if args.hare_lynx: 178 | eval_data = load_hare_lynx(1000, 16) 179 | else: 180 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar')) 181 | 182 | ravg = RunningAverage() 183 | model.eval() 184 | with torch.no_grad(): 185 | for batch in tqdm(eval_data): 186 | batch = standardize(batch) 187 | for key, val in batch.items(): 188 | batch[key] = val.cuda() 189 | outs = model(batch, num_samples=args.eval_num_samples) 190 | for key, val in outs.items(): 191 | ravg.update(key, val) 192 | 193 | torch.manual_seed(time.time()) 194 | torch.cuda.manual_seed(time.time()) 195 | 196 | line = f'{args.model}:{args.expid} ' 197 | line += ravg.info() 198 | 199 | if logger is not None: 200 | logger.info(line) 201 | 202 | return line 203 | 204 | def ensemble(args, model): 205 | num_runs = 5 206 | models = [] 207 | for i in range(num_runs): 208 | model_ = deepcopy(model) 209 | ckpt = torch.load(osp.join(results_path, 'lotka_volterra', args.model, f'run{i+1}', 'ckpt.tar')) 210 | model_.load_state_dict(ckpt['model']) 211 | model_.cuda() 212 | model_.eval() 213 | models.append(model_) 214 | 215 | torch.manual_seed(args.eval_seed) 216 | torch.cuda.manual_seed(args.eval_seed) 217 | 218 | if args.hare_lynx: 219 | eval_data = load_hare_lynx(1000, 16) 220 | else: 221 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar')) 222 | 223 | ravg = RunningAverage() 224 | with torch.no_grad(): 225 | for batch in tqdm(eval_data): 226 | batch = standardize(batch) 227 | for key, val in batch.items(): 228 | batch[key] = val.cuda() 229 | 230 | ctx_ll = [] 231 | tar_ll = [] 232 | for model_ in models: 233 | outs = model_(batch, 234 | num_samples=args.eval_num_samples, 235 | reduce_ll=False) 236 | ctx_ll.append(outs.ctx_ll) 237 | tar_ll.append(outs.tar_ll) 238 | 239 | if ctx_ll[0].dim() == 2: 240 | ctx_ll = torch.stack(ctx_ll) 241 | tar_ll = torch.stack(tar_ll) 242 | else: 243 | ctx_ll = torch.cat(ctx_ll) 244 | tar_ll = torch.cat(tar_ll) 245 | 246 | ctx_ll = logmeanexp(ctx_ll).mean() 247 | tar_ll = logmeanexp(tar_ll).mean() 248 | 249 | ravg.update('ctx_ll', ctx_ll) 250 | ravg.update('tar_ll', tar_ll) 251 | 252 | torch.manual_seed(time.time()) 253 | torch.cuda.manual_seed(time.time()) 254 | 255 | filename = 'ensemble' 256 | if args.hare_lynx: 257 | filename += '_hare_lynx' 258 | filename += '.log' 259 | logger = get_logger(osp.join(results_path, 'lotka_volterra', args.model, filename), mode='w') 260 | logger.info(ravg.info()) 261 | 262 | def plot(args, model): 263 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar')) 264 | model.load_state_dict(ckpt.model) 265 | 266 | def tnp(x): 267 | return x.squeeze().cpu().data.numpy() 268 | 269 | if args.hare_lynx: 270 | eval_data = load_hare_lynx(1000, 16) 271 | else: 272 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar')) 273 | bid = torch.randint(len(eval_data), [1]).item() 274 | batch = standardize(eval_data[bid]) 275 | 276 | for k, v in batch.items(): 277 | batch[k] = v.cuda() 278 | 279 | model.eval() 280 | outs = model(batch, num_samples=args.eval_num_samples) 281 | print(outs.tar_ll) 282 | 283 | fig, axes = plt.subplots(4, 4, figsize=(20, 20)) 284 | xp = [] 285 | for b in range(batch.x.shape[0]): 286 | bx = batch.x[b] 287 | xp.append(torch.linspace(bx.min()-0.1, bx.max()+0.1, 200)) 288 | xp = torch.stack(xp).unsqueeze(-1).cuda() 289 | 290 | model.eval() 291 | with torch.no_grad(): 292 | py = model.predict(batch.xc, batch.yc, xp, num_samples=args.plot_num_samples) 293 | mu, sigma = py.mean, py.scale 294 | 295 | if mu.dim() > 3: 296 | bmu = mu.mean(0) 297 | bvar = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2) 298 | bsigma = bvar.sqrt() 299 | else: 300 | bmu = mu 301 | bsigma = sigma 302 | 303 | for i, ax in enumerate(axes.flatten()): 304 | ax.plot(tnp(xp[i]), tnp(bmu[i]), alpha=0.5) 305 | upper = tnp(bmu[i][:,0] + bsigma[i][:,0]) 306 | lower = tnp(bmu[i][:,0] - bsigma[i][:,0]) 307 | ax.fill_between(tnp(xp[i]), lower, upper, 308 | alpha=0.2, linewidth=0.0, label='predator') 309 | 310 | upper = tnp(bmu[i][:,1] + bsigma[i][:,1]) 311 | lower = tnp(bmu[i][:,1] - bsigma[i][:,1]) 312 | ax.fill_between(tnp(xp[i]), lower, upper, 313 | alpha=0.2, linewidth=0.0, label='prey') 314 | 315 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i][:,0]), color='k', marker='*') 316 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i][:,1]), color='k', marker='*') 317 | 318 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i][:,0]), color='orchid', marker='x') 319 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i][:,1]), color='orchid', marker='x') 320 | 321 | plt.tight_layout() 322 | plt.show() 323 | 324 | if __name__ == '__main__': 325 | main() 326 | -------------------------------------------------------------------------------- /regression/gp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import argparse 5 | import yaml 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import math 11 | import time 12 | import matplotlib.pyplot as plt 13 | from attrdict import AttrDict 14 | from tqdm import tqdm 15 | from copy import deepcopy 16 | 17 | from data.gp import * 18 | 19 | from utils.misc import load_module, logmeanexp 20 | from utils.paths import results_path, evalsets_path 21 | from utils.log import get_logger, RunningAverage 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument('--mode', 27 | choices=['train', 'eval', 'plot', 'ensemble'], 28 | default='train') 29 | parser.add_argument('--expid', type=str, default='trial') 30 | parser.add_argument('--resume', action='store_true', default=False) 31 | parser.add_argument('--gpu', type=str, default='0') 32 | 33 | parser.add_argument('--max_num_points', type=int, default=50) 34 | 35 | parser.add_argument('--model', type=str, default='cnp') 36 | parser.add_argument('--train_batch_size', type=int, default=100) 37 | parser.add_argument('--train_num_samples', type=int, default=4) 38 | 39 | parser.add_argument('--lr', type=float, default=5e-4) 40 | parser.add_argument('--num_steps', type=int, default=100000) 41 | parser.add_argument('--print_freq', type=int, default=200) 42 | parser.add_argument('--eval_freq', type=int, default=5000) 43 | parser.add_argument('--save_freq', type=int, default=1000) 44 | 45 | parser.add_argument('--eval_seed', type=int, default=42) 46 | parser.add_argument('--eval_num_batches', type=int, default=3000) 47 | parser.add_argument('--eval_batch_size', type=int, default=16) 48 | parser.add_argument('--eval_num_samples', type=int, default=50) 49 | parser.add_argument('--eval_logfile', type=str, default=None) 50 | 51 | parser.add_argument('--plot_seed', type=int, default=None) 52 | parser.add_argument('--plot_batch_size', type=int, default=16) 53 | parser.add_argument('--plot_num_samples', type=int, default=30) 54 | parser.add_argument('--plot_num_ctx', type=int, default=None) 55 | 56 | # OOD settings 57 | parser.add_argument('--eval_kernel', type=str, default='rbf') 58 | parser.add_argument('--t_noise', type=float, default=None) 59 | 60 | args = parser.parse_args() 61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 62 | 63 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper()) 64 | with open(f'configs/gp/{args.model}.yaml', 'r') as f: 65 | config = yaml.safe_load(f) 66 | model = model_cls(**config).cuda() 67 | 68 | args.root = osp.join(results_path, 'gp', args.model, args.expid) 69 | 70 | if args.mode == 'train': 71 | train(args, model) 72 | elif args.mode == 'eval': 73 | eval(args, model) 74 | elif args.mode == 'plot': 75 | plot(args, model) 76 | elif args.mode == 'ensemble': 77 | ensemble(args, model) 78 | 79 | def train(args, model): 80 | if not osp.isdir(args.root): 81 | os.makedirs(args.root) 82 | 83 | with open(osp.join(args.root, 'args.yaml'), 'w') as f: 84 | yaml.dump(args.__dict__, f) 85 | 86 | sampler = GPSampler(RBFKernel()) 87 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 88 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 89 | optimizer, T_max=args.num_steps) 90 | 91 | if args.resume: 92 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar')) 93 | model.load_state_dict(ckpt.model) 94 | optimizer.load_state_dict(ckpt.optimizer) 95 | scheduler.load_state_dict(ckpt.scheduler) 96 | logfilename = ckpt.logfilename 97 | start_step = ckpt.step 98 | else: 99 | logfilename = os.path.join(args.root, 100 | f'train_{time.strftime("%Y%m%d-%H%M")}.log') 101 | start_step = 1 102 | 103 | logger = get_logger(logfilename) 104 | ravg = RunningAverage() 105 | 106 | if not args.resume: 107 | logger.info('Total number of parameters: {}\n'.format( 108 | sum(p.numel() for p in model.parameters()))) 109 | 110 | for step in range(start_step, args.num_steps+1): 111 | model.train() 112 | optimizer.zero_grad() 113 | batch = sampler.sample( 114 | batch_size=args.train_batch_size, 115 | max_num_points=args.max_num_points, 116 | device='cuda') 117 | outs = model(batch, num_samples=args.train_num_samples) 118 | outs.loss.backward() 119 | optimizer.step() 120 | scheduler.step() 121 | 122 | for key, val in outs.items(): 123 | ravg.update(key, val) 124 | 125 | if step % args.print_freq == 0: 126 | line = f'{args.model}:{args.expid} step {step} ' 127 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} ' 128 | line += ravg.info() 129 | logger.info(line) 130 | 131 | if step % args.eval_freq == 0: 132 | line = eval(args, model) 133 | logger.info(line + '\n') 134 | 135 | ravg.reset() 136 | 137 | if step % args.save_freq == 0 or step == args.num_steps: 138 | ckpt = AttrDict() 139 | ckpt.model = model.state_dict() 140 | ckpt.optimizer = optimizer.state_dict() 141 | ckpt.scheduler = scheduler.state_dict() 142 | ckpt.logfilename = logfilename 143 | ckpt.step = step + 1 144 | torch.save(ckpt, os.path.join(args.root, 'ckpt.tar')) 145 | 146 | args.mode = 'eval' 147 | eval(args, model) 148 | 149 | def gen_evalset(args): 150 | if args.eval_kernel == 'rbf': 151 | kernel = RBFKernel() 152 | elif args.eval_kernel == 'matern': 153 | kernel = Matern52Kernel() 154 | elif args.eval_kernel == 'periodic': 155 | kernel = PeriodicKernel() 156 | else: 157 | raise ValueError(f'Invalid kernel {args.eval_kernel}') 158 | 159 | torch.manual_seed(args.eval_seed) 160 | torch.cuda.manual_seed(args.eval_seed) 161 | 162 | sampler = GPSampler(kernel, t_noise=args.t_noise) 163 | batches = [] 164 | for i in tqdm(range(args.eval_num_batches)): 165 | batches.append(sampler.sample( 166 | batch_size=args.eval_batch_size, 167 | max_num_points=args.max_num_points)) 168 | 169 | torch.manual_seed(time.time()) 170 | torch.cuda.manual_seed(time.time()) 171 | 172 | path = osp.join(evalsets_path, 'gp') 173 | if not osp.isdir(path): 174 | os.makedirs(path) 175 | 176 | filename = f'{args.eval_kernel}' 177 | if args.t_noise is not None: 178 | filename += f'_{args.t_noise}' 179 | filename += '.tar' 180 | 181 | torch.save(batches, osp.join(path, filename)) 182 | 183 | def eval(args, model): 184 | if args.mode == 'eval': 185 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar')) 186 | model.load_state_dict(ckpt.model) 187 | if args.eval_logfile is None: 188 | eval_logfile = f'eval_{args.eval_kernel}' 189 | if args.t_noise is not None: 190 | eval_logfile += f'_tn_{args.t_noise}' 191 | eval_logfile += '.log' 192 | else: 193 | eval_logfile = args.eval_logfile 194 | filename = os.path.join(args.root, eval_logfile) 195 | logger = get_logger(filename, mode='w') 196 | else: 197 | logger = None 198 | 199 | if args.eval_kernel == 'rbf': 200 | kernel = RBFKernel() 201 | elif args.eval_kernel == 'matern': 202 | kernel = Matern52Kernel() 203 | elif args.eval_kernel == 'periodic': 204 | kernel = PeriodicKernel() 205 | else: 206 | raise ValueError(f'Invalid kernel {args.eval_kernel}') 207 | 208 | path = osp.join(evalsets_path, 'gp') 209 | filename = f'{args.eval_kernel}' 210 | if args.t_noise is not None: 211 | filename += f'_{args.t_noise}' 212 | filename += '.tar' 213 | if not osp.isfile(osp.join(path, filename)): 214 | print('generating evaluation sets...') 215 | gen_evalset(args) 216 | 217 | eval_batches = torch.load(osp.join(path, filename)) 218 | 219 | torch.manual_seed(args.eval_seed) 220 | torch.cuda.manual_seed(args.eval_seed) 221 | 222 | ravg = RunningAverage() 223 | model.eval() 224 | with torch.no_grad(): 225 | for batch in tqdm(eval_batches): 226 | for key, val in batch.items(): 227 | batch[key] = val.cuda() 228 | outs = model(batch, num_samples=args.eval_num_samples) 229 | for key, val in outs.items(): 230 | ravg.update(key, val) 231 | 232 | torch.manual_seed(time.time()) 233 | torch.cuda.manual_seed(time.time()) 234 | 235 | line = f'{args.model}:{args.expid} {args.eval_kernel} ' 236 | if args.t_noise is not None: 237 | line += f'tn {args.t_noise} ' 238 | line += ravg.info() 239 | 240 | if logger is not None: 241 | logger.info(line) 242 | 243 | return line 244 | 245 | def plot(args, model): 246 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar')) 247 | model.load_state_dict(ckpt.model) 248 | 249 | def tnp(x): 250 | return x.squeeze().cpu().data.numpy() 251 | 252 | if args.plot_seed is not None: 253 | torch.manual_seed(args.plot_seed) 254 | torch.cuda.manual_seed(args.plot_seed) 255 | 256 | kernel = RBFKernel() if args.pp is None else PeriodicKernel(p=args.pp) 257 | sampler = GPSampler(kernel, t_noise=args.t_noise) 258 | 259 | xp = torch.linspace(-2, 2, 200).cuda() 260 | batch = sampler.sample( 261 | batch_size=args.plot_batch_size, 262 | max_num_points=args.max_num_points, 263 | num_ctx=args.plot_num_ctx, 264 | device='cuda') 265 | 266 | model.eval() 267 | with torch.no_grad(): 268 | outs = model(batch, num_samples=args.eval_num_samples) 269 | print(f'ctx_ll {outs.ctx_ll.item():.4f}, tar_ll {outs.tar_ll.item():.4f}') 270 | 271 | py = model.predict(batch.xc, batch.yc, 272 | xp[None,:,None].repeat(args.plot_batch_size, 1, 1), 273 | num_samples=args.plot_num_samples) 274 | mu, sigma = py.mean.squeeze(0), py.scale.squeeze(0) 275 | 276 | if args.plot_batch_size > 1: 277 | nrows = max(args.plot_batch_size//4, 1) 278 | ncols = min(4, args.plot_batch_size) 279 | fig, axes = plt.subplots(nrows, ncols, 280 | figsize=(5*ncols, 5*nrows)) 281 | axes = axes.flatten() 282 | else: 283 | fig = plt.figure(figsize=(5, 5)) 284 | axes = [plt.gca()] 285 | 286 | # multi sample 287 | if mu.dim() == 4: 288 | #var = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2) 289 | #sigma = var.sqrt() 290 | #mu = mu.mean(0) 291 | 292 | for i, ax in enumerate(axes): 293 | #ax.plot(tnp(xp), tnp(mu[i]), color='steelblue', alpha=0.5) 294 | #ax.fill_between(tnp(xp), tnp(mu[i]-sigma[i]), tnp(mu[i]+sigma[i]), 295 | # color='skyblue', alpha=0.2, linewidth=0.0) 296 | for s in range(mu.shape[0]): 297 | ax.plot(tnp(xp), tnp(mu[s][i]), color='steelblue', 298 | alpha=max(0.5/args.plot_num_samples, 0.1)) 299 | ax.fill_between(tnp(xp), tnp(mu[s][i])-tnp(sigma[s][i]), 300 | tnp(mu[s][i])+tnp(sigma[s][i]), 301 | color='skyblue', 302 | alpha=max(0.2/args.plot_num_samples, 0.02), 303 | linewidth=0.0) 304 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i]), 305 | color='k', label='context', zorder=mu.shape[0]+1) 306 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i]), 307 | color='orchid', label='target', 308 | zorder=mu.shape[0]+1) 309 | ax.legend() 310 | else: 311 | for i, ax in enumerate(axes): 312 | ax.plot(tnp(xp), tnp(mu[i]), color='steelblue', alpha=0.5) 313 | ax.fill_between(tnp(xp), tnp(mu[i]-sigma[i]), tnp(mu[i]+sigma[i]), 314 | color='skyblue', alpha=0.2, linewidth=0.0) 315 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i]), 316 | color='k', label='context') 317 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i]), 318 | color='orchid', label='target') 319 | ax.legend() 320 | 321 | plt.tight_layout() 322 | plt.show() 323 | 324 | def ensemble(args, model): 325 | num_runs = 5 326 | models = [] 327 | for i in range(num_runs): 328 | model_ = deepcopy(model) 329 | ckpt = torch.load(osp.join(results_path, 'gp', args.model, f'run{i+1}', 'ckpt.tar')) 330 | model_.load_state_dict(ckpt['model']) 331 | model_.cuda() 332 | model_.eval() 333 | models.append(model_) 334 | 335 | path = osp.join(evalsets_path, 'gp') 336 | filename = f'{args.eval_kernel}' 337 | if args.t_noise is not None: 338 | filename += f'_{args.t_noise}' 339 | filename += '.tar' 340 | if not osp.isfile(osp.join(path, filename)): 341 | print('generating evaluation sets...') 342 | gen_evalset(args) 343 | 344 | eval_batches = torch.load(osp.join(path, filename)) 345 | 346 | torch.manual_seed(args.eval_seed) 347 | torch.cuda.manual_seed(args.eval_seed) 348 | 349 | ravg = RunningAverage() 350 | with torch.no_grad(): 351 | for batch in tqdm(eval_batches): 352 | for key, val in batch.items(): 353 | batch[key] = val.cuda() 354 | 355 | ctx_ll = [] 356 | tar_ll = [] 357 | for model in models: 358 | outs = model(batch, 359 | num_samples=args.eval_num_samples, 360 | reduce_ll=False) 361 | ctx_ll.append(outs.ctx_ll) 362 | tar_ll.append(outs.tar_ll) 363 | 364 | if ctx_ll[0].dim() == 2: 365 | ctx_ll = torch.stack(ctx_ll) 366 | tar_ll = torch.stack(tar_ll) 367 | else: 368 | ctx_ll = torch.cat(ctx_ll) 369 | tar_ll = torch.cat(tar_ll) 370 | 371 | ctx_ll = logmeanexp(ctx_ll).mean() 372 | tar_ll = logmeanexp(tar_ll).mean() 373 | 374 | ravg.update('ctx_ll', ctx_ll) 375 | ravg.update('tar_ll', tar_ll) 376 | 377 | filename = f'ensemble_{args.eval_kernel}' 378 | if args.t_noise is not None: 379 | filename += f'_{args.t_noise}' 380 | filename += '.log' 381 | logger = get_logger(osp.join(results_path, 'gp', args.model, filename), mode='w') 382 | logger.info(ravg.info()) 383 | 384 | if __name__ == '__main__': 385 | main() 386 | --------------------------------------------------------------------------------