├── CITransNet ├── base.py ├── clippedAdam.py ├── main_2x5_c.py ├── main_2x5_d.py ├── main_3x10_c.py ├── main_3x10_d.py ├── main_5x10_c.py ├── main_5x10_d.py ├── network.py ├── trainer.py └── utilities.py ├── LICENSE ├── README.md └── data_gen ├── data_gen_continous.py └── data_gen_discrete.py /CITransNet/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import random 5 | import logging 6 | import torch 7 | import json 8 | import shutil 9 | import argparse 10 | 11 | def str2bool(v): 12 | if v.lower() in ['yes', 'true', 't', 'y', '1']: 13 | return True 14 | elif v.lower() in ['no', 'false', 'f', 'n', '0']: 15 | return False 16 | else: 17 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--save_dir', default=None) 21 | parser.add_argument('--checkpoint_final', action='store_true') 22 | parser.add_argument('--no_checkpoint', action='store_true') 23 | parser.add_argument('--seed', type=int, default=0) 24 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 25 | parser.add_argument('--log_tag', type=str, default=None) 26 | parser.add_argument('--test', action='store_true') 27 | parser.add_argument('--restart', action='store_true') 28 | 29 | 30 | class BaseTrainer(object): 31 | def __init__(self, args): 32 | self.args = args 33 | self.set_save_dir(args) 34 | self.set_logger(args) 35 | self.set_seed(args) 36 | 37 | def set_save_dir(self, args): 38 | if args.save_dir: 39 | if not args.test and os.path.exists(args.save_dir): 40 | if args.restart: 41 | shutil.rmtree(args.save_dir) 42 | else: 43 | print('save path exist, continue training? [y/n]') 44 | s = input() 45 | if s == 'n': 46 | shutil.rmtree(args.save_dir) 47 | elif args.test and not os.path.exists(args.save_dir): 48 | print(f'no checkpoint at {args.save_dir}!') 49 | exit() 50 | if not os.path.exists(args.save_dir): 51 | os.makedirs(args.save_dir) 52 | 53 | def set_logger(self, args): 54 | ''' 55 | Write logs to checkpoint and console 56 | ''' 57 | log_file = None 58 | filemode = 'a' 59 | if args.save_dir: 60 | log_file = 'test' if args.test else 'train' 61 | if args.test: 62 | filemode = 'w' 63 | if args.log_tag: 64 | log_file += f'_{args.log_tag}' 65 | log_file = os.path.join(args.save_dir, log_file + '.log') 66 | 67 | logging.basicConfig( 68 | format='%(asctime)s %(levelname)-8s %(message)s', 69 | level=logging.INFO, 70 | datefmt='%Y-%m-%d %H:%M:%S', 71 | filename=log_file, 72 | filemode=filemode 73 | ) 74 | if log_file is not None: 75 | console = logging.StreamHandler() 76 | console.setLevel(logging.INFO) 77 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 78 | console.setFormatter(formatter) 79 | logging.getLogger('').addHandler(console) 80 | logging.info(f'save log on {log_file}') 81 | 82 | def save_model(self, args, save_variable_dict, tag=None): 83 | if not args.save_dir: 84 | pass 85 | if args.save_dir: 86 | logging.info(f'save to {args.save_dir}') 87 | argparse_dict = vars(args) 88 | with open(os.path.join(args.save_dir, 'config.json'), 'w') as fjson: 89 | json.dump(argparse_dict, fjson) 90 | fname = 'checkpoint' if tag is None else f'checkpoint-{tag}' 91 | torch.save(save_variable_dict, os.path.join(args.save_dir, fname)) 92 | 93 | def set_seed(self, args): 94 | random.seed(args.seed) 95 | np.random.seed(args.seed) 96 | torch.random.manual_seed(args.seed) 97 | 98 | -------------------------------------------------------------------------------- /CITransNet/clippedAdam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import collections 7 | import torch.optim as optim 8 | 9 | class Adam(optim.Optimizer): 10 | r"""Implements Adam algorithm. 11 | 12 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 13 | 14 | Arguments: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float, optional): learning rate (default: 1e-3) 18 | betas (Tuple[float, float], optional): coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps (float, optional): term added to the denominator to improve 21 | numerical stability (default: 1e-8) 22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 23 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 24 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 25 | (default: False) 26 | 27 | .. _Adam\: A Method for Stochastic Optimization: 28 | https://arxiv.org/abs/1412.6980 29 | .. _On the Convergence of Adam and Beyond: 30 | https://openreview.net/forum?id=ryQu7f-RZ 31 | """ 32 | 33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 34 | weight_decay=0, amsgrad=False): 35 | if not 0.0 <= lr: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if not 0.0 <= eps: 38 | raise ValueError("Invalid epsilon value: {}".format(eps)) 39 | if not 0.0 <= betas[0] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 41 | if not 0.0 <= betas[1] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay, amsgrad=amsgrad) 45 | super(Adam, self).__init__(params, defaults) 46 | 47 | def __setstate__(self, state): 48 | super(Adam, self).__setstate__(state) 49 | for group in self.param_groups: 50 | group.setdefault('amsgrad', False) 51 | 52 | def step(self, restricted = False, min=0, max=1, closure=None): 53 | """Performs a single optimization step. 54 | 55 | Arguments: 56 | closure (callable, optional): A closure that reevaluates the model 57 | and returns the loss. 58 | """ 59 | loss = None 60 | if closure is not None: 61 | loss = closure() 62 | 63 | for group in self.param_groups: 64 | for p in group['params']: 65 | if p.grad is None: 66 | continue 67 | grad = p.grad.data 68 | if grad.is_sparse: 69 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 70 | amsgrad = group['amsgrad'] 71 | 72 | state = self.state[p] 73 | 74 | # State initialization 75 | if len(state) == 0: 76 | state['step'] = 0 77 | # Exponential moving average of gradient values 78 | state['exp_avg'] = torch.zeros_like(p.data) 79 | # Exponential moving average of squared gradient values 80 | state['exp_avg_sq'] = torch.zeros_like(p.data) 81 | if amsgrad: 82 | # Maintains max of all exp. moving avg. of sq. grad. values 83 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 84 | 85 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 86 | if amsgrad: 87 | max_exp_avg_sq = state['max_exp_avg_sq'] 88 | beta1, beta2 = group['betas'] 89 | 90 | state['step'] += 1 91 | bias_correction1 = 1 - beta1 ** state['step'] 92 | bias_correction2 = 1 - beta2 ** state['step'] 93 | 94 | if group['weight_decay'] != 0: 95 | grad.add_(group['weight_decay'], p.data) 96 | 97 | # Decay the first and second moment running average coefficient 98 | # exp_avg.mul_(beta1).add_(1 - beta1, grad) 99 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 100 | # exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 101 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 102 | if amsgrad: 103 | # Maintains the maximum of all 2nd moment running avg. till now 104 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 105 | # Use the max. for normalizing running avg. of gradient 106 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 107 | else: 108 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 109 | 110 | step_size = group['lr'] / bias_correction1 111 | 112 | # p.data.addcdiv_(-step_size, exp_avg, denom) 113 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 114 | 115 | if restricted: 116 | p.data.clamp_(min, max) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /CITransNet/main_2x5_c.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import datetime 4 | from trainer import Trainer 5 | 6 | if __name__ == '__main__': 7 | from base import parser, str2bool 8 | parser.add_argument('--data_dir', type=str, default='../data_multi/10d_2x5') 9 | parser.add_argument('--training_set', type=str, default='training_100000') 10 | parser.add_argument('--test_set', type=str, default='test_5000') 11 | parser.add_argument('--n_bidder_type', type=int, default=None) 12 | parser.add_argument('--n_item_type', type=int, default=None) 13 | parser.add_argument('--d_emb', type=int, default=10) 14 | parser.add_argument('--n_layer', type=int, default=3) 15 | parser.add_argument('--n_head', type=int, default=4) 16 | parser.add_argument('--d_hidden', type=int, default=64) 17 | 18 | parser.add_argument('--n_bidder', type=int, default=2) 19 | parser.add_argument('--n_item', type=int, default=5) 20 | parser.add_argument('--r_train', type=int, default=25, help='Number of steps in the inner maximization loop') 21 | parser.add_argument('--r_test', type=int, default=200, help='Number of steps in the inner maximization loop when testing') 22 | parser.add_argument('--gamma', type=float, default=1e-3, help='The learning rate for the inner maximization loop') 23 | parser.add_argument('--n_misreport_init', type=int, default=100) 24 | parser.add_argument('--n_misreport_init_train', type=int, default=1) 25 | 26 | parser.add_argument('--train_size', type=int, default=None) 27 | parser.add_argument('--batch_size', type=int, default=500) 28 | parser.add_argument('--n_epoch', type=int, default=80) 29 | parser.add_argument('--learning_rate', type=float, default=1e-3) 30 | 31 | parser.add_argument('--test_size', type=float, default=None) 32 | parser.add_argument('--batch_test', type=int, default=50) 33 | parser.add_argument('--eval_freq', type=int, default=30) 34 | parser.add_argument('--save_freq', type=int, default=1) 35 | 36 | parser.add_argument('--lamb', type=float, default=5) 37 | parser.add_argument('--lamb_update_freq', type=int, default=6) 38 | parser.add_argument('--rho', type=float, default=1) 39 | parser.add_argument('--delta_rho', type=float, default=5) 40 | parser.add_argument('--v_min', type=float, default=0) 41 | parser.add_argument('--v_max', type=float, default=1) 42 | parser.add_argument('--data_parallel', type=str2bool, default=False) 43 | parser.add_argument('--continuous_context', type=str2bool, default=True) 44 | parser.add_argument('--cond_prob', type=str2bool, default=False) 45 | parser.add_argument('--data_parallel_test', action='store_true') 46 | parser.add_argument('--test_ckpt_tag', type=str, default=None) 47 | 48 | t0 = time() 49 | args = parser.parse_args() 50 | trainer = Trainer(args) 51 | if args.test: 52 | trainer.test(args, load=True) 53 | else: 54 | trainer.train(args) 55 | 56 | time_used = time() - t0 57 | logging.info(f'Time Cost={datetime.timedelta(seconds=time_used)}') 58 | -------------------------------------------------------------------------------- /CITransNet/main_2x5_d.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import datetime 4 | from trainer import Trainer 5 | 6 | if __name__ == '__main__': 7 | from base import parser, str2bool 8 | parser.add_argument('--data_dir', type=str, default='../data_multi/10t10t_2x5/') 9 | parser.add_argument('--training_set', type=str, default='training_100000') 10 | parser.add_argument('--test_set', type=str, default='test_5000') 11 | parser.add_argument('--n_bidder_type', type=int, default=10) 12 | parser.add_argument('--n_item_type', type=int, default=10) 13 | parser.add_argument('--d_emb', type=int, default=16) 14 | parser.add_argument('--n_layer', type=int, default=3) 15 | parser.add_argument('--n_head', type=int, default=4) 16 | parser.add_argument('--d_hidden', type=int, default=64) 17 | 18 | parser.add_argument('--n_bidder', type=int, default=2) 19 | parser.add_argument('--n_item', type=int, default=5) 20 | parser.add_argument('--r_train', type=int, default=25, help='Number of steps in the inner maximization loop') 21 | parser.add_argument('--r_test', type=int, default=200, help='Number of steps in the inner maximization loop when testing') 22 | parser.add_argument('--gamma', type=float, default=1e-3, help='The learning rate for the inner maximization loop') 23 | parser.add_argument('--n_misreport_init', type=int, default=100) 24 | parser.add_argument('--n_misreport_init_train', type=int, default=1) 25 | 26 | parser.add_argument('--train_size', type=int, default=None) 27 | parser.add_argument('--batch_size', type=int, default=500) 28 | parser.add_argument('--n_epoch', type=int, default=80) 29 | parser.add_argument('--learning_rate', type=float, default=1e-3) 30 | 31 | parser.add_argument('--test_size', type=float, default=None) 32 | parser.add_argument('--batch_test', type=int, default=50) 33 | parser.add_argument('--eval_freq', type=int, default=30) 34 | parser.add_argument('--save_freq', type=int, default=1) 35 | 36 | parser.add_argument('--lamb', type=float, default=5) 37 | parser.add_argument('--lamb_update_freq', type=int, default=2) 38 | parser.add_argument('--rho', type=float, default=1) 39 | parser.add_argument('--delta_rho', type=float, default=5) 40 | parser.add_argument('--v_min', type=float, default=0) 41 | parser.add_argument('--v_max', type=float, default=1) 42 | parser.add_argument('--data_parallel', type=str2bool, default=False) 43 | parser.add_argument('--continuous_context', type=str2bool, default=False) 44 | parser.add_argument('--cond_prob', type=str2bool, default=False) 45 | parser.add_argument('--data_parallel_test', action='store_true') 46 | parser.add_argument('--test_ckpt_tag', type=str, default=None) 47 | 48 | t0 = time() 49 | args = parser.parse_args() 50 | trainer = Trainer(args) 51 | if args.test: 52 | trainer.test(args, load=True) 53 | else: 54 | trainer.train(args) 55 | 56 | time_used = time() - t0 57 | logging.info(f'Time Cost={datetime.timedelta(seconds=time_used)}') -------------------------------------------------------------------------------- /CITransNet/main_3x10_c.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import datetime 4 | from trainer import Trainer 5 | 6 | if __name__ == '__main__': 7 | from base import parser, str2bool 8 | parser.add_argument('--data_dir', type=str, default='../data_multi/10d_3x10') 9 | parser.add_argument('--training_set', type=str, default='training_100000') 10 | parser.add_argument('--test_set', type=str, default='test_5000') 11 | parser.add_argument('--n_bidder_type', type=int, default=None) 12 | parser.add_argument('--n_item_type', type=int, default=None) 13 | parser.add_argument('--d_emb', type=int, default=10) 14 | parser.add_argument('--n_layer', type=int, default=3) 15 | parser.add_argument('--n_head', type=int, default=4) 16 | parser.add_argument('--d_hidden', type=int, default=64) 17 | 18 | parser.add_argument('--n_bidder', type=int, default=3) 19 | parser.add_argument('--n_item', type=int, default=10) 20 | parser.add_argument('--r_train', type=int, default=25, help='Number of steps in the inner maximization loop') 21 | parser.add_argument('--r_test', type=int, default=200, help='Number of steps in the inner maximization loop when testing') 22 | parser.add_argument('--gamma', type=float, default=1e-3, help='The learning rate for the inner maximization loop') 23 | parser.add_argument('--n_misreport_init', type=int, default=100) 24 | parser.add_argument('--n_misreport_init_train', type=int, default=1) 25 | 26 | 27 | parser.add_argument('--train_size', type=int, default=None) 28 | parser.add_argument('--batch_size', type=int, default=500) 29 | parser.add_argument('--n_epoch', type=int, default=80) 30 | parser.add_argument('--learning_rate', type=float, default=1e-3) 31 | 32 | parser.add_argument('--test_size', type=float, default=None) 33 | parser.add_argument('--batch_test', type=int, default=25) 34 | parser.add_argument('--eval_freq', type=int, default=30) 35 | parser.add_argument('--save_freq', type=int, default=1) 36 | 37 | parser.add_argument('--lamb', type=float, default=5) 38 | parser.add_argument('--lamb_update_freq', type=int, default=2) 39 | parser.add_argument('--rho', type=float, default=1) 40 | parser.add_argument('--delta_rho', type=float, default=5) 41 | parser.add_argument('--v_min', type=float, default=0) 42 | parser.add_argument('--v_max', type=float, default=1) 43 | parser.add_argument('--data_parallel', type=str2bool, default=False) 44 | parser.add_argument('--continuous_context', type=str2bool, default=True) 45 | parser.add_argument('--cond_prob', type=str2bool, default=False) 46 | parser.add_argument('--data_parallel_test', action='store_true') 47 | parser.add_argument('--test_ckpt_tag', type=str, default=None) 48 | 49 | t0 = time() 50 | args = parser.parse_args() 51 | trainer = Trainer(args) 52 | if args.test: 53 | trainer.test(args, load=True) 54 | else: 55 | trainer.train(args) 56 | 57 | time_used = time() - t0 58 | logging.info(f'Time Cost={datetime.timedelta(seconds=time_used)}') -------------------------------------------------------------------------------- /CITransNet/main_3x10_d.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import datetime 4 | from trainer import Trainer 5 | 6 | if __name__ == '__main__': 7 | from base import parser, str2bool 8 | parser.add_argument('--data_dir', type=str, default='../data_multi/10t10t_3x10/') 9 | parser.add_argument('--training_set', type=str, default='training_200000') 10 | parser.add_argument('--test_set', type=str, default='test_5000') 11 | parser.add_argument('--n_bidder_type', type=int, default=10) 12 | parser.add_argument('--n_item_type', type=int, default=10) 13 | parser.add_argument('--d_emb', type=int, default=16) 14 | parser.add_argument('--n_layer', type=int, default=3) 15 | parser.add_argument('--n_head', type=int, default=4) 16 | parser.add_argument('--d_hidden', type=int, default=64) 17 | 18 | parser.add_argument('--n_bidder', type=int, default=3) 19 | parser.add_argument('--n_item', type=int, default=10) 20 | parser.add_argument('--r_train', type=int, default=25, help='Number of steps in the inner maximization loop') 21 | parser.add_argument('--r_test', type=int, default=200, help='Number of steps in the inner maximization loop when testing') 22 | parser.add_argument('--gamma', type=float, default=1e-3, help='The learning rate for the inner maximization loop') 23 | parser.add_argument('--n_misreport_init', type=int, default=100) 24 | parser.add_argument('--n_misreport_init_train', type=int, default=1) 25 | 26 | parser.add_argument('--train_size', type=int, default=None) 27 | parser.add_argument('--batch_size', type=int, default=500) 28 | parser.add_argument('--n_epoch', type=int, default=80) 29 | parser.add_argument('--learning_rate', type=float, default=3e-4) 30 | 31 | parser.add_argument('--test_size', type=float, default=None) 32 | parser.add_argument('--batch_test', type=int, default=25) 33 | parser.add_argument('--eval_freq', type=int, default=20) 34 | parser.add_argument('--save_freq', type=int, default=1) 35 | 36 | parser.add_argument('--lamb', type=float, default=5) 37 | parser.add_argument('--lamb_update_freq', type=int, default=4) 38 | parser.add_argument('--rho', type=float, default=1) 39 | parser.add_argument('--delta_rho', type=float, default=5) 40 | parser.add_argument('--v_min', type=float, default=0) 41 | parser.add_argument('--v_max', type=float, default=1) 42 | parser.add_argument('--data_parallel', type=str2bool, default=False) 43 | parser.add_argument('--continuous_context', type=str2bool, default=False) 44 | parser.add_argument('--cond_prob', type=str2bool, default=False) 45 | parser.add_argument('--data_parallel_test', action='store_true') 46 | parser.add_argument('--test_ckpt_tag', type=str, default=None) 47 | 48 | t0 = time() 49 | args = parser.parse_args() 50 | trainer = Trainer(args) 51 | if args.test: 52 | trainer.test(args, load=True) 53 | else: 54 | trainer.train(args) 55 | 56 | time_used = time() - t0 57 | logging.info(f'Time Cost={datetime.timedelta(seconds=time_used)}') -------------------------------------------------------------------------------- /CITransNet/main_5x10_c.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import datetime 4 | from trainer import Trainer 5 | 6 | if __name__ == '__main__': 7 | from base import parser, str2bool 8 | parser.add_argument('--data_dir', type=str, default='../data_multi/10d_5x10') 9 | parser.add_argument('--training_set', type=str, default='training_100000') 10 | parser.add_argument('--test_set', type=str, default='test_5000') 11 | parser.add_argument('--n_bidder_type', type=int, default=None) 12 | parser.add_argument('--n_item_type', type=int, default=None) 13 | parser.add_argument('--d_emb', type=int, default=10) 14 | parser.add_argument('--n_layer', type=int, default=3) 15 | parser.add_argument('--n_head', type=int, default=4) 16 | parser.add_argument('--d_hidden', type=int, default=64) 17 | 18 | parser.add_argument('--n_bidder', type=int, default=5) 19 | parser.add_argument('--n_item', type=int, default=10) 20 | parser.add_argument('--r_train', type=int, default=25, help='Number of steps in the inner maximization loop') 21 | parser.add_argument('--r_test', type=int, default=200, help='Number of steps in the inner maximization loop when testing') 22 | parser.add_argument('--gamma', type=float, default=1e-3, help='The learning rate for the inner maximization loop') 23 | parser.add_argument('--n_misreport_init', type=int, default=100) 24 | parser.add_argument('--n_misreport_init_train', type=int, default=1) 25 | 26 | 27 | parser.add_argument('--train_size', type=int, default=None) 28 | parser.add_argument('--batch_size', type=int, default=500) 29 | parser.add_argument('--n_epoch', type=int, default=30) 30 | parser.add_argument('--learning_rate', type=float, default=1e-3) 31 | 32 | parser.add_argument('--test_size', type=float, default=None) 33 | parser.add_argument('--batch_test', type=int, default=12) 34 | parser.add_argument('--eval_freq', type=int, default=30) 35 | parser.add_argument('--save_freq', type=int, default=1) 36 | 37 | parser.add_argument('--lamb', type=float, default=5) 38 | parser.add_argument('--lamb_update_freq', type=int, default=4) 39 | parser.add_argument('--rho', type=float, default=1) 40 | parser.add_argument('--delta_rho', type=float, default=5) 41 | parser.add_argument('--v_min', type=float, default=0) 42 | parser.add_argument('--v_max', type=float, default=1) 43 | parser.add_argument('--data_parallel', type=str2bool, default=False) 44 | parser.add_argument('--continuous_context', type=str2bool, default=True) 45 | parser.add_argument('--cond_prob', type=str2bool, default=False) 46 | parser.add_argument('--data_parallel_test', action='store_true') 47 | parser.add_argument('--test_ckpt_tag', type=str, default=None) 48 | 49 | t0 = time() 50 | args = parser.parse_args() 51 | trainer = Trainer(args) 52 | if args.test: 53 | trainer.test(args, load=True) 54 | else: 55 | trainer.train(args) 56 | 57 | time_used = time() - t0 58 | logging.info(f'Time Cost={datetime.timedelta(seconds=time_used)}') -------------------------------------------------------------------------------- /CITransNet/main_5x10_d.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import datetime 4 | from trainer import Trainer 5 | 6 | if __name__ == '__main__': 7 | from base import parser, str2bool 8 | parser.add_argument('--data_dir', type=str, default='../data_multi/10t10t_5x10/') 9 | parser.add_argument('--training_set', type=str, default='training_200000') 10 | parser.add_argument('--test_set', type=str, default='test_5000') 11 | parser.add_argument('--n_bidder_type', type=int, default=10) 12 | parser.add_argument('--n_item_type', type=int, default=10) 13 | parser.add_argument('--d_emb', type=int, default=16) 14 | parser.add_argument('--n_layer', type=int, default=3) 15 | parser.add_argument('--n_head', type=int, default=4) 16 | parser.add_argument('--d_hidden', type=int, default=64) 17 | 18 | parser.add_argument('--n_bidder', type=int, default=5) 19 | parser.add_argument('--n_item', type=int, default=10) 20 | parser.add_argument('--r_train', type=int, default=25, help='Number of steps in the inner maximization loop') 21 | parser.add_argument('--r_test', type=int, default=200, help='Number of steps in the inner maximization loop when testing') 22 | parser.add_argument('--gamma', type=float, default=1e-3, help='The learning rate for the inner maximization loop') 23 | parser.add_argument('--n_misreport_init', type=int, default=100) 24 | parser.add_argument('--n_misreport_init_train', type=int, default=1) 25 | 26 | parser.add_argument('--train_size', type=int, default=None) 27 | parser.add_argument('--batch_size', type=int, default=500) 28 | parser.add_argument('--n_epoch', type=int, default=60) 29 | parser.add_argument('--learning_rate', type=float, default=1e-3) 30 | 31 | parser.add_argument('--test_size', type=float, default=None) 32 | parser.add_argument('--batch_test', type=int, default=30) 33 | parser.add_argument('--eval_freq', type=int, default=30) 34 | parser.add_argument('--save_freq', type=int, default=1) 35 | 36 | parser.add_argument('--lamb', type=float, default=5) 37 | parser.add_argument('--lamb_update_freq', type=int, default=2) 38 | parser.add_argument('--rho', type=float, default=1) 39 | parser.add_argument('--delta_rho', type=float, default=5) 40 | parser.add_argument('--v_min', type=float, default=0) 41 | parser.add_argument('--v_max', type=float, default=1) 42 | parser.add_argument('--data_parallel', type=str2bool, default=True) 43 | parser.add_argument('--continuous_context', type=str2bool, default=False) 44 | parser.add_argument('--cond_prob', type=str2bool, default=False) 45 | parser.add_argument('--data_parallel_test', action='store_true') 46 | parser.add_argument('--test_ckpt_tag', type=str, default=None) 47 | 48 | t0 = time() 49 | args = parser.parse_args() 50 | trainer = Trainer(args) 51 | if args.test: 52 | trainer.test(args, load=True) 53 | else: 54 | trainer.train(args) 55 | 56 | time_used = time() - t0 57 | logging.info(f'Time Cost={datetime.timedelta(seconds=time_used)}') -------------------------------------------------------------------------------- /CITransNet/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Transformer2DNet(nn.Module): 6 | def __init__(self, d_input, d_output, n_layer, n_head, d_hidden): 7 | super(Transformer2DNet, self).__init__() 8 | self.d_input = d_input 9 | self.d_output = d_output 10 | self.n_layer = n_layer 11 | 12 | d_in = d_input 13 | self.row_transformer = nn.ModuleList() 14 | self.col_transformer = nn.ModuleList() 15 | self.fc = nn.ModuleList() 16 | for i in range(n_layer): 17 | d_out = d_hidden if i != n_layer - 1 else d_output 18 | self.row_transformer.append(nn.TransformerEncoderLayer(d_in, n_head, d_hidden, batch_first=True, dropout=0)) 19 | self.col_transformer.append(nn.TransformerEncoderLayer(d_in, n_head, d_hidden, batch_first=True, dropout=0)) 20 | self.fc.append(nn.Sequential( 21 | nn.Linear(d_in + 2 * d_hidden, d_hidden), 22 | nn.ReLU(), 23 | nn.Linear(d_hidden, d_out) 24 | )) 25 | d_in = d_hidden 26 | 27 | def forward(self, input): 28 | bs, n_bidder, n_item, d = input.shape 29 | x = input 30 | for i in range(self.n_layer): 31 | row_x = x.view(-1, n_item, d) 32 | row = self.row_transformer[i](row_x) 33 | row = row.view(bs, n_bidder, n_item, -1) 34 | 35 | col_x = x.permute(0, 2, 1, 3).reshape(-1, n_bidder, d) 36 | col = self.col_transformer[i](col_x) 37 | col = col.view(bs, n_item, n_bidder, -1).permute(0, 2, 1, 3) 38 | 39 | glo = x.view(bs, n_bidder*n_item, -1).mean(1, keepdim=True) 40 | glo = glo.unsqueeze(1) # (bs, 1, 1, -1) 41 | glo = glo.repeat(1, n_bidder, n_item, 1) 42 | 43 | x = torch.cat([row, col, glo], dim=-1) 44 | x = self.fc[i](x) 45 | return x 46 | 47 | class TransformerMechanism(nn.Module): 48 | def __init__(self, n_bidder_type, n_item_type, d_emb, n_layer, n_head, d_hidden, v_min=0, v_max=1, 49 | continuous_context=False, cond_prob=False): 50 | super(TransformerMechanism, self).__init__() 51 | self.d_emb = d_emb 52 | self.v_min = v_min 53 | self.v_max = v_max 54 | self.continuous_context = continuous_context 55 | self.pre_net = nn.Sequential( 56 | nn.Linear(d_emb*2+1, d_hidden), 57 | nn.ReLU(), 58 | nn.Linear(d_hidden, d_hidden-1) 59 | ) 60 | d_input = d_hidden 61 | self.n_layer, self.n_head, self.d_hidden = n_layer, n_head, d_hidden 62 | self.mechanism = Transformer2DNet(d_input, 3, self.n_layer, n_head, d_hidden) 63 | if not continuous_context: 64 | self.bidder_embeddings = nn.Embedding(n_bidder_type, d_emb) 65 | self.item_embeddings = nn.Embedding(n_item_type, d_emb) 66 | self.cond_prob = cond_prob 67 | 68 | def forward(self, batch_data): 69 | raw_bid, bidder_context, item_context = batch_data 70 | bid = (raw_bid - self.v_min) / (self.v_max - self.v_min) 71 | bs, n_bidder, n_item = bid.shape 72 | x1 = bid.unsqueeze(-1) # (bs, n, m, 1) 73 | 74 | if self.continuous_context: 75 | bidder_emb = bidder_context 76 | item_emb = item_context 77 | else: 78 | bidder_emb = self.bidder_embeddings(bidder_context.view(-1, n_bidder)) 79 | item_emb = self.item_embeddings(item_context.view(-1, n_item)) 80 | 81 | x2 = bidder_emb.unsqueeze(2).repeat(1, 1, n_item, 1) 82 | x3 = item_emb.unsqueeze(1).repeat(1, n_bidder, 1, 1) 83 | 84 | x = torch.cat([x1, x2, x3], dim=-1) 85 | x = self.pre_net(x) 86 | x = torch.cat([x1, x], dim=-1) 87 | 88 | mechanism = self.mechanism(x) 89 | allocation, allocation_prob, payment = \ 90 | mechanism[:, :, :, 0], mechanism[:, :, :, 1], mechanism[:, :, :, 2] # (bs, n, m) 91 | 92 | allocation = F.softmax(allocation, dim=1) 93 | if self.cond_prob: 94 | allocation_prob = allocation_prob.mean(-2, keepdims=True) 95 | allocation_prob = torch.sigmoid(allocation_prob) 96 | allocation = allocation * allocation_prob 97 | 98 | payment = payment.mean(-1) 99 | payment = torch.sigmoid(payment) 100 | payment = (raw_bid * allocation).sum(-1) * payment 101 | 102 | return allocation, payment 103 | -------------------------------------------------------------------------------- /CITransNet/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from statistics import mean 5 | import random 6 | import logging 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm import tqdm 11 | import json 12 | import shutil 13 | from time import time 14 | 15 | from base import BaseTrainer 16 | from network import TransformerMechanism 17 | from utilities import misreportOptimization 18 | def loss_function(mechanism,lamb,rho,batch,trueValuations,misreports): 19 | from utilities import loss 20 | return loss(mechanism,lamb,rho,batch,trueValuations,misreports) 21 | 22 | 23 | class Trainer(BaseTrainer): 24 | def __init__(self, args): 25 | super(Trainer, self).__init__(args) 26 | self.set_data(args) 27 | self.set_model(args) 28 | self.start_epoch = 0 29 | self.rho = args.rho 30 | self.lamb = args.lamb * torch.ones(args.n_bidder).to(args.device) 31 | self.n_iter = 0 32 | 33 | def set_data(self, args): 34 | def load_data(dir, size): 35 | data = [np.load(os.path.join(dir, 'trueValuations.npy')).astype(np.float32), 36 | np.load(os.path.join(dir, 'Agent_names_idxs.npy')), 37 | np.load(os.path.join(dir, 'Object_names_idxs.npy'))] 38 | if args.continuous_context: 39 | for i in [1, 2]: 40 | data[i] = data[i].astype(np.float32) 41 | else: 42 | for i in [1, 2]: 43 | data[i] = data[i].astype(np.int64) 44 | if size is not None: 45 | data = [x[:self.train_size] for x in data] 46 | return tuple(data) 47 | 48 | if not args.test: 49 | self.train_dir = os.path.join(args.data_dir, args.training_set) 50 | self.train_size = args.train_size 51 | self.train_data = load_data(self.train_dir, self.train_size) 52 | self.train_size = len(self.train_data[0]) 53 | 54 | self.misreports = np.random.uniform(args.v_min, args.v_max, 55 | size=(self.train_size, args.n_misreport_init_train, 56 | args.n_bidder, args.n_item)) 57 | 58 | self.test_dir = os.path.join(args.data_dir, args.test_set) 59 | self.test_size = args.test_size 60 | self.test_data = load_data(self.test_dir, self.test_size) 61 | self.test_size = len(self.test_data[0]) 62 | 63 | def set_model(self, args): 64 | self.mechanism = TransformerMechanism(args.n_bidder_type, args.n_item_type, args.d_emb, 65 | args.n_layer, args.n_head, args.d_hidden, 66 | args.v_min, args.v_max, args.continuous_context, 67 | args.cond_prob).to(args.device) 68 | if args.data_parallel: 69 | self.mechanism = nn.DataParallel(self.mechanism) 70 | self.optimizer = torch.optim.Adam(self.mechanism.parameters(), lr=args.learning_rate) 71 | 72 | def save(self, epoch, tag=None): 73 | save_variable_dict = { 74 | 'mechanism': self.mechanism.state_dict(), 75 | 'optimizer': self.optimizer.state_dict(), 76 | 'start_epoch': epoch+1, 77 | 'rho': self.rho, 78 | 'lamb': self.lamb, 79 | 'n_iter': self.n_iter, 80 | 'misreports': self.misreports, 81 | } 82 | self.save_model(self.args, save_variable_dict, tag=tag) 83 | 84 | def load(self, tag=None): 85 | fname = 'checkpoint' if tag is None else f'checkpoint-{tag}' 86 | try: 87 | ckpt_path = os.path.join(self.args.save_dir, fname) 88 | except: 89 | return 90 | if os.path.exists(ckpt_path): 91 | logging.info(f'load checkpoint from {ckpt_path}') 92 | ckpt = torch.load(ckpt_path) 93 | self.mechanism.load_state_dict(ckpt['mechanism']) 94 | self.optimizer.load_state_dict(ckpt['optimizer']) 95 | self.start_epoch = ckpt['start_epoch'] 96 | self.rho = ckpt['rho'] 97 | self.lamb = ckpt['lamb'] 98 | self.n_iter = ckpt['n_iter'] 99 | if 'misreports' in ckpt: 100 | self.misreports = ckpt['misreports'] 101 | 102 | def train(self, args): 103 | self.load() 104 | for epoch in range(self.start_epoch, args.n_epoch): 105 | profit_sum = 0 106 | regret_sum = 0 107 | regret_max = 0 108 | loss_sum = 0 109 | for i in tqdm(range(0, self.train_size, args.batch_size)): 110 | self.n_iter += 1 111 | batch_indices = np.random.choice(self.train_size, args.batch_size) 112 | self.misreports = misreportOptimization(self.mechanism, batch_indices, self.train_data, self.misreports, 113 | args.r_train, args.gamma, args.v_min, args.v_max) 114 | loss, regret_mean_bidder, regret_max_batch, profit = \ 115 | loss_function(self.mechanism, self.lamb, self.rho, batch_indices, self.train_data, self.misreports) 116 | loss_sum += loss.item() * len(batch_indices) 117 | regret_sum += regret_mean_bidder.mean().item() * len(batch_indices) 118 | regret_max = max(regret_max, regret_max_batch.item()) 119 | profit_sum += profit.item() * len(batch_indices) 120 | 121 | self.optimizer.zero_grad() 122 | loss.backward() 123 | self.optimizer.step() 124 | if self.n_iter % args.lamb_update_freq == 0: 125 | self.lamb += self.rho * regret_mean_bidder.detach() 126 | 127 | if (epoch + 1) % 2 == 0: 128 | self.rho += args.delta_rho 129 | logging.info(f"Train: epoch={epoch + 1}, loss={loss_sum/self.train_size:.5}, " 130 | f"profit={(profit_sum)/self.train_size:.5}, " 131 | f"regret={(regret_sum)/self.train_size:.5}, regret_max={regret_max:.5}") 132 | logging.info(f"Train: rho={self.rho}, lamb={self.lamb.mean().item()}, ") 133 | if (epoch + 1) % args.save_freq == 0: 134 | self.save(epoch) 135 | if (epoch + 1) % args.eval_freq == 0: 136 | self.test(args, valid=True) 137 | self.save(epoch, tag=epoch) 138 | 139 | logging.info('Final test') 140 | self.save(args.n_epoch-1) 141 | self.test(args) 142 | 143 | def test(self, args, valid=False, load=False): 144 | if load: 145 | self.load(args.test_ckpt_tag) 146 | if valid == False and args.data_parallel_test: 147 | self.mechanism = nn.DataParallel(self.mechanism) 148 | if len(self.lamb) != args.n_bidder: 149 | self.lamb = args.lamb * torch.ones(args.n_bidder).to(args.device) 150 | if valid: 151 | data_size = args.batch_test * 10 152 | indices = np.random.choice(self.test_size, data_size) 153 | data = tuple([x[indices] for x in self.test_data]) 154 | else: 155 | data_size = self.test_size 156 | data = self.test_data 157 | misreports = np.random.uniform(args.v_min, args.v_max, size=(data_size, args.n_misreport_init, args.n_bidder, args.n_item)) 158 | indices = np.arange(data_size) 159 | profit_sum = 0.0 160 | regret_sum = 0.0 161 | regret_max = 0.0 162 | loss_sum = 0.0 163 | n_iter = 0.0 164 | for i in tqdm(range(0, data_size, args.batch_test)): 165 | batch_indices = indices[i:i+args.batch_test] 166 | n_iter += len(batch_indices) 167 | misreports = misreportOptimization(self.mechanism, batch_indices, data, misreports, 168 | args.r_test, args.gamma, args.v_min, args.v_max) 169 | with torch.no_grad(): 170 | loss, regret_mean_bidder, regret_max_batch, profit = \ 171 | loss_function(self.mechanism, self.lamb, self.rho, batch_indices, data, misreports) 172 | 173 | loss_sum += loss.item() * len(batch_indices) 174 | regret_sum += regret_mean_bidder.mean().item() * len(batch_indices) 175 | regret_max = max(regret_max, regret_max_batch.item()) 176 | profit_sum += profit.item() * len(batch_indices) 177 | 178 | if valid == False: 179 | logging.info(f"profit={(profit_sum)/n_iter:.5}, regret={(regret_sum)/n_iter:.5}") 180 | 181 | logging.info(f"Test: loss={loss_sum/data_size:.5}, profit={(profit_sum)/data_size:.5}, " 182 | f"regret={(regret_sum)/data_size:.5}, regret_max={regret_max:.5}") 183 | -------------------------------------------------------------------------------- /CITransNet/utilities.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import collections 7 | import torch.optim as optim 8 | from clippedAdam import Adam 9 | import matplotlib.pyplot as plt 10 | 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | 13 | 14 | class View(nn.Module): 15 | def __init__(self, *shape): 16 | super(View, self).__init__() 17 | self.shape = shape 18 | def forward(self, input): 19 | return input.view(*self.shape) 20 | 21 | def utility(valuations, allocation, pay): 22 | """ Given input valuation , payment and allocation , computes utility 23 | Input params: 24 | valuation : [num_batches, num_agents, num_items] 25 | allocation: [num_batches, num_agents, num_items] 26 | pay : [num_batches, num_agents] 27 | Output params: 28 | utility: [num_batches, num_agents] 29 | """ 30 | return (torch.sum(valuations*allocation, dim=-1) - pay) 31 | 32 | def revenue(pay): 33 | """ Given payment (pay), computes revenue 34 | Input params: 35 | pay: [num_batches, num_agents] 36 | Output params: 37 | revenue: scalar 38 | """ 39 | return torch.mean(torch.sum(pay, dim=-1)) 40 | 41 | 42 | def misreportUtility(mechanism,batch_data,batchMisreports): 43 | """ This function takes the valuation and misreport batches 44 | and returns a tensor constaining all the misreported utilities 45 | 46 | 47 | """ 48 | batchTrueValuations = batch_data[0] 49 | batch_bidder_context = batch_data[1] # (bs, n_bidder, d) 50 | batch_item_context = batch_data[2] # (bs, n_item, d) 51 | nAgent = batchTrueValuations.shape[-2] 52 | nObjects = batchTrueValuations.shape[-1] 53 | batchSize = batchTrueValuations.shape[0] 54 | nbrInitializations = batchMisreports.shape[1] 55 | 56 | V = batchTrueValuations.unsqueeze(1) 57 | V = V.repeat(1,nbrInitializations, 1, 1) 58 | V = V.unsqueeze(0) 59 | V = V.repeat(nAgent, 1, 1, 1, 1) # (n_bidder, bs, n_init, n_bidder, n_item) 60 | 61 | M = batchMisreports.unsqueeze(0) 62 | M = M.repeat(nAgent,1, 1, 1, 1) 63 | 64 | 65 | mask1 = np.zeros((nAgent,nAgent,nObjects)) 66 | mask1[np.arange(nAgent),np.arange(nAgent),:] = 1.0 67 | mask2 = np.ones((nAgent,nAgent,nObjects)) 68 | mask2 = mask2-mask1 69 | 70 | mask1 = (torch.tensor(mask1).float()).to(device) 71 | mask2 = (torch.tensor(mask2).float()).to(device) 72 | 73 | V = V.permute(1, 2, 0, 3, 4) # (bs, n_init, n_bidder, n_bidder, n_item) 74 | M = M.permute(1, 2, 0, 3, 4) 75 | 76 | tensor = M*mask1 + V*mask2 77 | 78 | tensor = tensor.permute(2, 0, 1, 3, 4) # (n_bidder, bs, n_init, n_bidder, n_item) 79 | bidder_context = batch_bidder_context.view(1, batchSize, 1, nAgent, -1).repeat(nAgent, 1, nbrInitializations, 1, 1) 80 | item_context = batch_item_context.view(1, batchSize, 1, nObjects, -1).repeat(nAgent, 1, nbrInitializations, 1, 1) 81 | 82 | V = V.permute(2, 0, 1, 3, 4) 83 | M = M.permute(2, 0, 1, 3, 4) 84 | 85 | tensor = View(-1,nAgent, nObjects)(tensor) 86 | tensor = tensor.float() 87 | bidder_context = bidder_context.view(tensor.shape[0], nAgent, -1) 88 | item_context = item_context.view(tensor.shape[0], nObjects, -1) 89 | 90 | allocation, payment = mechanism((tensor, bidder_context, item_context)) 91 | 92 | allocation = View(nAgent,batchSize,nbrInitializations,nAgent, nObjects)(allocation) 93 | payment = View(nAgent,batchSize,nbrInitializations,nAgent)(payment) 94 | 95 | advUtilities = torch.sum(allocation*V, dim=-1)-payment 96 | 97 | advUtility = advUtilities[np.arange(nAgent),:,:,np.arange(nAgent)] 98 | 99 | return(advUtility.permute(1, 2, 0)) 100 | 101 | 102 | def misreportOptimization(mechanism,batch, data, misreports, R, gamma, minimum=0, maximum=1): 103 | 104 | """ This function takes the valuation and misreport batches 105 | and R the number of optimization step and modifies the misreport array 106 | 107 | 108 | 109 | """ 110 | localMisreports = misreports[:] 111 | batchMisreports = torch.tensor(misreports[batch]).to(device) 112 | batchTrueValuations = torch.tensor(data[0][batch]).to(device) 113 | batch_bidder_type = torch.tensor(data[1][batch]).to(device) 114 | batch_item_type = torch.tensor(data[2][batch]).to(device) 115 | batch_data = (batchTrueValuations, batch_bidder_type, batch_item_type) 116 | batchMisreports.requires_grad = True 117 | 118 | opt = Adam([batchMisreports], lr=gamma) 119 | 120 | for k in range(R): 121 | advU = misreportUtility(mechanism,batch_data,batchMisreports) 122 | loss = -1*torch.sum(advU).to(device) 123 | loss.backward() 124 | opt.step(restricted= True, min=minimum, max=maximum) 125 | opt.zero_grad() 126 | 127 | mechanism.zero_grad() 128 | 129 | localMisreports[batch,:,:,:] = batchMisreports.cpu().detach().numpy() 130 | return(localMisreports) 131 | 132 | def trueUtility(mechanism,batch_data,allocation=None, payment=None): 133 | 134 | """ This function takes the valuation batches 135 | and returns a tensor constaining the utilities 136 | 137 | """ 138 | if allocation is None or payment is None: 139 | allocation, payment = mechanism(batch_data) 140 | batchTrueValuations = batch_data[0] 141 | return utility(batchTrueValuations, allocation, payment) 142 | 143 | 144 | def regret(mechanism, batch_data, batchMisreports, allocation, payment): 145 | """ This function takes the valuation and misreport batches 146 | and returns a tensor constaining the regrets for each bidder and each batch 147 | 148 | 149 | """ 150 | missReportUtilityAll = misreportUtility(mechanism,batch_data,batchMisreports) 151 | misReportUtilityMax = torch.max(missReportUtilityAll, dim =1)[0] 152 | return(misReportUtilityMax-trueUtility(mechanism,batch_data, allocation, payment)) 153 | 154 | 155 | def loss(mechanism, lamb, rho, batch, data, misreports): 156 | """ 157 | This function tackes a batch which is a numpy array of indices and computes 158 | the loss function : loss 159 | the average regret per agent which is a tensor of size [nAgent] : rMean 160 | the maximum regret among all batches and agenrs which is a tensor of size [1] : rMax 161 | the average payments which is a tensor of size [1] : -paymentLoss 162 | 163 | """ 164 | batchMisreports = torch.tensor(misreports[batch]).to(device) 165 | batchTrueValuations = torch.tensor(data[0][batch]).to(device) 166 | batch_bidder_type = torch.tensor(data[1][batch]).to(device) 167 | batch_item_type = torch.tensor(data[2][batch]).to(device) 168 | batch_data = (batchTrueValuations, batch_bidder_type, batch_item_type) 169 | allocation, payment = mechanism(batch_data) 170 | 171 | paymentLoss = -torch.sum(payment)/batch.shape[0] 172 | 173 | r = F.relu(regret(mechanism,batch_data,batchMisreports, allocation, payment)) 174 | rMean = torch.mean(r, dim=0).to(device) 175 | 176 | rMax = torch.max(r).to(device) 177 | 178 | lagrangianLoss = torch.sum(rMean*lamb) 179 | 180 | lagLoss = (rho/2)*torch.sum(torch.pow(rMean,2)) 181 | 182 | loss = paymentLoss +lagrangianLoss+lagLoss 183 | 184 | return(loss, rMean, rMax, -paymentLoss) 185 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zhijian Duan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This is the Pytorch implementation of our paper: *A Context-Integrated Transformer-Based Neural Network for Auction Design* () in *ICML 2022*. 4 | 5 | 6 | ## Requirements 7 | 8 | 9 | * Python >= 3.6 10 | * Pytorch 1.10.0 11 | * Argparse 12 | * Logging 13 | * Tqdm 14 | * Scipy 15 | 16 | ## Usage 17 | 18 | ### Generate the data 19 | 20 | ```bash 21 | cd data_gen 22 | # For Setting G,H,I 23 | python data_gen_continous.py 24 | 25 | # For Setting D,E,F 26 | python data_gen_discrete.py 27 | ``` 28 | 29 | ### Train CITransNet 30 | 31 | ```bash 32 | cd CITransNet 33 | # For Setting G 34 | python main_2x5_c.py 35 | 36 | # For Setting H 37 | python main_3x10_c.py 38 | 39 | # For Setting I 40 | python main_5x10_c.py 41 | 42 | # For Setting D 43 | python main_2x5_d.py 44 | 45 | # For Setting E 46 | python main_3x10_d.py 47 | 48 | # For Setting F, it is recommended to train it with 2 GPUs. 49 | CUDA_VISIBLE_DEVICES=0,1 python main_5x10_d.py --data_parallel True 50 | ``` 51 | 52 | ## Acknowledgement 53 | 54 | Our code is built upon the implementation of . -------------------------------------------------------------------------------- /data_gen/data_gen_continous.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from scipy.stats import truncnorm 5 | from tqdm import tqdm 6 | from IPython import embed 7 | 8 | def gen(n, m, d=10, phase='train', n_data=200000): 9 | dir = f'../data_multi/{d}d_{n}x{m}/{phase}_{n_data}' 10 | if not os.path.exists(dir): 11 | os.makedirs(dir) 12 | raw_valuation = torch.rand(n_data, n, m) 13 | tau = -1 + 2 * torch.rand(n_data, n, d) 14 | omega = -1 + 2 * torch.rand(n_data, m, d) 15 | 16 | valuation = raw_valuation * torch.sigmoid(tau @ omega.permute(0, 2, 1)) 17 | np.save(os.path.join(dir, 'trueValuations'), valuation.numpy()) 18 | np.save(os.path.join(dir, 'Agent_names_idxs'), tau.numpy()) 19 | np.save(os.path.join(dir, 'Object_names_idxs'), omega.numpy()) 20 | 21 | if __name__ == '__main__': 22 | seed = 0 23 | np.random.seed(seed) 24 | n, m = 2, 5 25 | gen(n, m, phase='test', n_data=5000) 26 | gen(n, m, phase='training', n_data=int(1e5)) 27 | n, m = 3, 10 28 | gen(n, m, phase='test', n_data=5000) 29 | gen(n, m, phase='training', n_data=int(1e5)) 30 | n, m = 5, 10 31 | gen(n, m, phase='test', n_data=5000) 32 | gen(n, m, phase='training', n_data=int(1e5)) 33 | -------------------------------------------------------------------------------- /data_gen/data_gen_discrete.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from scipy.stats import truncnorm 4 | from tqdm import tqdm 5 | 6 | def gen(n, m, n_type=10, m_type=10, phase='train', n_data=200000): 7 | dir = f'../data_multi/{n_type}t{m_type}t_{n}x{m}/{phase}_{n_data}' 8 | if not os.path.exists(dir): 9 | os.makedirs(dir) 10 | raw_data = np.zeros((n_data, n, n_type, m, m_type)) 11 | ind = [(i, j) for i in range(n_type) for j in range(m_type)] 12 | for i,j in tqdm(ind): 13 | myclip_a, myclip_b, my_mean, my_std = 0, 1, (1 + (i+j+2)%10)/11, 0.05 14 | a, b = (myclip_a - my_mean) / my_std, (myclip_b - my_mean) / my_std 15 | x = truncnorm.rvs(a, b, loc=my_mean, scale=my_std, size=(n_data, n, m)) 16 | raw_data[:, :, i, :, j] = x 17 | tau = np.random.randint(0, n_type, size=(n_data, n)) 18 | b_data = raw_data.reshape(n_data*n, n_type, m, m_type) 19 | b_data = b_data[np.arange(n_data*n), tau.reshape(-1)] 20 | b_data = b_data.reshape(n_data, n, m, m_type) 21 | 22 | omega = np.random.randint(0, m_type, size=(n_data, m)) 23 | data = b_data.transpose(0, 2, 3, 1) # n_data, m, m_type, n 24 | data = data.reshape(n_data*m, m_type, n) 25 | data = data[np.arange(n_data*m), omega.reshape(-1)] 26 | data = data.reshape(n_data, m, n).transpose(0, 2, 1) 27 | 28 | valuation = data 29 | np.save(os.path.join(dir, 'trueValuations'), valuation) 30 | np.save(os.path.join(dir, 'Agent_names_idxs'), tau) 31 | np.save(os.path.join(dir, 'Object_names_idxs'), omega) 32 | 33 | if __name__ == '__main__': 34 | seed = 0 35 | np.random.seed(seed) 36 | gen(2, 5, phase='test', n_data=5000) 37 | gen(2, 5, phase='training', n_data=int(1e5)) 38 | gen(3, 10, phase='test', n_data=5000) 39 | gen(3, 10, phase='training', n_data=int(2e5)) 40 | gen(5, 10, phase='test', n_data=5000) 41 | gen(5, 10, phase='training', n_data=int(2e5)) 42 | --------------------------------------------------------------------------------