├── dataset └── .gitkeep ├── run.sh ├── sampling.py ├── main.py ├── .gitignore ├── README.md ├── arguments.py ├── eval.py ├── data.py ├── baselines.py └── models.py /dataset/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py --model=LiveRec \ 4 | --dataset="dataset/" \ 5 | --fr_ctx \ 6 | --fr_rep \ 7 | --model_to="liverec" \ 8 | --device="cuda" \ 9 | --caching 10 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | def sample_av(p,t,args): 5 | # availability sampling 6 | av = args.ts[t] 7 | while True: 8 | ridx = random.randint(0,len(av)-1) 9 | ri = av[ridx] 10 | if p!=ri: 11 | return ri 12 | 13 | def sample_uni(p,t,args): 14 | # uniform sampling 15 | while True: 16 | ri = random.randint(0,args.N-1) 17 | if p!=ri: 18 | return ri 19 | 20 | def sample_negs(data,args): 21 | pos,xts = data[:,:,5],data[:,:,6] 22 | neg = torch.zeros_like(pos) 23 | 24 | ci = torch.nonzero(pos, as_tuple=False) 25 | ps = pos[ci[:,0],ci[:,1]].tolist() 26 | ts = xts[ci[:,0],ci[:,1]].tolist() 27 | 28 | for i in range(ci.shape[0]): 29 | p = ps[i]; t = ts[i] 30 | 31 | if args.uniform: 32 | neg[ci[i,0],ci[i,1]] = sample_uni(p,t,args) 33 | else: 34 | neg[ci[i,0],ci[i,1]] = sample_av(p,t,args) 35 | 36 | return neg 37 | 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys,os,time 2 | import pandas as pd 3 | from tqdm import tqdm 4 | 5 | from arguments import * 6 | from eval import * 7 | from data import * 8 | from models import * 9 | 10 | import torch 11 | import torch.optim as optim 12 | 13 | args = arg_parse() 14 | print_args(args) 15 | args.device = torch.device(args.device) 16 | 17 | MPATH,MODEL = get_model_type(args) 18 | 19 | data_fu = load_data(args) 20 | train_loader, val_loader, test_loader = get_dataloaders(data_fu,args) 21 | 22 | # baselines models 23 | if args.model in ['POP','REP']: 24 | data_tr = data_fu[data_fu.stopbest_val: 64 | best_val = hall 65 | torch.save(model.state_dict(), MPATH) 66 | best_cnt = best_max 67 | else: 68 | best_cnt -= 1 69 | if best_cnt == 0: 70 | break 71 | 72 | model = MODEL(args).to(args.device) 73 | model.load_state_dict(torch.load(MPATH)) 74 | 75 | scores = compute_recall(model, test_loader, args) 76 | print("Final score") 77 | print("="*11) 78 | print('Epoch: {:03d}, Loss: {:.5f}'.format(epoch, loss_all/loss_cnt)) 79 | print_scores(scores) 80 | save_scores(scores,args) 81 | 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | *.un~ 141 | logs.txt 142 | run_custom.sh 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LiveRec 2 | 3 | This repository contains the code of LiveRec, from the paper 4 | **Recommendation on Live-Streaming Platforms: Dynamic Availability and Repeat Consumption** 5 | by Jérémie Rappaz, Julian McAuley and Karl Aberer, accepted as a full paper at RecSys 2021 6 | 7 | ## Abstract 8 | 9 | Live-streaming platforms broadcast user-generated video in real-time. Recommendation on these platforms shares similarities with traditional settings, such as a large volume of heterogeneous content and highly skewed interaction distributions. However, several challenges must be overcome to adapt recommendation algorithms to live-streaming platforms: first, content availability is dynamic which restricts users to choose from only a subset of items at any given time; during training and inference we must carefully handle this factor in order to properly account for such signals, where 'non-interactions' reflect availability as much as implicit preference. Streamers are also fundamentally different from 'items' in traditional settings: repeat consumption of specific channels plays a significant role, though the content itself is fundamentally ephemeral. 10 | 11 | In this work, we study recommendation in this setting of a dynamically evolving set of available items. We propose LiveRec, a self-attentive model that personalizes item ranking based on both historical interactions and current availability. We also show that carefully modelling repeat consumption plays a significant role in model performance. To validate our approach, and to inspire further research on this setting, we release a dataset containing 475M user interactions on Twitch over a 43-day period. We evaluate our approach on a recommendation task and show our method to outperform various strong baselines in ranking the currently available content. 12 | 13 | ## Datasets 14 | 15 | Two datasets are provided in our [Google Drive](https://drive.google.com/drive/folders/1BD8m7a8m7onaifZay05yYjaLxyVV40si?usp=sharing). The file `full_a.csv.gz` contains the full dataset whil `100k.csv` is a subset of 100k users for benchmark purposes. 16 | 17 | | | Twitch 100k | Twitch full | 18 | |--------------------|-------------|-------------| 19 | | #Users | 100k | 15.5M | 20 | | #Items (streamers) | 162.6k | 465k | 21 | | #Interactions | 3M | 124M | 22 | | #Timesteps (10min) | 6148 | 6148 | 23 | 24 | Our datasets have been collected from Twitch. We took a full snapshot of all availble streams every 10 minutes, during 43 days. For each stream, we retrieved all logged in users from the chat. All usernames have been anonymized. Start and stop times are provided as integers and represent periods of 10 minutes. 25 | 26 | #### Fields description 27 | 28 | * `user_id`: user identifier (anonymized). 29 | * `stream id`: stream identifier, could be used to retreive a single broadcast segment (not used in our study). 30 | * `streamer name`: name of the channel. 31 | * `start time`: first crawling round at which the user was seen in the chat. 32 | * `stop time`: last crawling round at which the user was seen in the chat. 33 | 34 | ## Credits 35 | If you find any of this useful in your own research, please cite 36 | 37 | ``` 38 | @inproceedings{rappaz2021recommendation, 39 | title={Recommendation on Live-Streaming Platforms: Dynamic Availability and Repeat Consumption}, 40 | author={Rappaz, J{\'e}r{\'e}mie and McAuley, Julian and Aberer, Karl}, 41 | booktitle={Fifteenth ACM Conference on Recommender Systems}, 42 | pages={390--399}, 43 | year={2021} 44 | } 45 | ``` 46 | 47 | For more information, please contact Jérémie at `rappaz [dot] jeremie [at] gmail [dot] com`. 48 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from prettytable import PrettyTable 3 | 4 | def print_args(parse_args): 5 | x = PrettyTable() 6 | x.field_names = ["Arg.", "Value"] 7 | for arg in vars(parse_args): 8 | x.add_row([arg, getattr(parse_args, arg)]) 9 | print(x) 10 | 11 | def arg_parse(): 12 | parser = argparse.ArgumentParser(description='LiveRec - Twitch') 13 | 14 | parser.add_argument('--seed', dest='seed', type=int, 15 | help='Random seed') 16 | 17 | parser.add_argument('--batch_size', dest='batch_size', type=int, 18 | help='Batch size - only active if torch is used') 19 | parser.add_argument('--seq_len', dest='seq_len', type=int, 20 | help='Max size of the sequence to consider') 21 | 22 | parser.add_argument('--num_heads', dest='num_heads', type=int, 23 | help='Numer of heads to use for multi-heads attention') 24 | parser.add_argument('--num_heads_ctx', dest='num_heads_ctx', type=int, 25 | help='Numer of heads to use for multi-heads attention CTX') 26 | 27 | 28 | parser.add_argument('--dataset', dest='dataset', 29 | help='Input dataset.') 30 | 31 | parser.add_argument('--model', dest='model', type=str, 32 | help='Type of the model') 33 | 34 | parser.add_argument('--model_from', dest='mfrom', type=str, 35 | help='Name of the model to load') 36 | parser.add_argument('--model_to', dest='mto', type=str, 37 | help='Name of the model to save') 38 | parser.add_argument('--cache_dir', dest='cache_dir', type=str, 39 | help='Path to save the cached preprocessd dataset') 40 | 41 | parser.add_argument('--model_path', dest='model_path', type=str, 42 | help='Path to save the model') 43 | parser.add_argument('--early_stop', dest='early_stop', type=int, 44 | help='Number of iteration without improvment before stop') 45 | parser.add_argument('--ev_sample', dest='ev_sample', type=int, 46 | help='Number of samples for the final evaluation') 47 | parser.add_argument('--device', dest='device', type=str, 48 | help='Pytorch device') 49 | 50 | parser.add_argument('--lr', dest='lr', type=float, 51 | help='Learning rate.') 52 | parser.add_argument('--mask_prob', dest='mask_prob', type=float, 53 | help='BERT mask prob.') 54 | parser.add_argument('--l2', dest='l2', type=float, 55 | help='Strength of L2 regularization') 56 | parser.add_argument('--dim', dest='K', type=int, 57 | help='Number of latent factors') 58 | 59 | parser.add_argument('--num_iters', dest='num_iter', type=int, 60 | help='Number of training iterations') 61 | parser.add_argument('--num_epochs', dest='num_epochs', type=int, 62 | help='Number of training epochs') 63 | parser.add_argument('--num_att', dest='num_att', type=int, 64 | help='Num attention module for seq encoding') 65 | parser.add_argument('--num_att_ctx', dest='num_att_ctx', type=int, 66 | help='Num attention for ctx module') 67 | 68 | parser.add_argument('--topk_att', dest='topk_att', type=int, 69 | help='Items to send to attentive output') 70 | 71 | parser.add_argument('--fr_ctx', dest='fr_ctx', nargs='?', 72 | const=True, default=False, 73 | help='') 74 | parser.add_argument('--fr_rep', dest='fr_rep', nargs='?', 75 | const=True, default=False, 76 | help='') 77 | parser.add_argument('--uniform', dest='uniform', nargs='?', 78 | const=True, default=False, 79 | help='') 80 | parser.add_argument('--debug', dest='debug', nargs='?', 81 | const=True, default=False, 82 | help='') 83 | parser.add_argument('--caching', dest='caching', nargs='?', 84 | const=True, default=False, 85 | help='') 86 | 87 | parser.set_defaults( 88 | seed=42, 89 | dataset="/mnt/localdata/rappaz/twitch/data/v3/100k/", 90 | lr=0.0005, 91 | l2=0.1, 92 | mask_prob=0.5, 93 | batch_size=100, 94 | num_att=2, 95 | num_att_ctx=2, 96 | num_heads=4, 97 | num_heads_ctx=4, 98 | num_iter=200, 99 | seq_len=16, 100 | topk_att=64, 101 | early_stop=15, 102 | K=64, 103 | num_epochs=150, 104 | model="LiveRec", 105 | model_path="/mnt/datastore/rappaz/twitch/models", 106 | mto="liverec", 107 | device="cuda", 108 | cache_dir="dataset/" 109 | ) 110 | 111 | args = parser.parse_args() 112 | return args 113 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | import torch 4 | from sampling import * 5 | from data import * 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | def save_scores(scores, args): 10 | with open("logs.txt", 'a') as fout: 11 | fout.write('{};{};{};{};{:.5f};{:.5f};{};{}\n'.format( 12 | args.model, 13 | args.K, 14 | args.fr_ctx, 15 | args.fr_rep, 16 | args.lr, 17 | args.l2, 18 | args.seq_len, 19 | args.topk_att, 20 | )) 21 | for k in ['all','new','rep']: 22 | fout.write('{};{:.5f};{:.5f};{:.5f};{:.5f};{:.5f};{:.5f}\n'.format( 23 | k, 24 | scores[k]['h01'], 25 | scores[k]['h05'], 26 | scores[k]['h10'], 27 | scores[k]['ndcg01'], 28 | scores[k]['ndcg05'], 29 | scores[k]['ndcg10'], 30 | )) 31 | if args.model=="BERT": fout.write("mask_prob: %.2f\n" % (args.mask_prob)) 32 | fout.write('\n') 33 | 34 | 35 | def print_scores(scores): 36 | for k in ['all','new','rep']: 37 | print('{}: h@1: {:.5f} h@5: {:.5f} h@10: {:.5f} ndcg@1: {:.5f} ndcg@5: {:.5f} ndcg@10: {:.5f}'.format( 38 | k, 39 | scores[k]['h01'], 40 | scores[k]['h05'], 41 | scores[k]['h10'], 42 | scores[k]['ndcg01'], 43 | scores[k]['ndcg05'], 44 | scores[k]['ndcg10'], 45 | )) 46 | print("ratio: ", scores['ratio']) 47 | 48 | def metrics(a): 49 | a = np.array(a) 50 | tot = float(len(a)) 51 | 52 | return { 53 | 'h01': (a<1).sum()/tot, 54 | 'h05': (a<5).sum()/tot, 55 | 'h10': (a<10).sum()/tot, 56 | 'ndcg01': np.sum([1 / np.log2(rank + 2) for rank in a[a<1]])/tot, 57 | 'ndcg05': np.sum([1 / np.log2(rank + 2) for rank in a[a<5]])/tot, 58 | 'ndcg10': np.sum([1 / np.log2(rank + 2) for rank in a[a<10]])/tot, 59 | } 60 | 61 | def compute_recall(model, _loader, args, maxit=100000): 62 | 63 | store = {'rrep': [],'rnew': [],'rall': [], 'ratio': []} 64 | 65 | model.eval() 66 | with torch.no_grad(): 67 | for i,data in tqdm(enumerate(_loader)): 68 | data = data.to(args.device) 69 | store = model.compute_rank(data,store,k=10) 70 | if i>maxit: break 71 | 72 | return { 73 | 'rep': metrics(store['rrep']), 74 | 'new': metrics(store['rnew']), 75 | 'all': metrics(store['rall']), 76 | 'ratio': np.mean(store['ratio']), 77 | } 78 | 79 | def compute_rank(data,store,k=10): 80 | inputs,pos,_ = convert_batch(data,self.args,sample_neg=False) 81 | 82 | feats = self(inputs,data) 83 | 84 | xtsy = torch.zeros_like(pos) 85 | xtsy[data.x_s_batch,data.x_s[:,3]] = data.xts 86 | 87 | if self.args.fr_ctx: 88 | ctx,batch_inds = self.get_ctx_att(data,inputs,feats) 89 | 90 | if self.args.fr_ctx==False and self.args.fr_rep==True: 91 | rep_enc = self.rep_emb(self.get_av_rep(inputs,data)) 92 | 93 | # identify repeated interactions in the batch 94 | mask = torch.ones_like(pos[:,-1]).type(torch.bool) 95 | for b in range(pos.shape[0]): 96 | avt = pos[b,:-1] 97 | avt = avt[avt!=0] 98 | mask[b] = pos[b,-1] in avt 99 | store['ratio'] += [float(pos[b,-1] in avt)] 100 | 101 | for b in range(inputs.shape[0]): 102 | step = xtsy[b,-1].item() 103 | av = torch.LongTensor(self.args.ts[step]).to(self.args.device) 104 | av_embs = self.item_embedding(av) 105 | 106 | if self.args.fr_ctx==False and self.args.fr_rep: 107 | # get rep 108 | reps = inputs[b,inputs[b,:]!=0].unsqueeze(1)==av 109 | a = (step-xtsy[b,inputs[b,:]!=0]).unsqueeze(1).repeat(1,len(av)) * reps 110 | if torch.any(torch.any(reps,1)): 111 | a = a[torch.any(reps,1),:] 112 | a[a==0]=99999 113 | a = a.min(0).values*torch.any(reps,0) 114 | sm = torch.bucketize(a, self.boundaries)+1 115 | sm = sm*torch.any(reps,0) 116 | sm = self.rep_emb(sm) 117 | av_embs += sm 118 | 119 | if self.args.fr_ctx: 120 | ctx_expand = torch.zeros(self.args.av_tens.shape[1],self.args.K,device=self.args.device) 121 | ctx_expand[batch_inds[b,-1,:],:] = ctx[b,-1,:,:] 122 | scores = (feats[b,-1,:] * ctx_expand).sum(-1) 123 | scores = scores[:len(av)] 124 | else: 125 | scores = (feats[b,-1,:] * av_embs).sum(-1) 126 | 127 | iseq = pos[b,-1] == av 128 | idx = torch.where(iseq)[0] 129 | rank = torch.where(torch.argsort(scores, descending=True)==idx)[0].item() 130 | 131 | if mask[b]: # rep 132 | store['rrep'] += [rank] 133 | else: 134 | store['rnew'] += [rank] 135 | store['rall'] += [rank] 136 | 137 | return store 138 | 139 | 140 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | from sampling import * 9 | 10 | import torch.utils.data as data 11 | from torch.utils.data import DataLoader 12 | 13 | def load_data(args): 14 | INFILE = os.path.join(args.dataset,'100k.csv') 15 | #user,stream,streamer_id,start,stop 16 | cols = ["user","stream","streamer","start","stop"] 17 | data_fu = pd.read_csv(INFILE, header=None, names=cols) 18 | 19 | # Add one for padding 20 | data_fu.user = pd.factorize(data_fu.user)[0]+1 21 | data_fu['streamer_raw'] = data_fu.streamer 22 | data_fu.streamer = pd.factorize(data_fu.streamer)[0]+1 23 | print("Num users: ", data_fu.user.nunique()) 24 | print("Num streamers: ", data_fu.streamer.nunique()) 25 | print("Num interactions: ", len(data_fu)) 26 | print("Estimated watch time: ", (data_fu['stop']-data_fu['start']).sum() * 5 / 60.0) 27 | 28 | args.M = data_fu.user.max()+1 # users 29 | args.N = data_fu.streamer.max()+2 # items 30 | 31 | data_temp = data_fu.drop_duplicates(subset=['streamer','streamer_raw']) 32 | umap = dict(zip(data_temp.streamer_raw.tolist(),data_temp.streamer.tolist())) 33 | 34 | # Splitting and caching 35 | max_step = max(data_fu.start.max(),data_fu.stop.max()) 36 | print("Num timesteps: ", max_step) 37 | args.max_step = max_step 38 | args.pivot_1 = max_step-500 39 | args.pivot_2 = max_step-250 40 | 41 | print("caching availability") 42 | ts = {} 43 | max_avail = 0 44 | for s in range(max_step+1): 45 | all_av = data_fu[(data_fu.start<=s) & (data_fu.stop>s)].streamer.unique().tolist() 46 | ts[s] = all_av 47 | max_avail = max(max_avail,len(ts[s])) 48 | args.max_avail = max_avail 49 | args.ts = ts 50 | print("max_avail: ", max_avail) 51 | 52 | # Compute availability matrix of size (num_timesteps x max_available) 53 | max_av = max([len(v) for k,v in args.ts.items()]) 54 | max_step = max([k for k,v in args.ts.items()])+1 55 | av_tens = torch.zeros(max_step,max_av).type(torch.long) 56 | for k,v in args.ts.items(): 57 | av_tens[k,:len(v)] = torch.LongTensor(v) 58 | args.av_tens = av_tens.to(args.device) 59 | return data_fu 60 | 61 | def get_dataloaders(data_fu, args): 62 | if args.debug: 63 | mu = 1000 64 | else: 65 | mu = int(10e9) 66 | 67 | cache_tr = os.path.join(args.cache_dir,"100k_tr.p") 68 | cache_te = os.path.join(args.cache_dir,"100k_te.p") 69 | cache_va = os.path.join(args.cache_dir,"100k_val.p") 70 | 71 | if args.caching and all(list(map(os.path.isfile,[cache_tr,cache_te,cache_va]))): 72 | datalist_tr = pickle.load(open(cache_tr, "rb")) 73 | datalist_va = pickle.load(open(cache_va, "rb")) 74 | datalist_te = pickle.load(open(cache_te, "rb")) 75 | elif args.caching: 76 | datalist_tr = get_sequences(data_fu,0,args.pivot_1,args,mu) 77 | datalist_va = get_sequences(data_fu,args.pivot_1,args.pivot_2,args,mu) 78 | datalist_te = get_sequences(data_fu,args.pivot_2,args.max_step,args,mu) 79 | 80 | pickle.dump(datalist_te, open(cache_te, "wb")) 81 | pickle.dump(datalist_tr, open(cache_tr, "wb")) 82 | pickle.dump(datalist_va, open(cache_va, "wb")) 83 | 84 | train_loader = DataLoader(datalist_tr,batch_size=args.batch_size, 85 | collate_fn=lambda x: custom_collate(x,args)) 86 | val_loader = DataLoader(datalist_va,batch_size=args.batch_size, 87 | collate_fn=lambda x: custom_collate(x,args)) 88 | test_loader = DataLoader(datalist_te,batch_size=args.batch_size, 89 | collate_fn=lambda x: custom_collate(x,args)) 90 | 91 | return train_loader, val_loader, test_loader 92 | 93 | 94 | def custom_collate(batch,args): 95 | # returns a [batch x seq x feats] tensor 96 | # feats: [padded_positions,positions,inputs_ts,items,users,targets,targets_ts] 97 | 98 | bs = len(batch) 99 | feat_len = len(batch[0]) 100 | batch_seq = torch.zeros(bs,args.seq_len, feat_len, dtype=torch.long) 101 | for ib,b in enumerate(batch): 102 | for ifeat,feat in enumerate(b): 103 | batch_seq[ib,b[0],ifeat] = feat 104 | return batch_seq 105 | 106 | class SequenceDataset(data.Dataset): 107 | def __init__(self, data): 108 | self.data = data 109 | 110 | def __len__(self): 111 | return len(self.data) 112 | 113 | def __getitem__(self, idx): 114 | return self.data[idx] 115 | 116 | def collate_fn_padd(batch): 117 | ''' 118 | Padds batch of variable length 119 | 120 | note: it converts things ToTensor manually here since the ToTensor transform 121 | assume it takes in images rather than arbitrary tensors. 122 | ''' 123 | ## get sequence lengths 124 | lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device) 125 | ## padd 126 | batch = [ torch.Tensor(t).to(device) for t in batch ] 127 | batch = torch.nn.utils.rnn.pad_sequence(batch) 128 | ## compute mask 129 | mask = (batch != 0).to(device) 130 | return batch, lengths, mask 131 | 132 | def get_sequences(_data, _p1, _p2, args, max_u=int(10e9)): 133 | data_list = [] 134 | 135 | _data = _data[_data.stop<_p2].copy() 136 | 137 | grouped = _data.groupby('user') 138 | for user_id, group in tqdm(grouped): 139 | group = group.sort_values('start') 140 | group = group.tail(args.seq_len+1) 141 | if len(group)<2: continue 142 | 143 | group = group.reset_index(drop=True) 144 | 145 | # Get last interaction 146 | last_el = group.tail(1) 147 | yt = last_el.start.values[0] 148 | group.drop(last_el.index,inplace=True) 149 | 150 | # avoid including train in test/validation 151 | if yt < _p1 or yt >= _p2: continue 152 | 153 | padlen = args.seq_len - len(group) 154 | 155 | # sequence input features 156 | positions = torch.LongTensor(group.index.values) 157 | inputs_ts = torch.LongTensor(group.start.values) 158 | items = torch.LongTensor(group['streamer'].values) 159 | users = torch.LongTensor(group.user.values) 160 | bpad = torch.LongTensor(group.index.values + padlen) 161 | 162 | # sequence output features 163 | targets = torch.LongTensor(items[1:].tolist() + [last_el.streamer.values[0]]) 164 | targets_ts = torch.LongTensor(inputs_ts[1:].tolist() + [last_el.start.values[0]]) 165 | 166 | data_list.append([bpad,positions,inputs_ts,items,users,targets,targets_ts]) 167 | 168 | # stop if user limit is reached 169 | if len(data_list)>max_u: break 170 | 171 | return SequenceDataset(data_list) 172 | 173 | 174 | -------------------------------------------------------------------------------- /baselines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import defaultdict,Counter 4 | from sampling import * 5 | 6 | class POP(): 7 | def __init__(self, args, data_tr): 8 | self.args = args 9 | self.cnt = defaultdict(int,dict(Counter(data_tr['streamer'].tolist()))) 10 | def eval(self): pass 11 | def compute_rank(self,data,store,k=10): 12 | inputs = data[:,:,3] # inputs 13 | pos = data[:,:,5] # targets 14 | xtsy = data[:,:,6] # targets ts 15 | 16 | mask = torch.ones_like(pos[:,-1]).type(torch.bool) 17 | for b in range(pos.shape[0]): 18 | avt = pos[b,:-1] 19 | avt = avt[avt!=0] 20 | mask[b] = pos[b,-1] in avt 21 | store['ratio'] += [float(pos[b,-1] in avt)] 22 | 23 | for b in range(inputs.shape[0]): 24 | step = xtsy[b,-1].item() 25 | av = self.args.ts[step] 26 | 27 | scores = np.array([self.cnt[a] for a in av]) 28 | iseq = pos[b,-1] == torch.LongTensor(av).to(self.args.device) 29 | idx = torch.where(iseq) 30 | idx = int(idx[0].item()) 31 | rank = np.where(scores.argsort()[::-1]==idx) 32 | 33 | if mask[b]: # rep 34 | store['rrep'] += [rank] 35 | else: 36 | store['rnew'] += [rank] 37 | store['rall'] += [rank] 38 | return store 39 | 40 | class REP(): 41 | def __init__(self, args, data_tr): 42 | self.args = args 43 | def eval(self): pass 44 | def compute_rank(self,data,store,k=10): 45 | inputs = data[:,:,3] # inputs 46 | pos = data[:,:,5] # targets 47 | xtsy = data[:,:,6] # targets ts 48 | 49 | mask = torch.ones_like(pos[:,-1]).type(torch.bool) 50 | for b in range(pos.shape[0]): 51 | avt = pos[b,:-1] 52 | avt = avt[avt!=0] 53 | mask[b] = pos[b,-1] in avt 54 | store['ratio'] += [float(pos[b,-1] in avt)] 55 | 56 | for b in range(inputs.shape[0]): 57 | step = xtsy[b,-1].item() 58 | av = self.args.ts[step] 59 | 60 | cnt = defaultdict(int,dict(Counter(inputs[b,:-1].tolist()))) 61 | scores = np.array([cnt[a] for a in av]) 62 | iseq = pos[b,-1] == torch.LongTensor(av).to(self.args.device) 63 | idx = torch.where(iseq) 64 | idx = int(idx[0].item()) 65 | rank = np.where(scores.argsort()[::-1]==idx) 66 | 67 | if mask[b]: # rep 68 | store['rrep'] += [rank] 69 | else: 70 | store['rnew'] += [rank] 71 | store['rall'] += [rank] 72 | return store 73 | 74 | class MF(torch.nn.Module): 75 | def __init__(self, args): 76 | super(MF, self).__init__() 77 | self.args = args 78 | self.item_num = args.N 79 | self.mf_avg = False 80 | 81 | self.item_bias = torch.nn.Embedding(args.N+1, 1, padding_idx=0) 82 | self.item_embedding = torch.nn.Embedding(args.N+1, args.K, padding_idx=0) 83 | self.user_embedding = torch.nn.Embedding(args.M+1, args.K, padding_idx=0) 84 | 85 | def forward(self,users,items): 86 | ui = self.user_embedding(users) 87 | ii = self.item_embedding(items) 88 | ib = self.item_bias(items).squeeze() 89 | return (ui * ii).sum(-1) + ib 90 | 91 | def train_step(self, data, use_ctx=False): # for training 92 | bs = data.shape[0] 93 | inputs = data[:,:,3] 94 | pos = data[:,:,5] 95 | users = data[:,:,4] 96 | neg = sample_negs(data,self.args).to(self.args.device) 97 | 98 | pos_logits = self(users,pos) 99 | neg_logits = self(users,neg) 100 | 101 | loss = - (pos_logits - neg_logits).sigmoid().log() 102 | loss = loss[inputs!=0].sum() 103 | 104 | return loss 105 | 106 | def compute_rank(self,data,store,k=10): 107 | inputs = data[:,:,3] 108 | pos = data[:,:,5] 109 | users = data[:,:,4] 110 | xtsy = data[:,:,6] # targets ts 111 | 112 | mask = torch.ones_like(pos[:,-1]).type(torch.bool) 113 | for b in range(pos.shape[0]): 114 | avt = pos[b,:-1] 115 | avt = avt[avt!=0] 116 | mask[b] = pos[b,-1] in avt 117 | store['ratio'] += [float(pos[b,-1] in avt)] 118 | 119 | for b in range(inputs.shape[0]): 120 | step = xtsy[b,-1].item() 121 | av = torch.LongTensor(self.args.ts[step]).to(self.args.device) 122 | av_embs = self.item_embedding(av) 123 | 124 | if self.mf_avg: 125 | inp = inputs[b,:] 126 | mean_vec = self.item_embedding(inp[inp!=0]) 127 | scores = (mean_vec.mean(0).unsqueeze(0) * av_embs).sum(-1) 128 | scores += self.item_bias(av).squeeze() 129 | else: 130 | u_vec = self.user_embedding(users[b,-1]) 131 | scores = (u_vec.unsqueeze(0) * av_embs).sum(-1) 132 | scores += self.item_bias(av).squeeze() 133 | 134 | iseq = pos[b,-1] == av 135 | idx = torch.where(iseq) 136 | idx = idx[0] 137 | rank = torch.where(torch.argsort(scores, descending=True)==idx)[0].item() 138 | 139 | if mask[b]: # rep 140 | store['rrep'] += [rank] 141 | else: 142 | store['rnew'] += [rank] 143 | store['rall'] += [rank] 144 | return store 145 | 146 | class FPMC(torch.nn.Module): 147 | def __init__(self, args): 148 | super(FPMC, self).__init__() 149 | self.args = args 150 | self.item_num = args.N 151 | 152 | self.user_embs = torch.nn.Embedding(args.M+1, args.K, padding_idx=0) 153 | self.item_embs = torch.nn.Embedding(args.N+1, args.K, padding_idx=0) 154 | self.prev_embs = torch.nn.Embedding(args.N+1, args.K, padding_idx=0) 155 | self.next_embs = torch.nn.Embedding(args.N+1, args.K, padding_idx=0) 156 | self.item_bias = torch.nn.Embedding(args.N+1, 1, padding_idx=0) 157 | 158 | def forward(self,users,prev,items): 159 | ui = self.user_embs(users) 160 | ii = self.item_embs(items) 161 | 162 | ip = self.prev_embs(prev) 163 | ic = self.next_embs(items) 164 | 165 | ib = self.item_bias(items).squeeze() 166 | return (ui * ii).sum(-1) + (ip * ic).sum(-1) + ib 167 | 168 | def train_step(self, data, use_ctx=False): # for training 169 | bs = data.shape[0] 170 | inputs = data[:,:,3] 171 | pos = data[:,:,5] 172 | users = data[:,:,4] 173 | neg = sample_negs(data,self.args).to(self.args.device) 174 | 175 | pos_logits = self(users,inputs,pos) 176 | neg_logits = self(users,inputs,neg) 177 | 178 | loss = - (pos_logits[inputs!=0] - neg_logits[inputs!=0]).sigmoid().log() 179 | return loss.sum() 180 | 181 | def compute_rank(self,data,store,k=10): 182 | inputs = data[:,:,3] 183 | pos = data[:,:,5] 184 | users = data[:,:,4] 185 | xtsy = data[:,:,6] # targets ts 186 | neg = sample_negs(data,self.args).to(self.args.device) 187 | 188 | mask = torch.ones_like(pos[:,-1]).type(torch.bool) 189 | for b in range(pos.shape[0]): 190 | avt = pos[b,:-1] 191 | avt = avt[avt!=0] 192 | mask[b] = pos[b,-1] in avt 193 | store['ratio'] += [float(pos[b,-1] in avt)] 194 | 195 | for b in range(inputs.shape[0]): 196 | step = xtsy[b,-1].item() 197 | av = torch.LongTensor(self.args.ts[step]).to(self.args.device) 198 | 199 | ui = self.user_embs(users[b,-1]) 200 | pi = self.prev_embs(inputs[b,-1]) 201 | 202 | scores = (ui.unsqueeze(0) * self.item_embs(av)).sum(-1) 203 | scores += (pi.unsqueeze(0) * self.next_embs(av)).sum(-1) 204 | scores += self.item_bias(av).squeeze() 205 | 206 | iseq = pos[b,-1] == av 207 | idx = torch.where(iseq) 208 | idx = idx[0] 209 | rank = torch.where(torch.argsort(scores, descending=True)==idx)[0].item() 210 | 211 | if mask[b]: # rep 212 | store['rrep'] += [rank] 213 | else: 214 | store['rnew'] += [rank] 215 | store['rall'] += [rank] 216 | return store 217 | 218 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import * 3 | from baselines import * 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | def get_model_type(args): 11 | # Model and data 12 | mto = str(args.mto) 13 | mto += '_' + str(args.K) 14 | mto += '_' + str(args.l2) 15 | mto += '_' + str(args.topk_att) 16 | mto += '_' + str(args.num_att_ctx) 17 | mto += '_' + str(args.seq_len) 18 | mto += "_rep" if args.fr_rep else "" 19 | mto += "_ctx" if args.fr_ctx else "" 20 | mto += '.pt' 21 | 22 | if args.model == "POP": 23 | MODEL = POP 24 | elif args.model == "REP": 25 | MODEL = REP 26 | elif args.model == "MF": 27 | MODEL = MF 28 | elif args.model == "FPMC": 29 | MODEL = FPMC 30 | elif args.model == "LiveRec": 31 | MODEL = LiveRec 32 | 33 | return os.path.join(args.model_path,mto),MODEL 34 | 35 | class PointWiseFeedForward(nn.Module): 36 | def __init__(self, hidden_units, dropout_rate): 37 | 38 | super(PointWiseFeedForward, self).__init__() 39 | 40 | self.conv1 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 41 | self.dropout1 = nn.Dropout(p=dropout_rate) 42 | self.relu = nn.ReLU() 43 | self.conv2 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 44 | self.dropout2 = nn.Dropout(p=dropout_rate) 45 | 46 | def forward(self, inputs): 47 | outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2)))))) 48 | outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length) 49 | outputs += inputs 50 | return outputs 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, args, num_att, num_heads, causality=False): 54 | super(Attention, self).__init__() 55 | self.args = args 56 | self.causality = causality 57 | 58 | self.attention_layernorms = nn.ModuleList() 59 | self.attention_layers = nn.ModuleList() 60 | self.forward_layernorms = nn.ModuleList() 61 | self.forward_layers = nn.ModuleList() 62 | self.last_layernorm = nn.LayerNorm(args.K, eps=1e-8) 63 | 64 | for _ in range(num_att): 65 | new_attn_layernorm = nn.LayerNorm(args.K, eps=1e-8) 66 | self.attention_layernorms.append(new_attn_layernorm) 67 | 68 | new_attn_layer = nn.MultiheadAttention(args.K, 69 | num_heads, 70 | 0.2) 71 | self.attention_layers.append(new_attn_layer) 72 | 73 | new_fwd_layernorm = nn.LayerNorm(args.K, eps=1e-8) 74 | self.forward_layernorms.append(new_fwd_layernorm) 75 | 76 | new_fwd_layer = PointWiseFeedForward(args.K, 0.2) 77 | self.forward_layers.append(new_fwd_layer) 78 | 79 | def forward(self, seqs, timeline_mask=None): 80 | if self.causality: 81 | tl = seqs.shape[1] # time dim len for enforce causality 82 | attention_mask = ~torch.tril(torch.ones((tl, tl), 83 | dtype=torch.bool, 84 | device=self.args.device)) 85 | else: attention_mask = None 86 | 87 | if timeline_mask != None: 88 | seqs *= ~timeline_mask.unsqueeze(-1) 89 | 90 | for i in range(len(self.attention_layers)): 91 | seqs = torch.transpose(seqs, 0, 1) 92 | Q = self.attention_layernorms[i](seqs) 93 | mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 94 | attn_mask=attention_mask) 95 | seqs = Q + mha_outputs 96 | seqs = torch.transpose(seqs, 0, 1) 97 | 98 | seqs = self.forward_layernorms[i](seqs) 99 | seqs = self.forward_layers[i](seqs) 100 | if timeline_mask != None: 101 | seqs *= ~timeline_mask.unsqueeze(-1) 102 | 103 | return self.last_layernorm(seqs) 104 | 105 | 106 | class LiveRec(nn.Module): 107 | def __init__(self, args): 108 | super(LiveRec, self).__init__() 109 | self.args = args 110 | 111 | self.item_embedding = nn.Embedding(args.N+1, args.K, padding_idx=0) 112 | self.pos_emb = nn.Embedding(args.seq_len, args.K) 113 | self.emb_dropout = nn.Dropout(p=0.2) 114 | 115 | # Sequence encoding attention 116 | self.att = Attention(args, 117 | args.num_att, 118 | args.num_heads, 119 | causality=True) 120 | 121 | # Availability attention 122 | self.att_ctx = Attention(args, 123 | args.num_att_ctx, 124 | args.num_heads_ctx, 125 | causality=False) 126 | 127 | # Time interval embedding 128 | # 24h cycles, except for the first one set to 12h 129 | self.boundaries = torch.LongTensor([0]+list(range(77,3000+144, 144))).to(args.device) 130 | self.rep_emb = nn.Embedding(len(self.boundaries)+2, args.K, padding_idx=0) 131 | 132 | def forward(self, log_seqs): 133 | seqs = self.item_embedding(log_seqs) 134 | seqs *= self.item_embedding.embedding_dim ** 0.5 135 | 136 | positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1]) 137 | seqs += self.pos_emb(torch.LongTensor(positions).to(self.args.device)) 138 | 139 | seqs = self.emb_dropout(seqs) 140 | 141 | timeline_mask = (log_seqs == 0).to(self.args.device) 142 | 143 | feats = self.att(seqs, timeline_mask) 144 | 145 | return feats 146 | 147 | def predict(self,feats,inputs,items,ctx,data): 148 | if ctx!=None: i_embs = ctx 149 | else: self.item_embedding(items) 150 | 151 | return (feats * i_embs).sum(dim=-1) 152 | 153 | def compute_rank(self,data,store,k=10): 154 | inputs = data[:,:,3] # inputs 155 | pos = data[:,:,5] # targets 156 | xtsy = data[:,:,6] # targets ts 157 | 158 | feats = self(inputs) 159 | 160 | # Add time interval embeddings 161 | if self.args.fr_ctx: 162 | ctx,batch_inds = self.get_ctx_att(data,feats) 163 | 164 | # identify repeated interactions in the batch 165 | mask = torch.ones_like(pos[:,-1]).type(torch.bool) 166 | for b in range(pos.shape[0]): 167 | avt = pos[b,:-1] 168 | avt = avt[avt!=0] 169 | mask[b] = pos[b,-1] in avt 170 | store['ratio'] += [float(pos[b,-1] in avt)] 171 | 172 | for b in range(inputs.shape[0]): 173 | step = xtsy[b,-1].item() 174 | av = torch.LongTensor(self.args.ts[step]).to(self.args.device) 175 | av_embs = self.item_embedding(av) 176 | 177 | if self.args.fr_ctx: 178 | ctx_expand = torch.zeros(self.args.av_tens.shape[1],self.args.K,device=self.args.device) 179 | ctx_expand[batch_inds[b,-1,:],:] = ctx[b,-1,:,:] 180 | scores = (feats[b,-1,:] * ctx_expand).sum(-1) 181 | scores = scores[:len(av)] 182 | else: 183 | scores = (feats[b,-1,:] * av_embs).sum(-1) 184 | 185 | iseq = pos[b,-1] == av 186 | idx = torch.where(iseq)[0] 187 | rank = torch.where(torch.argsort(scores, descending=True)==idx)[0].item() 188 | 189 | if mask[b]: # rep 190 | store['rrep'] += [rank] 191 | else: 192 | store['rnew'] += [rank] 193 | store['rall'] += [rank] 194 | 195 | return store 196 | 197 | def get_ctx_att(self,data,feats,neg=None): 198 | if not self.args.fr_ctx: return None 199 | 200 | inputs,pos,xtsy = data[:,:,3],data[:,:,5],data[:,:,6] 201 | 202 | # unbatch indices 203 | ci = torch.nonzero(inputs, as_tuple=False) 204 | flat_xtsy = xtsy[ci[:,0],ci[:,1]] 205 | 206 | av = self.args.av_tens[flat_xtsy,:] 207 | av_embs = self.item_embedding(av) 208 | 209 | # repeat consumption: time interval embeddings 210 | if self.args.fr_rep: 211 | av_rep_batch = self.get_av_rep(data) 212 | av_rep_flat = av_rep_batch[ci[:,0],ci[:,1]] 213 | rep_enc = self.rep_emb(av_rep_flat) 214 | av_embs += rep_enc 215 | 216 | flat_feats = feats[ci[:,0],ci[:,1],:] 217 | flat_feats = flat_feats.unsqueeze(1).expand(flat_feats.shape[0], 218 | self.args.av_tens.shape[-1], 219 | flat_feats.shape[1]) 220 | 221 | scores = (av_embs * flat_feats).sum(-1) 222 | inds = scores.topk(self.args.topk_att,dim=1).indices 223 | 224 | # embed selected items 225 | seqs = torch.gather(av_embs, 1, inds.unsqueeze(2) \ 226 | .expand(-1,-1,self.args.K)) 227 | 228 | seqs = self.att_ctx(seqs) 229 | 230 | def expand_att(items): 231 | av_pos = torch.where(av==items[ci[:,0],ci[:,1]].unsqueeze(1))[1] 232 | is_in = torch.any(inds == av_pos.unsqueeze(1),1) 233 | 234 | att_feats = torch.zeros(av.shape[0],self.args.K).to(self.args.device) 235 | att_feats[is_in,:] = seqs[is_in,torch.where(av_pos.unsqueeze(1) == inds)[1],:] 236 | 237 | out = torch.zeros(inputs.shape[0],inputs.shape[1],self.args.K).to(self.args.device) 238 | out[ci[:,0],ci[:,1],:] = att_feats 239 | return out 240 | 241 | # training 242 | if pos != None and neg != None: 243 | return expand_att(pos),expand_att(neg) 244 | # testing 245 | else: 246 | out = torch.zeros(inputs.shape[0],inputs.shape[1],seqs.shape[1],self.args.K).to(self.args.device) 247 | out[ci[:,0],ci[:,1],:] = seqs 248 | batch_inds = torch.zeros(inputs.shape[0],inputs.shape[1],inds.shape[1],dtype=torch.long).to(self.args.device) 249 | batch_inds[ci[:,0],ci[:,1],:] = inds 250 | return out,batch_inds 251 | 252 | def train_step(self, data, use_ctx=False): # for training 253 | inputs,pos = data[:,:,3],data[:,:,5] 254 | neg = sample_negs(data,self.args).to(self.args.device) 255 | 256 | feats = self(inputs) 257 | 258 | ctx_pos,ctx_neg = None,None 259 | if self.args.fr_ctx: 260 | ctx_pos,ctx_neg = self.get_ctx_att(data,feats,neg) 261 | 262 | pos_logits = self.predict(feats,inputs,pos,ctx_pos,data) 263 | neg_logits = self.predict(feats,inputs,neg,ctx_neg,data) 264 | 265 | loss = (-torch.log(pos_logits[inputs!=0].sigmoid()+1e-24) 266 | -torch.log(1-neg_logits[inputs!=0].sigmoid()+1e-24)).sum() 267 | 268 | return loss 269 | 270 | def get_av_rep(self,data): 271 | bs = data.shape[0] 272 | inputs = data[:,:,3] # inputs 273 | xtsb = data[:,:,2] # inputs ts 274 | xtsy = data[:,:,6] # targets ts 275 | 276 | av_batch = self.args.av_tens[xtsy.view(-1),:] 277 | av_batch = av_batch.view(xtsy.shape[0],xtsy.shape[1],-1) 278 | av_batch *= (xtsy!=0).unsqueeze(2) # masking pad inputs 279 | av_batch = av_batch.to(self.args.device) 280 | 281 | mask_caus = 1-torch.tril(torch.ones(self.args.seq_len,self.args.seq_len),diagonal=-1) 282 | mask_caus = mask_caus.unsqueeze(0).unsqueeze(3) 283 | mask_caus = mask_caus.expand(bs,-1,-1,self.args.av_tens.shape[-1]) 284 | mask_caus = mask_caus.type(torch.bool).to(self.args.device) 285 | 286 | tile = torch.arange(self.args.seq_len).unsqueeze(0).repeat(bs,1).to(self.args.device) 287 | 288 | bm = (inputs.unsqueeze(2).unsqueeze(3)==av_batch.unsqueeze(1).expand(-1,self.args.seq_len,-1,-1)) 289 | bm &= mask_caus 290 | 291 | # **WARNING** this is a hacky way to get the last non-zero element in the sequence. 292 | # It works with pytorch 1.8.1 but might break in the future. 293 | sm = bm.type(torch.int).argmax(1) 294 | sm = torch.any(bm,1) * sm 295 | 296 | sm = (torch.gather(xtsy, 1, tile).unsqueeze(2) - 297 | torch.gather(xtsb.unsqueeze(2).expand(-1,-1,self.args.av_tens.shape[-1]), 1, sm)) 298 | sm = torch.bucketize(sm, self.boundaries)+1 299 | sm = torch.any(bm,1) * sm 300 | 301 | sm *= av_batch!=0 302 | sm *= inputs.unsqueeze(2)!=0 303 | return sm 304 | 305 | 306 | --------------------------------------------------------------------------------