├── regression
├── requirements.txt
├── configs
│ ├── gp
│ │ ├── bnp.yaml
│ │ ├── cnp.yaml
│ │ ├── np.yaml
│ │ ├── banp.yaml
│ │ ├── canp.yaml
│ │ └── anp.yaml
│ ├── celeba
│ │ ├── bnp.yaml
│ │ ├── cnp.yaml
│ │ ├── np.yaml
│ │ ├── banp.yaml
│ │ ├── canp.yaml
│ │ └── anp.yaml
│ ├── emnist
│ │ ├── bnp.yaml
│ │ ├── cnp.yaml
│ │ ├── np.yaml
│ │ ├── banp.yaml
│ │ ├── canp.yaml
│ │ └── anp.yaml
│ └── lotka_volterra
│ │ ├── bnp.yaml
│ │ ├── cnp.yaml
│ │ ├── np.yaml
│ │ ├── banp.yaml
│ │ ├── canp.yaml
│ │ └── anp.yaml
├── utils
│ ├── paths.py
│ ├── misc.py
│ ├── sampling.py
│ └── log.py
├── data
│ ├── emnist.py
│ ├── celeba.py
│ ├── image.py
│ ├── gp.py
│ └── lotka_volterra.py
├── models
│ ├── cnp.py
│ ├── attention.py
│ ├── canp.py
│ ├── bnp.py
│ ├── banp.py
│ ├── np.py
│ ├── anp.py
│ └── modules.py
├── celeba.py
├── emnist.py
├── lotka_volterra.py
└── gp.py
├── bnp_new.png
├── bayesian_optimization
├── requirements.txt
├── results
│ ├── oracle_1.npy
│ ├── oracle_2.npy
│ ├── oracle_3.npy
│ ├── oracle_4.npy
│ ├── oracle_5.npy
│ ├── oracle_6.npy
│ ├── oracle_7.npy
│ ├── oracle_8.npy
│ ├── oracle_9.npy
│ ├── oracle_10.npy
│ ├── oracle_100.npy
│ ├── oracle_11.npy
│ ├── oracle_12.npy
│ ├── oracle_13.npy
│ ├── oracle_14.npy
│ ├── oracle_15.npy
│ ├── oracle_16.npy
│ ├── oracle_17.npy
│ ├── oracle_18.npy
│ ├── oracle_19.npy
│ ├── oracle_20.npy
│ ├── oracle_21.npy
│ ├── oracle_22.npy
│ ├── oracle_23.npy
│ ├── oracle_24.npy
│ ├── oracle_25.npy
│ ├── oracle_26.npy
│ ├── oracle_27.npy
│ ├── oracle_28.npy
│ ├── oracle_29.npy
│ ├── oracle_30.npy
│ ├── oracle_31.npy
│ ├── oracle_32.npy
│ ├── oracle_33.npy
│ ├── oracle_34.npy
│ ├── oracle_35.npy
│ ├── oracle_36.npy
│ ├── oracle_37.npy
│ ├── oracle_38.npy
│ ├── oracle_39.npy
│ ├── oracle_40.npy
│ ├── oracle_41.npy
│ ├── oracle_42.npy
│ ├── oracle_43.npy
│ ├── oracle_44.npy
│ ├── oracle_45.npy
│ ├── oracle_46.npy
│ ├── oracle_47.npy
│ ├── oracle_48.npy
│ ├── oracle_49.npy
│ ├── oracle_50.npy
│ ├── oracle_51.npy
│ ├── oracle_52.npy
│ ├── oracle_53.npy
│ ├── oracle_54.npy
│ ├── oracle_55.npy
│ ├── oracle_56.npy
│ ├── oracle_57.npy
│ ├── oracle_58.npy
│ ├── oracle_59.npy
│ ├── oracle_60.npy
│ ├── oracle_61.npy
│ ├── oracle_62.npy
│ ├── oracle_63.npy
│ ├── oracle_64.npy
│ ├── oracle_65.npy
│ ├── oracle_66.npy
│ ├── oracle_67.npy
│ ├── oracle_68.npy
│ ├── oracle_69.npy
│ ├── oracle_70.npy
│ ├── oracle_71.npy
│ ├── oracle_72.npy
│ ├── oracle_73.npy
│ ├── oracle_74.npy
│ ├── oracle_75.npy
│ ├── oracle_76.npy
│ ├── oracle_77.npy
│ ├── oracle_78.npy
│ ├── oracle_79.npy
│ ├── oracle_80.npy
│ ├── oracle_81.npy
│ ├── oracle_82.npy
│ ├── oracle_83.npy
│ ├── oracle_84.npy
│ ├── oracle_85.npy
│ ├── oracle_86.npy
│ ├── oracle_87.npy
│ ├── oracle_88.npy
│ ├── oracle_89.npy
│ ├── oracle_90.npy
│ ├── oracle_91.npy
│ ├── oracle_92.npy
│ ├── oracle_93.npy
│ ├── oracle_94.npy
│ ├── oracle_95.npy
│ ├── oracle_96.npy
│ ├── oracle_97.npy
│ ├── oracle_98.npy
│ └── oracle_99.npy
├── configs
│ └── gp
│ │ ├── bnp.yaml
│ │ ├── cnp.yaml
│ │ ├── np.yaml
│ │ ├── banp.yaml
│ │ ├── canp.yaml
│ │ └── anp.yaml
├── utils
│ ├── paths.py
│ ├── misc.py
│ ├── sampling.py
│ └── log.py
├── .gitignore
├── test_all.sh
├── models
│ ├── attention.py
│ ├── cnp.py
│ ├── canp.py
│ ├── bnp.py
│ ├── banp.py
│ ├── np.py
│ ├── anp.py
│ └── modules.py
├── data
│ └── gp.py
└── run_bo.py
├── LICENSE
├── README.md
└── .gitignore
/regression/requirements.txt:
--------------------------------------------------------------------------------
1 | attrdict
2 | pyyaml
3 | tqdm
4 |
--------------------------------------------------------------------------------
/bnp_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bnp_new.png
--------------------------------------------------------------------------------
/bayesian_optimization/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | attrdict
3 | torch==1.4.0
4 | torchvision==0.5.0
5 | bayeso
6 |
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_1.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_2.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_3.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_4.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_5.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_6.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_7.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_8.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_9.npy
--------------------------------------------------------------------------------
/regression/configs/gp/bnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_pre_depth: 4
5 | enc_post_depth: 2
6 | dec_depth: 3
7 |
--------------------------------------------------------------------------------
/regression/configs/gp/cnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_pre_depth: 4
5 | enc_post_depth: 2
6 | dec_depth: 3
7 |
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_10.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_100.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_11.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_12.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_13.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_13.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_14.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_14.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_15.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_15.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_16.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_17.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_17.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_18.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_18.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_19.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_19.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_20.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_20.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_21.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_21.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_22.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_22.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_23.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_23.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_24.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_24.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_25.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_25.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_26.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_26.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_27.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_27.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_28.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_28.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_29.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_29.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_30.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_30.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_31.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_31.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_32.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_32.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_33.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_33.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_34.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_34.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_35.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_35.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_36.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_36.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_37.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_37.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_38.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_38.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_39.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_39.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_40.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_40.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_41.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_41.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_42.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_42.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_43.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_43.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_44.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_44.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_45.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_45.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_46.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_46.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_47.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_47.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_48.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_48.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_49.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_49.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_50.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_50.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_51.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_51.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_52.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_52.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_53.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_53.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_54.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_54.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_55.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_55.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_56.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_56.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_57.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_57.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_58.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_58.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_59.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_59.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_60.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_60.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_61.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_61.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_62.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_62.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_63.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_63.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_64.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_64.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_65.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_65.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_66.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_66.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_67.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_67.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_68.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_68.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_69.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_69.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_70.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_70.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_71.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_71.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_72.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_72.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_73.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_73.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_74.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_74.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_75.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_75.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_76.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_76.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_77.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_77.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_78.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_78.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_79.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_79.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_80.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_80.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_81.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_81.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_82.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_82.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_83.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_83.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_84.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_84.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_85.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_85.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_86.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_86.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_87.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_87.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_88.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_88.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_89.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_89.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_90.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_90.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_91.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_91.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_92.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_92.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_93.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_93.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_94.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_94.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_95.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_95.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_96.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_96.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_97.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_97.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_98.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_98.npy
--------------------------------------------------------------------------------
/bayesian_optimization/results/oracle_99.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juho-lee/bnp/HEAD/bayesian_optimization/results/oracle_99.npy
--------------------------------------------------------------------------------
/regression/configs/celeba/bnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 3
3 | dim_hid: 128
4 | enc_pre_depth: 6
5 | enc_post_depth: 3
6 | dec_depth: 5
7 |
--------------------------------------------------------------------------------
/regression/configs/celeba/cnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 3
3 | dim_hid: 128
4 | enc_pre_depth: 6
5 | enc_post_depth: 3
6 | dec_depth: 5
7 |
--------------------------------------------------------------------------------
/regression/configs/emnist/bnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_pre_depth: 5
5 | enc_post_depth: 3
6 | dec_depth: 4
7 |
--------------------------------------------------------------------------------
/regression/configs/emnist/cnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_pre_depth: 5
5 | enc_post_depth: 3
6 | dec_depth: 4
7 |
--------------------------------------------------------------------------------
/bayesian_optimization/configs/gp/bnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_pre_depth: 4
5 | enc_post_depth: 2
6 | dec_depth: 3
7 |
--------------------------------------------------------------------------------
/bayesian_optimization/configs/gp/cnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_pre_depth: 4
5 | enc_post_depth: 2
6 | dec_depth: 3
7 |
--------------------------------------------------------------------------------
/regression/configs/lotka_volterra/bnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 2
3 | dim_hid: 128
4 | enc_pre_depth: 4
5 | enc_post_depth: 2
6 | dec_depth: 3
7 |
--------------------------------------------------------------------------------
/regression/configs/lotka_volterra/cnp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 2
3 | dim_hid: 128
4 | enc_pre_depth: 4
5 | enc_post_depth: 2
6 | dec_depth: 3
7 |
--------------------------------------------------------------------------------
/regression/configs/celeba/np.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 3
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_pre_depth: 6
6 | enc_post_depth: 3
7 | dec_depth: 5
8 |
--------------------------------------------------------------------------------
/regression/configs/emnist/np.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 1
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_pre_depth: 5
6 | enc_post_depth: 3
7 | dec_depth: 4
8 |
--------------------------------------------------------------------------------
/regression/configs/gp/np.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_pre_depth: 4
6 | enc_post_depth: 2
7 | dec_depth: 3
8 |
--------------------------------------------------------------------------------
/bayesian_optimization/configs/gp/np.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_pre_depth: 4
6 | enc_post_depth: 2
7 | dec_depth: 3
8 |
--------------------------------------------------------------------------------
/regression/configs/lotka_volterra/np.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 2
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_pre_depth: 4
6 | enc_post_depth: 2
7 | dec_depth: 3
8 |
--------------------------------------------------------------------------------
/regression/configs/gp/banp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_v_depth: 4
5 | enc_qk_depth: 2
6 | enc_pre_depth: 4
7 | enc_post_depth: 2
8 | dec_depth: 3
9 |
--------------------------------------------------------------------------------
/regression/configs/gp/canp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_v_depth: 4
5 | enc_qk_depth: 2
6 | enc_pre_depth: 4
7 | enc_post_depth: 2
8 | dec_depth: 3
9 |
--------------------------------------------------------------------------------
/regression/configs/celeba/banp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 3
3 | dim_hid: 128
4 | enc_v_depth: 6
5 | enc_qk_depth: 3
6 | enc_pre_depth: 6
7 | enc_post_depth: 3
8 | dec_depth: 5
9 |
--------------------------------------------------------------------------------
/regression/configs/celeba/canp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 3
3 | dim_hid: 128
4 | enc_v_depth: 6
5 | enc_qk_depth: 3
6 | enc_pre_depth: 6
7 | enc_post_depth: 3
8 | dec_depth: 5
9 |
--------------------------------------------------------------------------------
/regression/configs/emnist/banp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_v_depth: 5
5 | enc_qk_depth: 3
6 | enc_pre_depth: 5
7 | enc_post_depth: 3
8 | dec_depth: 4
9 |
--------------------------------------------------------------------------------
/regression/configs/emnist/canp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_v_depth: 5
5 | enc_qk_depth: 3
6 | enc_pre_depth: 5
7 | enc_post_depth: 3
8 | dec_depth: 4
9 |
--------------------------------------------------------------------------------
/bayesian_optimization/utils/paths.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ROOT = '../../'
4 |
5 | datasets_path = os.path.join(ROOT, 'datasets')
6 | results_path = os.path.join(ROOT, 'ckpt_np')
7 |
--------------------------------------------------------------------------------
/bayesian_optimization/configs/gp/banp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_v_depth: 4
5 | enc_qk_depth: 2
6 | enc_pre_depth: 4
7 | enc_post_depth: 2
8 | dec_depth: 3
9 |
--------------------------------------------------------------------------------
/bayesian_optimization/configs/gp/canp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | enc_v_depth: 4
5 | enc_qk_depth: 2
6 | enc_pre_depth: 4
7 | enc_post_depth: 2
8 | dec_depth: 3
9 |
--------------------------------------------------------------------------------
/regression/configs/lotka_volterra/banp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 2
3 | dim_hid: 128
4 | enc_v_depth: 4
5 | enc_qk_depth: 2
6 | enc_pre_depth: 4
7 | enc_post_depth: 2
8 | dec_depth: 3
9 |
--------------------------------------------------------------------------------
/regression/configs/lotka_volterra/canp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 2
3 | dim_hid: 128
4 | enc_v_depth: 4
5 | enc_qk_depth: 2
6 | enc_pre_depth: 4
7 | enc_post_depth: 2
8 | dec_depth: 3
9 |
--------------------------------------------------------------------------------
/regression/configs/celeba/anp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 3
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_v_depth: 6
6 | enc_qk_depth: 3
7 | enc_pre_depth: 6
8 | enc_post_depth: 3
9 | dec_depth: 5
10 |
--------------------------------------------------------------------------------
/regression/configs/emnist/anp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 2
2 | dim_y: 1
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_v_depth: 5
6 | enc_qk_depth: 3
7 | enc_pre_depth: 5
8 | enc_post_depth: 3
9 | dec_depth: 4
10 |
--------------------------------------------------------------------------------
/regression/configs/gp/anp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_v_depth: 4
6 | enc_qk_depth: 2
7 | enc_pre_depth: 4
8 | enc_post_depth: 2
9 | dec_depth: 3
10 |
--------------------------------------------------------------------------------
/bayesian_optimization/configs/gp/anp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 1
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_v_depth: 4
6 | enc_qk_depth: 2
7 | enc_pre_depth: 4
8 | enc_post_depth: 2
9 | dec_depth: 3
10 |
--------------------------------------------------------------------------------
/regression/configs/lotka_volterra/anp.yaml:
--------------------------------------------------------------------------------
1 | dim_x: 1
2 | dim_y: 2
3 | dim_hid: 128
4 | dim_lat: 128
5 | enc_v_depth: 4
6 | enc_qk_depth: 2
7 | enc_pre_depth: 4
8 | enc_post_depth: 2
9 | dec_depth: 3
10 |
--------------------------------------------------------------------------------
/bayesian_optimization/.gitignore:
--------------------------------------------------------------------------------
1 | results/bo_rbf_oracle_*
2 | results/bo_rbf_noisy_oracle_*
3 | results/bo_matern_oracle_*
4 | results/bo_matern_noisy_oracle_*
5 | results/bo_periodic_oracle_*
6 | results/bo_periodic_noisy_oracle_*
7 |
8 |
--------------------------------------------------------------------------------
/regression/utils/paths.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ROOT = '/nfs/lang/ext01/john'
4 |
5 | evalsets_path = os.path.join(ROOT, 'bnp', 'evalsets')
6 | datasets_path = os.path.join(ROOT, 'datasets')
7 | results_path = os.path.join(ROOT, 'bnp', 'results')
8 |
--------------------------------------------------------------------------------
/bayesian_optimization/test_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python run_bo.py --mode bo --model anp
4 | python run_bo.py --mode bo --model banp
5 | python run_bo.py --mode bo --model bnp
6 | python run_bo.py --mode bo --model canp
7 | python run_bo.py --mode bo --model cnp
8 | python run_bo.py --mode bo --model np
9 |
10 |
--------------------------------------------------------------------------------
/regression/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib.machinery import SourceFileLoader
3 | import math
4 | import torch
5 |
6 | def gen_load_func(parser, func):
7 | def load(args, cmdline):
8 | sub_args, cmdline = parser.parse_known_args(cmdline)
9 | for k, v in sub_args.__dict__.items():
10 | args.__dict__[k] = v
11 | return func(**sub_args.__dict__), cmdline
12 | return load
13 |
14 | def load_module(filename):
15 | module_name = os.path.splitext(os.path.basename(filename))[0]
16 | return SourceFileLoader(module_name, filename).load_module()
17 |
18 | def logmeanexp(x, dim=0):
19 | return x.logsumexp(dim) - math.log(x.shape[dim])
20 |
21 | def stack(x, num_samples=None, dim=0):
22 | return x if num_samples is None \
23 | else torch.stack([x]*num_samples, dim=dim)
24 |
--------------------------------------------------------------------------------
/bayesian_optimization/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib.machinery import SourceFileLoader
3 | import math
4 | import torch
5 |
6 | def gen_load_func(parser, func):
7 | def load(args, cmdline):
8 | sub_args, cmdline = parser.parse_known_args(cmdline)
9 | for k, v in sub_args.__dict__.items():
10 | args.__dict__[k] = v
11 | return func(**sub_args.__dict__), cmdline
12 | return load
13 |
14 | def load_module(filename):
15 | module_name = os.path.splitext(os.path.basename(filename))[0]
16 | return SourceFileLoader(module_name, filename).load_module()
17 |
18 | def logmeanexp(x, dim=0):
19 | return x.logsumexp(dim) - math.log(x.shape[dim])
20 |
21 | def stack(x, num_samples=None, dim=0):
22 | return x if num_samples is None \
23 | else torch.stack([x]*num_samples, dim=dim)
24 |
--------------------------------------------------------------------------------
/regression/data/emnist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torchvision.datasets as tvds
5 |
6 | from utils.paths import datasets_path
7 | from utils.misc import gen_load_func
8 |
9 | class EMNIST(tvds.EMNIST):
10 | def __init__(self, train=True, class_range=[0, 47], device='cpu', download=True):
11 | super().__init__(datasets_path, train=train, split='balanced', download=download)
12 |
13 | self.data = self.data.unsqueeze(1).float().div(255).transpose(-1, -2).to(device)
14 | self.targets = self.targets.to(device)
15 |
16 | idxs = []
17 | for c in range(class_range[0], class_range[1]):
18 | idxs.append(torch.where(self.targets==c)[0])
19 | idxs = torch.cat(idxs)
20 |
21 | self.data = self.data[idxs]
22 | self.targets = self.targets[idxs]
23 |
24 | def __getitem__(self, idx):
25 | return self.data[idx], self.targets[idx]
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Juho Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/regression/utils/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def gather(items, idxs):
4 | K = idxs.shape[0]
5 | idxs = idxs.to(items[0].device)
6 | gathered = []
7 | for item in items:
8 | gathered.append(torch.gather(
9 | torch.stack([item]*K), -2,
10 | torch.stack([idxs]*item.shape[-1], -1)).squeeze(0))
11 | return gathered[0] if len(gathered) == 1 else gathered
12 |
13 | def sample_subset(*items, r_N=None, num_samples=None):
14 | r_N = r_N or torch.rand(1).item()
15 | K = num_samples or 1
16 | N = items[0].shape[-2]
17 | Ns = min(max(1, int(r_N * N)), N-1)
18 | batch_shape = items[0].shape[:-2]
19 | idxs = torch.rand((K,)+batch_shape+(N,)).argsort(-1)
20 | return gather(items, idxs[...,:Ns]), gather(items, idxs[...,Ns:])
21 |
22 | def sample_with_replacement(*items, num_samples=None, r_N=1.0, N_s=None):
23 | K = num_samples or 1
24 | N = items[0].shape[-2]
25 | N_s = N_s or max(1, int(r_N * N))
26 | batch_shape = items[0].shape[:-2]
27 | idxs = torch.randint(N, size=(K,)+batch_shape+(N_s,))
28 | return gather(items, idxs)
29 |
30 | def sample_mask(B, N, num_samples=None, min_num=3, prob=0.5):
31 | min_num = min(min_num, N)
32 | K = num_samples or 1
33 | fixed = torch.ones(K, B, min_num)
34 | if N - min_num > 0:
35 | rand = torch.bernoulli(prob*torch.ones(K, B, N-min_num))
36 | mask = torch.cat([fixed, rand], -1)
37 | return mask.squeeze(0)
38 | else:
39 | return fixed.squeeze(0)
40 |
--------------------------------------------------------------------------------
/bayesian_optimization/utils/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def gather(items, idxs):
4 | K = idxs.shape[0]
5 | idxs = idxs.to(items[0].device)
6 | gathered = []
7 | for item in items:
8 | gathered.append(torch.gather(
9 | torch.stack([item]*K), -2,
10 | torch.stack([idxs]*item.shape[-1], -1)).squeeze(0))
11 | return gathered[0] if len(gathered) == 1 else gathered
12 |
13 | def sample_subset(*items, r_N=None, num_samples=None):
14 | r_N = r_N or torch.rand(1).item()
15 | K = num_samples or 1
16 | N = items[0].shape[-2]
17 | Ns = max(1, int(r_N * N))
18 | batch_shape = items[0].shape[:-2]
19 | idxs = torch.rand((K,)+batch_shape+(Ns,)).argsort(-1)
20 | return gather(items, idxs[...,:Ns]), gather(items, idxs[...,Ns:])
21 |
22 | def sample_with_replacement(*items, num_samples=None, r_N=1.0, N_s=None):
23 | K = num_samples or 1
24 | N = items[0].shape[-2]
25 | N_s = N_s or max(1, int(r_N * N))
26 | batch_shape = items[0].shape[:-2]
27 | idxs = torch.randint(N, size=(K,)+batch_shape+(N_s,))
28 | return gather(items, idxs)
29 |
30 | def sample_mask(B, N, num_samples=None, min_num=3, prob=0.5):
31 | min_num = min(min_num, N)
32 | K = num_samples or 1
33 | fixed = torch.ones(K, B, min_num)
34 | if N - min_num > 0:
35 | rand = torch.bernoulli(prob*torch.ones(K, B, N-min_num))
36 | mask = torch.cat([fixed, rand], -1)
37 | return mask.squeeze(0)
38 | else:
39 | return fixed.squeeze(0)
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bootstrapping Neural Processes
2 | The official repository for the paper [Bootstrapping Neural Processes](https://arxiv.org/abs/2008.02956) (NeurIPS 2020) by Juho Lee et al.
3 |
4 |
5 |
6 |
7 |
8 | ## Abstract
9 | Unlike in the traditional statistical modeling for which a user typically hand-specify a prior, Neural Processes (NPs) implicitly define a broad class of stochastic processes with neural networks. Given a data stream, NP learns a stochastic process that best describes the data. While this ``data-driven'' way of learning stochastic processes has proven to handle various types of data, NPs still rely on an assumption that uncertainty in stochastic processes is modeled by a single latent variable, which potentially limits the flexibility. To this end, we propose the Boostrapping Neural Process (BNP), a novel extension of the NP family using the bootstrap. The bootstrap is a classical data-driven technique for estimating uncertainty, which allows BNP to learn the stochasticity in NPs without assuming a particular form. We demonstrate the efficacy of BNP on various types of data and its robustness in the presence of model-data mismatch.
10 |
11 | ## Citation
12 | If you find this useful in your research, please consider citing our paper:
13 | ```
14 | @misc{lee2020bootstrapping,
15 | title={Bootstrapping Neural Processes},
16 | author={Juho Lee and Yoonho Lee and Jungtaek Kim and Eunho Yang and Sung Ju Hwang and Yee Whye Teh},
17 | year={2020},
18 | journal={arXiv preprint arXiv:2008.02956},
19 | }
20 | ```
--------------------------------------------------------------------------------
/regression/utils/log.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import logging
4 | from collections import OrderedDict
5 |
6 | def get_logger(filename, mode='a'):
7 | logging.basicConfig(level=logging.INFO, format='%(message)s')
8 | logger = logging.getLogger()
9 | logger.addHandler(logging.FileHandler(filename, mode=mode))
10 | return logger
11 |
12 | class RunningAverage(object):
13 | def __init__(self, *keys):
14 | self.sum = OrderedDict()
15 | self.cnt = OrderedDict()
16 | self.clock = time.time()
17 | for key in keys:
18 | self.sum[key] = 0
19 | self.cnt[key] = 0
20 |
21 | def update(self, key, val):
22 | if isinstance(val, torch.Tensor):
23 | val = val.item()
24 | if self.sum.get(key, None) is None:
25 | self.sum[key] = val
26 | self.cnt[key] = 1
27 | else:
28 | self.sum[key] = self.sum[key] + val
29 | self.cnt[key] += 1
30 |
31 | def reset(self):
32 | for key in self.sum.keys():
33 | self.sum[key] = 0
34 | self.cnt[key] = 0
35 | self.clock = time.time()
36 |
37 | def clear(self):
38 | self.sum = OrderedDict()
39 | self.cnt = OrderedDict()
40 | self.clock = time.time()
41 |
42 | def keys(self):
43 | return self.sum.keys()
44 |
45 | def get(self, key):
46 | assert(self.sum.get(key, None) is not None)
47 | return self.sum[key] / self.cnt[key]
48 |
49 | def info(self, show_et=True):
50 | line = ''
51 | for key in self.sum.keys():
52 | val = self.sum[key] / self.cnt[key]
53 | if type(val) == float:
54 | line += f'{key} {val:.4f} '
55 | else:
56 | line += f'{key} {val} '.format(key, val)
57 | if show_et:
58 | line += f'({time.time()-self.clock:.3f} secs)'
59 | return line
60 |
--------------------------------------------------------------------------------
/bayesian_optimization/utils/log.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import logging
4 | from collections import OrderedDict
5 |
6 | def get_logger(filename, mode='a'):
7 | logging.basicConfig(level=logging.INFO, format='%(message)s')
8 | logger = logging.getLogger()
9 | logger.addHandler(logging.FileHandler(filename, mode=mode))
10 | return logger
11 |
12 | class RunningAverage(object):
13 | def __init__(self, *keys):
14 | self.sum = OrderedDict()
15 | self.cnt = OrderedDict()
16 | self.clock = time.time()
17 | for key in keys:
18 | self.sum[key] = 0
19 | self.cnt[key] = 0
20 |
21 | def update(self, key, val):
22 | if isinstance(val, torch.Tensor):
23 | val = val.item()
24 | if self.sum.get(key, None) is None:
25 | self.sum[key] = val
26 | self.cnt[key] = 1
27 | else:
28 | self.sum[key] = self.sum[key] + val
29 | self.cnt[key] += 1
30 |
31 | def reset(self):
32 | for key in self.sum.keys():
33 | self.sum[key] = 0
34 | self.cnt[key] = 0
35 | self.clock = time.time()
36 |
37 | def clear(self):
38 | self.sum = OrderedDict()
39 | self.cnt = OrderedDict()
40 | self.clock = time.time()
41 |
42 | def keys(self):
43 | return self.sum.keys()
44 |
45 | def get(self, key):
46 | assert(self.sum.get(key, None) is not None)
47 | return self.sum[key] / self.cnt[key]
48 |
49 | def info(self, show_et=True):
50 | line = ''
51 | for key in self.sum.keys():
52 | val = self.sum[key] / self.cnt[key]
53 | if type(val) == float:
54 | line += f'{key} {val:.4f} '
55 | else:
56 | line += f'{key} {val} '.format(key, val)
57 | if show_et:
58 | line += f'({time.time()-self.clock:.3f} secs)'
59 | return line
60 |
--------------------------------------------------------------------------------
/regression/models/cnp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.modules import PoolingEncoder, Decoder
6 |
7 | class CNP(nn.Module):
8 | def __init__(self,
9 | dim_x=1,
10 | dim_y=1,
11 | dim_hid=128,
12 | enc_pre_depth=4,
13 | enc_post_depth=2,
14 | dec_depth=3):
15 |
16 | super().__init__()
17 |
18 | self.enc1 = PoolingEncoder(
19 | dim_x=dim_x,
20 | dim_y=dim_y,
21 | dim_hid=dim_hid,
22 | pre_depth=enc_pre_depth,
23 | post_depth=enc_post_depth)
24 |
25 | self.enc2 = PoolingEncoder(
26 | dim_x=dim_x,
27 | dim_y=dim_y,
28 | dim_hid=dim_hid,
29 | pre_depth=enc_pre_depth,
30 | post_depth=enc_post_depth)
31 |
32 | self.dec = Decoder(
33 | dim_x=dim_x,
34 | dim_y=dim_y,
35 | dim_enc=2*dim_hid,
36 | dim_hid=dim_hid,
37 | depth=dec_depth)
38 |
39 | def predict(self, xc, yc, xt, num_samples=None):
40 | encoded = torch.cat([self.enc1(xc, yc), self.enc2(xc, yc)], -1)
41 | encoded = torch.stack([encoded]*xt.shape[-2], -2)
42 | return self.dec(encoded, xt)
43 |
44 | def forward(self, batch, num_samples=None, reduce_ll=True):
45 | outs = AttrDict()
46 | py = self.predict(batch.xc, batch.yc, batch.x)
47 | ll = py.log_prob(batch.y).sum(-1)
48 |
49 | if self.training:
50 | outs.loss = -ll.mean()
51 | else:
52 | num_ctx = batch.xc.shape[-2]
53 | if reduce_ll:
54 | outs.ctx_ll = ll[...,:num_ctx].mean()
55 | outs.tar_ll = ll[...,num_ctx:].mean()
56 | else:
57 | outs.ctx_ll = ll[...,:num_ctx]
58 | outs.tar_ll = ll[...,num_ctx:]
59 |
60 | return outs
61 |
--------------------------------------------------------------------------------
/regression/models/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | class MultiHeadAttn(nn.Module):
7 | def __init__(self, dim_q, dim_k, dim_v, dim_out, num_heads=8):
8 | super().__init__()
9 | self.num_heads = num_heads
10 | self.dim_out = dim_out
11 | self.fc_q = nn.Linear(dim_q, dim_out, bias=False)
12 | self.fc_k = nn.Linear(dim_k, dim_out, bias=False)
13 | self.fc_v = nn.Linear(dim_v, dim_out, bias=False)
14 | self.fc_out = nn.Linear(dim_out, dim_out)
15 | self.ln1 = nn.LayerNorm(dim_out)
16 | self.ln2 = nn.LayerNorm(dim_out)
17 |
18 | def scatter(self, x):
19 | return torch.cat(x.chunk(self.num_heads, -1), -3)
20 |
21 | def gather(self, x):
22 | return torch.cat(x.chunk(self.num_heads, -3), -1)
23 |
24 | def attend(self, q, k, v, mask=None):
25 | q_, k_, v_ = [self.scatter(x) for x in [q, k, v]]
26 | A_logits = q_ @ k_.transpose(-2, -1) / math.sqrt(self.dim_out)
27 | if mask is not None:
28 | mask = mask.bool().to(q.device)
29 | mask = torch.stack([mask]*q.shape[-2], -2)
30 | mask = torch.cat([mask]*self.num_heads, -3)
31 | A = torch.softmax(A_logits.masked_fill(mask, -float('inf')), -1)
32 | A = A.masked_fill(torch.isnan(A), 0.0)
33 | else:
34 | A = torch.softmax(A_logits, -1)
35 | return self.gather(A @ v_)
36 |
37 | def forward(self, q, k, v, mask=None):
38 | q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v)
39 | out = self.ln1(q + self.attend(q, k, v, mask=mask))
40 | out = self.ln2(out + F.relu(self.fc_out(out)))
41 | return out
42 |
43 | class SelfAttn(MultiHeadAttn):
44 | def __init__(self, dim_in, dim_out, num_heads=8):
45 | super().__init__(dim_in, dim_in, dim_in, dim_out, num_heads)
46 |
47 | def forward(self, x, mask=None):
48 | return super().forward(x, x, x, mask=mask)
49 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | class MultiHeadAttn(nn.Module):
7 | def __init__(self, dim_q, dim_k, dim_v, dim_out, num_heads=8):
8 | super().__init__()
9 | self.num_heads = num_heads
10 | self.dim_out = dim_out
11 | self.fc_q = nn.Linear(dim_q, dim_out, bias=False)
12 | self.fc_k = nn.Linear(dim_k, dim_out, bias=False)
13 | self.fc_v = nn.Linear(dim_v, dim_out, bias=False)
14 | self.fc_out = nn.Linear(dim_out, dim_out)
15 | self.ln1 = nn.LayerNorm(dim_out)
16 | self.ln2 = nn.LayerNorm(dim_out)
17 |
18 | def scatter(self, x):
19 | return torch.cat(x.chunk(self.num_heads, -1), -3)
20 |
21 | def gather(self, x):
22 | return torch.cat(x.chunk(self.num_heads, -3), -1)
23 |
24 | def attend(self, q, k, v, mask=None):
25 | q_, k_, v_ = [self.scatter(x) for x in [q, k, v]]
26 | A_logits = q_ @ k_.transpose(-2, -1) / math.sqrt(self.dim_out)
27 | if mask is not None:
28 | mask = mask.bool().to(q.device)
29 | mask = torch.stack([mask]*q.shape[-2], -2)
30 | mask = torch.cat([mask]*self.num_heads, -3)
31 | A = torch.softmax(A_logits.masked_fill(mask, -float('inf')), -1)
32 | A = A.masked_fill(torch.isnan(A), 0.0)
33 | else:
34 | A = torch.softmax(A_logits, -1)
35 | return self.gather(A @ v_)
36 |
37 | def forward(self, q, k, v, mask=None):
38 | q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v)
39 | out = self.ln1(q + self.attend(q, k, v, mask=mask))
40 | out = self.ln2(out + F.relu(self.fc_out(out)))
41 | return out
42 |
43 | class SelfAttn(MultiHeadAttn):
44 | def __init__(self, dim_in, dim_out, num_heads=8):
45 | super().__init__(dim_in, dim_in, dim_in, dim_out, num_heads)
46 |
47 | def forward(self, x, mask=None):
48 | return super().forward(x, x, x, mask=mask)
49 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/cnp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.modules import PoolingEncoder, Decoder
6 |
7 | class CNP(nn.Module):
8 | def __init__(self,
9 | dim_x=1,
10 | dim_y=1,
11 | dim_hid=128,
12 | enc_pre_depth=4,
13 | enc_post_depth=2,
14 | dec_depth=3):
15 |
16 | super().__init__()
17 |
18 | self.enc1 = PoolingEncoder(
19 | dim_x=dim_x,
20 | dim_y=dim_y,
21 | dim_hid=dim_hid,
22 | pre_depth=enc_pre_depth,
23 | post_depth=enc_post_depth)
24 |
25 | self.enc2 = PoolingEncoder(
26 | dim_x=dim_x,
27 | dim_y=dim_y,
28 | dim_hid=dim_hid,
29 | pre_depth=enc_pre_depth,
30 | post_depth=enc_post_depth)
31 |
32 | self.dec = Decoder(
33 | dim_x=dim_x,
34 | dim_y=dim_y,
35 | dim_enc=2*dim_hid,
36 | dim_hid=dim_hid,
37 | depth=dec_depth)
38 |
39 | def predict(self, xc, yc, xt, num_samples=None):
40 | encoded = torch.cat([self.enc1(xc, yc), self.enc2(xc, yc)], -1)
41 | encoded = torch.stack([encoded]*xt.shape[-2], -2)
42 | return self.dec(encoded, xt)
43 |
44 | def forward(self, batch, num_samples=None, reduce_ll=True):
45 | outs = AttrDict()
46 | py = self.predict(batch.xc, batch.yc, batch.x)
47 | ll = py.log_prob(batch.y).sum(-1)
48 |
49 | if self.training:
50 | outs.loss = -ll.mean()
51 | else:
52 | num_ctx = batch.xc.shape[-2]
53 | if reduce_ll:
54 | outs.ctx_ll = ll[...,:num_ctx].mean()
55 | outs.tar_ll = ll[...,num_ctx:].mean()
56 | else:
57 | outs.ctx_ll = ll[...,:num_ctx]
58 | outs.tar_ll = ll[...,num_ctx:]
59 |
60 | return outs
61 |
--------------------------------------------------------------------------------
/regression/models/canp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.modules import CrossAttnEncoder, Decoder, PoolingEncoder
6 |
7 | class CANP(nn.Module):
8 | def __init__(self,
9 | dim_x=1,
10 | dim_y=1,
11 | dim_hid=128,
12 | enc_v_depth=4,
13 | enc_qk_depth=2,
14 | enc_pre_depth=4,
15 | enc_post_depth=2,
16 | dec_depth=3):
17 |
18 | super().__init__()
19 |
20 | self.enc1 = CrossAttnEncoder(
21 | dim_x=dim_x,
22 | dim_y=dim_y,
23 | dim_hid=dim_hid,
24 | v_depth=enc_v_depth,
25 | qk_depth=enc_qk_depth)
26 |
27 | self.enc2 = PoolingEncoder(
28 | dim_x=dim_x,
29 | dim_y=dim_y,
30 | dim_hid=dim_hid,
31 | self_attn=True,
32 | pre_depth=enc_pre_depth,
33 | post_depth=enc_post_depth)
34 |
35 | self.dec = Decoder(
36 | dim_x=dim_x,
37 | dim_y=dim_y,
38 | dim_enc=2*dim_hid,
39 | dim_hid=dim_hid,
40 | depth=dec_depth)
41 |
42 | def predict(self, xc, yc, xt, num_samples=None):
43 | theta1 = self.enc1(xc, yc, xt)
44 | theta2 = self.enc2(xc, yc)
45 | encoded = torch.cat([theta1,
46 | torch.stack([theta2]*xt.shape[-2], -2)], -1)
47 | return self.dec(encoded, xt)
48 |
49 | def forward(self, batch, num_samples=None, reduce_ll=True):
50 | outs = AttrDict()
51 | py = self.predict(batch.xc, batch.yc, batch.x)
52 | ll = py.log_prob(batch.y).sum(-1)
53 |
54 | if self.training:
55 | outs.loss = -ll.mean()
56 | else:
57 | num_ctx = batch.xc.shape[-2]
58 | if reduce_ll:
59 | outs.ctx_ll = ll[...,:num_ctx].mean()
60 | outs.tar_ll = ll[...,num_ctx:].mean()
61 | else:
62 | outs.ctx_ll = ll[...,:num_ctx]
63 | outs.tar_ll = ll[...,num_ctx:]
64 |
65 | return outs
66 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/canp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.modules import CrossAttnEncoder, Decoder, PoolingEncoder
6 |
7 | class CANP(nn.Module):
8 | def __init__(self,
9 | dim_x=1,
10 | dim_y=1,
11 | dim_hid=128,
12 | enc_v_depth=4,
13 | enc_qk_depth=2,
14 | enc_pre_depth=4,
15 | enc_post_depth=2,
16 | dec_depth=3):
17 |
18 | super().__init__()
19 |
20 | self.enc1 = CrossAttnEncoder(
21 | dim_x=dim_x,
22 | dim_y=dim_y,
23 | dim_hid=dim_hid,
24 | v_depth=enc_v_depth,
25 | qk_depth=enc_qk_depth)
26 |
27 | self.enc2 = PoolingEncoder(
28 | dim_x=dim_x,
29 | dim_y=dim_y,
30 | dim_hid=dim_hid,
31 | self_attn=True,
32 | pre_depth=enc_pre_depth,
33 | post_depth=enc_post_depth)
34 |
35 | self.dec = Decoder(
36 | dim_x=dim_x,
37 | dim_y=dim_y,
38 | dim_enc=2*dim_hid,
39 | dim_hid=dim_hid,
40 | depth=dec_depth)
41 |
42 | def predict(self, xc, yc, xt, num_samples=None):
43 | theta1 = self.enc1(xc, yc, xt)
44 | theta2 = self.enc2(xc, yc)
45 | encoded = torch.cat([theta1,
46 | torch.stack([theta2]*xt.shape[-2], -2)], -1)
47 | return self.dec(encoded, xt)
48 |
49 | def forward(self, batch, num_samples=None, reduce_ll=True):
50 | outs = AttrDict()
51 | py = self.predict(batch.xc, batch.yc, batch.x)
52 | ll = py.log_prob(batch.y).sum(-1)
53 |
54 | if self.training:
55 | outs.loss = -ll.mean()
56 | else:
57 | num_ctx = batch.xc.shape[-2]
58 | if reduce_ll:
59 | outs.ctx_ll = ll[...,:num_ctx].mean()
60 | outs.tar_ll = ll[...,num_ctx:].mean()
61 | else:
62 | outs.ctx_ll = ll[...,:num_ctx]
63 | outs.tar_ll = ll[...,num_ctx:]
64 |
65 | return outs
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/regression/models/bnp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.cnp import CNP
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_with_replacement as SWR, sample_subset
8 |
9 | class BNP(CNP):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.dec.add_ctx(2*kwargs['dim_hid'])
13 |
14 | def encode(self, xc, yc, xt, mask=None):
15 | encoded = torch.cat([
16 | self.enc1(xc, yc, mask=mask),
17 | self.enc2(xc, yc, mask=mask)], -1)
18 | return stack(encoded, xt.shape[-2], -2)
19 |
20 | def predict(self, xc, yc, xt, num_samples=None, return_base=False):
21 | with torch.no_grad():
22 | bxc, byc = SWR(xc, yc, num_samples=num_samples)
23 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples)
24 |
25 | encoded = self.encode(bxc, byc, sxc)
26 | py_res = self.dec(encoded, sxc)
27 |
28 | mu, sigma = py_res.mean, py_res.scale
29 | res = SWR((syc - mu)/sigma).detach()
30 | res = (res - res.mean(-2, keepdim=True))
31 |
32 | bxc = sxc
33 | byc = mu + sigma * res
34 |
35 | encoded_base = self.encode(xc, yc, xt)
36 |
37 | sxt = stack(xt, num_samples)
38 | encoded_bs = self.encode(bxc, byc, sxt)
39 |
40 | py = self.dec(stack(encoded_base, num_samples),
41 | sxt, ctx=encoded_bs)
42 |
43 | if self.training or return_base:
44 | py_base = self.dec(encoded_base, xt)
45 | return py_base, py
46 | else:
47 | return py
48 |
49 | def forward(self, batch, num_samples=None, reduce_ll=True):
50 | outs = AttrDict()
51 |
52 | def compute_ll(py, y):
53 | ll = py.log_prob(y).sum(-1)
54 | if ll.dim() == 3 and reduce_ll:
55 | ll = logmeanexp(ll)
56 | return ll
57 |
58 | if self.training:
59 | py_base, py = self.predict(batch.xc, batch.yc, batch.x,
60 | num_samples=num_samples)
61 |
62 | outs.ll_base = compute_ll(py_base, batch.y).mean()
63 | outs.ll = compute_ll(py, batch.y).mean()
64 | outs.loss = -outs.ll_base - outs.ll
65 | else:
66 | py = self.predict(batch.xc, batch.yc, batch.x,
67 | num_samples=num_samples)
68 | ll = compute_ll(py, batch.y)
69 | num_ctx = batch.xc.shape[-2]
70 | if reduce_ll:
71 | outs.ctx_ll = ll[...,:num_ctx].mean()
72 | outs.tar_ll = ll[...,num_ctx:].mean()
73 | else:
74 | outs.ctx_ll = ll[...,:num_ctx]
75 | outs.tar_ll = ll[...,num_ctx:]
76 |
77 | return outs
78 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/bnp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.cnp import CNP
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_with_replacement as SWR, sample_subset
8 |
9 | class BNP(CNP):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.dec.add_ctx(2*kwargs['dim_hid'])
13 |
14 | def encode(self, xc, yc, xt, mask=None):
15 | encoded = torch.cat([
16 | self.enc1(xc, yc, mask=mask),
17 | self.enc2(xc, yc, mask=mask)], -1)
18 | return stack(encoded, xt.shape[-2], -2)
19 |
20 | def predict(self, xc, yc, xt, num_samples=None, return_base=False):
21 | with torch.no_grad():
22 | bxc, byc = SWR(xc, yc, num_samples=num_samples)
23 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples)
24 |
25 | encoded = self.encode(bxc, byc, sxc)
26 | py_res = self.dec(encoded, sxc)
27 |
28 | mu, sigma = py_res.mean, py_res.scale
29 | res = SWR((syc - mu)/sigma).detach()
30 | res = (res - res.mean(-2, keepdim=True))
31 |
32 | bxc = sxc
33 | byc = mu + sigma * res
34 |
35 | encoded_base = self.encode(xc, yc, xt)
36 |
37 | sxt = stack(xt, num_samples)
38 | encoded_bs = self.encode(bxc, byc, sxt)
39 |
40 | py = self.dec(stack(encoded_base, num_samples),
41 | sxt, ctx=encoded_bs)
42 |
43 | if self.training or return_base:
44 | py_base = self.dec(encoded_base, xt)
45 | return py_base, py
46 | else:
47 | return py
48 |
49 | def forward(self, batch, num_samples=None, reduce_ll=True):
50 | outs = AttrDict()
51 |
52 | def compute_ll(py, y):
53 | ll = py.log_prob(y).sum(-1)
54 | if ll.dim() == 3 and reduce_ll:
55 | ll = logmeanexp(ll)
56 | return ll
57 |
58 | if self.training:
59 | py_base, py = self.predict(batch.xc, batch.yc, batch.x,
60 | num_samples=num_samples)
61 |
62 | outs.ll_base = compute_ll(py_base, batch.y).mean()
63 | outs.ll = compute_ll(py, batch.y).mean()
64 | outs.loss = -outs.ll_base - outs.ll
65 | else:
66 | py = self.predict(batch.xc, batch.yc, batch.x,
67 | num_samples=num_samples)
68 | ll = compute_ll(py, batch.y)
69 | num_ctx = batch.xc.shape[-2]
70 | if reduce_ll:
71 | outs.ctx_ll = ll[...,:num_ctx].mean()
72 | outs.tar_ll = ll[...,num_ctx:].mean()
73 | else:
74 | outs.ctx_ll = ll[...,:num_ctx]
75 | outs.tar_ll = ll[...,num_ctx:]
76 |
77 | return outs
78 |
--------------------------------------------------------------------------------
/regression/models/banp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.canp import CANP
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_with_replacement as SWR, sample_subset
8 |
9 | class BANP(CANP):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.dec.add_ctx(2*kwargs['dim_hid'])
13 |
14 | def encode(self, xc, yc, xt, mask=None):
15 | theta1 = self.enc1(xc, yc, xt)
16 | theta2 = self.enc2(xc, yc)
17 | encoded = torch.cat([theta1,
18 | torch.stack([theta2]*xt.shape[-2], -2)], -1)
19 | return encoded
20 |
21 | def predict(self, xc, yc, xt, num_samples=None, return_base=False):
22 | with torch.no_grad():
23 | bxc, byc = SWR(xc, yc, num_samples=num_samples)
24 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples)
25 |
26 | encoded = self.encode(bxc, byc, sxc)
27 | py_res = self.dec(encoded, sxc)
28 |
29 | mu, sigma = py_res.mean, py_res.scale
30 | res = SWR((syc - mu)/sigma).detach()
31 | res = (res - res.mean(-2, keepdim=True))
32 |
33 | bxc = sxc
34 | byc = mu + sigma * res
35 |
36 | encoded_base = self.encode(xc, yc, xt)
37 |
38 | sxt = stack(xt, num_samples)
39 | encoded_bs = self.encode(bxc, byc, sxt)
40 |
41 | py = self.dec(stack(encoded_base, num_samples),
42 | sxt, ctx=encoded_bs)
43 |
44 | if self.training or return_base:
45 | py_base = self.dec(encoded_base, xt)
46 | return py_base, py
47 | else:
48 | return py
49 |
50 | def forward(self, batch, num_samples=None, reduce_ll=True):
51 | outs = AttrDict()
52 |
53 | def compute_ll(py, y):
54 | ll = py.log_prob(y).sum(-1)
55 | if ll.dim() == 3 and reduce_ll:
56 | ll = logmeanexp(ll)
57 | return ll
58 |
59 | if self.training:
60 | py_base, py = self.predict(batch.xc, batch.yc, batch.x,
61 | num_samples=num_samples)
62 |
63 | outs.ll_base = compute_ll(py_base, batch.y).mean()
64 | outs.ll = compute_ll(py, batch.y).mean()
65 | outs.loss = -outs.ll_base - outs.ll
66 | else:
67 | py = self.predict(batch.xc, batch.yc, batch.x,
68 | num_samples=num_samples)
69 | ll = compute_ll(py, batch.y)
70 | num_ctx = batch.xc.shape[-2]
71 | if reduce_ll:
72 | outs.ctx_ll = ll[...,:num_ctx].mean()
73 | outs.tar_ll = ll[...,num_ctx:].mean()
74 | else:
75 | outs.ctx_ll = ll[...,:num_ctx]
76 | outs.tar_ll = ll[...,num_ctx:]
77 |
78 | return outs
79 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/banp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attrdict import AttrDict
4 |
5 | from models.canp import CANP
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_with_replacement as SWR, sample_subset
8 |
9 | class BANP(CANP):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.dec.add_ctx(2*kwargs['dim_hid'])
13 |
14 | def encode(self, xc, yc, xt, mask=None):
15 | theta1 = self.enc1(xc, yc, xt)
16 | theta2 = self.enc2(xc, yc)
17 | encoded = torch.cat([theta1,
18 | torch.stack([theta2]*xt.shape[-2], -2)], -1)
19 | return encoded
20 |
21 | def predict(self, xc, yc, xt, num_samples=None, return_base=False):
22 | with torch.no_grad():
23 | bxc, byc = SWR(xc, yc, num_samples=num_samples)
24 | sxc, syc = stack(xc, num_samples), stack(yc, num_samples)
25 |
26 | encoded = self.encode(bxc, byc, sxc)
27 | py_res = self.dec(encoded, sxc)
28 |
29 | mu, sigma = py_res.mean, py_res.scale
30 | res = SWR((syc - mu)/sigma).detach()
31 | res = (res - res.mean(-2, keepdim=True))
32 |
33 | bxc = sxc
34 | byc = mu + sigma * res
35 |
36 | encoded_base = self.encode(xc, yc, xt)
37 |
38 | sxt = stack(xt, num_samples)
39 | encoded_bs = self.encode(bxc, byc, sxt)
40 |
41 | py = self.dec(stack(encoded_base, num_samples),
42 | sxt, ctx=encoded_bs)
43 |
44 | if self.training or return_base:
45 | py_base = self.dec(encoded_base, xt)
46 | return py_base, py
47 | else:
48 | return py
49 |
50 | def forward(self, batch, num_samples=None, reduce_ll=True):
51 | outs = AttrDict()
52 |
53 | def compute_ll(py, y):
54 | ll = py.log_prob(y).sum(-1)
55 | if ll.dim() == 3 and reduce_ll:
56 | ll = logmeanexp(ll)
57 | return ll
58 |
59 | if self.training:
60 | py_base, py = self.predict(batch.xc, batch.yc, batch.x,
61 | num_samples=num_samples)
62 |
63 | outs.ll_base = compute_ll(py_base, batch.y).mean()
64 | outs.ll = compute_ll(py, batch.y).mean()
65 | outs.loss = -outs.ll_base - outs.ll
66 | else:
67 | py = self.predict(batch.xc, batch.yc, batch.x,
68 | num_samples=num_samples)
69 | ll = compute_ll(py, batch.y)
70 | num_ctx = batch.xc.shape[-2]
71 | if reduce_ll:
72 | outs.ctx_ll = ll[...,:num_ctx].mean()
73 | outs.tar_ll = ll[...,num_ctx:].mean()
74 | else:
75 | outs.ctx_ll = ll[...,:num_ctx]
76 | outs.tar_ll = ll[...,num_ctx:]
77 |
78 | return outs
79 |
--------------------------------------------------------------------------------
/regression/data/celeba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os.path as osp
3 | import argparse
4 |
5 | from utils.paths import datasets_path
6 | from utils.misc import gen_load_func
7 |
8 | class CelebA(object):
9 | def __init__(self, train=True):
10 | self.data, self.targets = torch.load(
11 | osp.join(datasets_path, 'celeba',
12 | 'train.pt' if train else 'eval.pt'))
13 | self.data = self.data.float() / 255.0
14 |
15 | if train:
16 | self.data, self.targets = self.data, self.targets
17 | else:
18 | self.data, self.targets = self.data, self.targets
19 |
20 | def __len__(self):
21 | return len(self.data)
22 |
23 | def __getitem__(self, index):
24 | return self.data[index], self.targets[index]
25 |
26 | if __name__ == '__main__':
27 |
28 | # preprocess
29 | # before proceeding, download img_celeba.7z from
30 | # https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg
31 | # ,download list_eval_partitions.txt from
32 | # https://drive.google.com/drive/folders/0B7EVK8r0v71pdjI3dmwtNm5jRkE
33 | # and download identity_CelebA.txt from
34 | # https://drive.google.com/drive/folders/0B7EVK8r0v71pOC0wOVZlQnFfaGs
35 | # and place them in ${datasets_path}/celeba folder.
36 |
37 | import os
38 | import os.path as osp
39 | from PIL import Image
40 | from tqdm import tqdm
41 | import numpy as np
42 | import torch
43 |
44 | # load train/val/test split
45 | splitdict = {}
46 | with open(osp.join(datasets_path, 'celeba', 'list_eval_partition.txt'), 'r') as f:
47 | for line in f:
48 | fn, split = line.split()
49 | splitdict[fn] = int(split)
50 |
51 | # load identities
52 | iddict = {}
53 | with open(osp.join(datasets_path, 'celeba', 'identity_CelebA.txt'), 'r') as f:
54 | for line in f:
55 | fn, label = line.split()
56 | iddict[fn] = int(label)
57 |
58 | train_imgs = []
59 | train_labels = []
60 | eval_imgs = []
61 | eval_labels = []
62 | path = osp.join(datasets_path, 'celeba', 'img_align_celeba')
63 | imgfilenames = os.listdir(path)
64 | for fn in tqdm(imgfilenames):
65 |
66 | img = Image.open(osp.join(path, fn)).resize((32, 32))
67 | if splitdict[fn] == 2:
68 | eval_imgs.append(torch.LongTensor(np.array(img).transpose(2, 0, 1)))
69 | eval_labels.append(iddict[fn])
70 | else:
71 | train_imgs.append(torch.LongTensor(np.array(img).transpose(2, 0, 1)))
72 | train_labels.append(iddict[fn])
73 |
74 | print(f'{len(train_imgs)} train, {len(eval_imgs)} eval')
75 |
76 | train_imgs = torch.stack(train_imgs)
77 | train_labels = torch.LongTensor(train_labels)
78 | torch.save([train_imgs, train_labels], osp.join(datasets_path, 'celeba', 'train.pt'))
79 |
80 | eval_imgs = torch.stack(eval_imgs)
81 | eval_labels = torch.LongTensor(eval_labels)
82 | torch.save([eval_imgs, eval_labels], osp.join(datasets_path, 'celeba', 'eval.pt'))
83 |
--------------------------------------------------------------------------------
/regression/data/image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from attrdict import AttrDict
3 | from torch.utils.data import DataLoader
4 | from torch.distributions import StudentT, Normal
5 |
6 | def img_to_task(img, num_ctx=None,
7 | max_num_points=None, target_all=False, t_noise=None, device=None):
8 |
9 | B, C, H, W = img.shape
10 | num_pixels = H*W
11 | img = img.view(B, C, -1)
12 |
13 | if t_noise is not None:
14 | if t_noise == -1:
15 | t_noise = 0.09 * torch.rand(img.shape)
16 | img += t_noise * StudentT(2.1).rsample(img.shape)
17 |
18 | device = img.device if device is None else device
19 |
20 | batch = AttrDict()
21 | max_num_points = max_num_points or num_pixels
22 | num_ctx = num_ctx or \
23 | torch.randint(low=3, high=max_num_points-3, size=[1]).item()
24 | num_tar = max_num_points - num_ctx if target_all else \
25 | torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item()
26 | num_points = num_ctx + num_tar
27 | idxs = torch.rand(B, num_pixels).argsort(-1)[...,:num_points].to(img.device)
28 | x1, x2 = idxs//W, idxs%W
29 | batch.x = torch.stack([
30 | 2*x1.float()/(H-1) - 1,
31 | 2*x2.float()/(W-1) - 1], -1).to(device)
32 | batch.y = (torch.gather(img, -1, idxs.unsqueeze(-2).repeat(1, C, 1))\
33 | .transpose(-2, -1) - 0.5).to(device)
34 |
35 | batch.xc = batch.x[:,:num_ctx]
36 | batch.xt = batch.x[:,num_ctx:]
37 | batch.yc = batch.y[:,:num_ctx]
38 | batch.yt = batch.y[:,num_ctx:]
39 |
40 | return batch
41 |
42 | def coord_to_img(x, y, shape):
43 | x = x.cpu()
44 | y = y.cpu()
45 | B = x.shape[0]
46 | C, H, W = shape
47 |
48 | I = torch.zeros(B, 3, H, W)
49 | I[:,0,:,:] = 0.61
50 | I[:,1,:,:] = 0.55
51 | I[:,2,:,:] = 0.71
52 |
53 | x1, x2 = x[...,0], x[...,1]
54 | x1 = ((x1+1)*(H-1)/2).round().long()
55 | x2 = ((x2+1)*(W-1)/2).round().long()
56 | for b in range(B):
57 | for c in range(3):
58 | I[b,c,x1[b],x2[b]] = y[b,:,min(c,C-1)]
59 |
60 | return I
61 |
62 | def task_to_img(xc, yc, xt, yt, shape):
63 | xc = xc.cpu()
64 | yc = yc.cpu()
65 | xt = xt.cpu()
66 | yt = yt.cpu()
67 |
68 | B = xc.shape[0]
69 | C, H, W = shape
70 |
71 | xc1, xc2 = xc[...,0], xc[...,1]
72 | xc1 = ((xc1+1)*(H-1)/2).round().long()
73 | xc2 = ((xc2+1)*(W-1)/2).round().long()
74 |
75 | xt1, xt2 = xt[...,0], xt[...,1]
76 | xt1 = ((xt1+1)*(H-1)/2).round().long()
77 | xt2 = ((xt2+1)*(W-1)/2).round().long()
78 |
79 | task_img = torch.zeros(B, 3, H, W).to(xc.device)
80 | task_img[:,2,:,:] = 1.0
81 | task_img[:,1,:,:] = 0.4
82 | for b in range(B):
83 | for c in range(3):
84 | task_img[b,c,xc1[b],xc2[b]] = yc[b,:,min(c,C-1)] + 0.5
85 | task_img = task_img.clamp(0, 1)
86 |
87 | completed_img = task_img.clone()
88 | for b in range(B):
89 | for c in range(3):
90 | completed_img[b,c,xt1[b],xt2[b]] = yt[b,:,min(c,C-1)] + 0.5
91 | completed_img = completed_img.clamp(0, 1)
92 |
93 | return task_img, completed_img
94 |
--------------------------------------------------------------------------------
/regression/models/np.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions import kl_divergence
4 | from attrdict import AttrDict
5 |
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_subset
8 | from models.modules import PoolingEncoder, Decoder
9 |
10 | class NP(nn.Module):
11 | def __init__(self,
12 | dim_x=1,
13 | dim_y=1,
14 | dim_hid=128,
15 | dim_lat=128,
16 | enc_pre_depth=4,
17 | enc_post_depth=2,
18 | dec_depth=3):
19 |
20 | super().__init__()
21 |
22 | self.denc = PoolingEncoder(
23 | dim_x=dim_x,
24 | dim_y=dim_y,
25 | dim_hid=dim_hid,
26 | pre_depth=enc_pre_depth,
27 | post_depth=enc_post_depth)
28 |
29 | self.lenc = PoolingEncoder(
30 | dim_x=dim_x,
31 | dim_y=dim_y,
32 | dim_hid=dim_hid,
33 | dim_lat=dim_lat,
34 | pre_depth=enc_pre_depth,
35 | post_depth=enc_post_depth)
36 |
37 | self.dec = Decoder(
38 | dim_x=dim_x,
39 | dim_y=dim_y,
40 | dim_enc=dim_hid+dim_lat,
41 | dim_hid=dim_hid,
42 | depth=dec_depth)
43 |
44 | def predict(self, xc, yc, xt, z=None, num_samples=None):
45 | theta = stack(self.denc(xc, yc), num_samples)
46 | if z is None:
47 | pz = self.lenc(xc, yc)
48 | z = pz.rsample() if num_samples is None \
49 | else pz.rsample([num_samples])
50 | encoded = torch.cat([theta, z], -1)
51 | encoded = stack(encoded, xt.shape[-2], -2)
52 | return self.dec(encoded, stack(xt, num_samples))
53 |
54 | def forward(self, batch, num_samples=None, reduce_ll=True):
55 | outs = AttrDict()
56 | if self.training:
57 | pz = self.lenc(batch.xc, batch.yc)
58 | qz = self.lenc(batch.x, batch.y)
59 | z = qz.rsample() if num_samples is None else \
60 | qz.rsample([num_samples])
61 | py = self.predict(batch.xc, batch.yc, batch.x,
62 | z=z, num_samples=num_samples)
63 |
64 | if num_samples > 1:
65 | # K * B * N
66 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
67 | # K * B
68 | log_qz = qz.log_prob(z).sum(-1)
69 | log_pz = pz.log_prob(z).sum(-1)
70 |
71 | # K * B
72 | log_w = recon.sum(-1) + log_pz - log_qz
73 |
74 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
75 | else:
76 | outs.recon = py.log_prob(batch.y).sum(-1).mean()
77 | outs.kld = kl_divergence(qz, pz).sum(-1).mean()
78 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]
79 |
80 | else:
81 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
82 | if num_samples is None:
83 | ll = py.log_prob(batch.y).sum(-1)
84 | else:
85 | y = torch.stack([batch.y]*num_samples)
86 | if reduce_ll:
87 | ll = logmeanexp(py.log_prob(y).sum(-1))
88 | else:
89 | ll = py.log_prob(y).sum(-1)
90 | num_ctx = batch.xc.shape[-2]
91 | if reduce_ll:
92 | outs.ctx_ll = ll[...,:num_ctx].mean()
93 | outs.tar_ll = ll[...,num_ctx:].mean()
94 | else:
95 | outs.ctx_ll = ll[...,:num_ctx]
96 | outs.tar_ll = ll[...,num_ctx:]
97 | return outs
98 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/np.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions import kl_divergence
4 | from attrdict import AttrDict
5 |
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_subset
8 | from models.modules import PoolingEncoder, Decoder
9 |
10 | class NP(nn.Module):
11 | def __init__(self,
12 | dim_x=1,
13 | dim_y=1,
14 | dim_hid=128,
15 | dim_lat=128,
16 | enc_pre_depth=4,
17 | enc_post_depth=2,
18 | dec_depth=3):
19 |
20 | super().__init__()
21 |
22 | self.denc = PoolingEncoder(
23 | dim_x=dim_x,
24 | dim_y=dim_y,
25 | dim_hid=dim_hid,
26 | pre_depth=enc_pre_depth,
27 | post_depth=enc_post_depth)
28 |
29 | self.lenc = PoolingEncoder(
30 | dim_x=dim_x,
31 | dim_y=dim_y,
32 | dim_hid=dim_hid,
33 | dim_lat=dim_lat,
34 | pre_depth=enc_pre_depth,
35 | post_depth=enc_post_depth)
36 |
37 | self.dec = Decoder(
38 | dim_x=dim_x,
39 | dim_y=dim_y,
40 | dim_enc=dim_hid+dim_lat,
41 | dim_hid=dim_hid,
42 | depth=dec_depth)
43 |
44 | def predict(self, xc, yc, xt, z=None, num_samples=None):
45 | theta = stack(self.denc(xc, yc), num_samples)
46 | if z is None:
47 | pz = self.lenc(xc, yc)
48 | z = pz.rsample() if num_samples is None \
49 | else pz.rsample([num_samples])
50 | encoded = torch.cat([theta, z], -1)
51 | encoded = stack(encoded, xt.shape[-2], -2)
52 | return self.dec(encoded, stack(xt, num_samples))
53 |
54 | def forward(self, batch, num_samples=None, reduce_ll=True):
55 | outs = AttrDict()
56 | if self.training:
57 | pz = self.lenc(batch.xc, batch.yc)
58 | qz = self.lenc(batch.x, batch.y)
59 | z = qz.rsample() if num_samples is None else \
60 | qz.rsample([num_samples])
61 | py = self.predict(batch.xc, batch.yc, batch.x,
62 | z=z, num_samples=num_samples)
63 |
64 | if num_samples > 1:
65 | # K * B * N
66 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
67 | # K * B
68 | log_qz = qz.log_prob(z).sum(-1)
69 | log_pz = pz.log_prob(z).sum(-1)
70 |
71 | # K * B
72 | log_w = recon.sum(-1) + log_pz - log_qz
73 |
74 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
75 | else:
76 | outs.recon = py.log_prob(batch.y).sum(-1).mean()
77 | outs.kld = kl_divergence(qz, pz).sum(-1).mean()
78 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]
79 |
80 | else:
81 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
82 | if num_samples is None:
83 | ll = py.log_prob(batch.y).sum(-1)
84 | else:
85 | y = torch.stack([batch.y]*num_samples)
86 | if reduce_ll:
87 | ll = logmeanexp(py.log_prob(y).sum(-1))
88 | else:
89 | ll = py.log_prob(y).sum(-1)
90 | num_ctx = batch.xc.shape[-2]
91 | if reduce_ll:
92 | outs.ctx_ll = ll[...,:num_ctx].mean()
93 | outs.tar_ll = ll[...,num_ctx:].mean()
94 | else:
95 | outs.ctx_ll = ll[...,:num_ctx]
96 | outs.tar_ll = ll[...,num_ctx:]
97 | return outs
98 |
--------------------------------------------------------------------------------
/regression/models/anp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions import kl_divergence
4 | from attrdict import AttrDict
5 |
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_subset
8 |
9 | from models.modules import CrossAttnEncoder, PoolingEncoder, Decoder
10 |
11 | class ANP(nn.Module):
12 | def __init__(self,
13 | dim_x=1,
14 | dim_y=1,
15 | dim_hid=128,
16 | dim_lat=128,
17 | enc_v_depth=4,
18 | enc_qk_depth=2,
19 | enc_pre_depth=4,
20 | enc_post_depth=2,
21 | dec_depth=3):
22 |
23 | super().__init__()
24 |
25 | self.denc = CrossAttnEncoder(
26 | dim_x=dim_x,
27 | dim_y=dim_y,
28 | dim_hid=dim_hid,
29 | v_depth=enc_v_depth,
30 | qk_depth=enc_qk_depth)
31 |
32 | self.lenc = PoolingEncoder(
33 | dim_x=dim_x,
34 | dim_y=dim_y,
35 | dim_hid=dim_hid,
36 | dim_lat=dim_lat,
37 | self_attn=True,
38 | pre_depth=enc_pre_depth,
39 | post_depth=enc_post_depth)
40 |
41 | self.dec = Decoder(
42 | dim_x=dim_x,
43 | dim_y=dim_y,
44 | dim_enc=dim_hid+dim_lat,
45 | dim_hid=dim_hid,
46 | depth=dec_depth)
47 |
48 | def predict(self, xc, yc, xt, z=None, num_samples=None):
49 | theta = stack(self.denc(xc, yc, xt), num_samples)
50 | if z is None:
51 | pz = self.lenc(xc, yc)
52 | z = pz.rsample() if num_samples is None \
53 | else pz.rsample([num_samples])
54 | z = stack(z, xt.shape[-2], -2)
55 | encoded = torch.cat([theta, z], -1)
56 | return self.dec(encoded, stack(xt, num_samples))
57 |
58 | def forward(self, batch, num_samples=None, reduce_ll=True):
59 | outs = AttrDict()
60 | if self.training:
61 | pz = self.lenc(batch.xc, batch.yc)
62 | qz = self.lenc(batch.x, batch.y)
63 | z = qz.rsample() if num_samples is None else \
64 | qz.rsample([num_samples])
65 | py = self.predict(batch.xc, batch.yc, batch.x,
66 | z=z, num_samples=num_samples)
67 |
68 | if num_samples > 1:
69 | # K * B * N
70 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
71 | # K * B
72 | log_qz = qz.log_prob(z).sum(-1)
73 | log_pz = pz.log_prob(z).sum(-1)
74 |
75 | # K * B
76 | log_w = recon.sum(-1) + log_pz - log_qz
77 |
78 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
79 | else:
80 | outs.recon = py.log_prob(batch.y).sum(-1).mean()
81 | outs.kld = kl_divergence(qz, pz).sum(-1).mean()
82 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]
83 |
84 | else:
85 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
86 | if num_samples is None:
87 | ll = py.log_prob(batch.y).sum(-1)
88 | else:
89 | y = torch.stack([batch.y]*num_samples)
90 | if reduce_ll:
91 | ll = logmeanexp(py.log_prob(y).sum(-1))
92 | else:
93 | ll = py.log_prob(y).sum(-1)
94 | num_ctx = batch.xc.shape[-2]
95 |
96 | if reduce_ll:
97 | outs.ctx_ll = ll[...,:num_ctx].mean()
98 | outs.tar_ll = ll[...,num_ctx:].mean()
99 | else:
100 | outs.ctx_ll = ll[...,:num_ctx]
101 | outs.tar_ll = ll[...,num_ctx:]
102 |
103 | return outs
104 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/anp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions import kl_divergence
4 | from attrdict import AttrDict
5 |
6 | from utils.misc import stack, logmeanexp
7 | from utils.sampling import sample_subset
8 |
9 | from models.modules import CrossAttnEncoder, PoolingEncoder, Decoder
10 |
11 | class ANP(nn.Module):
12 | def __init__(self,
13 | dim_x=1,
14 | dim_y=1,
15 | dim_hid=128,
16 | dim_lat=128,
17 | enc_v_depth=4,
18 | enc_qk_depth=2,
19 | enc_pre_depth=4,
20 | enc_post_depth=2,
21 | dec_depth=3):
22 |
23 | super().__init__()
24 |
25 | self.denc = CrossAttnEncoder(
26 | dim_x=dim_x,
27 | dim_y=dim_y,
28 | dim_hid=dim_hid,
29 | v_depth=enc_v_depth,
30 | qk_depth=enc_qk_depth)
31 |
32 | self.lenc = PoolingEncoder(
33 | dim_x=dim_x,
34 | dim_y=dim_y,
35 | dim_hid=dim_hid,
36 | dim_lat=dim_lat,
37 | self_attn=True,
38 | pre_depth=enc_pre_depth,
39 | post_depth=enc_post_depth)
40 |
41 | self.dec = Decoder(
42 | dim_x=dim_x,
43 | dim_y=dim_y,
44 | dim_enc=dim_hid+dim_lat,
45 | dim_hid=dim_hid,
46 | depth=dec_depth)
47 |
48 | def predict(self, xc, yc, xt, z=None, num_samples=None):
49 | theta = stack(self.denc(xc, yc, xt), num_samples)
50 | if z is None:
51 | pz = self.lenc(xc, yc)
52 | z = pz.rsample() if num_samples is None \
53 | else pz.rsample([num_samples])
54 | z = stack(z, xt.shape[-2], -2)
55 | encoded = torch.cat([theta, z], -1)
56 | return self.dec(encoded, stack(xt, num_samples))
57 |
58 | def forward(self, batch, num_samples=None, reduce_ll=True):
59 | outs = AttrDict()
60 | if self.training:
61 | pz = self.lenc(batch.xc, batch.yc)
62 | qz = self.lenc(batch.x, batch.y)
63 | z = qz.rsample() if num_samples is None else \
64 | qz.rsample([num_samples])
65 | py = self.predict(batch.xc, batch.yc, batch.x,
66 | z=z, num_samples=num_samples)
67 |
68 | if num_samples > 1:
69 | # K * B * N
70 | recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
71 | # K * B
72 | log_qz = qz.log_prob(z).sum(-1)
73 | log_pz = pz.log_prob(z).sum(-1)
74 |
75 | # K * B
76 | log_w = recon.sum(-1) + log_pz - log_qz
77 |
78 | outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
79 | else:
80 | outs.recon = py.log_prob(batch.y).sum(-1).mean()
81 | outs.kld = kl_divergence(qz, pz).sum(-1).mean()
82 | outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]
83 |
84 | else:
85 | py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
86 | if num_samples is None:
87 | ll = py.log_prob(batch.y).sum(-1)
88 | else:
89 | y = torch.stack([batch.y]*num_samples)
90 | if reduce_ll:
91 | ll = logmeanexp(py.log_prob(y).sum(-1))
92 | else:
93 | ll = py.log_prob(y).sum(-1)
94 | num_ctx = batch.xc.shape[-2]
95 |
96 | if reduce_ll:
97 | outs.ctx_ll = ll[...,:num_ctx].mean()
98 | outs.tar_ll = ll[...,num_ctx:].mean()
99 | else:
100 | outs.ctx_ll = ll[...,:num_ctx]
101 | outs.tar_ll = ll[...,num_ctx:]
102 |
103 | return outs
104 |
--------------------------------------------------------------------------------
/regression/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal
5 |
6 | from models.attention import MultiHeadAttn, SelfAttn
7 |
8 | __all__ = ['PoolingEncoder', 'CrossAttnEncoder', 'Decoder']
9 |
10 | def build_mlp(dim_in, dim_hid, dim_out, depth):
11 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)]
12 | for _ in range(depth-2):
13 | modules.append(nn.Linear(dim_hid, dim_hid))
14 | modules.append(nn.ReLU(True))
15 | modules.append(nn.Linear(dim_hid, dim_out))
16 | return nn.Sequential(*modules)
17 |
18 | class PoolingEncoder(nn.Module):
19 | def __init__(self, dim_x=1, dim_y=1,
20 | dim_hid=128, dim_lat=None, self_attn=False,
21 | pre_depth=4, post_depth=2):
22 | super().__init__()
23 |
24 | self.use_lat = dim_lat is not None
25 |
26 | self.net_pre = build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth) \
27 | if not self_attn else \
28 | nn.Sequential(
29 | build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth-2),
30 | nn.ReLU(True),
31 | SelfAttn(dim_hid, dim_hid))
32 |
33 | self.net_post = build_mlp(dim_hid, dim_hid,
34 | 2*dim_lat if self.use_lat else dim_hid,
35 | post_depth)
36 |
37 | def forward(self, xc, yc, mask=None):
38 | out = self.net_pre(torch.cat([xc, yc], -1))
39 | if mask is None:
40 | out = out.mean(-2)
41 | else:
42 | mask = mask.to(xc.device)
43 | out = (out * mask.unsqueeze(-1)).sum(-2) / \
44 | (mask.sum(-1, keepdim=True).detach() + 1e-5)
45 | if self.use_lat:
46 | mu, sigma = self.net_post(out).chunk(2, -1)
47 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma)
48 | return Normal(mu, sigma)
49 | else:
50 | return self.net_post(out)
51 |
52 | class CrossAttnEncoder(nn.Module):
53 | def __init__(self, dim_x=1, dim_y=1, dim_hid=128,
54 | dim_lat=None, self_attn=True,
55 | v_depth=4, qk_depth=2):
56 | super().__init__()
57 | self.use_lat = dim_lat is not None
58 |
59 | if not self_attn:
60 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth)
61 | else:
62 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth-2)
63 | self.self_attn = SelfAttn(dim_hid, dim_hid)
64 |
65 | self.net_qk = build_mlp(dim_x, dim_hid, dim_hid, qk_depth)
66 |
67 | self.attn = MultiHeadAttn(dim_hid, dim_hid, dim_hid,
68 | 2*dim_lat if self.use_lat else dim_hid)
69 |
70 | def forward(self, xc, yc, xt, mask=None):
71 | q, k = self.net_qk(xt), self.net_qk(xc)
72 | v = self.net_v(torch.cat([xc, yc], -1))
73 |
74 | if hasattr(self, 'self_attn'):
75 | v = self.self_attn(v, mask=mask)
76 |
77 | out = self.attn(q, k, v, mask=mask)
78 | if self.use_lat:
79 | mu, sigma = out.chunk(2, -1)
80 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma)
81 | return Normal(mu, sigma)
82 | else:
83 | return out
84 |
85 | class Decoder(nn.Module):
86 | def __init__(self, dim_x=1, dim_y=1,
87 | dim_enc=128, dim_hid=128, depth=3):
88 | super().__init__()
89 | self.fc = nn.Linear(dim_x+dim_enc, dim_hid)
90 | self.dim_hid = dim_hid
91 |
92 | modules = [nn.ReLU(True)]
93 | for _ in range(depth-2):
94 | modules.append(nn.Linear(dim_hid, dim_hid))
95 | modules.append(nn.ReLU(True))
96 | modules.append(nn.Linear(dim_hid, 2*dim_y))
97 | self.mlp = nn.Sequential(*modules)
98 |
99 | def add_ctx(self, dim_ctx):
100 | self.dim_ctx = dim_ctx
101 | self.fc_ctx = nn.Linear(dim_ctx, self.dim_hid, bias=False)
102 |
103 | def forward(self, encoded, x, ctx=None):
104 | packed = torch.cat([encoded, x], -1)
105 | hid = self.fc(packed)
106 | if ctx is not None:
107 | hid = hid + self.fc_ctx(ctx)
108 | out = self.mlp(hid)
109 | mu, sigma = out.chunk(2, -1)
110 | sigma = 0.1 + 0.9 * F.softplus(sigma)
111 | return Normal(mu, sigma)
112 |
--------------------------------------------------------------------------------
/bayesian_optimization/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal
5 |
6 | from models.attention import MultiHeadAttn, SelfAttn
7 |
8 | __all__ = ['PoolingEncoder', 'CrossAttnEncoder', 'Decoder']
9 |
10 | def build_mlp(dim_in, dim_hid, dim_out, depth):
11 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)]
12 | for _ in range(depth-2):
13 | modules.append(nn.Linear(dim_hid, dim_hid))
14 | modules.append(nn.ReLU(True))
15 | modules.append(nn.Linear(dim_hid, dim_out))
16 | return nn.Sequential(*modules)
17 |
18 | class PoolingEncoder(nn.Module):
19 | def __init__(self, dim_x=1, dim_y=1,
20 | dim_hid=128, dim_lat=None, self_attn=False,
21 | pre_depth=4, post_depth=2):
22 | super().__init__()
23 |
24 | self.use_lat = dim_lat is not None
25 |
26 | self.net_pre = build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth) \
27 | if not self_attn else \
28 | nn.Sequential(
29 | build_mlp(dim_x+dim_y, dim_hid, dim_hid, pre_depth-2),
30 | nn.ReLU(True),
31 | SelfAttn(dim_hid, dim_hid))
32 |
33 | self.net_post = build_mlp(dim_hid, dim_hid,
34 | 2*dim_lat if self.use_lat else dim_hid,
35 | post_depth)
36 |
37 | def forward(self, xc, yc, mask=None):
38 | out = self.net_pre(torch.cat([xc, yc], -1))
39 | if mask is None:
40 | out = out.mean(-2)
41 | else:
42 | mask = mask.to(xc.device)
43 | out = (out * mask.unsqueeze(-1)).sum(-2) / \
44 | (mask.sum(-1, keepdim=True).detach() + 1e-5)
45 | if self.use_lat:
46 | mu, sigma = self.net_post(out).chunk(2, -1)
47 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma)
48 | return Normal(mu, sigma)
49 | else:
50 | return self.net_post(out)
51 |
52 | class CrossAttnEncoder(nn.Module):
53 | def __init__(self, dim_x=1, dim_y=1, dim_hid=128,
54 | dim_lat=None, self_attn=True,
55 | v_depth=4, qk_depth=2):
56 | super().__init__()
57 | self.use_lat = dim_lat is not None
58 |
59 | if not self_attn:
60 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth)
61 | else:
62 | self.net_v = build_mlp(dim_x+dim_y, dim_hid, dim_hid, v_depth-2)
63 | self.self_attn = SelfAttn(dim_hid, dim_hid)
64 |
65 | self.net_qk = build_mlp(dim_x, dim_hid, dim_hid, qk_depth)
66 |
67 | self.attn = MultiHeadAttn(dim_hid, dim_hid, dim_hid,
68 | 2*dim_lat if self.use_lat else dim_hid)
69 |
70 | def forward(self, xc, yc, xt, mask=None):
71 | q, k = self.net_qk(xt), self.net_qk(xc)
72 | v = self.net_v(torch.cat([xc, yc], -1))
73 |
74 | if hasattr(self, 'self_attn'):
75 | v = self.self_attn(v, mask=mask)
76 |
77 | out = self.attn(q, k, v, mask=mask)
78 | if self.use_lat:
79 | mu, sigma = out.chunk(2, -1)
80 | sigma = 0.1 + 0.9 * torch.sigmoid(sigma)
81 | return Normal(mu, sigma)
82 | else:
83 | return out
84 |
85 | class Decoder(nn.Module):
86 | def __init__(self, dim_x=1, dim_y=1,
87 | dim_enc=128, dim_hid=128, depth=3):
88 | super().__init__()
89 | self.fc = nn.Linear(dim_x+dim_enc, dim_hid)
90 | self.dim_hid = dim_hid
91 |
92 | modules = [nn.ReLU(True)]
93 | for _ in range(depth-2):
94 | modules.append(nn.Linear(dim_hid, dim_hid))
95 | modules.append(nn.ReLU(True))
96 | modules.append(nn.Linear(dim_hid, 2*dim_y))
97 | self.mlp = nn.Sequential(*modules)
98 |
99 | def add_ctx(self, dim_ctx):
100 | self.dim_ctx = dim_ctx
101 | self.fc_ctx = nn.Linear(dim_ctx, self.dim_hid, bias=False)
102 |
103 | def forward(self, encoded, x, ctx=None):
104 | packed = torch.cat([encoded, x], -1)
105 | hid = self.fc(packed)
106 | if ctx is not None:
107 | hid = hid + self.fc_ctx(ctx)
108 | out = self.mlp(hid)
109 | mu, sigma = out.chunk(2, -1)
110 | sigma = 0.1 + 0.9 * F.softplus(sigma)
111 | return Normal(mu, sigma)
112 |
--------------------------------------------------------------------------------
/regression/data/gp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import MultivariateNormal, StudentT
5 | from attrdict import AttrDict
6 | import math
7 |
8 | __all__ = ['GPSampler', 'RBFKernel', 'PeriodicKernel', 'Matern52Kernel']
9 |
10 | class GPSampler(object):
11 | def __init__(self, kernel, t_noise=None):
12 | self.kernel = kernel
13 | self.t_noise = t_noise
14 |
15 | def sample(self,
16 | batch_size=16,
17 | num_ctx=None,
18 | max_num_points=50,
19 | x_range=(-2, 2),
20 | device='cpu'):
21 |
22 | batch = AttrDict()
23 | num_ctx = num_ctx or torch.randint(low=3, high=max_num_points-3, size=[1]).item()
24 | num_tar = torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item()
25 |
26 | num_points = num_ctx + num_tar
27 | batch.x = x_range[0] + (x_range[1] - x_range[0]) \
28 | * torch.rand([batch_size, num_points, 1], device=device)
29 | batch.xc = batch.x[:,:num_ctx]
30 | batch.xt = batch.x[:,num_ctx:]
31 |
32 | # batch_size * num_points * num_points
33 | cov = self.kernel(batch.x)
34 | mean = torch.zeros(batch_size, num_points, device=device)
35 | batch.y = MultivariateNormal(mean, cov).rsample().unsqueeze(-1)
36 | batch.yc = batch.y[:,:num_ctx]
37 | batch.yt = batch.y[:,num_ctx:]
38 |
39 | if self.t_noise is not None:
40 | if self.t_noise == -1:
41 | t_noise = 0.15 * torch.rand(batch.y.shape).to(device)
42 | else:
43 | t_noise = self.t_noise
44 | batch.y += t_noise * StudentT(2.1).rsample(batch.y.shape).to(device)
45 | return batch
46 |
47 | class RBFKernel(object):
48 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
49 | self.sigma_eps = sigma_eps
50 | self.max_length = max_length
51 | self.max_scale = max_scale
52 |
53 | # x: batch_size * num_points * dim
54 | def __call__(self, x):
55 | length = 0.1 + (self.max_length-0.1) \
56 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
57 | scale = 0.1 + (self.max_scale-0.1) \
58 | * torch.rand([x.shape[0], 1, 1], device=x.device)
59 |
60 | # batch_size * num_points * num_points * dim
61 | dist = (x.unsqueeze(-2) - x.unsqueeze(-3))/length
62 |
63 | # batch_size * num_points * num_points
64 | cov = scale.pow(2) * torch.exp(-0.5 * dist.pow(2).sum(-1)) \
65 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)
66 |
67 | return cov
68 |
69 | class Matern52Kernel(object):
70 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
71 | self.sigma_eps = sigma_eps
72 | self.max_length = max_length
73 | self.max_scale = max_scale
74 |
75 | # x: batch_size * num_points * dim
76 | def __call__(self, x):
77 | length = 0.1 + (self.max_length-0.1) \
78 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
79 | scale = 0.1 + (self.max_scale-0.1) \
80 | * torch.rand([x.shape[0], 1, 1], device=x.device)
81 |
82 | # batch_size * num_points * num_points
83 | dist = torch.norm((x.unsqueeze(-2) - x.unsqueeze(-3))/length, dim=-1)
84 |
85 | cov = scale.pow(2)*(1 + math.sqrt(5.0)*dist + 5.0*dist.pow(2)/3.0) \
86 | * torch.exp(-math.sqrt(5.0) * dist) \
87 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)
88 |
89 | return cov
90 |
91 | class PeriodicKernel(object):
92 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
93 | #self.p = p
94 | self.sigma_eps = sigma_eps
95 | self.max_length = max_length
96 | self.max_scale = max_scale
97 |
98 | # x: batch_size * num_points * dim
99 | def __call__(self, x):
100 | p = 0.1 + 0.4*torch.rand([x.shape[0], 1, 1], device=x.device)
101 | length = 0.1 + (self.max_length-0.1) \
102 | * torch.rand([x.shape[0], 1, 1], device=x.device)
103 | scale = 0.1 + (self.max_scale-0.1) \
104 | * torch.rand([x.shape[0], 1, 1], device=x.device)
105 |
106 | dist = x.unsqueeze(-2) - x.unsqueeze(-3)
107 | cov = scale.pow(2) * torch.exp(\
108 | - 2*(torch.sin(math.pi*dist.abs().sum(-1)/p)/length).pow(2)) \
109 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)
110 |
111 | return cov
112 |
--------------------------------------------------------------------------------
/bayesian_optimization/data/gp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import MultivariateNormal, StudentT
5 | from attrdict import AttrDict
6 | import math
7 |
8 | __all__ = ['GPPriorSampler', 'GPSampler', 'RBFKernel', 'PeriodicKernel', 'Matern52Kernel']
9 |
10 | class GPPriorSampler(object):
11 | def __init__(self, kernel, t_noise=None):
12 | self.kernel = kernel
13 | self.t_noise = t_noise
14 |
15 | def sample(self,
16 | bx,
17 | device='cuda:0'):
18 | # bx: 1 * num_points * 1
19 |
20 | # 1 * num_points * num_points
21 | cov = self.kernel(bx)
22 | mean = torch.zeros(1, bx.shape[1], device=device)
23 | mean = mean.cuda()
24 |
25 | by = MultivariateNormal(mean, cov).rsample().unsqueeze(-1)
26 |
27 | if self.t_noise is not None:
28 | by += self.t_noise * StudentT(2.1).rsample(by.shape).to(device)
29 |
30 | return by
31 |
32 | class GPSampler(object):
33 | def __init__(self, kernel, t_noise=None):
34 | self.kernel = kernel
35 | self.t_noise = t_noise
36 |
37 | def sample(self,
38 | batch_size=16,
39 | num_ctx=None,
40 | max_num_points=50,
41 | x_range=(-2, 2),
42 | device='cpu'):
43 |
44 | batch = AttrDict()
45 | num_ctx = num_ctx or torch.randint(low=3, high=max_num_points-3, size=[1]).item()
46 | num_tar = torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item()
47 |
48 | num_points = num_ctx + num_tar
49 | batch.x = x_range[0] + (x_range[1] - x_range[0]) \
50 | * torch.rand([batch_size, num_points, 1], device=device)
51 | batch.xc = batch.x[:,:num_ctx]
52 | batch.xt = batch.x[:,num_ctx:]
53 |
54 | # batch_size * num_points * num_points
55 | cov = self.kernel(batch.x)
56 | mean = torch.zeros(batch_size, num_points, device=device)
57 | batch.y = MultivariateNormal(mean, cov).rsample().unsqueeze(-1)
58 | batch.yc = batch.y[:,:num_ctx]
59 | batch.yt = batch.y[:,num_ctx:]
60 |
61 | if self.t_noise is not None:
62 | batch.y += self.t_noise * StudentT(2.1).rsample(batch.y.shape).to(device)
63 | return batch
64 |
65 | class RBFKernel(object):
66 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
67 | self.sigma_eps = sigma_eps
68 | self.max_length = max_length
69 | self.max_scale = max_scale
70 |
71 | # x: batch_size * num_points * dim
72 | def __call__(self, x):
73 | length = 0.1 + (self.max_length-0.1) \
74 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
75 | scale = 0.1 + (self.max_scale-0.1) \
76 | * torch.rand([x.shape[0], 1, 1], device=x.device)
77 |
78 | # batch_size * num_points * num_points * dim
79 | dist = (x.unsqueeze(-2) - x.unsqueeze(-3))/length
80 |
81 | # batch_size * num_points * num_points
82 | cov = scale.pow(2) * torch.exp(-0.5 * dist.pow(2).sum(-1)) \
83 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)
84 |
85 | return cov
86 |
87 | class Matern52Kernel(object):
88 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
89 | self.sigma_eps = sigma_eps
90 | self.max_length = max_length
91 | self.max_scale = max_scale
92 |
93 | # x: batch_size * num_points * dim
94 | def __call__(self, x):
95 | length = 0.1 + (self.max_length-0.1) \
96 | * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
97 | scale = 0.1 + (self.max_scale-0.1) \
98 | * torch.rand([x.shape[0], 1, 1], device=x.device)
99 |
100 | # batch_size * num_points * num_points
101 | dist = torch.norm((x.unsqueeze(-2) - x.unsqueeze(-3))/length, dim=-1)
102 |
103 | cov = scale.pow(2)*(1 + math.sqrt(5.0)*dist + 5.0*dist.pow(2)/3.0) \
104 | * torch.exp(-math.sqrt(5.0) * dist) \
105 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)
106 |
107 | return cov
108 |
109 | class PeriodicKernel(object):
110 | def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
111 | #self.p = p
112 | self.sigma_eps = sigma_eps
113 | self.max_length = max_length
114 | self.max_scale = max_scale
115 |
116 | # x: batch_size * num_points * dim
117 | def __call__(self, x):
118 | p = 0.1 + 0.4*torch.rand([x.shape[0], 1, 1], device=x.device)
119 | length = 0.1 + (self.max_length-0.1) \
120 | * torch.rand([x.shape[0], 1, 1], device=x.device)
121 | scale = 0.1 + (self.max_scale-0.1) \
122 | * torch.rand([x.shape[0], 1, 1], device=x.device)
123 |
124 | dist = x.unsqueeze(-2) - x.unsqueeze(-3)
125 | cov = scale.pow(2) * torch.exp(\
126 | - 2*(torch.sin(math.pi*dist.abs().sum(-1)/p)/length).pow(2)) \
127 | + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)
128 |
129 | return cov
130 |
--------------------------------------------------------------------------------
/regression/data/lotka_volterra.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import numpy.random as npr
4 | import numba as nb
5 | from tqdm import tqdm
6 | from attrdict import AttrDict
7 | #import pandas as pd
8 | import wget
9 |
10 | import os.path as osp
11 | from utils.paths import datasets_path
12 |
13 | @nb.njit(nb.i4(nb.f8[:]))
14 | def catrnd(prob):
15 | cprob = prob.cumsum()
16 | u = npr.rand()
17 | for i in range(len(cprob)):
18 | if u < cprob[i]:
19 | return i
20 | return i
21 |
22 | @nb.njit(nb.types.Tuple((nb.f8[:,:,:], nb.f8[:,:,:], nb.i4)) \
23 | (nb.i4, nb.i4, nb.i4, \
24 | nb.f8, nb.f8, nb.f8, nb.f8, nb.f8, nb.f8))
25 | def _simulate_task(batch_size, num_steps, max_num_points,
26 | X0, Y0, theta0, theta1, theta2, theta3):
27 |
28 | time = np.zeros((batch_size, num_steps, 1))
29 | pop = np.zeros((batch_size, num_steps, 2))
30 | length = num_steps*np.ones((batch_size))
31 |
32 | for b in range(batch_size):
33 | pop[b,0,0] = max(int(X0 + npr.randn()), 1)
34 | pop[b,0,1] = max(int(Y0 + npr.randn()), 1)
35 | for i in range(1, num_steps):
36 | X, Y = pop[b,i-1,0], pop[b,i-1,1]
37 | rates = np.array([
38 | theta0*X*Y,
39 | theta1*X,
40 | theta2*Y,
41 | theta3*X*Y])
42 | total_rate = rates.sum()
43 |
44 | time[b,i,0] = time[b,i-1,0] + npr.exponential(scale=1./total_rate)
45 |
46 | pop[b,i,0] = pop[b,i-1,0]
47 | pop[b,i,1] = pop[b,i-1,1]
48 | a = catrnd(rates/total_rate)
49 | if a == 0:
50 | pop[b,i,0] += 1
51 | elif a == 1:
52 | pop[b,i,0] -= 1
53 | elif a == 2:
54 | pop[b,i,1] += 1
55 | else:
56 | pop[b,i,1] -= 1
57 |
58 | if pop[b,i,0] == 0 or pop[b,i,1] == 0:
59 | length[b] = i+1
60 | break
61 |
62 | num_ctx = npr.randint(15, max_num_points-15)
63 | num_tar = npr.randint(15, max_num_points-num_ctx)
64 | num_points = num_ctx + num_tar
65 | min_length = length.min()
66 | while num_points > min_length:
67 | num_ctx = npr.randint(15, max_num_points-15)
68 | num_tar = npr.randint(15, max_num_points-num_ctx)
69 | num_points = num_ctx + num_tar
70 |
71 | x = np.zeros((batch_size, num_points, 1))
72 | y = np.zeros((batch_size, num_points, 2))
73 | for b in range(batch_size):
74 | idxs = np.arange(int(length[b]))
75 | npr.shuffle(idxs)
76 | for j in range(num_points):
77 | x[b,j,0] = time[b,idxs[j],0]
78 | y[b,j,0] = pop[b,idxs[j],0]
79 | y[b,j,1] = pop[b,idxs[j],1]
80 |
81 | return x, y, num_ctx
82 |
83 | class LotkaVolterraSimulator(object):
84 | def __init__(self,
85 | X0=50,
86 | Y0=100,
87 | theta0=0.01,
88 | theta1=0.5,
89 | theta2=1.0,
90 | theta3=0.01):
91 |
92 | self.X0 = X0
93 | self.Y0 = Y0
94 | self.theta0 = theta0
95 | self.theta1 = theta1
96 | self.theta2 = theta2
97 | self.theta3 = theta3
98 |
99 | def simulate_tasks(self,
100 | num_batches,
101 | batch_size,
102 | num_steps=20000,
103 | max_num_points=100):
104 |
105 | batches = []
106 | for _ in tqdm(range(num_batches)):
107 | batch = AttrDict()
108 | x, y, num_ctx = _simulate_task(
109 | batch_size, num_steps, max_num_points,
110 | self.X0, self.Y0, self.theta0, self.theta1, self.theta2, self.theta3)
111 | batch.x = torch.Tensor(x)
112 | batch.y = torch.Tensor(y)
113 | batch.xc = batch.x[:,:num_ctx]
114 | batch.xt = batch.x[:,num_ctx:]
115 | batch.yc = batch.y[:,:num_ctx]
116 | batch.yt = batch.y[:,num_ctx:]
117 |
118 | batches.append(batch)
119 |
120 | return batches
121 |
122 | def load_hare_lynx(num_batches, batch_size):
123 |
124 | filename = osp.join(datasets_path, 'lotka_volterra', 'LynxHare.txt')
125 | if not osp.isfile(filename):
126 | wget.download('http://people.whitman.edu/~hundledr/courses/M250F03/LynxHare.txt',
127 | out=osp.join(datsets_path, 'lotka_volterra'))
128 |
129 | tb = np.loadtxt(filename)
130 | times = torch.Tensor(tb[:,0]).unsqueeze(-1)
131 | pops = torch.stack([torch.Tensor(tb[:,2]), torch.Tensor(tb[:,1])], -1)
132 |
133 | #tb = pd.read_csv(osp.join(datasets_path, 'lotka_volterra', 'hare-lynx.csv'))
134 | #times = torch.Tensor(np.array(tb['time'])).unsqueeze(-1)
135 | #pops = torch.stack([torch.Tensor(np.array(tb['lynx'])),
136 | # torch.Tensor(np.array(tb['hare']))], -1)
137 |
138 | batches = []
139 | N = pops.shape[-2]
140 | for _ in range(num_batches):
141 | batch = AttrDict()
142 |
143 | num_ctx = torch.randint(low=15, high=N-15, size=[1]).item()
144 | num_tar = N - num_ctx
145 |
146 | idxs = torch.rand(batch_size, N).argsort(-1)
147 |
148 | batch.x = torch.gather(
149 | torch.stack([times]*batch_size),
150 | -2, idxs.unsqueeze(-1))
151 | batch.y = torch.gather(torch.stack([pops]*batch_size),
152 | -2, torch.stack([idxs]*2, -1))
153 | batch.xc = batch.x[:,:num_ctx]
154 | batch.xt = batch.x[:,num_ctx:]
155 | batch.yc = batch.y[:,:num_ctx]
156 | batch.yt = batch.y[:,num_ctx:]
157 |
158 | batches.append(batch)
159 |
160 | return batches
161 |
162 | if __name__ == '__main__':
163 | import argparse
164 | import os
165 | from utils.paths import datasets_path
166 | import matplotlib.pyplot as plt
167 |
168 | parser = argparse.ArgumentParser()
169 | parser.add_argument('--num_batches', type=int, default=10000)
170 | parser.add_argument('--batch_size', type=int, default=50)
171 | parser.add_argument('--filename', type=str, default='batch')
172 | parser.add_argument('--X0', type=float, default=50)
173 | parser.add_argument('--Y0', type=float, default=100)
174 | parser.add_argument('--theta0', type=float, default=0.01)
175 | parser.add_argument('--theta1', type=float, default=0.5)
176 | parser.add_argument('--theta2', type=float, default=1.0)
177 | parser.add_argument('--theta3', type=float, default=0.01)
178 | parser.add_argument('--num_steps', type=int, default=20000)
179 | args = parser.parse_args()
180 |
181 | sim = LotkaVolterraSimulator(X0=args.X0, Y0=args.Y0,
182 | theta0=args.theta0, theta1=args.theta1,
183 | theta2=args.theta2, theta3=args.theta3)
184 |
185 | batches = sim.simulate_tasks(args.num_batches, args.batch_size,
186 | num_steps=args.num_steps)
187 |
188 | root = os.path.join(datasets_path, 'lotka_volterra')
189 | if not os.path.isdir(root):
190 | os.makedirs(root)
191 |
192 | torch.save(batches, os.path.join(root, f'{args.filename}.tar'))
193 |
194 | fig, axes = plt.subplots(1, 4, figsize=(16,4))
195 | for i, ax in enumerate(axes.flatten()):
196 | ax.scatter(batches[0].x[i,:,0], batches[0].y[i,:,0])
197 | ax.scatter(batches[0].x[i,:,0], batches[0].y[i,:,1])
198 | plt.show()
199 |
--------------------------------------------------------------------------------
/regression/celeba.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import argparse
5 | import yaml
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import math
11 | import time
12 | import matplotlib.pyplot as plt
13 | from attrdict import AttrDict
14 | from tqdm import tqdm
15 | from copy import deepcopy
16 |
17 | from data.image import img_to_task, task_to_img
18 | from data.celeba import CelebA
19 |
20 | from utils.misc import load_module, logmeanexp
21 | from utils.paths import results_path, evalsets_path
22 | from utils.log import get_logger, RunningAverage
23 |
24 | def main():
25 | parser = argparse.ArgumentParser()
26 |
27 | parser.add_argument('--mode',
28 | choices=['train', 'eval', 'plot', 'ensemble'],
29 | default='train')
30 | parser.add_argument('--expid', type=str, default='trial')
31 | parser.add_argument('--resume', action='store_true', default=False)
32 | parser.add_argument('--gpu', type=str, default='0')
33 |
34 | parser.add_argument('--max_num_points', type=int, default=200)
35 |
36 | parser.add_argument('--model', type=str, default='cnp')
37 | parser.add_argument('--train_batch_size', type=int, default=100)
38 | parser.add_argument('--train_num_samples', type=int, default=4)
39 |
40 | parser.add_argument('--lr', type=float, default=5e-4)
41 | parser.add_argument('--num_epochs', type=int, default=200)
42 | parser.add_argument('--eval_freq', type=int, default=10)
43 | parser.add_argument('--save_freq', type=int, default=10)
44 |
45 | parser.add_argument('--eval_seed', type=int, default=42)
46 | parser.add_argument('--eval_batch_size', type=int, default=16)
47 | parser.add_argument('--eval_num_samples', type=int, default=50)
48 | parser.add_argument('--eval_logfile', type=str, default=None)
49 |
50 | parser.add_argument('--plot_seed', type=int, default=None)
51 | parser.add_argument('--plot_batch_size', type=int, default=16)
52 | parser.add_argument('--plot_num_samples', type=int, default=30)
53 | parser.add_argument('--plot_num_ctx', type=int, default=100)
54 |
55 | # OOD settings
56 | parser.add_argument('--t_noise', type=float, default=None)
57 |
58 | args = parser.parse_args()
59 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
60 |
61 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper())
62 | with open(f'configs/celeba/{args.model}.yaml', 'r') as f:
63 | config = yaml.safe_load(f)
64 | model = model_cls(**config).cuda()
65 |
66 | args.root = osp.join(results_path, 'celeba', args.model, args.expid)
67 |
68 | if args.mode == 'train':
69 | train(args, model)
70 | elif args.mode == 'eval':
71 | eval(args, model)
72 | elif args.mode == 'plot':
73 | plot(args, model)
74 | elif args.mode == 'ensemble':
75 | ensemble(args, model)
76 |
77 | def train(args, model):
78 | if not osp.isdir(args.root):
79 | os.makedirs(args.root)
80 |
81 | with open(osp.join(args.root, 'args.yaml'), 'w') as f:
82 | yaml.dump(args.__dict__, f)
83 |
84 | train_ds = CelebA(train=True)
85 | eval_ds = CelebA(train=False)
86 | train_loader = torch.utils.data.DataLoader(train_ds,
87 | batch_size=args.train_batch_size,
88 | shuffle=True, num_workers=4)
89 | #eval_loader = torch.utils.data.DataLoader(eval_ds,
90 | # batch_size=args.eval_batch_size,
91 | # shuffle=False, num_workers=4)
92 |
93 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
94 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
95 | optimizer, T_max=len(train_loader)*args.num_epochs)
96 |
97 | if args.resume:
98 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
99 | model.load_state_dict(ckpt.model)
100 | optimizer.load_state_dict(ckpt.optimizer)
101 | scheduler.load_state_dict(ckpt.scheduler)
102 | logfilename = ckpt.logfilename
103 | start_epoch = ckpt.epoch
104 | else:
105 | logfilename = osp.join(args.root, 'train_{}.log'.format(
106 | time.strftime('%Y%m%d-%H%M')))
107 | start_epoch = 1
108 |
109 | logger = get_logger(logfilename)
110 | ravg = RunningAverage()
111 |
112 | if not args.resume:
113 | logger.info('Total number of parameters: {}\n'.format(
114 | sum(p.numel() for p in model.parameters())))
115 |
116 | for epoch in range(start_epoch, args.num_epochs+1):
117 | model.train()
118 | for (x, _) in tqdm(train_loader):
119 | batch = img_to_task(x,
120 | max_num_points=args.max_num_points,
121 | device='cuda')
122 | optimizer.zero_grad()
123 | outs = model(batch, num_samples=args.train_num_samples)
124 | outs.loss.backward()
125 | optimizer.step()
126 | scheduler.step()
127 |
128 | for key, val in outs.items():
129 | ravg.update(key, val)
130 |
131 | line = f'{args.model}:{args.expid} epoch {epoch} '
132 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} '
133 | line += ravg.info()
134 | logger.info(line)
135 |
136 | if epoch % args.eval_freq == 0:
137 | logger.info(eval(args, model) + '\n')
138 |
139 | ravg.reset()
140 |
141 | if epoch % args.save_freq == 0 or epoch == args.num_epochs:
142 | ckpt = AttrDict()
143 | ckpt.model = model.state_dict()
144 | ckpt.optimizer = optimizer.state_dict()
145 | ckpt.scheduler = scheduler.state_dict()
146 | ckpt.logfilename = logfilename
147 | ckpt.epoch = epoch + 1
148 | torch.save(ckpt, osp.join(args.root, 'ckpt.tar'))
149 |
150 | args.mode = 'eval'
151 | eval(args, model)
152 |
153 | def gen_evalset(args):
154 |
155 | torch.manual_seed(args.eval_seed)
156 | torch.cuda.manual_seed(args.eval_seed)
157 |
158 | eval_ds = CelebA(train=False)
159 | eval_loader = torch.utils.data.DataLoader(eval_ds,
160 | batch_size=args.eval_batch_size,
161 | shuffle=False, num_workers=4)
162 |
163 | batches = []
164 | for x, _ in tqdm(eval_loader):
165 | batches.append(img_to_task(x,
166 | t_noise=args.t_noise,
167 | max_num_points=args.max_num_points))
168 |
169 | torch.manual_seed(time.time())
170 | torch.cuda.manual_seed(time.time())
171 |
172 | path = osp.join(evalsets_path, 'celeba')
173 | if not osp.isdir(path):
174 | os.makedirs(path)
175 |
176 | filename = 'no_noise.tar' if args.t_noise is None else \
177 | f'{args.t_noise}.tar'
178 | torch.save(batches, osp.join(path, filename))
179 |
180 | def eval(args, model):
181 | if args.mode == 'eval':
182 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
183 | model.load_state_dict(ckpt.model)
184 | if args.eval_logfile is None:
185 | eval_logfile = f'eval'
186 | if args.t_noise is not None:
187 | eval_logfile += f'_{args.t_noise}'
188 | eval_logfile += '.log'
189 | else:
190 | eval_logfile = args.eval_logfile
191 | filename = osp.join(args.root, eval_logfile)
192 | logger = get_logger(filename, mode='w')
193 | else:
194 | logger = None
195 |
196 | path = osp.join(evalsets_path, 'celeba')
197 | if not osp.isdir(path):
198 | os.makedirs(path)
199 | filename = 'no_noise.tar' if args.t_noise is None else \
200 | f'{args.t_noise}.tar'
201 | if not osp.isfile(osp.join(path, filename)):
202 | print('generating evaluation sets...')
203 | gen_evalset(args)
204 |
205 | eval_batches = torch.load(osp.join(path, filename))
206 |
207 | torch.manual_seed(args.eval_seed)
208 | torch.cuda.manual_seed(args.eval_seed)
209 |
210 | ravg = RunningAverage()
211 | model.eval()
212 | with torch.no_grad():
213 | for batch in tqdm(eval_batches):
214 | for key, val in batch.items():
215 | batch[key] = val.cuda()
216 | outs = model(batch, num_samples=args.eval_num_samples)
217 | for key, val in outs.items():
218 | ravg.update(key, val)
219 |
220 | torch.manual_seed(time.time())
221 | torch.cuda.manual_seed(time.time())
222 |
223 | line = f'{args.model}:{args.expid} '
224 | if args.t_noise is not None:
225 | line += f'tn {args.t_noise} '
226 | line += ravg.info()
227 |
228 | if logger is not None:
229 | logger.info(line)
230 |
231 | return line
232 |
233 | def ensemble(args, model):
234 | num_runs = 5
235 | models = []
236 | for i in range(num_runs):
237 | model_ = deepcopy(model)
238 | ckpt = torch.load(osp.join(results_path, 'celeba', args.model, f'run{i+1}', 'ckpt.tar'))
239 | model_.load_state_dict(ckpt['model'])
240 | model_.cuda()
241 | model_.eval()
242 | models.append(model_)
243 |
244 | path = osp.join(evalsets_path, 'celeba')
245 | if not osp.isdir(path):
246 | os.makedirs(path)
247 | filename = 'no_noise.tar' if args.t_noise is None else \
248 | f'{args.t_noise}.tar'
249 | if not osp.isfile(osp.join(path, filename)):
250 | print('generating evaluation sets...')
251 | gen_evalset(args)
252 |
253 | eval_batches = torch.load(osp.join(path, filename))
254 |
255 | ravg = RunningAverage()
256 | with torch.no_grad():
257 | for batch in tqdm(eval_batches):
258 | for key, val in batch.items():
259 | batch[key] = val.cuda()
260 |
261 | ctx_ll = []
262 | tar_ll = []
263 | for model in models:
264 | outs = model(batch,
265 | num_samples=args.eval_num_samples,
266 | reduce_ll=False)
267 | ctx_ll.append(outs.ctx_ll)
268 | tar_ll.append(outs.tar_ll)
269 |
270 | if ctx_ll[0].dim() == 2:
271 | ctx_ll = torch.stack(ctx_ll)
272 | tar_ll = torch.stack(tar_ll)
273 | else:
274 | ctx_ll = torch.cat(ctx_ll)
275 | tar_ll = torch.cat(tar_ll)
276 |
277 | ctx_ll = logmeanexp(ctx_ll).mean()
278 | tar_ll = logmeanexp(tar_ll).mean()
279 |
280 | ravg.update('ctx_ll', ctx_ll)
281 | ravg.update('tar_ll', tar_ll)
282 |
283 | torch.manual_seed(time.time())
284 | torch.cuda.manual_seed(time.time())
285 |
286 | filename = f'ensemble'
287 | if args.t_noise is not None:
288 | filename += f'_{args.t_noise}'
289 | filename += '.log'
290 | logger = get_logger(osp.join(results_path, 'celeba', args.model, filename), mode='w')
291 | logger.info(ravg.info())
292 |
293 | if __name__ == '__main__':
294 | main()
295 |
--------------------------------------------------------------------------------
/regression/emnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import argparse
5 | import yaml
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import math
11 | import time
12 | import matplotlib.pyplot as plt
13 | from attrdict import AttrDict
14 | from tqdm import tqdm
15 | from copy import deepcopy
16 |
17 | from data.image import img_to_task, task_to_img
18 | from data.emnist import EMNIST
19 |
20 | from utils.misc import load_module, logmeanexp
21 | from utils.paths import results_path, evalsets_path
22 | from utils.log import get_logger, RunningAverage
23 |
24 | def main():
25 | parser = argparse.ArgumentParser()
26 |
27 | parser.add_argument('--mode',
28 | choices=['train', 'eval', 'plot', 'ensemble'],
29 | default='train')
30 | parser.add_argument('--expid', type=str, default='trial')
31 | parser.add_argument('--resume', action='store_true', default=False)
32 | parser.add_argument('--gpu', type=str, default='0')
33 |
34 | parser.add_argument('--max_num_points', type=int, default=200)
35 | parser.add_argument('--class_range', type=int, nargs='*', default=[0,10])
36 |
37 | parser.add_argument('--model', type=str, default='cnp')
38 | parser.add_argument('--train_batch_size', type=int, default=100)
39 | parser.add_argument('--train_num_samples', type=int, default=4)
40 |
41 | parser.add_argument('--lr', type=float, default=5e-4)
42 | parser.add_argument('--num_epochs', type=int, default=200)
43 | parser.add_argument('--eval_freq', type=int, default=10)
44 | parser.add_argument('--save_freq', type=int, default=10)
45 |
46 | parser.add_argument('--eval_seed', type=int, default=42)
47 | parser.add_argument('--eval_batch_size', type=int, default=16)
48 | parser.add_argument('--eval_num_samples', type=int, default=50)
49 | parser.add_argument('--eval_logfile', type=str, default=None)
50 |
51 | parser.add_argument('--plot_seed', type=int, default=None)
52 | parser.add_argument('--plot_batch_size', type=int, default=16)
53 | parser.add_argument('--plot_num_samples', type=int, default=30)
54 | parser.add_argument('--plot_num_ctx', type=int, default=100)
55 |
56 | # OOD settings
57 | parser.add_argument('--t_noise', type=float, default=None)
58 |
59 | args = parser.parse_args()
60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
61 |
62 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper())
63 | with open(f'configs/emnist/{args.model}.yaml', 'r') as f:
64 | config = yaml.safe_load(f)
65 | model = model_cls(**config).cuda()
66 |
67 | args.root = osp.join(results_path, 'emnist', args.model, args.expid)
68 |
69 | if args.mode == 'train':
70 | train(args, model)
71 | elif args.mode == 'eval':
72 | eval(args, model)
73 | elif args.mode == 'plot':
74 | plot(args, model)
75 | elif args.mode == 'ensemble':
76 | ensemble(args, model)
77 |
78 | def train(args, model):
79 | if not osp.isdir(args.root):
80 | os.makedirs(args.root)
81 |
82 | with open(osp.join(args.root, 'args.yaml'), 'w') as f:
83 | yaml.dump(args.__dict__, f)
84 |
85 | train_ds = EMNIST(train=True, class_range=args.class_range)
86 | eval_ds = EMNIST(train=False, class_range=args.class_range)
87 | train_loader = torch.utils.data.DataLoader(train_ds,
88 | batch_size=args.train_batch_size,
89 | shuffle=True, num_workers=4)
90 |
91 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
92 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
93 | optimizer, T_max=len(train_loader)*args.num_epochs)
94 |
95 | if args.resume:
96 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
97 | model.load_state_dict(ckpt.model)
98 | optimizer.load_state_dict(ckpt.optimizer)
99 | scheduler.load_state_dict(ckpt.scheduler)
100 | logfilename = ckpt.logfilename
101 | start_epoch = ckpt.epoch
102 | else:
103 | logfilename = osp.join(args.root, 'train_{}.log'.format(
104 | time.strftime('%Y%m%d-%H%M')))
105 | start_epoch = 1
106 |
107 | logger = get_logger(logfilename)
108 | ravg = RunningAverage()
109 |
110 | if not args.resume:
111 | logger.info('Total number of parameters: {}\n'.format(
112 | sum(p.numel() for p in model.parameters())))
113 |
114 | for epoch in range(start_epoch, args.num_epochs+1):
115 | model.train()
116 | for (x, _) in tqdm(train_loader):
117 | batch = img_to_task(x,
118 | max_num_points=args.max_num_points,
119 | device='cuda')
120 | optimizer.zero_grad()
121 | outs = model(batch, num_samples=args.train_num_samples)
122 | outs.loss.backward()
123 | optimizer.step()
124 | scheduler.step()
125 |
126 | for key, val in outs.items():
127 | ravg.update(key, val)
128 |
129 | line = f'{args.model}:{args.expid} epoch {epoch} '
130 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} '
131 | line += ravg.info()
132 | logger.info(line)
133 |
134 | if epoch % args.eval_freq == 0:
135 | logger.info(eval(args, model) + '\n')
136 |
137 | ravg.reset()
138 |
139 | if epoch % args.save_freq == 0 or epoch == args.num_epochs:
140 | ckpt = AttrDict()
141 | ckpt.model = model.state_dict()
142 | ckpt.optimizer = optimizer.state_dict()
143 | ckpt.scheduler = scheduler.state_dict()
144 | ckpt.logfilename = logfilename
145 | ckpt.epoch = epoch + 1
146 | torch.save(ckpt, osp.join(args.root, 'ckpt.tar'))
147 |
148 | args.mode = 'eval'
149 | eval(args, model)
150 |
151 | def gen_evalset(args):
152 |
153 | torch.manual_seed(args.eval_seed)
154 | torch.cuda.manual_seed(args.eval_seed)
155 |
156 | eval_ds = EMNIST(train=False, class_range=args.class_range)
157 | eval_loader = torch.utils.data.DataLoader(eval_ds,
158 | batch_size=args.eval_batch_size,
159 | shuffle=False, num_workers=4)
160 |
161 | batches = []
162 | for x, _ in tqdm(eval_loader):
163 | batches.append(img_to_task(x,
164 | t_noise=args.t_noise,
165 | max_num_points=args.max_num_points))
166 |
167 | torch.manual_seed(time.time())
168 | torch.cuda.manual_seed(time.time())
169 |
170 | path = osp.join(evalsets_path, 'emnist')
171 | if not osp.isdir(path):
172 | os.makedirs(path)
173 |
174 | c1, c2 = args.class_range
175 | filename = f'{c1}-{c2}'
176 | if args.t_noise is not None:
177 | filename += f'_{args.t_noise}'
178 | filename += '.tar'
179 |
180 | torch.save(batches, osp.join(path, filename))
181 |
182 | def eval(args, model):
183 | if args.mode == 'eval':
184 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
185 | model.load_state_dict(ckpt.model)
186 | if args.eval_logfile is None:
187 | c1, c2 = args.class_range
188 | eval_logfile = f'eval_{c1}-{c2}'
189 | if args.t_noise is not None:
190 | eval_logfile += f'_{args.t_noise}'
191 | eval_logfile += '.log'
192 | else:
193 | eval_logfile = args.eval_logfile
194 | filename = osp.join(args.root, eval_logfile)
195 | logger = get_logger(filename, mode='w')
196 | else:
197 | logger = None
198 |
199 | path = osp.join(evalsets_path, 'emnist')
200 | c1, c2 = args.class_range
201 | filename = f'{c1}-{c2}'
202 | if args.t_noise is not None:
203 | filename += f'_{args.t_noise}'
204 | filename += '.tar'
205 | if not osp.isfile(osp.join(path, filename)):
206 | print('generating evaluation sets...')
207 | gen_evalset(args)
208 |
209 | eval_batches = torch.load(osp.join(path, filename))
210 |
211 | torch.manual_seed(args.eval_seed)
212 | torch.cuda.manual_seed(args.eval_seed)
213 |
214 | ravg = RunningAverage()
215 | model.eval()
216 | with torch.no_grad():
217 | for batch in tqdm(eval_batches):
218 | for key, val in batch.items():
219 | batch[key] = val.cuda()
220 | outs = model(batch, num_samples=args.eval_num_samples)
221 | for key, val in outs.items():
222 | ravg.update(key, val)
223 |
224 | torch.manual_seed(time.time())
225 | torch.cuda.manual_seed(time.time())
226 |
227 | c1, c2 = args.class_range
228 | line = f'{args.model}:{args.expid} {c1}-{c2} '
229 | if args.t_noise is not None:
230 | line += f'tn {args.t_noise} '
231 | line += ravg.info()
232 |
233 | if logger is not None:
234 | logger.info(line)
235 |
236 | return line
237 |
238 | def ensemble(args, model):
239 | num_runs = 5
240 | models = []
241 | for i in range(num_runs):
242 | model_ = deepcopy(model)
243 | ckpt = torch.load(osp.join(results_path, 'emnist', args.model, f'run{i+1}', 'ckpt.tar'))
244 | model_.load_state_dict(ckpt['model'])
245 | model_.cuda()
246 | model_.eval()
247 | models.append(model_)
248 |
249 | path = osp.join(evalsets_path, 'emnist')
250 | c1, c2 = args.class_range
251 | filename = f'{c1}-{c2}'
252 | if args.t_noise is not None:
253 | filename += f'_{args.t_noise}'
254 | filename += '.tar'
255 | if not osp.isfile(osp.join(path, filename)):
256 | print('generating evaluation sets...')
257 | gen_evalset(args)
258 |
259 | eval_batches = torch.load(osp.join(path, filename))
260 |
261 | ravg = RunningAverage()
262 | with torch.no_grad():
263 | for batch in tqdm(eval_batches):
264 | for key, val in batch.items():
265 | batch[key] = val.cuda()
266 |
267 | ctx_ll = []
268 | tar_ll = []
269 | for model in models:
270 | outs = model(batch,
271 | num_samples=args.eval_num_samples,
272 | reduce_ll=False)
273 | ctx_ll.append(outs.ctx_ll)
274 | tar_ll.append(outs.tar_ll)
275 |
276 | if ctx_ll[0].dim() == 2:
277 | ctx_ll = torch.stack(ctx_ll)
278 | tar_ll = torch.stack(tar_ll)
279 | else:
280 | ctx_ll = torch.cat(ctx_ll)
281 | tar_ll = torch.cat(tar_ll)
282 |
283 | ctx_ll = logmeanexp(ctx_ll).mean()
284 | tar_ll = logmeanexp(tar_ll).mean()
285 |
286 | ravg.update('ctx_ll', ctx_ll)
287 | ravg.update('tar_ll', tar_ll)
288 |
289 | torch.manual_seed(time.time())
290 | torch.cuda.manual_seed(time.time())
291 |
292 | filename = f'ensemble_{c1}-{c2}'
293 | if args.t_noise is not None:
294 | filename += f'_{args.t_noise}'
295 | filename += '.log'
296 | logger = get_logger(osp.join(results_path, 'emnist', args.model, filename), mode='w')
297 | logger.info(ravg.info())
298 |
299 | if __name__ == '__main__':
300 | main()
301 |
--------------------------------------------------------------------------------
/bayesian_optimization/run_bo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from attrdict import AttrDict
4 |
5 | import numpy as np
6 | import os.path as osp
7 | import yaml
8 |
9 | import torch
10 | from data.gp import *
11 |
12 | import bayeso
13 | import bayeso.gp as bayesogp
14 | from bayeso import covariance
15 | from bayeso import acquisition
16 |
17 | from utils.paths import results_path
18 | from utils.misc import load_module
19 |
20 | def get_str_file(path_, str_kernel, str_model, noise, seed=None):
21 | if noise is not None:
22 | str_all = 'bo_{}_{}_{}'.format(str_kernel, 'noisy', str_model)
23 | else:
24 | str_all = 'bo_{}_{}'.format(str_kernel, str_model)
25 |
26 | if seed is not None:
27 | str_all += '_' + str(seed) + '.npy'
28 | else:
29 | str_all += '.npy'
30 |
31 | return osp.join(path_, str_all)
32 |
33 | def main():
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument('--mode',
37 | choices=['oracle', 'bo'],
38 | default='bo')
39 | parser.add_argument('--expid', type=str, default='run1')
40 | parser.add_argument('--gpu', type=str, default='0')
41 |
42 | parser.add_argument('--model', type=str, default='cnp')
43 |
44 | parser.add_argument('--bo_num_samples', type=int, default=200)
45 | parser.add_argument('--bo_num_init', type=int, default=1)
46 | parser.add_argument('--bo_kernel', type=str, default='periodic')
47 | parser.add_argument('--t_noise', type=float, default=None)
48 |
49 | args = parser.parse_args()
50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
51 |
52 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper())
53 | with open(f'configs/gp/{args.model}.yaml', 'r') as f:
54 | config = yaml.safe_load(f)
55 | model = model_cls(**config).cuda()
56 |
57 | args.root = osp.join(results_path, 'gp', args.model, args.expid)
58 |
59 | if args.mode == 'oracle':
60 | oracle(args, model)
61 | elif args.mode == 'bo':
62 | bo(args, model)
63 |
64 | def oracle(args, model):
65 | seed = 42
66 | num_all = 100
67 | num_iter = 50
68 | num_init = args.bo_num_init
69 | str_cov = 'se'
70 |
71 | list_dict = []
72 |
73 | if args.bo_kernel == 'rbf':
74 | kernel = RBFKernel()
75 | elif args.bo_kernel == 'matern':
76 | kernel = Matern52Kernel()
77 | elif args.bo_kernel == 'periodic':
78 | kernel = PeriodicKernel()
79 | else:
80 | raise ValueError(f'Invalid kernel {args.bo_kernel}')
81 |
82 | for ind_seed in range(1, num_all + 1):
83 | seed_ = seed * ind_seed
84 |
85 | if seed_ is not None:
86 | torch.manual_seed(seed_)
87 | torch.cuda.manual_seed(seed_)
88 |
89 | if os.path.exists(get_str_file('./results', args.bo_kernel, 'oracle', args.t_noise, seed=ind_seed)):
90 | dict_exp = np.load(get_str_file('./results', args.bo_kernel, 'oracle', args.t_noise, seed=ind_seed), allow_pickle=True)
91 | dict_exp = dict_exp[()]
92 | list_dict.append(dict_exp)
93 |
94 | print(dict_exp)
95 | print(dict_exp['global'])
96 | print(np.array2string(dict_exp['minima'], separator=','))
97 | print(np.array2string(dict_exp['regrets'], separator=','))
98 |
99 | continue
100 |
101 | sampler = GPPriorSampler(kernel, t_noise=args.t_noise)
102 |
103 | xp = torch.linspace(-2, 2, 1000).cuda()
104 | xp_ = xp.unsqueeze(0).unsqueeze(2)
105 |
106 | yp = sampler.sample(xp_)
107 | min_yp = yp.min()
108 | print(min_yp.cpu().numpy())
109 |
110 | model.eval()
111 |
112 | batch = AttrDict()
113 | indices_permuted = torch.randperm(yp.shape[1])
114 |
115 | batch.x = xp_[:, indices_permuted[:2*num_init], :]
116 | batch.y = yp[:, indices_permuted[:2*num_init], :]
117 |
118 | batch.xc = xp_[:, indices_permuted[:num_init], :]
119 | batch.yc = yp[:, indices_permuted[:num_init], :]
120 |
121 | batch.xt = xp_[:, indices_permuted[num_init:2*num_init], :]
122 | batch.yt = yp[:, indices_permuted[num_init:2*num_init], :]
123 |
124 | X_train = batch.xc.squeeze(0).cpu().numpy()
125 | Y_train = batch.yc.squeeze(0).cpu().numpy()
126 | X_test = xp_.squeeze(0).cpu().numpy()
127 |
128 | list_min = []
129 | list_min.append(batch.yc.min().cpu().numpy())
130 |
131 | for ind_iter in range(0, num_iter):
132 | print('ind_seed {} seed {} iter {}'.format(ind_seed, seed_, ind_iter + 1))
133 |
134 | cov_X_X, inv_cov_X_X, hyps = bayesogp.get_optimized_kernel(X_train, Y_train, None, str_cov, is_fixed_noise=False, debug=False)
135 |
136 | prior_mu_train = bayesogp.get_prior_mu(None, X_train)
137 | prior_mu_test = bayesogp.get_prior_mu(None, X_test)
138 | cov_X_Xs = covariance.cov_main(str_cov, X_train, X_test, hyps, False)
139 | cov_Xs_Xs = covariance.cov_main(str_cov, X_test, X_test, hyps, True)
140 | cov_Xs_Xs = (cov_Xs_Xs + cov_Xs_Xs.T) / 2.0
141 |
142 | mu_ = np.dot(np.dot(cov_X_Xs.T, inv_cov_X_X), Y_train - prior_mu_train) + prior_mu_test
143 | Sigma_ = cov_Xs_Xs - np.dot(np.dot(cov_X_Xs.T, inv_cov_X_X), cov_X_Xs)
144 | sigma_ = np.expand_dims(np.sqrt(np.maximum(np.diag(Sigma_), 0.0)), axis=1)
145 |
146 | acq_vals = -1.0 * acquisition.ei(np.ravel(mu_), np.ravel(sigma_), Y_train)
147 | ind_ = np.argmin(acq_vals)
148 |
149 | x_new = xp[ind_, None, None, None]
150 | y_new = yp[:, ind_, None, :]
151 |
152 | batch.x = torch.cat([batch.x, x_new], axis=1)
153 | batch.y = torch.cat([batch.y, y_new], axis=1)
154 |
155 | batch.xc = torch.cat([batch.xc, x_new], axis=1)
156 | batch.yc = torch.cat([batch.yc, y_new], axis=1)
157 |
158 | X_train = batch.xc.squeeze(0).cpu().numpy()
159 | Y_train = batch.yc.squeeze(0).cpu().numpy()
160 |
161 | min_cur = batch.yc.min()
162 | list_min.append(min_cur.cpu().numpy())
163 |
164 | print(min_yp.cpu().numpy())
165 | print(np.array2string(np.array(list_min), separator=','))
166 | print(np.array2string(np.array(list_min) - min_yp.cpu().numpy(), separator=','))
167 |
168 | dict_exp = {
169 | 'seed': seed_,
170 | 'str_cov': str_cov,
171 | 'global': min_yp.cpu().numpy(),
172 | 'minima': np.array(list_min),
173 | 'regrets': np.array(list_min) - min_yp.cpu().numpy(),
174 | 'xc': X_train,
175 | 'yc': Y_train,
176 | 'model': 'oracle',
177 | }
178 |
179 | np.save(get_str_file('./results', args.bo_kernel, 'oracle', args.t_noise, seed=ind_seed), dict_exp)
180 | list_dict.append(dict_exp)
181 |
182 | np.save(get_str_file('./figures/results', args.bo_kernel, 'oracle', args.t_noise), list_dict)
183 |
184 | def bo(args, model):
185 | if args.mode == 'bo':
186 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar'))
187 | model.load_state_dict(ckpt.model)
188 |
189 | if args.bo_kernel == 'rbf':
190 | kernel = RBFKernel()
191 | elif args.bo_kernel == 'matern':
192 | kernel = Matern52Kernel()
193 | elif args.bo_kernel == 'periodic':
194 | kernel = PeriodicKernel()
195 | else:
196 | raise ValueError(f'Invalid kernel {args.bo_kernel}')
197 |
198 | seed = 42
199 | str_cov = 'se'
200 | num_all = 100
201 | num_iter = 50
202 | num_init = args.bo_num_init
203 |
204 | list_dict = []
205 |
206 | for ind_seed in range(1, num_all + 1):
207 | seed_ = seed * ind_seed
208 |
209 | if seed_ is not None:
210 | torch.manual_seed(seed_)
211 | torch.cuda.manual_seed(seed_)
212 |
213 | obj_prior = GPPriorSampler(kernel, t_noise=args.t_noise)
214 |
215 | xp = torch.linspace(-2, 2, 1000).cuda()
216 | xp_ = xp.unsqueeze(0).unsqueeze(2)
217 |
218 | yp = obj_prior.sample(xp_)
219 | min_yp = yp.min()
220 | print(min_yp.cpu().numpy())
221 |
222 | model.eval()
223 |
224 | batch = AttrDict()
225 |
226 | indices_permuted = torch.randperm(yp.shape[1])
227 |
228 | batch.x = xp_[:, indices_permuted[:2*num_init], :]
229 | batch.y = yp[:, indices_permuted[:2*num_init], :]
230 |
231 | batch.xc = xp_[:, indices_permuted[:num_init], :]
232 | batch.yc = yp[:, indices_permuted[:num_init], :]
233 |
234 | batch.xt = xp_[:, indices_permuted[num_init:2*num_init], :]
235 | batch.yt = yp[:, indices_permuted[num_init:2*num_init], :]
236 |
237 | X_train = batch.xc.squeeze(0).cpu().numpy()
238 | Y_train = batch.yc.squeeze(0).cpu().numpy()
239 | X_test = xp_.squeeze(0).cpu().numpy()
240 |
241 | list_min = []
242 | list_min.append(batch.yc.min().cpu().numpy())
243 |
244 | for ind_iter in range(0, num_iter):
245 | print('ind_seed {} seed {} iter {}'.format(ind_seed, seed_, ind_iter + 1))
246 |
247 | with torch.no_grad():
248 | outs = model(batch, num_samples=args.bo_num_samples)
249 | print('ctx_ll {:.4f} tar ll {:.4f}'.format(
250 | outs.ctx_ll.item(), outs.tar_ll.item()))
251 | py = model.predict(batch.xc, batch.yc,
252 | xp[None,:,None].repeat(1, 1, 1),
253 | num_samples=args.bo_num_samples)
254 | mu, sigma = py.mean.squeeze(0), py.scale.squeeze(0)
255 |
256 | if mu.dim() == 4:
257 | print(mu.shape, sigma.shape)
258 | var = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2)
259 | sigma = var.sqrt().squeeze(0)
260 | mu = mu.mean(0).squeeze(0)
261 | mu_ = mu.cpu().numpy()
262 | sigma_ = sigma.cpu().numpy()
263 |
264 | acq_vals = -1.0 * acquisition.ei(np.ravel(mu_), np.ravel(sigma_), Y_train)
265 |
266 | # acq_vals = []
267 |
268 | # for ind_mu in range(0, mu.shape[0]):
269 | # acq_vals_ = -1.0 * acquisition.ei(np.ravel(mu[ind_mu].cpu().numpy()), np.ravel(sigma[ind_mu].cpu().numpy()), Y_train)
270 | # acq_vals.append(acq_vals_)
271 |
272 | # acq_vals = np.mean(acq_vals, axis=0)
273 | else:
274 | mu_ = mu.cpu().numpy()
275 | sigma_ = sigma.cpu().numpy()
276 |
277 | acq_vals = -1.0 * acquisition.ei(np.ravel(mu_), np.ravel(sigma_), Y_train)
278 |
279 | # var = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2)
280 | # sigma = var.sqrt().squeeze(0)
281 | # mu = mu.mean(0).squeeze(0)
282 |
283 | ind_ = np.argmin(acq_vals)
284 |
285 | x_new = xp[ind_, None, None, None]
286 | y_new = yp[:, ind_, None, :]
287 |
288 | batch.x = torch.cat([batch.x, x_new], axis=1)
289 | batch.y = torch.cat([batch.y, y_new], axis=1)
290 |
291 | batch.xc = torch.cat([batch.xc, x_new], axis=1)
292 | batch.yc = torch.cat([batch.yc, y_new], axis=1)
293 |
294 | X_train = batch.xc.squeeze(0).cpu().numpy()
295 | Y_train = batch.yc.squeeze(0).cpu().numpy()
296 |
297 | min_cur = batch.yc.min()
298 | list_min.append(min_cur.cpu().numpy())
299 |
300 | print(min_yp.cpu().numpy())
301 | print(np.array2string(np.array(list_min), separator=','))
302 | print(np.array2string(np.array(list_min) - min_yp.cpu().numpy(), separator=','))
303 |
304 | dict_exp = {
305 | 'seed': seed_,
306 | 'global': min_yp.cpu().numpy(),
307 | 'minima': np.array(list_min),
308 | 'regrets': np.array(list_min) - min_yp.cpu().numpy(),
309 | 'xc': X_train,
310 | 'yc': Y_train,
311 | 'model': args.model,
312 | 'cov': str_cov,
313 | }
314 |
315 | list_dict.append(dict_exp)
316 |
317 | np.save(get_str_file('./figures/results', args.bo_kernel, args.model, args.t_noise), list_dict)
318 |
319 | if __name__ == '__main__':
320 | main()
321 |
--------------------------------------------------------------------------------
/regression/lotka_volterra.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import argparse
5 | import yaml
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import math
11 | import time
12 | import matplotlib.pyplot as plt
13 | from attrdict import AttrDict
14 | from tqdm import tqdm
15 | from copy import deepcopy
16 |
17 | from utils.misc import load_module, logmeanexp
18 | from utils.paths import results_path, datasets_path, evalsets_path
19 | from utils.log import get_logger, RunningAverage
20 | from data.lotka_volterra import load_hare_lynx
21 |
22 | def standardize(batch):
23 | with torch.no_grad():
24 | mu, sigma = batch.xc.mean(-2, keepdim=True), batch.xc.std(-2, keepdim=True)
25 | sigma[sigma==0] = 1.0
26 | batch.x = (batch.x - mu) / (sigma + 1e-5)
27 | batch.xc = (batch.xc - mu) / (sigma + 1e-5)
28 | batch.xt = (batch.xt - mu) / (sigma + 1e-5)
29 |
30 | mu, sigma = batch.yc.mean(-2, keepdim=True), batch.yc.std(-2, keepdim=True)
31 | batch.y = (batch.y - mu) / (sigma + 1e-5)
32 | batch.yc = (batch.yc - mu) / (sigma + 1e-5)
33 | batch.yt = (batch.yt - mu) / (sigma + 1e-5)
34 | return batch
35 |
36 | def main():
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--mode',
39 | choices=['train', 'eval', 'plot', 'ensemble'],
40 | default='train')
41 | parser.add_argument('--expid', type=str, default='trial')
42 | parser.add_argument('--resume', action='store_true', default=False)
43 | parser.add_argument('--gpu', type=str, default='0')
44 |
45 | parser.add_argument('--max_num_points', type=int, default=50)
46 |
47 | parser.add_argument('--model', type=str, default='cnp')
48 | parser.add_argument('--train_batch_size', type=int, default=100)
49 | parser.add_argument('--train_num_samples', type=int, default=4)
50 |
51 | parser.add_argument('--lr', type=float, default=5e-4)
52 | parser.add_argument('--print_freq', type=int, default=200)
53 | parser.add_argument('--eval_freq', type=int, default=5000)
54 | parser.add_argument('--save_freq', type=int, default=1000)
55 |
56 | parser.add_argument('--eval_seed', type=int, default=42)
57 | parser.add_argument('--hare_lynx', action='store_true')
58 | parser.add_argument('--eval_num_samples', type=int, default=50)
59 | parser.add_argument('--eval_logfile', type=str, default=None)
60 |
61 | parser.add_argument('--plot_seed', type=int, default=None)
62 | parser.add_argument('--plot_batch_size', type=int, default=16)
63 | parser.add_argument('--plot_num_samples', type=int, default=30)
64 |
65 | args = parser.parse_args()
66 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
67 |
68 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper())
69 | with open(f'configs/lotka_volterra/{args.model}.yaml', 'r') as f:
70 | config = yaml.safe_load(f)
71 | model = model_cls(**config).cuda()
72 |
73 | args.root = osp.join(results_path, 'lotka_volterra', args.model, args.expid)
74 |
75 | if args.mode == 'train':
76 | train(args, model)
77 | elif args.mode == 'eval':
78 | eval(args, model)
79 | elif args.mode == 'plot':
80 | plot(args, model)
81 | elif args.mode == 'ensemble':
82 | ensemble(args, model)
83 |
84 | def train(args, model):
85 | if not osp.isdir(args.root):
86 | os.makedirs(args.root)
87 |
88 | with open(osp.join(args.root, 'args.yaml'), 'w') as f:
89 | yaml.dump(args.__dict__, f)
90 |
91 | train_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'train.tar'))
92 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar'))
93 | num_steps = len(train_data)
94 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
95 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
96 | optimizer, T_max=num_steps)
97 |
98 | if args.resume:
99 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
100 | model.load_state_dict(ckpt.model)
101 | optimizer.load_state_dict(ckpt.optimizer)
102 | scheduler.load_state_dict(ckpt.scheduler)
103 | logfilename = ckpt.logfilename
104 | start_step = ckpt.step
105 | else:
106 | logfilename = osp.join(args.root,
107 | f'train_{time.strftime("%Y%m%d-%H%M")}.log')
108 | start_step = 1
109 |
110 | logger = get_logger(logfilename)
111 | ravg = RunningAverage()
112 |
113 | if not args.resume:
114 | logger.info('Total number of parameters: {}\n'.format(
115 | sum(p.numel() for p in model.parameters())))
116 |
117 | for step in range(start_step, num_steps+1):
118 | model.train()
119 | optimizer.zero_grad()
120 |
121 | batch = standardize(train_data[step-1])
122 | for key, val in batch.items():
123 | batch[key] = val.cuda()
124 |
125 | outs = model(batch, num_samples=args.train_num_samples)
126 | outs.loss.backward()
127 | optimizer.step()
128 | scheduler.step()
129 |
130 | for key, val in outs.items():
131 | ravg.update(key, val)
132 |
133 | if step % args.print_freq == 0:
134 | line = f'{args.model}:{args.expid} step {step} '
135 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} '
136 | line += ravg.info()
137 | logger.info(line)
138 |
139 | if step % args.eval_freq == 0:
140 | line = eval(args, model, eval_data=eval_data)
141 | logger.info(line + '\n')
142 |
143 | ravg.reset()
144 |
145 | if step % args.save_freq == 0 or step == num_steps:
146 | ckpt = AttrDict()
147 | ckpt.model = model.state_dict()
148 | ckpt.optimizer = optimizer.state_dict()
149 | ckpt.scheduler = scheduler.state_dict()
150 | ckpt.logfilename = logfilename
151 | ckpt.step = step + 1
152 | torch.save(ckpt, osp.join(args.root, 'ckpt.tar'))
153 |
154 | args.mode = 'eval'
155 | eval(args, model, eval_data=eval_data)
156 |
157 | def eval(args, model, eval_data=None):
158 | if args.mode == 'eval':
159 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
160 | model.load_state_dict(ckpt.model)
161 | if args.eval_logfile is None:
162 | if args.hare_lynx:
163 | eval_logfile = 'hare_lynx.log'
164 | else:
165 | eval_logfile = 'eval.log'
166 | else:
167 | eval_logfile = args.eval_logfile
168 | filename = osp.join(args.root, eval_logfile)
169 | logger = get_logger(filename, mode='w')
170 | else:
171 | logger = None
172 |
173 | torch.manual_seed(args.eval_seed)
174 | torch.cuda.manual_seed(args.eval_seed)
175 |
176 | if eval_data is None:
177 | if args.hare_lynx:
178 | eval_data = load_hare_lynx(1000, 16)
179 | else:
180 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar'))
181 |
182 | ravg = RunningAverage()
183 | model.eval()
184 | with torch.no_grad():
185 | for batch in tqdm(eval_data):
186 | batch = standardize(batch)
187 | for key, val in batch.items():
188 | batch[key] = val.cuda()
189 | outs = model(batch, num_samples=args.eval_num_samples)
190 | for key, val in outs.items():
191 | ravg.update(key, val)
192 |
193 | torch.manual_seed(time.time())
194 | torch.cuda.manual_seed(time.time())
195 |
196 | line = f'{args.model}:{args.expid} '
197 | line += ravg.info()
198 |
199 | if logger is not None:
200 | logger.info(line)
201 |
202 | return line
203 |
204 | def ensemble(args, model):
205 | num_runs = 5
206 | models = []
207 | for i in range(num_runs):
208 | model_ = deepcopy(model)
209 | ckpt = torch.load(osp.join(results_path, 'lotka_volterra', args.model, f'run{i+1}', 'ckpt.tar'))
210 | model_.load_state_dict(ckpt['model'])
211 | model_.cuda()
212 | model_.eval()
213 | models.append(model_)
214 |
215 | torch.manual_seed(args.eval_seed)
216 | torch.cuda.manual_seed(args.eval_seed)
217 |
218 | if args.hare_lynx:
219 | eval_data = load_hare_lynx(1000, 16)
220 | else:
221 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar'))
222 |
223 | ravg = RunningAverage()
224 | with torch.no_grad():
225 | for batch in tqdm(eval_data):
226 | batch = standardize(batch)
227 | for key, val in batch.items():
228 | batch[key] = val.cuda()
229 |
230 | ctx_ll = []
231 | tar_ll = []
232 | for model_ in models:
233 | outs = model_(batch,
234 | num_samples=args.eval_num_samples,
235 | reduce_ll=False)
236 | ctx_ll.append(outs.ctx_ll)
237 | tar_ll.append(outs.tar_ll)
238 |
239 | if ctx_ll[0].dim() == 2:
240 | ctx_ll = torch.stack(ctx_ll)
241 | tar_ll = torch.stack(tar_ll)
242 | else:
243 | ctx_ll = torch.cat(ctx_ll)
244 | tar_ll = torch.cat(tar_ll)
245 |
246 | ctx_ll = logmeanexp(ctx_ll).mean()
247 | tar_ll = logmeanexp(tar_ll).mean()
248 |
249 | ravg.update('ctx_ll', ctx_ll)
250 | ravg.update('tar_ll', tar_ll)
251 |
252 | torch.manual_seed(time.time())
253 | torch.cuda.manual_seed(time.time())
254 |
255 | filename = 'ensemble'
256 | if args.hare_lynx:
257 | filename += '_hare_lynx'
258 | filename += '.log'
259 | logger = get_logger(osp.join(results_path, 'lotka_volterra', args.model, filename), mode='w')
260 | logger.info(ravg.info())
261 |
262 | def plot(args, model):
263 | ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
264 | model.load_state_dict(ckpt.model)
265 |
266 | def tnp(x):
267 | return x.squeeze().cpu().data.numpy()
268 |
269 | if args.hare_lynx:
270 | eval_data = load_hare_lynx(1000, 16)
271 | else:
272 | eval_data = torch.load(osp.join(datasets_path, 'lotka_volterra', 'eval.tar'))
273 | bid = torch.randint(len(eval_data), [1]).item()
274 | batch = standardize(eval_data[bid])
275 |
276 | for k, v in batch.items():
277 | batch[k] = v.cuda()
278 |
279 | model.eval()
280 | outs = model(batch, num_samples=args.eval_num_samples)
281 | print(outs.tar_ll)
282 |
283 | fig, axes = plt.subplots(4, 4, figsize=(20, 20))
284 | xp = []
285 | for b in range(batch.x.shape[0]):
286 | bx = batch.x[b]
287 | xp.append(torch.linspace(bx.min()-0.1, bx.max()+0.1, 200))
288 | xp = torch.stack(xp).unsqueeze(-1).cuda()
289 |
290 | model.eval()
291 | with torch.no_grad():
292 | py = model.predict(batch.xc, batch.yc, xp, num_samples=args.plot_num_samples)
293 | mu, sigma = py.mean, py.scale
294 |
295 | if mu.dim() > 3:
296 | bmu = mu.mean(0)
297 | bvar = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2)
298 | bsigma = bvar.sqrt()
299 | else:
300 | bmu = mu
301 | bsigma = sigma
302 |
303 | for i, ax in enumerate(axes.flatten()):
304 | ax.plot(tnp(xp[i]), tnp(bmu[i]), alpha=0.5)
305 | upper = tnp(bmu[i][:,0] + bsigma[i][:,0])
306 | lower = tnp(bmu[i][:,0] - bsigma[i][:,0])
307 | ax.fill_between(tnp(xp[i]), lower, upper,
308 | alpha=0.2, linewidth=0.0, label='predator')
309 |
310 | upper = tnp(bmu[i][:,1] + bsigma[i][:,1])
311 | lower = tnp(bmu[i][:,1] - bsigma[i][:,1])
312 | ax.fill_between(tnp(xp[i]), lower, upper,
313 | alpha=0.2, linewidth=0.0, label='prey')
314 |
315 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i][:,0]), color='k', marker='*')
316 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i][:,1]), color='k', marker='*')
317 |
318 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i][:,0]), color='orchid', marker='x')
319 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i][:,1]), color='orchid', marker='x')
320 |
321 | plt.tight_layout()
322 | plt.show()
323 |
324 | if __name__ == '__main__':
325 | main()
326 |
--------------------------------------------------------------------------------
/regression/gp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import argparse
5 | import yaml
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import math
11 | import time
12 | import matplotlib.pyplot as plt
13 | from attrdict import AttrDict
14 | from tqdm import tqdm
15 | from copy import deepcopy
16 |
17 | from data.gp import *
18 |
19 | from utils.misc import load_module, logmeanexp
20 | from utils.paths import results_path, evalsets_path
21 | from utils.log import get_logger, RunningAverage
22 |
23 | def main():
24 | parser = argparse.ArgumentParser()
25 |
26 | parser.add_argument('--mode',
27 | choices=['train', 'eval', 'plot', 'ensemble'],
28 | default='train')
29 | parser.add_argument('--expid', type=str, default='trial')
30 | parser.add_argument('--resume', action='store_true', default=False)
31 | parser.add_argument('--gpu', type=str, default='0')
32 |
33 | parser.add_argument('--max_num_points', type=int, default=50)
34 |
35 | parser.add_argument('--model', type=str, default='cnp')
36 | parser.add_argument('--train_batch_size', type=int, default=100)
37 | parser.add_argument('--train_num_samples', type=int, default=4)
38 |
39 | parser.add_argument('--lr', type=float, default=5e-4)
40 | parser.add_argument('--num_steps', type=int, default=100000)
41 | parser.add_argument('--print_freq', type=int, default=200)
42 | parser.add_argument('--eval_freq', type=int, default=5000)
43 | parser.add_argument('--save_freq', type=int, default=1000)
44 |
45 | parser.add_argument('--eval_seed', type=int, default=42)
46 | parser.add_argument('--eval_num_batches', type=int, default=3000)
47 | parser.add_argument('--eval_batch_size', type=int, default=16)
48 | parser.add_argument('--eval_num_samples', type=int, default=50)
49 | parser.add_argument('--eval_logfile', type=str, default=None)
50 |
51 | parser.add_argument('--plot_seed', type=int, default=None)
52 | parser.add_argument('--plot_batch_size', type=int, default=16)
53 | parser.add_argument('--plot_num_samples', type=int, default=30)
54 | parser.add_argument('--plot_num_ctx', type=int, default=None)
55 |
56 | # OOD settings
57 | parser.add_argument('--eval_kernel', type=str, default='rbf')
58 | parser.add_argument('--t_noise', type=float, default=None)
59 |
60 | args = parser.parse_args()
61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
62 |
63 | model_cls = getattr(load_module(f'models/{args.model}.py'), args.model.upper())
64 | with open(f'configs/gp/{args.model}.yaml', 'r') as f:
65 | config = yaml.safe_load(f)
66 | model = model_cls(**config).cuda()
67 |
68 | args.root = osp.join(results_path, 'gp', args.model, args.expid)
69 |
70 | if args.mode == 'train':
71 | train(args, model)
72 | elif args.mode == 'eval':
73 | eval(args, model)
74 | elif args.mode == 'plot':
75 | plot(args, model)
76 | elif args.mode == 'ensemble':
77 | ensemble(args, model)
78 |
79 | def train(args, model):
80 | if not osp.isdir(args.root):
81 | os.makedirs(args.root)
82 |
83 | with open(osp.join(args.root, 'args.yaml'), 'w') as f:
84 | yaml.dump(args.__dict__, f)
85 |
86 | sampler = GPSampler(RBFKernel())
87 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
88 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
89 | optimizer, T_max=args.num_steps)
90 |
91 | if args.resume:
92 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar'))
93 | model.load_state_dict(ckpt.model)
94 | optimizer.load_state_dict(ckpt.optimizer)
95 | scheduler.load_state_dict(ckpt.scheduler)
96 | logfilename = ckpt.logfilename
97 | start_step = ckpt.step
98 | else:
99 | logfilename = os.path.join(args.root,
100 | f'train_{time.strftime("%Y%m%d-%H%M")}.log')
101 | start_step = 1
102 |
103 | logger = get_logger(logfilename)
104 | ravg = RunningAverage()
105 |
106 | if not args.resume:
107 | logger.info('Total number of parameters: {}\n'.format(
108 | sum(p.numel() for p in model.parameters())))
109 |
110 | for step in range(start_step, args.num_steps+1):
111 | model.train()
112 | optimizer.zero_grad()
113 | batch = sampler.sample(
114 | batch_size=args.train_batch_size,
115 | max_num_points=args.max_num_points,
116 | device='cuda')
117 | outs = model(batch, num_samples=args.train_num_samples)
118 | outs.loss.backward()
119 | optimizer.step()
120 | scheduler.step()
121 |
122 | for key, val in outs.items():
123 | ravg.update(key, val)
124 |
125 | if step % args.print_freq == 0:
126 | line = f'{args.model}:{args.expid} step {step} '
127 | line += f'lr {optimizer.param_groups[0]["lr"]:.3e} '
128 | line += ravg.info()
129 | logger.info(line)
130 |
131 | if step % args.eval_freq == 0:
132 | line = eval(args, model)
133 | logger.info(line + '\n')
134 |
135 | ravg.reset()
136 |
137 | if step % args.save_freq == 0 or step == args.num_steps:
138 | ckpt = AttrDict()
139 | ckpt.model = model.state_dict()
140 | ckpt.optimizer = optimizer.state_dict()
141 | ckpt.scheduler = scheduler.state_dict()
142 | ckpt.logfilename = logfilename
143 | ckpt.step = step + 1
144 | torch.save(ckpt, os.path.join(args.root, 'ckpt.tar'))
145 |
146 | args.mode = 'eval'
147 | eval(args, model)
148 |
149 | def gen_evalset(args):
150 | if args.eval_kernel == 'rbf':
151 | kernel = RBFKernel()
152 | elif args.eval_kernel == 'matern':
153 | kernel = Matern52Kernel()
154 | elif args.eval_kernel == 'periodic':
155 | kernel = PeriodicKernel()
156 | else:
157 | raise ValueError(f'Invalid kernel {args.eval_kernel}')
158 |
159 | torch.manual_seed(args.eval_seed)
160 | torch.cuda.manual_seed(args.eval_seed)
161 |
162 | sampler = GPSampler(kernel, t_noise=args.t_noise)
163 | batches = []
164 | for i in tqdm(range(args.eval_num_batches)):
165 | batches.append(sampler.sample(
166 | batch_size=args.eval_batch_size,
167 | max_num_points=args.max_num_points))
168 |
169 | torch.manual_seed(time.time())
170 | torch.cuda.manual_seed(time.time())
171 |
172 | path = osp.join(evalsets_path, 'gp')
173 | if not osp.isdir(path):
174 | os.makedirs(path)
175 |
176 | filename = f'{args.eval_kernel}'
177 | if args.t_noise is not None:
178 | filename += f'_{args.t_noise}'
179 | filename += '.tar'
180 |
181 | torch.save(batches, osp.join(path, filename))
182 |
183 | def eval(args, model):
184 | if args.mode == 'eval':
185 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar'))
186 | model.load_state_dict(ckpt.model)
187 | if args.eval_logfile is None:
188 | eval_logfile = f'eval_{args.eval_kernel}'
189 | if args.t_noise is not None:
190 | eval_logfile += f'_tn_{args.t_noise}'
191 | eval_logfile += '.log'
192 | else:
193 | eval_logfile = args.eval_logfile
194 | filename = os.path.join(args.root, eval_logfile)
195 | logger = get_logger(filename, mode='w')
196 | else:
197 | logger = None
198 |
199 | if args.eval_kernel == 'rbf':
200 | kernel = RBFKernel()
201 | elif args.eval_kernel == 'matern':
202 | kernel = Matern52Kernel()
203 | elif args.eval_kernel == 'periodic':
204 | kernel = PeriodicKernel()
205 | else:
206 | raise ValueError(f'Invalid kernel {args.eval_kernel}')
207 |
208 | path = osp.join(evalsets_path, 'gp')
209 | filename = f'{args.eval_kernel}'
210 | if args.t_noise is not None:
211 | filename += f'_{args.t_noise}'
212 | filename += '.tar'
213 | if not osp.isfile(osp.join(path, filename)):
214 | print('generating evaluation sets...')
215 | gen_evalset(args)
216 |
217 | eval_batches = torch.load(osp.join(path, filename))
218 |
219 | torch.manual_seed(args.eval_seed)
220 | torch.cuda.manual_seed(args.eval_seed)
221 |
222 | ravg = RunningAverage()
223 | model.eval()
224 | with torch.no_grad():
225 | for batch in tqdm(eval_batches):
226 | for key, val in batch.items():
227 | batch[key] = val.cuda()
228 | outs = model(batch, num_samples=args.eval_num_samples)
229 | for key, val in outs.items():
230 | ravg.update(key, val)
231 |
232 | torch.manual_seed(time.time())
233 | torch.cuda.manual_seed(time.time())
234 |
235 | line = f'{args.model}:{args.expid} {args.eval_kernel} '
236 | if args.t_noise is not None:
237 | line += f'tn {args.t_noise} '
238 | line += ravg.info()
239 |
240 | if logger is not None:
241 | logger.info(line)
242 |
243 | return line
244 |
245 | def plot(args, model):
246 | ckpt = torch.load(os.path.join(args.root, 'ckpt.tar'))
247 | model.load_state_dict(ckpt.model)
248 |
249 | def tnp(x):
250 | return x.squeeze().cpu().data.numpy()
251 |
252 | if args.plot_seed is not None:
253 | torch.manual_seed(args.plot_seed)
254 | torch.cuda.manual_seed(args.plot_seed)
255 |
256 | kernel = RBFKernel() if args.pp is None else PeriodicKernel(p=args.pp)
257 | sampler = GPSampler(kernel, t_noise=args.t_noise)
258 |
259 | xp = torch.linspace(-2, 2, 200).cuda()
260 | batch = sampler.sample(
261 | batch_size=args.plot_batch_size,
262 | max_num_points=args.max_num_points,
263 | num_ctx=args.plot_num_ctx,
264 | device='cuda')
265 |
266 | model.eval()
267 | with torch.no_grad():
268 | outs = model(batch, num_samples=args.eval_num_samples)
269 | print(f'ctx_ll {outs.ctx_ll.item():.4f}, tar_ll {outs.tar_ll.item():.4f}')
270 |
271 | py = model.predict(batch.xc, batch.yc,
272 | xp[None,:,None].repeat(args.plot_batch_size, 1, 1),
273 | num_samples=args.plot_num_samples)
274 | mu, sigma = py.mean.squeeze(0), py.scale.squeeze(0)
275 |
276 | if args.plot_batch_size > 1:
277 | nrows = max(args.plot_batch_size//4, 1)
278 | ncols = min(4, args.plot_batch_size)
279 | fig, axes = plt.subplots(nrows, ncols,
280 | figsize=(5*ncols, 5*nrows))
281 | axes = axes.flatten()
282 | else:
283 | fig = plt.figure(figsize=(5, 5))
284 | axes = [plt.gca()]
285 |
286 | # multi sample
287 | if mu.dim() == 4:
288 | #var = sigma.pow(2).mean(0) + mu.pow(2).mean(0) - mu.mean(0).pow(2)
289 | #sigma = var.sqrt()
290 | #mu = mu.mean(0)
291 |
292 | for i, ax in enumerate(axes):
293 | #ax.plot(tnp(xp), tnp(mu[i]), color='steelblue', alpha=0.5)
294 | #ax.fill_between(tnp(xp), tnp(mu[i]-sigma[i]), tnp(mu[i]+sigma[i]),
295 | # color='skyblue', alpha=0.2, linewidth=0.0)
296 | for s in range(mu.shape[0]):
297 | ax.plot(tnp(xp), tnp(mu[s][i]), color='steelblue',
298 | alpha=max(0.5/args.plot_num_samples, 0.1))
299 | ax.fill_between(tnp(xp), tnp(mu[s][i])-tnp(sigma[s][i]),
300 | tnp(mu[s][i])+tnp(sigma[s][i]),
301 | color='skyblue',
302 | alpha=max(0.2/args.plot_num_samples, 0.02),
303 | linewidth=0.0)
304 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i]),
305 | color='k', label='context', zorder=mu.shape[0]+1)
306 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i]),
307 | color='orchid', label='target',
308 | zorder=mu.shape[0]+1)
309 | ax.legend()
310 | else:
311 | for i, ax in enumerate(axes):
312 | ax.plot(tnp(xp), tnp(mu[i]), color='steelblue', alpha=0.5)
313 | ax.fill_between(tnp(xp), tnp(mu[i]-sigma[i]), tnp(mu[i]+sigma[i]),
314 | color='skyblue', alpha=0.2, linewidth=0.0)
315 | ax.scatter(tnp(batch.xc[i]), tnp(batch.yc[i]),
316 | color='k', label='context')
317 | ax.scatter(tnp(batch.xt[i]), tnp(batch.yt[i]),
318 | color='orchid', label='target')
319 | ax.legend()
320 |
321 | plt.tight_layout()
322 | plt.show()
323 |
324 | def ensemble(args, model):
325 | num_runs = 5
326 | models = []
327 | for i in range(num_runs):
328 | model_ = deepcopy(model)
329 | ckpt = torch.load(osp.join(results_path, 'gp', args.model, f'run{i+1}', 'ckpt.tar'))
330 | model_.load_state_dict(ckpt['model'])
331 | model_.cuda()
332 | model_.eval()
333 | models.append(model_)
334 |
335 | path = osp.join(evalsets_path, 'gp')
336 | filename = f'{args.eval_kernel}'
337 | if args.t_noise is not None:
338 | filename += f'_{args.t_noise}'
339 | filename += '.tar'
340 | if not osp.isfile(osp.join(path, filename)):
341 | print('generating evaluation sets...')
342 | gen_evalset(args)
343 |
344 | eval_batches = torch.load(osp.join(path, filename))
345 |
346 | torch.manual_seed(args.eval_seed)
347 | torch.cuda.manual_seed(args.eval_seed)
348 |
349 | ravg = RunningAverage()
350 | with torch.no_grad():
351 | for batch in tqdm(eval_batches):
352 | for key, val in batch.items():
353 | batch[key] = val.cuda()
354 |
355 | ctx_ll = []
356 | tar_ll = []
357 | for model in models:
358 | outs = model(batch,
359 | num_samples=args.eval_num_samples,
360 | reduce_ll=False)
361 | ctx_ll.append(outs.ctx_ll)
362 | tar_ll.append(outs.tar_ll)
363 |
364 | if ctx_ll[0].dim() == 2:
365 | ctx_ll = torch.stack(ctx_ll)
366 | tar_ll = torch.stack(tar_ll)
367 | else:
368 | ctx_ll = torch.cat(ctx_ll)
369 | tar_ll = torch.cat(tar_ll)
370 |
371 | ctx_ll = logmeanexp(ctx_ll).mean()
372 | tar_ll = logmeanexp(tar_ll).mean()
373 |
374 | ravg.update('ctx_ll', ctx_ll)
375 | ravg.update('tar_ll', tar_ll)
376 |
377 | filename = f'ensemble_{args.eval_kernel}'
378 | if args.t_noise is not None:
379 | filename += f'_{args.t_noise}'
380 | filename += '.log'
381 | logger = get_logger(osp.join(results_path, 'gp', args.model, filename), mode='w')
382 | logger.info(ravg.info())
383 |
384 | if __name__ == '__main__':
385 | main()
386 |
--------------------------------------------------------------------------------