├── .gitignore ├── README.md ├── argument.py ├── configs ├── synthetic │ ├── config_test.yaml │ └── config_train.yaml └── weather │ ├── config_test.yaml │ └── config_train.yaml ├── data ├── generate_synthetic_data.py ├── synthetic │ ├── sine,tanh,sigmoid,gaussian_N1000_n200.png │ └── sine,tanh,sigmoid,gaussian_N1000_n200.pth └── weather │ └── weather_266cities_12tasks_258days.pth ├── dataset ├── __init__.py ├── synthetic.py ├── utils.py └── weather.py ├── main.py ├── model ├── __init__.py ├── attention.py ├── mlp.py ├── module.py ├── mtnp.py └── utils.py ├── requirments.txt ├── test.py └── train ├── __init__.py ├── loss.py ├── schedulers.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.ipynb_checkpoints* 3 | *label_masks 4 | *point_permutations 5 | experiments 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Task Neural Processes 2 | 3 | This repository contains a pytorch implementation of [Multi-Task Neural Processes](https://arxiv.org/abs/2110.14953). 4 | 5 | 6 | ## Basic Usage 7 | 8 | ### Training with missing rate 0.5 9 | ``` 10 | python main.py --model [mtp/stp/jtp/mtp_s] --data [synthetic/weather] --gamma_train 0.5 11 | ``` 12 | 13 | ### Testing with missing rate 0.5 and context size 10 14 | ``` 15 | python test.py --eval_name [mtp/stp/jtp/mtp_s] --data [synthetic/weather] --gamma 0.5 --cs 10 16 | ``` 17 | 18 | ## Citation 19 | If you find this work useful, please consider citing: 20 | ```bib 21 | @inproceedings{kim2021multi, 22 | title={Multi-Task Processes}, 23 | author={Kim, Donggyun and Cho, Seongwoong and Lee, Wonkwang and Hong, Seunghoon}, 24 | booktitle={International Conference on Learning Representations}, 25 | year={2021} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if v == 'True' or v == 'true': return True 6 | elif v == 'False' or v == 'false': return False 7 | else: raise argparse.ArgumentTypeError('Boolean value expected.') 8 | 9 | 10 | DATASETS = ['synthetic', 'weather'] 11 | MODELS = ['stp', 'jtp', 'mtp', 'mtp_s'] 12 | LR_SCHEDULES = ['constant', 'sqroot', 'cos', 'poly'] 13 | BETA_SCHEDULES = ['constant', 'linear_warmup'] 14 | 15 | # argument parser 16 | parser = argparse.ArgumentParser() 17 | 18 | # basic arguments 19 | parser.add_argument('--data', type=str, default='synthetic', choices=DATASETS) 20 | parser.add_argument('--model', type=str, default='mtp_s', choices=MODELS) 21 | parser.add_argument('--seed', type=int, default=0) 22 | parser.add_argument('--name', type=str, default='') 23 | parser.add_argument('--log_root', type=str, default='experiments') 24 | parser.add_argument('--imputer_path', type=str, default='') 25 | parser.add_argument('--name_postfix', '-ptf', type=str, default='') 26 | parser.add_argument('--debug_mode', '-debug', default=False, action='store_true') 27 | 28 | # model-specific arguments 29 | parser.add_argument('--dim_hidden', type=int, default=-1) 30 | parser.add_argument('--pma', type=str2bool, default=True) 31 | 32 | # training arguments 33 | parser.add_argument('--n_steps', type=int, default=-1) 34 | parser.add_argument('--global_batch_size', type=int, default=-1) 35 | parser.add_argument('--lr', type=float, default=-1.) 36 | parser.add_argument('--lr_schedule', '-lrs', type=str, default='', choices=LR_SCHEDULES) 37 | parser.add_argument('--beta_T_schedule', '-bts', type=str, default='', choices=BETA_SCHEDULES) 38 | parser.add_argument('--beta_G_schedule', '-bgs', type=str, default='', choices=BETA_SCHEDULES) 39 | parser.add_argument('--gamma_train', '-gtr', type=float, default=-1) 40 | parser.add_argument('--gamma_valid', '-gvl', type=float, default=-1) 41 | parser.add_argument('--cs_range_train', '-csr', nargs='+', default=[]) 42 | parser.add_argument('--ts_train', '-ts', type=int, default=-1) 43 | 44 | args = parser.parse_args() -------------------------------------------------------------------------------- /configs/synthetic/config_test.yaml: -------------------------------------------------------------------------------- 1 | ### data configs 2 | 3 | data: 'synthetic' 4 | data_path: 'data/synthetic/sine,tanh,sigmoid,gaussian_N1000_n200.pth' 5 | tasks: ['sine', 'tanh', 'sigmoid', 'gaussian'] # list of task names 6 | task_types: {'sine': 'continuous', 'tanh': 'continuous', 'sigmoid': 'continuous', 'gaussian': 'continuous'} # whether each task is continuous or discrete 7 | 8 | dim_x: 1 # input dimension 9 | dim_ys: {'sine': 1, 'tanh': 1, 'sigmoid': 1, 'gaussian': 1} # output dimensions or channels 10 | 11 | split_ratio: [0.8, 0.1, 0.1] 12 | num_workers: 4 13 | 14 | colors: {'sine': 'r', 'tanh': 'g', 'sigmoid': 'b', 'gaussian': 'c'} 15 | 16 | global_batch_size: 4 # number of datasets (multi-task functions) in a batch 17 | 18 | ### validation configs 19 | 20 | ts_valid: 200 21 | 22 | ### test configs 23 | 24 | ts_test: 1000 25 | ns_G: 5 # number of global sampling (for JTP, MTP) 26 | ns_T: 5 # number of per-task samplings (for STP, MTP) 27 | 28 | imputer_path: 'experiments/runs_synthetic/stp/checkpoints/best_error.pth' # imputer checkpoint path (for JTP) 29 | 30 | ### checkpointing configs 31 | 32 | eval_dir: 'runs_synthetic' # directory where the models to evaluate are stored -------------------------------------------------------------------------------- /configs/synthetic/config_train.yaml: -------------------------------------------------------------------------------- 1 | ### data configs 2 | 3 | data: 'synthetic' 4 | data_path: 'data/synthetic/sine,tanh,sigmoid,gaussian_N1000_n200.pth' 5 | tasks: ['sine', 'tanh', 'sigmoid', 'gaussian'] # list of task names 6 | task_types: {'sine': 'continuous', 'tanh': 'continuous', 'sigmoid': 'continuous', 'gaussian': 'continuous'} # whether each task is continuous or discrete 7 | 8 | dim_x: 1 # input dimension 9 | dim_ys: {'sine': 1, 'tanh': 1, 'sigmoid': 1, 'gaussian': 1} # output dimensions or channels 10 | 11 | split_ratio: [0.8, 0.1, 0.1] 12 | num_workers: 4 13 | 14 | colors: {'sine': 'r', 'tanh': 'g', 'sigmoid': 'b', 'gaussian': 'c'} 15 | 16 | 17 | ### training configs 18 | 19 | n_steps: 300000 # total training steps 20 | global_batch_size: 24 # number of datasets (multi-task functions) in a batch 21 | 22 | lr: 0.00025 # learning rate 23 | lr_schedule: 'sqroot' 24 | lr_warmup: 1000 25 | 26 | beta_G: 1 # beta coefficient for global kld 27 | beta_G_schedule: 'linear_warmup' 28 | beta_G_warmup: 10000 29 | 30 | beta_T: 1 # beta coefficient for per-task klds 31 | beta_T_schedule: 'linear_warmup' 32 | beta_T_warmup: 10000 33 | 34 | gamma_train: 0.5 # missing rate 35 | cs_range_train: [5, 20] # context size, null means default range (len(tasks), ts // 2) 36 | ts_train: 200 # target size 37 | 38 | 39 | ### validation configs 40 | 41 | cs_valid: 10 42 | ts_valid: 200 43 | ns_G: 5 # number of global sampling 44 | ns_T: 5 # number of per-task samplings 45 | 46 | gamma_valid: 0.5 # missing rate 47 | imputer_path: 'experiments/runs_synthetic/stp/checkpoints/best_error.pth' # imputer checkpoint path (for JTP) 48 | 49 | 50 | ### model configs 51 | 52 | dim_hidden: 128 # width of the networks, serves as a basic unit in all layers except the input & output heads (and also the latent dimensions) 53 | module_sizes: [3, 3, 2, 5] # depth of the networks: (element-wise encoder, intra-task attention, inter-task attention, element-wise decoder) 54 | pma: True # whether to use PMA pooling rather than average pooling 55 | 56 | attn_config: 57 | act_fn: 'gelu' 58 | ln: True # layernorm in attentions and mlps 59 | dr: 0.1 # dropout in mlps 60 | n_heads: 4 # number of attention heads 61 | epsilon: 0.1 # minimum standard deviation for Normal latent variables 62 | 63 | 64 | ### logging configs 65 | 66 | log_iter: 100 # interval between tqdm and tensorboard logging of training metrics 67 | val_iter: 5000 # interval between validation and tensorboard logging of validation metrics 68 | save_iter: 5000 # interval between checkpointing 69 | log_dir: 'runs_synthetic' # directory for saving checkpoints and logs 70 | -------------------------------------------------------------------------------- /configs/weather/config_test.yaml: -------------------------------------------------------------------------------- 1 | ### data configs 2 | 3 | data: 'weather' 4 | data_path: 'data/weather/weather_266cities_12tasks_258days.pth' 5 | tasks: ['tMin_Global', 'tMax_Global', 'humidity_Global', 'precip_Global', 'cloud_Global', 'dew_Global'] # list of task names 6 | task_types: {'tMin_Global': 'continuous', 'tMax_Global': 'continuous', 'humidity_Global': 'continuous', 'precip_Global': 'continuous', 7 | 'cloud_Global': 'continuous', 'wind_Global': 'continuous', 'dew_Global': 'continuous'} # whether each task is continuous or discrete 8 | 9 | dim_x: 1 # input dimension 10 | dim_ys: {'tMin_Global': 1, 'tMax_Global': 1, 'humidity_Global': 1, 'precip_Global': 1, 'cloud_Global': 1, 11 | 'dew_Global': 1, 'wind_Global': 1, 'pressure_Global': 1, 'ozone_Global': 1, 'uv_Global': 1} # output dimensions or channels 12 | 13 | split_ratio: [200, 30, 36] 14 | num_workers: 4 15 | 16 | colors: {'tMin_Global': 'r', 'tMax_Global': 'g', 'humidity_Global': 'b', 'precip_Global': 'c', 'cloud_Global': 'm', 17 | 'dew_Global': 'y', 'wind_Global': 'tab:purple', 'pressure_Global': 'tab:orange', 'ozone_Global': 'tab:brown', 'uv_Global': 'tab:pink'} 18 | 19 | global_batch_size: 16 # number of datasets (multi-task functions) in a batch 20 | 21 | ### validation configs 22 | 23 | ts_valid: 258 24 | 25 | ### test configs 26 | 27 | ts_test: 258 28 | ns_G: 5 # number of global sampling (for JTP, MTP) 29 | ns_T: 5 # number of per-task samplings (for STP, MTP) 30 | 31 | imputer_path: 'experiments/runs_weather/stp/checkpoints/best_error.pth' # imputer checkpoint path (for JTP) 32 | 33 | ### checkpointing configs 34 | 35 | eval_dir: 'runs_weather' # directory where the models to evaluate are stored -------------------------------------------------------------------------------- /configs/weather/config_train.yaml: -------------------------------------------------------------------------------- 1 | ### data configs 2 | 3 | data: 'weather' 4 | data_path: 'data/weather/weather_266cities_12tasks_258days.pth' 5 | tasks: ['tMin_Global', 'tMax_Global', 'humidity_Global', 'precip_Global', 'cloud_Global', 'dew_Global'] # list of task names 6 | task_types: {'tMin_Global': 'continuous', 'tMax_Global': 'continuous', 'humidity_Global': 'continuous', 'precip_Global': 'continuous', 7 | 'cloud_Global': 'continuous', 'wind_Global': 'continuous', 'dew_Global': 'continuous'} # whether each task is continuous or discrete 8 | 9 | dim_x: 1 # input dimension 10 | dim_ys: {'tMin_Global': 1, 'tMax_Global': 1, 'humidity_Global': 1, 'precip_Global': 1, 'cloud_Global': 1, 11 | 'dew_Global': 1, 'wind_Global': 1, 'pressure_Global': 1, 'ozone_Global': 1, 'uv_Global': 1} # output dimensions or channels 12 | 13 | split_ratio: [200, 30, 36] 14 | num_workers: 4 15 | 16 | colors: {'tMin_Global': 'r', 'tMax_Global': 'g', 'humidity_Global': 'b', 'precip_Global': 'c', 'cloud_Global': 'm', 17 | 'dew_Global': 'y', 'wind_Global': 'tab:purple', 'pressure_Global': 'tab:orange', 'ozone_Global': 'tab:brown', 'uv_Global': 'tab:pink'} 18 | 19 | 20 | ### training configs 21 | 22 | n_steps: 50000 # total training steps 23 | global_batch_size: 16 # number of datasets (multi-task functions) in a batch 24 | 25 | lr: 0.0001 # learning rate 26 | lr_schedule: 'sqroot' 27 | lr_warmup: 1000 28 | 29 | beta_G: 1 # beta coefficient for global kld 30 | beta_G_schedule: 'linear_warmup' 31 | beta_G_warmup: 10000 32 | 33 | beta_T: 1 # beta coefficient for per-task klds 34 | beta_T_schedule: 'linear_warmup' 35 | beta_T_warmup: 10000 36 | 37 | gamma_train: 0.5 # missing rate 38 | cs_range_train: [10, 30] # context size, null means default range (len(tasks), ts // 2) 39 | ts_train: 200 # target size 40 | 41 | 42 | ### validation configs 43 | 44 | cs_valid: 20 45 | ts_valid: 258 46 | ns_G: 5 # number of global sampling 47 | ns_T: 5 # number of per-task samplings 48 | 49 | gamma_valid: 0.5 # missing rate 50 | imputer_path: 'experiments/runs_weather/stp/checkpoints/best_error.pth' # imputer checkpoint path (for JTP) 51 | 52 | 53 | ### model configs 54 | 55 | dim_hidden: 64 # width of the networks, serves as a basic unit in all layers except the input & output heads (and also the latent dimensions) 56 | module_sizes: [3, 3, 2, 5] # depth of the networks: (element-wise encoder, intra-task attention, inter-task attention, element-wise decoder) 57 | pma: True # whether to use PMA pooling rather than average pooling 58 | 59 | attn_config: 60 | act_fn: 'gelu' 61 | ln: True # layernorm in attentions and mlps 62 | dr: 0.1 # dropout in mlps 63 | n_heads: 4 # number of attention heads 64 | epsilon: 0.1 # minimum standard deviation for Normal latent variables 65 | 66 | 67 | ### logging configs 68 | 69 | log_iter: 100 # interval between tqdm and tensorboard logging of training metrics 70 | val_iter: 1000 # interval between validation and tensorboard logging of validation metrics 71 | save_iter: 1000 # interval between checkpointing 72 | log_dir: 'runs_weather' # directory for saving checkpoints and logs 73 | -------------------------------------------------------------------------------- /data/generate_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import torch 4 | 5 | 6 | tasks = ['sine', 'tanh', 'sigmoid', 'gaussian'] 7 | activations = { 8 | 'sine': lambda x: torch.sin(x), 9 | 'tanh': lambda x: torch.tanh(x), 10 | 'sigmoid': lambda x: torch.sigmoid(x), 11 | 'gaussian': lambda x: torch.exp(-x.pow(2)) 12 | } 13 | colors = { 14 | 'sine': 'r', 15 | 'tanh': 'g', 16 | 'sigmoid': 'b', 17 | 'gaussian': 'c' 18 | } 19 | 20 | def generate_data(n_datasets, n_examples, task_noise=False, independent=False): 21 | meta_info = {} 22 | X = [] 23 | Y = {task: [] for task in tasks} 24 | for dataset in range(n_datasets): 25 | meta_info[dataset] = {} 26 | 27 | x = 5*(torch.rand(n_examples, 1)*2 - 1) # -5 to +5 28 | X.append(x) 29 | 30 | if not independent: 31 | a = torch.exp(torch.rand(1, 1) - 0.5) # e^-0.5 to e^0.5 32 | w = torch.exp(torch.rand(1, 1) - 0.5) # e^-0.5 to e^0.5 33 | b = 4*torch.rand(1, 1) - 2 # -2 to 2 34 | c = 4*torch.rand(1, 1) - 2 # -2 to 2 35 | 36 | if not task_noise: 37 | meta_info[dataset]['a'] = a 38 | meta_info[dataset]['w'] = w 39 | meta_info[dataset]['b'] = b 40 | meta_info[dataset]['c'] = c 41 | 42 | for task in tasks: 43 | if task_noise: 44 | a_ = a * torch.exp(torch.randn(1, 1) * 0.1) 45 | w_ = w * torch.exp(torch.randn(1, 1) * 0.1) 46 | b_ = b + torch.randn(1, 1) * 0.2 47 | c_ = c + torch.randn(1, 1) * 0.2 48 | meta_info[dataset][task] = {'a': a_, 'w': w_, 'b': b_, 'c': c_} 49 | elif independent: 50 | a_ = torch.exp(torch.rand(1, 1) - 0.5) # e^-0.5 to e^0.5 51 | w_ = torch.exp(torch.rand(1, 1) - 0.5) # e^-0.5 to e^0.5 52 | b_ = 4*torch.rand(1, 1) - 2 # -2 to 2 53 | c_ = 4*torch.rand(1, 1) - 2 # -2 to 2 54 | meta_info[dataset][task] = {'a': a_, 'w': w_, 'b': b_, 'c': c_} 55 | else: 56 | a_, w_, b_, c_ = a, w, b, c 57 | 58 | y = a_ * activations[task](w_ * x + b_) + c_ 59 | Y[task].append(y) 60 | 61 | for dataset in range(n_datasets): 62 | ids = torch.randperm(len(X[dataset])) 63 | X[dataset] = X[dataset][ids] 64 | for task in tasks: 65 | Y[task][dataset] = Y[task][dataset][ids] 66 | 67 | X = torch.stack(X) 68 | Y = {task: torch.stack(Y[task]) for task in tasks} 69 | 70 | return X, Y, meta_info 71 | 72 | 73 | if __name__ == '__main__': 74 | import argparse 75 | import matplotlib.pyplot as plt 76 | 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--n_datasets', type=int, default=1000) 79 | parser.add_argument('--n_examples', type=int, default=200) 80 | parser.add_argument('--task_noise', '-tn', default=False, action='store_true') 81 | parser.add_argument('--independent', '-ind', default=False, action='store_true') 82 | args = parser.parse_args() 83 | 84 | X, Y, meta_info = generate_data(args.n_datasets, args.n_examples, args.task_noise, args.independent) 85 | 86 | if args.task_noise: 87 | data_dir = 'synthetic_tn' 88 | elif args.independent: 89 | data_dir = 'synthetic_ind' 90 | else: 91 | data_dir = 'synthetic' 92 | name = ','.join(tasks) + f'_N{args.n_datasets}_n{args.n_examples}' 93 | os.makedirs(data_dir, exist_ok=True) 94 | torch.save((X, Y, meta_info), os.path.join(data_dir, f'{name}.pth')) 95 | 96 | plt.figure(figsize=(40, 12)) 97 | for dataset in range(10): 98 | plt.subplot(2, 5, dataset+1) 99 | x = torch.linspace(-5, 5, args.n_examples).unsqueeze(1) 100 | 101 | for task in Y: 102 | if args.task_noise or args.independent: 103 | a = meta_info[dataset][task]['a'] 104 | w = meta_info[dataset][task]['w'] 105 | b = meta_info[dataset][task]['b'] 106 | c = meta_info[dataset][task]['c'] 107 | else: 108 | a = meta_info[dataset]['a'] 109 | w = meta_info[dataset]['w'] 110 | b = meta_info[dataset]['b'] 111 | c = meta_info[dataset]['c'] 112 | 113 | y = a * activations[task](w * x + b) + c 114 | plt.plot(x, y, color=colors[task]) 115 | 116 | plt.savefig(os.path.join(data_dir, f'{name}.png')) 117 | -------------------------------------------------------------------------------- /data/synthetic/sine,tanh,sigmoid,gaussian_N1000_n200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/multi_task_neural_processes/a76216528249622ae1a1bf42c0092aaae5bde344/data/synthetic/sine,tanh,sigmoid,gaussian_N1000_n200.png -------------------------------------------------------------------------------- /data/synthetic/sine,tanh,sigmoid,gaussian_N1000_n200.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/multi_task_neural_processes/a76216528249622ae1a1bf42c0092aaae5bde344/data/synthetic/sine,tanh,sigmoid,gaussian_N1000_n200.pth -------------------------------------------------------------------------------- /data/weather/weather_266cities_12tasks_258days.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/multi_task_neural_processes/a76216528249622ae1a1bf42c0092aaae5bde344/data/weather/weather_266cities_12tasks_258days.pth -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from .utils import get_data_iterator, TrainCollator, to_device 4 | from .synthetic import SyntheticTrainDataset, SyntheticTNTestDataset 5 | from .weather import WeatherTrainDataset, WeatherTestDataset 6 | 7 | 8 | def load_data(config, device, split='trainval'): 9 | ''' 10 | Load train & valid or test data and return the iterator & loader. 11 | ''' 12 | if config.data == 'synthetic': 13 | TrainDataset, TestDataset = SyntheticTrainDataset, SyntheticTNTestDataset 14 | elif config.data == 'weather': 15 | TrainDataset, TestDataset = WeatherTrainDataset, WeatherTestDataset 16 | else: 17 | raise NotImplementedError 18 | 19 | 20 | # load train iterator 21 | if split == 'train' or split == 'trainval': 22 | train_data = TrainDataset(config.data_path, 'train', config.tasks, config.split_ratio, config.ts_train) 23 | train_collator = TrainCollator(config.ts_train, config.cs_range_train, config.gamma_train, config.tasks) 24 | train_loader = DataLoader(train_data, batch_size=config.global_batch_size, 25 | shuffle=True, pin_memory=(device.type == 'cuda'), 26 | drop_last=True, num_workers=config.num_workers, collate_fn=train_collator) 27 | train_iterator = get_data_iterator(train_loader, device) 28 | 29 | # load valid loader 30 | if split == 'valid' or split == 'trainval': 31 | valid_data = TestDataset(config.data_path, 'valid', config.tasks, config.split_ratio, 32 | config.ts_valid, config.cs_valid, config.gamma_valid, config.seed) 33 | valid_loader = DataLoader(valid_data, batch_size=config.global_batch_size, 34 | shuffle=False, pin_memory=(device.type == 'cuda'), 35 | drop_last=False, num_workers=config.num_workers) 36 | 37 | # load test loader 38 | if split == 'test': 39 | test_data = TestDataset(config.data_path, 'test', config.tasks, config.split_ratio, 40 | config.ts_test, config.cs_test, config.gamma_test, config.seed) 41 | test_loader = DataLoader(test_data, batch_size=config.global_batch_size, 42 | shuffle=False, pin_memory=(device.type == 'cuda'), 43 | drop_last=False, num_workers=config.num_workers) 44 | 45 | # return 46 | if split == 'trainval': 47 | return train_iterator, valid_loader 48 | elif split == 'train': 49 | return train_iterator 50 | elif split == 'valid': 51 | return valid_loader 52 | elif split == 'test': 53 | return test_loader 54 | else: 55 | raise NotImplementedError 56 | -------------------------------------------------------------------------------- /dataset/synthetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torch.distributions import Bernoulli 6 | 7 | from .utils import mask_labels 8 | 9 | 10 | activations = { 11 | 'sine': lambda x: torch.sin(x), 12 | 'tanh': lambda x: torch.tanh(x), 13 | 'sigmoid': lambda x: torch.sigmoid(x), 14 | 'gaussian': lambda x: torch.exp(-x.pow(2)) 15 | } 16 | 17 | 18 | class SyntheticDataset(Dataset): 19 | def __init__(self, data_path, split, tasks, split_ratio, target_size): 20 | assert len(split_ratio) == 3 21 | self.tasks = tasks 22 | # parse data root 23 | if os.path.isdir(data_path): 24 | self.root = data_path 25 | else: 26 | self.root = os.path.split(data_path)[0] 27 | 28 | self.n_functions = 1000 29 | self.n_points = 200 30 | self.target_size = target_size 31 | assert self.target_size <= self.n_points or split == 'test' 32 | 33 | # split function indices 34 | function_idxs = list(range(self.n_functions)) 35 | if type(split_ratio[0]) is float: 36 | cut1 = int(len(function_idxs)*split_ratio[0]) 37 | cut2 = cut1 + int(len(function_idxs)*split_ratio[1]) 38 | cut3 = cut2 + int(len(function_idxs)*split_ratio[2]) 39 | else: 40 | cut1 = split_ratio[0] 41 | cut2 = cut1 + split_ratio[1] 42 | cut3 = cut2 + split_ratio[2] 43 | 44 | # split functions 45 | if split == 'train': 46 | self.function_idxs = function_idxs[:cut1] 47 | elif split == 'valid': 48 | self.function_idxs = function_idxs[cut1:cut2] 49 | elif split == 'test': 50 | self.function_idxs = function_idxs[cut2:cut3] 51 | else: 52 | raise NotImplementedError 53 | 54 | # load data 55 | X, Y, meta_info = torch.load(data_path) 56 | self.X = X[self.function_idxs] 57 | self.Y = {task: Y[task][self.function_idxs] for task in tasks} 58 | self.meta_info = meta_info 59 | 60 | def __len__(self): 61 | return len(self.function_idxs) 62 | 63 | 64 | class SyntheticTrainDataset(SyntheticDataset): 65 | def __getitem__(self, idx): 66 | ''' 67 | Returns complete target. 68 | ''' 69 | # generate point permutation 70 | pp = torch.randperm(self.n_points)[:self.target_size] 71 | # pp = torch.randperm(self.n_points) 72 | 73 | # get permuted data 74 | X_D = self.X[idx][pp].clone() 75 | Y_D = {task: self.Y[task][idx][pp].clone() for task in self.tasks} 76 | 77 | return X_D, Y_D 78 | 79 | 80 | class SyntheticTestDataset(SyntheticDataset): 81 | def __init__(self, data_path, split, tasks, split_ratio, target_size, 82 | context_size, gamma, seed): 83 | ''' 84 | Load or generate random objects. 85 | ''' 86 | super().__init__(data_path, split, tasks, split_ratio, target_size) 87 | self.tasks = tasks 88 | self.context_size = context_size 89 | 90 | # generate or load point permutations of size (n_functions, n_points) 91 | os.makedirs(os.path.join(self.root, 'point_permutations'), exist_ok=True) 92 | pp_path = os.path.join(self.root, 'point_permutations', 'pp_seed{}_{}.pth'.format(seed, split)) 93 | if os.path.exists(pp_path): 94 | self.pp = torch.load(pp_path) 95 | else: 96 | self.pp = [torch.randperm(self.n_points) for _ in range(len(self.function_idxs))] 97 | torch.save(self.pp, pp_path) 98 | 99 | # generate label mask of size (n_functions, n_points, n_tasks) with missing rate gamma 100 | self.gamma = gamma 101 | if self.gamma > 0: 102 | os.makedirs(os.path.join(self.root, 'label_masks'), exist_ok=True) 103 | mask_path = os.path.join(self.root, 'label_masks', 104 | f'mask_T{len(self.tasks)}_gamma{gamma}_seed{seed}_{split}.pth') 105 | if os.path.exists(mask_path): 106 | self.mask = torch.load(mask_path) 107 | else: 108 | self.mask = Bernoulli(torch.tensor(gamma)).sample((len(self.function_idxs), 109 | self.n_points, len(self.tasks))).bool() 110 | self.mask[:, :len(self.tasks)] *= ~torch.eye(len(self.tasks)).bool() # guarantee at least one label per task 111 | torch.save(self.mask, mask_path) 112 | else: 113 | self.mask = None 114 | 115 | # generate target data on a uniform grid 116 | self.generate_target_data(target_size, tasks) 117 | 118 | def generate_target_data(self, target_size, tasks): 119 | self.X_D = torch.stack([torch.linspace(-5, 5, target_size).unsqueeze(1) for _ in self.function_idxs]) 120 | self.Y_D = { 121 | task: torch.stack([ 122 | self.meta_info[function_idx]['a'] * \ 123 | activations[task](self.meta_info[function_idx]['w']*self.X_D[i] + self.meta_info[function_idx]['b']) + \ 124 | self.meta_info[function_idx]['c'] 125 | for i, function_idx in enumerate(self.function_idxs) 126 | ]) 127 | for task in tasks 128 | } 129 | 130 | def __getitem__(self, idx): 131 | ''' 132 | Returns incomplete context, complete target, and complete context labels. 133 | ''' 134 | # pick context 135 | context_idxs = self.pp[idx][:self.context_size] 136 | 137 | X_C = self.X[idx][context_idxs].clone() 138 | Y_C = {task: self.Y[task][idx][context_idxs].clone() for task in self.tasks} 139 | 140 | # predict all points as target 141 | X_D = self.X_D[idx].clone() 142 | Y_D = {task: self.Y_D[task][idx].clone() for task in self.tasks} 143 | 144 | # random drop with mask 145 | Y_C_comp = {task: Y_C[task].clone() for task in Y_C} # copy unmasked context labels 146 | if self.gamma > 0: 147 | mask_labels(Y_C, self.mask[idx, :self.context_size], self.tasks) 148 | 149 | # gt parameters 150 | gt_params = self.meta_info[self.function_idxs[idx]] 151 | 152 | return X_C, Y_C, X_D, Y_D, Y_C_comp, gt_params 153 | 154 | 155 | class SyntheticTNTestDataset(SyntheticTestDataset): 156 | def generate_target_data(self, target_size, tasks): 157 | self.X_D = torch.stack([torch.linspace(-5, 5, target_size).unsqueeze(1) for _ in self.function_idxs]) 158 | self.Y_D = { 159 | task: torch.stack([ 160 | self.meta_info[function_idx][task]['a'] * \ 161 | activations[task](self.meta_info[function_idx][task]['w']*self.X_D[i] + \ 162 | self.meta_info[function_idx][task]['b']) + \ 163 | self.meta_info[function_idx][task]['c'] 164 | for i, function_idx in enumerate(self.function_idxs) 165 | ]) 166 | for task in tasks 167 | } -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.distributions import Bernoulli 5 | 6 | 7 | class TrainCollator: 8 | def __init__(self, target_size, context_size_range, gamma, tasks): 9 | self.target_size = target_size 10 | self.context_size_range = context_size_range 11 | self.gamma = gamma 12 | self.tasks = tasks 13 | 14 | def __call__(self, batch): 15 | X_D, Y_D = zip(*batch) 16 | X_D = torch.stack(X_D) 17 | Y_D = {task: torch.stack([Y_D_i[task] for Y_D_i in Y_D]) for task in Y_D[0]} 18 | batch_size = len(X_D) 19 | 20 | # sample context size 21 | if self.context_size_range is None: 22 | context_size = len(Y_D) + np.random.choice(self.target_size//2 - len(Y_D)) 23 | else: 24 | if self.context_size_range[0] == self.context_size_range[1]: 25 | context_size = self.context_size_range[0] 26 | else: 27 | context_size = self.context_size_range[0] + \ 28 | np.random.choice(self.context_size_range[1] - self.context_size_range[0]) 29 | 30 | # sample context 31 | context_idxs = torch.randperm(self.target_size)[:context_size] 32 | X_C = X_D[:, context_idxs].clone() 33 | Y_C = {task: Y_D[task][:, context_idxs].clone() for task in Y_D} 34 | 35 | # simuate incompleteness 36 | if self.gamma > 0: 37 | # generate mask 38 | mask = Bernoulli(torch.tensor(self.gamma)).sample((batch_size, context_size, len(self.tasks))).bool() 39 | mask[:, :len(self.tasks)] *= ~torch.eye(len(self.tasks)).bool() # guarantee at least one label per task 40 | 41 | # mask labels 42 | mask_labels(Y_C, mask, self.tasks) 43 | 44 | return X_C, Y_C, X_D, Y_D 45 | 46 | 47 | def mask_labels(Y, mask, tasks): 48 | assert len(tasks) == mask.shape[-1] 49 | for t_idx, task in enumerate(tasks): 50 | # fill the masked region with -1 (int) or nan (float) 51 | assert task in Y 52 | if Y[task].dtype == torch.int64: 53 | masked_tensor = -torch.ones_like(Y[task]) 54 | else: 55 | masked_tensor = float('nan')*Y[task] 56 | 57 | Y[task] = torch.where(mask[..., t_idx].unsqueeze(-1).expand_as(Y[task]), 58 | masked_tensor, 59 | Y[task]) 60 | 61 | 62 | def to_device(data, device): 63 | ''' 64 | Load data with arbitrary structure on device. 65 | ''' 66 | def to_device_wrapper(data): 67 | if isinstance(data, torch.Tensor): 68 | return data.to(device) 69 | elif isinstance(data, tuple): 70 | return tuple(map(to_device_wrapper, data)) 71 | elif isinstance(data, list): 72 | return list(map(to_device_wrapper, data)) 73 | elif isinstance(data, dict): 74 | return {key: to_device_wrapper(data[key]) for key in data} 75 | else: 76 | raise NotImplementedError 77 | 78 | return to_device_wrapper(data) 79 | 80 | 81 | def get_data_iterator(data_loader, device): 82 | ''' 83 | Iterator wrapper for dataloader 84 | ''' 85 | def get_batch(): 86 | while True: 87 | for batch in data_loader: 88 | yield to_device(batch, device) 89 | return get_batch() -------------------------------------------------------------------------------- /dataset/weather.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torch.distributions import Bernoulli 6 | 7 | from .utils import mask_labels 8 | 9 | 10 | class WeatherDataset(Dataset): 11 | def __init__(self, data_path, split, tasks, split_ratio, target_size): 12 | ''' 13 | Train dataset samples (X_D, Y_D). 14 | ''' 15 | super().__init__() 16 | 17 | self.tasks = tasks 18 | 19 | # parse data root 20 | if os.path.isdir(data_path): 21 | self.root = data_path 22 | else: 23 | self.root = os.path.split(data_path)[0] 24 | 25 | data = torch.load(data_path) 26 | 27 | # extract number of datasets and input points (timestamps). 28 | keys = sorted(list(data.keys())) 29 | self.n_functions = len(keys) 30 | self.n_points = data[keys[0]][tasks[0]].shape[1] 31 | self.target_size = target_size 32 | assert self.target_size <= self.n_points 33 | self.split = split 34 | 35 | # split function indices 36 | function_idxs = list(range(self.n_functions)) 37 | if type(split_ratio[0]) is float: 38 | cut1 = int(len(function_idxs)*split_ratio[0]) 39 | cut2 = cut1 + int(len(function_idxs)*split_ratio[1]) 40 | cut3 = cut2 + int(len(function_idxs)*split_ratio[2]) 41 | else: 42 | cut1 = split_ratio[0] 43 | cut2 = cut1 + split_ratio[1] 44 | cut3 = cut2 + split_ratio[2] 45 | 46 | # split functions 47 | if split == 'train': 48 | self.function_idxs = function_idxs[:cut1] 49 | elif split == 'valid': 50 | self.function_idxs = function_idxs[cut1:cut2] 51 | elif split == 'test': 52 | self.function_idxs = function_idxs[cut2:cut3] 53 | else: 54 | raise NotImplementedError 55 | 56 | # construct input and output tensors. 57 | self.X = torch.linspace(0, 1, self.n_points).unsqueeze(1) 58 | self.Y = {} 59 | for task in tasks: 60 | self.Y[task] = torch.stack([torch.from_numpy(data[keys[idx]][task]).transpose(0, 1).float() 61 | for idx in self.function_idxs]) 62 | 63 | def __len__(self): 64 | return len(self.function_idxs) 65 | 66 | 67 | class WeatherTrainDataset(WeatherDataset): 68 | def __getitem__(self, idx): 69 | ''' 70 | Returns complete target. 71 | ''' 72 | # generate point permutation 73 | pp = torch.randperm(self.n_points)[:self.target_size] 74 | # pp = torch.randperm(self.n_points) 75 | 76 | # get permuted data 77 | X_D = self.X[pp].clone() 78 | Y_D = {task: self.Y[task][idx][pp].clone() for task in self.tasks} 79 | 80 | return X_D, Y_D 81 | 82 | 83 | class WeatherTestDataset(WeatherDataset): 84 | def __init__(self, data_path, split, tasks, split_ratio, target_size, 85 | context_size, gamma, seed): 86 | ''' 87 | Load or generate random objects. 88 | ''' 89 | super().__init__(data_path, split, tasks, split_ratio, target_size) 90 | self.tasks = tasks 91 | self.context_size = context_size 92 | 93 | # generate or load point permutations of size (n_functions, n_points) 94 | os.makedirs(os.path.join(self.root, 'point_permutations'), exist_ok=True) 95 | pp_path = os.path.join(self.root, 'point_permutations', 'pp_seed{}_{}.pth'.format(seed, split)) 96 | if os.path.exists(pp_path): 97 | self.pp = torch.load(pp_path) 98 | else: 99 | self.pp = [torch.randperm(self.n_points) for _ in range(len(self.function_idxs))] 100 | torch.save(self.pp, pp_path) 101 | 102 | # generate label mask of size (n_functions, n_points, n_tasks) with missing rate gamma 103 | self.gamma = gamma 104 | if self.gamma > 0: 105 | os.makedirs(os.path.join(self.root, 'label_masks'), exist_ok=True) 106 | mask_path = os.path.join(self.root, 'label_masks', 107 | f'mask_T{len(self.tasks)}_gamma{gamma}_seed{seed}_{split}.pth') 108 | if os.path.exists(mask_path): 109 | self.mask = torch.load(mask_path) 110 | else: 111 | self.mask = Bernoulli(torch.tensor(gamma)).sample((len(self.function_idxs), 112 | self.n_points, len(self.tasks))).bool() 113 | self.mask[:, :len(self.tasks)] *= ~torch.eye(len(self.tasks)).bool() # guarantee at least one label per task 114 | torch.save(self.mask, mask_path) 115 | else: 116 | self.mask = None 117 | 118 | def __getitem__(self, idx): 119 | ''' 120 | Returns incomplete context, complete target, and complete context labels. 121 | ''' 122 | # pick context 123 | context_idxs = self.pp[idx][:self.context_size] 124 | 125 | X_C = self.X[context_idxs].clone() 126 | Y_C = {task: self.Y[task][idx][context_idxs].clone() for task in self.tasks} 127 | 128 | # predict all points as target 129 | X_D = self.X.clone() 130 | Y_D = {task: self.Y[task][idx].clone() for task in self.tasks} 131 | 132 | # random drop with mask 133 | Y_C_comp = {task: Y_C[task].clone() for task in Y_C} # copy unmasked context labels 134 | if self.gamma > 0: 135 | mask_labels(Y_C, self.mask[idx, :self.context_size], self.tasks) 136 | 137 | return X_C, Y_C, X_D, Y_D, Y_C_comp 138 | 139 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import copy 4 | import yaml 5 | from easydict import EasyDict 6 | 7 | import torch 8 | 9 | from dataset import load_data 10 | from model import get_model, get_imputer 11 | from train import train_step, evaluate, configure_experiment, get_schedulers, Saver 12 | 13 | 14 | # ENVIRONMENTAL SETTINGSs 15 | # to prevent over-threading 16 | torch.set_num_threads(1) 17 | 18 | # parse arguments 19 | from argument import args 20 | 21 | # load config 22 | with open(os.path.join('configs', args.data, 'config_train.yaml'), 'r') as f: 23 | config = EasyDict(yaml.safe_load(f)) 24 | 25 | # configure settings, logging and checkpointing paths 26 | logger, save_dir, log_keys = configure_experiment(config, args) 27 | config_copy = copy.deepcopy(config) 28 | 29 | # set device 30 | device = torch.device('cuda') 31 | 32 | # load train and valid data 33 | train_iterator, valid_loader = load_data(config, device, split='trainval') 34 | 35 | # model, optimizer, and schedulers 36 | model = get_model(config, device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 38 | lr_scheduler, beta_G_scheduler, beta_T_scheduler = get_schedulers(optimizer, config) 39 | 40 | # load pretrained model as an imputer if needed 41 | imputer, config_imputer = get_imputer(config, device) 42 | 43 | # checkpoint saver 44 | saver = Saver(model, save_dir, config_copy) 45 | 46 | # MAIN LOOP 47 | pbar = tqdm.tqdm(total=config.n_steps, initial=0, 48 | bar_format="{desc:<5}{percentage:3.0f}%|{bar:10}{r_bar}") 49 | while logger.global_step < config.n_steps: 50 | # train step 51 | train_data = next(train_iterator) 52 | train_step(model, optimizer, config, logger, *train_data) 53 | 54 | # schedulers step 55 | lr_scheduler.step() 56 | if config.model in ['mtp', 'jtp', 'mtp_s']: 57 | beta_G_scheduler.step() 58 | if config.model in ['mtp', 'stp', 'mtp_s']: 59 | beta_T_scheduler.step() 60 | 61 | # logging 62 | if logger.global_step % config.log_iter == 0: 63 | logger.log_values(log_keys, pbar, 'train', logger.global_step) 64 | logger.reset(log_keys) 65 | logger.writer.add_scalar('train/lr', lr_scheduler.lr, logger.global_step) 66 | if config.model in ['mtp', 'jtp', 'mtp_s']: 67 | logger.writer.add_scalar('train/beta_G', config.beta_G, logger.global_step) 68 | if config.model in ['mtp', 'stp', 'mtp_s']: 69 | logger.writer.add_scalar('train/beta_T', config.beta_T, logger.global_step) 70 | 71 | # evaluate and visualize 72 | if logger.global_step % config.val_iter == 0: 73 | valid_nlls, valid_errors = evaluate(model, valid_loader, device, config, logger, imputer, config_imputer, tag='valid') 74 | saver.save_best(model, valid_nlls, valid_errors, logger.global_step) 75 | 76 | # save model 77 | if logger.global_step % config.save_iter == 0: 78 | # save current model 79 | saver.save(model, valid_nlls, valid_errors, logger.global_step, f'step_{logger.global_step:06d}.pth') 80 | 81 | 82 | pbar.update(1) 83 | 84 | # Save Model and Terminate. 85 | saver.save(model, valid_nlls, valid_errors, logger.global_step, 'last.pth') 86 | 87 | pbar.close() 88 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .mtnp import MTP, STP, JTP, SharedMTP 5 | 6 | 7 | def get_model(config, device): 8 | if config.model == 'mtp': 9 | return MTP(config).to(device) 10 | elif config.model == 'stp': 11 | return STP(config).to(device) 12 | elif config.model == 'jtp': 13 | return JTP(config).to(device) 14 | elif config.model == 'mtp_s': 15 | return SharedMTP(config).to(device) 16 | else: 17 | raise NotImplementedError 18 | 19 | 20 | def get_imputer(config, device): 21 | if config.model == 'jtp' and config.gamma_valid > 0: 22 | assert os.path.exists(config.imputer_path) 23 | ckpt_imputer = torch.load(config.imputer_path) 24 | params_imputer = ckpt_imputer['model'] 25 | config_imputer = ckpt_imputer['config'] 26 | imputer = get_model(config_imputer, device) 27 | imputer.load_state_dict_(params_imputer) 28 | else: 29 | imputer = config_imputer = None 30 | 31 | return imputer, config_imputer -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from .utils import masked_forward 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__(self, dim, num_heads=4, act_fn=nn.ReLU, ln=False, dr=0.1): 10 | super().__init__() 11 | self.dim = dim 12 | self.num_heads = num_heads 13 | self.dim_split = dim // num_heads 14 | self.fc_q = nn.Linear(dim, dim, bias=False) 15 | self.fc_k = nn.Linear(dim, dim, bias=False) 16 | self.fc_v = nn.Linear(dim, dim, bias=False) 17 | self.fc_o = nn.Linear(dim, dim, bias=False) 18 | 19 | self.activation = act_fn() 20 | self.attn_dropout = nn.Dropout(dr) 21 | self.residual_dropout1 = nn.Dropout(dr) 22 | self.residual_dropout2 = nn.Dropout(dr) 23 | if ln: 24 | self.ln1 = nn.LayerNorm(dim) 25 | self.ln2 = nn.LayerNorm(dim) 26 | 27 | def forward(self, Q, K, V=None, mask_Q=None, mask_K=None, get_attn=False): 28 | if V is None: V = K 29 | 30 | if mask_Q is not None: 31 | Q = Q.clone().masked_fill(mask_Q.unsqueeze(-1), 0) 32 | else: 33 | mask_Q = torch.zeros(*Q.size()[:2], device=Q.device) 34 | 35 | if mask_K is not None: 36 | K = K.clone().masked_fill(mask_K.unsqueeze(-1), 0) 37 | V = V.clone().masked_fill(mask_K.unsqueeze(-1), 0) 38 | else: 39 | mask_K = torch.zeros(*K.size()[:2], device=K.device) 40 | 41 | Q = self.fc_q(Q) 42 | K = self.fc_k(K) 43 | V = self.fc_v(V) 44 | 45 | Q_ = torch.cat(Q.split(self.dim_split, 2), 0) 46 | K_ = torch.cat(K.split(self.dim_split, 2), 0) 47 | V_ = torch.cat(V.split(self.dim_split, 2), 0) 48 | 49 | mask = ~((1 - mask_Q.unsqueeze(-1).float()).bmm((1 - mask_K.unsqueeze(-1).float()).transpose(1, 2)).bool().repeat(self.num_heads, 1, 1)) 50 | 51 | A = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim) 52 | A = A.masked_fill(mask, -1e38) 53 | A = torch.softmax(A, 2) 54 | A = A.masked_fill(mask, 0) 55 | 56 | A = self.attn_dropout(A) 57 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 58 | 59 | O = Q + self.residual_dropout1(O) 60 | O = O if getattr(self, 'ln1', None) is None else masked_forward(self.ln1, O, mask_Q.bool(), self.dim) 61 | O = O + self.residual_dropout2(self.activation(masked_forward(self.fc_o, O, mask_Q.bool(), self.dim))) 62 | O = O if getattr(self, 'ln2', None) is None else masked_forward(self.ln2, O, mask_Q.bool(), self.dim) 63 | 64 | if get_attn: 65 | return O, A 66 | else: 67 | return O 68 | 69 | 70 | class PMA(nn.Module): 71 | def __init__(self, dim, num_heads, num_seeds, act_fn='relu', ln=False, dr=0.1): 72 | super().__init__() 73 | act_fn = nn.GELU if act_fn == 'gelu' else nn.ReLU 74 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 75 | nn.init.xavier_uniform_(self.S) 76 | self.attn = Attention(dim, num_heads, act_fn=act_fn, ln=ln, dr=dr) 77 | 78 | def forward(self, X, mask=None): 79 | return self.attn(self.S.repeat(X.size(0), 1, 1), X, mask_K=mask) 80 | 81 | 82 | class Mean(nn.Module): 83 | def __init__(self, dim): 84 | super().__init__() 85 | self.dim = dim 86 | 87 | def forward(self, x): 88 | return x.mean(dim=self.dim, keepdim=True) 89 | 90 | 91 | class SelfAttention(nn.Module): 92 | def __init__(self, dim, n_layers, n_heads, act_fn='relu', ln=False, dr=0.1): 93 | super().__init__() 94 | act_fn = nn.GELU if act_fn == 'gelu' else nn.ReLU 95 | 96 | self.attentions = nn.ModuleList([Attention(dim, n_heads, act_fn=act_fn, ln=ln, dr=dr) 97 | for _ in range(n_layers)]) 98 | 99 | def forward(self, Q, mask=None, **kwargs): 100 | for attention in self.attentions: 101 | Q = attention(Q, Q, Q, mask_Q=mask, mask_K=mask, **kwargs) 102 | 103 | return Q 104 | 105 | 106 | class CrossAttention(nn.Module): 107 | def __init__(self, dim_q, dim_k, dim_v, n_layers, n_heads, act_fn='relu', ln=False, dr=0.1): 108 | super().__init__() 109 | act_fn = nn.GELU if act_fn == 'gelu' else nn.ReLU 110 | 111 | self.query_proj = nn.Linear(dim_q, dim_v) 112 | self.key_proj = nn.Linear(dim_k, dim_v) 113 | self.attentions = nn.ModuleList([Attention(dim_v, n_heads, act_fn=act_fn, ln=ln, dr=dr) 114 | for _ in range(n_layers)]) 115 | 116 | def forward(self, Q, K, V, **kwargs): 117 | Q = self.query_proj(Q) 118 | K = self.key_proj(K) 119 | for attention in self.attentions: 120 | Q = attention(Q, K, V, **kwargs) 121 | 122 | return Q -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FFB(nn.Module): 6 | def __init__(self, dim_in, dim_out, act_fn, ln): 7 | super().__init__() 8 | self.layers = nn.Sequential( 9 | nn.Linear(dim_in, dim_out), 10 | nn.LayerNorm(dim_out) if ln else nn.Identity(), 11 | act_fn(), 12 | ) 13 | 14 | def forward(self, x): 15 | return self.layers(x) 16 | 17 | 18 | class MLP(nn.Module): 19 | def __init__(self, dim_in, dim_out, dim_hidden, n_layers, act_fn='relu', ln=False): 20 | super().__init__() 21 | assert n_layers >= 1 22 | act_fn = nn.GELU if act_fn == 'gelu' else nn.ReLU 23 | 24 | self.dim_in = dim_in 25 | self.dim_hidden = dim_hidden 26 | self.dim_out = dim_out 27 | 28 | layers = [] 29 | for l_idx in range(n_layers): 30 | di = dim_in if l_idx == 0 else dim_hidden 31 | do = dim_out if l_idx == n_layers - 1 else dim_hidden 32 | layers.append(FFB(di, do, act_fn, ln)) 33 | 34 | self.layers = nn.Sequential(*layers) 35 | 36 | def forward(self, x): 37 | x = self.layers(x) 38 | 39 | return x 40 | 41 | 42 | class LatentMLP(nn.Module): 43 | def __init__(self, dim_in, dim_out, dim_hidden, n_layers=2, act_fn='relu', ln=False, 44 | epsilon=0.1, sigma=True, sigma_act=torch.sigmoid): 45 | super().__init__() 46 | 47 | self.epsilon = epsilon 48 | self.sigma = sigma 49 | 50 | assert n_layers >= 1 51 | if n_layers >= 2: 52 | self.mlp = MLP(dim_in, dim_hidden, dim_hidden, n_layers-1, act_fn, ln) 53 | else: 54 | self.mlp = None 55 | 56 | self.hidden_to_mu = nn.Linear(dim_hidden, dim_out) 57 | if self.sigma: 58 | self.hidden_to_log_sigma = nn.Linear(dim_hidden, dim_out) 59 | self.sigma_act = sigma_act 60 | 61 | def forward(self, x): 62 | hidden = self.mlp(x) if self.mlp is not None else x 63 | 64 | mu = self.hidden_to_mu(hidden) 65 | if self.sigma: 66 | log_sigma = self.hidden_to_log_sigma(hidden) 67 | sigma = self.epsilon + (1 - self.epsilon)*self.sigma_act(log_sigma) 68 | 69 | return mu, sigma 70 | else: 71 | return mu -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.mlp import MLP, LatentMLP 6 | from model.attention import SelfAttention, CrossAttention, PMA 7 | from model.utils import masked_forward 8 | 9 | 10 | __all__ = ['SetEncoder', 'GlobalEncoder', 'TaskEncoder', 'ConditionalSetEncoder', 'MultiTaskAttention', 'MTPDecoder', 11 | 'SharedSetEncoder', 'SharedConditionalSetEncoder', 'SharedMTPDecoder'] 12 | 13 | 14 | class SetEncoder(nn.Module): 15 | def __init__(self, dim_x, dim_y, dim_hidden, mlp_layers, attn_layers, attn_config): 16 | super().__init__() 17 | self.dim_hidden = dim_hidden 18 | 19 | self.mlp = MLP(dim_x + dim_y, dim_hidden, dim_hidden, mlp_layers, act_fn=attn_config.act_fn, ln=attn_config.ln) 20 | self.task_embedding = nn.Parameter(torch.randn(dim_hidden), requires_grad=True) 21 | self.attention = SelfAttention(dim_hidden, attn_layers, attn_config.n_heads, 22 | act_fn=attn_config.act_fn, ln=attn_config.ln, dr=attn_config.dr) 23 | self.pool = PMA(dim_hidden, attn_config.n_heads, 1, act_fn=attn_config.act_fn, ln=attn_config.ln, dr=attn_config.dr) 24 | 25 | def forward(self, C): 26 | # nan mask 27 | mask = C[..., -1].isnan() 28 | 29 | # project (x, y) to s 30 | s = masked_forward(self.mlp, C, mask, self.dim_hidden) # (B, n, h) 31 | 32 | # add task embedding e^t 33 | s = s + self.task_embedding.unsqueeze(0).unsqueeze(1) 34 | 35 | # intra-task attention 36 | s = self.attention(s, mask=mask) # (B, n, h) 37 | 38 | # intra-task aggregation 39 | s = self.pool(s).squeeze(1) # (B, h) 40 | 41 | return s 42 | 43 | 44 | class SharedSetEncoder(nn.Module): 45 | def __init__(self, n_tasks, dim_x, dim_y, dim_hidden, mlp_layers, attn_layers, attn_config): 46 | super().__init__() 47 | self.dim_hidden = dim_hidden 48 | 49 | self.mlp = MLP(dim_x + dim_y, dim_hidden, dim_hidden, mlp_layers, act_fn=attn_config.act_fn, ln=attn_config.ln) 50 | self.task_embedding = nn.Parameter(torch.randn(n_tasks, dim_hidden), requires_grad=True) 51 | self.attention = SelfAttention(dim_hidden, attn_layers, attn_config.n_heads, 52 | act_fn=attn_config.act_fn, ln=attn_config.ln, dr=attn_config.dr) 53 | self.pool = PMA(dim_hidden, attn_config.n_heads, 1, act_fn=attn_config.act_fn, ln=attn_config.ln, dr=attn_config.dr) 54 | 55 | def forward(self, C): 56 | # nan mask 57 | mask = C[..., -1].isnan() 58 | 59 | # project (x, y) to s 60 | s = masked_forward(self.mlp, C, mask, self.dim_hidden) # (B, T, n, h) 61 | 62 | # add task embedding e^t 63 | s = s + self.task_embedding.unsqueeze(0).unsqueeze(2) 64 | 65 | # intra-task attention 66 | B, T = s.size()[:2] 67 | s = s.view(-1, *s.size()[2:]) 68 | mask = mask.view(-1, *mask.size()[2:]) 69 | s = self.attention(s, mask=mask) # (B*T, n, h) 70 | 71 | # intra-task aggregation 72 | s = self.pool(s).view(B, T, s.size(-1)) # (B, T, h) 73 | 74 | return s 75 | 76 | 77 | class GlobalEncoder(nn.Module): 78 | def __init__(self, dim_hidden, attn_layers, attn_config): 79 | super().__init__() 80 | self.attention = SelfAttention(dim_hidden, attn_layers, attn_config.n_heads, 81 | act_fn=attn_config.act_fn, ln=attn_config.ln, dr=attn_config.dr) 82 | self.pool = PMA(dim_hidden, attn_config.n_heads, 1, act_fn=attn_config.act_fn, ln=attn_config.ln, dr=attn_config.dr) 83 | 84 | self.global_amortizer = LatentMLP(dim_hidden, dim_hidden, dim_hidden, 2, attn_config.act_fn, attn_config.ln) 85 | 86 | def forward(self, s): 87 | # inter-task attention 88 | s = self.attention(s) # (B, T, h) 89 | 90 | # inter-task aggregation 91 | s = self.pool(s).squeeze(1) # (B, h) 92 | 93 | # global latent distribution 94 | q_G = self.global_amortizer(s) 95 | 96 | return q_G 97 | 98 | 99 | class TaskEncoder(nn.Module): 100 | def __init__(self, dim_hidden, attn_config, hierarchical=True): 101 | super().__init__() 102 | self.hierarchical = hierarchical 103 | self.task_amortizer = LatentMLP(dim_hidden*(1 + int(hierarchical)), dim_hidden, dim_hidden, 104 | 2, attn_config.act_fn, attn_config.ln) 105 | 106 | def forward(self, s, z=None): 107 | # hierarchical conditioning 108 | if self.hierarchical: 109 | assert z is not None 110 | s = torch.cat((s, z), -1) 111 | 112 | # task latent distribution 113 | q_T = self.task_amortizer(s) 114 | 115 | return q_T 116 | 117 | 118 | class ConditionalSetEncoder(nn.Module): 119 | def __init__(self, dim_x, dim_y, dim_hidden, mlp_layers, attn_layers, attn_config): 120 | super().__init__() 121 | self.dim_hidden = dim_hidden 122 | 123 | self.mlp = MLP(dim_x + dim_y, dim_hidden, dim_hidden, mlp_layers, act_fn=attn_config.act_fn, ln=attn_config.ln) 124 | self.task_embedding = nn.Parameter(torch.randn(dim_hidden), requires_grad=True) 125 | self.attention = CrossAttention(dim_x, dim_x, dim_hidden, attn_layers, attn_config.n_heads, 126 | attn_config.act_fn, attn_config.ln, attn_config.dr) 127 | 128 | def forward(self, C, X_C, X_D): 129 | # nan mask 130 | mask = C[..., -1].isnan() 131 | 132 | # project (x, y) to s 133 | d = masked_forward(self.mlp, C, mask, self.dim_hidden) # (B, n, h) 134 | 135 | # add task embedding e^t 136 | d = d + self.task_embedding.unsqueeze(0).unsqueeze(1) 137 | 138 | # intra-task attention 139 | u = self.attention(X_D, X_C, d, mask_K=mask) 140 | 141 | return u 142 | 143 | 144 | class SharedConditionalSetEncoder(nn.Module): 145 | def __init__(self, n_tasks, dim_x, dim_y, dim_hidden, mlp_layers, attn_layers, attn_config): 146 | super().__init__() 147 | self.dim_hidden = dim_hidden 148 | 149 | self.mlp = MLP(dim_x + dim_y, dim_hidden, dim_hidden, mlp_layers, act_fn=attn_config.act_fn, ln=attn_config.ln) 150 | self.task_embedding = nn.Parameter(torch.randn(n_tasks, dim_hidden), requires_grad=True) 151 | self.attention = CrossAttention(dim_x, dim_x, dim_hidden, attn_layers, attn_config.n_heads, 152 | attn_config.act_fn, attn_config.ln, attn_config.dr) 153 | 154 | def forward(self, C, X_C, X_D): 155 | # nan mask 156 | mask = C[..., -1].isnan() 157 | 158 | # project (x, y) to s 159 | d = masked_forward(self.mlp, C, mask, self.dim_hidden) # (B, T, n, h) 160 | 161 | # add task embedding e^t 162 | d = d + self.task_embedding.unsqueeze(0).unsqueeze(2) 163 | 164 | # intra-task attention 165 | B, T = d.size()[:2] 166 | d = d.view(B*T, *d.size()[2:]) 167 | mask = mask.view(B*T, *mask.size()[2:]) 168 | X_C = X_C.unsqueeze(1).repeat(1, T, 1, 1).view(B*T, *X_C.size()[1:]) 169 | X_D = X_D.unsqueeze(1).repeat(1, T, 1, 1).view(B*T, *X_D.size()[1:]) 170 | u = self.attention(X_D, X_C, d, mask_K=mask) 171 | u = u.view(B, T, *u.size()[1:]) 172 | 173 | return u 174 | 175 | 176 | class MultiTaskAttention(nn.Module): 177 | def __init__(self, dim_hidden, n_layers, n_heads, act_fn='relu', ln=False, dr=0.1): 178 | super().__init__() 179 | act_fn = nn.GELU if act_fn == 'gelu' else nn.ReLU 180 | 181 | self.attention = SelfAttention(dim_hidden, n_layers, n_heads, act_fn=act_fn, ln=ln, dr=dr) 182 | 183 | def forward(self, Q): 184 | bs, nb, ts, _ = Q.size() 185 | Q_ = Q.transpose(1, 2).reshape(bs*ts, nb, -1) 186 | Q_ = self.attention(Q_) # (bs*ts, nb, dim_hidden) or (bs*ts, 1, dim_hidden) 187 | Q = Q_.reshape(bs, ts, *Q_.size()[1:]).transpose(1, 2) 188 | return Q 189 | 190 | 191 | class MTPDecoder(nn.Module): 192 | def __init__(self, dim_x, dim_y, dim_hidden, n_layers, attn_config, sigma): 193 | super().__init__() 194 | self.input_projection = nn.Linear(dim_x, dim_hidden) 195 | self.task_embedding = nn.Parameter(torch.randn(dim_hidden), requires_grad=True) 196 | self.output_amortizer = LatentMLP(dim_hidden*3, dim_y, dim_hidden, n_layers, 197 | attn_config.act_fn, attn_config.ln, sigma=sigma, sigma_act=F.softplus) 198 | 199 | def forward(self, X, v, r): 200 | # project x to w 201 | w = self.input_projection(X) # (B, n, h) or (B, ns, n, h) 202 | 203 | if self.training: 204 | # add task embedding e^t 205 | w = w + self.task_embedding.unsqueeze(0).unsqueeze(1) 206 | else: 207 | # add task embedding e^t 208 | w = w + self.task_embedding.unsqueeze(0).unsqueeze(1).unsqueeze(2) 209 | 210 | # concat w, v, r 211 | v = v.unsqueeze(-2).repeat(*([1]*(len(w.size())-2)), w.size(-2), 1) 212 | 213 | decoder_input = torch.cat((w, v, r), -1) 214 | 215 | # output distribution 216 | p_Y = self.output_amortizer(decoder_input) 217 | 218 | return p_Y 219 | 220 | 221 | class SharedMTPDecoder(nn.Module): 222 | def __init__(self, n_tasks, dim_x, dim_y, dim_hidden, n_layers, attn_config, sigma): 223 | super().__init__() 224 | self.input_projection = nn.Linear(dim_x, dim_hidden) 225 | self.task_embedding = nn.Parameter(torch.randn(n_tasks, dim_hidden), requires_grad=True) 226 | self.output_amortizer = LatentMLP(dim_hidden*3, dim_y, dim_hidden, n_layers, 227 | attn_config.act_fn, attn_config.ln, sigma=sigma, sigma_act=F.softplus) 228 | 229 | def forward(self, X, v, r): 230 | # project x to w 231 | w = self.input_projection(X) # (B, n, h) or (B, ns, n, h) 232 | w = w.unsqueeze(1).repeat(1, v.size(1), *([1]*(len(w.size())-1))) # (B, T, n, h) or (B, T, ns, n, h) 233 | 234 | if self.training: 235 | # add task embedding e^t 236 | w = w + self.task_embedding.unsqueeze(0).unsqueeze(2) 237 | else: 238 | # add task embedding e^t 239 | w = w + self.task_embedding.unsqueeze(0).unsqueeze(2).unsqueeze(3) 240 | 241 | # concat w, v, r 242 | v = v.unsqueeze(-2).repeat(*([1]*(len(w.size())-2)), w.size(-2), 1) 243 | 244 | decoder_input = torch.cat((w, v, r), -1) 245 | 246 | # output distribution 247 | p_Y = self.output_amortizer(decoder_input) 248 | 249 | return p_Y -------------------------------------------------------------------------------- /model/mtnp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Normal 4 | 5 | from model.module import * 6 | 7 | 8 | class MTP(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.tasks = config.tasks 12 | 13 | # latent encoding path 14 | self.set_encoder = nn.ModuleList([SetEncoder(config.dim_x, config.dim_ys[task], config.dim_hidden, 15 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 16 | for task in self.tasks]) 17 | self.global_encoder = GlobalEncoder(config.dim_hidden, config.module_sizes[2], config.attn_config) 18 | self.task_encoder = nn.ModuleList([TaskEncoder(config.dim_hidden, config.attn_config, hierarchical=True) 19 | for task in self.tasks]) 20 | 21 | # deterministic encoding path 22 | self.conditional_set_encoder = nn.ModuleList([ConditionalSetEncoder(config.dim_x, config.dim_ys[task], config.dim_hidden, 23 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 24 | for task in self.tasks]) 25 | self.deterministic_encoder = MultiTaskAttention(config.dim_hidden, config.module_sizes[2], config.attn_config.n_heads, 26 | config.attn_config.act_fn, config.attn_config.ln, config.attn_config.dr) 27 | 28 | # decoding path 29 | self.decoder = nn.ModuleList([MTPDecoder(config.dim_x, config.dim_ys[task], config.dim_hidden, config.module_sizes[3], 30 | config.attn_config, sigma=(config.task_types[task] == 'continuous')) 31 | for task in self.tasks]) 32 | 33 | def state_dict_(self): 34 | return self.state_dict() 35 | 36 | def load_state_dict_(self, state_dict): 37 | self.load_state_dict(state_dict) 38 | 39 | def encode_global(self, X, Y): 40 | s = {} 41 | # per-task inference of latent path 42 | for t_idx, task in enumerate(self.tasks): 43 | D_t = torch.cat((X, Y[task]), -1) 44 | s[task] = self.set_encoder[t_idx](D_t) 45 | 46 | # global latent in across-task inference of latent path 47 | s_G = torch.stack([s[task] for task in s], 1) 48 | q_G = self.global_encoder(s_G) 49 | return q_G, s 50 | 51 | def encode_task(self, s, z): 52 | # task-specific latent in across-task inference of latent path 53 | q_T = {} 54 | for t_idx, task in enumerate(self.tasks): 55 | s_t = s[task] 56 | if not self.training: 57 | s_t = s_t.unsqueeze(1).repeat(1, z.size(1), 1) 58 | q_T[task] = self.task_encoder[t_idx](s_t, z) 59 | 60 | return q_T 61 | 62 | def encode_deterministic(self, X_C, Y_C, X_D): 63 | U_C = {} 64 | # cross-attention layers in across-task inference of deterministic path 65 | for t_idx, task in enumerate(self.tasks): 66 | C_t = torch.cat((X_C, Y_C[task]), -1) 67 | U_C[task] = self.conditional_set_encoder[t_idx](C_t, X_C, X_D) 68 | 69 | # self-attention layers in across-task inference of deterministic path 70 | U_C = torch.stack([U_C[task] for task in self.tasks], 1) 71 | r = self.deterministic_encoder(U_C) 72 | return r 73 | 74 | def decode(self, X, v, r): 75 | if not self.training: 76 | X = X.unsqueeze(1).repeat(1, v.size(2), 1, 1) 77 | r = r.unsqueeze(2).repeat(1, 1, v.size(2), 1, 1) 78 | 79 | p_Y = {} 80 | for t_idx, task in enumerate(self.tasks): 81 | p_Y[task] = self.decoder[t_idx](X, v[:, t_idx], r[:, t_idx]) 82 | return p_Y 83 | 84 | def forward(self, X_C, Y_C, X_D, Y_D=None, MAP=False, ns_G=5, ns_T=5): 85 | if self.training: 86 | assert Y_D is not None 87 | 88 | q_C_G, s_C = self.encode_global(X_C, Y_C) 89 | q_D_G, s_D = self.encode_global(X_D, Y_D) 90 | z = Normal(*q_D_G).rsample() 91 | 92 | q_C_T = self.encode_task(s_C, z) 93 | q_D_T = self.encode_task(s_D, z) 94 | v = torch.stack([Normal(*q_D_T[task]).rsample() for task in self.tasks], 1) 95 | 96 | r = self.encode_deterministic(X_C, Y_C, X_D) 97 | 98 | p_Y = self.decode(X_D, v, r) 99 | 100 | return p_Y, q_D_G, q_C_G, q_D_T, q_C_T 101 | else: 102 | q_C_G, s_C = self.encode_global(X_C, Y_C) 103 | if MAP: 104 | z = q_C_G[0].unsqueeze(1) 105 | else: 106 | z = Normal(*q_C_G).sample((ns_G,)).transpose(0, 1) 107 | 108 | q_C_T = self.encode_task(s_C, z) 109 | if MAP: 110 | v = torch.stack([q_C_T[task][0] for task in q_C_T], 1) 111 | else: 112 | v = torch.stack([Normal(*q_C_T[task]).sample((ns_T,)).transpose(0, 1).reshape(z.size(0), ns_G*ns_T, -1) 113 | for task in q_C_T], 1) 114 | 115 | r = self.encode_deterministic(X_C, Y_C, X_D) 116 | 117 | p_Y = self.decode(X_D, v, r) 118 | 119 | return p_Y 120 | 121 | 122 | class STP(nn.Module): 123 | def __init__(self, config): 124 | super().__init__() 125 | self.tasks = config.tasks 126 | 127 | # latent encoding path 128 | self.set_encoder = nn.ModuleList([SetEncoder(config.dim_x, config.dim_ys[task], config.dim_hidden, 129 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 130 | for task in self.tasks]) 131 | self.task_encoder = nn.ModuleList([TaskEncoder(config.dim_hidden, config.attn_config, hierarchical=False) 132 | for task in self.tasks]) 133 | 134 | # deterministic encoding path 135 | self.conditional_set_encoder = nn.ModuleList([ConditionalSetEncoder(config.dim_x, config.dim_ys[task], config.dim_hidden, 136 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 137 | for task in self.tasks]) 138 | 139 | # decoding path 140 | self.decoder = nn.ModuleList([MTPDecoder(config.dim_x, config.dim_ys[task], config.dim_hidden, config.module_sizes[3], 141 | config.attn_config, sigma=(config.task_types[task] == 'continuous')) 142 | for task in self.tasks]) 143 | 144 | def state_dict_(self): 145 | state_dict = {task: {} for task in self.tasks} 146 | for name, child in self.named_children(): 147 | for t_idx, task in enumerate(self.tasks): 148 | state_dict[task][name] = child[t_idx].state_dict() 149 | 150 | return state_dict 151 | 152 | def state_dict_task(self, task): 153 | state_dict = {} 154 | for name, child in self.named_children(): 155 | t_idx = self.tasks.index(task) 156 | state_dict[name] = child[t_idx].state_dict() 157 | 158 | return state_dict 159 | 160 | def load_state_dict_(self, state_dict): 161 | for name, child in self.named_children(): 162 | for t_idx, task in enumerate(self.tasks): 163 | child[t_idx].load_state_dict(state_dict[task][name]) 164 | 165 | def encode_task(self, X, Y): 166 | # task-specific latent in across-task inference of latent path 167 | q_T = {} 168 | for t_idx, task in enumerate(self.tasks): 169 | D_t = torch.cat((X, Y[task]), -1) 170 | s_t = self.set_encoder[t_idx](D_t) 171 | q_T[task] = self.task_encoder[t_idx](s_t) 172 | 173 | return q_T 174 | 175 | def encode_deterministic(self, X_C, Y_C, X_D): 176 | U_C = {} 177 | # cross-attention layers in across-task inference of deterministic path 178 | for t_idx, task in enumerate(self.tasks): 179 | C_t = torch.cat((X_C, Y_C[task]), -1) 180 | U_C[task] = self.conditional_set_encoder[t_idx](C_t, X_C, X_D) 181 | 182 | # self-attention layers in across-task inference of deterministic path 183 | r = torch.stack([U_C[task] for task in self.tasks], 1) 184 | return r 185 | 186 | def decode(self, X, v, r): 187 | if not self.training: 188 | X = X.unsqueeze(1).repeat(1, v.size(2), 1, 1) 189 | r = r.unsqueeze(2).repeat(1, 1, v.size(2), 1, 1) 190 | 191 | p_Y = {} 192 | for t_idx, task in enumerate(self.tasks): 193 | p_Y[task] = self.decoder[t_idx](X, v[:, t_idx], r[:, t_idx]) 194 | return p_Y 195 | 196 | def forward(self, X_C, Y_C, X_D, Y_D=None, MAP=False, ns_G=5, ns_T=5): 197 | if self.training: 198 | assert Y_D is not None 199 | q_C_T = self.encode_task(X_C, Y_C) 200 | q_D_T = self.encode_task(X_D, Y_D) 201 | v = torch.stack([Normal(*q_D_T[task]).rsample() for task in self.tasks], 1) 202 | 203 | r = self.encode_deterministic(X_C, Y_C, X_D) 204 | 205 | p_Y = self.decode(X_D, v, r) 206 | 207 | return p_Y, None, None, q_D_T, q_C_T 208 | else: 209 | q_C_T = self.encode_task(X_C, Y_C) 210 | if MAP: 211 | v = torch.stack([q_C_T[task][0].unsqueeze(1) for task in q_C_T], 1) 212 | else: 213 | v = torch.stack([Normal(*q_C_T[task]).sample((ns_T,)).transpose(0, 1) 214 | for task in q_C_T], 1) 215 | 216 | r = self.encode_deterministic(X_C, Y_C, X_D) 217 | 218 | p_Y = self.decode(X_D, v, r) 219 | 220 | return p_Y 221 | 222 | 223 | class JTP(nn.Module): 224 | def __init__(self, config): 225 | super().__init__() 226 | self.tasks = config.tasks 227 | self.dim_ys = config.dim_ys 228 | self.task_types = config.task_types 229 | 230 | dim_y = sum([self.dim_ys[task] for task in self.tasks]) 231 | task_type = 'continuous' if sum([int(config.task_types[task] == 'continuous') for task in self.tasks]) > 0 else 'discrete' 232 | 233 | # latent encoding path 234 | self.set_encoder = SetEncoder(config.dim_x, dim_y, config.dim_hidden, 235 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 236 | self.global_encoder = TaskEncoder(config.dim_hidden, config.attn_config, hierarchical=False) 237 | 238 | # deterministic encoding path 239 | self.conditional_set_encoder = ConditionalSetEncoder(config.dim_x, dim_y, config.dim_hidden, 240 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 241 | 242 | # decoding path 243 | self.decoder = MTPDecoder(config.dim_x, dim_y, config.dim_hidden, config.module_sizes[3], 244 | config.attn_config, sigma=(task_type == 'continuous')) 245 | 246 | def state_dict_(self): 247 | return self.state_dict() 248 | 249 | def load_state_dict_(self, state_dict): 250 | self.load_state_dict(state_dict) 251 | 252 | def encode_global(self, X, Y): 253 | # global inference of latent path 254 | D = torch.cat((X, Y), -1) 255 | s = self.set_encoder(D) 256 | 257 | # global latent in across-task inference of latent path 258 | q_G = self.global_encoder(s) 259 | return q_G 260 | 261 | def encode_deterministic(self, X_C, Y_C, X_D): 262 | # cross-attention layers in across-task inference of deterministic path 263 | C = torch.cat((X_C, Y_C), -1) 264 | r = self.conditional_set_encoder(C, X_C, X_D) 265 | return r 266 | 267 | def decode(self, X, z, r): 268 | if not self.training: 269 | X = X.unsqueeze(1).repeat(1, z.size(1), 1, 1) 270 | r = r.unsqueeze(1).repeat(1, z.size(1), 1, 1) 271 | 272 | p_Y = self.decoder(X, z, r) 273 | return p_Y 274 | 275 | def gather_outputs(self, Y): 276 | Y = torch.cat([Y[task] for task in self.tasks], -1) 277 | return Y 278 | 279 | def ungather_dists(self, p_Y): 280 | p_Y_ = {} 281 | offset = 0 282 | for task in self.tasks: 283 | if self.task_types[task] == 'continuous': 284 | p_Y_[task] = (p_Y[0][..., offset:offset+self.dim_ys[task]], 285 | p_Y[1][..., offset:offset+self.dim_ys[task]]) 286 | else: 287 | p_Y_[task] = p_Y[0][..., offset:offset+self.dim_ys[task]] 288 | 289 | offset += self.dim_ys[task] 290 | 291 | return p_Y_ 292 | 293 | def forward(self, X_C, Y_C, X_D, Y_D=None, MAP=False, ns_G=5, ns_T=5): 294 | if self.training: 295 | assert Y_D is not None 296 | 297 | Y_C = self.gather_outputs(Y_C) 298 | Y_D = self.gather_outputs(Y_D) 299 | 300 | q_C_G = self.encode_global(X_C, Y_C) 301 | q_D_G = self.encode_global(X_D, Y_D) 302 | z = Normal(*q_D_G).rsample() 303 | 304 | r = self.encode_deterministic(X_C, Y_C, X_D) 305 | 306 | p_Y = self.decode(X_D, z, r) 307 | p_Y = self.ungather_dists(p_Y) 308 | 309 | return p_Y, q_D_G, q_C_G, None, None 310 | else: 311 | Y_C = self.gather_outputs(Y_C) 312 | 313 | q_C_G = self.encode_global(X_C, Y_C) 314 | if MAP: 315 | z = q_C_G[0].unsqueeze(1) 316 | else: 317 | z = Normal(*q_C_G).sample((ns_G,)).transpose(0, 1) 318 | 319 | r = self.encode_deterministic(X_C, Y_C, X_D) 320 | 321 | p_Y = self.decode(X_D, z, r) 322 | p_Y = self.ungather_dists(p_Y) 323 | 324 | return p_Y 325 | 326 | 327 | class SharedMTP(nn.Module): 328 | def __init__(self, config): 329 | super().__init__() 330 | assert len(set(config.dim_ys.values())) == 1 331 | assert len(set(config.task_types.values())) == 1 332 | dim_y = config.dim_ys[config.tasks[0]] 333 | task_type = config.task_types[config.tasks[0]] 334 | 335 | self.tasks = config.tasks 336 | self.task_type = task_type 337 | 338 | # latent encoding path 339 | self.set_encoder = SharedSetEncoder(len(self.tasks), config.dim_x, dim_y, config.dim_hidden, 340 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 341 | self.global_encoder = GlobalEncoder(config.dim_hidden, config.module_sizes[2], config.attn_config) 342 | self.task_encoder = nn.ModuleList([TaskEncoder(config.dim_hidden, config.attn_config, hierarchical=True) 343 | for task in self.tasks]) 344 | 345 | # deterministic encoding path 346 | self.conditional_set_encoder = SharedConditionalSetEncoder(len(self.tasks), config.dim_x, dim_y, config.dim_hidden, 347 | config.module_sizes[0], config.module_sizes[1], config.attn_config) 348 | self.deterministic_encoder = MultiTaskAttention(config.dim_hidden, config.module_sizes[2], config.attn_config.n_heads, 349 | config.attn_config.act_fn, config.attn_config.ln, config.attn_config.dr) 350 | 351 | # decoding path 352 | self.decoder = SharedMTPDecoder(len(self.tasks), config.dim_x, dim_y, config.dim_hidden, config.module_sizes[3], 353 | config.attn_config, sigma=(task_type == 'continuous')) 354 | 355 | def state_dict_(self): 356 | return self.state_dict() 357 | 358 | def load_state_dict_(self, state_dict): 359 | self.load_state_dict(state_dict) 360 | 361 | def encode_global(self, X, Y): 362 | # per-task inference of latent path 363 | D = torch.stack([torch.cat((X, Y[task]), -1) for task in Y], 1) 364 | s = self.set_encoder(D) 365 | 366 | # global latent in across-task inference of latent path 367 | q_G = self.global_encoder(s) 368 | return q_G, s 369 | 370 | def encode_task(self, s, z): 371 | # task-specific latent in across-task inference of latent path 372 | q_T = {} 373 | for t_idx, task in enumerate(self.tasks): 374 | s_t = s[:, t_idx] 375 | if not self.training: 376 | s_t = s_t.unsqueeze(1).repeat(1, z.size(1), 1) 377 | q_T[task] = self.task_encoder[t_idx](s_t, z) 378 | 379 | return q_T 380 | 381 | def encode_deterministic(self, X_C, Y_C, X_D): 382 | U_C = {} 383 | # cross-attention layers in across-task inference of deterministic path 384 | 385 | C = torch.stack([torch.cat((X_C, Y_C[task]), -1) for task in Y_C], 1) 386 | U_C = self.conditional_set_encoder(C, X_C, X_D) 387 | 388 | # self-attention layers in across-task inference of deterministic path 389 | r = self.deterministic_encoder(U_C) 390 | return r 391 | 392 | def decode(self, X, v, r): 393 | if not self.training: 394 | X = X.unsqueeze(1).repeat(1, v.size(2), 1, 1) 395 | r = r.unsqueeze(2).repeat(1, 1, v.size(2), 1, 1) 396 | 397 | p_Y = self.decoder(X, v, r) 398 | 399 | return p_Y 400 | 401 | def ungather_dists(self, p_Y): 402 | p_Y_ = {} 403 | for t_idx, task in enumerate(self.tasks): 404 | if self.task_type == 'continuous': 405 | p_Y_[task] = (p_Y[0][:, t_idx], 406 | p_Y[1][:, t_idx]) 407 | else: 408 | p_Y_[task] = p_Y[0][:, t_idx] 409 | 410 | return p_Y_ 411 | 412 | def forward(self, X_C, Y_C, X_D, Y_D=None, MAP=False, ns_G=5, ns_T=5): 413 | if self.training: 414 | assert Y_D is not None 415 | 416 | q_C_G, s_C = self.encode_global(X_C, Y_C) 417 | q_D_G, s_D = self.encode_global(X_D, Y_D) 418 | z = Normal(*q_D_G).rsample() 419 | 420 | q_C_T = self.encode_task(s_C, z) 421 | q_D_T = self.encode_task(s_D, z) 422 | v = torch.stack([Normal(*q_D_T[task]).rsample() for task in self.tasks], 1) 423 | 424 | r = self.encode_deterministic(X_C, Y_C, X_D) 425 | 426 | p_Y = self.decode(X_D, v, r) 427 | p_Y = self.ungather_dists(p_Y) 428 | 429 | return p_Y, q_D_G, q_C_G, q_D_T, q_C_T 430 | else: 431 | q_C_G, s_C = self.encode_global(X_C, Y_C) 432 | if MAP: 433 | z = q_C_G[0].unsqueeze(1) 434 | else: 435 | z = Normal(*q_C_G).sample((ns_G,)).transpose(0, 1) 436 | 437 | q_C_T = self.encode_task(s_C, z) 438 | if MAP: 439 | v = torch.stack([q_C_T[task][0] for task in q_C_T], 1) 440 | else: 441 | v = torch.stack([Normal(*q_C_T[task]).sample((ns_T,)).transpose(0, 1).reshape(z.size(0), ns_G*ns_T, -1) 442 | for task in q_C_T], 1) 443 | 444 | r = self.encode_deterministic(X_C, Y_C, X_D) 445 | 446 | p_Y = self.decode(X_D, v, r) 447 | p_Y = self.ungather_dists(p_Y) 448 | 449 | return p_Y -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def masked_forward(module, x, mask, out_dim, **kwargs): 5 | assert x.size()[:-1] == mask.size() 6 | out = torch.zeros(*mask.size(), out_dim).to(x.device) 7 | out[~mask] = module(x[~mask], **kwargs) 8 | 9 | return out -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.0+cu113 2 | torchvision==0.11.1+cu113 3 | numpy 4 | matplotlib 5 | easydict 6 | tqdm 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import tqdm 4 | import yaml 5 | from easydict import EasyDict 6 | 7 | import torch 8 | 9 | from dataset import load_data, to_device 10 | from model import get_model 11 | from train import evaluate 12 | 13 | 14 | # ENVIRONMENTAL SETTINGS 15 | # to prevent over-threading 16 | torch.set_num_threads(1) 17 | 18 | DATASETS = ['synthetic', 'weather'] 19 | CHECKPOINTS = ['best_nll', 'best_error', 'last'] 20 | SPLITS = ['test', 'valid'] 21 | 22 | # arguments 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--data', type=str, default='synthetic', choices=DATASETS) 25 | parser.add_argument('--eval_root', type=str, default='experiments') 26 | parser.add_argument('--eval_dir', type=str, default='') 27 | parser.add_argument('--eval_name', type=str, default='') 28 | parser.add_argument('--eval_ckpt', type=str, default='best_error', choices=CHECKPOINTS) 29 | parser.add_argument('--split', type=str, default='test', choices=SPLITS) 30 | parser.add_argument('--device', type=str, default='0') 31 | parser.add_argument('--reset', default=False, action='store_true') 32 | parser.add_argument('--verbose', '-v', default=False, action='store_true') 33 | parser.add_argument('--use_valid_imputer', '-uvi', default=False, action='store_true') 34 | parser.add_argument('--use_homogeneous_imputer', '-uhi', default=False, action='store_true') 35 | 36 | parser.add_argument('--cs', type=int, default=10) 37 | parser.add_argument('--gamma', type=float, default=0.) 38 | parser.add_argument('--seed', type=int, default=0) 39 | parser.add_argument('--global_batch_size', type=int, default=16) 40 | 41 | args = parser.parse_args() 42 | 43 | 44 | # load test config 45 | with open(os.path.join('configs', args.data, 'config_test.yaml')) as f: 46 | config_test = EasyDict(yaml.safe_load(f)) 47 | config_test[f'cs_{args.split}'] = args.cs 48 | config_test[f'gamma_{args.split}'] = args.gamma 49 | config_test.seed = args.seed 50 | if args.eval_dir != '': 51 | config_test.eval_dir = args.eval_dir 52 | 53 | # set device and evaluation directory 54 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device 55 | device = torch.device('cuda') 56 | config_test.eval_dir = os.path.join(args.eval_root, config_test.eval_dir) 57 | if args.eval_name == '': 58 | eval_list = os.listdir(config_test.eval_dir) 59 | else: 60 | eval_list = [args.eval_name] 61 | 62 | # load test dataloader 63 | test_loader = load_data(config_test, device, split=args.split) 64 | 65 | # test models in eval_list 66 | for exp_name in eval_list: 67 | # skip if checkpoint not exists or still running 68 | eval_path = os.path.join(config_test.eval_dir, exp_name, 'checkpoints', f'{args.eval_ckpt}.pth') 69 | # last_path = os.path.join(config_test.eval_dir, exp_name, 'checkpoints', 'last.pth') 70 | last_path = eval_path 71 | if not (os.path.exists(eval_path) and os.path.exists(last_path)): 72 | if args.verbose: 73 | print(f'checkpoint of {exp_name} does not exist or still running - skip...') 74 | continue 75 | 76 | # skip if already tested 77 | result_dir = os.path.join(config_test.eval_dir, exp_name, 'results') 78 | os.makedirs(result_dir, exist_ok=True) 79 | result_path = os.path.join(result_dir, f'result_cs{args.cs}_gamma{args.gamma}_seed{args.seed}_{args.split}_from{args.eval_ckpt}.pth') 80 | if os.path.exists(result_path) and not args.reset: 81 | if args.verbose: 82 | print(f'result of {exp_name} already exists - skip...') 83 | continue 84 | 85 | # load model and config 86 | ckpt = torch.load(eval_path, map_location=device) 87 | config = ckpt['config'] 88 | params = ckpt['model'] 89 | 90 | model = get_model(config, device) 91 | model.load_state_dict_(params) 92 | 93 | # load imputer 94 | if config.model == 'jtp' and config_test[f'gamma_{args.split}'] > 0: 95 | if args.use_valid_imputer: 96 | imputer_path = config.imputer_path 97 | elif args.use_homogeneous_imputer: 98 | imputer_path = eval_path.replace('jtp', 'stp') 99 | else: 100 | imputer_path = config_test.imputer_path 101 | 102 | assert os.path.exists(imputer_path) 103 | ckpt_imputer = torch.load(imputer_path) 104 | config_imputer = ckpt_imputer['config'] 105 | params_imputer = ckpt_imputer['model'] 106 | 107 | imputer = get_model(config_imputer, device) 108 | imputer.load_state_dict_(params_imputer) 109 | else: 110 | imputer = config_imputer = None 111 | 112 | if args.verbose: 113 | print('evaluating {} with test seed {} and gamma {} on {} data'.format(exp_name, args.seed, args.gamma, args.split)) 114 | 115 | # evaluate and save results 116 | nlls, errors = evaluate(model, test_loader, device, config_test, imputer=imputer, config_imputer=config_imputer) 117 | if args.verbose: 118 | print(f'nll: {nlls}\nmse:{errors}') 119 | torch.save({'nlls': nlls, 'errors': errors, 'global_step': ckpt['global_step']}, result_path) 120 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import train_step, evaluate 2 | from .utils import configure_experiment, get_schedulers, Saver -------------------------------------------------------------------------------- /train/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.distributions import kl_divergence, Normal 4 | 5 | 6 | def compute_elbo(Y_D, p_Y, q_D_G, q_C_G, q_D_T, q_C_T, config, logger=None): 7 | ''' 8 | Compute (prior-approximated) elbo objective for NP-based models. 9 | ''' 10 | log_prob = 0 11 | for task in p_Y: 12 | if config.task_types[task] == 'continuous': 13 | log_prob_ = Normal(p_Y[task][0], p_Y[task][1]).log_prob(Y_D[task]).mean(0).sum() 14 | else: 15 | log_prob_ = -F.cross_entropy(p_Y[task].transpose(1, 2), torch.argmax(Y_D[task], -1), reduction='none').mean(0).sum() 16 | log_prob += log_prob_ 17 | if logger is not None: 18 | logger.add_value(f'nll_{task}', -log_prob_.item()) 19 | if logger is not None: 20 | logger.add_value('nll_normalized', -log_prob.item() / len(config.tasks) / Y_D[task].size(1)) 21 | 22 | kld_G = 0 23 | if q_D_G is not None: 24 | kld_G = kl_divergence(Normal(*q_D_G), Normal(*q_C_G)).mean(0).sum() 25 | if logger is not None: 26 | logger.add_value('kld_G', kld_G.item()) 27 | 28 | kld_T = 0 29 | if q_D_T is not None: 30 | for task in q_D_T: 31 | kld_T_ = kl_divergence(Normal(*q_D_T[task]), Normal(*q_C_T[task])).mean(0).sum() 32 | kld_T += kld_T_ 33 | if logger is not None: 34 | logger.add_value(f'kld_{task}', kld_T_.item()) 35 | if logger is not None: 36 | logger.add_value('kld_T_normalized', kld_T.item() / len(config.tasks) / Y_D[task].size(1)) 37 | 38 | elbo = log_prob - (config.beta_G*kld_G + config.beta_T*kld_T) 39 | 40 | return elbo 41 | 42 | 43 | def compute_normalized_nll(Y_D, p_Y, task_types): 44 | nll = {} 45 | for task in Y_D: 46 | if task_types[task] == 'continuous': 47 | nll[task] = -Normal(p_Y[task][0], p_Y[task][1]).log_prob(Y_D[task]).mean() 48 | else: 49 | nll[task] = F.cross_entropy(p_Y[task].transpose(1, 2), torch.argmax(Y_D[task], -1)) 50 | return nll 51 | 52 | 53 | def compute_error(Y_D, Y_D_pred, task_types, scales=None): 54 | ''' 55 | Compute (normalized) MSE 56 | ''' 57 | error = {} 58 | for task in Y_D: 59 | if isinstance(scales, dict): 60 | scale = scales[task].pow(2) 61 | elif scales is not None: 62 | scale = scales.pow(2) 63 | else: 64 | scale = torch.ones(Y_D_pred[task].size(0), 1, 1, device=Y_D_pred[task].device) 65 | 66 | mse = ((Y_D_pred[task] - Y_D[task]).pow(2) / scale).sum(-1).mean() 67 | error[task] = mse.cpu() 68 | 69 | return error 70 | -------------------------------------------------------------------------------- /train/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class LRScheduler(object): 5 | ''' 6 | Custom learning rate scheduler for pytorch optimizer. 7 | Assumes 1 <= self.iter <= 1 + num_iters. 8 | ''' 9 | def __init__(self, optimizer, mode, base_lr, num_iters, warmup_iters=1000, 10 | from_iter=0, decay_degree=0.9, decay_steps=5000): 11 | self.optimizer = optimizer 12 | self.mode = mode 13 | self.base_lr = base_lr 14 | self.lr = base_lr 15 | self.iter = from_iter 16 | self.N = num_iters + 1 17 | self.warmup_iters = warmup_iters 18 | self.decay_degree = decay_degree 19 | self.decay_steps = decay_steps 20 | 21 | def step(self): 22 | self.iter += 1 23 | if self.mode == 'cos': 24 | self.lr = 0.5 * self.base_lr * (1 + math.cos(1.0 * self.iter / self.N * math.pi)) 25 | elif self.mode == 'poly': 26 | if self.iter < self.N: 27 | self.lr = self.base_lr * pow((1 - 1.0 * self.iter / self.N), self.decay_degree) 28 | elif self.mode == 'step': 29 | self.lr = self.base_lr * (0.1**(self.decay_steps // self.iter)) 30 | elif self.mode == 'constant': 31 | self.lr = self.base_lr 32 | elif self.mode == 'sqroot': 33 | self.lr = self.base_lr * self.warmup_iters**0.5 * min(self.iter * self.warmup_iters**-1.5, self.iter**-0.5) 34 | else: 35 | raise NotImplemented 36 | 37 | # warm up lr schedule 38 | if self.warmup_iters > 0 and self.iter < self.warmup_iters and self.mode != 'sqroot': 39 | self.lr = self.base_lr * 1.0 * self.iter / self.warmup_iters 40 | assert self.lr >= 0 41 | self._adjust_learning_rate(self.optimizer, self.lr) 42 | 43 | def _adjust_learning_rate(self, optimizer, lr): 44 | if len(optimizer.param_groups) == 1: 45 | optimizer.param_groups[0]['lr'] = lr 46 | else: 47 | # enlarge the lr at the head 48 | optimizer.param_groups[0]['lr'] = lr 49 | for i in range(1, len(optimizer.param_groups)): 50 | optimizer.param_groups[i]['lr'] = lr * 10 51 | 52 | def reset(self): 53 | self.lr = self.base_lr 54 | self.iter = 0 55 | self._adjust_learning_rate(self.optimizer, self.lr) 56 | 57 | 58 | class HPScheduler: 59 | ''' 60 | Custom hyper-parameter scheduler for any nonzero coefficient wrapped by dictionary. 61 | ''' 62 | def __init__(self, coef_dict, coef_key, mode, base_coef, n_steps, warmup_steps=10000): 63 | self.coef_dict = coef_dict 64 | self.coef_key = coef_key 65 | self.mode = mode 66 | self.base_coef = base_coef 67 | self.warmup_steps = warmup_steps 68 | self.n_steps = n_steps + 1 69 | self.iter = 0 70 | 71 | def step(self): 72 | self.iter += 1 73 | if self.mode == 'constant': 74 | self.coef_dict[self.coef_key] = self.base_coef 75 | elif self.mode == 'linear_warmup': 76 | self.coef_dict[self.coef_key] = min(1, (self.iter / self.warmup_steps)) * self.base_coef 77 | # elif self.mode == 'linear': 78 | # return (self.iter / self.n_steps) * (self.beta_last - self.beta_init) + self.beta_init 79 | # elif self.mode == 'inverse-linear': 80 | # return (1 - self.iter / self.n_steps) * (self.beta_last - self.beta_init) + self.beta_init 81 | # elif self.mode == 'cyclic': 82 | # cycles = 10 83 | # period = (self.n_steps - 1) // cycles 84 | # iter_p = (self.iter - 1) % period 85 | # return float(iter_p) / max(1, float(period - 1)) * (self.beta_last - self.beta_init) + self.beta_init -------------------------------------------------------------------------------- /train/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .loss import compute_elbo, compute_error, compute_normalized_nll 5 | from .utils import plot_curves, broadcast_squeeze, broadcast_index, broadcast_mean 6 | from dataset import to_device 7 | 8 | 9 | def train_step(model, optimizer, config, logger, *train_data): 10 | ''' 11 | Perform a training step. 12 | ''' 13 | # forward 14 | X_C, Y_C, X_D, Y_D = train_data 15 | p_Y, q_D_G, q_C_G, q_D_T, q_C_T = model(X_C, Y_C, X_D, Y_D) 16 | 17 | loss = -compute_elbo(Y_D, p_Y, q_D_G, q_C_G, q_D_T, q_C_T, config, logger) 18 | 19 | # backward 20 | optimizer.zero_grad() 21 | loss.backward() 22 | optimizer.step() 23 | 24 | # update global step 25 | logger.global_step += 1 26 | 27 | 28 | @torch.no_grad() 29 | def inference_map(model, *test_data): 30 | ''' 31 | Calculate map estimation (or mode for categorical) with K global latents and L per-task latents. 32 | ''' 33 | X_C, Y_C, X_D = test_data 34 | 35 | model.eval() 36 | p_Ys = model(X_C, Y_C, X_D, MAP=True) 37 | model.train() 38 | 39 | return broadcast_squeeze(p_Ys, 1) 40 | 41 | 42 | @torch.no_grad() 43 | def inference_pmean(model, *test_data, task_types, ns_G=1, ns_T=1, get_pmeans=False): 44 | ''' 45 | Calculate posterior predictive mean (or mode for categorical) with K global latents and L per-task latents. 46 | ''' 47 | X_C, Y_C, X_D = test_data 48 | 49 | model.eval() 50 | p_Ys = model(X_C, Y_C, X_D, MAP=False, ns_G=ns_G, ns_T=ns_T) 51 | model.train() 52 | 53 | Y_D_pmeans = broadcast_index(p_Ys, 0) 54 | 55 | Y_D_pred = broadcast_mean(Y_D_pmeans, 1) 56 | 57 | for task in Y_D_pmeans: 58 | if task_types[task] == 'discrete': 59 | Y_D_pmeans[task] = torch.argmax(Y_D_pmeans[task], -1) 60 | Y_D_pred[task] = torch.argmax(Y_D_pred[task], -1) 61 | 62 | if get_pmeans: 63 | return Y_D_pred, Y_D_pmeans 64 | else: 65 | return Y_D_pred 66 | 67 | 68 | def evaluate(model, test_loader, device, config, logger=None, 69 | imputer=None, config_imputer=None, tag='valid'): 70 | ''' 71 | Calculate error of model based on the posterior predictive mean. 72 | ''' 73 | errors = {task: 0 for task in config.tasks} 74 | nlls = {task: 0 for task in config.tasks} 75 | 76 | n_datasets = 0 77 | Y_C_comp = scales = None 78 | for b_idx, test_data in enumerate(test_loader): 79 | if config.data == 'synthetic': 80 | X_C, Y_C, X_D, Y_D, Y_C_comp, gt_params = to_device(test_data, device) 81 | scales = {task: gt_params[task]['a'] for task in gt_params} 82 | elif config.data == 'weather': 83 | X_C, Y_C, X_D, Y_D, Y_C_comp = to_device(test_data, device) 84 | scales = None 85 | else: 86 | raise NotImplementedError 87 | 88 | # impute if imputer is given 89 | if imputer is not None: 90 | Y_C_input = Y_C_imp = inference_pmean(imputer, X_C, Y_C, X_C, task_types=config.task_types, ns_G=config_imputer.ns_G, ns_T=config_imputer.ns_T) 91 | for task in Y_C_input: 92 | if config.task_types[task] == 'discrete': 93 | Y_C_input[task] = F.one_hot(Y_C_input[task], config.dim_ys[task]).float() 94 | else: 95 | Y_C_input = Y_C 96 | Y_C_imp = None 97 | 98 | # MAP inference 99 | Y_D_pred_map = inference_map(model, X_C, Y_C_input, X_D) 100 | # plot single batch 101 | if logger is not None and b_idx == 0: 102 | plot_curves(logger, config.tasks, X_C, Y_C, X_D, Y_D, Y_C_comp, Y_D_pred_map, Y_C_imp, 103 | n_subplots=min(10, X_C.size(0)), pred_type='map', colors=config.colors) 104 | 105 | # posterior predictive inference 106 | Y_D_pred, Y_D_pmeans = inference_pmean(model, X_C, Y_C_input, X_D, task_types=config.task_types, ns_G=config.ns_G, ns_T=config.ns_T, get_pmeans=True) 107 | # plot single batch 108 | if logger is not None and b_idx == 0: 109 | plot_curves(logger, config.tasks, X_C, Y_C, X_D, Y_D, Y_C_comp, Y_D_pmeans, Y_C_imp, 110 | n_subplots=min(10, X_C.size(0)), pred_type='pmeans', colors=config.colors) 111 | 112 | # compute errors 113 | nlls_ = compute_normalized_nll(Y_D, Y_D_pred_map, config.task_types) 114 | errors_ = compute_error(Y_D, Y_D_pred, config.task_types, scales) 115 | 116 | # batch denormalization 117 | for task in config.tasks: 118 | nlls[task] += (nlls_[task]*X_C.size(0)) 119 | errors[task] += (errors_[task]*X_C.size(0)) 120 | n_datasets += X_C.size(0) 121 | 122 | # batch renormalization 123 | for task in config.tasks: 124 | nlls[task] /= n_datasets 125 | errors[task] /= n_datasets 126 | 127 | if logger is not None: 128 | for task in config.tasks: 129 | logger.writer.add_scalar(f'{tag}/nll_{task}', nlls[task].item(), 130 | global_step=logger.global_step) 131 | logger.writer.add_scalar(f'{tag}/error_{task}', errors[task].item(), 132 | global_step=logger.global_step) 133 | logger.writer.flush() 134 | 135 | return nlls, errors 136 | -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import random 5 | import io 6 | import matplotlib.pyplot as plt 7 | import PIL.Image 8 | import numpy as np 9 | 10 | import torch 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.distributions import Normal 13 | from torchvision.transforms import ToTensor 14 | 15 | from .schedulers import LRScheduler, HPScheduler 16 | 17 | 18 | def configure_experiment(config, args): 19 | # update config with arguments 20 | config.model = args.model 21 | config.seed = args.seed 22 | config.name_postfix = args.name_postfix 23 | config.pma = args.pma 24 | 25 | # parse arguments 26 | if args.n_steps > 0: config.n_steps = args.n_steps 27 | if args.lr > 0: config.lr = args.lr 28 | if args.global_batch_size > 0: config.global_batch_size = args.global_batch_size 29 | if args.dim_hidden > 0: config.dim_hidden = args.dim_hidden 30 | 31 | if args.lr_schedule != '': config.lr_schedule = args.lr_schedule 32 | if args.beta_T_schedule != '': config.beta_T_schedule = args.beta_T_schedule 33 | if args.beta_G_schedule != '': config.beta_G_schedule = args.beta_G_schedule 34 | if args.gamma_train >= 0: config.gamma_train = args.gamma_train 35 | if args.gamma_valid >= 0: config.gamma_valid = args.gamma_valid 36 | if len(args.cs_range_train) > 0: 37 | assert len(args.cs_range_train) == 2 38 | config.cs_range_train = (int(args.cs_range_train[0]), int(args.cs_range_train[1])) 39 | if args.ts_train > 0: 40 | config.ts_train = args.ts_train 41 | 42 | # configure training missing rate 43 | if config.model == 'jtp': 44 | config.gamma_train = 0. 45 | 46 | # set seeds 47 | torch.backends.cudnn.deterministic = True 48 | torch.manual_seed(config.seed) 49 | torch.cuda.manual_seed(config.seed) 50 | random.seed(config.seed) 51 | np.random.seed(config.seed) 52 | 53 | # for debugging 54 | if args.debug_mode: 55 | config.n_steps = 3 56 | config.log_iter = 1 57 | config.val_iter = 1 58 | config.save_iter = 1 59 | config.imputer_path = config.imputer_path.replace(config.log_dir, config.log_dir + '_debugging') 60 | config.log_dir += '_debugging' 61 | 62 | # set directories 63 | if args.log_root != '': 64 | config.log_root = args.log_root 65 | if args.name != '': 66 | exp_name = args.name 67 | else: 68 | exp_name = config.model + config.name_postfix 69 | if args.imputer_path != '': 70 | config.imputer_path = args.imputer_path 71 | 72 | os.makedirs(config.log_root, exist_ok=True) 73 | os.makedirs(os.path.join(config.log_root, config.log_dir), exist_ok=True) 74 | os.makedirs(os.path.join(config.log_root, config.log_dir, exp_name), exist_ok=True) 75 | log_dir = os.path.join(config.log_root, config.log_dir, exp_name, 'logs') 76 | save_dir = os.path.join(config.log_root, config.log_dir, exp_name, 'checkpoints') 77 | if os.path.exists(save_dir): 78 | if args.debug_mode: 79 | shutil.rmtree(save_dir) 80 | else: 81 | while True: 82 | print('redundant experiment name! remove existing checkpoints? (y/n)') 83 | inp = input() 84 | if inp == 'y': 85 | shutil.rmtree(save_dir) 86 | break 87 | elif inp == 'n': 88 | print('quit') 89 | sys.exit() 90 | else: 91 | print('invalid input') 92 | os.makedirs(save_dir) 93 | 94 | # tensorboard logger 95 | logger = Logger(log_dir, config.tasks) 96 | log_keys = ['nll_normalized'] + [f'nll_{task}' for task in config.tasks] 97 | if config.model in ['mtp', 'mtp_s', 'jtp', 'mtp_s']: 98 | log_keys.append('kld_G') 99 | if config.model in ['mtp', 'stp', 'mtp_s']: 100 | log_keys += ['kld_T_normalized'] + [f'kld_{task}' for task in config.tasks] 101 | for log_key in log_keys: 102 | logger.register_key(log_key) 103 | 104 | return logger, save_dir, log_keys 105 | 106 | 107 | def get_schedulers(optimizer, config): 108 | lr_scheduler = LRScheduler(optimizer, config.lr_schedule, config.lr, config.n_steps, config.lr_warmup) 109 | beta_G_scheduler = beta_T_scheduler = None 110 | if config.model in ['mtp', 'jtp', 'mtp_s']: 111 | beta_G_scheduler = HPScheduler(config, 'beta_G', config.beta_G_schedule, config.beta_G, config.n_steps, config.beta_G_warmup) 112 | if config.model in ['mtp', 'stp', 'mtp_s']: 113 | beta_T_scheduler = HPScheduler(config, 'beta_T', config.beta_T_schedule, config.beta_G, config.n_steps, config.beta_T_warmup) 114 | 115 | return lr_scheduler, beta_G_scheduler, beta_T_scheduler 116 | 117 | 118 | def plot_curves(logger, tasks, X_C, Y_C, X_D, Y_D, Y_C_comp, Y_D_pred, Y_C_imp=None, pred_type='map', size=3, markersize=5, n_subplots=10, n_row=5, colors=None): 119 | toten = ToTensor() 120 | plt.rc('xtick', labelsize=3*size) 121 | plt.rc('ytick', labelsize=3*size) 122 | 123 | if colors is None: 124 | colors = {task: 'k' for task in Y_D} 125 | 126 | n_row = min(n_row, n_subplots) 127 | n_subplots = (n_subplots // n_row) * n_row 128 | 129 | for task in tasks: 130 | plt.figure(figsize=(size*n_row*4/3, size*(n_subplots // n_row))) 131 | for idx_sub in range(n_subplots): 132 | plt.subplot(n_subplots // n_row, n_row, idx_sub+1) 133 | 134 | index_D = ~Y_D[task][idx_sub].isnan() 135 | x_d = X_D[idx_sub, index_D].cpu() 136 | if len(x_d.size()) > 1: 137 | x_d = x_d.squeeze(-1) 138 | p_D = torch.argsort(x_d) 139 | 140 | # plot target 141 | y_d = Y_D[task][idx_sub, index_D].cpu() 142 | if len(y_d.size()) > 1: 143 | y_d = y_d.squeeze(-1) 144 | plt.plot(x_d[p_D], y_d[p_D], color='k', alpha=0.5) 145 | 146 | 147 | # pick observable context indices and sort them 148 | index_C = ~Y_C[task][idx_sub].isnan() 149 | x_c = X_C[idx_sub, index_C].cpu() 150 | if len(x_c.size()) > 1: 151 | x_c = x_c.squeeze(-1) 152 | p_C = torch.argsort(x_c) 153 | 154 | # plot context 155 | if Y_C_comp is not None: 156 | y_c = Y_C_comp[task][idx_sub, index_C].cpu() 157 | else: 158 | y_c = Y_C[task][idx_sub, index_C].cpu() 159 | if len(y_c.size()) > 1: 160 | y_c = y_c.squeeze(-1) 161 | plt.scatter(x_c[p_C], y_c[p_C], color=colors[task], s=markersize*size) 162 | 163 | 164 | # plot imputed value 165 | x_s = X_C[idx_sub, ~index_C].squeeze(-1).cpu() 166 | if Y_C_imp is not None: 167 | y_i = Y_C_imp[task][idx_sub, ~index_C].squeeze(-1).cpu() 168 | plt.scatter(x_s, y_i, color=colors[task], s=markersize*size, marker='^') 169 | 170 | # plot source 171 | if Y_C_comp is not None: 172 | y_s = Y_C_comp[task][idx_sub, ~index_C].squeeze(-1).cpu() 173 | plt.scatter(x_s, y_s, color=colors[task], s=1.5*markersize*size, marker='x') 174 | 175 | 176 | # plot predictions (either MAP or predictive means) 177 | if pred_type == 'pmeans': 178 | samples = Y_D_pred[task][idx_sub, :, index_D].cpu() 179 | for y_ps in samples: 180 | plt.plot(x_d[p_D], y_ps[p_D], color=colors[task], alpha=0.1) 181 | 182 | y_pm = samples.mean(0) 183 | error = (y_pm - y_d).pow(2).mean().item() 184 | plt.plot(x_d[p_D], y_pm[p_D], color=colors[task], label=f'{error:.3f}') 185 | 186 | elif pred_type == 'map': 187 | mu = Y_D_pred[task][0][idx_sub, index_D].cpu() 188 | sigma = Y_D_pred[task][1][idx_sub, index_D].cpu() 189 | nll = -Normal(mu, sigma).log_prob(y_d).mean().cpu().item() 190 | 191 | plt.plot(x_d[p_D], mu[p_D], color=colors[task], label=f'{nll:.3f}') 192 | plt.fill_between(x_d[p_D], mu[p_D] - sigma[p_D], mu[p_D] + sigma[p_D], color=colors[task], alpha=0.2) 193 | 194 | plt.legend() 195 | 196 | plt.tight_layout() 197 | 198 | # plt figure to io buffer 199 | buf = io.BytesIO() 200 | plt.savefig(buf, format='jpeg') 201 | buf.seek(0) 202 | 203 | # io buffer to tensor 204 | vis = PIL.Image.open(buf) 205 | vis = toten(vis) 206 | 207 | # log tensor and close figure 208 | logger.writer.add_image(f'valid_samples_{pred_type}_{task}', vis, global_step=logger.global_step) 209 | plt.close() 210 | 211 | logger.writer.flush() 212 | 213 | 214 | class Logger(): 215 | def __init__(self, log_dir, tasks, reset=True): 216 | if os.path.exists(log_dir) and reset: 217 | shutil.rmtree(log_dir) 218 | os.makedirs(log_dir, exist_ok=True) 219 | self.writer = SummaryWriter(log_dir) 220 | self.global_step = 0 221 | 222 | self.logs = {} 223 | self.logs_saved = {} 224 | self.iters = {} 225 | 226 | def register_key(self, key): 227 | self.logs[key] = 0 228 | self.logs_saved[key] = 0 229 | self.iters[key] = 0 230 | 231 | def add_value(self, key, value): 232 | self.logs[key] += value 233 | self.iters[key] += 1 234 | 235 | def get_value(self, key): 236 | if self.iters[key] == 0: 237 | return self.logs_saved[key] 238 | else: 239 | return self.logs[key] / self.iters[key] 240 | 241 | def reset(self, keys): 242 | for key in keys: 243 | self.logs_saved[key] = self.get_value(key) 244 | self.logs[key] = 0 245 | self.iters[key] = 0 246 | 247 | def log_values(self, keys, pbar=None, tag='train', global_step=0): 248 | if pbar is not None: 249 | desc = 'step {:05d}'.format(global_step) 250 | 251 | if 'nll_normalized' in keys: 252 | desc += ', {}: {:.3f}'.format('nll_norm', self.get_value('nll_normalized')) 253 | if 'kld_T_normalized' in keys: 254 | desc += ', {}: {:.3f}'.format('kld_T_norm', self.get_value('kld_T_normalized')) 255 | if 'kld_G' in keys: 256 | desc += ', {}: {:.3f}'.format('kld_G', self.get_value('kld_G')) 257 | pbar.set_description(desc) 258 | 259 | for key in filter(lambda x: x not in ['nll_normalized', 'kld_T_normalized'], keys): 260 | self.writer.add_scalar('{}/{}'.format(tag, key), self.get_value(key), global_step=global_step) 261 | 262 | for key in filter(lambda x: x in ['nll_normalized', 'kld_T_normalized', 'kld_G'], keys): 263 | self.writer.add_scalar('{}_summary/{}'.format(tag, key), self.get_value(key), global_step=global_step) 264 | 265 | 266 | class Saver: 267 | def __init__(self, model, save_dir, config): 268 | self.save_dir = save_dir 269 | self.config = config 270 | self.tasks = config.tasks 271 | self.model_type = config.model 272 | if self.model_type == 'stp': 273 | self.best_nll_state_dict = model.state_dict_() 274 | self.best_error_state_dict = model.state_dict_() 275 | 276 | self.best_nll = float('inf') 277 | self.best_nlls = {task: float('inf') for task in config.tasks} 278 | self.best_error = float('inf') 279 | self.best_errors = {task: float('inf') for task in config.tasks} 280 | 281 | def save(self, model, valid_nlls, valid_errors, global_step, save_name): 282 | torch.save({'model': model.state_dict_(), 'config': self.config, 283 | 'nlls': valid_nlls, 'errors': valid_errors, 'global_step': global_step}, 284 | os.path.join(self.save_dir, save_name)) 285 | 286 | 287 | def save_best(self, model, valid_nlls, valid_errors, global_step): 288 | valid_nll = sum([valid_nlls[task] for task in self.tasks]) 289 | valid_error = sum([valid_errors[task] for task in self.tasks]) 290 | 291 | # save best model 292 | if self.model_type == 'stp': 293 | update_nll = False 294 | update_error = False 295 | for task in self.best_nlls: 296 | if valid_nlls[task] < self.best_nlls[task]: 297 | self.best_nlls[task] = valid_nlls[task] 298 | self.best_nll_state_dict[task] = model.state_dict_task(task) 299 | update_nll = True 300 | 301 | if valid_errors[task] < self.best_errors[task]: 302 | self.best_errors[task] = valid_errors[task] 303 | self.best_error_state_dict[task] = model.state_dict_task(task) 304 | update_error = True 305 | 306 | if update_nll: 307 | torch.save({'model': self.best_nll_state_dict, 'config': self.config, 308 | 'nlls': valid_nlls, 'errors': valid_errors, 'global_step': global_step}, 309 | os.path.join(self.save_dir, 'best_nll.pth')) 310 | if update_error: 311 | torch.save({'model': self.best_error_state_dict, 'config': self.config, 312 | 'nlls': valid_nlls, 'errors': valid_errors, 'global_step': global_step}, 313 | os.path.join(self.save_dir, 'best_error.pth')) 314 | else: 315 | if valid_nll < self.best_nll: 316 | self.best_nll = valid_nll 317 | torch.save({'model': model.state_dict_(), 'config': self.config, 318 | 'nlls': valid_nlls, 'errors': valid_errors, 'global_step': global_step}, 319 | os.path.join(self.save_dir, 'best_nll.pth')) 320 | if valid_error < self.best_error: 321 | self.best_error = valid_error 322 | torch.save({'model': model.state_dict_(), 'config': self.config, 323 | 'nlls': valid_nlls, 'errors': valid_errors, 'global_step': global_step}, 324 | os.path.join(self.save_dir, 'best_error.pth')) 325 | 326 | 327 | 328 | def broadcast_squeeze(data, dim): 329 | def squeeze_wrapper(data): 330 | if isinstance(data, torch.Tensor): 331 | return data.squeeze(dim) 332 | elif isinstance(data, tuple): 333 | return tuple(map(squeeze_wrapper, data)) 334 | elif isinstance(data, list): 335 | return list(map(squeeze_wrapper, data)) 336 | elif isinstance(data, dict): 337 | return {key: squeeze_wrapper(data[key]) for key in data} 338 | else: 339 | raise NotImplementedError 340 | 341 | return squeeze_wrapper(data) 342 | 343 | 344 | def broadcast_index(data, idx): 345 | def index_wrapper(data): 346 | if isinstance(data, torch.Tensor): 347 | return data 348 | elif isinstance(data, tuple): 349 | return data[idx] 350 | elif isinstance(data, list): 351 | return data[idx] 352 | elif isinstance(data, dict): 353 | return {key: index_wrapper(data[key]) for key in data} 354 | else: 355 | raise NotImplementedError 356 | 357 | return index_wrapper(data) 358 | 359 | 360 | def broadcast_mean(data, dim): 361 | def mean_wrapper(data): 362 | if isinstance(data, torch.Tensor): 363 | return data.mean(dim) 364 | elif isinstance(data, tuple): 365 | return tuple(map(mean_wrapper, data)) 366 | elif isinstance(data, list): 367 | return list(map(mean_wrapper, data)) 368 | elif isinstance(data, dict): 369 | return {key: mean_wrapper(data[key]) for key in data} 370 | else: 371 | raise NotImplementedError 372 | 373 | return mean_wrapper(data) --------------------------------------------------------------------------------